diff --git a/ascend/test/sglang/v0.4.8/test_triton_ops_merge_state_kernel.py b/ascend/test/sglang/v0.4.8/test_triton_ops_merge_state_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..6447f8058c765d3255fea927089c4fbfbe763f9a --- /dev/null +++ b/ascend/test/sglang/v0.4.8/test_triton_ops_merge_state_kernel.py @@ -0,0 +1,90 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +import test_common +sys.path.append("..") + + +# source: python/sglang/srt/layers/attention/triton_ops/merge_state.py +@triton.jit +def merge_state_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged + output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a + prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b + suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx) + s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx) + p_lse = float("-inf") if p_lse == float("inf") else p_lse + s_lse = float("-inf") if s_lse == float("inf") else s_lse + + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + out_se = tl.exp(p_lse) + tl.exp(s_lse) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + token_idx * num_heads + head_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load( + prefix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + s_out = tl.load( + suffix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se + out = p_out * p_scale + s_out * s_scale + tl.store( + output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask, + ) + + +def test_triton_ops_merge_state_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + merge_state_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file