From 5fec03eabb0889db684bb6883e0533d63c236ec5 Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 22 Aug 2025 16:19:15 +0800 Subject: [PATCH 1/2] Add feature filter framework implementation --- .../feature_filter/evict_feature_record.cpp | 33 ++ .../feature_filter/evict_feature_record.h | 35 ++ .../csrc/feature_filter/feature_filter.cpp | 139 ++++++ .../csrc/feature_filter/feature_filter.h | 62 +++ .../csrc/utils/string_tools.h | 71 +++ .../sparse/jagged_tensor_with_timestamp.py | 449 ++++++++++++++++++ 6 files changed, 789 insertions(+) create mode 100644 torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/evict_feature_record.cpp create mode 100644 torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/evict_feature_record.h create mode 100644 torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/feature_filter.cpp create mode 100644 torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/feature_filter.h create mode 100644 torchrec/torchrec_embcache/src/torchrec_embcache/csrc/utils/string_tools.h diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/evict_feature_record.cpp b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/evict_feature_record.cpp new file mode 100644 index 00000000..8c82cbb3 --- /dev/null +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/evict_feature_record.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) huawei Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "evict_feature_record.h" + +namespace Embcache { + +bool EvictFeatureRecord::CanRemoveFromEmbTable(uint64_t embUpdateCount) const +{ + return embUpdateCount == executeSwapCount_; +} + +void EvictFeatureRecord::SetSwapCount(uint64_t swapCount) +{ + executeSwapCount_ = swapCount; +} + +void EvictFeatureRecord::ClearEvictInfo() +{ + executeSwapCount_ = 0; + evictKeys_.clear(); +} + +std::vector& EvictFeatureRecord::GetEvictKeys() +{ + return evictKeys_; +} + +} // namespace Embcache diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/evict_feature_record.h b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/evict_feature_record.h new file mode 100644 index 00000000..bb690a63 --- /dev/null +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/evict_feature_record.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) huawei Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef EVICT_FEATURE_RECORD_H +#define EVICT_FEATURE_RECORD_H + +#include +#include +#include +#include + +namespace Embcache { +class EvictFeatureRecord { +public: + EvictFeatureRecord() = default; + bool CanRemoveFromEmbTable(uint64_t embUpdateCount) const; + void ClearEvictInfo(); + void SetSwapCount(uint64_t swapCount); + std::vector& GetEvictKeys(); + +private: + // 触发淘汰时ComputeSwapInfo的执行步数,用于判断调用embTable删除接口的时机 + uint64_t executeSwapCount_ = 0; + + // ComputeSwapInfo于EmbeddingUpdate之间存在执行时间差异,记录embTable待删除的keys + std::vector evictKeys_; +}; + +} // namespace Embcache + +#endif // EVICT_FEATURE_RECORD_H diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/feature_filter.cpp b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/feature_filter.cpp new file mode 100644 index 00000000..cd327ed8 --- /dev/null +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/feature_filter.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) huawei Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "feature_filter.h" + +#include +#include +#include + +#include "common/constants.h" +#include "utils/logger.h" + +namespace Embcache { + +FeatureFilter::FeatureFilter(const std::string& tableName, int32_t admitThreshold, + uint64_t evictThreshold, uint64_t evictStepInterval) + : tableName_(tableName), admitThreshold_(admitThreshold), + evictThreshold_(evictThreshold), evictStepInterval_(evictStepInterval) +{ +} + +void FeatureFilter::RecordTimestamp(const int64_t* featureDataPtr, int64_t startIndex, int64_t endIndex, + const int64_t* timestampDataPtr) +{ + auto beforeRecordSize = timestampRecordMap_.size(); + for (int64_t i = startIndex; i < endIndex; ++i) { + auto feature = *(featureDataPtr + i); + auto timestampData = *(timestampDataPtr + i); + auto timestamp = static_cast(timestampData); + timestampRecordMap_.insert_or_assign(feature, timestamp); + latestTimestamp_ = std::max(latestTimestamp_, timestamp); + } + auto afterRecordSize = timestampRecordMap_.size(); + LOG_DEBUG("Enter RecordTimestamp, beforeRecordSize: {}, afterRecordSize: {}", beforeRecordSize, afterRecordSize); + + // 因记录timestamp和计算swap info存在步数差异,因此记录timestamp时需同时记录淘汰keys + if (recordTsBatchId_ > 0 && (recordTsBatchId_ + 1) % evictStepInterval_ == 0) { + FeatureEvict(); + } + recordTsBatchId_++; +} + +void FeatureFilter::FeatureEvict() +{ + std::vector& evictKeys = evictFeatureRecord_.GetEvictKeys(); + if (evictThreshold_ == 0) { + LOG_DEBUG("Current table evictThreshold is 0, will skip."); + return; + } + + LOG_DEBUG("The latestTimestamp for current table: {}, evictThreshold: {}", latestTimestamp_, evictThreshold_); + auto tempEvictThreshold = static_cast(evictThreshold_); + for (const auto& iter : timestampRecordMap_) { + auto feature = iter.first; + if (feature == -1) { + continue; + } + if (latestTimestamp_ - iter.second > tempEvictThreshold) { + evictKeys.emplace_back(feature); + } + } + // 淘汰掉的key从timestampRecordMap中移出 + bool isAdmitEnabled = admitThreshold_ != -1; + for (const auto& feature : evictKeys) { + timestampRecordMap_.erase(feature); + if (isAdmitEnabled) { + // 开启准入时同时移出准入map中的key + featureRecordMap_.erase(feature); + } + } + LOG_DEBUG("The table name: {}, get evict keys size: {}", tableName_, evictKeys.size()); +} + +const std::unordered_map& FeatureFilter::GetFeatureCountMap() +{ + return featureRecordMap_; +} + +const std::unordered_map& FeatureFilter::GetFeatureTimestampMap() +{ + return timestampRecordMap_; +} + +void FeatureFilter::LoadFeatureRecords(const std::vector& keys, std::vector& counts) +{ + if (keys.size() != counts.size()) { + throw std::runtime_error("Failed to load key count info, vector size is not same between keys and counts."); + } + for (size_t i = 0; i < keys.size(); ++i) { + featureRecordMap_[keys[i]].count = counts[i]; + } +} + +void FeatureFilter::LoadTimestampRecords(const std::vector& keys, std::vector& timestamps) +{ + if (keys.size() != timestamps.size()) { + throw std::runtime_error("Failed to load timestamp info, vector size is not same between keys and timestamps."); + } + for (size_t i = 0; i < keys.size(); ++i) { + timestampRecordMap_[keys[i]] = static_cast(timestamps[i]); + } +} + +void FeatureFilter::StatisticsKeyCount(const int64_t* featureDataPtr, const int64_t* countDataPtr, int64_t startIndex, + int64_t endIndex, bool isCountDataEmpty) +{ + for (int64_t i = startIndex; i < endIndex; ++i) { + auto feature = *(featureDataPtr + i); + auto count = isCountDataEmpty ? 1 : *(countDataPtr + i); + auto iter = featureRecordMap_.find(feature); + if (iter != featureRecordMap_.end()) { + iter->second.count += count; + } else { + FeatureRecord featureRecord = {count}; + featureRecordMap_[feature] = featureRecord; + } + } +} + +void FeatureFilter::CountFilter(int64_t* featureDataPtr, int64_t startIndex, int64_t endIndex) +{ + // 准入检查,将未准入的特征置为-1 + auto thresholdCount = static_cast(admitThreshold_); + for (int64_t i = startIndex; i < endIndex; ++i) { + auto feature = *(featureDataPtr + i); + auto iter = featureRecordMap_.find(feature); + if (iter != featureRecordMap_.end() && iter->second.count < thresholdCount) { + LOG_DEBUG("Feature filtered out due to insufficient count. TableName: {}, Feature: {}, Count: {}, " + "Threshold: {}", tableName_, feature, iter->second.count, thresholdCount); + *(featureDataPtr + i) = INVALID_KEY; + } + } +} + +} // namespace Embcache diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/feature_filter.h b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/feature_filter.h new file mode 100644 index 00000000..1a1a0dfc --- /dev/null +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/feature_filter/feature_filter.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) huawei Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef FEATURE_FILTER_H +#define FEATURE_FILTER_H + +#include +#include +#include +#include +#include + +#include "evict_feature_record.h" + +namespace Embcache { + +struct FeatureRecord { + uint64_t count; +}; + +class FeatureFilter { +public: + FeatureFilter(const std::string& tableName, int32_t admitThreshold, uint64_t evictThreshold, + uint64_t evictStepInterval); + void StatisticsKeyCount(const int64_t* featureDataPtr, const int64_t* countDataPtr, int64_t startIndex, + int64_t endIndex, bool isCountDataEmpty); + void CountFilter(int64_t* featureDataPtr, int64_t startIndex, int64_t endIndex); + void RecordTimestamp(const int64_t* featureDataPtr, int64_t startIndex, int64_t endIndex, + const int64_t* timestampDataPtr); + void FeatureEvict(); + + // 要从embTable中删除的key信息,待lookup执行到和GetSwapInfo相同步数后删除key对应emb + EvictFeatureRecord evictFeatureRecord_; + + const std::unordered_map& GetFeatureCountMap(); + const std::unordered_map& GetFeatureTimestampMap(); + + void LoadFeatureRecords(const std::vector& keys, std::vector& counts); + void LoadTimestampRecords(const std::vector& keys, std::vector& timestamps); + +private: + std::string tableName_; + + // 准入相关配置 + int32_t admitThreshold_ = -1; // 准入阈值,默认值表示未开启准入 + std::unordered_map featureRecordMap_; // 准入,记录key次数 + + // 淘汰相关配置 + uint64_t evictThreshold_ = 0; // unit: second + uint64_t evictStepInterval_ = 0; // 淘汰间隔步数 + uint64_t recordTsBatchId_ = 0; + std::time_t latestTimestamp_ = 0; // 当前表最新的时间戳,用于判断淘汰 + std::unordered_map timestampRecordMap_; // 淘汰,记录key时间戳 +}; + +} // namespace Embcache + +#endif // FEATURE_FILTER_H diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/utils/string_tools.h b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/utils/string_tools.h new file mode 100644 index 00000000..a99af176 --- /dev/null +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/utils/string_tools.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) huawei Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef EMBEDDING_CACHE_UTILS_STRING_TOOLS_H +#define EMBEDDING_CACHE_UTILS_STRING_TOOLS_H + +#include +#include +#include + +namespace Embcache { + +class StringTools { +public: + template + static std::string ToString(std::vector>& items) + { + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < items.size(); i++) { + ss << items[i].first << ":" << items[i].second; + if (i != items.size() - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + + template + static std::string ToString(const std::vector& items) + { + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < items.size(); i++) { + ss << items[i]; + if (i != items.size() - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + + template + static std::string ToString(const T* data, size_t size) + { + if (data == nullptr || size == 0) { + return "[]"; + } + + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < size; ++i) { + if (i != 0) { + ss << ", "; + } + ss << data[i]; + } + ss << "]"; + + return ss.str(); + } +}; + +} // namespace Embcache +#endif // EMBEDDING_CACHE_UTILS_STRING_TOOLS_H \ No newline at end of file diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/sparse/jagged_tensor_with_timestamp.py b/torchrec/torchrec_embcache/src/torchrec_embcache/sparse/jagged_tensor_with_timestamp.py index 80207477..f0d1be70 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/sparse/jagged_tensor_with_timestamp.py +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/sparse/jagged_tensor_with_timestamp.py @@ -5,15 +5,464 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from typing import Optional, Dict, List, Tuple + +import torch + from torchrec.sparse.jagged_tensor import ( JaggedTensor, KeyedJaggedTensor, + _pin_and_move, + _permute_tensor_by_segments, ) +from torchrec.pt2.checks import is_torchdynamo_compiling, is_non_strict_exporting class JaggedTensorWithTimestamp(JaggedTensor): _fields = ["_timestamps"] + def __init__( + self, + values: torch.Tensor, + weights: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + timestamps: Optional[torch.Tensor] = None, + ) -> None: + if timestamps is not None and values.size() != timestamps.size(): + raise ValueError( + f"timestamps size must same with values, but got timestamp size:{timestamps.size()}," + f" values size:{values.size()}." + ) + + super().__init__(values, weights, lengths, offsets) + + # 和values值对应的时间戳,size需和values相同, 仅在input dist前使用 + self._timestamps = timestamps + + @property + def timestamps(self): + return self._timestamps + class KeyedJaggedTensorWithTimestamp(KeyedJaggedTensor): _fields = ["_timestamps"] + + def __init__( + self, + keys: List[str], + values: torch.Tensor, + timestamps: Optional[torch.Tensor] = None, + weights: Optional[torch.Tensor] = None, + lengths: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + stride: Optional[int] = None, + stride_per_key_per_rank: Optional[List[List[int]]] = None, + # Below exposed to ensure torch.script-able + stride_per_key: Optional[List[int]] = None, + length_per_key: Optional[List[int]] = None, + lengths_offset_per_key: Optional[List[int]] = None, + offset_per_key: Optional[List[int]] = None, + index_per_key: Optional[Dict[str, int]] = None, + jt_dict: Optional[Dict[str, JaggedTensor]] = None, + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, + ) -> None: + super().__init__( + keys, + values, + weights, + lengths, + offsets, + stride, + stride_per_key_per_rank, + stride_per_key, + length_per_key, + lengths_offset_per_key, + offset_per_key, + index_per_key, + jt_dict, + inverse_indices, + ) + + self._timestamps: torch.Tensor = timestamps + + @property + def timestamps(self) -> torch.Tensor: + return self._timestamps + + @staticmethod + def from_jt_dict( + jt_dict: Dict[str, JaggedTensorWithTimestamp], + ) -> "KeyedJaggedTensorWithTimestamp": + """ + Constructs a KeyedJaggedTensor from a dictionary of JaggedTensorWithTimestamps. + Automatically calls `kjt.sync()` on newly created KJT. + + Args: + jt_dict (Dict[str, JaggedTensor]): dictionary of JaggedTensors. + + Returns: + KeyedJaggedTensorWithTimestamp: constructed KeyedJaggedTensorWithTimestamp. + """ + kjt_keys = list(jt_dict.keys()) + kjt_vals_list: List[torch.Tensor] = [] + kjt_timestamps_list: List[torch.Tensor] = [] + kjt_lens_list: List[torch.Tensor] = [] + kjt_weights_list: List[torch.Tensor] = [] + stride_per_key: List[int] = [] + for jt in jt_dict.values(): + stride_per_key.append(len(jt.lengths())) + kjt_vals_list.append(jt.values()) + kjt_timestamps_list.append(jt.timestamps) + kjt_lens_list.append(jt.lengths()) + weight = jt.weights_or_none() + if weight is not None: + kjt_weights_list.append(weight) + kjt_vals = torch.concat(kjt_vals_list) + kjt_lens = torch.concat(kjt_lens_list) + + # handle custom attribute: timestamps + kjt_timestamps = ( + torch.concat(kjt_timestamps_list) if len(kjt_timestamps_list) > 0 else None + ) + + kjt_weights = ( + torch.concat(kjt_weights_list) if len(kjt_weights_list) > 0 else None + ) + kjt_stride, kjt_stride_per_key_per_rank = ( + (stride_per_key[0], None) + if all(s == stride_per_key[0] for s in stride_per_key) + else (None, [[stride] for stride in stride_per_key]) + ) + kjt = KeyedJaggedTensorWithTimestamp( + keys=kjt_keys, + values=kjt_vals, + timestamps=kjt_timestamps, + weights=kjt_weights, + lengths=kjt_lens, + stride=kjt_stride, + stride_per_key_per_rank=kjt_stride_per_key_per_rank, + ).sync() + return kjt + + def split(self, segments: List[int]) -> List["KeyedJaggedTensorWithTimestamp"]: + split_list: List[KeyedJaggedTensorWithTimestamp] = [] + start = 0 + start_offset = 0 + _length_per_key = self.length_per_key() + _offset_per_key = self.offset_per_key() + for segment in segments: + end = start + segment + end_offset = _offset_per_key[end] + keys: List[str] = self._keys[start:end] + + stride, stride_per_key_per_rank = ( + (None, self.stride_per_key_per_rank()[start:end]) + if self.variable_stride_per_key() + else (self._stride, None) + ) + if segment == len(self._keys): + # no torch slicing required + split_list.append( + KeyedJaggedTensorWithTimestamp( + keys=self._keys, + values=self._values, + timestamps=self._timestamps, + weights=self.weights_or_none(), + lengths=self._lengths, + offsets=self._offsets, + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=self._length_per_key, + offset_per_key=self._offset_per_key, + index_per_key=self._index_per_key, + jt_dict=self._jt_dict, + ) + ) + elif segment == 0: + empty_int_list: List[int] = torch.jit.annotate(List[int], []) + split_list.append( + KeyedJaggedTensorWithTimestamp( + keys=keys, + values=torch.tensor( + empty_int_list, + device=self.device(), + dtype=self._values.dtype, + ), + timestamps=torch.tensor( + empty_int_list, + device=self.device(), + dtype=self._timestamps.dtype, + ), + weights=( + None + if self.weights_or_none() is None + else torch.tensor( + empty_int_list, + device=self.device(), + dtype=self.weights().dtype, + ) + ), + lengths=torch.tensor( + empty_int_list, device=self.device(), dtype=torch.int + ), + offsets=torch.tensor( + empty_int_list, device=self.device(), dtype=torch.int + ), + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=None, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + ) + ) + else: + split_length_per_key = _length_per_key[start:end] + split_list.append( + KeyedJaggedTensorWithTimestamp( + keys=keys, + values=self._values[start_offset:end_offset], + timestamps=( + self._timestamps[start_offset:end_offset] + if self._timestamps is not None + else None + ), + weights=( + None + if self.weights_or_none() is None + else self.weights()[start_offset:end_offset] + ), + lengths=self.lengths()[ + self.lengths_offset_per_key()[ + start + ]: self.lengths_offset_per_key()[end] + ], + offsets=None, + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=split_length_per_key, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + ) + ) + start = end + start_offset = end_offset + return split_list + + def permute( + self, indices: List[int], indices_tensor: Optional[torch.Tensor] = None + ) -> "KeyedJaggedTensorWithTimestamp": + """ + Permutes the KeyedJaggedTensorWithTimestamp. + + Args: + indices (List[int]): list of indices. + indices_tensor (Optional[torch.Tensor]): tensor of indices. + + Returns: + KeyedJaggedTensorWithTimestamp: permuted KeyedJaggedTensorWithTimestamp. + """ + if indices_tensor is None: + indices_tensor = torch.tensor( + indices, dtype=torch.int, device=self.device() + ) + + length_per_key = self.length_per_key() + permuted_keys: List[str] = [] + permuted_stride_per_key_per_rank: List[List[int]] = [] + permuted_length_per_key: List[int] = [] + permuted_length_per_key_sum = 0 + for index in indices: + key = self.keys()[index] + permuted_keys.append(key) + permuted_length_per_key.append(length_per_key[index]) + if self.variable_stride_per_key(): + permuted_stride_per_key_per_rank.append( + self.stride_per_key_per_rank()[index] + ) + + permuted_length_per_key_sum = sum(permuted_length_per_key) + if not torch.jit.is_scripting() and is_non_strict_exporting(): + torch._check_is_size(permuted_length_per_key_sum) + torch._check(permuted_length_per_key_sum != -1) + torch._check(permuted_length_per_key_sum != 0) + + if self.variable_stride_per_key(): + length_per_key_tensor = _pin_and_move( + torch.tensor(self.length_per_key()), self.device() + ) + stride_per_key_tensor = _pin_and_move( + torch.tensor(self.stride_per_key()), self.device() + ) + permuted_lengths, _ = _permute_tensor_by_segments( + self.lengths(), + stride_per_key_tensor, + indices_tensor, + None, + ) + permuted_values, permuted_weights = _permute_tensor_by_segments( + self.values(), + length_per_key_tensor, + indices_tensor, + self.weights_or_none(), + ) + permuted_timestamps, _ = _permute_tensor_by_segments( + self.timestamps, + length_per_key_tensor, + indices_tensor, + self.weights_or_none(), + ) + elif is_torchdynamo_compiling() and not torch.jit.is_scripting(): + ( + permuted_lengths, + permuted_values, + permuted_weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + indices_tensor, + self.lengths(), + self.values(), + self.stride(), + self.weights_or_none(), + permuted_length_per_key_sum, + ) + _, permuted_timestamps, _ = torch.ops.fbgemm.permute_2D_sparse_data_input1D( + indices_tensor, + self.lengths(), + self.timestamps, + self.stride(), + self.weights_or_none(), + permuted_length_per_key_sum, + ) + else: + ( + permuted_lengths, + permuted_values, + permuted_weights, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + indices_tensor, + self.lengths().view(len(self._keys), -1), + self.values(), + self.weights_or_none(), + permuted_length_per_key_sum, + ) + _, permuted_timestamps, _ = torch.ops.fbgemm.permute_2D_sparse_data( + indices_tensor, + self.lengths().view(len(self._keys), -1), + self.timestamps, + self.weights_or_none(), + permuted_length_per_key_sum, + ) + stride_per_key_per_rank = ( + permuted_stride_per_key_per_rank if self.variable_stride_per_key() else None + ) + kjt = KeyedJaggedTensorWithTimestamp( + keys=permuted_keys, + values=permuted_values, + timestamps=permuted_timestamps, + weights=permuted_weights, + lengths=permuted_lengths.view(-1), + offsets=None, + stride=self._stride, + stride_per_key_per_rank=stride_per_key_per_rank, + stride_per_key=None, + length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None, + lengths_offset_per_key=None, + offset_per_key=None, + index_per_key=None, + jt_dict=None, + inverse_indices=None, + ) + return kjt + + def pin_memory(self) -> "KeyedJaggedTensorWithTimestamp": + weights = self._weights + lengths = self._lengths + offsets = self._offsets + stride, stride_per_key_per_rank = ( + (None, self._stride_per_key_per_rank) + if self.variable_stride_per_key() + else (self._stride, None) + ) + + return KeyedJaggedTensorWithTimestamp( + keys=self._keys, + values=self._values.pin_memory(), + timestamps=( + self._timestamps.pin_memory() if self._timestamps is not None else None + ), + weights=weights.pin_memory() if weights is not None else None, + lengths=lengths.pin_memory() if lengths is not None else None, + offsets=offsets.pin_memory() if offsets is not None else None, + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=self._length_per_key, + offset_per_key=self._offset_per_key, + index_per_key=self._index_per_key, + jt_dict=None, + ) + + def to( + self, device: torch.device, non_blocking: bool = False + ) -> "KeyedJaggedTensorWithTimestamp": + weights = self._weights + lengths = self._lengths + offsets = self._offsets + stride, stride_per_key_per_rank = ( + (None, self._stride_per_key_per_rank) + if self.variable_stride_per_key() + else (self._stride, None) + ) + length_per_key = self._length_per_key + offset_per_key = self._offset_per_key + index_per_key = self._index_per_key + jt_dict = self._jt_dict + + return KeyedJaggedTensorWithTimestamp( + keys=self._keys, + values=self._values.to(device, non_blocking=non_blocking), + timestamps=( + self._timestamps.to(device, non_blocking=non_blocking) + if self._timestamps is not None + else None + ), + weights=( + weights.to(device, non_blocking=non_blocking) + if weights is not None + else None + ), + lengths=( + lengths.to(device, non_blocking=non_blocking) + if lengths is not None + else None + ), + offsets=( + offsets.to(device, non_blocking=non_blocking) + if offsets is not None + else None + ), + stride=stride, + stride_per_key_per_rank=stride_per_key_per_rank, + length_per_key=length_per_key, + offset_per_key=offset_per_key, + index_per_key=index_per_key, + jt_dict=jt_dict, + ) + + @torch.jit.unused + def record_stream(self, stream: torch.cuda.streams.Stream) -> None: + super().record_stream(stream) + if self._timestamps is not None: + self._timestamps.record_stream(stream) + + def to_dict(self) -> Dict[str, JaggedTensor]: + # invoke base class's method, and will discard timestamp data. + return super().to_dict() + + def dist_splits(self, key_splits: List[int]) -> List[List[int]]: + return NotImplemented + + def dist_tensors(self) -> List[torch.Tensor]: + return NotImplemented -- Gitee From 406343d455700b80f9cc27490b97b4a3e142cb47 Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 22 Aug 2025 16:51:00 +0800 Subject: [PATCH 2/2] Rename featureFilters to featureFilters_ with trailing underscore to follow naming convention --- .../csrc/embedding_cache/embcache_manager.cpp | 87 ++++++++++++++++++- .../csrc/embedding_cache/embcache_manager.h | 2 + 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.cpp b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.cpp index 52c8d2f4..60d44769 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.cpp +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.cpp @@ -18,6 +18,7 @@ #include "utils/logger.h" #include "utils/time_cost.h" +#include "utils/string_tools.h" using namespace Embcache; @@ -49,7 +50,9 @@ EmbcacheManager::EmbcacheManager(const std::vector& embConfigs, bool } if (embConfigs[i].admitAndEvictConfig.IsFeatureFilterEnabled()) { - // 待补充 feature filter 初始化 + auto& aaeConfig = embConfigs[i].admitAndEvictConfig; + featureFilters_.emplace_back(FeatureFilter(embConfigs[i].tableName, aaeConfig.admitThreshold, + aaeConfig.evictThreshold, aaeConfig.evictStepInterval)); } } TORCH_CHECK(embConfigs.size() > 0, "ERROR, Size of embConfigs must > 0") @@ -99,7 +102,7 @@ SwapInfo EmbcacheManager::ComputeSwapInfo(const at::Tensor& batchKeys, const std for (int64_t i = 0; i < curTableIndices.size(); i++) { int64_t idx = curTableIndices[i]; if (embConfigs_[idx].admitAndEvictConfig.IsAdmitEnabled()) { - // 待补充 feature filter 统计 + featureFilters_[idx].CountFilter(keyPtr, offsetPerKey[i], offsetPerKey[i + 1]); } // 取出每个表的 key @@ -272,10 +275,44 @@ void EmbcacheManager::EmbeddingUpdate(const std::vector>& s void EmbcacheManager::RecordTimestamp(const at::Tensor& batchKeys, const std::vector& offsetPerKey, const at::Tensor& timestamps, const std::vector& tableIndices) { + LOG_INFO("Start invoke mgmt RecordTimestamp"); + TimeCost recordTimestampTC; + const auto* keyPtr = batchKeys.data_ptr(); + const auto* timestampsPtr = timestamps.data_ptr(); + const std::vector& curTableIndices = tableIndices.empty() ? embTableIndies_ : tableIndices; + TORCH_CHECK(curTableIndices.size() + 1 == offsetPerKey.size(), + "tableIndices size+1 must be equal to offsetPerKey size"); + + for (int64_t i = 0; i < embNum_; ++i) { + int32_t idx = curTableIndices[i]; + if (embConfigs_[idx].admitAndEvictConfig.IsEvictEnabled()) { + featureFilters_[idx].RecordTimestamp(keyPtr, offsetPerKey[i], offsetPerKey[i + 1], timestampsPtr); + } + } + LOG_INFO("RecordTimestamp execution time: {} ms", recordTimestampTC.ElapsedMS()); } void EmbcacheManager::EvictFeatures() { + LOG_INFO("Start invoke EvictFeatures method, ComputeSwapInfo execute times: {}", swapCount_); + TimeCost evictFeaturesTC; + size_t evictKeyCount = 0; + for (int32_t i = 0; i < embNum_; ++i) { + if (!embConfigs_[i].admitAndEvictConfig.IsEvictEnabled()) { + LOG_INFO("The table: {} doesn't enable evict, skip feature evict.", embConfigs_[i].tableName); + continue; + } + + // 获取当前表要淘汰的keys + const std::vector& evictFeatures = featureFilters_[i].evictFeatureRecord_.GetEvictKeys(); + // 调用swapManager删除映射信息 + // 删除embeddingTables中的embedding待对应step的swap out emb update执行完成后触发 + swapManagers_[i].RemoveKeys(evictFeatures); + featureFilters_[i].evictFeatureRecord_.SetSwapCount(swapCount_); + evictKeyCount += evictFeatures.size(); + } + LOG_INFO("EvictFeatures execution time : {} ms, all table evictKeyCount : {}", evictFeaturesTC.ElapsedMS(), + evictKeyCount); } void EmbcacheManager::RecordEmbeddingUpdateTimes() @@ -297,14 +334,60 @@ AsyncTask EmbcacheManager::EmbeddingUpdateAsync(const SwapInfo& swapInfo, } bool EmbcacheManager::NeedEvictEmbeddingTable() { + for (int32_t i = 0; i < embNum_; ++i) { + // 开启淘汰 + if (!embConfigs_[i].admitAndEvictConfig.IsEvictEnabled()) { + continue; + } + // 待删除embTable的keys非空且达到和GetSwapInfo相同的步数 + if (!featureFilters_[i].evictFeatureRecord_.GetEvictKeys().empty() && + featureFilters_[i].evictFeatureRecord_.CanRemoveFromEmbTable(embUpdateCount_)) { + return true; + } + } return false; } void EmbcacheManager::RemoveEmbeddingTableInfo() { + LOG_INFO("Start invoke RemoveEmbeddingTableInfo, embUpdateCount_: {}", embUpdateCount_); + TimeCost removeEmbeddingTableTC; + for (int32_t i = 0; i < embNum_; ++i) { + auto& keys = featureFilters_[i].evictFeatureRecord_.GetEvictKeys(); + if (keys.empty()) { + LOG_INFO("Feature keys list is empty, skip to remove embedding from table: {}", embConfigs_[i].tableName); + continue; + } + + embeddingTables_[i]->RemoveEmbedding(keys); + LOG_INFO("Remove table embedding info, table : {}, remove key size : {}, detail keys : {}", + embConfigs_[i].tableName, keys.size(), StringTools::ToString(keys)); + featureFilters_[i].evictFeatureRecord_.ClearEvictInfo(); + } + LOG_INFO("RemoveEmbeddingTableInfo execution time: {} ms", removeEmbeddingTableTC.ElapsedMS()); } void EmbcacheManager::StatisticsKeyCount(const at::Tensor& batchKeys, const torch::Tensor& offset, const at::Tensor& batchKeyCounts, int64_t tableIndex) { + LOG_INFO("StatisticsKeyCount, tableIndex : {}, isAdmit : {}", + tableIndex, embConfigs_[tableIndex].admitAndEvictConfig.IsAdmitEnabled()); + if (!embConfigs_[tableIndex].admitAndEvictConfig.IsAdmitEnabled()) { + return; + } + TORCH_CHECK(offset.numel() > tableIndex, "param error, tableIndex need be smaller than offset length," + " but got equal or greater than offset length.") + // 未开启local unique时,counts为空tensor,处理时默认key对应count为1 + bool isCountDataEmpty = batchKeyCounts.numel() == 0; + if (!isCountDataEmpty) { + TORCH_CHECK(batchKeys.numel() == batchKeyCounts.numel(), + "batchKeys length should equal with batchKeyCounts length when batchKeyCounts is not empty.") + } + auto* featureDataPtr = batchKeys.data_ptr(); + auto* countDataPtr = batchKeyCounts.data_ptr(); + auto* offsetDataPtr = offset.data_ptr(); + int64_t start = offsetDataPtr[tableIndex]; + int64_t end = offsetDataPtr[tableIndex + 1]; + TORCH_CHECK(end <= batchKeys.numel()) + featureFilters_[tableIndex].StatisticsKeyCount(featureDataPtr, countDataPtr, start, end, isCountDataEmpty); } diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.h b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.h index 1d74e608..b85a62ad 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.h +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/csrc/embedding_cache/embcache_manager.h @@ -16,6 +16,7 @@ #include "common/common.h" #include "emb_table/emb_table.h" +#include "feature_filter/feature_filter.h" #include "swap_manager.h" #include "utils/async_task.h" #include "utils/thread_pool.h" @@ -129,6 +130,7 @@ private: std::vector embConfigs_; std::vector swapManagers_; std::vector> embeddingTables_; + std::vector featureFilters_; uint64_t swapCount_ = 0; // ComputeSwapInfo 执行次数 uint64_t embUpdateCount_ = 0; // EmbeddingUpdate 执行次数 -- Gitee