From 4cb35ba52f8c9fbfc669a6210d531e92cc008083 Mon Sep 17 00:00:00 2001 From: Mario Drumond m00933363 Date: Thu, 28 Aug 2025 20:57:47 +0800 Subject: [PATCH 1/9] Added regression test as starting point. --- .../TritonToLinalg/indirect_load.mlir | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 ascend/test/Conversion/TritonToLinalg/indirect_load.mlir diff --git a/ascend/test/Conversion/TritonToLinalg/indirect_load.mlir b/ascend/test/Conversion/TritonToLinalg/indirect_load.mlir new file mode 100644 index 0000000..54c9947 --- /dev/null +++ b/ascend/test/Conversion/TritonToLinalg/indirect_load.mlir @@ -0,0 +1,83 @@ +// RUN: triton-adapter-opt --triton-to-linalg %s | FileCheck %s + +module { + tt.func public @indirect_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { + %c256_i64 = arith.constant 256 : i64 + %cst = arith.constant dense<0.000000e+00> : tensor<256xf32> + %cst_0 = arith.constant dense<0> : tensor<256xi32> + %cst_1 = arith.constant dense<-2147483648> : tensor<256xi64> + %cst_2 = arith.constant dense<2147483647> : tensor<256xi64> + %c-2147483648_i64 = arith.constant -2147483648 : i64 + %c2147483647_i64 = arith.constant 2147483647 : i64 + %c256_i32 = arith.constant 256 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.extsi %0 : i32 to i64 + %2 = arith.muli %1, %c256_i64 : i64 + %3 = arith.cmpi sle, %2, %c2147483647_i64 : i64 + %4 = arith.cmpi sge, %2, %c-2147483648_i64 : i64 + %5 = arith.andi %3, %4 : i1 + tt.assert %5, "int32 overflow detected for operation mul" : i1 + %6 = arith.muli %0, %c256_i32 : i32 + %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %8 = tt.splat %6 : i32 -> tensor<256xi32> + %9 = arith.extsi %6 : i32 to i64 + %10 = tt.splat %9 : i64 -> tensor<256xi64> + %11 = arith.extsi %7 : tensor<256xi32> to tensor<256xi64> + %12 = arith.addi %10, %11 : tensor<256xi64> + %13 = arith.cmpi sle, %12, %cst_2 : tensor<256xi64> + %14 = arith.cmpi sge, %12, %cst_1 : tensor<256xi64> + %15 = arith.andi %13, %14 : tensor<256xi1> + tt.assert %15, "int32 overflow detected for operation add" : tensor<256xi1> + %16 = arith.addi %8, %7 : tensor<256xi32> + %17 = tt.splat %arg3 : i32 -> tensor<256xi32> + %18 = arith.cmpi slt, %16, %17 : tensor<256xi32> + %19 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr> + %20 = tt.addptr %19, %16 : tensor<256x!tt.ptr>, tensor<256xi32> + %21 = tt.load %20, %18, %cst_0 : tensor<256x!tt.ptr> + %22 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> + %23 = tt.addptr %22, %21 : tensor<256x!tt.ptr>, tensor<256xi32> + %24 = tt.load %23, %18, %cst : tensor<256x!tt.ptr> + %25 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr> + %26 = tt.addptr %25, %16 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %26, %24 : tensor<256x!tt.ptr> + tt.return + } +} + +// CHECK-LABEL: func.func @indirect_load( +// CHECK-SAME: %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref, %[[VAL_2:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %[[VAL_3:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %[[VAL_4:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32, %[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { +// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_13:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_14:.*]] = arith.constant 256 : index +// CHECK: %[[VAL_15:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_16:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_17:.*]] = tensor.empty() : tensor<256xf32> +// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_9]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_18]] : i32 to index +// CHECK: %[[VAL_20:.*]] = memref.reinterpret_cast %[[VAL_3]] to offset: {{\[}}%[[VAL_19]]], sizes: [256], strides: [1] : memref to memref<256xi32, strided<[1], offset: ?>> +// CHECK: %[[VAL_21:.*]] = memref.alloc() : memref<256xi32> +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_14]] : index +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_5]] : i32 to index +// CHECK: %[[VAL_24:.*]] = arith.maxsi %[[VAL_19]], %[[VAL_23]] : index +// CHECK: %[[VAL_25:.*]] = arith.minsi %[[VAL_22]], %[[VAL_24]] : index +// CHECK: %[[VAL_26:.*]] = arith.subi %[[VAL_25]], %[[VAL_19]] : index +// CHECK: %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_26]], %[[VAL_14]] : index +// CHECK: scf.if %[[VAL_27]] { +// CHECK: linalg.fill ins(%[[VAL_16]] : i32) outs(%[[VAL_21]] : memref<256xi32>) +// CHECK: } +// CHECK: %[[VAL_28:.*]] = memref.subview %[[VAL_20]][0] {{\[}}%[[VAL_26]]] [1] : memref<256xi32, strided<[1], offset: ?>> to memref> +// CHECK: %[[VAL_29:.*]] = memref.subview %[[VAL_21]][0] {{\[}}%[[VAL_26]]] [1] : memref<256xi32> to memref> +// CHECK: memref.copy %[[VAL_28]], %[[VAL_29]] : memref> to memref> +// CHECK: %[[VAL_30:.*]] = bufferization.to_tensor %[[VAL_21]] restrict writable : memref<256xi32> +// CHECK: %[[VAL_31:.*]] = scf.for %[[VAL_32:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_12]] iter_args(%[[VAL_33:.*]] = %[[VAL_17]]) -> (tensor<256xf32>) { +// CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_30]]{{\[}}%[[VAL_32]]] : tensor<256xi32> +// CHECK: %[[VAL_35:.*]] = arith.index_cast %[[VAL_34]] : i32 to index +// CHECK: %[[VAL_36:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_35]]], sizes: [1], strides: [1] : memref to memref<1xf32, strided<[1], offset: ?>> +// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_13]]] : memref<1xf32, strided<[1], offset: ?>> +// CHECK: %[[VAL_38:.*]] = tensor.insert %[[VAL_37]] into %[[VAL_33]]{{\[}}%[[VAL_32]]] : tensor<256xf32> +// CHECK: scf.yield %[[VAL_38]] : tensor<256xf32> +// CHECK: } +// CHECK: %[[VAL_39:.*]] = memref.reinterpret_cast %[[VAL_4]] to offset: {{\[}}%[[VAL_19]]], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_31]] in writable %[[VAL_39]] : (tensor<256xf32>, memref<256xf32, strided<[1], offset: ?>>) -> () +// CHECK: return +// CHECK: } -- Gitee From a5bd078897156a1e0e6685c2f224dac3a4d1c563 Mon Sep 17 00:00:00 2001 From: Mario Drumond m00933363 Date: Fri, 29 Aug 2025 15:46:50 +0800 Subject: [PATCH 2/9] First try: Generic Linalg operation --- .../lib/TritonToLinalg/BlockPtrAnalysis.cpp | 83 +++++-------------- .../lib/TritonToLinalg/LoadStoreConverter.cpp | 64 +++++++------- 2 files changed, 54 insertions(+), 93 deletions(-) diff --git a/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp b/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp index 1f53f63..397fe5c 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp @@ -1286,68 +1286,27 @@ void BlockDataParser::rewriteAddPtrToUnstrucMemAcc( triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, ConversionPatternRewriter &rewriter, BlockData &data) { auto loc = op.getLoc(); - auto &offsets = data.getOffsetsRef(); - auto &blockSizes = data.getSizesRef(); - auto &strides = data.getStridesRef(); - Value ptrOffset = adaptor.getOffset(); - Value zeroIdx = - rewriter.create(loc, rewriter.getIndexAttr(0)); - Value oneIdx = - rewriter.create(loc, rewriter.getIndexAttr(1)); - auto addptrRes = op.getResult(); - assert(addptrRes.hasOneUse() && "Invalid: tt.addptr has multiple users"); - auto loadOp = *(addptrRes.user_begin()); - - // Prepare empty tensor for loop based scalar load - // FIXME: We use cast here because addptr must return tensor>. - // True? - auto resTy = cast(addptrRes.getType()); - auto resEPtrTy = resTy.getElementType(); - auto resETy = cast(resEPtrTy).getPointeeType(); - Value loaded = rewriter.create(loc, blockSizes, resETy); - SmallVector initArgs; - initArgs.push_back(loaded); - - SmallVector forLBs; - SmallVector forUBs; - SmallVector forSteps; - for (auto &s : offsets) { - forLBs.push_back(zeroIdx); - } - for (auto &s : blockSizes) { - forUBs.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, s)); - } - for (auto &s : strides) { - forSteps.push_back(oneIdx); - } - SmallVector ivs; - OpBuilder builder(op); - auto loop = createNestedLoops( - builder, loc, 0, blockSizes.size(), forLBs, forUBs, forSteps, ivs, - initArgs, - [&](OpBuilder &bB, Location bLoc, SmallVector &allIVs, - ValueRange iterArgs) { - OpBuilder::InsertionGuard g(bB); - bB.setInsertionPointToStart(bB.getBlock()); - - Value scalarOffsetRaw = - bB.create(bLoc, ptrOffset, allIVs); - Value scalarOffset = bB.create( - bLoc, bB.getIndexType(), scalarOffsetRaw); - // Replace offset & size. Only single element. - data.getOffsetsRef().clear(); - data.getOffsetsRef().push_back(scalarOffset); - data.getSizesRef().clear(); - data.getSizesRef().push_back(bB.getIndexAttr(1)); - data.getStridesRef().clear(); - data.getStridesRef().push_back(bB.getIndexAttr(1)); - memref::ReinterpretCastOp castOp = data.createCastOp({1}, bLoc, bB); - rewriter.replaceOp(op, castOp); - // Move tt.load using this tt.addptr into this block - loadOp->moveAfter(castOp); - loadOp->setAttr("IndirectLoad", UnitAttr::get(op.getContext())); - bB.create(bLoc, iterArgs); - }); + + // The addptr result must feed exactly one load so we can collapse the pair + Value addPtrRes = op.getResult(); + assert(addPtrRes.hasOneUse() && + "tt.addptr used multiple times – gather conversion requires a single " + "tt.load user"); + auto loadOp = cast(*addPtrRes.user_begin()); + + // We keep the base pointer and the index tensor separately; the load will be + // rewritten to a linalg.gather later in LoadConverter. + Value basePtr = adaptor.getPtr(); + Value indexTnsr = adaptor.getOffset(); + + // Mark the load so the later converter knows to emit a gather. + loadOp->setAttr("IndirectLoad", rewriter.getUnitAttr()); + + // Replace the tt.addptr result with the original pointer; thread the index + // tensor into the load (operand #1 is unused in the unstructured case). + rewriter.replaceOp(op, basePtr); + loadOp->setOperand(0, basePtr); + loadOp->setOperand(1, indexTnsr); } } // namespace triton diff --git a/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp b/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp index 1259903..0b08d2f 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -102,38 +103,39 @@ LoadConverter::checkModifiedByAddPtrConverter(triton::LoadOp &op) const { LogicalResult LoadConverter::continueModifyFromAddPtrConverter( triton::LoadOp &op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto forOp = op->getParentOfType(); - Operation *firstOp = &forOp.getBody()->front(); - auto extractOp = cast(firstOp); - auto ivs = extractOp.getIndices(); - // Single iterArg which is inserted by AddPtrConverter. - auto iterArg = forOp.getRegionIterArg(0); - auto ptr = adaptor.getPtr(); + // Only loads tagged by BlockPtrAnalysis are handled here. + if (!op->hasAttr("IndirectLoad")) + return failure(); + + auto loc = op.getLoc(); + Value base = adaptor.getPtr(); // operand #0: base pointer + Value indices = op.getOperand(1); // operand #1: index tensor + + // Result type of the gather (same as the original load result). + Type resultTy = op.getResult().getType(); + + // Build a linalg.generic that gathers elements at "indices" from "base". + auto maps = SmallVector{ + // map for base tensor + rewriter.getMultiDimIdentityMap(1), + // map for indices tensor + rewriter.getMultiDimIdentityMap(1)}; + + auto iterTypes = SmallVector{getParallelIteratorTypeName()}; + + auto gather = rewriter.create( + loc, resultTy, + /*inputs=*/ValueRange{base, indices}, + /*outputs=*/ValueRange{}, + maps, iterTypes, + [&](OpBuilder &b, Location l, ValueRange args) { + // args[0] – element from base; args[1] – index + Value elem = b.create(l, args[0], args[1]); + b.create(l, elem); + }); - rewriter.setInsertionPointAfter(op); - Value castVal = ptr.getDefiningOp(); - Value idxZero = - rewriter.create(loc, rewriter.getIndexAttr(0)); - Value loadVal = - rewriter.create(loc, castVal, ValueRange{idxZero}); - Value insertedVal = - rewriter.create(loc, loadVal, iterArg, ValueRange{ivs}); - // a yield op is already created by AddPtrConverter. - // so we need to replace it with a new yield op. - Operation *terminator = forOp.getBody()->getTerminator(); - scf::YieldOp oldYieldOp = cast(terminator); - auto yieldOp = rewriter.create(loc, ValueRange{insertedVal}); - rewriter.replaceOp(oldYieldOp, yieldOp); - // Now the scf.for is complete, we can replace tt.load with it. - auto rank = cast(op.getResult().getType()).getShape().size(); - Operation *rootForOp = op; - while (rank != 0) { - rank--; - rootForOp = rootForOp->getParentOfType(); - } - rewriter.replaceOp(op, rootForOp); - LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(rootForOp) << "\n"; }); + // Replace the original load with the gather op. + rewriter.replaceOp(op, gather.getResult(0)); return success(); } -- Gitee From 534cd2eede61a594e57b5807656c0c7da392f6d8 Mon Sep 17 00:00:00 2001 From: Mario Drumond m00933363 Date: Fri, 29 Aug 2025 17:25:57 +0800 Subject: [PATCH 3/9] Adding LinalgExt and plugging it. --- ascend/triton-adapter/include/CMakeLists.txt | 1 + .../include/LinalgExt/CMakeLists.txt | 8 +++ .../include/LinalgExt/LinalgExtDialect.h | 8 +++ .../include/LinalgExt/LinalgExtOps.h | 10 +++ .../include/LinalgExt/LinalgExtOps.td | 42 +++++++++++++ .../TritonToLinalg/LoadStoreConverter.h | 4 -- ascend/triton-adapter/lib/CMakeLists.txt | 1 + .../lib/LinalgExt/CMakeLists.txt | 3 + .../lib/LinalgExt/LinalgExtDialect.cpp | 15 +++++ .../lib/TritonToLinalg/BlockPtrAnalysis.cpp | 16 ++--- .../lib/TritonToLinalg/LoadStoreConverter.cpp | 62 +++++-------------- .../lib/TritonToLinalg/TritonToLinalgPass.cpp | 7 ++- 12 files changed, 111 insertions(+), 66 deletions(-) create mode 100644 ascend/triton-adapter/include/LinalgExt/CMakeLists.txt create mode 100644 ascend/triton-adapter/include/LinalgExt/LinalgExtDialect.h create mode 100644 ascend/triton-adapter/include/LinalgExt/LinalgExtOps.h create mode 100644 ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td create mode 100644 ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt create mode 100644 ascend/triton-adapter/lib/LinalgExt/LinalgExtDialect.cpp diff --git a/ascend/triton-adapter/include/CMakeLists.txt b/ascend/triton-adapter/include/CMakeLists.txt index 9fed756..4782084 100644 --- a/ascend/triton-adapter/include/CMakeLists.txt +++ b/ascend/triton-adapter/include/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(LinalgExt) add_subdirectory(TritonToAnnotation) add_subdirectory(TritonToHIVM) add_subdirectory(TritonToLinalg) \ No newline at end of file diff --git a/ascend/triton-adapter/include/LinalgExt/CMakeLists.txt b/ascend/triton-adapter/include/LinalgExt/CMakeLists.txt new file mode 100644 index 0000000..98a3100 --- /dev/null +++ b/ascend/triton-adapter/include/LinalgExt/CMakeLists.txt @@ -0,0 +1,8 @@ +# ascend/triton-adapter/include/LinalgExt/CMakeLists.txt +set(LLVM_TARGET_DEFINITIONS LinalgExtOps.td) +mlir_tablegen(LinalgExtOps.h.inc -gen-op-decls) +mlir_tablegen(LinalgExtOps.cpp.inc -gen-op-defs) +mlir_tablegen(LinalgExtOpsDialect.h.inc -gen-dialect-decls -dialect=linalg_ext) +mlir_tablegen(LinalgExtOpsDialect.cpp.inc -gen-dialect-defs -dialect=linalg_ext) +add_public_tablegen_target(LinalgExtOpsIncGen) +add_dependencies(mlir-headers LinalgExtOpsIncGen) diff --git a/ascend/triton-adapter/include/LinalgExt/LinalgExtDialect.h b/ascend/triton-adapter/include/LinalgExt/LinalgExtDialect.h new file mode 100644 index 0000000..2fe2113 --- /dev/null +++ b/ascend/triton-adapter/include/LinalgExt/LinalgExtDialect.h @@ -0,0 +1,8 @@ +#ifndef LINALGEXT_LINALGEXTDIALECT_H +#define LINALGEXT_LINALGEXTDIALECT_H + +#include "mlir/IR/Dialect.h" + +#include "LinalgExt/LinalgExtOpsDialect.h.inc" + +#endif diff --git a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.h b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.h new file mode 100644 index 0000000..0fb7727 --- /dev/null +++ b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.h @@ -0,0 +1,10 @@ +#ifndef LINALGEXT_LINALGEXTOPS_H +#define LINALGEXT_LINALGEXTOPS_H + +#include "LinalgExt/LinalgExtDialect.h" +#include "mlir/IR/OpImplementation.h" + +#define GET_OP_CLASSES +#include "LinalgExt/LinalgExtOps.h.inc" + +#endif diff --git a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td new file mode 100644 index 0000000..2e88496 --- /dev/null +++ b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td @@ -0,0 +1,42 @@ +#ifndef LINALG_EXT_OPS_TD +#define LINALG_EXT_OPS_TD + +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinOps.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/Dialect/Linalg/IR/LinalgOps.td" + +def LinalgExt_Dialect : Dialect { + let name = "linalg_ext"; + let cppNamespace = "::linalg_ext"; +} + +class LinalgExt_Op traits = []> + : Op; + +def GathermaskOp : LinalgExt_Op<"gathermask", [NoSideEffect]> { + let summary = "Gather with optional mask/other"; + let description = [{ + Gathers elements from `base` at `indices`. If a mask is supplied and the + mask value is false, the corresponding result element is `other`. + }]; + + let arguments = (ins + AnyTensor:$base, + AnyTensor:$indices, + Optional:$mask, + Optional:$other + ); + let results = (outs AnyTensor:$result); + + let assemblyFormat = [{ + $base `,` $indices (`,` $mask `,` $other^)? attr-dict + `:` type($base) `,` type($indices) `->` type($result) + }]; + + let builders = [ + OpBuilder<(ins "Type":$resultTy, "Value":$base, "Value":$indices, + "Value":$mask = "Value()", "Value":$other = "Value()")> + ]; +} +#endif // LINALG_EXT_OPS_TD diff --git a/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h b/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h index c1d2dbc..2e0bd22 100644 --- a/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h +++ b/ascend/triton-adapter/include/TritonToLinalg/LoadStoreConverter.h @@ -40,10 +40,6 @@ private: LogicalResult checkModifiedByAddPtrConverter(triton::LoadOp &op) const; - LogicalResult - continueModifyFromAddPtrConverter(triton::LoadOp &op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const; - void fillTensorWithOtherForMaskScenario(Value other, memref::AllocOp localMem, ArrayRef maskDim, diff --git a/ascend/triton-adapter/lib/CMakeLists.txt b/ascend/triton-adapter/lib/CMakeLists.txt index eb3a681..1d1ec28 100644 --- a/ascend/triton-adapter/lib/CMakeLists.txt +++ b/ascend/triton-adapter/lib/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(LinalgExt) add_subdirectory(TritonToAnnotation) add_subdirectory(TritonToHIVM) add_subdirectory(TritonToLinalg) diff --git a/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt b/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt new file mode 100644 index 0000000..78a9b77 --- /dev/null +++ b/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_dialect(LinalgExtOps linalg_ext + TD_FILE ${CMAKE_SOURCE_DIR}/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td) +add_mlir_doc(LinalgExtOps LinalgExtDialect LinalgExt/ -gen-op-doc) diff --git a/ascend/triton-adapter/lib/LinalgExt/LinalgExtDialect.cpp b/ascend/triton-adapter/lib/LinalgExt/LinalgExtDialect.cpp new file mode 100644 index 0000000..5456283 --- /dev/null +++ b/ascend/triton-adapter/lib/LinalgExt/LinalgExtDialect.cpp @@ -0,0 +1,15 @@ +#include "LinalgExt/LinalgExtDialect.h" +#include "LinalgExt/LinalgExtOps.h" +#include "mlir/IR/DialectImplementation.h" + +using namespace mlir; +using namespace linalg_ext; + +#include "LinalgExt/LinalgExtOpsDialect.cpp.inc" + +void LinalgExtDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "LinalgExt/LinalgExtOps.cpp.inc" + >(); +} diff --git a/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp b/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp index 397fe5c..0cadf8c 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp @@ -1285,25 +1285,17 @@ void BlockDataParser::rewriteForOp( void BlockDataParser::rewriteAddPtrToUnstrucMemAcc( triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, ConversionPatternRewriter &rewriter, BlockData &data) { - auto loc = op.getLoc(); - - // The addptr result must feed exactly one load so we can collapse the pair - Value addPtrRes = op.getResult(); - assert(addPtrRes.hasOneUse() && - "tt.addptr used multiple times – gather conversion requires a single " - "tt.load user"); + auto addPtrRes = op.getResult(); + assert(addPtrRes.hasOneUse() && "tt.addptr has multiple users"); auto loadOp = cast(*addPtrRes.user_begin()); - // We keep the base pointer and the index tensor separately; the load will be - // rewritten to a linalg.gather later in LoadConverter. Value basePtr = adaptor.getPtr(); Value indexTnsr = adaptor.getOffset(); - // Mark the load so the later converter knows to emit a gather. + // Tag the consumer load so LoadStoreConverter emits a gather. loadOp->setAttr("IndirectLoad", rewriter.getUnitAttr()); - // Replace the tt.addptr result with the original pointer; thread the index - // tensor into the load (operand #1 is unused in the unstructured case). + // Thread the base pointer and index tensor directly. rewriter.replaceOp(op, basePtr); loadOp->setOperand(0, basePtr); loadOp->setOperand(1, indexTnsr); diff --git a/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp b/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp index 0b08d2f..a3c20b7 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp @@ -5,6 +5,7 @@ // //===----------------------------------------------------------------------===// +#include "LinalgExt/LinalgExtOps.h" #include "TritonToLinalg/LoadStoreConverter.h" #include "TritonToLinalg/BlockPtrAnalysis.h" #include "TritonToLinalg/MaskAnalysis.h" @@ -93,51 +94,6 @@ LoadConverter::checkModifiedByAddPtrConverter(triton::LoadOp &op) const { return success(); } -/// @brief Continue to modify the triton::LoadOp from the state modified by the -/// AddPtrConverter. -/// @param op The triton::LoadOp operation to be processed. -/// @param adaptor The adaptor for the operation, used to obtain operands. -/// @param rewriter The pattern rewriter used to rewrite the operation. -/// @return Return success if the operation is successful; otherwise, return -/// failure. -LogicalResult LoadConverter::continueModifyFromAddPtrConverter( - triton::LoadOp &op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // Only loads tagged by BlockPtrAnalysis are handled here. - if (!op->hasAttr("IndirectLoad")) - return failure(); - - auto loc = op.getLoc(); - Value base = adaptor.getPtr(); // operand #0: base pointer - Value indices = op.getOperand(1); // operand #1: index tensor - - // Result type of the gather (same as the original load result). - Type resultTy = op.getResult().getType(); - - // Build a linalg.generic that gathers elements at "indices" from "base". - auto maps = SmallVector{ - // map for base tensor - rewriter.getMultiDimIdentityMap(1), - // map for indices tensor - rewriter.getMultiDimIdentityMap(1)}; - - auto iterTypes = SmallVector{getParallelIteratorTypeName()}; - - auto gather = rewriter.create( - loc, resultTy, - /*inputs=*/ValueRange{base, indices}, - /*outputs=*/ValueRange{}, - maps, iterTypes, - [&](OpBuilder &b, Location l, ValueRange args) { - // args[0] – element from base; args[1] – index - Value elem = b.create(l, args[0], args[1]); - b.create(l, elem); - }); - - // Replace the original load with the gather op. - rewriter.replaceOp(op, gather.getResult(0)); - return success(); -} void LoadConverter::fillTensorWithOtherForMaskScenario( Value other, memref::AllocOp localMem, ArrayRef maskDim, @@ -186,9 +142,19 @@ LogicalResult LoadConverter::matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - // Check if tt.load is modified by AddPtrConverter to a specified state. - if (checkModifiedByAddPtrConverter(op).succeeded()) { - return continueModifyFromAddPtrConverter(op, adaptor, rewriter); + // Indirect loads (load → load chains) become linalg_ext.gathermask. + if (op->hasAttr("IndirectLoad")) { + auto loc = op.getLoc(); + Value base = adaptor.getPtr(); + Value indices = op.getOperand(1); + Value mask = op.getMask(); + Value other = op.getOther(); + auto resultTy = cast(op.getType()); + + auto gather = rewriter.create( + loc, resultTy, base, indices, mask, other); + rewriter.replaceOp(op, gather.getResult()); + return success(); } auto ptr = adaptor.getPtr(); diff --git a/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp b/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp index 1b44a2c..b42dece 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp @@ -1,3 +1,5 @@ +#include "LinalgExt/LinalgExtDialect.h" +#include "LinalgExt/LinalgExtOps.h" #include "TritonToLinalg/TritonToLinalgPass.h" #include "TritonToLinalg/ArgMinMaxConverter.h" #include "TritonToLinalg/FunctionConverter.h" @@ -355,10 +357,11 @@ void TritonToLinalgPass::addDynamicLegal( linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, cf::ControlFlowDialect, tensor::TensorDialect, LLVM::LLVMDialect, bufferization::BufferizationDialect, memref::MemRefDialect, + linalg_ext::LinalgExtDialect, annotation::AnnotationDialect, hivm::HIVMDialect>(); // add legal dialect on condition - target.addLegalOp(); + target.addLegalOp(); // 根据条件判断需要转换的OP target.addDynamicallyLegalOp( @@ -541,7 +544,7 @@ void TritonToLinalgPass::getDependentDialects(DialectRegistry ®istry) const { registry.insert(); + memref::MemRefDialect, linalg_ext::LinalgExtDialect>(); } LogicalResult TritonToLinalgPass::processDescriptorOperations(ModuleOp moduleOp) -- Gitee From a973253bf03c66d85caae7ad387617b6e02c85a1 Mon Sep 17 00:00:00 2001 From: Mario Drumond m00933363 Date: Fri, 29 Aug 2025 17:41:33 +0800 Subject: [PATCH 4/9] Cmake works but build fails. --- .../include/LinalgExt/CMakeLists.txt | 1 - .../include/LinalgExt/LinalgExtOps.td | 2 +- .../triton-adapter/lib/LinalgExt/CMakeLists.txt | 16 +++++++++++++--- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ascend/triton-adapter/include/LinalgExt/CMakeLists.txt b/ascend/triton-adapter/include/LinalgExt/CMakeLists.txt index 98a3100..1a98590 100644 --- a/ascend/triton-adapter/include/LinalgExt/CMakeLists.txt +++ b/ascend/triton-adapter/include/LinalgExt/CMakeLists.txt @@ -1,4 +1,3 @@ -# ascend/triton-adapter/include/LinalgExt/CMakeLists.txt set(LLVM_TARGET_DEFINITIONS LinalgExtOps.td) mlir_tablegen(LinalgExtOps.h.inc -gen-op-decls) mlir_tablegen(LinalgExtOps.cpp.inc -gen-op-defs) diff --git a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td index 2e88496..2c96a3b 100644 --- a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td +++ b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td @@ -11,7 +11,7 @@ def LinalgExt_Dialect : Dialect { let cppNamespace = "::linalg_ext"; } -class LinalgExt_Op traits = []> +class LinalgExt_Op traits = []> : Op; def GathermaskOp : LinalgExt_Op<"gathermask", [NoSideEffect]> { diff --git a/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt b/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt index 78a9b77..a3a01af 100644 --- a/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt +++ b/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt @@ -1,3 +1,13 @@ -add_mlir_dialect(LinalgExtOps linalg_ext - TD_FILE ${CMAKE_SOURCE_DIR}/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td) -add_mlir_doc(LinalgExtOps LinalgExtDialect LinalgExt/ -gen-op-doc) +add_mlir_dialect_library(MLIRLinalgExt + LinalgExtDialect.cpp + + DEPENDS + LinalgExtOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR +) + +# Docs: point TD_FILE to the include location +add_mlir_doc(LinalgExtOps LinalgExtDialect LinalgExt/ -gen-op-doc + TD_FILE ${CMAKE_SOURCE_DIR}/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td) \ No newline at end of file -- Gitee From a8ac087c82c73554d5605fa69b6f37863247fd39 Mon Sep 17 00:00:00 2001 From: Mario Drumond m00933363 Date: Fri, 29 Aug 2025 18:43:32 +0800 Subject: [PATCH 5/9] Compile but does not link. --- .../include/LinalgExt/LinalgExtDialect.h | 1 + .../include/LinalgExt/LinalgExtOps.h | 8 +++++ .../include/LinalgExt/LinalgExtOps.td | 33 +++++++++++-------- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/ascend/triton-adapter/include/LinalgExt/LinalgExtDialect.h b/ascend/triton-adapter/include/LinalgExt/LinalgExtDialect.h index 2fe2113..b42bffb 100644 --- a/ascend/triton-adapter/include/LinalgExt/LinalgExtDialect.h +++ b/ascend/triton-adapter/include/LinalgExt/LinalgExtDialect.h @@ -2,6 +2,7 @@ #define LINALGEXT_LINALGEXTDIALECT_H #include "mlir/IR/Dialect.h" +#include "mlir/IR/BuiltinTypes.h" #include "LinalgExt/LinalgExtOpsDialect.h.inc" diff --git a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.h b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.h index 0fb7727..3d2ebc5 100644 --- a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.h +++ b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.h @@ -1,8 +1,16 @@ #ifndef LINALGEXT_LINALGEXTOPS_H #define LINALGEXT_LINALGEXTOPS_H + #include "LinalgExt/LinalgExtDialect.h" + #include "mlir/IR/OpImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #define GET_OP_CLASSES #include "LinalgExt/LinalgExtOps.h.inc" diff --git a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td index 2c96a3b..4fd0951 100644 --- a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td +++ b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td @@ -2,19 +2,24 @@ #define LINALG_EXT_OPS_TD include "mlir/IR/OpBase.td" -include "mlir/IR/BuiltinOps.td" -include "mlir/IR/BuiltinTypes.td" -include "mlir/Dialect/Linalg/IR/LinalgOps.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" + def LinalgExt_Dialect : Dialect { let name = "linalg_ext"; - let cppNamespace = "::linalg_ext"; + let cppNamespace = "::mlir::linalg_ext"; } -class LinalgExt_Op traits = []> +class LinalgExt_Op traits = []> : Op; -def GathermaskOp : LinalgExt_Op<"gathermask", [NoSideEffect]> { +def GathermaskOp + : LinalgExt_Op<"gathermask", [Pure, AttrSizedOperandSegments]> { let summary = "Gather with optional mask/other"; let description = [{ Gathers elements from `base` at `indices`. If a mask is supplied and the @@ -27,16 +32,18 @@ def GathermaskOp : LinalgExt_Op<"gathermask", [NoSideEffect]> { Optional:$mask, Optional:$other ); - let results = (outs AnyTensor:$result); - let assemblyFormat = [{ - $base `,` $indices (`,` $mask `,` $other^)? attr-dict - `:` type($base) `,` type($indices) `->` type($result) - }]; + let results = (outs AnyTensor:$result); let builders = [ - OpBuilder<(ins "Type":$resultTy, "Value":$base, "Value":$indices, - "Value":$mask = "Value()", "Value":$other = "Value()")> + OpBuilder<(ins + "::mlir::Type":$resultTy, + "::mlir::Value":$base, + "::mlir::Value":$indices, + CArg<"::mlir::Value","::mlir::Value()">:$mask, + CArg<"::mlir::Value","::mlir::Value()">:$other + )> ]; + } #endif // LINALG_EXT_OPS_TD -- Gitee From 7e1124c9ae4e01f5851f1318bba08720f866ecfa Mon Sep 17 00:00:00 2001 From: Mario Drumond m00933363 Date: Fri, 29 Aug 2025 18:57:34 +0800 Subject: [PATCH 6/9] The linker is still failing. --- ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt | 10 +++++++++- .../triton-adapter/lib/LinalgExt/LinalgExtDialect.cpp | 2 +- ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp | 5 +++++ .../triton-adapter/lib/TritonToLinalg/CMakeLists.txt | 1 + 4 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp diff --git a/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt b/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt index a3a01af..b0009bd 100644 --- a/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt +++ b/ascend/triton-adapter/lib/LinalgExt/CMakeLists.txt @@ -1,13 +1,21 @@ add_mlir_dialect_library(MLIRLinalgExt LinalgExtDialect.cpp + LinalgExtOps.cpp DEPENDS LinalgExtOpsIncGen LINK_LIBS PUBLIC MLIRIR + MLIRSideEffectInterfaces + MLIRLinalgDialect + MLIRTilingInterface +) + +target_include_directories(MLIRLinalgExt PUBLIC + ${CMAKE_SOURCE_DIR}/ascend/triton-adapter/include + ${CMAKE_BINARY_DIR}/ascend/triton-adapter/include ) -# Docs: point TD_FILE to the include location add_mlir_doc(LinalgExtOps LinalgExtDialect LinalgExt/ -gen-op-doc TD_FILE ${CMAKE_SOURCE_DIR}/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td) \ No newline at end of file diff --git a/ascend/triton-adapter/lib/LinalgExt/LinalgExtDialect.cpp b/ascend/triton-adapter/lib/LinalgExt/LinalgExtDialect.cpp index 5456283..a63ccf4 100644 --- a/ascend/triton-adapter/lib/LinalgExt/LinalgExtDialect.cpp +++ b/ascend/triton-adapter/lib/LinalgExt/LinalgExtDialect.cpp @@ -3,7 +3,7 @@ #include "mlir/IR/DialectImplementation.h" using namespace mlir; -using namespace linalg_ext; +using namespace mlir::linalg_ext; #include "LinalgExt/LinalgExtOpsDialect.cpp.inc" diff --git a/ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp b/ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp new file mode 100644 index 0000000..d7a4a6c --- /dev/null +++ b/ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp @@ -0,0 +1,5 @@ +#include "LinalgExt/LinalgExtOps.h" +#include "mlir/IR/Builders.h" + +#define GET_OP_CLASSES +#include "LinalgExt/LinalgExtOps.cpp.inc" \ No newline at end of file diff --git a/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt b/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt index 5730838..77bbd25 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt +++ b/ascend/triton-adapter/lib/TritonToLinalg/CMakeLists.txt @@ -15,6 +15,7 @@ add_triton_library(TritonToLinalg LINK_LIBS PUBLIC MLIRArithDialect MLIRDialectUtils + MLIRLinalgExt MLIRIR MLIRMathDialect MLIRPass -- Gitee From 9dd13609ca00b8ffd7eb1daa54a21ee69cff073a Mon Sep 17 00:00:00 2001 From: Mario Drumond m00933363 Date: Fri, 29 Aug 2025 19:00:30 +0800 Subject: [PATCH 7/9] Adding Type and Value build instead of *Range --- ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp b/ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp index d7a4a6c..c8d9f49 100644 --- a/ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp +++ b/ascend/triton-adapter/lib/LinalgExt/LinalgExtOps.cpp @@ -2,4 +2,13 @@ #include "mlir/IR/Builders.h" #define GET_OP_CLASSES -#include "LinalgExt/LinalgExtOps.cpp.inc" \ No newline at end of file +#include "LinalgExt/LinalgExtOps.cpp.inc" + +using namespace mlir; +using namespace mlir::linalg_ext; + +void GathermaskOp::build(OpBuilder &b, OperationState &st, + Type resultTy, Value base, Value indices, + Value mask, Value other) { + build(b, st, TypeRange{resultTy}, base, indices, mask, other); +} \ No newline at end of file -- Gitee From db2613f9999d237ecb16d44f5e8d2767a40de74d Mon Sep 17 00:00:00 2001 From: Mario Drumond m00933363 Date: Mon, 1 Sep 2025 04:27:53 +0800 Subject: [PATCH 8/9] Trying to convert base address generation in gathermask. --- .../include/LinalgExt/LinalgExtOps.td | 2 +- .../lib/TritonToLinalg/BlockPtrAnalysis.cpp | 58 ++++++++++++++++--- .../lib/TritonToLinalg/LoadStoreConverter.cpp | 16 ----- 3 files changed, 51 insertions(+), 25 deletions(-) diff --git a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td index 4fd0951..53abb67 100644 --- a/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td +++ b/ascend/triton-adapter/include/LinalgExt/LinalgExtOps.td @@ -27,7 +27,7 @@ def GathermaskOp }]; let arguments = (ins - AnyTensor:$base, + AnyMemRef:$base, AnyTensor:$indices, Optional:$mask, Optional:$other diff --git a/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp b/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp index 0cadf8c..b40ff7e 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/BlockPtrAnalysis.cpp @@ -6,6 +6,7 @@ //===----------------------------------------------------------------------===// #include "TritonToLinalg/BlockPtrAnalysis.h" #include "Utils/Utils.h" +#include "LinalgExt/LinalgExtOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -1285,20 +1286,61 @@ void BlockDataParser::rewriteForOp( void BlockDataParser::rewriteAddPtrToUnstrucMemAcc( triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor, ConversionPatternRewriter &rewriter, BlockData &data) { + auto addPtrRes = op.getResult(); assert(addPtrRes.hasOneUse() && "tt.addptr has multiple users"); auto loadOp = cast(*addPtrRes.user_begin()); + auto loc = op.getLoc(); - Value basePtr = adaptor.getPtr(); - Value indexTnsr = adaptor.getOffset(); + // 1) Result type (of the load) + auto resTy = cast(loadOp.getType()); + auto elemTy = resTy.getElementType(); + + // 2) Compute linear block size N = product(blockSizes) + auto &blockSizes = data.getSizesRef(); + SmallVector dynSizes; + dynSizes.reserve(blockSizes.size()); + for (auto s : blockSizes) + dynSizes.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, s)); + + // multiply all sizes -> 1D length + Value len = dynSizes.front(); + for (size_t i = 1; i < dynSizes.size(); ++i) + len = rewriter.create(loc, len, dynSizes[i]); + + // 3) Program BlockData to a single 1-D contiguous view starting at offset 0 + auto &offs = data.getOffsetsRef(); + auto &sizes = data.getSizesRef(); + auto &strides = data.getStridesRef(); + offs.clear(); + offs.push_back(rewriter.getIndexAttr(0)); // OpFoldResult := Attribute + sizes.clear(); + sizes.push_back(len); // dynamic length + strides.clear(); + strides.push_back(rewriter.getIndexAttr(1)); + + // 4) Create memref view: memref> (base for gather) + memref::ReinterpretCastOp baseView = data.createCastOp({1}, loc, rewriter); + + // 5) Wire up gather (single op), no loops, no pointer tensors involved. + Value indices = adaptor.getOffset(); // tensor<...xi32> (same shape as result) + Value mask = loadOp.getMask(); // tensor<...xi1> or null + Value other = loadOp.getOther(); // tensor<...xT> or null + + // If the op requires an explicit mask, synthesize all-true when absent. + if (!mask) { + auto mTy = RankedTensorType::get(resTy.getShape(), rewriter.getI1Type()); + Value c1 = rewriter.create(loc, rewriter.getBoolAttr(true)); + Value temp = rewriter.create(loc, resTy.getShape(), mTy.getElementType()); + mask = rewriter.create(loc, c1, temp).getResult(0); + } - // Tag the consumer load so LoadStoreConverter emits a gather. - loadOp->setAttr("IndirectLoad", rewriter.getUnitAttr()); + auto gather = rewriter.create( + loc, resTy, /*base=*/baseView, /*indices=*/indices, /*mask=*/mask, /*other=*/other); - // Thread the base pointer and index tensor directly. - rewriter.replaceOp(op, basePtr); - loadOp->setOperand(0, basePtr); - loadOp->setOperand(1, indexTnsr); + // 6) Replace the consumer and erase the addptr. + rewriter.replaceOp(loadOp, gather.getResult()); + rewriter.eraseOp(op); } } // namespace triton diff --git a/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp b/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp index a3c20b7..171087a 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp @@ -5,7 +5,6 @@ // //===----------------------------------------------------------------------===// -#include "LinalgExt/LinalgExtOps.h" #include "TritonToLinalg/LoadStoreConverter.h" #include "TritonToLinalg/BlockPtrAnalysis.h" #include "TritonToLinalg/MaskAnalysis.h" @@ -142,21 +141,6 @@ LogicalResult LoadConverter::matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - // Indirect loads (load → load chains) become linalg_ext.gathermask. - if (op->hasAttr("IndirectLoad")) { - auto loc = op.getLoc(); - Value base = adaptor.getPtr(); - Value indices = op.getOperand(1); - Value mask = op.getMask(); - Value other = op.getOther(); - auto resultTy = cast(op.getType()); - - auto gather = rewriter.create( - loc, resultTy, base, indices, mask, other); - rewriter.replaceOp(op, gather.getResult()); - return success(); - } - auto ptr = adaptor.getPtr(); auto mask = op.getMask(); auto other = op.getOther(); -- Gitee From a6eeb5733bd03924cec0a7e19a8a86eee92057b8 Mon Sep 17 00:00:00 2001 From: Mario Drumond m00933363 Date: Mon, 1 Sep 2025 15:47:10 +0800 Subject: [PATCH 9/9] Fixing tests. --- .../TritonToLinalg/indirect_load.mlir | 81 +++++++++++-------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/ascend/test/Conversion/TritonToLinalg/indirect_load.mlir b/ascend/test/Conversion/TritonToLinalg/indirect_load.mlir index 54c9947..e022451 100644 --- a/ascend/test/Conversion/TritonToLinalg/indirect_load.mlir +++ b/ascend/test/Conversion/TritonToLinalg/indirect_load.mlir @@ -43,41 +43,56 @@ module { tt.return } } - +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @indirect_load( // CHECK-SAME: %[[VAL_0:.*]]: memref, %[[VAL_1:.*]]: memref, %[[VAL_2:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %[[VAL_3:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 0 : i32}, %[[VAL_4:.*]]: memref {tt.divisibility = 16 : i32, tt.tensor_kind = 1 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: i32, %[[VAL_8:.*]]: i32, %[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: i32) attributes {SyncBlockLockArgIdx = 0 : i64, WorkspaceArgIdx = 1 : i64, global_kernel = "", mix_mode = "aiv"} { -// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_13:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_14:.*]] = arith.constant 256 : index -// CHECK: %[[VAL_15:.*]] = arith.constant 256 : i32 -// CHECK: %[[VAL_16:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_17:.*]] = tensor.empty() : tensor<256xf32> -// CHECK: %[[VAL_18:.*]] = arith.muli %[[VAL_9]], %[[VAL_15]] : i32 -// CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_18]] : i32 to index -// CHECK: %[[VAL_20:.*]] = memref.reinterpret_cast %[[VAL_3]] to offset: {{\[}}%[[VAL_19]]], sizes: [256], strides: [1] : memref to memref<256xi32, strided<[1], offset: ?>> -// CHECK: %[[VAL_21:.*]] = memref.alloc() : memref<256xi32> -// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_14]] : index -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_5]] : i32 to index -// CHECK: %[[VAL_24:.*]] = arith.maxsi %[[VAL_19]], %[[VAL_23]] : index -// CHECK: %[[VAL_25:.*]] = arith.minsi %[[VAL_22]], %[[VAL_24]] : index -// CHECK: %[[VAL_26:.*]] = arith.subi %[[VAL_25]], %[[VAL_19]] : index -// CHECK: %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_26]], %[[VAL_14]] : index -// CHECK: scf.if %[[VAL_27]] { -// CHECK: linalg.fill ins(%[[VAL_16]] : i32) outs(%[[VAL_21]] : memref<256xi32>) -// CHECK: } -// CHECK: %[[VAL_28:.*]] = memref.subview %[[VAL_20]][0] {{\[}}%[[VAL_26]]] [1] : memref<256xi32, strided<[1], offset: ?>> to memref> -// CHECK: %[[VAL_29:.*]] = memref.subview %[[VAL_21]][0] {{\[}}%[[VAL_26]]] [1] : memref<256xi32> to memref> -// CHECK: memref.copy %[[VAL_28]], %[[VAL_29]] : memref> to memref> -// CHECK: %[[VAL_30:.*]] = bufferization.to_tensor %[[VAL_21]] restrict writable : memref<256xi32> -// CHECK: %[[VAL_31:.*]] = scf.for %[[VAL_32:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_12]] iter_args(%[[VAL_33:.*]] = %[[VAL_17]]) -> (tensor<256xf32>) { -// CHECK: %[[VAL_34:.*]] = tensor.extract %[[VAL_30]]{{\[}}%[[VAL_32]]] : tensor<256xi32> -// CHECK: %[[VAL_35:.*]] = arith.index_cast %[[VAL_34]] : i32 to index -// CHECK: %[[VAL_36:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_35]]], sizes: [1], strides: [1] : memref to memref<1xf32, strided<[1], offset: ?>> -// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_13]]] : memref<1xf32, strided<[1], offset: ?>> -// CHECK: %[[VAL_38:.*]] = tensor.insert %[[VAL_37]] into %[[VAL_33]]{{\[}}%[[VAL_32]]] : tensor<256xf32> -// CHECK: scf.yield %[[VAL_38]] : tensor<256xf32> +// CHECK: %[[VAL_12:.*]] = arith.constant 256 : index +// CHECK: %[[VAL_13:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_15:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_16:.*]] = tensor.empty() : tensor<256xf32> +// CHECK: %[[VAL_17:.*]] = linalg.fill ins(%[[VAL_15]] : f32) outs(%[[VAL_16]] : tensor<256xf32>) -> tensor<256xf32> +// CHECK: %[[VAL_18:.*]] = tensor.empty() : tensor<256xi32> +// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_9]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_20:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]]], iterator_types = ["parallel"]} outs(%[[VAL_18]] : tensor<256xi32>) { +// CHECK: ^bb0(%[[VAL_21:.*]]: i32): +// CHECK: %[[VAL_22:.*]] = linalg.index 0 : index +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : index to i32 +// CHECK: linalg.yield %[[VAL_23]] : i32 +// CHECK: } -> tensor<256xi32> +// CHECK: %[[VAL_24:.*]] = linalg.fill ins(%[[VAL_19]] : i32) outs(%[[VAL_18]] : tensor<256xi32>) -> tensor<256xi32> +// CHECK: %[[VAL_25:.*]] = tensor.empty() : tensor<256xi1> +// CHECK: %[[VAL_26:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_24]], %[[VAL_20]] : tensor<256xi32>, tensor<256xi32>) outs(%[[VAL_24]] : tensor<256xi32>) { +// CHECK: ^bb0(%[[VAL_27:.*]]: i32, %[[VAL_28:.*]]: i32, %[[VAL_29:.*]]: i32): +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] {MixUse} : i32 +// CHECK: linalg.yield %[[VAL_30]] : i32 +// CHECK: } -> tensor<256xi32> +// CHECK: %[[VAL_31:.*]] = linalg.fill ins(%[[VAL_5]] : i32) outs(%[[VAL_18]] : tensor<256xi32>) -> tensor<256xi32> +// CHECK: %[[VAL_32:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel"]} ins(%[[VAL_26]], %[[VAL_31]] : tensor<256xi32>, tensor<256xi32>) outs(%[[VAL_25]] : tensor<256xi1>) { +// CHECK: ^bb0(%[[VAL_33:.*]]: i32, %[[VAL_34:.*]]: i32, %[[VAL_35:.*]]: i1): +// CHECK: %[[VAL_36:.*]] = arith.cmpi slt, %[[VAL_33]], %[[VAL_34]] {MixUse} : i32 +// CHECK: linalg.yield %[[VAL_36]] : i1 +// CHECK: } -> tensor<256xi1> +// CHECK: %[[VAL_37:.*]] = arith.index_cast %[[VAL_19]] : i32 to index +// CHECK: %[[VAL_38:.*]] = memref.reinterpret_cast %[[VAL_3]] to offset: {{\[}}%[[VAL_37]]], sizes: [256], strides: [1] : memref to memref<256xi32, strided<[1], offset: ?>> +// CHECK: %[[VAL_39:.*]] = memref.alloc() : memref<256xi32> +// CHECK: %[[VAL_40:.*]] = arith.addi %[[VAL_37]], %[[VAL_12]] : index +// CHECK: %[[VAL_41:.*]] = arith.index_cast %[[VAL_5]] : i32 to index +// CHECK: %[[VAL_42:.*]] = arith.maxsi %[[VAL_37]], %[[VAL_41]] : index +// CHECK: %[[VAL_43:.*]] = arith.minsi %[[VAL_40]], %[[VAL_42]] : index +// CHECK: %[[VAL_44:.*]] = arith.subi %[[VAL_43]], %[[VAL_37]] : index +// CHECK: %[[VAL_45:.*]] = arith.cmpi slt, %[[VAL_44]], %[[VAL_12]] : index +// CHECK: scf.if %[[VAL_45]] { +// CHECK: linalg.fill ins(%[[VAL_14]] : i32) outs(%[[VAL_39]] : memref<256xi32>) // CHECK: } -// CHECK: %[[VAL_39:.*]] = memref.reinterpret_cast %[[VAL_4]] to offset: {{\[}}%[[VAL_19]]], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_31]] in writable %[[VAL_39]] : (tensor<256xf32>, memref<256xf32, strided<[1], offset: ?>>) -> () +// CHECK: %[[VAL_46:.*]] = memref.subview %[[VAL_38]][0] {{\[}}%[[VAL_44]]] [1] : memref<256xi32, strided<[1], offset: ?>> to memref> +// CHECK: %[[VAL_47:.*]] = memref.subview %[[VAL_39]][0] {{\[}}%[[VAL_44]]] [1] : memref<256xi32> to memref> +// CHECK: memref.copy %[[VAL_46]], %[[VAL_47]] : memref> to memref> +// CHECK: %[[VAL_48:.*]] = bufferization.to_tensor %[[VAL_39]] restrict writable : memref<256xi32> +// CHECK: %[[VAL_49:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: {{\[}}%[[VAL_12]]], strides: [1] : memref to memref<1xf32, strided<[1]>> +// CHECK: %[[VAL_50:.*]] = "linalg_ext.gathermask"(%[[VAL_49]], %[[VAL_48]], %[[VAL_32]], %[[VAL_17]]) <{operandSegmentSizes = array}> : (memref<1xf32, strided<[1]>>, tensor<256xi32>, tensor<256xi1>, tensor<256xf32>) -> tensor<256xf32> +// CHECK: %[[VAL_51:.*]] = memref.reinterpret_cast %[[VAL_4]] to offset: {{\[}}%[[VAL_37]]], sizes: [256], strides: [1] : memref to memref<256xf32, strided<[1], offset: ?>> +// CHECK: bufferization.materialize_in_destination %[[VAL_50]] in writable %[[VAL_51]] : (tensor<256xf32>, memref<256xf32, strided<[1], offset: ?>>) -> () // CHECK: return // CHECK: } + -- Gitee