diff --git a/mxrec_add_ons/rec_for_torch/operators/disetangle_attention/READMD.md b/mxrec_add_ons/rec_for_torch/operators/disetangle_attention/READMD.md index f8a43f0a1579798625a313bb5d40dbaf23af2c34..fabdb712107f5efa483fed17fa0eec7a6d2f2e2b 100644 --- a/mxrec_add_ons/rec_for_torch/operators/disetangle_attention/READMD.md +++ b/mxrec_add_ons/rec_for_torch/operators/disetangle_attention/READMD.md @@ -73,10 +73,10 @@ c) 算子约束说明: * 支持的型号:Atlas A2系列产品; * 支持的CANN版本:8.2.RC1.alpha001及之后版本; -* b: 最大不超过100 -* n: 最大不超过64 -* s: 当前要求是256 -* d: 当前要求是16的倍数,最大不超过512 +* b: 取值范围[0, 100] +* n: 取值范围[0, 64] +* s: 取值为256 +* d: 取值范围[1, 512] * 数据类型: fp16 ## 算子逻辑 diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/disetangle_attention/DisetangleAttenFusion.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/disetangle_attention/DisetangleAttenFusion.cpp index 1a7dbc444b0d047d425ba8885e4e1b8c519c8238..dac428408be21aae17f63ddcc1f9ddc551b639d5 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/disetangle_attention/DisetangleAttenFusion.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/disetangle_attention/DisetangleAttenFusion.cpp @@ -18,6 +18,7 @@ constexpr uint32_t CONST_4 = 4; constexpr uint32_t MAX_BATCH = 100; constexpr uint32_t MAX_NUM = 64; constexpr uint32_t MAX_DIM = 512; +constexpr uint32_t SUPPORT_SEQ = 256; bool pos_attr_safe_get(const std::string pos_attr_type, int& pos_attr) { @@ -34,6 +35,26 @@ bool pos_attr_safe_get(const std::string pos_attr_type, int& pos_attr) return true; } +void tensor_dtype_check(const at::Tensor& query_layer, const at::Tensor& key_layer, const at::Tensor& value_layer, + const at::Tensor& pos_key_layer, const at::Tensor& pos_query_layer, + const at::Tensor& relative_pos, const at::Tensor& attn_mask) +{ + TORCH_CHECK(query_layer.scalar_type() == at::kHalf, + "float16 query_layer tensor expected but got dtype: ", query_layer.scalar_type()); + TORCH_CHECK(key_layer.scalar_type() == at::kHalf, + "float16 key_layer tensor expected but got dtype: ", key_layer.scalar_type()); + TORCH_CHECK(value_layer.scalar_type() == at::kHalf, + "float16 value_layer tensor expected but got dtype: ", value_layer.scalar_type()); + TORCH_CHECK(pos_key_layer.scalar_type() == at::kHalf, + "float16 pos_key_layer tensor expected but got dtype: ", pos_key_layer.scalar_type()); + TORCH_CHECK(pos_query_layer.scalar_type() == at::kHalf, + "float16 pos_query_layer tensor expected but got dtype: ", pos_query_layer.scalar_type()); + TORCH_CHECK(relative_pos.scalar_type() == at::kLong, + "int64 relative_pos tensor expected but got dtype: ", relative_pos.scalar_type()); + TORCH_CHECK(attn_mask.scalar_type() == at::kHalf, + "float16 attn_mask tensor expected but got dtype: ", attn_mask.scalar_type()); +} + void tensor_format_check(const at::Tensor& query_layer, const at::Tensor& key_layer, const at::Tensor& value_layer, const at::Tensor& pos_key_layer, const at::Tensor& pos_query_layer, const at::Tensor& relative_pos, const at::Tensor& attn_mask) @@ -101,15 +122,17 @@ std::tuple DisetangleAttentionPTA( auto relative_pos_conti = relative_pos.contiguous(); auto mask_conti = attn_mask.contiguous(); + tensor_dtype_check(query_layer, key_layer, value_layer, pos_key_layer, pos_query_layer, relative_pos, attn_mask); tensor_format_check(query_layer, key_layer, value_layer, pos_key_layer, pos_query_layer, relative_pos, attn_mask); auto batch = query_layer.size(0); auto head = query_layer.size(1); auto seq = query_layer.size(2); auto dim = query_layer.size(3); - TORCH_CHECK(batch <= MAX_BATCH, "current max batch is 100 but get", batch); - TORCH_CHECK(head <= MAX_NUM, "current max head is 64 but get", head); - TORCH_CHECK(dim <= MAX_DIM, "current max dim is 512 but get", dim); + TORCH_CHECK(batch <= MAX_BATCH && batch >= 0, "current batch range [0, 100] but get ", batch); + TORCH_CHECK(head <= MAX_NUM && head >= 0, "current head range [0, 64] but get ", head); + TORCH_CHECK(dim <= MAX_DIM && dim >= CONST_1, "current dim range [1, 512] but get ", dim); + TORCH_CHECK(seq == SUPPORT_SEQ, "current seq only support 256 but get ", seq); at::Tensor attn_output = at::empty_like(query_layer_conti); at::Tensor attn_probs = at::empty({batch, head, seq, seq}, query_layer_conti.options());