From b48f81349822fdbf386b548ab9f3f4da5511ba01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=9D=E6=97=B6=E5=8D=97=E6=9F=AF=E4=B8=80=E6=A2=A6?= <1642525756@qq.com> Date: Fri, 26 Sep 2025 11:04:09 +0800 Subject: [PATCH] test(sglang): add UT cases for sglang kernel(compute_src2dst_triton_kernel) --- .../test_compute_src2dst_triton_kernel.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 ascend/test/sglang/v0.4.8/test_compute_src2dst_triton_kernel.py diff --git a/ascend/test/sglang/v0.4.8/test_compute_src2dst_triton_kernel.py b/ascend/test/sglang/v0.4.8/test_compute_src2dst_triton_kernel.py new file mode 100644 index 0000000..0d61387 --- /dev/null +++ b/ascend/test/sglang/v0.4.8/test_compute_src2dst_triton_kernel.py @@ -0,0 +1,37 @@ +import sys +import pytest +import triton +import torch +import triton.language as tl +sys.path.append("..") +import test_common + + +# source: python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + tl.store(src2dst + src_id, dst_id, mask=mask) + + +def test_compute_src2dst_triton_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + compute_src2dst_triton_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") + \ No newline at end of file -- Gitee