From da692e3e07852a15057680cee90f39df234099f7 Mon Sep 17 00:00:00 2001 From: liqiang Date: Thu, 21 Aug 2025 14:20:29 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96hstu=5Fdense=5Fbackward?= =?UTF-8?q?=E7=AE=97=E5=AD=90=EF=BC=9A1.=E5=88=9D=E5=A7=8B=E5=8C=96?= =?UTF-8?q?=E6=B8=85=E9=9B=B6qGrad=E6=97=B6=E5=88=86=E6=A0=B8=E4=BC=98?= =?UTF-8?q?=E5=8C=96=EF=BC=9B2.=E7=BB=93=E6=9D=9F=E6=8B=B7=E8=B4=9D?= =?UTF-8?q?=E7=BB=93=E6=9E=9C=E6=97=B6=E5=88=86=E6=A0=B8=E4=BC=98=E5=8C=96?= =?UTF-8?q?=EF=BC=9B3.bf16=E5=B0=86matmul=E9=BB=98=E8=AE=A4=E8=BF=9E?= =?UTF-8?q?=E7=BB=AD=E9=A2=84=E5=8F=96=E5=9D=97=E6=95=B0=E9=87=8F=E6=94=B9?= =?UTF-8?q?=E4=B8=BA2=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: liqiang --- .../op_host/hstu_dense_backward.cpp | 12 +++++++++ .../hstu_dense_backward_jagged_kernel.h | 4 +-- .../op_kernel/hstu_dense_backward_kernel.h | 27 +++++++++++++++++-- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp index 5dc0f296..4095d31b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_host/hstu_dense_backward.cpp @@ -132,6 +132,18 @@ static ge::graphStatus TilingCommonFunc(gert::TilingContext *context, HstuDenseB vGradMatmul.GetTiling(tiling.vGradMatmul) == -1) { return ge::GRAPH_FAILED; } + // 设置matmul预取块个数,默认为6,GR: BLOCKHEIGHT=256,bf16,提升4ms + if (gradType == ge::DataType::DT_BF16) { + int64_t depth = 2; + tiling.qkMatmul.set_depthA1(depth); + tiling.qkMatmul.set_depthB1(depth); + tiling.qGradMatmul.set_depthA1(depth); + tiling.qGradMatmul.set_depthB1(depth); + tiling.kGradMatmul.set_depthA1(depth); + tiling.kGradMatmul.set_depthB1(depth); + tiling.vGradMatmul.set_depthA1(depth); + tiling.vGradMatmul.set_depthB1(depth); + } context->SetBlockDim(coreNum); tiling.set_aivNum(vecCoreNum); diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_jagged_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_jagged_kernel.h index f11cfe55..56ee6916 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_jagged_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_jagged_kernel.h @@ -570,9 +570,7 @@ public: { SyncAll(); - if (GetBlockIdx() == 0) { - this->DoCopyQGrad(backwardTilingData->seqOffset); - } + this->DoCopyQGrad(backwardTilingData->seqOffset); } protected: diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h index d4e50f98..ccef1cea 100644 --- a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_backward/op_kernel/hstu_dense_backward_kernel.h @@ -156,7 +156,21 @@ public: // 所有核共享一片globalMemory,且存在累加操作,每次执行需要清理内存防止上次执行结果残留数据影响本次结果 // 多核执行后需要调用SyncAll保证多核间同步正常 - InitGlobalMemory(qGradAccumTemp, qGradAccumTempSpace, static_cast(0)); + // 分核清零逻辑: + // 1. 将待清零长度分为batchSize份,qGradAccumTempSpace正好能整除batchSize + // 2. 全核参与清零,当核数大于batchSize,多余核不参与,当核数小于batchSize,每个核负责清理blockNum范围内固定划片 + // PS. GR场景优化8ms。By liqiang 2025.08 + int64_t blockNum = GetBlockNum(); + uint64_t unitClear = qGradAccumTempSpace/batchSize; + int64_t batchIdx = GetBlockIdx(); + while (batchIdx < batchSize) { + GlobalTensor thisBlockQGrad; + thisBlockQGrad.SetGlobalBuffer(reinterpret_cast<__gm__ float *>( + reinterpret_cast<__gm__ uint8_t *>(workspace) + aivNum * totalTempSpaceForOneVec + batchIdx * unitClear * sizeof(float)), unitClear); + InitGlobalMemory(thisBlockQGrad, unitClear, static_cast(0)); + batchIdx += blockNum; + } + SyncAll(); } @@ -852,7 +866,14 @@ public: __aicore__ inline void DoCopyQGrad(const uint32_t *seqOffset) { - for (int64_t batchIdx = 0; batchIdx < batchSize; batchIdx++) { + // 分核拷贝逻辑: + // 1. 以batchSize为划分粒度,每个核认领一块 + // 2. 当核数大于batchSize,多余的核不需要干活 + // 3. 当核数小于batchSize,多的batchSize依然按每个核认领一块,直到全部结束 + // PS. GR此处优化15ms,占单算子13%,by liqiang 2025.08 + int64_t blockNum = GetBlockNum(); + int64_t batchIdx = GetBlockIdx(); + while (batchIdx < batchSize) { int64_t curSeqLen = static_cast(seqOffset[batchIdx + 1] - seqOffset[batchIdx]); for (int64_t headIdx = 0; headIdx < headNum; headIdx++) { int64_t totalLen = curSeqLen * headDim; @@ -894,6 +915,8 @@ public: remain = remain - thisLen; } } + // 按核数分块后,每个核处理每一块的相应位置,最后超出batchSize退出 + batchIdx += blockNum; } } -- Gitee