From 07f6e82b06d03b00e3ba1ea288b3c9c10ff28afa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E4=B8=80=E9=93=AD?= Date: Mon, 22 Sep 2025 11:38:35 +0800 Subject: [PATCH 1/3] feat(propagateNan):test propagateNan --- .../test_clamp_max_min_propagatenan.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py diff --git a/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py b/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py new file mode 100644 index 0000000..2b327c0 --- /dev/null +++ b/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import pytest + +import triton +import triton.language as tl +import time +import torch +import torch_npu +import test_common + +def torch_clamp_propagate_nan(x0, x1, x2, propagate_nan : tl.constexpr, nan_loc): + res = torch.clamp(x0, x1, x2) + if propagate_nan == 'NONE': + res = torch.where(nan_loc == 1, x2, res) + if propagate_nan == 'ALL': + res[nan_loc == 1] = torch.nan + return res + +def nan_cmp(dtype, y_cal, y_ref): + mask_cal = torch.isnan(y_cal) + mask_ref = torch.isnan(y_ref) + assert torch.equal(mask_cal, mask_ref) + y_cal_fixed = y_cal.masked_fill(mask_cal, 0) + y_ref_fixed = y_ref.masked_fill(mask_ref, 0) + test_common.validate_cmp(dtype, y_cal_fixed, y_ref_fixed) + +@triton.jit +def triton_clamp(in_ptr0, in_ptr1, in_ptr2, out_ptr0, propagate_nan : tl.constexpr, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = tl.load(in_ptr1 + (x0), None) + tmp2 = tl.load(in_ptr2 + (x0), None) + tmp3 = tl.clamp(tmp0, tmp1, tmp2, propagate_nan=getattr(tl.PropagateNan, propagate_nan)) + tl.store(out_ptr0 + (x0), tmp3, None) + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("shape", [(4,4)]) +@pytest.mark.parametrize("ncore", [4]) +@pytest.mark.parametrize("xblock", [4]) +@pytest.mark.parametrize("xblock_sub", [1]) +def test_clamp_propagate_nan(dtype, propagate_nan, shape, ncore, xblock, xblock_sub): + x0 = test_common.generate_tensor(shape, dtype).npu() + x1 = test_common.generate_tensor(shape, dtype).npu() + x2 = test_common.generate_tensor(shape, dtype).npu() + mask = (x1 > x2) + temp_x1 = torch.where(mask, x2, x1) + temp_x2 = torch.where(mask, x1, x2) + x1.copy_(temp_x1) + x2.copy_(temp_x2) + random_tensor = torch.randint(0, 2, shape) + random_tensor = random_tensor.to(x0.device) + x0[random_tensor == 1] = torch.nan + y_ref = torch_clamp_propagate_nan(x0, x1, x2, propagate_nan, random_tensor) + y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + triton_clamp[ncore, 1, 1](x0, x1, x2, y_cal, propagate_nan, xblock, xblock_sub) + nan_cmp(dtype, y_cal, y_ref) + +@pytest.mark.parametrize("dtype", ['float16', 'float32']) +@pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) +@pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) +@pytest.mark.parametrize("device", ['npu']) +def test_propagate_nan(dtype, propagate_nan, func, device): + + @triton.jit + def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): + if func == 'clamp': + tl.store( + C, + getattr(tl, func)(tl.load(A), -tl.load(B), tl.load(B), + propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + else: + tl.store(C, + getattr(tl, func)(tl.load(A), tl.load(B), propagate_nan=getattr(tl.PropagateNan, propagate_nan))) + + for mode in ['A', 'B', 'both']: + if func == 'clamp' and mode == 'B': + # clamp does not guarantee propagation from 'min' and 'max' args + continue + A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'A' or mode == 'both': A[0] = torch.nan + B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) + if mode == 'B' or mode == 'both': B[0] = torch.nan + C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) + kernel[(1, )](A, B, C, propagate_nan, func) + + if mode == 'both' or propagate_nan == 'ALL': + assert torch.isnan(C[0]) + else: + assert not torch.isnan(C[0]) \ No newline at end of file -- Gitee From ca01b617b84c6ad7e412f4610e9aeaf3f2fd6dd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E4=B8=80=E9=93=AD?= Date: Mon, 22 Sep 2025 16:04:53 +0800 Subject: [PATCH 2/3] test(propagateNan): add unit tests for propagateNan --- .../test_clamp_max_min_propagatenan.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py b/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py index 2b327c0..9606eef 100644 --- a/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py +++ b/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py @@ -1,15 +1,16 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import time import pytest import triton import triton.language as tl -import time import torch import torch_npu import test_common -def torch_clamp_propagate_nan(x0, x1, x2, propagate_nan : tl.constexpr, nan_loc): + +def torch_clamp_propagate_nan(x0, x1, x2, propagate_nan : tl.constexpr, nan_loc) : res = torch.clamp(x0, x1, x2) if propagate_nan == 'NONE': res = torch.where(nan_loc == 1, x2, res) @@ -17,6 +18,7 @@ def torch_clamp_propagate_nan(x0, x1, x2, propagate_nan : tl.constexpr, nan_loc) res[nan_loc == 1] = torch.nan return res + def nan_cmp(dtype, y_cal, y_ref): mask_cal = torch.isnan(y_cal) mask_ref = torch.isnan(y_ref) @@ -25,8 +27,9 @@ def nan_cmp(dtype, y_cal, y_ref): y_ref_fixed = y_ref.masked_fill(mask_ref, 0) test_common.validate_cmp(dtype, y_cal_fixed, y_ref_fixed) + @triton.jit -def triton_clamp(in_ptr0, in_ptr1, in_ptr2, out_ptr0, propagate_nan : tl.constexpr, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr): +def triton_clamp(in_ptr0, in_ptr1, in_ptr2, out_ptr0, propagate_nan : tl.constexpr, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr) : offset = tl.program_id(0) * XBLOCK base1 = tl.arange(0, XBLOCK_SUB) loops1: tl.constexpr = XBLOCK // XBLOCK_SUB @@ -38,9 +41,10 @@ def triton_clamp(in_ptr0, in_ptr1, in_ptr2, out_ptr0, propagate_nan : tl.constex tmp3 = tl.clamp(tmp0, tmp1, tmp2, propagate_nan=getattr(tl.PropagateNan, propagate_nan)) tl.store(out_ptr0 + (x0), tmp3, None) + @pytest.mark.parametrize("dtype", ['float16', 'float32']) @pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) -@pytest.mark.parametrize("shape", [(4,4)]) +@pytest.mark.parametrize("shape", [(4, 4)]) @pytest.mark.parametrize("ncore", [4]) @pytest.mark.parametrize("xblock", [4]) @pytest.mark.parametrize("xblock_sub", [1]) @@ -61,6 +65,7 @@ def test_clamp_propagate_nan(dtype, propagate_nan, shape, ncore, xblock, xblock_ triton_clamp[ncore, 1, 1](x0, x1, x2, y_cal, propagate_nan, xblock, xblock_sub) nan_cmp(dtype, y_cal, y_ref) + @pytest.mark.parametrize("dtype", ['float16', 'float32']) @pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) @pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) @@ -83,9 +88,11 @@ def test_propagate_nan(dtype, propagate_nan, func, device): # clamp does not guarantee propagation from 'min' and 'max' args continue A = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) - if mode == 'A' or mode == 'both': A[0] = torch.nan + if mode == 'A' or mode == 'both': + A[0] = torch.nan B = torch.randn((1, ), device=device, dtype=getattr(torch, dtype)) - if mode == 'B' or mode == 'both': B[0] = torch.nan + if mode == 'B' or mode == 'both': + B[0] = torch.nan C = torch.zeros_like(A, device=device, dtype=getattr(torch, dtype)) kernel[(1, )](A, B, C, propagate_nan, func) -- Gitee From 6160c18a2f142286feb5da85c6747a3a9afde516 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E4=B8=80=E9=93=AD?= Date: Mon, 22 Sep 2025 16:37:42 +0800 Subject: [PATCH 3/3] test(propagateNan): add unit tests for propagateNan --- .../examples/pytest_ut/test_clamp_max_min_propagatenan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py b/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py index 9606eef..1b38a4c 100644 --- a/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py +++ b/ascend/examples/pytest_ut/test_clamp_max_min_propagatenan.py @@ -10,7 +10,7 @@ import torch_npu import test_common -def torch_clamp_propagate_nan(x0, x1, x2, propagate_nan : tl.constexpr, nan_loc) : +def torch_clamp_propagate_nan(x0, x1, x2, propagate_nan: tl.constexpr, nan_loc): res = torch.clamp(x0, x1, x2) if propagate_nan == 'NONE': res = torch.where(nan_loc == 1, x2, res) @@ -29,7 +29,7 @@ def nan_cmp(dtype, y_cal, y_ref): @triton.jit -def triton_clamp(in_ptr0, in_ptr1, in_ptr2, out_ptr0, propagate_nan : tl.constexpr, XBLOCK : tl.constexpr, XBLOCK_SUB : tl.constexpr) : +def triton_clamp(in_ptr0, in_ptr1, in_ptr2, out_ptr0, propagate_nan: tl.constexpr, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): offset = tl.program_id(0) * XBLOCK base1 = tl.arange(0, XBLOCK_SUB) loops1: tl.constexpr = XBLOCK // XBLOCK_SUB @@ -61,7 +61,7 @@ def test_clamp_propagate_nan(dtype, propagate_nan, shape, ncore, xblock, xblock_ random_tensor = random_tensor.to(x0.device) x0[random_tensor == 1] = torch.nan y_ref = torch_clamp_propagate_nan(x0, x1, x2, propagate_nan, random_tensor) - y_cal = torch.zeros(shape, dtype = eval('torch.' + dtype)).npu() + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() triton_clamp[ncore, 1, 1](x0, x1, x2, y_cal, propagate_nan, xblock, xblock_sub) nan_cmp(dtype, y_cal, y_ref) -- Gitee