diff --git a/src/kernels/mixkernels/toppsample/op_kernel/toppsample.cpp b/src/kernels/mixkernels/toppsample/op_kernel/toppsample.cpp index 7735712684c27008c83d358be580eb7a70390704..fbc9eda7c03d259f81015930d627bc2e69aa9905 100644 --- a/src/kernels/mixkernels/toppsample/op_kernel/toppsample.cpp +++ b/src/kernels/mixkernels/toppsample/op_kernel/toppsample.cpp @@ -20,7 +20,6 @@ static constexpr uint32_t DEFAULT_STRIDE = 8; static constexpr uint32_t FP32_PER_REPEAT = 64; static constexpr uint32_t FP16_PER_REPEAT = 128; static constexpr uint32_t FP16_PER_BLOCK = 16; -static constexpr uint32_t MAX_BATCH = 1024; static constexpr uint32_t NUM_4 = 4; using AscendC::HardEvent; @@ -45,6 +44,7 @@ public: nlCoreRun_ = (firstDim_ + realCore_ - 1) / realCore_; lCoreRun_ = firstDim_ - (realCore_ - 1) * nlCoreRun_; dynamicRound_ = (blockIdx_ == realCore_ - 1) ? lCoreRun_ : nlCoreRun_; + maxBatch_ = (firstDim_ + FP16_PER_BLOCK - 1) / FP16_PER_BLOCK * FP16_PER_BLOCK; xGm_.SetGlobalBuffer((__gm__ T *)cumsumed_probs); yGm_.SetGlobalBuffer((__gm__ T *)topp); // batch,num_samples @@ -54,8 +54,8 @@ public: pipe_.InitBuffer(inputBuf_, tempUbEleAligened_ * DATA_BYTE); pipe_.InitBuffer(tempBuf_, tempUbEleAligened_ * DATA_BYTE * DATA_BYTE); pipe_.InitBuffer(fp32Buf_, tempUbEleAligened_ * DATA_BYTE * DATA_BYTE); - pipe_.InitBuffer(yBuf_, MAX_BATCH * DATA_BYTE); // topp - pipe_.InitBuffer(yF32Buf_, MAX_BATCH * DATA_BYTE * DATA_BYTE); // toppfp32 + pipe_.InitBuffer(yBuf_, maxBatch_ * DATA_BYTE); // topp + pipe_.InitBuffer(yF32Buf_, maxBatch_ * DATA_BYTE * DATA_BYTE); // toppfp32 pipe_.InitBuffer(int8Buf_, tempUbEleAligened_ / DEFAULT_STRIDE); // compare pipe_.InitBuffer(blockBuf_, BLK_SIZE); // 存下标 pipe_.InitBuffer(int32Buf_, MAX_CORE_NUM * DATA_BYTE * DATA_BYTE); // 每个核做几个batch @@ -65,7 +65,7 @@ public: __aicore__ inline void PickUpRand() { AscendC::LocalTensor buf = yBuf_.Get(); - DataCopy(buf, yGm_, MAX_BATCH); + DataCopy(buf, yGm_, maxBatch_); } __aicore__ inline void FirstPick(uint32_t cid, uint32_t offset) @@ -127,7 +127,7 @@ public: Duplicate(uint32Buf_, uint32_t(0), tempUbEleAligened_ / BLK_SIZE); // 截断数可能是batch个,也可能是1个 // 每个batch往后取一个随机数。(*(tilingUb_ + batchOffset)) - Cast(toppBufF32_, toppBuf_, AscendC::RoundMode::CAST_NONE, MAX_BATCH); + Cast(toppBufF32_, toppBuf_, AscendC::RoundMode::CAST_NONE, maxBatch_); for (int cid = 0; cid < dynamicRound_; cid++) { // 每个核做多少次 absIdx_ = 0; uint32_t batchOffset = (blockIdx_ * nlCoreRun_ + cid) % MAX_CORE_NUM; @@ -336,6 +336,7 @@ private: uint32_t expandLastDim_{0}; uint32_t numSamplesMax_{0}; uint32_t firstDim_{0}; + uint32_t maxBatch_{0}; float maxNum_{0}; float tempValue_{0}; uint32_t perCoreRunNum_{0};