From da2b71fffc2dfbf25da2e2cf974896a2129c78e2 Mon Sep 17 00:00:00 2001 From: w00442059 Date: Tue, 16 Sep 2025 16:16:15 +0800 Subject: [PATCH 1/3] fix(offset): add assert for offset from subi Add assertion to check if offset is derived from subi add flag add RankedTensorType support for ConstantParse add setNegativeFlag in parseIntToPtr add parse in parseIntToPtr add arith::SubFOp for negativaflag add arith::SubFOp for negativaflag 11 --- .../TritonToUnstructure/OffsetAnalysis.h | 5 +- .../TritonToUnstructure/OffsetAnalysis.cpp | 155 +++++++++++++++++- .../UnstructureConversionPass.cpp | 34 +++- ascend/triton-adapter/safe_compile.cmake | 2 +- 4 files changed, 185 insertions(+), 11 deletions(-) diff --git a/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h b/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h index f8b9f05..2b922c2 100644 --- a/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h +++ b/ascend/triton-adapter/include/TritonToUnstructure/OffsetAnalysis.h @@ -66,6 +66,7 @@ public: Value getPtr() const; Value getOffset() const; bool isScalarLike() const; + bool isNegativeFlag() const; SmallVector &getStructuredRef(); const SmallVector &getStructured() const; int getRank() const; @@ -79,7 +80,7 @@ public: void setStructured(ArrayRef structured); void setStructured(const PtrOffsetInfo &other); void setScalarLike(bool scalarLike); - + void setNegativeFlag(bool negativeFlag); bool isStructured(int dim) const; bool isStructured() const; bool isUnstructured() const; @@ -90,7 +91,7 @@ private: Value offset; bool scalarLike = false; - + bool negativeFlag = false; SmallVector structured; }; diff --git a/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp b/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp index 8fb98c2..dff0d86 100644 --- a/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp +++ b/ascend/triton-adapter/lib/TritonToUnstructure/OffsetAnalysis.cpp @@ -48,12 +48,14 @@ PtrOffsetInfo &PtrOffsetInfo::operator=(const PtrOffsetInfo &other) { setOffset(other.getOffset()); setStructured(other.getStructured()); setScalarLike(other.isScalarLike()); + setNegativeFlag(other.isNegativeFlag()); return *this; } Value PtrOffsetInfo::getPtr() const { return this->ptr; } Value PtrOffsetInfo::getOffset() const { return this->offset; } bool PtrOffsetInfo::isScalarLike() const { return this->scalarLike; } +bool PtrOffsetInfo::isNegativeFlag() const { return this->negativeFlag; } SmallVector &PtrOffsetInfo::getStructuredRef() { return this->structured; } const SmallVector &PtrOffsetInfo::getStructured() const { @@ -101,6 +103,10 @@ void PtrOffsetInfo::setStructured(const PtrOffsetInfo &other) { this->setStructured(other.getStructured()); } +void PtrOffsetInfo::setNegativeFlag(bool negativeFlag) { + this->negativeFlag = negativeFlag; +} + void PtrOffsetInfo::setScalarLike(bool scalarLike) { this->scalarLike = scalarLike; } @@ -148,6 +154,7 @@ PtrOffsetInfo combineInfo(const PtrOffsetInfo &lhs, const PtrOffsetInfo &rhs) { structuredRef.resize(lhs.getRank()); for (size_t i = 0; i < structuredRef.size(); i++) structuredRef[i] = lhs.isStructured(i) && rhs.isStructured(i); + info.setNegativeFlag(lhs.isNegativeFlag() || rhs.isNegativeFlag()); return info; } @@ -212,6 +219,7 @@ void parse(Value operand, const Location &loc, RewriterBase &rewriter, for (auto s : data.getStructuredRef()) os << s; os << "\n"; + os << "FNparse: " << operand << " ,isNegativeFlag: " << data.isNegativeFlag() << "\n"; }); } @@ -385,6 +393,12 @@ void parseAddPtr(triton::AddPtrOp op, const Location &loc, for (size_t i = 0; i < offsetStructured.size(); i++) os << offsetStructured[i]; os << "\n"; + os << "[parseAddPtr] offsetOffsetInfo.isNegativeFlag(): "; + os << offsetOffsetInfo.isNegativeFlag(); + os << "\n"; + os << "[parseAddPtr] ptrOffsetInfo.isNegativeFlag(): "; + os << ptrOffsetInfo.isNegativeFlag(); + os << "\n"; }); } @@ -417,6 +431,7 @@ void parseSplat(triton::SplatOp op, const Location &loc, RewriterBase &rewriter, dstOffsetInfo.setStructured(dstType.getRank()); dstOffsetInfo.setScalarLike(true); + dstOffsetInfo.setNegativeFlag(srcOffsetInfo.isNegativeFlag()); offsetMap[dst] = dstOffsetInfo; } @@ -439,6 +454,13 @@ void parseBinaryOp(BinOpTy op, const Location &loc, RewriterBase &rewriter, dstOffsetInfo.setStructured(lhsStructured.size()); else dstOffsetInfo.setUnstructured(lhsStructured.size()); + + if (isa(op.getOperation())) { + dstOffsetInfo.setNegativeFlag(true); + } else { + dstOffsetInfo.setNegativeFlag(lhsOffsetInfo.isNegativeFlag() || + rhsOffsetInfo.isNegativeFlag()); + } offsetMap[dst] = dstOffsetInfo; } @@ -468,15 +490,118 @@ void parseIndexCast(arith::IndexCastOp op, const Location &loc, offsetMap[dst] = offsetMap.at(src); } +template +bool isConstantNegative(AttrTy attr, TypeTy type) { + if constexpr (std::is_same_v && + std::is_same_v) { + return attr.getInt() < 0; + } else if constexpr (std::is_same_v && + std::is_same_v) { + return attr.getValueAsDouble() < 0.0; + } else if constexpr(std::is_same_v && + std::is_same_v) { + return attr.getInt() < 0; + } else if constexpr (std::is_base_of_v && + std::is_base_of_v) { + auto tensorType = mlir::cast(type); + auto elemType = tensorType.getElementType(); + + if (auto denseIntAttr = dyn_cast(attr)) { + if (auto intElemType = dyn_cast(elemType)) { + for (auto elemVal : denseIntAttr.template getValues()) { + auto elemAttr = mlir::IntegerAttr::get(intElemType, elemVal); + if (isConstantNegative(elemAttr, intElemType)) { + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO: Tensor has negative element: " << elemAttr << "\n"; + }); + return true; + } + } + return false; + } + } + + else if (auto denseFloatAttr = dyn_cast(attr)) { + if (auto floatElemType = dyn_cast(elemType)) { + for (auto elemVal : denseFloatAttr.template getValues()) { + auto elemAttr = mlir::FloatAttr::get(floatElemType, elemVal); + if (isConstantNegative(elemAttr, floatElemType)) { + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO: Tensor has negative element: " << elemAttr << "\n"; + }); + return true; + } + } + return false; + } + } + + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO: Unsupported tensor elemType: " << elemType + << ",tensorType:" << tensorType << "\n"; + }); + return false; + } else { + LLVM_DEBUG({ + llvm::dbgs() << DEBUG_TYPE << " PCO, Unsupported: attr: " << attr + << ", type: " << type << " \n"; + }); + return false; + } +} + template void parseConstantOp(ConstOpTy dst, const Location &loc, RewriterBase &rewriter, llvm::DenseMap &offsetMap) { - // Set constant offset map - offsetMap[dst] = PtrOffsetInfo(); - offsetMap[dst].setScalarLike(true); - if (auto tensorType = dyn_cast(dst->getResult(0).getType())) - offsetMap[dst].setStructured(tensorType.getRank()); + mlir::Operation *opPtr = nullptr; + if constexpr (std::is_pointer_v) { + if (dst != nullptr) { + opPtr = dst->getOperation(); + } + } else { + opPtr = dst.getOperation(); + } + + mlir::Value opResult = opPtr->getResult(0); + + offsetMap[opResult] = PtrOffsetInfo(); + offsetMap[opResult].setScalarLike(true); + if (auto tensorType = mlir::dyn_cast(opResult.getType())) { + offsetMap[opResult].setStructured(tensorType.getRank()); + } + + auto constantOp = mlir::dyn_cast(opPtr); + if (!constantOp) { + LLVM_DEBUG({ + llvm::dbgs() << "Warning: Non-ConstantOp (" << opPtr->getName() + << ") passed to parseConstantOp\n"; + }); + return; + } + + mlir::Attribute constAttr = constantOp.getValue(); + mlir::Type resultType = opResult.getType(); + + if (auto intType = dyn_cast(resultType)) { + if (auto intAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(intAttr, intType)); + } + } else if (auto floatType = dyn_cast(resultType)) { + if (auto floatAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(floatAttr, floatType)); + } + } else if (auto indexType = dyn_cast(resultType)) { + if (auto intAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(intAttr, indexType)); + } + } else if (auto indexType = dyn_cast(resultType)) { + if (auto intAttr = dyn_cast(constAttr)) { + offsetMap[opResult].setNegativeFlag(isConstantNegative(intAttr, indexType)); + } + } else { + llvm_unreachable("PCO: Non-ConstantOp passed to parseConstantOp \n"); + } } void parseMakeRange(triton::MakeRangeOp op, const Location &loc, @@ -519,6 +644,7 @@ void parseBitcast(triton::BitcastOp op, const Location &loc, offsetMap[dst] = PtrOffsetInfo(srcStructured); } offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); } void parseLoad(triton::LoadOp op, const Location &loc, RewriterBase &rewriter, @@ -530,6 +656,7 @@ void parseLoad(triton::LoadOp op, const Location &loc, RewriterBase &rewriter, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(offsetMap[ptr].isScalarLike()); + offsetMap[dst].setNegativeFlag(offsetMap[ptr].isNegativeFlag()); auto &dstStructured = offsetMap[dst].getStructuredRef(); auto tensorType = dyn_cast(dst.getType()); if (!tensorType) @@ -556,6 +683,8 @@ void parseMulI(arith::MulIOp op, const Location &loc, RewriterBase &rewriter, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(lhsScalarLike && rhsScalarLike); + offsetMap[dst].setNegativeFlag(lhsOffsetInfo.isNegativeFlag() + || rhsOffsetInfo.isNegativeFlag()); SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); dstStructured.resize(maxSize); for (size_t i = 0; i < maxSize; i++) @@ -587,6 +716,7 @@ void parseBroadcast(triton::BroadcastOp op, const Location &loc, // Set broadcast offset map offsetMap[dst] = PtrOffsetInfo(srcOffsetInfo.getPtr()); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); if (srcOffsetInfo.getPtr()) { RewriterBase::InsertionGuard guard(rewriter); @@ -621,6 +751,7 @@ void parseExpandDims(triton::ExpandDimsOp op, const Location &loc, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(srcOffsetInfo.getPtr()); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); if (srcOffsetInfo.getPtr()) { RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); @@ -663,6 +794,9 @@ void parseClampF(triton::ClampFOp op, const Location &loc, offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike() && minOffsetInfo.isScalarLike() && maxOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag() || + minOffsetInfo.isNegativeFlag() || + maxOffsetInfo.isNegativeFlag()); auto dstType = dyn_cast(dst.getType()); if (!dstType) return; @@ -696,6 +830,8 @@ void parseSelect(arith::SelectOp op, const Location &loc, offsetMap[dst].setScalarLike(conditionScalarLike && trueValueScalarLike && falseValueScalarLike); auto dstType = dyn_cast(dst.getType()); + offsetMap[dst].setNegativeFlag(trueValueOffsetInfo.isNegativeFlag() || + falseValueOffsetInfo.isNegativeFlag()); if (!dstType) return; if (offsetMap[dst].isScalarLike()) @@ -715,6 +851,7 @@ void parseFPToSI(arith::FPToSIOp op, const Location &loc, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); auto dstType = dyn_cast(dst.getType()); if (!dstType) return; @@ -735,6 +872,7 @@ void parseSIToFP(arith::SIToFPOp op, const Location &loc, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); auto dstType = dyn_cast(dst.getType()); if (!dstType) return; @@ -793,6 +931,7 @@ void parseReduce(triton::ReduceOp op, const Location &loc, auto dstType = dyn_cast(dst.getType()); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); if (!dstType) return; SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); @@ -818,6 +957,7 @@ void parseReduceReturn(triton::ReduceReturnOp op, const Location &loc, auto dstType = dyn_cast(dst.getType()); offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(srcOffsetInfo.isScalarLike()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); if (!dstType) return; SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); @@ -853,6 +993,7 @@ void parseIf(scf::IfOp op, const Location &loc, RewriterBase &rewriter, // Set if offset map offsetMap[dst] = PtrOffsetInfo(); offsetMap[dst].setScalarLike(dstIsScalar); + offsetMap[dst].setNegativeFlag(thenOffsetInfo.isNegativeFlag()); SmallVector &dstStructured = offsetMap[dst].getStructuredRef(); dstStructured.resize(thenStructured.size()); for (size_t i = 0; i < dstStructured.size(); i++) @@ -907,6 +1048,10 @@ void parseIntToPtr(triton::IntToPtrOp op, const Location &loc, auto dst = op.getResult(); offsetMap[dst] = PtrOffsetInfo(dst); offsetMap[dst].setScalarLike(true); + + parse(op.getSrc(), op.getLoc(), rewriter, offsetMap); + auto srcOffsetInfo = offsetMap.at(op.getSrc()); + offsetMap[dst].setNegativeFlag(srcOffsetInfo.isNegativeFlag()); } } // namespace triton diff --git a/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp b/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp index c90e210..7ce9992 100644 --- a/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp +++ b/ascend/triton-adapter/lib/TritonToUnstructure/UnstructureConversionPass.cpp @@ -161,11 +161,39 @@ LogicalResult UnstructuredMemAccessConverter::matchAndRewrite( auto &os = llvm::dbgs(); os << "Converting " << op->getName() << "\n"; os << op << "\n"; - os << ptrOffsetInfo.isStructured() << "\n"; + os << "parentOp is : " << *op->getParentOp() << "\n"; + os << "isStructured = " << ptrOffsetInfo.isStructured() \ + << ",isScalarLike = " << ptrOffsetInfo.isScalarLike() \ + << ", isNegativeFlag = " << ptrOffsetInfo.isNegativeFlag() << "\n"; }); + bool flag = false; + if (ptrOffsetInfo.isNegativeFlag()) { + auto opoffset = ptrOffsetInfo.getOffset(); + auto opoffsetType = opoffset.getType(); + flag = true; + Value constantZero; + if (auto tensorType = dyn_cast(opoffsetType)) { + constantZero = rewriter.create( + loc, rewriter.getZeroAttr(tensorType)); + } else { + constantZero = rewriter.create( + loc, 0, opoffset.getType()); + } + Value cmpResult = rewriter.create( + loc, arith::CmpIPredicate::sge, ptrOffsetInfo.getOffset(), constantZero); + + mlir::StringAttr assertMsg = rewriter.getStringAttr( + "AddPtr offset (from subi) must be >= 0"); + + rewriter.create(loc, cmpResult, assertMsg); + } + if (ptrOffsetInfo.isStructured() && !ptrOffsetInfo.isScalarLike()) { - return failure(); + if (flag) + return success(); + else + return failure(); } if constexpr (std::is_same_v) @@ -440,4 +468,4 @@ void TritonToUnstructurePass::getDependentDialects( std::unique_ptr> triton::createTritonToUnstructurePass() { return std::make_unique(); -} \ No newline at end of file +} diff --git a/ascend/triton-adapter/safe_compile.cmake b/ascend/triton-adapter/safe_compile.cmake index 2b29614..f39dda3 100644 --- a/ascend/triton-adapter/safe_compile.cmake +++ b/ascend/triton-adapter/safe_compile.cmake @@ -7,4 +7,4 @@ set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,-z,now -pie -s") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,now -s") set(CMAKE_SKIP_RPATH TRUE) set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) -unset(CMAKE_INSTALL_RPATH) \ No newline at end of file +unset(CMAKE_INSTALL_RPATH) -- Gitee From 9bea23f9f5369efa2cfc94dd9cfed71bbed49e21 Mon Sep 17 00:00:00 2001 From: wangjintang Date: Mon, 29 Sep 2025 11:38:15 +0800 Subject: [PATCH 2/3] add test_causal_conv1d_update_kernel_gdn.py --- .../test_causal_conv1d_update_kernel_gdn.py | 252 ++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 ascend/examples/tutorials/GDN/test_causal_conv1d_update_kernel_gdn.py diff --git a/ascend/examples/tutorials/GDN/test_causal_conv1d_update_kernel_gdn.py b/ascend/examples/tutorials/GDN/test_causal_conv1d_update_kernel_gdn.py new file mode 100644 index 0000000..a178ba0 --- /dev/null +++ b/ascend/examples/tutorials/GDN/test_causal_conv1d_update_kernel_gdn.py @@ -0,0 +1,252 @@ +import pytest +import torch +import triton +import triton.language as tl +import numpy as np + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['cache'] is not None, + 'HAS_WEIGHT': lambda args: args['weight'] is not None, + 'HAS_BIAS': lambda args: args['bias'] is not None, + 'HAS_RESIDUAL': lambda args: args['residual'] is not None, +}) +@triton.jit +def causal_conv1d_update_kernel( + x, + cache, + residual, + y, + weight, + bias, + D: tl.constexpr, + W: tl.constexpr, + BD: tl.constexpr, + BW: tl.constexpr, + ACTIVATION: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, +): + i_d, i_n = tl.program_id(0), tl.program_id(1) + + o_d = i_d * BD + tl.arange(0, BD) + o_w = tl.arange(0, BW) + W - BW + m_d = o_d < D + m_w = o_w >= 0 + m_c = o_w < W - 1 + + # [BD] + b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=0).to(tl.float32) + + if USE_INITIAL_STATE: + # shift the cache by 1 with the last one being discarded + p_cache = tl.make_block_ptr(cache + i_n * D*W, (D, W), (W, 1), (i_d * BD, W - BW + 1), (BD, BW), (1, 0)) + # [BD, BW] + b_cache = tl.load(p_cache, boundary_check=(0, 1)).to(tl.float32) + b_cache = tl.where(m_c[None, :], b_cache, b_x[:, None]) + else: + b_cache = tl.zeros((BD, BW), dtype=tl.float32) + + if HAS_WEIGHT: + b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None] & m_w, other=0) + b_y = tl.sum(b_cache * b_w, 1) + else: + b_y = tl.sum(b_cache, 1) + if HAS_BIAS: + b_y += tl.load(bias + o_d, mask=m_d) + + if ACTIVATION == 'swish' or ACTIVATION == 'silu': + b_y = b_y * tl.sigmoid(b_y) + + if HAS_RESIDUAL: + b_y += tl.load(residual + i_n * D + o_d, mask=m_d, other=0) + + tl.store(y + i_n * D + o_d, tl.cast(b_y, dtype=y.dtype.element_ty, fp_downcast_rounding='rtne'), mask=m_d) + + if USE_INITIAL_STATE: + b_cache = tl.cast(b_cache, dtype=cache.dtype.element_ty, fp_downcast_rounding='rtne') + # update the cache in-place + p_cache = tl.make_block_ptr(cache + i_n * D*W, (D, W), (W, 1), (i_d * BD, W - BW), (BD, BW), (1, 0)) + tl.store(p_cache, b_cache, boundary_check=(0, 1)) + +def reference_causal_conv1d_update( + x, cache, residual, weight, bias, activation='silu', use_initial_state=True, + D=None, W=None, BD=16, BW=3 +): + batch_size, D = x.shape + W = cache.shape[2] if (use_initial_state and W is None) else W + dtype = x.dtype + device=x.device + + # 初始化缓存 + if not use_initial_state: + cache = torch.zeros(batch_size, D, W, device=device, dtype=dtype) + else: + cache = cache.clone() + + y = torch.zeros_like(x) + num_d_blocks = (D + BD - 1) // BD + + for i_n in range(batch_size): + for i_d in range(num_d_blocks): + # 特征分块索引 + o_d_start = i_d * BD + o_d_end = min((i_d + 1) * BD, D) + o_d = torch.arange(o_d_start, o_d_end, device=device) + bd = o_d.numel() + m_d = (o_d < D).unsqueeze(1) + + # 窗口索引(BW=3时,o_w = [0,1,2],无负数索引,适配NPU) + o_w = torch.arange(W - BW, W, device=device) # 3-3=0 → [0,1,2] + m_w = (o_w >= 0) & (o_w < W) # 全为True([0,1,2]均有效) + m_c = (o_w < (W - 1)).unsqueeze(0) # [1,3] → [T,T,F] + + # 加载x + b_x = x[i_n, o_d].to(torch.float32) + + # 缓存加载(NPU上BW=3无负数索引,无需特殊处理) + if use_initial_state: + cache_window_offset = W - BW + 1 # 3-3+1=1 + w_idx_cache = cache_window_offset + torch.arange(BW, device=device) # [1,2,3] + valid_cache_mask = (w_idx_cache < W) # [1,2,3]中3≥3 → [T,T,F] + + d_idx, w_idx = torch.meshgrid(o_d, w_idx_cache, indexing='ij') + valid_mask = m_d & valid_cache_mask.unsqueeze(0) + b_cache = torch.zeros(bd, BW, device=device, dtype=torch.float32) + b_cache[valid_mask] = cache[i_n, d_idx[valid_mask], w_idx[valid_mask]].to(torch.float32) + + # 缓存左移(m_c为[T,T,F],第3列用b_x更新) + update_mask = ~m_c.expand(bd, -1) + update_idx = torch.nonzero(update_mask, as_tuple=True) + b_cache[update_idx] = b_x[update_idx[0]] + + else: + b_cache = torch.zeros(bd, BW, device=device, dtype=torch.float32) + + # 权重加载(BW=3时无无效索引,直接计算) + if weight is not None: + w_d_idx, w_w_idx = torch.meshgrid(o_d, o_w, indexing='ij') + valid_weight_mask = m_d & m_w.unsqueeze(0) + b_w = torch.zeros(bd, BW, device=device, dtype=torch.float32) + b_w[valid_weight_mask] = weight[w_d_idx[valid_weight_mask], w_w_idx[valid_weight_mask]].to(torch.float32) + b_y = torch.sum(b_cache * b_w, dim=1) + else: + b_y = torch.sum(b_cache, dim=1) + + # 后续处理(与Triton一致) + if bias is not None: + b_y += bias[o_d].to(torch.float32) + if activation in ['silu', 'swish']: + b_y = b_y * torch.sigmoid(b_y) + if residual is not None: + b_y += residual[i_n, o_d].to(torch.float32) + + y[i_n, o_d] = b_y.to(dtype) + + # 缓存更新(NPU上无负数索引问题) + if use_initial_state: + update_window_offset = W - BW # 3-3=0 + o_w_update = update_window_offset + torch.arange(BW, device=device) # [0,1,2] + valid_update_mask = (o_w_update >= 0) & (o_w_update < W) # 全为True + u_d_idx, u_w_idx = torch.meshgrid(o_d, o_w_update, indexing='ij') + final_update_mask = m_d & valid_update_mask.unsqueeze(0) + cache[i_n, u_d_idx[final_update_mask], u_w_idx[final_update_mask]] = b_cache[final_update_mask].to(dtype) + + return y, cache + +def test_causal_conv1d_update_kernel_gdn(): + # GDN网络典型参数配置(覆盖不同场景) + configs = [ + # 场景1:完整配置(带缓存、权重、偏置、残差) + dict(use_initial_state=True, has_weight=True, has_bias=True, has_residual=True), + # 场景2:无权重(纯状态累加) + dict(use_initial_state=True, has_weight=False, has_bias=True, has_residual=True), + # 场景3:无初始缓存(从零开始) + dict(use_initial_state=False, has_weight=True, has_bias=False, has_residual=True), + # 场景4:无残差连接 + dict(use_initial_state=True, has_weight=True, has_bias=True, has_residual=False), + # 场景1:完整配置(带缓存、权重、偏置、残差) + dict(use_initial_state=True, has_weight=True, has_bias=True, has_residual=True), + # 场景2:无权重(纯状态累加) + dict(use_initial_state=True, has_weight=False, has_bias=True, has_residual=True), + # 场景3:无初始缓存(从零开始) + dict(use_initial_state=False, has_weight=True, has_bias=False, has_residual=True), + # 场景4:无残差连接 + dict(use_initial_state=True, has_weight=True, has_bias=True, has_residual=False), + ] + + # 通用参数(GDN典型值) + batch_size = 2 # 批量大小 + D = 64 # 特征维度(GDN常用64/128) + W = 3 # 卷积核大小(因果卷积常用3/5) + BD = 16 # 特征块大小(需整除D:64=16×4) + # BW = 3 #3 # 核块大小(需等于W,确保无边界问题) + activation = 'silu' # GDN常用门控激活 + + # 关键:根据硬件强制绑定BW(NPU仅支持3,CUDA仅支持4) + # device = 'cuda' if torch.cuda.is_available() else 'npu' + device = 'npu' + if device == 'cuda': + BW = 4 # CUDA必须用2的幂 + elif device == 'npu': + BW = 3 # NPU仅支持3(4触发的内存访问问题,需要继续定位) + else: + BW = 3 # CPU fallback + + # 确保数据类型为float32(避免NPU不支持double) + dtype = torch.float32 + + for cfg in configs: + # 生成输入数据(NPU设备,与算子匹配) + x = torch.randn(batch_size, D, device=device, dtype=dtype) # 当前时序输入 + # 缓存初始化:[batch, D, W],存储历史W-1步状态 + 当前步(初始随机) + cache = torch.randn(batch_size, D, W, device=device, dtype=dtype) if cfg['use_initial_state'] else None + residual = torch.randn(batch_size, D, device=device, dtype=dtype) if cfg['has_residual'] else None + weight = torch.randn(D, W, device=device, dtype=dtype) if cfg['has_weight'] else None + bias = torch.randn(D, device=device, dtype=dtype) if cfg['has_bias'] else None + + # 准备输出张量 + y_triton = torch.empty_like(x) + # 深拷贝缓存(避免参考实现修改原缓存) + cache_triton = cache.clone() if cache is not None else None + + # 调用Triton算子 + y_triton = torch.empty_like(x, dtype=dtype) # 显式指定dtype + cache_triton = cache.clone() if cfg['use_initial_state'] else None + grid = (D // BD, batch_size) + causal_conv1d_update_kernel[grid]( + x=x, cache=cache_triton, residual=residual, y=y_triton, + weight=weight, bias=bias, + D=D, W=W, BD=BD, BW=BW, ACTIVATION=activation, + USE_INITIAL_STATE=cfg['use_initial_state'], + HAS_WEIGHT=cfg['has_weight'], + HAS_BIAS=cfg['has_bias'], + HAS_RESIDUAL=cfg['has_residual'], + ) + + # 调用修复后的PyTorch参考实现 + cache_ref = cache.clone() if cfg['use_initial_state'] else None + y_ref, cache_ref = reference_causal_conv1d_update( + x=x, cache=cache_ref, residual=residual, weight=weight, bias=bias, + activation=activation, use_initial_state=cfg['use_initial_state'], + D=D, W=W, BD=BD, BW=BW + ) + + # 验证输出一致性(绝对误差<1e-4,兼容浮点精度差异) + y_close = torch.allclose(y_triton, y_ref, atol=1e-4) + assert y_close, \ + f"输出不一致(配置:{cfg})\nTriton: {y_triton[:2, :2]}\nPyTorch: {y_ref[:2, :2]}" + + # 验证缓存一致性(启用初始状态时) + if cfg['use_initial_state']: + cache_close = torch.allclose(cache_triton, cache_ref, atol=1e-4) + assert cache_close, \ + f"缓存不一致(配置:{cfg})\nTriton: {cache_triton[:2, :2, :]}\nPyTorch: {cache_ref[:2, :2, :]}" + + print("所有场景验证通过!") + + +# 执行验证 +if __name__ == "__main__": + test_causal_conv1d_update_kernel_gdn() -- Gitee From 58388b787a3239b393794401fac02856a5027bb5 Mon Sep 17 00:00:00 2001 From: wangjintang Date: Mon, 29 Sep 2025 11:41:23 +0800 Subject: [PATCH 3/3] Revert "add test_causal_conv1d_update_kernel_gdn.py" This reverts commit 9bea23f9f5369efa2cfc94dd9cfed71bbed49e21. --- .../test_causal_conv1d_update_kernel_gdn.py | 252 ------------------ 1 file changed, 252 deletions(-) delete mode 100644 ascend/examples/tutorials/GDN/test_causal_conv1d_update_kernel_gdn.py diff --git a/ascend/examples/tutorials/GDN/test_causal_conv1d_update_kernel_gdn.py b/ascend/examples/tutorials/GDN/test_causal_conv1d_update_kernel_gdn.py deleted file mode 100644 index a178ba0..0000000 --- a/ascend/examples/tutorials/GDN/test_causal_conv1d_update_kernel_gdn.py +++ /dev/null @@ -1,252 +0,0 @@ -import pytest -import torch -import triton -import triton.language as tl -import numpy as np - -@triton.heuristics({ - 'USE_INITIAL_STATE': lambda args: args['cache'] is not None, - 'HAS_WEIGHT': lambda args: args['weight'] is not None, - 'HAS_BIAS': lambda args: args['bias'] is not None, - 'HAS_RESIDUAL': lambda args: args['residual'] is not None, -}) -@triton.jit -def causal_conv1d_update_kernel( - x, - cache, - residual, - y, - weight, - bias, - D: tl.constexpr, - W: tl.constexpr, - BD: tl.constexpr, - BW: tl.constexpr, - ACTIVATION: tl.constexpr, - USE_INITIAL_STATE: tl.constexpr, - HAS_WEIGHT: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_RESIDUAL: tl.constexpr, -): - i_d, i_n = tl.program_id(0), tl.program_id(1) - - o_d = i_d * BD + tl.arange(0, BD) - o_w = tl.arange(0, BW) + W - BW - m_d = o_d < D - m_w = o_w >= 0 - m_c = o_w < W - 1 - - # [BD] - b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=0).to(tl.float32) - - if USE_INITIAL_STATE: - # shift the cache by 1 with the last one being discarded - p_cache = tl.make_block_ptr(cache + i_n * D*W, (D, W), (W, 1), (i_d * BD, W - BW + 1), (BD, BW), (1, 0)) - # [BD, BW] - b_cache = tl.load(p_cache, boundary_check=(0, 1)).to(tl.float32) - b_cache = tl.where(m_c[None, :], b_cache, b_x[:, None]) - else: - b_cache = tl.zeros((BD, BW), dtype=tl.float32) - - if HAS_WEIGHT: - b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None] & m_w, other=0) - b_y = tl.sum(b_cache * b_w, 1) - else: - b_y = tl.sum(b_cache, 1) - if HAS_BIAS: - b_y += tl.load(bias + o_d, mask=m_d) - - if ACTIVATION == 'swish' or ACTIVATION == 'silu': - b_y = b_y * tl.sigmoid(b_y) - - if HAS_RESIDUAL: - b_y += tl.load(residual + i_n * D + o_d, mask=m_d, other=0) - - tl.store(y + i_n * D + o_d, tl.cast(b_y, dtype=y.dtype.element_ty, fp_downcast_rounding='rtne'), mask=m_d) - - if USE_INITIAL_STATE: - b_cache = tl.cast(b_cache, dtype=cache.dtype.element_ty, fp_downcast_rounding='rtne') - # update the cache in-place - p_cache = tl.make_block_ptr(cache + i_n * D*W, (D, W), (W, 1), (i_d * BD, W - BW), (BD, BW), (1, 0)) - tl.store(p_cache, b_cache, boundary_check=(0, 1)) - -def reference_causal_conv1d_update( - x, cache, residual, weight, bias, activation='silu', use_initial_state=True, - D=None, W=None, BD=16, BW=3 -): - batch_size, D = x.shape - W = cache.shape[2] if (use_initial_state and W is None) else W - dtype = x.dtype - device=x.device - - # 初始化缓存 - if not use_initial_state: - cache = torch.zeros(batch_size, D, W, device=device, dtype=dtype) - else: - cache = cache.clone() - - y = torch.zeros_like(x) - num_d_blocks = (D + BD - 1) // BD - - for i_n in range(batch_size): - for i_d in range(num_d_blocks): - # 特征分块索引 - o_d_start = i_d * BD - o_d_end = min((i_d + 1) * BD, D) - o_d = torch.arange(o_d_start, o_d_end, device=device) - bd = o_d.numel() - m_d = (o_d < D).unsqueeze(1) - - # 窗口索引(BW=3时,o_w = [0,1,2],无负数索引,适配NPU) - o_w = torch.arange(W - BW, W, device=device) # 3-3=0 → [0,1,2] - m_w = (o_w >= 0) & (o_w < W) # 全为True([0,1,2]均有效) - m_c = (o_w < (W - 1)).unsqueeze(0) # [1,3] → [T,T,F] - - # 加载x - b_x = x[i_n, o_d].to(torch.float32) - - # 缓存加载(NPU上BW=3无负数索引,无需特殊处理) - if use_initial_state: - cache_window_offset = W - BW + 1 # 3-3+1=1 - w_idx_cache = cache_window_offset + torch.arange(BW, device=device) # [1,2,3] - valid_cache_mask = (w_idx_cache < W) # [1,2,3]中3≥3 → [T,T,F] - - d_idx, w_idx = torch.meshgrid(o_d, w_idx_cache, indexing='ij') - valid_mask = m_d & valid_cache_mask.unsqueeze(0) - b_cache = torch.zeros(bd, BW, device=device, dtype=torch.float32) - b_cache[valid_mask] = cache[i_n, d_idx[valid_mask], w_idx[valid_mask]].to(torch.float32) - - # 缓存左移(m_c为[T,T,F],第3列用b_x更新) - update_mask = ~m_c.expand(bd, -1) - update_idx = torch.nonzero(update_mask, as_tuple=True) - b_cache[update_idx] = b_x[update_idx[0]] - - else: - b_cache = torch.zeros(bd, BW, device=device, dtype=torch.float32) - - # 权重加载(BW=3时无无效索引,直接计算) - if weight is not None: - w_d_idx, w_w_idx = torch.meshgrid(o_d, o_w, indexing='ij') - valid_weight_mask = m_d & m_w.unsqueeze(0) - b_w = torch.zeros(bd, BW, device=device, dtype=torch.float32) - b_w[valid_weight_mask] = weight[w_d_idx[valid_weight_mask], w_w_idx[valid_weight_mask]].to(torch.float32) - b_y = torch.sum(b_cache * b_w, dim=1) - else: - b_y = torch.sum(b_cache, dim=1) - - # 后续处理(与Triton一致) - if bias is not None: - b_y += bias[o_d].to(torch.float32) - if activation in ['silu', 'swish']: - b_y = b_y * torch.sigmoid(b_y) - if residual is not None: - b_y += residual[i_n, o_d].to(torch.float32) - - y[i_n, o_d] = b_y.to(dtype) - - # 缓存更新(NPU上无负数索引问题) - if use_initial_state: - update_window_offset = W - BW # 3-3=0 - o_w_update = update_window_offset + torch.arange(BW, device=device) # [0,1,2] - valid_update_mask = (o_w_update >= 0) & (o_w_update < W) # 全为True - u_d_idx, u_w_idx = torch.meshgrid(o_d, o_w_update, indexing='ij') - final_update_mask = m_d & valid_update_mask.unsqueeze(0) - cache[i_n, u_d_idx[final_update_mask], u_w_idx[final_update_mask]] = b_cache[final_update_mask].to(dtype) - - return y, cache - -def test_causal_conv1d_update_kernel_gdn(): - # GDN网络典型参数配置(覆盖不同场景) - configs = [ - # 场景1:完整配置(带缓存、权重、偏置、残差) - dict(use_initial_state=True, has_weight=True, has_bias=True, has_residual=True), - # 场景2:无权重(纯状态累加) - dict(use_initial_state=True, has_weight=False, has_bias=True, has_residual=True), - # 场景3:无初始缓存(从零开始) - dict(use_initial_state=False, has_weight=True, has_bias=False, has_residual=True), - # 场景4:无残差连接 - dict(use_initial_state=True, has_weight=True, has_bias=True, has_residual=False), - # 场景1:完整配置(带缓存、权重、偏置、残差) - dict(use_initial_state=True, has_weight=True, has_bias=True, has_residual=True), - # 场景2:无权重(纯状态累加) - dict(use_initial_state=True, has_weight=False, has_bias=True, has_residual=True), - # 场景3:无初始缓存(从零开始) - dict(use_initial_state=False, has_weight=True, has_bias=False, has_residual=True), - # 场景4:无残差连接 - dict(use_initial_state=True, has_weight=True, has_bias=True, has_residual=False), - ] - - # 通用参数(GDN典型值) - batch_size = 2 # 批量大小 - D = 64 # 特征维度(GDN常用64/128) - W = 3 # 卷积核大小(因果卷积常用3/5) - BD = 16 # 特征块大小(需整除D:64=16×4) - # BW = 3 #3 # 核块大小(需等于W,确保无边界问题) - activation = 'silu' # GDN常用门控激活 - - # 关键:根据硬件强制绑定BW(NPU仅支持3,CUDA仅支持4) - # device = 'cuda' if torch.cuda.is_available() else 'npu' - device = 'npu' - if device == 'cuda': - BW = 4 # CUDA必须用2的幂 - elif device == 'npu': - BW = 3 # NPU仅支持3(4触发的内存访问问题,需要继续定位) - else: - BW = 3 # CPU fallback - - # 确保数据类型为float32(避免NPU不支持double) - dtype = torch.float32 - - for cfg in configs: - # 生成输入数据(NPU设备,与算子匹配) - x = torch.randn(batch_size, D, device=device, dtype=dtype) # 当前时序输入 - # 缓存初始化:[batch, D, W],存储历史W-1步状态 + 当前步(初始随机) - cache = torch.randn(batch_size, D, W, device=device, dtype=dtype) if cfg['use_initial_state'] else None - residual = torch.randn(batch_size, D, device=device, dtype=dtype) if cfg['has_residual'] else None - weight = torch.randn(D, W, device=device, dtype=dtype) if cfg['has_weight'] else None - bias = torch.randn(D, device=device, dtype=dtype) if cfg['has_bias'] else None - - # 准备输出张量 - y_triton = torch.empty_like(x) - # 深拷贝缓存(避免参考实现修改原缓存) - cache_triton = cache.clone() if cache is not None else None - - # 调用Triton算子 - y_triton = torch.empty_like(x, dtype=dtype) # 显式指定dtype - cache_triton = cache.clone() if cfg['use_initial_state'] else None - grid = (D // BD, batch_size) - causal_conv1d_update_kernel[grid]( - x=x, cache=cache_triton, residual=residual, y=y_triton, - weight=weight, bias=bias, - D=D, W=W, BD=BD, BW=BW, ACTIVATION=activation, - USE_INITIAL_STATE=cfg['use_initial_state'], - HAS_WEIGHT=cfg['has_weight'], - HAS_BIAS=cfg['has_bias'], - HAS_RESIDUAL=cfg['has_residual'], - ) - - # 调用修复后的PyTorch参考实现 - cache_ref = cache.clone() if cfg['use_initial_state'] else None - y_ref, cache_ref = reference_causal_conv1d_update( - x=x, cache=cache_ref, residual=residual, weight=weight, bias=bias, - activation=activation, use_initial_state=cfg['use_initial_state'], - D=D, W=W, BD=BD, BW=BW - ) - - # 验证输出一致性(绝对误差<1e-4,兼容浮点精度差异) - y_close = torch.allclose(y_triton, y_ref, atol=1e-4) - assert y_close, \ - f"输出不一致(配置:{cfg})\nTriton: {y_triton[:2, :2]}\nPyTorch: {y_ref[:2, :2]}" - - # 验证缓存一致性(启用初始状态时) - if cfg['use_initial_state']: - cache_close = torch.allclose(cache_triton, cache_ref, atol=1e-4) - assert cache_close, \ - f"缓存不一致(配置:{cfg})\nTriton: {cache_triton[:2, :2, :]}\nPyTorch: {cache_ref[:2, :2, :]}" - - print("所有场景验证通过!") - - -# 执行验证 -if __name__ == "__main__": - test_causal_conv1d_update_kernel_gdn() -- Gitee