diff --git a/ascend/test/sglang/v0.4.8/test_sparse_fwd_kernel_flash_decode_stage3.py b/ascend/test/sglang/v0.4.8/test_sparse_fwd_kernel_flash_decode_stage3.py new file mode 100644 index 0000000000000000000000000000000000000000..361399d6c872b2526a83af61a3594bfc311c276b --- /dev/null +++ b/ascend/test/sglang/v0.4.8/test_sparse_fwd_kernel_flash_decode_stage3.py @@ -0,0 +1,77 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +import test_common +sys.path.append("..") + + +# source: python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +@triton.jit +def _sparse_fwd_kernel_flash_decode_stage3( + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + seq_len, # NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_o_eb, + stride_mid_o_eh, + stride_obs, + stride_oh, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + + block_n_size = tl.where(seq_len <= 0, 0, seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + for block_seq_n in range(0, block_n_size, 1): + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + new_max_logic = tl.maximum(tlogic, max_logic) + + old_scale = tl.exp(max_logic - new_max_logic) + acc *= old_scale + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp = sum_exp * old_scale + exp_logic + max_logic = new_max_logic + + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) + return + + +def test_sparse_fwd_kernel_flash_decode_stage3(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _sparse_fwd_kernel_flash_decode_stage3[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file