diff --git a/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h b/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h index f8b9f0566d283343a32ad87c20473163e1307ae2..2b922c2b8075c5d63b31cac011bc7f0dc1c6aa08 100644 --- a/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h +++ b/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h @@ -66,6 +66,7 @@ public: Value getPtr() const; Value getOffset() const; bool isScalarLike() const; + bool isNegativeFlag() const; SmallVector &getStructuredRef(); const SmallVector &getStructured() const; int getRank() const; @@ -79,7 +80,7 @@ public: void setStructured(ArrayRef structured); void setStructured(const PtrOffsetInfo &other); void setScalarLike(bool scalarLike); - + void setNegativeFlag(bool negativeFlag); bool isStructured(int dim) const; bool isStructured() const; bool isUnstructured() const; @@ -90,7 +91,7 @@ private: Value offset; bool scalarLike = false; - + bool negativeFlag = false; SmallVector structured; }; diff --git a/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp b/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp index 4f2afb826222131011631e3747f62f82487316e1..096639a75ebd73f1f5a5eeec9bf88db4751980ac 100644 --- a/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp +++ b/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp @@ -48,12 +48,14 @@ PtrOffsetInfo &PtrOffsetInfo::operator=(const PtrOffsetInfo &other) { setOffset(other.getOffset()); setStructured(other.getStructured()); setScalarLike(other.isScalarLike()); + setNegativeFlag(other.isNegativeFlag()); return *this; } Value PtrOffsetInfo::getPtr() const { return this->ptr; } Value PtrOffsetInfo::getOffset() const { return this->offset; } bool PtrOffsetInfo::isScalarLike() const { return this->scalarLike; } +bool PtrOffsetInfo::isNegativeFlag() const { return this->negativeFlag; } SmallVector &PtrOffsetInfo::getStructuredRef() { return this->structured; } const SmallVector &PtrOffsetInfo::getStructured() const { @@ -101,6 +103,10 @@ void PtrOffsetInfo::setStructured(const PtrOffsetInfo &other) { this->setStructured(other.getStructured()); } +void PtrOffsetInfo::setNegativeFlag(bool negativeFlag) { + this->negativeFlag = negativeFlag; +} + void PtrOffsetInfo::setScalarLike(bool scalarLike) { this->scalarLike = scalarLike; } @@ -148,6 +154,7 @@ PtrOffsetInfo combineInfo(const PtrOffsetInfo &lhs, const PtrOffsetInfo &rhs) { structuredRef.resize(lhs.getRank()); for (size_t i = 0; i < structuredRef.size(); i++) structuredRef[i] = lhs.isStructured(i) && rhs.isStructured(i); + info.setNegativeFlag(lhs.isNegativeFlag() || rhs.isNegativeFlag()); return info; } @@ -212,6 +219,7 @@ void parse(Value operand, const Location &loc, RewriterBase &rewriter, for (auto s : data.getStructuredRef()) os << s; os << "\n"; + os << "FNparse: " << operand << " ,isNegativeFlag: " << data.isNegativeFlag() << "\n"; }); } @@ -385,6 +393,12 @@ void parseAddPtr(triton::AddPtrOp op, const Location &loc, for (size_t i = 0; i < offsetStructured.size(); i++) os << offsetStructured[i]; os << "\n"; + os << "[parseAddPtr] offsetOffsetInfo.isNegativeFlag(): "; + os << offsetOffsetInfo.isNegativeFlag(); + os << "\n"; + os << "[parseAddPtr] ptrOffsetInfo.isNegativeFlag(): "; + os << ptrOffsetInfo.isNegativeFlag(); + os << "\n"; }); } @@ -416,6 +430,7 @@ void parseSplat(triton::SplatOp op, const Location &loc, RewriterBase &rewriter, dstOffsetInfo.setStructured(dstType.getRank()); dstOffsetInfo.setScalarLike(true); + dstOffsetInfo.setNegativeFlag(srcOffsetInfo.isNegativeFlag()); offsetMap[dst] = dstOffsetInfo; } @@ -438,6 +453,13 @@ void parseBinaryOp(BinOpTy op, const Location &loc, RewriterBase &rewriter, dstOffsetInfo.setStructured(lhsStructured.size()); else dstOffsetInfo.setUnstructured(lhsStructured.size()); + + if (isa(op.getOperation())) { + dstOffsetInfo.setNegativeFlag(true); + } else { + dstOffsetInfo.setNegativeFlag(lhsOffsetInfo.isNegativeFlag() || + rhsOffsetInfo.isNegativeFlag()); + } offsetMap[dst] = dstOffsetInfo; } @@ -467,15 +489,118 @@ void parseIndexCast(arith::IndexCastOp op, const Location &loc, offsetMap[dst] = offsetMap.at(src); } +template +bool isConstantNegative(AttrTy attr, TypeTy type) { + if constexpr (std::is_same_v && + std::is_same_v) { + return attr.getInt() < 0; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return attr.getValueAsDouble() < 0.0; + } else if constexpr(std::is_same_v && + std::is_same_v) { + return attr.getInt() < 0; + } else if constexpr (std::is_base_of_v && + std::is_base_of_v) { + auto tensorType = mlir::cast(type); + auto elemType = tensorType.getElementType(); + + if (auto denseIntAttr = dyn_cast(attr)) { + if (auto intElemType = dyn_cast(elemType)) { + for (auto elemVal : denseIntAttr.template getValues()) { + auto elemAttr = mlir::IntegerAttr::get(intElemType, elemVal); + if (isConstantNegative(elemAttr, intElemType)) { + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO: Tensor has negative element: " << elemAttr << "\n"; + }); + return true; + } + } + return false; + } + } + + else if (auto denseFloatAttr = dyn_cast(attr)) { + if (auto floatElemType = dyn_cast(elemType)) { + for (auto elemVal : denseFloatAttr.template getValues()) { + auto elemAttr = mlir::FloatAttr::get(floatElemType, elemVal); + if (isConstantNegative(elemAttr, floatElemType)) { + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO: Tensor has negative element: " << elemAttr << "\n"; + }); + return true; + } + } + return false; + } + } + + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO: Unsupported tensor elemType: " << elemType + << ",tensorType:" << tensorType << "\n"; + }); + return false; + } else { + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO, Unsupported: attr: " << attr + << ", type: " << type << " \n"; + }); + return false; + } +} + template void parseConstantOp(ConstOpTy dst, const Location &loc, RewriterBase &rewriter, llvm::DenseMap &offsetMap) { - // Set constant offset map - offsetMap[dst] = PtrOffsetInfo(); - offsetMap[dst].setScalarLike(true); - if (auto tensorType = dyn_cast(dst->getResult(0).getType())) - offsetMap[dst].setStructured(tensorType.getRank()); + mlir::Operation *opPtr = nullptr; + if constexpr (std::is_pointer_v) { + if (dst != nullptr) { + opPtr = dst->getOperation(); + } + } else { + opPtr = dst.getOperation(); + } + + mlir::Value opResult = opPtr->getResult(0); + + offsetMap[opResult] = PtrOffsetInfo(); + offsetMap[opResult].setScalarLike(true); + if (auto tensorType = mlir::dyn_cast(opResult.getType())) { + offsetMap[opResult].setStructured(tensorType.getRank()); + } + + auto constantOp = mlir::dyn_cast(opPtr); + if (!constantOp) { + LLVM_DEBUG({ + llvm::dbgs() << "Warning: Non-ConstantOp (" << opPtr->getName() + << ") passed to parseConstantOp\n"; + }); + return; + } + + mlir::Attribute constAttr = constantOp.getValue(); + mlir::Type resultType = opResult.getType(); + + if (auto intType = dyn_cast(resultType)) { + if (auto intAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(intAttr, intType)); + } + } else if (auto floatType = dyn_cast(resultType)) { + if (auto floatAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(floatAttr, floatType)); + } + } else if (auto indexType = dyn_cast(resultType)) { + if (auto intAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(intAttr, indexType)); + } + } else if (auto indexType = dyn_cast(resultType)) { + if (auto intAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(intAttr, indexType)); + } + } else { + llvm_unreachable("PCO: Non-ConstantOp passed to parseConstantOp \n"); + } } void parseMakeRange(triton::MakeRangeOp op, const Location &loc, @@ -518,6 +643,7 @@ void parseBitcast(triton::BitcastOp op, const Location &loc, offsetMap[dst] = PtrOffsetInfo(srcStructured); } offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); } void parseLoad(triton::LoadOp op, const Location &loc, RewriterBase &rewriter, @@ -529,6 +655,8 @@ void parseLoad(triton::LoadOp op, const Location &loc, RewriterBase &rewriter, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(offsetMap[ptr].isScalarLike()); + offsetMap[dst].setNegativeFlag(offsetMap[ptr].isNegativeFlag()); + auto &dstStructured = offsetMap[dst].getStructuredRef(); auto tensorType = dyn_cast(dst.getType()); if (!tensorType) return; @@ -554,6 +682,8 @@ void parseMulI(arith::MulIOp op, const Location &loc, RewriterBase &rewriter, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(lhsScalarLike && rhsScalarLike); + offsetMap[dst].setNegativeFlag(lhsOffsetInfo.isNegativeFlag() + || rhsOffsetInfo.isNegativeFlag()); SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); dstStructured.resize(maxSize); for (size_t i = 0; i < maxSize; i++) @@ -585,6 +715,7 @@ void parseBroadcast(triton::BroadcastOp op, const Location &loc, // Set broadcast offset map offsetMap[dst] = PtrOffsetInfo(srcOffsetInfo.getPtr()); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); if (srcOffsetInfo.getPtr()) { RewriterBase::InsertionGuard guard(rewriter); @@ -619,6 +750,7 @@ void parseExpandDims(triton::ExpandDimsOp op, const Location &loc, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(srcOffsetInfo.getPtr()); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); if (srcOffsetInfo.getPtr()) { RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); @@ -661,6 +793,9 @@ void parseClampF(triton::ClampFOp op, const Location &loc, offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike() && minOffsetInfo.isScalarLike() && maxOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag() || + minOffsetInfo.isNegativeFlag() || + maxOffsetInfo.isNegativeFlag()); auto dstType = dyn_cast(dst.getType()); if (!dstType) return; @@ -693,6 +828,8 @@ void parseSelect(arith::SelectOp op, const Location &loc, offsetMap[dst].setScalarLike(conditionScalarLike && trueValueScalarLike && falseValueScalarLike); auto dstType = dyn_cast(dst.getType()); + offsetMap[dst].setNegativeFlag(trueValueOffsetInfo.isNegativeFlag() || + falseValueOffsetInfo.isNegativeFlag()); if (!dstType) return; if (offsetMap[dst].isScalarLike()) @@ -712,6 +849,7 @@ void parseFPToSI(arith::FPToSIOp op, const Location &loc, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); auto dstType = dyn_cast(dst.getType()); if (!dstType) return; @@ -732,6 +870,7 @@ void parseSIToFP(arith::SIToFPOp op, const Location &loc, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); auto dstType = dyn_cast(dst.getType()); if (!dstType) return; @@ -790,6 +929,7 @@ void parseReduce(triton::ReduceOp op, const Location &loc, auto dstType = dyn_cast(dst.getType()); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); if (!dstType) return; SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); @@ -815,6 +955,7 @@ void parseReduceReturn(triton::ReduceReturnOp op, const Location &loc, auto dstType = dyn_cast(dst.getType()); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); if (!dstType) return; SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); @@ -850,6 +991,7 @@ void parseIf(scf::IfOp op, const Location &loc, RewriterBase &rewriter, // Set if offset map offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(dstIsScalar); + offsetMap[dst].setNegativeFlag(thenOffsetInfo.isNegativeFlag()); SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); dstStructured.resize(thenStructured.size()); for (size_t i = 0; i < dstStructured.size(); i++) @@ -904,6 +1046,10 @@ void parseIntToPtr(triton::IntToPtrOp op, const Location &loc, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(dst); offsetMap[dst].setScalarLike(true); + + parse(op.getSrc(), op.getLoc(), rewriter, offsetMap); + auto srcOffsetInfo = offsetMap.at(op.getSrc()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); } } // namespace triton diff --git a/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp b/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp index c0368fe3f568ac6e743e2b32942510f2f564c809..7c2a1cd2cd670603e904d51f8548d66d2f4133ab 100644 --- a/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp +++ b/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp @@ -162,9 +162,41 @@ LogicalResult UnstructuredMemAccessConverter::matchAndRewrite( auto &os = llvm::dbgs(); os << "Converting " << op->getName() << "\n"; os << op << "\n"; - os << ptrOffsetInfo.isStructured() << "\n"; + os << "parentOp is : " << *op->getParentOp() << "\n"; + os << "isStructured = " << ptrOffsetInfo.isStructured() \ + << ",isScalarLike = " << ptrOffsetInfo.isScalarLike() \ + << ", isNegativeFlag = " << ptrOffsetInfo.isNegativeFlag() << "\n"; }); + bool flag = false; + if (ptrOffsetInfo.isNegativeFlag()) { + auto opoffset = ptrOffsetInfo.getOffset(); + auto opoffsetType = opoffset.getType(); + flag = true; + Value constantZero; + if (auto tensorType = dyn_cast(opoffsetType)) { + constantZero = rewriter.create( + loc, rewriter.getZeroAttr(tensorType)); + } else { + constantZero = rewriter.create( + loc, 0, opoffset.getType()); + } + Value cmpResult = rewriter.create( + loc, arith::CmpIPredicate::sge, ptrOffsetInfo.getOffset(), constantZero); + + mlir::StringAttr assertMsg = rewriter.getStringAttr( + "AddPtr offset (from subi) must be >= 0"); + + rewriter.create(loc, cmpResult, assertMsg); + } + + if (ptrOffsetInfo.isStructured() && !ptrOffsetInfo.isScalarLike()) { + if (flag) + return success(); + else + return failure(); + } + if constexpr (std::is_same_v) if (ptrOffsetInfo.isScalarLike()) { splatAndLoadScenario(op, ptrOffsetInfo.getRank(), rewriter); @@ -457,4 +489,4 @@ void TritonToUnstructurePass::getDependentDialects( std::unique_ptr> triton::createTritonToUnstructurePass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/ascend/triton-adapter/safe_compile.cmake b/ascend/triton-adapter/safe_compile.cmake index 2b29614c72ae10bf1ea98a6af6daf5099764f457..f39dda31284f2b1d1cdbf9bf49c3b7db7efbce69 100644 --- a/ascend/triton-adapter/safe_compile.cmake +++ b/ascend/triton-adapter/safe_compile.cmake @@ -7,4 +7,4 @@ set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-z,now -pie -s") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,now -s") set(CMAKE_SKIP_RPATH TRUE) set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) -unset(CMAKE_INSTALL_RPATH) \ No newline at end of file +unset(CMAKE_INSTALL_RPATH)