From 3d02bc379d8e29caafa46ad333cdd01b51b909ca Mon Sep 17 00:00:00 2001 From: zhoucy Date: Sat, 19 Jul 2025 16:05:17 +0800 Subject: [PATCH 01/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82values=E6=94=AF=E6=8C=81int64=E3=80=81int32?= =?UTF-8?q?=E3=80=81fp32=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 248 +++++++++--------- .../test_permute2d_sparse_data.py | 26 +- 2 files changed, 137 insertions(+), 137 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 07059da9..552ff067 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -20,126 +20,126 @@ See the License for the specific language governing permissions and #include "../../../common/ops_log.h" namespace optiling { -constexpr int GM_ALIGN = 64; -constexpr int RESERVER_UB_SIZE = 20 * 1024; -constexpr int DATA_TYPE_INT64 = 8; -constexpr int DATA_TYPE_INT32 = 4; -constexpr int DATA_TYPE_FLOAT32 = 4; -constexpr int NUM_QUEUE = 4; -constexpr int UB_ALIGN = 32; -constexpr int SUPPORT_EMBEDDING_DIM_NUM = 2; -constexpr int PERMUTE_INDEX = 0; -constexpr int LENGTH_INDEX = 1; -constexpr int VALUES_INDEX = 2; - -static ge::graphStatus SetTypeTiling(gert::TilingContext* context, Permute2dSparseDataTilingData& tiling) -{ - // check tensor is nullptr - OPS_LOG_E_IF_NULL("permute", context->GetInputTensor(PERMUTE_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("length", context->GetInputTensor(LENGTH_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("value", context->GetInputTensor(VALUES_INDEX), return ge::GRAPH_FAILED); - // permute: InputTensor(0), support int32 - int64_t permuteDataType = 0; - ge::DataType permuteDataTypeGe = context->GetInputTensor(0)->GetDataType(); - if (permuteDataTypeGe == ge::DataType::DT_INT32) { - permuteDataType = DATA_TYPE_INT32; - } - - // lengths: InputTensor(1), support int64、int32 - int64_t lengthsDataType = 0; - ge::DataType lengthsDataTypeGe = context->GetInputTensor(1)->GetDataType(); - if (lengthsDataTypeGe == ge::DataType::DT_INT64) { - lengthsDataType = DATA_TYPE_INT64; - } else { - lengthsDataType = DATA_TYPE_INT32; - } - - // value: InputTensor(2), support int64、int32、fp32 - int64_t valueDataType = 0; - ge::DataType dataType = context->GetInputTensor(2)->GetDataType(); - if (dataType == ge::DataType::DT_INT32) { - valueDataType = DATA_TYPE_INT32; - } else if (dataType == ge::DataType::DT_INT64) { - valueDataType = DATA_TYPE_INT64; - } else { - valueDataType = DATA_TYPE_FLOAT32; - } - - tiling.set_valueDataType(valueDataType); - tiling.set_permuteDataType(permuteDataType); - tiling.set_lengthsDataType(lengthsDataType); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus TilingFunc(gert::TilingContext* context) -{ - OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("permuteShape", context->GetInputShape(PERMUTE_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("lengthsShape", context->GetInputShape(LENGTH_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("valuesShape", context->GetInputShape(VALUES_INDEX), return ge::GRAPH_FAILED); - - Permute2dSparseDataTilingData tiling; - auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - - auto permuteShape = context->GetInputShape(0)->GetStorageShape(); - auto lengthsShape = context->GetInputShape(1)->GetStorageShape(); - auto valuesShape = context->GetInputShape(2)->GetStorageShape(); - - // set ub - uint64_t ubCanUsed; - ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubCanUsed); - ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; - tiling.set_ubCanUsed(ubCanUsed); - - // datatype check - if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || - (permuteShape.GetDim(0) != lengthsShape.GetDim(0))) { - printf("[ERROR]permute shape or lengths shape is error."); - return ge::GRAPH_FAILED; + constexpr int GM_ALIGN = 64; + constexpr int RESERVER_UB_SIZE = 20 * 1024; + constexpr int DATA_TYPE_INT64 = 8; + constexpr int DATA_TYPE_INT32 = 4; + constexpr int DATA_TYPE_FLOAT32 = 4; + constexpr int NUM_QUEUE = 4; + constexpr int UB_ALIGN = 32; + constexpr int SUPPORT_EMBEDDING_DIM_NUM = 2; + constexpr int PERMUTE_INDEX = 0; + constexpr int LENGTH_INDEX = 1; + constexpr int VALUES_INDEX = 2; + + static ge::graphStatus SetTypeTiling(gert::TilingContext* context, Permute2dSparseDataTilingData& tiling) + { + // check tensor is nullptr + OPS_LOG_E_IF_NULL("permute", context->GetInputTensor(PERMUTE_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("length", context->GetInputTensor(LENGTH_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("value", context->GetInputTensor(VALUES_INDEX), return ge::GRAPH_FAILED); + // permute: InputTensor(0), support int32 + int64_t permuteDataType = 0; + ge::DataType permuteDataTypeGe = context->GetInputTensor(0)->GetDataType(); + if (permuteDataTypeGe == ge::DataType::DT_INT32) { + permuteDataType = DATA_TYPE_INT32; + } + + // lengths: InputTensor(1), support int64、int32 + int64_t lengthsDataType = 0; + ge::DataType lengthsDataTypeGe = context->GetInputTensor(1)->GetDataType(); + if (lengthsDataTypeGe == ge::DataType::DT_INT64) { + lengthsDataType = DATA_TYPE_INT64; + } else { + lengthsDataType = DATA_TYPE_INT32; + } + + // value: InputTensor(2), support int64、int32、fp32 + int64_t valueDataType = 0; + ge::DataType dataType = context->GetInputTensor(2)->GetDataType(); + if (dataType == ge::DataType::DT_INT32) { + valueDataType = DATA_TYPE_INT32; + } else if (dataType == ge::DataType::DT_INT64) { + valueDataType = DATA_TYPE_INT64; + } else { + valueDataType = DATA_TYPE_FLOAT32; + } + + tiling.set_valueDataType(valueDataType); + tiling.set_permuteDataType(permuteDataType); + tiling.set_lengthsDataType(lengthsDataType); + return ge::GRAPH_SUCCESS; } - // set coreNUm - size_t coreNum = ascendPlatform.GetCoreNumAiv(); - if (coreNum == 0) { - return ge::GRAPH_FAILED; + static ge::graphStatus TilingFunc(gert::TilingContext* context) + { + OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("permuteShape", context->GetInputShape(PERMUTE_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("lengthsShape", context->GetInputShape(LENGTH_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("valuesShape", context->GetInputShape(VALUES_INDEX), return ge::GRAPH_FAILED); + + Permute2dSparseDataTilingData tiling; + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + + auto permuteShape = context->GetInputShape(0)->GetStorageShape(); + auto lengthsShape = context->GetInputShape(1)->GetStorageShape(); + auto valuesShape = context->GetInputShape(2)->GetStorageShape(); + + // set ub + uint64_t ubCanUsed; + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubCanUsed); + ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; + tiling.set_ubCanUsed(ubCanUsed); + + // datatype check + if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || + (permuteShape.GetDim(0) != lengthsShape.GetDim(0))) { + printf("[ERROR]permute shape or lengths shape is error."); + return ge::GRAPH_FAILED; + } + + // set coreNUm + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + if (coreNum == 0) { + return ge::GRAPH_FAILED; + } + tiling.set_coreNum(coreNum); + + // tiling core + int64_t totalBatch = permuteShape.GetDim(0); + tiling.set_totalBatch(totalBatch); + int64_t baseBatchLen = (permuteShape.GetDim(0)) / coreNum; + tiling.set_baseBatchLen(baseBatchLen); + int64_t tailSplitIndex = (permuteShape.GetDim(0)) % coreNum; + tiling.set_tailSplitIndex(tailSplitIndex); + + // set data dim + int64_t permuteDim0 = permuteShape.GetDim(0); + tiling.set_permuteDim0(permuteDim0); + int64_t lengthsT = lengthsShape.GetDim(0); + tiling.set_lengthsT(lengthsT); + int64_t lengthsB = lengthsShape.GetDim(1); + tiling.set_lengthsB(lengthsB); + int64_t valuesDim = valuesShape.GetDim(0); + tiling.set_valuesDim(valuesDim); + + // apply workspace + size_t* currentWorkspace = context->GetWorkspaceSizes(1); + size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); + currentWorkspace[0] = systemWorkspacesSize + (lengthsT + 1) * GM_ALIGN + (lengthsT + 1) * GM_ALIGN * coreNum; + + OPS_LOG_E_IF(SetTypeTiling(context, tiling) == ge::GRAPH_FAILED, context, return ge::GRAPH_FAILED, + "SetTypeTiling Failed."); + + context->SetBlockDim(coreNum); + + OPS_LOG_E_IF_NULL("raw tilingData", context->GetRawTilingData(), return ge::GRAPH_FAILED); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; } - tiling.set_coreNum(coreNum); - - // tiling core - int64_t totalBatch = permuteShape.GetDim(0); - tiling.set_totalBatch(totalBatch); - int64_t baseBatchLen = (permuteShape.GetDim(0)) / coreNum; - tiling.set_baseBatchLen(baseBatchLen); - int64_t tailSplitIndex = (permuteShape.GetDim(0)) % coreNum; - tiling.set_tailSplitIndex(tailSplitIndex); - - // set data dim - int64_t permuteDim0 = permuteShape.GetDim(0); - tiling.set_permuteDim0(permuteDim0); - int64_t lengthsT = lengthsShape.GetDim(0); - tiling.set_lengthsT(lengthsT); - int64_t lengthsB = lengthsShape.GetDim(1); - tiling.set_lengthsB(lengthsB); - int64_t valuesDim = valuesShape.GetDim(0); - tiling.set_valuesDim(valuesDim); - - // apply workspace - size_t* currentWorkspace = context->GetWorkspaceSizes(1); - size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); - currentWorkspace[0] = systemWorkspacesSize + (lengthsT + 1) * GM_ALIGN + (lengthsT + 1) * GM_ALIGN * coreNum; - - OPS_LOG_E_IF(SetTypeTiling(context, tiling) == ge::GRAPH_FAILED, context, return ge::GRAPH_FAILED, - "SetTypeTiling Failed."); - - context->SetBlockDim(coreNum); - - OPS_LOG_E_IF_NULL("raw tilingData", context->GetRawTilingData(), return ge::GRAPH_FAILED); - - tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); - context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); - - return ge::GRAPH_SUCCESS; -} } // namespace optiling namespace ge { @@ -178,37 +178,37 @@ public: { this->Input("permute") .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .DataTypeList({ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("lengths") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT32, ge::DT_INT64, ge::DT_INT32}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("values") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT, ge::DT_FLOAT}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("weights") .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("permuted_lengths") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT32, ge::DT_INT64, ge::DT_INT32}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("permuted_values") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT, ge::DT_FLOAT}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("permuted_weights") .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index c6f6120d..4315fb59 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -14,18 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - import sysconfig + import pytest import torch import torch_npu import fbgemm_gpu import numpy as np +DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") -lengths_type = [np.int64, np.int32, np.int64, np.int32] -values_type = [np.int64, np.int32, np.float32, np.float32] +lengths_type = [np.int64, np.int32] +values_type = [np.int64, np.int32, np.float32] def get_result(permute, lengths, values): @@ -44,11 +45,11 @@ def get_result(permute, lengths, values): return permuted_lengths.cpu(), permuted_values.cpu() -def get_result_npu(permute, lengths, values, device="npu:0"): - torch.npu.set_device(device) - input_permute_torch = torch.from_numpy(permute).to(device) - input_lengths_torch = torch.from_numpy(lengths).to(device) - input_values_torch = torch.from_numpy(values).to(device) +def get_result_npu(permute, lengths, values): + torch.npu.set_device(DEVICE) + input_permute_torch = torch.from_numpy(permute).to(DEVICE) + input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) + input_values_torch = torch.from_numpy(values).to(DEVICE) (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data( @@ -59,19 +60,18 @@ def get_result_npu(permute, lengths, values, device="npu:0"): return permuted_lengths.cpu(), permuted_values.cpu() -@pytest.mark.parametrize("device", ["npu:0", "npu:5"]) -@pytest.mark.parametrize("type_list", zip(lengths_type, values_type)) +@pytest.mark.parametrize("ltype", lengths_type) +@pytest.mark.parametrize("vtype", values_type) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(type_list, device, permute_dim, lengths): - ltype, vtype = type_list +def test_permute2d_sparse_data(permute_dim, lengths, ltype, vtype): input_permute = np.arange(permute_dim).astype(np.int32) np.random.shuffle(input_permute) input_lengths = np.ones((permute_dim, lengths), dtype=ltype) input_values = np.arange(0, permute_dim * lengths).astype(vtype) golden = get_result(input_permute, input_lengths, input_values) - result = get_result_npu(input_permute, input_lengths, input_values, device) + result = get_result_npu(input_permute, input_lengths, input_values) assert torch.allclose(golden[0], result[0], atol=1e-5) assert torch.allclose(golden[1], result[1], atol=1e-5) -- Gitee From 5095a680856a3f35f8423d0e508108a1b37218ef Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 21 Jul 2025 08:57:45 +0800 Subject: [PATCH 02/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82permute.size(0)=20<=3D=20lengths.size(0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 6 +++--- .../permute2d_sparse_data/test_permute2d_sparse_data.py | 7 ++++--- .../2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 552ff067..126cb8ee 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -93,7 +93,7 @@ namespace optiling { // datatype check if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || - (permuteShape.GetDim(0) != lengthsShape.GetDim(0))) { + (permuteShape.GetDim(0) > lengthsShape.GetDim(0))) { printf("[ERROR]permute shape or lengths shape is error."); return ge::GRAPH_FAILED; } @@ -116,8 +116,8 @@ namespace optiling { // set data dim int64_t permuteDim0 = permuteShape.GetDim(0); tiling.set_permuteDim0(permuteDim0); - int64_t lengthsT = lengthsShape.GetDim(0); - tiling.set_lengthsT(lengthsT); +// int64_t lengthsT = lengthsShape.GetDim(0); + tiling.set_lengthsT(permuteDim0); int64_t lengthsB = lengthsShape.GetDim(1); tiling.set_lengthsB(lengthsB); int64_t valuesDim = valuesShape.GetDim(0); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 4315fb59..5be5b643 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -63,12 +63,13 @@ def get_result_npu(permute, lengths, values): @pytest.mark.parametrize("ltype", lengths_type) @pytest.mark.parametrize("vtype", values_type) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) +@pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(permute_dim, lengths, ltype, vtype): +def test_permute2d_sparse_data(permute_dim, extra_permute_dim, lengths, ltype, vtype): input_permute = np.arange(permute_dim).astype(np.int32) np.random.shuffle(input_permute) - input_lengths = np.ones((permute_dim, lengths), dtype=ltype) - input_values = np.arange(0, permute_dim * lengths).astype(vtype) + input_lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) + input_values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) golden = get_result(input_permute, input_lengths, input_values) result = get_result_npu(input_permute, input_lengths, input_values) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index f24d0534..40114c81 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -33,8 +33,8 @@ tuple> permute2d_sparse_data_impl_npu( const auto T = lengths.size(0); const auto B = lengths.size(1); - at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); - at::Tensor outValues = at::empty({valuesConti.size(0)}, valuesConti.options()); + at::Tensor outLengths = lengthsConti.clone(); + at::Tensor outValues = valuesConti.clone(); at::Tensor outWeights = at::empty({1}, weightsConti.options()); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, T, -- Gitee From 4daa8b04832838e6cce9d837c22ffbaa9a8629a4 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 21 Jul 2025 14:31:33 +0800 Subject: [PATCH 03/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82permute.size(0)=20<=3D=20lengths.size(0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 4 ++-- .../permute2d_sparse_data.cpp | 20 ++++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 126cb8ee..f78556fa 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -116,8 +116,8 @@ namespace optiling { // set data dim int64_t permuteDim0 = permuteShape.GetDim(0); tiling.set_permuteDim0(permuteDim0); -// int64_t lengthsT = lengthsShape.GetDim(0); - tiling.set_lengthsT(permuteDim0); + int64_t lengthsT = permuteDim0; + tiling.set_lengthsT(lengthsT); int64_t lengthsB = lengthsShape.GetDim(1); tiling.set_lengthsB(lengthsB); int64_t valuesDim = valuesShape.GetDim(0); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index 40114c81..c7d79fcb 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -30,11 +30,25 @@ tuple> permute2d_sparse_data_impl_npu( // weight暂不支持 at::Tensor weightsConti = at::empty({1}, lengths.options()); - const auto T = lengths.size(0); + const auto T = permute.size(0); const auto B = lengths.size(1); - at::Tensor outLengths = lengthsConti.clone(); - at::Tensor outValues = valuesConti.clone(); + int outValuesLen; + if (permute.size(0) == lengths.size(0)) { + outValuesLen = valuesConti.size(0); + } else if (permute.size(0) > lengths.size(0)) { + throw std::runtime_error( + "permute.size(0) must be less than or equal to lengths.size(0). " + "Got permute.size(0): " + std::to_string(permute.size(0)) + + ", lengths.size(0): " + std::to_string(lengths.size(0))); + } else if (permuted_lengths_sum && permuted_lengths_sum > 0) { + outValuesLen = permuted_lengths_sum; + } else { + outValuesLen = lengthsConti.narrow(0, 0, T).sum().item(); + } + + at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); + at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); at::Tensor outWeights = at::empty({1}, weightsConti.options()); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, T, -- Gitee From aea35649565a596920a7c168c79dce0a2460b5d1 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 21 Jul 2025 15:11:05 +0800 Subject: [PATCH 04/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82permute.size(0)=20<=3D=20lengths.size(0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_permute2d_sparse_data.py | 10 ++++++---- .../permute2d_sparse_data.cpp | 19 +++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 5be5b643..f5db3bc2 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -45,7 +45,7 @@ def get_result(permute, lengths, values): return permuted_lengths.cpu(), permuted_values.cpu() -def get_result_npu(permute, lengths, values): +def get_result_npu(permute, lengths, values, permuted_lengths_sum=-1): torch.npu.set_device(DEVICE) input_permute_torch = torch.from_numpy(permute).to(DEVICE) input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) @@ -53,7 +53,7 @@ def get_result_npu(permute, lengths, values): (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data( - input_permute_torch, input_lengths_torch, input_values_torch, None + input_permute_torch, input_lengths_torch, input_values_torch, None, permuted_lengths_sum ) ) torch.npu.synchronize() @@ -64,15 +64,17 @@ def get_result_npu(permute, lengths, values): @pytest.mark.parametrize("vtype", values_type) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) @pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) +@pytest.mark.parametrize("permuted_lengths_sum", [None, -1, 0, 1]) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(permute_dim, extra_permute_dim, lengths, ltype, vtype): +def test_permute2d_sparse_data(permute_dim, extra_permute_dim, lengths, ltype, vtype, permuted_lengths_sum): input_permute = np.arange(permute_dim).astype(np.int32) np.random.shuffle(input_permute) input_lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) input_values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) + permuted_lengths_sum = permuted_lengths_sum if permuted_lengths_sum != 1 else input_lengths[:permute_dim].sum() golden = get_result(input_permute, input_lengths, input_values) - result = get_result_npu(input_permute, input_lengths, input_values) + result = get_result_npu(input_permute, input_lengths, input_values, permuted_lengths_sum) assert torch.allclose(golden[0], result[0], atol=1e-5) assert torch.allclose(golden[1], result[1], atol=1e-5) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index c7d79fcb..bae5287e 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -37,12 +37,11 @@ tuple> permute2d_sparse_data_impl_npu( if (permute.size(0) == lengths.size(0)) { outValuesLen = valuesConti.size(0); } else if (permute.size(0) > lengths.size(0)) { - throw std::runtime_error( - "permute.size(0) must be less than or equal to lengths.size(0). " - "Got permute.size(0): " + std::to_string(permute.size(0)) + - ", lengths.size(0): " + std::to_string(lengths.size(0))); - } else if (permuted_lengths_sum && permuted_lengths_sum > 0) { - outValuesLen = permuted_lengths_sum; + throw std::runtime_error("permute.size(0) must be less than or equal to lengths.size(0). " + "Got permute.size(0): " + std::to_string(permute.size(0)) + + ", lengths.size(0): " + std::to_string(lengths.size(0))); + } else if (permuted_lengths_sum.has_value() && permuted_lengths_sum.has_value() > 0) { + outValuesLen = static_cast(permuted_lengths_sum.value()); } else { outValuesLen = lengthsConti.narrow(0, 0, T).sum().item(); } @@ -60,10 +59,10 @@ tuple> permute2d_sparse_data_impl_npu( TORCH_LIBRARY(mxrec, m) { m.def("permute_2D_sparse_data(Tensor permute, " - "Tensor lengths, " - "Tensor values, " - "Tensor? weights=None, " - "SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)"); + " Tensor lengths, " + " Tensor values, " + " Tensor? weights=None, " + " SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)"); } TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) -- Gitee From 79ded207b320a21654b929e30877f9e9ebaa65de Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 22 Jul 2025 11:01:17 +0800 Subject: [PATCH 05/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81weights=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 103 ++++++---------- .../op_host/permute2d_sparse_data_tilling.h | 5 +- .../op_kernel/permute2d_sparse_data.cpp | 2 +- .../op_kernel/permute2d_sparse_data_kernel.h | 113 ++++++++++++------ .../test_permute2d_sparse_data.py | 49 +++++--- .../permute2d_sparse_data.cpp | 8 +- 6 files changed, 155 insertions(+), 125 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index f78556fa..d979869d 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -31,87 +31,40 @@ namespace optiling { constexpr int PERMUTE_INDEX = 0; constexpr int LENGTH_INDEX = 1; constexpr int VALUES_INDEX = 2; + constexpr int WEIGHTS_INDEX = 3; - static ge::graphStatus SetTypeTiling(gert::TilingContext* context, Permute2dSparseDataTilingData& tiling) + static ge::graphStatus TilingFunc(gert::TilingContext* context) { - // check tensor is nullptr - OPS_LOG_E_IF_NULL("permute", context->GetInputTensor(PERMUTE_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("length", context->GetInputTensor(LENGTH_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("value", context->GetInputTensor(VALUES_INDEX), return ge::GRAPH_FAILED); - // permute: InputTensor(0), support int32 - int64_t permuteDataType = 0; - ge::DataType permuteDataTypeGe = context->GetInputTensor(0)->GetDataType(); - if (permuteDataTypeGe == ge::DataType::DT_INT32) { - permuteDataType = DATA_TYPE_INT32; - } - - // lengths: InputTensor(1), support int64、int32 - int64_t lengthsDataType = 0; - ge::DataType lengthsDataTypeGe = context->GetInputTensor(1)->GetDataType(); - if (lengthsDataTypeGe == ge::DataType::DT_INT64) { - lengthsDataType = DATA_TYPE_INT64; - } else { - lengthsDataType = DATA_TYPE_INT32; - } - - // value: InputTensor(2), support int64、int32、fp32 - int64_t valueDataType = 0; - ge::DataType dataType = context->GetInputTensor(2)->GetDataType(); - if (dataType == ge::DataType::DT_INT32) { - valueDataType = DATA_TYPE_INT32; - } else if (dataType == ge::DataType::DT_INT64) { - valueDataType = DATA_TYPE_INT64; - } else { - valueDataType = DATA_TYPE_FLOAT32; - } + Permute2dSparseDataTilingData tiling; - tiling.set_valueDataType(valueDataType); - tiling.set_permuteDataType(permuteDataType); - tiling.set_lengthsDataType(lengthsDataType); - return ge::GRAPH_SUCCESS; - } + bool enableWeights = (context->GetOptionalInputTensor(WEIGHTS_INDEX) != nullptr); + tiling.set_enableWeights(enableWeights); - static ge::graphStatus TilingFunc(gert::TilingContext* context) - { OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("permuteShape", context->GetInputShape(PERMUTE_INDEX), return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("lengthsShape", context->GetInputShape(LENGTH_INDEX), return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("valuesShape", context->GetInputShape(VALUES_INDEX), return ge::GRAPH_FAILED); + if (enableWeights) { + OPS_LOG_E_IF_NULL("weightsShape", context->GetInputShape(WEIGHTS_INDEX), return ge::GRAPH_FAILED); + } - Permute2dSparseDataTilingData tiling; auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - auto permuteShape = context->GetInputShape(0)->GetStorageShape(); - auto lengthsShape = context->GetInputShape(1)->GetStorageShape(); - auto valuesShape = context->GetInputShape(2)->GetStorageShape(); - - // set ub - uint64_t ubCanUsed; - ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubCanUsed); - ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; - tiling.set_ubCanUsed(ubCanUsed); + auto permuteShape = context->GetInputShape(PERMUTE_INDEX)->GetStorageShape(); + auto lengthsShape = context->GetInputShape(LENGTH_INDEX)->GetStorageShape(); + auto valuesShape = context->GetInputShape(VALUES_INDEX)->GetStorageShape(); + auto weightsShape = context->GetInputShape(WEIGHTS_INDEX)->GetStorageShape(); - // datatype check + // shape check if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || (permuteShape.GetDim(0) > lengthsShape.GetDim(0))) { - printf("[ERROR]permute shape or lengths shape is error."); + OPS_LOG_E("", "[ERROR]permute shape or lengths shape is error. "); return ge::GRAPH_FAILED; } - - // set coreNUm - size_t coreNum = ascendPlatform.GetCoreNumAiv(); - if (coreNum == 0) { + if (enableWeights && valuesShape != weightsShape) { + OPS_LOG_E("", "[ERROR]values shape or weights shape is error. "); return ge::GRAPH_FAILED; } - tiling.set_coreNum(coreNum); - - // tiling core - int64_t totalBatch = permuteShape.GetDim(0); - tiling.set_totalBatch(totalBatch); - int64_t baseBatchLen = (permuteShape.GetDim(0)) / coreNum; - tiling.set_baseBatchLen(baseBatchLen); - int64_t tailSplitIndex = (permuteShape.GetDim(0)) % coreNum; - tiling.set_tailSplitIndex(tailSplitIndex); // set data dim int64_t permuteDim0 = permuteShape.GetDim(0); @@ -123,14 +76,32 @@ namespace optiling { int64_t valuesDim = valuesShape.GetDim(0); tiling.set_valuesDim(valuesDim); + // tiling core + int64_t totalBatch = permuteDim0; + tiling.set_totalBatch(totalBatch); + int64_t baseBatchLen = permuteDim0 / coreNum; + tiling.set_baseBatchLen(baseBatchLen); + int64_t tailSplitIndex = permuteDim0 % coreNum; + tiling.set_tailSplitIndex(tailSplitIndex); + + // set ub + uint64_t ubCanUsed; + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubCanUsed); + ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; + tiling.set_ubCanUsed(ubCanUsed); + + // set coreNUm + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + if (coreNum == 0) { + return ge::GRAPH_FAILED; + } + tiling.set_coreNum(coreNum); + // apply workspace size_t* currentWorkspace = context->GetWorkspaceSizes(1); size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); currentWorkspace[0] = systemWorkspacesSize + (lengthsT + 1) * GM_ALIGN + (lengthsT + 1) * GM_ALIGN * coreNum; - OPS_LOG_E_IF(SetTypeTiling(context, tiling) == ge::GRAPH_FAILED, context, return ge::GRAPH_FAILED, - "SetTypeTiling Failed."); - context->SetBlockDim(coreNum); OPS_LOG_E_IF_NULL("raw tilingData", context->GetRawTilingData(), return ge::GRAPH_FAILED); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h index 97115458..12a22507 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h @@ -27,10 +27,7 @@ namespace optiling { TILING_DATA_FIELD_DEF(int64_t, lengthsT); TILING_DATA_FIELD_DEF(int64_t, lengthsB); TILING_DATA_FIELD_DEF(int64_t, valuesDim); - - TILING_DATA_FIELD_DEF(int64_t, valueDataType); - TILING_DATA_FIELD_DEF(int64_t, permuteDataType); - TILING_DATA_FIELD_DEF(int64_t, lengthsDataType); + TILING_DATA_FIELD_DEF(bool, enableWeights); TILING_DATA_FIELD_DEF(int64_t, totalBatch); TILING_DATA_FIELD_DEF(int64_t, baseBatchLen); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp index 014bda5b..dae264ed 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp @@ -24,6 +24,6 @@ extern "C" __global__ __aicore__ void permute2d_sparse_data(GM_ADDR permute, GM_ { Permute2dSparseData::Args args{permute, lengths, values, weights, out_lengths, out_indices, out_weights, workspace, tiling}; - Permute2dSparseData::Permute2dSparseDataKernel kernel(args); + Permute2dSparseData::Permute2dSparseDataKernel kernel(args); kernel.Compute(); } \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 03b661c8..77646c9b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -42,6 +42,7 @@ struct Args { GM_ADDR tiling; }; +template class Permute2dSparseDataKernel { public: __aicore__ inline Permute2dSparseDataKernel(Args args) @@ -55,10 +56,6 @@ public: lengthsB = tilingData.lengthsB; valuesDim = tilingData.valuesDim; - valueDataType = tilingData.valueDataType; - permuteDataType = tilingData.permuteDataType; - lengthsDataType = tilingData.lengthsDataType; - totalBatch = tilingData.totalBatch; baseBatchLen = tilingData.baseBatchLen; tailSplitIndex = tilingData.tailSplitIndex; @@ -70,6 +67,8 @@ public: values = args.values; weights = args.weights; + enableWeights = args.enableWeights; + outLengths = args.out_lengths; outIndices = args.out_indices; outWeights = args.out_weights; @@ -84,13 +83,17 @@ public: tOffsetOfThisCore = tailSplitIndex * (baseBatchLen + 1) + (GetBlockIdx() - tailSplitIndex) * baseBatchLen; } - permuteGT.SetGlobalBuffer(permute, permuteDim0 * permuteDataType); - lengthsGT.SetGlobalBuffer(lengths, lengthsT * lengthsB * lengthsDataType); - valuesGT.SetGlobalBuffer(values, valuesDim * valueDataType); + permuteGT.SetGlobalBuffer(permute, permuteDim0 * sizeof(int32_t)); + lengthsGT.SetGlobalBuffer(lengths, lengthsT * lengthsB * sizeof(LType)); + valuesGT.SetGlobalBuffer(values, valuesDim * sizeof(VType)); - outLengthsGT.SetGlobalBuffer(outLengths, lengthsT * lengthsB * lengthsDataType); - outIndicesGT.SetGlobalBuffer(outIndices, valuesDim * valueDataType); + outLengthsGT.SetGlobalBuffer(outLengths, lengthsT * lengthsB * sizeof(LType)); + outIndicesGT.SetGlobalBuffer(outIndices, valuesDim * sizeof(VType)); + if (enableWeights) { + weightsGT.SetGlobalBuffer(weights, valuesDim * sizeof(WType)); + outWeightsGT.SetGlobalBuffer(outWeights, valuesDim * sizeof(WType)); + } // Init pipe pipe.InitBuffer(inQueueX, USE_QUEUE_NUM, ubCanUsed / USE_QUEUE_NUM); blockLen = ubCanUsed / USE_QUEUE_NUM; @@ -139,25 +142,15 @@ public: offsetPtr = (__gm__ int64_t*)workspace; GlobalTensor offsetGT; offsetGT.SetGlobalBuffer((__gm__ int64_t*)offsetPtr, (lengthsT + 1) * UB_ALIGN * UB_ALIGN); - if (lengthsDataType == DATA_TYPE_INT64) { - __gm__ int64_t* lengthsPtr = (__gm__ int64_t*)lengths; - for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { - int64_t offsetT = 0; - for (int64_t j = 0; j < lengthsB; j++) { - offsetT += *(lengthsPtr + i * lengthsB + j); - } - offsetGT.SetValue(i * UB_ALIGN, offsetT); - } - } else { - __gm__ int32_t* lengthsPtr = (__gm__ int32_t*)lengths; - for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { - int64_t offsetT = 0; - for (int64_t j = 0; j < lengthsB; j++) { - offsetT += *(lengthsPtr + i * lengthsB + j); - } - offsetGT.SetValue(i * UB_ALIGN, offsetT); + __gm__ LType* lengthsPtr = (__gm__ LType*)lengths; + for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { + int64_t offsetT = 0; + for (int64_t j = 0; j < lengthsB; j++) { + offsetT += *(lengthsPtr + i * lengthsB + j); } + offsetGT.SetValue(i * UB_ALIGN, offsetT); } + AscendC::DataCacheCleanAndInvalid(offsetGT); } @@ -169,11 +162,11 @@ public: int64_t ToffsetThisIndex = *(permutePtr + i); int64_t ToffsetNextIndex = *(permutePtr + i) + 1; - int64_t lengthsStartIndex = ToffsetThisIndex * lengthsB * lengthsDataType; - int64_t lengthsEndIndex = ToffsetNextIndex * lengthsB * lengthsDataType; + int64_t lengthsStartIndex = ToffsetThisIndex * lengthsB * sizeof(LType); + int64_t lengthsEndIndex = ToffsetNextIndex * lengthsB * sizeof(LType); - int64_t outStartIndex = i * lengthsB * lengthsDataType; - int64_t outEndIndex = (i + 1) * lengthsB * lengthsDataType; + int64_t outStartIndex = i * lengthsB * sizeof(LType); + int64_t outEndIndex = (i + 1) * lengthsB * sizeof(LType); int64_t totalLen = lengthsEndIndex - lengthsStartIndex; int64_t remainLen = totalLen; while (remainLen > 0) { @@ -245,6 +238,54 @@ public: } } + __aicore__ void PermuteWeights() + { + int64_t outWeightOffset = 0; + int64_t currentT = 0; + for (int64_t i = 0; i < permuteDim0; i++) { + currentT = *(permutePtr + i); + int64_t tLen = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN) - *(totalOffsetPtr + currentT * UB_ALIGN); + int64_t baseCoreLen = tLen / coreNum; + int64_t tailLen = tLen % coreNum; + + // calculate current core permute weights offset + if (GetBlockIdx() < tailLen) { + weightLenOfThisCore = baseCoreLen + 1; + offsetOfThisCore = GetBlockIdx() * (baseCoreLen + 1); + } else { + weightLenOfThisCore = baseCoreLen; + offsetOfThisCore = tailLen * (baseCoreLen + 1) + (GetBlockIdx() - tailLen) * baseCoreLen; + } + + int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); + int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); + + int64_t weightsStartIndex = (startIndex + offsetOfThisCore) * sizeof(WType); + int64_t outWeightStartIndex = (outWeightOffset + offsetOfThisCore) * sizeof(WType); + + int64_t remainLen = weightLenOfThisCore * sizeof(WType); + while (remainLen > 0) { + int64_t thisLen = blockLen; + if (remainLen < blockLen) { + thisLen = remainLen; + } + LocalTensor inputTensor = inQueueX.AllocTensor(); + + CpGm2Local(inputTensor, weightsGT[weightsStartIndex], thisLen); + inQueueX.EnQue(inputTensor); + LocalTensor outPutTensor = inQueueX.DeQue(); + + CpLocal2Gm(outWeightsGT[outWeightStartIndex], outPutTensor, thisLen); + + outWeightStartIndex += thisLen; + weightsStartIndex += thisLen; + inQueueX.FreeTensor(outPutTensor); + remainLen = remainLen - thisLen; + } + outWeightOffset += tLen; + } + } + __aicore__ void Compute() { CalculateOffsets(); @@ -264,6 +305,9 @@ public: } PermuteLengths(); PermuteValues(); + if (enableWeights) { + PermuteWeights(); + } } private: @@ -282,11 +326,7 @@ private: int64_t lengthsT; int64_t lengthsB; int64_t valuesDim; - - // DataType - int64_t valueDataType; - int64_t permuteDataType; - int64_t lengthsDataType; + bool enableWeights; // Tiling int64_t totalBatch; @@ -304,6 +344,7 @@ private: // ThisCoreLen for B int64_t valueLenOfThisCore; + int64_t weightLenOfThisCore; int64_t offsetOfThisCore; // Tpipe @@ -314,8 +355,10 @@ private: GlobalTensor permuteGT; GlobalTensor lengthsGT; GlobalTensor valuesGT; + GlobalTensor weightsGT; GlobalTensor outLengthsGT; GlobalTensor outIndicesGT; + GlobalTensor outWeightsGT; __gm__ int64_t* offsetPtr; __gm__ int32_t* permutePtr; diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index f5db3bc2..72d2c395 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -27,54 +27,71 @@ torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") lengths_type = [np.int64, np.int32] values_type = [np.int64, np.int32, np.float32] +weights_type = [None, np.int64, np.int32, np.float32] -def get_result(permute, lengths, values): +def get_result(permute, lengths, values, weights, permuted_lengths_sum): input_permute_torch = torch.from_numpy(permute) input_lengths_torch = torch.from_numpy(lengths) input_values_torch = torch.from_numpy(values) + input_weights_torch = torch.from_numpy(weights) (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data( input_permute_torch, input_lengths_torch, input_values_torch, + input_weights_torch, + permuted_lengths_sum ) ) - return permuted_lengths.cpu(), permuted_values.cpu() + return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() -def get_result_npu(permute, lengths, values, permuted_lengths_sum=-1): +def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=-1): torch.npu.set_device(DEVICE) input_permute_torch = torch.from_numpy(permute).to(DEVICE) input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) input_values_torch = torch.from_numpy(values).to(DEVICE) + input_weights_torch = torch.from_numpy(weights).to(DEVICE) (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data( - input_permute_torch, input_lengths_torch, input_values_torch, None, permuted_lengths_sum + input_permute_torch, + input_lengths_torch, + input_values_torch, + input_weights_torch, + permuted_lengths_sum ) ) torch.npu.synchronize() - return permuted_lengths.cpu(), permuted_values.cpu() + return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() @pytest.mark.parametrize("ltype", lengths_type) @pytest.mark.parametrize("vtype", values_type) +@pytest.mark.parametrize("wtype", weights_type) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) @pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) -@pytest.mark.parametrize("permuted_lengths_sum", [None, -1, 0, 1]) +@pytest.mark.parametrize("permuted_lengths_sum", [True, False]) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(permute_dim, extra_permute_dim, lengths, ltype, vtype, permuted_lengths_sum): - input_permute = np.arange(permute_dim).astype(np.int32) - np.random.shuffle(input_permute) - input_lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) - input_values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) - permuted_lengths_sum = permuted_lengths_sum if permuted_lengths_sum != 1 else input_lengths[:permute_dim].sum() +def test_permute2d_sparse_data(ltype, + vtype, + wtype, + permute_dim, + extra_permute_dim, + permuted_lengths_sum, + lengths): + permute = np.arange(permute_dim).astype(np.int32) + np.random.shuffle(permute) + lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) + values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) + weights = None if wtype is None else np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(wtype) + permuted_lengths_sum = lengths[:permute_dim].sum() if permuted_lengths_sum else None - golden = get_result(input_permute, input_lengths, input_values) - result = get_result_npu(input_permute, input_lengths, input_values, permuted_lengths_sum) + golden = get_result(permute, lengths, values, weights, permuted_lengths_sum) + result = get_result_npu(permute, lengths, values, weights, permuted_lengths_sum) - assert torch.allclose(golden[0], result[0], atol=1e-5) - assert torch.allclose(golden[1], result[1], atol=1e-5) + for gt, pred in zip(golden, result): + assert torch.allclose(gt, pred, atol=1e-5) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index bae5287e..7df2f6ed 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -27,8 +27,10 @@ tuple> permute2d_sparse_data_impl_npu( auto permuteConti = permute.contiguous(); auto lengthsConti = lengths.contiguous(); auto valuesConti = values.contiguous(); - // weight暂不支持 at::Tensor weightsConti = at::empty({1}, lengths.options()); + if (weights.has_value()) { + weightsConti = weights.value().contiguous(); + } const auto T = permute.size(0); const auto B = lengths.size(1); @@ -40,7 +42,7 @@ tuple> permute2d_sparse_data_impl_npu( throw std::runtime_error("permute.size(0) must be less than or equal to lengths.size(0). " "Got permute.size(0): " + std::to_string(permute.size(0)) + ", lengths.size(0): " + std::to_string(lengths.size(0))); - } else if (permuted_lengths_sum.has_value() && permuted_lengths_sum.has_value() > 0) { + } else if (permuted_lengths_sum.has_value() && permuted_lengths_sum.value() > 0) { outValuesLen = static_cast(permuted_lengths_sum.value()); } else { outValuesLen = lengthsConti.narrow(0, 0, T).sum().item(); @@ -48,7 +50,7 @@ tuple> permute2d_sparse_data_impl_npu( at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); - at::Tensor outWeights = at::empty({1}, weightsConti.options()); + at::Tensor outWeights = at::empty({outValuesLen}, weightsConti.options()); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, T, outLengths, outValues, outWeights); -- Gitee From 76986db6bf4f019c1fd651eddc0aea9eb17cdce6 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 23 Jul 2025 09:57:26 +0800 Subject: [PATCH 06/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81weights=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 14 +++++++------- .../op_kernel/permute2d_sparse_data.cpp | 2 +- .../op_kernel/permute2d_sparse_data_kernel.h | 8 ++++---- .../test_permute2d_sparse_data.py | 12 +++++++----- .../permute2d_sparse_data.cpp | 10 ++++++---- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index d979869d..0f7ce5ee 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -76,6 +76,13 @@ namespace optiling { int64_t valuesDim = valuesShape.GetDim(0); tiling.set_valuesDim(valuesDim); + // set coreNUm + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + if (coreNum == 0) { + return ge::GRAPH_FAILED; + } + tiling.set_coreNum(coreNum); + // tiling core int64_t totalBatch = permuteDim0; tiling.set_totalBatch(totalBatch); @@ -90,13 +97,6 @@ namespace optiling { ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; tiling.set_ubCanUsed(ubCanUsed); - // set coreNUm - size_t coreNum = ascendPlatform.GetCoreNumAiv(); - if (coreNum == 0) { - return ge::GRAPH_FAILED; - } - tiling.set_coreNum(coreNum); - // apply workspace size_t* currentWorkspace = context->GetWorkspaceSizes(1); size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp index dae264ed..067a8483 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp @@ -24,6 +24,6 @@ extern "C" __global__ __aicore__ void permute2d_sparse_data(GM_ADDR permute, GM_ { Permute2dSparseData::Args args{permute, lengths, values, weights, out_lengths, out_indices, out_weights, workspace, tiling}; - Permute2dSparseData::Permute2dSparseDataKernel kernel(args); + Permute2dSparseData::Permute2dSparseDataKernel kernel(args); kernel.Compute(); } \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 77646c9b..00acba8d 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -67,7 +67,7 @@ public: values = args.values; weights = args.weights; - enableWeights = args.enableWeights; + enableWeights = tilingData.enableWeights; outLengths = args.out_lengths; outIndices = args.out_indices; @@ -212,10 +212,10 @@ public: int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); - int64_t valuesStartIndex = (startIndex + offsetOfThisCore) * valueDataType; - int64_t outValueStartIndex = (outValueOffset + offsetOfThisCore) * valueDataType; + int64_t valuesStartIndex = (startIndex + offsetOfThisCore) * sizeof(VType); + int64_t outValueStartIndex = (outValueOffset + offsetOfThisCore) * sizeof(VType); - int64_t remainLen = valueLenOfThisCore * valueDataType; + int64_t remainLen = valueLenOfThisCore * sizeof(VType); while (remainLen > 0) { int64_t thisLen = blockLen; if (remainLen < blockLen) { diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 72d2c395..292428a3 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -49,7 +49,7 @@ def get_result(permute, lengths, values, weights, permuted_lengths_sum): return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() -def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=-1): +def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=None): torch.npu.set_device(DEVICE) input_permute_torch = torch.from_numpy(permute).to(DEVICE) input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) @@ -83,15 +83,17 @@ def test_permute2d_sparse_data(ltype, extra_permute_dim, permuted_lengths_sum, lengths): - permute = np.arange(permute_dim).astype(np.int32) + permute = np.arange(permute_dim, dtype=np.int32) np.random.shuffle(permute) + values = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=vtype) + weights = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=wtype) if wtype else None lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) - values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) - weights = None if wtype is None else np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(wtype) permuted_lengths_sum = lengths[:permute_dim].sum() if permuted_lengths_sum else None golden = get_result(permute, lengths, values, weights, permuted_lengths_sum) result = get_result_npu(permute, lengths, values, weights, permuted_lengths_sum) for gt, pred in zip(golden, result): - assert torch.allclose(gt, pred, atol=1e-5) + assert type(gt) is type(pred) + if isinstance(gt, torch.Tensor) and isinstance(pred, torch.Tensor): + assert torch.allclose(gt, pred, atol=1e-5) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index 7df2f6ed..95ca13b5 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -50,12 +50,14 @@ tuple> permute2d_sparse_data_impl_npu( at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); - at::Tensor outWeights = at::empty({outValuesLen}, weightsConti.options()); - - EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, T, + at::Tensor outWeights = at::Tensor(); + if (weights.has_value()) { + outWeights = at::empty({outValuesLen}, weightsConti.options()); + } + EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, outValuesLen, outLengths, outValues, outWeights); - return make_tuple(outLengths, outValues, at::Tensor()); + return make_tuple(outLengths, outValues, outWeights); } TORCH_LIBRARY(mxrec, m) -- Gitee From 03ba8e8f1479789845bc7e517bc0730a491a5bd0 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 23 Jul 2025 15:05:52 +0800 Subject: [PATCH 07/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81weights=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 25 +++++++------------ .../op_kernel/permute2d_sparse_data.cpp | 2 +- .../op_kernel/permute2d_sparse_data_kernel.h | 12 ++++----- .../test_permute2d_sparse_data.py | 11 +++----- .../permute2d_sparse_data.cpp | 10 ++------ 5 files changed, 21 insertions(+), 39 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 0f7ce5ee..262dcc40 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -150,38 +150,31 @@ public: this->Input("permute") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Input("lengths") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT64, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Input("values") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Input("weights") .ParamType(OPTIONAL) - .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .DataTypeList({ge::DT_FLOAT}) + .FormatList({ge::FORMAT_ND}); this->Output("permuted_lengths") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT64, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Output("permuted_values") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Output("permuted_weights") .ParamType(OPTIONAL) - .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .DataTypeList({ge::DT_FLOAT}) + .FormatList({ge::FORMAT_ND}); this->Attr("permuted_sum").Int(0); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp index 067a8483..7d226eab 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp @@ -24,6 +24,6 @@ extern "C" __global__ __aicore__ void permute2d_sparse_data(GM_ADDR permute, GM_ { Permute2dSparseData::Args args{permute, lengths, values, weights, out_lengths, out_indices, out_weights, workspace, tiling}; - Permute2dSparseData::Permute2dSparseDataKernel kernel(args); + Permute2dSparseData::Permute2dSparseDataKernel kernel(args); kernel.Compute(); } \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 00acba8d..6858fda1 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -42,7 +42,7 @@ struct Args { GM_ADDR tiling; }; -template +template class Permute2dSparseDataKernel { public: __aicore__ inline Permute2dSparseDataKernel(Args args) @@ -91,8 +91,8 @@ public: outIndicesGT.SetGlobalBuffer(outIndices, valuesDim * sizeof(VType)); if (enableWeights) { - weightsGT.SetGlobalBuffer(weights, valuesDim * sizeof(WType)); - outWeightsGT.SetGlobalBuffer(outWeights, valuesDim * sizeof(WType)); + weightsGT.SetGlobalBuffer(weights, valuesDim * sizeof(float)); + outWeightsGT.SetGlobalBuffer(outWeights, valuesDim * sizeof(float)); } // Init pipe pipe.InitBuffer(inQueueX, USE_QUEUE_NUM, ubCanUsed / USE_QUEUE_NUM); @@ -260,10 +260,10 @@ public: int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); - int64_t weightsStartIndex = (startIndex + offsetOfThisCore) * sizeof(WType); - int64_t outWeightStartIndex = (outWeightOffset + offsetOfThisCore) * sizeof(WType); + int64_t weightsStartIndex = (startIndex + offsetOfThisCore) * sizeof(float); + int64_t outWeightStartIndex = (outWeightOffset + offsetOfThisCore) * sizeof(float); - int64_t remainLen = weightLenOfThisCore * sizeof(WType); + int64_t remainLen = weightLenOfThisCore * sizeof(float); while (remainLen > 0) { int64_t thisLen = blockLen; if (remainLen < blockLen) { diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 292428a3..8e231e2a 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -25,11 +25,6 @@ import numpy as np DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") -lengths_type = [np.int64, np.int32] -values_type = [np.int64, np.int32, np.float32] -weights_type = [None, np.int64, np.int32, np.float32] - - def get_result(permute, lengths, values, weights, permuted_lengths_sum): input_permute_torch = torch.from_numpy(permute) input_lengths_torch = torch.from_numpy(lengths) @@ -69,9 +64,9 @@ def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=None) return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() -@pytest.mark.parametrize("ltype", lengths_type) -@pytest.mark.parametrize("vtype", values_type) -@pytest.mark.parametrize("wtype", weights_type) +@pytest.mark.parametrize("ltype", [np.int64, np.int32]) +@pytest.mark.parametrize("vtype", [np.int64, np.int32, np.float32]) +@pytest.mark.parametrize("wtype", [None, np.float32]) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) @pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) @pytest.mark.parametrize("permuted_lengths_sum", [True, False]) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index 95ca13b5..f9c967ae 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -27,10 +27,7 @@ tuple> permute2d_sparse_data_impl_npu( auto permuteConti = permute.contiguous(); auto lengthsConti = lengths.contiguous(); auto valuesConti = values.contiguous(); - at::Tensor weightsConti = at::empty({1}, lengths.options()); - if (weights.has_value()) { - weightsConti = weights.value().contiguous(); - } + auto weightsConti = weigths.value_or(at::Tensor()).contiguous(); const auto T = permute.size(0); const auto B = lengths.size(1); @@ -50,10 +47,7 @@ tuple> permute2d_sparse_data_impl_npu( at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); - at::Tensor outWeights = at::Tensor(); - if (weights.has_value()) { - outWeights = at::empty({outValuesLen}, weightsConti.options()); - } + at::Tensor outWeights = weights.has_value() ? at::empty({outValuesLen}, weightsConti.options()) : at::Tensor(); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, outValuesLen, outLengths, outValues, outWeights); -- Gitee From f22f303fca3cbe3dc1f4b145a4f38a1eb29c6faf Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 23 Jul 2025 16:16:29 +0800 Subject: [PATCH 08/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81weights=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 11 ++-- .../test_permute2d_sparse_data.py | 60 ++++++------------- .../permute2d_sparse_data.cpp | 2 +- 3 files changed, 26 insertions(+), 47 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 262dcc40..e732785b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -50,10 +50,13 @@ namespace optiling { auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - auto permuteShape = context->GetInputShape(PERMUTE_INDEX)->GetStorageShape(); - auto lengthsShape = context->GetInputShape(LENGTH_INDEX)->GetStorageShape(); - auto valuesShape = context->GetInputShape(VALUES_INDEX)->GetStorageShape(); - auto weightsShape = context->GetInputShape(WEIGHTS_INDEX)->GetStorageShape(); + gert::Shape permuteShape = context->GetInputShape(PERMUTE_INDEX)->GetStorageShape(); + gert::Shape lengthsShape = context->GetInputShape(LENGTH_INDEX)->GetStorageShape(); + gert::Shape valuesShape = context->GetInputShape(VALUES_INDEX)->GetStorageShape(); + gert::Shape weightsShape; + if (enableWeights) { + weightsShape = context->GetInputShape(WEIGHTS_INDEX)->GetStorageShape(); + } // shape check if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 8e231e2a..61801fb3 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -25,44 +25,26 @@ import numpy as np DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") -def get_result(permute, lengths, values, weights, permuted_lengths_sum): - input_permute_torch = torch.from_numpy(permute) - input_lengths_torch = torch.from_numpy(lengths) - input_values_torch = torch.from_numpy(values) - input_weights_torch = torch.from_numpy(weights) - - (permuted_lengths, permuted_values, permuted_weights) = ( - torch.ops.fbgemm.permute_2D_sparse_data( - input_permute_torch, - input_lengths_torch, - input_values_torch, - input_weights_torch, - permuted_lengths_sum - ) - ) - - return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() +def get_result(permute, lengths, values, weights, permuted_lengths_sum, device: str = 'cpu'): + tensors = { + 'permute': torch.from_numpy(permute), + 'lengths': torch.from_numpy(lengths), + 'values': torch.from_numpy(values), + 'weights': torch.from_numpy(weights) if weights is not None else None + } + if device and device.startswith('npu'): + torch.npu.set_device(device) + tensors = {k: v.to(device) if v is not None else None for k, v in tensors.items()} -def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=None): - torch.npu.set_device(DEVICE) - input_permute_torch = torch.from_numpy(permute).to(DEVICE) - input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) - input_values_torch = torch.from_numpy(values).to(DEVICE) - input_weights_torch = torch.from_numpy(weights).to(DEVICE) - - (permuted_lengths, permuted_values, permuted_weights) = ( - torch.ops.fbgemm.permute_2D_sparse_data( - input_permute_torch, - input_lengths_torch, - input_values_torch, - input_weights_torch, - permuted_lengths_sum - ) + results = torch.ops.fbgemm.permute_2D_sparse_data( + permuted_lengths_sum=permuted_lengths_sum, **tensors ) - torch.npu.synchronize() - return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() + if device: + torch.npu.synchronize() + + return tuple(r.cpu() if isinstance(r, torch.Tensor) else r for r in results) @pytest.mark.parametrize("ltype", [np.int64, np.int32]) @pytest.mark.parametrize("vtype", [np.int64, np.int32, np.float32]) @@ -71,13 +53,7 @@ def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=None) @pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) @pytest.mark.parametrize("permuted_lengths_sum", [True, False]) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(ltype, - vtype, - wtype, - permute_dim, - extra_permute_dim, - permuted_lengths_sum, - lengths): +def test_permute2d_sparse_data(ltype, vtype, wtype, permute_dim, extra_permute_dim, permuted_lengths_sum, lengths): permute = np.arange(permute_dim, dtype=np.int32) np.random.shuffle(permute) values = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=vtype) @@ -86,7 +62,7 @@ def test_permute2d_sparse_data(ltype, permuted_lengths_sum = lengths[:permute_dim].sum() if permuted_lengths_sum else None golden = get_result(permute, lengths, values, weights, permuted_lengths_sum) - result = get_result_npu(permute, lengths, values, weights, permuted_lengths_sum) + result = get_result(permute, lengths, values, weights, permuted_lengths_sum, DEVICE) for gt, pred in zip(golden, result): assert type(gt) is type(pred) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index f9c967ae..a6b29e4f 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -27,7 +27,7 @@ tuple> permute2d_sparse_data_impl_npu( auto permuteConti = permute.contiguous(); auto lengthsConti = lengths.contiguous(); auto valuesConti = values.contiguous(); - auto weightsConti = weigths.value_or(at::Tensor()).contiguous(); + auto weightsConti = weights.value_or(at::Tensor()).contiguous(); const auto T = permute.size(0); const auto B = lengths.size(1); -- Gitee From f3c94ffc2c9c191822bef8e590e6c4b59b6b29e3 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 24 Jul 2025 09:17:06 +0800 Subject: [PATCH 09/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 66 +++---------------- .../test_permute2d_sparse_data.py | 10 ++- 2 files changed, 13 insertions(+), 63 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 6858fda1..83cbfa1b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -190,7 +190,7 @@ public: } } - __aicore__ void PermuteValues() + __aicore__ void PermuteData(GlobalTensor srcGT, GlobalTensor dstGT, uint8_t datasize) { int64_t outValueOffset = 0; int64_t currentT = 0; @@ -212,10 +212,10 @@ public: int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); - int64_t valuesStartIndex = (startIndex + offsetOfThisCore) * sizeof(VType); - int64_t outValueStartIndex = (outValueOffset + offsetOfThisCore) * sizeof(VType); + int64_t valuesStartIndex = (startIndex + offsetOfThisCore) * datasize; + int64_t outValueStartIndex = (outValueOffset + offsetOfThisCore) * datasize; - int64_t remainLen = valueLenOfThisCore * sizeof(VType); + int64_t remainLen = valueLenOfThisCore * datasize; while (remainLen > 0) { int64_t thisLen = blockLen; if (remainLen < blockLen) { @@ -223,66 +223,18 @@ public: } LocalTensor inputTensor = inQueueX.AllocTensor(); - CpGm2Local(inputTensor, valuesGT[valuesStartIndex], thisLen); + CpGm2Local(inputTensor, srcGT[valuesStartIndex], thisLen); inQueueX.EnQue(inputTensor); LocalTensor outPutTensor = inQueueX.DeQue(); - CpLocal2Gm(outIndicesGT[outValueStartIndex], outPutTensor, thisLen); + CpLocal2Gm(dstGT[outValueStartIndex], outPutTensor, thisLen); outValueStartIndex += thisLen; valuesStartIndex += thisLen; inQueueX.FreeTensor(outPutTensor); remainLen = remainLen - thisLen; } - outValueOffset+=tLen; - } - } - - __aicore__ void PermuteWeights() - { - int64_t outWeightOffset = 0; - int64_t currentT = 0; - for (int64_t i = 0; i < permuteDim0; i++) { - currentT = *(permutePtr + i); - int64_t tLen = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN) - *(totalOffsetPtr + currentT * UB_ALIGN); - int64_t baseCoreLen = tLen / coreNum; - int64_t tailLen = tLen % coreNum; - - // calculate current core permute weights offset - if (GetBlockIdx() < tailLen) { - weightLenOfThisCore = baseCoreLen + 1; - offsetOfThisCore = GetBlockIdx() * (baseCoreLen + 1); - } else { - weightLenOfThisCore = baseCoreLen; - offsetOfThisCore = tailLen * (baseCoreLen + 1) + (GetBlockIdx() - tailLen) * baseCoreLen; - } - - int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); - int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); - - int64_t weightsStartIndex = (startIndex + offsetOfThisCore) * sizeof(float); - int64_t outWeightStartIndex = (outWeightOffset + offsetOfThisCore) * sizeof(float); - - int64_t remainLen = weightLenOfThisCore * sizeof(float); - while (remainLen > 0) { - int64_t thisLen = blockLen; - if (remainLen < blockLen) { - thisLen = remainLen; - } - LocalTensor inputTensor = inQueueX.AllocTensor(); - - CpGm2Local(inputTensor, weightsGT[weightsStartIndex], thisLen); - inQueueX.EnQue(inputTensor); - LocalTensor outPutTensor = inQueueX.DeQue(); - - CpLocal2Gm(outWeightsGT[outWeightStartIndex], outPutTensor, thisLen); - - outWeightStartIndex += thisLen; - weightsStartIndex += thisLen; - inQueueX.FreeTensor(outPutTensor); - remainLen = remainLen - thisLen; - } - outWeightOffset += tLen; + outValueOffset += tLen; } } @@ -304,9 +256,9 @@ public: offsetGt.GetValue((i - 1) * UB_ALIGN); } PermuteLengths(); - PermuteValues(); + PermuteData(valuesGT, outIndicesGT, sizeof(VType)); if (enableWeights) { - PermuteWeights(); + PermuteData(weightsGT, outWeightsGT, sizeof(float)); } } diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 61801fb3..36a1ec21 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -25,26 +25,24 @@ import numpy as np DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") + def get_result(permute, lengths, values, weights, permuted_lengths_sum, device: str = 'cpu'): tensors = { 'permute': torch.from_numpy(permute), 'lengths': torch.from_numpy(lengths), 'values': torch.from_numpy(values), - 'weights': torch.from_numpy(weights) if weights is not None else None + 'weights': torch.from_numpy(weights) if isinstance(weights, torch.Tensor) else None } if device and device.startswith('npu'): torch.npu.set_device(device) - tensors = {k: v.to(device) if v is not None else None for k, v in tensors.items()} + tensors = {k: v.to(device) if isinstance(v, torch.Tensor) else None for k, v in tensors.items()} results = torch.ops.fbgemm.permute_2D_sparse_data( permuted_lengths_sum=permuted_lengths_sum, **tensors ) + return tuple(result.cpu() if isinstance(result, torch.Tensor) else result for result in results) - if device: - torch.npu.synchronize() - - return tuple(r.cpu() if isinstance(r, torch.Tensor) else r for r in results) @pytest.mark.parametrize("ltype", [np.int64, np.int32]) @pytest.mark.parametrize("vtype", [np.int64, np.int32, np.float32]) -- Gitee From ec51f7895696c866ee31fab6d5ba73de65df41f5 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 24 Jul 2025 10:16:47 +0800 Subject: [PATCH 10/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 83cbfa1b..9fa0cd1e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -190,7 +190,7 @@ public: } } - __aicore__ void PermuteData(GlobalTensor srcGT, GlobalTensor dstGT, uint8_t datasize) + __aicore__ void PermuteData(GlobalTensor dstGT, GlobalTensor srcGT, uint8_t datasize) { int64_t outValueOffset = 0; int64_t currentT = 0; @@ -256,9 +256,9 @@ public: offsetGt.GetValue((i - 1) * UB_ALIGN); } PermuteLengths(); - PermuteData(valuesGT, outIndicesGT, sizeof(VType)); + PermuteData(outIndicesGT, valuesGT, sizeof(VType)); if (enableWeights) { - PermuteData(weightsGT, outWeightsGT, sizeof(float)); + PermuteData(outWeightsGT, weightsGT, sizeof(float)); } } -- Gitee From 372a0746f2d87b96f2336bbab32dd296064b96b8 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 24 Jul 2025 11:31:54 +0800 Subject: [PATCH 11/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_permute2d_sparse_data.py | 77 ++++++++++++------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 36a1ec21..1c4cc995 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import itertools import sysconfig +from typing import Iterable, Callable import pytest import torch @@ -25,42 +27,63 @@ import numpy as np DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") +PTYPE = [np.int32] +LTYPE = [np.int64, np.int32] +VTYPE = [np.int64, np.int32, np.float32] +WTYPE = [None, np.float32] +TYPE_LIST = itertools.product(PTYPE, LTYPE, VTYPE, WTYPE) + +T = np.random.randint(2, 30, 4) +EXTRA_T = [0, 3, 8] +B = [2048, 20480, 204800] +SHAPE_LIST = itertools.product(T, EXTRA_T, B) + +def tensors_apply(data: Iterable, func: Callable): + return (func(value) if isinstance(value, torch.Tensor) else value for value in data) -def get_result(permute, lengths, values, weights, permuted_lengths_sum, device: str = 'cpu'): - tensors = { - 'permute': torch.from_numpy(permute), - 'lengths': torch.from_numpy(lengths), - 'values': torch.from_numpy(values), - 'weights': torch.from_numpy(weights) if isinstance(weights, torch.Tensor) else None - } + +def get_result(tensors: dict, device: str = 'cpu'): + tensors = dict(tensors_apply(tensors.items(), lambda x, y: (x, torch.from_numpy(y)))) if device and device.startswith('npu'): torch.npu.set_device(device) - tensors = {k: v.to(device) if isinstance(v, torch.Tensor) else None for k, v in tensors.items()} + tensors = dict(tensors_apply(tensors.items(), lambda x, y: (x, y.to(device)))) + + results = torch.ops.fbgemm.permute_2D_sparse_data(**tensors) + return tuple(tensors_apply(results, lambda x: x.cpu())) - results = torch.ops.fbgemm.permute_2D_sparse_data( - permuted_lengths_sum=permuted_lengths_sum, **tensors - ) - return tuple(result.cpu() if isinstance(result, torch.Tensor) else result for result in results) +@pytest.mark.parametrize("types", TYPE_LIST) +@pytest.mark.parametrize("shapes", SHAPE_LIST) +@pytest.mark.parametrize("enable_permuted_sum", [True, False]) +def test_permute2d_sparse_data(types, shapes, enable_permuted_sum): + """ + Params: + permute: (T) dtype=int32 + lenghts: (T + T', B) dtype=ltype + L = lengths[:T].sum() + values: (L) dtype=vtype + weights: (L) dtype=fp32 + """ + ptype, ltype, vtype, wtype = types + t, extra_t, b = shapes -@pytest.mark.parametrize("ltype", [np.int64, np.int32]) -@pytest.mark.parametrize("vtype", [np.int64, np.int32, np.float32]) -@pytest.mark.parametrize("wtype", [None, np.float32]) -@pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) -@pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) -@pytest.mark.parametrize("permuted_lengths_sum", [True, False]) -@pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(ltype, vtype, wtype, permute_dim, extra_permute_dim, permuted_lengths_sum, lengths): - permute = np.arange(permute_dim, dtype=np.int32) + permute = np.arange(t, dtype=ptype) np.random.shuffle(permute) - values = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=vtype) - weights = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=wtype) if wtype else None - lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) - permuted_lengths_sum = lengths[:permute_dim].sum() if permuted_lengths_sum else None + lengths = np.ones((t + extra_t, b), dtype=ltype) + values = np.arange(0, (t + extra_t) * b, dtype=vtype) + weights = np.arange(0, (t + extra_t) * b, dtype=wtype) if wtype else None + permuted_lengths_sum = lengths[:t].sum() if enable_permuted_sum else None + params = { + 'permute': permute, + 'lengths': lengths, + 'values': values, + 'weights': weights, + 'permuted_lengths_sum': permuted_lengths_sum + } - golden = get_result(permute, lengths, values, weights, permuted_lengths_sum) - result = get_result(permute, lengths, values, weights, permuted_lengths_sum, DEVICE) + golden = get_result(params) + result = get_result(params, DEVICE) for gt, pred in zip(golden, result): assert type(gt) is type(pred) -- Gitee From a8e99bcce7ed379e987636339f10f460ec9bd392 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 24 Jul 2025 11:55:59 +0800 Subject: [PATCH 12/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../permute2d_sparse_data/test_permute2d_sparse_data.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 1c4cc995..90aa66b1 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -38,19 +38,16 @@ EXTRA_T = [0, 3, 8] B = [2048, 20480, 204800] SHAPE_LIST = itertools.product(T, EXTRA_T, B) -def tensors_apply(data: Iterable, func: Callable): - return (func(value) if isinstance(value, torch.Tensor) else value for value in data) - def get_result(tensors: dict, device: str = 'cpu'): - tensors = dict(tensors_apply(tensors.items(), lambda x, y: (x, torch.from_numpy(y)))) + tensors = {k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v for k, v in tensors.items()} if device and device.startswith('npu'): torch.npu.set_device(device) - tensors = dict(tensors_apply(tensors.items(), lambda x, y: (x, y.to(device)))) + tensors = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in tensors.items()} results = torch.ops.fbgemm.permute_2D_sparse_data(**tensors) - return tuple(tensors_apply(results, lambda x: x.cpu())) + return [x.cpu() if isinstance(x, torch.Tensor) else x for x in results] @pytest.mark.parametrize("types", TYPE_LIST) -- Gitee From 37d0bfa440ca6321c8d9390aedeec9427883f1c3 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Fri, 25 Jul 2025 09:44:43 +0800 Subject: [PATCH 13/20] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../permute2d_sparse_data/op_host/permute2d_sparse_data.cpp | 2 +- .../permute2d_sparse_data/test_permute2d_sparse_data.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index e732785b..710e044e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -36,11 +36,11 @@ namespace optiling { static ge::graphStatus TilingFunc(gert::TilingContext* context) { Permute2dSparseDataTilingData tiling; + OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); bool enableWeights = (context->GetOptionalInputTensor(WEIGHTS_INDEX) != nullptr); tiling.set_enableWeights(enableWeights); - OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("permuteShape", context->GetInputShape(PERMUTE_INDEX), return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("lengthsShape", context->GetInputShape(LENGTH_INDEX), return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("valuesShape", context->GetInputShape(VALUES_INDEX), return ge::GRAPH_FAILED); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 90aa66b1..72a0dc66 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -16,7 +16,6 @@ # ============================================================================== import itertools import sysconfig -from typing import Iterable, Callable import pytest import torch -- Gitee From 25b443e45894bd2d43bb9b86e2d72af4e23465c9 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 28 Jul 2025 11:16:22 +0800 Subject: [PATCH 14/20] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index ca06eca9..eef83078 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -48,14 +48,14 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh a) 算子的主要功能是实现fbgemm的permute2d_sparse_data, 实现了对二维稀疏数据进行重排。 b) 算子参数说明: -* permute: 重排的顺序参数tensor; -* lengths: 待重排长度参数; -* values: 待重排值参数; -* weights: 暂不支持使用 -* permute_sum: 暂不支持使用 +* permute: 重排的顺序参数tensor; +* lengths: 待重排长度参数; +* values: 待重排值参数; +* weights: 可选待重排值参数,与values操作完全相同; +* permute_sum: values/weights有效长度; * permuted_lengths: 输出, 重排后长度tensor; -* permuted_values: 输出,重排后值tensor; -* permuted_weights: 输出, 暂不支持 +* permuted_values: 输出,重排后的values; +* permuted_weights: 输出,重排后的weights; c) 算子约束说明: -- Gitee From 3b403fcf9dae9749d96b8685590fb80b500c2994 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 28 Jul 2025 11:17:16 +0800 Subject: [PATCH 15/20] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rec_for_torch/operators/permute2d_sparse_data/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index eef83078..2cb601e7 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -68,12 +68,13 @@ c) 算子约束说明: ## 算子逻辑 ``` -import numpy as np +import torch +import fbgemm_gpu def permute2d_sparse_data(permute, lengths, values): (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data(permute, lengths, values) ) - return permuted_lengths, permuted_values + return permuted_lengths, permuted_values, permuted_weights ``` \ No newline at end of file -- Gitee From a7c8ae8c3cd4909ec09ed94d70e8c14960c0576b Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 28 Jul 2025 11:22:44 +0800 Subject: [PATCH 16/20] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index 2cb601e7..0362f2e0 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -60,10 +60,16 @@ b) 算子参数说明: c) 算子约束说明: * 支持的型号:Atlas A2系列产品; -* 支持的CANN版本:8.2.RC1.alpha001及之后版本; -* 支持的输入数据类型:permute: int32, lengths: int64/int32, values: int64/int32/float; -* permute为1维tensor,lengths为二维tensor,且permute的第一维长度与lengths的第一维长度相等; -* values长度为lengths中所有数据长度之和 +* 支持的CANN版本:8.2.RC1.alpha001及之后版本; +* 支持的输入数据类型: + * permute: int32 + * lengths: int64/int32 + * values: int64/int32/fp32 + * weights: fp32 + * permute_sum: int(标量); +* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; +* 未指定permute_sum时,values/weights长度为lengths中所有数据长度之和; +* 指定permute_sum时,values/weights长度为permute_sum; * 算子参数均会在NPU显存中存放,请根据显存大小合理设置参数长度。 ## 算子逻辑 -- Gitee From 00ced895049c5c366cf0e7c28e59c2c22e2b4397 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 29 Jul 2025 09:03:57 +0800 Subject: [PATCH 17/20] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rec_for_torch/operators/permute2d_sparse_data/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index 0362f2e0..abaa4d3a 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -70,6 +70,7 @@ c) 算子约束说明: * permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; * 未指定permute_sum时,values/weights长度为lengths中所有数据长度之和; * 指定permute_sum时,values/weights长度为permute_sum; +* weights和values长度相同; * 算子参数均会在NPU显存中存放,请根据显存大小合理设置参数长度。 ## 算子逻辑 -- Gitee From b4e54b19182c6e2fe0f89dbf635f6b760e7a10b4 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 30 Jul 2025 20:11:47 +0800 Subject: [PATCH 18/20] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index abaa4d3a..c2e6a5c2 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -52,7 +52,7 @@ b) 算子参数说明: * lengths: 待重排长度参数; * values: 待重排值参数; * weights: 可选待重排值参数,与values操作完全相同; -* permute_sum: values/weights有效长度; +* permuted_lengths_sum: values/weights有效长度; * permuted_lengths: 输出, 重排后长度tensor; * permuted_values: 输出,重排后的values; * permuted_weights: 输出,重排后的weights; @@ -67,10 +67,11 @@ c) 算子约束说明: * values: int64/int32/fp32 * weights: fp32 * permute_sum: int(标量); -* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; -* 未指定permute_sum时,values/weights长度为lengths中所有数据长度之和; -* 指定permute_sum时,values/weights长度为permute_sum; -* weights和values长度相同; +* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; +同时permute中的每个值均满足: >= 0 且 < `lengths.shape[0]` +* 未指定permuted_lengths_sum时,values/weights长度为lengths中所有数据长度之和; +* 指定permuted_lengths_sum时,values/weights长度为permuted_lengths_sum; +* weights和values长度相同,均等于`lengths.sum()`; * 算子参数均会在NPU显存中存放,请根据显存大小合理设置参数长度。 ## 算子逻辑 -- Gitee From 73eb422d7f29069567d4d89776391b854831fa02 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 31 Jul 2025 10:02:52 +0800 Subject: [PATCH 19/20] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index c2e6a5c2..3367c9b8 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -50,9 +50,9 @@ b) 算子参数说明: * permute: 重排的顺序参数tensor; * lengths: 待重排长度参数; -* values: 待重排值参数; -* weights: 可选待重排值参数,与values操作完全相同; -* permuted_lengths_sum: values/weights有效长度; +* values: 待重排序的1D-tensor; +* weights: 可选入参,待重排序的1D-tensor。与values执行相同操作; +* permuted_lengths_sum: 可选入参,values/weights有效长度; * permuted_lengths: 输出, 重排后长度tensor; * permuted_values: 输出,重排后的values; * permuted_weights: 输出,重排后的weights; @@ -62,10 +62,10 @@ c) 算子约束说明: * 支持的型号:Atlas A2系列产品; * 支持的CANN版本:8.2.RC1.alpha001及之后版本; * 支持的输入数据类型: - * permute: int32 - * lengths: int64/int32 - * values: int64/int32/fp32 - * weights: fp32 + * permute: int32; + * lengths: int64/int32; + * values: int64/int32/fp32; + * weights: fp32; * permute_sum: int(标量); * permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; 同时permute中的每个值均满足: >= 0 且 < `lengths.shape[0]` -- Gitee From bbc3045fad9113bb58e320f9573f3106f76f5d99 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 31 Jul 2025 14:28:15 +0800 Subject: [PATCH 20/20] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index 3367c9b8..82fd4338 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -66,9 +66,8 @@ c) 算子约束说明: * lengths: int64/int32; * values: int64/int32/fp32; * weights: fp32; - * permute_sum: int(标量); -* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; -同时permute中的每个值均满足: >= 0 且 < `lengths.shape[0]` + * permuted_lengths_sum: int(标量); +* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度。同时permute中的每个值均满足: >= 0 且 < `lengths.shape[0]`; * 未指定permuted_lengths_sum时,values/weights长度为lengths中所有数据长度之和; * 指定permuted_lengths_sum时,values/weights长度为permuted_lengths_sum; * weights和values长度相同,均等于`lengths.sum()`; @@ -78,9 +77,9 @@ c) 算子约束说明: ``` import torch import fbgemm_gpu -def permute2d_sparse_data(permute, lengths, values): +def permute2d_sparse_data(permute, lengths, values, weights, permuted_lengths_sum): (permuted_lengths, permuted_values, permuted_weights) = ( - torch.ops.fbgemm.permute_2D_sparse_data(permute, lengths, values) + torch.ops.fbgemm.permute_2D_sparse_data(permute, lengths, values, weights, permuted_lengths_sum) ) return permuted_lengths, permuted_values, permuted_weights -- Gitee