diff --git a/ascend/examples/pytest_ut/test_lanzcos.py b/ascend/examples/pytest_ut/test_lanzcos.py index 4fd54669d0fb944395740974df9c18804abf9844..d21e7c91e4b802b5d84f10b0f93f492080b01be7 100644 --- a/ascend/examples/pytest_ut/test_lanzcos.py +++ b/ascend/examples/pytest_ut/test_lanzcos.py @@ -6,8 +6,49 @@ import math import pytest +def profiler_wrapper(fn, *args): + result_path = "./result_profiling" + skip_first = 10 + wait = 0 + warmup = 3 + active = 30 + repeat = 1 + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False, + ) + with torch_npu.profiler.profile( + activities=[ + # torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule( + wait=wait, + warmup=warmup, + active=active, + repeat=repeat, + skip_first=skip_first, + ), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), + record_shapes=True, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config, + ) as prof: + stream.synchronize() + for _ in range(skip_first + (wait + warmup + active) * repeat): + fn(*args) + prof.step() + stream.synchronize() + + @triton.jit -def lanczos_resize_kernel( +def triton_kernel( img_src_ptr, img_dst_ptr, img_coeffs_ptr, @@ -68,8 +109,9 @@ def lanczos_resize_kernel( tl.store(img_dst_ptr + dest_offs, res, mask=dst_mask) -def lanczos_resize_triton(img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols): +def triton_func(img_src, img_dst, c_lanczosCoeffs): N, C, src_rows, src_cols = img_src.shape + _, _, dst_rows, dst_cols = img_dst.shape R_H = float(dst_rows) / src_rows R_W = float(dst_cols) / src_cols @@ -81,7 +123,7 @@ def lanczos_resize_triton(img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols) triton.cdiv(dst_rows, meta["BLOCK_SIZE"]), triton.cdiv(dst_cols, meta["BLOCK_SIZE"]), ) - lanczos_resize_kernel[grid]( + triton_kernel[grid]( img_src, img_dst, c_lanczosCoeffs, @@ -103,8 +145,137 @@ def lanczos_resize_triton(img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols) return img_dst -def lanczos_resize_cpu(img_src, img_dst, img_coeffs, dst_rows, dst_cols): +@triton.jit +def triton_kernel_opt1( + img_src_ptr, + img_dst_ptr, + img_coeffs_ptr, + src_rows, + src_cols, + dst_rows, + dst_cols, + R_H, + R_W, + C, + stride_in_h, + stride_in_w, + stride_in_c, + stride_out_h, + stride_out_w, + stride_out_c, + HBLOCK_SIZE: tl.constexpr, + WBLOCK_SIZE: tl.constexpr, +): + tl.static_assert( + HBLOCK_SIZE == 1, "HBLOCK_SIZE must be 1 in this row load + gather optimization" + ) + block_id_c = tl.program_id(0) + block_id_h = tl.program_id(1) + block_id_w = tl.program_id(2) + dest_h_offs = block_id_h * HBLOCK_SIZE + tl.arange(0, HBLOCK_SIZE) + dest_w_offs = block_id_w * WBLOCK_SIZE + tl.arange(0, WBLOCK_SIZE) + dest_offs = ( + block_id_c[None, None] * stride_out_c + + dest_h_offs[:, None] * stride_out_h + + dest_w_offs[None, :] * stride_out_w + ) + + RR_H = 1.0 / R_H + RR_W = 1.0 / R_W + + fy = (dest_h_offs + 0.5) * RR_H - 0.5 + sy = tl.floor(fy) + fx = (dest_w_offs + 0.5) * RR_W - 0.5 + sx = tl.floor(fx) + + idxY = tl.floor((fy - sy) * 24.999999).to(tl.int32) + idxX = tl.floor((fx - sx) * 24.999999).to(tl.int32) + tableIndex = idxY[:, None] * 25 + idxX[None, :] + res = tl.zeros((HBLOCK_SIZE, WBLOCK_SIZE), tl.float32) + + for ii in range(4): + for jj in range(4): + src_offsets = ( + block_id_c[None, None] * stride_in_c + + (tl.clamp((sy + ii - 1), 0, src_rows - 1)).to(tl.int32)[:, None] + * stride_in_h + + (tl.clamp((sx + jj - 1), 0, src_cols - 1)).to(tl.int32)[None, :] + * stride_in_w + ) + # original: indirect load will be convertd to loop of scalar load + # src_val = tl.load(img_src_ptr + src_offsets) + # optimization: contiguous load of single row + gather + src_offset_min = tl.min(src_offsets, 1) + MAX_CONTIGUOUS_LEN: tl.constexpr = 1024 # in fact 640 is enough + # NOTE: you can use the following 2 lines to check the contiguous length + # src_offset_max = tl.max(src_offsets, 1) + # tl.device_print("idxMax - idxMax = ", src_offset_max - src_offset_min) + holder_idx = tl.arange(0, MAX_CONTIGUOUS_LEN) + # TODO: triton-adapter currently does not support the following 2 lines of load + # holder_vals = tl.load(img_src_ptr + src_offset_min + holder_idx[None, :]) + # holder_vals = tl.load(img_src_ptr + src_offset_min + holder_idx)[None, :] + # NOTE: currently triton-adapter requires the expand_dim op (i.e., [None, :]) + # to be used after the load + src_idx = src_offset_min + holder_idx + holder_vals = tl.load(img_src_ptr + src_idx)[None, :] + src_val = tl.gather(holder_vals, src_offsets - src_offset_min, 1) + # + coeffs_offs = tableIndex[:, :] * 16 + (ii * 4 + jj)[None, None] + # original: indirect load will be convertd to loop of scalar load + # coeffs = tl.load(img_coeffs_ptr + coeffs_offs) + # optimization: contiguous load of single row + gather + coeffs_offs_min = tl.min(coeffs_offs, 1) + img_coeffs_idx = coeffs_offs_min + holder_idx + holder_vals_coeffs = tl.load(img_coeffs_ptr + img_coeffs_idx)[None, :] + coeffs = tl.gather(holder_vals_coeffs, coeffs_offs - coeffs_offs_min, 1) + # + res = res + src_val * coeffs + dst_mask = (dest_h_offs[:, None] < dst_rows) & (dest_w_offs[None, :] < dst_cols) + res = tl.clamp(res, 0.0, 1.0) + tl.store(img_dst_ptr + dest_offs, res, mask=dst_mask) + + +def triton_func_opt1(img_src, img_dst, c_lanczosCoeffs): + N, C, src_rows, src_cols = img_src.shape + _, _, dst_rows, dst_cols = img_dst.shape + R_H = float(dst_rows) / src_rows + R_W = float(dst_cols) / src_cols + + stride_in_n, stride_in_c, stride_in_h, stride_in_w = img_src.stride() + stride_out_n, stride_out_c, stride_out_h, stride_out_w = img_dst.stride() + bs_h = 1 + bs_w = dst_cols + grid = lambda meta: ( + C, + triton.cdiv(dst_rows, meta["HBLOCK_SIZE"]), + triton.cdiv(dst_cols, meta["WBLOCK_SIZE"]), + ) + triton_kernel_opt1[grid]( + img_src, + img_dst, + c_lanczosCoeffs, + src_rows, + src_cols, + dst_rows, + dst_cols, + R_H, + R_W, + C, + stride_in_h, + stride_in_w, + stride_in_c, + stride_out_h, + stride_out_w, + stride_out_c, + bs_h, + bs_w, + ) + return img_dst + + +def numpy_func(img_src, img_dst, img_coeffs): N, C, src_rows, src_cols = img_src.shape + _, _, dst_rows, dst_cols = img_dst.shape R_H = float(dst_rows) / src_rows R_W = float(dst_cols) / src_cols for i in range(dst_rows): @@ -131,7 +302,12 @@ def lanczos_resize_cpu(img_src, img_dst, img_coeffs, dst_rows, dst_cols): img_dst[0, :, i, j] = np.clip(res, 0.0, 1.0) -@pytest.mark.parametrize("shapes", [[360, 640, 140, 280],]) +@pytest.mark.parametrize( + "shapes", + [ + [360, 640, 140, 280], + ], +) def test_lanzcos(shapes): c_lanczosCoeffs = torch.randn(10000, dtype=torch.float32, device="npu") / 4.0 src_rows, src_cols, dst_rows, dst_cols = shapes @@ -141,18 +317,21 @@ def test_lanzcos(shapes): dtype=img_src.dtype, device=img_src.device, ) - resized_image = lanczos_resize_triton( - img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols - ) + # cpu version img_src_cpu = img_src.cpu().numpy() img_dst_cpu = torch.zeros( (1, img_src_cpu.shape[1], dst_rows, dst_cols), dtype=img_src.dtype, device="cpu" ).numpy() - lanczos_resize_cpu( - img_src_cpu, img_dst_cpu, c_lanczosCoeffs.cpu().numpy(), dst_rows, dst_cols + numpy_func(img_src_cpu, img_dst_cpu, c_lanczosCoeffs.cpu().numpy()) + # base version + img_dst_npu = triton_func(img_src, img_dst, c_lanczosCoeffs) + torch.testing.assert_close( + img_dst_npu.cpu(), torch.from_numpy(img_dst_cpu), atol=1.0 / 255, rtol=0 ) + # optimized version + img_dst_npu_op1 = triton_func_opt1(img_src, img_dst, c_lanczosCoeffs) torch.testing.assert_close( - resized_image.cpu(), torch.from_numpy(img_dst_cpu), atol=1.0 / 255, rtol=0 + img_dst_npu_op1.cpu(), torch.from_numpy(img_dst_cpu), atol=1.0 / 255, rtol=0 ) @@ -215,18 +394,14 @@ if __name__ == "__main__": dtype=img_src.dtype, device=img_src.device, ) - resized_image = lanczos_resize_triton( - img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols - ) + resized_image = triton_func(img_src, img_dst, c_lanczosCoeffs) resized_cpu = resized_image.cpu().numpy() print("==========run cpu===============") img_src_cpu = img_src.cpu().numpy() img_dst_cpu = torch.zeros( (1, img_src_cpu.shape[1], dst_rows, dst_cols), dtype=img_src.dtype, device="cpu" ).numpy() - lanczos_resize_cpu( - img_src_cpu, img_dst_cpu, c_lanczosCoeffs.cpu().numpy(), dst_rows, dst_cols - ) + numpy_func(img_src_cpu, img_dst_cpu, c_lanczosCoeffs.cpu().numpy()) print("==========compare result===============") diff = np.abs(resized_cpu - img_dst_cpu) @@ -238,16 +413,22 @@ if __name__ == "__main__": ) print("==========profiling===============") - accelerate, eager_time, triton_time = benchmark_test( - lanczos_resize_cpu, - lanczos_resize_triton, - ref_args=( - img_src_cpu, - img_dst_cpu, - c_lanczosCoeffs.cpu().numpy(), - dst_rows, - dst_cols, - ), - triton_args=(img_src, img_dst, c_lanczosCoeffs, dst_rows, dst_cols), - name="lanzcos", - ) + + def wrapper_func(img_src, img_dst, c_lanczosCoeffs): + triton_cal = triton_func(img_src, img_dst, c_lanczosCoeffs) + triton_cal = triton_func_opt1(img_src, img_dst, c_lanczosCoeffs) + + profiler_wrapper(wrapper_func, img_src, img_dst, c_lanczosCoeffs) + + # CPU version is so slow that it is not suitable for profiling + # accelerate, eager_time, triton_time = benchmark_test( + # numpy_func, + # triton_func, + # ref_args=( + # img_src_cpu, + # img_dst_cpu, + # c_lanczosCoeffs.cpu().numpy(), + # ), + # triton_args=(img_src, img_dst, c_lanczosCoeffs), + # name="lanzcos", + # ) diff --git a/ascend/examples/pytest_ut/test_nearest.py b/ascend/examples/pytest_ut/test_nearest.py index de734ace8f2e245e3b4529a1251be357bfd8f241..675418bc2270b3247616d2c4ed914b40caf5c58f 100644 --- a/ascend/examples/pytest_ut/test_nearest.py +++ b/ascend/examples/pytest_ut/test_nearest.py @@ -6,30 +6,79 @@ import math import numpy as np import pytest + +def profiler_wrapper(fn, *args): + result_path = "./result_profiling" + skip_first = 10 + wait = 0 + warmup = 3 + active = 30 + repeat = 1 + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False, + ) + with torch_npu.profiler.profile( + activities=[ + # torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule( + wait=wait, + warmup=warmup, + active=active, + repeat=repeat, + skip_first=skip_first, + ), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), + record_shapes=True, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config, + ) as prof: + stream.synchronize() + for _ in range(skip_first + (wait + warmup + active) * repeat): + fn(*args) + prof.step() + stream.synchronize() + + @triton.jit -def nearest_resize_kernel( - img_src_ptr, img_dst_ptr, src_rows, src_cols, dst_rows, dst_cols, - RR_H, RR_W, C, - stride_in_h, stride_in_w, stride_in_c, - stride_out_h, stride_out_w, stride_out_c, - BLOCK_SIZE: tl.constexpr +def triton_kernel( + img_src_ptr, + img_dst_ptr, + src_rows, + src_cols, + dst_rows, + dst_cols, + RR_H, + RR_W, + C, + stride_in_h, + stride_in_w, + stride_in_c, + stride_out_h, + stride_out_w, + stride_out_c, + BLOCK_SIZE: tl.constexpr, ): - #RR_H和RR_W分别为高和宽的缩放比例 + # RR_H和RR_W分别为高和宽的缩放比例 block_id_c = tl.program_id(0) block_id_h = tl.program_id(1) block_id_w = tl.program_id(2) - dest_h_offs = ( - block_id_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - ) - dest_w_offs = ( - block_id_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - ) + dest_h_offs = block_id_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + dest_w_offs = block_id_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) dest_offs = ( - block_id_c[ None, None] * stride_out_c - + dest_h_offs[ :, None] * stride_out_h - + dest_w_offs[ None, :] * stride_out_w + block_id_c[None, None] * stride_out_c + + dest_h_offs[:, None] * stride_out_h + + dest_w_offs[None, :] * stride_out_w ) - #根据output image的坐标值(dest_h_offs, dest_w_offs)计算input image的坐标值(sy, sx) + # 根据output image的坐标值(dest_h_offs, dest_w_offs)计算input image的坐标值(sy, sx) fy = dest_h_offs * RR_H sy = tl.floor(fy) fx = dest_w_offs * RR_W @@ -37,13 +86,15 @@ def nearest_resize_kernel( src_offsets = ( block_id_c[None, None] * stride_in_c - + tl.clamp(sy, 0, src_rows -1)[:, None].to(tl.int32) * stride_in_h - + tl.clamp(sx, 0, src_cols -1)[None, :].to(tl.int32) * stride_in_w) + + tl.clamp(sy, 0, src_rows - 1)[:, None].to(tl.int32) * stride_in_h + + tl.clamp(sx, 0, src_cols - 1)[None, :].to(tl.int32) * stride_in_w + ) src_val = tl.load(img_src_ptr + src_offsets) - dst_mask = (dest_h_offs[ :, None] < dst_rows) & (dest_w_offs[None, :] < dst_cols) + dst_mask = (dest_h_offs[:, None] < dst_rows) & (dest_w_offs[None, :] < dst_cols) tl.store(img_dst_ptr + dest_offs, src_val, mask=dst_mask) -def triton_kernel(img_src, img_dst): + +def triton_func(img_src, img_dst): N, C, src_rows, src_cols = img_src.shape _, _, dst_rows, dst_cols = img_dst.shape R_H = float(dst_rows) / src_rows @@ -58,48 +109,194 @@ def triton_kernel(img_src, img_dst): triton.cdiv(dst_rows, meta["BLOCK_SIZE"]), triton.cdiv(dst_cols, meta["BLOCK_SIZE"]), ) - nearest_resize_kernel[grid]( - img_src, img_dst, src_rows, src_cols, dst_rows, dst_cols, - RR_H, RR_W, C, - stride_in_h, stride_in_w, stride_in_c, - stride_out_h, stride_out_w, stride_out_c, - bs) + triton_kernel[grid]( + img_src, + img_dst, + src_rows, + src_cols, + dst_rows, + dst_cols, + RR_H, + RR_W, + C, + stride_in_h, + stride_in_w, + stride_in_c, + stride_out_h, + stride_out_w, + stride_out_c, + bs, + ) return img_dst -def nearest_resize_cpu(img_src, img_dst): + +@triton.jit +def triton_kernel_opt1( + img_src_ptr, + img_dst_ptr, + src_rows, + src_cols, + dst_rows, + dst_cols, + RR_H, + RR_W, + C, + stride_in_h, + stride_in_w, + stride_in_c, + stride_out_h, + stride_out_w, + stride_out_c, + HBLOCK_SIZE: tl.constexpr, + WBLOCK_SIZE: tl.constexpr, +): + tl.static_assert( + HBLOCK_SIZE == 1, "HBLOCK_SIZE must be 1 in this row load + gather optimization" + ) + # RR_H和RR_W分别为高和宽的缩放比例 + block_id_c = tl.program_id(0) # 4 + block_id_h = tl.program_id(1) # triton.cdiv(140, HBLOCK_SIZE=16) = 9 + block_id_w = tl.program_id(2) # triton.cdiv(280, WBLOCK_SIZE=16) = 18 + dest_h_offs = block_id_h * HBLOCK_SIZE + tl.arange(0, HBLOCK_SIZE) + dest_w_offs = block_id_w * WBLOCK_SIZE + tl.arange(0, WBLOCK_SIZE) + dest_offs = ( + block_id_c[None, None] * stride_out_c + + dest_h_offs[:, None] * stride_out_h + + dest_w_offs[None, :] * stride_out_w + ) + # 根据output image的坐标值(dest_h_offs, dest_w_offs)计算input image的坐标值(sy, sx) + fy = dest_h_offs * RR_H + sy = tl.floor(fy) + fx = dest_w_offs * RR_W + sx = tl.floor(fx) + + src_offsets = ( + block_id_c[None, None] * stride_in_c + + tl.clamp(sy, 0, src_rows - 1)[:, None].to(tl.int32) * stride_in_h + + tl.clamp(sx, 0, src_cols - 1)[None, :].to(tl.int32) * stride_in_w + ) + # original: indirect load will be convertd to loop of scalar load + # src_val = tl.load(img_src_ptr + src_offsets) + # optimization: contiguous load of single row + gather + src_offset_min = tl.min(src_offsets, 1) + MAX_CONTIGUOUS_LEN: tl.constexpr = 1024 # in fact 640 is enough + # NOTE: you can use the following 2 lines to check the contiguous length + # src_offset_max = tl.max(src_offsets, 1) + # tl.device_print("idxMax - idxMax = ", src_offset_max - src_offset_min) + holder_idx = tl.arange(0, MAX_CONTIGUOUS_LEN) + # TODO: triton-adapter currently does not support the following 2 lines of load + # holder_vals = tl.load(img_src_ptr + src_offset_min + holder_idx[None, :]) + # holder_vals = tl.load(img_src_ptr + src_offset_min + holder_idx)[None, :] + # NOTE: currently triton-adapter requires the expand_dim op (i.e., [None, :]) + # to be used after the load + src_idx = src_offset_min + holder_idx + holder_vals = tl.load(img_src_ptr + src_idx)[None, :] + src_val = tl.gather(holder_vals, src_offsets - src_offset_min, 1) + dst_mask = (dest_h_offs[:, None] < dst_rows) & (dest_w_offs[None, :] < dst_cols) + tl.store(img_dst_ptr + dest_offs, src_val, mask=dst_mask) + + +def triton_func_opt1(img_src, img_dst): + N, C, src_rows, src_cols = img_src.shape + _, _, dst_rows, dst_cols = img_dst.shape + R_H = float(dst_rows) / src_rows + R_W = float(dst_cols) / src_cols + RR_H = 1.0 / R_H + RR_W = 1.0 / R_W + stride_in_n, stride_in_c, stride_in_h, stride_in_w = img_src.stride() + stride_out_n, stride_out_c, stride_out_h, stride_out_w = img_dst.stride() + bs_h = 1 + bs_w = dst_cols + grid = lambda meta: ( + C, + triton.cdiv(dst_rows, meta["HBLOCK_SIZE"]), + triton.cdiv(dst_cols, meta["WBLOCK_SIZE"]), + ) + triton_kernel_opt1[grid]( + img_src, + img_dst, + src_rows, + src_cols, + dst_rows, + dst_cols, + RR_H, + RR_W, + C, + stride_in_h, + stride_in_w, + stride_in_c, + stride_out_h, + stride_out_w, + stride_out_c, + bs_h, + bs_w, + ) + return img_dst + + +def numpy_func(img_src, img_dst): N, C, src_rows, src_cols = img_src.shape _, _, dst_rows, dst_cols = img_dst.shape - #RR_H和RR_W分别为高和宽的缩放比例 - RR_H = src_rows / float(dst_rows) - RR_W = src_cols / float(dst_cols) - #根据output image的坐标值(i,j)计算input image的坐标值(sy, sx) + # RR_H和RR_W分别为高和宽的缩放比例 + RR_H = src_rows / float(dst_rows) + RR_W = src_cols / float(dst_cols) + # 根据output image的坐标值(i,j)计算input image的坐标值(sy, sx) for i in range(dst_rows): for j in range(dst_cols): - fy = (i * RR_H) + fy = i * RR_H sy = math.floor(fy) - fx = (j * RR_W) + fx = j * RR_W sx = math.floor(fx) - src_val = img_src[0, :, np.clip(sy, 0, src_rows -1), np.clip(sx, 0, src_cols -1)] + src_val = img_src[ + 0, :, np.clip(sy, 0, src_rows - 1), np.clip(sx, 0, src_cols - 1) + ] img_dst[0, :, i, j] = src_val return img_dst -@pytest.mark.parametrize("shapes", [[360, 640, 140, 280],]) + +@pytest.mark.parametrize( + "shapes", + [ + [360, 640, 140, 280], + ], +) def test_nearest(shapes): src_rows, src_cols, dst_rows, dst_cols = shapes - img_src = torch.rand(1, 4, src_rows, src_cols, dtype=torch.float32, device='npu') - img_dst = torch.zeros((1, img_src.shape[1], dst_rows, dst_cols), dtype=img_src.dtype, device=img_src.device) - torch_ref = nearest_resize_cpu(img_src.cpu(), img_dst.cpu()) - triton_cal = triton_kernel(img_src, img_dst) + img_src = torch.rand(1, 4, src_rows, src_cols, dtype=torch.float32, device="npu") + img_dst = torch.zeros( + (1, img_src.shape[1], dst_rows, dst_cols), + dtype=img_src.dtype, + device=img_src.device, + ) + torch_ref = numpy_func(img_src.cpu(), img_dst.cpu()) + triton_cal = triton_func(img_src, img_dst) torch.testing.assert_close(torch_ref.npu(), triton_cal) + triton_cal_opt1 = triton_func_opt1(img_src, img_dst) + torch.testing.assert_close(torch_ref.npu(), triton_cal_opt1) + if __name__ == "__main__": src_rows, src_cols = 360, 640 dst_rows, dst_cols = 140, 280 - img_src = torch.rand(1, 4, src_rows, src_cols, dtype=torch.float32, device='npu') - img_dst = torch.zeros((1, img_src.shape[1], dst_rows, dst_cols), dtype=img_src.dtype, device=img_src.device) + img_src = torch.rand(1, 4, src_rows, src_cols, dtype=torch.float32, device="npu") + img_dst = torch.zeros( + (1, img_src.shape[1], dst_rows, dst_cols), + dtype=img_src.dtype, + device=img_src.device, + ) - assert img_src.shape[0] == 1, "currently supports only shape[0] == 1 which does not change the functionality of thie case" - torch_ref = nearest_resize_cpu(img_src.cpu(), img_dst.cpu()) - triton_cal = triton_kernel(img_src, img_dst) + assert ( + img_src.shape[0] == 1 + ), "currently supports only shape[0] == 1 which does not change the functionality of thie case" + torch_ref = numpy_func(img_src.cpu(), img_dst.cpu()) + triton_cal = triton_func(img_src, img_dst) torch.testing.assert_close(torch_ref.npu(), triton_cal) - print("success") \ No newline at end of file + triton_cal_opt1 = triton_func_opt1(img_src, img_dst) + torch.testing.assert_close(torch_ref.npu(), triton_cal_opt1) + + def wrapper_func(img_src, img_dst): + triton_cal = triton_func(img_src, img_dst) + triton_cal = triton_func_opt1(img_src, img_dst) + + profiler_wrapper(wrapper_func, img_src, img_dst) + print("success") diff --git a/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp b/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp index e3aa8f806f53e8a83099810f02a9ee847ea41ea8..1e452f88c192c696348d87dbb23df7999ae0fb39 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp @@ -38,1194 +38,1230 @@ namespace triton { // return (v1 > v2) ? v1 : v2; // } -SmallVector &BlockData::getOffsetsRef() { return this->offsets; } +SmallVector &BlockData::getOffsetsRef() +{ + return this->offsets; +} -SmallVector &BlockData::getSizesRef() { return this->sizes; } +SmallVector &BlockData::getSizesRef() +{ + return this->sizes; +} -SmallVector &BlockData::getStridesRef() { return this->strides; } +SmallVector &BlockData::getStridesRef() +{ + return this->strides; +} -Value &BlockData::getSourceRef() { return this->source; } +Value &BlockData::getSourceRef() +{ + return this->source; +} -OpFoldResult &BlockData::getScalarRef() { return this->scalar; } +OpFoldResult &BlockData::getScalarRef() +{ + return this->scalar; +} -SmallVector BlockData::getOffsets() const { - return this->offsets; +SmallVector BlockData::getOffsets() const +{ + return this->offsets; } -SmallVector BlockData::getSizes() const { return this->sizes; } +SmallVector BlockData::getSizes() const +{ + return this->sizes; +} -SmallVector BlockData::getStrides() const { - return this->strides; +SmallVector BlockData::getStrides() const +{ + return this->strides; } -OpFoldResult BlockData::getOffset(int index) const { - return this->offsets[index]; +OpFoldResult BlockData::getOffset(int index) const +{ + return this->offsets[index]; } -OpFoldResult BlockData::getSize(int index) const { return this->sizes[index]; } +OpFoldResult BlockData::getSize(int index) const +{ + return this->sizes[index]; +} -OpFoldResult BlockData::getStride(int index) const { - return this->strides[index]; +OpFoldResult BlockData::getStride(int index) const +{ + return this->strides[index]; } -OpFoldResult BlockData::getScalar() const { return this->scalar; } +OpFoldResult BlockData::getScalar() const +{ + return this->scalar; +} -Value BlockData::getSource() const { return this->source; } +Value BlockData::getSource() const +{ + return this->source; +} -MemAccType BlockData::getMemAccType() const { return this->memAccTy; }; +MemAccType BlockData::getMemAccType() const +{ + return this->memAccTy; +}; -MemAccType &BlockData::getMemAccTypeRef() { return this->memAccTy; }; +MemAccType &BlockData::getMemAccTypeRef() +{ + return this->memAccTy; +}; -bool BlockData::isScalar() const { return !(this->scalar).isNull(); } +bool BlockData::isScalar() const +{ + return !(this->scalar).isNull(); +} -bool BlockData::isEmpty() const { - return !(this->getRank() || this->source || !(this->scalar).isNull()); +bool BlockData::isEmpty() const +{ + return !(this->getRank() || this->source || !(this->scalar).isNull()); } -bool BlockData::hasSource() const { return this->source != nullptr; } +bool BlockData::hasSource() const +{ + return this->source != nullptr; +} -void BlockData::removeSource() { this->source = nullptr; }; +void BlockData::removeSource() +{ + this->source = nullptr; +}; -bool BlockData::hasResElemTy() const { return this->resElemTy != nullptr; } +bool BlockData::hasResElemTy() const +{ + return this->resElemTy != nullptr; +} -Type &BlockData::getResElemTyRef() { return this->resElemTy; } +Type &BlockData::getResElemTyRef() +{ + return this->resElemTy; +} -Type BlockData::getResElemTy() const { return this->resElemTy; } +Type BlockData::getResElemTy() const +{ + return this->resElemTy; +} -int64_t BlockData::getRank() const { - assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); - return this->offsets.size(); +int64_t BlockData::getRank() const +{ + assert(offsets.size() == sizes.size() && offsets.size() == strides.size()); + return this->offsets.size(); } -void BlockData::setResElemTy(const Type &Ty) { this->resElemTy = Ty; } +void BlockData::setResElemTy(const Type &Ty) +{ + this->resElemTy = Ty; +} -void BlockData::setScalar(const OpFoldResult &scalar) { this->scalar = scalar; } +void BlockData::setScalar(const OpFoldResult &scalar) +{ + this->scalar = scalar; +} -void BlockData::setSource(const Value &src) { this->source = src; } +void BlockData::setSource(const Value &src) +{ + this->source = src; +} -void BlockData::setOffsets(const SmallVector &offsets) { - this->offsets = offsets; +void BlockData::setOffsets(const SmallVector &offsets) +{ + this->offsets = offsets; } -void BlockData::setStrides(const SmallVector &strides) { - this->strides = strides; +void BlockData::setStrides(const SmallVector &strides) +{ + this->strides = strides; } -void BlockData::setSizes(const SmallVector &szs) { - this->sizes = szs; +void BlockData::setSizes(const SmallVector &szs) +{ + this->sizes = szs; } -void BlockData::setMemAccTy(const MemAccType &v) { this->memAccTy = v; } +void BlockData::setMemAccTy(const MemAccType &v) +{ + this->memAccTy = v; +} -void BlockData::setMemAccVal(const MemAccVal v) { this->memAccTy.value = v; } +void BlockData::setMemAccVal(const MemAccVal v) +{ + this->memAccTy.value = v; +} -OpFoldResult BlockData::inferBlockOffset(const Location &loc, - OpBuilder &builder) const { - OpFoldResult retOffset = builder.getIndexAttr(0); - for (auto ofr : offsets) { - retOffset = addOpFoldResult(retOffset, ofr, loc, builder); - } - return retOffset; +OpFoldResult BlockData::inferBlockOffset(const Location &loc, OpBuilder &builder) const +{ + OpFoldResult retOffset = builder.getIndexAttr(0); + for (auto ofr : offsets) { + retOffset = addOpFoldResult(retOffset, ofr, loc, builder); + } + return retOffset; } -MemRefType BlockData::getResultMemrefType(int64_t offset, - ArrayRef resultShape) const { - SmallVector staticStrides; - SmallVector dynamicStrides; - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - - auto elementType = - dyn_cast(this->source.getType()).getElementType(); - auto layout = - StridedLayoutAttr::get(this->source.getContext(), offset, staticStrides); - return MemRefType::get(resultShape, elementType, layout); +MemRefType BlockData::getResultMemrefType(int64_t offset, ArrayRef resultShape) const +{ + SmallVector staticStrides; + SmallVector dynamicStrides; + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + + auto elementType = dyn_cast(this->source.getType()).getElementType(); + auto layout = StridedLayoutAttr::get(this->source.getContext(), offset, staticStrides); + return MemRefType::get(resultShape, elementType, layout); } -void BlockData::addBlock(BlockData &lBlock, BlockData &rBlock, Location loc, - ConversionPatternRewriter &rewriter) { - assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); - // When both left block and right block have source, it is indirect load. - assert(!(lBlock.hasSource() && rBlock.hasSource()) && - "Don't support each BlockData has own base source pointer"); - this->source = - lBlock.hasSource() ? lBlock.getSourceRef() : rBlock.getSourceRef(); - - assert(!(lBlock.hasResElemTy() && rBlock.hasResElemTy())); - if (lBlock.hasResElemTy()) { - assert(lBlock.hasSource()); - this->resElemTy = lBlock.getResElemTyRef(); - } else if (rBlock.hasResElemTy()) { - assert(rBlock.hasSource()); - this->resElemTy = rBlock.getResElemTyRef(); - } - - // Acctually `scalar` should be accumulated into `offset` and `stride` finally - // In addBlock, just pass `scalar` when: - // 1. both lhs and rhs have `scalar` - // 2. otherwise, both lhs and rhs are scalar type with rank 0 - // Except above, original `scalar` has been fused into `offset` under add. - if (lBlock.isScalar() && rBlock.isScalar()) { - auto addScalar = addOpFoldResult(lBlock.getScalarRef(), - rBlock.getScalarRef(), loc, rewriter); - this->scalar = addScalar; - } else if (lBlock.getRank() == 0) { - // When both lhs and rhs are scalar type with rank 0, just try passing - // potential `scalar` - this->scalar = - lBlock.isScalar() ? lBlock.getScalarRef() : rBlock.getScalarRef(); - } - - for (const auto &[lOffset, rOffset] : - llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { - this->offsets.push_back(addOpFoldResult(lOffset, rOffset, loc, rewriter)); - } - - for (const auto &[lStride, rStride] : - llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { - this->strides.push_back(addOpFoldResult(lStride, rStride, loc, rewriter)); - } - - // Both sizes are same implicitly under `add` - this->sizes = lBlock.getSizesRef(); - - this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); - this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); - // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), - // rBlock.getMemAccType())); +void BlockData::addBlock(BlockData &lBlock, BlockData &rBlock, Location loc, ConversionPatternRewriter &rewriter) +{ + assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); + // When both left block and right block have source, it is indirect load. + assert(!(lBlock.hasSource() && rBlock.hasSource()) && "Don't support each BlockData has own base source pointer"); + this->source = lBlock.hasSource() ? lBlock.getSourceRef() : rBlock.getSourceRef(); + + assert(!(lBlock.hasResElemTy() && rBlock.hasResElemTy())); + if (lBlock.hasResElemTy()) { + assert(lBlock.hasSource()); + this->resElemTy = lBlock.getResElemTyRef(); + } else if (rBlock.hasResElemTy()) { + assert(rBlock.hasSource()); + this->resElemTy = rBlock.getResElemTyRef(); + } + + // Acctually `scalar` should be accumulated into `offset` and `stride` finally + // In addBlock, just pass `scalar` when: + // 1. both lhs and rhs have `scalar` + // 2. otherwise, both lhs and rhs are scalar type with rank 0 + // Except above, original `scalar` has been fused into `offset` under add. + if (lBlock.isScalar() && rBlock.isScalar()) { + auto addScalar = addOpFoldResult(lBlock.getScalarRef(), rBlock.getScalarRef(), loc, rewriter); + this->scalar = addScalar; + } else if (lBlock.getRank() == 0) { + // When both lhs and rhs are scalar type with rank 0, just try passing + // potential `scalar` + this->scalar = lBlock.isScalar() ? lBlock.getScalarRef() : rBlock.getScalarRef(); + } + + for (const auto &[lOffset, rOffset] : llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { + this->offsets.push_back(addOpFoldResult(lOffset, rOffset, loc, rewriter)); + } + + for (const auto &[lStride, rStride] : llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { + this->strides.push_back(addOpFoldResult(lStride, rStride, loc, rewriter)); + } + + // Both sizes are same implicitly under `add` + this->sizes = lBlock.getSizesRef(); + + this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); + this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); + // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), + // rBlock.getMemAccType())); } -void BlockData::mulBlock(BlockData &lBlock, BlockData &rBlock, Location loc, - ConversionPatternRewriter &rewriter) { - assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); +void BlockData::mulBlock(BlockData &lBlock, BlockData &rBlock, Location loc, ConversionPatternRewriter &rewriter) +{ + assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); - assert(!(lBlock.hasSource() && rBlock.hasSource())); + assert(!(lBlock.hasSource() && rBlock.hasSource())); - assert( - (lBlock.isScalar() ^ rBlock.isScalar()) && - "Currently only support one and only one scalar in function mulBlock()"); + assert((lBlock.isScalar() ^ rBlock.isScalar()) && + "Currently only support one and only one scalar in function mulBlock()"); - BlockData *lb = &lBlock; - BlockData *rb = &rBlock; - if (lb->isScalar()) { - std::swap(lb, rb); - } + BlockData *lb = &lBlock; + BlockData *rb = &rBlock; + if (lb->isScalar()) { + std::swap(lb, rb); + } - // In mulBlock, `scalar` will be accumulated into `offset` and `stride` - OpFoldResult rScalar = rb->getScalarRef(); - for (const auto &lOffset : lb->getOffsetsRef()) { - this->offsets.push_back(mulOpFoldResult(lOffset, rScalar, loc, rewriter)); - } + // In mulBlock, `scalar` will be accumulated into `offset` and `stride` + OpFoldResult rScalar = rb->getScalarRef(); + for (const auto &lOffset : lb->getOffsetsRef()) { + this->offsets.push_back(mulOpFoldResult(lOffset, rScalar, loc, rewriter)); + } - for (const auto &lStride : lb->getStridesRef()) { - this->strides.push_back(mulOpFoldResult(lStride, rScalar, loc, rewriter)); - } + for (const auto &lStride : lb->getStridesRef()) { + this->strides.push_back(mulOpFoldResult(lStride, rScalar, loc, rewriter)); + } - this->sizes = lb->getSizesRef(); + this->sizes = lb->getSizesRef(); - this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); - this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); - // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), - // rBlock.getMemAccType())); + this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); + this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); + // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), + // rBlock.getMemAccType())); } -void BlockData::divBlock(BlockData &lBlock, BlockData &rBlock, Location loc, - ConversionPatternRewriter &rewriter) { - assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); +void BlockData::divBlock(BlockData &lBlock, BlockData &rBlock, Location loc, ConversionPatternRewriter &rewriter) +{ + assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank()); - assert(!(lBlock.hasSource() && rBlock.hasSource())); + assert(!(lBlock.hasSource() && rBlock.hasSource())); - for (const auto &[lOffset, rOffset] : - llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { - this->offsets.push_back(divOpFoldResult(lOffset, rOffset, loc, rewriter)); - } + for (const auto &[lOffset, rOffset] : llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) { + this->offsets.push_back(divOpFoldResult(lOffset, rOffset, loc, rewriter)); + } - for (const auto &[lStride, rStride] : - llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { - this->strides.push_back(divOpFoldResult(lStride, rStride, loc, rewriter)); - } + for (const auto &[lStride, rStride] : llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) { + this->strides.push_back(divOpFoldResult(lStride, rStride, loc, rewriter)); + } - this->sizes = lBlock.getSizesRef(); + this->sizes = lBlock.getSizesRef(); - this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); - this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); - // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), - // rBlock.getMemAccType())); + this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef()); + this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef()); + // this->setMemAccTy(selectMaxMemAccTy(lBlock.getMemAccType(), + // rBlock.getMemAccType())); } -memref::ReinterpretCastOp BlockData::createCastOp(ArrayRef resultShape, - const Location &loc, - OpBuilder &builder) const { - OpFoldResult resOffset = this->inferBlockOffset(loc, builder); - auto resultType = this->getResultMemrefType( - isa(resOffset) ? getConstantIntValue(resOffset).value() - : ShapedType::kDynamic, - resultShape); - - return builder.create( - loc, resultType, this->source, resOffset, this->sizes, this->strides); +memref::ReinterpretCastOp BlockData::createCastOp(ArrayRef resultShape, const Location &loc, + OpBuilder &builder) const +{ + OpFoldResult resOffset = this->inferBlockOffset(loc, builder); + auto resultType = this->getResultMemrefType( + isa(resOffset) ? getConstantIntValue(resOffset).value() : ShapedType::kDynamic, resultShape); + + return builder.create(loc, resultType, this->source, resOffset, this->sizes, + this->strides); } -void BlockData::dump() const { - llvm::outs() << "[INFO][BEG] BlockData info\n"; - llvm::outs() << "offsets has " << offsets.size() << " items\n"; - int cnt = 0; - for (auto it = offsets.begin(); it != offsets.end(); it++) { - llvm::outs() << "offsets[" << cnt++ << "] = " << *it << "\n"; - } - llvm::outs() << "sizes has " << sizes.size() << " items\n"; - cnt = 0; - for (auto it = sizes.begin(); it != sizes.end(); it++) { - llvm::outs() << "sizes[" << cnt++ << "] = " << *it << "\n"; - } - llvm::outs() << "strides has " << strides.size() << " items\n"; - cnt = 0; - for (auto it = strides.begin(); it != strides.end(); it++) { - llvm::outs() << "strides[" << cnt++ << "] = " << *it << "\n"; - } - llvm::outs() << "source = " << source << "\n"; - llvm::outs() << "scalar = " << scalar << "\n"; - llvm::outs() << "resElemTy = " << resElemTy << "\n"; - llvm::outs() << "memAccTy = " << memAccTy.toString() << "\n"; - llvm::outs() << "[INFO][END] BlockData info\n"; +void BlockData::dump() const +{ + auto &os = llvm::outs(); + os << "[INFO][BEG] BlockData info\n"; + os << "offsets has " << offsets.size() << " items\n"; + int cnt = 0; + for (auto it = offsets.begin(); it != offsets.end(); it++) { + os << "offsets[" << cnt++ << "] = " << *it << "\n"; + } + os << "sizes has " << sizes.size() << " items\n"; + cnt = 0; + for (auto it = sizes.begin(); it != sizes.end(); it++) { + os << "sizes[" << cnt++ << "] = " << *it << "\n"; + } + os << "strides has " << strides.size() << " items\n"; + cnt = 0; + for (auto it = strides.begin(); it != strides.end(); it++) { + os << "strides[" << cnt++ << "] = " << *it << "\n"; + } + os << "source = " << source << "\n"; + os << "scalar = " << scalar << "\n"; + os << "resElemTy = " << resElemTy << "\n"; + os << "memAccTy = " << memAccTy.toString() << "\n"; + os << "[INFO][END] BlockData info\n"; + os.flush(); } -Value BlockDataParser::getScalarMemRef(Value ptr, Value memref, - const Location &loc, - ConversionPatternRewriter &rewriter) { - assert(isa(ptr.getType()) && "expect a scalar pointer"); - if (ptr.getDefiningOp()) { - if (auto castOp = memref.getDefiningOp()) { - return castOp.getResult(); - } else { - llvm_unreachable("pointer value is defined by an unexpected op"); +Value BlockDataParser::getScalarMemRef(Value ptr, Value memref, const Location &loc, + ConversionPatternRewriter &rewriter) +{ + assert(isa(ptr.getType()) && "expect a scalar pointer"); + if (ptr.getDefiningOp()) { + if (auto castOp = memref.getDefiningOp()) { + return castOp.getResult(); + } else { + llvm_unreachable("pointer value is defined by an unexpected op"); + } } - } - - assert(isa(ptr) && - "pointer should be produced by addptr or block argument"); - BlockData data; - data.setSource(memref); - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - data.getSizesRef().push_back(rewriter.getIndexAttr(1)); - data.getStridesRef().push_back(rewriter.getIndexAttr(1)); - auto castOp = data.createCastOp(SmallVector(1, 1), loc, rewriter); - return castOp.getResult(); + + assert(isa(ptr) && "pointer should be produced by addptr or block argument"); + BlockData data; + data.setSource(memref); + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + data.getSizesRef().push_back(rewriter.getIndexAttr(1)); + data.getStridesRef().push_back(rewriter.getIndexAttr(1)); + auto castOp = data.createCastOp(SmallVector(1, 1), loc, rewriter); + return castOp.getResult(); } -void BlockDataParser::parse( - Value operand, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - if (known.find(operand) != known.end()) { - return data = known.lookup(operand), void(); - } - - if (isa(operand.getType())) { - data.setScalar(getOpFoldResultOfLayoutInfo(operand, rewriter)); - return; - } - - // - if (isa(operand.getType())) { - // Just consider two state: ptr and ptr> - auto remappedPtr = rewriter.getRemappedValue(operand); - assert(remappedPtr); - if (auto op = operand.getDefiningOp()) { - if (auto addPtrOp = dyn_cast(op)) { +void BlockDataParser::parse(Value operand, BlockData &data, const Location &loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + if (known.find(operand) != known.end()) { + return data = known.lookup(operand), void(); + } + + if (isa(operand.getType())) { + data.setScalar(getOpFoldResultOfLayoutInfo(operand, rewriter)); + return; + } + + // + if (isa(operand.getType())) { + // Just consider two state: ptr and ptr> + auto remappedPtr = rewriter.getRemappedValue(operand); + assert(remappedPtr); + if (auto op = operand.getDefiningOp()) { + if (auto addPtrOp = dyn_cast(op)) { + parseAddPtr(addPtrOp, data, loc, rewriter, known); + } else if (auto bitcastOp = dyn_cast(op)) { + parseBitcast(bitcastOp, data, loc, rewriter, known); + } else if (auto makeTensorPtrOp = dyn_cast(op)) { + parseTensorPtr(makeTensorPtrOp, data, loc, rewriter, known); + } else if (auto advanceOp = dyn_cast(op)) { + // To support + // ptr_0 = tl.advance(ptr) + // ptr_1 = tl.advance(ptr_0) + parseTensorPtr(advanceOp, data, loc, rewriter, known); + } else { + llvm_unreachable("Unexpected operand defining operation, a scalar " + "pointer can only be produced by AddPtrOp or direct block ptr"); + } + } else { + data.setSource(remappedPtr); + } + return; + } + + // not a scalar pointer + if (auto addOp = operand.getDefiningOp()) { + parseAdd(addOp, data, loc, rewriter, known); + } else if (auto mulOp = operand.getDefiningOp()) { + parseMul(mulOp, data, loc, rewriter, known); + } else if (auto addPtrOp = operand.getDefiningOp()) { parseAddPtr(addPtrOp, data, loc, rewriter, known); - } else if (auto bitcastOp = dyn_cast(op)) { + } else if (auto constOp = operand.getDefiningOp()) { + parseConstSplat(constOp, data, loc, rewriter, known); + } else if (auto broadcastOp = operand.getDefiningOp()) { + parseBroadcast(broadcastOp, data, loc, rewriter, known); + } else if (auto splatOp = operand.getDefiningOp()) { + parseSplat(splatOp, data, loc, rewriter, known); + } else if (auto expandDimsOp = operand.getDefiningOp()) { + parseExpandDims(expandDimsOp, data, loc, rewriter, known); + } else if (auto remOp = operand.getDefiningOp()) { + parseRem(remOp, data, loc, rewriter, known); + } else if (auto bitcastOp = operand.getDefiningOp()) { parseBitcast(bitcastOp, data, loc, rewriter, known); - } else if (auto makeTensorPtrOp = dyn_cast(op)) { - parseTensorPtr(makeTensorPtrOp, data, loc, rewriter, known); - } else if (auto advanceOp = dyn_cast(op)) { - // To support - // ptr_0 = tl.advance(ptr) - // ptr_1 = tl.advance(ptr_0) - parseTensorPtr(advanceOp, data, loc, rewriter, known); - } else { - llvm_unreachable( - "Unexpected operand defining operation, a scalar " - "pointer can only be produced by AddPtrOp or direct block ptr"); - } + } else if (auto extsiOp = operand.getDefiningOp()) { + parseExtSI(extsiOp, data, loc, rewriter, known); + } else if (auto divOp = operand.getDefiningOp()) { + parseDiv(divOp, data, loc, rewriter, known); + } else if (auto makeRangeOp = operand.getDefiningOp()) { + parseMakeRange(makeRangeOp, data, loc, rewriter, known); + } else if (auto reduceOp = operand.getDefiningOp()) { + parseIndirectLoad(reduceOp, data, loc, rewriter, known); + } else if (auto loadOp = operand.getDefiningOp()) { + parseIndirectLoad(loadOp, data, loc, rewriter, known); + } else if (auto castOp = operand.getDefiningOp()) { + parseIndirectLoad(castOp, data, loc, rewriter, known); } else { - data.setSource(remappedPtr); + operand.dump(); + llvm_unreachable("encountered AddPtrOp produced by unsupported operation"); } - return; - } - - // not a scalar pointer - if (auto addOp = operand.getDefiningOp()) { - parseAdd(addOp, data, loc, rewriter, known); - } else if (auto mulOp = operand.getDefiningOp()) { - parseMul(mulOp, data, loc, rewriter, known); - } else if (auto addPtrOp = operand.getDefiningOp()) { - parseAddPtr(addPtrOp, data, loc, rewriter, known); - } else if (auto constOp = operand.getDefiningOp()) { - parseConstSplat(constOp, data, loc, rewriter, known); - } else if (auto broadcastOp = operand.getDefiningOp()) { - parseBroadcast(broadcastOp, data, loc, rewriter, known); - } else if (auto splatOp = operand.getDefiningOp()) { - parseSplat(splatOp, data, loc, rewriter, known); - } else if (auto expandDimsOp = - operand.getDefiningOp()) { - parseExpandDims(expandDimsOp, data, loc, rewriter, known); - } else if (auto remOp = operand.getDefiningOp()) { - parseRem(remOp, data, loc, rewriter, known); - } else if (auto bitcastOp = operand.getDefiningOp()) { - parseBitcast(bitcastOp, data, loc, rewriter, known); - } else if (auto extsiOp = operand.getDefiningOp()) { - parseExtSI(extsiOp, data, loc, rewriter, known); - } else if (auto divOp = operand.getDefiningOp()) { - parseDiv(divOp, data, loc, rewriter, known); - } else if (auto makeRangeOp = operand.getDefiningOp()) { - parseMakeRange(makeRangeOp, data, loc, rewriter, known); - } else if (auto reduceOp = operand.getDefiningOp()) { - parseReduce(reduceOp, data, loc, rewriter, known); - } else if (auto loadOp = operand.getDefiningOp()) { - parseIndirectLoad(loadOp, data, loc, rewriter, known); - } else if (auto castOp = operand.getDefiningOp()) { - parseIndirectLoad(castOp, data, loc, rewriter, known); - } else { - operand.dump(); - llvm_unreachable("encountered AddPtrOp produced by unsupported operation"); - } + LLVM_DEBUG({ + llvm::dbgs() << "[PtrAnalysis] parseOp: " << *(operand.getDefiningOp()) << "\n"; + data.dump(); + llvm::dbgs().flush(); + }); } -void BlockDataParser::parseAdd( - arith::AddIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - BlockData lBlock, rBlock; - parse(op.getLhs(), lBlock, loc, rewriter, known); - parse(op.getRhs(), rBlock, loc, rewriter, known); - data.addBlock(lBlock, rBlock, loc, rewriter); +void BlockDataParser::parseAdd(arith::AddIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, const llvm::SmallDenseMap &known) +{ + BlockData lBlock, rBlock; + parse(op.getLhs(), lBlock, loc, rewriter, known); + parse(op.getRhs(), rBlock, loc, rewriter, known); + data.addBlock(lBlock, rBlock, loc, rewriter); } -void BlockDataParser::parseMul( - arith::MulIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - BlockData lBlock, rBlock; - parse(op.getLhs(), lBlock, loc, rewriter, known); - parse(op.getRhs(), rBlock, loc, rewriter, known); +void BlockDataParser::parseMul(arith::MulIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, const llvm::SmallDenseMap &known) +{ + BlockData lBlock, rBlock; + parse(op.getLhs(), lBlock, loc, rewriter, known); + parse(op.getRhs(), rBlock, loc, rewriter, known); - data.mulBlock(lBlock, rBlock, loc, rewriter); + data.mulBlock(lBlock, rBlock, loc, rewriter); } -void BlockDataParser::parseDiv( - arith::DivSIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - BlockData lBlock, rBlock; - parse(op.getLhs(), lBlock, loc, rewriter, known); - parse(op.getRhs(), rBlock, loc, rewriter, known); - data.divBlock(lBlock, rBlock, loc, rewriter); +void BlockDataParser::parseDiv(arith::DivSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, const llvm::SmallDenseMap &known) +{ + BlockData lBlock, rBlock; + parse(op.getLhs(), lBlock, loc, rewriter, known); + parse(op.getRhs(), rBlock, loc, rewriter, known); + data.divBlock(lBlock, rBlock, loc, rewriter); } // TODO : support modulos -void BlockDataParser::parseRem( - arith::RemSIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(false && "Address expression with modulo is not supported yet, it " - "shall be analysis at linearize."); +void BlockDataParser::parseRem(arith::RemSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, const llvm::SmallDenseMap &known) +{ + assert(false && "Address expression with modulo is not supported yet, it " + "shall be analysis at linearize."); } -void BlockDataParser::parseMakeRange( - triton::MakeRangeOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - auto shape = dyn_cast(op.getType()).getShape(); - - auto start = op.getStart(); - auto end = op.getEnd(); - auto stride = (end >= start) && (end - start <= shape[0]); - assert(stride == 1 && - "make_range op should always return a tensor of stride 1"); - - data.getOffsetsRef().push_back(rewriter.getIndexAttr(start)); - data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); - data.getStridesRef().push_back(rewriter.getIndexAttr(stride)); +void BlockDataParser::parseMakeRange(triton::MakeRangeOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); + auto shape = dyn_cast(op.getType()).getShape(); + + auto start = op.getStart(); + auto end = op.getEnd(); + auto stride = (end >= start) && (end - start <= shape[0]); + assert(stride == 1 && "make_range op should always return a tensor of stride 1"); + + data.getOffsetsRef().push_back(rewriter.getIndexAttr(start)); + data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); + data.getStridesRef().push_back(rewriter.getIndexAttr(stride)); } -void BlockDataParser::parseExpandDims( - triton::ExpandDimsOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - parse(op.getSrcMutable().get(), data, loc, rewriter, known); - auto resShape = dyn_cast(op.getResult().getType()).getShape(); - auto axis = op.getAxis(); - - assert(resShape[axis] == 1 && - "The destiny shape of changed dimension should be 1"); - - data.getOffsetsRef().insert(data.getOffsetsRef().begin() + axis, - rewriter.getIndexAttr(0)); - data.getSizesRef().insert(data.getSizesRef().begin() + axis, - rewriter.getIndexAttr(1)); - data.getStridesRef().insert(data.getStridesRef().begin() + axis, - rewriter.getIndexAttr(0)); +void BlockDataParser::parseExpandDims(triton::ExpandDimsOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); + + parse(op.getSrcMutable().get(), data, loc, rewriter, known); + auto resShape = dyn_cast(op.getResult().getType()).getShape(); + auto axis = op.getAxis(); + + assert(resShape[axis] == 1 && "The destiny shape of changed dimension should be 1"); + + data.getOffsetsRef().insert(data.getOffsetsRef().begin() + axis, rewriter.getIndexAttr(0)); + data.getSizesRef().insert(data.getSizesRef().begin() + axis, rewriter.getIndexAttr(1)); + data.getStridesRef().insert(data.getStridesRef().begin() + axis, rewriter.getIndexAttr(0)); } -void BlockDataParser::parseBitcast( - triton::BitcastOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - parse(op.getSrc(), data, loc, rewriter, known); - - auto resType = op.getResult().getType(); - mlir::Type resElemPointeeTy; - if (auto resShapedTy = dyn_cast(resType)) { - auto resElemTy = resShapedTy.getElementType(); - resElemPointeeTy = - dyn_cast(resElemTy).getPointeeType(); - } else { - resElemPointeeTy = dyn_cast(resType).getPointeeType(); - } - data.setResElemTy(resElemPointeeTy); +void BlockDataParser::parseBitcast(triton::BitcastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); + parse(op.getSrc(), data, loc, rewriter, known); + + auto resType = op.getResult().getType(); + mlir::Type resElemPointeeTy; + if (auto resShapedTy = dyn_cast(resType)) { + auto resElemTy = resShapedTy.getElementType(); + resElemPointeeTy = dyn_cast(resElemTy).getPointeeType(); + } else { + resElemPointeeTy = dyn_cast(resType).getPointeeType(); + } + data.setResElemTy(resElemPointeeTy); } -void BlockDataParser::parseExtSI( - arith::ExtSIOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - parse(op.getIn(), data, loc, rewriter, known); +void BlockDataParser::parseExtSI(arith::ExtSIOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); + parse(op.getIn(), data, loc, rewriter, known); } -void BlockDataParser::parseBroadcast( - triton::BroadcastOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - auto src = op.getSrcMutable().get(); - auto dst = op.getResult(); - assert(isa(src.getType()) && - "tt.broadcast's input should be a tensor"); - - auto srcShape = dyn_cast(src.getType()).getShape(); - auto dstShape = dyn_cast(dst.getType()).getShape(); - assert(srcShape.size() == dstShape.size() && - "rank of source shoule be equal to destnation"); - - parse(src, data, loc, rewriter, known); - - for (const auto &[idx, src_dst] : - llvm::enumerate(llvm::zip(srcShape, dstShape))) { - const auto &[srcAxis, dstAxis] = src_dst; - if (srcAxis == dstAxis) { - continue; +void BlockDataParser::parseBroadcast(triton::BroadcastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); + + auto src = op.getSrcMutable().get(); + auto dst = op.getResult(); + assert(isa(src.getType()) && "tt.broadcast's input should be a tensor"); + + auto srcShape = dyn_cast(src.getType()).getShape(); + auto dstShape = dyn_cast(dst.getType()).getShape(); + assert(srcShape.size() == dstShape.size() && "rank of source shoule be equal to destnation"); + + parse(src, data, loc, rewriter, known); + + for (const auto &[idx, src_dst] : llvm::enumerate(llvm::zip(srcShape, dstShape))) { + const auto &[srcAxis, dstAxis] = src_dst; + if (srcAxis == dstAxis) { + continue; + } + assert(srcAxis < dstAxis && "srcShape of broadcastOp must be less than dstShape."); + data.getSizesRef()[idx] = rewriter.getIndexAttr(dstAxis); } - assert(srcAxis < dstAxis && - "srcShape of broadcastOp must be less than dstShape."); - data.getSizesRef()[idx] = rewriter.getIndexAttr(dstAxis); - } } -void BlockDataParser::parseSplat( - triton::SplatOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - auto src = op.getSrc(); - auto dst = op.getResult(); - auto dstShape = dyn_cast(dst.getType()).getShape(); - - parse(src, data, loc, rewriter, known); - - if (isa(src.getType()) || - isa(src.getType())) { - if (!data.isEmpty()) { - data.getOffsetsRef().clear(); - data.getSizesRef().clear(); - data.getStridesRef().clear(); +void BlockDataParser::parseSplat(triton::SplatOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); + auto src = op.getSrc(); + auto dst = op.getResult(); + auto dstShape = dyn_cast(dst.getType()).getShape(); + + parse(src, data, loc, rewriter, known); + + if (isa(src.getType()) || isa(src.getType())) { + if (!data.isEmpty()) { + data.getOffsetsRef().clear(); + data.getSizesRef().clear(); + data.getStridesRef().clear(); + } + for (auto dstAxis : dstShape) { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + data.getSizesRef().push_back(rewriter.getIndexAttr(dstAxis)); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } + } else { + op->emitError("Block data Analysis: unsupported splat pattern"); + return; } - for (auto dstAxis : dstShape) { - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - data.getSizesRef().push_back(rewriter.getIndexAttr(dstAxis)); - data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + if (data.isScalar()) { + data.getOffsetsRef()[0] = data.getScalarRef(); } - } else { - op->emitError("Block data Analysis: unsupported splat pattern"); - return; - } - if (data.isScalar()) { - data.getOffsetsRef()[0] = data.getScalarRef(); - } } -void BlockDataParser::parseConstSplat( - arith::ConstantOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); +void BlockDataParser::parseConstSplat(arith::ConstantOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); - DenseElementsAttr denseAttr = dyn_cast(op.getValue()); - assert(denseAttr && denseAttr.isSplat() && - isa(denseAttr.getElementType())); + DenseElementsAttr denseAttr = dyn_cast(op.getValue()); + assert(denseAttr && denseAttr.isSplat() && isa(denseAttr.getElementType())); - auto innerVal = denseAttr.getValues()[0].getValue(); - auto innerValIndexAttr = rewriter.getIndexAttr(innerVal.getSExtValue()); + auto innerVal = denseAttr.getValues()[0].getValue(); + auto innerValIndexAttr = rewriter.getIndexAttr(innerVal.getSExtValue()); - // for mul state - data.setScalar(innerValIndexAttr); + // for mul state + data.setScalar(innerValIndexAttr); - auto resType = dyn_cast(op.getResult().getType()); - size_t loopLimit = resType.getShape().size(); - for (auto i = 0; i < loopLimit; i++) { - // Add original dense val to first dim offset for add state - if (i == 0) { - data.getOffsetsRef().push_back(innerValIndexAttr); - } else { - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + auto resType = dyn_cast(op.getResult().getType()); + size_t loopLimit = resType.getShape().size(); + for (auto i = 0; i < loopLimit; i++) { + // Add original dense val to first dim offset for add state + if (i == 0) { + data.getOffsetsRef().push_back(innerValIndexAttr); + } else { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + } + data.getSizesRef().push_back(rewriter.getIndexAttr(resType.getShape()[i])); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); } - data.getSizesRef().push_back(rewriter.getIndexAttr(resType.getShape()[i])); - data.getStridesRef().push_back(rewriter.getIndexAttr(0)); - } } template -std::enable_if_t || - std::is_same_v> -BlockDataParser::parseTensorPtr( - T op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - Value remappedValue = rewriter.getRemappedValue(op); - if (auto castOp = remappedValue.getDefiningOp()) { - parseReinterpretCast(castOp, data, loc, rewriter, known); - } else { - llvm_unreachable("the value should be mapped to memref.reinterpret_cast"); - } -} - -void BlockDataParser::parseAddPtr( - triton::AddPtrOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - BlockData ptrBlock, offsetBlock; - parse(op.getPtr(), ptrBlock, op.getLoc(), rewriter, known); - parse(op.getOffset(), offsetBlock, op.getLoc(), rewriter, known); - - assert(ptrBlock.hasSource() && - "Ptr field should provide source/base pointer"); - // offset has source means offset is from tl.load and other ops(TODO) - if (offsetBlock.hasSource()) { - ptrBlock.setMemAccTy(offsetBlock.getMemAccType()); - offsetBlock.removeSource(); - } - - // handle for loop & scalar - if (ptrBlock.getRank() == 1 && offsetBlock.getRank() == 0) { - offsetBlock.getSizesRef().push_back(rewriter.getIndexAttr(1)); - offsetBlock.getOffsetsRef().push_back(offsetBlock.getScalarRef()); - offsetBlock.getStridesRef().push_back(rewriter.getIndexAttr(0)); - } - - assert(ptrBlock.getRank() == offsetBlock.getRank() && - "ptr and offset should have same rank"); - LLVM_DEBUG({ - auto &os = llvm::dbgs(); - os << "[parseAddPtr][BEG] =========================\n"; - os << "[parseAddPtr] op is " << op << "\n"; - for (int i = 0; i < ptrBlock.getRank(); i++) { - os << "ptrBlock.getOffsetsRef()[" << i - << "] = " << ptrBlock.getOffsetsRef()[i] << "\n"; - os << "ptrBlock.getSizesRef()[" << i - << "] = " << ptrBlock.getSizesRef()[i] << "\n"; - os << "ptrBlock.getStridesRef()[" << i - << "] = " << ptrBlock.getStridesRef()[i] << "\n"; - os << "offsetBlock.getOffsetsRef()[" << i - << "] = " << offsetBlock.getOffsetsRef()[i] << "\n"; - os << "offsetBlock.getSizesRef()[" << i - << "] = " << offsetBlock.getSizesRef()[i] << "\n"; - os << "offsetBlock.getStridesRef()[" << i - << "] = " << offsetBlock.getStridesRef()[i] << "\n"; +std::enable_if_t || std::is_same_v> +BlockDataParser::parseTensorPtr(T op, BlockData &data, const Location &loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); + + Value remappedValue = rewriter.getRemappedValue(op); + if (auto castOp = remappedValue.getDefiningOp()) { + parseReinterpretCast(castOp, data, loc, rewriter, known); + } else { + llvm_unreachable("the value should be mapped to memref.reinterpret_cast"); } - os << "[parseAddPtr][END] -------------------------\n"; - }); - data.addBlock(ptrBlock, offsetBlock, op.getLoc(), rewriter); } -void BlockDataParser::parseReinterpretCast( - memref::ReinterpretCastOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - assert(data.isEmpty()); - - data.setOffsets(op.getMixedOffsets()); - data.setSizes(op.getMixedSizes()); - data.setStrides(op.getMixedStrides()); - data.setSource(op.getSource()); - - // In memref::ReinterpretCastOp, offset means the total of collapsing multiple - // dimensions, which corresponds to first dim offset in block data. - // Here populate the rest of the dimensions with zeroes. - assert(data.getOffsetsRef().size() == 1); - size_t loopLimit = data.getSizesRef().size(); - for (size_t i = 1; i < loopLimit; i++) { - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - } - - // Block data cannot accept patterns of size 1 and non-zero const stride from - // memref type. Set stride back to zero if this scenario is detected. - loopLimit = data.getStridesRef().size(); - for (size_t i = 0; i < loopLimit; i++) { - auto strideConst = getConstantIntValue(data.getStridesRef()[i]); - auto sizeConst = getConstantIntValue(data.getSizesRef()[i]); - assert(sizeConst.has_value()); - if (sizeConst.value() == 1 && strideConst.has_value()) { - data.getStridesRef()[i] = rewriter.getIndexAttr(0); +void BlockDataParser::parseAddPtr(triton::AddPtrOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); + + BlockData ptrBlock, offsetBlock; + parse(op.getPtr(), ptrBlock, op.getLoc(), rewriter, known); + parse(op.getOffset(), offsetBlock, op.getLoc(), rewriter, known); + + assert(ptrBlock.hasSource() && "Ptr field should provide source/base pointer"); + // offset has source means offset is from tl.load and other ops(TODO) + if (offsetBlock.hasSource()) { + ptrBlock.setMemAccTy(offsetBlock.getMemAccType()); + offsetBlock.removeSource(); } - } -} -void BlockDataParser::parseReduce( - triton::ReduceOp op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - - const std::string scenarioMessages = - "PtsAnalysis supports indirectly block load in the following scenario\n" - "B = tl.load(Aptr + Aoffset) # B is 1D tensor\n" - "s = tl.min(B) # s is a scalar\n" - "D = tl.load(Cptr + s + Coffset) # s is used as the scalar offset\n"; - - auto reduce_src = op->getOperand(0); - BlockData srcBlock; - parse(reduce_src, srcBlock, loc, rewriter, known); - if (!srcBlock.hasSource()) { - llvm_unreachable(scenarioMessages.c_str()); - } - if (!isa(srcBlock.getSource().getDefiningOp())) { - llvm_unreachable(scenarioMessages.c_str()); - } - - auto reduce_result = op->getResult(0); - auto shaped_ty = dyn_cast(reduce_result.getType()); - auto shape = shaped_ty.getShape(); - auto ops = llvm::map_to_vector(op.getBody()->without_terminator(), - [](Operation &op) { return &op; }); - // Support only the case: scalar = tl.load(1D tensor) - if (shape.size() != 1 || op.getAxis() != 0 || ops.size() != 1 || - !isa(ops.front())) { - llvm_unreachable(scenarioMessages.c_str()); - } - - auto castOp = rewriter.create( - loc, RankedTensorType::get(shape, rewriter.getIndexType()), - reduce_result); - auto offset = castOp.getResult(); - if (data.isEmpty()) { - data.getOffsetsRef().push_back(offset); - data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0])); - data.getStridesRef().push_back(rewriter.getIndexAttr(1)); - } else { - llvm_unreachable("parseReduce with offset already setup not yet supported"); - } -} + // handle for loop & scalar + if (ptrBlock.getRank() == 1 && offsetBlock.getRank() == 0) { + offsetBlock.getSizesRef().push_back(rewriter.getIndexAttr(1)); + offsetBlock.getOffsetsRef().push_back(offsetBlock.getScalarRef()); + offsetBlock.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } -template -void parseIndirectLoad(OpTy op, BlockData &data, const Location &loc, - ConversionPatternRewriter &rewriter, - const llvm::SmallDenseMap &known) { - // FIXME: assume single result of operation - auto opRes = op->getResult(0); - auto opResTy = opRes.getType(); - std::vector resShape; - if (auto shapedResTy = dyn_cast(opResTy)) { - // For now, we consider this is UnstrucMemAcc because we have no other info. - // Visiting other ops may change the type due to more info. - data.setMemAccVal(MemAccVal::UnstrucMemAcc); - resShape = shapedResTy.getShape().vec(); - } else { - // scalar load means this is used as offset. It is StrucMemAcc. - data.setMemAccVal(MemAccVal::StrucMemAcc); - resShape.push_back(1); - } - for (auto &s : resShape) { - data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); - data.getSizesRef().push_back(rewriter.getIndexAttr(s)); - data.getStridesRef().push_back(rewriter.getIndexAttr(1)); - } - // set the source in BlockData so that we know an indirect-load op exists in - // the chain. - data.setSource(opRes); + assert(ptrBlock.getRank() == offsetBlock.getRank() && "ptr and offset should have same rank"); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "[parseAddPtr][BEG] =========================\n"; + os << "[parseAddPtr] op is " << op << "\n"; + for (int i = 0; i < ptrBlock.getRank(); i++) { + os << "ptrBlock.getOffsetsRef()[" << i << "] = " << ptrBlock.getOffsetsRef()[i] << "\n"; + os << "ptrBlock.getSizesRef()[" << i << "] = " << ptrBlock.getSizesRef()[i] << "\n"; + os << "ptrBlock.getStridesRef()[" << i << "] = " << ptrBlock.getStridesRef()[i] << "\n"; + os << "offsetBlock.getOffsetsRef()[" << i << "] = " << offsetBlock.getOffsetsRef()[i] << "\n"; + os << "offsetBlock.getSizesRef()[" << i << "] = " << offsetBlock.getSizesRef()[i] << "\n"; + os << "offsetBlock.getStridesRef()[" << i << "] = " << offsetBlock.getStridesRef()[i] << "\n"; + } + os << "[parseAddPtr][END] -------------------------\n"; + os.flush(); + }); + data.addBlock(ptrBlock, offsetBlock, op.getLoc(), rewriter); } -void BlockDataParser::rewriteAddPtr( - triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, - ConversionPatternRewriter &rewriter, - llvm::SmallDenseMap &known) { - auto insertPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPoint(op); +void BlockDataParser::parseReinterpretCast(memref::ReinterpretCastOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + assert(data.isEmpty()); + + data.setOffsets(op.getMixedOffsets()); + data.setSizes(op.getMixedSizes()); + data.setStrides(op.getMixedStrides()); + data.setSource(op.getSource()); + + // In memref::ReinterpretCastOp, offset means the total of collapsing multiple + // dimensions, which corresponds to first dim offset in block data. + // Here populate the rest of the dimensions with zeroes. + assert(data.getOffsetsRef().size() == 1); + size_t loopLimit = data.getSizesRef().size(); + for (size_t i = 1; i < loopLimit; i++) { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(0)); + } - BlockData data; - parseAddPtr(op, data, op.getLoc(), rewriter, known); + // Block data cannot accept patterns of size 1 and non-zero const stride from + // memref type. Set stride back to zero if this scenario is detected. + loopLimit = data.getStridesRef().size(); + for (size_t i = 0; i < loopLimit; i++) { + auto strideConst = getConstantIntValue(data.getStridesRef()[i]); + auto sizeConst = getConstantIntValue(data.getSizesRef()[i]); + assert(sizeConst.has_value()); + if (sizeConst.value() == 1 && strideConst.has_value()) { + data.getStridesRef()[i] = rewriter.getIndexAttr(0); + } + } +} - if (data.getMemAccTypeRef().isUnstructured()) { - // TODO: Based on more info, try to create a performant IR - rewriteAddPtrToUnstrucMemAcc(op, adaptor, rewriter, data); - LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; }); - return; - } +void BlockDataParser::parseReduce(triton::ReduceOp op, BlockData &data, const Location &loc, + ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + + const std::string scenarioMessages = "PtsAnalysis supports indirectly block load in the following scenario\n" + "s = tl.min(B) # s is a scalar\n" + "s = op(s) ... \n" + "D = tl.load(Cptr + s + Coffset) # s is used as the scalar offset\n"; + + auto reduceSrc = op->getOperand(0); + BlockData srcBlock; + parse(reduceSrc, srcBlock, loc, rewriter, known); + if (!srcBlock.hasSource()) { + llvm_unreachable(scenarioMessages.c_str()); + } - if (data.getSizesRef().size() == 0) { - data.getSizesRef().push_back(rewriter.getIndexAttr(1)); - data.getStridesRef().push_back(rewriter.getIndexAttr(0)); - data.getOffsetsRef().push_back(data.getScalarRef()); - } - - ArrayRef resultShape; - // shape {1,} is stub for single ptr - SmallVector stubScalarTypeShape(1, 1); - if (auto shapedType = dyn_cast(op.getResult().getType())) { - resultShape = shapedType.getShape(); - } else { - assert(data.getRank() == 1); - resultShape = stubScalarTypeShape; - } - - known[op.getResult()] = data; - - // If there are dimensions with size 1 and stride 0, replace 0 stride with the - // product of sizes of all lower dimensions. This avoids creating memref with - // zero stride. - // And here store the unmodified state into known ptrs, since any following - // pointer arithmetic operations should still use the original 0 stride. - auto inferedSize = 1; - for (int i = data.getSizesRef().size() - 1; i >= 0; i--) { - auto strideConst = getConstantIntValue(data.getStridesRef()[i]); - auto sizeConst = getConstantIntValue(data.getSizesRef()[i]); - assert(sizeConst.has_value()); - if (sizeConst.value() == 1 && strideConst && strideConst.value() == 0) { - data.getStridesRef()[i] = rewriter.getIndexAttr(inferedSize); + auto reduceRes = op->getResult(0); + auto resShapeTy = dyn_cast(reduceRes.getType()); + auto resShape = resShapeTy.getShape(); + auto srcShapeTy = dyn_cast(reduceSrc.getType()); + auto srcShape = srcShapeTy.getShape(); + auto ops = llvm::map_to_vector(op.getBody()->without_terminator(), [](Operation &op) { return &op; }); + // Support only the case: scalar = tl.load(1D tensor) + if (resShape.size() != 1 || op.getAxis() != srcShape.size() - 1 || ops.size() != 1 || + !isa(ops.front())) + { + llvm_unreachable(scenarioMessages.c_str()); } - inferedSize *= sizeConst.value(); - } - - if (data.hasResElemTy()) { - // Handle bitcast scenario - auto memrefType = dyn_cast(data.getSourceRef().getType()) - .cloneWith(std::nullopt, data.getResElemTyRef()); - UnrealizedConversionCastOp castOp = - rewriter.create( - op.getLoc(), memrefType, data.getSourceRef()); - data.setSource(castOp.getOutputs()[0]); - } - - // ToDo: need to handle module scenario - - memref::ReinterpretCastOp castOp = - data.createCastOp(resultShape, op.getLoc(), rewriter); - Value src = castOp.getResult(); - LLVM_DEBUG({ - llvm::dbgs() << "cast MemRefType:\n"; - castOp.getOperation()->print(llvm::dbgs(), - OpPrintingFlags().printGenericOpForm()); - llvm::dbgs() << "\n"; - }); - - rewriter.replaceOp(op, src); - rewriter.restoreInsertionPoint(insertPoint); -} -void BlockDataParser::rewriteAdvanceOp( - triton::AdvanceOp op, ConversionPatternRewriter &rewriter, - llvm::SmallDenseMap &known) { - OpBuilder::InsertionGuard insertionGuard(rewriter); - rewriter.setInsertionPoint(op); - auto loc = op.getLoc(); - - BlockData blockData; - parse(op.getOperand(0), blockData, loc, rewriter, known); - - // region [BUGFIX] Add the code block below following the same logic as 'BlockDataParser::rewriteAddPtr' function. - known[op.getResult()] = blockData; - auto inferedSize = 1; - for (int i = blockData.getSizesRef().size() - 1; i >= 0; i--) { - auto strideConst = getConstantIntValue(blockData.getStridesRef()[i]); - auto sizeConst = getConstantIntValue(blockData.getSizesRef()[i]); - assert(sizeConst.has_value()); - if (sizeConst.value() == 1 && strideConst && strideConst.value() == 0) { - blockData.getStridesRef()[i] = rewriter.getIndexAttr(inferedSize); + auto castOp = + rewriter.create(loc, RankedTensorType::get(resShape, rewriter.getIndexType()), reduceRes); + auto offset = castOp.getResult(); + if (data.isEmpty()) { + data.getOffsetsRef().push_back(offset); + data.getSizesRef().push_back(rewriter.getIndexAttr(resShape[0])); + data.getStridesRef().push_back(rewriter.getIndexAttr(1)); + } else { + llvm_unreachable("parseReduce with offset already setup not yet supported"); } - inferedSize *= sizeConst.value(); - } - // endregion - - SmallVector incrementOffsets = - llvm::map_to_vector(op.getOffsets(), [&](Value offset) { - return getOpFoldResultOfLayoutInfo(offset, rewriter); - }); - - SmallVector newOffsets; - for (const auto [increment, originalOffset, stride] : - llvm::zip(incrementOffsets, blockData.getOffsetsRef(), - blockData.getStridesRef())) { - auto curDimOffset = - addOpFoldResult(mulOpFoldResult(increment, stride, loc, rewriter), - originalOffset, loc, rewriter); - - newOffsets.push_back(curDimOffset); - } - - blockData.getOffsetsRef().clear(); - - for (auto offset : newOffsets) - blockData.getOffsetsRef().push_back(offset); - - SmallVector scalarShape(1, 1); // Stub shape - ArrayRef resultShape; - auto pointerType = cast(op.getResult().getType()); - - if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { - resultShape = shapedType.getShape(); - } else { - // scalar pointer, should produce a one dimensional memref - resultShape = scalarShape; - assert(blockData.getRank() == 1); - } - - auto newOp = blockData.createCastOp(resultShape, loc, rewriter); - rewriter.replaceOp(op, newOp.getResult()); - - known[newOp.getResult()] = blockData; } -void BlockDataParser::rewriteYieldOp( - scf::YieldOp op, ConversionPatternRewriter &rewriter, - const std::set blockArgIdxSet, - const llvm::SmallDenseMap &known) { - // Any inserted instruction should be before this yield - OpBuilder::InsertionGuard insertionGuard{rewriter}; - rewriter.setInsertionPoint(op); - - auto adaptor = scf::YieldOp::Adaptor(op); - - SmallVector initArgState; - SmallVector operands(adaptor.getOperands()); - - // For each of the init arg that we added additional Values in for loop, we - // need to add corresponding Values as yield operands. The loop below gathers - // BlockData for those values. - for (auto [i, v] : llvm::enumerate(adaptor.getOperands())) { - if (auto mappedV = rewriter.getRemappedValue(v)) { - // If this value is a tensor of pointers produced by AddPtrOp, - // we should have already converted to a ReinterpretCastOp without - // layout information for the normal cases - if (v.getDefiningOp() || - v.getDefiningOp() || - v.getDefiningOp()) { - if (auto castOp = mappedV.getDefiningOp()) { - v = castOp; +template +void parseIndirectLoad(OpTy op, BlockData &data, const Location &loc, ConversionPatternRewriter &rewriter, + const llvm::SmallDenseMap &known) +{ + // FIXME: assume single result of operation + auto opRes = op->getResult(0); + auto opResTy = opRes.getType(); + std::vector resShape; + bool resShapeAllOne = false; + Value scalarOffset; + if (auto shapedResTy = dyn_cast(opResTy)) { + // For now, we consider this is UnstrucMemAcc because we have no other info. + // Visiting other ops may change the type due to more info. + resShape = shapedResTy.getShape().vec(); + resShapeAllOne = llvm::all_of(resShape, [](int64_t dim) { return dim == 1; }); + if (resShapeAllOne) { + // the shape is all one indicating in fact a scalar load + data.setMemAccVal(MemAccVal::StrucMemAcc); + Value zero = rewriter.create(loc, rewriter.getIndexAttr(0)); + auto extractOp = rewriter.create(op->getLoc(), opRes, zero); + scalarOffset = + rewriter.create(op->getLoc(), rewriter.getIndexType(), extractOp->getResult(0)); } else { - llvm_unreachable("mapped value defined by an unexpected op"); + data.setMemAccVal(MemAccVal::UnstrucMemAcc); } - } else { - // If this value is not a tensor of pointers, we will use the - // mapped value, and rely on the conversion will happen later - // automatically when we legalize loop body. - - // TODO: - // The scenario where a value is a tensor of pointers but not - // produced by AddPtrOp is not supported - if (isa(mappedV.getType()) && - isa( - dyn_cast(mappedV.getType()).getElementType())) - llvm_unreachable("unsupported scenario where a value is a tensor of " - "pointers but not produced by AddPtrOp"); - v = mappedV; - } + } else { + // scalar load means this is used as offset. It is StrucMemAcc. + resShape.push_back(1); + data.setMemAccVal(MemAccVal::StrucMemAcc); + scalarOffset = rewriter.create(op->getLoc(), rewriter.getIndexType(), opRes); } - if (blockArgIdxSet.find(i) == blockArgIdxSet.end()) - continue; + if (data.getMemAccTypeRef().isUnstructured()) { + for (auto &s : resShape) { + data.getOffsetsRef().push_back(rewriter.getIndexAttr(ShapedType::kDynamic)); + data.getSizesRef().push_back(rewriter.getIndexAttr(s)); + data.getStridesRef().push_back(rewriter.getIndexAttr(ShapedType::kDynamic)); + } + } else { + // scalar load used as offset + for (auto &s : resShape) { + data.getOffsetsRef().push_back(scalarOffset); + data.getSizesRef().push_back(rewriter.getIndexAttr(s)); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + } + } + // set the source in BlockData so that we know an indirect-load op exists in + // the chain. + data.setSource(opRes); +} + +void BlockDataParser::rewriteAddPtr(triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, llvm::SmallDenseMap &known) +{ + auto insertPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + BlockData data; + parseAddPtr(op, data, op.getLoc(), rewriter, known); - auto reintCastOp = v.getDefiningOp(); - assert( - reintCastOp || - (isa(v.getType()) && - isa(dyn_cast(v.getType()).getElementType()))); + if (data.getMemAccTypeRef().isUnstructured()) { + // TODO: Based on more info, try to create a performant IR + rewriteAddPtrToUnstrucMemAcc(op, adaptor, rewriter, data); + LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; }); + return; + } - BlockData state; - if (reintCastOp) { - parseReinterpretCast(reintCastOp, state, op.getLoc(), rewriter, known); + if (data.getSizesRef().size() == 0) { + data.getSizesRef().push_back(rewriter.getIndexAttr(1)); + data.getStridesRef().push_back(rewriter.getIndexAttr(0)); + data.getOffsetsRef().push_back(data.getScalarRef()); + } + + ArrayRef resultShape; + // shape {1,} is stub for single ptr + SmallVector stubScalarTypeShape(1, 1); + if (auto shapedType = dyn_cast(op.getResult().getType())) { + resultShape = shapedType.getShape(); } else { - parse(v, state, op.getLoc(), rewriter, known); + assert(data.getRank() == 1); + resultShape = stubScalarTypeShape; } - initArgState.push_back(state); - } - - // For each of the BlockData recorded in the last step, extract value - // that correspond to offset and stride for each dimension and append - // them to yield operands. - for (auto state : initArgState) { - for (auto offset : state.getOffsetsRef()) { - // offsets can be IntAttr zeroes, since reinterpret_cast collapses - // them for the input memref, and the for loop may not update - // offsets other than offsets[0]. Create constants Values for those - // zeroes. - if (isa(offset)) { - auto constOffset = offset.get(); - assert(isa(constOffset) && - dyn_cast(constOffset).getInt() == 0 && - "attribute offsets should be zeroes"); - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr(0)); - operands.push_back(constOp.getResult()); - } else { - operands.push_back(offset.get()); - } + + known[op.getResult()] = data; + + // If there are dimensions with size 1 and stride 0, replace 0 stride with the + // product of sizes of all lower dimensions. This avoids creating memref with + // zero stride. + // And here store the unmodified state into known ptrs, since any following + // pointer arithmetic operations should still use the original 0 stride. + auto inferedSize = 1; + for (int i = data.getSizesRef().size() - 1; i >= 0; i--) { + auto strideConst = getConstantIntValue(data.getStridesRef()[i]); + auto sizeConst = getConstantIntValue(data.getSizesRef()[i]); + assert(sizeConst.has_value()); + if (sizeConst.value() == 1 && strideConst && strideConst.value() == 0) { + data.getStridesRef()[i] = rewriter.getIndexAttr(inferedSize); + } + inferedSize *= sizeConst.value(); } - for (OpFoldResult stride : state.getStridesRef()) { - assert(!isa(stride) && "BlockData strides for yield within for" - " loop not expected to be attribute."); - operands.push_back(stride.get()); + if (data.hasResElemTy()) { + // Handle bitcast scenario + auto memrefType = + dyn_cast(data.getSourceRef().getType()).cloneWith(std::nullopt, data.getResElemTyRef()); + UnrealizedConversionCastOp castOp = + rewriter.create(op.getLoc(), memrefType, data.getSourceRef()); + data.setSource(castOp.getOutputs()[0]); } - } - - // Yield is a terminator op that must be at the end of the function - rewriter.setInsertionPointAfter(op); - auto newOp = rewriter.replaceOpWithNewOp(op, operands); - assert(op->getNumResults() == 0); - - LLVM_DEBUG({ - llvm::dbgs() << "new yield:"; - newOp.getOperation()->print(llvm::dbgs(), - OpPrintingFlags().printGenericOpForm()); - llvm::dbgs() << "\n"; - }); + + // ToDo: need to handle module scenario + + memref::ReinterpretCastOp castOp = data.createCastOp(resultShape, op.getLoc(), rewriter); + Value src = castOp.getResult(); + LLVM_DEBUG({ + llvm::dbgs() << "cast MemRefType:\n"; + castOp.getOperation()->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + rewriter.replaceOp(op, src); + rewriter.restoreInsertionPoint(insertPoint); } -void BlockDataParser::rewriteForOp( - scf::ForOp op, ConversionPatternRewriter &rewriter, - llvm::SmallDenseMap &known) { - - SmallVector newInitArgs; - - SmallVector, 5> initArgIndexIfBlockData; - SmallVector, 5> knownPtrsTmp; - std::set blockArgIdxSet; - - // Create a new list of init args - for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { - auto mappedV = rewriter.getRemappedValue(arg); - memref::ReinterpretCastOp reintCastOp; - - // If this init arg is supposed to be remapped, use the remapped - // value instead. - // In addition, if this init arg is a memref created by a reinterpret_cast - // or a tensor of index, there is a chance that it will be used in addptr. - // Create BlockData for each such init arg. - if (mappedV) { - // TODO: - // Passing a block argument pointer directly into a for loop not - // supported. - assert(!(isa(mappedV) && - isa(mappedV.getType())) && - "cannot take pointer block argument as init arg for for loop"); - if (auto op = mappedV.getDefiningOp()) { - // Record memref::ReinterpretCastOp - reintCastOp = op; - newInitArgs.push_back(mappedV); - } else { - newInitArgs.push_back(mappedV); - } - } else { - newInitArgs.push_back(arg); +void BlockDataParser::rewriteAdvanceOp(triton::AdvanceOp op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known) +{ + OpBuilder::InsertionGuard insertionGuard(rewriter); + rewriter.setInsertionPoint(op); + auto loc = op.getLoc(); + + BlockData blockData; + parse(op.getOperand(0), blockData, loc, rewriter, known); + + // region [BUGFIX] Add the code block below following the same logic as 'BlockDataParser::rewriteAddPtr' function. + known[op.getResult()] = blockData; + auto inferedSize = 1; + for (int i = blockData.getSizesRef().size() - 1; i >= 0; i--) { + auto strideConst = getConstantIntValue(blockData.getStridesRef()[i]); + auto sizeConst = getConstantIntValue(blockData.getSizesRef()[i]); + assert(sizeConst.has_value()); + if (sizeConst.value() == 1 && strideConst && strideConst.value() == 0) { + blockData.getStridesRef()[i] = rewriter.getIndexAttr(inferedSize); + } + inferedSize *= sizeConst.value(); } + // endregion - auto indexTensor = - isa(arg.getType()) && - isa(dyn_cast(arg.getType()).getElementType()); + SmallVector incrementOffsets = llvm::map_to_vector( + op.getOffsets(), [&](Value offset) { return getOpFoldResultOfLayoutInfo(offset, rewriter); }); - // Handle memref::ReinterpretCastOp and tensor specially - if (!reintCastOp && !indexTensor) - continue; + SmallVector newOffsets; + for (const auto [increment, originalOffset, stride] : + llvm::zip(incrementOffsets, blockData.getOffsetsRef(), blockData.getStridesRef())) + { + auto curDimOffset = + addOpFoldResult(mulOpFoldResult(increment, stride, loc, rewriter), originalOffset, loc, rewriter); - BlockData data; - if (reintCastOp) { - parseReinterpretCast(reintCastOp, data, op.getLoc(), rewriter, - llvm::SmallDenseMap(0)); + newOffsets.push_back(curDimOffset); + } + + blockData.getOffsetsRef().clear(); + + for (auto offset : newOffsets) + blockData.getOffsetsRef().push_back(offset); + + SmallVector scalarShape(1, 1); // Stub shape + ArrayRef resultShape; + auto pointerType = cast(op.getResult().getType()); + + if (auto shapedType = dyn_cast(pointerType.getPointeeType())) { + resultShape = shapedType.getShape(); } else { - parse(arg, data, op.getLoc(), rewriter, - llvm::SmallDenseMap(0)); + // scalar pointer, should produce a one dimensional memref + resultShape = scalarShape; + assert(blockData.getRank() == 1); } - // Record the BlockData for later processing - initArgIndexIfBlockData.push_back(std::make_pair(i, data)); - } - - // Set insertion point to be before the for loop for new variables passed - // into the new loop. - auto origIp = rewriter.saveInsertionPoint(); - rewriter.setInsertionPoint(op); - - // For each of the BlockData recorded in the last step, insert new - // instructions to describe offset and stride for each dimension and append - // them to init args - for (auto [i, data] : initArgIndexIfBlockData) { - // For each dimension, if the corresponding offset and stride is an - // integer attribute, create a constant value and append them at the - // end of init arg list, which is prepared for calculate layout info with - // loop interation index - for (auto [j, dataOffset] : llvm::enumerate(data.getOffsetsRef())) { - - if (isa(dataOffset)) { - auto constDataOffset = dataOffset.get(); - assert(isa(constDataOffset)); - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr( - dyn_cast(constDataOffset).getInt())); - newInitArgs.push_back(constOp.getResult()); - data.getOffsetsRef()[j] = constOp.getResult(); - } else { - assert(isa(dataOffset.get().getType())); - newInitArgs.push_back(dataOffset.get()); - } + auto newOp = blockData.createCastOp(resultShape, loc, rewriter); + rewriter.replaceOp(op, newOp.getResult()); + + known[newOp.getResult()] = blockData; +} + +void BlockDataParser::rewriteYieldOp(scf::YieldOp op, ConversionPatternRewriter &rewriter, + const std::set blockArgIdxSet, + const llvm::SmallDenseMap &known) +{ + // Any inserted instruction should be before this yield + OpBuilder::InsertionGuard insertionGuard {rewriter}; + rewriter.setInsertionPoint(op); + + auto adaptor = scf::YieldOp::Adaptor(op); + + SmallVector initArgState; + SmallVector operands(adaptor.getOperands()); + + // For each of the init arg that we added additional Values in for loop, we + // need to add corresponding Values as yield operands. The loop below gathers + // BlockData for those values. + for (auto [i, v] : llvm::enumerate(adaptor.getOperands())) { + if (auto mappedV = rewriter.getRemappedValue(v)) { + // If this value is a tensor of pointers produced by AddPtrOp, + // we should have already converted to a ReinterpretCastOp without + // layout information for the normal cases + if (v.getDefiningOp() || v.getDefiningOp() || + v.getDefiningOp()) + { + if (auto castOp = mappedV.getDefiningOp()) { + v = castOp; + } else { + llvm_unreachable("mapped value defined by an unexpected op"); + } + } else { + // If this value is not a tensor of pointers, we will use the + // mapped value, and rely on the conversion will happen later + // automatically when we legalize loop body. + + // TODO: + // The scenario where a value is a tensor of pointers but not + // produced by AddPtrOp is not supported + if (isa(mappedV.getType()) && + isa(dyn_cast(mappedV.getType()).getElementType())) + llvm_unreachable("unsupported scenario where a value is a tensor of " + "pointers but not produced by AddPtrOp"); + v = mappedV; + } + } + + if (blockArgIdxSet.find(i) == blockArgIdxSet.end()) + continue; + + auto reintCastOp = v.getDefiningOp(); + assert(reintCastOp || + (isa(v.getType()) && isa(dyn_cast(v.getType()).getElementType()))); + + BlockData state; + if (reintCastOp) { + parseReinterpretCast(reintCastOp, state, op.getLoc(), rewriter, known); + } else { + parse(v, state, op.getLoc(), rewriter, known); + } + initArgState.push_back(state); } - for (auto [j, dataStride] : llvm::enumerate(data.getStridesRef())) { - - if (isa(dataStride)) { - auto constDataStride = dataStride.get(); - assert(isa(constDataStride)); - auto constOp = rewriter.create( - op.getLoc(), rewriter.getIndexAttr( - dyn_cast(constDataStride).getInt())); - newInitArgs.push_back(constOp.getResult()); - data.getStridesRef()[j] = constOp.getResult(); - } else { - assert(isa(dataStride.get().getType())); - newInitArgs.push_back(dataStride.get()); - } + // For each of the BlockData recorded in the last step, extract value + // that correspond to offset and stride for each dimension and append + // them to yield operands. + for (auto state : initArgState) { + for (auto offset : state.getOffsetsRef()) { + // offsets can be IntAttr zeroes, since reinterpret_cast collapses + // them for the input memref, and the for loop may not update + // offsets other than offsets[0]. Create constants Values for those + // zeroes. + if (isa(offset)) { + auto constOffset = offset.get(); + assert(isa(constOffset) && dyn_cast(constOffset).getInt() == 0 && + "attribute offsets should be zeroes"); + auto constOp = rewriter.create(op.getLoc(), rewriter.getIndexAttr(0)); + operands.push_back(constOp.getResult()); + } else { + operands.push_back(offset.get()); + } + } + + for (OpFoldResult stride : state.getStridesRef()) { + assert(!isa(stride) && "BlockData strides for yield within for" + " loop not expected to be attribute."); + operands.push_back(stride.get()); + } } - // Note that we want the knownPtrs to be indexed by block arg, but we - // only have index for now. Also, the blockdata we record is the init - // arg, but want to to use newly created block arg. These block args - // are not created yet. We will translate this mapping later. - knownPtrsTmp.push_back(std::make_pair(i, data)); - blockArgIdxSet.emplace(i); - - // If the original init arg is a memref produced by reinterpret_cast, - // create a new memref using new strides and offsets created above. - // This produces a canonicalized memref, which will match what the - // for loop generates if it modifies the memref. E.g., original - // reinterpret_cast can produce a memref with const stride: - // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + - // s0 + d1 - // * s1)>> - // The new reinterpret_cast will always have dynamic stride and - // offset: - // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 - // + s0 + d1 * s2)>> - if (newInitArgs[i].getDefiningOp()) { - SmallVector resultShape; - for (auto size : data.getSizesRef()) { - auto constSize = getConstantIntValue(size); - assert(constSize && "expected constant size"); - resultShape.push_back(constSize.value()); - } - - // In current block data layout info, strides and offsets must be dynamic - // value - auto castOp = data.createCastOp(resultShape, op.getLoc(), rewriter); - - LLVM_DEBUG({ - llvm::dbgs() << "new reinterpret_cast with dynamic sizes " - "and offsets:"; - castOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + // Yield is a terminator op that must be at the end of the function + rewriter.setInsertionPointAfter(op); + auto newOp = rewriter.replaceOpWithNewOp(op, operands); + assert(op->getNumResults() == 0); + + LLVM_DEBUG({ + llvm::dbgs() << "new yield:"; + newOp.getOperation()->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); llvm::dbgs() << "\n"; - }); + }); +} - newInitArgs[i] = castOp.getResult(); +void BlockDataParser::rewriteForOp(scf::ForOp op, ConversionPatternRewriter &rewriter, + llvm::SmallDenseMap &known) +{ + + SmallVector newInitArgs; + + SmallVector, 5> initArgIndexIfBlockData; + SmallVector, 5> knownPtrsTmp; + std::set blockArgIdxSet; + + // Create a new list of init args + for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { + auto mappedV = rewriter.getRemappedValue(arg); + memref::ReinterpretCastOp reintCastOp; + + // If this init arg is supposed to be remapped, use the remapped + // value instead. + // In addition, if this init arg is a memref created by a reinterpret_cast + // or a tensor of index, there is a chance that it will be used in addptr. + // Create BlockData for each such init arg. + if (mappedV) { + // TODO: + // Passing a block argument pointer directly into a for loop not + // supported. + assert(!(isa(mappedV) && isa(mappedV.getType())) && + "cannot take pointer block argument as init arg for for loop"); + if (auto op = mappedV.getDefiningOp()) { + // Record memref::ReinterpretCastOp + reintCastOp = op; + newInitArgs.push_back(mappedV); + } else { + newInitArgs.push_back(mappedV); + } + } else { + newInitArgs.push_back(arg); + } + + auto indexTensor = + isa(arg.getType()) && isa(dyn_cast(arg.getType()).getElementType()); + + // Handle memref::ReinterpretCastOp and tensor specially + if (!reintCastOp && !indexTensor) + continue; + + BlockData data; + if (reintCastOp) { + parseReinterpretCast(reintCastOp, data, op.getLoc(), rewriter, llvm::SmallDenseMap(0)); + } else { + parse(arg, data, op.getLoc(), rewriter, llvm::SmallDenseMap(0)); + } + + // Record the BlockData for later processing + initArgIndexIfBlockData.push_back(std::make_pair(i, data)); } - } - rewriter.restoreInsertionPoint(origIp); + // Set insertion point to be before the for loop for new variables passed + // into the new loop. + auto origIp = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + + // For each of the BlockData recorded in the last step, insert new + // instructions to describe offset and stride for each dimension and append + // them to init args + for (auto [i, data] : initArgIndexIfBlockData) { + // For each dimension, if the corresponding offset and stride is an + // integer attribute, create a constant value and append them at the + // end of init arg list, which is prepared for calculate layout info with + // loop interation index + for (auto [j, dataOffset] : llvm::enumerate(data.getOffsetsRef())) { + + if (isa(dataOffset)) { + auto constDataOffset = dataOffset.get(); + assert(isa(constDataOffset)); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(dyn_cast(constDataOffset).getInt())); + newInitArgs.push_back(constOp.getResult()); + data.getOffsetsRef()[j] = constOp.getResult(); + } else { + assert(isa(dataOffset.get().getType())); + newInitArgs.push_back(dataOffset.get()); + } + } - // Create a new scf::ForOp that uses updated init args and same loop body - auto newOp = rewriter.create( - op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), - newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - IRMapping mapping; - mapping.map(op.getInductionVar(), iv); - mapping.map(op.getInitArgs(), newInitArgs); - mapping.map(op.getRegionIterArgs(), args); + for (auto [j, dataStride] : llvm::enumerate(data.getStridesRef())) { + + if (isa(dataStride)) { + auto constDataStride = dataStride.get(); + assert(isa(constDataStride)); + auto constOp = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(dyn_cast(constDataStride).getInt())); + newInitArgs.push_back(constOp.getResult()); + data.getStridesRef()[j] = constOp.getResult(); + } else { + assert(isa(dataStride.get().getType())); + newInitArgs.push_back(dataStride.get()); + } + } - for (auto &bodyOp : op.getRegion().getOps()) { - b.clone(bodyOp, mapping); + // Note that we want the knownPtrs to be indexed by block arg, but we + // only have index for now. Also, the blockdata we record is the init + // arg, but want to to use newly created block arg. These block args + // are not created yet. We will translate this mapping later. + knownPtrsTmp.push_back(std::make_pair(i, data)); + blockArgIdxSet.emplace(i); + + // If the original init arg is a memref produced by reinterpret_cast, + // create a new memref using new strides and offsets created above. + // This produces a canonicalized memref, which will match what the + // for loop generates if it modifies the memref. E.g., original + // reinterpret_cast can produce a memref with const stride: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + + // s0 + d1 + // * s1)>> + // The new reinterpret_cast will always have dynamic stride and + // offset: + // - memref<4x256xbf16, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + // + s0 + d1 * s2)>> + if (newInitArgs[i].getDefiningOp()) { + SmallVector resultShape; + for (auto size : data.getSizesRef()) { + auto constSize = getConstantIntValue(size); + assert(constSize && "expected constant size"); + resultShape.push_back(constSize.value()); + } + + // In current block data layout info, strides and offsets must be dynamic + // value + auto castOp = data.createCastOp(resultShape, op.getLoc(), rewriter); + + LLVM_DEBUG({ + llvm::dbgs() << "new reinterpret_cast with dynamic sizes " + "and offsets:"; + castOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); + + newInitArgs[i] = castOp.getResult(); } - }); - - // Convert the book-keeping data structure to use the correct key and value. - // Key is converted from init arg index to newly created block arg, and - // Value's BlockData fields are converted from init arg to newly created block - // arg - int cnt = op.getRegionIterArgs().size(); - for (auto [i, data] : knownPtrsTmp) { - for (auto it = data.getOffsetsRef().begin(); - it != data.getOffsetsRef().end(); it++) { - *it = newOp.getRegionIterArgs()[cnt]; - cnt++; } - for (auto it = data.getStridesRef().begin(); - it != data.getStridesRef().end(); it++) { - *it = newOp.getRegionIterArgs()[cnt]; - cnt++; + rewriter.restoreInsertionPoint(origIp); + + // Create a new scf::ForOp that uses updated init args and same loop body + auto newOp = rewriter.create(op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), + newInitArgs, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + IRMapping mapping; + mapping.map(op.getInductionVar(), iv); + mapping.map(op.getInitArgs(), newInitArgs); + mapping.map(op.getRegionIterArgs(), args); + + for (auto &bodyOp : op.getRegion().getOps()) { + b.clone(bodyOp, mapping); + } + }); + + // Convert the book-keeping data structure to use the correct key and value. + // Key is converted from init arg index to newly created block arg, and + // Value's BlockData fields are converted from init arg to newly created block + // arg + int cnt = op.getRegionIterArgs().size(); + for (auto [i, data] : knownPtrsTmp) { + for (auto it = data.getOffsetsRef().begin(); it != data.getOffsetsRef().end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + for (auto it = data.getStridesRef().begin(); it != data.getStridesRef().end(); it++) { + *it = newOp.getRegionIterArgs()[cnt]; + cnt++; + } + + auto key = newOp.getRegionIterArgs()[i]; + known.insert(std::make_pair(key, data)); + } + assert(static_cast(cnt) == newOp.getRegionIterArgs().size() && "expect to remap all new block args"); + + // Replace only the results that correspond to the original scf.for + auto resultsToReplaceWith = ResultRange(newOp.result_begin(), newOp.result_begin() + op.getNumResults()); + rewriter.replaceOp(op, resultsToReplaceWith); + + // Update the loop body. Manually invoke the rewrite logic on addptr and yield + // in the loop body, so we can take advantage of the states we built up + for (auto &bodyOp : newOp.getRegion().getOps()) { + if (auto addptrOp = dyn_cast(bodyOp)) { + // FIXME: Constructed adaptor here does not hold the transformed op info. + auto adaptor = triton::AddPtrOp::Adaptor(addptrOp); + rewriteAddPtr(addptrOp, adaptor, rewriter, known); + } else if (auto advanceOp = dyn_cast(bodyOp)) { + rewriteAdvanceOp(advanceOp, rewriter, known); + } else if (auto forOp = dyn_cast(bodyOp)) { + // TODO: + // Nested for loops are not supported at the moment + assert(0 && "nested loops currently not supported"); + } } - auto key = newOp.getRegionIterArgs()[i]; - known.insert(std::make_pair(key, data)); - } - assert(static_cast(cnt) == newOp.getRegionIterArgs().size() && - "expect to remap all new block args"); - - // Replace only the results that correspond to the original scf.for - auto resultsToReplaceWith = ResultRange( - newOp.result_begin(), newOp.result_begin() + op.getNumResults()); - rewriter.replaceOp(op, resultsToReplaceWith); - - // Update the loop body. Manually invoke the rewrite logic on addptr and yield - // in the loop body, so we can take advantage of the states we built up - for (auto &bodyOp : newOp.getRegion().getOps()) { - if (auto addptrOp = dyn_cast(bodyOp)) { - // FIXME: Constructed adaptor here does not hold the transformed op info. - auto adaptor = triton::AddPtrOp::Adaptor(addptrOp); - rewriteAddPtr(addptrOp, adaptor, rewriter, known); - } else if (auto advanceOp = dyn_cast(bodyOp)) { - rewriteAdvanceOp(advanceOp, rewriter, known); - } else if (auto forOp = dyn_cast(bodyOp)) { - // TODO: - // Nested for loops are not supported at the moment - assert(0 && "nested loops currently not supported"); + if (op.getNumRegionIterArgs()) { + auto yieldOp = cast(newOp.getBody()->getTerminator()); + rewriteYieldOp(yieldOp, rewriter, blockArgIdxSet, known); } - } - - if (op.getNumRegionIterArgs()) { - auto yieldOp = cast(newOp.getBody()->getTerminator()); - rewriteYieldOp(yieldOp, rewriter, blockArgIdxSet, known); - } - - LLVM_DEBUG({ - llvm::dbgs() << "new for\n"; - newOp.getOperation()->print(llvm::dbgs(), - OpPrintingFlags().printGenericOpForm()); - llvm::dbgs() << "\n"; - }); + + LLVM_DEBUG({ + llvm::dbgs() << "new for\n"; + newOp.getOperation()->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm()); + llvm::dbgs() << "\n"; + }); } /// @brief Rewrite the triton::AddPtrOp to handle unstructured memory access. @@ -1233,72 +1269,66 @@ void BlockDataParser::rewriteForOp( /// @param adaptor The adaptor of the triton::AddPtrOp, used to get operands. /// @param rewriter The pattern rewriter used to modify the IR. /// @param data The BlockData containing information about the memory access. -void BlockDataParser::rewriteAddPtrToUnstrucMemAcc( - triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, - ConversionPatternRewriter &rewriter, BlockData &data) { - auto loc = op.getLoc(); - auto &offsets = data.getOffsetsRef(); - auto &blockSizes = data.getSizesRef(); - auto &strides = data.getStridesRef(); - Value ptrOffset = adaptor.getOffset(); - Value zeroIdx = - rewriter.create(loc, rewriter.getIndexAttr(0)); - Value oneIdx = - rewriter.create(loc, rewriter.getIndexAttr(1)); - auto addptrRes = op.getResult(); - assert(addptrRes.hasOneUse() && "Invalid: tt.addptr has multiple users"); - auto loadOp = *(addptrRes.user_begin()); - - // Prepare empty tensor for loop based scalar load - // FIXME: We use cast here because addptr must return tensor>. - // True? - auto resTy = cast(addptrRes.getType()); - auto resEPtrTy = resTy.getElementType(); - auto resETy = cast(resEPtrTy).getPointeeType(); - Value loaded = rewriter.create(loc, blockSizes, resETy); - SmallVector initArgs; - initArgs.push_back(loaded); - - SmallVector forLBs; - SmallVector forUBs; - SmallVector forSteps; - for (auto &s : offsets) { - forLBs.push_back(zeroIdx); - } - for (auto &s : blockSizes) { - forUBs.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, s)); - } - for (auto &s : strides) { - forSteps.push_back(oneIdx); - } - SmallVector ivs; - OpBuilder builder(op); - auto loop = createNestedLoops( - builder, loc, 0, blockSizes.size(), forLBs, forUBs, forSteps, ivs, - initArgs, - [&](OpBuilder &bB, Location bLoc, SmallVector &allIVs, - ValueRange iterArgs) { - OpBuilder::InsertionGuard g(bB); - bB.setInsertionPointToStart(bB.getBlock()); - - Value scalarOffsetRaw = - bB.create(bLoc, ptrOffset, allIVs); - Value scalarOffset = bB.create( - bLoc, bB.getIndexType(), scalarOffsetRaw); - // Replace offset & size. Only single element. - data.getOffsetsRef().clear(); - data.getOffsetsRef().push_back(scalarOffset); - data.getSizesRef().clear(); - data.getSizesRef().push_back(bB.getIndexAttr(1)); - data.getStridesRef().clear(); - data.getStridesRef().push_back(bB.getIndexAttr(1)); - memref::ReinterpretCastOp castOp = data.createCastOp({1}, bLoc, bB); - rewriter.replaceOp(op, castOp); - // Move tt.load using this tt.addptr into this block - loadOp->moveAfter(castOp); - loadOp->setAttr("IndirectLoad", UnitAttr::get(op.getContext())); - bB.create(bLoc, iterArgs); - }); +void BlockDataParser::rewriteAddPtrToUnstrucMemAcc(triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, BlockData &data) +{ + auto loc = op.getLoc(); + auto &offsets = data.getOffsetsRef(); + auto &blockSizes = data.getSizesRef(); + auto &strides = data.getStridesRef(); + Value ptrOffset = adaptor.getOffset(); + Value zeroIdx = rewriter.create(loc, rewriter.getIndexAttr(0)); + Value oneIdx = rewriter.create(loc, rewriter.getIndexAttr(1)); + auto addptrRes = op.getResult(); + assert(addptrRes.hasOneUse() && "Invalid: tt.addptr has multiple users"); + auto loadOp = *(addptrRes.user_begin()); + + // Prepare empty tensor for loop based scalar load + // FIXME: We use cast here because addptr must return tensor>. + // True? + auto resTy = cast(addptrRes.getType()); + auto resEPtrTy = resTy.getElementType(); + auto resETy = cast(resEPtrTy).getPointeeType(); + Value loaded = rewriter.create(loc, blockSizes, resETy); + SmallVector initArgs; + initArgs.push_back(loaded); + + SmallVector forLBs; + SmallVector forUBs; + SmallVector forSteps; + for (auto &s : offsets) { + forLBs.push_back(zeroIdx); + } + for (auto &s : blockSizes) { + forUBs.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, s)); + } + for (auto &s : strides) { + forSteps.push_back(oneIdx); + } + SmallVector ivs; + OpBuilder builder(op); + auto loop = createNestedLoops(builder, loc, 0, blockSizes.size(), forLBs, forUBs, forSteps, ivs, initArgs, + [&](OpBuilder &bB, Location bLoc, SmallVector &allIVs, ValueRange iterArgs) { + OpBuilder::InsertionGuard g(bB); + bB.setInsertionPointToStart(bB.getBlock()); + + Value scalarOffsetRaw = bB.create(bLoc, ptrOffset, allIVs); + Value scalarOffset = + bB.create(bLoc, bB.getIndexType(), scalarOffsetRaw); + // Replace offset & size. Only single element. + data.getOffsetsRef().clear(); + data.getOffsetsRef().push_back(scalarOffset); + data.getSizesRef().clear(); + data.getSizesRef().push_back(bB.getIndexAttr(1)); + data.getStridesRef().clear(); + data.getStridesRef().push_back(bB.getIndexAttr(1)); + memref::ReinterpretCastOp castOp = data.createCastOp({1}, bLoc, bB); + rewriter.replaceOp(op, castOp); + // Move tt.load using this tt.addptr into this block + loadOp->moveAfter(castOp); + loadOp->setAttr("IndirectLoad", UnitAttr::get(op.getContext())); + bB.create(bLoc, iterArgs); + }); } } // namespace triton diff --git a/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp b/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp index 4b096316d477785cd6be5f34f5eba8ba623ea714..5d6aab7452eb72d2d35f663ae048e905b8c3a82b 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/UseAnalysis.cpp @@ -22,341 +22,345 @@ using namespace dataflow; #define DEBUG_TYPE "triton-use-analysis" -std::string stringifyUseType(UseType useTy) { - std::string ret; - if (useTy == UseType::MetaUse) { - ret = "MetaUse"; - } else if (useTy == UseType::DataUse) { - ret = "DataUse"; - } else if (useTy == UseType::MixUse) { - ret = "MixUse"; - } else if (useTy == UseType::Undefined) { - ret = "Undefined"; - } - return ret; +std::string stringifyUseType(UseType useTy) +{ + std::string ret; + if (useTy == UseType::MetaUse) { + ret = "MetaUse"; + } else if (useTy == UseType::DataUse) { + ret = "DataUse"; + } else if (useTy == UseType::MixUse) { + ret = "MixUse"; + } else if (useTy == UseType::Undefined) { + ret = "Undefined"; + } + return ret; } #if LLVM_VERSION_MAJOR >= 20 -LogicalResult -triton::UseAnalysis::visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) { +LogicalResult triton::UseAnalysis::visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) +{ #else -void triton::UseAnalysis::visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) { +void triton::UseAnalysis::visitOperation(Operation *op, ArrayRef operands, ArrayRef results) +{ #endif - - if (op->getResults().size() == 1) { - auto resultType = dyn_cast(op->getResult(0).getType()); - if (resultType && isa(resultType.getElementType())) { - for (auto opnd : operands) { - propagateUse(opnd, UseType::MetaUse); - } + // If an Op's result is PtrTensor, this op's operands are all MetaUse. + // TODO: We do not handle TensorPtr which is defined by tt.make_tensor_ptr. + // Because handling tt.load also propogates to tt.make_tensor_ptr. + if (op->getResults().size() == 1) { + auto resultType = dyn_cast(op->getResult(0).getType()); + if (resultType && isa(resultType.getElementType())) { + for (auto opnd : operands) { + propagateUse(opnd, UseType::MetaUse); + } + } } - } - TypeSwitch(op) - .Case([&](auto load) { - propagateUse(operands[0], UseType::MetaUse); - auto mask = load.getMask(); - auto other = load.getOther(); - if (mask) { - assert(mask != other && "mask and other cannot be the same"); - propagateUse(operands[1], UseType::MetaUse); - } - if (other) { - propagateUse(operands[2], UseType::MetaUse); - } - }) - .Case([&](auto store) { - propagateUse(operands[0], UseType::MetaUse); - propagateUse(operands[1], UseType::DataUse); - auto value = store.getValue(); - auto mask = store.getMask(); - if (mask) { - assert(mask != value && "mask and data cannot be the same"); - propagateUse(operands[2], UseType::MetaUse); - } - }) - // Consider triton::AtomicRMWOp as store operation - .Case([&](auto atomicOp) { - propagateUse(operands[0], UseType::MetaUse); - propagateUse(operands[1], UseType::DataUse); - auto value = atomicOp.getVal(); - auto mask = atomicOp.getMask(); - if (mask) { - assert(mask != value && "mask and data cannot be the same"); - propagateUse(operands[2], UseType::MetaUse); - } - }) - .Case([&](auto dot) { - propagateResults(operands[0], results); - propagateResults(operands[1], results); + TypeSwitch(op) + .Case([&](auto load) { + propagateUse(operands[0], UseType::MetaUse); + auto mask = load.getMask(); + auto other = load.getOther(); + if (mask) { + assert(mask != other && "mask and other cannot be the same"); + propagateUse(operands[1], UseType::MetaUse); + } + if (other) { + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case([&](auto store) { + propagateUse(operands[0], UseType::MetaUse); + propagateUse(operands[1], UseType::DataUse); + auto value = store.getValue(); + auto mask = store.getMask(); + if (mask) { + assert(mask != value && "mask and data cannot be the same"); + propagateUse(operands[2], UseType::MetaUse); + } + }) + // Consider triton::AtomicRMWOp as store operation + .Case([&](auto atomicOp) { + propagateUse(operands[0], UseType::MetaUse); + propagateUse(operands[1], UseType::DataUse); + auto value = atomicOp.getVal(); + auto mask = atomicOp.getMask(); + if (mask) { + assert(mask != value && "mask and data cannot be the same"); + propagateUse(operands[2], UseType::MetaUse); + } + }) + .Case([&](auto dot) { + propagateResults(operands[0], results); + propagateResults(operands[1], results); - auto opc = dot.getC(); - triton::SplatOp splat; - if (opc) { - splat = opc.template getDefiningOp(); - } + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) { + splat = opc.template getDefiningOp(); + } - if (opc && splat && splat.getSrc().getDefiningOp()) { - propagateUse(operands[2], UseType::MetaUse); - } else { - propagateUse(operands[2], UseType::DataUse); - } - }) - .Default([&](Operation *op) { - // this condition account for tt.addptr - for (auto operand : operands) { - propagateResults(operand, results); - } - }); + if (opc && splat && splat.getSrc().getDefiningOp()) { + propagateUse(operands[2], UseType::MetaUse); + } else { + propagateUse(operands[2], UseType::DataUse); + } + }) + .Default([&](Operation *op) { + // this condition account for tt.addptr + for (auto operand : operands) { + propagateResults(operand, results); + } + }); #if LLVM_VERSION_MAJOR >= 20 - return success(); + return success(); #endif } -LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) { - MLIRContext *context = funcOp.getContext(); - SymbolTableCollection symbolTable; - - DataFlowSolver solver; - solver.load(); - solver.load(); - solver.load(symbolTable); - if (failed(solver.initializeAndRun(funcOp))) { - return failure(); - } - auto &os = llvm::dbgs(); - // Walk the func op, convert tags on operands to tags on operations - funcOp.walk([&](Operation *op) { - LLVM_DEBUG({ os << "[UseAnalysis] op is " << *op << "\n"; }); - UseType useType = UseType::Undefined; - for (auto result : op->getResults()) { - LLVM_DEBUG({ os << "[UseAnalysis] ===> result is " << result << "\n"; }); - auto use = solver.lookupState(result); - assert(use && "Lattice value not found"); - auto thisUseType = use->type; - LLVM_DEBUG({ - os << "[UseAnalysis] ==========> useType is " - << stringifyUseType(thisUseType) << "\n"; - }); - if (thisUseType == UseType::Undefined) { - continue; - } - if (useType == UseType::Undefined) { - useType = thisUseType; - } - if (thisUseType == UseType::MixUse || thisUseType != useType) { - useType = UseType::MixUse; - break; - } +LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) +{ + MLIRContext *context = funcOp.getContext(); + SymbolTableCollection symbolTable; + // First we run the Dataflow analysis to mark UseInfo of operands + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(funcOp))) { + return failure(); } - if (useType == UseType::Undefined) { - LLVM_DEBUG({ op->setAttr("Undefined", UnitAttr::get(context)); }); - return; - } else if (useType == UseType::MetaUse) { - if (!isa(op)) { - assert(op->getNumResults() == 1 && - "Ops used for meta computation are expected to have one result"); - } - for (auto it = 0; it < op->getNumResults(); ++it) { - // Only set the tag if the operation uses tensors - if (isa(op->getResult(it).getType()) || - (isa(op) && - isa(op->getResult(it).getType()))) { - // Setting tag for erasing op later - op->setAttr("MetaUse", UnitAttr::get(context)); + llvm::DenseMap> opMetaUsers; + auto &os = llvm::dbgs(); + // Then we walk the func op, convert tags on operands to tags on operations + funcOp.walk([&](Operation *op) { + LLVM_DEBUG({ os << "[UseAnalysis] op is " << *op << "\n"; }); + UseType opUseType = UseType::Undefined; + for (auto result : op->getResults()) { + LLVM_DEBUG({ os << "[UseAnalysis] ===> result is " << result << "\n"; }); + auto resUse = solver.lookupState(result); + if (!resUse) { + llvm::report_fatal_error("Lattice value not found"); + } + auto resUseTy = resUse->type; + LLVM_DEBUG({ os << "[UseAnalysis] ==========> resUseTy is " << stringifyUseType(resUseTy) << "\n"; }); + if (resUseTy == UseType::Undefined) { + continue; + } + if (opUseType == UseType::Undefined) { + opUseType = resUseTy; + } + if (resUseTy == UseType::MixUse || resUseTy != opUseType) { + opUseType = UseType::MixUse; + break; + } } - } - return; - } else if (useType == UseType::DataUse) { - LLVM_DEBUG({ op->setAttr("DataUse", UnitAttr::get(context)); }); - return; - } - assert(useType == UseType::MixUse); + if (opUseType == UseType::Undefined) { + op->setAttr("Undefined", UnitAttr::get(context)); + return; + } else if (opUseType == UseType::MetaUse) { + // FIXME: Why do we need this assert? + // if (!isa(op)) { + // assert(op->getNumResults() == 1 && + // "Ops used for meta computation are expected to have one result"); + // } + for (auto i = 0; i < op->getNumResults(); ++i) { + // Only set the tag if the operation uses tensors + if (isa(op->getResult(i).getType()) || + (isa(op) && isa(op->getResult(i).getType()))) + { + // Setting tag for erasing op later + op->setAttr("MetaUse", UnitAttr::get(context)); + } + } + return; + } else if (opUseType == UseType::DataUse) { + op->setAttr("DataUse", UnitAttr::get(context)); + return; + } - // If the operation only produces scalars, no need to clone it - bool shapedResult = true; - for (auto result : op->getResults()) - shapedResult &= isa(result.getType()); - if (!shapedResult) { - LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); - return; - } + op->setAttr("MixUse", UnitAttr::get(context)); - llvm::SetVector metaUsers; - for (auto result : op->getResults()) { - for (auto user : result.getUsers()) { - TypeSwitch(user) - .Case([&](auto load) { - auto ptr = load.getPtr(); - auto mask = load.getMask(); - auto other = load.getOther(); - if (result == ptr || result == mask || result == other) { - metaUsers.insert(user); - } - }) - .Case([&](auto store) { - auto ptr = store.getPtr(); - auto mask = store.getMask(); - if (result == ptr || result == mask) { - metaUsers.insert(user); - } - }) - .Case([&](auto atomicOp) { - auto ptr = atomicOp.getPtr(); - auto mask = atomicOp.getMask(); - if (result == ptr || result == mask) - metaUsers.insert(user); - }) - .Case([&](auto dot) { - auto opc = dot.getC(); - triton::SplatOp splat; - if (opc) { - splat = opc.template getDefiningOp(); - } + // Collect the metaUsers of this Op + auto &metaUsers = opMetaUsers[op]; + for (auto result : op->getResults()) { + for (auto user : result.getUsers()) { + TypeSwitch(user) + .Case([&](auto load) { + auto ptr = load.getPtr(); + auto mask = load.getMask(); + auto other = load.getOther(); + if (result == ptr || result == mask || result == other) { + metaUsers.insert(user); + } + }) + .Case([&](auto store) { + auto ptr = store.getPtr(); + auto mask = store.getMask(); + if (result == ptr || result == mask) { + metaUsers.insert(user); + } + }) + .Case([&](auto atomicOp) { + auto ptr = atomicOp.getPtr(); + auto mask = atomicOp.getMask(); + if (result == ptr || result == mask) + metaUsers.insert(user); + }) + .Case([&](auto dot) { + auto opc = dot.getC(); + triton::SplatOp splat; + if (opc) { + splat = opc.template getDefiningOp(); + } - if (opc && splat && - splat.getSrc().getDefiningOp()) { - metaUsers.insert(user); - } - }) - .Default([&](Operation *op) { - bool allMeta = true; - for (auto res : op->getResults()) { - auto resUse = solver.lookupState(res); - if (resUse->type != UseType::MetaUse) { - allMeta = false; - break; - } - } - if (allMeta) { - metaUsers.insert(user); - } - }); - } - } + if (opc && splat && splat.getSrc().getDefiningOp()) { + metaUsers.insert(user); + } + }) + .Default([&](Operation *op) { + bool allMeta = true; + for (auto res : op->getResults()) { + auto resUse = solver.lookupState(res); + if (resUse->type != UseType::MetaUse) { + allMeta = false; + break; + } + } + if (allMeta) { + metaUsers.insert(user); + } + }); + } + } - // If the operation doesn't have direct meta users, no need to clone it - if (metaUsers.empty()) { - LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); - return; - } + return; + }); - // Clone the operation; switch all meta users to use the clone - OpBuilder builder(op); - auto clone = builder.clone(*op); - LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); }); + // Now ops are marked with UseType. + // Post-process to handle exceptional cases. + LLVM_DEBUG({ os << "[UseAnalysis] Before post-process, funcOp is " << *funcOp << "\n"; }); + funcOp.walk([&](Operation *op) { + // Handle indirect load case. + // For example, load(1st) -> computeOp -> load(2nd). + // The first load is IndirectLoadInterfaceOp. + // Do not inplace replace MetaUse by MixUse. Because the condition checking + // depends on that the op has the attr of MetaUse. + // Handle the indirect load interface op + // We first trace from the 1st load to the 2nd load with the ops between + // them marked as MixUse. Then we traceback from the 2nd load to mark defs + // MixUse. + if (opIsIndirectLoad(op) || opIsIndirectCalc(op)) { + LLVM_DEBUG({ os << "[UseAnalysis] Found indirect load interface op: " << *op << "\n"; }); + llvm::SmallPtrSet stopOps; + // Modify the users of this op's result. + traverseForwardUpdateUserChainIf( + op, + /*conditionFn*/ + [op](Operation *curOp) { + // tt.addptr should always be treated as MetaUse + return isMetaUse(curOp) && curOp != op && !isa(curOp); + }, + /*stopFn*/ + [&](Operation *curOp) { + // triton::LoadOp without MetaUse means it is an indirect load + // instead of the load providing the offset. + // The pattern is as follows, + // load -> ops -> load + // We need to ensure the intermediate ops are marked MixUse + // so that they will be replaced instead of be erased without + // conversion. + return isa(curOp) && !curOp->hasAttr("MetaUse"); + }, + /*actionFn*/ + [](OpBuilder &b, Operation *op) { op->setAttr("MixUse", UnitAttr::get(op->getContext())); }, stopOps); + LLVM_DEBUG({ + os << "[UseAnalysis] stopOps are \n"; + int i = 0; + for (auto it = stopOps.begin(); it != stopOps.end(); it++) { + os << i++ << ": " << *(*it) << "\n"; + } + }); + LLVM_DEBUG({ os << "[UseAnalysis] After traverseForward, funcOp is " << *funcOp << "\n"; }); + for (auto it = stopOps.begin(); it != stopOps.end(); it++) { + auto stopOp = *it; + traverseBackwardUpdateOperandChainIf( + stopOp, + [stopOp](Operation *curOp) { + return isMetaUse(curOp) && curOp != stopOp && !isa(curOp); + }, + [](OpBuilder &b, Operation *op) { op->setAttr("MixUse", UnitAttr::get(op->getContext())); }); + } + LLVM_DEBUG({ os << "[UseAnalysis] After traverseBackward of stopOp, funcOp is " << *funcOp << "\n"; }); + // Modify this op. + op->setAttr("MixUse", UnitAttr::get(op->getContext())); + } + }); + // Remove MetaUse in case of MixUse existing in the op + funcOp.walk([&](Operation *op) { + if (isMetaUse(op) && isMixUse(op)) { + op->removeAttr("MetaUse"); + } + }); - // Setting tag for erasing op later - clone->setAttr("MetaUse", UnitAttr::get(context)); + // clone Op with MixUse attr + funcOp.walk([&](Operation *op) { + // If the operation only produces scalars, no need to clone it + bool shapedResult = false; + for (auto result : op->getResults()) + shapedResult |= isa(result.getType()); + if (!shapedResult) { + return; + } - for (auto [res_i, result] : llvm::enumerate(op->getResults())) { - for (auto user : metaUsers) { - for (auto &operand : user->getOpOperands()) { - if (operand.get() == result) { - operand.set(clone->getResult(res_i)); - } + // If the operation doesn't have direct meta users, no need to clone it + // But why the op without metaUser is marked as MixUse? + auto &metaUsers = opMetaUsers[op]; + if (metaUsers.empty()) { + return; } - } - } - }); - LLVM_DEBUG({ - os << "[UseAnalysis] Before post-process, funcOp is " << *funcOp << "\n"; - }); - // Post-process - funcOp.walk([&](Operation *op) { - // Handle indirect load case. - // For example, load(1st) -> computeOp -> load(2nd). - // The first load is IndirectLoadInterfaceOp. - // Do not inplace replace MetaUse by MixUse. Because the condition checking - // depends on that the op has the attr of MetaUse. - // Handle the indirect load interface op - // We first trace from the 1st load to the 2nd load with the ops between - // them marked as MixUse. Then we traceback from the 2nd load to mark defs - // MixUse. - if (opIsIndirectLoad(op) || opIsIndirectCalc(op)) { - LLVM_DEBUG({ - os << "[UseAnalysis] Found indirect load interface op: " << *op << "\n"; - }); - llvm::SmallPtrSet stopOps; - // Modify the users of this op's result. - traverseForwardUpdateUserChainIf( - op, - /*conditionFn*/ - [op](Operation *curOp) { return isMetaUse(curOp) && curOp != op; }, - /*stopFn*/ - [&](Operation *curOp) { - // triton::LoadOp without MetaUse means it is an indirect load - // instead of the load providing the offset. - // The pattern is as follows, - // load -> ops -> load - // We need to ensure the intermediate ops are marked MixUse - // so that they will be replaced instead of be erased without - // conversion. - return isa(curOp) && !curOp->hasAttr("MetaUse"); - }, - /*actionFn*/ - [](OpBuilder &b, Operation *op) { - op->setAttr("MixUse", UnitAttr::get(op->getContext())); - }, - stopOps); - LLVM_DEBUG({ - os << "[UseAnalysis] stopOps are \n"; - int i = 0; - for (auto it = stopOps.begin(); it != stopOps.end(); it++) { - os << i++ << ": " << *(*it) << "\n"; + // Even if tt.reduce is used for tt.load, it should not be deleted because + // we need to do real tt.reduce at runtime. + // TODO: Add more ops here + if (isa(op)) { + return; } - }); - LLVM_DEBUG({ - os << "[UseAnalysis] After trace, funcOp is " << *funcOp << "\n"; - }); - for (auto it = stopOps.begin(); it != stopOps.end(); it++) { - auto stopOp = *it; - traverseBackwardUpdateOperandChainIf( - stopOp, - [stopOp](Operation *curOp) { - return isMetaUse(curOp) && curOp != stopOp; - }, - [](OpBuilder &b, Operation *op) { - op->setAttr("MixUse", UnitAttr::get(op->getContext())); - }); - } - LLVM_DEBUG({ - os << "[UseAnalysis] After traceback of stopOp, funcOp is " << *funcOp - << "\n"; - }); - // Modify this op. - op->setAttr("MixUse", UnitAttr::get(op->getContext())); - } - }); - // Remove MetaUse in case of MixUse existing in the op - funcOp.walk([&](Operation *op) { - if (isMetaUse(op) && isMixUse(op)) { - op->removeAttr("MetaUse"); - } - }); - LLVM_DEBUG({ - os << "[UseAnalysis] After post-process, funcOp is " << *funcOp << "\n"; - }); - return success(); -} -MetaUseEraser::MetaUseEraser(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, context) {} + // Clone the operation; switch all meta users to use the clone + OpBuilder builder(op); + auto clone = builder.clone(*op); + // Setting tag for erasing op later + clone->setAttr("MetaUse", UnitAttr::get(context)); + clone->removeAttr("MixUse"); + // Set metaUser's operand to the result of this cloned Op so that + // the meraUsers are erased along with each other. + for (auto [res_i, result] : llvm::enumerate(op->getResults())) { + for (auto user : metaUsers) { + for (auto &operand : user->getOpOperands()) { + if (operand.get() == result) { + operand.set(clone->getResult(res_i)); + } + } + } + } + }); -LogicalResult MetaUseEraser::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { - if (isa(op)) { - return rewriter.notifyMatchFailure(op, - "AddPtrOp will be handled separately"); - } - if (isMetaUse(op)) { - rewriter.eraseOp(op); + LLVM_DEBUG({ os << "[UseAnalysis] After post-process, funcOp is " << *funcOp << "\n"; }); return success(); - } - return rewriter.notifyMatchFailure(op, "requires meta ops"); +} + +MetaUseEraser::MetaUseEraser(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/10, context) {} + +LogicalResult MetaUseEraser::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const +{ + if (isa(op)) { + return rewriter.notifyMatchFailure(op, "AddPtrOp will be handled separately"); + } + if (isMetaUse(op)) { + rewriter.eraseOp(op); + return success(); + } + return rewriter.notifyMatchFailure(op, "requires meta ops"); } diff --git a/ascend/triton-adapter/lib/Utils/Utils.cpp b/ascend/triton-adapter/lib/Utils/Utils.cpp index efa758ab0df1c42ac0d053f40d56a1132eefe7c6..37095b9905234a064199f8261de31ae07c6d0128 100644 --- a/ascend/triton-adapter/lib/Utils/Utils.cpp +++ b/ascend/triton-adapter/lib/Utils/Utils.cpp @@ -46,724 +46,702 @@ namespace mlir { namespace ConverterUtils { -Value getTransposedValue(Value source, const Location loc, - ConversionPatternRewriter &rewriter, - llvm::ArrayRef order) { - auto sourceType = cast(source.getType()); - auto sourceRank = sourceType.getRank(); - - SmallVector perm(order); - SmallVector originalShape(sourceType.getShape()); - SmallVector transposedShape(sourceRank); - for (size_t i = 0; i < sourceRank; i++) { - transposedShape[i] = originalShape[perm[i]]; - } - - Value transposeInit = rewriter.create( - loc, transposedShape, sourceType.getElementType()); - - Value transpose = - rewriter.create(loc, source, transposeInit, perm) - .getResults()[0]; - - return transpose; +Value getTransposedValue(Value source, const Location loc, ConversionPatternRewriter &rewriter, + llvm::ArrayRef order) +{ + auto sourceType = cast(source.getType()); + auto sourceRank = sourceType.getRank(); + + SmallVector perm(order); + SmallVector originalShape(sourceType.getShape()); + SmallVector transposedShape(sourceRank); + for (size_t i = 0; i < sourceRank; i++) { + transposedShape[i] = originalShape[perm[i]]; + } + + Value transposeInit = rewriter.create(loc, transposedShape, sourceType.getElementType()); + + Value transpose = rewriter.create(loc, source, transposeInit, perm).getResults()[0]; + + return transpose; } -SmallVector getNParallelLoopsAttrs(unsigned n) { - return SmallVector(n, utils::IteratorType::parallel); +SmallVector getNParallelLoopsAttrs(unsigned n) +{ + return SmallVector(n, utils::IteratorType::parallel); } -Value getScalarValue(Value operand, Location loc, - ConversionPatternRewriter &rewriter) { - SmallVector ops; - auto reconstructScalarValue = [&](Value src) { - for (auto op = ops.rbegin(); op != ops.rend(); ++op) { - src = mlir::TypeSwitch(*op) - .Case([&](Operation *op) { - auto resType = op->getResults()[0].getType(); - if (auto shapedType = dyn_cast(resType)) { - resType = shapedType.getElementType(); - } - return rewriter.create(loc, resType, src); - }) - .Case([&](Operation *op) { - auto resType = op->getResults()[0].getType(); - if (auto shapedType = dyn_cast(resType)) { - resType = shapedType.getElementType(); - } - return rewriter.create(loc, resType, src); - }) - .Default([](Operation *op) { - llvm_unreachable("unsupported op in generating "); - return nullptr; - }); - } - return src; - }; - - while (true) { - if (!dyn_cast(operand.getType())) { - return reconstructScalarValue(operand); - } else if (auto op = operand.getDefiningOp()) { - if (auto attr = dyn_cast(op.getValue())) { - if (!attr.isSplat()) { - InFlightDiagnostic diag = emitError(loc) - << "other value used in masked load " - "produced by unsupported instruction"; - return nullptr; +Value getScalarValue(Value operand, Location loc, ConversionPatternRewriter &rewriter) +{ + SmallVector ops; + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = mlir::TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return rewriter.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize(rewriter, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; } - auto elemValue = attr.getSplatValue(); - auto constOp = arith::ConstantOp::materialize( - rewriter, elemValue, attr.getElementType(), op.getLoc()); - return reconstructScalarValue(constOp.getResult()); - } - } else if (auto op = operand.getDefiningOp()) { - operand = op.getSrc(); - } else if (auto op = operand.getDefiningOp()) { - ops.push_back(op.getOperation()); - operand = op.getIn(); - } else if (auto op = operand.getDefiningOp()) { - ops.push_back(op.getOperation()); - operand = op.getIn(); - } else { - InFlightDiagnostic diag = emitError(loc) - << "other value used in masked load produced " - "by unsupported instruction"; - return nullptr; } - } - return nullptr; + return nullptr; } -memref::SubViewOp makeSubViewOp(Value src, - const llvm::SmallVector &sizes, - const Location &loc, - ConversionPatternRewriter &rewriter) { - auto srcType = dyn_cast(src.getType()); - SmallVector offsets(srcType.getRank(), - rewriter.getIndexAttr(0)); - SmallVector strides(srcType.getRank(), - rewriter.getIndexAttr(1)); - auto dstType = - memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); - return rewriter.create(loc, dyn_cast(dstType), - src, offsets, sizes, strides); +memref::SubViewOp makeSubViewOp(Value src, const llvm::SmallVector &sizes, const Location &loc, + ConversionPatternRewriter &rewriter) +{ + auto srcType = dyn_cast(src.getType()); + SmallVector offsets(srcType.getRank(), rewriter.getIndexAttr(0)); + SmallVector strides(srcType.getRank(), rewriter.getIndexAttr(1)); + auto dstType = memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return rewriter.create(loc, dyn_cast(dstType), src, offsets, sizes, strides); } -std::optional getFullShapeOp(Value val, - ConversionPatternRewriter &rewriter) { - assert(isa(val.getType())); - - if (isa(val)) { - auto blockArg = dyn_cast(val); - auto blockOp = blockArg.getOwner()->getParentOp(); - if (isa(blockOp)) { - auto forOp = dyn_cast(blockOp); - auto operand = forOp.getTiedLoopInit(blockArg)->get(); - return getFullShapeOp(operand, rewriter); - } else { - emitError(val.getLoc()) - << "getFullShapeOp() only support ReinterpretCastOp " - "and scf.for's block argument, but got : " - << val << "\n"; +std::optional getFullShapeOp(Value val, ConversionPatternRewriter &rewriter) +{ + assert(isa(val.getType())); + + if (isa(val)) { + auto blockArg = dyn_cast(val); + auto blockOp = blockArg.getOwner()->getParentOp(); + if (isa(blockOp)) { + auto forOp = dyn_cast(blockOp); + auto operand = forOp.getTiedLoopInit(blockArg)->get(); + return getFullShapeOp(operand, rewriter); + } else { + emitError(val.getLoc()) << "getFullShapeOp() only support ReinterpretCastOp " + "and scf.for's block argument, but got : " + << val << "\n"; + } + return std::nullopt; } - return std::nullopt; - } - if (!isa(val.getDefiningOp())) { - emitError(val.getLoc()) - << "getFullShapeOp() only support ReinterpretCastOp " - "and scf.for's block argument, but got : " - << val << "\n"; - return std::nullopt; - } + if (!isa(val.getDefiningOp())) { + emitError(val.getLoc()) << "getFullShapeOp() only support ReinterpretCastOp " + "and scf.for's block argument, but got : " + << val << "\n"; + return std::nullopt; + } - auto reCastOp = val.getDefiningOp(); - if (reCastOp->hasAttr("tensor_ptr_full_shape")) - return reCastOp; + auto reCastOp = val.getDefiningOp(); + if (reCastOp->hasAttr("tensor_ptr_full_shape")) + return reCastOp; - return getFullShapeOp(reCastOp.getSource(), rewriter); + return getFullShapeOp(reCastOp.getSource(), rewriter); } -SmallVector -getBoundarySizes(llvm::ArrayRef boundaryCheck, Value ptr, - const Location &loc, ConversionPatternRewriter &rewriter) { - if (isa(ptr.getType())) - ptr = rewriter.getRemappedValue(ptr); - - auto shapedType = dyn_cast_if_present(ptr.getType()); - assert(shapedType && shapedType.hasStaticShape()); - - auto fullShapeOp = getFullShapeOp(ptr, rewriter); - - assert(fullShapeOp.has_value()); - SmallVector boundarySize = - getAsIndexOpFoldResult(rewriter.getContext(), shapedType.getShape()); - - auto fullShapeReCast = - dyn_cast(fullShapeOp.value()); - OpFoldResult curPtrOffset; - if (auto curReCast = ptr.getDefiningOp()) { - curPtrOffset = curReCast.getConstifiedMixedOffset(); - } else if (isa(ptr) && - isa(ptr.getParentBlock()->getParentOp())) { - // Here's to process loop state where ptr is just from loop interator. - // Following assertion corresponds to conversion result from `rewriteFor` - auto blockArg = dyn_cast(ptr); - auto forOp = dyn_cast(ptr.getParentBlock()->getParentOp()); - auto initReCastOfLoop = forOp.getTiedLoopInit(blockArg) - ->get() - .getDefiningOp(); - assert(initReCastOfLoop && initReCastOfLoop.getOffsets().size() == 1); - Value initReCastOffset = initReCastOfLoop.getOffsets()[0]; - - for (OpOperand &use : initReCastOffset.getUses()) { - if (use.getOwner() == initReCastOfLoop) - continue; - else if (use.getOwner() == forOp) - curPtrOffset = OpFoldResult(forOp.getTiedLoopRegionIterArg(&use)); - else - llvm_unreachable("Illegal interation offset after rewriteFor"); - } - } else { - llvm_unreachable("Unsupported state when check tensor_ptr boundary"); - } - - assert(curPtrOffset); - - OpFoldResult offsetShift = subOpFoldResult( - curPtrOffset, fullShapeReCast.getConstifiedMixedOffset(), loc, rewriter); - - for (int i = 0; i < shapedType.getRank(); ++i) { - if (llvm::find(boundaryCheck, i) != boundaryCheck.end()) { - auto fullShape = fullShapeReCast.getConstifiedMixedSizes()[i]; - - OpFoldResult curOffset = divOpFoldResult( - offsetShift, fullShapeReCast.getConstifiedMixedStrides()[i], loc, - rewriter); - OpFoldResult curLeftSize = - maxOpFoldResult(subOpFoldResult(fullShape, curOffset, loc, rewriter), - rewriter.getIndexAttr(0), loc, rewriter); - - boundarySize[i] = - minOpFoldResult(boundarySize[i], curLeftSize, loc, rewriter); - - offsetShift = remOpFoldResult( - offsetShift, fullShapeReCast.getConstifiedMixedStrides()[i], loc, - rewriter); - } - } - - return boundarySize; +SmallVector getBoundarySizes(llvm::ArrayRef boundaryCheck, Value ptr, const Location &loc, + ConversionPatternRewriter &rewriter) +{ + if (isa(ptr.getType())) + ptr = rewriter.getRemappedValue(ptr); + + auto shapedType = dyn_cast_if_present(ptr.getType()); + assert(shapedType && shapedType.hasStaticShape()); + + auto fullShapeOp = getFullShapeOp(ptr, rewriter); + + assert(fullShapeOp.has_value()); + SmallVector boundarySize = getAsIndexOpFoldResult(rewriter.getContext(), shapedType.getShape()); + + auto fullShapeReCast = dyn_cast(fullShapeOp.value()); + OpFoldResult curPtrOffset; + if (auto curReCast = ptr.getDefiningOp()) { + curPtrOffset = curReCast.getConstifiedMixedOffset(); + } else if (isa(ptr) && isa(ptr.getParentBlock()->getParentOp())) { + // Here's to process loop state where ptr is just from loop interator. + // Following assertion corresponds to conversion result from `rewriteFor` + auto blockArg = dyn_cast(ptr); + auto forOp = dyn_cast(ptr.getParentBlock()->getParentOp()); + auto initReCastOfLoop = forOp.getTiedLoopInit(blockArg)->get().getDefiningOp(); + assert(initReCastOfLoop && initReCastOfLoop.getOffsets().size() == 1); + Value initReCastOffset = initReCastOfLoop.getOffsets()[0]; + + for (OpOperand &use : initReCastOffset.getUses()) { + if (use.getOwner() == initReCastOfLoop) + continue; + else if (use.getOwner() == forOp) + curPtrOffset = OpFoldResult(forOp.getTiedLoopRegionIterArg(&use)); + else + llvm_unreachable("Illegal interation offset after rewriteFor"); + } + } else { + llvm_unreachable("Unsupported state when check tensor_ptr boundary"); + } + + assert(curPtrOffset); + + OpFoldResult offsetShift = subOpFoldResult(curPtrOffset, fullShapeReCast.getConstifiedMixedOffset(), loc, rewriter); + + for (int i = 0; i < shapedType.getRank(); ++i) { + if (llvm::find(boundaryCheck, i) != boundaryCheck.end()) { + auto fullShape = fullShapeReCast.getConstifiedMixedSizes()[i]; + + OpFoldResult curOffset = + divOpFoldResult(offsetShift, fullShapeReCast.getConstifiedMixedStrides()[i], loc, rewriter); + OpFoldResult curLeftSize = maxOpFoldResult(subOpFoldResult(fullShape, curOffset, loc, rewriter), + rewriter.getIndexAttr(0), loc, rewriter); + + boundarySize[i] = minOpFoldResult(boundarySize[i], curLeftSize, loc, rewriter); + + offsetShift = remOpFoldResult(offsetShift, fullShapeReCast.getConstifiedMixedStrides()[i], loc, rewriter); + } + } + + return boundarySize; } -SmallVector getBroadcastDims(RankedTensorType src, - RankedTensorType dst) { - SmallVector broadcastDims; - auto srcShape = src.getShape(); - auto dstShape = dst.getShape(); - - for (size_t i = 0; i < srcShape.size(); ++i) { - if (dstShape[i] != srcShape[i]) { - assert(srcShape[i] == 1 && - "Size of source broadcast dimension must be 1"); - broadcastDims.push_back(i); - } - } - assert(!broadcastDims.empty() && "Cannot identify broadcast dimension"); - return broadcastDims; +SmallVector getBroadcastDims(RankedTensorType src, RankedTensorType dst) +{ + SmallVector broadcastDims; + auto srcShape = src.getShape(); + auto dstShape = dst.getShape(); + + for (size_t i = 0; i < srcShape.size(); ++i) { + if (dstShape[i] != srcShape[i]) { + assert(srcShape[i] == 1 && "Size of source broadcast dimension must be 1"); + broadcastDims.push_back(i); + } + } + assert(!broadcastDims.empty() && "Cannot identify broadcast dimension"); + return broadcastDims; } // Dimensions of collapesd tensor is all unbroadcast dims -SmallVector getUnbroadcastDims(RankedTensorType src, - RankedTensorType dst) { - SmallVector unbroadcastDims; - auto srcShape = src.getShape(); - auto dstShape = dst.getShape(); - - for (size_t i = 0; i < srcShape.size(); ++i) { - if (dstShape[i] == srcShape[i]) { - unbroadcastDims.emplace_back(srcShape[i]); - } - } - return unbroadcastDims; +SmallVector getUnbroadcastDims(RankedTensorType src, RankedTensorType dst) +{ + SmallVector unbroadcastDims; + auto srcShape = src.getShape(); + auto dstShape = dst.getShape(); + + for (size_t i = 0; i < srcShape.size(); ++i) { + if (dstShape[i] == srcShape[i]) { + unbroadcastDims.emplace_back(srcShape[i]); + } + } + return unbroadcastDims; } } // namespace ConverterUtils namespace triton { -mlir::Operation * -findFirstMatchingOperandDef(mlir::Operation *rootOp, - const std::function &condFn) { - LLVM_DEBUG(llvm::dbgs() << "[findFirstMatchingOperandDef] Current op: " - << *rootOp << "\n"); - mlir::Value lhs = nullptr; - mlir::Value rhs = nullptr; - if (auto op = dyn_cast(rootOp)) { - lhs = op.getPtr(); - rhs = op.getOffset(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getLhs(); - rhs = op.getRhs(); - } else if (auto op = dyn_cast(rootOp)) { - lhs = op.getSrc(); - } else if (auto op = dyn_cast(rootOp)) { - } else { - rootOp->emitRemark("Backtracing encounters unsupported Operation"); - return nullptr; - } - // Backtrace operands - if (!lhs) { - return nullptr; - } - auto lhsDef = lhs.getDefiningOp(); - mlir::Operation *targetOp; - if (lhsDef) { - if (condFn(lhsDef)) { - targetOp = lhsDef; +mlir::Operation *findFirstMatchingOperandDef(mlir::Operation *rootOp, const std::function &condFn) +{ + LLVM_DEBUG(llvm::dbgs() << "[findFirstMatchingOperandDef] Current op: " << *rootOp << "\n"); + mlir::Value lhs = nullptr; + mlir::Value rhs = nullptr; + if (auto op = dyn_cast(rootOp)) { + lhs = op.getPtr(); + rhs = op.getOffset(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getLhs(); + rhs = op.getRhs(); + } else if (auto op = dyn_cast(rootOp)) { + lhs = op.getSrc(); + } else if (auto op = dyn_cast(rootOp)) { } else { - targetOp = findFirstMatchingOperandDef(lhsDef, condFn); + rootOp->emitRemark("Backtracing encounters unsupported Operation"); + return nullptr; } - if (targetOp) { - return targetOp; + // Backtrace operands + if (!lhs) { + return nullptr; } - } - if (!rhs) { - return nullptr; - } - auto rhsDef = rhs.getDefiningOp(); - if (rhsDef) { - if (condFn(rhsDef)) { - targetOp = rhsDef; - } else { - targetOp = findFirstMatchingOperandDef(rhsDef, condFn); + auto lhsDef = lhs.getDefiningOp(); + mlir::Operation *targetOp; + if (lhsDef) { + if (condFn(lhsDef)) { + targetOp = lhsDef; + } else { + targetOp = findFirstMatchingOperandDef(lhsDef, condFn); + } + if (targetOp) { + return targetOp; + } + } + if (!rhs) { + return nullptr; } - if (targetOp) { - return targetOp; + auto rhsDef = rhs.getDefiningOp(); + if (rhsDef) { + if (condFn(rhsDef)) { + targetOp = rhsDef; + } else { + targetOp = findFirstMatchingOperandDef(rhsDef, condFn); + } + if (targetOp) { + return targetOp; + } } - } - return nullptr; + return nullptr; } -void traverseBackwardUpdateOperandChainIf( - Operation *op, std::function conditionFn, - std::function actionFn, - OpBuilder &builder) { +void traverseBackwardUpdateOperandChainIf(Operation *op, std::function conditionFn, + std::function actionFn, OpBuilder &builder) +{ - if (!op) - return; + if (!op) + return; - if (conditionFn(op)) { - actionFn(builder, op); - } + if (conditionFn(op)) { + actionFn(builder, op); + } - for (Value operand : op->getOperands()) { - // TODO: handle BlockArgument - if (Operation *defOp = operand.getDefiningOp()) { - traverseBackwardUpdateOperandChainIf(defOp, conditionFn, actionFn, - builder); + for (Value operand : op->getOperands()) { + // TODO: handle BlockArgument + if (Operation *defOp = operand.getDefiningOp()) { + traverseBackwardUpdateOperandChainIf(defOp, conditionFn, actionFn, builder); + } } - } } // Note: rootOp will also be processed. -void traverseBackwardUpdateOperandChainIf( - Operation *rootOp, std::function conditionFn, - std::function actionFn) { +void traverseBackwardUpdateOperandChainIf(Operation *rootOp, std::function conditionFn, + std::function actionFn) +{ - OpBuilder builder(rootOp->getContext()); + OpBuilder builder(rootOp->getContext()); - traverseBackwardUpdateOperandChainIf(rootOp, conditionFn, actionFn, builder); + traverseBackwardUpdateOperandChainIf(rootOp, conditionFn, actionFn, builder); } -void traverseForwardUpdateUserChainIf( - Operation *op, std::function conditionFn, - std::function stopFn, - std::function actionFn, OpBuilder &builder, - llvm::SmallPtrSet &stopOps) { +void traverseForwardUpdateUserChainIf(Operation *op, std::function conditionFn, + std::function stopFn, + std::function actionFn, OpBuilder &builder, + llvm::SmallPtrSet &stopOps) +{ - if (!op) { - return; - } + if (!op) { + return; + } - if (stopFn(op)) { - stopOps.insert(op); - return; - } + if (stopFn(op)) { + stopOps.insert(op); + return; + } - if (conditionFn(op)) { - actionFn(builder, op); - } + if (conditionFn(op)) { + actionFn(builder, op); + } - for (auto res : op->getResults()) { - for (auto userOp : res.getUsers()) { - traverseForwardUpdateUserChainIf(userOp, conditionFn, stopFn, actionFn, - builder, stopOps); + for (auto res : op->getResults()) { + for (auto userOp : res.getUsers()) { + traverseForwardUpdateUserChainIf(userOp, conditionFn, stopFn, actionFn, builder, stopOps); + } } - } } // Note: rootOp will also be processed. -void traverseForwardUpdateUserChainIf( - Operation *rootOp, std::function conditionFn, - std::function stopFn, - std::function actionFn, - llvm::SmallPtrSet &stopOps) { +void traverseForwardUpdateUserChainIf(Operation *rootOp, std::function conditionFn, + std::function stopFn, + std::function actionFn, + llvm::SmallPtrSet &stopOps) +{ - OpBuilder builder(rootOp->getContext()); + OpBuilder builder(rootOp->getContext()); - traverseForwardUpdateUserChainIf(rootOp, conditionFn, stopFn, actionFn, - builder, stopOps); + traverseForwardUpdateUserChainIf(rootOp, conditionFn, stopFn, actionFn, builder, stopOps); } -bool isMetaUse(Operation *op) { return op->hasAttr("MetaUse"); } +bool isMetaUse(Operation *op) +{ + return op->hasAttr("MetaUse"); +} -bool isMixUse(Operation *op) { return op->hasAttr("MixUse"); } +bool isMixUse(Operation *op) +{ + return op->hasAttr("MixUse"); +} -IndirectLoadInterfaceOpType getIndirectLoadInterfaceOpType(Operation *op) { - auto ty = IndirectLoadInterfaceOpType::Undefined; - if (isMetaUse(op)) { - if (isa(op)) { - ty = IndirectLoadInterfaceOpType::Load; - } else if (isa(op)) { - ty = IndirectLoadInterfaceOpType::Calc; +IndirectLoadInterfaceOpType getIndirectLoadInterfaceOpType(Operation *op) +{ + if (!isMetaUse(op) && !isMixUse(op)) { + return IndirectLoadInterfaceOpType::Undefined; } - } - return ty; + // The semantics of these ops indicate that they cannot be statically analyzed. + static const auto &opTypeMap = *new llvm::DenseMap( + {{TypeID::get(), IndirectLoadInterfaceOpType::Load}, + {TypeID::get(), IndirectLoadInterfaceOpType::Calc}, + {TypeID::get(), IndirectLoadInterfaceOpType::Calc}}); + + auto it = opTypeMap.find(op->getRegisteredInfo()->getTypeID()); + return it != opTypeMap.end() ? it->second : IndirectLoadInterfaceOpType::Undefined; } -bool opIsIndirectLoad(Operation *op) { - auto opType = getIndirectLoadInterfaceOpType(op); - return opType == IndirectLoadInterfaceOpType::Load; +bool opIsIndirectLoad(Operation *op) +{ + auto opType = getIndirectLoadInterfaceOpType(op); + return opType == IndirectLoadInterfaceOpType::Load; } -bool opIsIndirectCalc(Operation *op) { - auto opType = getIndirectLoadInterfaceOpType(op); - return opType == IndirectLoadInterfaceOpType::Calc; +bool opIsIndirectCalc(Operation *op) +{ + auto opType = getIndirectLoadInterfaceOpType(op); + return opType == IndirectLoadInterfaceOpType::Calc; } -scf::ForOp createNestedLoops( - OpBuilder &builder, Location loc, unsigned currentDim, unsigned totalDims, - ValueRange LBs, ValueRange UBs, ValueRange steps, SmallVector &ivs, - ValueRange initArgs, - function_ref &, ValueRange)> - bodyBuilder) { +scf::ForOp createNestedLoops(OpBuilder &builder, Location loc, unsigned currentDim, unsigned totalDims, ValueRange LBs, + ValueRange UBs, ValueRange steps, SmallVector &ivs, ValueRange initArgs, + function_ref &, ValueRange)> bodyBuilder) +{ - if (currentDim >= totalDims) { - bodyBuilder(builder, loc, ivs, initArgs); - return nullptr; - } - - auto loop = builder.create( - loc, LBs[currentDim], UBs[currentDim], steps[currentDim], initArgs, - [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, - ValueRange iterArgs) { - ivs.push_back(iv); - auto innerLoop = createNestedLoops(nestedBuilder, nestedLoc, - currentDim + 1, totalDims, LBs, UBs, - steps, ivs, iterArgs, bodyBuilder); - if (innerLoop) { - nestedBuilder.create(loc, innerLoop.getResults()); - } - }); + if (currentDim >= totalDims) { + bodyBuilder(builder, loc, ivs, initArgs); + return nullptr; + } - return loop; + auto loop = + builder.create(loc, LBs[currentDim], UBs[currentDim], steps[currentDim], initArgs, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange iterArgs) { + ivs.push_back(iv); + auto innerLoop = + createNestedLoops(nestedBuilder, nestedLoc, currentDim + 1, totalDims, LBs, + UBs, steps, ivs, iterArgs, bodyBuilder); + if (innerLoop) { + nestedBuilder.create(loc, innerLoop.getResults()); + } + }); + + return loop; } -ModuleOp getModuleOpFromOperation(Operation *op) { - Operation *parent = op; - while (parent != nullptr && !isa(parent)) { - parent = parent->getParentOp(); // 向上查找 - } - return cast(parent); // 如果没找到会抛出异常 +ModuleOp getModuleOpFromOperation(Operation *op) +{ + Operation *parent = op; + while (parent != nullptr && !isa(parent)) { + parent = parent->getParentOp(); // 向上查找 + } + return cast(parent); // 如果没找到会抛出异常 } } // namespace triton -static Value createConstIndexValueOp(const Location &loc, OpBuilder &b, - int64_t value) { - return b.create(loc, b.getIndexAttr(value)).getResult(); +static Value createConstIndexValueOp(const Location &loc, OpBuilder &b, int64_t value) +{ + return b.create(loc, b.getIndexAttr(value)).getResult(); } -static std::optional getConstantOfAttr(const OpFoldResult &arg) { - if (isa(arg)) { - return getConstantIntValue(arg); - } +static std::optional getConstantOfAttr(const OpFoldResult &arg) +{ + if (isa(arg)) { + return getConstantIntValue(arg); + } - return std::nullopt; + return std::nullopt; } // TODO: imply these function below -OpFoldResult addOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = getConstantOfAttr(lhs); - auto rhsInt = getConstantOfAttr(rhs); - - if (lhsInt && rhsInt) - return b.getIndexAttr(lhsInt.value() + rhsInt.value()); - - if (!lhsInt && rhsInt && rhsInt.value() == 0) - return lhs; - if (!rhsInt && lhsInt && lhsInt.value() == 0) - return rhs; - - auto lhsValue = dyn_cast(lhs); - if (lhsInt) - lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); - else - assert(isa(lhsValue.getType())); - - auto rhsValue = dyn_cast(rhs); - if (rhsInt) - rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); - else - assert(isa(rhsValue.getType())); - - return b.create(loc, lhsValue, rhsValue).getResult(); +OpFoldResult addOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, const Location &loc, OpBuilder &b) +{ + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() + rhsInt.value()); + + if (!lhsInt && rhsInt && rhsInt.value() == 0) + return lhs; + if (!rhsInt && lhsInt && lhsInt.value() == 0) + return rhs; + + auto lhsValue = dyn_cast(lhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + auto rhsValue = dyn_cast(rhs); + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); } -OpFoldResult subOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = getConstantOfAttr(lhs); - auto rhsInt = getConstantOfAttr(rhs); +OpFoldResult subOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, const Location &loc, OpBuilder &b) +{ + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); - if (lhsInt && rhsInt) - return b.getIndexAttr(lhsInt.value() - rhsInt.value()); + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() - rhsInt.value()); - if (!lhsInt && rhsInt && rhsInt.value() == 0) - return lhs; + if (!lhsInt && rhsInt && rhsInt.value() == 0) + return lhs; - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) - lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); - else - assert(isa(lhsValue.getType())); + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); - if (rhsInt) - rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); - else - assert(isa(rhsValue.getType())); + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); - return b.create(loc, lhsValue, rhsValue).getResult(); + return b.create(loc, lhsValue, rhsValue).getResult(); } -OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = getConstantOfAttr(lhs); - auto rhsInt = getConstantOfAttr(rhs); - - if (lhsInt && rhsInt) - return b.getIndexAttr(lhsInt.value() * rhsInt.value()); - - if (lhsInt) { - if (lhsInt.value() == 0) - return lhs; - if (lhsInt.value() == 1) - return rhs; - } - if (rhsInt) { - if (rhsInt.value() == 0) - return rhs; - if (rhsInt.value() == 1) - return lhs; - } - - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) - lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); - else - assert(isa(lhsValue.getType())); - - if (rhsInt) - rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); - else - assert(isa(rhsValue.getType())); - - return b.create(loc, lhsValue, rhsValue).getResult(); +OpFoldResult mulOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, const Location &loc, OpBuilder &b) +{ + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() * rhsInt.value()); + + if (lhsInt) { + if (lhsInt.value() == 0) + return lhs; + if (lhsInt.value() == 1) + return rhs; + } + if (rhsInt) { + if (rhsInt.value() == 0) + return rhs; + if (rhsInt.value() == 1) + return lhs; + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); } -OpFoldResult divOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = getConstantOfAttr(lhs); - auto rhsInt = getConstantOfAttr(rhs); - - if (rhsInt && rhsInt.value() == 0) { - emitError(loc) << "cannot div 0!"; - return OpFoldResult(); - } - - if (lhsInt && rhsInt) - return b.getIndexAttr(lhsInt.value() / rhsInt.value()); - - if (lhsInt) { - if (lhsInt.value() == 0) - return lhs; - } - - if (rhsInt) { - if (rhsInt.value() == 1) - return lhs; - } - - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) - lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); - else - assert(isa(lhsValue.getType())); - - if (rhsInt) - rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); - else - assert(isa(rhsValue.getType())); - - return b.create(loc, lhsValue, rhsValue).getResult(); +OpFoldResult divOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, const Location &loc, OpBuilder &b) +{ + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + + if (rhsInt && rhsInt.value() == 0) { + emitError(loc) << "cannot div 0!"; + return OpFoldResult(); + } + + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() / rhsInt.value()); + + if (lhsInt) { + if (lhsInt.value() == 0) + return lhs; + } + + if (rhsInt) { + if (rhsInt.value() == 1) + return lhs; + } + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); } -OpFoldResult remOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = getConstantOfAttr(lhs); - auto rhsInt = getConstantOfAttr(rhs); +OpFoldResult remOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, const Location &loc, OpBuilder &b) +{ + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); - if (rhsInt && rhsInt.value() == 0) { - emitError(loc) << "cannot remainder by 0!"; - return OpFoldResult(); - } + if (rhsInt && rhsInt.value() == 0) { + emitError(loc) << "cannot remainder by 0!"; + return OpFoldResult(); + } - if (lhsInt && rhsInt) - return b.getIndexAttr(lhsInt.value() % rhsInt.value()); + if (lhsInt && rhsInt) + return b.getIndexAttr(lhsInt.value() % rhsInt.value()); - if (lhsInt) { - if (lhsInt.value() == 0) - return lhs; - } + if (lhsInt) { + if (lhsInt.value() == 0) + return lhs; + } - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) - lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); - else - assert(isa(lhsValue.getType())); + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); - if (rhsInt) - rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); - else - assert(isa(rhsValue.getType())); + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); - return b.create(loc, lhsValue, rhsValue).getResult(); + return b.create(loc, lhsValue, rhsValue).getResult(); } -OpFoldResult minOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = getConstantOfAttr(lhs); - auto rhsInt = getConstantOfAttr(rhs); - if (lhsInt && rhsInt) - return b.getIndexAttr(std::min(lhsInt.value(), rhsInt.value())); - - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) - lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); - else - assert(isa(lhsValue.getType())); - - if (rhsInt) - rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); - else - assert(isa(rhsValue.getType())); - - return b.create(loc, lhsValue, rhsValue).getResult(); +OpFoldResult minOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, const Location &loc, OpBuilder &b) +{ + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + if (lhsInt && rhsInt) + return b.getIndexAttr(std::min(lhsInt.value(), rhsInt.value())); + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); } -OpFoldResult maxOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, - const Location &loc, OpBuilder &b) { - auto lhsInt = getConstantOfAttr(lhs); - auto rhsInt = getConstantOfAttr(rhs); - if (lhsInt && rhsInt) - return b.getIndexAttr(std::max(lhsInt.value(), rhsInt.value())); - - auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); - if (lhsInt) - lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); - else - assert(isa(lhsValue.getType())); - - if (rhsInt) - rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); - else - assert(isa(rhsValue.getType())); - - return b.create(loc, lhsValue, rhsValue).getResult(); +OpFoldResult maxOpFoldResult(const OpFoldResult &lhs, const OpFoldResult &rhs, const Location &loc, OpBuilder &b) +{ + auto lhsInt = getConstantOfAttr(lhs); + auto rhsInt = getConstantOfAttr(rhs); + if (lhsInt && rhsInt) + return b.getIndexAttr(std::max(lhsInt.value(), rhsInt.value())); + + auto lhsValue = dyn_cast(lhs), rhsValue = dyn_cast(rhs); + if (lhsInt) + lhsValue = createConstIndexValueOp(loc, b, lhsInt.value()); + else + assert(isa(lhsValue.getType())); + + if (rhsInt) + rhsValue = createConstIndexValueOp(loc, b, rhsInt.value()); + else + assert(isa(rhsValue.getType())); + + return b.create(loc, lhsValue, rhsValue).getResult(); } -LogicalResult -addReduceWithIndexAttrIfNeeded(ConversionPatternRewriter &rewriter, - linalg::ReduceOp reduceOp) { - // To verify whether the operation of the reduceOp is ReduceWithIndex - // TODO: maybe a better way of judging? - auto ctx = reduceOp.getContext(); - Block &body = reduceOp.getCombiner().front(); - auto yieldOp = dyn_cast(body.getTerminator()); - - auto yieldValue = yieldOp.getValues(); - if (yieldValue.size() == 0) { - return failure(); - } - - auto opIter = reduceOp.getBody()->without_terminator().begin(); - auto cmpMaskOp = dyn_cast(*opIter); - const StringRef reduceRef = "reduce_mode"; - if (cmpMaskOp) { - if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OGT) { - reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); - } else if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OLT) { - reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); - } - } - - auto cmpMaskIOp = dyn_cast(*opIter); - if (cmpMaskIOp) { - if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::sgt) { - reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); - } else if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::slt) { - reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); - } - } - - return success(); +LogicalResult addReduceWithIndexAttrIfNeeded(ConversionPatternRewriter &rewriter, linalg::ReduceOp reduceOp) +{ + // To verify whether the operation of the reduceOp is ReduceWithIndex + // TODO: maybe a better way of judging? + auto ctx = reduceOp.getContext(); + Block &body = reduceOp.getCombiner().front(); + auto yieldOp = dyn_cast(body.getTerminator()); + + auto yieldValue = yieldOp.getValues(); + if (yieldValue.size() == 0) { + return failure(); + } + + auto opIter = reduceOp.getBody()->without_terminator().begin(); + auto cmpMaskOp = dyn_cast(*opIter); + const StringRef reduceRef = "reduce_mode"; + if (cmpMaskOp) { + if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OGT) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); + } else if (cmpMaskOp.getPredicate() == arith::CmpFPredicate::OLT) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); + } + } + + auto cmpMaskIOp = dyn_cast(*opIter); + if (cmpMaskIOp) { + if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::sgt) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("max_with_index")); + } else if (cmpMaskIOp.getPredicate() == arith::CmpIPredicate::slt) { + reduceOp->setAttr(reduceRef, rewriter.getStringAttr("min_with_index")); + } + } + + return success(); } // Fold layout constant info to attr, otherwise convert to index type value -OpFoldResult getOpFoldResultOfLayoutInfo(Value value, OpBuilder &builder) { - OpFoldResult constantFold = getAsOpFoldResult(value); - if (llvm::isa(constantFold)) { - assert(isa(constantFold.get())); - return constantFold; - } - - if (!isa(value.getType())) - llvm_unreachable("Illegal data type when parse block data layout info"); - - if (!isa(value.getType())) { - if (value.getType().isInteger(/*width*/ 1)) - value = builder.create( - value.getLoc(), builder.getIndexType(), value); - else - value = builder.create(value.getLoc(), - builder.getIndexType(), value); - } +OpFoldResult getOpFoldResultOfLayoutInfo(Value value, OpBuilder &builder) +{ + OpFoldResult constantFold = getAsOpFoldResult(value); + if (llvm::isa(constantFold)) { + assert(isa(constantFold.get())); + return constantFold; + } + + if (!isa(value.getType())) + llvm_unreachable("Illegal data type when parse block data layout info"); + + if (!isa(value.getType())) { + if (value.getType().isInteger(/*width*/ 1)) + value = builder.create(value.getLoc(), builder.getIndexType(), value); + else + value = builder.create(value.getLoc(), builder.getIndexType(), value); + } - return value; + return value; } } // namespace mlir