diff --git a/src/atb/kernel_cache/aclnn_executor_cache.cpp b/src/atb/kernel_cache/aclnn_executor_cache.cpp index 8bacdec9ec3f0ee88f8737323bbd21ca6fc5cff2..bf6bf2a50a1bcc87ba83cfbaf9033f2050d9e82a 100644 --- a/src/atb/kernel_cache/aclnn_executor_cache.cpp +++ b/src/atb/kernel_cache/aclnn_executor_cache.cpp @@ -81,10 +81,11 @@ Status AclnnExecutorCache::AddCacheSlot(const std::string &opNameStr, const Runn if (slotVecSize < cacheCapacity_) { slotVec.emplace_back(aclnnCacheKey, inAclnnCacheSlot); ATB_LOG(INFO) << "ATB aclnn executor cache add op: " << opNameStr << " at index[" << slotVecSize << "]"; + return NO_ERROR; } // 淘汰方式:使用等长vector+FIFO - ATB_LOG(INFO) << "ATB aclnn executor cache full for op: " << opNameStr << "update index [" << nextUpdateIndex_ + ATB_LOG(INFO) << "ATB aclnn executor cache full for op: " << opNameStr << ", update index [" << nextUpdateIndex_ << "]"; cachePool_[opNameStr][nextUpdateIndex_] = std::make_pair(aclnnCacheKey, inAclnnCacheSlot); nextUpdateIndex_ = (nextUpdateIndex_ + 1) % cacheCapacity_; diff --git a/src/atb/runner/aclnn_runner.cpp b/src/atb/runner/aclnn_runner.cpp index 10c9e64e34b0fa7dfbd49694888c846629cae09b..268e8f19fe60c2b7e0d79ddf149d6fba3b8fd696 100644 --- a/src/atb/runner/aclnn_runner.cpp +++ b/src/atb/runner/aclnn_runner.cpp @@ -58,9 +58,9 @@ Status AclnnRunner::SetupImpl(RunnerVariantPack &runnerVariantPack) ATB_LOG(ERROR) << GetLogPrefix() << "Atb aclnn op set workspace failed with return value: " << aclnnRet; return ERROR_CANN_ERROR; } - ATB_LOG(INFO) << GetLogPrefix() - << "getWorkspace success, workspaceSize: " << this->atbVariantPack_.workspaceBufferSize - << ", workspace addr: " << this->atbVariantPack_.workspaceBuffer; + ATB_LOG(INFO) << GetLogPrefix() << "getWorkspace success, workspace addr: " + << reinterpret_cast(this->atbVariantPack_.workspaceBuffer) + << ", workspaceSize: " << this->atbVariantPack_.workspaceBufferSize; aclnnRet = aclSetAclOpExecutorRepeatable(this->aclnnExecutor_.get()); if (aclnnRet != 0) { // 设置算子可复用失败,标记cache中executor不可复用 @@ -90,6 +90,10 @@ Status AclnnRunner::PreExecuteImpl(RunnerVariantPack &runnerVariantPack) { ATB_LOG(INFO) << GetLogPrefix() << "AclNNOpCacheUpdateAclNNVariantPack"; for (size_t i = 0; i < this->aclnnVariantPack_.aclInTensors.size(); ++i) { + // 部分场景中存在aclnn接口使用空tensor占位最后可选tensor,但是runnerVariantPack中不存放tensor的情况,可以跳过 + if (i >= runnerVariantPack.inTensors.size()) { + break; + } int ret = -1; if (!this->aclnnVariantPack_.aclInTensors[i]->needUpdateTensorDataPtr) { continue; @@ -113,6 +117,9 @@ Status AclnnRunner::PreExecuteImpl(RunnerVariantPack &runnerVariantPack) } for (size_t i = 0; i < this->aclnnVariantPack_.aclOutTensors.size(); ++i) { + if (i >= runnerVariantPack.outTensors.size()) { + break; + } int ret = -1; if (!this->aclnnVariantPack_.aclOutTensors[i]->needUpdateTensorDataPtr) { continue; @@ -148,6 +155,7 @@ void AclnnRunner::UpdateWorkspace(const RunnerVariantPack &runnerVariantPack) Status AclnnRunner::ExecuteImpl(RunnerVariantPack &runnerVariantPack) { + ATB_LOG(INFO) << GetLogPrefix() << "AclnnRunner::ExecuteImpl"; UpdateWorkspace(runnerVariantPack); return LaunchAclnnKernel(); } diff --git a/src/atb/utils/aclnn_util.cpp b/src/atb/utils/aclnn_util.cpp index dcbd1ac7d42609b5acc27ba887dd6c988e7f9d23..f4ef020450c59d11fb27a3c81b8928b740a1cd2b 100644 --- a/src/atb/utils/aclnn_util.cpp +++ b/src/atb/utils/aclnn_util.cpp @@ -21,12 +21,11 @@ const int DIM0 = 0; const int DIM1 = 1; const int DIM2 = 2; const int DIM3 = 3; -} // namespace +} // namespace namespace atb { -template -typename std::common_type::type CheckIntMulOverFlow(const T a, const U b) +template typename std::common_type::type CheckIntMulOverFlow(const T a, const U b) { if (std::is_signed::value != std::is_signed::value) { throw std::runtime_error("Multiplication between signed and unsigned integer not supported, it's not safe"); @@ -69,7 +68,7 @@ typename std::common_type::type CheckIntMulOverFlow(const T a, const U b) atb::SVector GetCopyTensorStride(atb::Dims &tensorDims) { atb::SVector tmpStrides(tensorDims.dimNum, 1); - if (tensorDims.dimNum > 8) { // 8: tensor最大维度数量 + if (tensorDims.dimNum > 8) { // 8: tensor最大维度数量 ATB_LOG(ERROR) << "Tensor's dimNum is larger than 8, `GetCopyTensorStride` failed."; return tmpStrides; } @@ -83,26 +82,21 @@ atb::SVector GetTransposeTensorStride(atb::Dims &tensorDims) { atb::SVector tmptransposeStrides(tensorDims.dimNum, 1); tmptransposeStrides[tensorDims.dimNum - 1] = tensorDims.dims[tensorDims.dimNum - 1]; - if (tensorDims.dimNum == 3) { // 3: 维度 - tmptransposeStrides[0] = CheckIntMulOverFlow( // 0: 第0维 - tensorDims.dims[1], - tensorDims.dims[2]); // 1, 2: 跳过第1维和第2维的大小 + if (tensorDims.dimNum == 3) { // 3: 维度 + tmptransposeStrides[0] = CheckIntMulOverFlow(tensorDims.dims[1], tensorDims.dims[2]); // 1, 2: 跳过第1维和第2维的大小 } return tmptransposeStrides; } -atb::Status CallAclCreateTensor( - atb::Dims &viewDims, atb::Dims &storageDims, atb::Tensor &atbTensor, std::shared_ptr aclnnTensor) +atb::Status CallAclCreateTensor(atb::Dims &viewDims, atb::Dims &storageDims, atb::Tensor &atbTensor, + std::shared_ptr aclnnTensor, aclDataType dataType) { - aclnnTensor->tensor = aclCreateTensor(viewDims.dims, - viewDims.dimNum, - atbTensor.desc.dtype, - aclnnTensor->strides.data(), - 0, - atbTensor.desc.format, - storageDims.dims, - storageDims.dimNum, - atbTensor.deviceData); + if (dataType == ACL_DT_UNDEFINED) { + dataType = atbTensor.desc.dtype; + } + aclnnTensor->tensor = + aclCreateTensor(viewDims.dims, viewDims.dimNum, dataType, aclnnTensor->strides.data(), 0, atbTensor.desc.format, + storageDims.dims, storageDims.dimNum, atbTensor.deviceData); if (aclnnTensor->tensor == nullptr) { return atb::ERROR_INTERNAL_ERROR; } @@ -129,7 +123,7 @@ std::string PrintAclNNVariankPack(const AclNNVariantPack &aclnnVariantPack) ss << "index " << i << " dtype " << tensorDesc.dtype << " format " << tensorDesc.format << " dimNum " << tensorDesc.shape.dimNum; for (uint64_t j = 0; j < std::min(tensorDesc.shape.dimNum, static_cast(8)); - j++) { // 8: tensor最大维度数量 + j++) { // 8: tensor最大维度数量 ss << "dim[" << j << "]=" << tensorDesc.shape.dims[j] << " "; } } @@ -145,7 +139,7 @@ std::string PrintATBVariankPack(const atb::VariantPack &atbVariantPack) ss << "index " << i << " dtype " << tensorDesc.dtype << " format " << tensorDesc.format << " dimNum " << tensorDesc.shape.dimNum; for (uint64_t j = 0; j < std::min(tensorDesc.shape.dimNum, static_cast(8)); - j++) { // 8: tensor最大维度数量 + j++) { // 8: tensor最大维度数量 ss << "dim[" << j << "]=" << tensorDesc.shape.dims[j] << " "; } } @@ -165,7 +159,7 @@ bool IsHostDataEqual(const std::shared_ptr tensorA, const atb::Tens return false; } if (tensorA->intArrayHostData.intArray != nullptr && tensorB.hostData != nullptr) { - if (tensorA->intArrayHostData.dataOri.size() * 4 != tensorB.dataSize) { // 8: int64_t in bytes + if (tensorA->intArrayHostData.dataOri.size() * 4 != tensorB.dataSize) { // 8: int64_t in bytes ATB_LOG(DEBUG) << "ATB aclnn Op Cache: tensor index " << tensorIdx << " dataSize not equal"; return false; } @@ -192,7 +186,7 @@ bool IsTensorDescEqual(const atb::TensorDesc &tensorDescA, const atb::TensorDesc return false; } if (tensorDescA.shape.dimNum != tensorDescB.shape.dimNum || tensorDescA.shape.dimNum > 8 || - tensorDescA.shape.dimNum <= 0) { // 8: tensor最大维度数量 + tensorDescA.shape.dimNum <= 0) { // 8: tensor最大维度数量 ATB_LOG(DEBUG) << "ATB aclnn Op Cache: tensor index " << tensorIdx << " dimNum not equal, aclnnVariantPack dimNum " << tensorDescA.shape.dimNum << " atbVariantPack dimNum " << tensorDescB.shape.dimNum; @@ -209,8 +203,8 @@ bool IsTensorDescEqual(const atb::TensorDesc &tensorDescA, const atb::TensorDesc return true; } -bool AreTensorVectorsEqual( - const atb::SVector> &aclnnTensors, const atb::SVector &atbTensors) +bool AreTensorVectorsEqual(const atb::SVector> &aclnnTensors, + const atb::SVector &atbTensors) { // Check the size of two vectors if (aclnnTensors.size() != atbTensors.size()) { @@ -275,18 +269,18 @@ std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) aclnnTensor->atbTensor = atbTensor; aclnnTensor->tensorIdx = tensorIdx; aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); - CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensor); + CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensor, atbTensor.desc.dtype); return aclnnTensor; } int ConvertTensorToSeqLengths(atb::Tensor &tensor, aclIntArray *&actualSeqLengths) { static std::vector seqLenCache; - size_t dataSize = tensor.dataSize / 8; // 8: int64 size + size_t dataSize = tensor.dataSize / 8; // 8: int64 size if (seqLenCache.size() < dataSize) { seqLenCache.resize(dataSize); } - if (memcpy_s(seqLenCache.data(), dataSize * 8, tensor.hostData, dataSize * 8) != 0) { // 8: int64 size + if (memcpy_s(seqLenCache.data(), dataSize * 8, tensor.hostData, dataSize * 8) != 0) { // 8: int64 size ATB_LOG(ERROR) << "memcpy_s failed!"; return atb::ERROR_INTERNAL_ERROR; } @@ -297,4 +291,4 @@ int ConvertTensorToSeqLengths(atb::Tensor &tensor, aclIntArray *&actualSeqLength actualSeqLengths = aclCreateIntArray(static_cast(seqLenCache.data()), dataSize); return atb::NO_ERROR; } -} // namespace atb +} // namespace atb diff --git a/src/atb/utils/aclnn_util.h b/src/atb/utils/aclnn_util.h index 9806541ee379cb87d9fbb869ac21184c31bd74d2..fbcd3f0aa5ebaf9a713e8dd7d12196c308be3b76 100644 --- a/src/atb/utils/aclnn_util.h +++ b/src/atb/utils/aclnn_util.h @@ -42,9 +42,10 @@ atb::SVector GetTransposeTensorStride(atb::Dims &tensorDims); /// \param atbTensor The tensor passed through ATB framework. /// \param aclnnTensor A pointer to an `AclNNTensor` object whose `tensor` attribute is updated /// using the return value of `aclCreateTensor`. +/// \param dataType the dataType of the tensor, used for specifying empty tensors. /// \return A status code that indicates whether `aclTensor` has been created. atb::Status CallAclCreateTensor(atb::Dims &viewDims, atb::Dims &storageDims, atb::Tensor &atbTensor, - std::shared_ptr aclnnTensor); + std::shared_ptr aclnnTensor, aclDataType dataType = ACL_DT_UNDEFINED); /// Reshape a tensor by squeezing batch size axis and seq len axis if the tensor's shape has two dimensions. /// @@ -72,7 +73,8 @@ bool IsAclnnAtbVariankPackEqual(const AclNNVariantPack &aclnnVariantPack, const /// \param atbVariantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. /// \return A boolean value that indicates whether `aclnnVariantPack` and `atbVariantPack` are the same, /// except for tensors' device data. -bool IsAclnnRunnerVariankPackEqual(const AclNNVariantPack &aclnnVariantPack, const RunnerVariantPack &runnerVariantPack); +bool IsAclnnRunnerVariankPackEqual(const AclNNVariantPack &aclnnVariantPack, + const RunnerVariantPack &runnerVariantPack); /// Create a pointer to `AclNNTensor` by configuring it with tensor information extracted from `atbTensor`. /// diff --git a/src/ops/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp b/src/ops/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp index 4c5a102ecab9cd14e073add8b9def7beb4ef0abc..be6f1124c66fc3dfa21fe312f8960581903db173 100644 --- a/src/ops/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp +++ b/src/ops/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp @@ -16,9 +16,13 @@ namespace { static const uint32_t IN_TENSOR_NUM = 24; static const uint32_t OUT_TENSOR_NUM = 4; +static const uint32_t INPUT_INDEX = 0; static const uint32_t Q_OUT1_INDEX = 2; static const uint32_t KV_CACHE_OUT1_INDEX = 3; static const uint32_t KV_CACHE_ROPE_INDEX = 20; +static const uint32_t CTKV_SCALE_INDEX = 22; +static const uint32_t Q_NOPE_SCALE_INDEX = 23; + } // namespace namespace atb { @@ -51,22 +55,35 @@ MlaPreprocessAclnnRunner::~MlaPreprocessAclnnRunner() {} Status MlaPreprocessAclnnRunner::BuildAclnnVariantPack(const RunnerVariantPack &runnerVariantPack) { ATB_LOG(INFO) << GetLogPrefix() << "BuildAclnnVariantPack"; + ATB_LOG(INFO) << GetLogPrefix() << "variantPack: " << runnerVariantPack.ToString(); this->atbVariantPack_ = runnerVariantPack; Status ret = NO_ERROR; bool isRopeCache = param_.cacheMode != infer::MlaPreprocessParam::CacheMode::KVCACHE; this->aclnnVariantPack_.aclInTensors.reserve(IN_TENSOR_NUM); this->aclnnVariantPack_.aclInTensors.resize(IN_TENSOR_NUM); + aclDataType inputDataType = runnerVariantPack.inTensors.at(INPUT_INDEX).desc.dtype; for (size_t i = 0; i < this->aclnnVariantPack_.aclInTensors.size(); ++i) { + ATB_LOG(INFO) << GetLogPrefix() << "MlaPreprocessAclnnRunner::BuildAclnnVariantPack inTensor index: " << i; std::shared_ptr aclnnTensorPtr = std::make_shared(); + atb::Tensor atbTensor = runnerVariantPack.inTensors.at(i); if (i == KV_CACHE_ROPE_INDEX && !isRopeCache) { - // kvCache不带rope转置时kvCacheRope为nullptr - this->aclnnVariantPack_.aclInTensors[i] = aclnnTensorPtr; - continue; + // kvCache不带rope转置时kvCacheRope为空tensor + TensorDesc desc = {}; + desc.dtype = runnerVariantPack.inTensors.at(INPUT_INDEX).desc.dtype; + desc.format = runnerVariantPack.inTensors.at(INPUT_INDEX).desc.format; + atbTensor.desc = desc; } - atb::Tensor atbTensor = runnerVariantPack.inTensors.at(i); aclnnTensorPtr->atbTensor = atbTensor; aclnnTensorPtr->strides = GetCopyTensorStride(atbTensor.desc.shape); - ret = CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensorPtr); + if (param_.cacheMode != infer::MlaPreprocessParam::CacheMode::INT8_NZCACHE && + (i == CTKV_SCALE_INDEX || i == Q_NOPE_SCALE_INDEX)) { + // 非KROPE_CTKV场景,不传入ctkvScale和qNopeScale,使用inputDataType绕过空tensor的dtype检测 + ret = CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensorPtr, + inputDataType); + } else { + ret = CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensorPtr, + atbTensor.desc.dtype); + } if (ret != NO_ERROR) { ATB_LOG(ERROR) << GetLogPrefix() << "create aclTensor by aclCreateTensor failed!"; return ret; @@ -80,15 +97,22 @@ Status MlaPreprocessAclnnRunner::BuildAclnnVariantPack(const RunnerVariantPack & this->aclnnVariantPack_.aclOutTensors.resize(OUT_TENSOR_NUM); for (size_t i = 0; i < this->aclnnVariantPack_.aclOutTensors.size(); ++i) { std::shared_ptr aclnnTensorPtr = std::make_shared(); - if ((i == Q_OUT1_INDEX || i == KV_CACHE_OUT1_INDEX) && !isRopeCache) { + ATB_LOG(INFO) << GetLogPrefix() << "MlaPreprocessAclnnRunner::BuildAclnnVariantPack outTensor index: " << i; + atb::Tensor atbTensor = {}; + if ((i != Q_OUT1_INDEX && i != KV_CACHE_OUT1_INDEX) || isRopeCache) { + atbTensor = runnerVariantPack.outTensors.at(i); + } else { // kvCache不带rope转置时不生成2个rope分量 - this->aclnnVariantPack_.aclOutTensors[i] = aclnnTensorPtr; - continue; + // 使用input的dtype和format填补空tensor + TensorDesc desc = {}; + desc.dtype = runnerVariantPack.inTensors.at(INPUT_INDEX).desc.dtype; + desc.format = runnerVariantPack.inTensors.at(INPUT_INDEX).desc.format; + atbTensor.desc = desc; } - atb::Tensor atbTensor = runnerVariantPack.outTensors.at(i); aclnnTensorPtr->atbTensor = atbTensor; aclnnTensorPtr->strides = GetCopyTensorStride(atbTensor.desc.shape); - ret = CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensorPtr); + ret = CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensorPtr, + atbTensor.desc.dtype); if (ret != NO_ERROR) { ATB_LOG(ERROR) << GetLogPrefix() << "create aclTensor by aclCreateTensor failed!"; return ret; @@ -135,9 +159,6 @@ aclnnStatus MlaPreprocessAclnnRunner::SetAclNNWorkspaceExecutor() aclTensor *wuk = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; aclTensor *kvCache = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; aclTensor *kRope = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; - if (!isRopeCache) { - kRope = nullptr; - } aclTensor *slotmapping = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; aclTensor *ctkvScale = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; aclTensor *qNopeScale = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; @@ -146,10 +167,8 @@ aclnnStatus MlaPreprocessAclnnRunner::SetAclNNWorkspaceExecutor() aclTensor *kvCacheOut0 = this->aclnnVariantPack_.aclOutTensors.at(outTensorStart++)->tensor; aclTensor *qOut1 = nullptr; aclTensor *kvCacheOut1 = nullptr; - if (isRopeCache) { - qOut1 = this->aclnnVariantPack_.aclOutTensors.at(outTensorStart++)->tensor; - kvCacheOut1 = this->aclnnVariantPack_.aclOutTensors.at(outTensorStart++)->tensor; - } + qOut1 = this->aclnnVariantPack_.aclOutTensors.at(outTensorStart++)->tensor; + kvCacheOut1 = this->aclnnVariantPack_.aclOutTensors.at(outTensorStart++)->tensor; aclOpExecutor *raw_executor_ptr = this->aclnnExecutor_.get(); ATB_LOG(INFO) << GetLogPrefix() << "&(this->aclnnExecutor_): " << &(this->aclnnExecutor_) @@ -178,7 +197,7 @@ aclnnStatus MlaPreprocessAclnnRunner::SetAclNNWorkspaceExecutor() Status MlaPreprocessAclnnRunner::LaunchAclnnKernel() { - ATB_LOG(INFO) << GetLogPrefix() << " execute start."; + ATB_LOG(INFO) << GetLogPrefix() << "LaunchAclnnKernel execute start."; Status status = MlaPreprocessAclnnRunner::LoadMethod(); if (status != NO_ERROR) { ATB_LOG(ERROR) << GetLogPrefix() @@ -193,7 +212,7 @@ Status MlaPreprocessAclnnRunner::LaunchAclnnKernel() ATB_LOG(ERROR) << GetLogPrefix() << "Atb aclnn op kernel launch failed with return value: " << ret; return ERROR_CANN_ERROR; } - ATB_LOG(INFO) << GetLogPrefix() << " execute success."; + ATB_LOG(INFO) << GetLogPrefix() << "LaunchAclnnKernel execute success."; return NO_ERROR; } diff --git a/src/ops/ops_infer/mla_preprocess/mla_preprocess_operation.cpp b/src/ops/ops_infer/mla_preprocess/mla_preprocess_operation.cpp index 84020070494783142e6cfcc945b4635cadc8dbef..d70d2bd9a217b857f46b1dc155f08087878c61e2 100644 --- a/src/ops/ops_infer/mla_preprocess/mla_preprocess_operation.cpp +++ b/src/ops/ops_infer/mla_preprocess/mla_preprocess_operation.cpp @@ -428,6 +428,12 @@ Status MlaPreprocessOperation::CheckAclnnKernel(const SVector &inTen return NO_ERROR; } useAclnnKernel_ = true; + if (param_.quantMode != infer::MlaPreprocessParam::QuantMode::PER_TENSOR_QUANT_ASYMM) { + ATB_LOG(INFO) << GetLogPrefix() + << "aclnn mlaPreprocess only supports quantMode as PER_TENSOR_QUANT_ASYMM, but got: " + << param_.quantMode; + return ERROR_INVALID_PARAM; + } Status ret = MlaPreprocessAclnnRunner::LoadMethod(); ATB_LOG(INFO) << GetLogPrefix() << "MlaPreprocessAclnnRunner::LoadMethod() ret: " << ret; if (ret != NO_ERROR) { @@ -444,7 +450,8 @@ Status MlaPreprocessOperation::CheckAclnnKernel(const SVector &inTen return ret; } } - ATB_LOG(INFO) << GetLogPrefix() << "aclnn kernel is required and usable, generalizedHiddenSize: " << generalizedHiddenSize + ATB_LOG(INFO) << GetLogPrefix() + << "aclnn kernel is required and usable, generalizedHiddenSize: " << generalizedHiddenSize << ", doRmsNorm: " << doRmsNorm_; return NO_ERROR; } diff --git a/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py b/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py index 75aab78edd4ce7e42662cf3575efd40a89eb0b76..f0edd19d6867be40c47f98f81bff06cf7f7ee2c5 100644 --- a/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py +++ b/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py @@ -205,6 +205,9 @@ class TestMLAPrepross(operation_test.OperationTest): self.headNum = headNum self.epsilon = 1e-6 self.dtype = data_type + self.out_data_type = 1 + if data_type == torch.bfloat16: + self.out_data_type = 27 self.input1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(N, hiddenStrate))).to(data_type)# self.gamma1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(hiddenStrate))).to(data_type) @@ -267,7 +270,7 @@ class TestMLAPrepross(operation_test.OperationTest): out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() for i in out_tensors] op_name = "LinearOperation" - op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": 1}) + op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": self.out_data_type}) self.operation = torch.classes.OperationTorch.OperationTorch(op_name) self.op_param = op_param self.operation.set_param(op_param) @@ -319,261 +322,7 @@ class TestMLAPrepross(operation_test.OperationTest): mm2out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() for i in out_tensors] op_name = "LinearOperation" - op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": 1}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, mm2out_tensors_npu) - self.mm2Out_npu = mm2out_tensors_npu[0].cpu().clone() - print(self.mm2Out_npu.size()) - - # # ##SplitV - print("====================SplitV1====================") - splitSize = [128, 64] - splitVDim = 2 - in_tensors = mm2out_tensors_npu[0].reshape(N, headNum, 192) - in_tensors_npu = in_tensors.npu() - Split2_out_tensors_npu = [] - shape = in_tensors_npu.shape - in_tensors_npu = [in_tensors_npu] - for size in splitSize: - slice_shape = list(shape) - slice_shape[splitVDim] = size - Split2_out_tensors_npu.append(torch.zeros(slice_shape, dtype=data_type).npu()) - op_name = "SplitOperation" - op_param =json.dumps({"splitNum": 2, "splitDim": splitVDim, "splitSizes": splitSize}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, Split2_out_tensors_npu) - print(Split2_out_tensors_npu[0].size()) - - #EinSum - print("====================EinSum====================") - # EinSumInput = torch.transpose(Split2_out_tensors_npu[0], 0, 1) - self.trans_A, self.trans_B = False, False - bsize, msize, ksize, nsize = headNum, N, 128, 512 - in_tensors = [Split2_out_tensors_npu[0], self.wuk] - Einsumout_tensors = [torch.zeros((N, headNum, 512), dtype=data_type)] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - Einsumout_tensors_npu = [tensor.npu() for tensor in Einsumout_tensors] - op_name = "LinearOperation" - op_param =json.dumps({"matmulType": 1, "transposeA": False, "transposeB": False, "hasBias": False, "outDataType": -1}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, Einsumout_tensors_npu) - - # Einsumout_tensors_npu = [torch.transpose(Einsumout_tensors_npu[0], 0, 1)] - self.einsumOut = Einsumout_tensors_npu[0].cpu().clone() - print(self.einsumOut.size()) - # RopeQConcat - print("====================RopeQConcat====================") - self.qOut_npu = torch.zeros((N, headNum, 576), dtype=data_type) - Split2_out_tensors_npu[1] = Split2_out_tensors_npu[1].cpu().reshape(N, headNum * 64) - in_tensors = [Split2_out_tensors_npu[1], self.cos2, self.sin2, Einsumout_tensors_npu[0]] - qOut_tensors = [self.qOut_npu] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - qout_tensors_npu = [qOut_tensors[i] if isinstance(i, int) else i.npu() - for i in qOut_tensors] - op_name = "RopeQConcatOperation" - op_param =json.dumps({}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, qout_tensors_npu) - qout_tensors_npu[0] = torch.cat([qout_tensors_npu[0].cpu()[:,:,64:],qout_tensors_npu[0].cpu()[:,:,:64]], dim = 2) - self.qOut_npu = qout_tensors_npu[0].cpu().clone() - print(self.qOut_npu.size()) - # #Rope - print("====================Rope====================") - rotaryCoeff = 2 - self.RopeOut0 = torch.zeros((N, 64), dtype=data_type) - self.RopeOut1 = torch.zeros((N, 64), dtype=data_type) - seqlen = torch.randint(1, 2, (N,), dtype=torch.int32) - in_tensors = [Split1_out_tensors_npu[1].reshape(N, 1 * 64), Split1_out_tensors_npu[1].reshape(N, 1 * 64), self.cos1, self.sin1, seqlen] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - Ropeout_tensors_npu = [self.RopeOut0.npu(), self.RopeOut1.npu()] - - op_name = "RopeOperation" - op_param =json.dumps({"rotaryCoeff": rotaryCoeff}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, Ropeout_tensors_npu) - print(Ropeout_tensors_npu[0].size()) - - #RmsNorm - print("====================RmsNorm2====================") - in_tensors = [Split1_out_tensors_npu[0].reshape(N, 1, 512), self.gamma3] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - self.RmsNormOut = torch.zeros((N, 1, 512), dtype=data_type).npu() - RmsNormOut_tensors_npu = [self.RmsNormOut] - - op_name = "RmsNormOperation" - op_param =json.dumps({"layerType":1, "normParam":{"epsilon": 1e-6}}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, RmsNormOut_tensors_npu) - - out_tensors_npu = RmsNormOut_tensors_npu - #Concat - print("====================Concat====================") - in_tensors = [RmsNormOut_tensors_npu[0], Ropeout_tensors_npu[0].cpu().reshape(N, 1, 64)] - ConCat2out_tensors = [torch.zeros((N, 1, 576), dtype=data_type)] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - ConCat2out_tensors_npu = [tensor.npu() for tensor in ConCat2out_tensors] - - op_name = "ConcatOperation" - op_param =json.dumps({"concatDim": 2}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, ConCat2out_tensors_npu) - - #Reshape&Cache - print("====================Reshape&Cache====================") - in_tensors = [ConCat2out_tensors_npu[0], self.keyOutTensor, self.slotMapping] - out_tensors = [1] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - out_tensors_npu = [in_tensors_npu[i] if isinstance(i, int) else i.npu() - for i in out_tensors] - - op_name = "ReshapeAndCacheOperation" - op_param =json.dumps({"kvCacheCfg": 1}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, out_tensors_npu) - self.keyout_npu = out_tensors_npu[0].cpu().clone() - - def calc_vec_mm_atb_data_bf16(self, N, headNum, data_type, hidden_size): - hiddenStrate = hidden_size - blockNum = 192 - blockSize = 128 - headdim = 576 - self.input_token_num = N - self.rms_hidden_size = 512 - self.rope_hidden_size = 64 - self.headNum = headNum - self.epsilon = 1e-6 - self.dtype = data_type - - self.input1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(N, hiddenStrate))).to(data_type)# - self.gamma1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(hiddenStrate))).to(data_type) - self.quantScale1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(1))).to(data_type) - self.quantOffset1 = torch.from_numpy(np.random.uniform(-128.0, 127.0, size=(1))).to(torch.int8) - self.wdqkv = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(2112, hiddenStrate))).to(torch.int8)# - - self.deScale1 = torch.rand((2112), dtype=torch.float32) / 1000 - self.gamma2 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(1536))).to(data_type) - self.quantScale2 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(1))).to(data_type) - self.quantOffset2 = torch.from_numpy(np.random.uniform(-128.0, 127.0, size=(1))).to(torch.int8) - - self.wuq = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(headNum * 192, 1536))).to(torch.int8)# - - self.deScale2 = torch.rand((headNum * 192), dtype=torch.float32) / 1000 - - self.gamma3 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(512))).to(data_type) - self.sin1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(N, 64))).to(data_type) - self.cos1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(N, 64))).to(data_type) - self.keyCache = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(blockNum, blockSize, 1, headdim))).to(data_type) - - self.slotMapping = torch.from_numpy(np.random.choice(192 * 128, N, replace=False).astype(np.int32)).to(torch.int32) - - self.wuk = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(headNum, 128, 512))).to(data_type)# - self.sin2 = self.sin1 - self.cos2 = self.cos1 - - self.bias1 = torch.from_numpy(np.random.randint(-10, 10, (1, 2112)).astype(np.int32)).to(torch.int32) - self.bias2 = torch.from_numpy(np.random.randint(-10, 10, (1, headNum * 192)).astype(np.int32)).to(torch.int32) - - self.beta1 = torch.from_numpy(np.random.randint(-2, 2, (hiddenStrate)).astype(np.float16)).to(data_type) - self.beta2 = torch.from_numpy(np.random.randint(-2, 2, (1536)).astype(np.float16)).to(data_type) - - self.calc_vec_mm_data(N, headNum, data_type) - - ## RmsNorm - print("====================RmsNorm0====================") - self.rms1Out1_npu = torch.zeros((N, hiddenStrate), dtype=torch.int8) - npu_device = self.__get_npu_device() - torch_npu.npu.set_device(npu_device) - in_tensors = [self.input1, self.gamma1, self.beta1, self.quantScale1, self.quantOffset1] - out_tensors = [self.rms1Out1_npu] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() - for i in out_tensors] - op_name = "RmsNormOperation" - op_param =json.dumps({"layerType":1, "normParam":{"quantType": 2, "epsilon": 1e-6}}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, out_tensors_npu) - self.rms1Out1_npu = out_tensors_npu[0].cpu().clone() - - ##Ppmatmul - print("====================Ppmatmul0====================") - self.mm1Out1_npu = torch.zeros((N, 2112), dtype=data_type) - in_tensors = [out_tensors_npu[0], self.wdqkv, self.bias1, self.deScale1] - out_tensors = [self.mm1Out1_npu] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() - for i in out_tensors] - op_name = "LinearOperation" - op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": 27}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, out_tensors_npu) - self.mm1Out1_npu = out_tensors_npu[0].cpu().clone() - print(self.mm1Out1_npu.size()) - - ##SplitV - print("====================SplitV0====================") - splitSize = [512, 64, 1536] - splitVDim = 1 - in_tensors_npu = [out_tensors_npu[0]] - Split1_out_tensors_npu = [] - shape = in_tensors_npu[0].shape - for size in splitSize: - slice_shape = list(shape) - slice_shape[splitVDim] = size - Split1_out_tensors_npu.append(torch.zeros(slice_shape, dtype=data_type).npu()) - op_name = "SplitOperation" - op_param =json.dumps({"splitNum": 3, "splitDim": splitVDim, "splitSizes": splitSize}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, Split1_out_tensors_npu) - print(Split1_out_tensors_npu[0].size()) - ## RmsNorm - print("====================RmsNorm1====================") - self.rms2Out_npu = torch.zeros((N, 1536), dtype=torch.int8) - in_tensors = [Split1_out_tensors_npu[2], self.gamma2, self.beta2, self.quantScale2, self.quantOffset2] - out_tensors = [self.rms2Out_npu] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() - for i in out_tensors] - op_name = "RmsNormOperation" - op_param =json.dumps({"layerType":1, "normParam":{"quantType": 2, "epsilon": 1e-6}}) - self.operation = torch.classes.OperationTorch.OperationTorch(op_name) - self.op_param = op_param - self.operation.set_param(op_param) - self.operation.execute_out(in_tensors_npu, out_tensors_npu) - self.rms2Out_npu = out_tensors_npu[0].cpu().clone() - print(self.rms2Out_npu.size()) - - # ##Ppmatmul - print("====================Ppmatmul1====================") - self.mm2Out_npu = torch.zeros((N, headNum * 192), dtype=data_type) - in_tensors = [out_tensors_npu[0], self.wuq, self.bias2, self.deScale2] - out_tensors = [self.mm2Out_npu] - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - mm2out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() - for i in out_tensors] - op_name = "LinearOperation" - op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": 27}) + op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": self.out_data_type}) self.operation = torch.classes.OperationTorch.OperationTorch(op_name) self.op_param = op_param self.operation.set_param(op_param) @@ -744,108 +493,6 @@ class TestMLAPrepross(operation_test.OperationTest): return [self.qOut_npu[..., 0:512], self.keyout_npu[..., 0:512], self.qOut_npu[..., 512:576], self.keyout_npu[..., 512:576]] else: return [self.qOut_npu, self.keyout_npu] - - def __test_mlapo_impl( - self, - data_type: torch.dtype, - num_tokens: int, - num_heads: int, - cache_mode: int, - quant_mode: int, - weight_format: int, - ) -> None: - self.calc_vec_mm_atb_data(num_tokens, num_heads, data_type, cache_mode, quant_mode) - self.set_param( - "MlaPreprocessOperation", - {"N": num_tokens, "headNum": num_heads, "cacheMode": cache_mode, "quantMode": quant_mode}, - ) - self.set_input_formats( - [ - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nz, # self.wdqkv, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nz, # self.wuq, - self.format_nd, - weight_format, # self.wuk - self.format_nd, - self.format_nd, - self.format_nd, - self.format_nd, - ] - ) - in_tensors = [ - self.input1, - self.gamma1, - self.beta1, - self.quantScale1, - self.quantOffset1, - transdata(self.wdqkv, (16, 32)), # self.wdqkv, - self.bias1, - self.gamma2, - self.beta2, - self.quantScale2, - self.quantOffset2, - self.gamma3, - self.sin1, - self.cos1, - self.sin2, - self.cos2, - self.keyCache, - self.slotMapping, - transdata(self.wuq, (16, 32)), # self.wuq, - self.bias2, - self.wuk if weight_format == self.format_nd else transdata_3d(self.wuk), - self.deScale1, - self.deScale2, - self.quantScale3, - self.qNopeScale, - ] - if cache_mode == 0: - out_tensors = [ - torch.zeros_like(self.qOut, dtype=data_type), - self.keyOutTensor, - torch.tensor([]), - torch.tensor([]), - ] - elif cache_mode == 1: - out_tensors = [ - torch.zeros_like(self.qOut[..., :512], dtype=data_type), - self.keyOutTensor[..., :512], - torch.zeros_like(self.qOut[..., 512:], dtype=data_type), - self.keyOutTensor[..., 512:], - ] - elif cache_mode == 2: - out_tensors = [ - torch.zeros_like(self.qOut[..., :512], dtype=torch.int8), - self.I8keyCache1, - torch.zeros_like(self.qOut[..., 512:], dtype=data_type), - self.keyCache2, - ] - else: - out_tensors = [ - torch.zeros_like(self.qOut[..., :512], dtype=data_type), - self.FPkeyCache1, - torch.zeros_like(self.qOut[..., 512:], dtype=data_type), - self.keyCache2, - ] - - self.execute(in_tensors, out_tensors) - return def compare_data(self, tensor1, tensor2): out = tensor1.flatten() @@ -870,11 +517,10 @@ class TestMLAPrepross(operation_test.OperationTest): print("accuracy is correct: %r", (float(strict_error_count) / out_len) <= 0.001) def golden_compare(self, out_tensors, golden_tensors): - # self.compare_data(out_tensors.npu(), golden_tensors.npu()) + logging.info(f"qOut_npu npu max {torch.max(self.qOut_npu.clone().cpu())} min {torch.min(self.qOut_npu.clone().cpu())}") + logging.info(f"qout max {torch.max(self.qOut.clone().cpu())} min {torch.min(self.qOut.clone().cpu())}") + logging.info(f"out_tensors max {torch.max(out_tensors.clone().cpu())} min {torch.min(out_tensors.clone().cpu())}") if self.cacheMode == 1: - print(f"qOut_npu npu max {torch.max(self.qOut_npu.clone().detach()[..., 0:512].cpu())} min {torch.min(self.qOut_npu.clone().detach()[..., 0:512].cpu())}") - print(f"qout max {torch.max(self.qOut.clone()[..., 0:512].cpu())} min {torch.min(self.qOut.clone()[..., 0:512].cpu())}") - print(f"out_tensors max {torch.max(out_tensors[..., 0:512].cpu())} min {torch.min(out_tensors[..., 0:512].cpu())}") if self.compare_count == 0: self.compare_count += 1 return compare_cv(self.qOut_npu[..., 0:512].npu(), self.qOut[..., 0:512].npu(), out_tensors.npu()) @@ -893,13 +539,12 @@ class TestMLAPrepross(operation_test.OperationTest): else: return compare_cv(self.keyout_npu.npu(), self.keyOut1.npu(), out_tensors.npu()) - def test_mla_preprocess(self): + def test_mla_preprocess_cache1(self): if not operation_test.get_soc_version() == 'Ascend910B': print("this testcase only supports Ascend910B") return self.compare_count = 0 self.cacheMode = 1 - headNum = 32 data_type = torch.float16 hidden_size = 8000 N = 32 @@ -907,14 +552,12 @@ class TestMLAPrepross(operation_test.OperationTest): data_type = torch.float16 OP_NAME = "MlaPreprocessOperation" PARAM = json.dumps({"cacheMode":self.cacheMode}) - print("test_mla_preprocess") - self.calc_vec_mm_atb_data(N,headNum,data_type, hidden_size) + self.calc_vec_mm_atb_data(N, headNum, data_type, hidden_size) self.keyCache = self.keyCache.npu() qOut: torch.Tensor = torch.zeros((N, headNum, 512), dtype=data_type).npu() # float16 kvCacheOut: torch.Tensor = self.keyCache[..., 0:512].npu() # float16 qRopeOut: torch.Tensor = torch.zeros((N, headNum, 64), dtype=data_type).npu() # float16 krCacheOut: torch.Tensor = self.keyCache[..., 512:576].npu() - print("type -------------- [qOut, kvCacheOut, qRopeOut, krCacheOut]", type([qOut, kvCacheOut, qRopeOut, krCacheOut])) self.execute_out(OP_NAME, PARAM, [self.input1.npu(), # float16 self.gamma1.npu(), # float16 @@ -941,9 +584,52 @@ class TestMLAPrepross(operation_test.OperationTest): torch.tensor([]).to(data_type).npu(), torch.tensor([]).to(data_type).npu()], [qOut, kvCacheOut, qRopeOut, krCacheOut]) # float16 - - - def test_mla_preprocess_no_rms_norm(self): + + def test_mla_preprocess_split(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + self.compare_count = 0 + self.cacheMode = 1 + N = 56 + headNum = 77 + data_type = torch.float16 + OP_NAME = "MlaPreprocessOperation" + PARAM = json.dumps({"cacheMode":self.cacheMode}) + hidden_size = 6144 + self.calc_vec_mm_atb_data(N, headNum, data_type, hidden_size) + self.keyCache = self.keyCache.npu() + self.execute_out(OP_NAME, PARAM, + [self.input1.npu(), + self.gamma1.npu(), + self.beta1.npu(), + self.quantScale1.npu(), + self.quantOffset1.npu(), + torch_npu.npu_format_cast(transdata(self.wdqkv, (16, 32)).contiguous().npu(), 29), + self.deScale1.npu(), + self.bias1.npu(), + self.gamma2.npu(), + self.beta2.npu(), + self.quantScale2.npu(), + self.quantOffset2.npu(), + torch_npu.npu_format_cast(transdata(self.wuq, (16, 32)).contiguous().npu(), 29), + self.deScale2.npu(), + self.bias2.npu(), + self.gamma3.npu(), + self.cos1.npu(), + self.sin1.npu(), + self.wuk.npu(), + self.keyCache[..., 0:512].npu(), + self.keyCache[..., 512:576].npu(), + self.slotMapping.npu(), + torch.tensor([]).npu(), + torch.tensor([]).npu()], + [torch.zeros((N, headNum, 512), dtype=data_type).npu(), + self.keyCache[..., 0:512].npu(), + torch.zeros((N, headNum, 64), dtype=data_type).npu(), + self.keyCache[..., 512:576].npu()]) + + def test_mla_preprocess_split_32_32_4096(self): if not operation_test.get_soc_version() == 'Ascend910B': print("this testcase only supports Ascend910B") return @@ -951,51 +637,82 @@ class TestMLAPrepross(operation_test.OperationTest): self.cacheMode = 1 N = 32 headNum = 32 - data_type = torch.float16 - hidden_size = 8000 + data_type = torch.bfloat16 OP_NAME = "MlaPreprocessOperation" PARAM = json.dumps({"cacheMode":self.cacheMode}) - self.calc_vec_mm_atb_data(N,headNum,data_type, hidden_size) + hidden_size = 4096 + self.calc_vec_mm_atb_data(N, headNum, data_type, hidden_size) + self.keyCache = self.keyCache.npu() + self.execute_out(OP_NAME, PARAM, + [self.input1.npu(), + self.gamma1.npu(), + self.beta1.npu(), + self.quantScale1.npu(), + self.quantOffset1.npu(), + torch_npu.npu_format_cast(transdata(self.wdqkv, (16, 32)).contiguous().npu(), 29), + self.deScale1.npu(), + self.bias1.npu(), + self.gamma2.npu(), + self.beta2.npu(), + self.quantScale2.npu(), + self.quantOffset2.npu(), + torch_npu.npu_format_cast(transdata(self.wuq, (16, 32)).contiguous().npu(), 29), + self.deScale2.npu(), + self.bias2.npu(), + self.gamma3.npu(), + self.cos1.npu(), + self.sin1.npu(), + self.wuk.npu(), + self.keyCache[..., 0:512].npu(), + self.keyCache[..., 512:576].npu(), + self.slotMapping.npu(), + torch.tensor([]).npu(), + torch.tensor([]).npu()], + [torch.zeros((N, headNum, 512), dtype=data_type).npu(), + self.keyCache[..., 0:512].npu(), + torch.zeros((N, headNum, 64), dtype=data_type).npu(), + self.keyCache[..., 512:576].npu()]) + + def test_mla_preprocess_cache0(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + self.compare_count = 0 + self.cacheMode = 0 + N = 31 + headNum = 33 + data_type = torch.bfloat16 + OP_NAME = "MlaPreprocessOperation" + PARAM = json.dumps({}) + self.calc_vec_mm_atb_data(N, headNum, data_type, 5120) self.keyCache = self.keyCache.npu() - print("test_mla_preprocess_no_rms_norm") - qOut: torch.Tensor = torch.zeros((N, headNum, 512), dtype=data_type).npu() # float16 - kvCacheOut: torch.Tensor = self.keyCache[..., 0:512].npu() # float16 - qRopeOut: torch.Tensor = torch.zeros((N, headNum, 64), dtype=data_type).npu() # float16 - krCacheOut: torch.Tensor = self.keyCache[..., 512:576].npu() - print("type -------------- [qOut, kvCacheOut, qRopeOut, krCacheOut]", type([qOut, kvCacheOut, qRopeOut, krCacheOut])) self.execute_out(OP_NAME, PARAM, - [self.input1.npu(), # float16 - self.gamma1.npu(), # float16 - self.beta1.npu(), # float16 - self.quantScale1.npu(), # float16 - self.quantOffset1.npu(), # int8 - torch_npu.npu_format_cast(transdata(self.wdqkv.to(torch.float16), (16, 32)).contiguous().npu(), 29), # float16,nz - self.deScale1.to(torch.int64).npu(), # int64 - self.bias1.npu(), # int32 - self.gamma2.npu(), # float16 - self.beta2.npu(), # float16 - self.quantScale2.npu(), # float16 - self.quantOffset2.npu(), # int8 - torch_npu.npu_format_cast(transdata(self.wuq.to(torch.float16), (16, 32)).contiguous().npu(), 29), # float16,nz - self.deScale2.to(torch.int64).npu(), # int64 - self.bias2.npu(), # int32 - self.gamma3.npu(), # float16 - self.cos1.npu(), # float16 - self.sin1.npu(), # float16 - self.wuk.npu(), # float16 - self.keyCache[..., 0:512].npu(), # float16 - self.keyCache[..., 512:576].npu(), # float16 - self.slotMapping.npu(), # int32 - torch.tensor([]).to(data_type).npu(), - torch.tensor([]).to(data_type).npu()], - [qOut, kvCacheOut, qRopeOut, krCacheOut]) # float16 - -def suite(): - suite = unittest.TestSuite() - for _ in range(1): - suite.addTest(TestMLAPrepross('test_mla_preprocess')) - return suite - -if __name__ == '__main__': - runner = unittest.TextTestRunner() - runner.run(suite()) + [self.input1.npu(), + self.gamma1.npu(), + self.beta1.npu(), + self.quantScale1.npu(), + self.quantOffset1.npu(), + torch_npu.npu_format_cast(transdata(self.wdqkv, (16, 32)).contiguous().npu(), 29), + self.deScale1.npu(), + self.bias1.npu(), + self.gamma2.npu(), + self.beta2.npu(), + self.quantScale2.npu(), + self.quantOffset2.npu(), + torch_npu.npu_format_cast(transdata(self.wuq, (16, 32)).contiguous().npu(), 29), + self.deScale2.npu(), + self.bias2.npu(), + self.gamma3.npu(), + self.cos1.npu(), + self.sin1.npu(), + self.wuk.npu(), + self.keyCache.npu(), + torch.tensor([]).npu(), + self.slotMapping.npu(), + torch.tensor([]).npu(), + torch.tensor([]).npu()], + [torch.zeros((N, headNum, 576), dtype=data_type).npu(), + self.keyCache.npu()]) + +if __name__ == "__main__": + unittest.main()