diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py index ea47cc96a2da5a123411c3ddf00b1afd7a8f27d0..6914fd428a041662f3c8639fbfe247e8bae39e1d 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py @@ -83,12 +83,7 @@ def execute(rank: int, config: ExecuteConfig): logging.info("this test %s", os.path.basename(__file__)) embedding_config = generate_base_config(embedding_dims, num_embeddings) - dataset = RandomRecDataset( - batch_size=BATCH_NUM, - lookup_len=lookup_len, - num_lookups=[num_embedding // 2 for num_embedding in num_embeddings], - num_tables=table_num, - ) + dataset = RandomRecDataset(BATCH_NUM, lookup_len, [num_embedding // 2 for num_embedding in num_embeddings], table_num) gloden_dataset_loader = DataLoader( dataset, batch_size=None, @@ -241,6 +236,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embeddingbag.py b/torchrec/hybrid_torchrec/test/test_hybrid_embeddingbag.py index cb819ede9fe6a03b484c99293120868fe67486bf..f3df95fefcf2a3ad79ea87d3da09e57f104b3831 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_embeddingbag.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embeddingbag.py @@ -249,6 +249,7 @@ class TestModel: return results params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], @@ -263,26 +264,14 @@ params = { ExecuteConfig(*v) for v in itertools.product(*params.values()) ]) def test_hybrid_embedding_bag(config: ExecuteConfig): - table_num = config.table_num - embedding_dims = config.embedding_dims - num_embeddings = config.num_embeddings - pool_type = config.pool_type sharding_type = config.sharding_type - lookup_len = config.lookup_len device = config.device if device == "cpu" and sharding_type == "row_wise": return mp.spawn( execute, args=( - WORLD_SIZE, - table_num, - embedding_dims, - num_embeddings, - pool_type, - sharding_type, - lookup_len, - device, + config, ), nprocs=WORLD_SIZE, join=True, diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py index 969980be688a56c04a5759cb746d81c5c8be6ec9..c020fc947ee61ab70e294382595edc88e83a2ca9 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py @@ -231,6 +231,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embeddingbag.py b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embeddingbag.py index 7c8c35218cc80c6d8fcd2d3b60bdaed9f93e0aee..24a87b78bb4e6527582f0f300f866a81fe1dd81d 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embeddingbag.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embeddingbag.py @@ -249,6 +249,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py index 8f48ee8008e20313db3a2975103602e95d13f9bb..295f14419fd639bd33499660808b5ebb7badf01b 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py @@ -239,6 +239,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embeddingbag.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embeddingbag.py index 40283b9c527a88cf84fa0e0fa3410c62c038022d..1f3d6b03bef65ed4cfc3c6e3e7895ae7791c70c8 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embeddingbag.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embeddingbag.py @@ -244,6 +244,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [2], "embedding_dims": [[32, 64, 128]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_train_and_eval.py b/torchrec/hybrid_torchrec/test/test_train_and_eval.py index 3c33557f5188f320170c41472855b7ae3e7b72d8..d1dfe0cac3d3146f83648898cd7cb35608f31782 100644 --- a/torchrec/hybrid_torchrec/test/test_train_and_eval.py +++ b/torchrec/hybrid_torchrec/test/test_train_and_eval.py @@ -207,6 +207,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [2], "embedding_dims": [[32, 64, 128]], "num_embeddings": [[400, 4000, 400]],