From 671919c1879b94b8f439c94eb38d1431bdcbacf7 Mon Sep 17 00:00:00 2001 From: xiangpx Date: Wed, 20 Aug 2025 16:54:18 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E4=BF=AE=E5=A4=8DIdsMapper=E4=B8=AD?= =?UTF-8?q?=E7=9A=84hidx=E6=A3=80=E6=9F=A5=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E5=B9=B6=E5=A2=9E=E5=8A=A0do=5Funique=5Fhash=5Fout=E7=9A=84?= =?UTF-8?q?=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/ids_process/ids_mapper.cpp | 3 +-- .../test/dt/test_post_input_dist.py | 17 +++++++++++++++++ .../tests/acc_test/run_test.sh | 13 ++++++++++--- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/torchrec/hybrid_torchrec/src/ids_process/ids_mapper.cpp b/torchrec/hybrid_torchrec/src/ids_process/ids_mapper.cpp index 6a4c5fd3..d32089e5 100644 --- a/torchrec/hybrid_torchrec/src/ids_process/ids_mapper.cpp +++ b/torchrec/hybrid_torchrec/src/ids_process/ids_mapper.cpp @@ -201,9 +201,8 @@ size_t IdsMapper::ProcessIds2Indices(IdsMapper& mapper, std::vector& un } int64_t hidx = hashIdxPtr[i]; - if (hidx >= static_cast(fullMap->size())) { - TORCH_CHECK(hidx > (INT64_MAX - 1) / 2, "hidx is too large: ", hidx); + TORCH_CHECK(hidx < (INT64_MAX - 1) / 2, "hidx is too large: ", hidx, ">=", (INT64_MAX - 1) / 2); fullMap->resize(hidx * EXPAND_CAPACITY_RATE + 1, -1); bitmap = fullMap->data(); } diff --git a/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py b/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py index 5faf7233..25f67f0d 100644 --- a/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py +++ b/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py @@ -5,15 +5,18 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os from unittest.mock import patch import torch from hybrid_torchrec.distributed.sharding.post_input_dist import ( split_keys_offset, do_unique_hash, + do_unique_hash_out, HashMapBase, KeyedJaggedTensorWithLookHelper ) +from hybrid_torchrec.modules.ids_process import IdsMapper from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -85,6 +88,20 @@ class TestDoUniqueHash: assert len(result.unique_indices) == 9 +class TestDoUniqueHashOut: + @staticmethod + def test_do_unique_hash_out_with_parallel(): + with patch.dict(os.environ, {"ENABLE_PARALLEL_GLOBAL_UNIQUE": "1"}): + kjt = KeyedJaggedTensor( + keys=["f1"], + values=torch.tensor([1, 1, 2]), + lengths=torch.tensor([3]), + offsets=torch.tensor([0, 3]) + ) + result = do_unique_hash_out(kjt, [1], [IdsMapper(128)]) + assert len(result.unique_indices) == 2 + + class TestSplitKeysOffset: @staticmethod def test_split_key_offset_f0(): diff --git a/torchrec/torchrec_embcache/tests/acc_test/run_test.sh b/torchrec/torchrec_embcache/tests/acc_test/run_test.sh index 653c0dca..ca92a18f 100644 --- a/torchrec/torchrec_embcache/tests/acc_test/run_test.sh +++ b/torchrec/torchrec_embcache/tests/acc_test/run_test.sh @@ -46,7 +46,14 @@ export WORLD_SIZE=2 export ASCEND_RT_VISIBLE_DEVICES=6,7 pytest ./test_embedding_cache_pipeline.py -pytest ./test_embedding_ec_cache_pipeline.py -export DO_EC_LOCAL_UNIQUE=1 -pytest ./test_embedding_ec_cache_pipeline.py +( + export ENABLE_PARALLEL_GLOBAL_UNIQUE=1 + pytest ./test_embedding_ec_cache_pipeline.py +) + +( + export DO_EC_LOCAL_UNIQUE=1 + export LOCAL_UNIQUE_PARALLEL_BATCH_NUM=4 + pytest ./test_embedding_ec_cache_pipeline.py +) -- Gitee From 7746fe4182694b26e3c631a76db31c36f07075cf Mon Sep 17 00:00:00 2001 From: xiangpx Date: Wed, 20 Aug 2025 17:03:36 +0800 Subject: [PATCH 2/4] fix dt --- torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py b/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py index 25f67f0d..d97bd353 100644 --- a/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py +++ b/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py @@ -90,6 +90,7 @@ class TestDoUniqueHash: class TestDoUniqueHashOut: @staticmethod + @patch("torch.Tensor.pin_memory", new=lambda self, *args, **kwargs: self) def test_do_unique_hash_out_with_parallel(): with patch.dict(os.environ, {"ENABLE_PARALLEL_GLOBAL_UNIQUE": "1"}): kjt = KeyedJaggedTensor( -- Gitee From d48482b9939419e2a0f51479b3539a862a4b66ee Mon Sep 17 00:00:00 2001 From: xiangpx Date: Wed, 20 Aug 2025 17:48:01 +0800 Subject: [PATCH 3/4] fix dt --- torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py b/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py index d97bd353..2a5e2012 100644 --- a/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py +++ b/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py @@ -9,6 +9,7 @@ import os from unittest.mock import patch import torch +from torch import empty_like from hybrid_torchrec.distributed.sharding.post_input_dist import ( split_keys_offset, do_unique_hash, @@ -31,6 +32,11 @@ class MockHashMap(HashMapBase): pass +def empty_like_without_pin_memory(*args, **kwargs): + kwargs['pin_memory'] = False + return empty_like(*args, **kwargs) + + class TestDoUniqueHash: @staticmethod @patch("torch.Tensor.pin_memory", new=lambda self, *args, **kwargs: self) @@ -90,7 +96,7 @@ class TestDoUniqueHash: class TestDoUniqueHashOut: @staticmethod - @patch("torch.Tensor.pin_memory", new=lambda self, *args, **kwargs: self) + @patch("torch.empty_like", new=empty_like_without_pin_memory) def test_do_unique_hash_out_with_parallel(): with patch.dict(os.environ, {"ENABLE_PARALLEL_GLOBAL_UNIQUE": "1"}): kjt = KeyedJaggedTensor( -- Gitee From a53314bc80badd0e6656d1e96bece6fe133710c8 Mon Sep 17 00:00:00 2001 From: xiangpx Date: Wed, 20 Aug 2025 18:14:29 +0800 Subject: [PATCH 4/4] fix dt --- torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py b/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py index 2a5e2012..8f28c430 100644 --- a/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py +++ b/torchrec/hybrid_torchrec/test/dt/test_post_input_dist.py @@ -97,6 +97,7 @@ class TestDoUniqueHash: class TestDoUniqueHashOut: @staticmethod @patch("torch.empty_like", new=empty_like_without_pin_memory) + @patch("torch.Tensor.pin_memory", new=lambda self, *args, **kwargs: self) def test_do_unique_hash_out_with_parallel(): with patch.dict(os.environ, {"ENABLE_PARALLEL_GLOBAL_UNIQUE": "1"}): kjt = KeyedJaggedTensor( -- Gitee