From 3f27037d6351baccfd5817adc3ea0d8f5639c76d Mon Sep 17 00:00:00 2001 From: wuyangjian Date: Thu, 12 Jun 2025 19:17:20 +0800 Subject: [PATCH 1/7] =?UTF-8?q?test=E4=BB=A3=E7=A0=81=E5=90=88=E5=85=A5par?= =?UTF-8?q?t2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/test_hybrid_embedding.py | 253 +++++++++++++++++ .../test/test_hybrid_hash_embedding.py | 247 ++++++++++++++++ .../test_hybrid_pipeline_hash_embedding.py | 266 ++++++++++++++++++ .../test/test_kjt_with_count.py | 64 +++++ 4 files changed, 830 insertions(+) create mode 100644 torchrec/hybrid_torchrec/test/test_hybrid_embedding.py create mode 100644 torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py create mode 100644 torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py create mode 100644 torchrec/hybrid_torchrec/test/test_kjt_with_count.py diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py new file mode 100644 index 00000000..b44f08ae --- /dev/null +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +# Copyright (c) Huawei Platforms, Inc. and affiliates. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import torch +from typing import List +import torch_npu +import torch.multiprocessing as mp +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.nn.parallel import DistributedDataParallel as DDP +import torchrec +import pytest +import logging +from torchrec import ( + EmbeddingConfig, + EmbeddingCollection, +) +import torchrec.distributed +import torchrec.distributed.shard +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + Topology, + ParameterConstraints, +) +from torchrec.distributed.types import ShardingEnv +from torchrec.optim.keyed import CombinedOptimizer +from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders +from model import Model +from dataset import RandomRecDataset, Batch +from util import setup_logging + +LOOP_TIMES = 8 +BATCH_NUM = 32 +WORLD_SIZE = 2 + + +def generate_base_config(embedding_dims: List[int], + num_embeddings: List[int]) -> List[EmbeddingConfig]: + test_table_configs: List[EmbeddingConfig] = [] + for i, (table_dim, num_embedding) in enumerate(zip(embedding_dims, num_embeddings)): + config = EmbeddingConfig( + name=f"table{i}", + embedding_dim=table_dim, + num_embeddings=num_embedding, + feature_names=[f"feat{i}"], + init_fn=weight_init, + ) + test_table_configs.append(config) + return test_table_configs + + +def execute( + rank, + world_size, + table_num, + embedding_dims, + num_embeddings, + sharding_type, + lookup_len, + device, +): + setup_logging(rank) + logging.info("this test %s", os.path.basename(__file__)) + embeding_config = generate_base_config(embedding_dims, num_embeddings) + + dataset = RandomRecDataset(BATCH_NUM, lookup_len, [num_embedding//2 for num_embedding in num_embeddings], table_num) + gloden_dataset_loader = DataLoader( + dataset, + batch_size=None, + num_workers=1, + ) + data_loader = DataLoader( + dataset, + batch_size=None, + pin_memory=True, + pin_memory_device="npu", + num_workers=1, + ) + + test_model = TestModel(rank, world_size, device) + + gloden_results = test_model.cpu_gloden_loss( + embeding_config, gloden_dataset_loader, sharding_type + ) + test_results = test_model.test_loss(embeding_config, data_loader, sharding_type) + for gloden, result in zip(gloden_results, test_results): + logging.debug("") + logging.debug("===========================") + logging.debug("result test %s", gloden) + logging.debug("gloden test %s", result) + assert torch.allclose( + gloden, result, rtol=1e-04, atol=1e-04 + ), "gloden and result is not closed" + + +def weight_init(param: torch.nn.Parameter): + if len(param.shape) != 2: + return + torch.manual_seed(param.shape[1]) + result = torch.randn(1, param.shape[1]).repeat(param.shape[0], 1) + param.data.copy_(result) + + +class TestModel: + def __init__(self, rank, world_size, device): + self.rank = rank + self.world_size = world_size + self.device = device + self.pg_method = "hccl" if device == "npu" else "gloo" + if device == "npu": + torch_npu.npu.set_device(rank) + self.setup(rank=rank, world_size=world_size) + + @staticmethod + def cpu_gloden_loss( + embeding_config: List[EmbeddingConfig], + dataloader: DataLoader[Batch], + sharding_type: str, + ): + pg = dist.new_group(backend="gloo") + table_num = len(embeding_config) + ec = EmbeddingCollection(device="cpu", tables=embeding_config) + + num_features = sum([c.num_features() for c in embeding_config]) + ec = Model(ec, num_features) + model = DDP(ec, device_ids=None, process_group=pg) + + opt = torch.optim.Adagrad(ec.parameters(), lr=0.02, eps=1e-8) + results = [] + batch: Batch + iter_ = iter(dataloader) + for _ in range(LOOP_TIMES): + batch = next(iter_) + opt.zero_grad() + loss, output = model(batch) + results.append(loss.detach().cpu()) + results.append(output.detach().cpu()) + loss.backward() + opt.step() + + for i in range(table_num): + logging.debug( + "single table%d weight %s", + i, + ec.ec.embeddings[f"table{i}"].weight, + ) + return results + + def setup(self, rank: int, world_size: int): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "6000" + dist.init_process_group(self.pg_method, rank=rank, world_size=world_size) + + def test_loss( + self, + embeding_config: List[EmbeddingConfig], + dataloader: DataLoader[Batch], + sharding_type: str, + ): + num_features = sum([c.num_features() for c in embeding_config]) + rank, world_size = self.rank, self.world_size + host_gp = dist.new_group(backend="gloo") + host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) + # Shard + table_num = len(embeding_config) + ec = EmbeddingCollection(device="meta", tables=embeding_config) + ec = Model(ec, num_features) + apply_optimizer_in_backward( + optimizer_class=torch.optim.Adagrad, + params=ec.parameters(), + optimizer_kwargs={"lr": 0.02}, + ) + # Shard + constrans = { + f"table{i}": ParameterConstraints(sharding_types=[sharding_type]) + for i in range(table_num) + } + planner = EmbeddingShardingPlanner( + topology=Topology(world_size=self.world_size, compute_device=self.device), + constraints=constrans, + ) + plan = planner.collective_plan( + ec, get_default_hybrid_sharders(host_env), dist.GroupMember.WORLD + ) + if self.rank == 0: + logging.debug(plan) + + ddpModel = torchrec.distributed.DistributedModelParallel( + ec, + sharders=get_default_hybrid_sharders(host_env), + device=torch.device(self.device), + plan=plan, + ) + logging.debug(ddpModel) + # Optimizer + optimizer = CombinedOptimizer([ddpModel.fused_optimizer]) + results = [] + batch: Batch + iter_ = iter(dataloader) + for _ in range(LOOP_TIMES): + batch = next(iter_).to(self.device) + optimizer.zero_grad() + loss, output = ec(batch) + results.append(loss.detach().cpu()) + results.append(output.detach().cpu()) + loss.backward() + optimizer.step() + + for i in range(table_num): + logging.debug( + "shard table%d weight %s", + i, + ddpModel.module.ec.embeddings[f"table{i}"].weight, + ) + return results + + +@pytest.mark.parametrize("table_num", [3]) +@pytest.mark.parametrize("embedding_dims", [[32, 32, 32]]) +@pytest.mark.parametrize("num_embeddings", [[400,4000, 400]]) +@pytest.mark.parametrize("sharding_type", ["table_wise", "row_wise"]) +@pytest.mark.parametrize("lookup_len", [1024]) +@pytest.mark.parametrize("device", ["npu"]) +def test_hybrid_embedding( + table_num, + embedding_dims, + num_embeddings, + sharding_type, + lookup_len, + device, +): + if device == "cpu" and sharding_type == "row_wise": + return + mp.spawn( + execute, + args=( + WORLD_SIZE, + table_num, + embedding_dims, + num_embeddings, + sharding_type, + lookup_len, + device, + ), + nprocs=WORLD_SIZE, + join=True, + ) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py new file mode 100644 index 00000000..707a3fb2 --- /dev/null +++ b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +# Copyright (c) Huawei Platforms, Inc. and affiliates. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import torch +from typing import List +import torch_npu +import torch.multiprocessing as mp +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.nn.parallel import DistributedDataParallel as DDP +import torchrec +import pytest +import logging +from torchrec import ( + EmbeddingConfig, +) +import torchrec.distributed +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + Topology, + ParameterConstraints, +) +from torchrec.distributed.types import ShardingEnv +from torchrec.optim.keyed import CombinedOptimizer +from hybrid_torchrec import HashEmbeddingCollection, EmbeddingConfig +from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders +from model import Model +from dataset import RandomRecDataset, Batch +from util import setup_logging + +LOOP_TIMES = 8 +BATCH_NUM = 32 +WORLD_SIZE = 2 + +def generate_hash_config(embedding_dims: List[int], + num_embeddings: List[int]) -> List[EmbeddingConfig]: + test_table_configs: List[HashEmbeddingCollection] = [] + for i, (table_dim, num_embedding) in enumerate(zip(embedding_dims, num_embeddings)): + config = EmbeddingConfig( + name=f"table{i}", + embedding_dim=table_dim, + num_embeddings=num_embedding, + feature_names=[f"feat{i}"], + init_fn=weight_init, + ) + test_table_configs.append(config) + return test_table_configs + + +def execute( + rank, + world_size, + table_num, + embedding_dims, + num_embeddings, + sharding_type, + lookup_len, + device, +): + setup_logging(rank) + logging.info("this test %s", os.path.basename(__file__)) + embeding_config = generate_hash_config(embedding_dims, num_embeddings) + + dataset = RandomRecDataset(BATCH_NUM, lookup_len, num_embeddings, table_num) + gloden_dataset_loader = DataLoader( + dataset, + batch_size=None, + num_workers=1, + ) + data_loader = DataLoader( + dataset, + batch_size=None, + pin_memory=True, + pin_memory_device="npu", + num_workers=1, + ) + + test_model = TestModel(rank, world_size, device) + + gloden_results = test_model.cpu_gloden_loss(embeding_config, gloden_dataset_loader) + test_results = test_model.test_loss(embeding_config, data_loader, sharding_type) + for gloden, result in zip(gloden_results, test_results): + logging.debug("") + logging.debug("===========================") + logging.debug("result test %s", gloden) + logging.debug("gloden test %s", result) + assert torch.allclose( + gloden, result, rtol=1e-04, atol=1e-04 + ), "gloden and result is not closed" + + +def weight_init(param: torch.nn.Parameter): + if len(param.shape) != 2: + return + torch.manual_seed(param.shape[1]) + result = torch.randn((1, param.shape[1])).repeat(param.shape[0], 1) + param.data.copy_(result) + + +class TestModel: + def __init__(self, rank, world_size, device): + self.rank = rank + self.world_size = world_size + self.device = device + self.pg_method = "hccl" if device == "npu" else "gloo" + if device == "npu": + torch_npu.npu.set_device(rank) + self.setup(rank=rank, world_size=world_size) + + @staticmethod + def cpu_gloden_loss( + embeding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch] + ): + pg = dist.new_group(backend="gloo") + table_num = len(embeding_config) + ec = HashEmbeddingCollection(device="cpu", tables=embeding_config) + + num_features = sum([c.num_features() for c in embeding_config]) + ec = Model(ec, num_features) + model = DDP(ec, device_ids=None, process_group=pg) + + opt = torch.optim.Adagrad(ec.parameters(), lr=0.02, eps=1e-8) + results = [] + batch: Batch + iter_ = iter(dataloader) + for _ in range(LOOP_TIMES): + batch = next(iter_) + opt.zero_grad() + loss, out = model(batch) + results.append(loss.detach().cpu()) + results.append(out.detach().cpu()) + loss.backward() + opt.step() + + for i in range(table_num): + logging.debug( + "single table%d weight %s", + i, + ec.ec.embeddings[f"table{i}"].weight, + ) + return results + + def setup(self, rank: int, world_size: int): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "6000" + dist.init_process_group(self.pg_method, rank=rank, world_size=world_size) + + def test_loss( + self, + embeding_config: List[EmbeddingConfig], + dataloader: DataLoader[Batch], + sharding_type: str, + ): + num_features = sum([c.num_features() for c in embeding_config]) + rank, world_size = self.rank, self.world_size + host_gp = dist.new_group(backend="gloo") + host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) + # Shard + table_num = len(embeding_config) + ec = HashEmbeddingCollection(device=torch.device("meta"), tables=embeding_config) + ec = Model(ec, num_features) + apply_optimizer_in_backward( + optimizer_class=torch.optim.Adagrad, + params=ec.parameters(), + optimizer_kwargs={"lr": 0.02}, + ) + # Shard + constrans = { + f"table{i}": ParameterConstraints(sharding_types=[sharding_type]) + for i in range(table_num) + } + planner = EmbeddingShardingPlanner( + topology=Topology(world_size=self.world_size, compute_device=self.device), + constraints=constrans, + ) + plan = planner.collective_plan( + ec, get_default_hybrid_sharders(host_env), dist.GroupMember.WORLD + ) + if self.rank == 0: + logging.debug(plan) + + ddpModel = torchrec.distributed.DistributedModelParallel( + ec, + sharders=get_default_hybrid_sharders(host_env), + device=torch.device(self.device), + plan=plan, + ) + logging.debug(ddpModel) + # Optimizer + optimizer = CombinedOptimizer([ddpModel.fused_optimizer]) + results = [] + batch: Batch + iter_ = iter(dataloader) + for _ in range(LOOP_TIMES): + batch = next(iter_).to(self.device) + optimizer.zero_grad() + loss, out = ec(batch) + results.append(loss.detach().cpu()) + results.append(out.detach().cpu()) + loss.backward() + optimizer.step() + + for i in range(table_num): + logging.debug( + "shard table%d weight %s", + i, + ddpModel.module.ec.embeddings[f"table{i}"].weight, + ) + return results + + +@pytest.mark.parametrize("table_num", [3]) +@pytest.mark.parametrize("embedding_dims", [[32, 32, 32]]) +@pytest.mark.parametrize("num_embeddings", [[400, 4000, 400]]) +@pytest.mark.parametrize("sharding_type", ["table_wise", "row_wise"]) +@pytest.mark.parametrize("lookup_len", [1024]) +@pytest.mark.parametrize("device", ["npu"]) +def test_hybrid_hash_embedding( + table_num, + embedding_dims, + num_embeddings, + sharding_type, + lookup_len, + device, +): + if device == "cpu" and sharding_type == "row_wise": + return + mp.spawn( + execute, + args=( + WORLD_SIZE, + table_num, + embedding_dims, + num_embeddings, + sharding_type, + lookup_len, + device, + ), + nprocs=WORLD_SIZE, + join=True, + ) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py new file mode 100644 index 00000000..46e13b0c --- /dev/null +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 +# Copyright (c) Huawei Platforms, Inc. and affiliates. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import pytz +import torch +from typing import List +import torch_npu +import torch.multiprocessing as mp +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.optim import Adam, Adagrad +import torchrec +import pytest +import logging +import random +from torchrec import EmbeddingConfig, EmbeddingCollection +import torchrec.distributed +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionAwaitable +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + Topology, + ParameterConstraints, +) +from torchrec.distributed.types import ShardingEnv +from torchrec.optim.keyed import CombinedOptimizer +from hybrid_torchrec import HashEmbeddingBagCollection, HashEmbeddingBagConfig, HashEmbeddingCollection +from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders +from hybrid_torchrec.distributed.hybrid_train_pipeline import ( + HybridTrainPipelineSparseDist, +) +from dataset import RandomRecDataset, Batch +from model import Model +from util import setup_logging + +OPTIMIZER_PARAM = { + Adam: dict(lr=0.02), + Adagrad: dict(lr=0.02, eps=1.0e-8), +} + +WORLD_SIZE = 2 +LOOP_TIMES = 8 +BATCH_NUM = 32 + + +def execute( + rank, + world_size, + table_num, + embedding_dims, + num_embeddings, + pool_type, + sharding_type, + lookup_len, + device, + optim, +): + setup_logging(rank) + logging.info("this test %s", os.path.basename(__file__)) + # , batch_num, lookup_lens, num_embeddings, table_num + dataset_gloden = RandomRecDataset(BATCH_NUM, lookup_len, num_embeddings, table_num) + dataset = RandomRecDataset(BATCH_NUM, lookup_len, num_embeddings, table_num) + dataset_loader_gloden = DataLoader( + dataset_gloden, + batch_size=None, + batch_sampler=None, + pin_memory=True, + ) + data_loader = DataLoader( + dataset, + batch_size=None, + batch_sampler=None, + pin_memory=True, + pin_memory_device="npu", + num_workers=1, + ) + embeding_config = [] + for i in range(table_num): + ec_config = EmbeddingConfig( + name=f"table{i}", + embedding_dim=embedding_dims[i], + num_embeddings=num_embeddings[i], + feature_names=[f"feat{i}"], + init_fn=weight_init, + ) + embeding_config.append(ec_config) + + test_model = TestModel(rank, world_size, device) + gloden_results = test_model.cpu_gloden_loss(embeding_config, dataset_loader_gloden, optim) + test_results = test_model.test_loss(embeding_config, data_loader, sharding_type, optim) + for gloden, result in zip(gloden_results, test_results): + logging.debug("") + logging.debug("===========================") + logging.debug("result test %s", gloden) + logging.debug("gloden test %s", result) + assert torch.allclose( + gloden, result, rtol=1e-04, atol=1e-04 + ), "gloden and result is not closed" + + +def weight_init(param: torch.nn.Parameter): + if len(param.shape) != 2: + return + torch.manual_seed(param.shape[1]) + result = torch.randn((1, param.shape[1])).repeat(param.shape[0], 1) + param.data.copy_(result) + + +class TestModel: + def __init__(self, rank, world_size, device): + self.rank = rank + self.world_size = world_size + self.device = device + self.pg_method = "hccl" if device == "npu" else "gloo" + if device == "npu": + torch_npu.npu.set_device(rank) + self.setup(rank=rank, world_size=world_size) + + @staticmethod + def cpu_gloden_loss( + embeding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch], optim + ): + pg = dist.new_group(backend="gloo") + table_num = len(embeding_config) + ec = EmbeddingCollection(device="cpu", tables=embeding_config) + + num_features = sum([c.num_features() for c in embeding_config]) + ec = Model(ec, num_features) + model = DDP(ec, device_ids=None, process_group=pg) + + opt = optim(ec.parameters(), **OPTIMIZER_PARAM[optim]) + results = [] + batch: Batch + iter_ = iter(dataloader) + for _ in range(LOOP_TIMES): + batch = next(iter_) + opt.zero_grad() + loss, output = model(batch) + results.append(loss.detach().cpu()) + results.append(output.detach().cpu()) + loss.backward() + opt.step() + + for i in range(table_num): + logging.debug( + "single table%d weight %s", + i, + ec.ec.embeddings[f"table{i}"].weight, + ) + return results + + def setup(self, rank: int, world_size: int): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "6000" + dist.init_process_group(self.pg_method, rank=rank, world_size=world_size) + os.environ["LOCAL_RANK"] = f"{rank}" + + def test_loss( + self, + embeding_config: List[EmbeddingConfig], + dataloader: DataLoader[Batch], + sharding_type: str, + optim, + ): + rank, world_size = self.rank, self.world_size + host_gp = dist.new_group(backend="gloo") + host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) + + table_num = len(embeding_config) + ec = HashEmbeddingCollection(device=self.device, tables=embeding_config) + num_features = sum([c.num_features() for c in embeding_config]) + ec = Model(ec, num_features) + apply_optimizer_in_backward( + optimizer_class=optim, + params=ec.parameters(), + optimizer_kwargs=OPTIMIZER_PARAM[optim], + ) + # Shard + constrans = { + f"table{i}": ParameterConstraints(sharding_types=[sharding_type]) + for i in range(table_num) + } + planner = EmbeddingShardingPlanner( + topology=Topology(world_size=self.world_size, compute_device=self.device), + constraints=constrans, + ) + plan = planner.collective_plan( + ec, get_default_hybrid_sharders(host_env), dist.GroupMember.WORLD + ) + if self.rank == 0: + logging.debug(plan) + + ddpModel = torchrec.distributed.DistributedModelParallel( + ec, + sharders=get_default_hybrid_sharders(host_env), + device=torch.device(self.device), + plan=plan, + ) + logging.debug(ddpModel) + # Optimizer + optimizer = CombinedOptimizer([ddpModel.fused_optimizer]) + results = [] + iter_ = iter(dataloader) + ddpModel.train() + pipe = HybridTrainPipelineSparseDist( + ddpModel, + optimizer=optimizer, + device=torch.device(self.device), + return_loss=True, + ) + for _ in range(LOOP_TIMES): + out, loss = pipe.progress(iter_) + results.append(loss.detach().cpu()) + results.append(out.detach().cpu()) + + for i in range(table_num): + logging.debug( + "shard table%d weight %s", + i, + ddpModel.module.ec.embeddings[f"table{i}"].weight, + ) + return results + + +@pytest.mark.parametrize("table_num", [3]) +@pytest.mark.parametrize("embedding_dims", [[32, 32, 32]]) +@pytest.mark.parametrize("num_embeddings", [[400, 4000, 400]]) +@pytest.mark.parametrize("pool_type", [torchrec.PoolingType.MEAN]) +@pytest.mark.parametrize("sharding_type", ["table_wise", "row_wise"]) +@pytest.mark.parametrize("lookup_len", [1024]) +@pytest.mark.parametrize("device", ["npu"]) +@pytest.mark.parametrize("optim", [Adagrad]) +def test_hybrid_pipeline_hash_embedding( + table_num, + embedding_dims, + num_embeddings, + pool_type, + sharding_type, + lookup_len, + device, + optim, +): + if device == "cpu" and (sharding_type == "row_wise" or optim == Adam): + return + mp.spawn( + execute, + args=( + WORLD_SIZE, + table_num, + embedding_dims, + num_embeddings, + pool_type, + sharding_type, + lookup_len, + device, + optim, + ), + nprocs=WORLD_SIZE, + join=True, + ) diff --git a/torchrec/hybrid_torchrec/test/test_kjt_with_count.py b/torchrec/hybrid_torchrec/test/test_kjt_with_count.py new file mode 100644 index 00000000..677b444c --- /dev/null +++ b/torchrec/hybrid_torchrec/test/test_kjt_with_count.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) Huawei Platforms, Inc. and affiliates. +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +import logging +from hybrid_torchrec.sparse import ( + JaggedTensorWithCount, + KeyedJaggedTensorWithCount) + +logging.basicConfig(level=logging.DEBUG) + + +@pytest.mark.parametrize("table_num", [3]) +@pytest.mark.parametrize("feature_names", [[1, 1, 2]]) +@pytest.mark.parametrize("input_size", [10]) +def test_kjt_with_count(table_num, feature_names, input_size): + input_dict = {} + feature_len = sum(feature_names) + for ind in range(feature_len - 1, -1, -1): # feature 逆序,用于后面验证permute和split + name = f"feat{ind}" + id_range = input_size + ids = torch.randint(0, id_range, (input_size,)) + lengths = torch.ones(input_size).long() + counts = torch.clone(ids) # 生成和key相同的counts信息 + input_dict[name] = JaggedTensorWithCount(values=ids, lengths=lengths, counts=counts) + + kjt_with_count = KeyedJaggedTensorWithCount.from_jt_dict(input_dict) + logging.info("kjt_with_count:%s", kjt_with_count) + print("kjt_with_count:", kjt_with_count) + + # permute + feature_names_for_sharding = [f"feat{ind}" for ind in range(feature_len)] + input_feature_names = kjt_with_count.keys() + features_order_index = [] + for f in feature_names_for_sharding: + features_order_index.append(input_feature_names.index(f)) + kjt_permuted = kjt_with_count.permute(features_order_index) + logging.info("kjt_with_count keys after permute:%s", kjt_permuted.keys()) + + offset_per_key = kjt_permuted.offset_per_key() + for i in range(len(offset_per_key) - 1): + start = offset_per_key[i] + end = offset_per_key[i + 1] + keys = kjt_permuted.values()[start:end] + key_with_count = kjt_permuted.counts[start:end] + assert torch.all(key_with_count == keys), "key_with_count is not equal after kjt permute." + + # split 比较分割后的count是否符合预期 + feature_splits = [1, 1, 2] + kjt_list = kjt_permuted.split(feature_splits) + for index, kjt in enumerate(kjt_list): + key_with_counts = kjt.counts + keys = kjt.values() + assert torch.all(key_with_counts == keys), "key_with_counts is not equals with keys after kjt split." + + +if __name__ == '__main__': + test_kjt_with_count(3, [1, 1, 2], 10) -- Gitee From 64d23d5a6de05a05b07996af8b42aa99aecddf1d Mon Sep 17 00:00:00 2001 From: wuyangjian Date: Fri, 13 Jun 2025 12:00:39 +0800 Subject: [PATCH 2/7] =?UTF-8?q?test=E4=BB=A3=E7=A0=81=E5=90=88=E5=85=A5par?= =?UTF-8?q?t2,=20cleancode=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/test_hybrid_embedding.py | 111 ++++++++------- .../test/test_hybrid_hash_embedding.py | 115 +++++++++------- .../test_hybrid_pipeline_hash_embedding.py | 127 +++++++++--------- .../test/test_kjt_with_count.py | 4 +- 4 files changed, 196 insertions(+), 161 deletions(-) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py index b44f08ae..37cbfc71 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py @@ -5,23 +5,30 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +import logging import os -import torch from typing import List -import torch_npu -import torch.multiprocessing as mp + +import pytest +import torch import torch.distributed as dist -from torch.utils.data import DataLoader +import torch.multiprocessing as mp +import torch_npu +from dataset import RandomRecDataset, Batch +from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders +from model import Model from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from util import setup_logging + import torchrec -import pytest -import logging +import torchrec.distributed +import torchrec.distributed.shard from torchrec import ( EmbeddingConfig, EmbeddingCollection, ) -import torchrec.distributed -import torchrec.distributed.shard from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward from torchrec.distributed.planner import ( EmbeddingShardingPlanner, @@ -30,10 +37,7 @@ from torchrec.distributed.planner import ( ) from torchrec.distributed.types import ShardingEnv from torchrec.optim.keyed import CombinedOptimizer -from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders -from model import Model -from dataset import RandomRecDataset, Batch -from util import setup_logging + LOOP_TIMES = 8 BATCH_NUM = 32 @@ -55,21 +59,35 @@ def generate_base_config(embedding_dims: List[int], return test_table_configs -def execute( - rank, - world_size, - table_num, - embedding_dims, - num_embeddings, - sharding_type, - lookup_len, - device, -): +@dataclass +class ExecuteParams: + world_size: int + table_num: int + embedding_dims: List[int] + num_embeddings: List[int] + sharding_type: str + lookup_len: int + device: str + + +def execute(rank: int, params: ExecuteParams): + world_size = params.world_size + table_num = params.table_num + embedding_dims = params.embedding_dims + num_embeddings = params.num_embeddings + sharding_type = params.sharding_type + lookup_len = params.lookup_len + device = params.device setup_logging(rank) logging.info("this test %s", os.path.basename(__file__)) embeding_config = generate_base_config(embedding_dims, num_embeddings) - dataset = RandomRecDataset(BATCH_NUM, lookup_len, [num_embedding//2 for num_embedding in num_embeddings], table_num) + dataset = RandomRecDataset( + batch_size=BATCH_NUM, + lookup_len=lookup_len, + num_lookups=[num_embedding // 2 for num_embedding in num_embeddings], + num_tables=table_num, + ) gloden_dataset_loader = DataLoader( dataset, batch_size=None, @@ -221,33 +239,32 @@ class TestModel: return results -@pytest.mark.parametrize("table_num", [3]) -@pytest.mark.parametrize("embedding_dims", [[32, 32, 32]]) -@pytest.mark.parametrize("num_embeddings", [[400,4000, 400]]) -@pytest.mark.parametrize("sharding_type", ["table_wise", "row_wise"]) -@pytest.mark.parametrize("lookup_len", [1024]) -@pytest.mark.parametrize("device", ["npu"]) -def test_hybrid_embedding( - table_num, - embedding_dims, - num_embeddings, - sharding_type, - lookup_len, - device, -): - if device == "cpu" and sharding_type == "row_wise": +@pytest.mark.parametrize("params", [ + ExecuteParams( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="table_wise", + lookup_len=1024, + device="npu", + ), + ExecuteParams( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="row_wise", + lookup_len=1024, + device="npu", + ), +]) +def test_hybrid_embedding(params: ExecuteParams): + if params.device == "cpu" and params.sharding_type == "row_wise": return mp.spawn( execute, - args=( - WORLD_SIZE, - table_num, - embedding_dims, - num_embeddings, - sharding_type, - lookup_len, - device, - ), + args=(params,), nprocs=WORLD_SIZE, join=True, ) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py index 707a3fb2..b323731a 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py @@ -5,39 +5,44 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +import logging import os -import torch from typing import List -import torch_npu -import torch.multiprocessing as mp + +import pytest +import torch import torch.distributed as dist -from torch.utils.data import DataLoader +import torch.multiprocessing as mp +import torch_npu +from dataset import RandomRecDataset, Batch +from hybrid_torchrec import HashEmbeddingCollection, EmbeddingConfig +from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders +from model import Model from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from util import setup_logging + import torchrec -import pytest -import logging +import torchrec.distributed from torchrec import ( EmbeddingConfig, ) -import torchrec.distributed -from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward from torchrec.distributed.planner import ( EmbeddingShardingPlanner, Topology, ParameterConstraints, ) from torchrec.distributed.types import ShardingEnv +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward from torchrec.optim.keyed import CombinedOptimizer -from hybrid_torchrec import HashEmbeddingCollection, EmbeddingConfig -from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders -from model import Model -from dataset import RandomRecDataset, Batch -from util import setup_logging + LOOP_TIMES = 8 BATCH_NUM = 32 WORLD_SIZE = 2 + def generate_hash_config(embedding_dims: List[int], num_embeddings: List[int]) -> List[EmbeddingConfig]: test_table_configs: List[HashEmbeddingCollection] = [] @@ -53,16 +58,25 @@ def generate_hash_config(embedding_dims: List[int], return test_table_configs -def execute( - rank, - world_size, - table_num, - embedding_dims, - num_embeddings, - sharding_type, - lookup_len, - device, -): +@dataclass +class ExecuteParams: + world_size: int + table_num: int + embedding_dims: List[int] + num_embeddings: List[int] + sharding_type: str + lookup_len: int + device: str + + +def execute(rank: int, params: ExecuteParams): + world_size = params.world_size + table_num = params.table_num + embedding_dims = params.embedding_dims + num_embeddings = params.num_embeddings + sharding_type = params.sharding_type + lookup_len = params.lookup_len + device = params.device setup_logging(rank) logging.info("this test %s", os.path.basename(__file__)) embeding_config = generate_hash_config(embedding_dims, num_embeddings) @@ -185,15 +199,15 @@ class TestModel: if self.rank == 0: logging.debug(plan) - ddpModel = torchrec.distributed.DistributedModelParallel( + ddp_model = torchrec.distributed.DistributedModelParallel( ec, sharders=get_default_hybrid_sharders(host_env), device=torch.device(self.device), plan=plan, ) - logging.debug(ddpModel) + logging.debug(ddp_model) # Optimizer - optimizer = CombinedOptimizer([ddpModel.fused_optimizer]) + optimizer = CombinedOptimizer([ddp_model.fused_optimizer]) results = [] batch: Batch iter_ = iter(dataloader) @@ -210,38 +224,37 @@ class TestModel: logging.debug( "shard table%d weight %s", i, - ddpModel.module.ec.embeddings[f"table{i}"].weight, + ddp_model.module.ec.embeddings[f"table{i}"].weight, ) return results -@pytest.mark.parametrize("table_num", [3]) -@pytest.mark.parametrize("embedding_dims", [[32, 32, 32]]) -@pytest.mark.parametrize("num_embeddings", [[400, 4000, 400]]) -@pytest.mark.parametrize("sharding_type", ["table_wise", "row_wise"]) -@pytest.mark.parametrize("lookup_len", [1024]) -@pytest.mark.parametrize("device", ["npu"]) -def test_hybrid_hash_embedding( - table_num, - embedding_dims, - num_embeddings, - sharding_type, - lookup_len, - device, -): - if device == "cpu" and sharding_type == "row_wise": +@pytest.mark.parametrize("params", [ + ExecuteParams( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="table_wise", + lookup_len=1024, + device="npu" + ), + ExecuteParams( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="row_wise", + lookup_len=1024, + device="npu" + ), +]) +def test_hybrid_hash_embedding(params: ExecuteParams): + if params.device == "cpu" and params.sharding_type == "row_wise": return mp.spawn( execute, - args=( - WORLD_SIZE, - table_num, - embedding_dims, - num_embeddings, - sharding_type, - lookup_len, - device, - ), + args=(params,), nprocs=WORLD_SIZE, join=True, ) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py index 46e13b0c..aa1f835b 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py @@ -5,23 +5,31 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import os -import pytz -import torch +from dataclasses import dataclass from typing import List -import torch_npu -import torch.multiprocessing as mp + +import pytest +import torch import torch.distributed as dist +import torch.multiprocessing as mp +import torch_npu +from dataset import RandomRecDataset, Batch +from hybrid_torchrec import HashEmbeddingBagCollection, HashEmbeddingBagConfig, HashEmbeddingCollection +from hybrid_torchrec.distributed.hybrid_train_pipeline import ( + HybridTrainPipelineSparseDist, +) +from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders +from model import Model +from torch.optim import Adam, Adagrad from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader -from torch.optim import Adam, Adagrad +from util import setup_logging + import torchrec -import pytest -import logging -import random -from torchrec import EmbeddingConfig, EmbeddingCollection import torchrec.distributed -from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward +from torchrec import EmbeddingConfig, EmbeddingCollection from torchrec.distributed.embeddingbag import EmbeddingBagCollectionAwaitable from torchrec.distributed.planner import ( EmbeddingShardingPlanner, @@ -29,15 +37,8 @@ from torchrec.distributed.planner import ( ParameterConstraints, ) from torchrec.distributed.types import ShardingEnv +from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward from torchrec.optim.keyed import CombinedOptimizer -from hybrid_torchrec import HashEmbeddingBagCollection, HashEmbeddingBagConfig, HashEmbeddingCollection -from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders -from hybrid_torchrec.distributed.hybrid_train_pipeline import ( - HybridTrainPipelineSparseDist, -) -from dataset import RandomRecDataset, Batch -from model import Model -from util import setup_logging OPTIMIZER_PARAM = { Adam: dict(lr=0.02), @@ -49,18 +50,27 @@ LOOP_TIMES = 8 BATCH_NUM = 32 -def execute( - rank, - world_size, - table_num, - embedding_dims, - num_embeddings, - pool_type, - sharding_type, - lookup_len, - device, - optim, -): +@dataclass +class ExecuteParams: + world_size: int + table_num: int + embedding_dims: List[int] + num_embeddings: List[int] + sharding_type: str + lookup_len: int + device: str + optim: type + + +def execute(rank: int, params: ExecuteParams): + world_size = params.world_size + table_num = params.table_num + embedding_dims = params.embedding_dims + num_embeddings = params.num_embeddings + sharding_type = params.sharding_type + lookup_len = params.lookup_len + device = params.device + optim = params.optim setup_logging(rank) logging.info("this test %s", os.path.basename(__file__)) # , batch_num, lookup_lens, num_embeddings, table_num @@ -228,39 +238,34 @@ class TestModel: return results -@pytest.mark.parametrize("table_num", [3]) -@pytest.mark.parametrize("embedding_dims", [[32, 32, 32]]) -@pytest.mark.parametrize("num_embeddings", [[400, 4000, 400]]) -@pytest.mark.parametrize("pool_type", [torchrec.PoolingType.MEAN]) -@pytest.mark.parametrize("sharding_type", ["table_wise", "row_wise"]) -@pytest.mark.parametrize("lookup_len", [1024]) -@pytest.mark.parametrize("device", ["npu"]) -@pytest.mark.parametrize("optim", [Adagrad]) -def test_hybrid_pipeline_hash_embedding( - table_num, - embedding_dims, - num_embeddings, - pool_type, - sharding_type, - lookup_len, - device, - optim, -): - if device == "cpu" and (sharding_type == "row_wise" or optim == Adam): +@pytest.mark.parametrize("params", [ + ExecuteParams( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="table_wise", + lookup_len=1024, + device="npu", + optim=Adagrad, + ), + ExecuteParams( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="row_wise", + lookup_len=1024, + device="npu", + optim=Adagrad, + ), +]) +def test_hybrid_pipeline_hash_embedding(params: ExecuteParams): + if params.device == "cpu" and (params.sharding_type == "row_wise" or params.optim == Adam): return mp.spawn( execute, - args=( - WORLD_SIZE, - table_num, - embedding_dims, - num_embeddings, - pool_type, - sharding_type, - lookup_len, - device, - optim, - ), - nprocs=WORLD_SIZE, + args=(params,), + nprocs=params.world_size, join=True, ) diff --git a/torchrec/hybrid_torchrec/test/test_kjt_with_count.py b/torchrec/hybrid_torchrec/test/test_kjt_with_count.py index 677b444c..eec07fb8 100644 --- a/torchrec/hybrid_torchrec/test/test_kjt_with_count.py +++ b/torchrec/hybrid_torchrec/test/test_kjt_with_count.py @@ -6,9 +6,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging + import pytest import torch -import logging from hybrid_torchrec.sparse import ( JaggedTensorWithCount, KeyedJaggedTensorWithCount) @@ -32,7 +33,6 @@ def test_kjt_with_count(table_num, feature_names, input_size): kjt_with_count = KeyedJaggedTensorWithCount.from_jt_dict(input_dict) logging.info("kjt_with_count:%s", kjt_with_count) - print("kjt_with_count:", kjt_with_count) # permute feature_names_for_sharding = [f"feat{ind}" for ind in range(feature_len)] -- Gitee From 49a3d070a215a66b4717591b3068675acea55d62 Mon Sep 17 00:00:00 2001 From: wuyangjian Date: Fri, 13 Jun 2025 12:05:19 +0800 Subject: [PATCH 3/7] =?UTF-8?q?test=E4=BB=A3=E7=A0=81=E5=90=88=E5=85=A5par?= =?UTF-8?q?t2,=20cleancode=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../hybrid_torchrec/test/test_hybrid_embedding.py | 8 ++++---- .../test/test_hybrid_pipeline_hash_embedding.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py index 37cbfc71..a082bc71 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py @@ -209,15 +209,15 @@ class TestModel: if self.rank == 0: logging.debug(plan) - ddpModel = torchrec.distributed.DistributedModelParallel( + ddp_model = torchrec.distributed.DistributedModelParallel( ec, sharders=get_default_hybrid_sharders(host_env), device=torch.device(self.device), plan=plan, ) - logging.debug(ddpModel) + logging.debug(ddp_model) # Optimizer - optimizer = CombinedOptimizer([ddpModel.fused_optimizer]) + optimizer = CombinedOptimizer([ddp_model.fused_optimizer]) results = [] batch: Batch iter_ = iter(dataloader) @@ -234,7 +234,7 @@ class TestModel: logging.debug( "shard table%d weight %s", i, - ddpModel.module.ec.embeddings[f"table{i}"].weight, + ddp_model.module.ec.embeddings[f"table{i}"].weight, ) return results diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py index aa1f835b..87d99992 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py @@ -206,20 +206,20 @@ class TestModel: if self.rank == 0: logging.debug(plan) - ddpModel = torchrec.distributed.DistributedModelParallel( + ddp_model = torchrec.distributed.DistributedModelParallel( ec, sharders=get_default_hybrid_sharders(host_env), device=torch.device(self.device), plan=plan, ) - logging.debug(ddpModel) + logging.debug(ddp_model) # Optimizer - optimizer = CombinedOptimizer([ddpModel.fused_optimizer]) + optimizer = CombinedOptimizer([ddp_model.fused_optimizer]) results = [] iter_ = iter(dataloader) - ddpModel.train() + ddp_model.train() pipe = HybridTrainPipelineSparseDist( - ddpModel, + ddp_model, optimizer=optimizer, device=torch.device(self.device), return_loss=True, @@ -233,7 +233,7 @@ class TestModel: logging.debug( "shard table%d weight %s", i, - ddpModel.module.ec.embeddings[f"table{i}"].weight, + ddp_model.module.ec.embeddings[f"table{i}"].weight, ) return results -- Gitee From 3f96ba03547101657ff52e4b79826bca36421d59 Mon Sep 17 00:00:00 2001 From: wuyangjian Date: Fri, 13 Jun 2025 15:55:52 +0800 Subject: [PATCH 4/7] =?UTF-8?q?test=E4=BB=A3=E7=A0=81=E5=90=88=E5=85=A5par?= =?UTF-8?q?t2,=20cleancode=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/test_hybrid_embedding.py | 70 ++++++++++------- .../test/test_hybrid_hash_embedding.py | 70 ++++++++++------- .../test_hybrid_pipeline_hash_embedding.py | 78 ++++++++++++------- 3 files changed, 137 insertions(+), 81 deletions(-) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py index a082bc71..3a5ac785 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py @@ -60,7 +60,7 @@ def generate_base_config(embedding_dims: List[int], @dataclass -class ExecuteParams: +class ExecuteConfig: world_size: int table_num: int embedding_dims: List[int] @@ -70,17 +70,17 @@ class ExecuteParams: device: str -def execute(rank: int, params: ExecuteParams): - world_size = params.world_size - table_num = params.table_num - embedding_dims = params.embedding_dims - num_embeddings = params.num_embeddings - sharding_type = params.sharding_type - lookup_len = params.lookup_len - device = params.device +def execute(rank: int, config: ExecuteConfig): + world_size = config.world_size + table_num = config.table_num + embedding_dims = config.embedding_dims + num_embeddings = config.num_embeddings + sharding_type = config.sharding_type + lookup_len = config.lookup_len + device = config.device setup_logging(rank) logging.info("this test %s", os.path.basename(__file__)) - embeding_config = generate_base_config(embedding_dims, num_embeddings) + embedding_config = generate_base_config(embedding_dims, num_embeddings) dataset = RandomRecDataset( batch_size=BATCH_NUM, @@ -104,9 +104,9 @@ def execute(rank: int, params: ExecuteParams): test_model = TestModel(rank, world_size, device) gloden_results = test_model.cpu_gloden_loss( - embeding_config, gloden_dataset_loader, sharding_type + embedding_config, gloden_dataset_loader, sharding_type ) - test_results = test_model.test_loss(embeding_config, data_loader, sharding_type) + test_results = test_model.test_loss(embedding_config, data_loader, sharding_type) for gloden, result in zip(gloden_results, test_results): logging.debug("") logging.debug("===========================") @@ -137,15 +137,15 @@ class TestModel: @staticmethod def cpu_gloden_loss( - embeding_config: List[EmbeddingConfig], + embedding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch], sharding_type: str, ): pg = dist.new_group(backend="gloo") - table_num = len(embeding_config) - ec = EmbeddingCollection(device="cpu", tables=embeding_config) + table_num = len(embedding_config) + ec = EmbeddingCollection(device="cpu", tables=embedding_config) - num_features = sum([c.num_features() for c in embeding_config]) + num_features = sum([c.num_features() for c in embedding_config]) ec = Model(ec, num_features) model = DDP(ec, device_ids=None, process_group=pg) @@ -177,17 +177,17 @@ class TestModel: def test_loss( self, - embeding_config: List[EmbeddingConfig], + embedding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch], sharding_type: str, ): - num_features = sum([c.num_features() for c in embeding_config]) + num_features = sum([c.num_features() for c in embedding_config]) rank, world_size = self.rank, self.world_size host_gp = dist.new_group(backend="gloo") host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) # Shard - table_num = len(embeding_config) - ec = EmbeddingCollection(device="meta", tables=embeding_config) + table_num = len(embedding_config) + ec = EmbeddingCollection(device="meta", tables=embedding_config) ec = Model(ec, num_features) apply_optimizer_in_backward( optimizer_class=torch.optim.Adagrad, @@ -239,8 +239,8 @@ class TestModel: return results -@pytest.mark.parametrize("params", [ - ExecuteParams( +@pytest.mark.parametrize("config", [ + ExecuteConfig( world_size=WORLD_SIZE, table_num=3, embedding_dims=[32, 32, 32], @@ -249,7 +249,7 @@ class TestModel: lookup_len=1024, device="npu", ), - ExecuteParams( + ExecuteConfig( world_size=WORLD_SIZE, table_num=3, embedding_dims=[32, 32, 32], @@ -258,13 +258,31 @@ class TestModel: lookup_len=1024, device="npu", ), + ExecuteConfig( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="table_wise", + lookup_len=1024, + device="cpu", + ), + ExecuteConfig( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="row_wise", + lookup_len=1024, + device="cpu", + ), ]) -def test_hybrid_embedding(params: ExecuteParams): - if params.device == "cpu" and params.sharding_type == "row_wise": +def test_hybrid_embedding(config: ExecuteConfig): + if config.device == "cpu" and config.sharding_type == "row_wise": return mp.spawn( execute, - args=(params,), + args=(config,), nprocs=WORLD_SIZE, join=True, ) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py index b323731a..ea8d4bb9 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py @@ -59,7 +59,7 @@ def generate_hash_config(embedding_dims: List[int], @dataclass -class ExecuteParams: +class ExecuteConfig: world_size: int table_num: int embedding_dims: List[int] @@ -69,17 +69,17 @@ class ExecuteParams: device: str -def execute(rank: int, params: ExecuteParams): - world_size = params.world_size - table_num = params.table_num - embedding_dims = params.embedding_dims - num_embeddings = params.num_embeddings - sharding_type = params.sharding_type - lookup_len = params.lookup_len - device = params.device +def execute(rank: int, config: ExecuteConfig): + world_size = config.world_size + table_num = config.table_num + embedding_dims = config.embedding_dims + num_embeddings = config.num_embeddings + sharding_type = config.sharding_type + lookup_len = config.lookup_len + device = config.device setup_logging(rank) logging.info("this test %s", os.path.basename(__file__)) - embeding_config = generate_hash_config(embedding_dims, num_embeddings) + embedding_config = generate_hash_config(embedding_dims, num_embeddings) dataset = RandomRecDataset(BATCH_NUM, lookup_len, num_embeddings, table_num) gloden_dataset_loader = DataLoader( @@ -97,8 +97,8 @@ def execute(rank: int, params: ExecuteParams): test_model = TestModel(rank, world_size, device) - gloden_results = test_model.cpu_gloden_loss(embeding_config, gloden_dataset_loader) - test_results = test_model.test_loss(embeding_config, data_loader, sharding_type) + gloden_results = test_model.cpu_gloden_loss(embedding_config, gloden_dataset_loader) + test_results = test_model.test_loss(embedding_config, data_loader, sharding_type) for gloden, result in zip(gloden_results, test_results): logging.debug("") logging.debug("===========================") @@ -129,13 +129,13 @@ class TestModel: @staticmethod def cpu_gloden_loss( - embeding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch] + embedding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch] ): pg = dist.new_group(backend="gloo") - table_num = len(embeding_config) - ec = HashEmbeddingCollection(device="cpu", tables=embeding_config) + table_num = len(embedding_config) + ec = HashEmbeddingCollection(device="cpu", tables=embedding_config) - num_features = sum([c.num_features() for c in embeding_config]) + num_features = sum([c.num_features() for c in embedding_config]) ec = Model(ec, num_features) model = DDP(ec, device_ids=None, process_group=pg) @@ -167,17 +167,17 @@ class TestModel: def test_loss( self, - embeding_config: List[EmbeddingConfig], + embedding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch], sharding_type: str, ): - num_features = sum([c.num_features() for c in embeding_config]) + num_features = sum([c.num_features() for c in embedding_config]) rank, world_size = self.rank, self.world_size host_gp = dist.new_group(backend="gloo") host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) # Shard - table_num = len(embeding_config) - ec = HashEmbeddingCollection(device=torch.device("meta"), tables=embeding_config) + table_num = len(embedding_config) + ec = HashEmbeddingCollection(device=torch.device("meta"), tables=embedding_config) ec = Model(ec, num_features) apply_optimizer_in_backward( optimizer_class=torch.optim.Adagrad, @@ -229,8 +229,8 @@ class TestModel: return results -@pytest.mark.parametrize("params", [ - ExecuteParams( +@pytest.mark.parametrize("config", [ + ExecuteConfig( world_size=WORLD_SIZE, table_num=3, embedding_dims=[32, 32, 32], @@ -239,7 +239,7 @@ class TestModel: lookup_len=1024, device="npu" ), - ExecuteParams( + ExecuteConfig( world_size=WORLD_SIZE, table_num=3, embedding_dims=[32, 32, 32], @@ -248,13 +248,31 @@ class TestModel: lookup_len=1024, device="npu" ), + ExecuteConfig( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="table_wise", + lookup_len=1024, + device="cpu" + ), + ExecuteConfig( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="row_wise", + lookup_len=1024, + device="cpu" + ), ]) -def test_hybrid_hash_embedding(params: ExecuteParams): - if params.device == "cpu" and params.sharding_type == "row_wise": +def test_hybrid_hash_embedding(config: ExecuteConfig): + if config.device == "cpu" and config.sharding_type == "row_wise": return mp.spawn( execute, - args=(params,), + args=(config,), nprocs=WORLD_SIZE, join=True, ) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py index 87d99992..c4bff033 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py @@ -51,7 +51,7 @@ BATCH_NUM = 32 @dataclass -class ExecuteParams: +class ExecuteConfig: world_size: int table_num: int embedding_dims: List[int] @@ -62,15 +62,15 @@ class ExecuteParams: optim: type -def execute(rank: int, params: ExecuteParams): - world_size = params.world_size - table_num = params.table_num - embedding_dims = params.embedding_dims - num_embeddings = params.num_embeddings - sharding_type = params.sharding_type - lookup_len = params.lookup_len - device = params.device - optim = params.optim +def execute(rank: int, config: ExecuteConfig): + world_size = config.world_size + table_num = config.table_num + embedding_dims = config.embedding_dims + num_embeddings = config.num_embeddings + sharding_type = config.sharding_type + lookup_len = config.lookup_len + device = config.device + optim = config.optim setup_logging(rank) logging.info("this test %s", os.path.basename(__file__)) # , batch_num, lookup_lens, num_embeddings, table_num @@ -90,7 +90,7 @@ def execute(rank: int, params: ExecuteParams): pin_memory_device="npu", num_workers=1, ) - embeding_config = [] + embedding_config = [] for i in range(table_num): ec_config = EmbeddingConfig( name=f"table{i}", @@ -99,11 +99,11 @@ def execute(rank: int, params: ExecuteParams): feature_names=[f"feat{i}"], init_fn=weight_init, ) - embeding_config.append(ec_config) + embedding_config.append(ec_config) test_model = TestModel(rank, world_size, device) - gloden_results = test_model.cpu_gloden_loss(embeding_config, dataset_loader_gloden, optim) - test_results = test_model.test_loss(embeding_config, data_loader, sharding_type, optim) + gloden_results = test_model.cpu_gloden_loss(embedding_config, dataset_loader_gloden, optim) + test_results = test_model.test_loss(embedding_config, data_loader, sharding_type, optim) for gloden, result in zip(gloden_results, test_results): logging.debug("") logging.debug("===========================") @@ -134,13 +134,13 @@ class TestModel: @staticmethod def cpu_gloden_loss( - embeding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch], optim + embedding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch], optim ): pg = dist.new_group(backend="gloo") - table_num = len(embeding_config) - ec = EmbeddingCollection(device="cpu", tables=embeding_config) + table_num = len(embedding_config) + ec = EmbeddingCollection(device="cpu", tables=embedding_config) - num_features = sum([c.num_features() for c in embeding_config]) + num_features = sum([c.num_features() for c in embedding_config]) ec = Model(ec, num_features) model = DDP(ec, device_ids=None, process_group=pg) @@ -173,7 +173,7 @@ class TestModel: def test_loss( self, - embeding_config: List[EmbeddingConfig], + embedding_config: List[EmbeddingConfig], dataloader: DataLoader[Batch], sharding_type: str, optim, @@ -182,9 +182,9 @@ class TestModel: host_gp = dist.new_group(backend="gloo") host_env = ShardingEnv(world_size=world_size, rank=rank, pg=host_gp) - table_num = len(embeding_config) - ec = HashEmbeddingCollection(device=self.device, tables=embeding_config) - num_features = sum([c.num_features() for c in embeding_config]) + table_num = len(embedding_config) + ec = HashEmbeddingCollection(device=self.device, tables=embedding_config) + num_features = sum([c.num_features() for c in embedding_config]) ec = Model(ec, num_features) apply_optimizer_in_backward( optimizer_class=optim, @@ -238,8 +238,8 @@ class TestModel: return results -@pytest.mark.parametrize("params", [ - ExecuteParams( +@pytest.mark.parametrize("config", [ + ExecuteConfig( world_size=WORLD_SIZE, table_num=3, embedding_dims=[32, 32, 32], @@ -249,7 +249,7 @@ class TestModel: device="npu", optim=Adagrad, ), - ExecuteParams( + ExecuteConfig( world_size=WORLD_SIZE, table_num=3, embedding_dims=[32, 32, 32], @@ -259,13 +259,33 @@ class TestModel: device="npu", optim=Adagrad, ), + ExecuteConfig( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="table_wise", + lookup_len=1024, + device="cpu", + optim=Adagrad, + ), + ExecuteConfig( + world_size=WORLD_SIZE, + table_num=3, + embedding_dims=[32, 32, 32], + num_embeddings=[400, 4000, 400], + sharding_type="row_wise", + lookup_len=1024, + device="cpu", + optim=Adagrad, + ), ]) -def test_hybrid_pipeline_hash_embedding(params: ExecuteParams): - if params.device == "cpu" and (params.sharding_type == "row_wise" or params.optim == Adam): +def test_hybrid_pipeline_hash_embedding(config: ExecuteConfig): + if config.device == "cpu" and (config.sharding_type == "row_wise" or config.optim == Adam): return mp.spawn( execute, - args=(params,), - nprocs=params.world_size, + args=(config,), + nprocs=config.world_size, join=True, ) -- Gitee From 7dd25dcfd25357a4ac33e38ecd8c7639f67015aa Mon Sep 17 00:00:00 2001 From: wuyangjian Date: Fri, 13 Jun 2025 17:16:32 +0800 Subject: [PATCH 5/7] =?UTF-8?q?test=E4=BB=A3=E7=A0=81=E5=90=88=E5=85=A5par?= =?UTF-8?q?t2,=20=E4=BF=AE=E6=94=B9=E6=A3=80=E8=A7=86=E6=84=8F=E8=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/test_hybrid_embedding.py | 50 +++++------------ .../test/test_hybrid_hash_embedding.py | 50 +++++------------ .../test_hybrid_pipeline_hash_embedding.py | 54 +++++-------------- 3 files changed, 40 insertions(+), 114 deletions(-) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py index 3a5ac785..731d6791 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py @@ -5,9 +5,10 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +import itertools import logging import os +from dataclasses import dataclass from typing import List import pytest @@ -239,43 +240,18 @@ class TestModel: return results +params = { + "table_num": [3], + "embedding_dims": [[32, 32, 32]], + "num_embeddings": [[400, 4000, 400]], + "sharding_type": ["table_wise", "row_wise"], + "lookup_len": [1024], + "device": ["npu"] +} + + @pytest.mark.parametrize("config", [ - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="table_wise", - lookup_len=1024, - device="npu", - ), - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="row_wise", - lookup_len=1024, - device="npu", - ), - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="table_wise", - lookup_len=1024, - device="cpu", - ), - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="row_wise", - lookup_len=1024, - device="cpu", - ), + ExecuteConfig(*v) for v in itertools.product(*params.values()) ]) def test_hybrid_embedding(config: ExecuteConfig): if config.device == "cpu" and config.sharding_type == "row_wise": diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py index ea8d4bb9..efb06586 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py @@ -5,9 +5,10 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +import itertools import logging import os +from dataclasses import dataclass from typing import List import pytest @@ -229,43 +230,18 @@ class TestModel: return results +params = { + "table_num": [3], + "embedding_dims": [[32, 32, 32]], + "num_embeddings": [[400, 4000, 400]], + "sharding_type": ["table_wise", "row_wise"], + "lookup_len": [1024], + "device": ["npu"] +} + + @pytest.mark.parametrize("config", [ - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="table_wise", - lookup_len=1024, - device="npu" - ), - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="row_wise", - lookup_len=1024, - device="npu" - ), - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="table_wise", - lookup_len=1024, - device="cpu" - ), - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="row_wise", - lookup_len=1024, - device="cpu" - ), + ExecuteConfig(*v) for v in itertools.product(*params.values()) ]) def test_hybrid_hash_embedding(config: ExecuteConfig): if config.device == "cpu" and config.sharding_type == "row_wise": diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py index c4bff033..7ac075ed 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py @@ -5,6 +5,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools import logging import os from dataclasses import dataclass @@ -238,47 +239,20 @@ class TestModel: return results +params = { + "table_num": [3], + "embedding_dims": [[32, 32, 32]], + "num_embeddings": [[400, 4000, 400]], + "pool_type": [torchrec.PoolingType.MEAN], + "sharding_type": ["table_wise", "row_wise"], + "lookup_len": [1024], + "device": ["npu"], + "optim": [Adagrad, Adam], +} + + @pytest.mark.parametrize("config", [ - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="table_wise", - lookup_len=1024, - device="npu", - optim=Adagrad, - ), - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="row_wise", - lookup_len=1024, - device="npu", - optim=Adagrad, - ), - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="table_wise", - lookup_len=1024, - device="cpu", - optim=Adagrad, - ), - ExecuteConfig( - world_size=WORLD_SIZE, - table_num=3, - embedding_dims=[32, 32, 32], - num_embeddings=[400, 4000, 400], - sharding_type="row_wise", - lookup_len=1024, - device="cpu", - optim=Adagrad, - ), + ExecuteConfig(*v) for v in itertools.product(*params.values()) ]) def test_hybrid_pipeline_hash_embedding(config: ExecuteConfig): if config.device == "cpu" and (config.sharding_type == "row_wise" or config.optim == Adam): -- Gitee From 8419305b8947daa25d0df45363070120acccce33 Mon Sep 17 00:00:00 2001 From: wuyangjian Date: Mon, 16 Jun 2025 10:53:12 +0800 Subject: [PATCH 6/7] =?UTF-8?q?test=E4=BB=A3=E7=A0=81=E5=90=88=E5=85=A5par?= =?UTF-8?q?t2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torchrec/hybrid_torchrec/test/test_hybrid_embedding.py | 4 ++-- torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py | 4 ++-- .../test/test_hybrid_pipeline_hash_embedding.py | 5 ++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py index 731d6791..ea47cc96 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py @@ -111,8 +111,8 @@ def execute(rank: int, config: ExecuteConfig): for gloden, result in zip(gloden_results, test_results): logging.debug("") logging.debug("===========================") - logging.debug("result test %s", gloden) - logging.debug("gloden test %s", result) + logging.debug("result test %s", result) + logging.debug("gloden test %s", gloden) assert torch.allclose( gloden, result, rtol=1e-04, atol=1e-04 ), "gloden and result is not closed" diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py index efb06586..969980be 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py @@ -103,8 +103,8 @@ def execute(rank: int, config: ExecuteConfig): for gloden, result in zip(gloden_results, test_results): logging.debug("") logging.debug("===========================") - logging.debug("result test %s", gloden) - logging.debug("gloden test %s", result) + logging.debug("result test %s", result) + logging.debug("gloden test %s", gloden) assert torch.allclose( gloden, result, rtol=1e-04, atol=1e-04 ), "gloden and result is not closed" diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py index 7ac075ed..8f48ee80 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py @@ -74,7 +74,6 @@ def execute(rank: int, config: ExecuteConfig): optim = config.optim setup_logging(rank) logging.info("this test %s", os.path.basename(__file__)) - # , batch_num, lookup_lens, num_embeddings, table_num dataset_gloden = RandomRecDataset(BATCH_NUM, lookup_len, num_embeddings, table_num) dataset = RandomRecDataset(BATCH_NUM, lookup_len, num_embeddings, table_num) dataset_loader_gloden = DataLoader( @@ -108,8 +107,8 @@ def execute(rank: int, config: ExecuteConfig): for gloden, result in zip(gloden_results, test_results): logging.debug("") logging.debug("===========================") - logging.debug("result test %s", gloden) - logging.debug("gloden test %s", result) + logging.debug("result test %s", result) + logging.debug("gloden test %s", gloden) assert torch.allclose( gloden, result, rtol=1e-04, atol=1e-04 ), "gloden and result is not closed" -- Gitee From 6254ea72785e6b1be11aa938b342bf29acb68b87 Mon Sep 17 00:00:00 2001 From: wuyangjian Date: Thu, 19 Jun 2025 14:48:10 +0800 Subject: [PATCH 7/7] =?UTF-8?q?test=E6=96=87=E4=BB=B6=E8=81=94=E8=B0=83?= =?UTF-8?q?=E9=97=AE=E9=A2=98=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../hybrid_torchrec/test/test_hybrid_embedding.py | 8 ++------ .../test/test_hybrid_embeddingbag.py | 15 ++------------- .../test/test_hybrid_hash_embedding.py | 1 + .../test/test_hybrid_hash_embeddingbag.py | 1 + .../test/test_hybrid_pipeline_hash_embedding.py | 1 + .../test_hybrid_pipeline_hash_embeddingbag.py | 1 + .../hybrid_torchrec/test/test_train_and_eval.py | 1 + 7 files changed, 9 insertions(+), 19 deletions(-) diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py index ea47cc96..6914fd42 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embedding.py @@ -83,12 +83,7 @@ def execute(rank: int, config: ExecuteConfig): logging.info("this test %s", os.path.basename(__file__)) embedding_config = generate_base_config(embedding_dims, num_embeddings) - dataset = RandomRecDataset( - batch_size=BATCH_NUM, - lookup_len=lookup_len, - num_lookups=[num_embedding // 2 for num_embedding in num_embeddings], - num_tables=table_num, - ) + dataset = RandomRecDataset(BATCH_NUM, lookup_len, [num_embedding // 2 for num_embedding in num_embeddings], table_num) gloden_dataset_loader = DataLoader( dataset, batch_size=None, @@ -241,6 +236,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_embeddingbag.py b/torchrec/hybrid_torchrec/test/test_hybrid_embeddingbag.py index cb819ede..f3df95fe 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_embeddingbag.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_embeddingbag.py @@ -249,6 +249,7 @@ class TestModel: return results params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], @@ -263,26 +264,14 @@ params = { ExecuteConfig(*v) for v in itertools.product(*params.values()) ]) def test_hybrid_embedding_bag(config: ExecuteConfig): - table_num = config.table_num - embedding_dims = config.embedding_dims - num_embeddings = config.num_embeddings - pool_type = config.pool_type sharding_type = config.sharding_type - lookup_len = config.lookup_len device = config.device if device == "cpu" and sharding_type == "row_wise": return mp.spawn( execute, args=( - WORLD_SIZE, - table_num, - embedding_dims, - num_embeddings, - pool_type, - sharding_type, - lookup_len, - device, + config, ), nprocs=WORLD_SIZE, join=True, diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py index 969980be..c020fc94 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embedding.py @@ -231,6 +231,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embeddingbag.py b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embeddingbag.py index 7c8c3521..24a87b78 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_hash_embeddingbag.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_hash_embeddingbag.py @@ -249,6 +249,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py index 8f48ee80..295f1441 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embedding.py @@ -239,6 +239,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [3], "embedding_dims": [[32, 32, 32]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embeddingbag.py b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embeddingbag.py index 40283b9c..1f3d6b03 100644 --- a/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embeddingbag.py +++ b/torchrec/hybrid_torchrec/test/test_hybrid_pipeline_hash_embeddingbag.py @@ -244,6 +244,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [2], "embedding_dims": [[32, 64, 128]], "num_embeddings": [[400, 4000, 400]], diff --git a/torchrec/hybrid_torchrec/test/test_train_and_eval.py b/torchrec/hybrid_torchrec/test/test_train_and_eval.py index 3c33557f..d1dfe0ca 100644 --- a/torchrec/hybrid_torchrec/test/test_train_and_eval.py +++ b/torchrec/hybrid_torchrec/test/test_train_and_eval.py @@ -207,6 +207,7 @@ class TestModel: params = { + "world_size": [WORLD_SIZE], "table_num": [2], "embedding_dims": [[32, 64, 128]], "num_embeddings": [[400, 4000, 400]], -- Gitee