From 39c2e8fd2d597feac36778c807f6fdddad80640a Mon Sep 17 00:00:00 2001 From: zhanghanLeo Date: Wed, 27 Aug 2025 20:20:26 +0800 Subject: [PATCH 1/2] PagedAttention Supportted. --- .../paged_attention/paged_attention_common.h | 52 ++++ .../paged_attention/paged_attention_graph.cc | 260 ++++++++++++++++++ .../paged_attention_pynative.cc | 120 ++++++++ yaml/doc/paged_attention_op.yaml | 45 +++ .../paged_attention_op.yaml | 60 ++++ 5 files changed, 537 insertions(+) create mode 100644 ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h create mode 100644 ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_graph.cc create mode 100644 ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_pynative.cc create mode 100644 yaml/doc/paged_attention_op.yaml create mode 100644 yaml/ms_kernels_internal/paged_attention_op.yaml diff --git a/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h new file mode 100644 index 000000000..70561a881 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h @@ -0,0 +1,52 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MS_CUSTOM_OPS_CCSRC_OPS_MS_KERNELS_INTERNAL_PAGED_ATTENTION_H__ +#define __MS_CUSTOM_OPS_CCSRC_OPS_MS_KERNELS_INTERNAL_PAGED_ATTENTION_H__ + +#include + +namespace ms_custom_ops { +enum PagedAttentionInputIndex : size_t { + kPagedAttentionInputQueryIndex = 0, + kPagedAttentionInputKeyCacheIndex, + kPagedAttentionInputValueCacheIndex, + kPagedAttentionInputBlockTablesIndex, + kPagedAttentionInputContextLensIndex, + kPagedAttentionInputAntiquantScaleIndex, + kPagedAttentionInputAntiquantOffsetIndex, + kPagedAttentionInputAttnMaskIndex, + kPagedAttentionInputQueryLensIndex, + kPagedAttentionInputAlibiMaskIndex, + kPagedAttentionInputNumHeadIndex, + kPagedAttentionInputScaleValueIndex, + kPagedAttentionInputNumKVHeadIndex, + kPagedAttentionInputKVCacheQuantModeIndex, + kPagedAttentionInputMaskModeIndex, + kPagedAttentionInputMlaVDimIndex, + kPagedAttentionInputsNum +}; + +enum MlaMaskMode : int8_t { + kMaskNone = 0, + kMaskNorm, + kMaskAlibi, + kMaskSpec, + kMaskFree, +}; +} // namespace ms_custom_ops + +#endif // __MS_CUSTOM_OPS_CCSRC_OPS_MS_KERNELS_INTERNAL_PAGED_ATTENTION_H__ \ No newline at end of file diff --git a/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_graph.cc b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_graph.cc new file mode 100644 index 000000000..37626a6e3 --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_graph.cc @@ -0,0 +1,260 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h" + +#include +#include +#include +#include + +#include "ccsrc/ops/ms_kernels_internal/utils/attention_utils.h" +#include "ccsrc/base/ms_kernels_internal/graphmode/internal_kernel_mod.h" +#include "mindspore/core/include/ir/tensor.h" +#include "mindspore/ops/kernel/ascend/acl_ir/acl_convert.h" +#include "mindspore/ops/ops_utils/op_utils.h" +#include "mindspore/ccsrc/ms_extension/api.h" +#include "mindspore/core/include/ops/base_operator.h" +#include "mindspore/core/include/ops/ops_func_impl/op_func_impl.h" +#include "mindspore/core/include/ops/ops_func_impl/simple_infer.h" +#include "mindspore/ccsrc/runtime/device/kernel_runtime.h" +#include "mindspore/core/include/utils/check_convert_utils.h" + +namespace ms_custom_ops { + +class OPS_API PagedAttentionFuncImpl : public OpFuncImpl { + public: + ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + if (input_infos.size() != kPagedAttentionInputsNum) { + MS_LOG(EXCEPTION) << "Paged Attention input args should be equal to " << kPagedAttentionInputsNum + << ",but now get " << input_infos.size(); + } + + const InferInfoPtr &query_ptr = input_infos[kPagedAttentionInputQueryIndex]; + const InferInfoPtr &key_chache_ptr = input_infos[kPagedAttentionInputKeyCacheIndex]; + const InferInfoPtr &value_cache_ptr = input_infos[kPagedAttentionInputValueCacheIndex]; + + ShapeVector query_shape = query_ptr->GetShape(); + ShapeVector key_shape = key_chache_ptr->GetShape(); + ShapeVector value_shape = value_cache_ptr->GetShape(); + auto q_shape_len = query_shape.size(); + if (IsDynamicRank(query_shape) || IsDynamicRank(key_shape) || IsDynamicRank(value_shape)) { + return {ShapeVetcor{-2}}; + } + + if (IsDynamicShape(query_shape) || IsDynamicShape(key_shape) || IsDynamicShape(value_shape)) { + query_shape[q_shape_len - 1] = abstract::Shape::kShapeDimAny; + return {query_shape}; + } + + auto d_qk = key_shape[key_shape.size() - 1]; + auto mla_v_dim = input_infos[kPagedAttentionInputMlaVDimIndex]->GetScalarValueWithCheck(); + if (mla_v_dim > 0) { + query_shape[q_shape_len - 1] = query_shape[q_shape_len - 1] / d_qk * mla_v_dim; + return {query_shape}; + } + // DimV is different with DimQK in mLA + auto d_vo = value_shape[value_shape.size() - 1]; + query_shape[q_shape_len - 1] = query_shape[q_shape_len - 1] / d_qk * d_vo; + return {query_shape}; + } + + std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + bool enable_infer_boost = ms_context->IsEnableInferBoost(); + auto op_name = primitive->name(); + std::set valid_types = {kNumberTypeFloat16, kNumberTypeBFloat16}; + auto query_types = input_infos[kPagedAttentionInputQueryIndex]->GetType(); + auto key_types = input_infos[kPagedAttentionInputKeyCacheIndex]->GetType(); + auto value_types = input_infos[kPagedAttentionInputValueCacheIndex]->GetType(); + if (query_types.empty() || key_types.empty() || value_types.empty()) { + MS_LOG(EXECPTION) << "Query , KeyCache or ValueCache must have types"; + } + bool kvcache_quant = (key_types[0] == kNumberTypeInt8); + if (kvcache_quant && enable_infer_boost) { + std::set kvcache_types = {kNumberTypeInt8}; + CheckAndConvertUtils::CheckTypeIdValid("key_cache", key_types[0], kvcache_types, op_name); + CheckAndConvertUtils::CheckTypeIdValid("value_cache", value_types[0], kvcache_types, op_name); + } else { + // q, k, v should have the same types, fp16 or bf16; + CheckAndConvertUtils::CheckTypeIdValid("key_cache", key_types[0], valid_types, op_name); + CheckAndConvertUtils::CheckTypeIdValid("query", query_types[0], valid_types, op_name); + auto mla_v_dim = input_infos[kPageAttentionInputMlaDimIndex]->GetType(); + if (mla_v_dim == 0) { + CheckAndConvertUtils::CheckTypeIdValid("value_cache", value_types[0], valid_types, op_name); + } + } + + // check alibi_mask dtype equal to other inputs when alibi_mask is NOT None and infer_boost is ON + if (!input_infos[kPagedAttentionInputAlibiMaskIndex]->IsNone()) { + if (enable_infer_boost) { + if (input_infos[kPagedAttentionInputAlibiMaskIndex]->GetType().empty()) { + MS_LOG(EXCEPTION) << "Alibi Mask should have types"; + } + CheckAndConvertUtils::CheckTypeIdValid( + "alibi_mask", input_infos[kPagedAttentionInputAlibiMaskIndex]->GetType()[0], valid_types, op_name); + } else { + MS_LOG(EXCEPTION) << "" alibi_mask is not supported when infer_boost is disabled."; + } + } + + // check antiquant scale and offset dtypes when they are not None. + if (enable_infer_boost && !input_infos[kPagedAttentionInputAntiquantScaleIndex]->IsNone() && + !input_infos[kPagedAttentionInputAntiquantOffsetIndex]->IsNone()) { + bool valid_flag = false; + auto scale_type = input_infos[kPagedAttentionInputAntiquantScaleIndex]->GetType(); + auto offset_type = input_infos[kPagedAttentionInputAntiquantOffsetIndex]->GetTYpe(); + if (scale_type.empty() || offset_type.empty()) { + MS_LOG(EXCEPTION) << "Antiquant scale and offset should have types"; + } + auto scale_type_id = scale_type[0]; + auto offset_type_id = offset_type[0]; + if ((scale_type_id == kNumberTypeFloat16 && offset_type_id == kNumberTypeFloat16) || + (scale_type_id == kNumberTypeInt64 && offset_type_id == kNumberTypeInt32) || + (scale_type_id == kNumberTypeFloat32 && offset_type_id == kNumberTypeInt32)) { + valid_flag = true; + } + if (valid_flag) { + MS_LOG(EXCEPTION) << "types of antiquant_scale && antiquant_offset are not supported:" << scale_type_id << " & " + << offset_type_id; + } + } + + std::set block_tables_valid_types = {kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt64}; + auto context_valid_types = block_tables_valid_types; + auto block_tables_types = input_infos[kPagedAttentionInputBlockTablesIndex]->GetType(); + auto context_lens_types = input_infos[kPagedAttentionInputContextLensIndex]->GetType(); + if (block_tables_types.empty() || context_lens_types.empty()) { + MS_LOG(EXCEPTION) << "block_tables_types or context_lens_types should have types, but now empty."; + } + CheckAndConvertUtils::CheckTypeIdValid("block_tables", block_tables_types[0], block_tables_valid_types, op_name); + CheckAndConvertUtils::CheckTypeIdValid("context_lens", context_lens_types[0], context_valid_types, op_name); + return {input_infos[0]->GetType()}; + } + + bool GeneralInferRegistered() const override { return true; } + + std::set GetValueDependArgIndices() const override { + return {kPagedAttentionInputContextLensIndex, kPagedAttentionInputQueryLensIndex}; + }; +}; + +class PagedAttention : public InternalKernelMod { + public: + PagedAttention() : InternalKernelMod() {} + ~PagedAttention() = default; + + protected: + bool Init(const std::vector &inputs, const std::vector &outputs) override { + auto &llm_manager = LLMManager::GetInstance(); + llm_manager.add_force_resize_kernel(kernel_name_); + MS_LOG(INFO) << "Force op '" << kernel_name_ << "' to be resized to update op param 'seq_len'"; + return InternalKernelMod::Init(inputs, outputs); + } + internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs, + const std::vector &ms_inputs, + const std::vector &ms_outputs) override { + auto last_input_index = kIndex15; + if (ms_inputs.size() <= last_input_index) { + MS_LOG(EXCEPTION) << "For op " << kernel_name_ << ", inputs number should be larger than " << last_input_index + << ", but got " << ms_inputs.size(); + } + param_.head_num = static_cast(ms_inputs[kIndex10]->GetValueWithCheck()); + param_.tor = ms_inputs[kIndex11]->GetValueWithCheck(); + param_.kv_head_num = static_cast(ms_inputs[kIndex12]->GetValueWithCheck()); + param_.kv_cache_quant_mode = ms_inputs[kIndex13]->GetValueWithCheck(); + param_.mask_mode = + static_cast(ms_inputs[kIndex14]->GetValueWithCheck()); + param_.mla_v_dim = static_cast(ms_inputs[kIndex15]->GetValueWithCheck()); + has_attn_mask_ = (!(ms_inputs[kIndex7]->GetType()->isa())); + has_alibi_mask_ = (!(ms_inputs[kIndex9]->GetType()->isa())); + + param_.has_q_seq_lens = GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"q_seq_lens"}, ¶m_.q_seq_len); + (void)GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"batch_valid_length"}, ¶m_.kv_seq_len); + + CheckMask(); + + created_flag_ = true; + return internal::CreatePagedAttentionOp(inputs_ii, outputs_ii, param_, internal::kInternalPagedAttentionOpName); + } + + bool UpdateParam(const std::vector &inputs, const std::vector &outputs) override { + if (created_flag_) { + // the q_seq_len and batch_valid_length are inited in CreateKernel, so there is no need to load them again + created_flag_ = false; + return true; + } + + bool q_need_recreate = GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"q_seq_lens"}, ¶m_.q_seq_len); + bool kv_need_recreate = GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"batch_valid_length"}, ¶m_.kv_seq_len); + if (q_need_recreate || kv_need_recreate) { + CheckMask(); + auto ret = internal_op_->UpdateParam(¶m_); + if (ret != internal::kInternalOk) { + MS_LOG(ERROR) << "InternalPagedAttention UpdateParam failed, kernel_name: " << kernel_name_; + return false; + } + return true; + } + + return true; + } + + uint64_t GenerateTilingKey(const std::vector &inputs) override { + return InternalTilingCache::GenerateKey(kernel_name_, inputs, param_.q_seq_len, param_.kv_seq_len, + param_.has_q_seq_lens, param_.mla_v_dim); + } + + void InitKernelInputsOutputsIndex() override { + kernel_inputs_index_ = {kMlaInputQnopeIndex, kMlaInputQropeIndex, kMlaInputKvCacheIndex, + kMlaInputKropeIndex, kMlaInputBlockTablesIndex, kMlaInputAttnMaskIndex, + kMlaInputDeqScaleQkIndex, kMlaInputDeqScalePvIndex}; + kernel_outputs_index_ = {0, 1}; + } + + private: + inline void CheckMask() { + param_.mask_type = internal::PagedAttentionParam::MaskType::kMaskTypeNone; + auto enable_lookahead = + std::any_of(param_.q_seq_len.begin(), param_.q_seq_len.end(), [](int32_t seq_len) { return seq_len > 1; }); + if (enable_lookahead) { + if (has_attn_mask_) { + param_.mask_type = internal::PagedAttentionParam::MaskType::kMaskTypeLookAhead; + } + } else { + param_.q_seq_len.clear(); + } + + if (has_alibi_mask_) { + if (param_.mask_type == internal::PagedAttentionParam::MaskType::kMaskTypeLookAhead) { + MS_LOG(EXCEPTION) << "For op " << kernel_name_ << ", lookahead cannot be enabled when alibi_mask exists."; + } else { + param_.mask_type = internal::PagedAttentionParam::MaskType::kMaskTypeAlibi; + } + } + } + + bool created_flag_{false}; + bool has_attn_mask_{false}; + bool has_alibi_mask_{false}; + internal::PagedAttentionParam param_; +}; +} // namespace ms_custom_ops + +REG_GRAPH_MODE_OP(PagedAttention, ms_custom_ops::PagedAttentionFuncImpl, ms_custom_ops::PagedAttention); diff --git a/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_pynative.cc b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_pynative.cc new file mode 100644 index 000000000..8088b7d0c --- /dev/null +++ b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_pynative.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "ccsrc/ops/ms_kernels_internal/mla/mla_common.h" +#include "mindspore/ccsrc/ms_extension/api.h" +#include "lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal.h" +#include "ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.h" +#include "ccsrc/utils/utils.h" + +namespace ms_custom_ops { +class MlaRunner : public InternalPyboostRunner { + public: + MlaRunner(const std::string &op_name) : InternalPyboostRunner(op_name) {} + ~MlaRunner() = default; + + void UpdateParam(int32_t head_size, float tor, int32_t kv_head, mindspore::internal::MLAParam::MaskType mask_type, + int32_t is_ring, const std::vector &q_seq_len, const std::vector &kv_seq_len) { + param_.type = mindspore::internal::MLAParam::kSplitCache; + param_.head_size = head_size; + param_.tor = tor; + param_.kv_head = kv_head; + param_.mask_type = mask_type; + param_.is_ring = is_ring; + param_.q_seq_len = q_seq_len; + param_.kv_seq_len = kv_seq_len; + } + + protected: + internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, + const internal::OutputsImmutableInfoList &outputs) override { + return mindspore::internal::CreateMLAOp(inputs, outputs, param_, internal::kInternalMLAOpName); + } + + private: + mindspore::internal::MLAParam param_; +}; + +std::vector mla_atb(const ms::Tensor &q_nope, const ms::Tensor &q_rope, const ms::Tensor &ctkv, + const ms::Tensor &k_rope, const ms::Tensor &block_tables, + const std::optional &attn_mask, + const std::optional &deq_scale_qk, + const std::optional &deq_scale_pv, + const std::optional &q_seq_lens, + const std::optional &context_lens, int64_t head_num, double scale_value, + int64_t kv_head_num, int64_t mask_mode, int64_t is_ring) { + static auto op_name = "Mla"; + auto runner = std::make_shared(op_name); + MS_EXCEPTION_IF_NULL(runner); + + if (!q_seq_lens.has_value() || !context_lens.has_value()) { + MS_LOG(EXCEPTION) << "For " << op_name + << ", the q_seq_lens and context_lens can not be None, but got q_seq_lens.has_value(): " + << q_seq_lens.has_value() << ", context_lens.has_value(): " << context_lens.has_value(); + } + + auto q_seq_lens_value = GetValueFromTensor>(q_seq_lens.value(), op_name, "q_seq_lens"); + auto context_lens_value = GetValueFromTensor>(context_lens.value(), op_name, "context_lens"); + runner->UpdateParam(static_cast(head_num), static_cast(scale_value), + static_cast(kv_head_num), + static_cast(mask_mode), static_cast(is_ring), + q_seq_lens_value, context_lens_value); + + // Setup the runner with all parameters (including hash calculation) + runner->Setup(op_name, q_nope, q_rope, ctkv, k_rope, block_tables, attn_mask, deq_scale_qk, deq_scale_pv, q_seq_lens, + context_lens, head_num, scale_value, kv_head_num, mask_mode, is_ring); + + auto attn_out = ms::Tensor(q_nope.data_type(), q_nope.shape()); + auto lse_out = ms::Tensor(q_nope.data_type(), {0}); + + std::vector inputs = {q_nope, + q_rope, + ctkv, + k_rope, + block_tables, + GetTensorOrEmpty(attn_mask), + GetTensorOrEmpty(deq_scale_qk), + GetTensorOrEmpty(deq_scale_pv)}; + std::vector outputs = {attn_out, lse_out}; + runner->GetOrCreateKernel(inputs, outputs); + runner->Run(inputs, outputs); + return outputs; +} + +auto pyboost_mla(const ms::Tensor &q_nope, const ms::Tensor &q_rope, const ms::Tensor &ctkv, const ms::Tensor &k_rope, + const ms::Tensor &block_tables, const std::optional &attn_mask, + const std::optional &deq_scale_qk, const std::optional &deq_scale_pv, + const std::optional &q_seq_lens, const std::optional &context_lens, + int64_t head_num, double scale_value, int64_t kv_head_num, int64_t mask_mode, int64_t is_ring) { + return ms::pynative::PyboostRunner::Call<2>(mla_atb, q_nope, q_rope, ctkv, k_rope, block_tables, attn_mask, + deq_scale_qk, deq_scale_pv, q_seq_lens, context_lens, head_num, + scale_value, kv_head_num, mask_mode, is_ring); +} +} // namespace ms_custom_ops + +MS_CUSTOM_OPS_EXTENSION_MODULE(m) { + m.def("mla", &ms_custom_ops::pyboost_mla, "Multi-head Latent Attention", pybind11::arg("q_nope"), + pybind11::arg("q_rope"), pybind11::arg("ctkv"), pybind11::arg("k_rope"), pybind11::arg("block_tables"), + pybind11::arg("attn_mask") = std::nullopt, pybind11::arg("deq_scale_qk") = std::nullopt, + pybind11::arg("deq_scale_pv") = std::nullopt, pybind11::arg("q_seq_lens") = std::nullopt, + pybind11::arg("context_lens") = std::nullopt, pybind11::arg("head_num") = 32, + pybind11::arg("scale_value") = 0.0, pybind11::arg("kv_head_num") = 1, pybind11::arg("mask_mode") = 0, + pybind11::arg("is_ring") = 0); +} diff --git a/yaml/doc/paged_attention_op.yaml b/yaml/doc/paged_attention_op.yaml new file mode 100644 index 000000000..4645266ac --- /dev/null +++ b/yaml/doc/paged_attention_op.yaml @@ -0,0 +1,45 @@ +asd_paged_attention: + description: | + Calculate attention scores based on the paged attention mechanism for KV cache pagination management. + + .. warning:: + This is an experimental API that is subject to change or deletion. This API is only supported in Atlas A2 + training series for now. + + Args: + x1 (Tensor): Input data of AddRmsNorm. Support data type: float16, float32, bfloat16. + x2 (Tensor): Input data of AddRmsNorm. Support data type: float16, float32, bfloat16. + gamma (Tensor): Learnable parameter :math:`\gamma` . Support data type: float16, float32, bfloat16. + epsilon (float, optional): A float number ranged in (0, 1] to prevent division by 0. Default value is `1e-6`. + + query (Tensor):Query of PagedAttention.Supported data types:float16, bfloat16, int8. + key_cache (Tensor):key Cache of PagedAttention.Supported data types:float16, bfloat16, int8. + value_cache (Tensor): Value Cache of PagedAttention.Supported data types:float16, bfloat16, int8. + block_tables (Tensor): Block table of KV-Cache.Supported data types: int32 + context_lens (Tensor): KV sequence lengths of PagedAttention.Supported data types:int32. + antiquant_scale (Tensor): KV Cache anti-quant scale parameter.Supported data types:None, float16, int64.Only supported in quantization scenario. + antiquant_offset (Tensor): KV Cache anti-quant offset parameter.Supported data types:None, float16, int32. Only supported in quantization scenario. + attn_mask (Tensor): mask of attention.Supported data types:float16. + q_seq_lens (Tensor): Each batch's corresponding seqLen is required in parallel decoding scenarios.Supported data types:int32. + alibi_mask (Tensor): Alibi mask.Supported data types:float16. + head_num (int): The head nums of Query. + scale_value (float): scale_value = 1/sqrt(head_dim). + kv_head_num (int): the head num of KV. + kv_cache_quant_mode (string): the quant mode of KV Cache. + 'DEFAULT':perchannel + 'PERTOKEN':pertoken + mask_mode (string):the format of mask. + "MASK_DEFAULT" or "TRAPEZOIDAL" + mla_v_dim (int):default as 0, when in MLA scenario it's value should be 512. + + Returns: + attn_attention_out- Tensor, denotes the attention result, supported data types:float16 or bfloat16. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import mindspore + >>> import numpy as np + >>> from mindspore import Tensor + >>> import ms_custom_ops diff --git a/yaml/ms_kernels_internal/paged_attention_op.yaml b/yaml/ms_kernels_internal/paged_attention_op.yaml new file mode 100644 index 000000000..a20fddd16 --- /dev/null +++ b/yaml/ms_kernels_internal/paged_attention_op.yaml @@ -0,0 +1,60 @@ +#operator pagea_attention +paged_attention: + args: + query: + dtype: tensor + key_cache: + dtype: tensor + value_cache: + dtype: tensor + default: None + block_tables: + dtype: tensor + default: None + context_lens: + dtype: tensor + default: None + antiquant_scale: + dtype: tensor + default: None + antiquant_offset: + dtype: tensor + default: None + attn_mask: + dtype: tensor + default: None + q_seq_lens: + dtype: tensor + default: None + alibi_mask: + dtype: tensor + default: None + head_num: + dtype: int + prim_init: True + scale_value: + dtype: float + prim_init: True + kv_head_num: + dtype: int + prim_init: True + kv_cache_quant_mode: + dtype: int + default: "'DEFAULT'" + prim_init: True + arg_handler: str_to_enum + mask_mode: + dtype: int + default: "'MASK_DEFAULT'" + prim_init: True + arg_handler: str_to_enum + mla_v_dim: + dtype: int + default: 0 + prim_init: True + + returns: + attention_out: + dtype: tensor + class: + name: PagedAttention \ No newline at end of file -- Gitee From 553d06ed7ca8d2ffadf195b020aafa3dc0160ad9 Mon Sep 17 00:00:00 2001 From: zhanghanLeo Date: Thu, 11 Sep 2025 10:09:09 +0800 Subject: [PATCH 2/2] support pagedAttention. --- .../paged_attention/paged_attention_common.h | 104 +++-- .../paged_attention/paged_attention_graph.cc | 409 +++++++++++------- .../paged_attention_pynative.cc | 190 +++++--- .../paged_attention_op.yaml | 79 ++-- 4 files changed, 524 insertions(+), 258 deletions(-) diff --git a/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h index 70561a881..3cd17ac50 100644 --- a/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h +++ b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h @@ -21,32 +21,86 @@ namespace ms_custom_ops { enum PagedAttentionInputIndex : size_t { - kPagedAttentionInputQueryIndex = 0, - kPagedAttentionInputKeyCacheIndex, - kPagedAttentionInputValueCacheIndex, - kPagedAttentionInputBlockTablesIndex, - kPagedAttentionInputContextLensIndex, - kPagedAttentionInputAntiquantScaleIndex, - kPagedAttentionInputAntiquantOffsetIndex, - kPagedAttentionInputAttnMaskIndex, - kPagedAttentionInputQueryLensIndex, - kPagedAttentionInputAlibiMaskIndex, - kPagedAttentionInputNumHeadIndex, - kPagedAttentionInputScaleValueIndex, - kPagedAttentionInputNumKVHeadIndex, - kPagedAttentionInputKVCacheQuantModeIndex, - kPagedAttentionInputMaskModeIndex, - kPagedAttentionInputMlaVDimIndex, - kPagedAttentionInputsNum -}; - -enum MlaMaskMode : int8_t { - kMaskNone = 0, - kMaskNorm, - kMaskAlibi, - kMaskSpec, - kMaskFree, + kPagedAttentionInputQueryIndex = 0, // 0 + kPagedAttentionInputKeyCacheIndex, // 1 + kPagedAttentionInputValueCacheIndex, // 2 + kPagedAttentionInputBlockTablesIndex, // 3 + kPagedAttentionInputContextLensIndex, // 4 + kPagedAttentionInputAttnMaskIndex, // 5 + kPagedAttentionInputQSeqLenIndex, // 6 + kPagedAttentionInputBatchRunStatusIndex, // 7 + kPagedAttentionInputKDescalekIndex, // 8 + kPagedAttentionInputKOffsetIndex, // 9 + kPagedAttentionInputVDescaleIndex, // 10 + kPagedAttentionInputVOffsetIndex, // 11 + kPagedAttentionInputRazorOffsetIndex, // 12 + kPagedAttentionInputPScaleIndex, // 13 + kPagedAttentionInputLogNIndex, // 14 + kPagedAttentionInputQHeadNumIndex, // 15 + kPagedAttentionInputQKScaleIndex, // 16 + kPagedAttentionInputKVHeadNumIndex, // 17 + kPagedAttentionInputMaskTypeIndex, // 18 + kPagedAttentionInputBatchRunStatusEnableIndex, // 19 + kPagedAttentionInputQuantTypeIndex, // 20 + kPagedAttentionInputOutDataTypeIndex, // 21 + kPagedAttentionInputHasQuantOffsetIndex, // 22 + kPagedAttentionInputCompressTypeIndex, // 23 + kPagedAttentionInputCalcTypeIndex, // 24 + kPagedAttentionInputScaleTypeIndex, // 25 + kPagedAttentionInputInputLayoutIndex, // 26 + kPagedAttentionInputMlaVDimHeadSizeIndex, // 27 + kPagedAttentionInputInputFormatIndex, // 28 + kPagedAttentionInputsNum // 29 }; + +enum PAOutputIndex : size_t { + kPagedAttentionOutputIndex = 0, + kPagedAttentionOutputNum, +}; + +enum PAMaskType : int32_t { + kPA_MASK_UNDEFINED = 0, + kPA_MASK_TYPE_NORM, + kPA_MASK_TYPE_ALIBI, + kPA_MASK_TYPE_SPEC, + kPA_MASK_TYPE_MASK_FREE, +}; + +enum PAQuantType : int32_t { + kPA_TYPE_QUANT_UNDEFINED = 0, + kPA_TYPE_QUANT_UNQUANT = 0, + kPA_TYPE_DEQUANT_FUSION, + kPA_TYPE_QUANT_QKV_OFFLINE, + kPA_TYPE_QUANT_QKV_ONLINE, +}; + +enum PACompressType : int32_t { + kPA_COMPRESS_TYPE_UNDEFINED = 0, + kPA_COMPRESS_TYPE_KVHEAD, + kPA_COMPRESS_TYPE_KVHEAD_ROPE, + kPA_COMPRESS_TYPE_MAX, +}; + +enum PACalcType : int32_t { + kPA_CALC_TYPE_UNDEFINED = 0, + kPA_CALC_TYPE_SPEC, +}; + +enum PAOutDataType : int32_t { + kPA_ACL_DT_UNDEFINED = -1, + kPA_ACL_FLOAT16 = 1, + kPA_ACL_BF16 = 27, +}; + +enum PAInputLayout : int32_t { + kPA_INPUT_LAYOUT_BSND = 0, + kPA_INPUT_LAYOUT_BNSD = 1, +}; + +enum PAScaleType : int32_t { kPA_SCALE_TYPE_TOR = 0, kPA_SCALE_TYPE_LOGN, kPA_SCALE_TYPE_MAX }; + +enum PAInputFormat : int8_t { kKVFormatND = 0, kKVFormatNZ }; + } // namespace ms_custom_ops #endif // __MS_CUSTOM_OPS_CCSRC_OPS_MS_KERNELS_INTERNAL_PAGED_ATTENTION_H__ \ No newline at end of file diff --git a/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_graph.cc b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_graph.cc index 37626a6e3..747496f33 100644 --- a/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_graph.cc +++ b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_graph.cc @@ -30,11 +30,169 @@ #include "mindspore/core/include/ops/base_operator.h" #include "mindspore/core/include/ops/ops_func_impl/op_func_impl.h" #include "mindspore/core/include/ops/ops_func_impl/simple_infer.h" -#include "mindspore/ccsrc/runtime/device/kernel_runtime.h" #include "mindspore/core/include/utils/check_convert_utils.h" +#include "mindspore/core/include/utils/ms_context.h" +#include "mindspore/core/include/abstract/dshape.h" namespace ms_custom_ops { +static constexpr auto kPAQShapeRank = 3; +static constexpr auto kPAKVCacheRank = 4; +static constexpr auto kPAKVCacheRankAltas = 3; +static constexpr auto kPABlockTableRank = 2; +static constexpr auto kPAContextLenRank = 1; + +static void CheckParam(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) { + auto q_head_num = input_infos[kPagedAttntionInputQHeadNumIndex]->GetScalarValueWithCheck(); + MS_CHECK_VALUE(q_head_num != 0, + CheckAndConvertUtils::FormatCommMsg("For PagedAttention the q_head_num should not be 0, but got 0.")); + auto scale_type = input_infos[kPagedAttntionInputScaleTypeIndex]->GetScalarValueWithCheck(); + auto quant_type = input_infos[kPagedAttntionInputQuantTypeIndex]->GetScalarValueWithCheck(); + if (scale_type == PAScaleType::kPA_SCALE_TYPE_LOGN) { + MS_CHECK_VALUE( + ((quant_type != PAQuantType::kPA_TYPE_QUANT_QKV_OFFLINE) && + (quant_type != PAQuantType::kPA_TYPE_QUANT_QKV_ONLINE)), + CheckAndConvertUtils::FormatCommMsg( + "In PA scale type logn mode, quant is 2(Quant_QKV_OFFLINE) or 3(Quant_QKV_ONLINE) is not supported.")); + } + auto mla_v_head_size = input_infos[kPagedAttntionInputMlaVDimHeadSizeIndex]->GetScalarValueWithCheck(); + MS_CHECK_VALUE(((mla_v_head_size >= 0 && mla_v_head_size <= 576)), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention(MLA mode) the value head size should be [0, 576], but got ", mla_v_head_size)); + + auto input_layout = input_infos[kPagedAttntionInputInputLayoutIndex]->GetScalarValueWithCheck(); + MS_CHECK_VALUE( + ((input_layout == PAInputLayout::kPA_INPUT_LAYOUT_BNSD || input_layout == PAInputLayout::kPA_INPUT_LAYOUT_BSND)), + CheckAndConvertUtils::FormatCommMsg("For PagedAttention the input layout should be 0(BSND)/1(BNSD), but got ", + input_layout)); + auto calc_type = input_infos[kPagedAttntionInputCalcTypeIndex]->GetScalarValueWithCheck(); + MS_CHECK_VALUE(((calc_type == PACalcType::kPA_CALC_TYPE_SPEC || calc_type == PACalcType::kPA_CALC_TYPE_UNDEFINED)), + CheckAndConvertUtils::FormatCommMsg( + "For PagedAttention the calc_type should be 0(disable MTP)/1(enable MTP), but got ", calc_type)); + if (calc_type == PACalcType::kPA_CALC_TYPE_SPEC) { + MS_CHECK_VALUE((quant_type == PAQuantType::kPA_TYPE_QUANT_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg( + "In PA MTP scene, quant mode should be 0(TYPE_QUANT_UNQUANT/TYPE_QUANT_UNDEFINED), but now got "), + quant_type); + } + auto compress_type = input_infos[kPagedAttntionInputCompressTypeIndex]->GetScalarValueWithCheck(); + MS_CHECK_VALUE( + (compress_type != PACompressType::kPA_COMPRESS_TYPE_MAX), + CheckAndConvertUtils::FormatCommMsg("In PA compress scene, compress type should not be 3(kPA_COMPRESS_TYPE_MAX).")); + if (compress_type == PACompressType::kPA_COMPRESS_TYPE_KVHEAD || + compress_type == PACompressType::kPA_COMPRESS_TYPE_KVHEAD_ROPE) { + MS_CHECK_VALUE( + (quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_OFFLINE || quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_OFFLINE), + CheckAndConvertUtils::FormatCommMsg("In PA compress scene, quant mode should not be " + "2(kPA_TYPE_QUANT_QKV_OFFLINE)/3(kPA_TYPE_QUANT_QKV_OFFLINE), but now got "), + quant_type); + } + auto mask_type = input_infos[kPagedAttntionInputMaskTypeIndex]->GetScalarValueWithCheck(); + if (compress_type == PACompressType::kPA_COMPRESS_TYPE_KVHEAD_ROPE) { + MS_CHECK_VALUE((mask_type == PAMaskType::kPA_MASK_UNDEFINED), + CheckAndConvertUtils::FormatCommMsg("In PA COMPRESS_TYPE_KVHEAD_ROPE scene, mask type should not be " + "0(PA_MASK_UNDEFINED), but now got "), + mask_type); + } +} + +static void CheckShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) { + auto query_shape = input_infos[kPagedAttentionInputQueryIndex]->GetShape(); + auto key_cache_shape = input_infos[kPagedAttentionInputKeyCacheIndex]->GetShape(); + auto value_cache_shape = input_infos[kPagedAttentionInputValueCacheIndex]->GetShape(); + auto block_tables_shape = input_infos[kPagedAttentionInputBlockTablesIndex]->GetShape(); + auto context_len_shape = input_infos[kPagedAttentionInputContextLensIndex]->GetShape(); + if (!input_infos[kPagedAttentionInputQueryIndex]->IsDynamic()) { + MS_CHECK_VALUE(query_shape.size() == kPAQShapeRank, + CheckAndConvertUtils::FormatCommMsg("For PA The rank of query must be ", kPAQShapeRank, + ", but got shape: ", query_shape)); + } + + if (!input_infos[kPagedAttentionInputKeyCacheIndex]->IsDynamic()) { + MS_CHECK_VALUE(key_cache_shape.size() == kPAKVCacheRank, + CheckAndConvertUtils::FormatCommMsg("For PA The rank of key_cache must be ", kPAKVCacheRank, + ", but got shape: ", key_cache_shape)); + } + + if (!input_infos[kPagedAttentionInputValueCacheIndex]->IsDynamic()) { + MS_CHECK_VALUE(value_cache_shape.size() == kPAKVCacheRank, + CheckAndConvertUtils::FormatCommMsg("For PA The rank of value_cache must be ", kPAKVCacheRank, + ", but got shape: ", value_cache_shape)); + } + + if (!input_infos[kPagedAttentionInputBlockTablesIndex]->IsDynamic()) { + MS_CHECK_VALUE(block_tables_shape.size() == kPABlockTableRank, + CheckAndConvertUtils::FormatCommMsg("For PA The rank of block table must be ", kPABlockTableRank, + ", but got shape: ", block_tables_shape)); + } + + if (!input_infos[kPagedAttentionInputContextLensIndex]->IsDynamic()) { + MS_CHECK_VALUE(context_len_shape.size() == kPAContextLenRank, + CheckAndConvertUtils::FormatCommMsg("For PA The rank of context len must be ", kPAContextLenRank, + ", but got shape: ", context_len_shape)); + } +} + +static void CheckQuantShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) { + auto quant_type = input_infos[kPagedAttntionInputQuantTypeIndex]->GetScalarValueWithCheck(); + auto key_cache_shape = input_infos[kPagedAttentionInputKeyCacheIndex]->GetShape(); + auto value_cache_shape = input_infos[kPagedAttentionInputValueCacheIndex]->GetShape(); + if (!input_infos[kPagedAttentionInputKeyCacheIndex]->IsDynamic() && + !input_infos[kPagedAttentionInputValueCacheIndex]->IsDynamic()) { + auto key_head_dim = key_cache_shape[key_cache_shape.size() - 1]; + auto value_head_dim = value_cache_shape[value_cache_shape.size() - 1]; + MS_CHECK_VALUE( + ((key_head_dim == value_head_dim) && ((key_head_dim > 0) && (key_head_dim <= 256) && (value_head_dim > 0) && + (value_head_dim <= 256) && (value_head_dim * key_head_dim <= 128 * 128))), + CheckAndConvertUtils::FormatCommMsg( + "For PA keycache and value cache must be eqaul in head_dim, and " + "k_head_dim/v_head_dim in (0, 256], and k_head_dim * v_head_dim <= 128*128.But got k_head_dim:", + key_head_dim, ", v_head_dim: ", value_head_dim)); + } +} + +static void CheckType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) { + auto quant_type = input_infos[kPagedAttntionInputQuantTypeIndex]->GetScalarValueWithCheck(); + auto query_dtype = input_infos[kPagedAttentionInputQueryIndex]->GetType(); + auto key_cache_dtype = input_infos[kPagedAttentionInputKeyCacheIndex]->GetType(); + auto value_cache_dtype = input_infos[kPagedAttentionInputValueCacheIndex]->GetType(); + if (quant_type == PAQuantType::kPA_TYPE_QUANT_UNQUANT) { + MS_CHECK_VALUE(((quant_type == kNumberTypeFloat16) || (quant_type == kNumberTypeBFloat16)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in unquant mode, query dtype must be float16/bfloat16, but got type: ", quant_type)); + MS_CHECK_VALUE( + ((key_cache_dtype == kNumberTypeFloat16) || (key_cache_dtype == kNumberTypeBFloat16)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in unquant mode, key cache dtype must be float16/bfloat16, but got type: ", key_cache_dtype)); + MS_CHECK_VALUE( + ((value_cache_dtype == kNumberTypeFloat16) || (value_cache_dtype == kNumberTypeBFloat16)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in unquant mode,value cache dtype must be float16/bfloat16, but got type: ", value_cache_dtype)); + } else if (quant_type == PAQuantType::kPA_TYPE_DEQUANT_FUSION) { + MS_CHECK_VALUE(((quant_type == kNumberTypeFloat16) || (quant_type == kNumberTypeBFloat16)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in dequant mode, query dtype must be float16/bfloat16, but got type: ", quant_type)); + MS_CHECK_VALUE(((key_cache_dtype == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in dequant mode, key cache dtype must be int8, but got type: ", key_cache_dtype)); + MS_CHECK_VALUE(((value_cache_dtype == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in dequant mode,value cache dtype must be int8, but got type: ", value_cache_dtype)); + } else if (quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_OFFLINE || + quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_ONLINE) { + MS_CHECK_VALUE(((quant_type == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in dequant mode, query dtype must be int8, but got type: ", quant_type)); + MS_CHECK_VALUE(((key_cache_dtype == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in dequant mode, key cache dtype must be int8, but got type: ", key_cache_dtype)); + MS_CHECK_VALUE(((value_cache_dtype == kNumberTypeInt8)), + CheckAndConvertUtils::FormatCommMsg( + "For PA in dequant mode,value cache dtype must be int8, but got type: ", value_cache_dtype)); + } +} +} // namespace ms_custom_ops + class OPS_API PagedAttentionFuncImpl : public OpFuncImpl { public: ShapeArray InferShape(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { @@ -43,115 +201,79 @@ class OPS_API PagedAttentionFuncImpl : public OpFuncImpl { << ",but now get " << input_infos.size(); } - const InferInfoPtr &query_ptr = input_infos[kPagedAttentionInputQueryIndex]; - const InferInfoPtr &key_chache_ptr = input_infos[kPagedAttentionInputKeyCacheIndex]; - const InferInfoPtr &value_cache_ptr = input_infos[kPagedAttentionInputValueCacheIndex]; + auto &query_info = input_infos[kPagedAttentionInputQueryIndex]; + auto &query_shape = query_info->GetShape(); + + CheckShape(primitive, input_infos); + CheckParams(primitive, input_infos); + CheckQuant(primitive, input_infos); - ShapeVector query_shape = query_ptr->GetShape(); - ShapeVector key_shape = key_chache_ptr->GetShape(); - ShapeVector value_shape = value_cache_ptr->GetShape(); - auto q_shape_len = query_shape.size(); - if (IsDynamicRank(query_shape) || IsDynamicRank(key_shape) || IsDynamicRank(value_shape)) { - return {ShapeVetcor{-2}}; + if (query_info->IsDynamic() || input_infos[kPagedAttentionInputKeyCacheIndex]->IsDynamic() || + input_infos[kPagedAttentionInputValueCacheIndex]->IsDynamic()) { + return {ShapeVector(abstract::Shape::kShapeRankAny)}; } - if (IsDynamicShape(query_shape) || IsDynamicShape(key_shape) || IsDynamicShape(value_shape)) { - query_shape[q_shape_len - 1] = abstract::Shape::kShapeDimAny; + auto mla_v_dim = input_infos[kPagedAttntionInputMlaVDimHeadSizeIndex]->GetScalarValueWithCheck(); + if (mla_v_dim == 0) { return {query_shape}; } - auto d_qk = key_shape[key_shape.size() - 1]; - auto mla_v_dim = input_infos[kPagedAttentionInputMlaVDimIndex]->GetScalarValueWithCheck(); - if (mla_v_dim > 0) { - query_shape[q_shape_len - 1] = query_shape[q_shape_len - 1] / d_qk * mla_v_dim; - return {query_shape}; + auto key_cache_shape = input_infos[kPagedAttentionInputKeyCacheIndex]->GetShape(); + auto k_head_dim = key_cache_shape[key_cache_shape.size() - 1]; + if ((k_head_dim = abstract::Shape::kShapeDimAny) || + (query_shape[query_shape.size() - 1] == abstract::Shape::kShapeDimAny)) { + query_shape[query_shape.size() - 1] = abstract::Shape::kShapeDimAny; + + } else { + query_shape[query_shape.size() - 1] = query_shape[query_shape.size() - 1] / d_qk * mla_v_dim; } - // DimV is different with DimQK in mLA - auto d_vo = value_shape[value_shape.size() - 1]; - query_shape[q_shape_len - 1] = query_shape[q_shape_len - 1] / d_qk * d_vo; + return {query_shape}; } std::vector InferType(const PrimitivePtr &primitive, const InferInfoPtrList &input_infos) const override { - MS_EXCEPTION_IF_NULL(primitive); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - bool enable_infer_boost = ms_context->IsEnableInferBoost(); - auto op_name = primitive->name(); - std::set valid_types = {kNumberTypeFloat16, kNumberTypeBFloat16}; - auto query_types = input_infos[kPagedAttentionInputQueryIndex]->GetType(); - auto key_types = input_infos[kPagedAttentionInputKeyCacheIndex]->GetType(); - auto value_types = input_infos[kPagedAttentionInputValueCacheIndex]->GetType(); - if (query_types.empty() || key_types.empty() || value_types.empty()) { - MS_LOG(EXECPTION) << "Query , KeyCache or ValueCache must have types"; - } - bool kvcache_quant = (key_types[0] == kNumberTypeInt8); - if (kvcache_quant && enable_infer_boost) { - std::set kvcache_types = {kNumberTypeInt8}; - CheckAndConvertUtils::CheckTypeIdValid("key_cache", key_types[0], kvcache_types, op_name); - CheckAndConvertUtils::CheckTypeIdValid("value_cache", value_types[0], kvcache_types, op_name); - } else { - // q, k, v should have the same types, fp16 or bf16; - CheckAndConvertUtils::CheckTypeIdValid("key_cache", key_types[0], valid_types, op_name); - CheckAndConvertUtils::CheckTypeIdValid("query", query_types[0], valid_types, op_name); - auto mla_v_dim = input_infos[kPageAttentionInputMlaDimIndex]->GetType(); - if (mla_v_dim == 0) { - CheckAndConvertUtils::CheckTypeIdValid("value_cache", value_types[0], valid_types, op_name); - } - } - - // check alibi_mask dtype equal to other inputs when alibi_mask is NOT None and infer_boost is ON - if (!input_infos[kPagedAttentionInputAlibiMaskIndex]->IsNone()) { - if (enable_infer_boost) { - if (input_infos[kPagedAttentionInputAlibiMaskIndex]->GetType().empty()) { - MS_LOG(EXCEPTION) << "Alibi Mask should have types"; - } - CheckAndConvertUtils::CheckTypeIdValid( - "alibi_mask", input_infos[kPagedAttentionInputAlibiMaskIndex]->GetType()[0], valid_types, op_name); - } else { - MS_LOG(EXCEPTION) << "" alibi_mask is not supported when infer_boost is disabled."; - } - } + CheckType(primitive, input_infos); + auto quant_type = input_infos[kPagedAttntionInputQuantTypeIndex]->GetScalarValueWithCheck(); + auto out_data_type_ptr = input_infos[kPagedAttentionInputOutDataTypeIndex]; + auto query_ptr = input_infos[kPagedAttentionInputQueryIndex]; + if ((quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_ONLINE) || + (quant_type == PAQuantType::kPA_TYPE_QUANT_QKV_OFFLINE)) { + auto out_data_type = out_data_type_ptr->GetScalarValueWithCheck(); - // check antiquant scale and offset dtypes when they are not None. - if (enable_infer_boost && !input_infos[kPagedAttentionInputAntiquantScaleIndex]->IsNone() && - !input_infos[kPagedAttentionInputAntiquantOffsetIndex]->IsNone()) { - bool valid_flag = false; - auto scale_type = input_infos[kPagedAttentionInputAntiquantScaleIndex]->GetType(); - auto offset_type = input_infos[kPagedAttentionInputAntiquantOffsetIndex]->GetTYpe(); - if (scale_type.empty() || offset_type.empty()) { - MS_LOG(EXCEPTION) << "Antiquant scale and offset should have types"; - } - auto scale_type_id = scale_type[0]; - auto offset_type_id = offset_type[0]; - if ((scale_type_id == kNumberTypeFloat16 && offset_type_id == kNumberTypeFloat16) || - (scale_type_id == kNumberTypeInt64 && offset_type_id == kNumberTypeInt32) || - (scale_type_id == kNumberTypeFloat32 && offset_type_id == kNumberTypeInt32)) { - valid_flag = true; - } - if (valid_flag) { - MS_LOG(EXCEPTION) << "types of antiquant_scale && antiquant_offset are not supported:" << scale_type_id << " & " - << offset_type_id; + switch (out_data_type) { + case PAOutDataType::kPA_ACL_FLOAT16: + return {kNumberTypeFloat16}; + case PAOutDataType::kPA_ACL_BF16: + return {kNumberTypeBFloat16}; + default: + MS_LOG(EXCEPTION) << "In PA full quant scene, we should set the output data type:1(Float16) or 27(BFloat16)"; } } - std::set block_tables_valid_types = {kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt64}; - auto context_valid_types = block_tables_valid_types; - auto block_tables_types = input_infos[kPagedAttentionInputBlockTablesIndex]->GetType(); - auto context_lens_types = input_infos[kPagedAttentionInputContextLensIndex]->GetType(); - if (block_tables_types.empty() || context_lens_types.empty()) { - MS_LOG(EXCEPTION) << "block_tables_types or context_lens_types should have types, but now empty."; - } - CheckAndConvertUtils::CheckTypeIdValid("block_tables", block_tables_types[0], block_tables_valid_types, op_name); - CheckAndConvertUtils::CheckTypeIdValid("context_lens", context_lens_types[0], context_valid_types, op_name); - return {input_infos[0]->GetType()}; + return {query_ptr->GetType()}; } bool GeneralInferRegistered() const override { return true; } std::set GetValueDependArgIndices() const override { - return {kPagedAttentionInputContextLensIndex, kPagedAttentionInputQueryLensIndex}; - }; + return { + kPagedAttentionInputContextLensIndex, + kPagedAttntionInputQSeqLenIndex, + kPagedAttntionInputQHeadNumIndex, // 15 + kPagedAttntionInputQKScaleIndex, // 16 + kPagedAttntionInputKVHeadNumIndex, // 17 + kPagedAttntionInputMaskTypeIndex, // 18 + kPagedAttntionInputBatchRunStatusEnableIndex, // 19 + kPagedAttntionInputQuantTypeIndex, // 20 + kPagedAttntionInputOutDataTypeIndex, // 21 + kPagedAttntionInputHasQuantOffsetIndex, // 22 + kPagedAttntionInputCompressTypeIndex, // 23 + kPagedAttntionInputCalcTypeIndex, // 24 + kPagedAttntionInputScaleTypeIndex, // 25 + kPagedAttntionInputInputLayoutIndex, // 26 + kPagedAttntionInputMlaVDimHeadSizeIndex, // 27 + }; + } }; class PagedAttention : public InternalKernelMod { @@ -160,38 +282,42 @@ class PagedAttention : public InternalKernelMod { ~PagedAttention() = default; protected: - bool Init(const std::vector &inputs, const std::vector &outputs) override { - auto &llm_manager = LLMManager::GetInstance(); - llm_manager.add_force_resize_kernel(kernel_name_); - MS_LOG(INFO) << "Force op '" << kernel_name_ << "' to be resized to update op param 'seq_len'"; - return InternalKernelMod::Init(inputs, outputs); - } internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, const internal::OutputsImmutableInfoList &outputs, const std::vector &ms_inputs, const std::vector &ms_outputs) override { - auto last_input_index = kIndex15; - if (ms_inputs.size() <= last_input_index) { - MS_LOG(EXCEPTION) << "For op " << kernel_name_ << ", inputs number should be larger than " << last_input_index - << ", but got " << ms_inputs.size(); - } - param_.head_num = static_cast(ms_inputs[kIndex10]->GetValueWithCheck()); - param_.tor = ms_inputs[kIndex11]->GetValueWithCheck(); - param_.kv_head_num = static_cast(ms_inputs[kIndex12]->GetValueWithCheck()); - param_.kv_cache_quant_mode = ms_inputs[kIndex13]->GetValueWithCheck(); - param_.mask_mode = - static_cast(ms_inputs[kIndex14]->GetValueWithCheck()); - param_.mla_v_dim = static_cast(ms_inputs[kIndex15]->GetValueWithCheck()); - has_attn_mask_ = (!(ms_inputs[kIndex7]->GetType()->isa())); - has_alibi_mask_ = (!(ms_inputs[kIndex9]->GetType()->isa())); - - param_.has_q_seq_lens = GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"q_seq_lens"}, ¶m_.q_seq_len); - (void)GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"batch_valid_length"}, ¶m_.kv_seq_len); - - CheckMask(); - + param_.head_num = static_cast(inputs[kPagedAttntionInputQHeadNumIndex]->GetValueWithCheck()); + param_.tor = inputs[kPagedAttntionInputQKScaleIndex]->GetValueWithCheck(); + param_.kv_head_num = static_cast(inputs[kPagedAttntionInputKVHeadNumIndex]->GetValueWithCheck()); + param_.mask_type = static_cast( + inputs[kPagedAttntionInputMaskTypeIndex]->GetValueWithCheck()); + param_.batch_run_status_enable = inputs[kPagedAttntionInputBatchRunStatusEnableIndex]->GetValueWithCheck(); + param_.quant_type = static_cast(inputs[kPagedAttntionInputQuantTypeIndex]->GetValueWithCheck()); + param_.out_data_type = + static_cast(inputs[kPagedAttntionInputOutDataTypeIndex]->GetValueWithCheck()); + param_.has_quant_offset = inputs[kPagedAttntionInputHasQuantOffsetIndex]->GetValueWithCheck(); + param_.compress_type = + static_cast(inputs[kPagedAttntionInputCompressTypeIndex]->GetValueWithCheck()); + param_.calc_type = static_cast(inputs[kPagedAttntionInputCalcTypeIndex]->GetValueWithCheck()); + param_.scale_type = static_cast(inputs[kPagedAttntionInputScaleTypeIndex]->GetValueWithCheck()); + param_.input_layout = + static_cast(inputs[kPagedAttntionInputInputLayoutIndex]->GetValueWithCheck()); + param_.mla_v_dim = + static_cast(inputs[kPagedAttntionInputMlaVDimHeadSizeIndex]->GetValueWithCheck()); + param_.q_seq_len = inputs[kPagedAttntionInputQSeqLenIndex]->GetValueWithCheck>(); + param_.kv_seq_len = inputs[kPagedAttentionInputContextLensIndex]->GetValueWithCheck>(); created_flag_ = true; - return internal::CreatePagedAttentionOp(inputs_ii, outputs_ii, param_, internal::kInternalPagedAttentionOpName); + auto input_format_ = inputs[kPagedAttentionInputInputFormatIndex]->GetValueWithCheck(); + + if (input_format_ == PAInputFormat::kKVFormatNZ) { + auto inputs_new = inputs; + inputs_new[kPagedAttentionInputKeyCacheIndex].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_new[kPagedAttentionInputValueCacheIndex].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_new[kPagedAttntionInputAttnMaskIndex].SetFormat(internal::kFormatFRACTAL_NZ); + return internal::CreateASDPagedAttentionOp(inputs_new, outputs, param_, + internal::kInternalASDPagedAttentionOpName); + } + return internal::CreateASDPagedAttentionOp(inputs, outputs, param_, internal::kInternalASDPagedAttentionOpName); } bool UpdateParam(const std::vector &inputs, const std::vector &outputs) override { @@ -201,59 +327,40 @@ class PagedAttention : public InternalKernelMod { return true; } - bool q_need_recreate = GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"q_seq_lens"}, ¶m_.q_seq_len); - bool kv_need_recreate = GetSeqLenFromGraphAndCheckUpadate(kernel_name_, {"batch_valid_length"}, ¶m_.kv_seq_len); + auto q_need_recreate = GetSeqLenAndCheckUpdate(inputs[kPagedAttntionInputQSeqLenIndex], ¶m_.q_seq_len); + auto kv_need_recreate = GetSeqLenAndCheckUpdate(inputs[kPagedAttentionInputContextLensIndex], ¶m_.kv_seq_len); if (q_need_recreate || kv_need_recreate) { - CheckMask(); auto ret = internal_op_->UpdateParam(¶m_); if (ret != internal::kInternalOk) { - MS_LOG(ERROR) << "InternalPagedAttention UpdateParam failed, kernel_name: " << kernel_name_; + MS_LOG(ERROR) << "ASD PagedAttention UpdateParam failed, kernel_name: " << kernel_name_; return false; } return true; } - return true; } uint64_t GenerateTilingKey(const std::vector &inputs) override { + // User defined CacheKey, the inputs should include all the factors which + // will affect tiling result. return InternalTilingCache::GenerateKey(kernel_name_, inputs, param_.q_seq_len, param_.kv_seq_len, - param_.has_q_seq_lens, param_.mla_v_dim); + param_.mla_v_dim); } void InitKernelInputsOutputsIndex() override { - kernel_inputs_index_ = {kMlaInputQnopeIndex, kMlaInputQropeIndex, kMlaInputKvCacheIndex, - kMlaInputKropeIndex, kMlaInputBlockTablesIndex, kMlaInputAttnMaskIndex, - kMlaInputDeqScaleQkIndex, kMlaInputDeqScalePvIndex}; - kernel_outputs_index_ = {0, 1}; + kernel_inputs_index_ = { + kPagedAttentionInputQueryIndex, kPagedAttentionInputKeyCacheIndex, kPagedAttentionInputValueCacheIndex, + kPagedAttentionInputBlockTablesIndex, kPagedAttntionInputAttnMaskIndex, kPagedAttntionInputBatchRunStatusIndex, + kPagedAttntionInputKDescalekIndex, kPagedAttntionInputKOffsetIndex, kPagedAttntionInputVDescaleIndex, + kPagedAttntionInputVOffsetIndex, kPagedAttntionInputRazorOffsetIndex, kPagedAttntionInputPScaleIndex, + kPagedAttntionInputLogNIndex, + }; + kernel_outputs_index_ = {kPagedAttentionOutputIndex}; } private: - inline void CheckMask() { - param_.mask_type = internal::PagedAttentionParam::MaskType::kMaskTypeNone; - auto enable_lookahead = - std::any_of(param_.q_seq_len.begin(), param_.q_seq_len.end(), [](int32_t seq_len) { return seq_len > 1; }); - if (enable_lookahead) { - if (has_attn_mask_) { - param_.mask_type = internal::PagedAttentionParam::MaskType::kMaskTypeLookAhead; - } - } else { - param_.q_seq_len.clear(); - } - - if (has_alibi_mask_) { - if (param_.mask_type == internal::PagedAttentionParam::MaskType::kMaskTypeLookAhead) { - MS_LOG(EXCEPTION) << "For op " << kernel_name_ << ", lookahead cannot be enabled when alibi_mask exists."; - } else { - param_.mask_type = internal::PagedAttentionParam::MaskType::kMaskTypeAlibi; - } - } - } - bool created_flag_{false}; - bool has_attn_mask_{false}; - bool has_alibi_mask_{false}; - internal::PagedAttentionParam param_; + internal::ASDPagedAttentionParam param_; }; } // namespace ms_custom_ops diff --git a/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_pynative.cc b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_pynative.cc index 8088b7d0c..2698f0339 100644 --- a/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_pynative.cc +++ b/ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_pynative.cc @@ -18,50 +18,93 @@ #include #include #include -#include "ccsrc/ops/ms_kernels_internal/mla/mla_common.h" +#include "ccsrc/ops/ms_kernels_internal/paged_attention/paged_attention_common.h" #include "mindspore/ccsrc/ms_extension/api.h" #include "lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/internal.h" #include "ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.h" #include "ccsrc/utils/utils.h" namespace ms_custom_ops { -class MlaRunner : public InternalPyboostRunner { +class PagedAttentionRunner : public InternalPyboostRunner { public: - MlaRunner(const std::string &op_name) : InternalPyboostRunner(op_name) {} - ~MlaRunner() = default; + PagedAttentionRunner(const std::string &op_name) : InternalPyboostRunner(op_name) {} + ~PagedAttentionRunner() = default; - void UpdateParam(int32_t head_size, float tor, int32_t kv_head, mindspore::internal::MLAParam::MaskType mask_type, - int32_t is_ring, const std::vector &q_seq_len, const std::vector &kv_seq_len) { - param_.type = mindspore::internal::MLAParam::kSplitCache; - param_.head_size = head_size; - param_.tor = tor; - param_.kv_head = kv_head; - param_.mask_type = mask_type; - param_.is_ring = is_ring; - param_.q_seq_len = q_seq_len; - param_.kv_seq_len = kv_seq_len; + void SetParam(int32_t q_head_num, float qk_scale, int32_t kv_head_num, int32_t mask_type, + bool batch_run_status_enable, int32_t quant_type, int32_t out_data_type, bool has_quant_offset, + int32_t compress_type, int32_t calc_type, int32_t scale_type, int32_t input_layout, uint32_t mla_v_dim, + const std::vector &q_seq_len, const std::vector &kv_seq_len) { + param_.head_num = q_head_num; + param_.tor = qk_scale; + param_.kv_head_num = kv_head_num; + param_.mask_type = static_cast(mask_type); + param_.batch_run_status_enable = batch_run_status_enable; + param_.quant_type = quant_type; + param_.out_data_type = out_data_type; + param_.has_quant_offset = has_quant_offset; + param_.compress_type = compress_type; + param_.calc_type = calc_type; + param_.scale_type = scale_type; + param_.input_layout = input_layout; + param_.mla_v_dim = mla_v_dim; + auto is_q_changed = CheckAndUpdate(q_seq_len, ¶m_.q_seq_len); + auto is_kv_changed = CheckAndUpdate(kv_seq_len, ¶m_.kv_seq_len); + need_update_param_ = is_q_changed | is_kv_changed; } + void SetInputFormat(MlaInputFormat input_format) { input_format_ = input_format; } + protected: + bool UpdateParam() override { + if (created_flag_) { + // the q_seq_len and kv_seq_len are inited in CreatedKernel, so there is no need to load them again + created_flag_ = false; + } + + if (need_update_param_) { + auto ret = internal_op_->UpdateParam(¶m_); + if (ret != internal::kInternalOk) { + MS_LOG(ERROR) << "ASD PagedAttention UpdateParam failed in MlaRunner."; + return false; + } + return true; + } + } + internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, const internal::OutputsImmutableInfoList &outputs) override { - return mindspore::internal::CreateMLAOp(inputs, outputs, param_, internal::kInternalMLAOpName); + create_flag_ = true; + if (input_format_ == PAInputFormat::kKVFormatNZ) { + auto inputs_new = inputs; + inputs_new[kPagedAttentionInputKeyCacheIndex].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_new[kPagedAttentionInputValueCacheIndex].SetFormat(internal::kFormatFRACTAL_NZ); + inputs_new[kPagedAttntionInputAttnMaskIndex].SetFormat(internal::kFormatFRACTAL_NZ); + return internal::CreateASDPagedAttentionOp(inputs_new, outputs, param_, + internal::kInternalASDPagedAttentionOpName); + } + return mindspore::internal::CreateASDPagedAttentionOp(inputs, outputs, param_, + internal::kInternalASDPagedAttentionOpName); } private: - mindspore::internal::MLAParam param_; + mindspore::internal::ASDPagedAttentionParam param_; + bool created_flag_{true}; + bool need_update_param_{false}; + PAInputFormat input_format_{kKVFormatND}; }; -std::vector mla_atb(const ms::Tensor &q_nope, const ms::Tensor &q_rope, const ms::Tensor &ctkv, - const ms::Tensor &k_rope, const ms::Tensor &block_tables, - const std::optional &attn_mask, - const std::optional &deq_scale_qk, - const std::optional &deq_scale_pv, - const std::optional &q_seq_lens, - const std::optional &context_lens, int64_t head_num, double scale_value, - int64_t kv_head_num, int64_t mask_mode, int64_t is_ring) { - static auto op_name = "Mla"; - auto runner = std::make_shared(op_name); +std::vector paged_attention_atb( + cconst ms::Tensor &query, const ms::Tensor &key_cache, const ms::Tensor &value_cache, const ms::Tensor &block_tables, + const std::optional &context_lens, const std::optional &attn_mask, + const std::optional &q_seq_lens, const std::optional &batch_run_status, + const std::optional &k_descale, const std::optional &k_offset, + const std::optional &v_descale, const std::optional &v_offset, + const std::optional &razor_offset, const std::optional &p_scale, + const std::optional &log_n, int64_t q_head_num, double qk_scale, int64_t kv_head_num, int64_t mask_type, + bool batch_run_status_enable, int64_t quant_type, int64_t out_data_type, bool has_quant_offset, int64_t compress_type, + int64_t calc_type, int64_t scale_type, int64_t input_layout, int64_t mla_v_dim, int64_t input_format) { + static auto op_name = "PagedAttention"; + auto runner = std::make_shared(op_name); MS_EXCEPTION_IF_NULL(runner); if (!q_seq_lens.has_value() || !context_lens.has_value()) { @@ -69,52 +112,85 @@ std::vector mla_atb(const ms::Tensor &q_nope, const ms::Tensor &q_ro << ", the q_seq_lens and context_lens can not be None, but got q_seq_lens.has_value(): " << q_seq_lens.has_value() << ", context_lens.has_value(): " << context_lens.has_value(); } - + if (input_format != PAInputFormat::kKVFormatND && input_format != PAInputFormat::kKVFormatNZ) { + MS_LOG(EXCEPTION) << "For " << op_name << ", the input_format is invalid: " << input_format; + } auto q_seq_lens_value = GetValueFromTensor>(q_seq_lens.value(), op_name, "q_seq_lens"); auto context_lens_value = GetValueFromTensor>(context_lens.value(), op_name, "context_lens"); - runner->UpdateParam(static_cast(head_num), static_cast(scale_value), - static_cast(kv_head_num), - static_cast(mask_mode), static_cast(is_ring), - q_seq_lens_value, context_lens_value); + runner->SetInputFormat(input_format); + runner->SetParam(static_cast(q_head_num), static_cast(qk_scale), static_cast(kv_head_num), + static_cast(mask_type), batch_run_status_enable, static_cast(quant_type), + static_cast(out_data_type), has_quant_offset, static_cast(compress_type), + static_cast(calc_type), static_cast(scale_type), + static_cast(input_layout), static_cast(mla_v_dim), q_seq_lens_value, + context_lens_value); // Setup the runner with all parameters (including hash calculation) - runner->Setup(op_name, q_nope, q_rope, ctkv, k_rope, block_tables, attn_mask, deq_scale_qk, deq_scale_pv, q_seq_lens, - context_lens, head_num, scale_value, kv_head_num, mask_mode, is_ring); + runner->Setup(op_name, query, key_cache, value_cache, block_tables, context_lens, attn_mask, q_seq_lens, + batch_run_status, k_descale, k_offset, v_descale, v_offset, razor_offset, p_scale, log_n, q_head_num, + qk_scale, kv_head_num, mask_type, batch_run_status_enable, quant_type, out_data_type, has_quant_offset, + compress_type, calc_type, scale_type, input_layout, mla_v_dim); - auto attn_out = ms::Tensor(q_nope.data_type(), q_nope.shape()); - auto lse_out = ms::Tensor(q_nope.data_type(), {0}); + auto output_data_type = query.data_type(); + if (query.data_type() == kNumberTypeInt8 && out_data_type != PAOutDataType::kPA_ACL_DT_UNDEFINED) { + if (out_data_type == PAOutDataType::kPA_ACL_FLOAT16) { + output_data_type = kNumberTypeFloat16; + } else if (out_data_type == PAOutDataType::kPA_ACL_BF16) { + output_data_type = kNumberTypeBFloat16; + } + } + auto attn_out = ms::Tensor(output_data_type, query.shape()); - std::vector inputs = {q_nope, - q_rope, - ctkv, - k_rope, + std::vector inputs = {query, + key_cache, + value_cache, block_tables, GetTensorOrEmpty(attn_mask), - GetTensorOrEmpty(deq_scale_qk), - GetTensorOrEmpty(deq_scale_pv)}; - std::vector outputs = {attn_out, lse_out}; + GetTensorOrEmpty(batch_run_status), + GetTensorOrEmpty(k_descale), + GetTensorOrEmpty(k_offset), + GetTensorOrEmpty(v_descale), + GetTensorOrEmpty(v_offset), + GetTensorOrEmpty(razor_offset), + GetTensorOrEmpty(p_scale), + GetTensorOrEmpty(log_n)}; + std::vector outputs = {attn_out}; runner->GetOrCreateKernel(inputs, outputs); runner->Run(inputs, outputs); return outputs; } -auto pyboost_mla(const ms::Tensor &q_nope, const ms::Tensor &q_rope, const ms::Tensor &ctkv, const ms::Tensor &k_rope, - const ms::Tensor &block_tables, const std::optional &attn_mask, - const std::optional &deq_scale_qk, const std::optional &deq_scale_pv, - const std::optional &q_seq_lens, const std::optional &context_lens, - int64_t head_num, double scale_value, int64_t kv_head_num, int64_t mask_mode, int64_t is_ring) { - return ms::pynative::PyboostRunner::Call<2>(mla_atb, q_nope, q_rope, ctkv, k_rope, block_tables, attn_mask, - deq_scale_qk, deq_scale_pv, q_seq_lens, context_lens, head_num, - scale_value, kv_head_num, mask_mode, is_ring); +auto pyboost_paged_attention(const ms::Tensor &query, const ms::Tensor &key_cache, const ms::Tensor &value_cache, + const ms::Tensor &block_tables, const std::optional &context_lens, + const std::optional &attn_mask, const std::optional &q_seq_lens, + const std::optional &batch_run_status, + const std::optional &k_descale, const std::optional &k_offset, + const std::optional &v_descale, const std::optional &v_offset, + const std::optional &razor_offset, const std::optional &p_scale, + const std::optional &log_n, int64_t q_head_num, double qk_scale, + int64_t kv_head_num, int64_t mask_type, bool batch_run_status_enable, int64_t quant_type, + int64_t out_data_type, bool has_quant_offset, int64_t compress_type, int64_t calc_type, + int64_t scale_type, int64_t input_layout, int64_t mla_v_dim, int64_t input_format) { + return ms::pynative::PyboostRunner::Call( + paged_attention_atb, query, key_cache, value_cache, block_tables, context_lens, attn_mask, q_seq_lens, + batch_run_status, k_descale, k_offset, v_descale, v_offset, razor_offset, p_scale, log_n, q_head_num, qk_scale, + kv_head_num, mask_type, batch_run_status_enable, quant_type, out_data_type, has_quant_offset, compress_type, + calc_type, scale_type, input_layout, mla_v_dim, input_format); } } // namespace ms_custom_ops MS_CUSTOM_OPS_EXTENSION_MODULE(m) { - m.def("mla", &ms_custom_ops::pyboost_mla, "Multi-head Latent Attention", pybind11::arg("q_nope"), - pybind11::arg("q_rope"), pybind11::arg("ctkv"), pybind11::arg("k_rope"), pybind11::arg("block_tables"), - pybind11::arg("attn_mask") = std::nullopt, pybind11::arg("deq_scale_qk") = std::nullopt, - pybind11::arg("deq_scale_pv") = std::nullopt, pybind11::arg("q_seq_lens") = std::nullopt, - pybind11::arg("context_lens") = std::nullopt, pybind11::arg("head_num") = 32, - pybind11::arg("scale_value") = 0.0, pybind11::arg("kv_head_num") = 1, pybind11::arg("mask_mode") = 0, - pybind11::arg("is_ring") = 0); + m.def("mla", &ms_custom_ops::pyboost_paged_attention, "PagedAttention", pybind11::arg("query"), + pybind11::arg("key_cache"), pybind11::arg("value_cache"), pybind11::arg("block_tables"), + pybind11::arg("context_lens") = std::nullopt, pybind11::arg("attn_mask") = std::nullopt, + pybind11::arg("q_seq_lens") = std::nullopt, pybind11::arg("batch_run_status") = std::nullopt, + pybind11::arg("k_descale") = std::nullopt, pybind11::arg("k_offset") = std::nullopt, + pybind11::arg("v_descale") = std::nullopt, pybind11::arg("v_offset") = std::nullopt, + pybind11::arg("razor_offset") = std::nullopt, pybind11::arg("p_scale") = std::nullopt, + pybind11::arg("log_n") = std::nullopt, pybind11::arg("q_head_num") = 0, pybind11::arg("qk_scale") = 1.0, + pybind11::arg("kv_head_num") = 0, pybind11::arg("mask_type") = 0, + pybind11::arg("batch_run_status_enable") = false, pybind11::arg("quant_type") = 0, + pybind11::arg("out_data_type") = -1, pybind11::arg("has_quant_offset") = false, + pybind11::arg("compress_type") = 0, pybind11::arg("calc_type") = 0, pybind11::arg("scale_type") = 0, + pybind11::arg("input_layout") = 0, pybind11::arg("mla_v_dim") = 0, pybind11::arg("input_format") = 0); } diff --git a/yaml/ms_kernels_internal/paged_attention_op.yaml b/yaml/ms_kernels_internal/paged_attention_op.yaml index a20fddd16..ddb2999be 100644 --- a/yaml/ms_kernels_internal/paged_attention_op.yaml +++ b/yaml/ms_kernels_internal/paged_attention_op.yaml @@ -7,54 +7,83 @@ paged_attention: dtype: tensor value_cache: dtype: tensor - default: None block_tables: dtype: tensor - default: None context_lens: dtype: tensor default: None - antiquant_scale: + attn_mask: dtype: tensor default: None - antiquant_offset: + q_seq_lens: dtype: tensor default: None - attn_mask: + btach_run_status: dtype: tensor default: None - q_seq_lens: + k_descale: + dtype: tensor + default: None + k_offset: + dtype: tensor + default: None + v_descale: + dtype: tensor + default: None + v_offset: + dtype: tensor + default: None + razor_offset: dtype: tensor default: None - alibi_mask: + p_scale: dtype: tensor default: None - head_num: + log_n: + dtype: tensor + default: None + q_head_num: dtype: int - prim_init: True - scale_value: + default: 0 + qk_scale: dtype: float - prim_init: True + default: 1.0 kv_head_num: dtype: int - prim_init: True - kv_cache_quant_mode: + default: 0 + mask_type: + dtype: int + default: 0 + batch_run_status_enable: + dtype: bool + default: False + quant_type: + dtype: int + default: 0 + out_data_type: + dtype: int + default: -1 + has_quant_offset: + dtype: bool + default: False + compressType: + dtype: int + default: 0 + calc_type: + dtype: int + default: 0 + scale_type: + dtype: int + default: 0 + input_layout: dtype: int - default: "'DEFAULT'" - prim_init: True - arg_handler: str_to_enum - mask_mode: + default: 0 # 0(BSND)/1(BNSD) + mla_v_head_size: dtype: int - default: "'MASK_DEFAULT'" - prim_init: True - arg_handler: str_to_enum - mla_v_dim: + default: 0 # [0, 576] + input_format: dtype: int default: 0 - prim_init: True - returns: attention_out: dtype: tensor - class: - name: PagedAttention \ No newline at end of file -- Gitee