From 6469ca2479cf2dc9a90ab51783e524128a1f8729 Mon Sep 17 00:00:00 2001 From: tanfeng <823018000@qq.com> Date: Sat, 21 Jun 2025 18:38:32 +0800 Subject: [PATCH 1/4] =?UTF-8?q?NV=E7=9A=84GR=E6=A8=A1=E5=9E=8B=E9=80=82?= =?UTF-8?q?=E9=85=8DNPU=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../gr_nv/gr_nv2npu.patch | 1639 +++++++++++++++++ torch_examples_benchmark/gr_nv/run.sh | 96 + 2 files changed, 1735 insertions(+) create mode 100644 torch_examples_benchmark/gr_nv/gr_nv2npu.patch create mode 100644 torch_examples_benchmark/gr_nv/run.sh diff --git a/torch_examples_benchmark/gr_nv/gr_nv2npu.patch b/torch_examples_benchmark/gr_nv/gr_nv2npu.patch new file mode 100644 index 00000000..e8df1cd8 --- /dev/null +++ b/torch_examples_benchmark/gr_nv/gr_nv2npu.patch @@ -0,0 +1,1639 @@ +diff --git a/examples/commons/utils/distributed_utils.py b/examples/commons/utils/distributed_utils.py +index b0b19fe..cff32f1 100644 +--- a/examples/commons/utils/distributed_utils.py ++++ b/examples/commons/utils/distributed_utils.py +@@ -18,7 +18,7 @@ import torch + def collective_assert( + flag: bool, err_msg: str = "", group: torch.distributed.ProcessGroup = None + ): +- flag_tensor = torch.tensor(flag, dtype=torch.bool).cuda() ++ flag_tensor = torch.tensor(flag, dtype=torch.bool).npu() + torch.distributed.all_reduce( + flag_tensor, op=torch.distributed.ReduceOp.MIN, group=group + ) +diff --git a/examples/commons/utils/initialize.py b/examples/commons/utils/initialize.py +index c56f05c..6684b7e 100644 +--- a/examples/commons/utils/initialize.py ++++ b/examples/commons/utils/initialize.py +@@ -16,6 +16,7 @@ import gc + import os + + import torch ++import torch_npu + from megatron.core import parallel_state, tensor_parallel + + +@@ -24,9 +25,9 @@ def initialize_single_rank(): + return + torch.set_printoptions(precision=6, sci_mode=False) + rank = 0 +- device: torch.device = torch.device(f"cuda:{rank}") +- backend = "nccl" +- torch.cuda.set_device(device) ++ device: torch.device = torch.device(f"npu:{rank}") ++ backend = "hccl" ++ torch_npu.npu.set_device(device) + torch.distributed.init_process_group( + backend=backend, init_method="tcp://127.0.0.1:12345", rank=rank, world_size=1 + ) +@@ -37,9 +38,9 @@ def initialize_distributed(): + return + torch.set_printoptions(precision=6, sci_mode=False) + rank = int(os.environ["LOCAL_RANK"]) +- device: torch.device = torch.device(f"cuda:{rank}") +- backend = "nccl" +- torch.cuda.set_device(device) ++ device: torch.device = torch.device(f"npu:{rank}") ++ backend = "hccl" ++ torch_npu.npu.set_device(device) + torch.distributed.init_process_group(backend=backend) + + +@@ -49,11 +50,11 @@ def initialize_model_parallel(tensor_model_parallel_size=1): + parallel_state.initialize_model_parallel( + tensor_model_parallel_size, + ) +- torch.distributed.barrier(device_ids=[torch.cuda.current_device()]) ++ torch.distributed.barrier(device_ids=[torch_npu.npu.current_device()]) + + + def destroy_global_state(): +- torch.distributed.barrier(device_ids=[torch.cuda.current_device()]) ++ torch.distributed.barrier(device_ids=[torch_npu.npu.current_device()]) + torch.distributed.destroy_process_group( + group=parallel_state.get_data_parallel_group(with_context_parallel=True) + ) +@@ -63,7 +64,7 @@ def destroy_global_state(): + ) + parallel_state.destroy_model_parallel() + gc.collect() +- torch.cuda.empty_cache() ++ torch_npu.npu.empty_cache() + + + def set_random_seed(seed_): +@@ -84,16 +85,17 @@ def set_random_seed(seed_): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) +- if torch.cuda.device_count() > 0: +- tensor_parallel.model_parallel_cuda_manual_seed(seed) ++ if torch_npu.npu.device_count() > 0: ++ torch_npu.npu.manual_seed(seed) ++ torch_npu.npu.manual_seed_all(seed) + + # We must maintain an rng state for torchrec, because with different world size, the state evolution differ + # guarantee randomness across DPxTPxCPxPP for embedding-group +- seed = seed + 1234 +- seed = seed + (1000 * parallel_state.get_context_parallel_rank()) +- seed = seed + (10000 * parallel_state.get_tensor_model_parallel_rank()) +- rng_tracker = tensor_parallel.get_cuda_rng_tracker() +- rng_tracker.add("sharded-embedding-group-seed", seed) ++ # seed = seed + 1234 ++ # seed = seed + (1000 * parallel_state.get_context_parallel_rank()) ++ # seed = seed + (10000 * parallel_state.get_tensor_model_parallel_rank()) ++ # rng_tracker = tensor_parallel.get_cuda_rng_tracker() ++ # rng_tracker.add("sharded-embedding-group-seed", seed) + + else: +- raise ValueError("Seed ({}) should be a positive integer.".format(seed)) ++ raise ValueError("Seed ({}) should be a positive integer.".format(seed_)) +diff --git a/examples/hstu/configs/hstu_config.py b/examples/hstu/configs/hstu_config.py +index fee4520..dc16147 100644 +--- a/examples/hstu/configs/hstu_config.py ++++ b/examples/hstu/configs/hstu_config.py +@@ -49,6 +49,7 @@ class KernelBackend(Enum): + TRITON = "TRITON" + PYTORCH = "PYTORCH" + CUTLASS = "CUTLASS" ++ NPU_FUSED = "NPU_FUSED" + + + @dataclass +diff --git a/examples/hstu/dataset/utils.py b/examples/hstu/dataset/utils.py +index ed1b9bc..78f93f0 100644 +--- a/examples/hstu/dataset/utils.py ++++ b/examples/hstu/dataset/utils.py +@@ -16,6 +16,7 @@ from dataclasses import dataclass + from typing import Dict, List, Optional, Union + + import torch ++import torch_npu + from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +@@ -163,7 +164,7 @@ class RankingBatch(Batch): + RankingBatch: The batch on the specified device. + """ + return RankingBatch( +- features=self.features.to(device=device, non_blocking=non_blocking), ++ features=self.features, + batch_size=self.batch_size, + feature_to_max_seqlen=self.feature_to_max_seqlen, + contextual_feature_names=self.contextual_feature_names, +@@ -178,6 +179,10 @@ class RankingBatch(Batch): + labels=self.labels.to(device=device, non_blocking=non_blocking), + ) + ++ def record_stream(self, stream: torch_npu.npu.streams.Stream) -> None: ++ self.num_candidates.record_stream(stream) ++ self.labels.record_stream(stream) ++ + def pin_memory(self) -> "RankingBatch": + """ + Pin the memory of the batch. +diff --git a/examples/hstu/distributed/sharding.py b/examples/hstu/distributed/sharding.py +index 4389a45..ce39753 100644 +--- a/examples/hstu/distributed/sharding.py ++++ b/examples/hstu/distributed/sharding.py +@@ -14,22 +14,33 @@ + # limitations under the License. + + # pyre-strict +-from typing import Any, Dict, List, Set, Tuple, Type, Union ++from typing import Any, Dict, List, Set, Tuple, Type, Union, Optional + ++import os + import torch ++import torch_npu + import torch.distributed as dist + import torchrec + from configs.task_config import OptimizerParam +-from dynamicemb import DynamicEmbTableOptions +-from dynamicemb.planner import DynamicEmbeddingEnumerator +-from dynamicemb.planner import ( +- DynamicEmbeddingShardingPlanner as DynamicEmbeddingShardingPlanner, +-) +-from dynamicemb.planner import DynamicEmbParameterConstraints +-from dynamicemb.shard import ( +- DynamicEmbeddingBagCollectionSharder, +- DynamicEmbeddingCollectionSharder, ++from torchrec.distributed.planner.types import ParamterConstrains ++from torchrec.optim.keyed import CombineOptimizer, KeyeOptimizerWrapper ++from torchrec_embcache.distributed.embedding import EmbCacheEmbeddingCollection ++from torchrec_embcache.distributed.embedding_bag import EmbCacheEmbeddingBagCollection ++# from dynamicemb import DynamicEmbTableOptions ++# from dynamicemb.planner import DynamicEmbeddingEnumerator ++# from dynamicemb.planner import ( ++# DynamicEmbeddingShardingPlanner as DynamicEmbeddingShardingPlanner, ++# ) ++# from dynamicemb.planner import DynamicEmbParameterConstraints ++# from dynamicemb.shard import ( ++# DynamicEmbeddingBagCollectionSharder, ++# DynamicEmbeddingCollectionSharder, ++# ) ++from torchrec_embcache.distributed.sharding.embedding_sharder import ( ++ EmbCacheEmbeddingBagCollectionSharder, ++ EmbCacheEmbeddingCollectionSharder, + ) ++ + from fbgemm_gpu.split_embedding_configs import EmbOptimType, SparseType + from megatron.core import parallel_state, tensor_parallel + from megatron.core.distributed import DistributedDataParallel as DDP +@@ -59,7 +70,7 @@ from torchrec.distributed.fbgemm_qcomm_codec import ( + get_qcomm_codecs_registry, + ) + from torchrec.distributed.model_parallel import DistributedModelParallel +-from torchrec.distributed.planner import Topology ++from torchrec.distributed.planner import Topology, EmbeddingShardingPlanner + from torchrec.distributed.planner.storage_reservations import ( + HeuristicalStorageReservation, + ) +@@ -69,16 +80,23 @@ from torchrec.distributed.types import ( + ShardingEnv, + ShardingType, + ) +-from torchrec.modules.embedding_configs import EmbeddingConfig + from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, + ) + from torchrec.optim.optimizers import in_backward_optimizer_filter ++from hybrid_torchrec import HashEmbeddingCollection, EmbeddingConfig ++from hybrid_torchrec.modules.hash_embeddingbag import HashEmbeddingBagCollection ++ ++from torch_examples_benchmark.dlrm.dlrm_main import with_embcache + + TORCHREC_TYPES: Set[Type[Union[EmbeddingBagCollection, EmbeddingCollection]]] = { + EmbeddingBagCollection, + EmbeddingCollection, ++ EmbCacheEmbeddingBagCollection, ++ EmbCacheEmbeddingCollection, ++ HashEmbeddingCollection, ++ HashEmbeddingBagCollection, + } + + DATA_PARALLEL_EMBEDDING_MODULE_NAME = "_data_parallel_embedding_collection" +@@ -147,66 +165,27 @@ def apply_megatron_ddp( + def get_planner( + eb_configs: List[EmbeddingConfig], + data_parallel_embedding_table_names: Set[str], +- dynamicemb_options_dict: Dict[str, DynamicEmbTableOptions], + device: torch.device, +-): +- constraints = {} +- for config in eb_configs: +- if config.name in data_parallel_embedding_table_names: +- constraint = DynamicEmbParameterConstraints( +- sharding_types=[ +- ShardingType.DATA_PARALLEL.value, +- ], +- bounds_check_mode=BoundsCheckMode.NONE, +- use_dynamicemb=False, +- ) +- elif config.name in dynamicemb_options_dict: +- dynamicemb_options = dynamicemb_options_dict[config.name] +- constraint = DynamicEmbParameterConstraints( +- sharding_types=[ShardingType.ROW_WISE.value], +- bounds_check_mode=BoundsCheckMode.NONE, # dynamic embedding has no bounding! +- enforce_hbm=True, +- use_dynamicemb=True, +- dynamicemb_options=dynamicemb_options, +- ) ++) -> EmbeddingShardingPlanner: ++ constraints: Dict[str, List[str]] = {} ++ for cfg in eb_configs: ++ if cfg.name in data_parallel_embedding_table_names: ++ constraints[cfg.name] = ParamterConstrains(sharding_type=[ShardingType.DATA_PARALLEL.value]) + else: +- constraint = DynamicEmbParameterConstraints( +- sharding_types=[ +- ShardingType.ROW_WISE.value, +- ShardingType.TABLE_WISE.value, +- ShardingType.TABLE_ROW_WISE.value, +- ], +- bounds_check_mode=BoundsCheckMode.NONE, +- use_dynamicemb=False, +- ) +- constraints.update({config.name: constraint}) +- hbm_cap = torch.cuda.get_device_properties(0).total_memory +- ddr_cap = 512 * 1024 * 1024 * 1024 # Assume a Node have 512GB memory +- intra_host_bw = 450e9 # Nvlink bandwidth +- inter_host_bw = 25e9 # NIC bandwidth ++ constraints[cfg.name] = ParamterConstrains(sharding_type=[ShardingType.ROW_WISE.value]) + + topology = Topology( + local_world_size=get_local_size(), + world_size=dist.get_world_size(), + compute_device=device.type, +- hbm_cap=hbm_cap, +- ddr_cap=ddr_cap, # For HVK , if we need to put embedding vector into Host memory , it is important set ddr capacity +- intra_host_bw=intra_host_bw, +- inter_host_bw=inter_host_bw, + ) +- enumerator = DynamicEmbeddingEnumerator( ++ print("constraints:", constraints) ++ return EmbeddingShardingPlanner( + topology=topology, + constraints=constraints, +- ) +- return DynamicEmbeddingShardingPlanner( +- eb_configs=eb_configs, +- topology=topology, +- constraints=constraints, +- enumerator=enumerator, + storage_reservation=HeuristicalStorageReservation(percentage=0.05), + ) + +- + _optimizer_str_to_optim_type = { + "adam": EmbOptimType.ADAM, + "sgd": EmbOptimType.EXACT_SGD, +@@ -252,7 +231,6 @@ def sparse_optimizer_factory_and_class( + + def apply_dmp( + model: torch.nn.Module, +- dynamicemb_options_dict: Dict[str, DynamicEmbTableOptions], + sparse_optimizer_param: OptimizerParam, + pg: torch.distributed.ProcessGroup, + device: torch.device, +@@ -279,7 +257,6 @@ def apply_dmp( + "beta1": sparse_optimizer_param.adam_beta1, + "beta2": sparse_optimizer_param.adam_beta2, + "eps": sparse_optimizer_param.adam_eps, +- # 'weight_decay_mode' : WeightDecayMode.NONE, + "output_dtype": SparseType.FP32, + } + eb_configs = [] +@@ -301,7 +278,6 @@ def apply_dmp( + planner = get_planner( + eb_configs, + set(data_parallel_embedding_table_names), +- dynamicemb_options_dict, + device, + ) + qcomm_codecs_registry = get_qcomm_codecs_registry( +@@ -310,17 +286,36 @@ def apply_dmp( + backward_precision=CommType.FP32, + ) + ) +- sharders = [ +- DynamicEmbeddingBagCollectionSharder( +- qcomm_codecs_registry=qcomm_codecs_registry, +- fused_params=fused_params, +- ), +- DynamicEmbeddingCollectionSharder( +- qcomm_codecs_registry=qcomm_codecs_registry, +- use_index_dedup=True, +- fused_params=fused_params, +- ), +- ] ++ ++ with_embcache = os.getenv("WITH_EMBCACHE", "1") == "1" ++ cpu_device = torch.device("cpu") ++ cpu_pg = dist.new_group(backend="gloo") ++ cpu_env = ShardingEnv.from_process_group(cpu_pg) ++ npu_env = ShardingEnv.from_process_group(pg) ++ if with_embcache: ++ sharders = [ ++ EmbCacheEmbeddingBagCollectionSharder( ++ cpu_device=cpu_device, ++ cpu_env=cpu_env, ++ npu_device=device, ++ npu_env=npu_env, ++ fused_params=fused_params, ++ qcomm_codecs_registry=qcomm_codecs_registry, ++ ), ++ EmbCacheEmbeddingCollectionSharder( ++ cpu_device=cpu_device, ++ cpu_env=cpu_env, ++ npu_device=device, ++ npu_env=npu_env, ++ fused_params=fused_params, ++ qcomm_codecs_registry=qcomm_codecs_registry, ++ ), ++ ] ++ else: ++ from hybrid_torchrec.distributed.sharding_plan import get_default_hybrid_sharders ++ sharders = get_default_hybrid_sharders(cpu_env) ++ ++ + plan = planner.collective_plan(model, sharders, pg) + data_parallel_sharding_plans = [] + for data_parallel_embedding_module_name in data_parallel_embedding_module_names: +@@ -328,15 +323,14 @@ def apply_dmp( + plan.plan.pop(data_parallel_embedding_module_name, None) + ) + # Shard model +- with tensor_parallel.get_cuda_rng_tracker().fork("sharded-embedding-group-seed"): +- model = DistributedModelParallel( +- module=model, +- env=ShardingEnv.from_process_group(pg), +- device=device, +- sharders=sharders, +- plan=plan, +- init_data_parallel=False, +- ) ++ model = DistributedModelParallel( ++ module=model, ++ env=ShardingEnv.from_process_group(pg), ++ device=device, ++ sharders=sharders, ++ plan=plan, ++ init_data_parallel=False, ++ ) + + # Create keyed optimizer + non_fused_sparse_params = {} +@@ -379,17 +373,16 @@ def make_optimizer_and_shard( + config: TransformerConfig, + sparse_optimizer_param: OptimizerParam, + dense_optimizer_param: OptimizerParam, +- dynamicemb_options_dict: Dict[str, DynamicEmbTableOptions] = {}, + device: torch.device = None, + pg: torch.distributed.ProcessGroup = None, + ) -> Tuple[DistributedModelParallel, torch.optim.Optimizer]: + if device is None: +- device = torch.device("cuda", torch.cuda.current_device()) ++ device = torch.device("npu", torch_npu.npu.current_device()) + if pg is None: + pg = dist.group.WORLD + + model = apply_dmp( +- model, dynamicemb_options_dict, sparse_optimizer_param, pg, device ++ model, sparse_optimizer_param, pg, device + ) + model, dense_optimizer = apply_megatron_ddp( + model, config, dense_optimizer_param, device +diff --git a/examples/hstu/model/__init__.py b/examples/hstu/model/__init__.py +index e4fba4b..1ca3e9f 100644 +--- a/examples/hstu/model/__init__.py ++++ b/examples/hstu/model/__init__.py +@@ -1,8 +1,8 @@ + from configs import HSTUConfig, RankingConfig, RetrievalConfig + from model.ranking_gr import RankingGR +-from model.retrieval_gr import RetrievalGR ++# from model.retrieval_gr import RetrievalGR + +-from . import ranking_gr, retrieval_gr ++from . import ranking_gr + + __all__ = ["ranking_gr", "retrieval_gr"] + +@@ -25,19 +25,19 @@ def get_ranking_model( + return RankingGR(hstu_config=hstu_config, task_config=task_config) + + +-def get_retrieval_model( +- hstu_config: HSTUConfig, +- task_config: RetrievalConfig, +-) -> RetrievalGR: +- """ +- Get a retrieval model. +- +- Args: +- hstu_config (HSTUConfig): The HSTU configuration. +- task_config (RetrievalConfig): The retrieval task configuration. +- +- Returns: +- RetrievalGR: The retrieval model. +- """ +- assert isinstance(task_config, RetrievalConfig), "please provide a retrieval config" +- return RetrievalGR(hstu_config=hstu_config, task_config=task_config) ++# def get_retrieval_model( ++# hstu_config: HSTUConfig, ++# task_config: RetrievalConfig, ++# ) -> RetrievalGR: ++# """ ++# Get a retrieval model. ++# ++# Args: ++# hstu_config (HSTUConfig): The HSTU configuration. ++# task_config (RetrievalConfig): The retrieval task configuration. ++# ++# Returns: ++# RetrievalGR: The retrieval model. ++# """ ++# assert isinstance(task_config, RetrievalConfig), "please provide a retrieval config" ++# return RetrievalGR(hstu_config=hstu_config, task_config=task_config) +diff --git a/examples/hstu/model/ranking_gr.py b/examples/hstu/model/ranking_gr.py +index bf1f9ac..151b3df 100644 +--- a/examples/hstu/model/ranking_gr.py ++++ b/examples/hstu/model/ranking_gr.py +@@ -13,10 +13,12 @@ + # See the License for the specific language governing permissions and + # limitations under the License. + from collections import OrderedDict +-from typing import Tuple ++import os ++from typing import List, Tuple + + import torch +-from commons.utils.nvtx_op import output_nvtx_hook ++import torch_npu ++# from commons.utils.nvtx_op import output_nvtx_hook + from configs import HSTUConfig, RankingConfig + from dataset.utils import RankingBatch + from megatron.core import parallel_state +@@ -26,6 +28,16 @@ from modules.hstu_block import HSTUBlock + from modules.metrics import get_multi_event_metric_module + from modules.mlp import MLP + from modules.multi_task_loss_module import MultiTaskLossModule ++from torchrec_embcache.distributed.configs import ( ++ EmbCacheEmbeddingConfig, ++ InintializerType, ++ AdmitAndEvictConfig, ++) ++from torchrec_embcache.distributed.embedding import EmbCacheEmbeddingCollection ++from torchrec.modules.embedding_configs import EmbeddingConfig ++from hybrid_torchrec import HashEmbeddingCollection ++ ++from torch_examples_benchmark.model_zoo.aliccp.step7_gen_spec import multi_hot_fields + + + class RankingGR(BaseModel): +@@ -49,7 +61,7 @@ class RankingGR(BaseModel): + assert ( + self._tp_size == 1 + ), "RankingGR does not support tensor model parallel for now" +- self._device = torch.device("cuda", torch.cuda.current_device()) ++ self._device = torch.device("npu", torch_npu.npu.current_device()) + self._hstu_config = hstu_config + self._task_config = task_config + +@@ -59,7 +71,47 @@ class RankingGR(BaseModel): + ebc_config.dim == self._embedding_dim + ), "hstu layer hidden size should equal to embedding dim" + +- self._embedding_collection = ShardedEmbedding(task_config.embedding_configs) ++ use_embcache = os.getenv("WITH_EMBCACHE", "1") == "1" ++ if use_embcache: ++ cache_cfgs = [ ++ EmbCacheEmbeddingConfig( ++ name=cfg.table_name, ++ feature_names=cfg.feature_names, ++ num_embeddings=cfg.vocab_size, ++ embedding_dim=cfg.dim, ++ initializer_type=InintializerType.UNIFORM, ++ weight_init_mean=0.0, ++ weight_init_stddev=0.05, ++ admit_and_evict_config=AdmitAndEvictConfig(), ++ ) ++ for cfg in task_config.embedding_configs ++ ] ++ world_size = parallel_state.get_data_parallel_world_size() ++ local_batch = int(os.getenv("LOCAL_BATCH_SIZE", "128")) ++ multi_hot_sizes = [1 for _ in cache_cfgs] ++ ++ self._embedding_collection = EmbCacheEmbeddingCollection( ++ tables=cache_cfgs, ++ world_size=world_size, ++ batch_size=local_batch, ++ multi_hot_sizes=multi_hot_sizes, ++ need_indices=False, ++ device=torch.device('meta'), ++ ) ++ else: ++ ec_configs = [ ++ EmbeddingConfig( ++ name=cfg.table_name, ++ feature_names=cfg.feature_names, ++ embedding_dim=cfg.dim, ++ num_embeddings=cfg.vocab_size, ++ ) ++ for cfg in task_config.embedding_configs ++ ] ++ self._embedding_collection = HashEmbeddingCollection( ++ tables=ec_configs, ++ device=torch.device('meta'), ++ ) + + self._hstu_block = HSTUBlock(hstu_config) + self._mlp = MLP( +@@ -125,7 +177,6 @@ class RankingGR(BaseModel): + + return self._mlp(hidden_states.values), batch.labels + +- @output_nvtx_hook(nvtx_tag="RankingModel", backward=False) + def forward( # type: ignore[override] + self, + batch: RankingBatch, +diff --git a/examples/hstu/modules/embedding.py b/examples/hstu/modules/embedding.py +index 2cfd36a..fa19a95 100644 +--- a/examples/hstu/modules/embedding.py ++++ b/examples/hstu/modules/embedding.py +@@ -22,9 +22,6 @@ import torch + import torch.distributed as dist + import torch.nn as nn + from configs.task_config import ShardedEmbeddingConfig +-from dynamicemb.planner import ( +- DynamicEmbeddingShardingPlanner as DynamicEmbeddingShardingPlanner, +-) + from torch import distributed as dist + from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo + from torchrec.distributed.embedding_types import EmbeddingComputeKernel +diff --git a/examples/hstu/modules/hstu_attention.py b/examples/hstu/modules/hstu_attention.py +index 8aa6894..3be7e77 100644 +--- a/examples/hstu/modules/hstu_attention.py ++++ b/examples/hstu/modules/hstu_attention.py +@@ -13,7 +13,7 @@ + # See the License for the specific language governing permissions and + # limitations under the License. + import abc +-from typing import Optional, Union ++from typing import Optional, Union, List + + import torch + from commons.utils.nvtx_op import output_nvtx_hook +@@ -132,6 +132,43 @@ class TorchHSTUAttention(HSTUAttention): + target_group_size=target_group_size, + ).view(-1, self.num_heads * self.linear_dim) + ++class NpuFusedHSTUAttention(torch.nn.Module): ++ def __init__( ++ self, ++ num_heads: int, ++ attention_dim: int, ++ is_causal:bool, ++ ): ++ super().__init__() ++ self.num_heads = num_heads ++ self.attention_dim = attention_dim ++ self.is_causal = is_causal ++ ++ def forward(self, ++ tq: torch.Tensor, ++ tk: torch.Tensor, ++ tv: torch.Tensor, ++ mask_type: int, ++ max_seq_len: int, ++ silu_value: float, ++ seq_offset: Optional[List[int]], ++ attn_bias: Optional[torch.Tensor] = None, ++ mask: Optional[torch.Tensor] = None, ++ ) -> torch.Tensor: ++ import sysconfig ++ torch.npu.config.allow_internal_format = False ++ torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") ++ return torch.ops.mxrec.hstu_dense( ++ tq.view(-1, self.num_heads, self.attention_dim), ++ tk.view(-1, self.num_heads, self.attention_dim), ++ tv.view(-1, self.num_heads, self.linear_dim), ++ mask, ++ attn_bias, ++ 0 if self.is_causal else mask_type, ++ max_seq_len, ++ silu_value, ++ "jagged", ++ seq_offset).view(-1, self.num_heads * self.attention_dim) + + class TritonHSTUAttention(HSTUAttention): + """ +@@ -414,7 +451,9 @@ def create_hstu_attention( + Raises: + ValueError: If the kernel backend is not supported. + """ +- if kernel_backend == KernelBackend.CUTLASS: ++ if kernel_backend == KernelBackend.NPU_FUSED: ++ return NpuFusedHSTUAttention(num_heads, attention_dim, is_causal) ++ elif kernel_backend == KernelBackend.CUTLASS: + sm_major_version = torch.cuda.get_device_properties(0).major + sm_minor_version = torch.cuda.get_device_properties(0).minor + if sm_major_version == 9 and sm_minor_version == 0: +diff --git a/examples/hstu/modules/hstu_block.py b/examples/hstu/modules/hstu_block.py +index 2a2b580..07c72d3 100644 +--- a/examples/hstu/modules/hstu_block.py ++++ b/examples/hstu/modules/hstu_block.py +@@ -3,19 +3,25 @@ + from typing import Dict, Optional, Union + + import torch +-from commons.utils.nvtx_op import output_nvtx_hook ++import torch_npu ++from Tools.scripts.make_ctype import values ++# from commons.utils.nvtx_op import output_nvtx_hook + from configs.hstu_config import HSTUConfig, HSTULayerType + from dataset.utils import RankingBatch, RetrievalBatch + from megatron.core.transformer.module import MegatronModule +-from modules.fused_hstu_layer import FusedHSTULayer ++# from modules.fused_hstu_layer import FusedHSTULayer + from modules.jagged_data import JaggedData + from modules.native_hstu_layer import HSTULayer + from modules.position_encoder import HSTUPositionalEncoder + from ops.jagged_tensor_op import concat_2D_jagged_tensors + from ops.length_to_offsets import length_to_complete_offsets +-from ops.triton_ops.triton_jagged import ( # type: ignore[attr-defined] +- triton_concat_2D_jagged, +- triton_split_2D_jagged, ++# from ops.triton_ops.triton_jagged import ( # type: ignore[attr-defined] ++# triton_concat_2D_jagged, ++# triton_split_2D_jagged, ++# ) ++from ops.pt_ops.pt_jagged_tensors import ( # type: ignore[attr-defined] ++ pytorch_concat_2D_jagged, ++ pytorch_split_2D_jagged, + ) + from torchrec.sparse.jagged_tensor import JaggedTensor + +@@ -49,16 +55,11 @@ class HSTUBlock(MegatronModule): + use_time_encoding=config.position_encoding_config.use_time_encoding, + training_dtype=self._training_dtype, + ) +- HSTULayerImpl = ( +- FusedHSTULayer +- if config.hstu_layer_type == HSTULayerType.FUSED +- else HSTULayer +- ) ++ HSTULayerImpl = HSTULayer + self._attention_layers = torch.nn.ModuleList( + [HSTULayerImpl(config) for l in range(self.config.num_layers)] + ) + +- @output_nvtx_hook(nvtx_tag="hstu_preprocess") + def hstu_preprocess( + self, embeddings: Dict[str, JaggedTensor], batch: RankingBatch + ) -> JaggedData: +@@ -125,12 +126,21 @@ class HSTUBlock(MegatronModule): + contextual_seqlen + ) + +- sequence_embeddings = triton_concat_2D_jagged( +- max_seq_len=contextual_max_seqlen + sequence_max_seqlen, +- values_a=contextual_embedding, +- values_b=sequence_embeddings, +- offsets_a=contextual_seqlen_offsets, +- offsets_b=sequence_embeddings_lengths_offsets, ++ # sequence_embeddings = triton_concat_2D_jagged( ++ # max_seq_len=contextual_max_seqlen + sequence_max_seqlen, ++ # values_a=contextual_embedding, ++ # values_b=sequence_embeddings, ++ # offsets_a=contextual_seqlen_offsets, ++ # offsets_b=sequence_embeddings_lengths_offsets, ++ # ) ++ ++ sequence_embeddings = pytorch_concat_2D_jagged( ++ values_left=contextual_embedding, ++ values_right=sequence_embeddings, ++ max_len_left=None, ++ max_len_right=None, ++ offsets_left=contextual_seqlen_offsets, ++ offsets_right=sequence_embeddings_lengths_offsets, + ) + + sequence_embeddings_lengths = ( +@@ -182,7 +192,6 @@ class HSTUBlock(MegatronModule): + has_interleaved_action=batch.action_feature_name is not None, + ) + +- @output_nvtx_hook(nvtx_tag="hstu_postprocess") + def hstu_postprocess(self, jd: JaggedData) -> JaggedData: + """ + Postprocess the output from the HSTU architecture. +@@ -202,20 +211,24 @@ class HSTUBlock(MegatronModule): + if jd.max_num_candidates > 0: + seqlen_offsets = jd.num_candidates_offsets + max_seqlen = jd.max_num_candidates +- _, sequence_embeddings = triton_split_2D_jagged( +- jd.values, ++ _, sequence_embeddings = pytorch_split_2D_jagged( + jd.max_seqlen, +- offsets_a=jd.seqlen_offsets - jd.num_candidates_offsets, +- offsets_b=seqlen_offsets, ++ jd.values, ++ max_len_left=None, ++ max_len_right=None, ++ offsets_left=jd.seqlen_offsets - jd.num_candidates_offsets, ++ offsets_right=seqlen_offsets, + ) + elif jd.contextual_max_seqlen > 0: + seqlen_offsets = jd.seqlen_offsets - jd.contextual_seqlen_offsets + max_seqlen = jd.max_seqlen - jd.contextual_max_seqlen +- _, sequence_embeddings = triton_split_2D_jagged( +- jd.values, ++ _, sequence_embeddings = pytorch_split_2D_jagged( + jd.max_seqlen, +- offsets_a=jd.contextual_seqlen_offsets, +- offsets_b=seqlen_offsets, ++ jd.values, ++ max_len_left=None, ++ max_len_right=None, ++ offsets_left=jd.contextual_seqlen_offsets, ++ offsets_right=seqlen_offsets, + ) + else: + sequence_embeddings = jd.values +@@ -239,7 +252,6 @@ class HSTUBlock(MegatronModule): + has_interleaved_action=False, + ) + +- @output_nvtx_hook(nvtx_tag="HSTUBlock", hook_tensor_attr_name="values") + def forward( + self, + embeddings: Dict[str, JaggedTensor], +diff --git a/examples/hstu/modules/metrics/metric_modules.py b/examples/hstu/modules/metrics/metric_modules.py +index 398341b..f21f3e0 100644 +--- a/examples/hstu/modules/metrics/metric_modules.py ++++ b/examples/hstu/modules/metrics/metric_modules.py +@@ -23,10 +23,6 @@ from typing import Dict, List, Optional, Tuple + import numpy as np + import torch + import torchmetrics.classification as classification_metrics +-from commons.utils.nvtx_op import output_nvtx_hook +-from dynamicemb.planner import ( +- DynamicEmbeddingShardingPlanner as DynamicEmbeddingShardingPlanner, +-) + from megatron.core import parallel_state + from ops.collective_ops import grouped_allgatherv_tensor_list + +@@ -192,7 +188,6 @@ class MultiClassificationTaskMetric(BaseTaskMetric): + self.training = False + + # return a +- @output_nvtx_hook("ranking metrics", backward=False) + def forward(self, multi_task_logits, targets): + """ + Forward one eval batch, this forward returns None object. +@@ -273,7 +268,6 @@ class RetrievalTaskMetricWithSampling(BaseTaskMetric): + self._cache_target_ids: List[torch.Tensor] = [] + self._chunk_size = 512 + +- @output_nvtx_hook("retrieval metrics", backward=False) + def forward( + self, + query_embeddings: torch.Tensor, # preds, dense embedding tensor +diff --git a/examples/hstu/modules/multi_task_loss_module.py b/examples/hstu/modules/multi_task_loss_module.py +index a22de46..0439a48 100644 +--- a/examples/hstu/modules/multi_task_loss_module.py ++++ b/examples/hstu/modules/multi_task_loss_module.py +@@ -14,7 +14,6 @@ + # limitations under the License. + + import torch +-from commons.utils.nvtx_op import output_nvtx_hook + + + class MultiTaskLossModule(torch.nn.Module): +@@ -44,7 +43,6 @@ class MultiTaskLossModule(torch.nn.Module): + ), "num_tasks should be 1 for multi-class classification" + self._loss_modules.append(torch.nn.CrossEntropyLoss(reduction=reduction)) + +- @output_nvtx_hook(nvtx_tag="loss computation") + def forward(self, merged_logits, labels): + """ + Forward pass of the MultiTaskLossModule. +diff --git a/examples/hstu/modules/native_hstu_layer.py b/examples/hstu/modules/native_hstu_layer.py +index 47d115f..eb6e225 100644 +--- a/examples/hstu/modules/native_hstu_layer.py ++++ b/examples/hstu/modules/native_hstu_layer.py +@@ -13,10 +13,9 @@ + # See the License for the specific language governing permissions and + # limitations under the License. + +-import nvtx + import torch ++import torch_npu + import torch.nn.functional as F +-from commons.utils.nvtx_op import output_nvtx_hook + from configs import HSTUConfig + from configs.hstu_config import HSTULayerType + from megatron.core.transformer.module import MegatronModule +@@ -58,7 +57,7 @@ class HSTULayer(MegatronModule): + self._attention_dim_per_head * self._num_heads, + ] + self._residual = config.residual +- device = torch.cuda.current_device() ++ device = torch_npu.npu.current_device() + if config.learnable_input_layernorm: + self._input_layernorm_weight = torch.nn.Parameter( + torch.ones(self._embedding_dim, device=device) +@@ -127,7 +126,6 @@ class HSTULayer(MegatronModule): + del mixed_uvqk + return user, value, query, key + +- @output_nvtx_hook(nvtx_tag="HSTULayer", hook_tensor_attr_name="values") + def forward(self, jd: JaggedData) -> JaggedData: + """ + Forward pass of the HSTULayer +@@ -140,41 +138,41 @@ class HSTULayer(MegatronModule): + """ + # input is [*, h] + x = jd.values +- with nvtx.annotate("hstu ln+linear_bias+silu fwd", color="RED"): +- normed_x = F.layer_norm( +- x, +- normalized_shape=[self._embedding_dim], +- weight=self._input_layernorm_weight, +- bias=self._input_layernorm_bias, +- eps=self._eps, +- ) +- tu, tv, tq, tk = self.get_user_value_query_key_tensors(normed_x) ++ # with nvtx.annotate("hstu ln+linear_bias+silu fwd", color="RED"): ++ normed_x = F.layer_norm( ++ x, ++ normalized_shape=[self._embedding_dim], ++ weight=self._input_layernorm_weight, ++ bias=self._input_layernorm_bias, ++ eps=self._eps, ++ ) ++ tu, tv, tq, tk = self.get_user_value_query_key_tensors(normed_x) + # TODO: remove contiguous once cutlass backend is ready +- with nvtx.annotate("hstu attn fwd", color="BLUE"): +- jagged_attn_output = self._attn_func( +- tq, +- tk, +- tv, +- jd.seqlen_offsets, +- num_contextuals=jd.contextual_seqlen, +- num_candidates=jd.num_candidates, +- max_seqlen=jd.max_seqlen, +- target_group_size=self._target_group_size, +- ) +- with nvtx.annotate("hstu norm mul dropout fwd", color="GREEN"): +- parallel_input = pytorch_norm_mul_dropout( +- jagged_attn_output, +- tu, +- self._output_layernorm_weight, +- self._output_layernorm_bias, +- self._eps, +- self._dropout_ratio, +- self.training, +- ) +- with nvtx.annotate("hstu linear_residual fwd", color="YELLOW"): +- output = self._linear_proj(parallel_input) +- if self._residual: +- output = output + x ++ # with nvtx.annotate("hstu attn fwd", color="BLUE"): ++ jagged_attn_output = self._attn_func( ++ tq, ++ tk, ++ tv, ++ jd.seqlen_offsets, ++ num_contextuals=jd.contextual_seqlen, ++ num_candidates=jd.num_candidates, ++ max_seqlen=jd.max_seqlen, ++ target_group_size=self._target_group_size, ++ ) ++ # with nvtx.annotate("hstu norm mul dropout fwd", color="GREEN"): ++ parallel_input = pytorch_norm_mul_dropout( ++ jagged_attn_output, ++ tu, ++ self._output_layernorm_weight, ++ self._output_layernorm_bias, ++ self._eps, ++ self._dropout_ratio, ++ self.training, ++ ) ++ # with nvtx.annotate("hstu linear_residual fwd", color="YELLOW"): ++ output = self._linear_proj(parallel_input) ++ if self._residual: ++ output = output + x + return JaggedData( + values=output, + seqlen=jd.seqlen, +diff --git a/examples/hstu/modules/position_encoder.py b/examples/hstu/modules/position_encoder.py +index 122016f..69bfe06 100644 +--- a/examples/hstu/modules/position_encoder.py ++++ b/examples/hstu/modules/position_encoder.py +@@ -34,9 +34,13 @@ from math import sqrt + from typing import Optional + + import torch +-from ops.triton_ops.triton_position import ( # type: ignore[attr-defined] +- triton_add_position_embeddings, +- triton_add_timestamp_positional_embeddings, ++# from ops.triton_ops.triton_position import ( # type: ignore[attr-defined] ++# triton_add_position_embeddings, ++# triton_add_timestamp_positional_embeddings, ++# ) ++from ops.pt_ops.pt_position import ( # type: ignore[attr-defined] ++ pytorch_add_position_embeddings, ++ pytorch_add_timestamp_positional_embeddings, + ) + from torch.fx._symbolic_trace import is_fx_tracing + +@@ -99,7 +103,7 @@ class HSTUPositionalEncoder(torch.nn.Module): + alpha = self._embedding_dim**0.5 + if self._use_time_encoding: + seq_embeddings = seq_embeddings * alpha +- seq_embeddings = triton_add_timestamp_positional_embeddings( ++ seq_embeddings = pytorch_add_timestamp_positional_embeddings( + seq_embeddings=seq_embeddings, + seq_offsets=seq_offsets, + pos_embeddings=self._position_embeddings_weight, +@@ -125,7 +129,7 @@ class HSTUPositionalEncoder(torch.nn.Module): + _, D2 = self._position_embeddings_weight.shape + torch._assert(D2 == D, "wrong dense shape[1]") + +- seq_embeddings = triton_add_position_embeddings( ++ seq_embeddings=pytorch_add_position_embeddings( + jagged=seq_embeddings, + jagged_offsets=seq_offsets, + high_inds=high_inds, +diff --git a/examples/hstu/pretrain_gr_ranking.py b/examples/hstu/pretrain_gr_ranking.py +index 1715109..f1a1a1c 100644 +--- a/examples/hstu/pretrain_gr_ranking.py ++++ b/examples/hstu/pretrain_gr_ranking.py +@@ -12,7 +12,13 @@ + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. ++import os + import warnings ++from pickletools import optimize ++ ++from Tools.scripts.objgraph import ignore ++ ++from torch_examples_benchmark.din.TorchEasyRec.tzrec.main import lib_fbgemm_npu_api_so_path + + # Ignore all FutureWarnings + warnings.filterwarnings("ignore", category=FutureWarning) +@@ -25,6 +31,8 @@ from typing import List, Tuple, cast + import commons.utils.initialize as init + import gin + import torch # pylint: disable-unused-import ++import torch_npu ++import mindspeed.megatron_adaptor #新增代码行 + from commons.utils.logging import print_rank_0 + from configs import RankingConfig + from distributed.sharding import make_optimizer_and_shard +@@ -34,15 +42,15 @@ from utils import ( + OptimizerArgs, + TensorModelParallelArgs, + TrainerArgs, +- create_dynamic_optitons_dict, + create_embedding_config, + create_hstu_config, + create_optimizer_params, + get_data_loader, + get_dataset_and_embedding_args, +- maybe_load_ckpts, + train, + ) ++from megatron.training.initialize import initialize_megatron ++from megatron.training import get_args + + + @gin.configurable +@@ -64,21 +72,7 @@ class RankingArgs: + "gelu", + ], "prediction_head_act_type should be in ['relu', 'gelu']" + +- +-parser = argparse.ArgumentParser( +- description="Distributed GR Arguments", allow_abbrev=False +-) +-parser.add_argument("--gin-config-file", type=str) +-args = parser.parse_args() +-gin.parse_config_file(args.gin_config_file) +-trainer_args = TrainerArgs() +-dataset_args, embedding_args = get_dataset_and_embedding_args() +-network_args = NetworkArgs() +-optimizer_args = OptimizerArgs() +-tp_args = TensorModelParallelArgs() +- +- +-def create_ranking_config() -> RankingConfig: ++def create_ranking_config(network_args, embedding_args) -> RankingConfig: + ranking_args = RankingArgs() + + return RankingConfig( +@@ -93,43 +87,55 @@ def create_ranking_config() -> RankingConfig: + eval_metrics=ranking_args.eval_metrics, + ) + ++def extra_init_args(parser): ++ group = parser.add_argument_group('Distributed GR Arguments') ++ group.add_argument("--gin-config-file", type=str) ++ return parser ++ + + def main(): ++ lib_fbgemm_npu_api_so_path = os.getenv('LIB_FBGEMM_NPU_API_SO_PATH') ++ torch.ops.load_library(lib_fbgemm_npu_api_so_path) ++ ++ initialize_megatron(args_defaults={}, extra_args_provider=extra_init_args, ignore_unknown_args=True) ++ args = get_args() ++ ++ gin.parse_config_file(args.gin_config_file) ++ trainer_args = TrainerArgs() ++ dataset_args, embedding_args = get_dataset_and_embedding_args() ++ network_args = NetworkArgs() ++ optimizer_args = OptimizerArgs() ++ tp_args = TensorModelParallelArgs() ++ + init.initialize_distributed() + init.initialize_model_parallel( + tensor_model_parallel_size=tp_args.tensor_model_parallel_size + ) + init.set_random_seed(trainer_args.seed) +- free_memory, total_memory = torch.cuda.mem_get_info() ++ free_memory, total_memory = torch_npu.npu.mem_get_info() + print_rank_0( + f"distributed env initialization done. Free cuda memory: {free_memory / (1024 ** 2):.2f} MB" + ) + hstu_config = create_hstu_config(network_args) +- task_config = create_ranking_config() ++ task_config = create_ranking_config(network_args, embedding_args) + model = get_ranking_model(hstu_config=hstu_config, task_config=task_config) + +- dynamic_options_dict = create_dynamic_optitons_dict( +- embedding_args, network_args.hidden_size +- ) +- + optimizer_param = create_optimizer_params(optimizer_args) ++ + model_train, dense_optimizer = make_optimizer_and_shard( + model, + config=hstu_config, + sparse_optimizer_param=optimizer_param, + dense_optimizer_param=optimizer_param, +- dynamicemb_options_dict=dynamic_options_dict, + ) + train_dataloader, test_dataloader = get_data_loader( + "ranking", dataset_args, trainer_args, task_config.num_tasks + ) +- free_memory, total_memory = torch.cuda.mem_get_info() ++ free_memory, total_memory = torch_npu.npu.mem_get_info() + print_rank_0( + f"model initialization done, start training. Free cuda memory: {free_memory / (1024 ** 2):.2f} MB" + ) + +- maybe_load_ckpts(trainer_args.ckpt_load_dir, model, dense_optimizer) +- + train( + model_train, + trainer_args, +diff --git a/examples/hstu/utils.py b/examples/hstu/utils.py +index 09b1a55..10aa03c 100644 +--- a/examples/hstu/utils.py ++++ b/examples/hstu/utils.py +@@ -17,17 +17,19 @@ import sys + from dataclasses import dataclass + from functools import partial # pylint: disable-unused-import + from itertools import islice +-from typing import Dict, List, Optional, Tuple, Union ++import time ++from typing import Any, Dict, List, Optional, Tuple, Union + +-import commons.checkpoint as checkpoint ++# import commons.checkpoint as checkpoint + import configs + import dataset + import gin + import torch # pylint: disable-unused-import + import torch.distributed as dist +-from commons.checkpoint import get_unwrapped_module ++import torch_npu ++# from commons.checkpoint import get_unwrapped_module + from commons.utils.distributed_utils import collective_assert +-from commons.utils.gpu_timer import GPUTimer ++# from commons.utils.gpu_timer import GPUTimer + from commons.utils.logging import print_rank_0 + from commons.utils.stringify import stringify_dict + from configs import ( +@@ -37,13 +39,31 @@ from configs import ( + PositionEncodingConfig, + get_hstu_config, + ) +-from dynamicemb import DynamicEmbTableOptions ++# from dynamicemb import DynamicEmbTableOptions + from megatron.core import parallel_state + from megatron.core.distributed import finalize_model_grads + from model import RankingGR, RetrievalGR + from modules.embedding import ShardedEmbeddingConfig + from torchrec.distributed.model_parallel import DistributedModelParallel +- ++from torchrec_embcache.distributed.train_pipline import EmbCacheTrainPiplelineSparseDist ++import os ++from typing import Optional ++ ++from megatron.core.distributed import DistributedDateParllel ++from megatron.core.transformer.module import Float16Module ++from torch import nn ++ ++def get_unwrapped_module(module: nn.Module) -> nn.Module: ++ while( ++ isinstance(module, DistributedModelParallel) ++ or isinstance(module, Float16Module) ++ or isinstance(module, DistributedDateParllel) ++ ): ++ if isinstance(module, DistributedModelParallel): ++ module = module._dmp_wrapped_module ++ else: ++ module = module.module ++ return module + + @gin.configurable + @dataclass +@@ -176,7 +196,7 @@ class NetworkArgs: + "float16", + ], "Only support bfloat16 and float16 precision for Network." + +- assert self.kernel_backend.lower() in ["cutlass", "triton", "pytorch"] ++ assert self.kernel_backend.lower() in ["cutlass", "triton", "pytorch", "npu_fused"] + assert self.layer_type.lower() in ["fused", "native"] + + +@@ -211,6 +231,8 @@ def create_hstu_config(network_args: NetworkArgs): + kernel_backend = KernelBackend.TRITON + elif network_args.kernel_backend == "pytorch": + kernel_backend = KernelBackend.PYTORCH ++ elif network_args.kernel_backend == "npu_fused": ++ kernel_backend = KernelBackend.NPU_FUSED + else: + raise ValueError( + f"Kernel backend {network_args.kernel_backend} is not supported." +@@ -346,39 +368,38 @@ def create_embedding_config( + ) + + +-def create_dynamic_optitons_dict( +- embedding_args_list: List[Union[EmbeddingArgs, DynamicEmbeddingArgs]], +- hidden_size: int, +-) -> Dict[str, DynamicEmbTableOptions]: +- dynamic_options_dict: Dict[str, DynamicEmbTableOptions] = {} +- for embedding_args in embedding_args_list: +- if isinstance(embedding_args, DynamicEmbeddingArgs): +- from dynamicemb import DynamicEmbCheckMode, DynamicEmbEvictStrategy +- +- embedding_args.calculate_and_reset_global_hbm_for_values(hidden_size) +- dynamic_options_dict[embedding_args.table_name] = DynamicEmbTableOptions( +- global_hbm_for_values=embedding_args.global_hbm_for_values, +- evict_strategy=DynamicEmbEvictStrategy.LRU +- if embedding_args.evict_strategy == "lru" +- else DynamicEmbEvictStrategy.LFU, +- safe_check_mode=DynamicEmbCheckMode.IGNORE, +- bucket_capacity=128, +- ) +- return dynamic_options_dict ++# def create_dynamic_optitons_dict( ++# embedding_args_list: List[Union[EmbeddingArgs, DynamicEmbeddingArgs]], ++# hidden_size: int, ++# ) -> Dict[str, DynamicEmbTableOptions]: ++# dynamic_options_dict: Dict[str, DynamicEmbTableOptions] = {} ++# for embedding_args in embedding_args_list: ++# if isinstance(embedding_args, DynamicEmbeddingArgs): ++# from dynamicemb import DynamicEmbCheckMode, DynamicEmbEvictStrategy ++# ++# embedding_args.calculate_and_reset_global_hbm_for_values(hidden_size) ++# dynamic_options_dict[embedding_args.table_name] = DynamicEmbTableOptions( ++# global_hbm_for_values=embedding_args.global_hbm_for_values, ++# evict_strategy=DynamicEmbEvictStrategy.LRU ++# if embedding_args.evict_strategy == "lru" ++# else DynamicEmbEvictStrategy.LFU, ++# safe_check_mode=DynamicEmbCheckMode.IGNORE, ++# bucket_capacity=128, ++# ) ++# return dynamic_options_dict + + + def evaluate( +- model: Union[RankingGR, RetrievalGR], ++ model: RankingGR, + trainer_args: TrainerArgs, + eval_loader: torch.utils.data.DataLoader, + max_eval_iters: Optional[int] = None, + ): + eval_iter = 0 +- torch.cuda.nvtx.range_push(f"#evaluate") + with torch.no_grad(): + # drop last batch + for batch in islice(eval_loader, len(eval_loader)): +- batch = batch.to("cuda") ++ batch = batch.to("npu") + eval_iter += 1 + model.evaluate_one_batch(batch) + if max_eval_iters is not None and eval_iter == max_eval_iters: +@@ -391,44 +412,43 @@ def evaluate( + f"[eval] [eval {eval_iter * dp_size * trainer_args.eval_batch_size} samples]:\n " + + stringify_dict(eval_metric_dict, prefix="Metrics", sep="\n ") + ) +- torch.cuda.nvtx.range_pop() +- +- +-def maybe_load_ckpts( +- ckpt_load_dir: str, +- model: Union[RankingGR, RetrievalGR], +- dense_optimizer: Optional[torch.optim.Optimizer] = None, +-): +- if ckpt_load_dir == "": +- return +- +- assert os.path.exists( +- ckpt_load_dir +- ), f"ckpt_load_dir {ckpt_load_dir} does not exist" +- +- print_rank_0(f"Loading checkpoints from {ckpt_load_dir}") +- checkpoint.load(ckpt_load_dir, model, dense_optimizer=dense_optimizer) +- print_rank_0(f"Checkpoints loaded!!") + + +-def save_ckpts( +- ckpt_save_dir: str, +- model: Union[RankingGR, RetrievalGR], +- dense_optimizer: Optional[torch.optim.Optimizer] = None, +-): +- print_rank_0(f"Saving checkpoints to {ckpt_save_dir}") +- import shutil +- +- if dist.get_rank() == 0: +- if os.path.exists(ckpt_save_dir): +- shutil.rmtree(ckpt_save_dir) +- try: +- os.makedirs(ckpt_save_dir, exist_ok=True) +- except Exception as e: +- raise Exception("can't build path:", ckpt_save_dir) from e +- dist.barrier(device_ids=[torch.cuda.current_device()]) +- checkpoint.save(ckpt_save_dir, model, dense_optimizer=dense_optimizer) +- print_rank_0(f"Checkpoints saved!!") ++# def maybe_load_ckpts( ++# ckpt_load_dir: str, ++# model: Union[RankingGR, RetrievalGR], ++# dense_optimizer: Optional[torch.optim.Optimizer] = None, ++# ): ++# if ckpt_load_dir == "": ++# return ++# ++# assert os.path.exists( ++# ckpt_load_dir ++# ), f"ckpt_load_dir {ckpt_load_dir} does not exist" ++# ++# print_rank_0(f"Loading checkpoints from {ckpt_load_dir}") ++# checkpoint.load(ckpt_load_dir, model, dense_optimizer=dense_optimizer) ++# print_rank_0(f"Checkpoints loaded!!") ++# ++# ++# def save_ckpts( ++# ckpt_save_dir: str, ++# model: Union[RankingGR, RetrievalGR], ++# dense_optimizer: Optional[torch.optim.Optimizer] = None, ++# ): ++# print_rank_0(f"Saving checkpoints to {ckpt_save_dir}") ++# import shutil ++# ++# if dist.get_rank() == 0: ++# if os.path.exists(ckpt_save_dir): ++# shutil.rmtree(ckpt_save_dir) ++# try: ++# os.makedirs(ckpt_save_dir, exist_ok=True) ++# except Exception as e: ++# raise Exception("can't build path:", ckpt_save_dir) from e ++# dist.barrier(device_ids=[torch.cuda.current_device()]) ++# checkpoint.save(ckpt_save_dir, model, dense_optimizer=dense_optimizer) ++# print_rank_0(f"Checkpoints saved!!") + + + def train( +@@ -438,26 +458,13 @@ def train( + eval_loader: torch.utils.data.DataLoader, + dense_optimizer: torch.optim.Optimizer, + ): +- gpu_timer = GPUTimer() + max_train_iters = trainer_args.max_train_iters or len(train_loader) + dp_size = parallel_state.get_data_parallel_world_size() * 1.0 +- gpu_timer.start() +- last_td = 0 ++ last_td = time.time() + # using a tensor on gpu to avoid d2h copy +- tokens_logged = torch.zeros(1).cuda().float() ++ tokens_logged = torch.zeros(1).npu().float() + train_loader_iter = iter(train_loader) + for train_iter in range(max_train_iters): +- if trainer_args.profile and train_iter == trainer_args.profile_step_start: +- torch.cuda.profiler.start() +- if ( +- train_iter * trainer_args.ckpt_save_interval > 0 +- and train_iter % trainer_args.ckpt_save_interval == 0 +- ): +- save_path = os.path.join(trainer_args.ckpt_save_dir, f"iter{train_iter}") +- save_ckpts(save_path, model, dense_optimizer) +- +- torch.cuda.nvtx.range_push(f"step {train_iter}") +- + try: + batch = next(train_loader_iter) + except StopIteration: +@@ -465,7 +472,7 @@ def train( + train_loader_iter = iter(train_loader) + batch = next(train_loader_iter) + +- batch = batch.to("cuda") ++ batch = batch.to("npu") + model.module.zero_grad_buffer() + dense_optimizer.zero_grad() + +@@ -473,7 +480,7 @@ def train( + losses, (_, logits, labels) = model(batch) + collective_assert(not torch.isnan(losses).any(), "loss has nan value") + jagged_size = logits.size(0) +- local_tokens = torch.tensor(jagged_size).cuda().float() ++ local_tokens = torch.tensor(jagged_size).npu().float() + + losses = torch.sum(losses, dim=0) + local_loss = torch.cat([torch.sum(losses).view(1), local_tokens.view(1)]) +@@ -484,10 +491,9 @@ def train( + ) + tokens_logged += reporting_loss[1] + if train_iter >= 0 and train_iter % trainer_args.log_interval == 0: +- gpu_timer.stop() +- cur_td = gpu_timer.elapsed_time() - last_td ++ cur_td = time.time() - last_td + print_rank_0( +- f"[train] [iter {train_iter}, tokens {int(tokens_logged.item())}, elapsed_time {cur_td:.2f} ms]: loss {reporting_loss[0] / reporting_loss[1]:.6f}" ++ f"[train] [iter {train_iter}, tokens {int(tokens_logged.item())}, elapsed_time {cur_td:.2f} s]: loss {reporting_loss[0] / reporting_loss[1]:.6f}" + ) + last_td = cur_td + last_td + tokens_logged.zero_() +@@ -498,9 +504,7 @@ def train( + + # dense gradient allreduce + finalize_model_grads([model.module], None) +- torch.cuda.nvtx.range_push(f"#dense opt") + dense_optimizer.step() +- torch.cuda.nvtx.range_pop() + if train_iter > 0 and train_iter % trainer_args.eval_interval == 0: + model.eval() + evaluate( +@@ -510,10 +514,6 @@ def train( + max_eval_iters=None, + ) + model.train() +- torch.cuda.nvtx.range_pop() +- +- if trainer_args.profile and train_iter == trainer_args.profile_step_end: +- torch.cuda.profiler.stop() + + + def get_dataset_and_embedding_args() -> ( +@@ -531,41 +531,41 @@ def get_dataset_and_embedding_args() -> ( + HASH_SIZE = 10_000_000 + if dataset_args.dataset_name == "kuairand-pure": + return dataset_args, [ +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["user_active_degree"], + table_name="user_active_degree", + item_vocab_size_or_capacity=10, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["follow_user_num_range"], + table_name="follow_user_num_range", + item_vocab_size_or_capacity=9, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["fans_user_num_range"], + table_name="fans_user_num_range", + item_vocab_size_or_capacity=10, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["friend_user_num_range"], + table_name="friend_user_num_range", + item_vocab_size_or_capacity=8, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["register_days_range"], + table_name="register_days_range", + item_vocab_size_or_capacity=8, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["action_weights"], + table_name="action_weights", + item_vocab_size_or_capacity=226, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), + DynamicEmbeddingArgs( + feature_names=["video_id"], +@@ -582,41 +582,41 @@ def get_dataset_and_embedding_args() -> ( + ] + elif dataset_args.dataset_name == "kuairand-1k": + return dataset_args, [ +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["user_active_degree"], + table_name="user_active_degree", + item_vocab_size_or_capacity=8, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["follow_user_num_range"], + table_name="follow_user_num_range", + item_vocab_size_or_capacity=9, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["fans_user_num_range"], + table_name="fans_user_num_range", + item_vocab_size_or_capacity=9, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["friend_user_num_range"], + table_name="friend_user_num_range", + item_vocab_size_or_capacity=8, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["register_days_range"], + table_name="register_days_range", + item_vocab_size_or_capacity=8, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["action_weights"], + table_name="action_weights", + item_vocab_size_or_capacity=233, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), + DynamicEmbeddingArgs( + feature_names=["video_id"], +@@ -633,41 +633,41 @@ def get_dataset_and_embedding_args() -> ( + ] + elif dataset_args.dataset_name == "kuairand-27k": + return dataset_args, [ +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["user_active_degree"], + table_name="user_active_degree", + item_vocab_size_or_capacity=10, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["follow_user_num_range"], + table_name="follow_user_num_range", + item_vocab_size_or_capacity=9, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["fans_user_num_range"], + table_name="fans_user_num_range", + item_vocab_size_or_capacity=10, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["friend_user_num_range"], + table_name="friend_user_num_range", + item_vocab_size_or_capacity=8, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["register_days_range"], + table_name="register_days_range", + item_vocab_size_or_capacity=8, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["action_weights"], + table_name="action_weights", + item_vocab_size_or_capacity=246, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), + DynamicEmbeddingArgs( + feature_names=["video_id"], +@@ -684,35 +684,35 @@ def get_dataset_and_embedding_args() -> ( + ] + elif dataset_args.dataset_name == "ml-1m": + return dataset_args, [ +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["sex"], + table_name="sex", + item_vocab_size_or_capacity=3, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["age_group"], + table_name="age_group", + item_vocab_size_or_capacity=8, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["occupation"], + table_name="occupation", + item_vocab_size_or_capacity=22, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["zip_code"], + table_name="zip_code", + item_vocab_size_or_capacity=3440, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["rating"], + table_name="action_weights", + item_vocab_size_or_capacity=11, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), + DynamicEmbeddingArgs( + feature_names=["movie_id"], +@@ -729,11 +729,11 @@ def get_dataset_and_embedding_args() -> ( + ] + elif dataset_args.dataset_name == "ml-20m": + return dataset_args, [ +- EmbeddingArgs( ++ DynamicEmbeddingArgs( + feature_names=["rating"], + table_name="action_weights", + item_vocab_size_or_capacity=11, +- sharding_type="data_parallel", ++ item_vocab_gpu_capacity_ratio=1.0, + ), + DynamicEmbeddingArgs( + feature_names=["movie_id"], diff --git a/torch_examples_benchmark/gr_nv/run.sh b/torch_examples_benchmark/gr_nv/run.sh new file mode 100644 index 00000000..20a17a33 --- /dev/null +++ b/torch_examples_benchmark/gr_nv/run.sh @@ -0,0 +1,96 @@ +#!/bin/bash +set -e------------------------------------------- +# lib relate +#--------------------------------------------- +source /usr/local/Ascend/ascend-toolkit/set_env.sh +export LD_PRELOAD=/usr/lib64/libgomp.so.1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export CUDA_DEVICE_MAX_CONNECTIONS=1 + + +RECSYS_DIR=$(realpath ../) +HSTU_DIR=$RECSYS_DIR/hstu +MEGATRON_DIR=$RECSYS_DIR/../../../Megatron-LM/ +MINDSPEED_DIR=$RECSYS_DIR/../../../MindSpeed/ +export PYTHONPATH=${PYTHONPATH}:${HSTU_DIR}:${MEGATRON_DIR}:{MINDSPEED_DIR} + +#在当前路径下建立tmp_data: mkdir tmp_data +#将/home/common_user/GR_data下预处理过的数据集软链接到tmp_data: ln -s /home/common_user/GR_data* tmp_data +export LIB_FBGEMM_NPU_API_SO_PATH="/path/to/libfbgemm_npu_api.so" #根据实际情况修改 +#--------------------------------------------- +# speedup +#--------------------------------------------- +export TASK_QUEUE_ENABLE=2 + +# cpu-binding +NPU_NUM=$(npu-smi info|grep 910B|wc -1) +CPU_CORES=$(nproc --all) +CORES_PER_NPU=$((CPU_CORES / NPU_NUM)) +CPU_AFFINITY_CONF_TMP=1 +if [ "$NPU_NUM" -gt 0]; then + for (( i=0; i&1 |tee temp_$(date '+%Y%m%d_%H%M%S').log \ No newline at end of file -- Gitee From 9aa85c7b7b54b5a55571b8ba2698b14e7227a3d7 Mon Sep 17 00:00:00 2001 From: tanfeng <823018000@qq.com> Date: Mon, 23 Jun 2025 21:50:02 +0800 Subject: [PATCH 2/4] =?UTF-8?q?NV=E7=9A=84GR=E6=A8=A1=E5=9E=8B=E9=80=82?= =?UTF-8?q?=E9=85=8DNPU=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../gr_nv/gr_nv2npu.patch | 70 ++++++++++++++----- torch_examples_benchmark/gr_nv/run.sh | 32 +++++---- 2 files changed, 74 insertions(+), 28 deletions(-) diff --git a/torch_examples_benchmark/gr_nv/gr_nv2npu.patch b/torch_examples_benchmark/gr_nv/gr_nv2npu.patch index e8df1cd8..0a3e4a94 100644 --- a/torch_examples_benchmark/gr_nv/gr_nv2npu.patch +++ b/torch_examples_benchmark/gr_nv/gr_nv2npu.patch @@ -110,6 +110,20 @@ index fee4520..dc16147 100644 @dataclass +diff --git a/examples/hstu/dataset/sequence_dataset.py b/examples/hstu/dataset/sequence_dataset.py +index bc07416..2ca7751 100644 +--- a/examples/hstu/dataset/sequence_dataset.py ++++ b/examples/hstu/dataset/sequence_dataset.py +@@ -142,7 +142,8 @@ class SequenceDataset(IterableDataset[Batch]): + + # We do batching in our own + def __len__(self) -> int: +- return math.ceil(self._num_samples / self._global_batch_size) ++ return self._num_samples // self._global_batch_size ++ #return math.ceil(self._num_samples / self._global_batch_size) + + def __iter__(self) -> Iterator[Batch]: + for i in range(len(self)): diff --git a/examples/hstu/dataset/utils.py b/examples/hstu/dataset/utils.py index ed1b9bc..78f93f0 100644 --- a/examples/hstu/dataset/utils.py @@ -143,7 +157,7 @@ index ed1b9bc..78f93f0 100644 """ Pin the memory of the batch. diff --git a/examples/hstu/distributed/sharding.py b/examples/hstu/distributed/sharding.py -index 4389a45..ce39753 100644 +index 4389a45..03828d7 100644 --- a/examples/hstu/distributed/sharding.py +++ b/examples/hstu/distributed/sharding.py @@ -14,22 +14,33 @@ @@ -168,7 +182,7 @@ index 4389a45..ce39753 100644 -from dynamicemb.shard import ( - DynamicEmbeddingBagCollectionSharder, - DynamicEmbeddingCollectionSharder, -+from torchrec.distributed.planner.types import ParamterConstrains ++from torchrec.distributed.planner.types import ParameterConstrains +from torchrec.optim.keyed import CombineOptimizer, KeyeOptimizerWrapper +from torchrec_embcache.distributed.embedding import EmbCacheEmbeddingCollection +from torchrec_embcache.distributed.embedding_bag import EmbCacheEmbeddingBagCollection @@ -254,7 +268,7 @@ index 4389a45..ce39753 100644 + constraints: Dict[str, List[str]] = {} + for cfg in eb_configs: + if cfg.name in data_parallel_embedding_table_names: -+ constraints[cfg.name] = ParamterConstrains(sharding_type=[ShardingType.DATA_PARALLEL.value]) ++ constraints[cfg.name] = ParameterConstrains(sharding_type=[ShardingType.DATA_PARALLEL.value]) else: - constraint = DynamicEmbParameterConstraints( - sharding_types=[ @@ -270,7 +284,7 @@ index 4389a45..ce39753 100644 - ddr_cap = 512 * 1024 * 1024 * 1024 # Assume a Node have 512GB memory - intra_host_bw = 450e9 # Nvlink bandwidth - inter_host_bw = 25e9 # NIC bandwidth -+ constraints[cfg.name] = ParamterConstrains(sharding_type=[ShardingType.ROW_WISE.value]) ++ constraints[cfg.name] = ParameterConstrains(sharding_type=[ShardingType.ROW_WISE.value]) topology = Topology( local_world_size=get_local_size(), @@ -415,6 +429,20 @@ index 4389a45..ce39753 100644 ) model, dense_optimizer = apply_megatron_ddp( model, config, dense_optimizer_param, device +diff --git a/examples/hstu/kuairand_27k_ranking.gin b/examples/hstu/kuairand_27k_ranking.gin +index d6fe667..a1cc8b5 100644 +--- a/examples/hstu/kuairand_27k_ranking.gin ++++ b/examples/hstu/kuairand_27k_ranking.gin +@@ -12,6 +12,9 @@ NetworkArgs.num_layers = 1 + NetworkArgs.num_attention_heads = 4 + NetworkArgs.hidden_size = 128 + NetworkArgs.kv_channels = 128 ++NetworkArgs.kernel_backend = 'npu_fused' ++NetworkArgs.layer_type = 'native' ++ + + RankingArgs.prediction_head_arch = [ + 512, 8 diff --git a/examples/hstu/model/__init__.py b/examples/hstu/model/__init__.py index e4fba4b..1ca3e9f 100644 --- a/examples/hstu/model/__init__.py @@ -979,25 +1007,35 @@ index 122016f..69bfe06 100644 jagged=seq_embeddings, jagged_offsets=seq_offsets, high_inds=high_inds, +diff --git a/examples/hstu/movielen_ranking.gin b/examples/hstu/movielen_ranking.gin +index 314a4e5..a10319b 100644 +--- a/examples/hstu/movielen_ranking.gin ++++ b/examples/hstu/movielen_ranking.gin +@@ -18,8 +18,10 @@ NetworkArgs.num_attention_heads = 4 + NetworkArgs.hidden_size = 128 + NetworkArgs.kv_channels = 128 + NetworkArgs.target_group_size = 1 ++NetworkArgs.kernel_backend = 'npu_fused' ++NetworkArgs.layer_type = 'native' + +-# ratings 0-5 ++# ratings 0-551 + RankingArgs.prediction_head_arch = [512, 10] + RankingArgs.prediction_head_bias = True + RankingArgs.num_tasks = 1 diff --git a/examples/hstu/pretrain_gr_ranking.py b/examples/hstu/pretrain_gr_ranking.py -index 1715109..f1a1a1c 100644 +index 1715109..b95eed7 100644 --- a/examples/hstu/pretrain_gr_ranking.py +++ b/examples/hstu/pretrain_gr_ranking.py -@@ -12,7 +12,13 @@ +@@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import warnings -+from pickletools import optimize -+ -+from Tools.scripts.objgraph import ignore -+ -+from torch_examples_benchmark.din.TorchEasyRec.tzrec.main import lib_fbgemm_npu_api_so_path # Ignore all FutureWarnings - warnings.filterwarnings("ignore", category=FutureWarning) -@@ -25,6 +31,8 @@ from typing import List, Tuple, cast +@@ -25,6 +26,8 @@ from typing import List, Tuple, cast import commons.utils.initialize as init import gin import torch # pylint: disable-unused-import @@ -1006,7 +1044,7 @@ index 1715109..f1a1a1c 100644 from commons.utils.logging import print_rank_0 from configs import RankingConfig from distributed.sharding import make_optimizer_and_shard -@@ -34,15 +42,15 @@ from utils import ( +@@ -34,15 +37,15 @@ from utils import ( OptimizerArgs, TensorModelParallelArgs, TrainerArgs, @@ -1024,7 +1062,7 @@ index 1715109..f1a1a1c 100644 @gin.configurable -@@ -64,21 +72,7 @@ class RankingArgs: +@@ -64,21 +67,7 @@ class RankingArgs: "gelu", ], "prediction_head_act_type should be in ['relu', 'gelu']" @@ -1047,7 +1085,7 @@ index 1715109..f1a1a1c 100644 ranking_args = RankingArgs() return RankingConfig( -@@ -93,43 +87,55 @@ def create_ranking_config() -> RankingConfig: +@@ -93,43 +82,55 @@ def create_ranking_config() -> RankingConfig: eval_metrics=ranking_args.eval_metrics, ) diff --git a/torch_examples_benchmark/gr_nv/run.sh b/torch_examples_benchmark/gr_nv/run.sh index 20a17a33..05f8105d 100644 --- a/torch_examples_benchmark/gr_nv/run.sh +++ b/torch_examples_benchmark/gr_nv/run.sh @@ -1,5 +1,5 @@ #!/bin/bash -set -e------------------------------------------- +# -------------------------------------------- # lib relate #--------------------------------------------- source /usr/local/Ascend/ascend-toolkit/set_env.sh @@ -10,24 +10,24 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 RECSYS_DIR=$(realpath ../) HSTU_DIR=$RECSYS_DIR/hstu -MEGATRON_DIR=$RECSYS_DIR/../../../Megatron-LM/ -MINDSPEED_DIR=$RECSYS_DIR/../../../MindSpeed/ -export PYTHONPATH=${PYTHONPATH}:${HSTU_DIR}:${MEGATRON_DIR}:{MINDSPEED_DIR} +# 根据实际情况设置python引用路径 +MEGATRON_DIR=$RECSYS_DIR/megatron-lm/ +MINDSPEED_DIR=$RECSYS_DIR/MindSpeed/ +export PYTHONPATH=${PYTHONPATH}:${HSTU_DIR}:${MEGATRON_DIR}:${MINDSPEED_DIR} -#在当前路径下建立tmp_data: mkdir tmp_data -#将/home/common_user/GR_data下预处理过的数据集软链接到tmp_data: ln -s /home/common_user/GR_data* tmp_data -export LIB_FBGEMM_NPU_API_SO_PATH="/path/to/libfbgemm_npu_api.so" #根据实际情况修改 +#根据实际情况设置算子适配so文件 +export LIB_FBGEMM_NPU_API_SO_PATH="/path/to/libfbgemm_npu_api.so" #--------------------------------------------- # speedup #--------------------------------------------- export TASK_QUEUE_ENABLE=2 # cpu-binding -NPU_NUM=$(npu-smi info|grep 910B|wc -1) +NPU_NUM=$(npu-smi info|grep 910B|wc -l) CPU_CORES=$(nproc --all) CORES_PER_NPU=$((CPU_CORES / NPU_NUM)) CPU_AFFINITY_CONF_TMP=1 -if [ "$NPU_NUM" -gt 0]; then +if [ "$NPU_NUM" -gt 0 ]; then for (( i=0; i Date: Tue, 24 Jun 2025 11:39:37 +0800 Subject: [PATCH 3/4] =?UTF-8?q?NV=E7=9A=84GR=E6=A8=A1=E5=9E=8B=E9=80=82?= =?UTF-8?q?=E9=85=8DNPU=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../gr_nv/gr_nv2npu.patch | 1124 ++++++++++++++++- torch_examples_benchmark/gr_nv/run.sh | 4 +- 2 files changed, 1102 insertions(+), 26 deletions(-) diff --git a/torch_examples_benchmark/gr_nv/gr_nv2npu.patch b/torch_examples_benchmark/gr_nv/gr_nv2npu.patch index 0a3e4a94..e2f70f83 100644 --- a/torch_examples_benchmark/gr_nv/gr_nv2npu.patch +++ b/torch_examples_benchmark/gr_nv/gr_nv2npu.patch @@ -157,10 +157,10 @@ index ed1b9bc..78f93f0 100644 """ Pin the memory of the batch. diff --git a/examples/hstu/distributed/sharding.py b/examples/hstu/distributed/sharding.py -index 4389a45..03828d7 100644 +index 4389a45..b8ea3cf 100644 --- a/examples/hstu/distributed/sharding.py +++ b/examples/hstu/distributed/sharding.py -@@ -14,22 +14,33 @@ +@@ -14,22 +14,32 @@ # limitations under the License. # pyre-strict @@ -182,8 +182,7 @@ index 4389a45..03828d7 100644 -from dynamicemb.shard import ( - DynamicEmbeddingBagCollectionSharder, - DynamicEmbeddingCollectionSharder, -+from torchrec.distributed.planner.types import ParameterConstrains -+from torchrec.optim.keyed import CombineOptimizer, KeyeOptimizerWrapper ++from torchrec.distributed.planner.types import ParameterConstraints +from torchrec_embcache.distributed.embedding import EmbCacheEmbeddingCollection +from torchrec_embcache.distributed.embedding_bag import EmbCacheEmbeddingBagCollection +# from dynamicemb import DynamicEmbTableOptions @@ -204,7 +203,7 @@ index 4389a45..03828d7 100644 from fbgemm_gpu.split_embedding_configs import EmbOptimType, SparseType from megatron.core import parallel_state, tensor_parallel from megatron.core.distributed import DistributedDataParallel as DDP -@@ -59,7 +70,7 @@ from torchrec.distributed.fbgemm_qcomm_codec import ( +@@ -59,7 +69,7 @@ from torchrec.distributed.fbgemm_qcomm_codec import ( get_qcomm_codecs_registry, ) from torchrec.distributed.model_parallel import DistributedModelParallel @@ -213,7 +212,7 @@ index 4389a45..03828d7 100644 from torchrec.distributed.planner.storage_reservations import ( HeuristicalStorageReservation, ) -@@ -69,16 +80,23 @@ from torchrec.distributed.types import ( +@@ -69,16 +79,22 @@ from torchrec.distributed.types import ( ShardingEnv, ShardingType, ) @@ -226,7 +225,6 @@ index 4389a45..03828d7 100644 +from hybrid_torchrec import HashEmbeddingCollection, EmbeddingConfig +from hybrid_torchrec.modules.hash_embeddingbag import HashEmbeddingBagCollection + -+from torch_examples_benchmark.dlrm.dlrm_main import with_embcache TORCHREC_TYPES: Set[Type[Union[EmbeddingBagCollection, EmbeddingCollection]]] = { EmbeddingBagCollection, @@ -238,7 +236,7 @@ index 4389a45..03828d7 100644 } DATA_PARALLEL_EMBEDDING_MODULE_NAME = "_data_parallel_embedding_collection" -@@ -147,66 +165,27 @@ def apply_megatron_ddp( +@@ -147,66 +163,27 @@ def apply_megatron_ddp( def get_planner( eb_configs: List[EmbeddingConfig], data_parallel_embedding_table_names: Set[str], @@ -268,7 +266,7 @@ index 4389a45..03828d7 100644 + constraints: Dict[str, List[str]] = {} + for cfg in eb_configs: + if cfg.name in data_parallel_embedding_table_names: -+ constraints[cfg.name] = ParameterConstrains(sharding_type=[ShardingType.DATA_PARALLEL.value]) ++ constraints[cfg.name] = ParameterConstraints(sharding_type=[ShardingType.DATA_PARALLEL.value]) else: - constraint = DynamicEmbParameterConstraints( - sharding_types=[ @@ -284,7 +282,7 @@ index 4389a45..03828d7 100644 - ddr_cap = 512 * 1024 * 1024 * 1024 # Assume a Node have 512GB memory - intra_host_bw = 450e9 # Nvlink bandwidth - inter_host_bw = 25e9 # NIC bandwidth -+ constraints[cfg.name] = ParameterConstrains(sharding_type=[ShardingType.ROW_WISE.value]) ++ constraints[cfg.name] = ParameterConstraints(sharding_type=[ShardingType.ROW_WISE.value]) topology = Topology( local_world_size=get_local_size(), @@ -313,7 +311,7 @@ index 4389a45..03828d7 100644 _optimizer_str_to_optim_type = { "adam": EmbOptimType.ADAM, "sgd": EmbOptimType.EXACT_SGD, -@@ -252,7 +231,6 @@ def sparse_optimizer_factory_and_class( +@@ -252,7 +229,6 @@ def sparse_optimizer_factory_and_class( def apply_dmp( model: torch.nn.Module, @@ -321,7 +319,7 @@ index 4389a45..03828d7 100644 sparse_optimizer_param: OptimizerParam, pg: torch.distributed.ProcessGroup, device: torch.device, -@@ -279,7 +257,6 @@ def apply_dmp( +@@ -279,7 +255,6 @@ def apply_dmp( "beta1": sparse_optimizer_param.adam_beta1, "beta2": sparse_optimizer_param.adam_beta2, "eps": sparse_optimizer_param.adam_eps, @@ -329,7 +327,7 @@ index 4389a45..03828d7 100644 "output_dtype": SparseType.FP32, } eb_configs = [] -@@ -301,7 +278,6 @@ def apply_dmp( +@@ -301,7 +276,6 @@ def apply_dmp( planner = get_planner( eb_configs, set(data_parallel_embedding_table_names), @@ -337,7 +335,7 @@ index 4389a45..03828d7 100644 device, ) qcomm_codecs_registry = get_qcomm_codecs_registry( -@@ -310,17 +286,36 @@ def apply_dmp( +@@ -310,17 +284,36 @@ def apply_dmp( backward_precision=CommType.FP32, ) ) @@ -385,7 +383,7 @@ index 4389a45..03828d7 100644 plan = planner.collective_plan(model, sharders, pg) data_parallel_sharding_plans = [] for data_parallel_embedding_module_name in data_parallel_embedding_module_names: -@@ -328,15 +323,14 @@ def apply_dmp( +@@ -328,15 +321,14 @@ def apply_dmp( plan.plan.pop(data_parallel_embedding_module_name, None) ) # Shard model @@ -409,7 +407,7 @@ index 4389a45..03828d7 100644 # Create keyed optimizer non_fused_sparse_params = {} -@@ -379,17 +373,16 @@ def make_optimizer_and_shard( +@@ -379,17 +371,16 @@ def make_optimizer_and_shard( config: TransformerConfig, sparse_optimizer_param: OptimizerParam, dense_optimizer_param: OptimizerParam, @@ -679,16 +677,15 @@ index 8aa6894..3be7e77 100644 sm_minor_version = torch.cuda.get_device_properties(0).minor if sm_major_version == 9 and sm_minor_version == 0: diff --git a/examples/hstu/modules/hstu_block.py b/examples/hstu/modules/hstu_block.py -index 2a2b580..07c72d3 100644 +index 2a2b580..9a502ee 100644 --- a/examples/hstu/modules/hstu_block.py +++ b/examples/hstu/modules/hstu_block.py -@@ -3,19 +3,25 @@ +@@ -3,19 +3,24 @@ from typing import Dict, Optional, Union import torch -from commons.utils.nvtx_op import output_nvtx_hook +import torch_npu -+from Tools.scripts.make_ctype import values +# from commons.utils.nvtx_op import output_nvtx_hook from configs.hstu_config import HSTUConfig, HSTULayerType from dataset.utils import RankingBatch, RetrievalBatch @@ -713,7 +710,7 @@ index 2a2b580..07c72d3 100644 ) from torchrec.sparse.jagged_tensor import JaggedTensor -@@ -49,16 +55,11 @@ class HSTUBlock(MegatronModule): +@@ -49,16 +54,11 @@ class HSTUBlock(MegatronModule): use_time_encoding=config.position_encoding_config.use_time_encoding, training_dtype=self._training_dtype, ) @@ -731,7 +728,7 @@ index 2a2b580..07c72d3 100644 def hstu_preprocess( self, embeddings: Dict[str, JaggedTensor], batch: RankingBatch ) -> JaggedData: -@@ -125,12 +126,21 @@ class HSTUBlock(MegatronModule): +@@ -125,12 +125,21 @@ class HSTUBlock(MegatronModule): contextual_seqlen ) @@ -759,7 +756,7 @@ index 2a2b580..07c72d3 100644 ) sequence_embeddings_lengths = ( -@@ -182,7 +192,6 @@ class HSTUBlock(MegatronModule): +@@ -182,7 +191,6 @@ class HSTUBlock(MegatronModule): has_interleaved_action=batch.action_feature_name is not None, ) @@ -767,7 +764,7 @@ index 2a2b580..07c72d3 100644 def hstu_postprocess(self, jd: JaggedData) -> JaggedData: """ Postprocess the output from the HSTU architecture. -@@ -202,20 +211,24 @@ class HSTUBlock(MegatronModule): +@@ -202,20 +210,24 @@ class HSTUBlock(MegatronModule): if jd.max_num_candidates > 0: seqlen_offsets = jd.num_candidates_offsets max_seqlen = jd.max_num_candidates @@ -800,7 +797,7 @@ index 2a2b580..07c72d3 100644 ) else: sequence_embeddings = jd.values -@@ -239,7 +252,6 @@ class HSTUBlock(MegatronModule): +@@ -239,7 +251,6 @@ class HSTUBlock(MegatronModule): has_interleaved_action=False, ) @@ -1023,6 +1020,1085 @@ index 314a4e5..a10319b 100644 RankingArgs.prediction_head_arch = [512, 10] RankingArgs.prediction_head_bias = True RankingArgs.num_tasks = 1 +diff --git a/examples/hstu/new_utils.py b/examples/hstu/new_utils.py +new file mode 100644 +index 0000000..951617c +--- /dev/null ++++ b/examples/hstu/new_utils.py +@@ -0,0 +1,591 @@ ++ ++ ++ ++from dataclasses import dataclass ++ ++import configs ++import dataset ++import gin ++import torch # pylint: disable-unused-import ++import torch.distributed as dist ++import os ++import sys ++from functools import partial ++from itertools import islice ++from typing import Dict, List, Optional, Tuple, Union ++from configs import ( ++ HSTULayerType, ++ KernelBackend, ++ OptimizerParam, ++ PositionEncodingConfig, ++ get_hstu_config, ++) ++ ++ ++@dataclass ++class BaseEmbeddingArgs: ++ # for dynamic emb, it serves as capacity, while for static emb, it serves as vocab size ++ feature_names: List[str] ++ table_name: str ++ item_vocab_size_or_capacity: int ++ ++ ++@gin.configurable ++@dataclass ++class EmbeddingArgs(BaseEmbeddingArgs): ++ sharding_type: str = "None" ++ ++ def __post_init__(self): ++ assert self.sharding_type.lower() in [ ++ "data_parallel", ++ "model_parallel", ++ ] ++ ++ ++@gin.configurable ++@dataclass ++class DynamicEmbeddingArgs(EmbeddingArgs): ++ # the precedence is global_hbm_for_values > item_vocab_gpu_capacity > item_vocab_gpu_capacity_ratio ++ # without optimizer consideration ++ global_hbm_for_values: Optional[int] = None ++ item_vocab_gpu_capacity: Optional[float] = None ++ item_vocab_gpu_capacity_ratio: Optional[float] = None ++ ++ evict_strategy: str = "lru" ++ ++ def __post_init__(self): ++ assert self.evict_strategy.lower() in ["lru", "lfu"] ++ ++ def calculate_and_reset_global_hbm_for_values(self, hidden_size): ++ if self.global_hbm_for_values is not None: ++ return ++ assert ( ++ self.item_vocab_gpu_capacity_ratio is not None ++ or self.item_vocab_gpu_capacity is not None ++ ), "Please provide either item_vocab_gpu_capacity_ratio or item_vocab_gpu_capacity" ++ if self.item_vocab_gpu_capacity is None: ++ self.item_vocab_gpu_capacity = int( ++ self.item_vocab_size_or_capacity * self.item_vocab_gpu_capacity_ratio ++ ) ++ self.global_hbm_for_values = self.item_vocab_gpu_capacity * hidden_size * 4 ++ ++ ++@gin.configurable ++@dataclass ++class TrainerArgs: ++ # below batchsize is batchsize_per_gpu ++ train_batch_size: int ++ eval_batch_size: int ++ ++ eval_interval: int = 100 ++ log_interval: int = 100 ++ ++ seed: int = 1234 ++ # ==nsys args== ++ profile: bool = False ++ profile_step_start: int = 100 ++ profile_step_end: int = 200 ++ # ==nsys args== ++ max_train_iters: Optional[int] = None ++ max_eval_iters: Optional[int] = None ++ ++ # ckpt args ++ ckpt_save_interval: int = -1 # -1 means not save ckpt ++ ckpt_save_dir: str = "./checkpoints" ++ ckpt_load_dir: str = "" ++ ++ def __post_init__(self): ++ if isinstance(self.max_train_iters, str): ++ self.max_train_iters = int(self.max_train_iters) ++ ++ ++@gin.configurable ++@dataclass ++class DatasetArgs: ++ dataset_name: str ++ max_sequence_length: int ++ max_num_candidates: int = 0 ++ shuffle: bool = False ++ ++ ++@gin.configurable ++@dataclass ++class FeatureArgs: ++ feature_names: List[str] ++ max_sequence_length: int ++ is_jagged: bool = False ++ ++ ++@gin.configurable ++@dataclass ++class BenchmarkDatasetArgs: ++ feature_args: List[FeatureArgs] ++ embedding_args: List[Union[EmbeddingArgs, DynamicEmbeddingArgs]] ++ item_feature_name: str ++ contextual_feature_names: List[str] ++ action_feature_name: Optional[str] = None ++ max_num_candidates: int = 0 ++ ++ ++@gin.configurable ++@dataclass ++class NetworkArgs: ++ num_layers: int ++ hidden_size: int ++ num_attention_heads: int ++ kv_channels: int ++ ++ hidden_dropout: float = 0.2 ++ norm_epsilon: float = 1e-5 ++ is_causal: bool = True ++ ++ dtype_str: str = "bfloat16" ++ ++ kernel_backend: str = "cutlass" ++ layer_type: str = "fused" ++ target_group_size: int = 1 ++ ++ num_position_buckets: int = 8192 ++ ++ def __post_init__(self): ++ assert self.dtype_str in [ ++ "bfloat16", ++ "float16", ++ ], "Only support bfloat16 and float16 precision for Network." ++ ++ assert self.kernel_backend.lower() in ["cutlass", "triton", "pytorch", "npu_fused"] ++ assert self.layer_type.lower() in ["fused", "native"] ++ ++ ++@gin.configurable ++@dataclass ++class OptimizerArgs: ++ optimizer_str: str ++ learning_rate: float ++ adam_beta1: float = 0.9 ++ adam_beta2: float = 0.999 ++ adam_eps: float = 1e-8 ++ ++ ++@gin.configurable ++@dataclass ++class TensorModelParallelArgs: ++ tensor_model_parallel_size: int = 1 ++ ++@dataclass ++class ShardedEmbeddingConfig: ++ """ ++ Configuration for sharded embeddings with sharding type. Inherits from BaseShardedEmbeddingConfig. ++ ++ Args: ++ config (EmbeddingConfig): The embedding configuration. ++ sharding_type (str): The type of sharding, ``'data_parallel'`` | ``'model_parallel'``. ++ """ ++ ++ """ ++ Base configuration for sharded embeddings. ++ ++ Args: ++ feature_names (List[str]): The name of the features in this embedding. ++ table_name (str): The name of the table. ++ vocab_size (int): The size of the vocabulary. ++ dim (int): The dimension size of the embeddings. ++ sharding_type (str): The type of sharding, ``'data_parallel'`` | ``'model_parallel'``. ++ """ ++ ++ feature_names: List[str] ++ table_name: str ++ vocab_size: int ++ dim: int ++ sharding_type: str ++ ++ def __post_init__(self): ++ assert self.sharding_type in [ ++ "data_parallel", ++ "model_parallel", ++ ], "sharding type should be data_parallel or model_parallel" ++ ++def create_embedding_config( ++ hidden_size: int, embedding_args: EmbeddingArgs ++) -> ShardedEmbeddingConfig: ++ if isinstance(embedding_args, DynamicEmbeddingArgs): ++ return configs.ShardedEmbeddingConfig( ++ feature_names=embedding_args.feature_names, ++ table_name=embedding_args.table_name, ++ vocab_size=embedding_args.item_vocab_size_or_capacity, ++ dim=hidden_size, ++ sharding_type="model_parallel", ++ ) ++ return configs.ShardedEmbeddingConfig( ++ feature_names=embedding_args.feature_names, ++ table_name=embedding_args.table_name, ++ vocab_size=embedding_args.item_vocab_size_or_capacity, ++ dim=hidden_size, ++ sharding_type=embedding_args.sharding_type, ++ ) ++ ++def get_dataset_and_embedding_args() -> ( ++ Tuple[ ++ Union[DatasetArgs, BenchmarkDatasetArgs], ++ List[Union[DynamicEmbeddingArgs, EmbeddingArgs]], ++ ] ++): ++ try: ++ dataset_args = DatasetArgs() # type: ignore[call-arg] ++ except: ++ benchmark_dataset_args = BenchmarkDatasetArgs() # type: ignore[call-arg] ++ return benchmark_dataset_args, benchmark_dataset_args.embedding_args ++ assert isinstance(dataset_args, DatasetArgs) ++ HASH_SIZE = 10_000_000 ++ if dataset_args.dataset_name == "kuairand-pure": ++ return dataset_args, [ ++ DynamicEmbeddingArgs( ++ feature_names=["user_active_degree"], ++ table_name="user_active_degree", ++ item_vocab_size_or_capacity=10, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["follow_user_num_range"], ++ table_name="follow_user_num_range", ++ item_vocab_size_or_capacity=9, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["fans_user_num_range"], ++ table_name="fans_user_num_range", ++ item_vocab_size_or_capacity=10, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["friend_user_num_range"], ++ table_name="friend_user_num_range", ++ item_vocab_size_or_capacity=8, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["register_days_range"], ++ table_name="register_days_range", ++ item_vocab_size_or_capacity=8, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["action_weights"], ++ table_name="action_weights", ++ item_vocab_size_or_capacity=226, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["video_id"], ++ table_name="video_id", ++ item_vocab_size_or_capacity=HASH_SIZE, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["user_id"], ++ table_name="user_id", ++ item_vocab_size_or_capacity=HASH_SIZE, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ ] ++ elif dataset_args.dataset_name == "kuairand-1k": ++ return dataset_args, [ ++ DynamicEmbeddingArgs( ++ feature_names=["user_active_degree"], ++ table_name="user_active_degree", ++ item_vocab_size_or_capacity=8, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["follow_user_num_range"], ++ table_name="follow_user_num_range", ++ item_vocab_size_or_capacity=9, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["fans_user_num_range"], ++ table_name="fans_user_num_range", ++ item_vocab_size_or_capacity=9, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["friend_user_num_range"], ++ table_name="friend_user_num_range", ++ item_vocab_size_or_capacity=8, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["register_days_range"], ++ table_name="register_days_range", ++ item_vocab_size_or_capacity=8, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["action_weights"], ++ table_name="action_weights", ++ item_vocab_size_or_capacity=233, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["video_id"], ++ table_name="video_id", ++ item_vocab_size_or_capacity=HASH_SIZE, ++ item_vocab_gpu_capacity_ratio=0.5, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["user_id"], ++ table_name="user_id", ++ item_vocab_size_or_capacity=HASH_SIZE, ++ item_vocab_gpu_capacity_ratio=0.5, ++ ), ++ ] ++ elif dataset_args.dataset_name == "kuairand-27k": ++ return dataset_args, [ ++ DynamicEmbeddingArgs( ++ feature_names=["user_active_degree"], ++ table_name="user_active_degree", ++ item_vocab_size_or_capacity=10, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["follow_user_num_range"], ++ table_name="follow_user_num_range", ++ item_vocab_size_or_capacity=9, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["fans_user_num_range"], ++ table_name="fans_user_num_range", ++ item_vocab_size_or_capacity=10, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["friend_user_num_range"], ++ table_name="friend_user_num_range", ++ item_vocab_size_or_capacity=8, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["register_days_range"], ++ table_name="register_days_range", ++ item_vocab_size_or_capacity=8, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["action_weights"], ++ table_name="action_weights", ++ item_vocab_size_or_capacity=246, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["video_id"], ++ table_name="video_id", ++ item_vocab_size_or_capacity=32038725, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["user_id"], ++ table_name="user_id", ++ item_vocab_size_or_capacity=HASH_SIZE, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ ] ++ elif dataset_args.dataset_name == "ml-1m": ++ return dataset_args, [ ++ DynamicEmbeddingArgs( ++ feature_names=["sex"], ++ table_name="sex", ++ item_vocab_size_or_capacity=3, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["age_group"], ++ table_name="age_group", ++ item_vocab_size_or_capacity=8, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["occupation"], ++ table_name="occupation", ++ item_vocab_size_or_capacity=22, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["zip_code"], ++ table_name="zip_code", ++ item_vocab_size_or_capacity=3440, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["rating"], ++ table_name="action_weights", ++ item_vocab_size_or_capacity=11, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["movie_id"], ++ table_name="movie_id", ++ item_vocab_size_or_capacity=HASH_SIZE, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["user_id"], ++ table_name="user_id", ++ item_vocab_size_or_capacity=HASH_SIZE, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ ] ++ elif dataset_args.dataset_name == "ml-20m": ++ return dataset_args, [ ++ DynamicEmbeddingArgs( ++ feature_names=["rating"], ++ table_name="action_weights", ++ item_vocab_size_or_capacity=11, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["movie_id"], ++ table_name="movie_id", ++ item_vocab_size_or_capacity=HASH_SIZE, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ DynamicEmbeddingArgs( ++ feature_names=["user_id"], ++ table_name="user_id", ++ item_vocab_size_or_capacity=HASH_SIZE, ++ item_vocab_gpu_capacity_ratio=1.0, ++ ), ++ ] ++ else: ++ raise ValueError(f"dataset {dataset_args.dataset_name} is not supported") ++ ++def create_hstu_config(network_args: NetworkArgs): ++ dtype = None ++ if network_args.dtype_str == "bfloat16": ++ dtype = torch.bfloat16 ++ if network_args.dtype_str == "float16": ++ dtype = torch.float16 ++ assert dtype is not None, "dtype not selected. Check your input." ++ ++ kernel_backend = None ++ if network_args.kernel_backend == "cutlass": ++ kernel_backend = KernelBackend.CUTLASS ++ elif network_args.kernel_backend == "triton": ++ kernel_backend = KernelBackend.TRITON ++ elif network_args.kernel_backend == "pytorch": ++ kernel_backend = KernelBackend.PYTORCH ++ elif network_args.kernel_backend == "npu_fused": ++ kernel_backend = KernelBackend.NPU_FUSED ++ else: ++ raise ValueError( ++ f"Kernel backend {network_args.kernel_backend} is not supported." ++ ) ++ layer_type = None ++ if network_args.layer_type == "fused": ++ layer_type = HSTULayerType.FUSED ++ elif network_args.layer_type == "native": ++ layer_type = HSTULayerType.NATIVE ++ else: ++ raise ValueError(f"Layer type {network_args.layer_type} is not supported.") ++ position_encoding_config = PositionEncodingConfig( ++ num_position_buckets=network_args.num_position_buckets, ++ num_time_buckets=2048, ++ use_time_encoding=False, ++ ) ++ return get_hstu_config( ++ hidden_size=network_args.hidden_size, ++ kv_channels=network_args.kv_channels, ++ num_attention_heads=network_args.num_attention_heads, ++ num_layers=network_args.num_layers, ++ hidden_dropout=network_args.hidden_dropout, ++ norm_epsilon=network_args.norm_epsilon, ++ is_causal=network_args.is_causal, ++ dtype=dtype, ++ kernel_backend=kernel_backend, ++ position_encoding_config=position_encoding_config, ++ target_group_size=network_args.target_group_size, ++ hstu_layer_type=layer_type, ++ ) ++ ++ ++def create_optimizer_params(optimizer_args: OptimizerArgs): ++ return OptimizerParam( ++ optimizer_str=optimizer_args.optimizer_str, ++ learning_rate=optimizer_args.learning_rate, ++ adam_beta1=optimizer_args.adam_beta1, ++ adam_beta2=optimizer_args.adam_beta2, ++ adam_eps=optimizer_args.adam_eps, ++ ) ++ ++ ++def get_data_loader( ++ task_type: str, ++ dataset_args: Union[DatasetArgs, BenchmarkDatasetArgs], ++ trainer_args: TrainerArgs, ++ num_tasks: int, ++): ++ assert task_type in [ ++ "ranking", ++ "retrieval", ++ ], f"task type should be ranking or retrieval not {task_type}" ++ if isinstance(dataset_args, BenchmarkDatasetArgs): ++ from dataset.utils import FeatureConfig ++ ++ assert ( ++ trainer_args.max_train_iters is not None ++ and trainer_args.max_eval_iters is not None ++ ), "Benchmark dataset expects max_train_iters and max_eval_iters as num_batches" ++ feature_name_to_max_item_id = {} ++ for e in dataset_args.embedding_args: ++ for feature_name in e.feature_names: ++ feature_name_to_max_item_id[feature_name] = ( ++ sys.maxsize ++ if isinstance(e, DynamicEmbeddingArgs) ++ else e.item_vocab_size_or_capacity ++ ) ++ feature_configs = [] ++ for f in dataset_args.feature_args: ++ feature_configs.append( ++ FeatureConfig( ++ feature_names=f.feature_names, ++ max_item_ids=[ ++ feature_name_to_max_item_id[n] for n in f.feature_names ++ ], ++ max_sequence_length=f.max_sequence_length, ++ is_jagged=f.is_jagged, ++ ) ++ ) ++ ++ kwargs = dict( ++ feature_configs=feature_configs, ++ item_feature_name=dataset_args.item_feature_name, ++ contextual_feature_names=dataset_args.contextual_feature_names, ++ action_feature_name=dataset_args.action_feature_name, ++ max_num_candidates=dataset_args.max_num_candidates, ++ num_generated_batches=100, ++ num_tasks=num_tasks, ++ ) ++ train_dataset = dataset.dummy_dataset.DummySequenceDataset( ++ batch_size=trainer_args.train_batch_size, **kwargs ++ ) ++ test_dataset = dataset.dummy_dataset.DummySequenceDataset( ++ batch_size=trainer_args.eval_batch_size, **kwargs ++ ) ++ else: ++ assert isinstance(dataset_args, DatasetArgs) ++ ( ++ train_dataset, ++ test_dataset, ++ ) = dataset.sequence_dataset.get_dataset( ++ dataset_name=dataset_args.dataset_name, ++ max_sequence_length=dataset_args.max_sequence_length, ++ max_num_candidates=dataset_args.max_num_candidates, ++ num_tasks=num_tasks, ++ batch_size=trainer_args.train_batch_size, ++ rank=dist.get_rank(), ++ world_size=dist.get_world_size(), ++ shuffle=dataset_args.shuffle, ++ random_seed=trainer_args.seed, ++ eval_batch_size=trainer_args.eval_batch_size, ++ ) ++ return dataset.get_data_loader(train_dataset), dataset.get_data_loader(test_dataset) # type: ignore[attr-defined] +diff --git a/examples/hstu/ops/pt_ops/pt_add_position_embeddings.py b/examples/hstu/ops/pt_ops/pt_add_position_embeddings.py +new file mode 100644 +index 0000000..2490317 +--- /dev/null ++++ b/examples/hstu/ops/pt_ops/pt_add_position_embeddings.py +@@ -0,0 +1,16 @@ ++import torch ++from typing import Optional ++ ++def add_position_embeddings(jagged, jagged_offsets, high_inds, max_seq_len, dense, scale=1.0): ++ L, D = jagged.shape ++ B = high_inds.shape[0] ++ out = torch.empty_like(jagged) ++ for b in range(B): ++ start = jagged_offsets[b].item() ++ end = jagged_offsets[b + 1].item() ++ seq_len = end - start ++ ++ pos_ids = torch.arrage(seq_len, device=jagged.device) + high_inds[b].item() ++ pos_emb = dense[pos_ids] ++ out[start:end, :] = jagged[start:end, :] + pos_emb * scale ++ return out +\ No newline at end of file +diff --git a/examples/hstu/ops/pt_ops/pt_concat_2d_jagged.py b/examples/hstu/ops/pt_ops/pt_concat_2d_jagged.py +new file mode 100644 +index 0000000..0061576 +--- /dev/null ++++ b/examples/hstu/ops/pt_ops/pt_concat_2d_jagged.py +@@ -0,0 +1,24 @@ ++import torch ++from typing import Optional, Tuple ++ ++ ++def concat_2D_jagged(values_a: torch.Tensor, ++ offsets_a: torch.Tensor, ++ values_b: torch.Tensor, ++ offsets_b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ++ B = offsets_a.shape[0] - 1 ++ D = values_a.shape[1] ++ out_values = [] ++ out_offsets = [0] ++ for i in range(B): ++ a_start, a_end = offsets_a[i].item(), offsets_a[i + 1].item() ++ b_start, b_end = offsets_b[i].item(), offsets_b[i + 1].item() ++ a_slice = values_a[a_start:a_end] ++ b_slice = values_b[b_start:b_end] ++ out_slice = torch.cat([a_slice, b_slice], dim=0) ++ out_values.append(out_slice) ++ out_offsets.append(out_offsets[-1] + out_slice.shape[0]) ++ out_values = torch.cat(out_values, dim=0) ++ out_offsets = torch.cat(out_offsets, device=values_a.device, dtype=offsets_a.dtype) ++ return out_values, out_offsets ++ +diff --git a/examples/hstu/ops/pt_ops/pt_jagged_tensors.py b/examples/hstu/ops/pt_ops/pt_jagged_tensors.py +new file mode 100644 +index 0000000..11df57b +--- /dev/null ++++ b/examples/hstu/ops/pt_ops/pt_jagged_tensors.py +@@ -0,0 +1,241 @@ ++# Copyright (c) Meta Platforms, Inc. and affiliates. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++#!/usr/bin/env python3 ++ ++# pyre-strict ++ ++from typing import Optional, Tuple ++ ++import torch ++ ++@torch.fx.wrap ++def fx_arange(len: int, device: torch.device) -> torch.Tensor: ++ return torch.arange(len, device=device) ++ ++def _concat_2D_jagged_jagged( ++ values_left: torch.Tensor, ++ values_right: torch.Tensor, ++ max_len_left: int, ++ max_len_right: int, ++ offsets_left: torch.Tensor, ++ offsets_right: torch.Tensor, ++) -> torch.Tensor: ++ max_seq_len = max_len_left + max_len_right ++ lengths_left = offsets_left[1:] - offsets_left[:-1] ++ lengths_right = offsets_right[1:] - offsets_right[:-1] ++ padded_left = torch.ops.fbgemm.jagged_to_padded_dense( ++ values=values_left, ++ offsets=[offsets_left], ++ max_lengths=[max_len_left], ++ padding_value=0.0, ++ ) ++ padded_right = torch.ops.fbgemm.jagged_to_padded_dense( ++ values=values_right, ++ offsets=[offsets_right], ++ max_lengths=[max_len_right], ++ padding_value=0.0, ++ ) ++ concatted_dense = torch.cat([padded_left, padded_right], dim=1) ++ mask = fx_arange(max_seq_len, device=offsets_left.device).view(1, -1) ++ mask = torch.logical_or( ++ mask < lengths_left.view(-1, 1), ++ torch.logical_and( ++ mask >= max_len_left, ++ mask < max_len_left + lengths_right.view(-1, 1), ++ ), ++ ) ++ return concatted_dense.flatten(0, 1)[mask.view(-1), :] ++ ++ ++@torch.fx.wrap ++def pytorch_concat_2D_jagged( ++ values_left: torch.Tensor, ++ values_right: torch.Tensor, ++ max_len_left: Optional[int], ++ max_len_right: Optional[int], ++ offsets_left: Optional[torch.Tensor], ++ offsets_right: Optional[torch.Tensor], ++) -> torch.Tensor: ++ if offsets_left is None: ++ assert max_len_left is not None ++ B = values_left.shape[0] // max_len_left ++ offsets_left_non_optional = max_len_left * torch.arange( ++ B + 1, device=values_left.device ++ ) ++ else: ++ offsets_left_non_optional = offsets_left ++ if offsets_right is None: ++ assert max_len_right is not None ++ B = values_right.shape[0] // max_len_right ++ offsets_right_non_optional = max_len_right * torch.arange( ++ B + 1, device=values_left.device ++ ) ++ else: ++ offsets_right_non_optional = offsets_right ++ max_len_left = ( ++ int( ++ (offsets_left_non_optional[1:] - offsets_left_non_optional[:-1]) ++ .max() ++ .item() ++ ) ++ if max_len_left is None ++ else max_len_left ++ ) ++ max_len_right = ( ++ int( ++ (offsets_right_non_optional[1:] - offsets_right_non_optional[:-1]) ++ .max() ++ .item() ++ ) ++ if max_len_right is None ++ else max_len_right ++ ) ++ return _concat_2D_jagged_jagged( ++ values_left=values_left, ++ values_right=values_right, ++ max_len_left=max_len_left, ++ max_len_right=max_len_right, ++ offsets_left=offsets_left_non_optional, ++ offsets_right=offsets_right_non_optional, ++ ) ++ ++def _split_2D_jagged_jagged( ++ max_seq_len: int, ++ values: torch.Tensor, ++ offsets_left: torch.Tensor, ++ offsets_right: torch.Tensor, ++) -> Tuple[torch.Tensor, torch.Tensor]: ++ offsets = offsets_left + offsets_right ++ padded_values = torch.ops.fbgemm.jagged_to_padded_dense( ++ values=values, ++ offsets=[offsets], ++ max_lengths=[max_seq_len], ++ padding_value=0.0, ++ ).flatten(0, 1) ++ lengths_left = offsets_left[1:] - offsets_left[:-1] ++ lengths_right = offsets_right[1:] - offsets_right[:-1] ++ mask = fx_arange(max_seq_len, device=values.device).view(1, -1) ++ mask_left = mask < lengths_left.view(-1, 1) ++ mask_right = torch.logical_and( ++ mask >= lengths_left.view(-1, 1), ++ mask < (lengths_left + lengths_right).view(-1, 1), ++ ) ++ return padded_values[mask_left.view(-1), :], padded_values[mask_right.view(-1), :] ++ ++ ++@torch.fx.wrap ++def pytorch_split_2D_jagged( ++ max_seq_len: int, ++ values: torch.Tensor, ++ max_len_left: Optional[int], ++ max_len_right: Optional[int], ++ offsets_left: Optional[torch.Tensor], ++ offsets_right: Optional[torch.Tensor], ++) -> Tuple[torch.Tensor, torch.Tensor]: ++ if offsets_left is None: ++ assert max_len_left is not None ++ assert offsets_right is not None ++ offsets_left_non_optional = max_len_left * torch.arange( ++ offsets_right.shape[0], device=values.device ++ ) ++ else: ++ offsets_left_non_optional = offsets_left ++ if offsets_right is None: ++ assert max_len_right is not None ++ assert offsets_left is not None ++ offsets_right_non_optional = max_len_right * torch.arange( ++ offsets_left.shape[0], device=values.device ++ ) ++ else: ++ offsets_right_non_optional = offsets_right ++ return _split_2D_jagged_jagged( ++ max_seq_len=max_seq_len, ++ values=values, ++ offsets_left=offsets_left_non_optional, ++ offsets_right=offsets_right_non_optional, ++ ) ++ ++ ++def pytorch_hstu_split_l2_embeddings( ++ max_seq_len: int, ++ x: torch.Tensor, ++ prefix_offsets: torch.Tensor, ++ l2_offsets: torch.Tensor, ++ contextual_seq_len: int, ++) -> Tuple[torch.Tensor, torch.Tensor]: ++ x_offsets = prefix_offsets + l2_offsets ++ x_lengths = x_offsets[1:] - x_offsets[:-1] ++ padded_x = torch.ops.fbgemm.jagged_to_padded_dense( ++ values=x, ++ offsets=[x_offsets], ++ max_lengths=[max_seq_len], ++ padding_value=0.0, ++ ).flatten(0, 1) ++ prefix_lengths = prefix_offsets[1:] - prefix_offsets[:-1] ++ mask = fx_arange(max_seq_len, device=x_offsets.device).view(1, -1) ++ mask_prefix = torch.logical_and( ++ mask >= contextual_seq_len, ++ mask < prefix_lengths.view(-1, 1) + contextual_seq_len, ++ ) ++ mask_l2 = torch.logical_or( ++ mask < contextual_seq_len, ++ torch.logical_and( ++ mask >= prefix_lengths.view(-1, 1) + contextual_seq_len, ++ mask < x_lengths.view(-1, 1), ++ ), ++ ) ++ return padded_x[mask_prefix.view(-1), :], padded_x[mask_l2.view(-1), :] ++ ++ ++def pytorch_hstu_concat_l2_embeddings( ++ max_prefix_len: int, ++ prefix_x: torch.Tensor, ++ prefix_offsets: torch.Tensor, ++ max_l2_len: int, ++ l2_x: torch.Tensor, ++ l2_offsets: torch.Tensor, ++ contextual_seq_len: int, ++) -> torch.Tensor: ++ padded_prefix_x = torch.ops.fbgemm.jagged_to_padded_dense( ++ values=prefix_x, ++ offsets=[prefix_offsets], ++ max_lengths=[max_prefix_len], ++ padding_value=0.0, ++ ) ++ padded_l2_x = torch.ops.fbgemm.jagged_to_padded_dense( ++ values=l2_x, ++ offsets=[l2_offsets], ++ max_lengths=[max_l2_len], ++ padding_value=0.0, ++ ) ++ padded_x = torch.cat( ++ [ ++ padded_l2_x[:, 0:contextual_seq_len, :], ++ padded_prefix_x, ++ padded_l2_x[:, contextual_seq_len:, :], ++ ], ++ dim=1, ++ ) ++ mask = fx_arange(max_prefix_len + max_l2_len, device=prefix_x.device).view(1, -1) ++ prefix_lengths = prefix_offsets[1:] - prefix_offsets[:-1] ++ l2_lengths = l2_offsets[1:] - l2_offsets[:-1] ++ mask = torch.logical_or( ++ mask < prefix_lengths.view(-1, 1) + contextual_seq_len, ++ torch.logical_and( ++ mask >= max_prefix_len + contextual_seq_len, ++ mask < max_prefix_len + l2_lengths.view(-1, 1), ++ ), ++ ) ++ return padded_x.flatten(0, 1)[mask.view(-1), :] +diff --git a/examples/hstu/ops/pt_ops/pt_position.py b/examples/hstu/ops/pt_ops/pt_position.py +new file mode 100644 +index 0000000..5bd0134 +--- /dev/null ++++ b/examples/hstu/ops/pt_ops/pt_position.py +@@ -0,0 +1,146 @@ ++# Copyright (c) Meta Platforms, Inc. and affiliates. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++#!/usr/bin/env python3 ++ ++# pyre-strict ++ ++from typing import Optional ++ ++import torch ++ ++@torch.fx.wrap ++def fx_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor: ++ assert optional is not None, "Expected optional to be non-None Tensor" ++ return optional ++ ++@torch.fx.wrap ++def torch_arange(end: int, device: torch.device) -> torch.Tensor: ++ return torch.arange(end, device=device) ++ ++def pytorch_add_position_embeddings( ++ jagged: torch.Tensor, ++ jagged_offsets: torch.Tensor, ++ high_inds: torch.Tensor, ++ max_seq_len: int, ++ dense: torch.Tensor, ++ scale: float=1.0 ++) -> torch.Tensor: ++ jagged = jagged * scale ++ B = high_inds.shape[0] ++ col_indices = torch_arange(max_seq_len, device=high_inds.device).expend( ++ B, max_seq_len ++ ) ++ col_indices = torch.clamp(col_indices, max=high_inds.view(-1, 1)) ++ dense_values = torch.index_select(dense, 0, col_indices.reshape(-1)).view( ++ B, max_seq_len, -1 ++ ) ++ return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( ++ jagged, ++ [jagged_offsets], ++ dense_values, ++ )[0] ++ ++ ++@torch.fx.wrap ++def _get_col_indices( ++ max_seq_len: int, ++ max_contextual_seq_len: int, ++ max_pos_ind: int, ++ seq_lengths: torch.Tensor, ++ num_targets: Optional[torch.Tensor], ++ interleave_targets: bool, ++) -> torch.Tensor: ++ B = seq_lengths.size(0) ++ col_indices = torch.arange(max_seq_len, device=seq_lengths.device).expand( ++ B, max_seq_len ++ ) ++ if num_targets is not None: ++ if interleave_targets: ++ high_inds = seq_lengths - fx_unwrap_optional_tensor(num_targets) * 2 ++ else: ++ high_inds = seq_lengths - fx_unwrap_optional_tensor(num_targets) ++ col_indices = torch.clamp(col_indices, max=high_inds.view(-1, 1)) ++ col_indices = high_inds.view(-1, 1) - col_indices ++ else: ++ col_indices = seq_lengths.view(-1, 1) - col_indices ++ col_indices = col_indices + max_contextual_seq_len ++ col_indices = torch.clamp(col_indices, max=max_pos_ind - 1) ++ if max_contextual_seq_len > 0: ++ col_indices[:, :max_contextual_seq_len] = torch.arange( ++ 0, ++ max_contextual_seq_len, ++ device=col_indices.device, ++ dtype=col_indices.dtype, ++ ).view(1, -1) ++ return col_indices ++ ++ ++def pytorch_add_timestamp_positional_embeddings( ++ seq_embeddings: torch.Tensor, ++ seq_offsets: torch.Tensor, ++ pos_embeddings: torch.Tensor, ++ ts_embeddings: torch.Tensor, ++ timestamps: torch.Tensor, ++ max_seq_len: int, ++ max_contextual_seq_len: int, ++ seq_lengths: torch.Tensor, ++ num_targets: Optional[torch.Tensor], ++ interleave_targets: bool, ++ time_bucket_fn: str, ++) -> torch.Tensor: ++ max_pos_ind = pos_embeddings.size(0) ++ # position encoding ++ pos_inds = _get_col_indices( ++ max_seq_len=max_seq_len, ++ max_contextual_seq_len=max_contextual_seq_len, ++ max_pos_ind=max_pos_ind, ++ seq_lengths=seq_lengths, ++ num_targets=num_targets, ++ interleave_targets=interleave_targets, ++ ) ++ B, _ = pos_inds.shape ++ # timestamp encoding ++ num_time_buckets = ts_embeddings.size(1) - 1 ++ time_bucket_increments = 60.0 ++ time_bucket_divisor = 1.0 ++ time_delta = 0 ++ timestamps = timestamps[:, :max_seq_len] ++ query_time = torch.gather( ++ timestamps, dim=1, index=(seq_lengths - 1).unsqueeze(1).clamp(min=0) ++ ) ++ ts = query_time - timestamps ++ ts = ts + time_delta ++ ts = ts.clamp(min=1e-6) / time_bucket_increments ++ if time_bucket_fn == "log": ++ ts = torch.log(ts) ++ else: ++ ts = torch.sqrt(ts) ++ ts = (ts / time_bucket_divisor).clamp(min=0).int() ++ ts = torch.clamp( ++ ts, ++ min=0, ++ max=num_time_buckets, ++ ) ++ position_embeddings = torch.index_select( ++ pos_embeddings, 0, pos_inds.reshape(-1) ++ ).view(B, max_seq_len, -1) ++ time_embeddings = torch.index_select(ts_embeddings, 0, ts.reshape(-1)).view( ++ B, max_seq_len, -1 ++ ) ++ return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( ++ seq_embeddings, ++ [seq_offsets], ++ (time_embeddings + position_embeddings).to(seq_embeddings.dtype), ++ )[0] +diff --git a/examples/hstu/ops/pt_ops/pt_split_2d_jagged.py b/examples/hstu/ops/pt_ops/pt_split_2d_jagged.py +new file mode 100644 +index 0000000..8c1d396 +--- /dev/null ++++ b/examples/hstu/ops/pt_ops/pt_split_2d_jagged.py +@@ -0,0 +1,24 @@ ++import torch ++from typing import Optional, Tuple ++ ++ ++def split_2D_jagged( ++ values: torch.Tensor, ++ offsets_a: torch.Tensor, ++ offsets_b: torch.Tensor ++) -> Tuple[torch.Tensor, torch.Tensor]: ++ L, D = values.shape ++ B = offsets_a.shape[0] - 1 ++ values_a_list = [] ++ values_b_list = [] ++ ++ for i in range(B): ++ a_start = offsets_a[i].item() ++ a_end = offsets_a[i + 1].item() ++ b_start = offsets_b[i].item() ++ b_end = offsets_b[i + 1].item() ++ values_a_list.append(values[a_start:a_end]) ++ values_b_list.append(values[b_start:b_end]) ++ values_a = torch.cat(values_a_list, dim=0) if values_a_list else torch.empty(0, D, device=values.device) ++ values_b = torch.cat(values_b_list, dim=0) if values_b_list else torch.empty(0, D, device=values.device) ++ return values_a, values_b diff --git a/examples/hstu/pretrain_gr_ranking.py b/examples/hstu/pretrain_gr_ranking.py index 1715109..b95eed7 100644 --- a/examples/hstu/pretrain_gr_ranking.py diff --git a/torch_examples_benchmark/gr_nv/run.sh b/torch_examples_benchmark/gr_nv/run.sh index 05f8105d..73342b66 100644 --- a/torch_examples_benchmark/gr_nv/run.sh +++ b/torch_examples_benchmark/gr_nv/run.sh @@ -11,8 +11,8 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 RECSYS_DIR=$(realpath ../) HSTU_DIR=$RECSYS_DIR/hstu # 根据实际情况设置python引用路径 -MEGATRON_DIR=$RECSYS_DIR/megatron-lm/ -MINDSPEED_DIR=$RECSYS_DIR/MindSpeed/ +MEGATRON_DIR=$RECSYS_DIR/../../megatron-lm/ +MINDSPEED_DIR=$RECSYS_DIR/../../MindSpeed/ export PYTHONPATH=${PYTHONPATH}:${HSTU_DIR}:${MEGATRON_DIR}:${MINDSPEED_DIR} #根据实际情况设置算子适配so文件 -- Gitee From df428297d036ca82bd6090751c091d66f10197eb Mon Sep 17 00:00:00 2001 From: tanfeng <823018000@qq.com> Date: Tue, 24 Jun 2025 22:06:46 +0800 Subject: [PATCH 4/4] =?UTF-8?q?NV=E7=9A=84GR=E6=A8=A1=E5=9E=8B=E9=80=82?= =?UTF-8?q?=E9=85=8DNPU=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../gr_nv/gr_nv2npu.patch | 39 +++++++++---------- torch_examples_benchmark/gr_nv/run.sh | 7 ++-- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/torch_examples_benchmark/gr_nv/gr_nv2npu.patch b/torch_examples_benchmark/gr_nv/gr_nv2npu.patch index e2f70f83..ef90ff5e 100644 --- a/torch_examples_benchmark/gr_nv/gr_nv2npu.patch +++ b/torch_examples_benchmark/gr_nv/gr_nv2npu.patch @@ -157,7 +157,7 @@ index ed1b9bc..78f93f0 100644 """ Pin the memory of the batch. diff --git a/examples/hstu/distributed/sharding.py b/examples/hstu/distributed/sharding.py -index 4389a45..b8ea3cf 100644 +index 4389a45..9c37d67 100644 --- a/examples/hstu/distributed/sharding.py +++ b/examples/hstu/distributed/sharding.py @@ -14,22 +14,32 @@ @@ -266,7 +266,7 @@ index 4389a45..b8ea3cf 100644 + constraints: Dict[str, List[str]] = {} + for cfg in eb_configs: + if cfg.name in data_parallel_embedding_table_names: -+ constraints[cfg.name] = ParameterConstraints(sharding_type=[ShardingType.DATA_PARALLEL.value]) ++ constraints[cfg.name] = ParameterConstraints(sharding_types=[ShardingType.DATA_PARALLEL.value]) else: - constraint = DynamicEmbParameterConstraints( - sharding_types=[ @@ -282,7 +282,7 @@ index 4389a45..b8ea3cf 100644 - ddr_cap = 512 * 1024 * 1024 * 1024 # Assume a Node have 512GB memory - intra_host_bw = 450e9 # Nvlink bandwidth - inter_host_bw = 25e9 # NIC bandwidth -+ constraints[cfg.name] = ParameterConstraints(sharding_type=[ShardingType.ROW_WISE.value]) ++ constraints[cfg.name] = ParameterConstraints(sharding_types=[ShardingType.ROW_WISE.value]) topology = Topology( local_world_size=get_local_size(), @@ -493,7 +493,7 @@ index e4fba4b..1ca3e9f 100644 +# assert isinstance(task_config, RetrievalConfig), "please provide a retrieval config" +# return RetrievalGR(hstu_config=hstu_config, task_config=task_config) diff --git a/examples/hstu/model/ranking_gr.py b/examples/hstu/model/ranking_gr.py -index bf1f9ac..151b3df 100644 +index bf1f9ac..c302d2e 100644 --- a/examples/hstu/model/ranking_gr.py +++ b/examples/hstu/model/ranking_gr.py @@ -13,10 +13,12 @@ @@ -511,24 +511,22 @@ index bf1f9ac..151b3df 100644 from configs import HSTUConfig, RankingConfig from dataset.utils import RankingBatch from megatron.core import parallel_state -@@ -26,6 +28,16 @@ from modules.hstu_block import HSTUBlock +@@ -26,6 +28,14 @@ from modules.hstu_block import HSTUBlock from modules.metrics import get_multi_event_metric_module from modules.mlp import MLP from modules.multi_task_loss_module import MultiTaskLossModule +from torchrec_embcache.distributed.configs import ( + EmbCacheEmbeddingConfig, -+ InintializerType, ++ InitializerType, + AdmitAndEvictConfig, +) +from torchrec_embcache.distributed.embedding import EmbCacheEmbeddingCollection +from torchrec.modules.embedding_configs import EmbeddingConfig +from hybrid_torchrec import HashEmbeddingCollection -+ -+from torch_examples_benchmark.model_zoo.aliccp.step7_gen_spec import multi_hot_fields class RankingGR(BaseModel): -@@ -49,7 +61,7 @@ class RankingGR(BaseModel): +@@ -49,7 +59,7 @@ class RankingGR(BaseModel): assert ( self._tp_size == 1 ), "RankingGR does not support tensor model parallel for now" @@ -537,7 +535,7 @@ index bf1f9ac..151b3df 100644 self._hstu_config = hstu_config self._task_config = task_config -@@ -59,7 +71,47 @@ class RankingGR(BaseModel): +@@ -59,7 +69,47 @@ class RankingGR(BaseModel): ebc_config.dim == self._embedding_dim ), "hstu layer hidden size should equal to embedding dim" @@ -586,7 +584,7 @@ index bf1f9ac..151b3df 100644 self._hstu_block = HSTUBlock(hstu_config) self._mlp = MLP( -@@ -125,7 +177,6 @@ class RankingGR(BaseModel): +@@ -125,7 +175,6 @@ class RankingGR(BaseModel): return self._mlp(hidden_states.values), batch.labels @@ -1005,14 +1003,14 @@ index 122016f..69bfe06 100644 jagged_offsets=seq_offsets, high_inds=high_inds, diff --git a/examples/hstu/movielen_ranking.gin b/examples/hstu/movielen_ranking.gin -index 314a4e5..a10319b 100644 +index 314a4e5..fc0c30d 100644 --- a/examples/hstu/movielen_ranking.gin +++ b/examples/hstu/movielen_ranking.gin @@ -18,8 +18,10 @@ NetworkArgs.num_attention_heads = 4 NetworkArgs.hidden_size = 128 NetworkArgs.kv_channels = 128 NetworkArgs.target_group_size = 1 -+NetworkArgs.kernel_backend = 'npu_fused' ++NetworkArgs.kernel_backend = 'pytorch' +NetworkArgs.layer_type = 'native' -# ratings 0-5 @@ -1919,7 +1917,7 @@ index 0000000..11df57b + return padded_x.flatten(0, 1)[mask.view(-1), :] diff --git a/examples/hstu/ops/pt_ops/pt_position.py b/examples/hstu/ops/pt_ops/pt_position.py new file mode 100644 -index 0000000..5bd0134 +index 0000000..91833ea --- /dev/null +++ b/examples/hstu/ops/pt_ops/pt_position.py @@ -0,0 +1,146 @@ @@ -1964,7 +1962,7 @@ index 0000000..5bd0134 +) -> torch.Tensor: + jagged = jagged * scale + B = high_inds.shape[0] -+ col_indices = torch_arange(max_seq_len, device=high_inds.device).expend( ++ col_indices = torch_arange(max_seq_len, device=high_inds.device).expand( + B, max_seq_len + ) + col_indices = torch.clamp(col_indices, max=high_inds.view(-1, 1)) @@ -2228,7 +2226,7 @@ index 1715109..b95eed7 100644 model_train, trainer_args, diff --git a/examples/hstu/utils.py b/examples/hstu/utils.py -index 09b1a55..10aa03c 100644 +index 09b1a55..5b5e3e9 100644 --- a/examples/hstu/utils.py +++ b/examples/hstu/utils.py @@ -17,17 +17,19 @@ import sys @@ -2263,15 +2261,16 @@ index 09b1a55..10aa03c 100644 +# from dynamicemb import DynamicEmbTableOptions from megatron.core import parallel_state from megatron.core.distributed import finalize_model_grads - from model import RankingGR, RetrievalGR +-from model import RankingGR, RetrievalGR ++from model import RankingGR from modules.embedding import ShardedEmbeddingConfig from torchrec.distributed.model_parallel import DistributedModelParallel - -+from torchrec_embcache.distributed.train_pipline import EmbCacheTrainPiplelineSparseDist ++from torchrec_embcache.distributed.train_pipeline import EmbCacheTrainPipelineSparseDist +import os +from typing import Optional + -+from megatron.core.distributed import DistributedDateParllel ++from megatron.core.distributed import DistributedDataParallel +from megatron.core.transformer.module import Float16Module +from torch import nn + @@ -2279,7 +2278,7 @@ index 09b1a55..10aa03c 100644 + while( + isinstance(module, DistributedModelParallel) + or isinstance(module, Float16Module) -+ or isinstance(module, DistributedDateParllel) ++ or isinstance(module, DistributedDataParallel) + ): + if isinstance(module, DistributedModelParallel): + module = module._dmp_wrapped_module diff --git a/torch_examples_benchmark/gr_nv/run.sh b/torch_examples_benchmark/gr_nv/run.sh index 73342b66..eb100009 100644 --- a/torch_examples_benchmark/gr_nv/run.sh +++ b/torch_examples_benchmark/gr_nv/run.sh @@ -80,12 +80,13 @@ TOKENIZER_MODEL=$RECSYS_DIR/llama2-tokenizer.model export WORLD_SIZE=4 export ASCEND_RT_VISIBLE_DEVICE=4,5,6,7 -MICRO_BATCH_SIIZE=8 -GLOBAL_BATCH_SIZE=$((MICRO_BATCH_SIIZE * WORLD_SIZE)) +MICRO_BATCH_SIZE=8 +GLOBAL_BATCH_SIZE=$((MICRO_BATCH_SIZE * WORLD_SIZE)) GPT_ARGS=" - --mico-batch-size ${MICRO_BATCH_SIIZE} \ + --micro-batch-size ${MICRO_BATCH_SIZE} \ --global-batch-size ${GLOBAL_BATCH_SIZE} \ + --num-layers 1 \ --hidden-size 128 \ --num-attention-heads 4 \ --seq-length 8000 \ -- Gitee