From 14ebdddb3e85f312a5c9fdeb6b3f72a83c7a2542 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Mon, 14 Jul 2025 21:51:05 +0800 Subject: [PATCH 1/2] style cleancode 0714 --- .../multiStream_multiGraph_demo.cpp | 5 +- .../op_demo/all_gather/all_gather_demo.cpp | 3 +- example/op_demo/mla_preprocess/mlapo_demo.cpp | 46 ++++---- .../self_attention_encoder_demo.cpp | 3 +- scripts/build.sh | 2 +- src/atb/core/node_impl/mki_node_implement.cpp | 1 - .../device_tiling_buffer_pool.cpp | 4 +- src/atb/operation/operation_base.cpp | 2 +- src/atb/runner/ops_runner.cpp | 2 +- src/cinterface/atb_acl_mla.cpp | 4 +- src/cinterface/atb_acl_util.cpp | 3 +- .../atb/core/node_impl/mki_node_implement.h | 2 +- .../block_copy/block_copy_operation.cpp | 4 +- src/ops_infer/fill/fill_ops_runner.h | 26 +++-- .../linear_parallel_operation.cpp | 4 +- .../paged_attention_operation.cpp | 105 +++++++++--------- .../paged_cache_load_operation.cpp | 24 ++-- .../topk_topp_sampling_operation.cpp | 44 ++++---- src/torch_atb/resource/utils.h | 2 +- 19 files changed, 151 insertions(+), 135 deletions(-) diff --git a/example/multiStream/multiStream_multiGraph_demo.cpp b/example/multiStream/multiStream_multiGraph_demo.cpp index 67ef86a5..cc732b76 100644 --- a/example/multiStream/multiStream_multiGraph_demo.cpp +++ b/example/multiStream/multiStream_multiGraph_demo.cpp @@ -85,7 +85,7 @@ static void CreateMiniGraphOperation(atb::GraphParam &opGraph, atb::Operation ** opGraph.outTensorNum = GRAPH_OUT_TENSOR_NUM; opGraph.internalTensorNum = GRAPH_INTERNAL_TENSOR_NUM; const int GRAPH_NODE_NUM = 3; - opGraph.nodes.resize(3); + opGraph.nodes.resize(GRAPH_NODE_NUM); size_t nodeId = 0; atb::Node &addNode = opGraph.nodes.at(nodeId++); @@ -257,8 +257,7 @@ int main() packRW.outTensors.resize(outTensorNum); operationWR->InferShape(intensorDescs, outtensorDescs); - aclError ret; - ret = CreateInTensors(packWR.inTensors, intensorDescs); + aclError ret = CreateInTensors(packWR.inTensors, intensorDescs); if (ret != 0) { exit(ret); } diff --git a/example/op_demo/all_gather/all_gather_demo.cpp b/example/op_demo/all_gather/all_gather_demo.cpp index 9d2a43f2..0dc0865d 100644 --- a/example/op_demo/all_gather/all_gather_demo.cpp +++ b/example/op_demo/all_gather/all_gather_demo.cpp @@ -14,6 +14,7 @@ namespace { const int64_t INTPUT_DIM_NUM = 2; +const int64_t OUTPUT_DIM_NUM = 3; const int64_t DIM2 = 2; const int64_t DIM3 = 3; const int64_t DIM5 = 5; @@ -59,7 +60,7 @@ atb::Status AllGatherSample(int rank, int rankSize) atb::Tensor output; output.desc.dtype = ACL_FLOAT16; output.desc.format = ACL_FORMAT_ND; - output.desc.shape.dimNum = 3; + output.desc.shape.dimNum = OUTPUT_DIM_NUM; output.desc.shape.dims[IDX0] = DIM2; output.desc.shape.dims[IDX1] = DIM3; output.desc.shape.dims[IDX2] = DIM5; diff --git a/example/op_demo/mla_preprocess/mlapo_demo.cpp b/example/op_demo/mla_preprocess/mlapo_demo.cpp index 744ba3ec..7cd6640a 100644 --- a/example/op_demo/mla_preprocess/mlapo_demo.cpp +++ b/example/op_demo/mla_preprocess/mlapo_demo.cpp @@ -34,7 +34,7 @@ const int32_t MATMUL_DIM2112 = 2112; const int32_t MATMUL_DIM32 = 32; const int32_t MATMUL_DIM192 = 192; const int32_t MATMUL_DIM48 = 48; -} +} // namespace /** * @brief 准备atb::VariantPack中的输入tensor @@ -48,8 +48,9 @@ atb::Status PrepareInTensor1(atb::Context *contextPtr, aclrtStream stream, aclDa { // 创建shape为[tokenNum, 7168]的输入input tensor atb::Tensor input; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * RMSNORM_QUANT_DIM7168, 0), dtype, - aclFormat::ACL_FORMAT_ND, {tokenNum, RMSNORM_QUANT_DIM7168}, input, dtype)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * RMSNORM_QUANT_DIM7168, 0), + dtype, aclFormat::ACL_FORMAT_ND, {tokenNum, RMSNORM_QUANT_DIM7168}, input, + dtype)); // 创建shape为[7168]的输入gamma0 tensor atb::Tensor gamma0; CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(RMSNORM_QUANT_DIM7168, 0), dtype, @@ -68,16 +69,18 @@ atb::Status PrepareInTensor1(atb::Context *contextPtr, aclrtStream stream, aclDa aclFormat::ACL_FORMAT_ND, {1}, quantOffset0)); // 创建shape为[1,224,2112,32]的输入wdqkv tensor atb::Tensor wdqkv; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(MATMUL_DIM224 * MATMUL_DIM2112 * MATMUL_DIM32, 1), ACL_INT8, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, MATMUL_DIM224, MATMUL_DIM2112, MATMUL_DIM32}, wdqkv)); + CHECK_STATUS(CreateTensorFromVector( + contextPtr, stream, std::vector(MATMUL_DIM224 * MATMUL_DIM2112 * MATMUL_DIM32, 1), ACL_INT8, + aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, MATMUL_DIM224, MATMUL_DIM2112, MATMUL_DIM32}, wdqkv)); // 创建shape为[2112]的输入deScale0 tensor atb::Tensor deScale0; if (dtype == ACL_BF16) { CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(MATMUL_DIM2112, 1), ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {MATMUL_DIM2112}, deScale0)); } else { - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(MATMUL_DIM2112, 10), ACL_INT64, - aclFormat::ACL_FORMAT_ND, {MATMUL_DIM2112}, deScale0)); + int64_t deScale0Value = 10; + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(MATMUL_DIM2112, deScale0Value), + ACL_INT64, aclFormat::ACL_FORMAT_ND, {MATMUL_DIM2112}, deScale0)); } // 创建shape为[2112]的输入bias0 tensor atb::Tensor bias0; @@ -116,16 +119,17 @@ atb::Status PrepareInTensor2(atb::Context *contextPtr, aclrtStream stream, aclDa { // 创建shape为[1,48,headNum*192,32]的输入wuq tensor atb::Tensor wuq; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(MATMUL_DIM48 * headNum * MATMUL_DIM192 * MATMUL_DIM32, 1), ACL_INT8, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, MATMUL_DIM48, headNum * MATMUL_DIM192, MATMUL_DIM32}, wuq)); + CHECK_STATUS(CreateTensorFromVector( + contextPtr, stream, std::vector(MATMUL_DIM48 * headNum * MATMUL_DIM192 * MATMUL_DIM32, 1), ACL_INT8, + aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, MATMUL_DIM48, headNum * MATMUL_DIM192, MATMUL_DIM32}, wuq)); // 创建shape为[headNum*192]的输入deScale1 tensor atb::Tensor deScale1; if (dtype == ACL_BF16) { - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * MATMUL_DIM192, 1), ACL_FLOAT, - aclFormat::ACL_FORMAT_ND, {headNum * MATMUL_DIM192}, deScale1)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * MATMUL_DIM192, 1), + ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {headNum * MATMUL_DIM192}, deScale1)); } else { - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * MATMUL_DIM192, 1), ACL_INT64, - aclFormat::ACL_FORMAT_ND, {headNum * MATMUL_DIM192}, deScale1)); + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector(headNum * MATMUL_DIM192, 1), + ACL_INT64, aclFormat::ACL_FORMAT_ND, {headNum * MATMUL_DIM192}, deScale1)); } // 创建shape为[headNum*192]的输入bias1 tensor atb::Tensor bias1; @@ -141,22 +145,26 @@ atb::Status PrepareInTensor2(atb::Context *contextPtr, aclrtStream stream, aclDa aclFormat::ACL_FORMAT_ND, {tokenNum, ROPE_DIM64}, cos, dtype)); // 创建shape为[tokenNum,64]的输入sin tensor atb::Tensor sin; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * ROPE_DIM64, 0.5), dtype, + __fp16 sinValue = 0.5; + CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(tokenNum * ROPE_DIM64, sinValue), dtype, aclFormat::ACL_FORMAT_ND, {tokenNum, ROPE_DIM64}, sin, dtype)); // 创建shape为[headNum,32,128,16]的输入wuk tensor atb::Tensor wuk; - CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<__fp16>(headNum * MATMUL_DIM32 * BLOCK_SIZE * NZ_DIM16, 0), dtype, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {headNum, MATMUL_DIM32, BLOCK_SIZE, NZ_DIM16}, wuk, dtype)); + CHECK_STATUS(CreateTensorFromVector( + contextPtr, stream, std::vector<__fp16>(headNum * MATMUL_DIM32 * BLOCK_SIZE * NZ_DIM16, 0), dtype, + aclFormat::ACL_FORMAT_FRACTAL_NZ, {headNum, MATMUL_DIM32, BLOCK_SIZE, NZ_DIM16}, wuk, dtype)); // 创建shape为[BLOCK_NUM, headNum*512/32,block_size, 32]的输入kvCache tensor atb::Tensor kvCache; CHECK_STATUS(CreateTensorFromVector( contextPtr, stream, std::vector(BLOCK_NUM * headNum * NOPE_DIM512 * BLOCK_SIZE, 1), ACL_INT8, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {BLOCK_NUM, headNum * NOPE_DIM512 / MATMUL_DIM32, BLOCK_SIZE, MATMUL_DIM32}, kvCache)); + aclFormat::ACL_FORMAT_FRACTAL_NZ, {BLOCK_NUM, headNum * NOPE_DIM512 / MATMUL_DIM32, BLOCK_SIZE, MATMUL_DIM32}, + kvCache)); // 创建shape为[BLOCK_NUM, headNum*64/16 ,block_size, 16]的输入kvCacheRope tensor atb::Tensor kvCacheRope; CHECK_STATUS(CreateTensorFromVector( - contextPtr, stream, std::vector<__fp16>(BLOCK_NUM * headNum * ROPE_DIM64 / NZ_DIM16 * BLOCK_SIZE * NZ_DIM16, 0), dtype, - aclFormat::ACL_FORMAT_FRACTAL_NZ, {BLOCK_NUM, headNum * ROPE_DIM64 / NZ_DIM16, BLOCK_SIZE, NZ_DIM16}, kvCacheRope, dtype)); + contextPtr, stream, std::vector<__fp16>(BLOCK_NUM * headNum * ROPE_DIM64 / NZ_DIM16 * BLOCK_SIZE * NZ_DIM16, 0), + dtype, aclFormat::ACL_FORMAT_FRACTAL_NZ, {BLOCK_NUM, headNum * ROPE_DIM64 / NZ_DIM16, BLOCK_SIZE, NZ_DIM16}, + kvCacheRope, dtype)); auto slotmappingHost = std::vector(1, tokenNum); for (size_t i = 0; i < slotmappingHost.size(); i++) slotmappingHost[i] = static_cast(i); diff --git a/example/op_demo/self_attention/self_attention_encoder_demo.cpp b/example/op_demo/self_attention/self_attention_encoder_demo.cpp index 2f5c4b22..003c7ece 100644 --- a/example/op_demo/self_attention/self_attention_encoder_demo.cpp +++ b/example/op_demo/self_attention/self_attention_encoder_demo.cpp @@ -68,7 +68,8 @@ atb::Status PrepareInTensor(atb::Context *contextPtr, aclrtStream stream, std::v } // 创建tokenOffset,host侧tensor atb::Tensor tensorTokenOffset; - atb::Tensor tensorSeqLen atb::Tensor tensorLayerId; + atb::Tensor tensorSeqLen; + atb::Tensor tensorLayerId; CHECK_STATUS(CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {BATCH_SIZE}, tensorTokenOffset)); tensorTokenOffset.hostData = tokenOffsetHost.data(); // host侧tensor,拷贝值 // 创建seqLen,host侧tensor diff --git a/scripts/build.sh b/scripts/build.sh index a7b67fed..1397d437 100644 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -468,7 +468,7 @@ function fn_build() fi fn_build_3rdparty_for_compile cd $CACHE_DIR - if [ "$CMAKE_CXX_COMPILER_LAUNCHER" == "" ] && [ command -v ccache &> /dev/null ] ; then + if [ "$CMAKE_CXX_COMPILER_LAUNCHER" == "" ] && command -v ccache &> /dev/null; then COMPILE_OPTIONS="${COMPILE_OPTIONS} -DCMAKE_CXX_COMPILER_LAUNCHER=ccache" fi echo "COMPILE_OPTIONS:$COMPILE_OPTIONS" diff --git a/src/atb/core/node_impl/mki_node_implement.cpp b/src/atb/core/node_impl/mki_node_implement.cpp index fc9396f5..eacb23b7 100644 --- a/src/atb/core/node_impl/mki_node_implement.cpp +++ b/src/atb/core/node_impl/mki_node_implement.cpp @@ -15,7 +15,6 @@ #include "atb/utils/tensor_util.h" #include "atb/utils/statistic.h" #include "atb/utils/store_util.h" -#include "atb/utils/singleton.h" #include "atb/utils/probe.h" namespace atb { diff --git a/src/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.cpp b/src/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.cpp index c9df0409..5f086f38 100644 --- a/src/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.cpp +++ b/src/atb/core/tiling_buffer_pool/device_tiling_buffer_pool.cpp @@ -11,7 +11,9 @@ #include "atb/utils/log.h" namespace atb { -DeviceTilingBufferPool::DeviceTilingBufferPool(uint64_t blockNum, uint64_t blockSize, const std::function& alloc, const std::function& dealloc) +DeviceTilingBufferPool::DeviceTilingBufferPool(uint64_t blockNum, uint64_t blockSize, + const std::function &alloc, + const std::function &dealloc) : TilingBufferPool(blockNum, blockSize), allocateFunc_(alloc), deallocateFunc_(dealloc) { } diff --git a/src/atb/operation/operation_base.cpp b/src/atb/operation/operation_base.cpp index cc907115..67544272 100644 --- a/src/atb/operation/operation_base.cpp +++ b/src/atb/operation/operation_base.cpp @@ -1080,7 +1080,7 @@ Status OperationBase::Execute(const VariantPack &variantPack, uint8_t *workspace (executeType == EXECUTE_PRELAUNCH ? OPERATION_PRELAUNCH : OPERATION_LAUNCH); std::shared_ptr mstxMemRegister; mstxMemRegister = std::make_shared(); - if (workspaceSize) { + if (workspaceSize != 0) { mstxMemRegister->MstxHeapRegister(workspace, workspaceSize); if (mstxMemRegister && mstxMemRegister->IsValid()) { runnerVariantPack_.mstxMemRegister = mstxMemRegister.get(); diff --git a/src/atb/runner/ops_runner.cpp b/src/atb/runner/ops_runner.cpp index 767ce33d..d5f022d4 100644 --- a/src/atb/runner/ops_runner.cpp +++ b/src/atb/runner/ops_runner.cpp @@ -617,7 +617,7 @@ Status OpsRunner::RunAllKernel(RunnerVariantPack &runnerVariantPack) KernelGraphNode &node = kernelGraph_.nodes.at(nodeId); if (runnerVariantPack.mstxMemRegister != nullptr) { runnerVariantPack.mstxMemRegister->ClearMstxMemRegions(); - if (runnerVariantPack.workspaceBufferSize) { + if (runnerVariantPack.workspaceBufferSize != 0) { runnerVariantPack.workspaceBufferSize = static_cast(TensorUtil::AlignInt(runnerVariantPack.workspaceBufferSize, ALIGN_INT)); runnerVariantPack.mstxMemRegister->AddTensorMemRegions(runnerVariantPack.workspaceBuffer, diff --git a/src/cinterface/atb_acl_mla.cpp b/src/cinterface/atb_acl_mla.cpp index 9ea8a6aa..57c6d91f 100644 --- a/src/cinterface/atb_acl_mla.cpp +++ b/src/cinterface/atb_acl_mla.cpp @@ -114,10 +114,10 @@ atb::Status AtbMLAGetWorkspaceSize(const aclTensor *qNope, const aclTensor *qRop return atb::NO_ERROR; } -atb::Status AtbMLA(void *workSpcace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context) +atb::Status AtbMLA(void *workspcace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context) { atb::VariantPack pack; - atb::Status st = op->Execute(pack, (uint8_t *)(workSpcace), workspaceSize, context); + atb::Status st = op->Execute(pack, (uint8_t *)(workspcace), workspaceSize, context); ATB_CHECK(st == atb::NO_ERROR, "AtbMLA Execute failed!", return st); return st; } diff --git a/src/cinterface/atb_acl_util.cpp b/src/cinterface/atb_acl_util.cpp index 9d1eb39a..cf269672 100644 --- a/src/cinterface/atb_acl_util.cpp +++ b/src/cinterface/atb_acl_util.cpp @@ -51,7 +51,8 @@ atb::Status aclTensorToAtbTensor(const aclTensor *aclTensorSrc, atb::Tensor *atb atbTensorDst->desc = desc; atbTensorDst->deviceData = aclTensorSrc->GetData(); atbTensorDst->hostData = nullptr; - atbTensorDst->dataSize = GetTensorSize(aclTensorSrc) * aclDataTypeSize(dataType); + atbTensorDst->dataSize = + static_cast(GetTensorSize(aclTensorSrc) * static_cast(aclDataTypeSize(dataType))); return atb::NO_ERROR; } diff --git a/src/include/atb/core/node_impl/mki_node_implement.h b/src/include/atb/core/node_impl/mki_node_implement.h index 66e7adb9..c44772d3 100644 --- a/src/include/atb/core/node_impl/mki_node_implement.h +++ b/src/include/atb/core/node_impl/mki_node_implement.h @@ -76,7 +76,7 @@ private: void *argsHostBuffer_ = nullptr; }; -static const std::unordered_map InitAtbMkiErrorHash() noexcept +static inline const std::unordered_map InitAtbMkiErrorHash() noexcept { return {{Mki::ErrorType::NO_ERROR, atb::ErrorType::NO_ERROR}, {Mki::ErrorType::ERROR_INVALID_VALUE, atb::ErrorType::ERROR_INVALID_PARAM}, diff --git a/src/ops_infer/block_copy/block_copy_operation.cpp b/src/ops_infer/block_copy/block_copy_operation.cpp index 4671e10f..0ac8cb88 100644 --- a/src/ops_infer/block_copy/block_copy_operation.cpp +++ b/src/ops_infer/block_copy/block_copy_operation.cpp @@ -161,8 +161,8 @@ Status BlockCopyOperation::SetupDimCheck310P(const SVector &inTenso inTensors.at(INPUT_V_CACHE).desc.format == aclFormat::ACL_FORMAT_FRACTAL_NZ) { if ((inTensors.at(INPUT_K_CACHE).desc.shape.dims[INPUT_DST_BLOCK] != NZBLOCKSIZE) || (inTensors.at(INPUT_V_CACHE).desc.shape.dims[INPUT_DST_BLOCK] != NZBLOCKSIZE) || - (inTensors.at(INPUT_K_CACHE).desc.shape.dims[2] % NZBLOCKSIZE != 0) || - (inTensors.at(INPUT_V_CACHE).desc.shape.dims[2] % NZBLOCKSIZE != 0)) { // 2: dim + (inTensors.at(INPUT_K_CACHE).desc.shape.dims[2] % NZBLOCKSIZE != 0) || // 2: blockSize dim + (inTensors.at(INPUT_V_CACHE).desc.shape.dims[2] % NZBLOCKSIZE != 0)) { // 2: blockSize dim ATB_LOG(ERROR) << GetLogPrefix() << "NZ format tensor dim should be aligned to 16"; return ERROR_INVALID_TENSOR_DIM; } diff --git a/src/ops_infer/fill/fill_ops_runner.h b/src/ops_infer/fill/fill_ops_runner.h index 9df2e690..fb2fbb45 100644 --- a/src/ops_infer/fill/fill_ops_runner.h +++ b/src/ops_infer/fill/fill_ops_runner.h @@ -27,20 +27,22 @@ private: }; namespace infer { +inline bool IsFloatSVectorEqual(const SVector &v1, const SVector &v2) +{ + if (v1.size() != v2.size()) { + return false; + } + for (size_t i = 0; i < v1.size(); ++i) { + if (!UtilsInternal::IsFloatEqual(v1[i], v2[i])) { + return false; + } + } + return true; +} + inline bool operator==(const FillParam &left, const FillParam &right) { - return left.withMask == right.withMask && - [](const SVector &v1, const SVector &v2) { - if (v1.size() != v2.size()) { - return false; - } - for (size_t i = 0; i < v1.size(); ++i) { - if (!UtilsInternal::IsFloatEqual(v1[i], v2[i])) { - return false; - } - } - return true; - }(left.value, right.value) && + return left.withMask == right.withMask && IsFloatSVectorEqual(left.value, right.value) && left.outDim == right.outDim; } } // namespace infer diff --git a/src/ops_infer/linear_parallel/linear_parallel_operation.cpp b/src/ops_infer/linear_parallel/linear_parallel_operation.cpp index 95eabcfc..5151eaa9 100644 --- a/src/ops_infer/linear_parallel/linear_parallel_operation.cpp +++ b/src/ops_infer/linear_parallel/linear_parallel_operation.cpp @@ -108,7 +108,7 @@ template <> Status CreateOperation(const infer::LinearParallelParam &opParam, Op ATB_LOG(ERROR) << "LinearParallelOperation DistributedInitCheck failed."; return ERROR_INVALID_PARAM; } - int rankSize = opParam.rankSize; + uint32_t rankSize = static_cast(opParam.rankSize); if (opParam.rankSize <= 0 || (rankSize & (rankSize - 1)) != 0) { ATB_LOG(ERROR) << "LinearParallel rankSize support power of 2 but got [" << opParam.rankSize << "]"; return ERROR_INVALID_PARAM; @@ -335,7 +335,7 @@ Status LinearParallelOperation::InferShapeCheckLinearAllReduce(const SVector infer::LinearParallelParam::QuantType::QUANT_TYPE_UNQUANT && param_.quantType < infer::LinearParallelParam::QuantType::QUANT_TYPE_MAX; - if (isQuant && inTensorDescs.at(3).dtype == ACL_FLOAT && param_.outDataType == ACL_FLOAT16) { + if (isQuant && inTensorDescs.at(3).dtype == ACL_FLOAT && param_.outDataType == ACL_FLOAT16) { // 3: deqScale return ERROR_INVALID_TENSOR_INI_MATCH; } return CheckResidual(inTensorDescs); diff --git a/src/ops_infer/paged_attention/paged_attention_operation.cpp b/src/ops_infer/paged_attention/paged_attention_operation.cpp index d1fec979..37f78af8 100644 --- a/src/ops_infer/paged_attention/paged_attention_operation.cpp +++ b/src/ops_infer/paged_attention/paged_attention_operation.cpp @@ -35,7 +35,10 @@ static const int LOGN_BIT = 0x00040; static const int QKVQUANTOFFLINE_BIT = 0x00040; static const int QKVQUANTONLINE_BIT = 0x00080; static const int BLOCK_SIZE_DIM128 = 128; -static const int DIM4 = 4; +static const int DIM0 = 0; +static const int DIM1 = 1; +static const int DIM2 = 2; +static const int DIM3 = 3; static const int IN_MASK_IDX = 5; static const int MAX_BLOCK_SIZE = 256; } // namespace @@ -653,20 +656,21 @@ Status PagedAttentionOperation::MaskFreeInferShapeCheck310P(const SVector().Is310P()) { - if (inTensorDescs.at(IN_MASK_IDX).shape.dimNum != 4) { + if (inTensorDescs.at(IN_MASK_IDX).shape.dimNum != 4) { // 4: PA MASK_TYPE_MASK_FREE dimNum ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dim num should be 4"; return ERROR_INVALID_TENSOR_DIM; } - if (inTensorDescs.at(IN_MASK_IDX).shape.dims[0] != 1 || inTensorDescs.at(IN_MASK_IDX).shape.dims[1] != 8 || - inTensorDescs.at(IN_MASK_IDX).shape.dims[2] != BLOCK_SIZE_DIM128 || - inTensorDescs.at(IN_MASK_IDX).shape.dims[3] != 16) { + if (inTensorDescs.at(IN_MASK_IDX).shape.dims[DIM0] != 1 || // 1: mask dims [1,8,128,16] + inTensorDescs.at(IN_MASK_IDX).shape.dims[DIM1] != 8 || // 8: mask dims [1,8,128,16] + inTensorDescs.at(IN_MASK_IDX).shape.dims[DIM2] != BLOCK_SIZE_DIM128 || + inTensorDescs.at(IN_MASK_IDX).shape.dims[DIM3] != DIM_ALIGN_16_NZ) { ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dims should " "be [1,8,128,16]"; return ERROR_INVALID_TENSOR_DIM; } - size_t kBlockSize = inTensorDescs.at(1).shape.dims[2]; - size_t vBlockSize = inTensorDescs.at(2).shape.dims[2]; + size_t kBlockSize = inTensorDescs.at(DIM1).shape.dims[2]; // 1: k, 2: blockSize + size_t vBlockSize = inTensorDescs.at(DIM2).shape.dims[2]; // 2: v, 2: blockSize if (kBlockSize != BLOCK_SIZE_DIM128 || vBlockSize != BLOCK_SIZE_DIM128) { ATB_LOG(ERROR) << "PagedAttentionOperation intensor1 and intensor2 dim2 should be 128."; return ERROR_INVALID_PARAM; @@ -681,55 +685,54 @@ Status PagedAttentionOperation::MaskFreeInferShapeCheck310P(const SVector &inTensor) const { - if (param_.maskType == atb::infer::PagedAttentionParam::MASK_TYPE_MASK_FREE) { - if (GetSingleton().Is310P()) { - if (GetSingleton().Is310P() && - param_.maskType == atb::infer::PagedAttentionParam::MASK_TYPE_MASK_FREE) { - if (inTensor.at(IN_MASK_IDX).desc.shape.dimNum != 4) { - ATB_LOG(ERROR) - << "When maskType is mask free on Altas 300I Duo inference products, mask dim num should be 4"; - return ERROR_INVALID_TENSOR_DIM; - } - if (inTensor.at(IN_MASK_IDX).desc.shape.dims[0] != 1 || - inTensor.at(IN_MASK_IDX).desc.shape.dims[1] != 8 || - inTensor.at(IN_MASK_IDX).desc.shape.dims[2] != BLOCK_SIZE_DIM128 || - inTensor.at(IN_MASK_IDX).desc.shape.dims[3] != DIM_ALIGN_16_NZ) { - ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dims " - "should be [1,8,128,16]"; - return ERROR_INVALID_TENSOR_DIM; - } - } - if (inTensor.at(DIM4).desc.shape.dimNum == 1) { - size_t batch = inTensor.at(DIM4).desc.shape.dims[0]; - int *kSeqlenList = static_cast(inTensor[DIM4].hostData); - int *qSeqlenList = static_cast(inTensor[6].hostData); - - for (size_t i = 0; i < batch; i++) { - if (kSeqlenList[i] < qSeqlenList[i]) { - ATB_LOG(ERROR) << "PagedAttentionOperation intensor4[i] should bigger than intensor6[i]."; - return ERROR_INVALID_PARAM; - } - if ((kSeqlenList[i] - qSeqlenList[i]) % BLOCK_SIZE_DIM128 != 0) { - ATB_LOG(ERROR) - << "PagedAttentionOperation (intensor4[i] - item in intensor6[i]) % 128 should be 0. "; - return ERROR_INVALID_PARAM; - } - } - } else { - ATB_LOG(ERROR) << "PagedAttentionOperation kSeqlenList dims should be 1."; + if (param_.maskType != atb::infer::PagedAttentionParam::MASK_TYPE_MASK_FREE) { + return NO_ERROR; + } + if (!GetSingleton().Is310P()) { + ATB_LOG(ERROR) << "Only Altas 300I Duo inference products support mask free"; + return ERROR_INVALID_TENSOR_DIM; + } + if (GetSingleton().Is310P() && param_.maskType == atb::infer::PagedAttentionParam::MASK_TYPE_MASK_FREE) { + if (inTensor.at(IN_MASK_IDX).desc.shape.dimNum != 4) { // 4: PA MASK_TYPE_MASK_FREE dimNum + ATB_LOG(ERROR) + << "When maskType is mask free on Altas 300I Duo inference products, mask dim num should be 4"; + return ERROR_INVALID_TENSOR_DIM; + } + if (inTensor.at(IN_MASK_IDX).desc.shape.dims[DIM0] != 1 || // 1: mask dims [1,8,128,16] + inTensor.at(IN_MASK_IDX).desc.shape.dims[DIM1] != 8 || // 8: mask dims [1,8,128,16] + inTensor.at(IN_MASK_IDX).desc.shape.dims[DIM2] != BLOCK_SIZE_DIM128 || + inTensor.at(IN_MASK_IDX).desc.shape.dims[DIM3] != DIM_ALIGN_16_NZ) { + ATB_LOG(ERROR) << "When maskType is mask free on Altas 300I Duo inference products, mask dims " + "should be [1,8,128,16]"; + return ERROR_INVALID_TENSOR_DIM; + } + } + static const int KSEQLEN_INDEX4 = 4; + if (inTensor.at(KSEQLEN_INDEX4).desc.shape.dimNum == 1) { + size_t batch = inTensor.at(KSEQLEN_INDEX4).desc.shape.dims[0]; + int *kSeqlenList = static_cast(inTensor[KSEQLEN_INDEX4].hostData); + int *qSeqlenList = static_cast(inTensor[6].hostData); // 6: qSeqlen + + for (size_t i = 0; i < batch; i++) { + if (kSeqlenList[i] < qSeqlenList[i]) { + ATB_LOG(ERROR) << "PagedAttentionOperation intensor4[i] should bigger than intensor6[i]."; return ERROR_INVALID_PARAM; } - - size_t kBlockSize = inTensor.at(1).desc.shape.dims[2]; - size_t vBlockSize = inTensor.at(2).desc.shape.dims[2]; - if (kBlockSize != BLOCK_SIZE_DIM128 || vBlockSize != BLOCK_SIZE_DIM128) { - ATB_LOG(ERROR) << "PagedAttentionOperation intensor1 and intensor2 dim2 should be 128."; + if ((kSeqlenList[i] - qSeqlenList[i]) % BLOCK_SIZE_DIM128 != 0) { + ATB_LOG(ERROR) << "PagedAttentionOperation (intensor4[i] - item in intensor6[i]) % 128 should be 0. "; return ERROR_INVALID_PARAM; } - } else { - ATB_LOG(ERROR) << "Only Altas 300I Duo inference products support mask free"; - return ERROR_INVALID_TENSOR_DIM; } + } else { + ATB_LOG(ERROR) << "PagedAttentionOperation kSeqlenList dims should be 1."; + return ERROR_INVALID_PARAM; + } + + size_t kBlockSize = inTensor.at(1).desc.shape.dims[2]; // 1: k, 2: blockSize dim + size_t vBlockSize = inTensor.at(2).desc.shape.dims[2]; // 2: v, 2: blockSize dim + if (kBlockSize != BLOCK_SIZE_DIM128 || vBlockSize != BLOCK_SIZE_DIM128) { + ATB_LOG(ERROR) << "PagedAttentionOperation intensor1 and intensor2 dim2 should be 128."; + return ERROR_INVALID_PARAM; } return NO_ERROR; } diff --git a/src/ops_infer/paged_cache_load/paged_cache_load_operation.cpp b/src/ops_infer/paged_cache_load/paged_cache_load_operation.cpp index 4b7f8d61..63014e4d 100644 --- a/src/ops_infer/paged_cache_load/paged_cache_load_operation.cpp +++ b/src/ops_infer/paged_cache_load/paged_cache_load_operation.cpp @@ -129,12 +129,12 @@ Status PagedCacheLoadOperation::SetupCheckImpl(const SVector &inTensors, inTensorDescs.push_back(inTensors.at(i).desc); } if (param_.kvCacheCfg == infer::PagedCacheLoadParam::KvCacheCfg::K_CACHE_V_CACHE_NZ) { // NZ - int64_t AlignCacheK = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM1] * + int64_t alignCacheK = inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM1] * inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM3]; - int64_t AlignCacheV = inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM1] * + int64_t alignCacheV = inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM1] * inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM3]; - if (outTensors.at(IN_TENSOR_0_KEYCACHE).desc.shape.dims[DIM1] != AlignCacheK || - outTensors.at(IN_TENSOR_1_VALUECACHE).desc.shape.dims[DIM1] != AlignCacheV) { + if (outTensors.at(IN_TENSOR_0_KEYCACHE).desc.shape.dims[DIM1] != alignCacheK || + outTensors.at(IN_TENSOR_1_VALUECACHE).desc.shape.dims[DIM1] != alignCacheV) { ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of outTensors needs to remain aligned"; return ERROR_INVALID_TENSOR_DIM; } @@ -235,22 +235,22 @@ Status PagedCacheLoadOperation::KVCacheDimCheck910BNZ(const SVector return ERROR_INVALID_TENSOR_DIM_NUM; } if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).dtype == ACL_INT8) { - if (THIRTYTWO != inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[OUT_DIM] || - THIRTYTWO != inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[OUT_DIM]) { // 1: valueCache + if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[OUT_DIM] != THIRTYTWO || + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[OUT_DIM] != THIRTYTWO) { // 1: valueCache ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of keycache and valuecache must be 32"; return ERROR_INVALID_TENSOR_DIM; } - if (MAX_K < inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM1] * THIRTYTWO || - MAX_V < inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM1] * THIRTYTWO) { + if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM1] * THIRTYTWO > MAX_K || + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM1] * THIRTYTWO > MAX_V) { ATB_LOG(ERROR) << GetLogPrefix() << "The scend dimension of blocktables must be less than 147456"; return ERROR_INVALID_TENSOR_DIM; } - } else if (SIXTEEN != inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[OUT_DIM] || - SIXTEEN != inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[OUT_DIM]) { // 1: valueCache + } else if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[OUT_DIM] != SIXTEEN || + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[OUT_DIM] != SIXTEEN) { // 1: valueCache ATB_LOG(ERROR) << GetLogPrefix() << "The last dimension of keycache and valuecache must be 16"; return ERROR_INVALID_TENSOR_DIM; - } else if (MAX_K < inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM1] * SIXTEEN || - MAX_V < inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM1] * SIXTEEN) { + } else if (inTensorDescs.at(IN_TENSOR_0_KEYCACHE).shape.dims[DIM1] * SIXTEEN > MAX_K || + inTensorDescs.at(IN_TENSOR_1_VALUECACHE).shape.dims[DIM1] * SIXTEEN > MAX_V) { ATB_LOG(ERROR) << GetLogPrefix() << "The scend dimension of blocktables must be less than 147456"; return ERROR_INVALID_TENSOR_DIM; } diff --git a/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp b/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp index 873e9531..5903ac49 100644 --- a/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp +++ b/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp @@ -41,7 +41,7 @@ static const uint32_t LOG_PROBS_OUT_TENSOR_INDEX = 2; static const uint32_t LOG_PROBS_OUT_TENSOR_DIM = 2; static const uint32_t LAST_DIM = 1; -using atbInferTopkToppSamplingType = atb::infer::TopkToppSamplingParam::TopkToppSamplingType; +using AtbInferTopkToppSamplingType = atb::infer::TopkToppSamplingParam::TopkToppSamplingType; bool ParamCheck(const atb::infer::TopkToppSamplingParam &opParam) { @@ -60,15 +60,15 @@ OPERATION_PARAM_FUNCS(TopkToppSamplingOperation, infer::TopkToppSamplingParam) static Mki::OperationIr *GetOperationIrForTopkToppSampling(const infer::TopkToppSamplingParam ¶m) { switch (param.topkToppSamplingType) { - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingBatchTopKExpOperation"); - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingBatchTopKLogProbsExpOperation"); - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingBatchTopKMultiOperation"); - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingBatchTopKLogProbsMultiOperation"); - case atbInferTopkToppSamplingType::SINGLE_TOPK_SAMPLING: + case AtbInferTopkToppSamplingType::SINGLE_TOPK_SAMPLING: return GetSingleton().GetOperationIr("TopkToppSamplingSingleTopKOperation"); default: ATB_LOG(ERROR) << "UnSupported TopkToppSamplingType: " << param.topkToppSamplingType; @@ -89,15 +89,15 @@ TopkToppSamplingOperation::~TopkToppSamplingOperation() {} uint32_t TopkToppSamplingOperation::GetInputNum() const { switch (param_.topkToppSamplingType) { - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_SAMPLING: return BATCH_TOPK_EXP_IN_TENSOR_NUM; - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_SAMPLING: return BATCH_TOPK_MULTI_IN_TENSOR_NUM; - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: return BATCH_TOPK_EXP_IN_TENSOR_NUM; - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: return BATCH_TOPK_MULTI_LOGPROBS_IN_TENSOR_NUM; - case atbInferTopkToppSamplingType::SINGLE_TOPK_SAMPLING: + case AtbInferTopkToppSamplingType::SINGLE_TOPK_SAMPLING: return SINGLE_TOPK_IN_TENSOR_NUM; default: ATB_LOG(ERROR) << GetLogPrefix() << "UnSupported TopkToppSamplingType: " << param_.topkToppSamplingType; @@ -108,9 +108,9 @@ uint32_t TopkToppSamplingOperation::GetInputNum() const uint32_t TopkToppSamplingOperation::GetOutputNum() const { switch (param_.topkToppSamplingType) { - case atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING: return LOG_PROBS_OUT_TENSOR_NUM; - case atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: + case AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING: return LOG_PROBS_OUT_TENSOR_NUM; default: return OUT_TENSOR_NUM; @@ -128,8 +128,8 @@ Status TopkToppSamplingOperation::InferShapeImpl(const SVector &inTe outTensorDescs.at(1) = inTensorDescs.at(0); outTensorDescs.at(1).shape.dims[dimNum - 1] = 1; - if (param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || - param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { + if (param_.topkToppSamplingType == AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || + param_.topkToppSamplingType == AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { outTensorDescs.at(OUT_TENSOR_LOGPROBS) = inTensorDescs.at(0); outTensorDescs.at(OUT_TENSOR_LOGPROBS).dtype = ACL_FLOAT; outTensorDescs.at(OUT_TENSOR_LOGPROBS).shape.dims[dimNum - 1] = param_.logProbsSize; @@ -250,15 +250,15 @@ Status TopkToppSamplingOperation::CheckIntensorAndParam(const SVector &inTensor return ERROR_INVALID_TENSOR_DIM; } } - if (param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || - param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { + if (param_.topkToppSamplingType == AtbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || + param_.topkToppSamplingType == AtbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { Status logProbsOutTensorCheckRes = TopkToppLogProbsOutTensorCheck(outTensorDescs); if (logProbsOutTensorCheckRes != NO_ERROR) { return logProbsOutTensorCheckRes; diff --git a/src/torch_atb/resource/utils.h b/src/torch_atb/resource/utils.h index 3be3a691..ef23ef83 100644 --- a/src/torch_atb/resource/utils.h +++ b/src/torch_atb/resource/utils.h @@ -26,4 +26,4 @@ aclrtStream GetCurrentStream(); } // namespace Utils } // namespace TorchAtb -#endif // TORCH_ATB_UTILS_H \ No newline at end of file +#endif // TORCH_ATB_UTILS_H -- Gitee From 4a2fada96566ce6629e8769729124ab91cb1d510 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Tue, 15 Jul 2025 14:14:41 +0800 Subject: [PATCH 2/2] style delete redundant code --- docs/conf.py | 3 ++- src/ops_infer/ring_mla/ring_mla_ops_runner.cpp | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 1305c2b0..6ec0a9ca 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -18,7 +18,8 @@ import subprocess -branch = subprocess.check_output(["/bin/bash", "-c", "git symbolic-ref -q --short HEAD || git describe --tags --exact-match 2> /dev/null || git rev-parse HEAD"]).strip().decode() +branch = subprocess.check_output(["/bin/bash", "-c", "git symbolic-ref -q --short HEAD || git describe --tags \ + --exact-match 2> /dev/null || git rev-parse HEAD"]).strip().decode() project = 'Ascend Transformer Boost Guidebook' author = 'Ascend' copyright = '2024, Ascend. This work is licensed under a Creative Commons Attribution 4.0 International License' diff --git a/src/ops_infer/ring_mla/ring_mla_ops_runner.cpp b/src/ops_infer/ring_mla/ring_mla_ops_runner.cpp index 61b4076e..94a8ae40 100644 --- a/src/ops_infer/ring_mla/ring_mla_ops_runner.cpp +++ b/src/ops_infer/ring_mla/ring_mla_ops_runner.cpp @@ -63,8 +63,6 @@ RingMLAOpsRunner::RingMLAOpsRunner(const infer::RingMLAParam ¶m) ringMLANode.opDesc = {0, "RINGMLAOperation", ringMLAParam}; - // flashAttentionEncoderNode.inTensors = {&querySplit1, querySplit2, &keySplit1, keySplit2, value, - // mask, slopes, qkDescale, qkOffset, vpvDescale, vpvOffset, pScale, logN, prevOut, prevLse}; ringMLANode.inTensors = {querySplit1, querySplit2, keySplit1, keySplit2, value, mask, &nullTensor_, &nullTensor_, &nullTensor_, &nullTensor_, &nullTensor_, &nullTensor_, &nullTensor_, prevOut, prevLse}; -- Gitee