diff --git a/ascend/test/sglang/v0.4.8/test__fwd_kernel_ep_scatter_1.py b/ascend/test/sglang/v0.4.8/test__fwd_kernel_ep_scatter_1.py new file mode 100644 index 0000000000000000000000000000000000000000..ecdfab64502842cea8415ac4f3c36bf0aaa15de9 --- /dev/null +++ b/ascend/test/sglang/v0.4.8/test__fwd_kernel_ep_scatter_1.py @@ -0,0 +1,60 @@ +import sys +import pytest +import torch + +import triton +import triton.language as tl + + +sys.path.append("..") +import test_common + + +#source python\sglang\srt\layers\moe\ep_moe\kernels.py +@triton.jit +def _fwd_kernel_ep_scatter_1( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts: tl.constexpr, + BLOCK_E: tl.constexpr, + BLOCK_EXPERT_NUM: tl.constexpr, +): + cur_expert = tl.program_id(0) + + offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) + tokens_per_expert = tl.load( + num_recv_tokens_per_expert + offset_cumsum, + mask=offset_cumsum < num_experts, + other=0, + ) + cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert + tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) + + cur_expert_start = tl.load(expert_start_loc + cur_expert) + cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) + + m_indices_start_ptr = m_indices + cur_expert_start + off_expert = tl.arange(0, BLOCK_E) + + for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): + tl.store( + m_indices_start_ptr + start_m + off_expert, + cur_expert, + ) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + _fwd_kernel_ep_scatter_1[data['grid']](**input_data) + + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") \ No newline at end of file