From 82bd1abe522d032e3d6ebd47912a5d64f5cbfcbb Mon Sep 17 00:00:00 2001 From: WangjunZou Date: Wed, 30 Jul 2025 14:50:18 +0800 Subject: [PATCH] feat(triton language): add explicit block sync interface --- .../TritonToLinalg/TritonOpConverter.h | 185 +++++----- .../TritonToLinalg/TritonToLinalgPass.h | 7 +- .../lib/TritonToLinalg/TritonOpConverter.cpp | 175 +++++---- .../lib/TritonToLinalg/TritonToLinalgPass.cpp | 343 ++++++++++-------- .../python/triton_patch/language/core.py | 35 +- triton_patch/python/triton_patch/testing.py | 2 + 6 files changed, 438 insertions(+), 309 deletions(-) diff --git a/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h b/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h index 2427253..7eb129c 100644 --- a/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h +++ b/ascend/triton-adapter/include/TritonToLinalg/TritonOpConverter.h @@ -190,15 +190,16 @@ class ReductionOpBaseConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(OpTy op, - typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { auto sourceType = cast(adaptor.getOperands().front().getType()); assert(sourceType.hasRank() && "Expected input is ranked"); int64_t axis = op.getAxis(); - assert(axis >= 0 && axis < sourceType.getRank() && "Expected reduction axis is within operand's rank"); + assert(axis >= 0 && axis < sourceType.getRank() && + "Expected reduction axis is within operand's rank"); auto reductionOps = this->getRedOps(op); if (reductionOps.size() == 1) { @@ -211,7 +212,7 @@ protected: llvm::SmallVector getRedOps(OpTy redOp) const { auto redBody = redOp.getBody(); return llvm::map_to_vector(redBody->without_terminator(), - [](Operation &op) { return &op; }); + [](Operation &op) { return &op; }); } arith::ConstantOp getRedBaseConstOp(ConversionPatternRewriter &rewriter, @@ -219,85 +220,89 @@ protected: Type constantType) const { const int64_t bitWidth = constantType.getIntOrFloatBitWidth(); - auto attr = llvm::TypeSwitch(redOp) - .Case([&](arith::AddFOp) { - return rewriter.getFloatAttr(constantType, 0.f); - }) - .Case([&](arith::AddIOp) { - return rewriter.getIntegerAttr(constantType, 0); - }) - .Case([&](arith::MulFOp) { - return rewriter.getFloatAttr(constantType, 1.f); - }) - .template Case([&](auto) { - return rewriter.getFloatAttr( - constantType, -std::numeric_limits::infinity()); - }) - .template Case([&](auto) { - return rewriter.getFloatAttr( - constantType, std::numeric_limits::infinity()); - }) - .Case([&](arith::MinSIOp) { - return rewriter.getIntegerAttr(constantType, - llvm::maxIntN(bitWidth)); - }) - .Case([&](arith::MinUIOp) { - return rewriter.getIntegerAttr(constantType, - llvm::maxUIntN(bitWidth)); - }) - .Case([&](arith::MaxSIOp) { - return rewriter.getIntegerAttr(constantType, - llvm::minIntN(bitWidth)); - }) - .Case([&](arith::MaxUIOp) { - return rewriter.getIntegerAttr(constantType, 0); - }) - .Case([&](arith::OrIOp) { - return rewriter.getIntegerAttr(constantType, 0); - }) - .Case([&](arith::AndIOp) { - return rewriter.getIntegerAttr(constantType, 1); - }) - .Case([&](arith::XOrIOp) { - return rewriter.getIntegerAttr(constantType, 0); - }) - .Default([](Operation *op) { - op->dump(); - llvm_unreachable("Reduction op not supported yet"); - return nullptr; - }); - - return rewriter.create(redOp->getLoc(), constantType, attr); + auto attr = + llvm::TypeSwitch(redOp) + .Case([&](arith::AddFOp) { + return rewriter.getFloatAttr(constantType, 0.f); + }) + .Case([&](arith::AddIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](arith::MulFOp) { + return rewriter.getFloatAttr(constantType, 1.f); + }) + .template Case([&](auto) { + return rewriter.getFloatAttr( + constantType, -std::numeric_limits::infinity()); + }) + .template Case([&](auto) { + return rewriter.getFloatAttr( + constantType, std::numeric_limits::infinity()); + }) + .Case([&](arith::MinSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxIntN(bitWidth)); + }) + .Case([&](arith::MinUIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::maxUIntN(bitWidth)); + }) + .Case([&](arith::MaxSIOp) { + return rewriter.getIntegerAttr(constantType, + llvm::minIntN(bitWidth)); + }) + .Case([&](arith::MaxUIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](arith::OrIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Case([&](arith::AndIOp) { + return rewriter.getIntegerAttr(constantType, 1); + }) + .Case([&](arith::XOrIOp) { + return rewriter.getIntegerAttr(constantType, 0); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not supported yet"); + return nullptr; + }); + + return rewriter.create(redOp->getLoc(), constantType, + attr); } bool requiresF32Conversion(const Type elemType, Operation *redOp) const { return isa(elemType) && - elemType.getIntOrFloatBitWidth() < - Float32Type::get(elemType.getContext()).getWidth() && - (isa(redOp) || isa(redOp)); + elemType.getIntOrFloatBitWidth() < + Float32Type::get(elemType.getContext()).getWidth() && + (isa(redOp) || isa(redOp)); } - Value getRedElement( - Value lhs, Value rhs, const Location loc, Operation *redOp, OpBuilder &b, - const bool convertLhsToF32Precision) const { - return llvm::TypeSwitch(redOp) - .template Case([&](auto redOp) { - if (convertLhsToF32Precision) { - lhs = b.create(loc, Float32Type::get(b.getContext()), - lhs); - } - return b.create(loc, lhs, rhs); - }) - .template Case( - [&](auto redOp) { return b.create(loc, lhs, rhs); }) - .Default([](Operation *op) { - op->dump(); - llvm_unreachable("Reduction op not yet supported"); - return nullptr; - }); + Value getRedElement(Value lhs, Value rhs, const Location loc, + Operation *redOp, OpBuilder &b, + const bool convertLhsToF32Precision) const { + return llvm::TypeSwitch(redOp) + .template Case([&](auto redOp) { + if (convertLhsToF32Precision) { + lhs = b.create(loc, Float32Type::get(b.getContext()), + lhs); + } + return b.create(loc, lhs, rhs); + }) + .template Case( + [&](auto redOp) { + return b.create(loc, lhs, rhs); + }) + .Default([](Operation *op) { + op->dump(); + llvm_unreachable("Reduction op not yet supported"); + return nullptr; + }); } virtual bool isReductionOpSupported(Operation *redOp) const = 0; @@ -322,13 +327,14 @@ protected: bool isReductionOpSupported(Operation *redOp) const override; LogicalResult - convertToTargetOp(triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, + convertToTargetOp(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; LogicalResult - convertToTargetOpExtended(triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, + convertToTargetOpExtended(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; - }; class ScanConverter : public ReductionOpBaseConverter { @@ -346,9 +352,9 @@ protected: ConversionPatternRewriter &rewriter) const override; LogicalResult - convertToTargetOpExtended(triton::ScanOp op, typename triton::ScanOp::Adaptor adaptor, + convertToTargetOpExtended(triton::ScanOp op, + typename triton::ScanOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; - }; class ExternElementwiseClOpConverter @@ -519,6 +525,19 @@ struct MatmulConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override; }; +// Here wanna use ElementwiseInlineAsmOp with special asm info to express Ascend +// unique instruction +class ElementwiseInlineAsmConverter + : public OpConversionPattern { + using OpConversionPattern< + triton::ElementwiseInlineAsmOp>::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(triton::ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // end of namespace TTOpConverters #endif diff --git a/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h b/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h index 5e17e05..cf5b42f 100644 --- a/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h +++ b/ascend/triton-adapter/include/TritonToLinalg/TritonToLinalgPass.h @@ -10,6 +10,8 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -57,8 +59,9 @@ private: void addTensorKindToArguments(OpTy op, triton::FuncOp func, TensorKind tensorKind); void convertTTFunc(triton::FuncOp func, const bool existDot); - // 处理嵌套的if/else - void transformNestedIfElse(Operation &nestedBranch, OpBuilder &builder); + + LogicalResult convertMultipleBlockControlFlow(Operation *funcOp, + OpBuilder &builder); void addDynamicLegal(ConversionTarget &target, TritonTypeConverter &tritonTypeConverter); diff --git a/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp b/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp index bb02abb..a3daf99 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/TritonOpConverter.cpp @@ -23,9 +23,12 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/ValueRange.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" namespace TTOpConverters { using namespace mlir; @@ -512,57 +515,67 @@ MakeTensorPtrCanonicalizer::matchAndRewrite(triton::MakeTensorPtrOp op, return success(); } -LogicalResult ReduceSingleCanonicalizer::matchAndRewrite(triton::ReduceOp reduceOp, PatternRewriter &rewriter) const -{ - auto srcs = reduceOp.getSrcs(); - bool allSrcSingleElem = true; - for (auto src : srcs) { - auto srcType = cast(src.getType()); - auto srcShape = srcType.getShape(); - int64_t numel = 1; - for (auto s : srcShape) { - numel *= s; - } - if (numel != 1) { - allSrcSingleElem = false; - break; - } +LogicalResult +ReduceSingleCanonicalizer::matchAndRewrite(triton::ReduceOp reduceOp, + PatternRewriter &rewriter) const { + auto srcs = reduceOp.getSrcs(); + bool allSrcSingleElem = true; + for (auto src : srcs) { + auto srcType = cast(src.getType()); + auto srcShape = srcType.getShape(); + int64_t numel = 1; + for (auto s : srcShape) { + numel *= s; } - - if (!allSrcSingleElem) { - return rewriter.notifyMatchFailure(reduceOp, "reduce's srcs are not all with single element"); + if (numel != 1) { + allSrcSingleElem = false; + break; } + } - auto results = reduceOp.getResult(); - auto loc = reduceOp->getLoc(); - auto zero = rewriter - .create(loc, rewriter.getIndexType(), - rewriter.getIntegerAttr(rewriter.getIndexType(), 0)) - .getResult(); - for (int i = 0; i < srcs.size(); i++) { - auto src = srcs[i]; - auto srcType = cast(src.getType()); - auto srcRank = srcType.getRank(); - auto res = results[i]; - Value extracted; - if (srcRank == 1) { - // vector reduce generates a scalar result - extracted = rewriter.create(loc, src, zero).getResult(); - } else { - auto srcShape = srcType.getShape(); - auto resType = cast(res.getType()); - auto resShape = resType.getShape(); - auto collapseReassociationIndicesOptional = getReassociationIndicesForCollapse(srcShape, resShape); - if (!collapseReassociationIndicesOptional.has_value()) { - return rewriter.notifyMatchFailure(reduceOp, "Failure with getReassociationIndicesForCollapse call"); - } - auto collapseReassociationIndices = collapseReassociationIndicesOptional.value(); - extracted = rewriter.create(loc, src, collapseReassociationIndices).getResult(); - } - res.replaceAllUsesWith(extracted); + if (!allSrcSingleElem) { + return rewriter.notifyMatchFailure( + reduceOp, "reduce's srcs are not all with single element"); + } + + auto results = reduceOp.getResult(); + auto loc = reduceOp->getLoc(); + auto zero = rewriter + .create( + loc, rewriter.getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), 0)) + .getResult(); + for (int i = 0; i < srcs.size(); i++) { + auto src = srcs[i]; + auto srcType = cast(src.getType()); + auto srcRank = srcType.getRank(); + auto res = results[i]; + Value extracted; + if (srcRank == 1) { + // vector reduce generates a scalar result + extracted = + rewriter.create(loc, src, zero).getResult(); + } else { + auto srcShape = srcType.getShape(); + auto resType = cast(res.getType()); + auto resShape = resType.getShape(); + auto collapseReassociationIndicesOptional = + getReassociationIndicesForCollapse(srcShape, resShape); + if (!collapseReassociationIndicesOptional.has_value()) { + return rewriter.notifyMatchFailure( + reduceOp, "Failure with getReassociationIndicesForCollapse call"); + } + auto collapseReassociationIndices = + collapseReassociationIndicesOptional.value(); + extracted = rewriter + .create( + loc, src, collapseReassociationIndices) + .getResult(); } + res.replaceAllUsesWith(extracted); + } - return success(); + return success(); } LogicalResult DenseConstantConverter::matchAndRewrite( @@ -774,14 +787,15 @@ BroadcastConverter::matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor, // Reduce Converter bool ReduceConverter::isReductionOpSupported(Operation *redOp) const { return isa(redOp); + arith::MaxNumFOp, arith::MinimumFOp, arith::MinNumFOp, + arith::MinSIOp, arith::MinUIOp, arith::MaxSIOp, arith::MaxUIOp, + arith::AndIOp, arith::OrIOp, arith::XOrIOp>(redOp); } -LogicalResult ReduceConverter::convertToTargetOp( - triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { +LogicalResult +ReduceConverter::convertToTargetOp(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto source = adaptor.getOperands().front(); auto sourceType = cast(source.getType()); auto elemType = sourceType.getElementType(); @@ -794,7 +808,8 @@ LogicalResult ReduceConverter::convertToTargetOp( // subview that skips over each first element. if (!this->isReductionOpSupported(reductionOps.front())) { return rewriter.notifyMatchFailure( - op, "Only support lowering reduction with single op and limited types of reducetion"); + op, "Only support lowering reduction with single op and limited types " + "of reducetion"); } auto rop = reductionOps.front(); @@ -821,19 +836,22 @@ LogicalResult ReduceConverter::convertToTargetOp( .getResult(0); } - Value finalResult = rewriter.create( + Value finalResult = + rewriter + .create( loc, ValueRange{source}, ValueRange{initTensor}, SmallVector{axis}, [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { assert(inputs.size() == 2); - Value result = this->getRedElement(inputs[0], inputs[1], loc, rop, - opBuilder, false); + Value result = this->getRedElement(inputs[0], inputs[1], loc, + rop, opBuilder, false); opBuilder.create(loc, result); }) .getResult(0); if (sourceType.getRank() == 1) { - finalResult = rewriter.create(loc, constantType, finalResult); + finalResult = + rewriter.create(loc, constantType, finalResult); } rewriter.replaceOp(op, finalResult); @@ -841,7 +859,8 @@ LogicalResult ReduceConverter::convertToTargetOp( } LogicalResult ReduceConverter::convertToTargetOpExtended( - triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { + triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); auto elemTypes = op.getElementTypes(); @@ -899,16 +918,18 @@ bool ScanConverter::isReductionOpSupported(Operation *redOp) const { return isa(redOp); } -LogicalResult ScanConverter::convertToTargetOp( - triton::ScanOp op, typename triton::ScanOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { +LogicalResult +ScanConverter::convertToTargetOp(triton::ScanOp op, + typename triton::ScanOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto reductionOps = this->getRedOps(op); // Reduction of arbitrary operations isn't supported because using the first // element across the reduction dimension requires us to iterate over a // subview that skips over each first element. if (!this->isReductionOpSupported(reductionOps.front())) { return rewriter.notifyMatchFailure( - op, "Only support lowering reduction with single op and limited types of reducetion"); + op, "Only support lowering reduction with single op and limited types " + "of reducetion"); } llvm::SmallString<64> funcName; @@ -953,7 +974,8 @@ LogicalResult ScanConverter::convertToTargetOp( LogicalResult ScanConverter::convertToTargetOpExtended( triton::ScanOp op, typename triton::ScanOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { - return op->emitError("tt.scan with multiple ops inside the body unsupported!"); + return op->emitError( + "tt.scan with multiple ops inside the body unsupported!"); } LogicalResult ExternElementwiseClOpConverter::matchAndRewrite( @@ -1279,4 +1301,31 @@ MatmulConverter::matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, } return success(); } + +LogicalResult ElementwiseInlineAsmConverter::matchAndRewrite( + triton::ElementwiseInlineAsmOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto moduleOp = op->getParentOfType(); + + StringAttr stubName = op.getAsmStringAttr(); + // White list to constrain ElementwiseInlineAsmOp usage + if (!stubName.strref().starts_with("ascend_block_sync")) + return op->emitOpError("Unsupported use of ElementwiseInlineAsmOp"); + + SymbolTableCollection symCollection; + auto stubFunc = + symCollection.lookupNearestSymbolFrom(moduleOp, stubName); + if (!stubFunc) { + OpBuilder::InsertionGuard guard(rewriter); + + rewriter.setInsertionPointToStart(moduleOp.getBody()); + stubFunc = rewriter.create( + op->getLoc(), stubName, /*void()*/ rewriter.getFunctionType({}, {})); + stubFunc.setPrivate(); + } + + rewriter.create(op.getLoc(), stubFunc, ValueRange{}); + rewriter.eraseOp(op); + return success(); +} } // namespace TTOpConverters diff --git a/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp b/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp index e81c1fe..7bdfb2c 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/TritonToLinalgPass.cpp @@ -6,6 +6,10 @@ #include "TritonToLinalg/UseAnalysis.h" #include "Utils/InterleaveOptimization.h" #include "Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -24,13 +28,17 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LogicalResult.h" +#include #include #include @@ -89,52 +97,155 @@ void TritonToLinalgPass::addProgramInfo(triton::FuncOp func, } } -static void setBlockArgumentAttr(BlockArgument blockArg, triton::FuncOp func, TensorKind tensorKind) -{ - unsigned argIdx = blockArg.getArgNumber(); - auto existingAttr = func.getArgAttrOfType(argIdx, "tt.tensor_kind"); - TensorKind oldVal = existingAttr ? static_cast(existingAttr.getInt()) : TensorKind::NONE; - - TensorKind finalVal = tensorKind; - if ((oldVal == TensorKind::INPUT && tensorKind == TensorKind::OUTPUT) || - (oldVal == TensorKind::OUTPUT && tensorKind == TensorKind::INPUT)) { - finalVal = TensorKind::INPUT_OUTPUT; - } else if (oldVal == TensorKind::INPUT_OUTPUT) { - finalVal = oldVal; - } +static void setBlockArgumentAttr(BlockArgument blockArg, triton::FuncOp func, + TensorKind tensorKind) { + unsigned argIdx = blockArg.getArgNumber(); + auto existingAttr = + func.getArgAttrOfType(argIdx, "tt.tensor_kind"); + TensorKind oldVal = existingAttr + ? static_cast(existingAttr.getInt()) + : TensorKind::NONE; + + TensorKind finalVal = tensorKind; + if ((oldVal == TensorKind::INPUT && tensorKind == TensorKind::OUTPUT) || + (oldVal == TensorKind::OUTPUT && tensorKind == TensorKind::INPUT)) { + finalVal = TensorKind::INPUT_OUTPUT; + } else if (oldVal == TensorKind::INPUT_OUTPUT) { + finalVal = oldVal; + } - func.setArgAttr(argIdx, "tt.tensor_kind", - IntegerAttr::get(IntegerType::get(func.getContext(), INT_BIT_WIDTH), static_cast(finalVal))); + func.setArgAttr( + argIdx, "tt.tensor_kind", + IntegerAttr::get(IntegerType::get(func.getContext(), INT_BIT_WIDTH), + static_cast(finalVal))); } template -void TritonToLinalgPass::addTensorKindToArguments(OpTy op, triton::FuncOp func, TensorKind tensorKind) -{ - Value ptr = op.getPtr(); - if (!ptr) - return; - - Value cur = ptr; - llvm::SmallPtrSet visited; - // 回溯 def-use 链,找到起源 BlockArgument - while (visited.insert(cur).second) { - // 如果是 BlockArgument,则尝试设置属性 - if (auto blockArg = dyn_cast(cur)) { - if (blockArg.getOwner() == &func.getBody().front()) { - auto type = blockArg.getType(); - // 检查是否是 triton::PointerType - if (!isa(type)) - break; - setBlockArgumentAttr(blockArg, func, tensorKind); - break; - } - } - - Operation *defOp = cur.getDefiningOp(); - if (!defOp) - break; - cur = defOp->getOperand(0); +void TritonToLinalgPass::addTensorKindToArguments(OpTy op, triton::FuncOp func, + TensorKind tensorKind) { + Value ptr = op.getPtr(); + if (!ptr) + return; + + Value cur = ptr; + llvm::SmallPtrSet visited; + // 回溯 def-use 链,找到起源 BlockArgument + while (visited.insert(cur).second) { + // 如果是 BlockArgument,则尝试设置属性 + if (auto blockArg = dyn_cast(cur)) { + if (blockArg.getOwner() == &func.getBody().front()) { + auto type = blockArg.getType(); + // 检查是否是 triton::PointerType + if (!isa(type)) + break; + setBlockArgumentAttr(blockArg, func, tensorKind); + break; + } } + + Operation *defOp = cur.getDefiningOp(); + if (!defOp) + break; + cur = defOp->getOperand(0); + } +} + +// Here consider that every block's terminator must be either triton::ReturnOp +// or cf::CondBranchOp. And when meet cf::CondBranchOp, its both branch will +// reach different return finally. +LogicalResult +TritonToLinalgPass::convertMultipleBlockControlFlow(Operation *funcOp, + OpBuilder &builder) { + assert(isa(funcOp)); + + SmallVector candidate; + SmallVector eraseBlocks; + for (Block &block : dyn_cast(funcOp).getBody()) { + auto curTerminator = block.getTerminator(); + if (isa(curTerminator)) + candidate.push_back(curTerminator); + else if (isa(curTerminator)) + assert(candidate.size() > 0); + else + return failure(); + + if (!block.isEntryBlock()) + eraseBlocks.push_back(&block); + } + + assert(!candidate.empty()); + + llvm::BitVector visitFlag(candidate.size(), false); + + // Recursive function to convert all cf::CondBranchOp to scf::IfOp + std::function convertToSCF = + [&](Operation *op, Operation *insertPosOp) -> void { + auto condBranchOp = dyn_cast_if_present(op); + auto iter = llvm::find(candidate, condBranchOp); + assert(condBranchOp && iter != candidate.end()); + visitFlag.set(iter - candidate.begin()); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(insertPosOp); + + // Well, here force to destory original control flow + builder.create( + condBranchOp->getLoc(), condBranchOp.getCondition(), + /*thenBuilder=*/ + [&](OpBuilder &builder, Location loc) { + SmallVector movedOps = llvm::map_to_vector( + condBranchOp.getTrueDest()->without_terminator(), + [](Operation &op) { return &op; }); + for (auto *innerOp : movedOps) { + innerOp->moveBefore(builder.getInsertionBlock(), + builder.getInsertionPoint()); + } + + auto blockTerm = condBranchOp.getTrueDest()->getTerminator(); + if (isa(blockTerm)) { + assert(!movedOps.empty()); + convertToSCF(blockTerm, movedOps.back()); + } + + builder.create(loc); + }, + /*elseBuilder=*/ + [&](OpBuilder &builder, Location loc) { + SmallVector movedOps = llvm::map_to_vector( + condBranchOp.getFalseDest()->without_terminator(), + [](Operation &op) { return &op; }); + for (auto *innerOp : movedOps) { + innerOp->moveBefore(builder.getInsertionBlock(), + builder.getInsertionPoint()); + } + + auto blockTerm = condBranchOp.getFalseDest()->getTerminator(); + if (isa(blockTerm)) { + assert(!movedOps.empty()); + convertToSCF(blockTerm, movedOps.back()); + } + + builder.create(loc); + }); + }; + + Block::iterator insertOp(candidate.front()); + --insertOp; + convertToSCF(candidate.front(), &(*insertOp)); + + if (!visitFlag.all()) + return failure(); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(candidate.front()); + builder.create(candidate.front()->getLoc()); + + for (Operation *eachTerm : candidate) + eachTerm->erase(); + for (Block *block : llvm::reverse(eraseBlocks)) + block->erase(); + + return success(); } void TritonToLinalgPass::convertTTFunc(triton::FuncOp func, @@ -225,123 +336,25 @@ void TritonToLinalgPass::convertTTFunc(triton::FuncOp func, IRMapping map; funcBody.cloneInto(&funcFuncBody, map); + // Only when there exists any cf::CondBranchOp, the func region would have + // multiple blocks. And here's to convert cf back to scf. + if (!funcFuncBody.hasOneBlock()) { + if (failed(convertMultipleBlockControlFlow(funcFunc, builder))) { + llvm_unreachable("Encounter unsupported control flow"); + } + } + + assert(funcFuncBody.hasOneBlock()); + for (Block &block : funcFuncBody.getBlocks()) { auto term = block.getTerminator(); - if (auto condBranch = dyn_cast(term)) { - SmallVector trueOps; - SmallVector falseOps; - bool trueHasReturn = false; - bool falseHasReturn = false; - for (Operation &op : condBranch.getTrueDest()->without_terminator()) { - if (dyn_cast(&op)) { - transformNestedIfElse(op, builder); - } - trueOps.push_back(&op); - if (isa(op)) { - trueHasReturn = true; - } - } - for (Operation &op : condBranch.getFalseDest()->without_terminator()) { - if (dyn_cast(&op)) { - transformNestedIfElse(op, builder); - } - falseOps.push_back(&op); - if (isa(op)) { - falseHasReturn = true; - } - } - builder.setInsertionPoint(condBranch); - auto ifOp = builder.create ( - condBranch.getLoc(), - condBranch.getCondition(), - [&](OpBuilder &thenBuilder, Location loc) { - for (Operation *op : trueOps) { - op->moveBefore(thenBuilder.getInsertionBlock(), thenBuilder.getInsertionPoint()); - } - if (!trueHasReturn) { - thenBuilder.create(loc); - } - }, - [&](OpBuilder &elseBuilder, Location loc) { - for (Operation *op : falseOps) { - op->moveBefore(elseBuilder.getInsertionBlock(), elseBuilder.getInsertionPoint()); - } - if (!falseHasReturn) { - elseBuilder.create(loc); - } - } - ); - if (!trueHasReturn && !falseHasReturn) { - Block *afterBlock = condBranch->getBlock(); - if (!afterBlock->empty()) { - builder.setInsertionPointToEnd(afterBlock); - builder.create(condBranch.getLoc()); - } - } - condBranch.erase(); - condBranch.getTrueDest()->erase(); - condBranch.getFalseDest()->erase(); - } else { - builder.setInsertionPoint(term); - builder.create(func.getLoc(), term->getOperands()); - term->erase(); - } + builder.setInsertionPoint(term); + builder.create(func.getLoc(), term->getOperands()); + term->erase(); } func.erase(); } -// 处理嵌套的if/else -void TritonToLinalgPass::transformNestedIfElse(Operation &op, OpBuilder &builder) { - auto nestedBranch = dyn_cast(&op); - SmallVector nestedTrueOps; - SmallVector nestedFalseOps; - bool nestedTrueHasReturn = false; - bool nestedFalseHasReturn = false; - - for (Operation &op : nestedBranch.getTrueDest()->without_terminator()) { - if (dyn_cast(&op)) { - transformNestedIfElse(op, builder); - } - nestedTrueOps.push_back(&op); - if (isa(op)) { - nestedTrueHasReturn = true; - } - } - for (Operation &op : nestedBranch.getFalseDest()->without_terminator()) { - if (dyn_cast(&op)) { - transformNestedIfElse(op, builder); - } - nestedFalseOps.push_back(&op); - if (isa(op)) { - nestedFalseHasReturn = true; - } - } - builder.setInsertionPoint(nestedBranch); - auto nestedIfOp = builder.create( - nestedBranch.getLoc(), - nestedBranch.getCondition(), - [&](OpBuilder &thenBuilder, Location loc) { - for (Operation *op : nestedTrueOps) { - op->moveBefore(thenBuilder.getInsertionBlock(), thenBuilder.getInsertionPoint()); - } - if (!nestedTrueHasReturn) { - thenBuilder.create(loc); - } - }, - [&](OpBuilder &elseBuilder, Location loc) { - for (Operation *op : nestedFalseOps) { - op->moveBefore(elseBuilder.getInsertionBlock(), elseBuilder.getInsertionPoint()); - } - if (!nestedTrueHasReturn) { - elseBuilder.create(loc); - } - } - ); - nestedBranch.erase(); - nestedBranch.getTrueDest()->erase(); - nestedBranch.getFalseDest()->erase(); -} - void TritonToLinalgPass::addDynamicLegal( ConversionTarget &target, TritonTypeConverter &tritonTypeConverter) { target.addLegalDialect< @@ -427,6 +440,8 @@ void TritonToLinalgPass::populateTritonToLinalgCanonicalizationPatterns(RewriteP patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + // ToDo: Here AtomicRMW explains both load and store semantic when it + // returns value, while this should only aims to store operation. patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); @@ -521,6 +536,8 @@ void TritonToLinalgPass::populateTritonToLinalgConversionPatterns( patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); if (!this->namedOps) { linalg::populateElementwiseToLinalgConversionPatterns(patterns); @@ -565,12 +582,16 @@ void TritonToLinalgPass::runOnOperation() { // 1.标准化 LoadStore ScalarStoreCanonicalizer this->populateTritonToLinalgCanonicalizationPatterns(canonicalizerPatterns); + // As there exists def-use dependency in canonicalization, here would enable + // up to down traversal order + GreedyRewriteConfig config = GreedyRewriteConfig(); + config.useTopDownTraversal = true; if (failed(applyPatternsAndFoldGreedily(moduleOp, - std::move(canonicalizerPatterns)))) { + std::move(canonicalizerPatterns), + config))) { moduleOp->emitError("failed to apply Canonicalizer Patterns"); signalPassFailure(); } - // 2.使用分析 moduleOp.walk([this](triton::FuncOp op) { if (failed(runUseAnalysis(op))) { @@ -673,16 +694,17 @@ void TritonToLinalgPass::runOnOperation() { auto context = func.getContext(); constexpr int64_t syncBlockLockArgIdx = 0; - NamedAttribute syncBlockLockArgAttr(StringAttr::get(context, "syncBlockLock"), - UnitAttr::get(context)); + NamedAttribute syncBlockLockArgAttr( + StringAttr::get(context, "syncBlockLock"), UnitAttr::get(context)); MemRefType syncBlockLockArgType = MemRefType::get(SmallVector(1, ShapedType::kDynamic), IntegerType::get(context, 8)); - func.insertArgument(syncBlockLockArgIdx, // argIndex - syncBlockLockArgType, // argType + func.insertArgument(syncBlockLockArgIdx, // argIndex + syncBlockLockArgType, // argType nullptr, func->getLoc()); // dicAttr func->setAttr("SyncBlockLockArgIdx", - IntegerAttr::get(IntegerType::get(&getContext(), 64), 0)); // 64: 64位整型 + IntegerAttr::get(IntegerType::get(&getContext(), 64), + 0)); // 64: 64位整型 constexpr int64_t workspaceArgIdx = 1; MemRefType workspaceArgType = @@ -695,7 +717,8 @@ void TritonToLinalgPass::runOnOperation() { /*argType*/ workspaceArgType, /*dicAttr*/ nullptr, func->getLoc()); func->setAttr("WorkspaceArgIdx", - IntegerAttr::get(IntegerType::get(&getContext(), 64), 1)); // 64: 64位整型 + IntegerAttr::get(IntegerType::get(&getContext(), 64), + 1)); // 64: 64位整型 } // Fix the Location info diff --git a/triton_patch/python/triton_patch/language/core.py b/triton_patch/python/triton_patch/language/core.py index 8c2f667..2ad049b 100644 --- a/triton_patch/python/triton_patch/language/core.py +++ b/triton_patch/python/triton_patch/language/core.py @@ -1,5 +1,5 @@ import os -from typing import List +from typing import List, Set, Dict from triton._C.libtriton import ir from triton.language import semantic as real_semantic @@ -10,11 +10,13 @@ from triton.language.core import ( builtin, constexpr, dtype as real_dtype, + int8, float32, tensor, check_bit_width, _unwrap_if_constexpr, range, + inline_asm_elementwise, ) # from triton.language.core import _unwrap_if_constexpr, _unwrap_shape @@ -337,3 +339,34 @@ def compile_hint(ptr, hint_name, hint_val=None, _builder=None): assert isinstance(hint_name, str), f"hint name: {hint_name} is not string" hint_val = _unwrap_if_constexpr(hint_val) if hint_val else hint_val semantic.compile_hint(ptr, hint_name, hint_val, _builder) + +@builtin +def explicit_intra_block_sync(mode, block_arch, producer_pipeline, _builder=None): + PARAM_WHITELISTS: Dict[str, Set[str]] = { + "mode": {"set", "wait"}, + "block_arch": {"cube", "vector"}, + "producer_pipeline": {"pipe_fix", "pipe_mte3"} + } + + mode = _constexpr_to_value(mode) + block_arch = _constexpr_to_value(block_arch) + producer_pipeline = _constexpr_to_value(producer_pipeline) + def validate_param(param_name: str, value: str): + whitelist = PARAM_WHITELISTS[param_name] + + if value not in whitelist: + valid_options = ", ".join(sorted(whitelist)) + raise ValueError( + f"arg '{param_name}' not supported value: '{value}'" + f"The supported value are: {valid_options}" + ) + + validate_param("mode", mode) + validate_param("block_arch", block_arch) + validate_param("producer_pipeline", producer_pipeline) + + asm_alias = f"ascend_block_sync_{mode}_{block_arch}_{producer_pipeline}" + return inline_asm_elementwise(asm_alias, + "", [], dtype=int8, is_pure=False, pack=1, # stub args + _builder=_builder) + diff --git a/triton_patch/python/triton_patch/testing.py b/triton_patch/python/triton_patch/testing.py index ca142a7..16078e5 100644 --- a/triton_patch/python/triton_patch/testing.py +++ b/triton_patch/python/triton_patch/testing.py @@ -622,6 +622,7 @@ from .triton_patch.language.core import ( __rshift__, parallel, compile_hint, + explicit_intra_block_sync ) from .triton_patch.language.standard import flip, sigmoid, softmax from .triton_patch.language.math import ( @@ -674,6 +675,7 @@ language.tensor.__rshift__ = __rshift__ language.trans = trans language.parallel = parallel language.compile_hint = compile_hint +language.explicit_intra_block_sync = explicit_intra_block_sync # from .triton_patch.language.core import dtype, pointer_type, block_type, function_type # language.core.dtype = dtype -- Gitee