From 3d02bc379d8e29caafa46ad333cdd01b51b909ca Mon Sep 17 00:00:00 2001 From: zhoucy Date: Sat, 19 Jul 2025 16:05:17 +0800 Subject: [PATCH 01/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82values=E6=94=AF=E6=8C=81int64=E3=80=81int32?= =?UTF-8?q?=E3=80=81fp32=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 248 +++++++++--------- .../test_permute2d_sparse_data.py | 26 +- 2 files changed, 137 insertions(+), 137 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 07059da9..552ff067 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -20,126 +20,126 @@ See the License for the specific language governing permissions and #include "../../../common/ops_log.h" namespace optiling { -constexpr int GM_ALIGN = 64; -constexpr int RESERVER_UB_SIZE = 20 * 1024; -constexpr int DATA_TYPE_INT64 = 8; -constexpr int DATA_TYPE_INT32 = 4; -constexpr int DATA_TYPE_FLOAT32 = 4; -constexpr int NUM_QUEUE = 4; -constexpr int UB_ALIGN = 32; -constexpr int SUPPORT_EMBEDDING_DIM_NUM = 2; -constexpr int PERMUTE_INDEX = 0; -constexpr int LENGTH_INDEX = 1; -constexpr int VALUES_INDEX = 2; - -static ge::graphStatus SetTypeTiling(gert::TilingContext* context, Permute2dSparseDataTilingData& tiling) -{ - // check tensor is nullptr - OPS_LOG_E_IF_NULL("permute", context->GetInputTensor(PERMUTE_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("length", context->GetInputTensor(LENGTH_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("value", context->GetInputTensor(VALUES_INDEX), return ge::GRAPH_FAILED); - // permute: InputTensor(0), support int32 - int64_t permuteDataType = 0; - ge::DataType permuteDataTypeGe = context->GetInputTensor(0)->GetDataType(); - if (permuteDataTypeGe == ge::DataType::DT_INT32) { - permuteDataType = DATA_TYPE_INT32; - } - - // lengths: InputTensor(1), support int64、int32 - int64_t lengthsDataType = 0; - ge::DataType lengthsDataTypeGe = context->GetInputTensor(1)->GetDataType(); - if (lengthsDataTypeGe == ge::DataType::DT_INT64) { - lengthsDataType = DATA_TYPE_INT64; - } else { - lengthsDataType = DATA_TYPE_INT32; - } - - // value: InputTensor(2), support int64、int32、fp32 - int64_t valueDataType = 0; - ge::DataType dataType = context->GetInputTensor(2)->GetDataType(); - if (dataType == ge::DataType::DT_INT32) { - valueDataType = DATA_TYPE_INT32; - } else if (dataType == ge::DataType::DT_INT64) { - valueDataType = DATA_TYPE_INT64; - } else { - valueDataType = DATA_TYPE_FLOAT32; - } - - tiling.set_valueDataType(valueDataType); - tiling.set_permuteDataType(permuteDataType); - tiling.set_lengthsDataType(lengthsDataType); - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus TilingFunc(gert::TilingContext* context) -{ - OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("permuteShape", context->GetInputShape(PERMUTE_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("lengthsShape", context->GetInputShape(LENGTH_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("valuesShape", context->GetInputShape(VALUES_INDEX), return ge::GRAPH_FAILED); - - Permute2dSparseDataTilingData tiling; - auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - - auto permuteShape = context->GetInputShape(0)->GetStorageShape(); - auto lengthsShape = context->GetInputShape(1)->GetStorageShape(); - auto valuesShape = context->GetInputShape(2)->GetStorageShape(); - - // set ub - uint64_t ubCanUsed; - ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubCanUsed); - ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; - tiling.set_ubCanUsed(ubCanUsed); - - // datatype check - if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || - (permuteShape.GetDim(0) != lengthsShape.GetDim(0))) { - printf("[ERROR]permute shape or lengths shape is error."); - return ge::GRAPH_FAILED; + constexpr int GM_ALIGN = 64; + constexpr int RESERVER_UB_SIZE = 20 * 1024; + constexpr int DATA_TYPE_INT64 = 8; + constexpr int DATA_TYPE_INT32 = 4; + constexpr int DATA_TYPE_FLOAT32 = 4; + constexpr int NUM_QUEUE = 4; + constexpr int UB_ALIGN = 32; + constexpr int SUPPORT_EMBEDDING_DIM_NUM = 2; + constexpr int PERMUTE_INDEX = 0; + constexpr int LENGTH_INDEX = 1; + constexpr int VALUES_INDEX = 2; + + static ge::graphStatus SetTypeTiling(gert::TilingContext* context, Permute2dSparseDataTilingData& tiling) + { + // check tensor is nullptr + OPS_LOG_E_IF_NULL("permute", context->GetInputTensor(PERMUTE_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("length", context->GetInputTensor(LENGTH_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("value", context->GetInputTensor(VALUES_INDEX), return ge::GRAPH_FAILED); + // permute: InputTensor(0), support int32 + int64_t permuteDataType = 0; + ge::DataType permuteDataTypeGe = context->GetInputTensor(0)->GetDataType(); + if (permuteDataTypeGe == ge::DataType::DT_INT32) { + permuteDataType = DATA_TYPE_INT32; + } + + // lengths: InputTensor(1), support int64、int32 + int64_t lengthsDataType = 0; + ge::DataType lengthsDataTypeGe = context->GetInputTensor(1)->GetDataType(); + if (lengthsDataTypeGe == ge::DataType::DT_INT64) { + lengthsDataType = DATA_TYPE_INT64; + } else { + lengthsDataType = DATA_TYPE_INT32; + } + + // value: InputTensor(2), support int64、int32、fp32 + int64_t valueDataType = 0; + ge::DataType dataType = context->GetInputTensor(2)->GetDataType(); + if (dataType == ge::DataType::DT_INT32) { + valueDataType = DATA_TYPE_INT32; + } else if (dataType == ge::DataType::DT_INT64) { + valueDataType = DATA_TYPE_INT64; + } else { + valueDataType = DATA_TYPE_FLOAT32; + } + + tiling.set_valueDataType(valueDataType); + tiling.set_permuteDataType(permuteDataType); + tiling.set_lengthsDataType(lengthsDataType); + return ge::GRAPH_SUCCESS; } - // set coreNUm - size_t coreNum = ascendPlatform.GetCoreNumAiv(); - if (coreNum == 0) { - return ge::GRAPH_FAILED; + static ge::graphStatus TilingFunc(gert::TilingContext* context) + { + OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("permuteShape", context->GetInputShape(PERMUTE_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("lengthsShape", context->GetInputShape(LENGTH_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("valuesShape", context->GetInputShape(VALUES_INDEX), return ge::GRAPH_FAILED); + + Permute2dSparseDataTilingData tiling; + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + + auto permuteShape = context->GetInputShape(0)->GetStorageShape(); + auto lengthsShape = context->GetInputShape(1)->GetStorageShape(); + auto valuesShape = context->GetInputShape(2)->GetStorageShape(); + + // set ub + uint64_t ubCanUsed; + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubCanUsed); + ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; + tiling.set_ubCanUsed(ubCanUsed); + + // datatype check + if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || + (permuteShape.GetDim(0) != lengthsShape.GetDim(0))) { + printf("[ERROR]permute shape or lengths shape is error."); + return ge::GRAPH_FAILED; + } + + // set coreNUm + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + if (coreNum == 0) { + return ge::GRAPH_FAILED; + } + tiling.set_coreNum(coreNum); + + // tiling core + int64_t totalBatch = permuteShape.GetDim(0); + tiling.set_totalBatch(totalBatch); + int64_t baseBatchLen = (permuteShape.GetDim(0)) / coreNum; + tiling.set_baseBatchLen(baseBatchLen); + int64_t tailSplitIndex = (permuteShape.GetDim(0)) % coreNum; + tiling.set_tailSplitIndex(tailSplitIndex); + + // set data dim + int64_t permuteDim0 = permuteShape.GetDim(0); + tiling.set_permuteDim0(permuteDim0); + int64_t lengthsT = lengthsShape.GetDim(0); + tiling.set_lengthsT(lengthsT); + int64_t lengthsB = lengthsShape.GetDim(1); + tiling.set_lengthsB(lengthsB); + int64_t valuesDim = valuesShape.GetDim(0); + tiling.set_valuesDim(valuesDim); + + // apply workspace + size_t* currentWorkspace = context->GetWorkspaceSizes(1); + size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); + currentWorkspace[0] = systemWorkspacesSize + (lengthsT + 1) * GM_ALIGN + (lengthsT + 1) * GM_ALIGN * coreNum; + + OPS_LOG_E_IF(SetTypeTiling(context, tiling) == ge::GRAPH_FAILED, context, return ge::GRAPH_FAILED, + "SetTypeTiling Failed."); + + context->SetBlockDim(coreNum); + + OPS_LOG_E_IF_NULL("raw tilingData", context->GetRawTilingData(), return ge::GRAPH_FAILED); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; } - tiling.set_coreNum(coreNum); - - // tiling core - int64_t totalBatch = permuteShape.GetDim(0); - tiling.set_totalBatch(totalBatch); - int64_t baseBatchLen = (permuteShape.GetDim(0)) / coreNum; - tiling.set_baseBatchLen(baseBatchLen); - int64_t tailSplitIndex = (permuteShape.GetDim(0)) % coreNum; - tiling.set_tailSplitIndex(tailSplitIndex); - - // set data dim - int64_t permuteDim0 = permuteShape.GetDim(0); - tiling.set_permuteDim0(permuteDim0); - int64_t lengthsT = lengthsShape.GetDim(0); - tiling.set_lengthsT(lengthsT); - int64_t lengthsB = lengthsShape.GetDim(1); - tiling.set_lengthsB(lengthsB); - int64_t valuesDim = valuesShape.GetDim(0); - tiling.set_valuesDim(valuesDim); - - // apply workspace - size_t* currentWorkspace = context->GetWorkspaceSizes(1); - size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); - currentWorkspace[0] = systemWorkspacesSize + (lengthsT + 1) * GM_ALIGN + (lengthsT + 1) * GM_ALIGN * coreNum; - - OPS_LOG_E_IF(SetTypeTiling(context, tiling) == ge::GRAPH_FAILED, context, return ge::GRAPH_FAILED, - "SetTypeTiling Failed."); - - context->SetBlockDim(coreNum); - - OPS_LOG_E_IF_NULL("raw tilingData", context->GetRawTilingData(), return ge::GRAPH_FAILED); - - tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); - context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); - - return ge::GRAPH_SUCCESS; -} } // namespace optiling namespace ge { @@ -178,37 +178,37 @@ public: { this->Input("permute") .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .DataTypeList({ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("lengths") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT32, ge::DT_INT64, ge::DT_INT32}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("values") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT, ge::DT_FLOAT}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("weights") .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("permuted_lengths") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT32, ge::DT_INT64, ge::DT_INT32}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("permuted_values") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT, ge::DT_FLOAT}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("permuted_weights") .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index c6f6120d..4315fb59 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -14,18 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - import sysconfig + import pytest import torch import torch_npu import fbgemm_gpu import numpy as np +DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") -lengths_type = [np.int64, np.int32, np.int64, np.int32] -values_type = [np.int64, np.int32, np.float32, np.float32] +lengths_type = [np.int64, np.int32] +values_type = [np.int64, np.int32, np.float32] def get_result(permute, lengths, values): @@ -44,11 +45,11 @@ def get_result(permute, lengths, values): return permuted_lengths.cpu(), permuted_values.cpu() -def get_result_npu(permute, lengths, values, device="npu:0"): - torch.npu.set_device(device) - input_permute_torch = torch.from_numpy(permute).to(device) - input_lengths_torch = torch.from_numpy(lengths).to(device) - input_values_torch = torch.from_numpy(values).to(device) +def get_result_npu(permute, lengths, values): + torch.npu.set_device(DEVICE) + input_permute_torch = torch.from_numpy(permute).to(DEVICE) + input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) + input_values_torch = torch.from_numpy(values).to(DEVICE) (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data( @@ -59,19 +60,18 @@ def get_result_npu(permute, lengths, values, device="npu:0"): return permuted_lengths.cpu(), permuted_values.cpu() -@pytest.mark.parametrize("device", ["npu:0", "npu:5"]) -@pytest.mark.parametrize("type_list", zip(lengths_type, values_type)) +@pytest.mark.parametrize("ltype", lengths_type) +@pytest.mark.parametrize("vtype", values_type) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(type_list, device, permute_dim, lengths): - ltype, vtype = type_list +def test_permute2d_sparse_data(permute_dim, lengths, ltype, vtype): input_permute = np.arange(permute_dim).astype(np.int32) np.random.shuffle(input_permute) input_lengths = np.ones((permute_dim, lengths), dtype=ltype) input_values = np.arange(0, permute_dim * lengths).astype(vtype) golden = get_result(input_permute, input_lengths, input_values) - result = get_result_npu(input_permute, input_lengths, input_values, device) + result = get_result_npu(input_permute, input_lengths, input_values) assert torch.allclose(golden[0], result[0], atol=1e-5) assert torch.allclose(golden[1], result[1], atol=1e-5) -- Gitee From 5095a680856a3f35f8423d0e508108a1b37218ef Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 21 Jul 2025 08:57:45 +0800 Subject: [PATCH 02/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82permute.size(0)=20<=3D=20lengths.size(0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 6 +++--- .../permute2d_sparse_data/test_permute2d_sparse_data.py | 7 ++++--- .../2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 552ff067..126cb8ee 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -93,7 +93,7 @@ namespace optiling { // datatype check if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || - (permuteShape.GetDim(0) != lengthsShape.GetDim(0))) { + (permuteShape.GetDim(0) > lengthsShape.GetDim(0))) { printf("[ERROR]permute shape or lengths shape is error."); return ge::GRAPH_FAILED; } @@ -116,8 +116,8 @@ namespace optiling { // set data dim int64_t permuteDim0 = permuteShape.GetDim(0); tiling.set_permuteDim0(permuteDim0); - int64_t lengthsT = lengthsShape.GetDim(0); - tiling.set_lengthsT(lengthsT); +// int64_t lengthsT = lengthsShape.GetDim(0); + tiling.set_lengthsT(permuteDim0); int64_t lengthsB = lengthsShape.GetDim(1); tiling.set_lengthsB(lengthsB); int64_t valuesDim = valuesShape.GetDim(0); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 4315fb59..5be5b643 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -63,12 +63,13 @@ def get_result_npu(permute, lengths, values): @pytest.mark.parametrize("ltype", lengths_type) @pytest.mark.parametrize("vtype", values_type) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) +@pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(permute_dim, lengths, ltype, vtype): +def test_permute2d_sparse_data(permute_dim, extra_permute_dim, lengths, ltype, vtype): input_permute = np.arange(permute_dim).astype(np.int32) np.random.shuffle(input_permute) - input_lengths = np.ones((permute_dim, lengths), dtype=ltype) - input_values = np.arange(0, permute_dim * lengths).astype(vtype) + input_lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) + input_values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) golden = get_result(input_permute, input_lengths, input_values) result = get_result_npu(input_permute, input_lengths, input_values) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index f24d0534..40114c81 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -33,8 +33,8 @@ tuple> permute2d_sparse_data_impl_npu( const auto T = lengths.size(0); const auto B = lengths.size(1); - at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); - at::Tensor outValues = at::empty({valuesConti.size(0)}, valuesConti.options()); + at::Tensor outLengths = lengthsConti.clone(); + at::Tensor outValues = valuesConti.clone(); at::Tensor outWeights = at::empty({1}, weightsConti.options()); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, T, -- Gitee From 4daa8b04832838e6cce9d837c22ffbaa9a8629a4 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 21 Jul 2025 14:31:33 +0800 Subject: [PATCH 03/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82permute.size(0)=20<=3D=20lengths.size(0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 4 ++-- .../permute2d_sparse_data.cpp | 20 ++++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 126cb8ee..f78556fa 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -116,8 +116,8 @@ namespace optiling { // set data dim int64_t permuteDim0 = permuteShape.GetDim(0); tiling.set_permuteDim0(permuteDim0); -// int64_t lengthsT = lengthsShape.GetDim(0); - tiling.set_lengthsT(permuteDim0); + int64_t lengthsT = permuteDim0; + tiling.set_lengthsT(lengthsT); int64_t lengthsB = lengthsShape.GetDim(1); tiling.set_lengthsB(lengthsB); int64_t valuesDim = valuesShape.GetDim(0); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index 40114c81..c7d79fcb 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -30,11 +30,25 @@ tuple> permute2d_sparse_data_impl_npu( // weight暂不支持 at::Tensor weightsConti = at::empty({1}, lengths.options()); - const auto T = lengths.size(0); + const auto T = permute.size(0); const auto B = lengths.size(1); - at::Tensor outLengths = lengthsConti.clone(); - at::Tensor outValues = valuesConti.clone(); + int outValuesLen; + if (permute.size(0) == lengths.size(0)) { + outValuesLen = valuesConti.size(0); + } else if (permute.size(0) > lengths.size(0)) { + throw std::runtime_error( + "permute.size(0) must be less than or equal to lengths.size(0). " + "Got permute.size(0): " + std::to_string(permute.size(0)) + + ", lengths.size(0): " + std::to_string(lengths.size(0))); + } else if (permuted_lengths_sum && permuted_lengths_sum > 0) { + outValuesLen = permuted_lengths_sum; + } else { + outValuesLen = lengthsConti.narrow(0, 0, T).sum().item(); + } + + at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); + at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); at::Tensor outWeights = at::empty({1}, weightsConti.options()); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, T, -- Gitee From aea35649565a596920a7c168c79dce0a2460b5d1 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 21 Jul 2025 15:11:05 +0800 Subject: [PATCH 04/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82permute.size(0)=20<=3D=20lengths.size(0)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_permute2d_sparse_data.py | 10 ++++++---- .../permute2d_sparse_data.cpp | 19 +++++++++---------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 5be5b643..f5db3bc2 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -45,7 +45,7 @@ def get_result(permute, lengths, values): return permuted_lengths.cpu(), permuted_values.cpu() -def get_result_npu(permute, lengths, values): +def get_result_npu(permute, lengths, values, permuted_lengths_sum=-1): torch.npu.set_device(DEVICE) input_permute_torch = torch.from_numpy(permute).to(DEVICE) input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) @@ -53,7 +53,7 @@ def get_result_npu(permute, lengths, values): (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data( - input_permute_torch, input_lengths_torch, input_values_torch, None + input_permute_torch, input_lengths_torch, input_values_torch, None, permuted_lengths_sum ) ) torch.npu.synchronize() @@ -64,15 +64,17 @@ def get_result_npu(permute, lengths, values): @pytest.mark.parametrize("vtype", values_type) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) @pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) +@pytest.mark.parametrize("permuted_lengths_sum", [None, -1, 0, 1]) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(permute_dim, extra_permute_dim, lengths, ltype, vtype): +def test_permute2d_sparse_data(permute_dim, extra_permute_dim, lengths, ltype, vtype, permuted_lengths_sum): input_permute = np.arange(permute_dim).astype(np.int32) np.random.shuffle(input_permute) input_lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) input_values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) + permuted_lengths_sum = permuted_lengths_sum if permuted_lengths_sum != 1 else input_lengths[:permute_dim].sum() golden = get_result(input_permute, input_lengths, input_values) - result = get_result_npu(input_permute, input_lengths, input_values) + result = get_result_npu(input_permute, input_lengths, input_values, permuted_lengths_sum) assert torch.allclose(golden[0], result[0], atol=1e-5) assert torch.allclose(golden[1], result[1], atol=1e-5) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index c7d79fcb..bae5287e 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -37,12 +37,11 @@ tuple> permute2d_sparse_data_impl_npu( if (permute.size(0) == lengths.size(0)) { outValuesLen = valuesConti.size(0); } else if (permute.size(0) > lengths.size(0)) { - throw std::runtime_error( - "permute.size(0) must be less than or equal to lengths.size(0). " - "Got permute.size(0): " + std::to_string(permute.size(0)) + - ", lengths.size(0): " + std::to_string(lengths.size(0))); - } else if (permuted_lengths_sum && permuted_lengths_sum > 0) { - outValuesLen = permuted_lengths_sum; + throw std::runtime_error("permute.size(0) must be less than or equal to lengths.size(0). " + "Got permute.size(0): " + std::to_string(permute.size(0)) + + ", lengths.size(0): " + std::to_string(lengths.size(0))); + } else if (permuted_lengths_sum.has_value() && permuted_lengths_sum.has_value() > 0) { + outValuesLen = static_cast(permuted_lengths_sum.value()); } else { outValuesLen = lengthsConti.narrow(0, 0, T).sum().item(); } @@ -60,10 +59,10 @@ tuple> permute2d_sparse_data_impl_npu( TORCH_LIBRARY(mxrec, m) { m.def("permute_2D_sparse_data(Tensor permute, " - "Tensor lengths, " - "Tensor values, " - "Tensor? weights=None, " - "SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)"); + " Tensor lengths, " + " Tensor values, " + " Tensor? weights=None, " + " SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)"); } TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) -- Gitee From 79ded207b320a21654b929e30877f9e9ebaa65de Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 22 Jul 2025 11:01:17 +0800 Subject: [PATCH 05/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81weights=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 103 ++++++---------- .../op_host/permute2d_sparse_data_tilling.h | 5 +- .../op_kernel/permute2d_sparse_data.cpp | 2 +- .../op_kernel/permute2d_sparse_data_kernel.h | 113 ++++++++++++------ .../test_permute2d_sparse_data.py | 49 +++++--- .../permute2d_sparse_data.cpp | 8 +- 6 files changed, 155 insertions(+), 125 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index f78556fa..d979869d 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -31,87 +31,40 @@ namespace optiling { constexpr int PERMUTE_INDEX = 0; constexpr int LENGTH_INDEX = 1; constexpr int VALUES_INDEX = 2; + constexpr int WEIGHTS_INDEX = 3; - static ge::graphStatus SetTypeTiling(gert::TilingContext* context, Permute2dSparseDataTilingData& tiling) + static ge::graphStatus TilingFunc(gert::TilingContext* context) { - // check tensor is nullptr - OPS_LOG_E_IF_NULL("permute", context->GetInputTensor(PERMUTE_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("length", context->GetInputTensor(LENGTH_INDEX), return ge::GRAPH_FAILED); - OPS_LOG_E_IF_NULL("value", context->GetInputTensor(VALUES_INDEX), return ge::GRAPH_FAILED); - // permute: InputTensor(0), support int32 - int64_t permuteDataType = 0; - ge::DataType permuteDataTypeGe = context->GetInputTensor(0)->GetDataType(); - if (permuteDataTypeGe == ge::DataType::DT_INT32) { - permuteDataType = DATA_TYPE_INT32; - } - - // lengths: InputTensor(1), support int64、int32 - int64_t lengthsDataType = 0; - ge::DataType lengthsDataTypeGe = context->GetInputTensor(1)->GetDataType(); - if (lengthsDataTypeGe == ge::DataType::DT_INT64) { - lengthsDataType = DATA_TYPE_INT64; - } else { - lengthsDataType = DATA_TYPE_INT32; - } - - // value: InputTensor(2), support int64、int32、fp32 - int64_t valueDataType = 0; - ge::DataType dataType = context->GetInputTensor(2)->GetDataType(); - if (dataType == ge::DataType::DT_INT32) { - valueDataType = DATA_TYPE_INT32; - } else if (dataType == ge::DataType::DT_INT64) { - valueDataType = DATA_TYPE_INT64; - } else { - valueDataType = DATA_TYPE_FLOAT32; - } + Permute2dSparseDataTilingData tiling; - tiling.set_valueDataType(valueDataType); - tiling.set_permuteDataType(permuteDataType); - tiling.set_lengthsDataType(lengthsDataType); - return ge::GRAPH_SUCCESS; - } + bool enableWeights = (context->GetOptionalInputTensor(WEIGHTS_INDEX) != nullptr); + tiling.set_enableWeights(enableWeights); - static ge::graphStatus TilingFunc(gert::TilingContext* context) - { OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("permuteShape", context->GetInputShape(PERMUTE_INDEX), return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("lengthsShape", context->GetInputShape(LENGTH_INDEX), return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("valuesShape", context->GetInputShape(VALUES_INDEX), return ge::GRAPH_FAILED); + if (enableWeights) { + OPS_LOG_E_IF_NULL("weightsShape", context->GetInputShape(WEIGHTS_INDEX), return ge::GRAPH_FAILED); + } - Permute2dSparseDataTilingData tiling; auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - auto permuteShape = context->GetInputShape(0)->GetStorageShape(); - auto lengthsShape = context->GetInputShape(1)->GetStorageShape(); - auto valuesShape = context->GetInputShape(2)->GetStorageShape(); - - // set ub - uint64_t ubCanUsed; - ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubCanUsed); - ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; - tiling.set_ubCanUsed(ubCanUsed); + auto permuteShape = context->GetInputShape(PERMUTE_INDEX)->GetStorageShape(); + auto lengthsShape = context->GetInputShape(LENGTH_INDEX)->GetStorageShape(); + auto valuesShape = context->GetInputShape(VALUES_INDEX)->GetStorageShape(); + auto weightsShape = context->GetInputShape(WEIGHTS_INDEX)->GetStorageShape(); - // datatype check + // shape check if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || (permuteShape.GetDim(0) > lengthsShape.GetDim(0))) { - printf("[ERROR]permute shape or lengths shape is error."); + OPS_LOG_E("", "[ERROR]permute shape or lengths shape is error. "); return ge::GRAPH_FAILED; } - - // set coreNUm - size_t coreNum = ascendPlatform.GetCoreNumAiv(); - if (coreNum == 0) { + if (enableWeights && valuesShape != weightsShape) { + OPS_LOG_E("", "[ERROR]values shape or weights shape is error. "); return ge::GRAPH_FAILED; } - tiling.set_coreNum(coreNum); - - // tiling core - int64_t totalBatch = permuteShape.GetDim(0); - tiling.set_totalBatch(totalBatch); - int64_t baseBatchLen = (permuteShape.GetDim(0)) / coreNum; - tiling.set_baseBatchLen(baseBatchLen); - int64_t tailSplitIndex = (permuteShape.GetDim(0)) % coreNum; - tiling.set_tailSplitIndex(tailSplitIndex); // set data dim int64_t permuteDim0 = permuteShape.GetDim(0); @@ -123,14 +76,32 @@ namespace optiling { int64_t valuesDim = valuesShape.GetDim(0); tiling.set_valuesDim(valuesDim); + // tiling core + int64_t totalBatch = permuteDim0; + tiling.set_totalBatch(totalBatch); + int64_t baseBatchLen = permuteDim0 / coreNum; + tiling.set_baseBatchLen(baseBatchLen); + int64_t tailSplitIndex = permuteDim0 % coreNum; + tiling.set_tailSplitIndex(tailSplitIndex); + + // set ub + uint64_t ubCanUsed; + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubCanUsed); + ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; + tiling.set_ubCanUsed(ubCanUsed); + + // set coreNUm + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + if (coreNum == 0) { + return ge::GRAPH_FAILED; + } + tiling.set_coreNum(coreNum); + // apply workspace size_t* currentWorkspace = context->GetWorkspaceSizes(1); size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); currentWorkspace[0] = systemWorkspacesSize + (lengthsT + 1) * GM_ALIGN + (lengthsT + 1) * GM_ALIGN * coreNum; - OPS_LOG_E_IF(SetTypeTiling(context, tiling) == ge::GRAPH_FAILED, context, return ge::GRAPH_FAILED, - "SetTypeTiling Failed."); - context->SetBlockDim(coreNum); OPS_LOG_E_IF_NULL("raw tilingData", context->GetRawTilingData(), return ge::GRAPH_FAILED); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h index 97115458..12a22507 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h @@ -27,10 +27,7 @@ namespace optiling { TILING_DATA_FIELD_DEF(int64_t, lengthsT); TILING_DATA_FIELD_DEF(int64_t, lengthsB); TILING_DATA_FIELD_DEF(int64_t, valuesDim); - - TILING_DATA_FIELD_DEF(int64_t, valueDataType); - TILING_DATA_FIELD_DEF(int64_t, permuteDataType); - TILING_DATA_FIELD_DEF(int64_t, lengthsDataType); + TILING_DATA_FIELD_DEF(bool, enableWeights); TILING_DATA_FIELD_DEF(int64_t, totalBatch); TILING_DATA_FIELD_DEF(int64_t, baseBatchLen); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp index 014bda5b..dae264ed 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp @@ -24,6 +24,6 @@ extern "C" __global__ __aicore__ void permute2d_sparse_data(GM_ADDR permute, GM_ { Permute2dSparseData::Args args{permute, lengths, values, weights, out_lengths, out_indices, out_weights, workspace, tiling}; - Permute2dSparseData::Permute2dSparseDataKernel kernel(args); + Permute2dSparseData::Permute2dSparseDataKernel kernel(args); kernel.Compute(); } \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 03b661c8..77646c9b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -42,6 +42,7 @@ struct Args { GM_ADDR tiling; }; +template class Permute2dSparseDataKernel { public: __aicore__ inline Permute2dSparseDataKernel(Args args) @@ -55,10 +56,6 @@ public: lengthsB = tilingData.lengthsB; valuesDim = tilingData.valuesDim; - valueDataType = tilingData.valueDataType; - permuteDataType = tilingData.permuteDataType; - lengthsDataType = tilingData.lengthsDataType; - totalBatch = tilingData.totalBatch; baseBatchLen = tilingData.baseBatchLen; tailSplitIndex = tilingData.tailSplitIndex; @@ -70,6 +67,8 @@ public: values = args.values; weights = args.weights; + enableWeights = args.enableWeights; + outLengths = args.out_lengths; outIndices = args.out_indices; outWeights = args.out_weights; @@ -84,13 +83,17 @@ public: tOffsetOfThisCore = tailSplitIndex * (baseBatchLen + 1) + (GetBlockIdx() - tailSplitIndex) * baseBatchLen; } - permuteGT.SetGlobalBuffer(permute, permuteDim0 * permuteDataType); - lengthsGT.SetGlobalBuffer(lengths, lengthsT * lengthsB * lengthsDataType); - valuesGT.SetGlobalBuffer(values, valuesDim * valueDataType); + permuteGT.SetGlobalBuffer(permute, permuteDim0 * sizeof(int32_t)); + lengthsGT.SetGlobalBuffer(lengths, lengthsT * lengthsB * sizeof(LType)); + valuesGT.SetGlobalBuffer(values, valuesDim * sizeof(VType)); - outLengthsGT.SetGlobalBuffer(outLengths, lengthsT * lengthsB * lengthsDataType); - outIndicesGT.SetGlobalBuffer(outIndices, valuesDim * valueDataType); + outLengthsGT.SetGlobalBuffer(outLengths, lengthsT * lengthsB * sizeof(LType)); + outIndicesGT.SetGlobalBuffer(outIndices, valuesDim * sizeof(VType)); + if (enableWeights) { + weightsGT.SetGlobalBuffer(weights, valuesDim * sizeof(WType)); + outWeightsGT.SetGlobalBuffer(outWeights, valuesDim * sizeof(WType)); + } // Init pipe pipe.InitBuffer(inQueueX, USE_QUEUE_NUM, ubCanUsed / USE_QUEUE_NUM); blockLen = ubCanUsed / USE_QUEUE_NUM; @@ -139,25 +142,15 @@ public: offsetPtr = (__gm__ int64_t*)workspace; GlobalTensor offsetGT; offsetGT.SetGlobalBuffer((__gm__ int64_t*)offsetPtr, (lengthsT + 1) * UB_ALIGN * UB_ALIGN); - if (lengthsDataType == DATA_TYPE_INT64) { - __gm__ int64_t* lengthsPtr = (__gm__ int64_t*)lengths; - for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { - int64_t offsetT = 0; - for (int64_t j = 0; j < lengthsB; j++) { - offsetT += *(lengthsPtr + i * lengthsB + j); - } - offsetGT.SetValue(i * UB_ALIGN, offsetT); - } - } else { - __gm__ int32_t* lengthsPtr = (__gm__ int32_t*)lengths; - for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { - int64_t offsetT = 0; - for (int64_t j = 0; j < lengthsB; j++) { - offsetT += *(lengthsPtr + i * lengthsB + j); - } - offsetGT.SetValue(i * UB_ALIGN, offsetT); + __gm__ LType* lengthsPtr = (__gm__ LType*)lengths; + for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { + int64_t offsetT = 0; + for (int64_t j = 0; j < lengthsB; j++) { + offsetT += *(lengthsPtr + i * lengthsB + j); } + offsetGT.SetValue(i * UB_ALIGN, offsetT); } + AscendC::DataCacheCleanAndInvalid(offsetGT); } @@ -169,11 +162,11 @@ public: int64_t ToffsetThisIndex = *(permutePtr + i); int64_t ToffsetNextIndex = *(permutePtr + i) + 1; - int64_t lengthsStartIndex = ToffsetThisIndex * lengthsB * lengthsDataType; - int64_t lengthsEndIndex = ToffsetNextIndex * lengthsB * lengthsDataType; + int64_t lengthsStartIndex = ToffsetThisIndex * lengthsB * sizeof(LType); + int64_t lengthsEndIndex = ToffsetNextIndex * lengthsB * sizeof(LType); - int64_t outStartIndex = i * lengthsB * lengthsDataType; - int64_t outEndIndex = (i + 1) * lengthsB * lengthsDataType; + int64_t outStartIndex = i * lengthsB * sizeof(LType); + int64_t outEndIndex = (i + 1) * lengthsB * sizeof(LType); int64_t totalLen = lengthsEndIndex - lengthsStartIndex; int64_t remainLen = totalLen; while (remainLen > 0) { @@ -245,6 +238,54 @@ public: } } + __aicore__ void PermuteWeights() + { + int64_t outWeightOffset = 0; + int64_t currentT = 0; + for (int64_t i = 0; i < permuteDim0; i++) { + currentT = *(permutePtr + i); + int64_t tLen = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN) - *(totalOffsetPtr + currentT * UB_ALIGN); + int64_t baseCoreLen = tLen / coreNum; + int64_t tailLen = tLen % coreNum; + + // calculate current core permute weights offset + if (GetBlockIdx() < tailLen) { + weightLenOfThisCore = baseCoreLen + 1; + offsetOfThisCore = GetBlockIdx() * (baseCoreLen + 1); + } else { + weightLenOfThisCore = baseCoreLen; + offsetOfThisCore = tailLen * (baseCoreLen + 1) + (GetBlockIdx() - tailLen) * baseCoreLen; + } + + int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); + int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); + + int64_t weightsStartIndex = (startIndex + offsetOfThisCore) * sizeof(WType); + int64_t outWeightStartIndex = (outWeightOffset + offsetOfThisCore) * sizeof(WType); + + int64_t remainLen = weightLenOfThisCore * sizeof(WType); + while (remainLen > 0) { + int64_t thisLen = blockLen; + if (remainLen < blockLen) { + thisLen = remainLen; + } + LocalTensor inputTensor = inQueueX.AllocTensor(); + + CpGm2Local(inputTensor, weightsGT[weightsStartIndex], thisLen); + inQueueX.EnQue(inputTensor); + LocalTensor outPutTensor = inQueueX.DeQue(); + + CpLocal2Gm(outWeightsGT[outWeightStartIndex], outPutTensor, thisLen); + + outWeightStartIndex += thisLen; + weightsStartIndex += thisLen; + inQueueX.FreeTensor(outPutTensor); + remainLen = remainLen - thisLen; + } + outWeightOffset += tLen; + } + } + __aicore__ void Compute() { CalculateOffsets(); @@ -264,6 +305,9 @@ public: } PermuteLengths(); PermuteValues(); + if (enableWeights) { + PermuteWeights(); + } } private: @@ -282,11 +326,7 @@ private: int64_t lengthsT; int64_t lengthsB; int64_t valuesDim; - - // DataType - int64_t valueDataType; - int64_t permuteDataType; - int64_t lengthsDataType; + bool enableWeights; // Tiling int64_t totalBatch; @@ -304,6 +344,7 @@ private: // ThisCoreLen for B int64_t valueLenOfThisCore; + int64_t weightLenOfThisCore; int64_t offsetOfThisCore; // Tpipe @@ -314,8 +355,10 @@ private: GlobalTensor permuteGT; GlobalTensor lengthsGT; GlobalTensor valuesGT; + GlobalTensor weightsGT; GlobalTensor outLengthsGT; GlobalTensor outIndicesGT; + GlobalTensor outWeightsGT; __gm__ int64_t* offsetPtr; __gm__ int32_t* permutePtr; diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index f5db3bc2..72d2c395 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -27,54 +27,71 @@ torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") lengths_type = [np.int64, np.int32] values_type = [np.int64, np.int32, np.float32] +weights_type = [None, np.int64, np.int32, np.float32] -def get_result(permute, lengths, values): +def get_result(permute, lengths, values, weights, permuted_lengths_sum): input_permute_torch = torch.from_numpy(permute) input_lengths_torch = torch.from_numpy(lengths) input_values_torch = torch.from_numpy(values) + input_weights_torch = torch.from_numpy(weights) (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data( input_permute_torch, input_lengths_torch, input_values_torch, + input_weights_torch, + permuted_lengths_sum ) ) - return permuted_lengths.cpu(), permuted_values.cpu() + return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() -def get_result_npu(permute, lengths, values, permuted_lengths_sum=-1): +def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=-1): torch.npu.set_device(DEVICE) input_permute_torch = torch.from_numpy(permute).to(DEVICE) input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) input_values_torch = torch.from_numpy(values).to(DEVICE) + input_weights_torch = torch.from_numpy(weights).to(DEVICE) (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data( - input_permute_torch, input_lengths_torch, input_values_torch, None, permuted_lengths_sum + input_permute_torch, + input_lengths_torch, + input_values_torch, + input_weights_torch, + permuted_lengths_sum ) ) torch.npu.synchronize() - return permuted_lengths.cpu(), permuted_values.cpu() + return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() @pytest.mark.parametrize("ltype", lengths_type) @pytest.mark.parametrize("vtype", values_type) +@pytest.mark.parametrize("wtype", weights_type) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) @pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) -@pytest.mark.parametrize("permuted_lengths_sum", [None, -1, 0, 1]) +@pytest.mark.parametrize("permuted_lengths_sum", [True, False]) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(permute_dim, extra_permute_dim, lengths, ltype, vtype, permuted_lengths_sum): - input_permute = np.arange(permute_dim).astype(np.int32) - np.random.shuffle(input_permute) - input_lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) - input_values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) - permuted_lengths_sum = permuted_lengths_sum if permuted_lengths_sum != 1 else input_lengths[:permute_dim].sum() +def test_permute2d_sparse_data(ltype, + vtype, + wtype, + permute_dim, + extra_permute_dim, + permuted_lengths_sum, + lengths): + permute = np.arange(permute_dim).astype(np.int32) + np.random.shuffle(permute) + lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) + values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) + weights = None if wtype is None else np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(wtype) + permuted_lengths_sum = lengths[:permute_dim].sum() if permuted_lengths_sum else None - golden = get_result(input_permute, input_lengths, input_values) - result = get_result_npu(input_permute, input_lengths, input_values, permuted_lengths_sum) + golden = get_result(permute, lengths, values, weights, permuted_lengths_sum) + result = get_result_npu(permute, lengths, values, weights, permuted_lengths_sum) - assert torch.allclose(golden[0], result[0], atol=1e-5) - assert torch.allclose(golden[1], result[1], atol=1e-5) + for gt, pred in zip(golden, result): + assert torch.allclose(gt, pred, atol=1e-5) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index bae5287e..7df2f6ed 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -27,8 +27,10 @@ tuple> permute2d_sparse_data_impl_npu( auto permuteConti = permute.contiguous(); auto lengthsConti = lengths.contiguous(); auto valuesConti = values.contiguous(); - // weight暂不支持 at::Tensor weightsConti = at::empty({1}, lengths.options()); + if (weights.has_value()) { + weightsConti = weights.value().contiguous(); + } const auto T = permute.size(0); const auto B = lengths.size(1); @@ -40,7 +42,7 @@ tuple> permute2d_sparse_data_impl_npu( throw std::runtime_error("permute.size(0) must be less than or equal to lengths.size(0). " "Got permute.size(0): " + std::to_string(permute.size(0)) + ", lengths.size(0): " + std::to_string(lengths.size(0))); - } else if (permuted_lengths_sum.has_value() && permuted_lengths_sum.has_value() > 0) { + } else if (permuted_lengths_sum.has_value() && permuted_lengths_sum.value() > 0) { outValuesLen = static_cast(permuted_lengths_sum.value()); } else { outValuesLen = lengthsConti.narrow(0, 0, T).sum().item(); @@ -48,7 +50,7 @@ tuple> permute2d_sparse_data_impl_npu( at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); - at::Tensor outWeights = at::empty({1}, weightsConti.options()); + at::Tensor outWeights = at::empty({outValuesLen}, weightsConti.options()); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, T, outLengths, outValues, outWeights); -- Gitee From 76986db6bf4f019c1fd651eddc0aea9eb17cdce6 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 23 Jul 2025 09:57:26 +0800 Subject: [PATCH 06/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81weights=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 14 +++++++------- .../op_kernel/permute2d_sparse_data.cpp | 2 +- .../op_kernel/permute2d_sparse_data_kernel.h | 8 ++++---- .../test_permute2d_sparse_data.py | 12 +++++++----- .../permute2d_sparse_data.cpp | 10 ++++++---- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index d979869d..0f7ce5ee 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -76,6 +76,13 @@ namespace optiling { int64_t valuesDim = valuesShape.GetDim(0); tiling.set_valuesDim(valuesDim); + // set coreNUm + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + if (coreNum == 0) { + return ge::GRAPH_FAILED; + } + tiling.set_coreNum(coreNum); + // tiling core int64_t totalBatch = permuteDim0; tiling.set_totalBatch(totalBatch); @@ -90,13 +97,6 @@ namespace optiling { ubCanUsed = (ubCanUsed - RESERVER_UB_SIZE) / UB_ALIGN / NUM_QUEUE * UB_ALIGN * NUM_QUEUE; tiling.set_ubCanUsed(ubCanUsed); - // set coreNUm - size_t coreNum = ascendPlatform.GetCoreNumAiv(); - if (coreNum == 0) { - return ge::GRAPH_FAILED; - } - tiling.set_coreNum(coreNum); - // apply workspace size_t* currentWorkspace = context->GetWorkspaceSizes(1); size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp index dae264ed..067a8483 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp @@ -24,6 +24,6 @@ extern "C" __global__ __aicore__ void permute2d_sparse_data(GM_ADDR permute, GM_ { Permute2dSparseData::Args args{permute, lengths, values, weights, out_lengths, out_indices, out_weights, workspace, tiling}; - Permute2dSparseData::Permute2dSparseDataKernel kernel(args); + Permute2dSparseData::Permute2dSparseDataKernel kernel(args); kernel.Compute(); } \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 77646c9b..00acba8d 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -67,7 +67,7 @@ public: values = args.values; weights = args.weights; - enableWeights = args.enableWeights; + enableWeights = tilingData.enableWeights; outLengths = args.out_lengths; outIndices = args.out_indices; @@ -212,10 +212,10 @@ public: int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); - int64_t valuesStartIndex = (startIndex + offsetOfThisCore) * valueDataType; - int64_t outValueStartIndex = (outValueOffset + offsetOfThisCore) * valueDataType; + int64_t valuesStartIndex = (startIndex + offsetOfThisCore) * sizeof(VType); + int64_t outValueStartIndex = (outValueOffset + offsetOfThisCore) * sizeof(VType); - int64_t remainLen = valueLenOfThisCore * valueDataType; + int64_t remainLen = valueLenOfThisCore * sizeof(VType); while (remainLen > 0) { int64_t thisLen = blockLen; if (remainLen < blockLen) { diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 72d2c395..292428a3 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -49,7 +49,7 @@ def get_result(permute, lengths, values, weights, permuted_lengths_sum): return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() -def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=-1): +def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=None): torch.npu.set_device(DEVICE) input_permute_torch = torch.from_numpy(permute).to(DEVICE) input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) @@ -83,15 +83,17 @@ def test_permute2d_sparse_data(ltype, extra_permute_dim, permuted_lengths_sum, lengths): - permute = np.arange(permute_dim).astype(np.int32) + permute = np.arange(permute_dim, dtype=np.int32) np.random.shuffle(permute) + values = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=vtype) + weights = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=wtype) if wtype else None lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) - values = np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(vtype) - weights = None if wtype is None else np.arange(0, (permute_dim + extra_permute_dim) * lengths).astype(wtype) permuted_lengths_sum = lengths[:permute_dim].sum() if permuted_lengths_sum else None golden = get_result(permute, lengths, values, weights, permuted_lengths_sum) result = get_result_npu(permute, lengths, values, weights, permuted_lengths_sum) for gt, pred in zip(golden, result): - assert torch.allclose(gt, pred, atol=1e-5) + assert type(gt) is type(pred) + if isinstance(gt, torch.Tensor) and isinstance(pred, torch.Tensor): + assert torch.allclose(gt, pred, atol=1e-5) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index 7df2f6ed..95ca13b5 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -50,12 +50,14 @@ tuple> permute2d_sparse_data_impl_npu( at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); - at::Tensor outWeights = at::empty({outValuesLen}, weightsConti.options()); - - EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, T, + at::Tensor outWeights = at::Tensor(); + if (weights.has_value()) { + outWeights = at::empty({outValuesLen}, weightsConti.options()); + } + EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, outValuesLen, outLengths, outValues, outWeights); - return make_tuple(outLengths, outValues, at::Tensor()); + return make_tuple(outLengths, outValues, outWeights); } TORCH_LIBRARY(mxrec, m) -- Gitee From 03ba8e8f1479789845bc7e517bc0730a491a5bd0 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 23 Jul 2025 15:05:52 +0800 Subject: [PATCH 07/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81weights=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 25 +++++++------------ .../op_kernel/permute2d_sparse_data.cpp | 2 +- .../op_kernel/permute2d_sparse_data_kernel.h | 12 ++++----- .../test_permute2d_sparse_data.py | 11 +++----- .../permute2d_sparse_data.cpp | 10 ++------ 5 files changed, 21 insertions(+), 39 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 0f7ce5ee..262dcc40 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -150,38 +150,31 @@ public: this->Input("permute") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Input("lengths") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT64, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Input("values") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Input("weights") .ParamType(OPTIONAL) - .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .DataTypeList({ge::DT_FLOAT}) + .FormatList({ge::FORMAT_ND}); this->Output("permuted_lengths") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT64, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Output("permuted_values") .ParamType(REQUIRED) .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .FormatList({ge::FORMAT_ND}); this->Output("permuted_weights") .ParamType(OPTIONAL) - .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + .DataTypeList({ge::DT_FLOAT}) + .FormatList({ge::FORMAT_ND}); this->Attr("permuted_sum").Int(0); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp index 067a8483..7d226eab 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data.cpp @@ -24,6 +24,6 @@ extern "C" __global__ __aicore__ void permute2d_sparse_data(GM_ADDR permute, GM_ { Permute2dSparseData::Args args{permute, lengths, values, weights, out_lengths, out_indices, out_weights, workspace, tiling}; - Permute2dSparseData::Permute2dSparseDataKernel kernel(args); + Permute2dSparseData::Permute2dSparseDataKernel kernel(args); kernel.Compute(); } \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 00acba8d..6858fda1 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -42,7 +42,7 @@ struct Args { GM_ADDR tiling; }; -template +template class Permute2dSparseDataKernel { public: __aicore__ inline Permute2dSparseDataKernel(Args args) @@ -91,8 +91,8 @@ public: outIndicesGT.SetGlobalBuffer(outIndices, valuesDim * sizeof(VType)); if (enableWeights) { - weightsGT.SetGlobalBuffer(weights, valuesDim * sizeof(WType)); - outWeightsGT.SetGlobalBuffer(outWeights, valuesDim * sizeof(WType)); + weightsGT.SetGlobalBuffer(weights, valuesDim * sizeof(float)); + outWeightsGT.SetGlobalBuffer(outWeights, valuesDim * sizeof(float)); } // Init pipe pipe.InitBuffer(inQueueX, USE_QUEUE_NUM, ubCanUsed / USE_QUEUE_NUM); @@ -260,10 +260,10 @@ public: int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); - int64_t weightsStartIndex = (startIndex + offsetOfThisCore) * sizeof(WType); - int64_t outWeightStartIndex = (outWeightOffset + offsetOfThisCore) * sizeof(WType); + int64_t weightsStartIndex = (startIndex + offsetOfThisCore) * sizeof(float); + int64_t outWeightStartIndex = (outWeightOffset + offsetOfThisCore) * sizeof(float); - int64_t remainLen = weightLenOfThisCore * sizeof(WType); + int64_t remainLen = weightLenOfThisCore * sizeof(float); while (remainLen > 0) { int64_t thisLen = blockLen; if (remainLen < blockLen) { diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 292428a3..8e231e2a 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -25,11 +25,6 @@ import numpy as np DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") -lengths_type = [np.int64, np.int32] -values_type = [np.int64, np.int32, np.float32] -weights_type = [None, np.int64, np.int32, np.float32] - - def get_result(permute, lengths, values, weights, permuted_lengths_sum): input_permute_torch = torch.from_numpy(permute) input_lengths_torch = torch.from_numpy(lengths) @@ -69,9 +64,9 @@ def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=None) return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() -@pytest.mark.parametrize("ltype", lengths_type) -@pytest.mark.parametrize("vtype", values_type) -@pytest.mark.parametrize("wtype", weights_type) +@pytest.mark.parametrize("ltype", [np.int64, np.int32]) +@pytest.mark.parametrize("vtype", [np.int64, np.int32, np.float32]) +@pytest.mark.parametrize("wtype", [None, np.float32]) @pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) @pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) @pytest.mark.parametrize("permuted_lengths_sum", [True, False]) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index 95ca13b5..f9c967ae 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -27,10 +27,7 @@ tuple> permute2d_sparse_data_impl_npu( auto permuteConti = permute.contiguous(); auto lengthsConti = lengths.contiguous(); auto valuesConti = values.contiguous(); - at::Tensor weightsConti = at::empty({1}, lengths.options()); - if (weights.has_value()) { - weightsConti = weights.value().contiguous(); - } + auto weightsConti = weigths.value_or(at::Tensor()).contiguous(); const auto T = permute.size(0); const auto B = lengths.size(1); @@ -50,10 +47,7 @@ tuple> permute2d_sparse_data_impl_npu( at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); - at::Tensor outWeights = at::Tensor(); - if (weights.has_value()) { - outWeights = at::empty({outValuesLen}, weightsConti.options()); - } + at::Tensor outWeights = weights.has_value() ? at::empty({outValuesLen}, weightsConti.options()) : at::Tensor(); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, outValuesLen, outLengths, outValues, outWeights); -- Gitee From f22f303fca3cbe3dc1f4b145a4f38a1eb29c6faf Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 23 Jul 2025 16:16:29 +0800 Subject: [PATCH 08/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81weights=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 11 ++-- .../test_permute2d_sparse_data.py | 60 ++++++------------- .../permute2d_sparse_data.cpp | 2 +- 3 files changed, 26 insertions(+), 47 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 262dcc40..e732785b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -50,10 +50,13 @@ namespace optiling { auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - auto permuteShape = context->GetInputShape(PERMUTE_INDEX)->GetStorageShape(); - auto lengthsShape = context->GetInputShape(LENGTH_INDEX)->GetStorageShape(); - auto valuesShape = context->GetInputShape(VALUES_INDEX)->GetStorageShape(); - auto weightsShape = context->GetInputShape(WEIGHTS_INDEX)->GetStorageShape(); + gert::Shape permuteShape = context->GetInputShape(PERMUTE_INDEX)->GetStorageShape(); + gert::Shape lengthsShape = context->GetInputShape(LENGTH_INDEX)->GetStorageShape(); + gert::Shape valuesShape = context->GetInputShape(VALUES_INDEX)->GetStorageShape(); + gert::Shape weightsShape; + if (enableWeights) { + weightsShape = context->GetInputShape(WEIGHTS_INDEX)->GetStorageShape(); + } // shape check if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 8e231e2a..61801fb3 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -25,44 +25,26 @@ import numpy as np DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") -def get_result(permute, lengths, values, weights, permuted_lengths_sum): - input_permute_torch = torch.from_numpy(permute) - input_lengths_torch = torch.from_numpy(lengths) - input_values_torch = torch.from_numpy(values) - input_weights_torch = torch.from_numpy(weights) - - (permuted_lengths, permuted_values, permuted_weights) = ( - torch.ops.fbgemm.permute_2D_sparse_data( - input_permute_torch, - input_lengths_torch, - input_values_torch, - input_weights_torch, - permuted_lengths_sum - ) - ) - - return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() +def get_result(permute, lengths, values, weights, permuted_lengths_sum, device: str = 'cpu'): + tensors = { + 'permute': torch.from_numpy(permute), + 'lengths': torch.from_numpy(lengths), + 'values': torch.from_numpy(values), + 'weights': torch.from_numpy(weights) if weights is not None else None + } + if device and device.startswith('npu'): + torch.npu.set_device(device) + tensors = {k: v.to(device) if v is not None else None for k, v in tensors.items()} -def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=None): - torch.npu.set_device(DEVICE) - input_permute_torch = torch.from_numpy(permute).to(DEVICE) - input_lengths_torch = torch.from_numpy(lengths).to(DEVICE) - input_values_torch = torch.from_numpy(values).to(DEVICE) - input_weights_torch = torch.from_numpy(weights).to(DEVICE) - - (permuted_lengths, permuted_values, permuted_weights) = ( - torch.ops.fbgemm.permute_2D_sparse_data( - input_permute_torch, - input_lengths_torch, - input_values_torch, - input_weights_torch, - permuted_lengths_sum - ) + results = torch.ops.fbgemm.permute_2D_sparse_data( + permuted_lengths_sum=permuted_lengths_sum, **tensors ) - torch.npu.synchronize() - return permuted_lengths.cpu(), permuted_values.cpu(), permuted_weights.cpu() + if device: + torch.npu.synchronize() + + return tuple(r.cpu() if isinstance(r, torch.Tensor) else r for r in results) @pytest.mark.parametrize("ltype", [np.int64, np.int32]) @pytest.mark.parametrize("vtype", [np.int64, np.int32, np.float32]) @@ -71,13 +53,7 @@ def get_result_npu(permute, lengths, values, weights, permuted_lengths_sum=None) @pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) @pytest.mark.parametrize("permuted_lengths_sum", [True, False]) @pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(ltype, - vtype, - wtype, - permute_dim, - extra_permute_dim, - permuted_lengths_sum, - lengths): +def test_permute2d_sparse_data(ltype, vtype, wtype, permute_dim, extra_permute_dim, permuted_lengths_sum, lengths): permute = np.arange(permute_dim, dtype=np.int32) np.random.shuffle(permute) values = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=vtype) @@ -86,7 +62,7 @@ def test_permute2d_sparse_data(ltype, permuted_lengths_sum = lengths[:permute_dim].sum() if permuted_lengths_sum else None golden = get_result(permute, lengths, values, weights, permuted_lengths_sum) - result = get_result_npu(permute, lengths, values, weights, permuted_lengths_sum) + result = get_result(permute, lengths, values, weights, permuted_lengths_sum, DEVICE) for gt, pred in zip(golden, result): assert type(gt) is type(pred) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index f9c967ae..a6b29e4f 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -27,7 +27,7 @@ tuple> permute2d_sparse_data_impl_npu( auto permuteConti = permute.contiguous(); auto lengthsConti = lengths.contiguous(); auto valuesConti = values.contiguous(); - auto weightsConti = weigths.value_or(at::Tensor()).contiguous(); + auto weightsConti = weights.value_or(at::Tensor()).contiguous(); const auto T = permute.size(0); const auto B = lengths.size(1); -- Gitee From f3c94ffc2c9c191822bef8e590e6c4b59b6b29e3 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 24 Jul 2025 09:17:06 +0800 Subject: [PATCH 09/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 66 +++---------------- .../test_permute2d_sparse_data.py | 10 ++- 2 files changed, 13 insertions(+), 63 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 6858fda1..83cbfa1b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -190,7 +190,7 @@ public: } } - __aicore__ void PermuteValues() + __aicore__ void PermuteData(GlobalTensor srcGT, GlobalTensor dstGT, uint8_t datasize) { int64_t outValueOffset = 0; int64_t currentT = 0; @@ -212,10 +212,10 @@ public: int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); - int64_t valuesStartIndex = (startIndex + offsetOfThisCore) * sizeof(VType); - int64_t outValueStartIndex = (outValueOffset + offsetOfThisCore) * sizeof(VType); + int64_t valuesStartIndex = (startIndex + offsetOfThisCore) * datasize; + int64_t outValueStartIndex = (outValueOffset + offsetOfThisCore) * datasize; - int64_t remainLen = valueLenOfThisCore * sizeof(VType); + int64_t remainLen = valueLenOfThisCore * datasize; while (remainLen > 0) { int64_t thisLen = blockLen; if (remainLen < blockLen) { @@ -223,66 +223,18 @@ public: } LocalTensor inputTensor = inQueueX.AllocTensor(); - CpGm2Local(inputTensor, valuesGT[valuesStartIndex], thisLen); + CpGm2Local(inputTensor, srcGT[valuesStartIndex], thisLen); inQueueX.EnQue(inputTensor); LocalTensor outPutTensor = inQueueX.DeQue(); - CpLocal2Gm(outIndicesGT[outValueStartIndex], outPutTensor, thisLen); + CpLocal2Gm(dstGT[outValueStartIndex], outPutTensor, thisLen); outValueStartIndex += thisLen; valuesStartIndex += thisLen; inQueueX.FreeTensor(outPutTensor); remainLen = remainLen - thisLen; } - outValueOffset+=tLen; - } - } - - __aicore__ void PermuteWeights() - { - int64_t outWeightOffset = 0; - int64_t currentT = 0; - for (int64_t i = 0; i < permuteDim0; i++) { - currentT = *(permutePtr + i); - int64_t tLen = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN) - *(totalOffsetPtr + currentT * UB_ALIGN); - int64_t baseCoreLen = tLen / coreNum; - int64_t tailLen = tLen % coreNum; - - // calculate current core permute weights offset - if (GetBlockIdx() < tailLen) { - weightLenOfThisCore = baseCoreLen + 1; - offsetOfThisCore = GetBlockIdx() * (baseCoreLen + 1); - } else { - weightLenOfThisCore = baseCoreLen; - offsetOfThisCore = tailLen * (baseCoreLen + 1) + (GetBlockIdx() - tailLen) * baseCoreLen; - } - - int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); - int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); - - int64_t weightsStartIndex = (startIndex + offsetOfThisCore) * sizeof(float); - int64_t outWeightStartIndex = (outWeightOffset + offsetOfThisCore) * sizeof(float); - - int64_t remainLen = weightLenOfThisCore * sizeof(float); - while (remainLen > 0) { - int64_t thisLen = blockLen; - if (remainLen < blockLen) { - thisLen = remainLen; - } - LocalTensor inputTensor = inQueueX.AllocTensor(); - - CpGm2Local(inputTensor, weightsGT[weightsStartIndex], thisLen); - inQueueX.EnQue(inputTensor); - LocalTensor outPutTensor = inQueueX.DeQue(); - - CpLocal2Gm(outWeightsGT[outWeightStartIndex], outPutTensor, thisLen); - - outWeightStartIndex += thisLen; - weightsStartIndex += thisLen; - inQueueX.FreeTensor(outPutTensor); - remainLen = remainLen - thisLen; - } - outWeightOffset += tLen; + outValueOffset += tLen; } } @@ -304,9 +256,9 @@ public: offsetGt.GetValue((i - 1) * UB_ALIGN); } PermuteLengths(); - PermuteValues(); + PermuteData(valuesGT, outIndicesGT, sizeof(VType)); if (enableWeights) { - PermuteWeights(); + PermuteData(weightsGT, outWeightsGT, sizeof(float)); } } diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 61801fb3..36a1ec21 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -25,26 +25,24 @@ import numpy as np DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") + def get_result(permute, lengths, values, weights, permuted_lengths_sum, device: str = 'cpu'): tensors = { 'permute': torch.from_numpy(permute), 'lengths': torch.from_numpy(lengths), 'values': torch.from_numpy(values), - 'weights': torch.from_numpy(weights) if weights is not None else None + 'weights': torch.from_numpy(weights) if isinstance(weights, torch.Tensor) else None } if device and device.startswith('npu'): torch.npu.set_device(device) - tensors = {k: v.to(device) if v is not None else None for k, v in tensors.items()} + tensors = {k: v.to(device) if isinstance(v, torch.Tensor) else None for k, v in tensors.items()} results = torch.ops.fbgemm.permute_2D_sparse_data( permuted_lengths_sum=permuted_lengths_sum, **tensors ) + return tuple(result.cpu() if isinstance(result, torch.Tensor) else result for result in results) - if device: - torch.npu.synchronize() - - return tuple(r.cpu() if isinstance(r, torch.Tensor) else r for r in results) @pytest.mark.parametrize("ltype", [np.int64, np.int32]) @pytest.mark.parametrize("vtype", [np.int64, np.int32, np.float32]) -- Gitee From ec51f7895696c866ee31fab6d5ba73de65df41f5 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 24 Jul 2025 10:16:47 +0800 Subject: [PATCH 10/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 83cbfa1b..9fa0cd1e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -190,7 +190,7 @@ public: } } - __aicore__ void PermuteData(GlobalTensor srcGT, GlobalTensor dstGT, uint8_t datasize) + __aicore__ void PermuteData(GlobalTensor dstGT, GlobalTensor srcGT, uint8_t datasize) { int64_t outValueOffset = 0; int64_t currentT = 0; @@ -256,9 +256,9 @@ public: offsetGt.GetValue((i - 1) * UB_ALIGN); } PermuteLengths(); - PermuteData(valuesGT, outIndicesGT, sizeof(VType)); + PermuteData(outIndicesGT, valuesGT, sizeof(VType)); if (enableWeights) { - PermuteData(weightsGT, outWeightsGT, sizeof(float)); + PermuteData(outWeightsGT, weightsGT, sizeof(float)); } } -- Gitee From 372a0746f2d87b96f2336bbab32dd296064b96b8 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 24 Jul 2025 11:31:54 +0800 Subject: [PATCH 11/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_permute2d_sparse_data.py | 77 ++++++++++++------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 36a1ec21..1c4cc995 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import itertools import sysconfig +from typing import Iterable, Callable import pytest import torch @@ -25,42 +27,63 @@ import numpy as np DEVICE = "npu:7" torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") +PTYPE = [np.int32] +LTYPE = [np.int64, np.int32] +VTYPE = [np.int64, np.int32, np.float32] +WTYPE = [None, np.float32] +TYPE_LIST = itertools.product(PTYPE, LTYPE, VTYPE, WTYPE) + +T = np.random.randint(2, 30, 4) +EXTRA_T = [0, 3, 8] +B = [2048, 20480, 204800] +SHAPE_LIST = itertools.product(T, EXTRA_T, B) + +def tensors_apply(data: Iterable, func: Callable): + return (func(value) if isinstance(value, torch.Tensor) else value for value in data) -def get_result(permute, lengths, values, weights, permuted_lengths_sum, device: str = 'cpu'): - tensors = { - 'permute': torch.from_numpy(permute), - 'lengths': torch.from_numpy(lengths), - 'values': torch.from_numpy(values), - 'weights': torch.from_numpy(weights) if isinstance(weights, torch.Tensor) else None - } + +def get_result(tensors: dict, device: str = 'cpu'): + tensors = dict(tensors_apply(tensors.items(), lambda x, y: (x, torch.from_numpy(y)))) if device and device.startswith('npu'): torch.npu.set_device(device) - tensors = {k: v.to(device) if isinstance(v, torch.Tensor) else None for k, v in tensors.items()} + tensors = dict(tensors_apply(tensors.items(), lambda x, y: (x, y.to(device)))) + + results = torch.ops.fbgemm.permute_2D_sparse_data(**tensors) + return tuple(tensors_apply(results, lambda x: x.cpu())) - results = torch.ops.fbgemm.permute_2D_sparse_data( - permuted_lengths_sum=permuted_lengths_sum, **tensors - ) - return tuple(result.cpu() if isinstance(result, torch.Tensor) else result for result in results) +@pytest.mark.parametrize("types", TYPE_LIST) +@pytest.mark.parametrize("shapes", SHAPE_LIST) +@pytest.mark.parametrize("enable_permuted_sum", [True, False]) +def test_permute2d_sparse_data(types, shapes, enable_permuted_sum): + """ + Params: + permute: (T) dtype=int32 + lenghts: (T + T', B) dtype=ltype + L = lengths[:T].sum() + values: (L) dtype=vtype + weights: (L) dtype=fp32 + """ + ptype, ltype, vtype, wtype = types + t, extra_t, b = shapes -@pytest.mark.parametrize("ltype", [np.int64, np.int32]) -@pytest.mark.parametrize("vtype", [np.int64, np.int32, np.float32]) -@pytest.mark.parametrize("wtype", [None, np.float32]) -@pytest.mark.parametrize("permute_dim", np.random.randint(2, 30, 4).tolist()) -@pytest.mark.parametrize("extra_permute_dim", [0, 3, 8]) -@pytest.mark.parametrize("permuted_lengths_sum", [True, False]) -@pytest.mark.parametrize("lengths", [2048, 20480, 204800]) -def test_permute2d_sparse_data(ltype, vtype, wtype, permute_dim, extra_permute_dim, permuted_lengths_sum, lengths): - permute = np.arange(permute_dim, dtype=np.int32) + permute = np.arange(t, dtype=ptype) np.random.shuffle(permute) - values = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=vtype) - weights = np.arange(0, (permute_dim + extra_permute_dim) * lengths, dtype=wtype) if wtype else None - lengths = np.ones((permute_dim + extra_permute_dim, lengths), dtype=ltype) - permuted_lengths_sum = lengths[:permute_dim].sum() if permuted_lengths_sum else None + lengths = np.ones((t + extra_t, b), dtype=ltype) + values = np.arange(0, (t + extra_t) * b, dtype=vtype) + weights = np.arange(0, (t + extra_t) * b, dtype=wtype) if wtype else None + permuted_lengths_sum = lengths[:t].sum() if enable_permuted_sum else None + params = { + 'permute': permute, + 'lengths': lengths, + 'values': values, + 'weights': weights, + 'permuted_lengths_sum': permuted_lengths_sum + } - golden = get_result(permute, lengths, values, weights, permuted_lengths_sum) - result = get_result(permute, lengths, values, weights, permuted_lengths_sum, DEVICE) + golden = get_result(params) + result = get_result(params, DEVICE) for gt, pred in zip(golden, result): assert type(gt) is type(pred) -- Gitee From a8e99bcce7ed379e987636339f10f460ec9bd392 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 24 Jul 2025 11:55:59 +0800 Subject: [PATCH 12/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../permute2d_sparse_data/test_permute2d_sparse_data.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 1c4cc995..90aa66b1 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -38,19 +38,16 @@ EXTRA_T = [0, 3, 8] B = [2048, 20480, 204800] SHAPE_LIST = itertools.product(T, EXTRA_T, B) -def tensors_apply(data: Iterable, func: Callable): - return (func(value) if isinstance(value, torch.Tensor) else value for value in data) - def get_result(tensors: dict, device: str = 'cpu'): - tensors = dict(tensors_apply(tensors.items(), lambda x, y: (x, torch.from_numpy(y)))) + tensors = {k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v for k, v in tensors.items()} if device and device.startswith('npu'): torch.npu.set_device(device) - tensors = dict(tensors_apply(tensors.items(), lambda x, y: (x, y.to(device)))) + tensors = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in tensors.items()} results = torch.ops.fbgemm.permute_2D_sparse_data(**tensors) - return tuple(tensors_apply(results, lambda x: x.cpu())) + return [x.cpu() if isinstance(x, torch.Tensor) else x for x in results] @pytest.mark.parametrize("types", TYPE_LIST) -- Gitee From 37d0bfa440ca6321c8d9390aedeec9427883f1c3 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Fri, 25 Jul 2025 09:44:43 +0800 Subject: [PATCH 13/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82clean=20code=E9=99=8D=E4=BD=8E=E9=87=8D?= =?UTF-8?q?=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../permute2d_sparse_data/op_host/permute2d_sparse_data.cpp | 2 +- .../permute2d_sparse_data/test_permute2d_sparse_data.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index e732785b..710e044e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -36,11 +36,11 @@ namespace optiling { static ge::graphStatus TilingFunc(gert::TilingContext* context) { Permute2dSparseDataTilingData tiling; + OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); bool enableWeights = (context->GetOptionalInputTensor(WEIGHTS_INDEX) != nullptr); tiling.set_enableWeights(enableWeights); - OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("permuteShape", context->GetInputShape(PERMUTE_INDEX), return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("lengthsShape", context->GetInputShape(LENGTH_INDEX), return ge::GRAPH_FAILED); OPS_LOG_E_IF_NULL("valuesShape", context->GetInputShape(VALUES_INDEX), return ge::GRAPH_FAILED); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 90aa66b1..72a0dc66 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -16,7 +16,6 @@ # ============================================================================== import itertools import sysconfig -from typing import Iterable, Callable import pytest import torch -- Gitee From 25b443e45894bd2d43bb9b86e2d72af4e23465c9 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 28 Jul 2025 11:16:22 +0800 Subject: [PATCH 14/31] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index ca06eca9..eef83078 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -48,14 +48,14 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh a) 算子的主要功能是实现fbgemm的permute2d_sparse_data, 实现了对二维稀疏数据进行重排。 b) 算子参数说明: -* permute: 重排的顺序参数tensor; -* lengths: 待重排长度参数; -* values: 待重排值参数; -* weights: 暂不支持使用 -* permute_sum: 暂不支持使用 +* permute: 重排的顺序参数tensor; +* lengths: 待重排长度参数; +* values: 待重排值参数; +* weights: 可选待重排值参数,与values操作完全相同; +* permute_sum: values/weights有效长度; * permuted_lengths: 输出, 重排后长度tensor; -* permuted_values: 输出,重排后值tensor; -* permuted_weights: 输出, 暂不支持 +* permuted_values: 输出,重排后的values; +* permuted_weights: 输出,重排后的weights; c) 算子约束说明: -- Gitee From 3b403fcf9dae9749d96b8685590fb80b500c2994 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 28 Jul 2025 11:17:16 +0800 Subject: [PATCH 15/31] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rec_for_torch/operators/permute2d_sparse_data/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index eef83078..2cb601e7 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -68,12 +68,13 @@ c) 算子约束说明: ## 算子逻辑 ``` -import numpy as np +import torch +import fbgemm_gpu def permute2d_sparse_data(permute, lengths, values): (permuted_lengths, permuted_values, permuted_weights) = ( torch.ops.fbgemm.permute_2D_sparse_data(permute, lengths, values) ) - return permuted_lengths, permuted_values + return permuted_lengths, permuted_values, permuted_weights ``` \ No newline at end of file -- Gitee From a7c8ae8c3cd4909ec09ed94d70e8c14960c0576b Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 28 Jul 2025 11:22:44 +0800 Subject: [PATCH 16/31] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index 2cb601e7..0362f2e0 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -60,10 +60,16 @@ b) 算子参数说明: c) 算子约束说明: * 支持的型号:Atlas A2系列产品; -* 支持的CANN版本:8.2.RC1.alpha001及之后版本; -* 支持的输入数据类型:permute: int32, lengths: int64/int32, values: int64/int32/float; -* permute为1维tensor,lengths为二维tensor,且permute的第一维长度与lengths的第一维长度相等; -* values长度为lengths中所有数据长度之和 +* 支持的CANN版本:8.2.RC1.alpha001及之后版本; +* 支持的输入数据类型: + * permute: int32 + * lengths: int64/int32 + * values: int64/int32/fp32 + * weights: fp32 + * permute_sum: int(标量); +* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; +* 未指定permute_sum时,values/weights长度为lengths中所有数据长度之和; +* 指定permute_sum时,values/weights长度为permute_sum; * 算子参数均会在NPU显存中存放,请根据显存大小合理设置参数长度。 ## 算子逻辑 -- Gitee From 00ced895049c5c366cf0e7c28e59c2c22e2b4397 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 29 Jul 2025 09:03:57 +0800 Subject: [PATCH 17/31] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rec_for_torch/operators/permute2d_sparse_data/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index 0362f2e0..abaa4d3a 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -70,6 +70,7 @@ c) 算子约束说明: * permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; * 未指定permute_sum时,values/weights长度为lengths中所有数据长度之和; * 指定permute_sum时,values/weights长度为permute_sum; +* weights和values长度相同; * 算子参数均会在NPU显存中存放,请根据显存大小合理设置参数长度。 ## 算子逻辑 -- Gitee From b4e54b19182c6e2fe0f89dbf635f6b760e7a10b4 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 30 Jul 2025 20:11:47 +0800 Subject: [PATCH 18/31] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index abaa4d3a..c2e6a5c2 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -52,7 +52,7 @@ b) 算子参数说明: * lengths: 待重排长度参数; * values: 待重排值参数; * weights: 可选待重排值参数,与values操作完全相同; -* permute_sum: values/weights有效长度; +* permuted_lengths_sum: values/weights有效长度; * permuted_lengths: 输出, 重排后长度tensor; * permuted_values: 输出,重排后的values; * permuted_weights: 输出,重排后的weights; @@ -67,10 +67,11 @@ c) 算子约束说明: * values: int64/int32/fp32 * weights: fp32 * permute_sum: int(标量); -* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; -* 未指定permute_sum时,values/weights长度为lengths中所有数据长度之和; -* 指定permute_sum时,values/weights长度为permute_sum; -* weights和values长度相同; +* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; +同时permute中的每个值均满足: >= 0 且 < `lengths.shape[0]` +* 未指定permuted_lengths_sum时,values/weights长度为lengths中所有数据长度之和; +* 指定permuted_lengths_sum时,values/weights长度为permuted_lengths_sum; +* weights和values长度相同,均等于`lengths.sum()`; * 算子参数均会在NPU显存中存放,请根据显存大小合理设置参数长度。 ## 算子逻辑 -- Gitee From 73eb422d7f29069567d4d89776391b854831fa02 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 31 Jul 2025 10:02:52 +0800 Subject: [PATCH 19/31] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index c2e6a5c2..3367c9b8 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -50,9 +50,9 @@ b) 算子参数说明: * permute: 重排的顺序参数tensor; * lengths: 待重排长度参数; -* values: 待重排值参数; -* weights: 可选待重排值参数,与values操作完全相同; -* permuted_lengths_sum: values/weights有效长度; +* values: 待重排序的1D-tensor; +* weights: 可选入参,待重排序的1D-tensor。与values执行相同操作; +* permuted_lengths_sum: 可选入参,values/weights有效长度; * permuted_lengths: 输出, 重排后长度tensor; * permuted_values: 输出,重排后的values; * permuted_weights: 输出,重排后的weights; @@ -62,10 +62,10 @@ c) 算子约束说明: * 支持的型号:Atlas A2系列产品; * 支持的CANN版本:8.2.RC1.alpha001及之后版本; * 支持的输入数据类型: - * permute: int32 - * lengths: int64/int32 - * values: int64/int32/fp32 - * weights: fp32 + * permute: int32; + * lengths: int64/int32; + * values: int64/int32/fp32; + * weights: fp32; * permute_sum: int(标量); * permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; 同时permute中的每个值均满足: >= 0 且 < `lengths.shape[0]` -- Gitee From bbc3045fad9113bb58e320f9573f3106f76f5d99 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 31 Jul 2025 14:28:15 +0800 Subject: [PATCH 20/31] =?UTF-8?q?[docx]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82README?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index 3367c9b8..82fd4338 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -66,9 +66,8 @@ c) 算子约束说明: * lengths: int64/int32; * values: int64/int32/fp32; * weights: fp32; - * permute_sum: int(标量); -* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度; -同时permute中的每个值均满足: >= 0 且 < `lengths.shape[0]` + * permuted_lengths_sum: int(标量); +* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度。同时permute中的每个值均满足: >= 0 且 < `lengths.shape[0]`; * 未指定permuted_lengths_sum时,values/weights长度为lengths中所有数据长度之和; * 指定permuted_lengths_sum时,values/weights长度为permuted_lengths_sum; * weights和values长度相同,均等于`lengths.sum()`; @@ -78,9 +77,9 @@ c) 算子约束说明: ``` import torch import fbgemm_gpu -def permute2d_sparse_data(permute, lengths, values): +def permute2d_sparse_data(permute, lengths, values, weights, permuted_lengths_sum): (permuted_lengths, permuted_values, permuted_weights) = ( - torch.ops.fbgemm.permute_2D_sparse_data(permute, lengths, values) + torch.ops.fbgemm.permute_2D_sparse_data(permute, lengths, values, weights, permuted_lengths_sum) ) return permuted_lengths, permuted_values, permuted_weights -- Gitee From aa44fb875a3b911514d699c8f29fd2f0ffe7e5fc Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 31 Jul 2025 15:48:49 +0800 Subject: [PATCH 21/31] =?UTF-8?q?[fix]torch.ops.fbgemm.permute=5F2D=5Fspar?= =?UTF-8?q?se=5Fdata=E3=80=82permute=E3=80=81lengths=E4=B8=8D=E7=AD=89?= =?UTF-8?q?=E9=95=BF=E6=97=B6=E5=86=85=E5=AD=98=E8=B6=8A=E7=95=8Cdebug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 15 ++-- .../op_host/permute2d_sparse_data_tilling.h | 1 + .../op_kernel/permute2d_sparse_data_kernel.h | 85 ++++++++++++------- .../test_permute2d_sparse_data.py | 10 ++- .../permute2d_sparse_data.cpp | 10 +-- 5 files changed, 72 insertions(+), 49 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 710e044e..7d58f510 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -37,6 +37,7 @@ namespace optiling { { Permute2dSparseDataTilingData tiling; OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("context->GetAttrs", context->GetAttrs(), return ge::GRAPH_FAILED); bool enableWeights = (context->GetOptionalInputTensor(WEIGHTS_INDEX) != nullptr); tiling.set_enableWeights(enableWeights); @@ -70,14 +71,16 @@ namespace optiling { } // set data dim - int64_t permuteDim0 = permuteShape.GetDim(0); + int64_t permuteDim0 = permuteShape.GetDim(0); // permute[T] tiling.set_permuteDim0(permuteDim0); - int64_t lengthsT = permuteDim0; + int64_t lengthsT = lengthsShape.GetDim(0); // lengths[T + T', B] tiling.set_lengthsT(lengthsT); - int64_t lengthsB = lengthsShape.GetDim(1); + int64_t lengthsB = lengthsShape.GetDim(1); // lengths[T + T', B] tiling.set_lengthsB(lengthsB); - int64_t valuesDim = valuesShape.GetDim(0); + int64_t valuesDim = valuesShape.GetDim(0); // values[L] tiling.set_valuesDim(valuesDim); + int64_t valuesOutDim = *context->GetAttrs()->GetInt(0); + tiling.set_valuesOutDim(valuesOutDim); // set coreNUm size_t coreNum = ascendPlatform.GetCoreNumAiv(); @@ -103,7 +106,9 @@ namespace optiling { // apply workspace size_t* currentWorkspace = context->GetWorkspaceSizes(1); size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); - currentWorkspace[0] = systemWorkspacesSize + (lengthsT + 1) * GM_ALIGN + (lengthsT + 1) * GM_ALIGN * coreNum; + // 使用workspace共享lengths.sum(dim=1)和offsets计算结果, 因此为两份内存 + size_t userWorkspacesSize = 2 * (lengthsT + 1) * sizeof(int64_t); + currentWorkspace[0] = systemWorkspacesSize + userWorkspacesSize; context->SetBlockDim(coreNum); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h index 12a22507..90d1dbac 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data_tilling.h @@ -27,6 +27,7 @@ namespace optiling { TILING_DATA_FIELD_DEF(int64_t, lengthsT); TILING_DATA_FIELD_DEF(int64_t, lengthsB); TILING_DATA_FIELD_DEF(int64_t, valuesDim); + TILING_DATA_FIELD_DEF(int64_t, valuesOutDim); TILING_DATA_FIELD_DEF(bool, enableWeights); TILING_DATA_FIELD_DEF(int64_t, totalBatch); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 9fa0cd1e..0955ab82 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -55,6 +55,7 @@ public: lengthsT = tilingData.lengthsT; lengthsB = tilingData.lengthsB; valuesDim = tilingData.valuesDim; + valuesOutDim = tilingData.valuesOutDim; totalBatch = tilingData.totalBatch; baseBatchLen = tilingData.baseBatchLen; @@ -74,7 +75,7 @@ public: outWeights = args.out_weights; workspace = args.workspace; - // Calculate current core's tOffset. + // 计算分核 if (GetBlockIdx() < tailSplitIndex) { lenOfThisCore = baseBatchLen + 1; tOffsetOfThisCore = GetBlockIdx() * (baseBatchLen + 1); @@ -86,15 +87,17 @@ public: permuteGT.SetGlobalBuffer(permute, permuteDim0 * sizeof(int32_t)); lengthsGT.SetGlobalBuffer(lengths, lengthsT * lengthsB * sizeof(LType)); valuesGT.SetGlobalBuffer(values, valuesDim * sizeof(VType)); + // 使用workspace共享lengths.sum(dim=1)和offsets计算结果, 因此为两份内存 + offsetGT.SetGlobalBuffer(workspace, 2 * (lengthsT + 1) * sizeof(VType)); outLengthsGT.SetGlobalBuffer(outLengths, lengthsT * lengthsB * sizeof(LType)); - outIndicesGT.SetGlobalBuffer(outIndices, valuesDim * sizeof(VType)); + outIndicesGT.SetGlobalBuffer(outIndices, valuesOutDim * sizeof(VType)); if (enableWeights) { weightsGT.SetGlobalBuffer(weights, valuesDim * sizeof(float)); - outWeightsGT.SetGlobalBuffer(outWeights, valuesDim * sizeof(float)); + outWeightsGT.SetGlobalBuffer(outWeights, valuesOutDim * sizeof(float)); } - // Init pipe + pipe.InitBuffer(inQueueX, USE_QUEUE_NUM, ubCanUsed / USE_QUEUE_NUM); blockLen = ubCanUsed / USE_QUEUE_NUM; @@ -137,22 +140,51 @@ public: } } - __aicore__ void CalculateOffsets() + __aicore__ void CalculateLengthSum() { - offsetPtr = (__gm__ int64_t*)workspace; - GlobalTensor offsetGT; - offsetGT.SetGlobalBuffer((__gm__ int64_t*)offsetPtr, (lengthsT + 1) * UB_ALIGN * UB_ALIGN); + // 使用ub保存计算的行总长度, 以避免Scalar单元写GM数据时产生Cache一致性问题 + LocalTensor lengthsUb = inQueueX.AllocTensor (); __gm__ LType* lengthsPtr = (__gm__ LType*)lengths; - for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { - int64_t offsetT = 0; + // 计算分核信息, 当前core计算lengths[T, B]中的哪几行之和 + int64_t rows; + int64_t start; + + if (GetBlockIdx() < (lengthsT % coreNum)) { + rows = lengthsT / coreNum + 1; + start = GetBlockIdx() * rows; + } else { + rows = lengthsT / coreNum; + start = tailSplitIndex * (rows + 1) + (GetBlockIdx() - tailSplitIndex) * rows; + } + // 暂不考虑lengthsT过长情况, 默认UB可以装下lengthsT * sizeof(int64_t) + for (int64_t i = start; i < start + rows; i++) { + int64_t lineSum = 0; for (int64_t j = 0; j < lengthsB; j++) { - offsetT += *(lengthsPtr + i * lengthsB + j); + lineSum += *(lengthsPtr + i * lengthsB + j); } - offsetGT.SetValue(i * UB_ALIGN, offsetT); + lengthsUb.SetValue(i - start, lineSum); } + CpLocal2Gm(offsetGT[start], lengthsUb, rows); + pipe_barrier(PIPE_ALL); + SyncAll(); + inQueueX.FreeTensor(lengthsUb); + } + + __aicore__ void CalculateOffsets() + { + totalOffsetPtr = (__gm__ int64_t*)workspace + (lengthsT + 1) * UB_ALIGN + + GetBlockIdx() * (lengthsT + 1) * UB_ALIGN; + *(totalOffsetPtr) = 0; + GlobalTensor offsetGt; + offsetGt.SetGlobalBuffer((__gm__ int64_t*)offsetPtr, (lengthsT + 1) * UB_ALIGN * UB_ALIGN); AscendC::DataCacheCleanAndInvalid(offsetGT); + AscendC::DcciDst::CACHELINE_OUT>(offsetGt); + + for (int64_t i = 1; i < lengthsT + 1; i++) { + *(totalOffsetPtr + i * UB_ALIGN) = *(totalOffsetPtr + (i - 1) * UB_ALIGN) + + offsetGt.GetValue((i - 1) * UB_ALIGN); + } } __aicore__ void PermuteLengths() @@ -196,7 +228,10 @@ public: int64_t currentT = 0; for (int64_t i = 0; i < permuteDim0; i++) { currentT = *(permutePtr + i); - int64_t tLen = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN) - *(totalOffsetPtr + currentT * UB_ALIGN); + int64_t startIndex = *(totalOffsetPtr + currentT); + int64_t endIndex = *(totalOffsetPtr + currentT + 1); + + int64_t tLen = endIndex - startIndex; int64_t baseCoreLen = tLen / coreNum; int64_t tailLen = tLen % coreNum; @@ -209,13 +244,10 @@ public: offsetOfThisCore = tailLen * (baseCoreLen + 1) + (GetBlockIdx() - tailLen) * baseCoreLen; } - int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); - int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); - int64_t valuesStartIndex = (startIndex + offsetOfThisCore) * datasize; int64_t outValueStartIndex = (outValueOffset + offsetOfThisCore) * datasize; - int64_t remainLen = valueLenOfThisCore * datasize; + int64_t remainLen = valueLenOfThisCore * datasize; while (remainLen > 0) { int64_t thisLen = blockLen; if (remainLen < blockLen) { @@ -240,21 +272,8 @@ public: __aicore__ void Compute() { + CalculateLengthSum(); CalculateOffsets(); - pipe_barrier(PIPE_ALL); - SyncAll(); - totalOffsetPtr = (__gm__ int64_t*)workspace + (lengthsT + 1) * UB_ALIGN + - GetBlockIdx() * (lengthsT + 1) * UB_ALIGN; - *(totalOffsetPtr) = 0; - GlobalTensor offsetGt; - offsetGt.SetGlobalBuffer((__gm__ int64_t*)offsetPtr, (lengthsT + 1) * UB_ALIGN * UB_ALIGN); - AscendC::DataCacheCleanAndInvalid(offsetGt); - - for (int64_t i = 1; i < lengthsT + 1; i++) { - *(totalOffsetPtr + i * UB_ALIGN) = *(totalOffsetPtr + (i - 1) * UB_ALIGN) + - offsetGt.GetValue((i - 1) * UB_ALIGN); - } PermuteLengths(); PermuteData(outIndicesGT, valuesGT, sizeof(VType)); if (enableWeights) { @@ -278,6 +297,7 @@ private: int64_t lengthsT; int64_t lengthsB; int64_t valuesDim; + int64_t valuesOutDim; bool enableWeights; // Tiling @@ -308,6 +328,7 @@ private: GlobalTensor lengthsGT; GlobalTensor valuesGT; GlobalTensor weightsGT; + GlobalTensor offsetGT; GlobalTensor outLengthsGT; GlobalTensor outIndicesGT; GlobalTensor outWeightsGT; diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 72a0dc66..296fec34 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -30,12 +30,12 @@ PTYPE = [np.int32] LTYPE = [np.int64, np.int32] VTYPE = [np.int64, np.int32, np.float32] WTYPE = [None, np.float32] -TYPE_LIST = itertools.product(PTYPE, LTYPE, VTYPE, WTYPE) +TYPE_LIST = list(itertools.product(PTYPE, LTYPE, VTYPE, WTYPE)) T = np.random.randint(2, 30, 4) EXTRA_T = [0, 3, 8] B = [2048, 20480, 204800] -SHAPE_LIST = itertools.product(T, EXTRA_T, B) +SHAPE_LIST = list(itertools.product(T, EXTRA_T, B)) def get_result(tensors: dict, device: str = 'cpu'): @@ -64,12 +64,14 @@ def test_permute2d_sparse_data(types, shapes, enable_permuted_sum): ptype, ltype, vtype, wtype = types t, extra_t, b = shapes - permute = np.arange(t, dtype=ptype) + permute = np.arange(t + extra_t, dtype=ptype) np.random.shuffle(permute) + permute = permute[:t] + lengths = np.ones((t + extra_t, b), dtype=ltype) values = np.arange(0, (t + extra_t) * b, dtype=vtype) weights = np.arange(0, (t + extra_t) * b, dtype=wtype) if wtype else None - permuted_lengths_sum = lengths[:t].sum() if enable_permuted_sum else None + permuted_lengths_sum = lengths[permute].sum() if enable_permuted_sum else None params = { 'permute': permute, 'lengths': lengths, diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index a6b29e4f..4a7dd6de 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -33,16 +33,10 @@ tuple> permute2d_sparse_data_impl_npu( const auto B = lengths.size(1); int outValuesLen; - if (permute.size(0) == lengths.size(0)) { - outValuesLen = valuesConti.size(0); - } else if (permute.size(0) > lengths.size(0)) { - throw std::runtime_error("permute.size(0) must be less than or equal to lengths.size(0). " - "Got permute.size(0): " + std::to_string(permute.size(0)) + - ", lengths.size(0): " + std::to_string(lengths.size(0))); - } else if (permuted_lengths_sum.has_value() && permuted_lengths_sum.value() > 0) { + if (permuted_lengths_sum.has_value() && permuted_lengths_sum.value() > 0) { outValuesLen = static_cast(permuted_lengths_sum.value()); } else { - outValuesLen = lengthsConti.narrow(0, 0, T).sum().item(); + outValuesLen = lengthsConti.index_select(0, permuteConti).sum().item(); } at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); -- Gitee From 7dd2072a07d47461be5dec84378cbd740950367f Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 31 Jul 2025 19:09:44 +0800 Subject: [PATCH 22/31] =?UTF-8?q?[fix]torch.ops.fbgemm.permute=5F2D=5Fspar?= =?UTF-8?q?se=5Fdata=E3=80=82permute=E3=80=81lengths=E4=B8=8D=E7=AD=89?= =?UTF-8?q?=E9=95=BF=E6=97=B6=E5=86=85=E5=AD=98=E8=B6=8A=E7=95=8Cdebug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 0955ab82..21794307 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -88,7 +88,7 @@ public: lengthsGT.SetGlobalBuffer(lengths, lengthsT * lengthsB * sizeof(LType)); valuesGT.SetGlobalBuffer(values, valuesDim * sizeof(VType)); // 使用workspace共享lengths.sum(dim=1)和offsets计算结果, 因此为两份内存 - offsetGT.SetGlobalBuffer(workspace, 2 * (lengthsT + 1) * sizeof(VType)); + offsetGT.SetGlobalBuffer((__gm__ int64_t*)workspace, 2 * (lengthsT + 1) * sizeof(int64_t)); outLengthsGT.SetGlobalBuffer(outLengths, lengthsT * lengthsB * sizeof(LType)); outIndicesGT.SetGlobalBuffer(outIndices, valuesOutDim * sizeof(VType)); @@ -173,17 +173,13 @@ public: __aicore__ void CalculateOffsets() { - totalOffsetPtr = (__gm__ int64_t*)workspace + (lengthsT + 1) * UB_ALIGN + - GetBlockIdx() * (lengthsT + 1) * UB_ALIGN; + totalOffsetPtr = (__gm__ int64_t*)workspace + (lengthsT + 1); // 第一段内存保存lengths.sum值 *(totalOffsetPtr) = 0; - GlobalTensor offsetGt; - offsetGt.SetGlobalBuffer((__gm__ int64_t*)offsetPtr, (lengthsT + 1) * UB_ALIGN * UB_ALIGN); AscendC::DataCacheCleanAndInvalid(offsetGt); for (int64_t i = 1; i < lengthsT + 1; i++) { - *(totalOffsetPtr + i * UB_ALIGN) = *(totalOffsetPtr + (i - 1) * UB_ALIGN) + - offsetGt.GetValue((i - 1) * UB_ALIGN); + *(totalOffsetPtr + i) = *(totalOffsetPtr + (i - 1)) + offsetGt.GetValue((i - 1)); } } -- Gitee From e2ba4516615e1e5c7337f61afcb8659b61833921 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 31 Jul 2025 19:30:46 +0800 Subject: [PATCH 23/31] =?UTF-8?q?[fix]torch.ops.fbgemm.permute=5F2D=5Fspar?= =?UTF-8?q?se=5Fdata=E3=80=82permute=E3=80=81lengths=E4=B8=8D=E7=AD=89?= =?UTF-8?q?=E9=95=BF=E6=97=B6=E5=86=85=E5=AD=98=E8=B6=8A=E7=95=8Cdebug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 21794307..36017b76 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -176,10 +176,10 @@ public: totalOffsetPtr = (__gm__ int64_t*)workspace + (lengthsT + 1); // 第一段内存保存lengths.sum值 *(totalOffsetPtr) = 0; AscendC::DataCacheCleanAndInvalid(offsetGt); + AscendC::DcciDst::CACHELINE_OUT>(offsetGT); for (int64_t i = 1; i < lengthsT + 1; i++) { - *(totalOffsetPtr + i) = *(totalOffsetPtr + (i - 1)) + offsetGt.GetValue((i - 1)); + *(totalOffsetPtr + i) = *(totalOffsetPtr + (i - 1)) + offsetGT.GetValue((i - 1)); } } -- Gitee From be7385174b9bf3f92b1f7ee8572a723880f1a4df Mon Sep 17 00:00:00 2001 From: zhoucy Date: Fri, 1 Aug 2025 09:02:53 +0800 Subject: [PATCH 24/31] =?UTF-8?q?[fix]torch.ops.fbgemm.permute=5F2D=5Fspar?= =?UTF-8?q?se=5Fdata=E3=80=82permute=E3=80=81lengths=E4=B8=8D=E7=AD=89?= =?UTF-8?q?=E9=95=BF=E6=97=B6=E5=86=85=E5=AD=98=E8=B6=8A=E7=95=8Cdebug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 36017b76..6b27ad31 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -179,7 +179,7 @@ public: AscendC::DcciDst::CACHELINE_OUT>(offsetGT); for (int64_t i = 1; i < lengthsT + 1; i++) { - *(totalOffsetPtr + i) = *(totalOffsetPtr + (i - 1)) + offsetGT.GetValue((i - 1)); + *(totalOffsetPtr + i) = *(totalOffsetPtr + i - 1) + offsetGT.GetValue(i - 1); } } -- Gitee From 4866c432e37c9f1e7e86e8500f8225fb01e946d0 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Fri, 1 Aug 2025 16:02:47 +0800 Subject: [PATCH 25/31] =?UTF-8?q?[fix]torch.ops.fbgemm.permute=5F2D=5Fspar?= =?UTF-8?q?se=5Fdata=E3=80=82permute=E3=80=81lengths=E4=B8=8D=E7=AD=89?= =?UTF-8?q?=E9=95=BF=E6=97=B6=E5=86=85=E5=AD=98=E8=B6=8A=E7=95=8Cdebug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 17 ++++++++--------- .../test_permute2d_sparse_data.py | 4 +++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 6b27ad31..6f2835b5 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -148,19 +148,21 @@ public: // 计算分核信息, 当前core计算lengths[T, B]中的哪几行之和 int64_t rows; int64_t start; + int64_t tailIndex = (lengthsT % coreNum); - if (GetBlockIdx() < (lengthsT % coreNum)) { + if (GetBlockIdx() < tailIndex) { rows = lengthsT / coreNum + 1; start = GetBlockIdx() * rows; } else { rows = lengthsT / coreNum; - start = tailSplitIndex * (rows + 1) + (GetBlockIdx() - tailSplitIndex) * rows; + start = tailIndex * (rows + 1) + (GetBlockIdx() - tailIndex) * rows; } // 暂不考虑lengthsT过长情况, 默认UB可以装下lengthsT * sizeof(int64_t) for (int64_t i = start; i < start + rows; i++) { int64_t lineSum = 0; + int64_t offset = i * lengthsB; for (int64_t j = 0; j < lengthsB; j++) { - lineSum += *(lengthsPtr + i * lengthsB + j); + lineSum += *(lengthsPtr + offset + j); } lengthsUb.SetValue(i - start, lineSum); } @@ -186,16 +188,13 @@ public: __aicore__ void PermuteLengths() { permutePtr = (__gm__ int32_t*)permute; + int64_t totalLen = lengthsB; + for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { int64_t ToffsetThisIndex = *(permutePtr + i); - int64_t ToffsetNextIndex = *(permutePtr + i) + 1; - int64_t lengthsStartIndex = ToffsetThisIndex * lengthsB * sizeof(LType); - int64_t lengthsEndIndex = ToffsetNextIndex * lengthsB * sizeof(LType); - int64_t outStartIndex = i * lengthsB * sizeof(LType); - int64_t outEndIndex = (i + 1) * lengthsB * sizeof(LType); - int64_t totalLen = lengthsEndIndex - lengthsStartIndex; + int64_t remainLen = totalLen; while (remainLen > 0) { int64_t thisLen = blockLen; diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index 296fec34..f99a4888 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -15,6 +15,7 @@ # limitations under the License. # ============================================================================== import itertools +import random import sysconfig import pytest @@ -33,7 +34,7 @@ WTYPE = [None, np.float32] TYPE_LIST = list(itertools.product(PTYPE, LTYPE, VTYPE, WTYPE)) T = np.random.randint(2, 30, 4) -EXTRA_T = [0, 3, 8] +EXTRA_T = [True, False] B = [2048, 20480, 204800] SHAPE_LIST = list(itertools.product(T, EXTRA_T, B)) @@ -63,6 +64,7 @@ def test_permute2d_sparse_data(types, shapes, enable_permuted_sum): """ ptype, ltype, vtype, wtype = types t, extra_t, b = shapes + extra_t = random.randint(1, t) if extra_t else 0 permute = np.arange(t + extra_t, dtype=ptype) np.random.shuffle(permute) -- Gitee From c79efbdda6c1236526ad231d8f2f262bc72c08e9 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Fri, 1 Aug 2025 16:32:28 +0800 Subject: [PATCH 26/31] =?UTF-8?q?[fix]torch.ops.fbgemm.permute=5F2D=5Fspar?= =?UTF-8?q?se=5Fdata=E3=80=82permute=E3=80=81lengths=E4=B8=8D=E7=AD=89?= =?UTF-8?q?=E9=95=BF=E6=97=B6=E5=86=85=E5=AD=98=E8=B6=8A=E7=95=8Cdebug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 6f2835b5..c543bd83 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -188,7 +188,7 @@ public: __aicore__ void PermuteLengths() { permutePtr = (__gm__ int32_t*)permute; - int64_t totalLen = lengthsB; + int64_t totalLen = lengthsB * sizeof(LType); for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { int64_t ToffsetThisIndex = *(permutePtr + i); -- Gitee From 1351adfe67ffe715ea70827ca62ee7ab8b6afc3d Mon Sep 17 00:00:00 2001 From: zhoucy Date: Fri, 1 Aug 2025 16:43:49 +0800 Subject: [PATCH 27/31] =?UTF-8?q?[fix]torch.ops.fbgemm.permute=5F2D=5Fspar?= =?UTF-8?q?se=5Fdata=E3=80=82permute=E3=80=81lengths=E4=B8=8D=E7=AD=89?= =?UTF-8?q?=E9=95=BF=E6=97=B6=E5=86=85=E5=AD=98=E8=B6=8A=E7=95=8Cdebug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/permute2d_sparse_data_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index 6f2835b5..c543bd83 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -188,7 +188,7 @@ public: __aicore__ void PermuteLengths() { permutePtr = (__gm__ int32_t*)permute; - int64_t totalLen = lengthsB; + int64_t totalLen = lengthsB * sizeof(LType); for (int64_t i = tOffsetOfThisCore; i < lenOfThisCore + tOffsetOfThisCore; i++) { int64_t ToffsetThisIndex = *(permutePtr + i); -- Gitee From 2ec2c0a09646f87d8d0e36d0856f5cff668449e0 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 5 Aug 2025 19:13:51 +0800 Subject: [PATCH 28/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81len(permute)=20>=20len(len?= =?UTF-8?q?gths)=E5=9C=BA=E6=99=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../operators/permute2d_sparse_data/README.md | 6 +++--- .../op_host/permute2d_sparse_data.cpp | 8 ++++---- .../permute2d_sparse_data/test_permute2d_sparse_data.py | 9 +++------ 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md index 82fd4338..2e6cc967 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/README.md @@ -67,9 +67,9 @@ c) 算子约束说明: * values: int64/int32/fp32; * weights: fp32; * permuted_lengths_sum: int(标量); -* permute为1维tensor,lengths为二维tensor,且permute的第一维长度小于等于lengths的第一维长度。同时permute中的每个值均满足: >= 0 且 < `lengths.shape[0]`; -* 未指定permuted_lengths_sum时,values/weights长度为lengths中所有数据长度之和; -* 指定permuted_lengths_sum时,values/weights长度为permuted_lengths_sum; +* permute为1维tensor,lengths为二维tensor,permute中的每个值均满足: >= 0 且 < `lengths.shape[0]`; +* 指定permuted_lengths_sum时,permuted_values/permuted_weights长度为permuted_lengths_sum,请用户自行保证数值正确; +* 未指定permuted_lengths_sum时,算子将计算得到permuted_lengths_sum; * weights和values长度相同,均等于`lengths.sum()`; * 算子参数均会在NPU显存中存放,请根据显存大小合理设置参数长度。 diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 7d58f510..46667179 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -60,13 +60,13 @@ namespace optiling { } // shape check - if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM) || - (permuteShape.GetDim(0) > lengthsShape.GetDim(0))) { + if ((permuteShape.GetDimNum() != 1) || (lengthsShape.GetDimNum() != SUPPORT_EMBEDDING_DIM_NUM)) { OPS_LOG_E("", "[ERROR]permute shape or lengths shape is error. "); return ge::GRAPH_FAILED; } - if (enableWeights && valuesShape != weightsShape) { - OPS_LOG_E("", "[ERROR]values shape or weights shape is error. "); + if (enableWeights && (valuesShape != weightsShape || valuesShape.GetDimNum() != 1)) { + OPS_LOG_E("", "[ERROR]values shape or weights shape is error. values.size() = %d, weights.size() = %d\n", + valuesShape.GetDim(0), weights.GetDim(0)); return ge::GRAPH_FAILED; } diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py index f99a4888..5968a824 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/permute2d_sparse_data/test_permute2d_sparse_data.py @@ -34,7 +34,7 @@ WTYPE = [None, np.float32] TYPE_LIST = list(itertools.product(PTYPE, LTYPE, VTYPE, WTYPE)) T = np.random.randint(2, 30, 4) -EXTRA_T = [True, False] +EXTRA_T = [1, 0, -1] B = [2048, 20480, 204800] SHAPE_LIST = list(itertools.product(T, EXTRA_T, B)) @@ -64,12 +64,9 @@ def test_permute2d_sparse_data(types, shapes, enable_permuted_sum): """ ptype, ltype, vtype, wtype = types t, extra_t, b = shapes - extra_t = random.randint(1, t) if extra_t else 0 - - permute = np.arange(t + extra_t, dtype=ptype) - np.random.shuffle(permute) - permute = permute[:t] + extra_t = random.randint(1, t - 1) * extra_t + permute = np.random.choice(t + extra_t, t).astype(dtype=np.int32) lengths = np.ones((t + extra_t, b), dtype=ltype) values = np.arange(0, (t + extra_t) * b, dtype=vtype) weights = np.arange(0, (t + extra_t) * b, dtype=wtype) if wtype else None -- Gitee From 8911063eabea89ab5b198e46b94f9109438cb7d4 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 5 Aug 2025 20:10:44 +0800 Subject: [PATCH 29/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81len(permute)=20>=20len(len?= =?UTF-8?q?gths)=E5=9C=BA=E6=99=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../permute2d_sparse_data/op_host/permute2d_sparse_data.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 46667179..4d9aa146 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -66,7 +66,7 @@ namespace optiling { } if (enableWeights && (valuesShape != weightsShape || valuesShape.GetDimNum() != 1)) { OPS_LOG_E("", "[ERROR]values shape or weights shape is error. values.size() = %d, weights.size() = %d\n", - valuesShape.GetDim(0), weights.GetDim(0)); + valuesShape.GetDim(0), weightsShape.GetDim(0)); return ge::GRAPH_FAILED; } -- Gitee From c5dd00b5ec0384496d73756042bd72c9706d4a46 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 6 Aug 2025 15:51:38 +0800 Subject: [PATCH 30/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E6=94=AF=E6=8C=81len(permute)=20>=20len(len?= =?UTF-8?q?gths)=E5=9C=BA=E6=99=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/permute2d_sparse_data.cpp | 5 +-- .../op_kernel/permute2d_sparse_data_kernel.h | 31 ++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index 4d9aa146..cdb515c1 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -106,8 +106,9 @@ namespace optiling { // apply workspace size_t* currentWorkspace = context->GetWorkspaceSizes(1); size_t systemWorkspacesSize = ascendPlatform.GetLibApiWorkSpaceSize(); - // 使用workspace共享lengths.sum(dim=1)和offsets计算结果, 因此为两份内存 - size_t userWorkspacesSize = 2 * (lengthsT + 1) * sizeof(int64_t); + // 使用workspace共享lengths.sum(dim=1) + 各core计算的offsets结果 + // 为保证workspace同步成功需要保证首地址的32位对齐,因此乘以64 + size_t userWorkspacesSize = (lengthsT + 1) * GM_ALIGN * (coreNum + 1); currentWorkspace[0] = systemWorkspacesSize + userWorkspacesSize; context->SetBlockDim(coreNum); diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h index c543bd83..53d8336e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_kernel/permute2d_sparse_data_kernel.h @@ -87,8 +87,6 @@ public: permuteGT.SetGlobalBuffer(permute, permuteDim0 * sizeof(int32_t)); lengthsGT.SetGlobalBuffer(lengths, lengthsT * lengthsB * sizeof(LType)); valuesGT.SetGlobalBuffer(values, valuesDim * sizeof(VType)); - // 使用workspace共享lengths.sum(dim=1)和offsets计算结果, 因此为两份内存 - offsetGT.SetGlobalBuffer((__gm__ int64_t*)workspace, 2 * (lengthsT + 1) * sizeof(int64_t)); outLengthsGT.SetGlobalBuffer(outLengths, lengthsT * lengthsB * sizeof(LType)); outIndicesGT.SetGlobalBuffer(outIndices, valuesOutDim * sizeof(VType)); @@ -142,8 +140,10 @@ public: __aicore__ void CalculateLengthSum() { - // 使用ub保存计算的行总长度, 以避免Scalar单元写GM数据时产生Cache一致性问题 - LocalTensor lengthsUb = inQueueX.AllocTensor (); + offsetPtr = (__gm__ int64_t*)workspace; + GlobalTensor offsetGT; + // 创建[T+1, UB_ALIGN]大小的workspace + offsetGT.SetGlobalBuffer((__gm__ int64_t*)offsetPtr, (lengthsT + 1) * UB_ALIGN * sizeof(int64_t)); __gm__ LType* lengthsPtr = (__gm__ LType*)lengths; // 计算分核信息, 当前core计算lengths[T, B]中的哪几行之和 int64_t rows; @@ -164,24 +164,27 @@ public: for (int64_t j = 0; j < lengthsB; j++) { lineSum += *(lengthsPtr + offset + j); } - lengthsUb.SetValue(i - start, lineSum); + // 竖着写入,保证GT首地址32位对齐。否则DataCacheCleanAndInvalid同步失效 + offsetGT.SetValue(i * UB_ALIGN, lineSum); } - CpLocal2Gm(offsetGT[start], lengthsUb, rows); + AscendC::DataCacheCleanAndInvalid(offsetGT); pipe_barrier(PIPE_ALL); SyncAll(); - - inQueueX.FreeTensor(lengthsUb); } __aicore__ void CalculateOffsets() { - totalOffsetPtr = (__gm__ int64_t*)workspace + (lengthsT + 1); // 第一段内存保存lengths.sum值 + totalOffsetPtr = (__gm__ int64_t*)workspace + (1 + GetBlockIdx()) * (lengthsT + 1) * UB_ALIGN; *(totalOffsetPtr) = 0; + GlobalTensor offsetGt; + offsetGt.SetGlobalBuffer((__gm__ int64_t*)offsetPtr, (lengthsT + 1) * UB_ALIGN * UB_ALIGN); AscendC::DataCacheCleanAndInvalid(offsetGT); + AscendC::DcciDst::CACHELINE_OUT>(offsetGt); for (int64_t i = 1; i < lengthsT + 1; i++) { - *(totalOffsetPtr + i) = *(totalOffsetPtr + i - 1) + offsetGT.GetValue(i - 1); + *(totalOffsetPtr + i * UB_ALIGN) = *(totalOffsetPtr + (i - 1) * UB_ALIGN) + + offsetGt.GetValue((i - 1) * UB_ALIGN); } } @@ -223,14 +226,14 @@ public: int64_t currentT = 0; for (int64_t i = 0; i < permuteDim0; i++) { currentT = *(permutePtr + i); - int64_t startIndex = *(totalOffsetPtr + currentT); - int64_t endIndex = *(totalOffsetPtr + currentT + 1); + int64_t startIndex = *(totalOffsetPtr + currentT * UB_ALIGN); + int64_t endIndex = *(totalOffsetPtr + (currentT + 1) * UB_ALIGN); int64_t tLen = endIndex - startIndex; int64_t baseCoreLen = tLen / coreNum; int64_t tailLen = tLen % coreNum; - // calculate current core permute values offset + // 计算当前核上处理的values起始位置、处理量 if (GetBlockIdx() < tailLen) { valueLenOfThisCore = baseCoreLen + 1; offsetOfThisCore = GetBlockIdx() * (baseCoreLen + 1); -- Gitee From 6defd539d9bf2d3662ef854506614089b9023aa7 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Sat, 9 Aug 2025 08:38:56 +0800 Subject: [PATCH 31/31] =?UTF-8?q?[feat]torch.ops.fbgemm.permute=5F2D=5Fspa?= =?UTF-8?q?rse=5Fdata=E3=80=82=E7=AE=97=E5=AD=90=E5=8E=9F=E5=BD=A2?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../permute2d_sparse_data/op_host/permute2d_sparse_data.cpp | 6 +++--- .../2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index f95e8a33..39d5adde 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -174,15 +174,15 @@ public: .FormatList({ge::FORMAT_ND}); this->Output("permuted_lengths") .ParamType(REQUIRED) - .DataTypeList({ge::DT_INT64, ge::DT_INT32}) + .Follow("lengths", FollowType::DTYPE) .FormatList({ge::FORMAT_ND}); this->Output("permuted_values") .ParamType(REQUIRED) - .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) + .Follow("values", FollowType::DTYPE) .FormatList({ge::FORMAT_ND}); this->Output("permuted_weights") .ParamType(OPTIONAL) - .DataTypeList({ge::DT_FLOAT}) + .Follow("weights", FollowType::DTYPE) .FormatList({ge::FORMAT_ND}); this->Attr("permuted_sum").Int(0); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index 4a7dd6de..b676ea34 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -40,8 +40,8 @@ tuple> permute2d_sparse_data_impl_npu( } at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); - at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); - at::Tensor outWeights = weights.has_value() ? at::empty({outValuesLen}, weightsConti.options()) : at::Tensor(); + at::Tensor outValues = at::zeros({outValuesLen}, valuesConti.options()); + at::Tensor outWeights = weights.has_value() ? at::zeros({outValuesLen}, weightsConti.options()) : at::Tensor(); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, outValuesLen, outLengths, outValues, outWeights); -- Gitee