diff --git a/ascend/examples/pytest_ut/test_load_store_scalar.py b/ascend/examples/pytest_ut/test_load_store_scalar.py new file mode 100644 index 0000000000000000000000000000000000000000..c472c7e8ac98ee1de57f852d69b519dfbcc38bd7 --- /dev/null +++ b/ascend/examples/pytest_ut/test_load_store_scalar.py @@ -0,0 +1,36 @@ +import os + +import torch +import torch_npu +import triton +import triton.language as tl + +torch.set_printoptions(precision=4, + sci_mode=False) + + +@triton.jit +def load_store_scalar_kernel( + x_ptr, + output_ptr, + BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + x = tl.load(x_ptr) + tl.store(output_ptr, x) + + +def load_store_scalar(args): + x = args + output = torch.empty_like(x) + grid = (x.shape[0],) + block_size = 1024 + load_store_scalar_kernel[grid](x, output, BLOCK_SIZE=block_size) + return output + + +def test_load_store_scalar(): + size = (2) + x = torch.rand(size, device="npu") + output_triton = load_store_scalar(x) + assert torch.allclose(output_triton[0], x[0]) \ No newline at end of file diff --git a/ascend/triton-adapter/lib/AnalysisStructured/PtrAnalysis.cpp b/ascend/triton-adapter/lib/AnalysisStructured/PtrAnalysis.cpp index edc444515f9a6764bb89b2815448207fb1b47653..016817af1a4c33a275daafa5d61476f805add3f2 100644 --- a/ascend/triton-adapter/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/ascend/triton-adapter/lib/AnalysisStructured/PtrAnalysis.cpp @@ -2117,11 +2117,11 @@ LogicalResult PtrAnalysis::extractScalarFromLoadedTensor(Operation* op, OpBuilde return failure(); } - if(ptrMap.lookupOrNull(ptr)){ + if (ptrMap.lookupOrNull(ptr) || !ptr.getDefiningOp()) { auto tensorType = dyn_cast(loadResult.getType()); auto index = builder.create(loc, 0); SmallVector indices; - for(size_t i = 0; i < tensorType.getRank(); ++i){ + for (size_t i = 0; i < tensorType.getRank(); ++i) { assert(tensorType.getDimSize(i) == 1 && "Input tensor should be of shape tensor<1xanytype>"); indices.push_back(index); } @@ -2310,6 +2310,10 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { OpBuilder builder(op); auto ptrState = knownPtrs[op.getPtr()]; auto defaultAttr = builder.getIndexAttr(0); + if (!ptr && !op.getPtr().getDefiningOp() && + (dyn_cast(op.getResult().getType()) || dyn_cast(op.getResult().getType()))) { + ptr = op.getPtr(); + } if (!ptr && analysisSplat(op, builder, ptr, ptrState).failed()) { op->emitRemark("The offset value for the load operation is neither from addptr nor splat"); @@ -2346,14 +2350,21 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { loadOp->dump(); }); mlir::Value loadResult = loadOp.getResult(); - - // auto loadArry = cast(ptr.getType()).getShape(); // Support shape reading for both !ptr and tensor pointer types. SmallVector loadShape; if (auto ptrType = dyn_cast(ptr.getType())) { - auto loadArry = cast(ptrType.getPointeeType()).getShape(); - loadShape = SmallVector(loadArry.begin(), loadArry.end()); + if (auto tensorType = dyn_cast(ptrType.getPointeeType())) { + auto loadArry = tensorType.getShape(); + loadShape = SmallVector(loadArry.begin(), loadArry.end()); + } else if (dyn_cast(ptrType.getPointeeType()) || dyn_cast(ptrType.getPointeeType())) { + if (extractScalarFromLoadedTensor(op, builder, loadResult, loc).failed()) { + return failure(); + } + op.replaceAllUsesWith(loadResult); + op->erase(); + return success(); + } } else if (auto ptrType = dyn_cast(ptr.getType())) { auto loadArry = ptrType.getShape(); loadShape = SmallVector(loadArry.begin(), loadArry.end()); @@ -2486,6 +2497,9 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { auto loc = op.getLoc(); auto ptrState = knownPtrs[op.getPtr()]; OpBuilder builder(op); + if (!ptr && !op.getPtr().getDefiningOp()) { + ptr = op.getPtr(); + } if (!ptr && analysisSplat(op, builder, ptr, ptrState).failed()) { op->emitRemark("The offset value for the load operation is neither from addptr nor splat"); @@ -2542,8 +2556,12 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op) { // Support shape reading for both !ptr and tensor pointer types. SmallVector storeShape; if (auto ptrType = dyn_cast(ptr.getType())) { - auto loadArry = cast(ptrType.getPointeeType()).getShape(); - storeShape = SmallVector(loadArry.begin(), loadArry.end()); + if (!ptr.getDefiningOp()) { + storeShape = SmallVector(1, 1); + } else { + auto loadArry = cast(ptrType.getPointeeType()).getShape(); + storeShape = SmallVector(loadArry.begin(), loadArry.end()); + } } else if (auto ptrType = dyn_cast(ptr.getType())) { auto loadArry = ptrType.getShape(); storeShape = SmallVector(loadArry.begin(), loadArry.end()); diff --git a/ascend/triton-adapter/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/ascend/triton-adapter/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp index 787b6ff0b9f1a015c99d358a295bbf4305e2a86e..ffedf64df27833a7323cd6b7d86318a87d928c39 100644 --- a/ascend/triton-adapter/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp +++ b/ascend/triton-adapter/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -82,11 +82,13 @@ void LoadOp::build(OpBuilder &b, OperationState &state, Value ptr, auto ptrType = cast(ptrTensorType.getElementType()); auto elemType = ptrType.getPointeeType(); resType = RankedTensorType::get(ptrTensorType.getShape(), elemType); - } else if (tensorPtrType) { - auto tensorType = cast(tensorPtrType.getPointeeType()); - resType = RankedTensorType::get(tensorType.getShape(), - tensorType.getElementType()); + if (auto tensorType = dyn_cast(tensorPtrType.getPointeeType())) { + resType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType()); + } else if (dyn_cast(tensorPtrType.getPointeeType()) || + dyn_cast(tensorPtrType.getPointeeType())) { + resType = RankedTensorType::get({1}, tensorPtrType.getPointeeType()); + } } build(b, state, resType, ptr, dynamicDims, b.getDenseI64ArrayAttr(staticDims), dim_mode, other);