diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_forward/op_kernel/hstu_dense_forward.cpp b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_forward/op_kernel/hstu_dense_forward.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5ded2f30fd8f648741da8eb086251c99f9963b45 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_forward/op_kernel/hstu_dense_forward.cpp @@ -0,0 +1,81 @@ +/* Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + + +#ifdef SUPPORT_V200 +#include "hstu_dense_forward_kernel_v200.h" + +template +__aicore__ inline void InvokeHstuOpImpl(const HstuDenseForward::Args &args) +{ + TPipe tPipe; + T op; + GET_TILING_DATA(tilingData, args.tiling); + const HstuDenseForwardTilingData *__restrict tilingDataPtr = &tilingData; + REGIST_MATMUL_OBJ(&tPipe, GetSysWorkSpacePtr(), op.qkMatmul, &tilingDataPtr->qkMatmul, op.svMatmul, + &tilingDataPtr->svMatmul); + op.Init(args, tilingDataPtr, &tPipe); + op.Compute(tilingDataPtr); +} + +#else +#include "hstu_dense_forward_jagged_kernel.h" +#include "hstu_dense_forward_kernel.h" + +template +__aicore__ inline void InvokeHstuOpImpl(const HstuDenseForward::Args &args) +{ + TPipe tPipe; + T op; + GET_TILING_DATA(tilingData, args.tiling); + const HstuDenseForwardTilingData *__restrict tilingDataPtr = &tilingData; + REGIST_MATMUL_OBJ(&tPipe, GetSysWorkSpacePtr(), op.qkMatmul, &tilingDataPtr->qkMatmul, op.svMatmul, + &tilingDataPtr->svMatmul); + uint64_t tilingPtr = reinterpret_cast(args.tiling); + op.qkMatmul.SetUserDefInfo(tilingPtr); + op.svMatmul.SetUserDefInfo(tilingPtr); + op.Init(args, tilingDataPtr, &tPipe); + op.Compute(tilingDataPtr); +} + +#endif + +#include "kernel_operator.h" + +extern "C" __global__ __aicore__ void hstu_dense_forward(GM_ADDR q, GM_ADDR k, GM_ADDR v, GM_ADDR mask, + GM_ADDR attnBias, GM_ADDR attnOutput, GM_ADDR workspace, + GM_ADDR tiling) +{ + HstuDenseForward::Args args{q, k, v, attnBias, mask, attnOutput, workspace, tiling}; +#ifdef SUPPORT_V200 + if (TILING_KEY_IS(0)) { + InvokeHstuOpImpl>(args); + } +#else + if (TILING_KEY_IS(0)) { + InvokeHstuOpImpl>(args); + } else if (TILING_KEY_IS(1)) { + InvokeHstuOpImpl>(args); + } else if (TILING_KEY_IS(2)) { + InvokeHstuOpImpl>(args); + } else if (TILING_KEY_IS(3)) { + InvokeHstuOpImpl>(args); + } else if (TILING_KEY_IS(4)) { + InvokeHstuOpImpl>(args); + } else if (TILING_KEY_IS(5)) { + InvokeHstuOpImpl>(args); + } +#endif +} \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_forward/op_kernel/hstu_dense_forward_jagged_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_forward/op_kernel/hstu_dense_forward_jagged_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..dd73088352ec64ec20882751316eacda9b61c940 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_forward/op_kernel/hstu_dense_forward_jagged_kernel.h @@ -0,0 +1,390 @@ +/* Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ +#ifndef HSTU_DENSE_FORWARD_JAGGED_KERNEL_FUN_H +#define HSTU_DENSE_FORWARD_JAGGED_KERNEL_FUN_H + + +#include "hstu_dense_forward_kernel_patten_bsnd.h" + +using namespace AscendC; + +namespace HstuDenseForward { + +struct JaggedTaskArgs { + uint32_t batchId = 0; // 该基本块所属的batch + uint32_t headId = 0; // 该基本块所属的head + uint32_t qSeqId = 0; // 该基本块所属Query 输入的第几个seq block 一个block是256条seq + uint32_t kSeqId = 0; // 该基本块所属Key 输入的第几个seq block 一个block是256条seq + uint32_t actualSeqLen = 0; // 该基本块实际的序列长度 + uint32_t kSeqNum = 0; // 该基本块在K轴需要乘多少次 + uint32_t causalMask = 0; // 该基本块是否需要做causal 掩码 + uint32_t transTaskId = 0; // 该基本块转置任务的id + uint32_t computeASeqLen = 0; // 该基本块matmul计算左矩阵的序列长度 + uint32_t computeBSeqLen = 0; // 该基本块matmul计算右矩阵的序列长度 + float scale = 0.0f; // 该基本块的siluScale + int64_t seqGlobalOffset = 0; // 该基本块的全局序列偏移 + int64_t batchOffset = 0; // 该基本块的batch偏移 + int64_t headSeqLimit = 0; // 该基本块的head offset最大长度, 超过则需要考虑切换head_id + int64_t kvOffset = 0; // 该基本块的key value计算偏移 + int64_t ioOffset = 0; // 该基本块的query attenOutput计算偏移 +}; + +template +class HstuDenseForwardJaggedKernel : public HstuDenseForwardKernelPattenBsnd { +public: + __aicore__ inline HstuDenseForwardJaggedKernel() {} + + __aicore__ inline void Compute(const HstuDenseForwardTilingData *__restrict tilingDataPtr); + + __aicore__ inline void ComputeAllBlock(); + +private: + __aicore__ inline int PreInit(const HstuDenseForwardTilingData *__restrict tilingDataPtr); + + __aicore__ inline void GetTaskInfo(uint32_t sBlkId); + + __aicore__ inline void UpdateTaskInfo(uint32_t taskId); + + __aicore__ inline void FillTaskInfo(uint32_t batchId, uint32_t head_id, int64_t seqGlobalOffset, uint32_t taskId); + + __aicore__ inline void ComputeQkMatmul(uint32_t taskId); + + __aicore__ inline void ComputeVecScore(uint32_t taskId); + + __aicore__ inline void ComputeSvMatmul(uint32_t taskId); + + __aicore__ inline void TransResult(uint32_t transtaskId); + + uint32_t seqOffsets[MAX_BATCH_SIZE + 1]; + uint32_t sBlkId {0}; + uint32_t eBlkId {0}; + uint32_t maxSeqLen {0}; + + uint32_t batchSize {0}; + uint32_t seqLen {0}; + uint32_t headNum {0}; + uint32_t headDim {0}; + + JaggedTaskArgs computeTaskInfo[COMPUTE_PIPE_NUM]; + JaggedTaskArgs trasnTaskInfo[TRANS_PIPE_NUM]; +}; + +template +__aicore__ inline void +HstuDenseForwardJaggedKernel::Compute(const HstuDenseForwardTilingData *__restrict tilingDataPtr) +{ + int ret = PreInit(tilingDataPtr); + if (ret == -1) { + return; // no task + } + ComputeAllBlock(); +} + +template +__aicore__ inline void HstuDenseForwardJaggedKernel::ComputeSvMatmul(uint32_t taskId) +{ + int isAtomic = 1; + if (computeTaskInfo[taskId].kSeqId == 0) { + isAtomic = 0; + } + + this->DoSvMatmulImpl(computeTaskInfo[taskId].kvOffset, taskId, computeTaskInfo[taskId].transTaskId, isAtomic, + computeTaskInfo[taskId].computeASeqLen, this->headDim, computeTaskInfo[taskId].computeBSeqLen); +} + +template +__aicore__ inline void HstuDenseForwardJaggedKernel::ComputeQkMatmul(uint32_t taskId) +{ + this->DoQkMatmulImpl(computeTaskInfo[taskId].ioOffset, computeTaskInfo[taskId].kvOffset, taskId, + computeTaskInfo[taskId].computeASeqLen, computeTaskInfo[taskId].computeBSeqLen, this->headDim); +} + + +template +__aicore__ inline void HstuDenseForwardJaggedKernel::ComputeVecScore(uint32_t taskId) +{ + int64_t biasOffset = computeTaskInfo[taskId].batchId * this->headNum * this->maxSeqLen * this->maxSeqLen + \ + computeTaskInfo[taskId].headId * this->maxSeqLen * this->maxSeqLen + \ + computeTaskInfo[taskId].qSeqId * this->maxSeqLen * this->blockHeight + \ + computeTaskInfo[taskId].kSeqId * this->blockHeight; + + int64_t maskOffset = biasOffset; + + this->VecScoreImpl(taskId, biasOffset, maskOffset, computeTaskInfo[taskId].scale, + computeTaskInfo[taskId].causalMask, computeTaskInfo[taskId].computeASeqLen, + computeTaskInfo[taskId].computeBSeqLen); +} + +template +__aicore__ inline void HstuDenseForwardJaggedKernel::TransResult(uint32_t transtaskId) +{ + this->DoTransSvImpl(transtaskId, trasnTaskInfo[transtaskId].ioOffset, trasnTaskInfo[transtaskId].computeASeqLen); +} + +template +__aicore__ inline void HstuDenseForwardJaggedKernel::ComputeAllBlock() +{ + GetTaskInfo(this->sBlkId); + + uint32_t taskId = 0; + uint32_t transtaskId = 0; + + uint32_t currentTaskId = 0; + uint32_t preTaskId = 0; + uint32_t prePreTaskId = 0; + uint32_t nextTaskId = 0; + + for (auto blkId = sBlkId; blkId < eBlkId; blkId++) { + auto kSeqNum = computeTaskInfo[taskId % COMPUTE_PIPE_NUM].kSeqNum; + for (auto kSeqId = 0; kSeqId < kSeqNum; kSeqId++) { + uint32_t causalMask = 0; + + if ((this->maskType == CausalMaskT::MASK_TRIL) && + kSeqId > computeTaskInfo[taskId % COMPUTE_PIPE_NUM].qSeqId) { + continue; + } + + currentTaskId = taskId % COMPUTE_PIPE_NUM; + preTaskId = (taskId - 1) % COMPUTE_PIPE_NUM; + prePreTaskId = (taskId - 2) % COMPUTE_PIPE_NUM; + nextTaskId = (taskId + 1) % COMPUTE_PIPE_NUM; + + if ((this->maskType == CausalMaskT::MASK_TRIL) && + kSeqId == computeTaskInfo[currentTaskId].qSeqId) { + causalMask = 1; + } + + this->computeTaskInfo[currentTaskId].transTaskId = transtaskId % TRANS_PIPE_NUM; + this->computeTaskInfo[currentTaskId].causalMask = causalMask; + this->computeTaskInfo[currentTaskId].kSeqId = kSeqId; + this->computeTaskInfo[currentTaskId].computeBSeqLen = + (kSeqId != (kSeqNum - 1)) ? + (this->blockHeight) : + (this->computeTaskInfo[currentTaskId].actualSeqLen - kSeqId * this->blockHeight); + this->computeTaskInfo[currentTaskId].kvOffset = \ + this->computeTaskInfo[currentTaskId].batchOffset * this->headDim * this->headNum + \ + this->computeTaskInfo[currentTaskId].kSeqId * this->blockHeight * this->headNum * this->headDim + \ + this->computeTaskInfo[currentTaskId].headId * this->headDim; + + // matmul qk + this->ComputeQkMatmul(currentTaskId); + + // matmul sv + if (taskId > 1) { + this->ComputeSvMatmul(prePreTaskId); + } + + // VecScore + if (taskId > 0) { + this->ComputeVecScore(preTaskId); + } + + // wait qk + this->WaitQkMatmul(); + + // wait sv + if (taskId > 1) { + this->WaitSvMatmul(); + } + + computeTaskInfo[nextTaskId] = computeTaskInfo[currentTaskId]; + taskId++; + } + + this->trasnTaskInfo[transtaskId % TRANS_PIPE_NUM] = this->computeTaskInfo[currentTaskId]; + if (transtaskId > 1) { + this->TransResult((transtaskId - 2) % TRANS_PIPE_NUM); + } + transtaskId++; + + this->UpdateTaskInfo(taskId % COMPUTE_PIPE_NUM); + } + + if (taskId == 0) { + return; + } + + if (taskId == 1) { + this->ComputeVecScore(currentTaskId); + pipe_barrier(PIPE_ALL); + + this->ComputeSvMatmul(currentTaskId); + this->WaitSvMatmul(); + + this->TransResult((transtaskId - 1) % TRANS_PIPE_NUM); + return; + } + + if (transtaskId == 1) { + this->ComputeSvMatmul(preTaskId); + this->WaitSvMatmul(); + + this->ComputeVecScore(currentTaskId); + pipe_barrier(PIPE_ALL); + + this->ComputeSvMatmul(currentTaskId); + this->WaitSvMatmul(); + this->TransResult((transtaskId - 1) % TRANS_PIPE_NUM); + return; + } + + this->ComputeSvMatmul(preTaskId); + this->WaitSvMatmul(); + + this->ComputeVecScore(currentTaskId); + pipe_barrier(PIPE_ALL); + + this->ComputeSvMatmul(currentTaskId); + this->WaitSvMatmul(); + + this->TransResult((transtaskId - 2) % TRANS_PIPE_NUM); + this->TransResult((transtaskId - 1) % TRANS_PIPE_NUM); +} + +template +__aicore__ inline void HstuDenseForwardJaggedKernel::FillTaskInfo(uint32_t batchId, uint32_t headId, + int64_t seqGlobalOffset, uint32_t taskId) +{ + if (batchId >= this->batchSize) { + return; + } + + taskId = taskId % COMPUTE_PIPE_NUM; + + auto nextBatchSeqOffset = this->seqOffsets[batchId + 1]; + auto currentBatchSeqOffset = this->seqOffsets[batchId]; + + computeTaskInfo[taskId].seqGlobalOffset = seqGlobalOffset; + computeTaskInfo[taskId].batchId = batchId; + computeTaskInfo[taskId].actualSeqLen = nextBatchSeqOffset - currentBatchSeqOffset; + computeTaskInfo[taskId].scale = this->siluScale; + computeTaskInfo[taskId].batchOffset = currentBatchSeqOffset; + computeTaskInfo[taskId].headSeqLimit = + computeTaskInfo[taskId].batchOffset * this->headNum + computeTaskInfo[taskId].actualSeqLen * (headId + 1); + + auto batchInnerOffset = seqGlobalOffset - (computeTaskInfo[taskId].batchOffset * this->headNum); + computeTaskInfo[taskId].headId = headId; + computeTaskInfo[taskId].qSeqId = + (batchInnerOffset - computeTaskInfo[taskId].headId * computeTaskInfo[taskId].actualSeqLen) / this->blockHeight; + computeTaskInfo[taskId].kSeqNum = + (computeTaskInfo[taskId].actualSeqLen + this->blockHeight - 1) / this->blockHeight; + + computeTaskInfo[taskId].ioOffset = + computeTaskInfo[taskId].batchOffset * this->headDim * this->headNum + \ + computeTaskInfo[taskId].qSeqId * this->blockHeight * this->headNum * this->headDim + \ + computeTaskInfo[taskId].headId * this->headDim; + + if ((computeTaskInfo[taskId].headSeqLimit - seqGlobalOffset) >= this->blockHeight) { + computeTaskInfo[taskId].computeASeqLen = this->blockHeight; + } else { + computeTaskInfo[taskId].computeASeqLen = computeTaskInfo[taskId].headSeqLimit - seqGlobalOffset; + } +} + +template +__aicore__ inline void HstuDenseForwardJaggedKernel::UpdateTaskInfo(uint32_t taskId) +{ + auto batchId = computeTaskInfo[taskId].batchId; + auto headId = computeTaskInfo[taskId].headId; + + int64_t seqGlobalOffset = computeTaskInfo[taskId].seqGlobalOffset; + int64_t gap = computeTaskInfo[taskId].headSeqLimit - seqGlobalOffset; + + if (gap <= this->blockHeight) { + headId++; + if (headId >= this->headNum) { + batchId++; + } + + if (batchId >= this->batchSize) { + return; + } + + seqGlobalOffset = seqGlobalOffset + gap; + headId = headId % this->headNum; + this->FillTaskInfo(batchId, headId, seqGlobalOffset, taskId); + } else { + computeTaskInfo[taskId].seqGlobalOffset = seqGlobalOffset + this->blockHeight; + + uint32_t computeASeqLen = this->blockHeight; + if ((computeTaskInfo[taskId].seqGlobalOffset + this->blockHeight) > computeTaskInfo[taskId].headSeqLimit) { + computeASeqLen = computeTaskInfo[taskId].headSeqLimit - computeTaskInfo[taskId].seqGlobalOffset; + } + + auto batchInnerOffset = + computeTaskInfo[taskId].seqGlobalOffset - (computeTaskInfo[taskId].batchOffset * this->headNum); + computeTaskInfo[taskId].qSeqId = + (batchInnerOffset - computeTaskInfo[taskId].headId * computeTaskInfo[taskId].actualSeqLen) / + this->blockHeight; + computeTaskInfo[taskId].ioOffset = + computeTaskInfo[taskId].batchOffset * this->headDim * this->headNum + \ + computeTaskInfo[taskId].qSeqId * this->blockHeight * this->headNum * this->headDim + \ + computeTaskInfo[taskId].headId * this->headDim; + computeTaskInfo[taskId].computeASeqLen = computeASeqLen; + } +} + +template +__aicore__ inline void HstuDenseForwardJaggedKernel::GetTaskInfo(uint32_t sBlkId) +{ + uint32_t offsetOfBlk = 0; + int64_t offsetOfSeq = 0; + int64_t seqGlobalOffset = 0; + + for (auto index = 0; index < this->batchSize * this->headNum; index++) { + uint32_t batchId = index / this->headNum; + uint32_t headId = index % this->headNum; + + uint32_t batchSeqSize = this->seqOffsets[batchId + 1] - this->seqOffsets[batchId]; + uint32_t batchBlkSize = (batchSeqSize + this->blockHeight - 1) / this->blockHeight; + + if (sBlkId < (offsetOfBlk + batchBlkSize)) { + uint32_t innerBlkId = sBlkId - offsetOfBlk; + seqGlobalOffset = seqGlobalOffset + innerBlkId * this->blockHeight; + this->FillTaskInfo(batchId, headId, seqGlobalOffset, 0); + return; + } + + offsetOfBlk += batchBlkSize; + seqGlobalOffset += batchSeqSize; + } +} + +template +__aicore__ inline int +HstuDenseForwardJaggedKernel::PreInit(const HstuDenseForwardTilingData *__restrict tilingDataPtr) +{ + this->maxSeqLen = tilingDataPtr->maxSeqLen; + this->sBlkId = tilingDataPtr->eachCoreStartBlockId[GetBlockIdx()]; + this->eBlkId = tilingDataPtr->eachCoreEndBlockId[GetBlockIdx()]; + + if (this->sBlkId == this->eBlkId && this->eBlkId == 0) { + return -1; + } + + for (auto i = 0; i < this->xDim0 + 1; i++) { + this->seqOffsets[i] = tilingDataPtr->seqOffset[i]; + } + + this->batchSize = this->xDim0; + this->seqLen = this->xDim1; + this->headNum = this->xDim2; + this->headDim = this->xDim3; + return 0; +} + +} + +#endif \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/hstu_dense_forward/op_kernel/hstu_dense_forward_kernel.h b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_forward/op_kernel/hstu_dense_forward_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..1fa954dc34079c120bd11228e44470dea090db46 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/hstu_dense_forward/op_kernel/hstu_dense_forward_kernel.h @@ -0,0 +1,272 @@ +/* Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#ifndef HSTU_DENSE_FORWARD_KERNEL_FUN_H +#define HSTU_DENSE_FORWARD_KERNEL_FUN_H +#ifdef SUPPORT_V200 + #include "hstu_dense_forward_kernel_patten_bsnd_v200.h" +#else + #include "hstu_dense_forward_kernel_patten_bsnd.h" +#endif +namespace HstuDenseForward { + +struct QkMatmulArgs { + int64_t taskId = INVALID_TASK_ID; + int64_t qkBlockId; + int64_t batchId; + int64_t headId; + int64_t qSeqId; + int64_t kSeqId; +}; + +struct ScoreVectorArgs { + int64_t taskId = INVALID_TASK_ID; + int64_t scoreBlockId; + int64_t batchId; + int64_t headId; + int64_t qSeqId; + int64_t kSeqId; +}; + +struct SvMatmulArgs { + int64_t transTaskId = INVALID_TASK_ID; + int64_t taskId = INVALID_TASK_ID; + int64_t scoreBlockId; + int64_t batchId; + int64_t headId; + int64_t qSeqId; + int64_t kSeqId; + int64_t vSeqId; +}; + +struct SVTransArgs { + int64_t transTaskId = INVALID_TASK_ID; + int64_t scoreBlockId; + int64_t batchId; + int64_t headId; + int64_t qSeqId; +}; + +template +class HstuDenseForwardKernel : public HstuDenseForwardKernelPattenBsnd { +public: + __aicore__ inline HstuDenseForwardKernel() {} + + __aicore__ inline void PreInit(const HstuDenseForwardTilingData *__restrict tilingDataPtr) + { + seqBlockNumQk = DivCeil(this->xDim1, this->blockHeight); // 不满足一个block的按照一个block进行计算 + qkTotalBlock = this->xDim0 * this->xDim2 * seqBlockNumQk; + } + + __aicore__ inline void VecScore(ScoreVectorArgs& scoreArgs) + { + if (scoreArgs.taskId == INVALID_TASK_ID) { + return; + } + + int64_t attnBiasOffset = scoreArgs.batchId * this->xDim2 * this->xDim1 * this->xDim1 + \ + scoreArgs.headId * this->xDim1 * this->xDim1 + \ + scoreArgs.qSeqId * this->blockHeight * this->xDim1 + \ + scoreArgs.kSeqId * this->blockHeight; + + int64_t maskOffset = attnBiasOffset; +#ifdef SUPPORT_V200 + maskOffset = scoreArgs.batchId * this->xDim1 * this->xDim1 + \ + scoreArgs.qSeqId * this->blockHeight * this->xDim1 + \ + scoreArgs.kSeqId * this->blockHeight; +#endif + int causalMask = ((scoreArgs.qSeqId == scoreArgs.kSeqId) && + (this->maskType == CausalMaskT::MASK_TRIL)) ? 1 : 0; + + int64_t m = (scoreArgs.qSeqId != (seqBlockNumQk - 1)) ? this->blockHeight : + (this->xDim1 - scoreArgs.qSeqId * this->blockHeight); + int64_t n = (scoreArgs.kSeqId != (seqBlockNumQk - 1)) ? this->blockHeight : + (this->xDim1 - scoreArgs.kSeqId * this->blockHeight); + + this->VecScoreImpl(scoreArgs.taskId, attnBiasOffset, maskOffset, this->siluScale, causalMask, m, n); + } + + __aicore__ inline void DoQkMatmul(QkMatmulArgs& qkPosArgs) + { + if (qkPosArgs.taskId == INVALID_TASK_ID) { + return; + } + + int64_t qOffset = qkPosArgs.batchId * this->xDim1 * this->xDim2 * this->xDim3 + \ + qkPosArgs.qSeqId * this->blockHeight * this->xDim2 * this->xDim3 + \ + qkPosArgs.headId * this->xDim3; + int64_t kOffset = qkPosArgs.batchId * this->xDim1 * this->xDim2 * this->xDim3 + \ + qkPosArgs.kSeqId * this->blockHeight * this->xDim2 * this->xDim3 + \ + qkPosArgs.headId * this->xDim3; + + int64_t m = (qkPosArgs.qSeqId != (seqBlockNumQk - 1)) ? this->blockHeight : + (this->xDim1 - qkPosArgs.qSeqId * this->blockHeight); + int64_t n = (qkPosArgs.kSeqId != (seqBlockNumQk - 1)) ? this->blockHeight : + (this->xDim1 - qkPosArgs.kSeqId * this->blockHeight); + + this->DoQkMatmulImpl(qOffset, kOffset, qkPosArgs.taskId, m, n, this->xDim3); + } + + __aicore__ inline void DoSvMatmul(SvMatmulArgs& svArgs) + { + if (svArgs.taskId == INVALID_TASK_ID) { + return; + } + + int64_t vOffset = svArgs.batchId * this->xDim1 * this->xDim2 * this->xDim3 + + svArgs.vSeqId * this->blockHeight * this->xDim2 * this->xDim3 + + svArgs.headId * this->xDim3; + + int64_t m = (svArgs.qSeqId != (seqBlockNumQk - 1)) ? this->blockHeight : + (this->xDim1 - svArgs.qSeqId * this->blockHeight); + int64_t n = (svArgs.vSeqId != (seqBlockNumQk - 1)) ? this->blockHeight : + (this->xDim1 - svArgs.vSeqId * this->blockHeight); + + if (svArgs.vSeqId == 0) { + // Override + this->DoSvMatmulImpl(vOffset, svArgs.taskId, svArgs.transTaskId, 0, m, this->xDim3, n); + } else { + // Automic Add + this->DoSvMatmulImpl(vOffset, svArgs.taskId, svArgs.transTaskId, 1, m, this->xDim3, n); + } + } + + __aicore__ inline void DoTransSv(SVTransArgs& args) + { + if (args.transTaskId == INVALID_TASK_ID) { + return; + } + + int64_t outStartOffset = args.batchId * this->xDim1 * this->xDim2 * this->xDim3 + \ + args.qSeqId * this->blockHeight * this->xDim2 * this->xDim3 + \ + args.headId * this->xDim3; + + int64_t m = (args.qSeqId != (seqBlockNumQk - 1)) ? this->blockHeight : + (this->xDim1 - args.qSeqId * this->blockHeight); + + this->DoTransSvImpl(args.transTaskId, outStartOffset, m); + } + + __aicore__ inline void Compute(const HstuDenseForwardTilingData *__restrict tilingDataPtr) + { + PreInit(tilingDataPtr); + int64_t taskId = 0; + int64_t transTaskId = 0; + + int64_t cubeCoreLen = this->qkTotalBlock / GetBlockNum(); + int64_t cubeCoreSplitId = this->qkTotalBlock % GetBlockNum(); + + int64_t blockNumOfOneBatch = this->xDim2 * this->seqBlockNumQk; + int64_t blockNumOfOneHead = this->seqBlockNumQk; + + int64_t lenOfThisCore; + int64_t offsetOfThisCore; + + if (GetBlockIdx() / SPLIT_CORE >= cubeCoreSplitId) { + lenOfThisCore = cubeCoreLen; + offsetOfThisCore = + cubeCoreSplitId * (cubeCoreLen + 1) + (GetBlockIdx() / SPLIT_CORE - cubeCoreSplitId) * cubeCoreLen; + } else { + lenOfThisCore = cubeCoreLen + 1; + offsetOfThisCore = GetBlockIdx() / SPLIT_CORE * (cubeCoreLen + 1); + } + + SVTransArgs lastSvTrans; + SVTransArgs lastLastSvTrans; + + ScoreVectorArgs lastVectorScore; + SvMatmulArgs lastSvMatmulArgs; + SvMatmulArgs lastLastSvMatmulArgs; + for (int64_t qBlockId = offsetOfThisCore; qBlockId < offsetOfThisCore + lenOfThisCore; qBlockId++) { + int64_t batchId = qBlockId / blockNumOfOneBatch; + int64_t batchRemain = qBlockId % blockNumOfOneBatch; + + int64_t headId = batchRemain / blockNumOfOneHead; + int64_t headReamin = batchRemain % blockNumOfOneHead; + + int64_t qSeqId = headReamin; + + if ((headId + qSeqId) % SPLIT_CORE != GetBlockIdx() % SPLIT_CORE) { + continue; + } + for (int64_t kSeqId = 0; kSeqId < this->seqBlockNumQk; kSeqId++) { + if ((this->maskType == CausalMaskT::MASK_TRIL) and (kSeqId > qSeqId)) { + continue; + } + + int qkBlockId = qBlockId * this->seqBlockNumQk + kSeqId; + QkMatmulArgs qkArgs = {taskId, qkBlockId, batchId, headId, qSeqId, kSeqId}; + this->DoQkMatmul(qkArgs); + if (taskId > 1) { + this->DoSvMatmul(lastLastSvMatmulArgs); + } + ScoreVectorArgs scoreArgs = {taskId, qkBlockId, batchId, headId, qSeqId, kSeqId}; + if (taskId > 0) { + this->VecScore(lastVectorScore); + } + lastVectorScore = scoreArgs; + + int64_t vSeqId = kSeqId; + SvMatmulArgs svMatmulArgs = {transTaskId, taskId, qkBlockId, batchId, headId, qSeqId, kSeqId, vSeqId}; + lastLastSvMatmulArgs = lastSvMatmulArgs; + lastSvMatmulArgs = svMatmulArgs; + + this->WaitQkMatmul(); + if (taskId > 1) { + this->WaitSvMatmul(); + } + + taskId += 1; + } + + SVTransArgs svTransArgs = {transTaskId, qBlockId * this->seqBlockNumQk, batchId, headId, qSeqId}; + if (transTaskId > 1) { + this->DoTransSv(lastLastSvTrans); + } + lastLastSvTrans = lastSvTrans; + lastSvTrans = svTransArgs; + transTaskId += 1; + } + if (taskId == 0) { + return; + } + + if (taskId == 1) { + this->VecScore(lastVectorScore); + pipe_barrier(PIPE_ALL); + this->DoSvMatmul(lastSvMatmulArgs); + this->WaitSvMatmul(); + this->DoTransSv(lastSvTrans); + return; + } + + this->DoSvMatmul(lastLastSvMatmulArgs); + this->VecScore(lastVectorScore); + pipe_barrier(PIPE_ALL); + this->DoSvMatmul(lastSvMatmulArgs); + this->WaitSvMatmul(); + this->DoTransSv(lastLastSvTrans); + this->WaitSvMatmul(); + this->DoTransSv(lastSvTrans); + } + +private: + int64_t seqBlockNumQk; + int64_t qkTotalBlock; +}; + +} + +#endif \ No newline at end of file