From d527e396b3b519cb603f2efc085e06ad0fb63ed8 Mon Sep 17 00:00:00 2001 From: zzy Date: Fri, 26 Sep 2025 15:05:03 +0800 Subject: [PATCH 1/2] test(sglang): add UT cases for sglang kernel(alloc_decode_kernel) --- ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py diff --git a/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py b/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py new file mode 100644 index 0000000..e69de29 -- Gitee From e3f0ece3ff54c6dea2c7387ffb7ed3952c021777 Mon Sep 17 00:00:00 2001 From: zzy Date: Fri, 26 Sep 2025 15:05:03 +0800 Subject: [PATCH 2/2] test(sglang): add UT cases for sglang kernel(alloc_decode_kernel) --- .../sglang/v0.4.8/test_alloc_decode_kernel.py | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py diff --git a/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py b/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py new file mode 100644 index 0000000..1758212 --- /dev/null +++ b/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py @@ -0,0 +1,74 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/mem_cache/allocator.py +@triton.jit +def alloc_decode_kernel( + seq_lens_ptr, + last_loc_ptr, + free_page_ptr, + out_indices, + ret_values, + bs_upper: tl.constexpr, + page_size: tl.constexpr, +): + pid = tl.program_id(0) + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid) + pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens) + + seq_len = tl.load(seq_lens_ptr + pid) + pre_len = seq_len - 1 + + num_pages_after = (seq_lens + page_size - 1) // page_size + num_pages_before = (pre_lens + page_size - 1) // page_size + num_new_pages = num_pages_after - num_pages_before + + num_page_start_loc_self = (seq_len + page_size - 1) // page_size - ( + pre_len + page_size - 1 + ) // page_size + sum_num_new_pages = tl.sum(num_new_pages) + new_page_start_loc = sum_num_new_pages - num_page_start_loc_self + + # Return value + if pid == tl.num_programs(0) - 1: + tl.store(ret_values, sum_num_new_pages) + + if num_page_start_loc_self == 0: + last_loc = tl.load(last_loc_ptr + pid) + tl.store(out_indices + pid, last_loc + 1) + else: + page = tl.load(free_page_ptr + new_page_start_loc) + tl.store(out_indices + pid, page * page_size) + + +def test_alloc_decode_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + alloc_decode_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") -- Gitee