From 2191e9191bfaeb557a03396a9a8b7799a8bf5385 Mon Sep 17 00:00:00 2001 From: zhuxg33 <1074959344@qq.com> Date: Thu, 25 Sep 2025 19:02:41 +0800 Subject: [PATCH] test(sglang): add UT case for sglang kernel(_fwd_kernel, extend attetion) --- .../sglang/v0.4.8/test_extend_fwd_kernel.py | 298 ++++++++++++++++++ 1 file changed, 298 insertions(+) create mode 100644 ascend/test/sglang/v0.4.8/test_extend_fwd_kernel.py diff --git a/ascend/test/sglang/v0.4.8/test_extend_fwd_kernel.py b/ascend/test/sglang/v0.4.8/test_extend_fwd_kernel.py new file mode 100644 index 0000000..e301801 --- /dev/null +++ b/ascend/test/sglang/v0.4.8/test_extend_fwd_kernel.py @@ -0,0 +1,298 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +import test_common +sys.path.append("..") + + +# source: python/sglang/srt/layers/attention/triton_ops/extend_attention.py +@triton.jit +def _fwd_kernel( + Q_Extend, + K_Extend, + V_Extend, + O_Extend, + K_Buffer, + V_Buffer, + qo_indptr, + kv_indptr, + kv_indices, + mask_ptr, + mask_indptr, + sm_scale, + kv_group_num, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + SLIDING_WINDOW_SIZE: tl.constexpr, + logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + USE_CUSTOM_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + SKIP_PREFIX_CUSTOM_MASK: tl.constexpr, + STORE_TRANSPOSE: tl.constexpr, +): + cur_seq = tl.program_id(0) + cur_head = tl.program_id(1) + cur_block_m = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + + cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq) + cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx + cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq) + cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx + cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend + + if USE_CUSTOM_MASK: + cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + offs_m = tl.arange(0, BLOCK_M) + mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + + mask_d = offs_d < Lq + mask_dv = offs_dv < Lv + + offs_q = ( + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + q = tl.load( + Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0 + ) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + offs_qpe = ( + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_dpe[None, :] + ) + qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) + + # stage 1: compute scores with prefix + offs_n = tl.arange(0, BLOCK_N) + + acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) + deno = tl.zeros([BLOCK_M], dtype=tl.float32) + e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + + for start_n in range(0, cur_seq_len_prefix, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_seq_len_prefix + + offs_kv_loc = tl.load( + kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0 + ) + + # load k in transposed way + offs_buf_k = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q.to(k.dtype), k) + if BLOCK_DPE > 0: + offs_kpe = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe.to(kpe.dtype), kpe) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + final_mask = mask_m[:, None] & mask_n[None, :] + if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK: + custom_mask = tl.load( + mask_ptr + + cur_seq_mask_start_idx + + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + + start_n + + offs_n[None, :], + mask=(mask_m[:, None] & mask_n[None, :]), + other=0, + ) + final_mask &= custom_mask + if SLIDING_WINDOW_SIZE > 0: + # Add mask where q_id <= kv_id + sliding_window_size + window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= ( + start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE + ) + final_mask &= window_mask + qk = tl.where(final_mask, qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_buf_v = ( + offs_kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + # stage 2: compute the triangle part + + cur_block_m_end = ( + cur_seq_len_extend + if not IS_CAUSAL + else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) + ) + for start_n in range(0, cur_block_m_end, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_block_m_end + + # load k in transposed way + offs_k = ( + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] + ) + k = tl.load( + K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q, k, out_dtype=tl.float32) + if BLOCK_DPE > 0: + offs_kpe = ( + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + + cur_kv_head * stride_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Extend + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe, kpe) + + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + if USE_CUSTOM_MASK: + custom_mask = tl.load( + mask_ptr + + cur_seq_mask_start_idx + + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len + + cur_seq_len_prefix + + start_n + + offs_n[None, :], + mask=(mask_m[:, None] & mask_n[None, :]), + other=0, + ) + custom_mask &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(custom_mask, qk, float("-inf")) + elif IS_CAUSAL: + mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( + start_n + offs_n[None, :] + ) + mask_causual &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_causual, qk, float("-inf")) + else: + mask_non_causal = mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_non_causal, qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_v = ( + (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs + + cur_kv_head * stride_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + offs_o = ( + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_obs + + cur_head * stride_oh + + offs_dv[None, :] + ) + if STORE_TRANSPOSE: + tl.store( + O_Extend + offs_o.T, + (acc / deno[:, None]).T, + mask=(mask_m[:, None] & mask_dv[None, :]).T, + ) + else: + tl.store( + O_Extend + offs_o, + acc / deno[:, None], + mask=mask_m[:, None] & mask_dv[None, :], + ) + + +def test_extend_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _fwd_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file -- Gitee