diff --git a/torchrec/hybrid_torchrec/src/ids_process/bucketize.cpp b/torchrec/hybrid_torchrec/src/ids_process/bucketize.cpp index 44526df01cb04a16c4bba0e4677d0b583ce9bf66..f657d8fef97233b3a27feaf2e6e9920aefc025c3 100644 --- a/torchrec/hybrid_torchrec/src/ids_process/bucketize.cpp +++ b/torchrec/hybrid_torchrec/src/ids_process/bucketize.cpp @@ -199,7 +199,7 @@ void BlockBucketizeSparseFeaturesCpuKernel(const at::Tensor& lengths, const at:: // 去重逻辑 (需要时启用) if constexpr (DoUnique) { - auto* idsCountData = GetSafeDataPtr(idsCounts, "idsCounts"); + auto* idsCountData = ReturnCount ? GetSafeDataPtr(idsCounts, "idsCounts") : nullptr; int64_t uniqueSize = Deduplicate( newLengthsData, newOffsetsData, offsetsData, indicesData, newIndicesData, unbucketizePermuteData, numFeatures, batchSize, bucketSize, idsCountData); diff --git a/torchrec/hybrid_torchrec/test/dt/test_bucketize_kjt.py b/torchrec/hybrid_torchrec/test/dt/test_bucketize_kjt.py index 6256ab8fd868efc29ef02010fabdffea89d25263..1f00d35390210ee8759070db916355884a88c3ef 100644 --- a/torchrec/hybrid_torchrec/test/dt/test_bucketize_kjt.py +++ b/torchrec/hybrid_torchrec/test/dt/test_bucketize_kjt.py @@ -26,3 +26,29 @@ class TestBucketizeKJTBeforeAll2All: block_sizes = [kjt.lengths()[i * length: i * length + length].sum() for i in range(len(kjt.keys()))] block_sizes = torch.tensor(block_sizes) bucketize_kjt_before_all2all(kjt, world_size, block_sizes) + + @pytest.mark.parametrize("world_size", [1, 2]) + def test_do_unique_without_admit(self, world_size): + kjt = self.create_kjt() + length = len(kjt.lengths()) // len(kjt.keys()) + block_sizes = [kjt.lengths()[i * length: i * length + length].sum() for i in range(len(kjt.keys()))] + block_sizes = torch.tensor(block_sizes) + + bucketized, _ = bucketize_kjt_before_all2all( + kjt, world_size, block_sizes, output_permute=True, do_unique=True, + ) + + assert isinstance(bucketized, KeyedJaggedTensor) + # keys 应在分桶后按 bucket 重复 + assert len(bucketized.keys()) == len(kjt.keys()) * world_size + + # 校验:分桶且去重后,每个 key 段内不应存在重复值 + b_vals = bucketized.values() + b_lens = bucketized.lengths().view(-1).tolist() + pos = 0 + for length in b_lens: + if length == 0: + continue + seg = b_vals[pos:pos + length] + assert torch.unique(seg).numel() == seg.numel() + pos += length diff --git a/torchrec/hybrid_torchrec/test/dt/test_ids_process.py b/torchrec/hybrid_torchrec/test/dt/test_ids_process.py index 1b9a2bc1658dac894260cce157a4f5271b29c023..6d91a4a7ded56bcff7c4aa60c120f6e103b5661d 100644 --- a/torchrec/hybrid_torchrec/test/dt/test_ids_process.py +++ b/torchrec/hybrid_torchrec/test/dt/test_ids_process.py @@ -151,8 +151,10 @@ def test_ids2indices_out(input_size, pin_memory, num_mapper): unique_start = unique_offset[i].item() unique_end = unique_offset[i + 1].item() unique_this = unique[unique_start:unique_end] + unique_ids_this = unique_ids[unique_start:unique_end] unique_inverse_this = unique_inverse[start:end] verify_unique(indices, unique_this, unique_inverse_this) + assert torch.unique(unique_ids_this).numel() == unique_ids_this.numel(), "unique_ids_this is not unique" @pytest.mark.parametrize("input_size", [10000]) diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/embedding.py b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/embedding.py index 46597a728c9297305f64e5e83e87326f948d538e..15f9237bd2cb1621fdb25d1fdc504eb09e42d9f2 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/embedding.py +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/embedding.py @@ -323,10 +323,12 @@ class EmbCacheShardedEmbeddingCollection(ShardedEmbeddingCollection): module.embedding_configs() ) self._table_names: List[str] = [ - config.name for config in self._embedding_configs + config.name + for config in self._embedding_configs ] self._table_name_to_config: Dict[str, EmbCacheEmbeddingConfig] = { - config.name: config for config in self._embedding_configs + config.name: config + for config in self._embedding_configs } self.module_sharding_plan: EmbeddingModuleShardingPlan = cast( EmbeddingModuleShardingPlan, @@ -503,7 +505,8 @@ class EmbCacheShardedEmbeddingCollection(ShardedEmbeddingCollection): for emb_table in lookup.grouped_configs[0].embedding_tables ] emb_names: List[str] = [ - emb_table.name for emb_table in lookup.grouped_configs[0].embedding_tables + emb_table.name + for emb_table in lookup.grouped_configs[0].embedding_tables ] emb_not_admitted_default_value: List[float] = [] for emb_name in emb_names: diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/sharding/rw_sequence_sharding.py b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/sharding/rw_sequence_sharding.py index b0b268dd7abc3c12b517ed4997547486bcd1c385..b6b040bf457b044f0a8624a387a143a796478866 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/sharding/rw_sequence_sharding.py @@ -70,6 +70,7 @@ class EmbCacheRwSequenceEmbeddingSharding(RwSequenceEmbeddingSharding): has_feature_processor=self._has_feature_processor, need_pos=self._need_pos, enable_admit=self._enable_admit, + is_ec=True, ) def create_lookup( diff --git a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/sharding/rw_sharding.py b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/sharding/rw_sharding.py index 3df6adf78739b95b2e92f398acdb0e2ca6e7b45e..fe3ef5ee2067edf86356449b6dc36d6c2246dde9 100644 --- a/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/sharding/rw_sharding.py +++ b/torchrec/torchrec_embcache/src/torchrec_embcache/distributed/sharding/rw_sharding.py @@ -89,6 +89,7 @@ class EmbCacheRwSparseFeaturesDist(RwSparseFeaturesDist): need_pos: bool = False, keep_original_indices: bool = False, enable_admit: bool = False, + is_ec: bool = False, ) -> None: super().__init__( pg, @@ -105,8 +106,7 @@ class EmbCacheRwSparseFeaturesDist(RwSparseFeaturesDist): # local unique只可用于EC(Embedding Collection / Sequence Embedding) yes_str = ("true", "1", "yes") - self._do_unique = os.environ.get("DO_EC_LOCAL_UNIQUE", "False").lower() in yes_str and \ - os.environ.get("USE_EC", "False").lower() in yes_str + self._do_unique = os.environ.get("DO_EC_LOCAL_UNIQUE", "False").lower() in yes_str and is_ec self._enable_admit = enable_admit diff --git a/torchrec/torchrec_embcache/tests/acc_test/run_test.sh b/torchrec/torchrec_embcache/tests/acc_test/run_test.sh index 9d046b90f3b7367d9d4149886fff11ba61ef61ea..653c0dca629101bc28a264f107453007e53730cd 100644 --- a/torchrec/torchrec_embcache/tests/acc_test/run_test.sh +++ b/torchrec/torchrec_embcache/tests/acc_test/run_test.sh @@ -48,5 +48,5 @@ export ASCEND_RT_VISIBLE_DEVICES=6,7 pytest ./test_embedding_cache_pipeline.py pytest ./test_embedding_ec_cache_pipeline.py -# export DO_EC_LOCAL_UNIQUE=1 -# pytest ./test_feature_filter.py +export DO_EC_LOCAL_UNIQUE=1 +pytest ./test_embedding_ec_cache_pipeline.py