From 175d516a2188db4189692c4a03b13c1c164b467c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=8D=9A=E6=A5=9A?= Date: Wed, 13 Aug 2025 17:34:20 +0800 Subject: [PATCH] refactor: remove the 16 fold label for plastic surgery. --- .../python/triton_patch/runtime/jit.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/triton_patch/python/triton_patch/runtime/jit.py b/triton_patch/python/triton_patch/runtime/jit.py index f6eaf52..db52f8d 100644 --- a/triton_patch/python/triton_patch/runtime/jit.py +++ b/triton_patch/python/triton_patch/runtime/jit.py @@ -277,7 +277,6 @@ class KernelParam: def compute_spec_key(v, align): - if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): return "D" elif isinstance(v, int): @@ -292,6 +291,18 @@ def compute_spec_key(v, align): dtype2str = {} +def compute_spec_key_ascend(value, align): + if align and hasattr(value, "data_ptr") and (value.data_ptr() % 16 == 0): + return "D" + elif isinstance(value, int): + # bool is a subclass of int, so we don't check explicitly above. + if align and (value % 16 == 0): + return "N" + elif value == 1: + return "1" + return "N" + + def mangle_type(arg, is_const=False): if arg is None: @@ -374,9 +385,9 @@ def create_function_from_signature(sig, kparams, backend): non_constexpr_vals.append(name) if not kp.do_not_specialize: if not kp.do_not_specialize_on_alignment: - specialisations.append('compute_spec_key(%s, align=True)' % name) + specialisations.append('compute_spec_key_ascend(%s, align=True)' % name) else: - specialisations.append('compute_spec_key(%s, align=False)' % name) + specialisations.append('compute_spec_key_ascend(%s, align=False)' % name) if kp.annotation_type: signature_types.append('"%s"' % kp.annotation_type) else: @@ -402,7 +413,7 @@ def create_function_from_signature(sig, kparams, backend): } func_namespace['mangle_type'] = mangle_type - func_namespace['compute_spec_key'] = backend.compute_spec_key + func_namespace['compute_spec_key_ascend'] = compute_spec_key_ascend # Execute the function string in func_namespace to create the function exec(func_body, func_namespace) @@ -584,7 +595,6 @@ class JITFunction(KernelInterface[T]): # compute cache key key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) kernel = self.cache[device].get(key, None) - if kernel is None: # Kernel is not cached; we have to compile. options = backend.parse_options(kwargs) -- Gitee