From 625ed64d63553ccb3b5c3afc7ca38901cf0737c0 Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Thu, 26 Jun 2025 16:54:56 +0800 Subject: [PATCH 01/18] cleancode --- .../op_host/hstu_dense_backward.cpp | 184 ++++++++---------- .../op_kernel/hstu_dense_backward_kernel.h | 10 +- .../2.6.0/hstu/HstuDenseNpuFusion.cpp | 18 +- 3 files changed, 90 insertions(+), 122 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp index fd616945..be384988 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp @@ -22,13 +22,36 @@ See the License for the specific language governing permissions and #include "hstu_dense_backward_jagged_tiling.h" namespace optiling { -static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseBackwardTilingData &tiling) + +static void SetQKVGrad(matmul_tiling::MatmulApiTiling &matmul) { - int64_t headDim = tiling.get_headDim(); - int64_t blockHeight = tiling.get_blockHeight(); - int64_t dataTypeLength = tiling.get_dataTypeLength(); + matmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + matmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + matmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, + matmul_tiling::DataType::DT_FLOAT); + matmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - matmul_tiling::DataType dataType; + matmul.SetOrgShape(blockHeight, headDim, blockHeight); + matmul.SetShape(blockHeight, headDim, blockHeight); + matmul.SetBias(false); + matmul.SetBufferSpace(-1, -1, -1); +} + +static void SetQKMatmul(matmul_tiling::MatmulApiTiling &matmul) +{ + matmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + matmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + matmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + matmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + + matmul.SetOrgShape(blockHeight, blockHeight, headDim); + matmul.SetShape(blockHeight, blockHeight, headDim); + matmul.SetBias(false); + matmul.SetBufferSpace(-1, -1, -1);} +} + +static ge::graphStatus GetDataType(gert::TilingContext *context, matmul_tiling::DataType &dataType) +{ ge::DataType gradType = context->GetInputTensor(INDEX_T::INDEX_0)->GetDataType(); if (gradType == ge::DataType::DT_FLOAT) { dataType = matmul_tiling::DataType::DT_FLOAT; @@ -40,6 +63,17 @@ static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseB OPS_LOG_E("", "invalid datatype, only support float/fp16/bf16\n"); return ge::GRAPH_FAILED; } +} + +static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseBackwardTilingData &tiling) +{ + int64_t headDim = tiling.get_headDim(); + int64_t blockHeight = tiling.get_blockHeight(); + int64_t dataTypeLength = tiling.get_dataTypeLength(); + + matmul_tiling::DataType dataType; + OPS_CHECK(GetDataType(context, dataType) == ge::GRAPH_FAILED, OPS_LOG_E("", "GetDataType failed\n"), + return ge::GRAPH_FAILED); auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); size_t coreNum = ascendPlatform.GetCoreNumAic(); @@ -55,8 +89,7 @@ static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseB int64_t maskTempSpace = blockHeight * blockHeight; - int64_t totalTempSpaceForOneVec = - MID_USE_TIMES * + int64_t totalTempSpaceForOneVec = MID_USE_TIMES * ((vGradAccumTempSpace + kGradAccumTempSpace) * sizeof(float) + (qkMatmulTempSpace + gvMatmulTempSpace + scoreTempSpace) * dataTypeLength) + maskTempSpace * dataTypeLength; @@ -68,56 +101,19 @@ static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseB currentWorkspace[0] = workspaceSize + systemWorkspaceSize; matmul_tiling::MatmulApiTiling qkMatmul(ascendPlatform); - qkMatmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - qkMatmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - qkMatmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - qkMatmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - - qkMatmul.SetOrgShape(blockHeight, blockHeight, headDim); - qkMatmul.SetShape(blockHeight, blockHeight, headDim); - qkMatmul.SetBias(false); - qkMatmul.SetBufferSpace(-1, -1, -1); + SetQKMatmul(qkMatmul); matmul_tiling::MatmulApiTiling qGradMatmul(ascendPlatform); - qGradMatmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - qGradMatmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - qGradMatmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, - matmul_tiling::DataType::DT_FLOAT); - qGradMatmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - - qGradMatmul.SetOrgShape(blockHeight, headDim, blockHeight); - qGradMatmul.SetShape(blockHeight, headDim, blockHeight); - qGradMatmul.SetBias(false); - qGradMatmul.SetBufferSpace(-1, -1, -1); + SetQKVGrad(qGradMatmul); matmul_tiling::MatmulApiTiling kGradMatmul(ascendPlatform); - kGradMatmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - kGradMatmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - kGradMatmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, - matmul_tiling::DataType::DT_FLOAT); - kGradMatmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - - kGradMatmul.SetOrgShape(blockHeight, headDim, blockHeight); - kGradMatmul.SetShape(blockHeight, headDim, blockHeight); - kGradMatmul.SetBias(false); - kGradMatmul.SetBufferSpace(-1, -1, -1); + SetQKVGrad(kGradMatmul); matmul_tiling::MatmulApiTiling vGradMatmul(ascendPlatform); - vGradMatmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - vGradMatmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - vGradMatmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, - matmul_tiling::DataType::DT_FLOAT); - vGradMatmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - - vGradMatmul.SetOrgShape(blockHeight, headDim, blockHeight); - vGradMatmul.SetShape(blockHeight, headDim, blockHeight); - vGradMatmul.SetBias(false); - vGradMatmul.SetBufferSpace(-1, -1, -1); + SetQKVGrad(vGradMatmul); - if (qkMatmul.GetTiling(tiling.qkMatmul) == -1 || - qGradMatmul.GetTiling(tiling.qGradMatmul) == -1 || - kGradMatmul.GetTiling(tiling.kGradMatmul) == -1 || - vGradMatmul.GetTiling(tiling.vGradMatmul) == -1) { + if (qkMatmul.GetTiling(tiling.qkMatmul) == -1 || qGradMatmul.GetTiling(tiling.qGradMatmul) == -1 || + kGradMatmul.GetTiling(tiling.kGradMatmul) == -1 || vGradMatmul.GetTiling(tiling.vGradMatmul) == -1) { return ge::GRAPH_FAILED; } @@ -190,57 +186,37 @@ class HstuDenseBackward : public OpDef { public: explicit HstuDenseBackward(const char *name) : OpDef(name) { - this->Input("grad") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("q") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("k") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("v") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("mask") - .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("attn_bias") - .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - - this->Output("q_grad") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("k_grad") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("v_grad") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("attn_bias_grad") - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("q").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("k").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("v").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("mask").ParamType(OPTIONAL).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("attn_bias").ParamType(OPTIONAL).DataType + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Output("q_grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("k_grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("v_grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("attn_bias_grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("layout").String("normal"); this->Attr("mask_type").Int(); @@ -249,10 +225,8 @@ public: this->Attr("seq_offsets").AttrType(OPTIONAL).ListInt(); OpAICoreConfig aicore_config; - aicore_config.DynamicCompileStaticFlag(true) - .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") - .ExtendCfgInfo("coreType.value", "AiCore") - .ExtendCfgInfo("prebuildPattern.value", "Opaque"); + aicore_config.DynamicCompileStaticFlag(true).ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") + .ExtendCfgInfo("coreType.value", "AiCore").ExtendCfgInfo("prebuildPattern.value", "Opaque"); this->SetInferShape(ge::InferShape); this->SetInferDataType(ge::InferDtype); diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index ae25b5bc..14653dd1 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -57,7 +57,7 @@ public: ComputeSecond(); } - __aicore__ inline void Init(Args &args) + __aicore__inline void InitGlobalBuffer(Args &args) { GET_TILING_DATA(tilingData, args.tiling); @@ -101,7 +101,10 @@ public: kGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.kGrad), totalElementOfQ); vGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.vGrad), totalElementOfQ); attnBiasGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBiasGrad), totalElementOfAttnBias); + } + __aicore__ inline void InitPipe() + { int64_t qkMatmulTempSpace = blockHeight * blockHeight; int64_t gvMatmulTempSpace = blockHeight * blockHeight; int64_t vGradAccumTempSpace = blockHeight * headDim; @@ -144,7 +147,12 @@ public: pipe.InitBuffer(queueOutputScore, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); pipe.InitBuffer(queueOutputBias, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); pipe.InitBuffer(queueOutputTemp, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); + } + __aicore__ inline void Init(Args &args) + { + InitGlobalBuffer(args); + InitPipe(); CreateMask(); } diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/hstu/HstuDenseNpuFusion.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/hstu/HstuDenseNpuFusion.cpp index fdeb47a0..258358d3 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/hstu/HstuDenseNpuFusion.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/hstu/HstuDenseNpuFusion.cpp @@ -305,22 +305,8 @@ std::tuple hstu_dense_jagged_bac } const char *layout = "jagged"; - EXEC_NPU_CMD(aclnnHstuDenseBackward, - denseGrad, - denseQ, - denseK, - denseV, - denseMask, - denseAttnBias, - layout, - maskType, - maxSeqLen, - realSiluScale, - acSeqOffset, - qGradOutput, - kGradOutput, - vGradOutput, - attnBiasGradOutput); + EXEC_NPU_CMD(aclnnHstuDenseBackward, denseGrad, denseQ, denseK, denseV, denseMask, denseAttnBias, layout, maskType, + maxSeqLen, realSiluScale, acSeqOffset, qGradOutput, kGradOutput, vGradOutput, attnBiasGradOutput); return std::make_tuple(qGradOutput, kGradOutput, vGradOutput, attnBiasGradOutput); } -- Gitee From 7a69c4c94d8f040c6e16185734d3ce274989a9f9 Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Thu, 26 Jun 2025 17:12:38 +0800 Subject: [PATCH 02/18] cleancode --- .../op_host/hstu_dense_backward.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp index be384988..76a66f15 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp @@ -23,7 +23,7 @@ See the License for the specific language governing permissions and namespace optiling { -static void SetQKVGrad(matmul_tiling::MatmulApiTiling &matmul) +static void SetQKVGrad(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::DataType dataType) { matmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); matmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); @@ -37,7 +37,7 @@ static void SetQKVGrad(matmul_tiling::MatmulApiTiling &matmul) matmul.SetBufferSpace(-1, -1, -1); } -static void SetQKMatmul(matmul_tiling::MatmulApiTiling &matmul) +static void SetQKMatmul(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::DataType dataType) { matmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); matmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); @@ -101,16 +101,16 @@ static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseB currentWorkspace[0] = workspaceSize + systemWorkspaceSize; matmul_tiling::MatmulApiTiling qkMatmul(ascendPlatform); - SetQKMatmul(qkMatmul); + SetQKMatmul(qkMatmul, dataType); matmul_tiling::MatmulApiTiling qGradMatmul(ascendPlatform); - SetQKVGrad(qGradMatmul); + SetQKVGrad(qGradMatmul, dataType); matmul_tiling::MatmulApiTiling kGradMatmul(ascendPlatform); - SetQKVGrad(kGradMatmul); + SetQKVGrad(kGradMatmul, dataType); matmul_tiling::MatmulApiTiling vGradMatmul(ascendPlatform); - SetQKVGrad(vGradMatmul); + SetQKVGrad(vGradMatmul, dataType); if (qkMatmul.GetTiling(tiling.qkMatmul) == -1 || qGradMatmul.GetTiling(tiling.qGradMatmul) == -1 || kGradMatmul.GetTiling(tiling.kGradMatmul) == -1 || vGradMatmul.GetTiling(tiling.vGradMatmul) == -1) { -- Gitee From f9cbfffeff0a32eb11d012f2ca443c429ee4a8a3 Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Thu, 26 Jun 2025 17:14:28 +0800 Subject: [PATCH 03/18] cleancode --- .../op_host/hstu_dense_backward.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp index 76a66f15..d595ce6f 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp @@ -23,7 +23,8 @@ See the License for the specific language governing permissions and namespace optiling { -static void SetQKVGrad(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::DataType dataType) +static void SetQKVGrad(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::DataType dataType, + int blockHeight, int headDim) { matmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); matmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); @@ -37,7 +38,8 @@ static void SetQKVGrad(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::Da matmul.SetBufferSpace(-1, -1, -1); } -static void SetQKMatmul(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::DataType dataType) +static void SetQKMatmul(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::DataType dataType, + int blockHeight, int headDim) { matmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); matmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); @@ -101,16 +103,16 @@ static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseB currentWorkspace[0] = workspaceSize + systemWorkspaceSize; matmul_tiling::MatmulApiTiling qkMatmul(ascendPlatform); - SetQKMatmul(qkMatmul, dataType); + SetQKMatmul(qkMatmul, dataType, blockHeight, headDim); matmul_tiling::MatmulApiTiling qGradMatmul(ascendPlatform); - SetQKVGrad(qGradMatmul, dataType); + SetQKVGrad(qGradMatmul, dataType, blockHeight, headDim); matmul_tiling::MatmulApiTiling kGradMatmul(ascendPlatform); - SetQKVGrad(kGradMatmul, dataType); + SetQKVGrad(kGradMatmul, dataType, blockHeight, headDim); matmul_tiling::MatmulApiTiling vGradMatmul(ascendPlatform); - SetQKVGrad(vGradMatmul, dataType); + SetQKVGrad(vGradMatmul, dataType, blockHeight, headDim); if (qkMatmul.GetTiling(tiling.qkMatmul) == -1 || qGradMatmul.GetTiling(tiling.qGradMatmul) == -1 || kGradMatmul.GetTiling(tiling.kGradMatmul) == -1 || vGradMatmul.GetTiling(tiling.vGradMatmul) == -1) { -- Gitee From 715de046ece1673eeda549b2c9506ee3078de5f0 Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Thu, 26 Jun 2025 17:19:51 +0800 Subject: [PATCH 04/18] cleancode --- .../hstu_dense_backward/op_host/hstu_dense_backward.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp index d595ce6f..da2ed3f8 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp @@ -49,7 +49,7 @@ static void SetQKMatmul(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::D matmul.SetOrgShape(blockHeight, blockHeight, headDim); matmul.SetShape(blockHeight, blockHeight, headDim); matmul.SetBias(false); - matmul.SetBufferSpace(-1, -1, -1);} + matmul.SetBufferSpace(-1, -1, -1); } static ge::graphStatus GetDataType(gert::TilingContext *context, matmul_tiling::DataType &dataType) @@ -203,7 +203,7 @@ public: this->Input("mask").ParamType(OPTIONAL).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("attn_bias").ParamType(OPTIONAL).DataType + this->Input("attn_bias").ParamType(OPTIONAL).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); -- Gitee From 4b18f7616776c48d5c3e80b3b6b819aca890d0ec Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Thu, 26 Jun 2025 19:25:26 +0800 Subject: [PATCH 05/18] cleancode --- .../hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index 14653dd1..b7cdcf97 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -57,7 +57,7 @@ public: ComputeSecond(); } - __aicore__inline void InitGlobalBuffer(Args &args) + __aicore__ inline void InitGlobalBuffer(Args &args) { GET_TILING_DATA(tilingData, args.tiling); -- Gitee From 5fa4d51a8ae6c4bd309e685f3be80d05ca91b34c Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Thu, 26 Jun 2025 19:28:38 +0800 Subject: [PATCH 06/18] cleancode --- .../op_kernel/hstu_dense_backward_kernel.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index b7cdcf97..e97f738e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -61,8 +61,6 @@ public: { GET_TILING_DATA(tilingData, args.tiling); - GM_ADDR workspace = args.workspace; - batchSize = tilingData.batchSize; seqLen = tilingData.seqLen; headNum = tilingData.headNum; @@ -103,8 +101,9 @@ public: attnBiasGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBiasGrad), totalElementOfAttnBias); } - __aicore__ inline void InitPipe() + __aicore__ inline void InitPipe(Args &args) { + GM_ADDR workspace = args.workspace; int64_t qkMatmulTempSpace = blockHeight * blockHeight; int64_t gvMatmulTempSpace = blockHeight * blockHeight; int64_t vGradAccumTempSpace = blockHeight * headDim; @@ -152,7 +151,7 @@ public: __aicore__ inline void Init(Args &args) { InitGlobalBuffer(args); - InitPipe(); + InitPipe(args); CreateMask(); } -- Gitee From 0e5eea337e7e3eedae47a40391804f3304a9e71a Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Thu, 26 Jun 2025 19:39:54 +0800 Subject: [PATCH 07/18] cleancode --- .../op_host/hstu_dense_backward.cpp | 186 ++++++++++-------- 1 file changed, 105 insertions(+), 81 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp index da2ed3f8..fd616945 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp @@ -22,38 +22,13 @@ See the License for the specific language governing permissions and #include "hstu_dense_backward_jagged_tiling.h" namespace optiling { - -static void SetQKVGrad(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::DataType dataType, - int blockHeight, int headDim) -{ - matmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - matmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - matmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, - matmul_tiling::DataType::DT_FLOAT); - matmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - - matmul.SetOrgShape(blockHeight, headDim, blockHeight); - matmul.SetShape(blockHeight, headDim, blockHeight); - matmul.SetBias(false); - matmul.SetBufferSpace(-1, -1, -1); -} - -static void SetQKMatmul(matmul_tiling::MatmulApiTiling &matmul, matmul_tiling::DataType dataType, - int blockHeight, int headDim) +static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseBackwardTilingData &tiling) { - matmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - matmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - matmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - matmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); - - matmul.SetOrgShape(blockHeight, blockHeight, headDim); - matmul.SetShape(blockHeight, blockHeight, headDim); - matmul.SetBias(false); - matmul.SetBufferSpace(-1, -1, -1); -} + int64_t headDim = tiling.get_headDim(); + int64_t blockHeight = tiling.get_blockHeight(); + int64_t dataTypeLength = tiling.get_dataTypeLength(); -static ge::graphStatus GetDataType(gert::TilingContext *context, matmul_tiling::DataType &dataType) -{ + matmul_tiling::DataType dataType; ge::DataType gradType = context->GetInputTensor(INDEX_T::INDEX_0)->GetDataType(); if (gradType == ge::DataType::DT_FLOAT) { dataType = matmul_tiling::DataType::DT_FLOAT; @@ -65,17 +40,6 @@ static ge::graphStatus GetDataType(gert::TilingContext *context, matmul_tiling:: OPS_LOG_E("", "invalid datatype, only support float/fp16/bf16\n"); return ge::GRAPH_FAILED; } -} - -static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseBackwardTilingData &tiling) -{ - int64_t headDim = tiling.get_headDim(); - int64_t blockHeight = tiling.get_blockHeight(); - int64_t dataTypeLength = tiling.get_dataTypeLength(); - - matmul_tiling::DataType dataType; - OPS_CHECK(GetDataType(context, dataType) == ge::GRAPH_FAILED, OPS_LOG_E("", "GetDataType failed\n"), - return ge::GRAPH_FAILED); auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); size_t coreNum = ascendPlatform.GetCoreNumAic(); @@ -91,7 +55,8 @@ static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseB int64_t maskTempSpace = blockHeight * blockHeight; - int64_t totalTempSpaceForOneVec = MID_USE_TIMES * + int64_t totalTempSpaceForOneVec = + MID_USE_TIMES * ((vGradAccumTempSpace + kGradAccumTempSpace) * sizeof(float) + (qkMatmulTempSpace + gvMatmulTempSpace + scoreTempSpace) * dataTypeLength) + maskTempSpace * dataTypeLength; @@ -103,19 +68,56 @@ static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseB currentWorkspace[0] = workspaceSize + systemWorkspaceSize; matmul_tiling::MatmulApiTiling qkMatmul(ascendPlatform); - SetQKMatmul(qkMatmul, dataType, blockHeight, headDim); + qkMatmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + qkMatmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + qkMatmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + qkMatmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + + qkMatmul.SetOrgShape(blockHeight, blockHeight, headDim); + qkMatmul.SetShape(blockHeight, blockHeight, headDim); + qkMatmul.SetBias(false); + qkMatmul.SetBufferSpace(-1, -1, -1); matmul_tiling::MatmulApiTiling qGradMatmul(ascendPlatform); - SetQKVGrad(qGradMatmul, dataType, blockHeight, headDim); + qGradMatmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + qGradMatmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + qGradMatmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, + matmul_tiling::DataType::DT_FLOAT); + qGradMatmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + + qGradMatmul.SetOrgShape(blockHeight, headDim, blockHeight); + qGradMatmul.SetShape(blockHeight, headDim, blockHeight); + qGradMatmul.SetBias(false); + qGradMatmul.SetBufferSpace(-1, -1, -1); matmul_tiling::MatmulApiTiling kGradMatmul(ascendPlatform); - SetQKVGrad(kGradMatmul, dataType, blockHeight, headDim); + kGradMatmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + kGradMatmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + kGradMatmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, + matmul_tiling::DataType::DT_FLOAT); + kGradMatmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + + kGradMatmul.SetOrgShape(blockHeight, headDim, blockHeight); + kGradMatmul.SetShape(blockHeight, headDim, blockHeight); + kGradMatmul.SetBias(false); + kGradMatmul.SetBufferSpace(-1, -1, -1); matmul_tiling::MatmulApiTiling vGradMatmul(ascendPlatform); - SetQKVGrad(vGradMatmul, dataType, blockHeight, headDim); + vGradMatmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + vGradMatmul.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + vGradMatmul.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, + matmul_tiling::DataType::DT_FLOAT); + vGradMatmul.SetBiasType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, dataType); + + vGradMatmul.SetOrgShape(blockHeight, headDim, blockHeight); + vGradMatmul.SetShape(blockHeight, headDim, blockHeight); + vGradMatmul.SetBias(false); + vGradMatmul.SetBufferSpace(-1, -1, -1); - if (qkMatmul.GetTiling(tiling.qkMatmul) == -1 || qGradMatmul.GetTiling(tiling.qGradMatmul) == -1 || - kGradMatmul.GetTiling(tiling.kGradMatmul) == -1 || vGradMatmul.GetTiling(tiling.vGradMatmul) == -1) { + if (qkMatmul.GetTiling(tiling.qkMatmul) == -1 || + qGradMatmul.GetTiling(tiling.qGradMatmul) == -1 || + kGradMatmul.GetTiling(tiling.kGradMatmul) == -1 || + vGradMatmul.GetTiling(tiling.vGradMatmul) == -1) { return ge::GRAPH_FAILED; } @@ -188,37 +190,57 @@ class HstuDenseBackward : public OpDef { public: explicit HstuDenseBackward(const char *name) : OpDef(name) { - this->Input("grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("q").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("k").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("v").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("mask").ParamType(OPTIONAL).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("attn_bias").ParamType(OPTIONAL).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - - this->Output("q_grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("k_grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("v_grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("attn_bias_grad").ParamType(REQUIRED).DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("grad") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("q") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("k") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("v") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("mask") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("attn_bias") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + + this->Output("q_grad") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("k_grad") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("v_grad") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("attn_bias_grad") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("layout").String("normal"); this->Attr("mask_type").Int(); @@ -227,8 +249,10 @@ public: this->Attr("seq_offsets").AttrType(OPTIONAL).ListInt(); OpAICoreConfig aicore_config; - aicore_config.DynamicCompileStaticFlag(true).ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") - .ExtendCfgInfo("coreType.value", "AiCore").ExtendCfgInfo("prebuildPattern.value", "Opaque"); + aicore_config.DynamicCompileStaticFlag(true) + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") + .ExtendCfgInfo("coreType.value", "AiCore") + .ExtendCfgInfo("prebuildPattern.value", "Opaque"); this->SetInferShape(ge::InferShape); this->SetInferDataType(ge::InferDtype); -- Gitee From 9d8fe810da4bcb382e9fbd3b7e15b54cb875c89f Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 09:30:42 +0800 Subject: [PATCH 08/18] cleancode --- .../op_kernel/hstu_dense_backward_kernel.h | 722 ++++++------------ .../hstu_dense_backward_kernel_common.h | 255 +++++++ 2 files changed, 502 insertions(+), 475 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index e97f738e..1445d3bb 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -36,154 +36,56 @@ struct BlockInfo { }; template -class HstuDenseBackwardKernel { +class HstuDenseBackwardKernel : public HstuDenseBackwardKernelInterface { public: __aicore__ inline HstuDenseBackwardKernel() {} __aicore__ inline void Compute(Args &args) { GET_TILING_DATA(tilingData, args.tiling); - REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), qkMatmul, &tilingData.qkMatmul, qGradMatmul, - &tilingData.qGradMatmul, kGradMatmul, &tilingData.kGradMatmul, vGradMatmul, - &tilingData.vGradMatmul); + REGIST_MATMUL_OBJ(&this->pipe, GetSysWorkSpacePtr(), this->qkMatmul, &tilingData.->qkMatmul, this->qGradMatmul, + &tilingData.->qGradMatmul, this->kGradMatmul, &tilingData.->kGradMatmul, this->vGradMatmul, + &tilingData.->vGradMatmul); uint64_t tilingPtr = reinterpret_cast(args.tiling); - qkMatmul.SetUserDefInfo(tilingPtr); - qGradMatmul.SetUserDefInfo(tilingPtr); - kGradMatmul.SetUserDefInfo(tilingPtr); - vGradMatmul.SetUserDefInfo(tilingPtr); - - Init(args); - ComputeFirst(); - ComputeSecond(); - } - - __aicore__ inline void InitGlobalBuffer(Args &args) - { - GET_TILING_DATA(tilingData, args.tiling); - - batchSize = tilingData.batchSize; - seqLen = tilingData.seqLen; - headNum = tilingData.headNum; - headDim = tilingData.headDim; - - maxSeqLen = tilingData.maxSeqLen; - biasGradSeqLen = tilingData.biasGradSeqLen; - siluScale = tilingData.siluScale; - - blockHeight = tilingData.blockHeight; - - maskType = tilingData.maskType; - enableBias = tilingData.enableBias; - - rowBlockNum = (seqLen + blockHeight - 1) / blockHeight; - colBlockNum = (seqLen + blockHeight - 1) / blockHeight; - totalRowBlockNum = batchSize * headNum * rowBlockNum; - totalColBlockNum = batchSize * headNum * colBlockNum; - totalBlockNum = totalRowBlockNum * colBlockNum; - - int64_t totalElementOfQ = batchSize * maxSeqLen * headNum * headDim; - int64_t totalElementOfAttnBias = batchSize * headNum * biasGradSeqLen * biasGradSeqLen; - - grad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.grad), totalElementOfQ); - q.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.q), totalElementOfQ); - k.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.k), totalElementOfQ); - v.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.v), totalElementOfQ); - if (enableBias) { - attnBias.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBias), totalElementOfAttnBias); - } - if (IfMask(maskType, MaskType::MASK_CUSTOM)) { - mask.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.mask), totalElementOfAttnBias); - } - - qGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.qGrad), totalElementOfQ); - kGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.kGrad), totalElementOfQ); - vGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.vGrad), totalElementOfQ); - attnBiasGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBiasGrad), totalElementOfAttnBias); - } - - __aicore__ inline void InitPipe(Args &args) - { - GM_ADDR workspace = args.workspace; - int64_t qkMatmulTempSpace = blockHeight * blockHeight; - int64_t gvMatmulTempSpace = blockHeight * blockHeight; - int64_t vGradAccumTempSpace = blockHeight * headDim; - int64_t kGradAccumTempSpace = blockHeight * headDim; - int64_t scoreTempSpace = blockHeight * blockHeight; - int64_t maskTempSpace = blockHeight * blockHeight; - - int64_t totalTempSpaceForOneVec = - MID_USE_TIMES * ((vGradAccumTempSpace + kGradAccumTempSpace) * sizeof(float) + - (qkMatmulTempSpace + gvMatmulTempSpace + scoreTempSpace) * sizeof(qType)) + - maskTempSpace * sizeof(qType); - - curAICWorkspace = reinterpret_cast<__gm__ uint8_t *>(workspace) + GetBlockIdx() * totalTempSpaceForOneVec; - - qkTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), qkMatmulTempSpace * MID_USE_TIMES); - curAICWorkspace += qkMatmulTempSpace * sizeof(qType) * MID_USE_TIMES; - - gvTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), gvMatmulTempSpace * MID_USE_TIMES); - curAICWorkspace += gvMatmulTempSpace * sizeof(qType) * MID_USE_TIMES; - - scoreTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), scoreTempSpace * MID_USE_TIMES); - curAICWorkspace += scoreTempSpace * sizeof(qType) * MID_USE_TIMES; - - vGradAccumTemp.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(curAICWorkspace), - vGradAccumTempSpace * MID_USE_TIMES); - curAICWorkspace += vGradAccumTempSpace * sizeof(float) * MID_USE_TIMES; - - kGradAccumTemp.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(curAICWorkspace), - kGradAccumTempSpace * MID_USE_TIMES); - curAICWorkspace += kGradAccumTempSpace * sizeof(float) * MID_USE_TIMES; - - maskTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), maskTempSpace); - - vecOnceDataNum = DATA_ALIGN_BYTES / sizeof(float) * blockHeight; - pipe.InitBuffer(queueVecScoreQK, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); - pipe.InitBuffer(queueVecScoreGV, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); - pipe.InitBuffer(queueVecScoreMask, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); - pipe.InitBuffer(queueVecScoreBias, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); - - pipe.InitBuffer(queueOutputScore, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); - pipe.InitBuffer(queueOutputBias, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); - pipe.InitBuffer(queueOutputTemp, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); - } - - __aicore__ inline void Init(Args &args) - { - InitGlobalBuffer(args); - InitPipe(args); - CreateMask(); + this->qkMatmul.SetUserDefInfo(tilingPtr); + this->qGradMatmul.SetUserDefInfo(tilingPtr); + this->kGradMatmul.SetUserDefInfo(tilingPtr); + this->vGradMatmul.SetUserDefInfo(tilingPtr); + + this->Init(args); + this->ComputeFirst(); + this->ComputeSecond(); } __aicore__ inline void CalcBaseOffsets(int64_t curTaskId, bool isCol = true) { - taskInfo[curTaskId].qkLeftOffset = taskInfo[curTaskId].batchId * seqLen * headNum * headDim + - taskInfo[curTaskId].rowId * blockHeight * headNum * headDim + - taskInfo[curTaskId].headId * headDim; - taskInfo[curTaskId].qkRightOffset = taskInfo[curTaskId].batchId * seqLen * headNum * headDim + - taskInfo[curTaskId].colId * blockHeight * headNum * headDim + - taskInfo[curTaskId].headId * headDim; - taskInfo[curTaskId].kGradLeftOffset = taskInfo[curTaskId].batchId * headNum * biasGradSeqLen * biasGradSeqLen + - taskInfo[curTaskId].headId * biasGradSeqLen * biasGradSeqLen + - taskInfo[curTaskId].rowId * blockHeight * biasGradSeqLen + - taskInfo[curTaskId].colId * blockHeight; + this->taskInfo[curTaskId].qkLeftOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + + this->taskInfo[curTaskId].rowId * this->blockHeight * this->headNum * this->headDim + + this->taskInfo[curTaskId].headId * this->headDim; + this->taskInfo[curTaskId].qkRightOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + + this->taskInfo[curTaskId].colId * this->blockHeight * this->headNum * this->headDim + + this->taskInfo[curTaskId].headId * this->headDim; + this->taskInfo[curTaskId].kGradLeftOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->biasGradSeqLen * this->biasGradSeqLen + + this->taskInfo[curTaskId].headId * this->biasGradSeqLen * this->biasGradSeqLen + + this->taskInfo[curTaskId].rowId * this->blockHeight * this->biasGradSeqLen + + this->taskInfo[curTaskId].colId * this->blockHeight; if (isCol) { - taskInfo[curTaskId].vGradRightOffset = taskInfo[curTaskId].batchId * seqLen * headNum * headDim + - taskInfo[curTaskId].rowId * blockHeight * headNum * headDim + - taskInfo[curTaskId].headId * headDim; + this->taskInfo[curTaskId].vGradRightOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + + this->taskInfo[curTaskId].rowId * this->blockHeight * this->headNum * this->headDim + + this->taskInfo[curTaskId].headId * this->headDim; - taskInfo[curTaskId].rowLine = seqLen - taskInfo[curTaskId].rowId * blockHeight; - if (taskInfo[curTaskId].rowLine > blockHeight) { - taskInfo[curTaskId].rowLine = blockHeight; + this->taskInfo[curTaskId].rowLine = this->seqLen - this->taskInfo[curTaskId].rowId * this->blockHeight; + if (this->taskInfo[curTaskId].rowLine > this->blockHeight) { + this->taskInfo[curTaskId].rowLine = this->blockHeight; } } else { - taskInfo[curTaskId].vGradRightOffset = taskInfo[curTaskId].batchId * seqLen * headNum * headDim + - taskInfo[curTaskId].colId * blockHeight * headNum * headDim + - taskInfo[curTaskId].headId * headDim; + this->taskInfo[curTaskId].vGradRightOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + + this->taskInfo[curTaskId].colId * this->blockHeight * this->headNum * this->headDim + + this->taskInfo[curTaskId].headId * this->headDim; - taskInfo[curTaskId].colLine = seqLen - taskInfo[curTaskId].colId * blockHeight; - if (taskInfo[curTaskId].colLine > blockHeight) { - taskInfo[curTaskId].colLine = blockHeight; + this->taskInfo[curTaskId].colLine = this->seqLen - this->taskInfo[curTaskId].colId * this->blockHeight; + if (this->taskInfo[curTaskId].colLine > this->blockHeight) { + this->taskInfo[curTaskId].colLine = this->blockHeight; } } } @@ -192,86 +94,86 @@ public: { int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; int64_t midResultIdx = taskId % MID_USE_TIMES; - int64_t outOffset = midResultIdx * blockHeight * blockHeight; + int64_t outOffset = midResultIdx * this->blockHeight * this->blockHeight; - qkMatmul.SetTail(taskInfo[curTaskId].rowLine, taskInfo[curTaskId].colLine, headDim); - DoQKMatmulImpl(taskInfo[curTaskId].qkLeftOffset, taskInfo[curTaskId].qkRightOffset, outOffset); + this->qkMatmul.SetTail(this->taskInfo[curTaskId].rowLine, this->taskInfo[curTaskId].colLine, this->headDim); + DoQKMatmulImpl(this->taskInfo[curTaskId].qkLeftOffset, this->taskInfo[curTaskId].qkRightOffset, outOffset); } __aicore__ inline void DoQKMatmulImpl(int64_t left, int64_t right, int64_t out) { - qkMatmul.SetTensorA(q[left]); - qkMatmul.SetTensorB(k[right], true); + this->qkMatmul.SetTensorA(this->q[left]); + this->qkMatmul.SetTensorB(this->k[right], true); - qkMatmul.template IterateAll(qkTemp[out], 0, false, true); + this->qkMatmul.template IterateAll(this->qkTemp[out], 0, false, true); } __aicore__ inline void DoGVMatmul(int64_t taskId) { int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; int64_t midResultIdx = taskId % MID_USE_TIMES; - int64_t outOffset = midResultIdx * blockHeight * blockHeight; + int64_t outOffset = midResultIdx * this->blockHeight * this->blockHeight; - qkMatmul.SetTail(taskInfo[curTaskId].rowLine, taskInfo[curTaskId].colLine, headDim); - DoGVMatmulImpl(taskInfo[curTaskId].qkLeftOffset, taskInfo[curTaskId].qkRightOffset, outOffset); + this->qkMatmul.SetTail(this->taskInfo[curTaskId].rowLine, this->taskInfo[curTaskId].colLine, this->headDim); + DoGVMatmulImpl(this->taskInfo[curTaskId].qkLeftOffset, this->taskInfo[curTaskId].qkRightOffset, outOffset); } __aicore__ inline void DoGVMatmulImpl(int64_t left, int64_t right, int64_t out) { - qkMatmul.SetTensorA(grad[left]); - qkMatmul.SetTensorB(v[right], true); + this->qkMatmul.SetTensorA(this->grad[left]); + this->qkMatmul.SetTensorB(this->v[right], true); - qkMatmul.template IterateAll(gvTemp[out], 0, false, true); + this->qkMatmul.template IterateAll(this->gvTemp[out], 0, false, true); } __aicore__ inline void DoQGradMatmul(int64_t taskId) { int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; - int64_t midAccumIdx = taskInfo[curTaskId].accumId % MID_USE_TIMES; - int64_t outOffset = midAccumIdx * blockHeight * headDim; + int64_t midAccumIdx = this->taskInfo[curTaskId].accumId % MID_USE_TIMES; + int64_t outOffset = midAccumIdx * this->blockHeight * this->headDim; - bool isNew = taskInfo[curTaskId].colId == 0; + bool isNew = this->taskInfo[curTaskId].colId == 0; - qGradMatmul.SetTail(taskInfo[curTaskId].rowLine, headDim, taskInfo[curTaskId].colLine); - DoQGradMatmulImpl(taskInfo[curTaskId].kGradLeftOffset, taskInfo[curTaskId].vGradRightOffset, outOffset, isNew); + this->qGradMatmul.SetTail(this->taskInfo[curTaskId].rowLine, this->headDim, this->taskInfo[curTaskId].colLine); + DoQGradMatmulImpl(this->taskInfo[curTaskId].kGradLeftOffset, this->taskInfo[curTaskId].vGradRightOffset, outOffset, isNew); } __aicore__ inline void DoQGradMatmulImpl(int64_t left, int64_t right, int64_t out, bool isNew) { - qGradMatmul.SetTensorA(attnBiasGrad[left]); - qGradMatmul.SetTensorB(k[right]); + this->qGradMatmul.SetTensorA(this->attnBiasGrad[left]); + this->qGradMatmul.SetTensorB(this->k[right]); if (isNew) { - qGradMatmul.template IterateAll(kGradAccumTemp[out], 0, false, true); + this->qGradMatmul.template IterateAll(this->kGradAccumTemp[out], 0, false, true); } else { - qGradMatmul.template IterateAll(kGradAccumTemp[out], 1, false, true); + this->qGradMatmul.template IterateAll(this->kGradAccumTemp[out], 1, false, true); } } __aicore__ inline void DoKGradMatmul(int64_t taskId) { int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; - int64_t midAccumIdx = taskInfo[curTaskId].accumId % MID_USE_TIMES; - int64_t outOffset = midAccumIdx * blockHeight * headDim; + int64_t midAccumIdx = this->taskInfo[curTaskId].accumId % MID_USE_TIMES; + int64_t outOffset = midAccumIdx * this->blockHeight * this->headDim; bool isNew = false; - if (IfMask(maskType, MaskType::MASK_TRIL)) { - isNew = taskInfo[curTaskId].rowId == taskInfo[curTaskId].colId; + if (IfMask(this->maskType, MaskType::MASK_TRIL)) { + isNew = this->taskInfo[curTaskId].rowId == this->taskInfo[curTaskId].colId; } else { - isNew = taskInfo[curTaskId].rowId == 0; + isNew = this->taskInfo[curTaskId].rowId == 0; } - kGradMatmul.SetTail(taskInfo[curTaskId].colLine, headDim, taskInfo[curTaskId].rowLine); - DoKGradMatmulImpl(taskInfo[curTaskId].kGradLeftOffset, taskInfo[curTaskId].vGradRightOffset, outOffset, isNew); + this->kGradMatmul.SetTail(this->taskInfo[curTaskId].colLine, this->headDim, this->taskInfo[curTaskId].rowLine); + DoKGradMatmulImpl(this->taskInfo[curTaskId].kGradLeftOffset, this->taskInfo[curTaskId].vGradRightOffset, outOffset, isNew); } __aicore__ inline void DoKGradMatmulImpl(int64_t left, int64_t right, int64_t out, bool isNew) { - kGradMatmul.SetTensorA(attnBiasGrad[left], true); - kGradMatmul.SetTensorB(q[right]); + this->kGradMatmul.SetTensorA(this->attnBiasGrad[left], true); + this->kGradMatmul.SetTensorB(this->q[right]); if (isNew) { - kGradMatmul.template IterateAll(kGradAccumTemp[out], 0, false, true); + this->kGradMatmul.template IterateAll(this->kGradAccumTemp[out], 0, false, true); } else { - kGradMatmul.template IterateAll(kGradAccumTemp[out], 1, false, true); + this->kGradMatmul.template IterateAll(this->kGradAccumTemp[out], 1, false, true); } } @@ -279,75 +181,34 @@ public: { int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; int64_t midResultIdx = taskId % MID_USE_TIMES; - int64_t midAccumIdx = taskInfo[curTaskId].accumId % MID_USE_TIMES; + int64_t midAccumIdx = this->taskInfo[curTaskId].accumId % MID_USE_TIMES; - int64_t scoreTempOffset = midResultIdx * blockHeight * blockHeight; - int64_t outOffset = midAccumIdx * blockHeight * headDim; + int64_t scoreTempOffset = midResultIdx * this->blockHeight * this->blockHeight; + int64_t outOffset = midAccumIdx * this->blockHeight * this->headDim; bool isNew = false; - if (IfMask(maskType, MaskType::MASK_TRIL)) { - isNew = taskInfo[curTaskId].rowId == taskInfo[curTaskId].colId; + if (IfMask(this->maskType, MaskType::MASK_TRIL)) { + isNew = this->taskInfo[curTaskId].rowId == this->taskInfo[curTaskId].colId; } else { - isNew = taskInfo[curTaskId].rowId == 0; + isNew = this->taskInfo[curTaskId].rowId == 0; } - vGradMatmul.SetTail(taskInfo[curTaskId].colLine, headDim, taskInfo[curTaskId].rowLine); - DoVGradMatmulImpl(scoreTempOffset, taskInfo[curTaskId].vGradRightOffset, outOffset, isNew); + this->vGradMatmul.SetTail(this->taskInfo[curTaskId].colLine, this->headDim, this->taskInfo[curTaskId].rowLine); + DoVGradMatmulImpl(scoreTempOffset, this->taskInfo[curTaskId].vGradRightOffset, outOffset, isNew); } __aicore__ inline void DoVGradMatmulImpl(int64_t left, int64_t right, int64_t out, bool isNew) { - vGradMatmul.SetTensorA(scoreTemp[left], true); - vGradMatmul.SetTensorB(grad[right]); + this->vGradMatmul.SetTensorA(this->scoreTemp[left], true); + this->vGradMatmul.SetTensorB(this->grad[right]); if (isNew) { - vGradMatmul.template IterateAll(vGradAccumTemp[out], 0, false, true); + this->vGradMatmul.template IterateAll(this->vGradAccumTemp[out], 0, false, true); } else { - vGradMatmul.template IterateAll(vGradAccumTemp[out], 1, false, true); + this->vGradMatmul.template IterateAll(this->vGradAccumTemp[out], 1, false, true); } } - __aicore__ inline void CreateMask() - { - if (IfMask(maskType, MaskType::MASK_TRIL)) { - // create lower triangular - int64_t total = blockHeight * blockHeight; - int64_t remain = total; - int64_t thisLen = vecOnceDataNum; - while (remain > 0) { - if (remain < thisLen) { - thisLen = remain; - } - - int64_t baseOffset = total - remain; - int32_t validNums = 1 + baseOffset / blockHeight; - - LocalTensor input = queueVecScoreMask.AllocTensor(); - Duplicate(input, 0, thisLen); - for (int i = 0; i < thisLen / blockHeight; i++) { - if (validNums + i >= blockHeight) { - Duplicate(input[i * blockHeight], 1, blockHeight); - } else { - Duplicate(input[i * blockHeight], 1, validNums + i); - } - } - queueVecScoreMask.EnQue(input); - - LocalTensor newInput = queueVecScoreMask.DeQue(); - LocalTensor output = queueOutputTemp.AllocTensor(); - DataCopy(output, newInput, thisLen); - queueOutputTemp.EnQue(output); - queueVecScoreMask.FreeTensor(newInput); - - output = queueOutputTemp.DeQue(); - DataCopy(maskTemp[baseOffset], output, thisLen); - queueOutputTemp.FreeTensor(output); - remain -= thisLen; - } - - pipe_barrier(PIPE_ALL); - } - } __aicore__ inline void CastQType2Float(LocalTensor dstTensor, LocalTensor srcTensor, LocalTensor midTensor, int64_t len) @@ -360,52 +221,52 @@ public: LocalTensor &inputMask, LocalTensor &inputBias, int64_t thisLen, bool useMask) { - LocalTensor outputMidTemp = queueOutputTemp.AllocTensor(); + LocalTensor outputMidTemp = this->queueOutputTemp.template AllocTensor(); if (!std::is_same::value) { CastQType2Float(inputQK, inputQK.template ReinterpretCast(), outputMidTemp, thisLen); CastQType2Float(inputGV, inputGV.template ReinterpretCast(), outputMidTemp, thisLen); if (useMask) { CastQType2Float(inputMask, inputMask.template ReinterpretCast(), outputMidTemp, thisLen); } - if (enableBias) { + if (this->enableBias) { CastQType2Float(inputBias, inputBias.template ReinterpretCast(), outputMidTemp, thisLen); } } - queueOutputTemp.FreeTensor(outputMidTemp); + this->queueOutputTemp.template FreeTensor(outputMidTemp); } __aicore__ inline void CalcuScoreWithFloat32(int64_t thisLen, bool useMask) { - auto inputQK = queueVecScoreQK.DeQue(); - auto inputGV = queueVecScoreGV.DeQue(); - LocalTensor inputMask = useMask ? queueVecScoreMask.DeQue() : - queueVecScoreMask.AllocTensor(); - LocalTensor inputBias = enableBias ? queueVecScoreBias.DeQue() : - queueVecScoreBias.AllocTensor(); + auto inputQK = this->queueVecScoreQK.template DeQue(); + auto inputGV = this->queueVecScoreGV.template DeQue(); + LocalTensor inputMask = useMask ? this->queueVecScoreMask.template DeQue() : + this->queueVecScoreMask.template AllocTensor(); + LocalTensor inputBias = this->enableBias ? this->queueVecScoreBias.template DeQue() : + this->queueVecScoreBias.template AllocTensor(); CastInputData(inputQK, inputGV, inputMask, inputBias, thisLen, useMask); - if (enableBias) { + if (this->enableBias) { // qkb = qk + attn_bias Add(inputQK, inputQK, inputBias, thisLen); } - // score = F.silu(qkb) * siluScale * mask + // score = F.silu(qkb) * this->siluScale * this->mask Silu(inputBias, inputQK, thisLen); - Muls(inputBias, inputBias, siluScale, thisLen); + Muls(inputBias, inputBias, this->siluScale, thisLen); if (useMask) { Mul(inputBias, inputBias, inputMask, thisLen); } - // score_grad = gv * siluScale * mask - Muls(inputGV, inputGV, siluScale, thisLen); + // score_grad = gv * this->siluScale * this->mask + Muls(inputGV, inputGV, this->siluScale, thisLen); if (useMask) { Mul(inputGV, inputGV, inputMask, thisLen); } // bias_grad = (F.sigmoid(qkb) * (1 + qkb * (1 - F.sigmoid(qkb)))) * score_grad // F.sigmoid(qkb) - LocalTensor sigmoidBuffer = queueOutputTemp.AllocTensor(); + LocalTensor sigmoidBuffer = this->queueOutputTemp.template AllocTensor(); Sigmoid(inputMask, inputQK, sigmoidBuffer, thisLen); // qkb * F.sigmoid(qkb) LocalTensor tmpBuffer = sigmoidBuffer.template ReinterpretCast(); @@ -417,15 +278,15 @@ public: // F.sigmoid(qkb) * (1 + qkb * (1 - F.sigmoid(qkb))) Mul(inputQK, inputMask, tmpBuffer, thisLen); - queueVecScoreMask.FreeTensor(inputMask); - queueOutputTemp.FreeTensor(sigmoidBuffer); + this->queueVecScoreMask.template FreeTensor(inputMask); + this->queueOutputTemp.template FreeTensor(sigmoidBuffer); // (F.sigmoid(qkb) * (1 + qkb * (1 - F.sigmoid(qkb)))) * score_grad Mul(inputQK, inputQK, inputGV, thisLen); - queueVecScoreGV.FreeTensor(inputGV); + this->queueVecScoreGV.template FreeTensor(inputGV); - LocalTensor outputScore = queueOutputScore.AllocTensor(); - LocalTensor outputBias = queueOutputBias.AllocTensor(); + LocalTensor outputScore = this->queueOutputScore.template AllocTensor(); + LocalTensor outputBias = this->queueOutputBias.template AllocTensor(); if (!std::is_same::value) { Cast(outputScore, inputBias, RoundMode::CAST_RINT, thisLen); Cast(outputBias, inputQK, RoundMode::CAST_RINT, thisLen); @@ -436,41 +297,41 @@ public: LocalTensor newOutputBias = outputBias.template ReinterpretCast(); DataCopy(newOutputBias, inputQK, thisLen); } - queueVecScoreQK.FreeTensor(inputQK); - queueVecScoreBias.FreeTensor(inputBias); + this->queueVecScoreQK.template FreeTensor(inputQK); + this->queueVecScoreBias.template FreeTensor(inputBias); - queueOutputScore.EnQue(outputScore); - queueOutputBias.EnQue(outputBias); + this->queueOutputScore.template EnQue(outputScore); + this->queueOutputBias.template EnQue(outputBias); } __aicore__ inline void VecScore(int64_t taskId) { int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; - int64_t attnBiasOffset = taskInfo[curTaskId].batchId * headNum * biasGradSeqLen * biasGradSeqLen + - taskInfo[curTaskId].headId * biasGradSeqLen * biasGradSeqLen + - taskInfo[curTaskId].rowId * blockHeight * biasGradSeqLen + - taskInfo[curTaskId].colId * blockHeight; - int64_t attnBiasDiagonalOffset = taskInfo[curTaskId].batchId * headNum * biasGradSeqLen * biasGradSeqLen + - taskInfo[curTaskId].headId * biasGradSeqLen * biasGradSeqLen + - taskInfo[curTaskId].colId * blockHeight * biasGradSeqLen + - taskInfo[curTaskId].rowId * blockHeight; + int64_t attnBiasOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->biasGradSeqLen * this->biasGradSeqLen + + this->taskInfo[curTaskId].headId * this->biasGradSeqLen * this->biasGradSeqLen + + this->taskInfo[curTaskId].rowId * this->blockHeight * this->biasGradSeqLen + + this->taskInfo[curTaskId].colId * this->blockHeight; + int64_t attnBiasDiagonalOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->biasGradSeqLen * this->biasGradSeqLen + + this->taskInfo[curTaskId].headId * this->biasGradSeqLen * this->biasGradSeqLen + + this->taskInfo[curTaskId].colId * this->blockHeight * this->biasGradSeqLen + + this->taskInfo[curTaskId].rowId * this->blockHeight; int64_t maskOffset = 0; - if (IfMask(maskType, MaskType::MASK_CUSTOM)) { - maskOffset = taskInfo[curTaskId].batchId * headNum * maxSeqLen * maxSeqLen + - taskInfo[curTaskId].headId * maxSeqLen * maxSeqLen + - taskInfo[curTaskId].rowId * blockHeight * maxSeqLen + taskInfo[curTaskId].colId * blockHeight; + if (IfMask(this->maskType, MaskType::MASK_CUSTOM)) { + maskOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->maxSeqLen * this->maxSeqLen + + this->taskInfo[curTaskId].headId * this->maxSeqLen * this->maxSeqLen + + this->taskInfo[curTaskId].rowId * this->blockHeight * this->maxSeqLen + this->taskInfo[curTaskId].colId * this->blockHeight; } bool useMask = false; - if (IfMask(maskType, MaskType::MASK_TRIL)) { - useMask = taskInfo[curTaskId].rowId == taskInfo[curTaskId].colId; - } else if (IfMask(maskType, MaskType::MASK_CUSTOM)) { + if (IfMask(this->maskType, MaskType::MASK_TRIL)) { + useMask = this->taskInfo[curTaskId].rowId == this->taskInfo[curTaskId].colId; + } else if (IfMask(this->maskType, MaskType::MASK_CUSTOM)) { useMask = true; } - VecScoreImpl(taskId, attnBiasOffset, attnBiasDiagonalOffset, maskOffset, taskInfo[curTaskId].rowLine, - taskInfo[curTaskId].colLine, useMask); + VecScoreImpl(taskId, attnBiasOffset, attnBiasDiagonalOffset, maskOffset, this->taskInfo[curTaskId].rowLine, + this->taskInfo[curTaskId].colLine, useMask); } __aicore__ inline void CopyInPadding(LocalTensor dstTensor, GlobalTensor srcTensor, int64_t rowNum, @@ -479,8 +340,8 @@ public: uint16_t blockCount = rowNum; uint32_t blockLen = colNum * sizeof(qType); uint32_t srcStride = (seqLen - colNum) * sizeof(qType); - uint32_t dstStride = (blockHeight - colNum) / (DATA_ALIGN_BYTES / sizeof(qType)); - uint8_t rightPadding = (blockHeight - colNum) % (DATA_ALIGN_BYTES / sizeof(qType)); + uint32_t dstStride = (this->blockHeight - colNum) / (DATA_ALIGN_BYTES / sizeof(qType)); + uint8_t rightPadding = (this->blockHeight - colNum) % (DATA_ALIGN_BYTES / sizeof(qType)); DataCopyExtParams copyParams{blockCount, blockLen, srcStride, dstStride, 0}; DataCopyPadExtParams padParams{true, 0, rightPadding, 0}; @@ -492,7 +353,7 @@ public: { uint16_t blockCount = rowNum; uint32_t blockLen = colNum * sizeof(qType); - uint32_t srcStride = (blockHeight - colNum) / (DATA_ALIGN_BYTES / sizeof(qType)); + uint32_t srcStride = (this->blockHeight - colNum) / (DATA_ALIGN_BYTES / sizeof(qType)); uint32_t dstStride = (seqLen - colNum) * sizeof(qType); DataCopyExtParams copyParams{blockCount, blockLen, srcStride, dstStride, 0}; @@ -504,39 +365,39 @@ public: { int64_t gvOffset = qkOffset; int64_t scoreTempOffset = qkOffset; - LocalTensor inputQK = queueVecScoreQK.AllocTensor(); - DataCopy(inputQK.template ReinterpretCast(), qkTemp[qkOffset], thisLen); - queueVecScoreQK.EnQue(inputQK); + LocalTensor inputQK = this->queueVecScoreQK.template AllocTensor(); + DataCopy(inputQK.template ReinterpretCast(), this->qkTemp[qkOffset], thisLen); + this->queueVecScoreQK.template EnQue(inputQK); - LocalTensor inputGV = queueVecScoreGV.AllocTensor(); - DataCopy(inputGV.template ReinterpretCast(), gvTemp[gvOffset], thisLen); - queueVecScoreGV.EnQue(inputGV); + LocalTensor inputGV = this->queueVecScoreGV.template AllocTensor(); + DataCopy(inputGV.template ReinterpretCast(), this->gvTemp[gvOffset], thisLen); + this->queueVecScoreGV.template EnQue(inputGV); if (useMask) { - LocalTensor inputMask = queueVecScoreMask.AllocTensor(); - if (IfMask(maskType, MaskType::MASK_TRIL)) { - DataCopy(inputMask.template ReinterpretCast(), maskTemp[curMaskOffset], thisLen); + LocalTensor inputMask = this->queueVecScoreMask.template AllocTensor(); + if (IfMask(this->maskType, MaskType::MASK_TRIL)) { + DataCopy(inputMask.template ReinterpretCast(), this->maskTemp[curMaskOffset], thisLen); } - if (IfMask(maskType, MaskType::MASK_CUSTOM)) { - CopyInPadding(inputMask.template ReinterpretCast(), mask[curMaskOffset], validRowNum, - totalColNum, maxSeqLen); + if (IfMask(this->maskType, MaskType::MASK_CUSTOM)) { + CopyInPadding(inputMask.template ReinterpretCast(), this->mask[curMaskOffset], validRowNum, + totalColNum, this->maxSeqLen); } - queueVecScoreMask.EnQue(inputMask); + this->queueVecScoreMask.template EnQue(inputMask); } - if (enableBias) { - LocalTensor inputBias = queueVecScoreBias.AllocTensor(); - CopyInPadding(inputBias.template ReinterpretCast(), attnBias[curAttnBiasOffset], validRowNum, - totalColNum, biasGradSeqLen); - queueVecScoreBias.EnQue(inputBias); + if (this->enableBias) { + LocalTensor inputBias = this->queueVecScoreBias.template AllocTensor(); + CopyInPadding(inputBias.template ReinterpretCast(), this->attnBias[curAttnBiasOffset], validRowNum, + totalColNum, this->biasGradSeqLen); + this->queueVecScoreBias.template EnQue(inputBias); } CalcuScoreWithFloat32(thisLen, useMask); - LocalTensor outputScore = queueOutputScore.DeQue(); - LocalTensor outputBias = queueOutputBias.DeQue(); - DataCopy(scoreTemp[scoreTempOffset], outputScore, thisLen); - CopyOutPadding(attnBiasGrad[curAttnBiasOffset], outputBias, validRowNum, totalColNum, biasGradSeqLen); - queueOutputScore.FreeTensor(outputScore); - queueOutputBias.FreeTensor(outputBias); + LocalTensor outputScore = this->queueOutputScore.template DeQue(); + LocalTensor outputBias = this->queueOutputBias.template DeQue(); + DataCopy(this->scoreTemp[scoreTempOffset], outputScore, thisLen); + CopyOutPadding(this->attnBiasGrad[curAttnBiasOffset], outputBias, validRowNum, totalColNum, this->biasGradSeqLen); + this->queueOutputScore.template FreeTensor(outputScore); + this->queueOutputBias.template FreeTensor(outputBias); } __aicore__ inline void VecScoreImpl(int64_t taskId, int64_t attnBiasOffset, int64_t attnBiasDiagonalOffset, @@ -545,9 +406,9 @@ public: int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; int64_t midResultIdx = taskId % MID_USE_TIMES; - int64_t total = blockHeight * blockHeight; + int64_t total = this->blockHeight * this->blockHeight; int64_t remain = total; - int64_t thisLen = vecOnceDataNum; + int64_t thisLen = this->vecOnceDataNum; while (remain > 0) { if (remain < thisLen) { thisLen = remain; @@ -555,35 +416,35 @@ public: int64_t baseOffset = total - remain; - int64_t startRowNum = baseOffset / blockHeight; - int64_t thisRowNum = thisLen / blockHeight; + int64_t startRowNum = baseOffset / this->blockHeight; + int64_t thisRowNum = thisLen / this->blockHeight; int64_t validRowNum = totalRowNum - startRowNum; validRowNum = validRowNum > thisRowNum ? thisRowNum : validRowNum; validRowNum = validRowNum < 0 ? 0 : validRowNum; - int64_t qkOffset = midResultIdx * blockHeight * blockHeight + baseOffset; - int64_t curAttnBiasOffset = attnBiasOffset + startRowNum * biasGradSeqLen; + int64_t qkOffset = midResultIdx * this->blockHeight * this->blockHeight + baseOffset; + int64_t curAttnBiasOffset = attnBiasOffset + startRowNum * this->biasGradSeqLen; int64_t curMaskOffset = 0; - if (IfMask(maskType, MaskType::MASK_TRIL)) { + if (IfMask(this->maskType, MaskType::MASK_TRIL)) { curMaskOffset = maskOffset + baseOffset; - } else if (IfMask(maskType, MaskType::MASK_CUSTOM)) { - curMaskOffset = maskOffset + startRowNum * maxSeqLen; + } else if (IfMask(this->maskType, MaskType::MASK_CUSTOM)) { + curMaskOffset = maskOffset + startRowNum * this->maxSeqLen; } if (validRowNum > 0) { ValidVecScore(thisLen, validRowNum, totalColNum, qkOffset, curMaskOffset, curAttnBiasOffset, useMask); } - if (enableBias && IfMask(maskType, MaskType::MASK_TRIL) && !useMask) { - LocalTensor outputTempTensor = queueOutputTemp.AllocTensor(); + if (this->enableBias && IfMask(this->maskType, MaskType::MASK_TRIL) && !useMask) { + LocalTensor outputTempTensor = this->queueOutputTemp.template AllocTensor(); Duplicate(outputTempTensor, 0, thisLen); - queueOutputTemp.EnQue(outputTempTensor); + this->queueOutputTemp.template EnQue(outputTempTensor); - int64_t curAttnBiasDiagonalOffset = attnBiasDiagonalOffset + startRowNum * biasGradSeqLen; - outputTempTensor = queueOutputTemp.DeQue(); - CopyOutPadding(attnBiasGrad[curAttnBiasDiagonalOffset], outputTempTensor, thisRowNum, totalRowNum, - biasGradSeqLen); - queueOutputTemp.FreeTensor(outputTempTensor); + int64_t curAttnBiasDiagonalOffset = attnBiasDiagonalOffset + startRowNum * this->biasGradSeqLen; + outputTempTensor = this->queueOutputTemp.template DeQue(); + CopyOutPadding(this->attnBiasGrad[curAttnBiasDiagonalOffset], outputTempTensor, thisRowNum, totalRowNum, + this->biasGradSeqLen); + this->queueOutputTemp.template FreeTensor(outputTempTensor); } remain = remain - thisLen; @@ -593,20 +454,20 @@ public: __aicore__ inline void DoTrans(int64_t taskId, GlobalTensor from, GlobalTensor to, bool isCol = true) { int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; - int64_t midResultIdx = taskInfo[curTaskId].accumId % MID_USE_TIMES; - int64_t fromOffset = midResultIdx * blockHeight * headDim; + int64_t midResultIdx = this->taskInfo[curTaskId].accumId % MID_USE_TIMES; + int64_t fromOffset = midResultIdx * this->blockHeight * this->headDim; int64_t toOffset = 0; int64_t total = 0; if (isCol) { - toOffset = taskInfo[curTaskId].batchId * seqLen * headNum * headDim + - taskInfo[curTaskId].colId * blockHeight * headNum * headDim + - taskInfo[curTaskId].headId * headDim; - total = taskInfo[curTaskId].colLine * headDim; + toOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + + this->taskInfo[curTaskId].colId * this->blockHeight * this->headNum * this->headDim + + this->taskInfo[curTaskId].headId * this->headDim; + total = this->taskInfo[curTaskId].colLine * this->headDim; } else { - toOffset = taskInfo[curTaskId].batchId * seqLen * headNum * headDim + - taskInfo[curTaskId].rowId * blockHeight * headNum * headDim + - taskInfo[curTaskId].headId * headDim; - total = taskInfo[curTaskId].rowLine * headDim; + toOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + + this->taskInfo[curTaskId].rowId * this->blockHeight * this->headNum * this->headDim + + this->taskInfo[curTaskId].headId * this->headDim; + total = this->taskInfo[curTaskId].rowLine * this->headDim; } DoTransImpl(from, to, fromOffset, toOffset, total); @@ -616,36 +477,36 @@ public: int64_t toOffset, int64_t total = 0) { int64_t remain = total; - int64_t thisLen = vecOnceDataNum; + int64_t thisLen = this->vecOnceDataNum; while (remain > 0) { if (thisLen > remain) { thisLen = remain; } int64_t curFromOffset = total - remain; - int64_t curToOffset = curFromOffset * headNum; + int64_t curToOffset = curFromOffset * this->headNum; - LocalTensor input = queueVecScoreQK.AllocTensor(); + LocalTensor input = this->queueVecScoreQK.template AllocTensor(); DataCopy(input, from[fromOffset + curFromOffset], thisLen); - queueVecScoreQK.EnQue(input); + this->queueVecScoreQK.template EnQue(input); - LocalTensor newInput = queueVecScoreQK.DeQue(); - LocalTensor output = queueOutputTemp.AllocTensor(); + LocalTensor newInput = this->queueVecScoreQK.template DeQue(); + LocalTensor output = this->queueOutputTemp.template AllocTensor(); if (std::is_same::value) { DataCopy(output.template ReinterpretCast(), newInput, thisLen); } else { Cast(output, newInput, RoundMode::CAST_RINT, thisLen); } - queueOutputTemp.EnQue(output); - queueVecScoreQK.FreeTensor(newInput); + this->queueOutputTemp.template EnQue(output); + this->queueVecScoreQK.template FreeTensor(newInput); - LocalTensor newOutput = queueOutputTemp.DeQue(); - uint16_t blockCount = thisLen / headDim; - uint16_t blockLen = headDim * sizeof(qType) / DATA_ALIGN_BYTES; - uint16_t dstStride = (headNum * headDim - headDim) * sizeof(qType) / DATA_ALIGN_BYTES; + LocalTensor newOutput = this->queueOutputTemp.template DeQue(); + uint16_t blockCount = thisLen / this->headDim; + uint16_t blockLen = this->headDim * sizeof(qType) / DATA_ALIGN_BYTES; + uint16_t dstStride = (this->headNum * this->headDim - this->headDim) * sizeof(qType) / DATA_ALIGN_BYTES; DataCopyParams copyParams{blockCount, blockLen, 0, dstStride}; DataCopy(to[toOffset + curToOffset], newOutput, copyParams); - queueOutputTemp.FreeTensor(newOutput); + this->queueOutputTemp.template FreeTensor(newOutput); remain = remain - thisLen; } @@ -663,19 +524,19 @@ public: VecScore(taskId - 1); } - qkMatmul.WaitIterateAll(); - qkMatmul.End(); - qkMatmul.WaitIterateAll(); - qkMatmul.End(); + this->qkMatmul.WaitIterateAll(); + this->qkMatmul.End(); + this->qkMatmul.WaitIterateAll(); + this->qkMatmul.End(); if (taskId > 1) { - vGradMatmul.WaitIterateAll(); - vGradMatmul.End(); - kGradMatmul.WaitIterateAll(); - kGradMatmul.End(); - if (taskInfo[(taskId - TWO) % COMPUTE_PIPE_NUM].accumId != - taskInfo[(taskId - 1) % COMPUTE_PIPE_NUM].accumId) { - DoTrans(taskId - TWO, vGradAccumTemp, vGrad); - DoTrans(taskId - TWO, kGradAccumTemp, kGrad); + this->vGradMatmul.WaitIterateAll(); + this->vGradMatmul.End(); + this->kGradMatmul.WaitIterateAll(); + this->kGradMatmul.End(); + if (this->taskInfo[(taskId - TWO) % COMPUTE_PIPE_NUM].accumId != + this->taskInfo[(taskId - 1) % COMPUTE_PIPE_NUM].accumId) { + DoTrans(taskId - TWO, this->vGradAccumTemp, this->vGrad); + DoTrans(taskId - TWO, this->kGradAccumTemp, this->kGrad); } } } @@ -686,24 +547,24 @@ public: DoVGradMatmul(taskId - TWO); DoKGradMatmul(taskId - TWO); VecScore(taskId - 1); - vGradMatmul.WaitIterateAll(); - vGradMatmul.End(); - kGradMatmul.WaitIterateAll(); - kGradMatmul.End(); - if (taskInfo[(taskId - TWO) % COMPUTE_PIPE_NUM].accumId != - taskInfo[(taskId - 1) % COMPUTE_PIPE_NUM].accumId) { - DoTrans(taskId - TWO, vGradAccumTemp, vGrad); - DoTrans(taskId - TWO, kGradAccumTemp, kGrad); + this->vGradMatmul.WaitIterateAll(); + this->vGradMatmul.End(); + this->kGradMatmul.WaitIterateAll(); + this->kGradMatmul.End(); + if (this->taskInfo[(taskId - TWO) % COMPUTE_PIPE_NUM].accumId != + this->taskInfo[(taskId - 1) % COMPUTE_PIPE_NUM].accumId) { + DoTrans(taskId - TWO, this->vGradAccumTemp, this->vGrad); + DoTrans(taskId - TWO, this->kGradAccumTemp, this->kGrad); } DoVGradMatmul(taskId - 1); DoKGradMatmul(taskId - 1); - vGradMatmul.WaitIterateAll(); - vGradMatmul.End(); - kGradMatmul.WaitIterateAll(); - kGradMatmul.End(); - DoTrans(taskId - 1, vGradAccumTemp, vGrad); - DoTrans(taskId - 1, kGradAccumTemp, kGrad); + this->vGradMatmul.WaitIterateAll(); + this->vGradMatmul.End(); + this->kGradMatmul.WaitIterateAll(); + this->kGradMatmul.End(); + DoTrans(taskId - 1, this->vGradAccumTemp, this->vGrad); + DoTrans(taskId - 1, this->kGradAccumTemp, this->kGrad); } if (taskId == 1) { @@ -711,16 +572,16 @@ public: DoVGradMatmul(taskId - 1); DoKGradMatmul(taskId - 1); - vGradMatmul.WaitIterateAll(); - vGradMatmul.End(); - kGradMatmul.WaitIterateAll(); - kGradMatmul.End(); - DoTrans(taskId - 1, vGradAccumTemp, vGrad); - DoTrans(taskId - 1, kGradAccumTemp, kGrad); + this->vGradMatmul.WaitIterateAll(); + this->vGradMatmul.End(); + this->kGradMatmul.WaitIterateAll(); + this->kGradMatmul.End(); + DoTrans(taskId - 1, this->vGradAccumTemp, this->vGrad); + DoTrans(taskId - 1, this->kGradAccumTemp, this->kGrad); } } - __aicore__ inline void ComputeFirst() + __aicore__ inline void this->ComputeFirst() { int64_t taskId = 0; int64_t accumId = 0; @@ -729,22 +590,22 @@ public: int64_t startId = GetBlockIdx(); int64_t nextCol = totalAivNum * TWO - GetBlockIdx() * TWO - 1; - for (int64_t gColId = startId; gColId < totalColBlockNum;) { - int64_t batchId = gColId / (headNum * colBlockNum); - int64_t colIdInBatch = gColId % (headNum * colBlockNum); - int64_t headId = colIdInBatch / colBlockNum; - int64_t colId = colIdInBatch % colBlockNum; - int64_t colLine = seqLen - colId * blockHeight; - colLine = colLine > blockHeight ? blockHeight : colLine; + for (int64_t gColId = startId; gColId < this->totalColBlockNum;) { + int64_t batchId = gColId / (this->headNum * this->colBlockNum); + int64_t colIdInBatch = gColId % (this->headNum * this->colBlockNum); + int64_t headId = colIdInBatch / this->colBlockNum; + int64_t colId = colIdInBatch % this->colBlockNum; + int64_t colLine = this->seqLen - colId * this->blockHeight; + colLine = colLine > this->blockHeight ? this->blockHeight : colLine; - for (int64_t rowId = 0; rowId < rowBlockNum; rowId++) { - if (IfMask(maskType, MaskType::MASK_TRIL) && rowId < colId) { + for (int64_t rowId = 0; rowId < this->rowBlockNum; rowId++) { + if (IfMask(this->maskType, MaskType::MASK_TRIL) && rowId < colId) { continue; } int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; - taskInfo[curTaskId] = BlockInfo{taskId, batchId, headId, rowId, colId, accumId}; - taskInfo[curTaskId].colLine = colLine; + this->taskInfo[curTaskId] = BlockInfo{taskId, batchId, headId, rowId, colId, accumId}; + this->taskInfo[curTaskId].colLine = colLine; CalcBaseOffsets(curTaskId); FirstStagePipeline(taskId); @@ -763,15 +624,15 @@ public: { DoQGradMatmul(taskId); if (taskId > 0) { - if (taskInfo[(taskId - 1) % COMPUTE_PIPE_NUM].accumId != taskInfo[taskId % COMPUTE_PIPE_NUM].accumId) { - DoTrans(taskId - 1, kGradAccumTemp, qGrad, 0); + if (this->taskInfo[(taskId - 1) % COMPUTE_PIPE_NUM].accumId != this->taskInfo[taskId % COMPUTE_PIPE_NUM].accumId) { + DoTrans(taskId - 1, this->kGradAccumTemp, this->qGrad, 0); } } - qGradMatmul.WaitIterateAll(); - qGradMatmul.End(); + this->qGradMatmul.WaitIterateAll(); + this->qGradMatmul.End(); } - __aicore__ inline void ComputeSecond() + __aicore__ inline void this->ComputeSecond() { SyncAll(); @@ -782,22 +643,22 @@ public: int64_t startId = GetBlockIdx(); int64_t nextRow = totalAivNum * TWO - GetBlockIdx() * TWO - 1; - for (int64_t gRowId = startId; gRowId < totalRowBlockNum;) { - int64_t batchId = gRowId / (headNum * rowBlockNum); - int64_t rowIdInBatch = gRowId % (headNum * rowBlockNum); - int64_t headId = rowIdInBatch / rowBlockNum; - int64_t rowId = rowIdInBatch % rowBlockNum; - int64_t rowLine = seqLen - rowId * blockHeight; - rowLine = rowLine > blockHeight ? blockHeight : rowLine; + for (int64_t gRowId = startId; gRowId < this->totalRowBlockNum;) { + int64_t batchId = gRowId / (this->headNum * this->rowBlockNum); + int64_t rowIdInBatch = gRowId % (this->headNum * this->rowBlockNum); + int64_t headId = rowIdInBatch / this->rowBlockNum; + int64_t rowId = rowIdInBatch % this->rowBlockNum; + int64_t rowLine = this->seqLen - rowId * this->blockHeight; + rowLine = rowLine > this->blockHeight ? this->blockHeight : rowLine; - for (int64_t colId = 0; colId < colBlockNum; colId++) { - if (IfMask(maskType, MaskType::MASK_TRIL) && rowId < colId) { + for (int64_t colId = 0; colId < this->colBlockNum; colId++) { + if (IfMask(this->maskType, MaskType::MASK_TRIL) && rowId < colId) { continue; } int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; - taskInfo[curTaskId] = BlockInfo{taskId, batchId, headId, rowId, colId, accumId}; - taskInfo[curTaskId].rowLine = rowLine; + this->taskInfo[curTaskId] = BlockInfo{taskId, batchId, headId, rowId, colId, accumId}; + this->taskInfo[curTaskId].rowLine = rowLine; CalcBaseOffsets(curTaskId, false); SecondStagePipeline(taskId); @@ -811,98 +672,9 @@ public: } if (taskId > 0) { - DoTrans(taskId - 1, kGradAccumTemp, qGrad, 0); + DoTrans(taskId - 1, this->kGradAccumTemp, this->qGrad, 0); } } - - GM_ADDR curAICWorkspace; - - // Shape - int64_t batchSize; - int64_t seqLen; - int64_t headNum; - int64_t headDim; - int64_t maxSeqLen; - int64_t biasGradSeqLen; - int64_t blockHeight; - - // Attr - int32_t maskType; - int32_t enableBias; - float siluScale; - - // Tiling - int64_t rowBlockNum; - int64_t colBlockNum; - int64_t totalRowBlockNum; - int64_t totalColBlockNum; - int64_t totalBlockNum; - - // task - BlockInfo taskInfo[COMPUTE_PIPE_NUM]; - - // Tpipe - TPipe pipe; - - // vec score - int64_t vecOnceDataNum; - TQue queueVecScoreQK; - TQue queueVecScoreGV; - TQue queueVecScoreMask; - TQue queueVecScoreBias; - - TQue queueOutputScore; - TQue queueOutputBias; - TQue queueOutputTemp; - - // Gt - GlobalTensor grad; - GlobalTensor q; - GlobalTensor k; - GlobalTensor v; - GlobalTensor attnBias; - GlobalTensor mask; - - GlobalTensor qGrad; - GlobalTensor kGrad; - GlobalTensor vGrad; - GlobalTensor attnBiasGrad; - - GlobalTensor qkTemp; - GlobalTensor gvTemp; - GlobalTensor scoreTemp; - GlobalTensor kGradAccumTemp; // qGrad share temp space with kGrad - GlobalTensor vGradAccumTemp; - GlobalTensor maskTemp; - - // Matmul - matmul::Matmul, - matmul::MatmulType, - matmul::MatmulType, - matmul::MatmulType, CFG_NORM, - matmul::MatmulCallBackFunc, CopyQKB1>> - qkMatmul; - - matmul::Matmul, - matmul::MatmulType, - matmul::MatmulType, - matmul::MatmulType, CFG_NORM, - matmul::MatmulCallBackFunc, CopyVGradB1>> - qGradMatmul; - - matmul::Matmul, - matmul::MatmulType, - matmul::MatmulType, - matmul::MatmulType, CFG_NORM, - matmul::MatmulCallBackFunc, CopyVGradB1>> - kGradMatmul; - - matmul::Matmul, - matmul::MatmulType, - matmul::MatmulType, - matmul::MatmulType, CFG_NORM, - matmul::MatmulCallBackFunc>> - vGradMatmul; }; } // namespace HstuDenseBackward #endif diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h index 1455a79c..4b62d23f 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h @@ -183,6 +183,261 @@ __aicore__ inline void CopyVGradB1(const LocalTensor &bMatrix, const __g int64_t startIdx = row * baseK * headNum * headDim + col * baseN; DataCopy(bMatrix.ReinterpretCast(), globalGt[startIdx], param); }; + + + +struct BlockInfo { + int64_t taskId; + int64_t batchId; + int64_t headId; + int64_t rowId; + int64_t colId; + int64_t accumId; + int64_t qkLeftOffset; + int64_t qkRightOffset; + int64_t kGradLeftOffset; + int64_t vGradRightOffset; + int64_t rowLine; + int64_t colLine; +}; + +template +class HstuDenseBackwardKernelInterface { +public: + __aicore__ inline HstuDenseBackwardKernelInterface() {} + + __aicore__ inline void InitGlobalBuffer(Args &args) + { + GET_TILING_DATA(tilingData, args.tiling); + + batchSize = tilingData.batchSize; + seqLen = tilingData.seqLen; + headNum = tilingData.headNum; + headDim = tilingData.headDim; + + maxSeqLen = tilingData.maxSeqLen; + biasGradSeqLen = tilingData.biasGradSeqLen; + siluScale = tilingData.siluScale; + + blockHeight = tilingData.blockHeight; + + maskType = tilingData.maskType; + enableBias = tilingData.enableBias; + + rowBlockNum = (seqLen + blockHeight - 1) / blockHeight; + colBlockNum = (seqLen + blockHeight - 1) / blockHeight; + totalRowBlockNum = batchSize * headNum * rowBlockNum; + totalColBlockNum = batchSize * headNum * colBlockNum; + totalBlockNum = totalRowBlockNum * colBlockNum; + + int64_t totalElementOfQ = batchSize * maxSeqLen * headNum * headDim; + int64_t totalElementOfAttnBias = batchSize * headNum * biasGradSeqLen * biasGradSeqLen; + + grad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.grad), totalElementOfQ); + q.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.q), totalElementOfQ); + k.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.k), totalElementOfQ); + v.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.v), totalElementOfQ); + if (enableBias) { + attnBias.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBias), totalElementOfAttnBias); + } + if (IfMask(maskType, MaskType::MASK_CUSTOM)) { + mask.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.mask), totalElementOfAttnBias); + } + + qGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.qGrad), totalElementOfQ); + kGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.kGrad), totalElementOfQ); + vGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.vGrad), totalElementOfQ); + attnBiasGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBiasGrad), totalElementOfAttnBias); + } + + __aicore__ inline void InitPipe(Args &args) + { + GM_ADDR workspace = args.workspace; + int64_t qkMatmulTempSpace = blockHeight * blockHeight; + int64_t gvMatmulTempSpace = blockHeight * blockHeight; + int64_t vGradAccumTempSpace = blockHeight * headDim; + int64_t kGradAccumTempSpace = blockHeight * headDim; + int64_t scoreTempSpace = blockHeight * blockHeight; + int64_t maskTempSpace = blockHeight * blockHeight; + + int64_t totalTempSpaceForOneVec = + MID_USE_TIMES * ((vGradAccumTempSpace + kGradAccumTempSpace) * sizeof(float) + + (qkMatmulTempSpace + gvMatmulTempSpace + scoreTempSpace) * sizeof(qType)) + + maskTempSpace * sizeof(qType); + + curAICWorkspace = reinterpret_cast<__gm__ uint8_t *>(workspace) + GetBlockIdx() * totalTempSpaceForOneVec; + + qkTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), qkMatmulTempSpace * MID_USE_TIMES); + curAICWorkspace += qkMatmulTempSpace * sizeof(qType) * MID_USE_TIMES; + + gvTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), gvMatmulTempSpace * MID_USE_TIMES); + curAICWorkspace += gvMatmulTempSpace * sizeof(qType) * MID_USE_TIMES; + + scoreTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), scoreTempSpace * MID_USE_TIMES); + curAICWorkspace += scoreTempSpace * sizeof(qType) * MID_USE_TIMES; + + vGradAccumTemp.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(curAICWorkspace), + vGradAccumTempSpace * MID_USE_TIMES); + curAICWorkspace += vGradAccumTempSpace * sizeof(float) * MID_USE_TIMES; + + kGradAccumTemp.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(curAICWorkspace), + kGradAccumTempSpace * MID_USE_TIMES); + curAICWorkspace += kGradAccumTempSpace * sizeof(float) * MID_USE_TIMES; + + maskTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), maskTempSpace); + + vecOnceDataNum = DATA_ALIGN_BYTES / sizeof(float) * blockHeight; + pipe.InitBuffer(queueVecScoreQK, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); + pipe.InitBuffer(queueVecScoreGV, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); + pipe.InitBuffer(queueVecScoreMask, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); + pipe.InitBuffer(queueVecScoreBias, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); + + pipe.InitBuffer(queueOutputScore, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); + pipe.InitBuffer(queueOutputBias, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); + pipe.InitBuffer(queueOutputTemp, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); + } + + __aicore__ inline void Init(Args &args) + { + InitGlobalBuffer(args); + InitPipe(args); + CreateMask(); + } + + __aicore__ inline void CreateMask() + { + if (IfMask(maskType, MaskType::MASK_TRIL)) { + // create lower triangular + int64_t total = blockHeight * blockHeight; + int64_t remain = total; + int64_t thisLen = vecOnceDataNum; + while (remain > 0) { + if (remain < thisLen) { + thisLen = remain; + } + + int64_t baseOffset = total - remain; + int32_t validNums = 1 + baseOffset / blockHeight; + + LocalTensor input = queueVecScoreMask.AllocTensor(); + Duplicate(input, 0, thisLen); + for (int i = 0; i < thisLen / blockHeight; i++) { + if (validNums + i >= blockHeight) { + Duplicate(input[i * blockHeight], 1, blockHeight); + } else { + Duplicate(input[i * blockHeight], 1, validNums + i); + } + } + queueVecScoreMask.EnQue(input); + + LocalTensor newInput = queueVecScoreMask.DeQue(); + LocalTensor output = queueOutputTemp.AllocTensor(); + DataCopy(output, newInput, thisLen); + queueOutputTemp.EnQue(output); + queueVecScoreMask.FreeTensor(newInput); + + output = queueOutputTemp.DeQue(); + DataCopy(maskTemp[baseOffset], output, thisLen); + queueOutputTemp.FreeTensor(output); + + remain -= thisLen; + } + + pipe_barrier(PIPE_ALL); + } + } + + + GM_ADDR curAICWorkspace; + + // Shape + int64_t batchSize; + int64_t seqLen; + int64_t headNum; + int64_t headDim; + int64_t maxSeqLen; + int64_t biasGradSeqLen; + int64_t blockHeight; + + // Attr + int32_t maskType; + int32_t enableBias; + float siluScale; + + // Tiling + int64_t rowBlockNum; + int64_t colBlockNum; + int64_t totalRowBlockNum; + int64_t totalColBlockNum; + int64_t totalBlockNum; + + // task + BlockInfo taskInfo[COMPUTE_PIPE_NUM]; + + // Tpipe + TPipe pipe; + + // vec score + int64_t vecOnceDataNum; + TQue queueVecScoreQK; + TQue queueVecScoreGV; + TQue queueVecScoreMask; + TQue queueVecScoreBias; + + TQue queueOutputScore; + TQue queueOutputBias; + TQue queueOutputTemp; + + // Gt + GlobalTensor grad; + GlobalTensor q; + GlobalTensor k; + GlobalTensor v; + GlobalTensor attnBias; + GlobalTensor mask; + + GlobalTensor qGrad; + GlobalTensor kGrad; + GlobalTensor vGrad; + GlobalTensor attnBiasGrad; + + GlobalTensor qkTemp; + GlobalTensor gvTemp; + GlobalTensor scoreTemp; + GlobalTensor kGradAccumTemp; // qGrad share temp space with kGrad + GlobalTensor vGradAccumTemp; + GlobalTensor maskTemp; + + // Matmul + matmul::Matmul, + matmul::MatmulType, + matmul::MatmulType, + matmul::MatmulType, CFG_NORM, + matmul::MatmulCallBackFunc, CopyQKB1>> + qkMatmul; + + matmul::Matmul, + matmul::MatmulType, + matmul::MatmulType, + matmul::MatmulType, CFG_NORM, + matmul::MatmulCallBackFunc, CopyVGradB1>> + qGradMatmul; + + matmul::Matmul, + matmul::MatmulType, + matmul::MatmulType, + matmul::MatmulType, CFG_NORM, + matmul::MatmulCallBackFunc, CopyVGradB1>> + kGradMatmul; + + matmul::Matmul, + matmul::MatmulType, + matmul::MatmulType, + matmul::MatmulType, CFG_NORM, + matmul::MatmulCallBackFunc>> + vGradMatmul; +}; + } // namespace HstuDenseBackward #endif \ No newline at end of file -- Gitee From a4827f7fb853769254819a18ceb0d6bf6eec15ac Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 09:48:48 +0800 Subject: [PATCH 09/18] cleancode --- .../op_kernel/hstu_dense_backward_kernel.h | 23 ++++--------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index 1445d3bb..cd3519a6 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -20,32 +20,17 @@ See the License for the specific language governing permissions and namespace HstuDenseBackward { -struct BlockInfo { - int64_t taskId; - int64_t batchId; - int64_t headId; - int64_t rowId; - int64_t colId; - int64_t accumId; - int64_t qkLeftOffset; - int64_t qkRightOffset; - int64_t kGradLeftOffset; - int64_t vGradRightOffset; - int64_t rowLine; - int64_t colLine; -}; - template -class HstuDenseBackwardKernel : public HstuDenseBackwardKernelInterface { +class HstuDenseBackwardKernel : public HstuDenseBackwardKernelInterface { public: __aicore__ inline HstuDenseBackwardKernel() {} __aicore__ inline void Compute(Args &args) { GET_TILING_DATA(tilingData, args.tiling); - REGIST_MATMUL_OBJ(&this->pipe, GetSysWorkSpacePtr(), this->qkMatmul, &tilingData.->qkMatmul, this->qGradMatmul, - &tilingData.->qGradMatmul, this->kGradMatmul, &tilingData.->kGradMatmul, this->vGradMatmul, - &tilingData.->vGradMatmul); + REGIST_MATMUL_OBJ(&this->pipe, GetSysWorkSpacePtr(), this->qkMatmul, &tilingData.qkMatmul, this->qGradMatmul, + &tilingData.qGradMatmul, this->kGradMatmul, &tilingData.kGradMatmul, this->vGradMatmul, + &tilingData.vGradMatmul); uint64_t tilingPtr = reinterpret_cast(args.tiling); this->qkMatmul.SetUserDefInfo(tilingPtr); this->qGradMatmul.SetUserDefInfo(tilingPtr); -- Gitee From 6fc77a8c3e8c896abe52dbcf865e54450000159e Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 09:55:15 +0800 Subject: [PATCH 10/18] cleancode --- .../op_kernel/hstu_dense_backward_kernel.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index cd3519a6..dee3cb42 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -566,7 +566,7 @@ public: } } - __aicore__ inline void this->ComputeFirst() + __aicore__ inline void ComputeFirst() { int64_t taskId = 0; int64_t accumId = 0; @@ -617,7 +617,7 @@ public: this->qGradMatmul.End(); } - __aicore__ inline void this->ComputeSecond() + __aicore__ inline void ComputeSecond() { SyncAll(); -- Gitee From fabbb39460ba77c7cdb81fffab2f5803d49372cf Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 10:00:52 +0800 Subject: [PATCH 11/18] cleancode --- .../op_kernel/hstu_dense_backward_kernel_common.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h index 4b62d23f..7d462810 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h @@ -297,13 +297,6 @@ public: pipe.InitBuffer(queueOutputTemp, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); } - __aicore__ inline void Init(Args &args) - { - InitGlobalBuffer(args); - InitPipe(args); - CreateMask(); - } - __aicore__ inline void CreateMask() { if (IfMask(maskType, MaskType::MASK_TRIL)) { @@ -347,6 +340,13 @@ public: } } + __aicore__ inline void Init(Args &args) + { + InitGlobalBuffer(args); + InitPipe(args); + CreateMask(); + } + GM_ADDR curAICWorkspace; -- Gitee From 8a549109289d7c81c61f5c8a2775e16865c3a7d6 Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 10:30:32 +0800 Subject: [PATCH 12/18] cleancode --- .../op_kernel/hstu_dense_backward_kernel.h | 65 ++++++++++--------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index dee3cb42..bedcca3c 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -44,29 +44,29 @@ public: __aicore__ inline void CalcBaseOffsets(int64_t curTaskId, bool isCol = true) { - this->taskInfo[curTaskId].qkLeftOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + - this->taskInfo[curTaskId].rowId * this->blockHeight * this->headNum * this->headDim + - this->taskInfo[curTaskId].headId * this->headDim; - this->taskInfo[curTaskId].qkRightOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + - this->taskInfo[curTaskId].colId * this->blockHeight * this->headNum * this->headDim + - this->taskInfo[curTaskId].headId * this->headDim; - this->taskInfo[curTaskId].kGradLeftOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->biasGradSeqLen * this->biasGradSeqLen + - this->taskInfo[curTaskId].headId * this->biasGradSeqLen * this->biasGradSeqLen + - this->taskInfo[curTaskId].rowId * this->blockHeight * this->biasGradSeqLen + - this->taskInfo[curTaskId].colId * this->blockHeight; + this->taskInfo[curTaskId].qkLeftOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * + this->headDim + this->taskInfo[curTaskId].rowId * this->blockHeight * this->headNum * this->headDim + + this->taskInfo[curTaskId].headId * this->headDim; + this->taskInfo[curTaskId].qkRightOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * + this->headDim + this->taskInfo[curTaskId].colId * this->blockHeight * + this->headNum * this->headDim + this->taskInfo[curTaskId].headId * this->headDim; + this->taskInfo[curTaskId].kGradLeftOffset = this->taskInfo[curTaskId].batchId * this->headNum * + this->biasGradSeqLen * this->biasGradSeqLen + this->taskInfo[curTaskId].headId * this->biasGradSeqLen * + this->biasGradSeqLen + this->taskInfo[curTaskId].rowId * this->blockHeight * this->biasGradSeqLen + + this->taskInfo[curTaskId].colId * this->blockHeight; if (isCol) { - this->taskInfo[curTaskId].vGradRightOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + - this->taskInfo[curTaskId].rowId * this->blockHeight * this->headNum * this->headDim + - this->taskInfo[curTaskId].headId * this->headDim; + this->taskInfo[curTaskId].vGradRightOffset = this->taskInfo[curTaskId].batchId * this->seqLen * + this->headNum * this->headDim + this->taskInfo[curTaskId].rowId * this->blockHeight * + this->headNum * this->headDim + this->taskInfo[curTaskId].headId * this->headDim; this->taskInfo[curTaskId].rowLine = this->seqLen - this->taskInfo[curTaskId].rowId * this->blockHeight; if (this->taskInfo[curTaskId].rowLine > this->blockHeight) { this->taskInfo[curTaskId].rowLine = this->blockHeight; } } else { - this->taskInfo[curTaskId].vGradRightOffset = this->taskInfo[curTaskId].batchId * this->seqLen * this->headNum * this->headDim + - this->taskInfo[curTaskId].colId * this->blockHeight * this->headNum * this->headDim + - this->taskInfo[curTaskId].headId * this->headDim; + this->taskInfo[curTaskId].vGradRightOffset = this->taskInfo[curTaskId].batchId * this->seqLen * + this->headNum * this->headDim + this->taskInfo[curTaskId].colId * this->blockHeight * + this->headNum * this->headDim + this->taskInfo[curTaskId].headId * this->headDim; this->taskInfo[curTaskId].colLine = this->seqLen - this->taskInfo[curTaskId].colId * this->blockHeight; if (this->taskInfo[curTaskId].colLine > this->blockHeight) { @@ -120,7 +120,8 @@ public: bool isNew = this->taskInfo[curTaskId].colId == 0; this->qGradMatmul.SetTail(this->taskInfo[curTaskId].rowLine, this->headDim, this->taskInfo[curTaskId].colLine); - DoQGradMatmulImpl(this->taskInfo[curTaskId].kGradLeftOffset, this->taskInfo[curTaskId].vGradRightOffset, outOffset, isNew); + DoQGradMatmulImpl(this->taskInfo[curTaskId].kGradLeftOffset, this->taskInfo[curTaskId].vGradRightOffset, + outOffset, isNew); } __aicore__ inline void DoQGradMatmulImpl(int64_t left, int64_t right, int64_t out, bool isNew) @@ -148,7 +149,8 @@ public: } this->kGradMatmul.SetTail(this->taskInfo[curTaskId].colLine, this->headDim, this->taskInfo[curTaskId].rowLine); - DoKGradMatmulImpl(this->taskInfo[curTaskId].kGradLeftOffset, this->taskInfo[curTaskId].vGradRightOffset, outOffset, isNew); + DoKGradMatmulImpl(this->taskInfo[curTaskId].kGradLeftOffset, this->taskInfo[curTaskId].vGradRightOffset, + outOffset, isNew); } __aicore__ inline void DoKGradMatmulImpl(int64_t left, int64_t right, int64_t out, bool isNew) @@ -292,20 +294,21 @@ public: __aicore__ inline void VecScore(int64_t taskId) { int64_t curTaskId = taskId % COMPUTE_PIPE_NUM; - int64_t attnBiasOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->biasGradSeqLen * this->biasGradSeqLen + - this->taskInfo[curTaskId].headId * this->biasGradSeqLen * this->biasGradSeqLen + - this->taskInfo[curTaskId].rowId * this->blockHeight * this->biasGradSeqLen + - this->taskInfo[curTaskId].colId * this->blockHeight; - int64_t attnBiasDiagonalOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->biasGradSeqLen * this->biasGradSeqLen + - this->taskInfo[curTaskId].headId * this->biasGradSeqLen * this->biasGradSeqLen + - this->taskInfo[curTaskId].colId * this->blockHeight * this->biasGradSeqLen + - this->taskInfo[curTaskId].rowId * this->blockHeight; + int64_t attnBiasOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->biasGradSeqLen * + this->biasGradSeqLen + this->taskInfo[curTaskId].headId * this->biasGradSeqLen * this->biasGradSeqLen + + this->taskInfo[curTaskId].rowId * this->blockHeight * this->biasGradSeqLen + + this->taskInfo[curTaskId].colId * this->blockHeight; + int64_t attnBiasDiagonalOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->biasGradSeqLen * + this->biasGradSeqLen + this->taskInfo[curTaskId].headId * this->biasGradSeqLen * this->biasGradSeqLen + + this->taskInfo[curTaskId].colId * this->blockHeight * this->biasGradSeqLen + + this->taskInfo[curTaskId].rowId * this->blockHeight; int64_t maskOffset = 0; if (IfMask(this->maskType, MaskType::MASK_CUSTOM)) { maskOffset = this->taskInfo[curTaskId].batchId * this->headNum * this->maxSeqLen * this->maxSeqLen + - this->taskInfo[curTaskId].headId * this->maxSeqLen * this->maxSeqLen + - this->taskInfo[curTaskId].rowId * this->blockHeight * this->maxSeqLen + this->taskInfo[curTaskId].colId * this->blockHeight; + this->taskInfo[curTaskId].headId * this->maxSeqLen * this->maxSeqLen + + this->taskInfo[curTaskId].rowId * this->blockHeight * this->maxSeqLen + + this->taskInfo[curTaskId].colId * this->blockHeight; } bool useMask = false; @@ -380,7 +383,8 @@ public: LocalTensor outputScore = this->queueOutputScore.template DeQue(); LocalTensor outputBias = this->queueOutputBias.template DeQue(); DataCopy(this->scoreTemp[scoreTempOffset], outputScore, thisLen); - CopyOutPadding(this->attnBiasGrad[curAttnBiasOffset], outputBias, validRowNum, totalColNum, this->biasGradSeqLen); + CopyOutPadding(this->attnBiasGrad[curAttnBiasOffset], outputBias, validRowNum, totalColNum, + this->biasGradSeqLen); this->queueOutputScore.template FreeTensor(outputScore); this->queueOutputBias.template FreeTensor(outputBias); } @@ -609,7 +613,8 @@ public: { DoQGradMatmul(taskId); if (taskId > 0) { - if (this->taskInfo[(taskId - 1) % COMPUTE_PIPE_NUM].accumId != this->taskInfo[taskId % COMPUTE_PIPE_NUM].accumId) { + if (this->taskInfo[(taskId - 1) % COMPUTE_PIPE_NUM].accumId != + this->taskInfo[taskId % COMPUTE_PIPE_NUM].accumId) { DoTrans(taskId - 1, this->kGradAccumTemp, this->qGrad, 0); } } -- Gitee From 8d8b074b4d86a97c4c2c8b4b2a7cb0d3c2ac973f Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 10:32:31 +0800 Subject: [PATCH 13/18] cleancode --- .../hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h | 2 -- .../op_kernel/hstu_dense_backward_kernel_common.h | 3 --- 2 files changed, 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index bedcca3c..c2955f6b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -195,8 +195,6 @@ public: } } - - __aicore__ inline void CastQType2Float(LocalTensor dstTensor, LocalTensor srcTensor, LocalTensor midTensor, int64_t len) { diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h index 7d462810..3ef975d6 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h @@ -184,8 +184,6 @@ __aicore__ inline void CopyVGradB1(const LocalTensor &bMatrix, const __g DataCopy(bMatrix.ReinterpretCast(), globalGt[startIdx], param); }; - - struct BlockInfo { int64_t taskId; int64_t batchId; @@ -347,7 +345,6 @@ public: CreateMask(); } - GM_ADDR curAICWorkspace; // Shape -- Gitee From 08d89c7c07ca68b6777bcfba6a921012ad2353ca Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 10:44:04 +0800 Subject: [PATCH 14/18] cleancode --- .../op_kernel/hstu_dense_backward_kernel_common.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h index 3ef975d6..6becc77d 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h @@ -295,6 +295,13 @@ public: pipe.InitBuffer(queueOutputTemp, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); } + __aicore__ inline void Init(Args &args) + { + InitGlobalBuffer(args); + InitPipe(args); + CreateMask(); + } + __aicore__ inline void CreateMask() { if (IfMask(maskType, MaskType::MASK_TRIL)) { @@ -338,13 +345,6 @@ public: } } - __aicore__ inline void Init(Args &args) - { - InitGlobalBuffer(args); - InitPipe(args); - CreateMask(); - } - GM_ADDR curAICWorkspace; // Shape -- Gitee From 7956ee3f0a7cc3470460f298542725bdd66ec796 Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 10:48:57 +0800 Subject: [PATCH 15/18] cleancode --- .../hstu_dense_backward_kernel_common.h | 237 ---------------- .../hstu_dense_backward_kernel_interface.h | 258 ++++++++++++++++++ 2 files changed, 258 insertions(+), 237 deletions(-) create mode 100644 mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h index 6becc77d..73e2e754 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_common.h @@ -198,243 +198,6 @@ struct BlockInfo { int64_t rowLine; int64_t colLine; }; - -template -class HstuDenseBackwardKernelInterface { -public: - __aicore__ inline HstuDenseBackwardKernelInterface() {} - - __aicore__ inline void InitGlobalBuffer(Args &args) - { - GET_TILING_DATA(tilingData, args.tiling); - - batchSize = tilingData.batchSize; - seqLen = tilingData.seqLen; - headNum = tilingData.headNum; - headDim = tilingData.headDim; - - maxSeqLen = tilingData.maxSeqLen; - biasGradSeqLen = tilingData.biasGradSeqLen; - siluScale = tilingData.siluScale; - - blockHeight = tilingData.blockHeight; - - maskType = tilingData.maskType; - enableBias = tilingData.enableBias; - - rowBlockNum = (seqLen + blockHeight - 1) / blockHeight; - colBlockNum = (seqLen + blockHeight - 1) / blockHeight; - totalRowBlockNum = batchSize * headNum * rowBlockNum; - totalColBlockNum = batchSize * headNum * colBlockNum; - totalBlockNum = totalRowBlockNum * colBlockNum; - - int64_t totalElementOfQ = batchSize * maxSeqLen * headNum * headDim; - int64_t totalElementOfAttnBias = batchSize * headNum * biasGradSeqLen * biasGradSeqLen; - - grad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.grad), totalElementOfQ); - q.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.q), totalElementOfQ); - k.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.k), totalElementOfQ); - v.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.v), totalElementOfQ); - if (enableBias) { - attnBias.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBias), totalElementOfAttnBias); - } - if (IfMask(maskType, MaskType::MASK_CUSTOM)) { - mask.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.mask), totalElementOfAttnBias); - } - - qGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.qGrad), totalElementOfQ); - kGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.kGrad), totalElementOfQ); - vGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.vGrad), totalElementOfQ); - attnBiasGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBiasGrad), totalElementOfAttnBias); - } - - __aicore__ inline void InitPipe(Args &args) - { - GM_ADDR workspace = args.workspace; - int64_t qkMatmulTempSpace = blockHeight * blockHeight; - int64_t gvMatmulTempSpace = blockHeight * blockHeight; - int64_t vGradAccumTempSpace = blockHeight * headDim; - int64_t kGradAccumTempSpace = blockHeight * headDim; - int64_t scoreTempSpace = blockHeight * blockHeight; - int64_t maskTempSpace = blockHeight * blockHeight; - - int64_t totalTempSpaceForOneVec = - MID_USE_TIMES * ((vGradAccumTempSpace + kGradAccumTempSpace) * sizeof(float) + - (qkMatmulTempSpace + gvMatmulTempSpace + scoreTempSpace) * sizeof(qType)) + - maskTempSpace * sizeof(qType); - - curAICWorkspace = reinterpret_cast<__gm__ uint8_t *>(workspace) + GetBlockIdx() * totalTempSpaceForOneVec; - - qkTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), qkMatmulTempSpace * MID_USE_TIMES); - curAICWorkspace += qkMatmulTempSpace * sizeof(qType) * MID_USE_TIMES; - - gvTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), gvMatmulTempSpace * MID_USE_TIMES); - curAICWorkspace += gvMatmulTempSpace * sizeof(qType) * MID_USE_TIMES; - - scoreTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), scoreTempSpace * MID_USE_TIMES); - curAICWorkspace += scoreTempSpace * sizeof(qType) * MID_USE_TIMES; - - vGradAccumTemp.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(curAICWorkspace), - vGradAccumTempSpace * MID_USE_TIMES); - curAICWorkspace += vGradAccumTempSpace * sizeof(float) * MID_USE_TIMES; - - kGradAccumTemp.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(curAICWorkspace), - kGradAccumTempSpace * MID_USE_TIMES); - curAICWorkspace += kGradAccumTempSpace * sizeof(float) * MID_USE_TIMES; - - maskTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), maskTempSpace); - - vecOnceDataNum = DATA_ALIGN_BYTES / sizeof(float) * blockHeight; - pipe.InitBuffer(queueVecScoreQK, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); - pipe.InitBuffer(queueVecScoreGV, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); - pipe.InitBuffer(queueVecScoreMask, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); - pipe.InitBuffer(queueVecScoreBias, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); - - pipe.InitBuffer(queueOutputScore, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); - pipe.InitBuffer(queueOutputBias, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); - pipe.InitBuffer(queueOutputTemp, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); - } - - __aicore__ inline void Init(Args &args) - { - InitGlobalBuffer(args); - InitPipe(args); - CreateMask(); - } - - __aicore__ inline void CreateMask() - { - if (IfMask(maskType, MaskType::MASK_TRIL)) { - // create lower triangular - int64_t total = blockHeight * blockHeight; - int64_t remain = total; - int64_t thisLen = vecOnceDataNum; - while (remain > 0) { - if (remain < thisLen) { - thisLen = remain; - } - - int64_t baseOffset = total - remain; - int32_t validNums = 1 + baseOffset / blockHeight; - - LocalTensor input = queueVecScoreMask.AllocTensor(); - Duplicate(input, 0, thisLen); - for (int i = 0; i < thisLen / blockHeight; i++) { - if (validNums + i >= blockHeight) { - Duplicate(input[i * blockHeight], 1, blockHeight); - } else { - Duplicate(input[i * blockHeight], 1, validNums + i); - } - } - queueVecScoreMask.EnQue(input); - - LocalTensor newInput = queueVecScoreMask.DeQue(); - LocalTensor output = queueOutputTemp.AllocTensor(); - DataCopy(output, newInput, thisLen); - queueOutputTemp.EnQue(output); - queueVecScoreMask.FreeTensor(newInput); - - output = queueOutputTemp.DeQue(); - DataCopy(maskTemp[baseOffset], output, thisLen); - queueOutputTemp.FreeTensor(output); - - remain -= thisLen; - } - - pipe_barrier(PIPE_ALL); - } - } - - GM_ADDR curAICWorkspace; - - // Shape - int64_t batchSize; - int64_t seqLen; - int64_t headNum; - int64_t headDim; - int64_t maxSeqLen; - int64_t biasGradSeqLen; - int64_t blockHeight; - - // Attr - int32_t maskType; - int32_t enableBias; - float siluScale; - - // Tiling - int64_t rowBlockNum; - int64_t colBlockNum; - int64_t totalRowBlockNum; - int64_t totalColBlockNum; - int64_t totalBlockNum; - - // task - BlockInfo taskInfo[COMPUTE_PIPE_NUM]; - - // Tpipe - TPipe pipe; - - // vec score - int64_t vecOnceDataNum; - TQue queueVecScoreQK; - TQue queueVecScoreGV; - TQue queueVecScoreMask; - TQue queueVecScoreBias; - - TQue queueOutputScore; - TQue queueOutputBias; - TQue queueOutputTemp; - - // Gt - GlobalTensor grad; - GlobalTensor q; - GlobalTensor k; - GlobalTensor v; - GlobalTensor attnBias; - GlobalTensor mask; - - GlobalTensor qGrad; - GlobalTensor kGrad; - GlobalTensor vGrad; - GlobalTensor attnBiasGrad; - - GlobalTensor qkTemp; - GlobalTensor gvTemp; - GlobalTensor scoreTemp; - GlobalTensor kGradAccumTemp; // qGrad share temp space with kGrad - GlobalTensor vGradAccumTemp; - GlobalTensor maskTemp; - - // Matmul - matmul::Matmul, - matmul::MatmulType, - matmul::MatmulType, - matmul::MatmulType, CFG_NORM, - matmul::MatmulCallBackFunc, CopyQKB1>> - qkMatmul; - - matmul::Matmul, - matmul::MatmulType, - matmul::MatmulType, - matmul::MatmulType, CFG_NORM, - matmul::MatmulCallBackFunc, CopyVGradB1>> - qGradMatmul; - - matmul::Matmul, - matmul::MatmulType, - matmul::MatmulType, - matmul::MatmulType, CFG_NORM, - matmul::MatmulCallBackFunc, CopyVGradB1>> - kGradMatmul; - - matmul::Matmul, - matmul::MatmulType, - matmul::MatmulType, - matmul::MatmulType, CFG_NORM, - matmul::MatmulCallBackFunc>> - vGradMatmul; -}; - } // namespace HstuDenseBackward #endif \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h new file mode 100644 index 00000000..c481c2f5 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h @@ -0,0 +1,258 @@ +/* Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ +#ifndef HSTU_DENSE_BACKWARD_KERNEL_INTERFACE_H +#define HSTU_DENSE_BACKWARD_KERNEL_INTERFACE_H + + +namespace HstuDenseBackward { + +template +class HstuDenseBackwardKernelInterface { +public: + __aicore__ inline HstuDenseBackwardKernelInterface() {} + + __aicore__ inline void InitGlobalBuffer(Args &args) + { + GET_TILING_DATA(tilingData, args.tiling); + + batchSize = tilingData.batchSize; + seqLen = tilingData.seqLen; + headNum = tilingData.headNum; + headDim = tilingData.headDim; + + maxSeqLen = tilingData.maxSeqLen; + biasGradSeqLen = tilingData.biasGradSeqLen; + siluScale = tilingData.siluScale; + + blockHeight = tilingData.blockHeight; + + maskType = tilingData.maskType; + enableBias = tilingData.enableBias; + + rowBlockNum = (seqLen + blockHeight - 1) / blockHeight; + colBlockNum = (seqLen + blockHeight - 1) / blockHeight; + totalRowBlockNum = batchSize * headNum * rowBlockNum; + totalColBlockNum = batchSize * headNum * colBlockNum; + totalBlockNum = totalRowBlockNum * colBlockNum; + + int64_t totalElementOfQ = batchSize * maxSeqLen * headNum * headDim; + int64_t totalElementOfAttnBias = batchSize * headNum * biasGradSeqLen * biasGradSeqLen; + + grad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.grad), totalElementOfQ); + q.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.q), totalElementOfQ); + k.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.k), totalElementOfQ); + v.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.v), totalElementOfQ); + if (enableBias) { + attnBias.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBias), totalElementOfAttnBias); + } + if (IfMask(maskType, MaskType::MASK_CUSTOM)) { + mask.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.mask), totalElementOfAttnBias); + } + + qGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.qGrad), totalElementOfQ); + kGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.kGrad), totalElementOfQ); + vGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.vGrad), totalElementOfQ); + attnBiasGrad.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(args.attnBiasGrad), totalElementOfAttnBias); + } + + __aicore__ inline void InitPipe(Args &args) + { + GM_ADDR workspace = args.workspace; + int64_t qkMatmulTempSpace = blockHeight * blockHeight; + int64_t gvMatmulTempSpace = blockHeight * blockHeight; + int64_t vGradAccumTempSpace = blockHeight * headDim; + int64_t kGradAccumTempSpace = blockHeight * headDim; + int64_t scoreTempSpace = blockHeight * blockHeight; + int64_t maskTempSpace = blockHeight * blockHeight; + + int64_t totalTempSpaceForOneVec = + MID_USE_TIMES * ((vGradAccumTempSpace + kGradAccumTempSpace) * sizeof(float) + + (qkMatmulTempSpace + gvMatmulTempSpace + scoreTempSpace) * sizeof(qType)) + + maskTempSpace * sizeof(qType); + + curAICWorkspace = reinterpret_cast<__gm__ uint8_t *>(workspace) + GetBlockIdx() * totalTempSpaceForOneVec; + + qkTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), qkMatmulTempSpace * MID_USE_TIMES); + curAICWorkspace += qkMatmulTempSpace * sizeof(qType) * MID_USE_TIMES; + + gvTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), gvMatmulTempSpace * MID_USE_TIMES); + curAICWorkspace += gvMatmulTempSpace * sizeof(qType) * MID_USE_TIMES; + + scoreTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), scoreTempSpace * MID_USE_TIMES); + curAICWorkspace += scoreTempSpace * sizeof(qType) * MID_USE_TIMES; + + vGradAccumTemp.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(curAICWorkspace), + vGradAccumTempSpace * MID_USE_TIMES); + curAICWorkspace += vGradAccumTempSpace * sizeof(float) * MID_USE_TIMES; + + kGradAccumTemp.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(curAICWorkspace), + kGradAccumTempSpace * MID_USE_TIMES); + curAICWorkspace += kGradAccumTempSpace * sizeof(float) * MID_USE_TIMES; + + maskTemp.SetGlobalBuffer(reinterpret_cast<__gm__ qType *>(curAICWorkspace), maskTempSpace); + + vecOnceDataNum = DATA_ALIGN_BYTES / sizeof(float) * blockHeight; + pipe.InitBuffer(queueVecScoreQK, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); + pipe.InitBuffer(queueVecScoreGV, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); + pipe.InitBuffer(queueVecScoreMask, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); + pipe.InitBuffer(queueVecScoreBias, USE_BUFFER_NUM, vecOnceDataNum * sizeof(float)); + + pipe.InitBuffer(queueOutputScore, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); + pipe.InitBuffer(queueOutputBias, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); + pipe.InitBuffer(queueOutputTemp, USE_BUFFER_NUM, vecOnceDataNum * sizeof(qType)); + } + + __aicore__ inline void Init(Args &args) + { + InitGlobalBuffer(args); + InitPipe(args); + CreateMask(); + } + + __aicore__ inline void CreateMask() + { + if (IfMask(maskType, MaskType::MASK_TRIL)) { + // create lower triangular + int64_t total = blockHeight * blockHeight; + int64_t remain = total; + int64_t thisLen = vecOnceDataNum; + while (remain > 0) { + if (remain < thisLen) { + thisLen = remain; + } + + int64_t baseOffset = total - remain; + int32_t validNums = 1 + baseOffset / blockHeight; + + LocalTensor input = queueVecScoreMask.AllocTensor(); + Duplicate(input, 0, thisLen); + for (int i = 0; i < thisLen / blockHeight; i++) { + if (validNums + i >= blockHeight) { + Duplicate(input[i * blockHeight], 1, blockHeight); + } else { + Duplicate(input[i * blockHeight], 1, validNums + i); + } + } + queueVecScoreMask.EnQue(input); + + LocalTensor newInput = queueVecScoreMask.DeQue(); + LocalTensor output = queueOutputTemp.AllocTensor(); + DataCopy(output, newInput, thisLen); + queueOutputTemp.EnQue(output); + queueVecScoreMask.FreeTensor(newInput); + + output = queueOutputTemp.DeQue(); + DataCopy(maskTemp[baseOffset], output, thisLen); + queueOutputTemp.FreeTensor(output); + + remain -= thisLen; + } + + pipe_barrier(PIPE_ALL); + } + } + + GM_ADDR curAICWorkspace; + + // Shape + int64_t batchSize; + int64_t seqLen; + int64_t headNum; + int64_t headDim; + int64_t maxSeqLen; + int64_t biasGradSeqLen; + int64_t blockHeight; + + // Attr + int32_t maskType; + int32_t enableBias; + float siluScale; + + // Tiling + int64_t rowBlockNum; + int64_t colBlockNum; + int64_t totalRowBlockNum; + int64_t totalColBlockNum; + int64_t totalBlockNum; + + // task + BlockInfo taskInfo[COMPUTE_PIPE_NUM]; + + // Tpipe + TPipe pipe; + + // vec score + int64_t vecOnceDataNum; + TQue queueVecScoreQK; + TQue queueVecScoreGV; + TQue queueVecScoreMask; + TQue queueVecScoreBias; + + TQue queueOutputScore; + TQue queueOutputBias; + TQue queueOutputTemp; + + // Gt + GlobalTensor grad; + GlobalTensor q; + GlobalTensor k; + GlobalTensor v; + GlobalTensor attnBias; + GlobalTensor mask; + + GlobalTensor qGrad; + GlobalTensor kGrad; + GlobalTensor vGrad; + GlobalTensor attnBiasGrad; + + GlobalTensor qkTemp; + GlobalTensor gvTemp; + GlobalTensor scoreTemp; + GlobalTensor kGradAccumTemp; // qGrad share temp space with kGrad + GlobalTensor vGradAccumTemp; + GlobalTensor maskTemp; + + // Matmul + matmul::Matmul, + matmul::MatmulType, + matmul::MatmulType, + matmul::MatmulType, CFG_NORM, + matmul::MatmulCallBackFunc, CopyQKB1>> + qkMatmul; + + matmul::Matmul, + matmul::MatmulType, + matmul::MatmulType, + matmul::MatmulType, CFG_NORM, + matmul::MatmulCallBackFunc, CopyVGradB1>> + qGradMatmul; + + matmul::Matmul, + matmul::MatmulType, + matmul::MatmulType, + matmul::MatmulType, CFG_NORM, + matmul::MatmulCallBackFunc, CopyVGradB1>> + kGradMatmul; + + matmul::Matmul, + matmul::MatmulType, + matmul::MatmulType, + matmul::MatmulType, CFG_NORM, + matmul::MatmulCallBackFunc>> + vGradMatmul; +}; +} + +#endif \ No newline at end of file -- Gitee From 6ab1c2f03fcebf88cf08635a7c7d7bc3a2cb341a Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 10:50:51 +0800 Subject: [PATCH 16/18] cleancode --- .../hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h | 2 +- .../op_kernel/hstu_dense_backward_kernel_interface.h | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index c2955f6b..0fad8317 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -16,7 +16,7 @@ See the License for the specific language governing permissions and #ifndef HSTU_DENSE_BACKWARD_KERNEL_H #define HSTU_DENSE_BACKWARD_KERNEL_H -#include "hstu_dense_backward_kernel_common.h" +#include "hstu_dense_backward_kernel_interface.h" namespace HstuDenseBackward { diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h index c481c2f5..d01288ee 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h @@ -15,6 +15,7 @@ See the License for the specific language governing permissions and #ifndef HSTU_DENSE_BACKWARD_KERNEL_INTERFACE_H #define HSTU_DENSE_BACKWARD_KERNEL_INTERFACE_H +#include "hstu_dense_backward_kernel_common.h" namespace HstuDenseBackward { -- Gitee From bef2551cb690dcae7ccd24a998c9db9ff80f3f7e Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 11:08:27 +0800 Subject: [PATCH 17/18] cleancode --- .../hstu_dense_backward_kernel_interface.h | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h index d01288ee..99c5bd59 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + + #ifndef HSTU_DENSE_BACKWARD_KERNEL_INTERFACE_H #define HSTU_DENSE_BACKWARD_KERNEL_INTERFACE_H @@ -122,6 +124,18 @@ public: CreateMask(); } + __aicore__ inline void DuplicateInput(LocalTensor &input, int32_t validNums) + { + Duplicate(input, 0, thisLen); + for (int i = 0; i < thisLen / blockHeight; i++) { + if (validNums + i >= blockHeight) { + Duplicate(input[i * blockHeight], 1, blockHeight); + } else { + Duplicate(input[i * blockHeight], 1, validNums + i); + } + } + } + __aicore__ inline void CreateMask() { if (IfMask(maskType, MaskType::MASK_TRIL)) { @@ -138,14 +152,7 @@ public: int32_t validNums = 1 + baseOffset / blockHeight; LocalTensor input = queueVecScoreMask.AllocTensor(); - Duplicate(input, 0, thisLen); - for (int i = 0; i < thisLen / blockHeight; i++) { - if (validNums + i >= blockHeight) { - Duplicate(input[i * blockHeight], 1, blockHeight); - } else { - Duplicate(input[i * blockHeight], 1, validNums + i); - } - } + DuplicateInput(input, validNums); queueVecScoreMask.EnQue(input); LocalTensor newInput = queueVecScoreMask.DeQue(); -- Gitee From bc43c68a5a799035aed85cf3774109bcbfd375c6 Mon Sep 17 00:00:00 2001 From: zxorange_321 Date: Fri, 27 Jun 2025 11:09:58 +0800 Subject: [PATCH 18/18] cleancode --- .../op_kernel/hstu_dense_backward_kernel_interface.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h index 99c5bd59..4571b385 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel_interface.h @@ -124,7 +124,7 @@ public: CreateMask(); } - __aicore__ inline void DuplicateInput(LocalTensor &input, int32_t validNums) + __aicore__ inline void DuplicateInput(LocalTensor &input, int64_t thisLen, int32_t validNums) { Duplicate(input, 0, thisLen); for (int i = 0; i < thisLen / blockHeight; i++) { @@ -152,7 +152,7 @@ public: int32_t validNums = 1 + baseOffset / blockHeight; LocalTensor input = queueVecScoreMask.AllocTensor(); - DuplicateInput(input, validNums); + DuplicateInput(input, thisLen, validNums); queueVecScoreMask.EnQue(input); LocalTensor newInput = queueVecScoreMask.DeQue(); -- Gitee