diff --git a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp index f95e8a33d2d570e78aa85ac234e8fd682defd2d4..39d5adde8b6ea5961ee029d0114543360fa64c57 100644 --- a/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/permute2d_sparse_data/op_host/permute2d_sparse_data.cpp @@ -174,15 +174,15 @@ public: .FormatList({ge::FORMAT_ND}); this->Output("permuted_lengths") .ParamType(REQUIRED) - .DataTypeList({ge::DT_INT64, ge::DT_INT32}) + .Follow("lengths", FollowType::DTYPE) .FormatList({ge::FORMAT_ND}); this->Output("permuted_values") .ParamType(REQUIRED) - .DataTypeList({ge::DT_INT64, ge::DT_INT32, ge::DT_FLOAT}) + .Follow("values", FollowType::DTYPE) .FormatList({ge::FORMAT_ND}); this->Output("permuted_weights") .ParamType(OPTIONAL) - .DataTypeList({ge::DT_FLOAT}) + .Follow("weights", FollowType::DTYPE) .FormatList({ge::FORMAT_ND}); this->Attr("permuted_sum").Int(0); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp index 4a7dd6deec7ac081da2062768666d34b4ce176c0..b676ea3490fcf236c55a718c237ff1ec66696972 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/permute2d_sparse_data/permute2d_sparse_data.cpp @@ -40,8 +40,8 @@ tuple> permute2d_sparse_data_impl_npu( } at::Tensor outLengths = at::empty({T, B}, lengthsConti.options()); - at::Tensor outValues = at::empty({outValuesLen}, valuesConti.options()); - at::Tensor outWeights = weights.has_value() ? at::empty({outValuesLen}, weightsConti.options()) : at::Tensor(); + at::Tensor outValues = at::zeros({outValuesLen}, valuesConti.options()); + at::Tensor outWeights = weights.has_value() ? at::zeros({outValuesLen}, weightsConti.options()) : at::Tensor(); EXEC_NPU_CMD(aclnnPermute2dSparseData, permuteConti, lengthsConti, valuesConti, weightsConti, outValuesLen, outLengths, outValues, outWeights);