diff --git a/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py b/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..17582127dcc0b152af121d9a1bed4bd774628455 --- /dev/null +++ b/ascend/test/sglang/v0.4.8/test_alloc_decode_kernel.py @@ -0,0 +1,74 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: python/sglang/srt/mem_cache/allocator.py +@triton.jit +def alloc_decode_kernel( + seq_lens_ptr, + last_loc_ptr, + free_page_ptr, + out_indices, + ret_values, + bs_upper: tl.constexpr, + page_size: tl.constexpr, +): + pid = tl.program_id(0) + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid) + pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens) + + seq_len = tl.load(seq_lens_ptr + pid) + pre_len = seq_len - 1 + + num_pages_after = (seq_lens + page_size - 1) // page_size + num_pages_before = (pre_lens + page_size - 1) // page_size + num_new_pages = num_pages_after - num_pages_before + + num_page_start_loc_self = (seq_len + page_size - 1) // page_size - ( + pre_len + page_size - 1 + ) // page_size + sum_num_new_pages = tl.sum(num_new_pages) + new_page_start_loc = sum_num_new_pages - num_page_start_loc_self + + # Return value + if pid == tl.num_programs(0) - 1: + tl.store(ret_values, sum_num_new_pages) + + if num_page_start_loc_self == 0: + last_loc = tl.load(last_loc_ptr + pid) + tl.store(out_indices + pid, last_loc + 1) + else: + page = tl.load(free_page_ptr + new_page_start_loc) + tl.store(out_indices + pid, page * page_size) + + +def test_alloc_decode_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # key : value + # [gpu_output] (dict): + # key : value + # [grid] : + # (1,) + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + alloc_decode_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed")