diff --git a/ascend/examples/pytest_ut/test_ldst.py b/ascend/examples/pytest_ut/test_ldst.py index ed116a914259360cd18a9cae9c87fd820db95bc9..6d419781bcac24e76b6c0271774a829d491c8582 100644 --- a/ascend/examples/pytest_ut/test_ldst.py +++ b/ascend/examples/pytest_ut/test_ldst.py @@ -405,6 +405,53 @@ def test_ldst_indirect_08(): triton_cal = triton_ldst_indirect_08_func(xc, x2, blocksize, lowdimsize) torch.testing.assert_close(triton_cal, torch_ref) +def test_ldst_indirect_09(): + + @triton.jit + def triton_ldst_indirect_09_kernel( + out_ptr0, in_ptr1, in_ptr2, stride_in_r, + offset: tl.constexpr, XS: tl.constexpr, RS: tl.constexpr + ): + pid = tl.program_id(0) + in_idx0 = tl.arange(0, XS) + in_idx1 = tl.arange(0, RS) + tmp0 = pid * XS + tl.load(in_ptr1 + in_idx0) + tmp1 = tl.arange(0, RS) + offset + in_idx2 = tmp0[:, None] * stride_in_r + tmp1[None, :] + tmp2 = tl.load(in_ptr2 + in_idx2) + tmp2 = tl_math.exp(tmp2) + out0_idx = pid * XS * RS + in_idx0[:, None] * RS + in_idx1[None, :] + tl.store(out_ptr0 + out0_idx, tmp2) + + def triton_ldst_indirect_09_func(xr, x2, offset, xs, rs): + nr = xr.numel() + nc = rs + stride_in_r = x2.stride()[0] + y0 = torch.empty((nr, nc), dtype=x2.dtype, device=x2.device) + triton_ldst_indirect_09_kernel[nr // xs, 1, 1]( + y0, xr, x2, stride_in_r, offset = offset, XS = xs, RS = rs) + return y0 + + def torch_ldst_indirect_09_func(xr, xc, x2): + flatten_idx = (xr[:, None] * x2.stride()[0] + xc[None, :]).flatten() + extracted = x2.flatten()[flatten_idx].reshape([xr.numel(), xc.numel()]) + return torch.exp(extracted) + + DEV = "npu" + DTYPE = torch.float32 + offset = 8 + N0, N1 = 16, 32 + blocksize = 8 + lowdimsize = N0 + assert N1 >= N0+offset, "N1 must be >= N0+offset" + assert N0 == lowdimsize, "N0 must be == lowdimsize" + xc = offset + torch.arange(0, N0, device=DEV) + xr = torch.arange(0, blocksize, device=DEV) + x2 = torch.randn((blocksize, N1), dtype=DTYPE, device=DEV) + torch_ref = torch_ldst_indirect_09_func(xr, xc, x2) + triton_cal = triton_ldst_indirect_09_func(xr, x2, offset, blocksize, lowdimsize) + torch.testing.assert_close(triton_cal, torch_ref) + if __name__ == "__main__": test_ldst_indirect_05() print("success: test_ldst_indirect_05") \ No newline at end of file diff --git a/ascend/test/Conversion/TritonToUnstructure/unstructure_mix.mlir b/ascend/test/Conversion/TritonToUnstructure/unstructure_mix.mlir new file mode 100644 index 0000000000000000000000000000000000000000..57189868ddcb081b955e64dc8ace1f9d0d15665c --- /dev/null +++ b/ascend/test/Conversion/TritonToUnstructure/unstructure_mix.mlir @@ -0,0 +1,83 @@ +// RUN: triton-adapter-opt --triton-to-unstructure %s | FileCheck %s + +tt.func public @indirect_mix_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<16> : tensor<1x8xi32> + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c8_i32 : i32 + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %3 = tt.splat %1 : i32 -> tensor<8xi32> + %4 = arith.addi %3, %2 : tensor<8xi32> + %5 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr> + %7 = tt.addptr %6, %5 : tensor<16x!tt.ptr>, tensor<16xi32> + %8 = tt.load %7 : tensor<16x!tt.ptr> + %9 = tt.expand_dims %2 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %10 = tt.splat %arg3 : i32 -> tensor<1x8xi32> + %11 = arith.muli %9, %10 : tensor<1x8xi32> + %12 = tt.expand_dims %8 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64> + %13 = arith.extsi %11 : tensor<1x8xi32> to tensor<1x8xi64> + %14 = tt.broadcast %13 : tensor<1x8xi64> -> tensor<16x8xi64> + %15 = tt.broadcast %12 : tensor<16x1xi64> -> tensor<16x8xi64> + %16 = arith.addi %14, %15 : tensor<16x8xi64> + %17 = tt.splat %arg2 : !tt.ptr -> tensor<16x8x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<16x8x!tt.ptr>, tensor<16x8xi64> + %19 = tt.load %18 : tensor<16x8x!tt.ptr> + %20 = math.exp %19 : tensor<16x8xf32> + %21 = tt.expand_dims %4 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %22 = arith.muli %21, %cst : tensor<1x8xi32> + %23 = tt.expand_dims %5 {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> + %24 = tt.broadcast %22 : tensor<1x8xi32> -> tensor<16x8xi32> + %25 = tt.broadcast %23 : tensor<16x1xi32> -> tensor<16x8xi32> + %26 = arith.addi %24, %25 : tensor<16x8xi32> + %27 = tt.splat %arg0 : !tt.ptr -> tensor<16x8x!tt.ptr> + %28 = tt.addptr %27, %26 : tensor<16x8x!tt.ptr>, tensor<16x8xi32> + tt.store %28, %20 : tensor<16x8x!tt.ptr> + tt.return +} + +// CHECK-LABEL: tt.func public @indirect_mix_kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_7:.*]] = arith.constant dense<16> : tensor<1x8xi32> +// CHECK: %[[VAL_8:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_9:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_11:.*]] = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> +// CHECK: %[[VAL_12:.*]] = tt.splat %[[VAL_10]] : i32 -> tensor<8xi32> +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_11]] : tensor<8xi32> +// CHECK: %[[VAL_14:.*]] = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> +// CHECK: %[[VAL_15:.*]] = tt.splat %[[VAL_1]] : !tt.ptr -> tensor<16x!tt.ptr> +// CHECK: %[[VAL_16:.*]] = tt.addptr %[[VAL_15]], %[[VAL_14]] : tensor<16x!tt.ptr>, tensor<16xi32> +// CHECK: %[[VAL_17:.*]] = tt.load %[[VAL_16]] : tensor<16x!tt.ptr> +// CHECK: %[[VAL_18:.*]] = tt.expand_dims %[[VAL_11]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_19:.*]] = tt.splat %[[VAL_3]] : i32 -> tensor<1x8xi32> +// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_18]], %[[VAL_19]] : tensor<1x8xi32> +// CHECK: %[[VAL_21:.*]] = tt.expand_dims %[[VAL_17]] {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64> +// CHECK: %[[VAL_22:.*]] = arith.extsi %[[VAL_20]] : tensor<1x8xi32> to tensor<1x8xi64> +// CHECK: %[[VAL_23:.*]] = tt.broadcast %[[VAL_22]] : tensor<1x8xi64> -> tensor<16x8xi64> +// CHECK: %[[VAL_24:.*]] = tt.broadcast %[[VAL_21]] : tensor<16x1xi64> -> tensor<16x8xi64> +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : tensor<16x8xi64> +// CHECK: %[[VAL_26:.*]] = tensor.empty() : tensor<16x8xf32> +// CHECK: %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] iter_args(%[[VAL_29:.*]] = %[[VAL_26]]) -> (tensor<16x8xf32>) { +// CHECK: %[[VAL_30:.*]] = tensor.extract_slice %[[VAL_25]]{{\[}}%[[VAL_28]], 0] [1, 8] [1, 1] {DiscreteMemAccess} : tensor<16x8xi64> to tensor<1x8xi64> +// CHECK: %[[VAL_31:.*]] = tt.splat %[[VAL_2]] : !tt.ptr -> tensor<1x8x!tt.ptr> +// CHECK: %[[VAL_32:.*]] = tt.addptr %[[VAL_31]], %[[VAL_30]] : tensor<1x8x!tt.ptr>, tensor<1x8xi64> +// CHECK: %[[VAL_33:.*]] = tt.load %[[VAL_32]] {DiscreteMemAccess} : tensor<1x8x!tt.ptr> +// CHECK: %[[VAL_34:.*]] = tensor.insert_slice %[[VAL_33]] into %[[VAL_29]]{{\[}}%[[VAL_28]], 0] [1, 8] [1, 1] : tensor<1x8xf32> into tensor<16x8xf32> +// CHECK: scf.yield {DiscreteMemAccess} %[[VAL_34]] : tensor<16x8xf32> +// CHECK: } {ExtractedLoadOrStore} +// CHECK: %[[VAL_35:.*]] = math.exp %[[VAL_27]] : tensor<16x8xf32> +// CHECK: %[[VAL_36:.*]] = tt.expand_dims %[[VAL_13]] {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> +// CHECK: %[[VAL_37:.*]] = arith.muli %[[VAL_36]], %[[VAL_7]] : tensor<1x8xi32> +// CHECK: %[[VAL_38:.*]] = tt.expand_dims %[[VAL_14]] {axis = 1 : i32} : tensor<16xi32> -> tensor<16x1xi32> +// CHECK: %[[VAL_39:.*]] = tt.broadcast %[[VAL_37]] : tensor<1x8xi32> -> tensor<16x8xi32> +// CHECK: %[[VAL_40:.*]] = tt.broadcast %[[VAL_38]] : tensor<16x1xi32> -> tensor<16x8xi32> +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_39]], %[[VAL_40]] : tensor<16x8xi32> +// CHECK: %[[VAL_42:.*]] = tt.splat %[[VAL_0]] : !tt.ptr -> tensor<16x8x!tt.ptr> +// CHECK: %[[VAL_43:.*]] = tt.addptr %[[VAL_42]], %[[VAL_41]] : tensor<16x8x!tt.ptr>, tensor<16x8xi32> +// CHECK: tt.store %[[VAL_43]], %[[VAL_35]] : tensor<16x8x!tt.ptr> +// CHECK: tt.return +// CHECK: } \ No newline at end of file diff --git a/ascend/triton-adapter/include/TritonToUnstructure/BubbleUpOperation.h b/ascend/triton-adapter/include/TritonToUnstructure/BubbleUpOperation.h index ee975662a41cf5bda31f30e1062670625a84e5e1..bca4c61c91d83dec6f4df969eb5d72cea584ab09 100644 --- a/ascend/triton-adapter/include/TritonToUnstructure/BubbleUpOperation.h +++ b/ascend/triton-adapter/include/TritonToUnstructure/BubbleUpOperation.h @@ -6,10 +6,10 @@ #include "mlir/IR/PatternMatch.h" #define GEN_PASS_DECL_BUBBLEUPOPERATION -#include "../../include/TritonToUnstructure/Passes.h.inc" +#include "ascend/triton-adapter/include/TritonToUnstructure/Passes.h.inc" #define GEN_PASS_DEF_BUBBLEUPOPERATION -#include "../../include/TritonToUnstructure/Passes.h.inc" +#include "ascend/triton-adapter/include/TritonToUnstructure/Passes.h.inc" namespace mlir { namespace triton { @@ -23,102 +23,59 @@ createBubbleUpOperationPass(const BubbleUpOperationOptions &options = {}); using namespace mlir; using namespace triton; -class BubbleUpExtract : public OpRewritePattern { +template +class BubbleUpExtract : public OpRewritePattern { + static_assert(std::is_same_v || + std::is_same_v); + public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; explicit BubbleUpExtract(MLIRContext *context, bool enableAggressiveMode); - LogicalResult matchAndRewrite(tensor::ExtractOp op, + LogicalResult matchAndRewrite(ExtractOpTy op, PatternRewriter &rewriter) const override; private: - Value createExtractOp(Value value, ArrayRef indices, Location loc, + Value createExtractOp(ExtractOpTy op, Value value, Location loc, PatternRewriter &rewriter) const; template - void bubbleUpIntBinaryOp(Operation *op, BinOpTy binOp, - ArrayRef indices, Location loc, + void bubbleUpIntBinaryOp(ExtractOpTy op, BinOpTy binOp, Location loc, PatternRewriter &rewriter) const; template - void bubbleUpFloatBinaryOp(Operation *op, BinOpTy binOp, - ArrayRef indices, Location loc, + void bubbleUpFloatBinaryOp(ExtractOpTy op, BinOpTy binOp, Location loc, PatternRewriter &rewriter) const; - template - void bubbleUpOperation(Operation *op, ParentOpTy parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const = delete; - - template <> - void bubbleUpOperation(Operation *op, arith::ExtSIOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - - template <> - void bubbleUpOperation(Operation *op, arith::CmpIOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - template <> - void bubbleUpOperation(Operation *op, - arith::TruncFOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - - template <> - void bubbleUpOperation(Operation *op, - arith::ExtFOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - - template <> - void bubbleUpOperation(Operation *op, - arith::FPToSIOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - template <> - void bubbleUpOperation(Operation *op, - arith::SIToFPOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - template <> - void - bubbleUpOperation(Operation *op, triton::ClampFOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - template <> - void bubbleUpOperation(Operation *op, arith::CmpFOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - template <> - void bubbleUpOperation(Operation *op, - triton::BroadcastOp parentOp, - ArrayRef indices, - Location loc, - PatternRewriter &rewriter) const; - template <> - void bubbleUpOperation(Operation *op, - triton::ExpandDimsOp parentOp, - ArrayRef indices, - Location loc, - PatternRewriter &rewriter) const; - template <> - void bubbleUpOperation(Operation *op, - triton::SplatOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - template <> - void bubbleUpOperation(Operation *op, - triton::MakeRangeOp parentOp, - ArrayRef indices, - Location loc, - PatternRewriter &rewriter) const; - template <> - void bubbleUpOperation(Operation *op, math::FloorOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; - template <> - void bubbleUpOperation(Operation *op, math::CeilOp parentOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const; + + void bubbleUpOperation(ExtractOpTy op, arith::ExtSIOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::CmpIOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::TruncFOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::ExtFOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::FPToSIOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::SIToFPOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::ClampFOp parentOp, + Location loc, PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, arith::CmpFOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::BroadcastOp parentOp, + Location loc, PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::ExpandDimsOp parentOp, + Location loc, PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::SplatOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, triton::MakeRangeOp parentOp, + Location loc, PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, math::FloorOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, math::CeilOp parentOp, Location loc, + PatternRewriter &rewriter) const; + void bubbleUpOperation(ExtractOpTy op, tensor::ExtractSliceOp parentOp, Location loc, + PatternRewriter &rewriter) const; bool enableAggressiveMode; }; @@ -128,4 +85,4 @@ class BubbleUpOperationPass public: explicit BubbleUpOperationPass(const BubbleUpOperationOptions &options); void runOnOperation() override; -}; \ No newline at end of file +}; diff --git a/ascend/triton-adapter/include/TritonToUnstructure/Passes.td b/ascend/triton-adapter/include/TritonToUnstructure/Passes.td index 6086b4595e68b537e5eb084d549e06e3d64bcda7..81a7bd29b7a21780fd5fe6780a0a8f6d2cb324de 100644 --- a/ascend/triton-adapter/include/TritonToUnstructure/Passes.td +++ b/ascend/triton-adapter/include/TritonToUnstructure/Passes.td @@ -6,6 +6,10 @@ include "mlir/Pass/PassBase.td" def TritonToUnstructure : Pass<"triton-to-unstructure", "mlir::ModuleOp"> { let summary = "Convert Triton for unstructure case"; let constructor = "triton::createTritonToUnstructurePass()"; + let options = [ + Option<"forceScalarizeMode", "force-scalarize-mode", "bool", "false", + "Scalarize unstructured memory access even if structured dimensions are mixed.">, + ]; } def BubbleUpOperation : Pass<"bubble-up-operation", "mlir::ModuleOp"> { diff --git a/ascend/triton-adapter/include/TritonToUnstructure/UnstructureConversionPass.h b/ascend/triton-adapter/include/TritonToUnstructure/UnstructureConversionPass.h index c806521243b16d5e992e378d557d09b07c5d9b81..00aeedbc702b3eaf919c49a15f3f8420e3294e14 100644 --- a/ascend/triton-adapter/include/TritonToUnstructure/UnstructureConversionPass.h +++ b/ascend/triton-adapter/include/TritonToUnstructure/UnstructureConversionPass.h @@ -7,13 +7,17 @@ #include "mlir/IR/PatternMatch.h" +#define GEN_PASS_DECL_TRITONTOUNSTRUCTURE +#include "ascend/triton-adapter/include/TritonToUnstructure/Passes.h.inc" + #define GEN_PASS_DEF_TRITONTOUNSTRUCTURE #include "ascend/triton-adapter/include/TritonToUnstructure/Passes.h.inc" namespace mlir { namespace triton { -std::unique_ptr> createTritonToUnstructurePass(); +std::unique_ptr> +createTritonToUnstructurePass(const TritonToUnstructureOptions &options = {}); } // namespace triton } // namespace mlir @@ -61,29 +65,36 @@ public: using OpRewritePattern::OpRewritePattern; explicit UnstructuredMemAccessConverter( - MLIRContext *context, + MLIRContext *context, bool forceScalarizeMode, const llvm::DenseMap &offsetMap); LogicalResult matchAndRewrite(MemAccOpTy op, PatternRewriter &rewriter) const override; private: - Value createExtractOp(Location loc, Value value, ArrayRef iterIdx, - PatternRewriter &rewriter) const; + Value createExtractOp(Location loc, Value value, PatternRewriter &rewriter, + ArrayRef iterIdx) const; + Value createExtractOp(Location loc, Value value, PatternRewriter &rewriter, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) const; template typename std::enable_if, void>::type splatAndLoadScenario(MemAccOpTy op, int rank, PatternRewriter &rewriter) const; + template MemAccOpTy createMemAccOp(MemAccOpTy op, Value ptrToAccess, Location loc, - ArrayRef iterIdx, - PatternRewriter &rewriter) const; + PatternRewriter &rewriter, + Args &&...args) const = delete; const llvm::DenseMap &offsetMap; + bool forceScalarizeMode; }; class TritonToUnstructurePass : public ::impl::TritonToUnstructureBase { public: + explicit TritonToUnstructurePass(const TritonToUnstructureOptions &options); void getDependentDialects(DialectRegistry ®istry) const override; void runOnOperation() override; diff --git a/ascend/triton-adapter/lib/TritonToUnstructure/BubbleUpOperation.cpp b/ascend/triton-adapter/lib/TritonToUnstructure/BubbleUpOperation.cpp index 949849e1f4283928a805c55cfd599c1956109026..c68852e1e03e19b77f7f60f068623a6126ff4b5e 100644 --- a/ascend/triton-adapter/lib/TritonToUnstructure/BubbleUpOperation.cpp +++ b/ascend/triton-adapter/lib/TritonToUnstructure/BubbleUpOperation.cpp @@ -7,111 +7,157 @@ #define DEBUG_TYPE "triton-bubble-up-operation" -BubbleUpExtract::BubbleUpExtract(MLIRContext *context, - bool enableAggressiveMode) - : OpRewritePattern(context), +template +BubbleUpExtract::BubbleUpExtract(MLIRContext *context, + bool enableAggressiveMode) + : OpRewritePattern(context), enableAggressiveMode(enableAggressiveMode) {} +template LogicalResult -BubbleUpExtract::matchAndRewrite(tensor::ExtractOp op, - PatternRewriter &rewriter) const { - auto tensorValue = op.getTensor(); +BubbleUpExtract::matchAndRewrite(ExtractOpTy op, + PatternRewriter &rewriter) const { + Value tensorValue; + if constexpr (std::is_same_v) { + tensorValue = op.getTensor(); + } else if constexpr (std::is_same_v) { + tensorValue = op.getSource(); + if (tensorValue.getType() == op.getResult().getType()) { + rewriter.replaceAllUsesWith(op.getResult(), tensorValue); + rewriter.eraseOp(op); + return success(); + } + } else { + llvm_unreachable("Unhandled case"); + } + auto funcOp = op->template getParentOfType(); auto parentOp = tensorValue.getDefiningOp(); - auto indices = - SmallVector(op.getIndices().begin(), op.getIndices().end()); auto loc = op.getLoc(); - if (!parentOp || - (!enableAggressiveMode && !parentOp->hasOneUse())) { + if (!parentOp || (!enableAggressiveMode && !parentOp->hasOneUse())) { return failure(); } + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Before bubble up\n" << op << '\n' << funcOp << "\n"; + }); + if (auto extsiOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, extsiOp, indices, loc, rewriter); + bubbleUpOperation(op, extsiOp, loc, rewriter); } else if (auto addIOp = dyn_cast(parentOp)) { - bubbleUpIntBinaryOp(op, addIOp, indices, loc, rewriter); + bubbleUpIntBinaryOp(op, addIOp, loc, rewriter); } else if (auto subIOp = dyn_cast(parentOp)) { - bubbleUpIntBinaryOp(op, subIOp, indices, loc, rewriter); + bubbleUpIntBinaryOp(op, subIOp, loc, rewriter); } else if (auto mulIOp = dyn_cast(parentOp)) { - bubbleUpIntBinaryOp(op, mulIOp, indices, loc, rewriter); + bubbleUpIntBinaryOp(op, mulIOp, loc, rewriter); } else if (auto divSIOp = dyn_cast(parentOp)) { - bubbleUpIntBinaryOp(op, divSIOp, indices, loc, rewriter); + bubbleUpIntBinaryOp(op, divSIOp, loc, rewriter); } else if (auto remSIOp = dyn_cast(parentOp)) { - bubbleUpIntBinaryOp(op, remSIOp, indices, loc, rewriter); + bubbleUpIntBinaryOp(op, remSIOp, loc, rewriter); } else if (auto maxSIOp = dyn_cast(parentOp)) { - bubbleUpIntBinaryOp(op, maxSIOp, indices, loc, rewriter); + bubbleUpIntBinaryOp(op, maxSIOp, loc, rewriter); } else if (auto minSIOp = dyn_cast(parentOp)) { - bubbleUpIntBinaryOp(op, minSIOp, indices, loc, rewriter); + bubbleUpIntBinaryOp(op, minSIOp, loc, rewriter); } else if (auto andIOp = dyn_cast(parentOp)) { - bubbleUpIntBinaryOp(op, andIOp, indices, loc, rewriter); + bubbleUpIntBinaryOp(op, andIOp, loc, rewriter); } else if (auto orIOp = dyn_cast(parentOp)) { - bubbleUpIntBinaryOp(op, orIOp, indices, loc, rewriter); + bubbleUpIntBinaryOp(op, orIOp, loc, rewriter); } else if (auto cmpIOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, cmpIOp, indices, loc, rewriter); + bubbleUpOperation(op, cmpIOp, loc, rewriter); } else if (auto truncFOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, truncFOp, indices, loc, rewriter); + bubbleUpOperation(op, truncFOp, loc, rewriter); } else if (auto extFOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, extFOp, indices, loc, rewriter); + bubbleUpOperation(op, extFOp, loc, rewriter); } else if (auto fpTosiOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, fpTosiOp, indices, loc, rewriter); + bubbleUpOperation(op, fpTosiOp, loc, rewriter); } else if (auto siTofpOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, siTofpOp, indices, loc, rewriter); + bubbleUpOperation(op, siTofpOp, loc, rewriter); } else if (auto clampFOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, clampFOp, indices, loc, rewriter); + bubbleUpOperation(op, clampFOp, loc, rewriter); } else if (auto addFOp = dyn_cast(parentOp)) { - bubbleUpFloatBinaryOp(op, addFOp, indices, loc, rewriter); + bubbleUpFloatBinaryOp(op, addFOp, loc, rewriter); } else if (auto subFOp = dyn_cast(parentOp)) { - bubbleUpFloatBinaryOp(op, subFOp, indices, loc, rewriter); + bubbleUpFloatBinaryOp(op, subFOp, loc, rewriter); } else if (auto mulFOp = dyn_cast(parentOp)) { - bubbleUpFloatBinaryOp(op, mulFOp, indices, loc, rewriter); + bubbleUpFloatBinaryOp(op, mulFOp, loc, rewriter); } else if (auto divFOp = dyn_cast(parentOp)) { - bubbleUpFloatBinaryOp(op, divFOp, indices, loc, rewriter); + bubbleUpFloatBinaryOp(op, divFOp, loc, rewriter); } else if (auto minNumFOp = dyn_cast(parentOp)) { - bubbleUpFloatBinaryOp(op, minNumFOp, indices, loc, - rewriter); + bubbleUpFloatBinaryOp(op, minNumFOp, loc, rewriter); } else if (auto maxNumFOp = dyn_cast(parentOp)) { - bubbleUpFloatBinaryOp(op, maxNumFOp, indices, loc, - rewriter); + bubbleUpFloatBinaryOp(op, maxNumFOp, loc, rewriter); } else if (auto cmpFOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, cmpFOp, indices, loc, rewriter); + bubbleUpOperation(op, cmpFOp, loc, rewriter); } else if (auto broadCastOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, broadCastOp, indices, loc, rewriter); + bubbleUpOperation(op, broadCastOp, loc, rewriter); } else if (auto expandDimsOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, expandDimsOp, indices, loc, - rewriter); + bubbleUpOperation(op, expandDimsOp, loc, rewriter); } else if (auto splatOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, splatOp, indices, loc, rewriter); + bubbleUpOperation(op, splatOp, loc, rewriter); } else if (auto makeRangeOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, makeRangeOp, indices, loc, - rewriter); + bubbleUpOperation(op, makeRangeOp, loc, rewriter); } else if (auto floorOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, floorOp, indices, loc, rewriter); + bubbleUpOperation(op, floorOp, loc, rewriter); } else if (auto ceilOp = dyn_cast(parentOp)) { - bubbleUpOperation(op, ceilOp, indices, loc, rewriter); + bubbleUpOperation(op, ceilOp, loc, rewriter); + } else if (auto extractSliceOp = dyn_cast(parentOp)) { + if constexpr (std::is_same_v) { + bubbleUpOperation(op, extractSliceOp, loc, rewriter); + } else { + return failure(); + } } else { return failure(); } if (parentOp->use_empty()) rewriter.eraseOp(parentOp); + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "After bubble up\n" << funcOp << '\n'; + }); + return success(); } -Value BubbleUpExtract::createExtractOp(Value value, ArrayRef indices, - Location loc, - PatternRewriter &rewriter) const { - auto extractedOp = rewriter.create(loc, value, indices); +template +Value BubbleUpExtract::createExtractOp( + ExtractOpTy op, Value value, Location loc, + PatternRewriter &rewriter) const { + llvm_unreachable("Unhandled extract operation"); +} + +template <> +Value BubbleUpExtract::createExtractOp( + tensor::ExtractOp op, Value value, Location loc, + PatternRewriter &rewriter) const { + auto extractedOp = + rewriter.create(loc, value, op.getIndices()); + extractedOp->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + return extractedOp; +} + +template <> +Value BubbleUpExtract::createExtractOp( + tensor::ExtractSliceOp op, Value value, Location loc, + PatternRewriter &rewriter) const { + auto extractedOp = rewriter.create( + loc, value, op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides()); extractedOp->setAttr(ConverterUtils::discreteAttrName, UnitAttr::get(rewriter.getContext())); return extractedOp; } +template template -void BubbleUpExtract::bubbleUpIntBinaryOp(Operation *op, BinOpTy binOp, - ArrayRef indices, Location loc, - PatternRewriter &rewriter) const { - auto lhs = createExtractOp(binOp.getLhs(), indices, loc, rewriter); - auto rhs = createExtractOp(binOp.getRhs(), indices, loc, rewriter); +void BubbleUpExtract::bubbleUpIntBinaryOp( + ExtractOpTy op, BinOpTy binOp, Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(op, binOp.getLhs(), loc, rewriter); + auto rhs = createExtractOp(op, binOp.getRhs(), loc, rewriter); LLVM_DEBUG({ auto &os = llvm::dbgs(); os << "Binary\n" << *op << '\n' << binOp << '\n'; @@ -119,44 +165,43 @@ void BubbleUpExtract::bubbleUpIntBinaryOp(Operation *op, BinOpTy binOp, rewriter.replaceOpWithNewOp(op, lhs, rhs); } +template template -void BubbleUpExtract::bubbleUpFloatBinaryOp(Operation *op, BinOpTy binOp, - ArrayRef indices, - Location loc, - PatternRewriter &rewriter) const { - auto lhs = createExtractOp(binOp.getLhs(), indices, loc, rewriter); - auto rhs = createExtractOp(binOp.getRhs(), indices, loc, rewriter); +void BubbleUpExtract::bubbleUpFloatBinaryOp( + ExtractOpTy op, BinOpTy binOp, Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(op, binOp.getLhs(), loc, rewriter); + auto rhs = createExtractOp(op, binOp.getRhs(), loc, rewriter); rewriter.replaceOpWithNewOp(op, lhs, rhs); } -template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, arith::ExtSIOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { - auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); - auto resultType = cast(parentOp.getOut().getType()); - rewriter.replaceOpWithNewOp(op, resultType.getElementType(), - in); +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::ExtSIOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), in); } -template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, arith::CmpIOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { - auto lhs = createExtractOp(parentOp.getLhs(), indices, loc, rewriter); - auto rhs = createExtractOp(parentOp.getRhs(), indices, loc, rewriter); +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::CmpIOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(op, parentOp.getLhs(), loc, rewriter); + auto rhs = createExtractOp(op, parentOp.getRhs(), loc, rewriter); rewriter.replaceOpWithNewOp(op, parentOp.getPredicateAttr(), lhs, rhs); } template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, triton::BroadcastOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, triton::BroadcastOp parentOp, Location loc, + PatternRewriter &rewriter) const { auto src = parentOp.getSrc(); auto srcShape = cast(src.getType()).getShape(); SmallVector newIndices; - for (const auto [index, shape] : llvm::zip_equal(indices, srcShape)) { + for (const auto &[index, shape] : + llvm::zip_equal(op.getIndices(), srcShape)) { if (shape == 1) { newIndices.push_back( rewriter.create(loc, rewriter.getIndexAttr(0))); @@ -164,121 +209,235 @@ void BubbleUpExtract::bubbleUpOperation( newIndices.push_back(index); } } - auto extractedOp = createExtractOp(src, newIndices, loc, rewriter); + auto extractedOp = rewriter.create(loc, src, newIndices); + extractedOp->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); rewriter.replaceOp(op, extractedOp); } template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, triton::ExpandDimsOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractSliceOp op, triton::BroadcastOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + auto srcShape = cast(src.getType()).getShape(); + SmallVector newOffsets; + SmallVector newSizes; + bool isScalarLikeSrc = true; + for (const auto &[offset, size, shape] : + llvm::zip_equal(op.getMixedOffsets(), op.getMixedSizes(), srcShape)) { + if (shape == 1) { + newOffsets.push_back(rewriter.getIndexAttr(0)); + newSizes.push_back(rewriter.getIndexAttr(1)); + } else { + newOffsets.push_back(offset); + newSizes.push_back(size); + } + if (getConstantIntValue(newSizes.back()).value_or(-1) != 1) + isScalarLikeSrc = false; + } + auto extractedOp = rewriter.create( + loc, src, newOffsets, newSizes, op.getMixedStrides()); + extractedOp->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + if (isScalarLikeSrc) { + SmallVector indices( + srcShape.size(), + rewriter.create(loc, rewriter.getIndexAttr(0))); + auto extractedValue = + rewriter.create(loc, extractedOp, indices); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + extractedValue); + } else { + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), extractedOp); + } +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, triton::ExpandDimsOp parentOp, Location loc, + PatternRewriter &rewriter) const { auto src = parentOp.getSrc(); SmallVector newIndices; - for (const auto index : llvm::enumerate(indices)) { - if (index.index() != parentOp.getAxis()) { + for (const auto index : llvm::enumerate(op.getIndices())) { + if (index.index() != parentOp.getAxis()) newIndices.push_back(index.value()); - } } - auto extractedOp = createExtractOp(src, newIndices, loc, rewriter); + auto extractedOp = rewriter.create(loc, src, newIndices); + extractedOp->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); rewriter.replaceOp(op, extractedOp); } template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, triton::SplatOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractSliceOp op, triton::ExpandDimsOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + auto srcShape = cast(src.getType()).getShape(); + SmallVector newOffsets; + SmallVector newSizes; + SmallVector newStrides; + size_t j = 0; + for (size_t i = 0; i <= srcShape.size(); i++) { + if (i != parentOp.getAxis()) { + newOffsets.push_back(op.getMixedOffsets()[i]); + newSizes.push_back(op.getMixedSizes()[i]); + newStrides.push_back(op.getMixedStrides()[i]); + } + } + auto extractedOp = rewriter.create( + loc, src, newOffsets, newSizes, newStrides); + extractedOp->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + rewriter.replaceOpWithNewOp(op, extractedOp, + parentOp.getAxisAttr()); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, triton::SplatOp parentOp, Location loc, + PatternRewriter &rewriter) const { auto src = parentOp.getSrc(); rewriter.replaceOp(op, src); } template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, triton::MakeRangeOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractSliceOp op, triton::SplatOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto src = parentOp.getSrc(); + rewriter.replaceOpWithNewOp( + op, cast(op.getResult().getType()), src); +} + +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, triton::MakeRangeOp parentOp, Location loc, + PatternRewriter &rewriter) const { auto resultType = cast(parentOp.getResult().getType()); rewriter.replaceOpWithNewOp( - op, resultType.getElementType(), indices[0]); + op, resultType.getElementType(), op.getIndices()[0]); } template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, arith::TruncFOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { - auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); - auto resultType = cast(parentOp.getOut().getType()); - rewriter.replaceOpWithNewOp(op, resultType.getElementType(), - in); +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractSliceOp op, triton::MakeRangeOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto resultType = cast(parentOp.getResult().getType()); + Value idx; + if (auto offsetVal = dyn_cast(op.getMixedOffsets()[0])) { + idx = offsetVal; + } else { + idx = rewriter.create( + op.getLoc(), rewriter.getIndexAttr( + getConstantIntValue(op.getMixedOffsets()[0]).value())); + } + idx = rewriter.create(op.getLoc(), + resultType.getElementType(), idx); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + idx); } -template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, arith::ExtFOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { - auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); - auto resultType = cast(parentOp.getOut().getType()); - rewriter.replaceOpWithNewOp(op, resultType.getElementType(), +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::TruncFOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), in); } -template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, arith::FPToSIOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { - auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); - auto resultType = cast(parentOp.getOut().getType()); - rewriter.replaceOpWithNewOp(op, resultType.getElementType(), +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::ExtFOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), in); +} + +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::FPToSIOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), in); } -template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, arith::SIToFPOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { - auto in = createExtractOp(parentOp.getIn(), indices, loc, rewriter); - auto outType = - cast(parentOp.getOut().getType()).getElementType(); - rewriter.replaceOpWithNewOp(op, outType, in); +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::SIToFPOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto in = createExtractOp(op, parentOp.getIn(), loc, rewriter); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + in); } -template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, triton::ClampFOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { - auto x = createExtractOp(parentOp.getX(), indices, loc, rewriter); - auto min = createExtractOp(parentOp.getMin(), indices, loc, rewriter); - auto max = createExtractOp(parentOp.getMax(), indices, loc, rewriter); +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, triton::ClampFOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto x = createExtractOp(op, parentOp.getX(), loc, rewriter); + auto min = createExtractOp(op, parentOp.getMin(), loc, rewriter); + auto max = createExtractOp(op, parentOp.getMax(), loc, rewriter); rewriter.replaceOpWithNewOp(op, x, min, max, parentOp.getPropagateNan()); } -template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, arith::CmpFOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { - auto lhs = createExtractOp(parentOp.getLhs(), indices, loc, rewriter); - auto rhs = createExtractOp(parentOp.getRhs(), indices, loc, rewriter); +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, arith::CmpFOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto lhs = createExtractOp(op, parentOp.getLhs(), loc, rewriter); + auto rhs = createExtractOp(op, parentOp.getRhs(), loc, rewriter); rewriter.replaceOpWithNewOp(op, parentOp.getPredicateAttr(), lhs, rhs); } -template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, math::FloorOp parentOp, ArrayRef indices, - Location loc, PatternRewriter &rewriter) const { - auto operand = createExtractOp(parentOp.getOperand(), indices, loc, rewriter); +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, math::FloorOp parentOp, Location loc, + PatternRewriter &rewriter) const { + auto operand = createExtractOp(op, parentOp.getOperand(), loc, rewriter); rewriter.replaceOpWithNewOp(op, operand, parentOp.getFastmath()); } -template <> -void BubbleUpExtract::bubbleUpOperation( - Operation *op, math::CeilOp parentOp, ArrayRef indices, Location loc, +template +void BubbleUpExtract::bubbleUpOperation( + ExtractOpTy op, math::CeilOp parentOp, Location loc, PatternRewriter &rewriter) const { - auto operand = createExtractOp(parentOp.getOperand(), indices, loc, rewriter); + auto operand = createExtractOp(op, parentOp.getOperand(), loc, rewriter); rewriter.replaceOpWithNewOp(op, operand, parentOp.getFastmath()); } +template <> +void BubbleUpExtract::bubbleUpOperation( + tensor::ExtractOp op, tensor::ExtractSliceOp parentOp, + Location loc, PatternRewriter &rewriter) const { + SmallVector newIndices; + for (const auto &[offset, index] : + llvm::zip_equal(parentOp.getMixedOffsets(), op.getIndices())) { + Value offsetVal; + if (isa(offset)) { + offsetVal = offset.template get(); + } else { + offsetVal = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(*getConstantIntValue(offset))); + } + newIndices.push_back( + rewriter.create(op.getLoc(), offsetVal, index)); + } + rewriter + .replaceOpWithNewOp(op, parentOp.getSource(), + newIndices) + ->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); +} + BubbleUpOperationPass::BubbleUpOperationPass( const BubbleUpOperationOptions &options) : BubbleUpOperationBase(options) {} @@ -288,7 +447,9 @@ void BubbleUpOperationPass::runOnOperation() { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx, enableAggressiveMode); + patterns.add, + BubbleUpExtract>(ctx, + enableAggressiveMode); if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) { moduleOp->emitError("failed to apply Patterns"); diff --git a/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp b/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp index d093cb7158c1752034b6f78368028405e8d27d64..3655a85f01c8916926f53d5c17a65a3c22967045 100644 --- a/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp +++ b/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp @@ -23,88 +23,133 @@ using namespace triton; template Value UnstructuredMemAccessConverter::createExtractOp( - Location loc, Value value, ArrayRef iterIdx, - PatternRewriter &rewriter) const { + Location loc, Value value, PatternRewriter &rewriter, + ArrayRef iterIdx) const { if (!value) return value; - auto extractedOp = rewriter.create(loc, value, iterIdx); + SmallVector indices; + for (auto idx : iterIdx) { + if (auto val = dyn_cast(idx)) { + indices.push_back(val); + } else { + auto idxVal = rewriter.create( + loc, rewriter.getIndexAttr(*getConstantIntValue(idx))); + indices.push_back(idxVal); + } + } + auto extractedOp = rewriter.create(loc, value, indices); extractedOp->setAttr(ConverterUtils::discreteAttrName, UnitAttr::get(rewriter.getContext())); return extractedOp; } template -MemAccOpTy UnstructuredMemAccessConverter::createMemAccOp( - MemAccOpTy op, Value ptrToAccess, Location loc, ArrayRef iterIdx, - PatternRewriter &rewriter) const { - llvm_unreachable("Unhandled discrete memory access operation"); +Value UnstructuredMemAccessConverter::createExtractOp( + Location loc, Value value, PatternRewriter &rewriter, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides) const { + if (!value) + return value; + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Extracting\n"; + os << value << "\n"; + }); + auto extractedOp = rewriter.create( + loc, value, offsets, sizes, strides); + extractedOp->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); + return extractedOp; } template <> +template triton::LoadOp UnstructuredMemAccessConverter::createMemAccOp( - triton::LoadOp op, Value ptrToAccess, Location loc, ArrayRef iterIdx, - PatternRewriter &rewriter) const { + triton::LoadOp op, Value ptrToAccess, Location loc, + PatternRewriter &rewriter, Args &&...args) const { return rewriter.create(loc, ptrToAccess, op.getCache(), - op.getEvict(), false); + op.getEvict(), op.getIsVolatile()); } template <> +template triton::AtomicRMWOp UnstructuredMemAccessConverter::createMemAccOp( triton::AtomicRMWOp op, Value ptrToAccess, Location loc, - ArrayRef iterIdx, PatternRewriter &rewriter) const { - auto extractedValue = createExtractOp(loc, op.getVal(), iterIdx, rewriter); - auto extractedMask = createExtractOp(loc, op.getMask(), iterIdx, rewriter); - auto resultType = cast(op.getResult().getType()); - SmallVector scalarLikeShape(resultType.getRank(), 1); - auto scalarLikeType = - RankedTensorType::get(scalarLikeShape, resultType.getElementType()); - auto splatedPtrToAccess = rewriter.create( - loc, RankedTensorType::get(scalarLikeShape, ptrToAccess.getType()), - ptrToAccess); - auto splatedExtractedValue = rewriter.create( - loc, RankedTensorType::get(scalarLikeShape, extractedValue.getType()), - extractedValue); - auto splatedExtractedMask = rewriter.create( - loc, RankedTensorType::get(scalarLikeShape, extractedMask.getType()), - extractedMask); + PatternRewriter &rewriter, Args &&...args) const { + auto extractedValue = + createExtractOp(loc, op.getVal(), rewriter, std::forward(args)...); + auto extractedMask = + createExtractOp(loc, op.getMask(), rewriter, std::forward(args)...); + Type targetType = ptrToAccess.getType(); + if (auto tensorType = dyn_cast(targetType)) { + auto ptrType = cast(tensorType.getElementType()); + targetType = + RankedTensorType::get(tensorType.getShape(), ptrType.getPointeeType()); + } else { + auto resultType = cast(op.getResult().getType()); + SmallVector scalarLikeShape(resultType.getRank(), 1); + targetType = + RankedTensorType::get(scalarLikeShape, resultType.getElementType()); + ptrToAccess = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, ptrToAccess.getType()), + ptrToAccess); + extractedValue = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedValue.getType()), + extractedValue); + extractedMask = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedMask.getType()), + extractedMask); + } return rewriter.create( - loc, scalarLikeType, op.getAtomicRmwOpAttr(), splatedPtrToAccess, - splatedExtractedValue, splatedExtractedMask, op.getSemAttr(), - op.getScopeAttr()); + loc, targetType, op.getAtomicRmwOpAttr(), ptrToAccess, extractedValue, + extractedMask, op.getSemAttr(), op.getScopeAttr()); } template <> +template triton::AtomicCASOp UnstructuredMemAccessConverter::createMemAccOp( triton::AtomicCASOp op, Value ptrToAccess, Location loc, - ArrayRef iterIdx, PatternRewriter &rewriter) const { - auto extractedCmp = createExtractOp(loc, op.getCmp(), iterIdx, rewriter); - auto extractedValue = createExtractOp(loc, op.getVal(), iterIdx, rewriter); - auto resultType = cast(op.getResult().getType()); - SmallVector scalarLikeShape(resultType.getRank(), 1); - auto scalarLikeType = - RankedTensorType::get(scalarLikeShape, resultType.getElementType()); - auto splatedPtrToAccess = rewriter.create( - loc, RankedTensorType::get(scalarLikeShape, ptrToAccess.getType()), - ptrToAccess); - auto splatedExtractedCmp = rewriter.create( - loc, RankedTensorType::get(scalarLikeShape, extractedCmp.getType()), - extractedCmp); - auto splatedExtractedValue = rewriter.create( - loc, RankedTensorType::get(scalarLikeShape, extractedValue.getType()), - extractedValue); + PatternRewriter &rewriter, Args &&...args) const { + auto extractedCmp = + createExtractOp(loc, op.getCmp(), rewriter, std::forward(args)...); + auto extractedValue = + createExtractOp(loc, op.getVal(), rewriter, std::forward(args)...); + Type targetType = ptrToAccess.getType(); + if (auto tensorType = dyn_cast(targetType)) { + auto ptrType = cast(tensorType.getElementType()); + targetType = + RankedTensorType::get(tensorType.getShape(), ptrType.getPointeeType()); + } else { + auto resultType = cast(op.getResult().getType()); + SmallVector scalarLikeShape(resultType.getRank(), 1); + targetType = + RankedTensorType::get(scalarLikeShape, resultType.getElementType()); + ptrToAccess = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, ptrToAccess.getType()), + ptrToAccess); + extractedCmp = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedCmp.getType()), + extractedCmp); + extractedValue = rewriter.create( + loc, RankedTensorType::get(scalarLikeShape, extractedValue.getType()), + extractedValue); + } return rewriter.create( - loc, scalarLikeType, splatedPtrToAccess, splatedExtractedCmp, - splatedExtractedValue, op.getSemAttr(), op.getScopeAttr()); + loc, targetType, ptrToAccess, extractedCmp, extractedValue, + op.getSemAttr(), op.getScopeAttr()); } template <> +template triton::StoreOp UnstructuredMemAccessConverter::createMemAccOp( triton::StoreOp op, Value ptrToAccess, Location loc, - ArrayRef iterIdx, PatternRewriter &rewriter) const { - auto extractedValue = createExtractOp(loc, op.getValue(), iterIdx, rewriter); - auto extractedMask = createExtractOp(loc, op.getMask(), iterIdx, rewriter); + PatternRewriter &rewriter, Args &&...args) const { + auto extractedValue = createExtractOp(loc, op.getValue(), rewriter, + std::forward(args)...); + auto extractedMask = + createExtractOp(loc, op.getMask(), rewriter, std::forward(args)...); return rewriter.create(loc, ptrToAccess, extractedValue, extractedMask); } @@ -115,9 +160,8 @@ void UnstructuredMemAccessConverter::splatAndLoadScenario< triton::LoadOp>(triton::LoadOp op, int rank, PatternRewriter &rewriter) const { auto loc = op.getLoc(); - SmallVector idx( - rank, rewriter.create(loc, rewriter.getIndexAttr(0))); - auto extractedPtr = createExtractOp(loc, op.getPtr(), idx, rewriter); + SmallVector idx(rank, rewriter.getIndexAttr(0)); + auto extractedPtr = createExtractOp(loc, op.getPtr(), rewriter, idx); Value mask = op.getMask(); Value other = op.getOther(); Value loadedValue = rewriter.create( @@ -134,8 +178,10 @@ void UnstructuredMemAccessConverter::splatAndLoadScenario< template UnstructuredMemAccessConverter::UnstructuredMemAccessConverter( - MLIRContext *context, const llvm::DenseMap &offsetMap) - : OpRewritePattern(context), offsetMap(offsetMap) {} + MLIRContext *context, bool forceScalarizeMode, + const llvm::DenseMap &offsetMap) + : OpRewritePattern(context), + forceScalarizeMode(forceScalarizeMode), offsetMap(offsetMap) {} template LogicalResult UnstructuredMemAccessConverter::matchAndRewrite( @@ -145,9 +191,8 @@ LogicalResult UnstructuredMemAccessConverter::matchAndRewrite( auto ptr = op.getPtr(); auto ptrType = dyn_cast(ptr.getType()); - if (!ptrType || op->hasAttr(ConverterUtils::discreteAttrName)) { + if (!ptrType || op->hasAttr(ConverterUtils::discreteAttrName)) return failure(); - } if (!offsetMap.contains(ptr)) return op.emitError() << "PtrOffsetInfo should be computed\n" << ptr; @@ -178,11 +223,17 @@ LogicalResult UnstructuredMemAccessConverter::matchAndRewrite( op, op.getPtr(), selectOp.getTrueValue(), selectOp.getCondition(), op.getCache(), op.getEvict()); rewriter.setInsertionPoint(op); + ptrOffsetInfo.setUnstructured(ptrOffsetInfo.getRank()); } } + if (forceScalarizeMode || + op->template getParentOfType()) { + ptrOffsetInfo.setUnstructured(ptrOffsetInfo.getRank()); + } + auto srcPtr = ptrOffsetInfo.getPtr(); - auto offset = ptrOffsetInfo.getOffset(); + auto ptrOffset = ptrOffsetInfo.getOffset(); // LoadLike is operation with result bool isLoadLike = !op->use_empty(); @@ -195,6 +246,25 @@ LogicalResult UnstructuredMemAccessConverter::matchAndRewrite( auto resultElementType = cast(ptrType.getElementType()).getPointeeType(); + int64_t sizeInByte; + if (auto intType = dyn_cast(resultElementType)) { + sizeInByte = intType.getWidth() / 8; + } else if (auto floatType = dyn_cast(resultElementType)) { + sizeInByte = floatType.getWidth() / 8; + } else { + llvm_unreachable("Unhandled element type of tensor"); + } + + for (int i = ptrOffsetInfo.getRank() - 1; i >= 0; i--) { + if (!ptrOffsetInfo.isStructured(i)) + break; + sizeInByte *= resultShape[i]; + } + + // Force scalarize if memory is not aligned + if (sizeInByte % 32 != 0) + ptrOffsetInfo.setUnstructured(ptrOffsetInfo.getRank()); + Value iterArg = nullptr; // Only load case @@ -206,72 +276,90 @@ LogicalResult UnstructuredMemAccessConverter::matchAndRewrite( auto insertPoint = rewriter.saveInsertionPoint(); - SmallVector dims(resultShape.size(), rewriter.getIndexAttr(1)); SmallVector offsets; + SmallVector sizes; SmallVector strides; - SmallVector iterIdx; - - SmallVector localMemStrides(1, 1); + SmallVector extractedShape; - for (auto size : llvm::reverse(resultShape)) { - localMemStrides.push_back(localMemStrides.back() * size); - } - localMemStrides.pop_back(); - - std::reverse(localMemStrides.begin(), localMemStrides.end()); - bool isExtractedAttrInserted = false; - for (const auto &[size, localMemStride] : - llvm::zip_equal(resultShape, localMemStrides)) { + for (const auto &[size, structured] : llvm::zip_equal( + resultShape, ptrOffsetInfo.getStructured())) { // handle indirect dimension - strides.push_back(rewriter.getIndexAttr(localMemStride)); + strides.push_back(rewriter.getIndexAttr(1)); Value sizeVal = rewriter.create(loc, rewriter.getIndexAttr(size)); - scf::ForOp forOp; - if (isLoadLike) { - forOp = rewriter.create(loc, zeroIdx, sizeVal, oneIdx, - ValueRange({iterArg})); - if (!newOpResult) { - newOpResult = forOp->getResult(0); + if (structured) { + offsets.push_back(rewriter.getIndexAttr(0)); + sizes.push_back(rewriter.getIndexAttr(size)); + extractedShape.push_back(size); + } else { + scf::ForOp forOp; + if (isLoadLike) { + forOp = rewriter.create(loc, zeroIdx, sizeVal, oneIdx, + ValueRange({iterArg})); + if (!newOpResult) { + newOpResult = forOp->getResult(0); + } else { + rewriter.create(loc, forOp->getResult(0)); + } + iterArg = forOp.getRegionIterArg(0); } else { - rewriter.create(loc, forOp->getResult(0)); + forOp = rewriter.create(loc, zeroIdx, sizeVal, oneIdx); } - iterArg = forOp.getRegionIterArg(0); - } else { - forOp = rewriter.create(loc, zeroIdx, sizeVal, oneIdx); + offsets.push_back(forOp.getInductionVar()); + sizes.push_back(rewriter.getIndexAttr(1)); + extractedShape.push_back(1); + forOp->setAttr("ExtractedLoadOrStore", + UnitAttr::get(rewriter.getContext())); + rewriter.setInsertionPointToStart(forOp.getBody()); } - offsets.push_back(forOp.getInductionVar()); - iterIdx.push_back(forOp.getInductionVar()); - forOp->setAttr("ExtractedLoadOrStore", - UnitAttr::get(rewriter.getContext())); - rewriter.setInsertionPointToStart(forOp.getBody()); } - auto scalarLikeShape = SmallVector(dims.size(), 1); - auto scalarLikeType = - RankedTensorType::get(scalarLikeShape, resultElementType); + bool fullyUnstructured = ptrOffsetInfo.isUnstructured(); + auto extractedType = RankedTensorType::get(extractedShape, resultElementType); - auto extractedOffset = createExtractOp(loc, offset, iterIdx, rewriter); - if (isa(srcPtr.getType())) { - srcPtr = createExtractOp(loc, srcPtr, iterIdx, rewriter); + Value extractedOffset; + if (fullyUnstructured) { + extractedOffset = createExtractOp(loc, ptrOffset, rewriter, offsets); + } else { + extractedOffset = + createExtractOp(loc, ptrOffset, rewriter, offsets, sizes, strides); + } + + LLVM_DEBUG({ + auto &os = llvm::dbgs(); + os << "Extracted offset\n"; + os << extractedOffset << "\n"; + }); + + assert(!isa(srcPtr.getType()) && "src must be ptr type"); + if (!fullyUnstructured) { + srcPtr = rewriter.create( + loc, RankedTensorType::get(extractedShape, srcPtr.getType()), srcPtr); } Value ptrToAccess = rewriter.create( loc, srcPtr.getType(), srcPtr, extractedOffset); - MemAccOpTy accessedValue = - createMemAccOp(op, ptrToAccess, loc, iterIdx, rewriter); - accessedValue->setAttr(ConverterUtils::discreteAttrName, - UnitAttr::get(rewriter.getContext())); + MemAccOpTy accessedOp; + if (fullyUnstructured) { + accessedOp = createMemAccOp(op, ptrToAccess, loc, rewriter, offsets); + } else { + accessedOp = + createMemAccOp(op, ptrToAccess, loc, rewriter, offsets, sizes, strides); + } + + accessedOp->setAttr(ConverterUtils::discreteAttrName, + UnitAttr::get(rewriter.getContext())); if (isLoadLike) { assert(iterArg && "Load case must have iterArg in for loop"); - Value splatedValue = accessedValue->getResult(0); - if (!isa(splatedValue.getType())) { - splatedValue = - rewriter.create(loc, scalarLikeType, splatedValue); + Value accessedValue = accessedOp->getResult(0); + if (!isa(accessedValue.getType())) { + accessedValue = + rewriter.create(loc, extractedType, accessedValue); } auto result = rewriter.create( - loc, splatedValue, iterArg, offsets, dims, strides); + loc, accessedValue, iterArg, offsets, sizes, strides); rewriter.create(loc, result->getResult(0)) ->setAttr(ConverterUtils::discreteAttrName, UnitAttr::get(rewriter.getContext())); @@ -402,6 +490,10 @@ void TritonToUnstructurePass::runParse(MemAccOpTy op) { parse(op.getPtr(), op.getLoc(), rewriter, offsetMap); } +TritonToUnstructurePass::TritonToUnstructurePass( + const TritonToUnstructureOptions &options) + : TritonToUnstructureBase(options) {} + void TritonToUnstructurePass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext *ctx = &getContext(); @@ -424,8 +516,8 @@ void TritonToUnstructurePass::runOnOperation() { patterns.add, UnstructuredMemAccessConverter, UnstructuredMemAccessConverter, - UnstructuredMemAccessConverter>(ctx, - offsetMap); + UnstructuredMemAccessConverter>( + ctx, forceScalarizeMode, offsetMap); LLVM_DEBUG({ auto &os = llvm::dbgs(); @@ -453,7 +545,7 @@ void TritonToUnstructurePass::getDependentDialects( triton::TritonDialect>(); } -std::unique_ptr> -triton::createTritonToUnstructurePass() { - return std::make_unique(); +std::unique_ptr> triton::createTritonToUnstructurePass( + const TritonToUnstructureOptions &options) { + return std::make_unique(options); } \ No newline at end of file