diff --git a/README_TORCH.md b/README_TORCH.md index cb10dda4ae755065d86c4bf2b9b5947128241321..7310713c57952047536baeba705a072758443c6a 100644 --- a/README_TORCH.md +++ b/README_TORCH.md @@ -20,7 +20,7 @@ ### 编译环境 参考torchrec/docker/README.md -### Ascend-mindxsdk-torchrec-1.1.0-npu-linux-*.tar.gz +### 编译Ascend-mindxsdk-torchrec-1.1.0-npu-linux-*.tar.gz 参考 RecSDK/torchrec/README.md @@ -33,7 +33,7 @@ tar zxvf Ascend-mindxsdk-torchrec1.1.0-npu-linux-*.tar.gz pip3 install torchrec-1.1.0+npu-py3-none-linux_*.whl ``` -### Ascend-mindxsdk-hybrid-torchrec-1.1.0-linux-*.tar.gz +### 编译Ascend-mindxsdk-hybrid-torchrec-1.1.0-linux-*.tar.gz 参考 RecSDK/torchrec/hybrid_torchrec/README.md @@ -46,7 +46,7 @@ tar zxvf Ascend-mindxsdk-hybrid-torchrec1.1.0-linux-*.tar.gz pip3 install hybrid_torchrec-1.1.0-py3-none-linux_*.whl ``` -### Ascend-mindxsdk-mxrec-add-ons-linux-*.tar.gz +### 编译Ascend-mindxsdk-mxrec-add-ons-linux-*.tar.gz ``` cd RecSDK/mxrec_add_ons/build bash build.sh @@ -65,7 +65,7 @@ bash mxrec_opp_permute2d_sparse_data.run bash mxrec_opp_split_embedding_codegen_forward_unweighted.run ``` -### libfbgemm_npu_api.so +### 编译libfbgemm_npu_api.so ``` cd RecSDK/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/common diff --git a/torchrec/docker/Dockerfile b/torchrec/docker/Dockerfile index 2c9ccdb48555dd6e5dbc59b059a97e0b79bf1e26..179df080c1e3f618efd22be02fd5232addce77ae 100644 --- a/torchrec/docker/Dockerfile +++ b/torchrec/docker/Dockerfile @@ -88,9 +88,9 @@ RUN pip3 install -U pip && \ pip3 install Cython && \ pip3 install absl-py && \ pip3 install gin-config && \ - pip3 install torch-npu==2.6.0.rc1 && \ pip3 install torch-2.6.0+cpu-*.whl && \ pip3 install fbgemm_gpu-1.1.0+cpu-*.whl && \ + pip3 install torch-npu==2.6.0.rc1 && \ pip3 install tensorboard==2.18.0 && \ pip3 install black && \ pip3 install cmake && \ diff --git a/torchrec/hybrid_torchrec/hybrid_torchrec/__init__.py b/torchrec/hybrid_torchrec/hybrid_torchrec/__init__.py index fe107a6c6ca3dc83c7e2e5c85fe80308d414963e..782fca5f06087c8bf56f2e17638633d6c3b73e4d 100644 --- a/torchrec/hybrid_torchrec/hybrid_torchrec/__init__.py +++ b/torchrec/hybrid_torchrec/hybrid_torchrec/__init__.py @@ -9,8 +9,12 @@ import os import sysconfig import torch -from hybrid_torchrec.modules.hash_embeddingbag import HashEmbeddingBagCollection, HashEmbeddingBagConfig, \ - HybridHashTable +from hybrid_torchrec.modules.hash_embeddingbag import ( + HashEmbeddingBagCollection, + HashEmbeddingBagConfig, + HybridHashTable, +) +__all__ = ["HashEmbeddingBagCollection", "HashEmbeddingBagConfig"] -torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") \ No newline at end of file +torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") diff --git a/torchrec/hybrid_torchrec/hybrid_torchrec/constants.py b/torchrec/hybrid_torchrec/hybrid_torchrec/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9f0a18576e3def2751da79d7a191d5ce5f07f320 --- /dev/null +++ b/torchrec/hybrid_torchrec/hybrid_torchrec/constants.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 +# Copyright (c) Huawei Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +MAX_NUM_EMBEDDINGS = 1000000000 +MAX_EMBEDDINGS_DIM = 8192 +EMBEDDINGS_DIM_ALIGNMENT = 8 \ No newline at end of file diff --git a/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/__init__.py b/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/__init__.py index bc78cdeaf1df66d822add45cccab74ee06ec07fb..c320ba0c8fff59dde129ee01d402909fb150bc49 100644 --- a/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/__init__.py +++ b/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/__init__.py @@ -5,3 +5,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from hybrid_train_pipeline import HybridTrainPipelineSparseDist +from sharding_plan import get_default_hybrid_sharders + +__all__ = ["HybridTrainPipelineSparseDist", "get_default_hybrid_sharders"] diff --git a/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/hybrid_train_pipeline.py b/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/hybrid_train_pipeline.py index d10f4d8ce02eda89f551aa83fa3ef3768bb7d433..0e4652cd2da1a790f014748722ee6c7aa24791ad 100644 --- a/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/hybrid_train_pipeline.py +++ b/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/hybrid_train_pipeline.py @@ -248,7 +248,7 @@ class HybridTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): return_loss: bool = False, pipe_n_batch: int = 6, ) -> None: - self.param_check(model, device, pipe_n_batch) + self.param_check(model, device, pipe_n_batch, apply_jit, execute_all_batches) super().__init__(model, optimizer, device, execute_all_batches, apply_jit) self._return_loss = return_loss self._contexts = [[] for _ in range(pipe_n_batch)] @@ -262,6 +262,8 @@ class HybridTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): model: torch.nn.Module, device: torch.device, pipe_n_batch, + apply_jit, + execute_all_batches ): if pipe_n_batch <= 0 or pipe_n_batch > MAX_PIPE_N_BATCH: raise ValueError(f"pipe_n_batch must be in range in [1, {MAX_PIPE_N_BATCH}], \ @@ -280,6 +282,12 @@ class HybridTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): if model.device != device: raise ValueError(f"model device is {model.device}, but input device is {device}.") + + if apply_jit: + raise ValueError(f"apply_jit is not support") + + if not execute_all_batches: + raise ValueError(f"execute_all_batches cant not be false") def enque_context(self, line_id, context: HybridTrainPipelineContext): self._contexts[line_id].append(context) diff --git a/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/sharding_plan.py b/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/sharding_plan.py index d6510890f26d59127ee0de68cbe1c5599fc5123a..516e9b32c2ffd35df13b32ffab458d9c7bdcacc5 100644 --- a/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/sharding_plan.py +++ b/torchrec/hybrid_torchrec/hybrid_torchrec/distributed/sharding_plan.py @@ -16,6 +16,11 @@ from torchrec.distributed.types import ModuleSharder, ShardingEnv def get_default_hybrid_sharders(host_env: ShardingEnv) -> List[ModuleSharder[nn.Module]]: + if host_env.process_group is None: + raise RuntimeError("process_group shold be not None") + if host_env.process_group._get_backend_name() != "gloo": + raise RuntimeError("Rec SDK Torch only support host dit with gloo") + return [ cast(ModuleSharder[nn.Module], HybridEmbeddingBagCollectionSharder(host_env)), cast(ModuleSharder[nn.Module], HybridHashEmbeddingBagCollectionSharder(host_env)), diff --git a/torchrec/hybrid_torchrec/hybrid_torchrec/modules/hash_embeddingbag.py b/torchrec/hybrid_torchrec/hybrid_torchrec/modules/hash_embeddingbag.py index cea2955226d86dddee22e51e8f0fbc7b1f1f9969..2b6fd712d3244e6963e7117ecac244806fb978fc 100644 --- a/torchrec/hybrid_torchrec/hybrid_torchrec/modules/hash_embeddingbag.py +++ b/torchrec/hybrid_torchrec/hybrid_torchrec/modules/hash_embeddingbag.py @@ -14,16 +14,23 @@ from torch import nn from hybrid_torchrec.modules.embedding_config import HYBRID_SUPPORT_DEVICE from hybrid_torchrec.modules.ids_process import IdsMapper +from hybrid_torchrec.constants import ( + MAX_EMBEDDINGS_DIM, + MAX_NUM_EMBEDDINGS, + EMBEDDINGS_DIM_ALIGNMENT, +) from torchrec.modules.embedding_configs import ( DataType, EmbeddingBagConfig, pooling_type_to_str, + PoolingType, ) from torchrec.modules.embedding_modules import ( EmbeddingBagCollectionInterface, get_embedding_names_by_table, ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +from torchrec.types import DataType @torch.fx.wrap @@ -81,6 +88,66 @@ class HashEmbeddingBagConfig(EmbeddingBagConfig): pass +def is_valid_feat_name(feat_name): + for char in feat_name: + if not (char.isalnum() or char == "_"): + return False + return True + + +def check_embedding_config_valid(config: HashEmbeddingBagConfig): + if config.embedding_dim % EMBEDDINGS_DIM_ALIGNMENT != 0: + raise ValueError( + f"The embedding dim should be a multiple of 8, but is {config.embedding_dim}" + ) + if ( + config.embedding_dim < EMBEDDINGS_DIM_ALIGNMENT + or config.embedding_dim > MAX_EMBEDDINGS_DIM + ): + raise ValueError( + f"The embedding dim should be in [{EMBEDDINGS_DIM_ALIGNMENT}, " + "{MAX_EMBEDDINGS_DIM}], but is {config.embedding_dim}" + ) + if config.num_embeddings < 1 or config.num_embeddings > MAX_NUM_EMBEDDINGS: + raise ValueError( + f"The embedding dim should be in [1, {MAX_NUM_EMBEDDINGS}], but is {config.num_embeddings}" + ) + if config.data_type != DataType.FP32: + raise ValueError(f"The data_type should be FP32, but is {config.data_type}") + if config.feature_names is None or len(config.feature_names) == 0: + raise ValueError( + f"The feature_names should not be empty, but is {config.feature_names}" + ) + for feat_name in config.feature_names: + if not is_valid_feat_name(feat_name): + raise ValueError( + f"The feature_name should contain a-Z, 0-9, _, but is {feat_name}" + ) + if config.weight_init_max is not None: + raise ValueError( + f"The config.weight_init_max should be None, but is {config.weight_init_max}" + ) + if config.weight_init_min is not None: + raise ValueError( + f"The config.weight_init_min should be None, but is {config.weight_init_min}" + ) + if config.num_embeddings_post_pruning is not None: + raise ValueError( + f"The config.num_embeddings_post_pruning should be None, but is {config.num_embeddings_post_pruning}" + ) + if config.init_fn is not None and not hasattr(config.init_fn, "__call__"): + raise ValueError( + f"The config.init_fn should be callable, but is {config.init_fn}" + ) + if config.pooling is not None and config.pooling not in [ + PoolingType.SUM, + PoolingType.MEAN, + ]: + raise ValueError( + f"The config.pooling should be in [PoolingType.SUM, PoolingType.MEAN], but is {config.pooling}" + ) + + class HashEmbeddingBag(torch.nn.Module): def __init__(self, config: HashEmbeddingBagConfig, device: torch.device): pass @@ -94,7 +161,9 @@ class HashEmbeddingBag(torch.nn.Module): ): return NotImplemented - def forward(self, input_tensor: torch.Tensor, offsets: Optional[torch.Tensor] = None): + def forward( + self, input_tensor: torch.Tensor, offsets: Optional[torch.Tensor] = None + ): return NotImplemented @@ -146,6 +215,8 @@ class HashEmbeddingBagCollection(EmbeddingBagCollectionInterface): table_names = set() for embedding_config in tables: + check_embedding_config_valid(embedding_config) + if embedding_config.name in table_names: raise ValueError(f"Duplicate table name {embedding_config.name}") table_names.add(embedding_config.name) @@ -154,8 +225,9 @@ class HashEmbeddingBagCollection(EmbeddingBagCollectionInterface): if embedding_config.data_type == DataType.FP32 else torch.float16 ) - is_hybrid_device = (isinstance(device, str) and device in HYBRID_SUPPORT_DEVICE - ) or (hasattr(device, 'type') and device.type in HYBRID_SUPPORT_DEVICE) + is_hybrid_device = ( + isinstance(device, str) and device in HYBRID_SUPPORT_DEVICE + ) or (hasattr(device, "type") and device.type in HYBRID_SUPPORT_DEVICE) if is_hybrid_device: self.embedding_bags[embedding_config.name] = HybridHashTable( config=embedding_config,