From f6b1067a4aeb2a1a69cb19ef02ad6b412d15df0c Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 19 May 2025 11:01:05 +0800 Subject: [PATCH 01/23] =?UTF-8?q?[feat]rab=E7=AE=97=E5=AD=90=E3=80=82?= =?UTF-8?q?=E6=8E=A8=E7=90=86=E3=80=81=E6=AD=A3=E5=90=91=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias.cpp | 204 ++++++++++++++ .../op_host/relative_attn_bias_tiling.h | 34 +++ .../relative_attn_bias/op_kernel/rab_common.h | 40 +++ .../op_kernel/relative_attn_bias.cpp | 34 +++ .../op_kernel/relative_attn_bias_kernel.h | 35 +++ .../op_kernel/relative_attn_bias_pos.h | 178 ++++++++++++ .../op_kernel/relative_attn_bias_time.h | 261 ++++++++++++++++++ .../relative_attn_bias.json | 82 ++++++ .../operators/relative_attn_bias/run.sh | 58 ++++ .../relative_attn_bias/relative_attn_bias.py | 210 ++++++++++++++ .../relative_attn_bias_v200.py | 205 ++++++++++++++ .../2.6.0/relative_attn_bias/CMakeLists.txt | 31 +++ .../2.6.0/relative_attn_bias/build_ops.sh | 20 ++ .../relative_attn_bias/relative_attn_bias.cpp | 76 +++++ 14 files changed, 1468 insertions(+) create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias.cpp create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias_tiling.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/rab_common.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias.cpp create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/relative_attn_bias.json create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/run.sh create mode 100644 mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py create mode 100644 mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py create mode 100644 mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/CMakeLists.txt create mode 100644 mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/build_ops.sh create mode 100644 mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias.cpp new file mode 100644 index 00000000..bf20f394 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias.cpp @@ -0,0 +1,204 @@ +/** +* @file relative_attn_bias.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include +#include "relative_attn_bias_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_api.h" +#include "tiling/platform/platform_ascendc.h" + +constexpr int32_t RESERVER_UB_SIZE = (5 * 1024); +constexpr int32_t DATA_ALIGN_BYTES = 32; +constexpr uint8_t NUM_BUFFER = 2; + +// input index +constexpr int IDENTITY_INDEX = 1; +constexpr int TIMESTAMPS_INDEX = 2; +constexpr int TIMESTAMPS_WEIGHTS_INDEX = 3; +// output index +constexpr int RAB_POSITION_INDEX = 0; +constexpr int RAB_TIME_INDEX = 1; +// attr index +constexpr int PAST_VALID_LENS_INDEX = 0; +constexpr int BUCKET_DIV_INDEX = 1; + +namespace optiling { +static ge::graphStatus TilingFunc(gert::TilingContext* context) +{ + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + if (coreNum == 0) { + printf("[ERROR]No available aicore\n"); + return ge::GRAPH_FAILED; + } + + RelativeAttnBiasTilingData tilingData; + // 获取序列长度大小 + auto identityShape = context->GetInputShape(IDENTITY_INDEX)->GetStorageShape(); + int s = identityShape.GetDim(0) / 2; // identityShape(2s, 2s) + tilingData.set_s(s); + + // 获取batchsize + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + const auto pastValidLensPtr = attrs->GetAttrPointer(PAST_VALID_LENS_INDEX); + int bs = pastValidLensPtr->GetSize(); + tilingData.set_bs(bs); + + auto *pastValidLensData = const_cast(reinterpret_cast(pastValidLensPtr->GetData())); + uint32_t pastValidLens[MAX_BATCH_SIZE]; + for (auto i = 0; i < bs; ++i) { + pastValidLens[i] = pastValidLensData[i]; + } + tilingData.set_pastValidLens(pastValidLens); + + // 获取ts_w(num_layer, num_buckets+1) + auto tswShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX)->GetStorageShape(); + int numLayer = tswShape.GetDim(0); + int numBuckets = tswShape.GetDim(1); + tilingData.set_numBuckets(numBuckets); + tilingData.set_numLayer(numLayer); + + float divs = *context->GetAttrs()->GetFloat(BUCKET_DIV_INDEX); + float clampMax = exp((numBuckets - 1) * divs); + tilingData.set_bucketDivisor(divs); + tilingData.set_clampMax(clampMax); + + // 获取ub + uint64_t ub; + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); + ub = ub - RESERVER_UB_SIZE; + // 获取数据类型 + auto floatType = context->GetInputTensor(IDENTITY_INDEX)->GetDataType(); + auto intType = context->GetInputTensor(TIMESTAMPS_INDEX)->GetDataType(); + int floatSize = ge::GetSizeByDataType(floatType); + int intSize = ge::GetSizeByDataType(intType); + tilingData.set_floatType(floatType); + tilingData.set_intType(intType); + + // 计算一次处理的窗口大小(stride) + int stride = ub / (NUM_BUFFER * 3 * floatSize); + tilingData.set_positionStride(stride); + + // 计算不含buff的stride长度 + ub -= numBuckets * numLayer * floatSize + numLayer * DATA_ALIGN_BYTES; // 减去tsw预留ub + uint32_t alignSeqLen = (s * floatSize + DATA_ALIGN_BYTES - 1) / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES / floatSize; + stride = ub / (sizeof(float) + intSize) / alignSeqLen; + + // 计算clamp buff所需空间 + std::vector shape_vec = {stride * alignSeqLen}; + ge::Shape shape(shape_vec); + uint32_t maxBuff = 0; + uint32_t minBuff = 0; + AscendC::GetClampMaxMinTmpSize(shape, sizeof(float), false, maxBuff, minBuff); + tilingData.set_buffSize(maxBuff); + + // 重新计算stride长度 + stride = (ub - maxBuff) / (sizeof(float) + intSize) / alignSeqLen; + tilingData.set_timeStride(stride); + + context->SetBlockDim(coreNum); + auto rowTilingData = context->GetRawTilingData(); + if (rowTilingData == nullptr) { + printf("[ERROR]Raw tiling data is nullptr\n"); + return ge::GRAPH_FAILED; + } + tilingData.SaveToBuffer(rowTilingData->GetData(), rowTilingData->GetCapacity()); + rowTilingData->SetDataSize(tilingData.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext* context) +{ + gert::Shape* rabPosOutShape = context->GetOutputShape(RAB_POSITION_INDEX); + + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + const auto pastValidLensPtr = attrs->GetAttrPointer(PAST_VALID_LENS_INDEX); + int bs = pastValidLensPtr->GetSize(); + const gert::Shape* identityShape = context->GetInputShape(IDENTITY_INDEX); + int s = identityShape->GetDim(0); // identityShape(2s, 2s) + + rabPosOutShape->SetDimNum(3); + rabPosOutShape->SetDim(0, bs); + rabPosOutShape->SetDim(1, s); + rabPosOutShape->SetDim(2, s); + + const gert::Shape* tShape = context->GetInputShape(TIMESTAMPS_INDEX); + const gert::Shape* tswShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX); + gert::Shape* rabTimeOutShape = context->GetOutputShape(RAB_TIME_INDEX); + int numLayers = tswShape->GetDim(1); + + rabTimeOutShape->SetDimNum(6); + rabPosOutShape->SetDim(0, numLayers); + rabPosOutShape->SetDim(1, bs); + rabPosOutShape->SetDim(2, s); + rabPosOutShape->SetDim(3, 1); + rabPosOutShape->SetDim(4, s); + rabPosOutShape->SetDim(5, 1); + return GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class RelativeAttnBias : public OpDef { +public: + explicit RelativeAttnBias(const char* name) : OpDef(name) + { + this->Input("rel_pos_bias") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("identity") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("timestamps") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("timestamps_weights") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("rab_pos") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("rab_time") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("past_valid_lens").ListInt(); + this->Attr("bucket_divisor").Float(); + + this->SetInferShape(ge::InferShape); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") + .ExtendCfgInfo("coreType.value", "AiCore") + .ExtendCfgInfo("prebuildPattern.value", "Opaque"); + + this->AICore().SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910", aicore_config); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + this->AICore().AddConfig("ascend310p", aicore_config); + } +}; + +OP_ADD(RelativeAttnBias); + +} // namespace ops \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias_tiling.h new file mode 100644 index 00000000..ab627f5e --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias_tiling.h @@ -0,0 +1,34 @@ +/** +* @file relative_attn_bias_tiling.h +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#include "register/tilingdata_base.h" +constexpr int MAX_BATCH_SIZE = 512; + +namespace optiling { +BEGIN_TILING_DATA_DEF(RelativeAttnBiasTilingData) +TILING_DATA_FIELD_DEF(int64_t, s); +TILING_DATA_FIELD_DEF(int64_t, bs); +TILING_DATA_FIELD_DEF(int64_t, positionStride); +TILING_DATA_FIELD_DEF(int64_t, timeStride); +TILING_DATA_FIELD_DEF_ARR(uint32_t, MAX_BATCH_SIZE, pastValidLens); + +TILING_DATA_FIELD_DEF(float, bucketDivisor); +TILING_DATA_FIELD_DEF(int64_t, numBuckets); +TILING_DATA_FIELD_DEF(int64_t, numLayer); +TILING_DATA_FIELD_DEF(float, clampMax); + +TILING_DATA_FIELD_DEF(int, floatType); +TILING_DATA_FIELD_DEF(int, intType); +TILING_DATA_FIELD_DEF(int, buffSize); + +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(RelativeAttnBias, RelativeAttnBiasTilingData) +} +#endif //MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/rab_common.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/rab_common.h new file mode 100644 index 00000000..4b784a01 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/rab_common.h @@ -0,0 +1,40 @@ +/** +* @file rab_common.h +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + + +#ifndef MXREC_ADD_ONS_RAB_COMMON_H +#define MXREC_ADD_ONS_RAB_COMMON_H + +#include "kernel_operator.h" +constexpr int DATA_ALIGN_BYTES = 32; +constexpr int MAX_BATCH_SIZE = 512; +constexpr int NUM_BUFFER = 2; +constexpr int MAX_SEQ_CNT = 128; +constexpr int GATHER_PROCESS_WINDOW = 4096; + +constexpr int8_t TYPE_FP32 = 0; +constexpr int8_t TYPE_FP16 = 1; +constexpr int8_t TYPE_INT32 = 3; +constexpr int8_t TYPE_INT64 = 9; + +using namespace AscendC; + +struct Args{ + // pos_bias + GM_ADDR positionBias; + GM_ADDR identity; + // ts_bias + GM_ADDR timestamps; + GM_ADDR timestampsWeights; + // out + GM_ADDR rabPosOut; + GM_ADDR rabTimeOut; + + GM_ADDR workspace; + GM_ADDR tiling; +}; +#endif //MXREC_ADD_ONS_RAB_COMMON_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias.cpp new file mode 100644 index 00000000..510ee1a6 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias.cpp @@ -0,0 +1,34 @@ +/** +* @file relative_attn_bias.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include "rab_common.h" +#include "relative_attn_bias_kernel.h" +#include "kernel_operator.h" + +extern "C" __global__ __aicore__ void relative_attn_bias( + GM_ADDR positionBias, + GM_ADDR identity, + GM_ADDR timestamps, + GM_ADDR timestampsWeights, + GM_ADDR rabPosOut, + GM_ADDR rabTimeOut, + GM_ADDR workspace, + GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + Args args{ + positionBias, identity, timestamps, timestampsWeights, rabPosOut, rabTimeOut, workspace, tiling + }; + if (tilingData.floatType == TYPE_FP32) { + RelativeAttnBias kernel; + kernel.Compute(args); + } else if (tilingData.floatType == TYPE_FP16) { + RelativeAttnBias kernel; + kernel.Compute(args); + } + +} diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h new file mode 100644 index 00000000..0abce386 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h @@ -0,0 +1,35 @@ +/** +* @file relative_attn_bias_kernel.h +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_KERNEL_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_KERNEL_H + +#include "rab_common.h" +#include "relative_attn_bias_pos.h" +#include "relative_attn_bias_time.h" +#include "kernel_operator.h" +using namespace AscendC; + +template +class RelativeAttnBias { +public: + __aicore__ inline RelativeAttnBias() {} + + __aicore__ inline void Compute(Args args) + { +#ifdef SUPPORT_V200 + RelativeAttnBiasPos rabPos; + rabPos.Compute(args); +#else +#endif + RelativeAttnBiasTime rabTime; + rabTime.Compute(args); + } + +}; + +#endif //MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_KERNEL_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h new file mode 100644 index 00000000..eafc12b4 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h @@ -0,0 +1,178 @@ +/** +* @file relative_attn_bias_pos.h +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H +#include "rab_common.h" +#include "kernel_operator.h" +using namespace AscendC; + +template +class RelativeAttnBiasPos { +public: + __aicore__ inline RelativeAttnBiasPos() {} + + __aicore__ inline void Init(Args args) + { + GET_TILING_DATA(tilingData, args.tiling); + s = 2 * tilingData.s; + bs = tilingData.bs; + stride = tilingData.positionStride; + for (auto i = 0; i < bs; ++i) { + pastValidLens[i] = tilingData.pastValidLens[i]; + } + + posBiasGT.SetGlobalBuffer((__gm__ floatType*)args.positionBias, s * s); + identityGT.SetGlobalBuffer((__gm__ floatType*)args.identity, s * s); + rabPosBiasOutGT.SetGlobalBuffer((__gm__ floatType*)args.rabPosOut, bs * s * s); + + pipe.InitBuffer(queIdentityIn, NUM_BUFFER, Ceil(2 * stride * sizeof(floatType))); + pipe.InitBuffer(quePosIn, NUM_BUFFER, Ceil(stride * sizeof(floatType))); + + int64_t totalTableSizeSplit = s % GetBlockNum(); + int64_t baseLen = s / GetBlockNum(); + if (GetBlockIdx() >= totalTableSizeSplit) { + totalRow = baseLen; + rowOffset = totalTableSizeSplit * (baseLen + 1) + (GetBlockIdx() - totalTableSizeSplit) * baseLen; + } else { + totalRow = baseLen + 1; + rowOffset = GetBlockIdx() * (baseLen + 1); + } + REL_POS_BIAS_FIRST = posBiasGT.GetValue(0); + } + + __aicore__ inline void ComputeIdentity(int offset, int cnt) + { + // DataCopyIn identity + LocalTensor identityUb = queIdentityIn.AllocTensor(); + + DataCopy(identityUb, identityGT[offset], Ceil(cnt * sizeof(floatType)) / sizeof(floatType)); + queIdentityIn.EnQue(identityUb); + + // Compute identity * rel_pos_bias[0, 0], (1 - identity) + LocalTensor identityFilledUb = queIdentityIn.DeQue(); + + // 后半段 (1 - identity) + Muls(identityFilledUb[stride], identityFilledUb, (floatType) -1, cnt); + Adds(identityFilledUb[stride], identityFilledUb[stride], (floatType) 1, cnt); + + // 前半段 identity * rel_pos_bias[0, 0] + Muls(identityFilledUb, identityFilledUb, REL_POS_BIAS_FIRST, cnt); + + queIdentityIn.EnQue(identityFilledUb); + } + + __aicore__ inline void DataCopyIn(int row, int offset, int cnt) + { + LocalTensor posBiasUb = quePosIn.AllocTensor(); + DataCopy(posBiasUb, posBiasGT[row * s + offset], Ceil(cnt * sizeof(floatType)) / sizeof(floatType)); + quePosIn.EnQue(posBiasUb); + } + + __aicore__ inline void ComputeRabBias(LocalTensor& identityCalcUb, int cnt) + { + LocalTensor posBiasUb = quePosIn.DeQue(); + Mul(posBiasUb, posBiasUb, identityCalcUb[stride], cnt); + Add(posBiasUb, posBiasUb, identityCalcUb, cnt); + pipe_barrier(PIPE_ALL); + quePosIn.EnQue(posBiasUb); + } + + __aicore__ inline int64_t Ceil(int64_t a, int64_t b=DATA_ALIGN_BYTES) + { + if (b == 0) { + return 0; + } + return (a + b - 1) / b * b; + } + + __aicore__ inline void DataCopyOut(int offset, int cnt) + { + uint32_t datasize = cnt * sizeof(floatType); + uint32_t alignLen = datasize / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES; + uint32_t unAlignLen = datasize - alignLen; + uint32_t alignCnt = alignLen / sizeof(floatType); + uint32_t unAlignCnt = unAlignLen / sizeof(floatType); + + LocalTensor posBiasUb = quePosIn.DeQue(); + // 对齐部分拷出 + if (alignLen > 0) { + DataCopy(rabPosBiasOutGT[offset], posBiasUb, cnt); + } + // 非对齐部分拷出 + if (unAlignLen > 0) { +#ifdef SUPPORT_V200 + uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(floatType))) - (1ul << unAlignCnt); + uint64_t mask[2] = {mask0, 0}; + Duplicate(posBiasUb[alignCnt], (floatType) 0, mask, 1, 1, 1); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + SetAtomicAdd(); + DataCopy(rabPosBiasOutGT[offset + alignCnt], + posBiasUb[alignCnt], + Ceil(unAlignLen) / sizeof(floatType)); + SetAtomicNone(); +#else + const DataCopyExtParams dataCopyExtParams{1, unAlignLen, 0, 0, 0}; + DataCopyPad(rabPosBiasOutGT[offset + alignCnt], + posBiasUb[alignCnt], + dataCopyExtParams); +#endif + } + quePosIn.FreeTensor(posBiasUb); + } + + __aicore__ inline void Compute(Args args) + { + Init(args); + for (int row=rowOffset; row < rowOffset + totalRow; ++row) { + int offset = 0; + for (int j = 0; j < (s + stride - 1) / stride; ++j) { + int remain = s - offset; + int cnt = remain > stride ? stride : remain; + ComputeIdentity(offset + row * s, cnt); + LocalTensor identityCalcUb = queIdentityIn.DeQue(); + + for (int b = 0; b < bs; ++b) { + int valid_len = pastValidLens[b]; + int valid_row = row > valid_len ? valid_len : row; + DataCopyIn(valid_row, offset, cnt); + ComputeRabBias(identityCalcUb, cnt); + int padOutPtr = b * s * s + row * s + j * stride; + DataCopyOut(padOutPtr, cnt); + } + queIdentityIn.FreeTensor(identityCalcUb); + offset += cnt; + } + } + } + +private: + // shape + int s; + int bs; + int stride; + // tiling + int rowOffset; // identity、rel_pos_bias(s, s)的行偏移 + int totalRow; // 需要处理的总行数 + +private: + TPipe pipe; + TQue queIdentityIn; + TQue queIdentityCalcIn; + TQue quePosIn; + + GlobalTensor identityGT; + GlobalTensor posBiasGT; + GlobalTensor rabPosBiasOutGT; + uint32_t pastValidLens[MAX_BATCH_SIZE]; + floatType REL_POS_BIAS_FIRST; // identity[0, 0] + +}; + +#endif //MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h new file mode 100644 index 00000000..7a0f7752 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h @@ -0,0 +1,261 @@ +/** +* @file relative_attn_bias_time.h +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H +#include +#include "rab_common.h" +#include "kernel_operator.h" +using namespace AscendC; + +struct SequenceParams { + int startIndexGT; + int startIndexUb; + int subValue; +}; + +template +class RelativeAttnBiasTime { +public: + __aicore__ inline RelativeAttnBiasTime() {} + + __aicore__ inline void Init(Args args) + { + GET_TILING_DATA(tilingData, args.tiling); + s = tilingData.s; + bs = tilingData.bs; + stride = tilingData.timeStride; + alignSeqLen = Ceil(s * sizeof(floatType)) / sizeof(floatType); + + int totalLen = bs * s; + uint32_t seqDatasize = s * sizeof(floatType); + alignLen = seqDatasize / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES; + alignCnt = alignLen / sizeof(floatType); + unalignLen = seqDatasize - alignLen; + unalignCnt = unalignLen / sizeof(floatType); + + div = 1 / tilingData.bucketDivisor; + numBuckets = tilingData.numBuckets; + alignNumBuckets = Ceil(numBuckets * sizeof(floatType)) / sizeof(floatType); + numLayer = tilingData.numLayer; + + clampMin = 1; // 根据仿真代码,指定为1 + clampMax = tilingData.clampMax; + + timestampsGT.SetGlobalBuffer((__gm__ int32_t*)args.timestamps, bs * s); + timestampsWeightsGT.SetGlobalBuffer((__gm__ floatType*)args.timestampsWeights, numBuckets * numLayer); + rabTimeBiasOutGT.SetGlobalBuffer((__gm__ floatType*)args.rabTimeOut, numLayer * bs * s * s); + + pipe.InitBuffer(queTimestamps, 1, stride * alignSeqLen * sizeof(int32_t)); + pipe.InitBuffer(queTimestampsFloat, 1, stride * alignSeqLen * sizeof(float)); + pipe.InitBuffer(queTimestampsWeights, 1, alignNumBuckets * numLayer * sizeof(floatType)); + pipe.InitBuffer(tmpQue, 1, Ceil(tilingData.buffSize)); + + + int totalTableSizeSplit = totalLen % GetBlockNum(); + int baseLen = totalLen / GetBlockNum(); + if (GetBlockIdx() >= totalTableSizeSplit) { + processRowLen = baseLen; + startIndex = totalTableSizeSplit * (baseLen + 1) + (GetBlockIdx() - totalTableSizeSplit) * baseLen; + } else { + processRowLen = baseLen + 1; + startIndex = GetBlockIdx() * (baseLen + 1); + } + } + + __aicore__ inline void FillSeqParams(SequenceParams* params, int offset, int cnt) + { + LocalTensor ts = queTimestamps.AllocTensor(); + DataCopy(ts, timestampsGT[offset], Ceil(cnt)); + for (int i=0; i < cnt; ++i) { + int seqSubValue = ts.GetValue(i); + int seqId = (offset + i) / s; + int seqOffsetUb = i * alignSeqLen; + int seqOffsetGT = seqId * s; + + params[i].startIndexGT = seqOffsetGT; + params[i].startIndexUb = seqOffsetUb; + params[i].subValue = seqSubValue; + } + queTimestamps.FreeTensor(ts); + } + + __aicore__ inline void DataCopyIn(SequenceParams* params, int cnt) + { + LocalTensor ts = queTimestamps.AllocTensor(); + for (int i=0; i < cnt; ++i) { + SequenceParams param = params[i]; + int startIndexGT = param.startIndexGT; + int startIndexUb = param.startIndexUb; + + DataCopy(ts[startIndexUb], timestampsGT[startIndexGT], alignSeqLen); + } + queTimestamps.EnQue(ts); + } + + __aicore__ inline void ComputeBucketTimestamps(SequenceParams* params, int rowCnt) + { + + LocalTensor tsInt = queTimestamps.DeQue(); + LocalTensor tsTmp = tsInt.template ReinterpretCast(); + LocalTensor ts = queTimestampsFloat.AllocTensor(); + LocalTensor buff = tmpQue.AllocTensor(); + + for (int i=0; i < rowCnt; ++i) { + SequenceParams param = params[i]; + int startIndexUb = param.startIndexUb; + int value = param.subValue; + Adds(tsInt[startIndexUb], tsInt[startIndexUb], (int32_t) -value, s); + } + + uint32_t cnt = rowCnt * alignSeqLen; + Cast(ts, tsInt, RoundMode::CAST_NONE, cnt); + + Abs(ts, ts, cnt); + ClampMin(tsTmp, ts, buff, clampMin, cnt); + Log(ts, tsTmp, cnt); + Muls(ts, ts, div, cnt); + ClampMax(tsTmp, ts, buff, (float) numBuckets, cnt); + + Cast(tsInt, tsTmp, RoundMode::CAST_TRUNC, cnt); + Muls(tsInt, tsInt, (int32_t) sizeof(floatType), cnt); // 计算gather时的偏移量单位为bytes + + tmpQue.FreeTensor(buff); + queTimestampsFloat.FreeTensor(ts); + queTimestamps.EnQue(tsInt); + } + + __aicore__ inline void IndexSelect(LocalTensor& tsw, LocalTensor& tsInt, int layer, int rowCnt) + { + uint32_t cnt = rowCnt * alignSeqLen; + LocalTensor rabTime = queTimestampsFloat.AllocTensor(); + uint32_t processLenMax = GATHER_PROCESS_WINDOW / sizeof(floatType); + uint32_t tmpOffset = 0; + while (tmpOffset < cnt) { + uint32_t processLen = (cnt - tmpOffset) > processLenMax ? processLenMax : (cnt - tmpOffset); + Gather(rabTime[tmpOffset], tsw[layer * alignNumBuckets], tsInt[tmpOffset], (uint32_t) 0, processLen); + tmpOffset += processLen; + } + queTimestampsFloat.EnQue(rabTime); + } + + __aicore__ inline void DataCopyOut(uint32_t ptr, int rowCnt) + { + LocalTensor rabTime = queTimestampsFloat.DeQue(); + + for (int i = 0; i < rowCnt; ++i) { + uint32_t ptrUb = i * alignSeqLen; + + // 对齐部分拷出 + if (alignLen > 0) { + DataCopy(rabTimeBiasOutGT[ptr + i * s], rabTime[ptrUb], s); + } + // 非对齐拷出 + if (unalignLen == 0) {continue;} +#ifdef SUPPORT_V200 + uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(floatType))) - (1ul << unalignCnt); + uint64_t mask[2] = {mask0, 0}; + Duplicate(rabTime[ptrUb + alignCnt], (floatType) 0, mask, 1, 1, 1); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + SetAtomicAdd(); + DataCopy(rabTimeBiasOutGT[ptr + i * s + alignCnt], + rabTime[ptrUb + alignCnt], + Ceil(unalignLen) / sizeof(floatType)); + SetAtomicNone(); +#else + const DataCopyExtParams dataCopyExtParams{1, unalignLen, 0, 0, 0}; + + DataCopyPad(rabTimeBiasOutGT[ptr + i * s + alignCnt], + rabTime[ptrUb + alignCnt], + dataCopyExtParams); +#endif + } + queTimestampsFloat.FreeTensor(rabTime); + } + + __aicore__ inline void DataCopyInTsw() + { + LocalTensor tsw = queTimestampsWeights.AllocTensor(); + for (int n = 0; n < numLayer; ++n) { + DataCopy(tsw[n * alignNumBuckets], timestampsWeightsGT[n * numBuckets], alignNumBuckets); + } + queTimestampsWeights.EnQue(tsw); + } + + __aicore__ inline int64_t Ceil(int64_t a, int64_t b=DATA_ALIGN_BYTES) + { + if (b == 0) { + return 0; + } + return (a + b - 1) / b * b; + } + + __aicore__ inline void Compute(Args args) + { + Init(args); + DataCopyInTsw(); + LocalTensor tsw = queTimestampsWeights.DeQue(); + + for (int offset = 0; offset < processRowLen; offset += stride) { + int rowOffset = offset + startIndex; + int rowCnt = stride > (processRowLen - offset) ? (processRowLen - offset) : stride; + + SequenceParams params[MAX_SEQ_CNT]; + FillSeqParams(params, rowOffset, rowCnt); + DataCopyIn(params, rowCnt); + ComputeBucketTimestamps(params, rowCnt); + + LocalTensor tsInt = queTimestamps.DeQue(); + for (int n = 0; n < numLayer; ++n) { + IndexSelect(tsw, tsInt, n, rowCnt); + pipe_barrier(PIPE_ALL); + + uint32_t ptr = (n * bs * s + rowOffset) * s; + DataCopyOut(ptr, rowCnt); + } + queTimestamps.FreeTensor(tsInt); + + } + queTimestampsWeights.FreeTensor(tsw); + } +private: + // shape + uint32_t s; + uint32_t alignSeqLen; + uint32_t bs; + uint32_t stride; + // align + uint32_t alignLen; + uint32_t alignCnt; + uint32_t unalignLen; + uint32_t unalignCnt; + // tiling + uint32_t startIndex; + uint32_t processRowLen; + + float div; + int32_t numBuckets; + int32_t alignNumBuckets; + int32_t numLayer; + float clampMin; + float clampMax; + +private: + GlobalTensor timestampsGT; + GlobalTensor timestampsWeightsGT; + GlobalTensor rabTimeBiasOutGT; + + TPipe pipe; + TQue queTimestamps; + TQue queTimestampsFloat; + TQue queTimestampsWeights; + TQue tmpQue; + +}; +#endif //MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/relative_attn_bias.json b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/relative_attn_bias.json new file mode 100644 index 00000000..b1559e7c --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/relative_attn_bias.json @@ -0,0 +1,82 @@ +[ + { + "op": "RelativeAttnBias", + "language": "cpp", + "input_desc": [ + { + "name": "rel_pos_bias", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float" + ] + }, + { + "name": "identity", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float" + ] + }, + { + "name": "timestamps", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "int64" + ] + }, + { + "name": "timestamps_weights", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float" + ] + } + ], + "output_desc": [ + { + "name": "rab_pos", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float" + ] + }, + { + "name": "rab_time", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float" + ] + } + ], + "attr": [ + { + "name": "past_valid_lens", + "param_type": "required", + "type": "list_int" + }, + { + "name": "bucket_divisor", + "param_type": "required", + "type": "float" + } + ] + } +] \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/run.sh b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/run.sh new file mode 100644 index 00000000..e4ee6838 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/run.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# ============================================================================== + +set -e + +# 查找msopgen的路径,加入到环境变量PATH中 +msopgen_path=$(find /usr/local/Ascend/ -name msopgen | grep bin) +parent_dir=$(dirname "$msopgen_path") +export PATH=$parent_dir:$PATH + +ai_core="ai_core-Ascend910B1" +if [ "$#" -eq 1 ]; then + ai_core=$1 +fi + +# 利用msopgen生成可编译文件 +rm -rf ./relative_attn_bias +python3 /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/bin/msopgen gen -i relative_attn_bias.json -f tf -c ${ai_core} -lan cpp -out ./relative_attn_bias -m 0 -op RelativeAttnBias +rm -rf relative_attn_bias/op_kernel/*.h +rm -rf relative_attn_bias/op_kernel/*.cpp +rm -rf relative_attn_bias/host/*.h +rm -rf relative_attn_bias/host/*.cpp +cp -rf op_kernel relative_attn_bias/ +cp -rf op_host relative_attn_bias/ + +cd relative_attn_bias + +# 判断当前目录下是否存在CMakePresets.json文件 +if [ ! -f "CMakePresets.json" ]; then + echo "ERROR, CMakePresets.json file not exist." + exit 1 +fi + +# 禁止生成CRC校验和 +sed -i 's/--nomd5/--nomd5 --nocrc/g' ./cmake/makeself.cmake + +# 修改cann安装路径 +sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json +# 修改vendor_name 防止覆盖之前vendor_name为customize的算子; +# vendor_name需要和aclnn中的CMakeLists.txt中的CUST_PKG_PATH值同步,不同步aclnn会调用失败; +# vendor_name字段值不能包含customize;包含会导致多算子部署场景CANN的vendors路径下config.ini文件内容截取错误 +sed -i 's:"customize":"relative_attn_bias":g' CMakePresets.json + +if [ "$ai_core" = "ai_core-Ascend310P3" ]; then + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_kernel.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_time.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_pos.h +fi + +line=`awk '/ENABLE_SOURCE_PACKAGE/{print NR}' CMakePresets.json` +line=`expr ${line} + 2` +sed -i "${line}s/True/False/g" CMakePresets.json + +bash build.sh + +# # 安装编译成功的算子包 +bash ./build_out/custom_opp*.run diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py new file mode 100644 index 00000000..435ab85d --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py @@ -0,0 +1,210 @@ +import random +import sysconfig +import time + +import torch +import torch_npu +import pytest +import torch.nn.functional as F + +torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") + +DEVICE = "npu:7" +NUM_BUCKETS = 128 +BUCKET_DIVISOR = 0.301 + + +def create_pos_w(train_len: int, num_layers: int): + return torch.range(0, 2 * train_len).unsqueeze(1).repeat(1, num_layers) + + +def create_past_valid_lens(bs: int, past_len: int): + return torch.randint(0, past_len, (bs,)) + + +def create_timestamps(train_len: int, candidate_len: int, past_valid_lens: torch.Tensor): + bs = past_valid_lens.size(0) + timestamps = torch.zeros(bs, train_len + candidate_len // 2) + for i, valid_len in enumerate(past_valid_lens): + if valid_len > 0: + timestamps[i, :valid_len] = torch.range(1, valid_len.int()) + + if candidate_len <= 0: + return timestamps + timestamps[:, -candidate_len // 2:] = train_len + 1 + return timestamps + + +def create_timestamps_weights(num_layers: int): + """ + :param num_layers: + :return: timestamps_weights(num_layers, NUM_BUCKETS + 1) + """ + return torch.range(0, NUM_BUCKETS).repeat(num_layers).reshape(num_layers, NUM_BUCKETS + 1) + + +def init_rel_pos_bias(pos_w: torch.Tensor, train_len: int, candidate_len: int, num_layers: int): + rel_pos_bias_list, identity_list = [], [] + + max_len = train_len + candidate_len // 2 + max_len_x2 = train_len * 2 + candidate_len + for layer_num in range(num_layers): + t = F.pad(pos_w[:2 * train_len - 1, layer_num], [0, train_len]).repeat(train_len) + t = t[..., :-train_len].reshape(1, train_len, 3 * train_len - 2) + r = (2 * train_len - 1) // 2 + + _rel_pos_bias = t[:, :, r:-r] + _rel_pos_bias = torch.nn.functional.pad(_rel_pos_bias, + (0, candidate_len // 2, 0, candidate_len // 2), + 'constant', + 0.0) + _rel_pos_bias = _rel_pos_bias.unsqueeze(-1).repeat(1, 1, 2, 2).reshape(1, max_len_x2, max_len_x2) + + pos_indices = torch.arange(max_len).repeat(max_len).view(max_len, max_len).to(_rel_pos_bias.device) + pos_indices = pos_indices.unsqueeze(-1).repeat(1, 2, 2).reshape(max_len * 2, max_len * 2) + identity = (pos_indices.t() == pos_indices).float() + + rel_pos_bias_list.append(_rel_pos_bias.squeeze(0)) + identity_list.append(identity) + + return torch.stack(rel_pos_bias_list), torch.stack(identity_list) + + +def rab_npu(rel_pos_bias: torch.Tensor, + identity: torch.Tensor, + timestamps: torch.Tensor, + timestamps_weights: torch.Tensor, + past_valid_lens: torch.Tensor): + """ + past_len = 1 ~ 4000 + candidate_len = 256 ~ 600 + bs = 1 ~ 10 + + :param rel_pos_bias: [past_len * 2 + candidate_len][past_len * 2 + candidate_len] + :param identity: [past_len * 2 + candidate_len][past_len * 2 + candidate_len] + :param past_valid_lens: [bs] + :return: [bs][1][past_len * 2 + candidate_len + 2][past_len * 2 + candidate_len + 2] + """ + + rab_pos, rab_time = torch.ops.mxrec.relative_attn_bias(rel_pos_bias, + identity, + timestamps, + timestamps_weights, + past_valid_lens.tolist(), + BUCKET_DIVISOR) + return rab_pos, rab_time + + +def rab_time_golden(ts_w: torch.Tensor, timestamps: torch.Tensor): + """ + num_buckets = 128 + num_layers = 1 - 20 + past_len = 1 - 4000 + candidate_len = 256 - 600 + + :param ts_w: [num_buckets + 1][num_layers] + :param timestamps: [bs][past_len + candidate_len // 2] + :param bucketization_divisor: float + :return: [num_layers][bs][1][2 * past_len + candidate_len + 1][2 * past_len + candidate_len + 2] + """ + infer_len = timestamps.shape[1] * 2 + bs = timestamps.shape[0] + num_layers = ts_w.shape[1] + + timestamps = timestamps.unsqueeze(-1).repeat(1, 1, 2) + diff_timestamps = timestamps.reshape(bs, infer_len, 1) - timestamps.reshape(bs, 1, infer_len) + + clamp_max = torch.exp(torch.tensor(NUM_BUCKETS * BUCKET_DIVISOR)) + diff_timestamps = torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) / BUCKET_DIVISOR + + bucket_timestamps = diff_timestamps.long().view(-1) + rab_time = torch.index_select(ts_w, dim=0, index=bucket_timestamps) + rab_time = rab_time.t().view(num_layers, bs, infer_len, infer_len) + return rab_time + + +def rab_pos_golden(rel_pos_bias: torch.Tensor, identity: torch.Tensor, past_valid_lens: torch.Tensor): + """ + past_len = 1 ~ 4000 + candidate_len = 256 ~ 600 + bs = 1 ~ 10 + + :param rel_pos_bias: [past_len * 2 + candidate_len][past_len * 2 + candidate_len] + :param identity: [past_len * 2 + candidate_len][past_len * 2 + candidate_len] + :param past_valid_lens: [bs] + :return: [bs][1][past_len * 2 + candidate_len + 2][past_len * 2 + candidate_len + 2] + """ + bs = past_valid_lens.shape[0] + rel_pos_bias_list = rel_pos_bias[:].unsqueeze(0).repeat(bs, 1, 1) + for i, valid_len in enumerate(past_valid_lens): + rel_pos_bias_list[i, valid_len:, :] = rel_pos_bias[valid_len, :] + + rel_pos_bias_list = rel_pos_bias_list * (1 - identity) + identity * rel_pos_bias_list[0, 0, 0] + rel_pos_bias_list = rel_pos_bias_list[:, :identity.shape[0], :identity.shape[0]] + return rel_pos_bias_list + + +@torch.no_grad() +def rab(num_layers, train_len, candidate_len, bs, dtype): + print(f"\n{num_layers}\t{train_len}\t{candidate_len}\t{bs}\t{dtype}", end="\t") + t0 = time.time() + layer_num = random.randint(0, num_layers - 1) + + pos_w = create_pos_w(train_len, num_layers).to(dtype) + past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) + timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) + timestamps_weights = create_timestamps_weights(num_layers).to(dtype) + rel_pos_bias_list, identity_list = init_rel_pos_bias(pos_w=pos_w, + train_len=train_len, + candidate_len=candidate_len, + num_layers=num_layers) + rel_pos_bias_list, identity_list = rel_pos_bias_list.to(dtype), identity_list.to(dtype) + t1 = time.time() + print(f"create_data: {t1 - t0:.4f}s", end="\t") + + rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) + identity_list = identity_list.to(DEVICE) + timestamps = timestamps.to(DEVICE) + timestamps_weights = timestamps_weights.to(DEVICE) + past_valid_lens = past_valid_lens.to(DEVICE) + torch_npu.npu.synchronize() + t2 = time.time() + print(f"to_device: {t2 - t1:.4f}s", end="\t") + + rab_pos_out, rab_time_out = rab_npu(rel_pos_bias=rel_pos_bias_list[layer_num, ...], + identity=identity_list[layer_num, ...], + timestamps=timestamps, + timestamps_weights=timestamps_weights, + past_valid_lens=past_valid_lens) + torch_npu.npu.synchronize() + t3 = time.time() + print(f"rab_npu: {t3 - t2:.4f}s", end="\t") + + # rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias_list[layer_num, ...], + # identity=identity_list[layer_num, ...], + # past_valid_lens=past_valid_lens) + rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1), + timestamps=timestamps) + torch_npu.npu.synchronize() + t4 = time.time() + print(f"rab_golden: {t4 - t3:.4f}s", end="\t") + + # assert torch.allclose(rab_pos_out_golden, rab_pos_out) + assert torch.allclose(rab_time_out_golden, rab_time_out) + + +@pytest.mark.parametrize("num_layers", [1, 8]) +@pytest.mark.parametrize("train_len", [500, 1000, 2000, 4000]) +@pytest.mark.parametrize("candidate_len", [600]) +@pytest.mark.parametrize("bs", [1, 2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): + rab(num_layers, train_len, candidate_len, bs, dtype) + + +@pytest.mark.parametrize("num_layers", [1, 8]) +@pytest.mark.parametrize("train_len,bs", [(500, 128), (1000, 32), (1000, 64), (4000, 8)]) +@pytest.mark.parametrize("candidate_len", [0]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_rab_train(num_layers, train_len, candidate_len, bs, dtype): + rab(num_layers, train_len, candidate_len, bs, dtype) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py new file mode 100644 index 00000000..25a2297e --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py @@ -0,0 +1,205 @@ +import random +import sysconfig +import time + +import torch +import torch_npu +import pytest +import torch.nn.functional as F + +torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") + +DEVICE = "npu:0" +NUM_BUCKETS = 128 +BUCKET_DIVISOR = 0.301 + + +def create_pos_w(train_len: int, num_layers: int): + return torch.range(0, 2 * train_len).unsqueeze(1).repeat(1, num_layers) + + +def create_past_valid_lens(bs: int, past_len: int): + return torch.randint(0, past_len, (bs,)) + + +def create_timestamps(train_len: int, candidate_len: int, past_valid_lens: torch.Tensor): + bs = past_valid_lens.size(0) + timestamps = torch.zeros(bs, train_len + candidate_len // 2) + for i, valid_len in enumerate(past_valid_lens): + if valid_len > 0: + timestamps[i, :valid_len] = torch.range(1, valid_len.int()) + + if candidate_len <= 0: + return timestamps + timestamps[:, -candidate_len // 2:] = train_len + 1 + return timestamps + + +def create_timestamps_weights(num_layers: int): + """ + :param num_layers: + :return: timestamps_weights(num_layers, NUM_BUCKETS + 1) + """ + return torch.range(0, NUM_BUCKETS).repeat(num_layers).reshape(num_layers, NUM_BUCKETS + 1) + + +def init_rel_pos_bias(pos_w: torch.Tensor, train_len: int, candidate_len: int, num_layers: int): + rel_pos_bias_list, identity_list = [], [] + + max_len = train_len + candidate_len // 2 + max_len_x2 = train_len * 2 + candidate_len + for layer_num in range(num_layers): + t = F.pad(pos_w[:2 * train_len - 1, layer_num], [0, train_len]).repeat(train_len) + t = t[..., :-train_len].reshape(1, train_len, 3 * train_len - 2) + r = (2 * train_len - 1) // 2 + + _rel_pos_bias = t[:, :, r:-r] + _rel_pos_bias = torch.nn.functional.pad(_rel_pos_bias, + (0, candidate_len // 2, 0, candidate_len // 2), + 'constant', + 0.0) + _rel_pos_bias = _rel_pos_bias.unsqueeze(-1).repeat(1, 1, 2, 2).reshape(1, max_len_x2, max_len_x2) + + pos_indices = torch.arange(max_len).repeat(max_len).view(max_len, max_len).to(_rel_pos_bias.device) + pos_indices = pos_indices.unsqueeze(-1).repeat(1, 2, 2).reshape(max_len * 2, max_len * 2) + identity = (pos_indices.t() == pos_indices).float() + + rel_pos_bias_list.append(_rel_pos_bias.squeeze(0)) + identity_list.append(identity) + + return torch.stack(rel_pos_bias_list), torch.stack(identity_list) + + +def rab_npu(rel_pos_bias: torch.Tensor, + identity: torch.Tensor, + timestamps: torch.Tensor, + timestamps_weights: torch.Tensor, + past_valid_lens: torch.Tensor): + """ + past_len = 1 ~ 4000 + candidate_len = 256 ~ 600 + bs = 1 ~ 10 + + :param rel_pos_bias: [past_len * 2 + candidate_len][past_len * 2 + candidate_len] + :param identity: [past_len * 2 + candidate_len][past_len * 2 + candidate_len] + :param past_valid_lens: [bs] + :return: [bs][1][past_len * 2 + candidate_len + 2][past_len * 2 + candidate_len + 2] + """ + + rab_pos, rab_time = torch.ops.mxrec.relative_attn_bias(rel_pos_bias, + identity, + timestamps, + timestamps_weights, + past_valid_lens.tolist(), + BUCKET_DIVISOR) + return rab_pos, rab_time + + +def rab_time_golden(ts_w: torch.Tensor, timestamps: torch.Tensor): + """ + num_buckets = 128 + num_layers = 1 - 20 + past_len = 1 - 4000 + candidate_len = 256 - 600 + + :param ts_w: [num_buckets + 1][num_layers] + :param timestamps: [bs][past_len + candidate_len // 2] + :param bucketization_divisor: float + :return: [num_layers][bs][1][2 * past_len + candidate_len + 1][2 * past_len + candidate_len + 2] + """ + infer_len = timestamps.shape[1] * 2 + bs = timestamps.shape[0] + num_layers = ts_w.shape[1] + + timestamps = timestamps.unsqueeze(-1).repeat(1, 1, 2) + diff_timestamps = timestamps.reshape(bs, infer_len, 1) - timestamps.reshape(bs, 1, infer_len) + + clamp_max = torch.exp(torch.tensor(NUM_BUCKETS * BUCKET_DIVISOR)) + diff_timestamps = torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) / BUCKET_DIVISOR + + bucket_timestamps = diff_timestamps.long().view(-1) + rab_time = torch.index_select(ts_w, dim=0, index=bucket_timestamps) + rab_time = rab_time.t().view(num_layers, bs, infer_len, infer_len) + return rab_time + + +def rab_pos_golden(rel_pos_bias: torch.Tensor, identity: torch.Tensor, past_valid_lens: torch.Tensor): + """ + past_len = 1 ~ 4000 + candidate_len = 256 ~ 600 + bs = 1 ~ 10 + + :param rel_pos_bias: [past_len * 2 + candidate_len][past_len * 2 + candidate_len] + :param identity: [past_len * 2 + candidate_len][past_len * 2 + candidate_len] + :param past_valid_lens: [bs] + :return: [bs][1][past_len * 2 + candidate_len + 2][past_len * 2 + candidate_len + 2] + """ + bs = past_valid_lens.shape[0] + rel_pos_bias_list = rel_pos_bias[:].unsqueeze(0).repeat(bs, 1, 1) + for i, valid_len in enumerate(past_valid_lens): + rel_pos_bias_list[i, valid_len:, :] = rel_pos_bias[valid_len, :] + + rel_pos_bias_list = rel_pos_bias_list * (1 - identity) + identity * rel_pos_bias_list[0, 0, 0] + rel_pos_bias_list = rel_pos_bias_list[:, :identity.shape[0], :identity.shape[0]] + return rel_pos_bias_list + + +@torch.no_grad() +def rab(num_layers, train_len, candidate_len, bs, dtype): + print(f"\n{num_layers}\t{train_len}\t{candidate_len}\t{bs}\t{dtype}", end="\t") + t0 = time.time() + layer_num = random.randint(0, num_layers - 1) + + pos_w = create_pos_w(train_len, num_layers).to(dtype) + past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) + timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) + timestamps_weights = create_timestamps_weights(num_layers).to(dtype) + rel_pos_bias_list, identity_list = init_rel_pos_bias(pos_w=pos_w, + train_len=train_len, + candidate_len=candidate_len, + num_layers=num_layers) + rel_pos_bias_list, identity_list = rel_pos_bias_list.to(dtype), identity_list.to(dtype) + t1 = time.time() + print(f"create_data: {t1 - t0:.4f}s", end="\t") + + rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) + identity_list = identity_list.to(DEVICE) + timestamps = timestamps.to(DEVICE) + timestamps_weights = timestamps_weights.to(DEVICE) + past_valid_lens = past_valid_lens.to(DEVICE) + torch_npu.npu.synchronize() + t2 = time.time() + print(f"to_device: {t2 - t1:.4f}s", end="\t") + + rab_pos_out, rab_time_out = rab_npu(rel_pos_bias=rel_pos_bias_list[layer_num, ...], + identity=identity_list[layer_num, ...], + timestamps=timestamps, + timestamps_weights=timestamps_weights, + past_valid_lens=past_valid_lens) + torch_npu.npu.synchronize() + t3 = time.time() + print(f"rab_npu: {t3 - t2:.4f}s", end="\t") + + rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias_list[layer_num, ...].to("cpu"), + identity=identity_list[layer_num, ...].to("cpu"), + past_valid_lens=past_valid_lens.to("cpu")) + rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1).to("cpu"), + timestamps=timestamps.to("cpu")) + torch_npu.npu.synchronize() + t4 = time.time() + print(f"rab_golden: {t4 - t3:.4f}s", end="\t") + + rab_pos_out, rab_time_out = rab_pos_out.to("cpu"), rab_time_out.to("cpu") + torch_npu.npu.synchronize() + + assert torch.allclose(rab_pos_out_golden, rab_pos_out) + assert torch.allclose(rab_time_out_golden, rab_time_out) + + +@pytest.mark.parametrize("num_layers", [8]) +@pytest.mark.parametrize("train_len", [500, 1000, 2000, 4000]) +@pytest.mark.parametrize("candidate_len", [600]) +@pytest.mark.parametrize("bs", [1, 2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): + rab(num_layers, train_len, candidate_len, bs, dtype) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/CMakeLists.txt b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/CMakeLists.txt new file mode 100644 index 00000000..f5bf4e2b --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/CMakeLists.txt @@ -0,0 +1,31 @@ +cmake_minimum_required(VERSION 3.10) + +project(relative_attn_bias) + +execute_process( + COMMAND python3 -c "import site; print(site.getsitepackages()[0])" + OUTPUT_VARIABLE python_site_packages_path +) +string(STRIP "${python_site_packages_path}" python_site_packages_path) + + +set(CMAKE_CXX_FLAGS "-fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -fPIE -pie ${CMAKE_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "-fabi-version=11 ${CMAKE_CXX_FLAGS}") +set(PYTORCH_INSTALL_PATH ${python_site_packages_path}/torch) +set(PYTORCH_NPU_INSTALL_PATH ${python_site_packages_path}/torch_npu) + +link_directories(${PYTORCH_INSTALL_PATH}/lib) +link_directories(${PYTORCH_NPU_INSTALL_PATH}/lib) + +add_library(relative_attn_bias SHARED relative_attn_bias.cpp) + +target_compile_features(relative_attn_bias PRIVATE cxx_std_17) +target_compile_options(relative_attn_bias PRIVATE -D_GLIBCXX_USE_CXX11_ABI=1) + +include_directories(${PYTORCH_NPU_INSTALL_PATH}/include/third_party/acl/inc) +include_directories(${PYTORCH_NPU_INSTALL_PATH}/include) +include_directories(${PYTORCH_INSTALL_PATH}/include) +include_directories(${PYTORCH_INSTALL_PATH}/include/torch/csrc/distributed) +include_directories(${PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include) + +target_link_libraries(relative_attn_bias PUBLIC c10 torch torch_cpu torch_npu) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/build_ops.sh b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/build_ops.sh new file mode 100644 index 00000000..c10e3ca3 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/build_ops.sh @@ -0,0 +1,20 @@ +#!/bin/bash +if [ -n "$ASCEND_INSTALL_PATH" ]; then + _ASCEND_INSTALL_PATH=$ASCEND_INSTALL_PATH +elif [ -n "$ASCEND_HOME_PATH" ]; then + _ASCEND_INSTALL_PATH=$ASCEND_HOME_PATH +else + if [ -d "$HOME/Ascend/ascend-toolkit/latest" ]; then + _ASCEND_INSTALL_PATH=$HOME/Ascend/ascend-toolkit/latest + else + _ASCEND_INSTALL_PATH=/usr/local/Ascend/ascend-toolkit/latest + fi +fi +source $_ASCEND_INSTALL_PATH/bin/setenv.bash + +set -e +rm -rf build +mkdir -p build +cmake -B build +cmake --build build -j +export LD_LIBRARY_PATH=$ASCEND_OPP_PATH/vendors/relative_attn_bias/op_api/lib:$LD_LIBRARY_PATH \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp new file mode 100644 index 00000000..8472bbb0 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp @@ -0,0 +1,76 @@ +/** + * @file relative_attn_bias.cpp + * + * Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ +#include +#include +#include +#include + +#include "../common/pytorch_npu_helper.hpp" +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using tensor_list = std::vector; +using namespace at; +using namespace std; + +std::tuple relative_attn_bias_impl_npu( + const Tensor &rel_pos_bias, + const Tensor &identity, + const Tensor ×tamps, + const Tensor ×tamps_weights, + const at::IntArrayRef past_valid_lens, + const double bucket_divisor) +{ + auto rel_pos_bias_conti = rel_pos_bias.contiguous(); + auto identity_conti = identity.contiguous(); + auto timestamps_conti = timestamps.contiguous(); + auto timestamps_weights_conti = timestamps_weights.contiguous(); + + const int bs = past_valid_lens.size(); + const int s = rel_pos_bias.size(0); // (2s, 2s) + const int _s = s / 2; // (2s, 2s) + const int num_layers = timestamps_weights.size(0); + + at::Tensor rab_pos_out = at::zeros({bs, s, s}, rel_pos_bias_conti.options()); + at::Tensor rab_time_out = at::zeros({num_layers, bs, _s, 1, _s, 1}, timestamps_weights_conti.options()); + + EXEC_NPU_CMD(aclnnRelativeAttnBias, + rel_pos_bias_conti, + identity_conti, + timestamps_conti, + timestamps_weights_conti, + past_valid_lens, + bucket_divisor, + rab_pos_out, + rab_time_out); + rab_time_out = rab_time_out.repeat({1, 1, 1, 2, 1, 2}) + .reshape({num_layers, bs, s, s}); + return {rab_pos_out, rab_time_out}; +} + +TORCH_LIBRARY_FRAGMENT(mxrec, m) +{ + m.def("relative_attn_bias(Tensor rel_pos_bias, " + " Tensor identity, " + " Tensor timestamps, " + " Tensor timestamps_weights, " + " int[] past_valid_lens," + " float bucket_divisor" + " ) -> (Tensor, Tensor)"); +} + +TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) +{ + m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); +} + +TORCH_LIBRARY_IMPL(fbgemm, PrivateUse1, m) +{ + m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); +} -- Gitee From 1f404dd72eb3e8932c40e8f21b2ac0e608980a49 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 19 May 2025 15:40:41 +0800 Subject: [PATCH 02/23] =?UTF-8?q?[feat]rab=E7=AE=97=E5=AD=90=E3=80=82?= =?UTF-8?q?=E6=8E=A8=E7=90=86=E3=80=81=E6=AD=A3=E5=90=91=E7=AE=97=E5=AD=90?= =?UTF-8?q?clean=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias.cpp | 42 ++++++++++----- .../op_host/relative_attn_bias_tiling.h | 15 +++--- .../relative_attn_bias/op_kernel/rab_common.h | 15 +++--- .../op_kernel/relative_attn_bias.cpp | 8 ++- .../op_kernel/relative_attn_bias_kernel.h | 17 +++--- .../op_kernel/relative_attn_bias_pos.h | 42 +++++++-------- .../op_kernel/relative_attn_bias_time.h | 53 +++++++++---------- .../relative_attn_bias/relative_attn_bias.py | 17 ++---- .../relative_attn_bias_v200.py | 19 ++----- .../relative_attn_bias/relative_attn_bias.cpp | 50 +++++++---------- 10 files changed, 122 insertions(+), 156 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias.cpp index bf20f394..43dcf324 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias.cpp @@ -25,6 +25,16 @@ constexpr int RAB_TIME_INDEX = 1; // attr index constexpr int PAST_VALID_LENS_INDEX = 0; constexpr int BUCKET_DIV_INDEX = 1; +// output dim +constexpr int RAB_POS_OUT_DIM = 3; +constexpr int RAB_TIME_OUT_DIM = 6; +constexpr int DIM_PLACE_HOLDER = 1; +constexpr int DIM0 = 0; +constexpr int DIM1 = 1; +constexpr int DIM2 = 2; +constexpr int DIM3 = 3; +constexpr int DIM4 = 4; +constexpr int DIM5 = 5; namespace optiling { static ge::graphStatus TilingFunc(gert::TilingContext* context) @@ -78,6 +88,10 @@ static ge::graphStatus TilingFunc(gert::TilingContext* context) int intSize = ge::GetSizeByDataType(intType); tilingData.set_floatType(floatType); tilingData.set_intType(intType); + if (floatSize == 0) { + printf("[ERROR]float type(%d) error. sizeof(float) = %d\n", floatType, floatSize); + return ge::GRAPH_FAILED; + } // 计算一次处理的窗口大小(stride) int stride = ub / (NUM_BUFFER * 3 * floatSize); @@ -122,25 +136,25 @@ static ge::graphStatus InferShape(gert::InferShapeContext* context) const auto pastValidLensPtr = attrs->GetAttrPointer(PAST_VALID_LENS_INDEX); int bs = pastValidLensPtr->GetSize(); const gert::Shape* identityShape = context->GetInputShape(IDENTITY_INDEX); - int s = identityShape->GetDim(0); // identityShape(2s, 2s) + int s = identityShape->GetDim(DIM0); // identityShape(2s, 2s) - rabPosOutShape->SetDimNum(3); - rabPosOutShape->SetDim(0, bs); - rabPosOutShape->SetDim(1, s); - rabPosOutShape->SetDim(2, s); + rabPosOutShape->SetDimNum(RAB_POS_OUT_DIM); + rabPosOutShape->SetDim(DIM0, bs); + rabPosOutShape->SetDim(DIM1, s); + rabPosOutShape->SetDim(DIM2, s); const gert::Shape* tShape = context->GetInputShape(TIMESTAMPS_INDEX); const gert::Shape* tswShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX); gert::Shape* rabTimeOutShape = context->GetOutputShape(RAB_TIME_INDEX); - int numLayers = tswShape->GetDim(1); - - rabTimeOutShape->SetDimNum(6); - rabPosOutShape->SetDim(0, numLayers); - rabPosOutShape->SetDim(1, bs); - rabPosOutShape->SetDim(2, s); - rabPosOutShape->SetDim(3, 1); - rabPosOutShape->SetDim(4, s); - rabPosOutShape->SetDim(5, 1); + int numLayers = tswShape->GetDim(DIM1); + + rabTimeOutShape->SetDimNum(RAB_TIME_OUT_DIM); + rabPosOutShape->SetDim(DIM0, numLayers); + rabPosOutShape->SetDim(DIM1, bs); + rabPosOutShape->SetDim(DIM2, s); + rabPosOutShape->SetDim(DIM3, DIM_PLACE_HOLDER); + rabPosOutShape->SetDim(DIM4, s); + rabPosOutShape->SetDim(DIM5, DIM_PLACE_HOLDER); return GRAPH_SUCCESS; } } // namespace ge diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias_tiling.h index ab627f5e..13babbd6 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias_tiling.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_host/relative_attn_bias_tiling.h @@ -1,10 +1,9 @@ /** -* @file relative_attn_bias_tiling.h -* -* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. -* -*/ - + * @file relative_attn_bias_tiling.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ #ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H #define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H @@ -30,5 +29,5 @@ TILING_DATA_FIELD_DEF(int, buffSize); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(RelativeAttnBias, RelativeAttnBiasTilingData) -} -#endif //MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +} // namespace optiling +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/rab_common.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/rab_common.h index 4b784a01..78d0b3f2 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/rab_common.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/rab_common.h @@ -1,10 +1,9 @@ /** -* @file rab_common.h -* -* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. -* -*/ - + * @file rab_common.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ #ifndef MXREC_ADD_ONS_RAB_COMMON_H #define MXREC_ADD_ONS_RAB_COMMON_H @@ -23,7 +22,7 @@ constexpr int8_t TYPE_INT64 = 9; using namespace AscendC; -struct Args{ +struct Args { // pos_bias GM_ADDR positionBias; GM_ADDR identity; @@ -37,4 +36,4 @@ struct Args{ GM_ADDR workspace; GM_ADDR tiling; }; -#endif //MXREC_ADD_ONS_RAB_COMMON_H +#endif // MXREC_ADD_ONS_RAB_COMMON_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias.cpp index 510ee1a6..2aa793ec 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias.cpp @@ -9,8 +9,7 @@ #include "relative_attn_bias_kernel.h" #include "kernel_operator.h" -extern "C" __global__ __aicore__ void relative_attn_bias( - GM_ADDR positionBias, +extern "C" __global__ __aicore__ void relative_attn_bias(GM_ADDR positionBias, GM_ADDR identity, GM_ADDR timestamps, GM_ADDR timestampsWeights, @@ -24,11 +23,10 @@ extern "C" __global__ __aicore__ void relative_attn_bias( positionBias, identity, timestamps, timestampsWeights, rabPosOut, rabTimeOut, workspace, tiling }; if (tilingData.floatType == TYPE_FP32) { - RelativeAttnBias kernel; + RelativeAttnBiasKernel kernel; kernel.Compute(args); } else if (tilingData.floatType == TYPE_FP16) { - RelativeAttnBias kernel; + RelativeAttnBiasKernel kernel; kernel.Compute(args); } - } diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h index 0abce386..6599441f 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h @@ -1,9 +1,9 @@ /** -* @file relative_attn_bias_kernel.h -* -* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. -* -*/ + * @file relative_attn_bias_kernel.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ #ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_KERNEL_H #define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_KERNEL_H @@ -14,8 +14,8 @@ #include "kernel_operator.h" using namespace AscendC; -template -class RelativeAttnBias { +template +class RelativeAttnBiasKernel { public: __aicore__ inline RelativeAttnBias() {} @@ -29,7 +29,6 @@ public: RelativeAttnBiasTime rabTime; rabTime.Compute(args); } - }; -#endif //MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_KERNEL_H +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_KERNEL_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h index eafc12b4..63f64159 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h @@ -1,10 +1,9 @@ /** -* @file relative_attn_bias_pos.h -* -* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. -* -*/ - + * @file relative_attn_bias_pos.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ #ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H #define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H @@ -12,7 +11,9 @@ #include "kernel_operator.h" using namespace AscendC; -template +constexpr SEQ_EXPAND = 2; // rab_pos中序列长度为原本输入的两倍 + +template class RelativeAttnBiasPos { public: __aicore__ inline RelativeAttnBiasPos() {} @@ -20,7 +21,7 @@ public: __aicore__ inline void Init(Args args) { GET_TILING_DATA(tilingData, args.tiling); - s = 2 * tilingData.s; + s = SEQ_EXPAND * tilingData.s; bs = tilingData.bs; stride = tilingData.positionStride; for (auto i = 0; i < bs; ++i) { @@ -31,7 +32,7 @@ public: identityGT.SetGlobalBuffer((__gm__ floatType*)args.identity, s * s); rabPosBiasOutGT.SetGlobalBuffer((__gm__ floatType*)args.rabPosOut, bs * s * s); - pipe.InitBuffer(queIdentityIn, NUM_BUFFER, Ceil(2 * stride * sizeof(floatType))); + pipe.InitBuffer(queIdentityIn, NUM_BUFFER, Ceil(SEQ_EXPAND * stride * sizeof(floatType))); pipe.InitBuffer(quePosIn, NUM_BUFFER, Ceil(stride * sizeof(floatType))); int64_t totalTableSizeSplit = s % GetBlockNum(); @@ -58,8 +59,8 @@ public: LocalTensor identityFilledUb = queIdentityIn.DeQue(); // 后半段 (1 - identity) - Muls(identityFilledUb[stride], identityFilledUb, (floatType) -1, cnt); - Adds(identityFilledUb[stride], identityFilledUb[stride], (floatType) 1, cnt); + Muls(identityFilledUb[stride], identityFilledUb, (floatType)-1, cnt); + Adds(identityFilledUb[stride], identityFilledUb[stride], (floatType)1, cnt); // 前半段 identity * rel_pos_bias[0, 0] Muls(identityFilledUb, identityFilledUb, REL_POS_BIAS_FIRST, cnt); @@ -83,7 +84,7 @@ public: quePosIn.EnQue(posBiasUb); } - __aicore__ inline int64_t Ceil(int64_t a, int64_t b=DATA_ALIGN_BYTES) + __aicore__ inline int64_t Ceil(int64_t a, int64_t b = DATA_ALIGN_BYTES) { if (b == 0) { return 0; @@ -109,19 +110,15 @@ public: #ifdef SUPPORT_V200 uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(floatType))) - (1ul << unAlignCnt); uint64_t mask[2] = {mask0, 0}; - Duplicate(posBiasUb[alignCnt], (floatType) 0, mask, 1, 1, 1); + Duplicate(posBiasUb[alignCnt], (floatType)0, mask, 1, 1, 1); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); SetAtomicAdd(); - DataCopy(rabPosBiasOutGT[offset + alignCnt], - posBiasUb[alignCnt], - Ceil(unAlignLen) / sizeof(floatType)); + DataCopy(rabPosBiasOutGT[offset + alignCnt], posBiasUb[alignCnt], Ceil(unAlignLen) / sizeof(floatType)); SetAtomicNone(); #else const DataCopyExtParams dataCopyExtParams{1, unAlignLen, 0, 0, 0}; - DataCopyPad(rabPosBiasOutGT[offset + alignCnt], - posBiasUb[alignCnt], - dataCopyExtParams); + DataCopyPad(rabPosBiasOutGT[offset + alignCnt], posBiasUb[alignCnt], dataCopyExtParams); #endif } quePosIn.FreeTensor(posBiasUb); @@ -130,7 +127,7 @@ public: __aicore__ inline void Compute(Args args) { Init(args); - for (int row=rowOffset; row < rowOffset + totalRow; ++row) { + for (int row = rowOffset; row < rowOffset + totalRow; ++row) { int offset = 0; for (int j = 0; j < (s + stride - 1) / stride; ++j) { int remain = s - offset; @@ -159,7 +156,7 @@ private: int stride; // tiling int rowOffset; // identity、rel_pos_bias(s, s)的行偏移 - int totalRow; // 需要处理的总行数 + int totalRow; // 需要处理的总行数 private: TPipe pipe; @@ -172,7 +169,6 @@ private: GlobalTensor rabPosBiasOutGT; uint32_t pastValidLens[MAX_BATCH_SIZE]; floatType REL_POS_BIAS_FIRST; // identity[0, 0] - }; -#endif //MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h index 7a0f7752..6e17ac7a 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h @@ -1,10 +1,9 @@ /** -* @file relative_attn_bias_time.h -* -* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. -* -*/ - + * @file relative_attn_bias_time.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ #ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H #define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H @@ -19,7 +18,7 @@ struct SequenceParams { int subValue; }; -template +template class RelativeAttnBiasTime { public: __aicore__ inline RelativeAttnBiasTime() {} @@ -31,7 +30,7 @@ public: bs = tilingData.bs; stride = tilingData.timeStride; alignSeqLen = Ceil(s * sizeof(floatType)) / sizeof(floatType); - + int totalLen = bs * s; uint32_t seqDatasize = s * sizeof(floatType); alignLen = seqDatasize / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES; @@ -56,7 +55,6 @@ public: pipe.InitBuffer(queTimestampsWeights, 1, alignNumBuckets * numLayer * sizeof(floatType)); pipe.InitBuffer(tmpQue, 1, Ceil(tilingData.buffSize)); - int totalTableSizeSplit = totalLen % GetBlockNum(); int baseLen = totalLen / GetBlockNum(); if (GetBlockIdx() >= totalTableSizeSplit) { @@ -72,7 +70,7 @@ public: { LocalTensor ts = queTimestamps.AllocTensor(); DataCopy(ts, timestampsGT[offset], Ceil(cnt)); - for (int i=0; i < cnt; ++i) { + for (int i = 0; i < cnt; ++i) { int seqSubValue = ts.GetValue(i); int seqId = (offset + i) / s; int seqOffsetUb = i * alignSeqLen; @@ -88,7 +86,7 @@ public: __aicore__ inline void DataCopyIn(SequenceParams* params, int cnt) { LocalTensor ts = queTimestamps.AllocTensor(); - for (int i=0; i < cnt; ++i) { + for (int i = 0; i < cnt; ++i) { SequenceParams param = params[i]; int startIndexGT = param.startIndexGT; int startIndexUb = param.startIndexUb; @@ -100,17 +98,16 @@ public: __aicore__ inline void ComputeBucketTimestamps(SequenceParams* params, int rowCnt) { - LocalTensor tsInt = queTimestamps.DeQue(); LocalTensor tsTmp = tsInt.template ReinterpretCast(); LocalTensor ts = queTimestampsFloat.AllocTensor(); LocalTensor buff = tmpQue.AllocTensor(); - for (int i=0; i < rowCnt; ++i) { + for (int i = 0; i < rowCnt; ++i) { SequenceParams param = params[i]; int startIndexUb = param.startIndexUb; int value = param.subValue; - Adds(tsInt[startIndexUb], tsInt[startIndexUb], (int32_t) -value, s); + Adds(tsInt[startIndexUb], tsInt[startIndexUb], (int32_t)-value, s); } uint32_t cnt = rowCnt * alignSeqLen; @@ -120,10 +117,10 @@ public: ClampMin(tsTmp, ts, buff, clampMin, cnt); Log(ts, tsTmp, cnt); Muls(ts, ts, div, cnt); - ClampMax(tsTmp, ts, buff, (float) numBuckets, cnt); + ClampMax(tsTmp, ts, buff, (float)numBuckets, cnt); Cast(tsInt, tsTmp, RoundMode::CAST_TRUNC, cnt); - Muls(tsInt, tsInt, (int32_t) sizeof(floatType), cnt); // 计算gather时的偏移量单位为bytes + Muls(tsInt, tsInt, (int32_t)sizeof(floatType), cnt); // 计算gather时的偏移量单位为bytes tmpQue.FreeTensor(buff); queTimestampsFloat.FreeTensor(ts); @@ -138,7 +135,7 @@ public: uint32_t tmpOffset = 0; while (tmpOffset < cnt) { uint32_t processLen = (cnt - tmpOffset) > processLenMax ? processLenMax : (cnt - tmpOffset); - Gather(rabTime[tmpOffset], tsw[layer * alignNumBuckets], tsInt[tmpOffset], (uint32_t) 0, processLen); + Gather(rabTime[tmpOffset], tsw[layer * alignNumBuckets], tsInt[tmpOffset], (uint32_t)0, processLen); tmpOffset += processLen; } queTimestampsFloat.EnQue(rabTime); @@ -156,26 +153,25 @@ public: DataCopy(rabTimeBiasOutGT[ptr + i * s], rabTime[ptrUb], s); } // 非对齐拷出 - if (unalignLen == 0) {continue;} + if (unalignLen == 0) { + continue; + } #ifdef SUPPORT_V200 uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(floatType))) - (1ul << unalignCnt); uint64_t mask[2] = {mask0, 0}; - Duplicate(rabTime[ptrUb + alignCnt], (floatType) 0, mask, 1, 1, 1); + Duplicate(rabTime[ptrUb + alignCnt], (floatType)0, mask, 1, 1, 1); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); SetAtomicAdd(); - DataCopy(rabTimeBiasOutGT[ptr + i * s + alignCnt], - rabTime[ptrUb + alignCnt], + DataCopy(rabTimeBiasOutGT[ptr + i * s + alignCnt], rabTime[ptrUb + alignCnt], Ceil(unalignLen) / sizeof(floatType)); SetAtomicNone(); #else const DataCopyExtParams dataCopyExtParams{1, unalignLen, 0, 0, 0}; - DataCopyPad(rabTimeBiasOutGT[ptr + i * s + alignCnt], - rabTime[ptrUb + alignCnt], - dataCopyExtParams); + DataCopyPad(rabTimeBiasOutGT[ptr + i * s + alignCnt], rabTime[ptrUb + alignCnt], dataCopyExtParams); #endif - } + } queTimestampsFloat.FreeTensor(rabTime); } @@ -188,7 +184,7 @@ public: queTimestampsWeights.EnQue(tsw); } - __aicore__ inline int64_t Ceil(int64_t a, int64_t b=DATA_ALIGN_BYTES) + __aicore__ inline int64_t Ceil(int64_t a, int64_t b = DATA_ALIGN_BYTES) { if (b == 0) { return 0; @@ -220,10 +216,10 @@ public: DataCopyOut(ptr, rowCnt); } queTimestamps.FreeTensor(tsInt); - } queTimestampsWeights.FreeTensor(tsw); } + private: // shape uint32_t s; @@ -256,6 +252,5 @@ private: TQue queTimestampsFloat; TQue queTimestampsWeights; TQue tmpQue; - }; -#endif //MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py index 435ab85d..058c9889 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py @@ -1,11 +1,10 @@ import random import sysconfig -import time -import torch -import torch_npu import pytest +import torch import torch.nn.functional as F +import torch_npu torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") @@ -146,8 +145,6 @@ def rab_pos_golden(rel_pos_bias: torch.Tensor, identity: torch.Tensor, past_vali @torch.no_grad() def rab(num_layers, train_len, candidate_len, bs, dtype): - print(f"\n{num_layers}\t{train_len}\t{candidate_len}\t{bs}\t{dtype}", end="\t") - t0 = time.time() layer_num = random.randint(0, num_layers - 1) pos_w = create_pos_w(train_len, num_layers).to(dtype) @@ -159,8 +156,6 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): candidate_len=candidate_len, num_layers=num_layers) rel_pos_bias_list, identity_list = rel_pos_bias_list.to(dtype), identity_list.to(dtype) - t1 = time.time() - print(f"create_data: {t1 - t0:.4f}s", end="\t") rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) identity_list = identity_list.to(DEVICE) @@ -168,8 +163,6 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): timestamps_weights = timestamps_weights.to(DEVICE) past_valid_lens = past_valid_lens.to(DEVICE) torch_npu.npu.synchronize() - t2 = time.time() - print(f"to_device: {t2 - t1:.4f}s", end="\t") rab_pos_out, rab_time_out = rab_npu(rel_pos_bias=rel_pos_bias_list[layer_num, ...], identity=identity_list[layer_num, ...], @@ -177,8 +170,6 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): timestamps_weights=timestamps_weights, past_valid_lens=past_valid_lens) torch_npu.npu.synchronize() - t3 = time.time() - print(f"rab_npu: {t3 - t2:.4f}s", end="\t") # rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias_list[layer_num, ...], # identity=identity_list[layer_num, ...], @@ -186,9 +177,7 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1), timestamps=timestamps) torch_npu.npu.synchronize() - t4 = time.time() - print(f"rab_golden: {t4 - t3:.4f}s", end="\t") - + # assert torch.allclose(rab_pos_out_golden, rab_pos_out) assert torch.allclose(rab_time_out_golden, rab_time_out) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py index 25a2297e..c1f55b33 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py @@ -1,11 +1,10 @@ import random import sysconfig -import time -import torch -import torch_npu import pytest +import torch import torch.nn.functional as F +import torch_npu torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") @@ -146,8 +145,6 @@ def rab_pos_golden(rel_pos_bias: torch.Tensor, identity: torch.Tensor, past_vali @torch.no_grad() def rab(num_layers, train_len, candidate_len, bs, dtype): - print(f"\n{num_layers}\t{train_len}\t{candidate_len}\t{bs}\t{dtype}", end="\t") - t0 = time.time() layer_num = random.randint(0, num_layers - 1) pos_w = create_pos_w(train_len, num_layers).to(dtype) @@ -159,8 +156,6 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): candidate_len=candidate_len, num_layers=num_layers) rel_pos_bias_list, identity_list = rel_pos_bias_list.to(dtype), identity_list.to(dtype) - t1 = time.time() - print(f"create_data: {t1 - t0:.4f}s", end="\t") rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) identity_list = identity_list.to(DEVICE) @@ -168,17 +163,14 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): timestamps_weights = timestamps_weights.to(DEVICE) past_valid_lens = past_valid_lens.to(DEVICE) torch_npu.npu.synchronize() - t2 = time.time() - print(f"to_device: {t2 - t1:.4f}s", end="\t") rab_pos_out, rab_time_out = rab_npu(rel_pos_bias=rel_pos_bias_list[layer_num, ...], identity=identity_list[layer_num, ...], timestamps=timestamps, timestamps_weights=timestamps_weights, past_valid_lens=past_valid_lens) + rab_pos_out, rab_time_out = rab_pos_out.to("cpu"), rab_time_out.to("cpu") torch_npu.npu.synchronize() - t3 = time.time() - print(f"rab_npu: {t3 - t2:.4f}s", end="\t") rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias_list[layer_num, ...].to("cpu"), identity=identity_list[layer_num, ...].to("cpu"), @@ -186,12 +178,7 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1).to("cpu"), timestamps=timestamps.to("cpu")) torch_npu.npu.synchronize() - t4 = time.time() - print(f"rab_golden: {t4 - t3:.4f}s", end="\t") - rab_pos_out, rab_time_out = rab_pos_out.to("cpu"), rab_time_out.to("cpu") - torch_npu.npu.synchronize() - assert torch.allclose(rab_pos_out_golden, rab_pos_out) assert torch.allclose(rab_time_out_golden, rab_time_out) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp index 8472bbb0..4b93a520 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp @@ -15,43 +15,33 @@ #include "../common/pytorch_npu_helper.hpp" using torch::autograd::AutogradContext; using torch::autograd::Function; -using tensor_list = std::vector; using namespace at; using namespace std; -std::tuple relative_attn_bias_impl_npu( - const Tensor &rel_pos_bias, - const Tensor &identity, - const Tensor ×tamps, - const Tensor ×tamps_weights, - const at::IntArrayRef past_valid_lens, - const double bucket_divisor) +std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, + const Tensor& identity, + const Tensor& timestamps, + const Tensor& timestampsWeights, + const at::IntArrayRef pastValidLens, + const double bucketDivisor) { - auto rel_pos_bias_conti = rel_pos_bias.contiguous(); - auto identity_conti = identity.contiguous(); - auto timestamps_conti = timestamps.contiguous(); - auto timestamps_weights_conti = timestamps_weights.contiguous(); + auto relPosBiasConti = relPosBias.contiguous(); + auto identityConti = identity.contiguous(); + auto timestampsConti = timestamps.contiguous(); + auto timestampsWeightsConti = timestampsWeights.contiguous(); - const int bs = past_valid_lens.size(); - const int s = rel_pos_bias.size(0); // (2s, 2s) - const int _s = s / 2; // (2s, 2s) - const int num_layers = timestamps_weights.size(0); + const int bs = pastValidLens.size(); + const int sx2 = relPosBias.size(0); // relPosBias(2s, 2s) + const int s = sx2 / 2; + const int numLayers = timestampsWeights.size(0); - at::Tensor rab_pos_out = at::zeros({bs, s, s}, rel_pos_bias_conti.options()); - at::Tensor rab_time_out = at::zeros({num_layers, bs, _s, 1, _s, 1}, timestamps_weights_conti.options()); + at::Tensor rabPosOut = at::zeros({bs, sx2, sx2}, relPosBiasConti.options()); + at::Tensor rabTimeOut = at::zeros({numLayers, bs, s, 1, s, 1}, timestampsWeightsConti.options()); - EXEC_NPU_CMD(aclnnRelativeAttnBias, - rel_pos_bias_conti, - identity_conti, - timestamps_conti, - timestamps_weights_conti, - past_valid_lens, - bucket_divisor, - rab_pos_out, - rab_time_out); - rab_time_out = rab_time_out.repeat({1, 1, 1, 2, 1, 2}) - .reshape({num_layers, bs, s, s}); - return {rab_pos_out, rab_time_out}; + EXEC_NPU_CMD(aclnnRelativeAttnBias, relPosBiasConti, identityConti, timestampsConti, timestampsWeightsConti, + pastValidLens, bucketDivisor, rabPosOut, rabTimeOut); + rabTimeOut = rabTimeOut.repeat({1, 1, 1, 2, 1, 2}).reshape({numLayers, bs, sx2, sx2}); + return {rabPosOut, rabTimeOut}; } TORCH_LIBRARY_FRAGMENT(mxrec, m) -- Gitee From 6ad9608c831984d97a20dbfd36fb0919a6eac631 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 20 May 2025 14:42:10 +0800 Subject: [PATCH 03/23] =?UTF-8?q?[feat]rab=E7=AE=97=E5=AD=90=E3=80=82clean?= =?UTF-8?q?=20code=20+=20=E6=A3=80=E8=A7=86=E6=84=8F=E8=A7=81=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/relative_attn_bias_kernel.h | 2 +- .../relative_attn_bias/op_kernel/relative_attn_bias_pos.h | 6 +++--- .../relative_attn_bias/op_kernel/relative_attn_bias_time.h | 4 ++-- .../rec_for_torch/operators/relative_attn_bias/run.sh | 2 +- .../2.6.0/relative_attn_bias/relative_attn_bias.cpp | 6 +++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h index 6599441f..58a58368 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_kernel.h @@ -17,7 +17,7 @@ using namespace AscendC; template class RelativeAttnBiasKernel { public: - __aicore__ inline RelativeAttnBias() {} + __aicore__ inline RelativeAttnBiasKernel() {} __aicore__ inline void Compute(Args args) { diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h index 63f64159..48961279 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_pos.h @@ -11,7 +11,7 @@ #include "kernel_operator.h" using namespace AscendC; -constexpr SEQ_EXPAND = 2; // rab_pos中序列长度为原本输入的两倍 +constexpr int SEQ_EXPAND = 2; // rab_pos中序列长度为原本输入的两倍 template class RelativeAttnBiasPos { @@ -111,8 +111,8 @@ public: uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(floatType))) - (1ul << unAlignCnt); uint64_t mask[2] = {mask0, 0}; Duplicate(posBiasUb[alignCnt], (floatType)0, mask, 1, 1, 1); - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + quePosIn.EnQue(posBiasUb); + posBiasUb = quePosIn.DeQue(); SetAtomicAdd(); DataCopy(rabPosBiasOutGT[offset + alignCnt], posBiasUb[alignCnt], Ceil(unAlignLen) / sizeof(floatType)); SetAtomicNone(); diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h index 6e17ac7a..133b71a6 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/op_kernel/relative_attn_bias_time.h @@ -160,8 +160,8 @@ public: uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(floatType))) - (1ul << unalignCnt); uint64_t mask[2] = {mask0, 0}; Duplicate(rabTime[ptrUb + alignCnt], (floatType)0, mask, 1, 1, 1); - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + queTimestampsFloat.EnQue(rabTime); + rabTime = queTimestampsFloat.DeQue(); SetAtomicAdd(); DataCopy(rabTimeBiasOutGT[ptr + i * s + alignCnt], rabTime[ptrUb + alignCnt], Ceil(unalignLen) / sizeof(floatType)); diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/run.sh b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/run.sh index e4ee6838..0a6d0783 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/run.sh +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias/run.sh @@ -54,5 +54,5 @@ sed -i "${line}s/True/False/g" CMakePresets.json bash build.sh -# # 安装编译成功的算子包 +# 安装编译成功的算子包 bash ./build_out/custom_opp*.run diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp index 4b93a520..056fe4a7 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp @@ -18,11 +18,11 @@ using torch::autograd::Function; using namespace at; using namespace std; -std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, +std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, const Tensor& identity, - const Tensor& timestamps, + const Tensor& timestamps, const Tensor& timestampsWeights, - const at::IntArrayRef pastValidLens, + const at::IntArrayRef pastValidLens, const double bucketDivisor) { auto relPosBiasConti = relPosBias.contiguous(); -- Gitee From 0c8535f7b7ea9476ee29aeaccade755617139028 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 21 May 2025 09:38:48 +0800 Subject: [PATCH 04/23] =?UTF-8?q?[feat]rab=E7=AE=97=E5=AD=90=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B=E5=91=BD=E5=90=8D=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../{relative_attn_bias.py => test_relative_attn_bias.py} | 0 ...relative_attn_bias_v200.py => test_relative_attn_bias_v200.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/{relative_attn_bias.py => test_relative_attn_bias.py} (100%) rename mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/{relative_attn_bias_v200.py => test_relative_attn_bias_v200.py} (100%) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py similarity index 100% rename from mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias.py rename to mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py similarity index 100% rename from mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_v200.py rename to mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py -- Gitee From 2a8fdd9903ca8651bd081e797d6892ebbf17cebc Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 21 May 2025 19:11:58 +0800 Subject: [PATCH 05/23] =?UTF-8?q?[feat]rab=E5=8F=8D=E5=90=91=E3=80=82?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_backward.cpp | 153 ++++++++++++++++++ .../relative_attn_bias_backward_tiling.h | 28 ++++ .../op_kernel/rab_common.h | 29 ++++ .../op_kernel/relative_attn_bias_backward.cpp | 29 ++++ .../op_kernel/relative_attn_bias_backward.h | 150 +++++++++++++++++ .../relative_attn_bias_backward.json | 47 ++++++ .../relative_attn_bias_backward/run.sh | 61 +++++++ .../relative_attn_bias_backward.py | 72 +++++++++ .../relative_attn_bias/relative_attn_bias.cpp | 24 +++ 9 files changed, 593 insertions(+) create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward_tiling.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/rab_common.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/relative_attn_bias_backward.json create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/run.sh create mode 100644 mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_backward.py diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp new file mode 100644 index 00000000..c706f624 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp @@ -0,0 +1,153 @@ +/** +* @file relative_attn_bias_backward.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include +#include "relative_attn_bias_backward_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_api.h" +#include "tiling/platform/platform_ascendc.h" +#include "../../../common/ops_log.h" + +constexpr int32_t RESERVER_UB_SIZE = (20 * 1024); +constexpr int32_t DATA_ALIGN_BYTES = 32; +constexpr uint8_t NUM_BUFFER = 2; + +// input index +constexpr int TIMESTAMPS_WEIGHTS_GRAD_INDEX = 0; +constexpr int BUCKET_TIMESTAMPS_INDEX = 1; +// output index +constexpr int RAB_POSITION_INDEX = 0; +constexpr int RAB_TIME_INDEX = 1; +// attr index +constexpr int NUM_BUCKET_INDEX = 0; +// output dim +constexpr int RAB_POS_OUT_DIM = 3; +constexpr int RAB_TIME_OUT_DIM = 6; +constexpr int DIM_PLACE_HOLDER = 1; +constexpr int DIM0 = 0; +constexpr int DIM1 = 1; +constexpr int DIM2 = 2; +constexpr int DIM3 = 3; +constexpr int DIM4 = 4; +constexpr int DIM5 = 5; + +namespace optiling { +static ge::graphStatus TimeTilingFunc(gert::TilingContext* context) +{ + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + + RelativeAttnBiasTilingData tilingData; + // 获取、校验必要shape数据 + auto gradShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_GRAD_INDEX)->GetStorageShape(); // grad(n, b, 2s, 2s) + int numBuckets = *context->GetAttrs()->GetInt(NUM_BUCKET_INDEX); + int numLayer = gradShape.GetDim(DIM0); + int batchsize = gradShape.GetDim(DIM1); + int s = gradShape.GetDim(DIM2); + int s2 = gradShape.GetDim(DIM3); + + OPS_CHECK(numBuckets <= 0, + OPS_LOG_E("Tiling Debug", "NumBuckets is invalid."), + return ge::GRAPH_FAILED); + OPS_CHECK(numLayer <= 0, + OPS_LOG_E("Tiling Debug", "Numlayer is invalid."), + return ge::GRAPH_FAILED); + OPS_CHECK(batchsize <= 0, + OPS_LOG_E("Tiling Debug", "Batchsize is invalid."), + return ge::GRAPH_FAILED); + OPS_CHECK(s <= 0 || s != s2, + OPS_LOG_E("Tiling Debug", "Sequence len is invalid."), + return ge::GRAPH_FAILED); + + tilingData.set_numBuckets(numBuckets); + tilingData.set_numLayer(numLayer); + tilingData.set_bs(batchsize); + tilingData.set_s(s); + // 获取计算中使用的步长等数据 + uint64_t ub; + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); + ub = ub - RESERVER_UB_SIZE; + // 获取数据类型 + auto floatType = context->GetInputTensor(TIMESTAMPS_WEIGHTS_GRAD_INDEX)->GetDataType(); + auto intType = context->GetInputTensor(BUCKET_TIMESTAMPS_INDEX)->GetDataType(); + int floatSize = ge::GetSizeByDataType(floatType); + int intSize = ge::GetSizeByDataType(intType); + OPS_CHECK(floatSize == 0 || intSize == 0, + OPS_LOG_E("Tiling Debug", "Invalid data type."), + return ge::GRAPH_FAILED); + // 去除tswGrad所需ub + ub = ub - numBuckets * numLayer * floatSize; + // 计算单次处理的block大小 + int stride = ub / (intSize + floatSize); + tilingData.set_floatType(floatType); + tilingData.set_intType(intType); + tilingData.set_timeStride(stride); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingFunc(gert::TilingContext* context) +{ + OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("rabTimeGrad", context->GetInputShape(TIMESTAMPS_WEIGHTS_GRAD_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("bucketTimestamps", context->GetInputShape(BUCKET_TIMESTAMPS_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("attrs", context->GetAttrs(), return ge::GRAPH_FAILED); + + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + OPS_CHECK(coreNum == 0, + OPS_LOG_E("Tiling Debug", "Core num is 0."), + return ge::GRAPH_FAILED); + auto ret = TimeTilingFunc(context); + + context->SetBlockDim(coreNum); + auto rowTilingData = context->GetRawTilingData(); + OPS_LOG_E_IF_NULL("GetRawTilingData", rowTilingData, return ge::GRAPH_FAILED); + tilingData.SaveToBuffer(rowTilingData->GetData(), rowTilingData->GetCapacity()); + rowTilingData->SetDataSize(tilingData.GetDataSize()); + return ret; +} +} // namespace optiling + +namespace ops { +class RelativeAttnBiasBackward : public OpDef { +public: + explicit RelativeAttnBiasBackward(const char* name) : OpDef(name) + { + this->Input("rab_time_grad") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("bucket_timestamps") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("timestamps_weights_grad") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("num_buckets").Int(); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") + .ExtendCfgInfo("coreType.value", "AiCore") + .ExtendCfgInfo("prebuildPattern.value", "Opaque"); + + this->AICore().SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910", aicore_config); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + } +}; + +OP_ADD(RelativeAttnBiasBackward); + +} // namespace ops \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward_tiling.h new file mode 100644 index 00000000..74536842 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward_tiling.h @@ -0,0 +1,28 @@ +/** + * @file relative_attn_bias_tiling.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(RelativeAttnBiasBackwardTilingData) +TILING_DATA_FIELD_DEF(int64_t, s); +TILING_DATA_FIELD_DEF(int64_t, bs); +TILING_DATA_FIELD_DEF(int64_t, timeStride); + +TILING_DATA_FIELD_DEF(float, bucketDivisor); +TILING_DATA_FIELD_DEF(int64_t, numBuckets); +TILING_DATA_FIELD_DEF(int64_t, numLayer); + +TILING_DATA_FIELD_DEF(int, floatType); +TILING_DATA_FIELD_DEF(int, intType); + +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(RelativeAttnBiasBackward, RelativeAttnBiasBackwardTilingData) +} // namespace optiling +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/rab_common.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/rab_common.h new file mode 100644 index 00000000..2e317177 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/rab_common.h @@ -0,0 +1,29 @@ +/** + * @file rab_common.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RAB_COMMON_H +#define MXREC_ADD_ONS_RAB_COMMON_H + +#include "kernel_operator.h" +constexpr int DATA_ALIGN_BYTES = 32; + +constexpr int8_t TYPE_FP32 = 0; +constexpr int8_t TYPE_FP16 = 1; +constexpr int8_t TYPE_INT32 = 3; +constexpr int8_t TYPE_INT64 = 9; + +using namespace AscendC; + +struct Args { + GM_ADDR rabTimeGrad; + GM_ADDR bucketTimestamps; + GM_ADDR timestampsWeightsGrad; + + GM_ADDR workspace; + GM_ADDR tiling; +}; +#endif // MXREC_ADD_ONS_RAB_COMMON_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp new file mode 100644 index 00000000..ab857f14 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp @@ -0,0 +1,29 @@ +/** +* @file relative_attn_bias_backward.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include "rab_common.h" +#include "relative_attn_bias_backward.h" +#include "kernel_operator.h" + +extern "C" __global__ __aicore__ void relative_attn_bias_backward(GM_ADDR rabTimeGrad, + GM_ADDR bucketTimestamps, + GM_ADDR timestampsWeightsGrad, + GM_ADDR workspace, + GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + Args args{ + rabTimeGrad, bucketTimestamps, timestampsWeightsGrad, workspace, tiling + }; + if (tilingData.floatType == TYPE_FP32) { + RelativeAttnBiasBackward kernel; + kernel.Compute(args); + } else if (tilingData.floatType == TYPE_FP16) { + RelativeAttnBiasBackward kernel; + kernel.Compute(args); + } +} diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h new file mode 100644 index 00000000..15074b29 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h @@ -0,0 +1,150 @@ +/** + * @file relative_attn_bias_backward.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_RELATIVE_ATTN_BIAS_BACKWARD_H +#define MXREC_RELATIVE_ATTN_BIAS_BACKWARD_H +#include "rab_common.h" +#include "kernel_operator.h" + +template +class RelativeAttnBiasBackward { +public: + __aicore__ inline RelativeAttnBiasBackward() {} + + __aicore__ inline void InitTensor(Args args) + { + tsGradGT.SetGlobalBuffer((__gm__ FloatType*)args.rabTimeGrad, numLayer * bs * s * s); + inQueBucketTimestamps.SetGlobalBuffer((__gm__ int32_t*)args.bucketTimestamps, numLayer * bs * s * s); + outQueTswGradOut.SetGlobalBuffer((__gm__ FloatType*)args.timestampsWeightsGrad, numLayer * bs * s * s); + + pipe.InitBuffer(inQueTsGrad, 1, CeilUp(stride * sizeof(FloatType), DATA_ALIGN_BYTES)); + pipe.InitBuffer(inQueBucketTimestamps, 1, CeilUp(stride * sizeof(int32_t), DATA_ALIGN_BYTES)); + pipe.InitBuffer(outQueTswGradOut, 1, CeilUp(numBuckets * numLayer * sizeof(FloatType), DATA_ALIGN_BYTES)); + } + + __aicore__ inline void InitTiling() + { + int totalLen = bs * s * s; + int totalTableSizeSplit = totalLen % GetBlockNum(); + int baseLen = totalLen / GetBlockNum(); + // 计算总共要处理的数据量、数据起始位置 + if (GetBlockIdx() >= totalTableSizeSplit) { + processLen = baseLen; + startGT = totalTableSizeSplit * (baseLen + 1) + (GetBlockIdx() - totalTableSizeSplit) * baseLen; + } else { + processLen = baseLen + 1; + startGT = GetBlockIdx() * (baseLen + 1); + } + } + + __aicore__ inline void Init(Args args) + { + GET_TILING_DATA(tilingData, args.tiling); + s = tilingData.s; + bs = tilingData.bs; + stride = tilingData.timeStride; + numBuckets = tilingData.numBuckets; + numLayer = tilingData.numLayer; + + InitTensor(args); + InitTiling(); + } + + __aicore__ inline void InitTswGrad() + { + LocalTensor gradOut = tswGradOutGT.AllocTensor(); + Duplicate(gradOut, (FloatType) 0, CeilUp(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType)); + tswGradOutGT.EnQue(gradOut); + } + + __aicore__ inline void DataCopyInIndex(uint32_t offset, uint32_t cnt) + { + LocalTensor bucketTimestamps = inQueBucketTimestamps.AllocTensor(); + DataCopy(bucketTimestamps, bucketTimestampsGT[offset + startGT], cnt + DATA_ALIGN_BYTES); + inQueBucketTimestamps.EnQue(bucketTimestamps); + } + + __aicore__ inline void DataCopyInGrad(uint8_t layer, uint32_t offset, uint32_t cnt) + { + LocalTensor grad = inQueTsGrad.AllocTensor(); + DataCopy(grad, tsGradGT[offset + layer * bs * s * s], cnt + DATA_ALIGN_BYTES) + inQueTsGrad.EnQue(grad); + } + + __aicore__ inline void ScatterAdd(LocalTensor dst, + LocalTensor src, + LocalTensor index, + uint8_t layer, + uint32_t cnt) + { + __ubuf__ FloatType* dstAddr = reinterpret_cast<__ubuf__ FloatType*>(dst.GetPhyAddr()); + __ubuf__ FloatType* srcAddr = reinterpret_cast<__ubuf__ FloatType*>(src.GetPhyAddr()); + __ubuf__ int32_t* indexAddr = reinterpret_cast<__ubuf__ int32_t*>(index.GetPhyAddr()); + uint32_t layerOffset = layer * numBuckets; + for (int i = 0; i < cnt; ++i) { + const auto ind = indexAddr[i]; + const auto value = src[i]; + dst[layerOffset + ind] += value; + } + } + + __aicore__ inline void DataCopyOut(LocalTensor gradOut) + { + // 同步计算结果 + tswGradOutGT.EnQue(gradOut); + gradOut = tswGradOutGT.DeQue; + + SetAtomicAdd(); + DataCopy(outQueTswGradOut, gradOut, CeilUp(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType)); + SetAtomicNone(); + } + + __aicore__ inline void Compute(Args args) + { + Init(args); + InitTswGrad(); + + uint32_t offset = 0; + LocalTensor gradOut = tswGradOutGT.DeQue(); + while (offset < processLen) { + uint32_t remain = processLen - offset; + uint32_t cnt = remain > stride ? stride : remain; + + DataCopyInIndex(offset, cnt); + LocalTensor index = inQueBucketTimestamps.DeQue(); + for (uint8_t n = 0; n < numLayer; ++n) { + DataCopyInGrad(n, offset, cnt); + LocalTensor grad = tsGradGT.DeQue(); + ScatterAdd(gradOut, grad, index, n, cnt) + } + inQueBucketTimestamps.FreeTensor(index); + } + DataCopyOut(gradOut); + tswGradOutGT.FreeTensor(gradOut); + } + +private: + GlobalTensor tsGradGT; + GlobalTensor bucketTimestampsGT; + GlobalTensor tswGradOutGT; + + TPipe pipe; + TQue inQueTsGrad; + TQue inQueBucketTimestamps; + TQue outQueTswGradOut; +private: + // shape + uint32_t s; + uint32_t bs; + uint32_t stride; + uint32_t numBuckets; + uint32_t numLayer; + // tiling + uint32_t processLen; + uint32_t startGT; +}; +#endif // MXREC_RELATIVE_ATTN_BIAS_BACKWARD_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/relative_attn_bias_backward.json b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/relative_attn_bias_backward.json new file mode 100644 index 00000000..70b0752f --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/relative_attn_bias_backward.json @@ -0,0 +1,47 @@ +[ + { + "op": "RelativeAttnBiasBackward", + "language": "cpp", + "input_desc": [ + { + "name": "rab_time_grad", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float" + ] + }, + { + "name": "bucket_timestamps", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "int32" + ] + } + ], + "output_desc": [ + { + "name": "timestamps_weights_grad", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "float" + ] + } + ], + "attr": [ + { + "name": "num_buckets", + "param_type": "required", + "type": "int" + } + ] + } +] \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/run.sh b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/run.sh new file mode 100644 index 00000000..bbb141e2 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/run.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# ============================================================================== + +set -e + +# 查找msopgen的路径,加入到环境变量PATH中 +msopgen_path=$(find /usr/local/Ascend/ -name msopgen | grep bin) +parent_dir=$(dirname "$msopgen_path") +export PATH=$parent_dir:$PATH + +ai_core="ai_core-Ascend910B1" +if [ "$#" -eq 1 ]; then + ai_core=$1 +fi + +# 利用msopgen生成可编译文件 +rm -rf ./relative_attn_bias_backward +python3 /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/bin/msopgen gen -i relative_attn_bias_backward.json -f tf -c ${ai_core} -lan cpp -out ./relative_attn_bias_backward -m 0 -op RelativeAttnBias +rm -rf relative_attn_bias_backward/op_kernel/*.h +rm -rf relative_attn_bias_backward/op_kernel/*.cpp +rm -rf relative_attn_bias_backward/host/*.h +rm -rf relative_attn_bias_backward/host/*.cpp +cp -rf op_kernel relative_attn_bias_backward/ +cp -rf op_host relative_attn_bias_backward/ + +cd relative_attn_bias_backward + +# 判断当前目录下是否存在CMakePresets.json文件 +if [ ! -f "CMakePresets.json" ]; then + echo "ERROR, CMakePresets.json file not exist." + exit 1 +fi + +# 禁止生成CRC校验和 +sed -i 's/--nomd5/--nomd5 --nocrc/g' ./cmake/makeself.cmake + +# 修改cann安装路径 +sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json +# 修改vendor_name 防止覆盖之前vendor_name为customize的算子; +# vendor_name需要和aclnn中的CMakeLists.txt中的CUST_PKG_PATH值同步,不同步aclnn会调用失败; +# vendor_name字段值不能包含customize;包含会导致多算子部署场景CANN的vendors路径下config.ini文件内容截取错误 +sed -i 's:"customize":"relative_attn_bias_backward":g' CMakePresets.json + +line=`awk '/ENABLE_SOURCE_PACKAGE/{print NR}' CMakePresets.json` +line=`expr ${line} + 2` +sed -i "${line}s/True/False/g" CMakePresets.json + +# 增加LOG_CPP编译选项支持错误日志打印 +sed -i "1 i include(../../../cmake/func.cmake)" ./op_host/CMakeLists.txt + +line1=`awk '/tartet_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line1}s/OP_TILING_LIB/OP_TILING_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +line2=`awk '/tartet_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line2}s/OP_PROTO_LIB/OP_PROTO_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +bash build.sh + +# 安装编译成功的算子包 +bash ./build_out/custom_opp*.run diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_backward.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_backward.py new file mode 100644 index 00000000..8762bb93 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_backward.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math +import sysconfig + +import pytest +import torch +import torch_npu + +torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") + +DEVICE = "npu:7" +NUM_BUCKETS = 128 + + +def create_rab_time_grad(num_layers: int, batchsize: int, s: int): + nearest10 = math.ceil(math.log10(s+1)) + table = torch.arange(s * s).reshape(s, s) / (10 **nearest10) # 用于排查哪些索引上的结果有问题 + batch = torch.arange(batchsize).reshape(-1, 1, 1) + result = batch + table.unsqueeze(0) # (b, s, s) + return result.unsqueeze(0).repeat(num_layers, 1, 1, 1) # (n, b, s, s) + + +def create_bucket_timestamps(batchsize: int, s: int): + repeat_times = batchsize * s * s // NUM_BUCKETS + 1 + result = torch.arange(NUM_BUCKETS).repeat(repeat_times)[:batchsize * s * s] + return result.reshape(batchsize, s, s) # (b, s, s) + + +def rab_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): + num_layers, b, s, _ = rab_time_grad.shape + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS).to(rab_time_grad.device) + + bucket_timestamps_expand = (bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) + .repeat(1, 1, 2, 1, 2) + .reshape(b, s, s)) + for n, grad in enumerate(rab_time_grad): + tsw_grad[n] = tsw_grad[n].scatter_add(src=grad.view(-1), index=bucket_timestamps_expand.view(-1), dim=0) + return tsw_grad + + +def rab_backward_op(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): + return torch.ops.mxrec.relative_attn_bias_backward(rab_time_grad, bucket_timestamps, NUM_BUCKETS) + + +@torch.no_grad() +def rab_backward(num_layers: int, batchsize: int, s: int): + grad = create_rab_time_grad(num_layers, batchsize, s) + bucket_timestamps = create_bucket_timestamps(batchsize, s // 2) + + op_result = rab_backward_op(grad, bucket_timestamps) + golden_result = rab_backward_golden(grad, bucket_timestamps) + assert torch.allclose(op_result, golden_result) + + +if __name__ == '__main__': + rab_backward(8, 1, 10) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp index 6b8d9737..b7f88521 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp @@ -44,6 +44,26 @@ std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, return {rabPosOut, rabTimeOut}; } +Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, + const Tensor& bucketTimestamps, + const int numBuckets) +{ + const int numLayers = rabTimeGrad.size(0); // rabTimeGrad(n, b, 2s, 2s) + const int batchsize = rabTimeGrad.size(1); // rabTimeGrad(n, b, 2s, 2s) + const int sx2 = rabTimeGrad.size(2); // rabTimeGrad(n, b, 2s, 2s) + const int s = sx2 / 2; + + auto rabTimeGradConti = rabTimeGrad.contiguous(); + auto bucketTimestampsConti = bucketTimestamps.contiguous(); // (n, b, s, s) + bucketTimestampsConti = bucketTimestampsConti.reshape({numLayers, batchsize, s, 1, s, 1}) + .repeat({1, 1, 1, 2, 1, 2}) + .reshape({numLayers, batchsize, sx2, sx2}); + + at::Tensor rabTimeGradOut = at::zeros({numLayers, numBuckets}, rabTimeGrad.options()); + EXEC_NPU_CMD(aclnnRelativeAttnBiasBackward, rabTimeGradConti, bucketTimestampsConti, rabTimeGradOut, numBuckets); + return rabTimeGradOut; +} + TORCH_LIBRARY_FRAGMENT(mxrec, m) { m.def("relative_attn_bias(Tensor rel_pos_bias, " @@ -53,6 +73,10 @@ TORCH_LIBRARY_FRAGMENT(mxrec, m) " int[] past_valid_lens," " float bucket_divisor" " ) -> (Tensor, Tensor)"); + m.def("relative_attn_bias(Tensor rab_time_grad, " + " Tensor bucket_timestamps, " + " int num_buckets" + " ) -> Tensor") } TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) -- Gitee From aa7bcca616ea155bda95b5e817db12d8e786913b Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 22 May 2025 11:38:55 +0800 Subject: [PATCH 06/23] =?UTF-8?q?[feat]rab=E5=8F=8D=E5=90=91=E3=80=82debug?= =?UTF-8?q?=E7=BC=96=E8=AF=91=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_backward.cpp | 9 ++---- .../op_kernel/relative_attn_bias_backward.h | 28 +++++++++---------- .../relative_attn_bias_backward/run.sh | 2 +- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp index c706f624..9313df4e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp @@ -36,12 +36,8 @@ constexpr int DIM4 = 4; constexpr int DIM5 = 5; namespace optiling { -static ge::graphStatus TimeTilingFunc(gert::TilingContext* context) +static ge::graphStatus TimeTilingFunc(RelativeAttnBiasBackwardTilingData& tilingData, gert::TilingContext* context) { - auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); - size_t coreNum = ascendPlatform.GetCoreNumAiv(); - - RelativeAttnBiasTilingData tilingData; // 获取、校验必要shape数据 auto gradShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_GRAD_INDEX)->GetStorageShape(); // grad(n, b, 2s, 2s) int numBuckets = *context->GetAttrs()->GetInt(NUM_BUCKET_INDEX); @@ -102,7 +98,8 @@ static ge::graphStatus TilingFunc(gert::TilingContext* context) OPS_CHECK(coreNum == 0, OPS_LOG_E("Tiling Debug", "Core num is 0."), return ge::GRAPH_FAILED); - auto ret = TimeTilingFunc(context); + RelativeAttnBiasBackwardTilingData tilingData; + auto ret = TimeTilingFunc(tilingData, context); context->SetBlockDim(coreNum); auto rowTilingData = context->GetRawTilingData(); diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h index 15074b29..6a4b9e3f 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h @@ -18,12 +18,12 @@ public: __aicore__ inline void InitTensor(Args args) { tsGradGT.SetGlobalBuffer((__gm__ FloatType*)args.rabTimeGrad, numLayer * bs * s * s); - inQueBucketTimestamps.SetGlobalBuffer((__gm__ int32_t*)args.bucketTimestamps, numLayer * bs * s * s); - outQueTswGradOut.SetGlobalBuffer((__gm__ FloatType*)args.timestampsWeightsGrad, numLayer * bs * s * s); + bucketTimestampsGT.SetGlobalBuffer((__gm__ int32_t*)args.bucketTimestamps, numLayer * bs * s * s); + tswGradOutGT.SetGlobalBuffer((__gm__ FloatType*)args.timestampsWeightsGrad, numLayer * bs * s * s); - pipe.InitBuffer(inQueTsGrad, 1, CeilUp(stride * sizeof(FloatType), DATA_ALIGN_BYTES)); - pipe.InitBuffer(inQueBucketTimestamps, 1, CeilUp(stride * sizeof(int32_t), DATA_ALIGN_BYTES)); - pipe.InitBuffer(outQueTswGradOut, 1, CeilUp(numBuckets * numLayer * sizeof(FloatType), DATA_ALIGN_BYTES)); + pipe.InitBuffer(inQueTsGrad, 1, AlignTo32(stride * sizeof(FloatType))); + pipe.InitBuffer(inQueBucketTimestamps, 1, AlignTo32(stride * sizeof(int32_t))); + pipe.InitBuffer(outQueTswGradOut, 1, AlignTo32(numBuckets * numLayer * sizeof(FloatType))); } __aicore__ inline void InitTiling() @@ -56,15 +56,15 @@ public: __aicore__ inline void InitTswGrad() { - LocalTensor gradOut = tswGradOutGT.AllocTensor(); - Duplicate(gradOut, (FloatType) 0, CeilUp(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType)); - tswGradOutGT.EnQue(gradOut); + LocalTensor gradOut = outQueTswGradOut.AllocTensor(); + Duplicate(gradOut, (FloatType) 0, AlignTo32(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType)); + outQueTswGradOut.EnQue(gradOut); } __aicore__ inline void DataCopyInIndex(uint32_t offset, uint32_t cnt) { LocalTensor bucketTimestamps = inQueBucketTimestamps.AllocTensor(); - DataCopy(bucketTimestamps, bucketTimestampsGT[offset + startGT], cnt + DATA_ALIGN_BYTES); + DataCopy(bucketTimestamps, bucketTimestampsGT[offset + startGT], cnt + DATA_ALIGN_BYTES / sizeof(int32_t)); inQueBucketTimestamps.EnQue(bucketTimestamps); } @@ -95,11 +95,11 @@ public: __aicore__ inline void DataCopyOut(LocalTensor gradOut) { // 同步计算结果 - tswGradOutGT.EnQue(gradOut); - gradOut = tswGradOutGT.DeQue; + outQueTswGradOut.EnQue(gradOut); + gradOut = outQueTswGradOut.DeQue; SetAtomicAdd(); - DataCopy(outQueTswGradOut, gradOut, CeilUp(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType)); + DataCopy(tswGradOutGT, gradOut, AlignTo32(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType)); SetAtomicNone(); } @@ -119,12 +119,12 @@ public: for (uint8_t n = 0; n < numLayer; ++n) { DataCopyInGrad(n, offset, cnt); LocalTensor grad = tsGradGT.DeQue(); - ScatterAdd(gradOut, grad, index, n, cnt) + ScatterAdd(gradOut, grad, index, n, cnt); } inQueBucketTimestamps.FreeTensor(index); } DataCopyOut(gradOut); - tswGradOutGT.FreeTensor(gradOut); + outQueTswGradOut.FreeTensor(gradOut); } private: diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/run.sh b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/run.sh index bbb141e2..38ff04e6 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/run.sh +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/run.sh @@ -16,7 +16,7 @@ fi # 利用msopgen生成可编译文件 rm -rf ./relative_attn_bias_backward -python3 /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/bin/msopgen gen -i relative_attn_bias_backward.json -f tf -c ${ai_core} -lan cpp -out ./relative_attn_bias_backward -m 0 -op RelativeAttnBias +python3 /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/bin/msopgen gen -i relative_attn_bias_backward.json -f tf -c ${ai_core} -lan cpp -out ./relative_attn_bias_backward -m 0 -op RelativeAttnBiasBackward rm -rf relative_attn_bias_backward/op_kernel/*.h rm -rf relative_attn_bias_backward/op_kernel/*.cpp rm -rf relative_attn_bias_backward/host/*.h -- Gitee From fad0f7312a912e34a9ac34aa6e7a8483b70f5a9b Mon Sep 17 00:00:00 2001 From: zhoucy Date: Sat, 24 May 2025 16:40:06 +0800 Subject: [PATCH 07/23] =?UTF-8?q?[feat]rab=E5=8F=8D=E5=90=91=E3=80=82float?= =?UTF-8?q?32=E8=B7=91=E9=80=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_backward.cpp | 38 +++++-- .../op_kernel/relative_attn_bias_backward.h | 100 ++++++++++++------ ...py => test_relative_attn_bias_backward.py} | 57 ++++++---- .../relative_attn_bias/relative_attn_bias.cpp | 33 +++--- 4 files changed, 149 insertions(+), 79 deletions(-) rename mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/{relative_attn_bias_backward.py => test_relative_attn_bias_backward.py} (47%) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp index 9313df4e..84b0e257 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp @@ -17,23 +17,18 @@ constexpr int32_t DATA_ALIGN_BYTES = 32; constexpr uint8_t NUM_BUFFER = 2; // input index -constexpr int TIMESTAMPS_WEIGHTS_GRAD_INDEX = 0; +constexpr int INPUT_GRAD_INDEX = 0; constexpr int BUCKET_TIMESTAMPS_INDEX = 1; // output index -constexpr int RAB_POSITION_INDEX = 0; -constexpr int RAB_TIME_INDEX = 1; +constexpr int TIMESTAMPS_WEIGHTS_GRAD_INDEX = 0; // attr index constexpr int NUM_BUCKET_INDEX = 0; // output dim -constexpr int RAB_POS_OUT_DIM = 3; -constexpr int RAB_TIME_OUT_DIM = 6; -constexpr int DIM_PLACE_HOLDER = 1; +constexpr int TSW_GRAD_OUT_DIM = 2; constexpr int DIM0 = 0; constexpr int DIM1 = 1; constexpr int DIM2 = 2; constexpr int DIM3 = 3; -constexpr int DIM4 = 4; -constexpr int DIM5 = 5; namespace optiling { static ge::graphStatus TimeTilingFunc(RelativeAttnBiasBackwardTilingData& tilingData, gert::TilingContext* context) @@ -77,9 +72,14 @@ static ge::graphStatus TimeTilingFunc(RelativeAttnBiasBackwardTilingData& tiling OPS_LOG_E("Tiling Debug", "Invalid data type."), return ge::GRAPH_FAILED); // 去除tswGrad所需ub - ub = ub - numBuckets * numLayer * floatSize; + ub = ub - numBuckets * numLayer * sizeof(float); // 计算单次处理的block大小 - int stride = ub / (intSize + floatSize); + int stride; + if (floatType == ge::DataType::DT_FLOAT16) { + stride = ub / (intSize + floatSize + sizeof(float)); // 申请额外内存做cast + } else { + stride = ub / (intSize + sizeof(float)); + } tilingData.set_floatType(floatType); tilingData.set_intType(intType); tilingData.set_timeStride(stride); @@ -110,6 +110,22 @@ static ge::graphStatus TilingFunc(gert::TilingContext* context) } } // namespace optiling +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext* context) +{ + gert::Shape* tswGradOutShape = context->GetOutputShape(TIMESTAMPS_WEIGHTS_GRAD_INDEX); + const gert::Shape* tsGradShape = context->GetInputShape(INPUT_GRAD_INDEX); // (n, b, 2s, 2s) + int n = tsGradShape.GetDim(DIM0); + int numBuckets = *context->GetAttrs()->GetInt(NUM_BUCKET_INDEX); + + rabPosOutShape->SetDimNum(TSW_GRAD_OUT_DIM); + rabPosOutShape->SetDim(DIM0, n); + rabPosOutShape->SetDim(DIM1, numBuckets); + return GRAPH_SUCCESS; +} +} // namespace ge + + namespace ops { class RelativeAttnBiasBackward : public OpDef { public: @@ -132,6 +148,8 @@ public: .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("num_buckets").Int(); + this->SetInferShape(ge::InferShape); + OpAICoreConfig aicore_config; aicore_config.DynamicCompileStaticFlag(true) .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h index 6a4b9e3f..f442b121 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h @@ -7,6 +7,7 @@ #ifndef MXREC_RELATIVE_ATTN_BIAS_BACKWARD_H #define MXREC_RELATIVE_ATTN_BIAS_BACKWARD_H +#include #include "rab_common.h" #include "kernel_operator.h" @@ -18,12 +19,15 @@ public: __aicore__ inline void InitTensor(Args args) { tsGradGT.SetGlobalBuffer((__gm__ FloatType*)args.rabTimeGrad, numLayer * bs * s * s); - bucketTimestampsGT.SetGlobalBuffer((__gm__ int32_t*)args.bucketTimestamps, numLayer * bs * s * s); + bucketTimestampsGT.SetGlobalBuffer((__gm__ int32_t*)args.bucketTimestamps, bs * s * s); tswGradOutGT.SetGlobalBuffer((__gm__ FloatType*)args.timestampsWeightsGrad, numLayer * bs * s * s); - pipe.InitBuffer(inQueTsGrad, 1, AlignTo32(stride * sizeof(FloatType))); + pipe.InitBuffer(inQueTsGrad, 1, AlignTo32(stride * sizeof(float))); pipe.InitBuffer(inQueBucketTimestamps, 1, AlignTo32(stride * sizeof(int32_t))); - pipe.InitBuffer(outQueTswGradOut, 1, AlignTo32(numBuckets * numLayer * sizeof(FloatType))); + pipe.InitBuffer(outQueTswGradOut, 1, AlignTo32(numBuckets * numLayer * sizeof(float))); + if (std::is_same::value) { + pipe.InitBuffer(tmpQue, 1, AlignTo32(stride * sizeof(FloatType))); + } } __aicore__ inline void InitTiling() @@ -56,51 +60,77 @@ public: __aicore__ inline void InitTswGrad() { - LocalTensor gradOut = outQueTswGradOut.AllocTensor(); - Duplicate(gradOut, (FloatType) 0, AlignTo32(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType)); + LocalTensor gradOut = outQueTswGradOut.AllocTensor(); + Duplicate(gradOut, (float)0, AlignTo32(numLayer * numBuckets * sizeof(float)) / sizeof(float)); outQueTswGradOut.EnQue(gradOut); } __aicore__ inline void DataCopyInIndex(uint32_t offset, uint32_t cnt) { LocalTensor bucketTimestamps = inQueBucketTimestamps.AllocTensor(); - DataCopy(bucketTimestamps, bucketTimestampsGT[offset + startGT], cnt + DATA_ALIGN_BYTES / sizeof(int32_t)); + DataCopy(bucketTimestamps, bucketTimestampsGT[offset], AlignTo32(cnt * sizeof(int32_t)) / sizeof(int32_t)); inQueBucketTimestamps.EnQue(bucketTimestamps); } - __aicore__ inline void DataCopyInGrad(uint8_t layer, uint32_t offset, uint32_t cnt) + __aicore__ inline void DataCopyInGrad(uint32_t layer, uint32_t offset, uint32_t cnt) { - LocalTensor grad = inQueTsGrad.AllocTensor(); - DataCopy(grad, tsGradGT[offset + layer * bs * s * s], cnt + DATA_ALIGN_BYTES) - inQueTsGrad.EnQue(grad); + if (std::is_same::value) { + // 数据拷入 + LocalTensor gradFP16 = tmpQue.AllocTensor(); + DataCopy(grad, tsGradGT[offset + layer * bs * s * s], + AlignTo32(cnt * sizeof(FloatType)) / sizeof(FloatType)); + tmpQue.EnQue(gradFP16); + gradFP16 = tmpQue.DeQue(); + // 数据转换 + LocalTensor gradFP32 = inQueTsGrad.AllocTensor(); + Cast(gradFP32, gradFP16, cnt); + + inQueTsGrad.EnQue(gradFP32); + tmpQue.FreeTensor(gradFP16); + } else { + // 数据拷入 + LocalTensor gradFP32 = inQueTsGrad.AllocTensor(); + DataCopy(gradFP32, tsGradGT[offset + layer * bs * s * s], + AlignTo32(cnt * sizeof(FloatType)) / sizeof(FloatType)); + inQueTsGrad.EnQue(gradFP32); + } } - __aicore__ inline void ScatterAdd(LocalTensor dst, - LocalTensor src, - LocalTensor index, - uint8_t layer, - uint32_t cnt) + __aicore__ inline void ScatterAdd(LocalTensor& dst, LocalTensor& src, LocalTensor& index, + uint32_t layer, uint32_t cnt) { - __ubuf__ FloatType* dstAddr = reinterpret_cast<__ubuf__ FloatType*>(dst.GetPhyAddr()); - __ubuf__ FloatType* srcAddr = reinterpret_cast<__ubuf__ FloatType*>(src.GetPhyAddr()); - __ubuf__ int32_t* indexAddr = reinterpret_cast<__ubuf__ int32_t*>(index.GetPhyAddr()); uint32_t layerOffset = layer * numBuckets; + __ubuf__ float* dstAddr = reinterpret_cast<__ubuf__ float*>(dst[layerOffset].GetPhyAddr()); + __ubuf__ float* srcAddr = reinterpret_cast<__ubuf__ float*>(src.GetPhyAddr()); + __ubuf__ int32_t* indexAddr = reinterpret_cast<__ubuf__ int32_t*>(index.GetPhyAddr()); for (int i = 0; i < cnt; ++i) { const auto ind = indexAddr[i]; - const auto value = src[i]; - dst[layerOffset + ind] += value; + const auto value = srcAddr[i]; + dstAddr[ind] += value; } } - __aicore__ inline void DataCopyOut(LocalTensor gradOut) + __aicore__ inline void DataCopyOut(LocalTensor& gradOut) { // 同步计算结果 + uint32_t alignCnt = AlignTo32(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType); outQueTswGradOut.EnQue(gradOut); - gradOut = outQueTswGradOut.DeQue; - - SetAtomicAdd(); - DataCopy(tswGradOutGT, gradOut, AlignTo32(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType)); - SetAtomicNone(); + if (std::is_same::value) { + gradOut = outQueTswGradOut.DeQue(); + LocalTensor gradOutFP16 = gradOut.template ReinterpretCast(); + Cast(gradOutFP16, gradOut, RoundMode::CAST_NONE, numLayer * numBuckets); + outQueTswGradOut.EnQue(gradOutFP16); + gradOutFP16 = outQueTswGradOut.DeQue(); + + SetAtomicAdd(); + DataCopy(tswGradOutGT, gradOutFP16, alignCnt); + SetAtomicNone(); + } else if (std::is_same::value) { + LocalTensor gradOutFP32 = outQueTswGradOut.DeQue(); + SetAtomicAdd(); + DataCopy(tswGradOutGT, gradOutFP32, alignCnt); + SetAtomicNone(); + } } __aicore__ inline void Compute(Args args) @@ -109,19 +139,25 @@ public: InitTswGrad(); uint32_t offset = 0; - LocalTensor gradOut = tswGradOutGT.DeQue(); + LocalTensor gradOut = tswGradOutGT.DeQue(); while (offset < processLen) { uint32_t remain = processLen - offset; uint32_t cnt = remain > stride ? stride : remain; - DataCopyInIndex(offset, cnt); + DataCopyInIndex(startGT + offset, cnt); LocalTensor index = inQueBucketTimestamps.DeQue(); - for (uint8_t n = 0; n < numLayer; ++n) { - DataCopyInGrad(n, offset, cnt); - LocalTensor grad = tsGradGT.DeQue(); + for (uint32_t n = 0; n < numLayer; ++n) { + DataCopyInGrad(n, startGT + offset, cnt); + pipe_barrier(PIPE_ALL); + + LocalTensor grad = inQueTsGrad.DeQue(); ScatterAdd(gradOut, grad, index, n, cnt); + pipe_barrier(PIPE_ALL); + + inQueTsGrad.FreeTensor(grad); } inQueBucketTimestamps.FreeTensor(index); + offset += cnt; } DataCopyOut(gradOut); outQueTswGradOut.FreeTensor(gradOut); @@ -133,9 +169,11 @@ private: GlobalTensor tswGradOutGT; TPipe pipe; + TQue tmpQue; TQue inQueTsGrad; TQue inQueBucketTimestamps; TQue outQueTswGradOut; + private: // shape uint32_t s; diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_backward.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py similarity index 47% rename from mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_backward.py rename to mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py index 8762bb93..e5d734fb 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/relative_attn_bias_backward.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py @@ -29,29 +29,27 @@ NUM_BUCKETS = 128 def create_rab_time_grad(num_layers: int, batchsize: int, s: int): - nearest10 = math.ceil(math.log10(s+1)) - table = torch.arange(s * s).reshape(s, s) / (10 **nearest10) # 用于排查哪些索引上的结果有问题 - batch = torch.arange(batchsize).reshape(-1, 1, 1) - result = batch + table.unsqueeze(0) # (b, s, s) - return result.unsqueeze(0).repeat(num_layers, 1, 1, 1) # (n, b, s, s) + return torch.randn(num_layers, batchsize, s, s) * 1e-4 def create_bucket_timestamps(batchsize: int, s: int): - repeat_times = batchsize * s * s // NUM_BUCKETS + 1 - result = torch.arange(NUM_BUCKETS).repeat(repeat_times)[:batchsize * s * s] - return result.reshape(batchsize, s, s) # (b, s, s) + result = torch.arange(batchsize * s) % NUM_BUCKETS + result = result.unsqueeze(-1).repeat(1, 1, s) + return result -def rab_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): +def rab_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor, dtype: torch.dtype): num_layers, b, s, _ = rab_time_grad.shape - tsw_grad = torch.zeros(num_layers, NUM_BUCKETS).to(rab_time_grad.device) + tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to(rab_time_grad.device) bucket_timestamps_expand = (bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) .repeat(1, 1, 2, 1, 2) .reshape(b, s, s)) - for n, grad in enumerate(rab_time_grad): - tsw_grad[n] = tsw_grad[n].scatter_add(src=grad.view(-1), index=bucket_timestamps_expand.view(-1), dim=0) - return tsw_grad + for n, grad in enumerate(rab_time_grad.to(torch.float32)): + tsw_grad[n], _ = torch.ops.mxrec.index_select_for_rank1_backward(grad.view(-1), + tsw_grad[n], + bucket_timestamps_expand.view(-1)) + return tsw_grad.to(dtype) def rab_backward_op(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): @@ -59,14 +57,33 @@ def rab_backward_op(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor @torch.no_grad() -def rab_backward(num_layers: int, batchsize: int, s: int): - grad = create_rab_time_grad(num_layers, batchsize, s) - bucket_timestamps = create_bucket_timestamps(batchsize, s // 2) +def rab_backward(num_layers: int, batchsize: int, s: int, dtype: torch.dtype): + torch_npu.npu.set_device(DEVICE) + grad = create_rab_time_grad(num_layers, batchsize, s).to(dtype).to(DEVICE) + bucket_timestamps = create_bucket_timestamps(batchsize, s // 2).to(torch.int32).to(DEVICE) + torch_npu.npu.synchronize() + + golden_result = rab_backward_golden(grad, bucket_timestamps, dtype) op_result = rab_backward_op(grad, bucket_timestamps) - golden_result = rab_backward_golden(grad, bucket_timestamps) - assert torch.allclose(op_result, golden_result) + assert torch.allclose(op_result, golden_result, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("num_layers", [1, 8]) +@pytest.mark.parametrize("train_len", [500, 1000, 2000, 4000]) +@pytest.mark.parametrize("candidate_len", [600]) +@pytest.mark.parametrize("bs", [1, 2, 4]) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): + s = 2 * train_len + candidate_len + rab_backward(num_layers, bs, s, dtype) + +@pytest.mark.parametrize("num_layers", [1, 8]) +@pytest.mark.parametrize("train_len,bs", [(500, 128), (1000, 32), (1000, 64), (4000, 8)]) +@pytest.mark.parametrize("candidate_len", [0]) +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_rab_train(num_layers, train_len, candidate_len, bs, dtype): + s = 2 * train_len + candidate_len + rab_backward(num_layers, bs, s, dtype) -if __name__ == '__main__': - rab_backward(8, 1, 10) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp index b7f88521..935d4271 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp @@ -18,12 +18,9 @@ using torch::autograd::Function; using namespace at; using namespace std; -std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, - const Tensor& identity, - const Tensor& timestamps, - const Tensor& timestampsWeights, - const at::IntArrayRef pastValidLens, - const double bucketDivisor) +std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, const Tensor& identity, + const Tensor& timestamps, const Tensor& timestampsWeights, + const at::IntArrayRef pastValidLens, const double bucketDivisor) { auto relPosBiasConti = relPosBias.contiguous(); auto identityConti = identity.contiguous(); @@ -44,23 +41,21 @@ std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, return {rabPosOut, rabTimeOut}; } -Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, - const Tensor& bucketTimestamps, - const int numBuckets) +Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, const Tensor& bucketTimestamps, + const int64_t numBuckets) { const int numLayers = rabTimeGrad.size(0); // rabTimeGrad(n, b, 2s, 2s) const int batchsize = rabTimeGrad.size(1); // rabTimeGrad(n, b, 2s, 2s) - const int sx2 = rabTimeGrad.size(2); // rabTimeGrad(n, b, 2s, 2s) + const int sx2 = rabTimeGrad.size(2); // rabTimeGrad(n, b, 2s, 2s) const int s = sx2 / 2; auto rabTimeGradConti = rabTimeGrad.contiguous(); auto bucketTimestampsConti = bucketTimestamps.contiguous(); // (n, b, s, s) - bucketTimestampsConti = bucketTimestampsConti.reshape({numLayers, batchsize, s, 1, s, 1}) - .repeat({1, 1, 1, 2, 1, 2}) - .reshape({numLayers, batchsize, sx2, sx2}); + bucketTimestampsConti = + bucketTimestampsConti.reshape({batchsize, s, 1, s, 1}).repeat({1, 1, 2, 1, 2}).reshape({batchsize, sx2, sx2}); at::Tensor rabTimeGradOut = at::zeros({numLayers, numBuckets}, rabTimeGrad.options()); - EXEC_NPU_CMD(aclnnRelativeAttnBiasBackward, rabTimeGradConti, bucketTimestampsConti, rabTimeGradOut, numBuckets); + EXEC_NPU_CMD(aclnnRelativeAttnBiasBackward, rabTimeGradConti, bucketTimestampsConti, numBuckets, rabTimeGradOut); return rabTimeGradOut; } @@ -73,18 +68,20 @@ TORCH_LIBRARY_FRAGMENT(mxrec, m) " int[] past_valid_lens," " float bucket_divisor" " ) -> (Tensor, Tensor)"); - m.def("relative_attn_bias(Tensor rab_time_grad, " - " Tensor bucket_timestamps, " - " int num_buckets" - " ) -> Tensor") + m.def("relative_attn_bias_backward(Tensor rab_time_grad, " + " Tensor bucket_timestamps, " + " int num_buckets" + " ) -> Tensor"); } TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) { m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); + m.impl("relative_attn_bias_backward", &relative_attn_bias_backward_impl_npu); } TORCH_LIBRARY_IMPL(fbgemm, PrivateUse1, m) { m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); + m.impl("relative_attn_bias_backward", &relative_attn_bias_backward_impl_npu); } -- Gitee From 682525a840c8c4ff56734daddccc2bcf6c3fcf46 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Mon, 26 May 2025 11:26:32 +0800 Subject: [PATCH 08/23] =?UTF-8?q?[feat]relative=5Fattn=5Fbias=5Fbackward?= =?UTF-8?q?=E3=80=82debug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/relative_attn_bias_backward.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h index f442b121..d46b5fab 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h @@ -77,13 +77,13 @@ public: if (std::is_same::value) { // 数据拷入 LocalTensor gradFP16 = tmpQue.AllocTensor(); - DataCopy(grad, tsGradGT[offset + layer * bs * s * s], + DataCopy(gradFP16, tsGradGT[offset + layer * bs * s * s], AlignTo32(cnt * sizeof(FloatType)) / sizeof(FloatType)); tmpQue.EnQue(gradFP16); gradFP16 = tmpQue.DeQue(); // 数据转换 LocalTensor gradFP32 = inQueTsGrad.AllocTensor(); - Cast(gradFP32, gradFP16, cnt); + Cast(gradFP32, gradFP16, RoundMode::CAST_NONE, cnt); inQueTsGrad.EnQue(gradFP32); tmpQue.FreeTensor(gradFP16); @@ -118,7 +118,7 @@ public: if (std::is_same::value) { gradOut = outQueTswGradOut.DeQue(); LocalTensor gradOutFP16 = gradOut.template ReinterpretCast(); - Cast(gradOutFP16, gradOut, RoundMode::CAST_NONE, numLayer * numBuckets); + Cast(gradOutFP16, gradOut, RoundMode::CAST_TRUNC, numLayer * numBuckets); outQueTswGradOut.EnQue(gradOutFP16); gradOutFP16 = outQueTswGradOut.DeQue(); @@ -139,7 +139,7 @@ public: InitTswGrad(); uint32_t offset = 0; - LocalTensor gradOut = tswGradOutGT.DeQue(); + LocalTensor gradOut = outQueTswGradOut.DeQue(); while (offset < processLen) { uint32_t remain = processLen - offset; uint32_t cnt = remain > stride ? stride : remain; -- Gitee From 3202943ee43937f0ba8e5342f14d5a4eaeaf0b3d Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 28 May 2025 11:42:21 +0800 Subject: [PATCH 09/23] =?UTF-8?q?[feat]relative=5Fattn=5Fbias=5Fbackward?= =?UTF-8?q?=E3=80=82=E8=B7=91=E9=80=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_backward.cpp | 16 ++++++++-- .../op_kernel/relative_attn_bias_backward.h | 31 ++++++++++++------- .../test_relative_attn_bias_backward.py | 13 +++++--- .../relative_attn_bias/relative_attn_bias.cpp | 5 +-- 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp index 84b0e257..0867c2ad 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp @@ -35,22 +35,34 @@ static ge::graphStatus TimeTilingFunc(RelativeAttnBiasBackwardTilingData& tiling { // 获取、校验必要shape数据 auto gradShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_GRAD_INDEX)->GetStorageShape(); // grad(n, b, 2s, 2s) + auto indexShape = context->GetInputShape(BUCKET_TIMESTAMPS_INDEX)->GetStorageShape(); // grad(b, 2s, 2s) + int numBuckets = *context->GetAttrs()->GetInt(NUM_BUCKET_INDEX); int numLayer = gradShape.GetDim(DIM0); int batchsize = gradShape.GetDim(DIM1); int s = gradShape.GetDim(DIM2); int s2 = gradShape.GetDim(DIM3); + int indexBatchsize = indexShape.GetDim(DIM0); + int indexS1 = indexShape.GetDim(DIM1); + int indexS2 = indexShape.GetDim(DIM2); + + OPS_CHECK(gradShape.GetDimNum() != 4, + OPS_LOG_E("Tiling Debug", "Grad shape is invalid."), + return ge::GRAPH_FAILED); + OPS_CHECK(indexShape.GetDimNum() != 3, + OPS_LOG_E("Tiling Debug", "bucket_timestamps shape is invalid."), + return ge::GRAPH_FAILED); OPS_CHECK(numBuckets <= 0, OPS_LOG_E("Tiling Debug", "NumBuckets is invalid."), return ge::GRAPH_FAILED); OPS_CHECK(numLayer <= 0, OPS_LOG_E("Tiling Debug", "Numlayer is invalid."), return ge::GRAPH_FAILED); - OPS_CHECK(batchsize <= 0, + OPS_CHECK(batchsize <= 0 || batchsize != indexBatchsize, OPS_LOG_E("Tiling Debug", "Batchsize is invalid."), return ge::GRAPH_FAILED); - OPS_CHECK(s <= 0 || s != s2, + OPS_CHECK(s <= 0 || s != s2 || s != indexS1 || s != indexS2, OPS_LOG_E("Tiling Debug", "Sequence len is invalid."), return ge::GRAPH_FAILED); diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h index d46b5fab..a99b4f10 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h @@ -20,7 +20,7 @@ public: { tsGradGT.SetGlobalBuffer((__gm__ FloatType*)args.rabTimeGrad, numLayer * bs * s * s); bucketTimestampsGT.SetGlobalBuffer((__gm__ int32_t*)args.bucketTimestamps, bs * s * s); - tswGradOutGT.SetGlobalBuffer((__gm__ FloatType*)args.timestampsWeightsGrad, numLayer * bs * s * s); + tswGradOutGT.SetGlobalBuffer((__gm__ FloatType*)args.timestampsWeightsGrad, tswTableSize); pipe.InitBuffer(inQueTsGrad, 1, AlignTo32(stride * sizeof(float))); pipe.InitBuffer(inQueBucketTimestamps, 1, AlignTo32(stride * sizeof(int32_t))); @@ -53,6 +53,7 @@ public: stride = tilingData.timeStride; numBuckets = tilingData.numBuckets; numLayer = tilingData.numLayer; + tswTableSize = numLayer * numBuckets; InitTensor(args); InitTiling(); @@ -96,8 +97,11 @@ public: } } - __aicore__ inline void ScatterAdd(LocalTensor& dst, LocalTensor& src, LocalTensor& index, - uint32_t layer, uint32_t cnt) + __aicore__ inline void ScatterAdd(LocalTensor& dst, + LocalTensor& src, + LocalTensor& index, + uint32_t layer, + uint32_t cnt) { uint32_t layerOffset = layer * numBuckets; __ubuf__ float* dstAddr = reinterpret_cast<__ubuf__ float*>(dst[layerOffset].GetPhyAddr()); @@ -113,22 +117,24 @@ public: __aicore__ inline void DataCopyOut(LocalTensor& gradOut) { // 同步计算结果 - uint32_t alignCnt = AlignTo32(numLayer * numBuckets * sizeof(FloatType)) / sizeof(FloatType); + uint32_t alignTswTableSize = AlignTo32(tswTableSize * sizeof(FloatType)) / sizeof(FloatType); outQueTswGradOut.EnQue(gradOut); + LocalTensor gradOutFP32 = outQueTswGradOut.DeQue(); + if (std::is_same::value) { - gradOut = outQueTswGradOut.DeQue(); - LocalTensor gradOutFP16 = gradOut.template ReinterpretCast(); - Cast(gradOutFP16, gradOut, RoundMode::CAST_TRUNC, numLayer * numBuckets); - outQueTswGradOut.EnQue(gradOutFP16); - gradOutFP16 = outQueTswGradOut.DeQue(); + LocalTensor gradOutFP16 = tmpQue.AllocTensor(); + Cast(gradOutFP16, gradOutFP32, RoundMode::CAST_ROUND, tswTableSize); + tmpQue.EnQue(gradOutFP16); + gradOutFP16 = tmpQue.DeQue(); SetAtomicAdd(); - DataCopy(tswGradOutGT, gradOutFP16, alignCnt); + DataCopy(tswGradOutGT, gradOutFP16, alignTswTableSize); SetAtomicNone(); + + tmpQue.FreeTensor(gradOutFP16); } else if (std::is_same::value) { - LocalTensor gradOutFP32 = outQueTswGradOut.DeQue(); SetAtomicAdd(); - DataCopy(tswGradOutGT, gradOutFP32, alignCnt); + DataCopy(tswGradOutGT, gradOutFP32, alignTswTableSize); SetAtomicNone(); } } @@ -181,6 +187,7 @@ private: uint32_t stride; uint32_t numBuckets; uint32_t numLayer; + uint32_t tswTableSize; // tiling uint32_t processLen; uint32_t startGT; diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py index e5d734fb..c861fc81 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py @@ -44,7 +44,8 @@ def rab_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Te bucket_timestamps_expand = (bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) .repeat(1, 1, 2, 1, 2) - .reshape(b, s, s)) + .reshape(b, s, s) + .to(torch.int64)) for n, grad in enumerate(rab_time_grad.to(torch.float32)): tsw_grad[n], _ = torch.ops.mxrec.index_select_for_rank1_backward(grad.view(-1), tsw_grad[n], @@ -64,9 +65,13 @@ def rab_backward(num_layers: int, batchsize: int, s: int, dtype: torch.dtype): bucket_timestamps = create_bucket_timestamps(batchsize, s // 2).to(torch.int32).to(DEVICE) torch_npu.npu.synchronize() - golden_result = rab_backward_golden(grad, bucket_timestamps, dtype) - op_result = rab_backward_op(grad, bucket_timestamps) - assert torch.allclose(op_result, golden_result, rtol=1e-5, atol=1e-5) + golden_result = rab_backward_golden(grad, bucket_timestamps, dtype).to("cpu") + op_result = rab_backward_op(grad, bucket_timestamps).to("cpu") + loss = 1e-5 + if dtype == torch.float16: + op_result = op_result.to(torch.float32) + loss = 1e-3 + assert torch.allclose(op_result, golden_result, rtol=loss, atol=loss) @pytest.mark.parametrize("num_layers", [1, 8]) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp index 935d4271..8df45253 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp @@ -51,8 +51,9 @@ Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, const Ten auto rabTimeGradConti = rabTimeGrad.contiguous(); auto bucketTimestampsConti = bucketTimestamps.contiguous(); // (n, b, s, s) - bucketTimestampsConti = - bucketTimestampsConti.reshape({batchsize, s, 1, s, 1}).repeat({1, 1, 2, 1, 2}).reshape({batchsize, sx2, sx2}); + bucketTimestampsConti = bucketTimestampsConti.view({batchsize, s, 1, s, 1}) + .repeat({1, 1, 2, 1, 2}) + .reshape({batchsize, sx2, sx2}); at::Tensor rabTimeGradOut = at::zeros({numLayers, numBuckets}, rabTimeGrad.options()); EXEC_NPU_CMD(aclnnRelativeAttnBiasBackward, rabTimeGradConti, bucketTimestampsConti, numBuckets, rabTimeGradOut); -- Gitee From d23128a038ef336c1e480a8a502baa2110bd91ce Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 28 May 2025 14:26:45 +0800 Subject: [PATCH 10/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=5Fbackward?= =?UTF-8?q?=E3=80=82=E7=BC=96=E8=AF=91debug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_backward.cpp | 8 ++++---- .../op_kernel/relative_attn_bias_backward.h | 3 ++- .../test_relative_attn_bias_backward.py | 12 +++++------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp index 0867c2ad..14b6539d 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp @@ -127,12 +127,12 @@ static ge::graphStatus InferShape(gert::InferShapeContext* context) { gert::Shape* tswGradOutShape = context->GetOutputShape(TIMESTAMPS_WEIGHTS_GRAD_INDEX); const gert::Shape* tsGradShape = context->GetInputShape(INPUT_GRAD_INDEX); // (n, b, 2s, 2s) - int n = tsGradShape.GetDim(DIM0); + int n = tsGradShape->GetDim(DIM0); int numBuckets = *context->GetAttrs()->GetInt(NUM_BUCKET_INDEX); - rabPosOutShape->SetDimNum(TSW_GRAD_OUT_DIM); - rabPosOutShape->SetDim(DIM0, n); - rabPosOutShape->SetDim(DIM1, numBuckets); + tswGradOutShape->SetDimNum(TSW_GRAD_OUT_DIM); + tswGradOutShape->SetDim(DIM0, n); + tswGradOutShape->SetDim(DIM1, numBuckets); return GRAPH_SUCCESS; } } // namespace ge diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h index a99b4f10..c857db47 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.h @@ -119,9 +119,9 @@ public: // 同步计算结果 uint32_t alignTswTableSize = AlignTo32(tswTableSize * sizeof(FloatType)) / sizeof(FloatType); outQueTswGradOut.EnQue(gradOut); - LocalTensor gradOutFP32 = outQueTswGradOut.DeQue(); if (std::is_same::value) { + LocalTensor gradOutFP32 = outQueTswGradOut.DeQue(); LocalTensor gradOutFP16 = tmpQue.AllocTensor(); Cast(gradOutFP16, gradOutFP32, RoundMode::CAST_ROUND, tswTableSize); tmpQue.EnQue(gradOutFP16); @@ -133,6 +133,7 @@ public: tmpQue.FreeTensor(gradOutFP16); } else if (std::is_same::value) { + LocalTensor gradOutFP32 = outQueTswGradOut.DeQue(); SetAtomicAdd(); DataCopy(tswGradOutGT, gradOutFP32, alignTswTableSize); SetAtomicNone(); diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py index c861fc81..ef18a1f3 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py @@ -15,7 +15,6 @@ # limitations under the License. # ============================================================================== -import math import sysconfig import pytest @@ -43,9 +42,9 @@ def rab_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Te tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to(rab_time_grad.device) bucket_timestamps_expand = (bucket_timestamps.reshape(b, s // 2, 1, s // 2, 1) - .repeat(1, 1, 2, 1, 2) - .reshape(b, s, s) - .to(torch.int64)) + .repeat(1, 1, 2, 1, 2) + .reshape(b, s, s) + .to(torch.int64)) for n, grad in enumerate(rab_time_grad.to(torch.float32)): tsw_grad[n], _ = torch.ops.mxrec.index_select_for_rank1_backward(grad.view(-1), tsw_grad[n], @@ -78,7 +77,7 @@ def rab_backward(num_layers: int, batchsize: int, s: int, dtype: torch.dtype): @pytest.mark.parametrize("train_len", [500, 1000, 2000, 4000]) @pytest.mark.parametrize("candidate_len", [600]) @pytest.mark.parametrize("bs", [1, 2, 4]) -@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): s = 2 * train_len + candidate_len rab_backward(num_layers, bs, s, dtype) @@ -87,8 +86,7 @@ def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("num_layers", [1, 8]) @pytest.mark.parametrize("train_len,bs", [(500, 128), (1000, 32), (1000, 64), (4000, 8)]) @pytest.mark.parametrize("candidate_len", [0]) -@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rab_train(num_layers, train_len, candidate_len, bs, dtype): s = 2 * train_len + candidate_len rab_backward(num_layers, bs, s, dtype) - -- Gitee From 23bb48cc721293aa9b6bdc8ee934c2f1a092c921 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 28 May 2025 14:34:27 +0800 Subject: [PATCH 11/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=5Fbackward?= =?UTF-8?q?=E3=80=82=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8Btorch16=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../relative_attn_bias/test_relative_attn_bias_backward.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py index ef18a1f3..8c11fce0 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py @@ -66,10 +66,7 @@ def rab_backward(num_layers: int, batchsize: int, s: int, dtype: torch.dtype): golden_result = rab_backward_golden(grad, bucket_timestamps, dtype).to("cpu") op_result = rab_backward_op(grad, bucket_timestamps).to("cpu") - loss = 1e-5 - if dtype == torch.float16: - op_result = op_result.to(torch.float32) - loss = 1e-3 + loss = 1e-5 if dtype == torch.float32 else 1e-3 assert torch.allclose(op_result, golden_result, rtol=loss, atol=loss) -- Gitee From 95fcd19ccfa5d2251a696cfad29862f39b49dc18 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 28 May 2025 15:27:19 +0800 Subject: [PATCH 12/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=5Fbackward?= =?UTF-8?q?=E3=80=82=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8Btorch16=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_relative_attn_bias_backward.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py index 8c11fce0..78177fbd 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_backward.py @@ -37,7 +37,7 @@ def create_bucket_timestamps(batchsize: int, s: int): return result -def rab_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor, dtype: torch.dtype): +def rab_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): num_layers, b, s, _ = rab_time_grad.shape tsw_grad = torch.zeros(num_layers, NUM_BUCKETS, dtype=torch.float32).to(rab_time_grad.device) @@ -49,7 +49,7 @@ def rab_backward_golden(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Te tsw_grad[n], _ = torch.ops.mxrec.index_select_for_rank1_backward(grad.view(-1), tsw_grad[n], bucket_timestamps_expand.view(-1)) - return tsw_grad.to(dtype) + return tsw_grad def rab_backward_op(rab_time_grad: torch.Tensor, bucket_timestamps: torch.Tensor): @@ -64,8 +64,8 @@ def rab_backward(num_layers: int, batchsize: int, s: int, dtype: torch.dtype): bucket_timestamps = create_bucket_timestamps(batchsize, s // 2).to(torch.int32).to(DEVICE) torch_npu.npu.synchronize() - golden_result = rab_backward_golden(grad, bucket_timestamps, dtype).to("cpu") - op_result = rab_backward_op(grad, bucket_timestamps).to("cpu") + golden_result = rab_backward_golden(grad, bucket_timestamps).to("cpu") + op_result = rab_backward_op(grad, bucket_timestamps).to(torch.float32).to("cpu") loss = 1e-5 if dtype == torch.float32 else 1e-3 assert torch.allclose(op_result, golden_result, rtol=loss, atol=loss) -- Gitee From a041e72b6a11a9d7742bccffbf4f913d2fc52ad2 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 28 May 2025 15:48:19 +0800 Subject: [PATCH 13/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=5Fbackward?= =?UTF-8?q?=E3=80=82include=E5=A4=B4=E6=96=87=E4=BB=B6=E9=A1=BA=E5=BA=8F?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_kernel/relative_attn_bias_backward.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp index ab857f14..ff9337df 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp @@ -5,9 +5,9 @@ * */ +#include "kernel_operator.h" #include "rab_common.h" #include "relative_attn_bias_backward.h" -#include "kernel_operator.h" extern "C" __global__ __aicore__ void relative_attn_bias_backward(GM_ADDR rabTimeGrad, GM_ADDR bucketTimestamps, -- Gitee From 9a2e873221aa5543504aa4b0167f3dc178bf4f89 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 28 May 2025 16:09:42 +0800 Subject: [PATCH 14/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=5Fbackward?= =?UTF-8?q?=E3=80=82=E9=AD=94=E9=AC=BC=E6=95=B0=E5=AD=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_backward.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp index 14b6539d..77060152 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp @@ -25,6 +25,9 @@ constexpr int TIMESTAMPS_WEIGHTS_GRAD_INDEX = 0; constexpr int NUM_BUCKET_INDEX = 0; // output dim constexpr int TSW_GRAD_OUT_DIM = 2; +constexpr int RAB_TIME_GRAD_DIM = 4; +constexpr int BUCKET_TIMESTAMPS_DIM = 3; + constexpr int DIM0 = 0; constexpr int DIM1 = 1; constexpr int DIM2 = 2; @@ -47,10 +50,10 @@ static ge::graphStatus TimeTilingFunc(RelativeAttnBiasBackwardTilingData& tiling int indexS1 = indexShape.GetDim(DIM1); int indexS2 = indexShape.GetDim(DIM2); - OPS_CHECK(gradShape.GetDimNum() != 4, + OPS_CHECK(gradShape.GetDimNum() != RAB_TIME_GRAD_DIM, OPS_LOG_E("Tiling Debug", "Grad shape is invalid."), return ge::GRAPH_FAILED); - OPS_CHECK(indexShape.GetDimNum() != 3, + OPS_CHECK(indexShape.GetDimNum() != BUCKET_TIMESTAMPS_DIM, OPS_LOG_E("Tiling Debug", "bucket_timestamps shape is invalid."), return ge::GRAPH_FAILED); OPS_CHECK(numBuckets <= 0, @@ -112,6 +115,9 @@ static ge::graphStatus TilingFunc(gert::TilingContext* context) return ge::GRAPH_FAILED); RelativeAttnBiasBackwardTilingData tilingData; auto ret = TimeTilingFunc(tilingData, context); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } context->SetBlockDim(coreNum); auto rowTilingData = context->GetRawTilingData(); -- Gitee From 7bfb56f0abedcdedbcbae47b8f88654b564f8d4e Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 28 May 2025 16:18:52 +0800 Subject: [PATCH 15/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=5Fbackward?= =?UTF-8?q?=E3=80=82=E5=91=BD=E5=90=8D=E4=BF=AE=E6=94=B9+json=E5=AF=B9?= =?UTF-8?q?=E9=BD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_backward.cpp | 20 +++++++++---------- .../relative_attn_bias_backward.json | 12 +++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp index 77060152..4168384a 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp @@ -79,24 +79,24 @@ static ge::graphStatus TimeTilingFunc(RelativeAttnBiasBackwardTilingData& tiling ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); ub = ub - RESERVER_UB_SIZE; // 获取数据类型 - auto floatType = context->GetInputTensor(TIMESTAMPS_WEIGHTS_GRAD_INDEX)->GetDataType(); - auto intType = context->GetInputTensor(BUCKET_TIMESTAMPS_INDEX)->GetDataType(); - int floatSize = ge::GetSizeByDataType(floatType); - int intSize = ge::GetSizeByDataType(intType); - OPS_CHECK(floatSize == 0 || intSize == 0, + auto gradDataType = context->GetInputTensor(TIMESTAMPS_WEIGHTS_GRAD_INDEX)->GetDataType(); + auto indexDataType = context->GetInputTensor(BUCKET_TIMESTAMPS_INDEX)->GetDataType(); + int gradSize = ge::GetSizeByDataType(gradDataType); + int indexSize = ge::GetSizeByDataType(indexDataType); + OPS_CHECK(gradSize == 0 || indexSize == 0, OPS_LOG_E("Tiling Debug", "Invalid data type."), return ge::GRAPH_FAILED); // 去除tswGrad所需ub ub = ub - numBuckets * numLayer * sizeof(float); // 计算单次处理的block大小 int stride; - if (floatType == ge::DataType::DT_FLOAT16) { - stride = ub / (intSize + floatSize + sizeof(float)); // 申请额外内存做cast + if (gradDataType == ge::DataType::DT_FLOAT16) { + stride = ub / (indexSize + gradSize + sizeof(float)); // 申请额外内存做cast } else { - stride = ub / (intSize + sizeof(float)); + stride = ub / (indexSize + sizeof(float)); } - tilingData.set_floatType(floatType); - tilingData.set_intType(intType); + tilingData.set_floatType(gradDataType); + tilingData.set_intType(indexDataType); tilingData.set_timeStride(stride); return ge::GRAPH_SUCCESS; } diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/relative_attn_bias_backward.json b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/relative_attn_bias_backward.json index 70b0752f..ae0d8f8f 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/relative_attn_bias_backward.json +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/relative_attn_bias_backward.json @@ -7,20 +7,20 @@ "name": "rab_time_grad", "param_type": "required", "format": [ - "ND" + "ND", "ND" ], "type": [ - "float" + "fp32", "fp16" ] }, { "name": "bucket_timestamps", "param_type": "required", "format": [ - "ND" + "ND", "ND" ], "type": [ - "int32" + "int32", "int32" ] } ], @@ -29,10 +29,10 @@ "name": "timestamps_weights_grad", "param_type": "required", "format": [ - "ND" + "ND", "ND" ], "type": [ - "float" + "fp32", "fp16" ] } ], -- Gitee From 5e1546b99139d7152736cf48f005302865bee959 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Wed, 28 May 2025 16:23:35 +0800 Subject: [PATCH 16/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=5Fbackward?= =?UTF-8?q?=E3=80=82=E5=91=BD=E5=90=8D=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_backward.cpp | 4 ++-- .../op_host/relative_attn_bias_backward_tiling.h | 4 ++-- .../op_kernel/relative_attn_bias_backward.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp index 4168384a..13092840 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward.cpp @@ -95,8 +95,8 @@ static ge::graphStatus TimeTilingFunc(RelativeAttnBiasBackwardTilingData& tiling } else { stride = ub / (indexSize + sizeof(float)); } - tilingData.set_floatType(gradDataType); - tilingData.set_intType(indexDataType); + tilingData.set_gradDataType(gradDataType); + tilingData.set_indexDataType(indexDataType); tilingData.set_timeStride(stride); return ge::GRAPH_SUCCESS; } diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward_tiling.h index 74536842..40b6cce3 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward_tiling.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_host/relative_attn_bias_backward_tiling.h @@ -19,8 +19,8 @@ TILING_DATA_FIELD_DEF(float, bucketDivisor); TILING_DATA_FIELD_DEF(int64_t, numBuckets); TILING_DATA_FIELD_DEF(int64_t, numLayer); -TILING_DATA_FIELD_DEF(int, floatType); -TILING_DATA_FIELD_DEF(int, intType); +TILING_DATA_FIELD_DEF(int, gradDataType); +TILING_DATA_FIELD_DEF(int, indexDataType); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(RelativeAttnBiasBackward, RelativeAttnBiasBackwardTilingData) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp index ff9337df..17b6ad1e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_backward/op_kernel/relative_attn_bias_backward.cpp @@ -19,10 +19,10 @@ extern "C" __global__ __aicore__ void relative_attn_bias_backward(GM_ADDR rabTim Args args{ rabTimeGrad, bucketTimestamps, timestampsWeightsGrad, workspace, tiling }; - if (tilingData.floatType == TYPE_FP32) { + if (tilingData.gradDataType == TYPE_FP32) { RelativeAttnBiasBackward kernel; kernel.Compute(args); - } else if (tilingData.floatType == TYPE_FP16) { + } else if (tilingData.gradDataType == TYPE_FP16) { RelativeAttnBiasBackward kernel; kernel.Compute(args); } -- Gitee From 640a734f0be0b4ded93197c2f13ebda99d60f5fa Mon Sep 17 00:00:00 2001 From: zhoucy Date: Thu, 29 May 2025 09:47:08 +0800 Subject: [PATCH 17/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=E3=80=82?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8Bclean=20code=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../torch_demo/relative_attn_bias/test_relative_attn_bias.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py index 0e1fe420..83bb6f01 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py @@ -189,10 +189,6 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): past_valid_lens=past_valid_lens) torch_npu.npu.synchronize() - # 验证训练正向精度时需注释rab_pos_golden部分 - rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias_list[layer_num, ...], - identity=identity_list[layer_num, ...], - past_valid_lens=past_valid_lens) rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1), timestamps=timestamps) torch_npu.npu.synchronize() -- Gitee From d3d86148dd24918363f1d1ed5ed2a1a2fb84b137 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Fri, 30 May 2025 09:11:27 +0800 Subject: [PATCH 18/23] =?UTF-8?q?[feat]relative=5Fattn=5Fbias=E3=80=82?= =?UTF-8?q?=E6=8B=86=E5=88=86position?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias.cpp | 185 ++++++++++++++++++ .../op_host/relative_attn_bias_tiling.h | 25 +++ .../op_kernel/rab_common.h | 35 ++++ .../op_kernel/relative_attn_bias.cpp | 29 +++ .../op_kernel/relative_attn_bias_pos.h | 174 ++++++++++++++++ .../relative_attn_bias_pos.json | 47 +++++ .../operators/relative_attn_bias_pos/run.sh | 67 +++++++ .../test_relative_attn_bias_v200.py | 35 ++++ .../relative_attn_bias/relative_attn_bias.cpp | 21 ++ 9 files changed, 618 insertions(+) create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_tiling.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias.cpp create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/run.sh diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp new file mode 100644 index 00000000..11c74a21 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp @@ -0,0 +1,185 @@ +/** +* @file relative_attn_bias.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include +#include "relative_attn_bias_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_api.h" +#include "tiling/platform/platform_ascendc.h" +#include "../../../common/ops_log.h" + +constexpr int32_t RESERVER_UB_SIZE = (20 * 1024); +constexpr uint8_t NUM_BUFFER = 2; + +// input index +constexpr int REL_POS_BIAS_INDEX = 0; +constexpr int IDENTITY_INDEX = 1; +// output index +constexpr int RAB_POSITION_INDEX = 0; +// attr index +constexpr int PAST_VALID_LENS_INDEX = 0; +// output dim +constexpr int REL_POS_BIAS_DIM = 2; +constexpr int IDENTITY_DIM = 2; +constexpr int RAB_POS_OUT_DIM = 3; +constexpr int DIM0 = 0; +constexpr int DIM1 = 1; +constexpr int DIM2 = 2; + +namespace optiling { +static ge::graphStatus PosTilingFunc(TilingData& tilingData, gert::TilingContext* context) +{ + // 设置past_valid_len + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + const auto pastValidLensPtr = attrs->GetAttrPointer(PAST_VALID_LENS_INDEX); + int batchsize = pastValidLensPtr->GetSize(); + OPS_CHECK(batchsize <= 0, + OPS_LOG_E("Tiling Debug", "mismatch batchsize of past_valid_len and timestamps."), + return ge::GRAPH_FAILED); + tilingData.set_bs(batchsize); + + auto *pastValidLensData = const_cast(reinterpret_cast(pastValidLensPtr->GetData())); + uint32_t pastValidLens[MAX_BATCH_SIZE]; + for (auto i = 0; i < batchsize; ++i) { + pastValidLens[i] = pastValidLensData[i]; + } + tilingData.set_pastValidLens(pastValidLens); + // 校验s + auto biasShape = context->GetInputShape(REL_POS_BIAS_INDEX)->GetStorageShape(); // (2s, 2s) + auto identityShape = context->GetInputShape(IDENTITY_INDEX)->GetStorageShape(); // (2s, 2s) + + int biasSeqLen = biasShape.GetDim(DIM0); + int biasSeqLen2 = biasShape.GetDim(DIM1); + int idSeqLen = identityShape.GetDim(DIM0); + int idSeqLen2 = identityShape.GetDim(DIM1); + + OPS_CHECK(biasShape.GetDimNum() != REL_POS_BIAS_DIM, + OPS_LOG_E("Tiling Debug", "Invalid rel_pos_bias shape."), + return ge::GRAPH_FAILED); + OPS_CHECK(identityShape.GetDimNum() != IDENTITY_DIM, + OPS_LOG_E("Tiling Debug", "Invalid identity shape."), + return ge::GRAPH_FAILED); + OPS_CHECK(biasSeqLen != biasSeqLen2 || biasSeqLen != idSeqLen || biasSeqLen != idSeqLen2, + OPS_LOG_E("Tiling Debug", "Mismatch sequence len of rel_pos_bias and identity."), + return ge::GRAPH_FAILED); + tilingData.set_s(biasSeqLen / 2); + + // 获取ub + uint64_t ub; + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); + ub = ub - RESERVER_UB_SIZE; + // 获取数据类型 + auto identityType = context->GetInputTensor(IDENTITY_INDEX)->GetDataType(); + auto biasType = context->GetInputTensor(REL_POS_BIAS_INDEX)->GetDataType(); + + int identitySize = ge::GetSizeByDataType(identityType); + OPS_CHECK(identityType != biasType, + OPS_LOG_E("Tiling Debug", "Mismatch data type of identity and rel_pos_bias."), + return ge::GRAPH_FAILED); + OPS_CHECK(identitySize == 0, + OPS_LOG_E("Tiling Debug", "Invalid data type."), + return ge::GRAPH_FAILED); + tilingData.set_dataType(identityType); + + // 计算一次处理的窗口大小(stride) + int stride = ub / (NUM_BUFFER * 3 * identitySize); + tilingData.set_stride(stride); +} + +static ge::graphStatus TilingFunc(gert::TilingContext* context) +{ + OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("relPosBiasShape", context->GetInputShape(REL_POS_BIAS_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("identityShape", context->GetInputShape(IDENTITY_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("attrs", context->GetAttrs(), return ge::GRAPH_FAILED); + + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + const auto pastValidLensPtr = attrs->GetAttrPointer(PAST_VALID_LENS_INDEX); + OPS_LOG_E_IF_NULL("past_valid_len", pastValidLensPtr, return ge::GRAPH_FAILED); + + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + OPS_CHECK(coreNum == 0, + OPS_LOG_E("Tiling Debug", "Core num is 0."), + return ge::GRAPH_FAILED); + context->SetBlockDim(coreNum); + + TilingData tilingData; + auto ret = PosTilingFunc(tilingData, context); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + auto rowTilingData = context->GetRawTilingData(); + OPS_LOG_E_IF_NULL("GetRawTilingData", rowTilingData, return ge::GRAPH_FAILED); + tilingData.SaveToBuffer(rowTilingData->GetData(), rowTilingData->GetCapacity()); + rowTilingData->SetDataSize(tilingData.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext* context) +{ + gert::Shape* rabPosOutShape = context->GetOutputShape(RAB_POSITION_INDEX); + + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + const auto pastValidLensPtr = attrs->GetAttrPointer(PAST_VALID_LENS_INDEX); + int bs = pastValidLensPtr->GetSize(); + const gert::Shape* identityShape = context->GetInputShape(IDENTITY_INDEX); + int s = identityShape->GetDim(DIM0); // identityShape(2s, 2s) + + rabPosOutShape->SetDimNum(RAB_POS_OUT_DIM); + rabPosOutShape->SetDim(DIM0, bs); + rabPosOutShape->SetDim(DIM1, s); + rabPosOutShape->SetDim(DIM2, s); + + return GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class RelativeAttnBias : public OpDef { +public: + explicit RelativeAttnBias(const char* name) : OpDef(name) + { + this->Input("rel_pos_bias") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("identity") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("rab_pos") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("past_valid_lens").ListInt(); + + this->SetInferShape(ge::InferShape); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") + .ExtendCfgInfo("coreType.value", "AiCore") + .ExtendCfgInfo("prebuildPattern.value", "Opaque"); + + this->AICore().SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910", aicore_config); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + this->AICore().AddConfig("ascend310p", aicore_config); + } +}; + +OP_ADD(RelativeAttnBias); + +} // namespace ops \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_tiling.h new file mode 100644 index 00000000..17d05ef5 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_tiling.h @@ -0,0 +1,25 @@ +/** + * @file relative_attn_bias_tiling.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#include "register/tilingdata_base.h" +constexpr int MAX_BATCH_SIZE = 512; + +namespace optiling { +BEGIN_TILING_DATA_DEF(TilingData) +TILING_DATA_FIELD_DEF(int64_t, s); +TILING_DATA_FIELD_DEF(int64_t, bs); +TILING_DATA_FIELD_DEF(int64_t, stride); +TILING_DATA_FIELD_DEF_ARR(uint32_t, MAX_BATCH_SIZE, pastValidLens); + +TILING_DATA_FIELD_DEF(int, dataType); + +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(RelativeAttnBias, TilingData) +} // namespace optiling +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h new file mode 100644 index 00000000..62b5ddcb --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h @@ -0,0 +1,35 @@ +/** + * @file rab_common.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RAB_COMMON_H +#define MXREC_ADD_ONS_RAB_COMMON_H + +#include "kernel_operator.h" +constexpr int DATA_ALIGN_BYTES = 32; +constexpr int MAX_BATCH_SIZE = 512; +constexpr int NUM_BUFFER = 2; +constexpr int MAX_SEQ_CNT = 128; +constexpr int GATHER_PROCESS_WINDOW = 4096; + +constexpr int8_t TYPE_FP32 = 0; +constexpr int8_t TYPE_FP16 = 1; +constexpr int8_t TYPE_INT32 = 3; +constexpr int8_t TYPE_INT64 = 9; + +using namespace AscendC; + +struct Args { + // pos_bias + GM_ADDR positionBias; + GM_ADDR identity; + // out + GM_ADDR rabPosOut; + + GM_ADDR workspace; + GM_ADDR tiling; +}; +#endif // MXREC_ADD_ONS_RAB_COMMON_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias.cpp new file mode 100644 index 00000000..249f4c2b --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias.cpp @@ -0,0 +1,29 @@ +/** +* @file relative_attn_bias.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include "rab_common.h" +#include "relative_attn_bias_pos.h" +#include "kernel_operator.h" + +extern "C" __global__ __aicore__ void relative_attn_bias(GM_ADDR positionBias, + GM_ADDR identity, + GM_ADDR rabPosOut, + GM_ADDR workspace, + GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + Args args{ + positionBias, identity, rabPosOut, workspace, tiling + }; + if (tilingData.dataType == TYPE_FP32) { + RelativeAttnBiasPos kernel; + kernel.Compute(args); + } else if (tilingData.floatType == TYPE_FP16) { + RelativeAttnBiasPos kernel; + kernel.Compute(args); + } +} diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h new file mode 100644 index 00000000..bd3fccea --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h @@ -0,0 +1,174 @@ +/** + * @file relative_attn_bias_pos.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H +#include "rab_common.h" +#include "kernel_operator.h" +using namespace AscendC; + +constexpr int SEQ_EXPAND = 2; // rab_pos中序列长度为原本输入的两倍 + +template +class RelativeAttnBiasPos { +public: + __aicore__ inline RelativeAttnBiasPos() {} + + __aicore__ inline void Init(Args args) + { + GET_TILING_DATA(tilingData, args.tiling); + s = SEQ_EXPAND * tilingData.s; + bs = tilingData.bs; + stride = tilingData.positionStride; + for (auto i = 0; i < bs; ++i) { + pastValidLens[i] = tilingData.pastValidLens[i]; + } + + posBiasGT.SetGlobalBuffer((__gm__ FloatType*)args.positionBias, s * s); + identityGT.SetGlobalBuffer((__gm__ FloatType*)args.identity, s * s); + rabPosBiasOutGT.SetGlobalBuffer((__gm__ FloatType*)args.rabPosOut, bs * s * s); + + pipe.InitBuffer(queIdentityIn, NUM_BUFFER, Ceil(SEQ_EXPAND * stride * sizeof(FloatType))); + pipe.InitBuffer(quePosIn, NUM_BUFFER, Ceil(stride * sizeof(FloatType))); + + int64_t totalTableSizeSplit = s % GetBlockNum(); + int64_t baseLen = s / GetBlockNum(); + if (GetBlockIdx() >= totalTableSizeSplit) { + totalRow = baseLen; + rowOffset = totalTableSizeSplit * (baseLen + 1) + (GetBlockIdx() - totalTableSizeSplit) * baseLen; + } else { + totalRow = baseLen + 1; + rowOffset = GetBlockIdx() * (baseLen + 1); + } + REL_POS_BIAS_FIRST = posBiasGT.GetValue(0); + } + + __aicore__ inline void ComputeIdentity(int offset, int cnt) + { + // DataCopyIn identity + LocalTensor identityUb = queIdentityIn.AllocTensor(); + + DataCopy(identityUb, identityGT[offset], Ceil(cnt * sizeof(FloatType)) / sizeof(FloatType)); + queIdentityIn.EnQue(identityUb); + + // Compute identity * rel_pos_bias[0, 0], (1 - identity) + LocalTensor identityFilledUb = queIdentityIn.DeQue(); + + // 后半段 (1 - identity) + Muls(identityFilledUb[stride], identityFilledUb, (FloatType)-1, cnt); + Adds(identityFilledUb[stride], identityFilledUb[stride], (FloatType)1, cnt); + + // 前半段 identity * rel_pos_bias[0, 0] + Muls(identityFilledUb, identityFilledUb, REL_POS_BIAS_FIRST, cnt); + + queIdentityIn.EnQue(identityFilledUb); + } + + __aicore__ inline void DataCopyIn(int row, int offset, int cnt) + { + LocalTensor posBiasUb = quePosIn.AllocTensor(); + DataCopy(posBiasUb, posBiasGT[row * s + offset], Ceil(cnt * sizeof(FloatType)) / sizeof(FloatType)); + quePosIn.EnQue(posBiasUb); + } + + __aicore__ inline void ComputeRabBias(LocalTensor& identityCalcUb, int cnt) + { + LocalTensor posBiasUb = quePosIn.DeQue(); + Mul(posBiasUb, posBiasUb, identityCalcUb[stride], cnt); + Add(posBiasUb, posBiasUb, identityCalcUb, cnt); + pipe_barrier(PIPE_ALL); + quePosIn.EnQue(posBiasUb); + } + + __aicore__ inline int64_t Ceil(int64_t a, int64_t b = DATA_ALIGN_BYTES) + { + if (b == 0) { + return 0; + } + return (a + b - 1) / b * b; + } + + __aicore__ inline void DataCopyOut(int offset, int cnt) + { + uint32_t datasize = cnt * sizeof(FloatType); + uint32_t alignLen = datasize / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES; + uint32_t unAlignLen = datasize - alignLen; + uint32_t alignCnt = alignLen / sizeof(FloatType); + uint32_t unAlignCnt = unAlignLen / sizeof(FloatType); + + LocalTensor posBiasUb = quePosIn.DeQue(); + // 对齐部分拷出 + if (alignLen > 0) { + DataCopy(rabPosBiasOutGT[offset], posBiasUb, cnt); + } + // 非对齐部分拷出 + if (unAlignLen > 0) { +#ifdef SUPPORT_V200 + uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(FloatType))) - (1ul << unAlignCnt); + uint64_t mask[2] = {mask0, 0}; + Duplicate(posBiasUb[alignCnt], (FloatType)0, mask, 1, 1, 1); + quePosIn.EnQue(posBiasUb); + posBiasUb = quePosIn.DeQue(); + SetAtomicAdd(); + DataCopy(rabPosBiasOutGT[offset + alignCnt], posBiasUb[alignCnt], Ceil(unAlignLen) / sizeof(FloatType)); + SetAtomicNone(); +#else + const DataCopyExtParams dataCopyExtParams{1, unAlignLen, 0, 0, 0}; + DataCopyPad(rabPosBiasOutGT[offset + alignCnt], posBiasUb[alignCnt], dataCopyExtParams); +#endif + } + quePosIn.FreeTensor(posBiasUb); + } + + __aicore__ inline void Compute(Args args) + { + Init(args); + for (int row = rowOffset; row < rowOffset + totalRow; ++row) { + int offset = 0; + for (int j = 0; j < (s + stride - 1) / stride; ++j) { + int remain = s - offset; + int cnt = remain > stride ? stride : remain; + ComputeIdentity(offset + row * s, cnt); + LocalTensor identityCalcUb = queIdentityIn.DeQue(); + + for (int b = 0; b < bs; ++b) { + int valid_len = pastValidLens[b]; + int valid_row = row > valid_len ? valid_len : row; + DataCopyIn(valid_row, offset, cnt); + ComputeRabBias(identityCalcUb, cnt); + int padOutPtr = b * s * s + row * s + j * stride; + DataCopyOut(padOutPtr, cnt); + } + queIdentityIn.FreeTensor(identityCalcUb); + offset += cnt; + } + } + } + +private: + // shape + int s; + int bs; + int stride; + // tiling + int rowOffset; // identity、rel_pos_bias(s, s)的行偏移 + int totalRow; // 需要处理的总行数 + +private: + TPipe pipe; + TQue queIdentityIn; + TQue queIdentityCalcIn; + TQue quePosIn; + + GlobalTensor identityGT; + GlobalTensor posBiasGT; + GlobalTensor rabPosBiasOutGT; + uint32_t pastValidLens[MAX_BATCH_SIZE]; + FloatType REL_POS_BIAS_FIRST; // identity[0, 0] +}; + +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json new file mode 100644 index 00000000..d5ed4f18 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json @@ -0,0 +1,47 @@ +[ + { + "op": "RelativeAttnBias", + "language": "cpp", + "input_desc": [ + { + "name": "rel_pos_bias", + "param_type": "required", + "format": [ + "ND", "ND" + ], + "type": [ + "fp16", "fp32" + ] + }, + { + "name": "identity", + "param_type": "required", + "format": [ + "ND", "ND" + ], + "type": [ + "fp16", "fp32" + ] + } + ], + "output_desc": [ + { + "name": "rab_pos", + "param_type": "required", + "format": [ + "ND", "ND" + ], + "type": [ + "fp16", "fp32" + ] + } + ], + "attr": [ + { + "name": "past_valid_lens", + "param_type": "required", + "type": "list_int" + } + ] + } +] \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/run.sh b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/run.sh new file mode 100644 index 00000000..45326162 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/run.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# ============================================================================== + +set -e + +# 查找msopgen的路径,加入到环境变量PATH中 +msopgen_path=$(find /usr/local/Ascend/ -name msopgen | grep bin) +parent_dir=$(dirname "$msopgen_path") +export PATH=$parent_dir:$PATH + +ai_core="ai_core-Ascend910B1" +if [ "$#" -eq 1 ]; then + ai_core=$1 +fi + +# 利用msopgen生成可编译文件 +rm -rf ./relative_attn_bias_pos +python3 /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/bin/msopgen gen -i relative_attn_bias_pos.json -f tf -c ${ai_core} -lan cpp -out ./relative_attn_bias_pos -m 0 -op RelativeAttnBiasPos +rm -rf relative_attn_bias_pos/op_kernel/*.h +rm -rf relative_attn_bias_pos/op_kernel/*.cpp +rm -rf relative_attn_bias_pos/host/*.h +rm -rf relative_attn_bias_pos/host/*.cpp +cp -rf op_kernel relative_attn_bias_pos/ +cp -rf op_host relative_attn_bias_pos/ + +cd relative_attn_bias_pos + +# 判断当前目录下是否存在CMakePresets.json文件 +if [ ! -f "CMakePresets.json" ]; then + echo "ERROR, CMakePresets.json file not exist." + exit 1 +fi + +# 禁止生成CRC校验和 +sed -i 's/--nomd5/--nomd5 --nocrc/g' ./cmake/makeself.cmake + +# 修改cann安装路径 +sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json +# 修改vendor_name 防止覆盖之前vendor_name为customize的算子; +# vendor_name需要和aclnn中的CMakeLists.txt中的CUST_PKG_PATH值同步,不同步aclnn会调用失败; +# vendor_name字段值不能包含customize;包含会导致多算子部署场景CANN的vendors路径下config.ini文件内容截取错误 +sed -i 's:"customize":"relative_attn_bias_pos":g' CMakePresets.json + +if [ "$ai_core" = "ai_core-Ascend310P3" ]; then + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_pos_kernel.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_pos_time.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_pos_pos.h +fi + +line=`awk '/ENABLE_SOURCE_PACKAGE/{print NR}' CMakePresets.json` +line=`expr ${line} + 2` +sed -i "${line}s/True/False/g" CMakePresets.json + +# 增加LOG_CPP编译选项支持错误日志打印 +sed -i "1 i include(../../../cmake/func.cmake)" ./op_host/CMakeLists.txt + +line1=`awk '/tartet_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line1}s/OP_TILING_LIB/OP_TILING_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +line2=`awk '/tartet_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line2}s/OP_PROTO_LIB/OP_PROTO_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +bash build.sh + +# 安装编译成功的算子包 +bash ./build_out/custom_opp*.run diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py index 3a24438b..8ce7fce4 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py @@ -160,6 +160,32 @@ def rab_pos_golden(rel_pos_bias: torch.Tensor, identity: torch.Tensor, past_vali return rel_pos_bias_list +@torch.no_grad() +def rab_pos(num_layers, train_len, candidate_len, bs, dtype): + torch_npu.npu.set_device(DEVICE) + pos_w = create_pos_w(train_len, num_layers).to(dtype) + past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) + rel_pos_bias_list, identity_list = init_rel_pos_bias(pos_w=pos_w, + train_len=train_len, + candidate_len=candidate_len, + num_layers=num_layers) + rel_pos_bias_list, identity_list = rel_pos_bias_list.to(dtype), identity_list.to(dtype) + + rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) + identity_list = identity_list.to(DEVICE) + past_valid_lens = past_valid_lens.to(DEVICE) + torch_npu.npu.synchronize() + + for rel_pos_bias, identity in zip(rel_pos_bias_list, identity_list): + op_result = torch.ops.mxrec.relative_attn_bias_pos(rel_pos_bias=rel_pos_bias, + identity=identity, + past_valid_lens=past_valid_lens.tolist()).to('cpu') + golden_result = rab_pos_golden(rel_pos_bias=rel_pos_bias.to('cpu'), + identity=identity.to('cpu'), + past_valid_lens=past_valid_lens.to('cpu')) + assert torch.allclose(op_result, golden_result) + + @torch.no_grad() def rab(num_layers, train_len, candidate_len, bs, dtype): torch_npu.npu.set_device(DEVICE) @@ -208,3 +234,12 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): rab(num_layers, train_len, candidate_len, bs, dtype) + + +@pytest.mark.parametrize("num_layers", [8]) +@pytest.mark.parametrize("train_len", [500, 1000, 2000, 4000]) +@pytest.mark.parametrize("candidate_len", [600]) +@pytest.mark.parametrize("bs", [1, 2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_rab_pos_eval(num_layers, train_len, candidate_len, bs, dtype): + rab_pos(num_layers, train_len, candidate_len, bs, dtype) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp index 8df45253..e99d93d2 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp @@ -41,6 +41,21 @@ std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, return {rabPosOut, rabTimeOut}; } +Tensor relative_attn_bias_pos_impl_npu(const Tensor& relPosBias, const Tensor& identity, + const at::IntArrayRef pastValidLens) +{ + auto relPosBiasConti = relPosBias.contiguous(); + auto identityConti = identity.contiguous(); + + const int bs = pastValidLens.size(); + const int sx2 = relPosBias.size(0); // relPosBias(2s, 2s) + + at::Tensor rabPosOut = at::zeros({bs, sx2, sx2}, relPosBiasConti.options()); + + EXEC_NPU_CMD(aclnnRelativeAttnBiasPos, relPosBiasConti, identityConti, pastValidLens, rabPosOut); + return rabPosOut; +} + Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, const Tensor& bucketTimestamps, const int64_t numBuckets) { @@ -69,6 +84,10 @@ TORCH_LIBRARY_FRAGMENT(mxrec, m) " int[] past_valid_lens," " float bucket_divisor" " ) -> (Tensor, Tensor)"); + m.def("relative_attn_bias_pos(Tensor rel_pos_bias, " + " Tensor identity, " + " int[] past_valid_lens" + " ) -> Tensor"); m.def("relative_attn_bias_backward(Tensor rab_time_grad, " " Tensor bucket_timestamps, " " int num_buckets" @@ -78,11 +97,13 @@ TORCH_LIBRARY_FRAGMENT(mxrec, m) TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) { m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); + m.impl("relative_attn_bias_pos", &relative_attn_bias_pos_impl_npu); m.impl("relative_attn_bias_backward", &relative_attn_bias_backward_impl_npu); } TORCH_LIBRARY_IMPL(fbgemm, PrivateUse1, m) { m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); + m.impl("relative_attn_bias_pos", &relative_attn_bias_pos_impl_npu); m.impl("relative_attn_bias_backward", &relative_attn_bias_backward_impl_npu); } -- Gitee From 24c2d37a175885554f5821f9fa669d746c403490 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Fri, 30 May 2025 14:37:55 +0800 Subject: [PATCH 19/23] =?UTF-8?q?[feat]relative=5Fattn=5Fbias=E3=80=82?= =?UTF-8?q?=E6=8B=86=E5=88=86time=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mxrec_add_ons/build/build.sh | 2 +- .../op_host/relative_attn_bias.cpp | 1 + .../op_host/relative_attn_bias.cpp | 193 +++++++++++++ .../op_host/relative_attn_bias_tiling.h | 31 +++ .../op_kernel/rab_common.h | 39 +++ .../op_kernel/relative_attn_bias.cpp | 29 ++ .../op_kernel/relative_attn_bias_time.h | 255 ++++++++++++++++++ .../relative_attn_bias.json | 53 ++++ .../operators/relative_attn_bias_time/run.sh | 67 +++++ .../test_relative_attn_bias.py | 28 +- .../test_relative_attn_bias_v200.py | 30 +-- .../relative_attn_bias/relative_attn_bias.cpp | 59 ++-- 12 files changed, 705 insertions(+), 82 deletions(-) create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias.cpp create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_tiling.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias.cpp create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias.json create mode 100644 mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/run.sh diff --git a/mxrec_add_ons/build/build.sh b/mxrec_add_ons/build/build.sh index 38720b99..cc8eba35 100644 --- a/mxrec_add_ons/build/build.sh +++ b/mxrec_add_ons/build/build.sh @@ -32,7 +32,7 @@ permute2d_sparse_data split_embedding_codegen_forward_unweighted dense_to_jagged " -support_310p_list="gather_for_rank1 hstu_dense_forward_fuxi relative_attn_bias" +support_310p_list="gather_for_rank1 hstu_dense_forward_fuxi relative_attn_bias_pos relative_attn_bias_time" cd "${MxRec_DIR}" diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp index 11c74a21..4d71e14f 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp @@ -88,6 +88,7 @@ static ge::graphStatus PosTilingFunc(TilingData& tilingData, gert::TilingContext // 计算一次处理的窗口大小(stride) int stride = ub / (NUM_BUFFER * 3 * identitySize); tilingData.set_stride(stride); + return ge::GRAPH_SUCCESS; } static ge::graphStatus TilingFunc(gert::TilingContext* context) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias.cpp new file mode 100644 index 00000000..13046a0a --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias.cpp @@ -0,0 +1,193 @@ +/** +* @file relative_attn_bias.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include +#include "relative_attn_bias_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_api.h" +#include "tiling/platform/platform_ascendc.h" +#include "../../../common/ops_log.h" + +constexpr int32_t RESERVER_UB_SIZE = (20 * 1024); +constexpr int32_t DATA_ALIGN_BYTES = 32; + +// input index +constexpr int TIMESTAMPS_INDEX = 0; +constexpr int TIMESTAMPS_WEIGHTS_INDEX = 1; +// output index +constexpr int RAB_TIME_INDEX = 0; +// attr index +constexpr int BUCKET_DIV_INDEX = 0; +// input/output dim +constexpr int TIMESTAMPS_DIM = 2; +constexpr int TIMESTAMPS_WEIGHTS_DIM = 2; +constexpr int RAB_TIME_OUT_DIM = 6; +constexpr int DIM_PLACE_HOLDER = 1; +constexpr int DIM0 = 0; +constexpr int DIM1 = 1; +constexpr int DIM2 = 2; +constexpr int DIM3 = 3; +constexpr int DIM4 = 4; +constexpr int DIM5 = 5; +// constrain of params +constexpr int MAX_S = 4300; + +namespace optiling { +static ge::graphStatus TimeTilingFunc(TilingData& tilingData, gert::TilingContext* context) +{ + auto tsShape = context->GetInputShape(TIMESTAMPS_INDEX)->GetStorageShape(); // (b, s) + auto tswShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX)->GetStorageShape(); // (num_layer, num_buckets) + + int batchsize = tsShape.GetDim(DIM0); // (b, s) + int s = tsShape.GetDim(DIM1); // (b, s) + int numLayers = tswShape.GetDim(DIM0); // (num_layer, num_buckets) + int numBuckets = tswShape.GetDim(DIM1); // (num_layer, num_buckets) + float divs = *context->GetAttrs()->GetFloat(BUCKET_DIV_INDEX); + float clampMax = exp((numBuckets - 1) * divs); + + OPS_CHECK(tsShape.GetDimNum() != TIMESTAMPS_DIM, + OPS_LOG_E("Tiling Debug", "Invalid timestamps shape."), + return ge::GRAPH_FAILED); + OPS_CHECK(tswShape.GetDimNum() != TIMESTAMPS_WEIGHTS_DIM, + OPS_LOG_E("Tiling Debug", "Invalid timestamps_weights shape."), + return ge::GRAPH_FAILED); + OPS_CHECK(s > MAX_S, + OPS_LOG_E("Tiling Debug", "Len of timestamps sequence larger than limit."), + return ge::GRAPH_FAILED); + OPS_CHECK(bs <= 0, + OPS_LOG_E("Tiling Debug", "Invalid batchsize of timestamps."), + return ge::GRAPH_FAILED); + + tilingData.set_bs(bs); + tilingData.set_s(s); + tilingData.set_numLayer(numLayer); + tilingData.set_numBuckets(numBuckets); + tilingData.set_bucketDivisor(divs); + tilingData.set_clampMax(clampMax); + + // 计算stride、buff + // 获取ub + uint64_t ub; + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); + ub = ub - RESERVER_UB_SIZE; + // 获取数据类型 + auto tswType = context->GetInputTensor(TIMESTAMPS_WEIGHTS_INDEX)->GetDataType(); + auto tsType = context->GetInputTensor(TIMESTAMPS_INDEX)->GetDataType(); + int tswSize = ge::GetSizeByDataType(tswType); + int tsSize = ge::GetSizeByDataType(tsType); + tilingData.set_tswType(tswType); + tilingData.set_tsType(tsType); + // 计算不含buff的stride长度 + ub -= numBuckets * numLayer * tswSize + numLayer * DATA_ALIGN_BYTES; // 减去tsw预留ub + uint32_t alignSeqLen = (s * tswSize + DATA_ALIGN_BYTES - 1) / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES / tswSize; + stride = ub / (sizeof(float) + tsSize) / alignSeqLen; + + // 计算clamp buff所需空间 + std::vector shape_vec = {stride * alignSeqLen}; + ge::Shape shape(shape_vec); + uint32_t maxBuff = 0; + uint32_t minBuff = 0; + AscendC::GetClampMaxMinTmpSize(shape, sizeof(float), false, maxBuff, minBuff); + tilingData.set_buffSize(maxBuff); + + // 重新计算stride长度 + stride = (ub - maxBuff) / (sizeof(float) + tsSize) / alignSeqLen; + tilingData.set_stride(stride); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingFunc(gert::TilingContext* context) +{ + OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("timestampShape", context->GetInputShape(TIMESTAMPS_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("tswShape", context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("attrs", context->GetAttrs(), return ge::GRAPH_FAILED); + + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + OPS_CHECK(coreNum == 0, + OPS_LOG_E("Tiling Debug", "Core num is 0."), + return ge::GRAPH_FAILED); + + RelativeAttnBiasTilingData tilingData; + auto ret = TimeTilingFunc(tilingData, context); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + context->SetBlockDim(coreNum); + auto rowTilingData = context->GetRawTilingData(); + OPS_LOG_E_IF_NULL("GetRawTilingData", rowTilingData, return ge::GRAPH_FAILED); + tilingData.SaveToBuffer(rowTilingData->GetData(), rowTilingData->GetCapacity()); + rowTilingData->SetDataSize(tilingData.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext* context) +{ + const gert::Shape* tsShape = context->GetInputShape(TIMESTAMPS_INDEX); + const gert::Shape* tswShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX); + gert::Shape* rabTimeOutShape = context->GetOutputShape(RAB_TIME_INDEX); + int bs = tsShape->GetDim(DIM0); + int s = tsShape->GetDim(DIM1); + int numLayers = tswShape->GetDim(DIM1); + + rabTimeOutShape->SetDimNum(RAB_TIME_OUT_DIM); + rabTimeOutShape->SetDim(DIM0, numLayers); + rabTimeOutShape->SetDim(DIM1, bs); + rabTimeOutShape->SetDim(DIM2, s); + rabTimeOutShape->SetDim(DIM3, DIM_PLACE_HOLDER); + rabTimeOutShape->SetDim(DIM4, s); + rabTimeOutShape->SetDim(DIM5, DIM_PLACE_HOLDER); + return GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class RelativeAttnBias : public OpDef { +public: + explicit RelativeAttnBias(const char* name) : OpDef(name) + { + this->Input("timestamps") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("timestamps_weights") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("rab_time") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("bucket_divisor").Float(); + + this->SetInferShape(ge::InferShape); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") + .ExtendCfgInfo("coreType.value", "AiCore") + .ExtendCfgInfo("prebuildPattern.value", "Opaque"); + + this->AICore().SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910", aicore_config); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + this->AICore().AddConfig("ascend310p", aicore_config); + } +}; + +OP_ADD(RelativeAttnBias); + +} // namespace ops \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_tiling.h new file mode 100644 index 00000000..3bfcca17 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_tiling.h @@ -0,0 +1,31 @@ +/** + * @file relative_attn_bias_tiling.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#include "register/tilingdata_base.h" +constexpr int MAX_BATCH_SIZE = 512; + +namespace optiling { +BEGIN_TILING_DATA_DEF(TilingData) +TILING_DATA_FIELD_DEF(int64_t, s); +TILING_DATA_FIELD_DEF(int64_t, bs); +TILING_DATA_FIELD_DEF(int64_t, stride); + +TILING_DATA_FIELD_DEF(float, bucketDivisor); +TILING_DATA_FIELD_DEF(int64_t, numBuckets); +TILING_DATA_FIELD_DEF(int64_t, numLayer); +TILING_DATA_FIELD_DEF(float, clampMax); + +TILING_DATA_FIELD_DEF(int, tswType); +TILING_DATA_FIELD_DEF(int, tsType); +TILING_DATA_FIELD_DEF(int, buffSize); + +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(RelativeAttnBias, TilingData) +} // namespace optiling +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h new file mode 100644 index 00000000..78d0b3f2 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h @@ -0,0 +1,39 @@ +/** + * @file rab_common.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RAB_COMMON_H +#define MXREC_ADD_ONS_RAB_COMMON_H + +#include "kernel_operator.h" +constexpr int DATA_ALIGN_BYTES = 32; +constexpr int MAX_BATCH_SIZE = 512; +constexpr int NUM_BUFFER = 2; +constexpr int MAX_SEQ_CNT = 128; +constexpr int GATHER_PROCESS_WINDOW = 4096; + +constexpr int8_t TYPE_FP32 = 0; +constexpr int8_t TYPE_FP16 = 1; +constexpr int8_t TYPE_INT32 = 3; +constexpr int8_t TYPE_INT64 = 9; + +using namespace AscendC; + +struct Args { + // pos_bias + GM_ADDR positionBias; + GM_ADDR identity; + // ts_bias + GM_ADDR timestamps; + GM_ADDR timestampsWeights; + // out + GM_ADDR rabPosOut; + GM_ADDR rabTimeOut; + + GM_ADDR workspace; + GM_ADDR tiling; +}; +#endif // MXREC_ADD_ONS_RAB_COMMON_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias.cpp new file mode 100644 index 00000000..b2f3ec35 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias.cpp @@ -0,0 +1,29 @@ +/** +* @file relative_attn_bias.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include "kernel_operator.h" +#include "rab_common.h" +#include "relative_attn_bias_time.h" + +extern "C" __global__ __aicore__ void relative_attn_bias(GM_ADDR timestamps, + GM_ADDR timestampsWeights, + GM_ADDR rabTimeOut, + GM_ADDR workspace, + GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + Args args{ + timestamps, timestampsWeights, rabTimeOut, workspace, tiling + }; + if (tilingData.tswType == TYPE_FP32) { + RelativeAttnBiasTime kernel; + kernel.Compute(args); + } else if (tilingData.tswType == TYPE_FP16) { + RelativeAttnBiasTime kernel; + kernel.Compute(args); + } +} diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h new file mode 100644 index 00000000..8f859551 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h @@ -0,0 +1,255 @@ +/** + * @file relative_attn_bias_time.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H +#include +#include "rab_common.h" +#include "kernel_operator.h" +using namespace AscendC; + +struct SequenceParams { + int startIndexGT; + int startIndexUb; + int subValue; +}; + +template +class RelativeAttnBiasTime { +public: + __aicore__ inline RelativeAttnBiasTime() {} + + __aicore__ inline void Init(Args args) + { + GET_TILING_DATA(tilingData, args.tiling); + s = tilingData.s; + bs = tilingData.bs; + stride = tilingData.timeStride; + alignSeqLen = Ceil(s * sizeof(FloatType)) / sizeof(FloatType); + + int totalLen = bs * s; + uint32_t seqDatasize = s * sizeof(FloatType); + alignLen = seqDatasize / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES; + alignCnt = alignLen / sizeof(FloatType); + unalignLen = seqDatasize - alignLen; + unalignCnt = unalignLen / sizeof(FloatType); + + div = 1 / tilingData.bucketDivisor; + numBuckets = tilingData.numBuckets; + alignNumBuckets = Ceil(numBuckets * sizeof(FloatType)) / sizeof(FloatType); + numLayer = tilingData.numLayer; + + clampMin = 1; // 根据仿真代码,指定为1 + clampMax = tilingData.clampMax; + + timestampsGT.SetGlobalBuffer((__gm__ int32_t*)args.timestamps, bs * s); + timestampsWeightsGT.SetGlobalBuffer((__gm__ FloatType*)args.timestampsWeights, numBuckets * numLayer); + rabTimeBiasOutGT.SetGlobalBuffer((__gm__ FloatType*)args.rabTimeOut, numLayer * bs * s * s); + + pipe.InitBuffer(queTimestamps, 1, stride * alignSeqLen * sizeof(int32_t)); + pipe.InitBuffer(queTimestampsFloat, 1, stride * alignSeqLen * sizeof(float)); + pipe.InitBuffer(queTimestampsWeights, 1, alignNumBuckets * numLayer * sizeof(FloatType)); + pipe.InitBuffer(tmpQue, 1, Ceil(tilingData.buffSize)); + + int totalTableSizeSplit = totalLen % GetBlockNum(); + int baseLen = totalLen / GetBlockNum(); + if (GetBlockIdx() >= totalTableSizeSplit) { + processRowLen = baseLen; + startIndex = totalTableSizeSplit * (baseLen + 1) + (GetBlockIdx() - totalTableSizeSplit) * baseLen; + } else { + processRowLen = baseLen + 1; + startIndex = GetBlockIdx() * (baseLen + 1); + } + } + + __aicore__ inline void FillSeqParams(SequenceParams* params, int offset, int cnt) + { + LocalTensor ts = queTimestamps.AllocTensor(); + DataCopy(ts, timestampsGT[offset], Ceil(cnt)); + for (int i = 0; i < cnt; ++i) { + int seqSubValue = ts.GetValue(i); + int seqId = (offset + i) / s; + int seqOffsetUb = i * alignSeqLen; + int seqOffsetGT = seqId * s; + + params[i].startIndexGT = seqOffsetGT; + params[i].startIndexUb = seqOffsetUb; + params[i].subValue = seqSubValue; + } + queTimestamps.FreeTensor(ts); + } + + __aicore__ inline void DataCopyIn(SequenceParams* params, int cnt) + { + LocalTensor ts = queTimestamps.AllocTensor(); + for (int i = 0; i < cnt; ++i) { + SequenceParams param = params[i]; + int startIndexGT = param.startIndexGT; + int startIndexUb = param.startIndexUb; + + DataCopy(ts[startIndexUb], timestampsGT[startIndexGT], alignSeqLen); + } + queTimestamps.EnQue(ts); + } + + __aicore__ inline void ComputeBucketTimestamps(SequenceParams* params, int rowCnt) + { + LocalTensor tsInt = queTimestamps.DeQue(); + LocalTensor tsTmp = tsInt.template ReinterpretCast(); + LocalTensor ts = queTimestampsFloat.AllocTensor(); + LocalTensor buff = tmpQue.AllocTensor(); + + for (int i = 0; i < rowCnt; ++i) { + SequenceParams param = params[i]; + int startIndexUb = param.startIndexUb; + int value = param.subValue; + Adds(tsInt[startIndexUb], tsInt[startIndexUb], (int32_t)-value, s); + } + + uint32_t cnt = rowCnt * alignSeqLen; + Cast(ts, tsInt, RoundMode::CAST_NONE, cnt); + + Abs(ts, ts, cnt); + ClampMin(tsTmp, ts, buff, clampMin, cnt); + Log(ts, tsTmp, cnt); + Muls(ts, ts, div, cnt); + ClampMax(tsTmp, ts, buff, (float)numBuckets, cnt); + + Cast(tsInt, tsTmp, RoundMode::CAST_TRUNC, cnt); + Muls(tsInt, tsInt, (int32_t)sizeof(FloatType), cnt); // 计算gather时的偏移量单位为bytes + + tmpQue.FreeTensor(buff); + queTimestampsFloat.FreeTensor(ts); + queTimestamps.EnQue(tsInt); + } + + __aicore__ inline void IndexSelect(LocalTensor& tsw, LocalTensor& tsInt, int layer, int rowCnt) + { + uint32_t cnt = rowCnt * alignSeqLen; + LocalTensor rabTime = queTimestampsFloat.AllocTensor(); + uint32_t processLenMax = GATHER_PROCESS_WINDOW / sizeof(FloatType); + uint32_t tmpOffset = 0; + while (tmpOffset < cnt) { + uint32_t processLen = (cnt - tmpOffset) > processLenMax ? processLenMax : (cnt - tmpOffset); + Gather(rabTime[tmpOffset], tsw[layer * alignNumBuckets], tsInt[tmpOffset], (uint32_t)0, processLen); + tmpOffset += processLen; + } + queTimestampsFloat.EnQue(rabTime); + } + + __aicore__ inline void DataCopyOut(uint32_t ptr, int rowCnt) + { + LocalTensor rabTime = queTimestampsFloat.DeQue(); + + for (int i = 0; i < rowCnt; ++i) { + uint32_t ptrUb = i * alignSeqLen; + + // 对齐部分拷出 + if (alignLen > 0) { + DataCopy(rabTimeBiasOutGT[ptr + i * s], rabTime[ptrUb], s); + } + // 非对齐拷出 + if (unalignLen == 0) { + continue; + } +#ifdef SUPPORT_V200 + uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(FloatType))) - (1ul << unalignCnt); + uint64_t mask[2] = {mask0, 0}; + Duplicate(rabTime[ptrUb + alignCnt], (FloatType)0, mask, 1, 1, 1); + queTimestampsFloat.EnQue(rabTime); + rabTime = queTimestampsFloat.DeQue(); + SetAtomicAdd(); + DataCopy(rabTimeBiasOutGT[ptr + i * s + alignCnt], rabTime[ptrUb + alignCnt], + Ceil(unalignLen) / sizeof(FloatType)); + SetAtomicNone(); +#else + const DataCopyExtParams dataCopyExtParams{1, unalignLen, 0, 0, 0}; + DataCopyPad(rabTimeBiasOutGT[ptr + i * s + alignCnt], rabTime[ptrUb + alignCnt], dataCopyExtParams); +#endif + } + queTimestampsFloat.FreeTensor(rabTime); + } + + __aicore__ inline void DataCopyInTsw() + { + LocalTensor tsw = queTimestampsWeights.AllocTensor(); + for (int n = 0; n < numLayer; ++n) { + DataCopy(tsw[n * alignNumBuckets], timestampsWeightsGT[n * numBuckets], alignNumBuckets); + } + queTimestampsWeights.EnQue(tsw); + } + + __aicore__ inline int64_t Ceil(int64_t a, int64_t b = DATA_ALIGN_BYTES) + { + if (b == 0) { + return 0; + } + return (a + b - 1) / b * b; + } + + __aicore__ inline void Compute(Args args) + { + Init(args); + DataCopyInTsw(); + LocalTensor tsw = queTimestampsWeights.DeQue(); + + for (int offset = 0; offset < processRowLen; offset += stride) { + int rowOffset = offset + startIndex; + int rowCnt = stride > (processRowLen - offset) ? (processRowLen - offset) : stride; + + SequenceParams params[MAX_SEQ_CNT]; + FillSeqParams(params, rowOffset, rowCnt); + DataCopyIn(params, rowCnt); + ComputeBucketTimestamps(params, rowCnt); + + LocalTensor tsInt = queTimestamps.DeQue(); + for (int n = 0; n < numLayer; ++n) { + IndexSelect(tsw, tsInt, n, rowCnt); + pipe_barrier(PIPE_ALL); + + uint32_t ptr = (n * bs * s + rowOffset) * s; + DataCopyOut(ptr, rowCnt); + } + queTimestamps.FreeTensor(tsInt); + } + queTimestampsWeights.FreeTensor(tsw); + } + +private: + // shape + uint32_t s; + uint32_t alignSeqLen; + uint32_t bs; + uint32_t stride; + // align + uint32_t alignLen; + uint32_t alignCnt; + uint32_t unalignLen; + uint32_t unalignCnt; + // tiling + uint32_t startIndex; + uint32_t processRowLen; + + float div; + int32_t numBuckets; + int32_t alignNumBuckets; + int32_t numLayer; + float clampMin; + float clampMax; + +private: + GlobalTensor timestampsGT; + GlobalTensor timestampsWeightsGT; + GlobalTensor rabTimeBiasOutGT; + + TPipe pipe; + TQue queTimestamps; + TQue queTimestampsFloat; + TQue queTimestampsWeights; + TQue tmpQue; +}; +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias.json b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias.json new file mode 100644 index 00000000..bfb8a8c0 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias.json @@ -0,0 +1,53 @@ +[ + { + "op": "RelativeAttnBiasTime", + "language": "cpp", + "input_desc": [ + { + "name": "timestamps", + "param_type": "required", + "format": [ + "ND", + "ND" + ], + "type": [ + "int32", + "int32" + ] + }, + { + "name": "timestamps_weights", + "param_type": "required", + "format": [ + "ND", + "ND" + ], + "type": [ + "fp32", + "fp16" + ] + } + ], + "output_desc": [ + { + "name": "rab_time", + "param_type": "required", + "format": [ + "ND", + "ND" + ], + "type": [ + "fp32", + "fp16" + ] + } + ], + "attr": [ + { + "name": "bucket_divisor", + "param_type": "required", + "type": "float" + } + ] + } +] \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/run.sh b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/run.sh new file mode 100644 index 00000000..c0041861 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/run.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# ============================================================================== + +set -e + +# 查找msopgen的路径,加入到环境变量PATH中 +msopgen_path=$(find /usr/local/Ascend/ -name msopgen | grep bin) +parent_dir=$(dirname "$msopgen_path") +export PATH=$parent_dir:$PATH + +ai_core="ai_core-Ascend910B1" +if [ "$#" -eq 1 ]; then + ai_core=$1 +fi + +# 利用msopgen生成可编译文件 +rm -rf ./relative_attn_bias_time +python3 /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/bin/msopgen gen -i relative_attn_bias_time.json -f tf -c ${ai_core} -lan cpp -out ./relative_attn_bias_time -m 0 -op RelativeAttnBiasTime +rm -rf relative_attn_bias_time/op_kernel/*.h +rm -rf relative_attn_bias_time/op_kernel/*.cpp +rm -rf relative_attn_bias_time/host/*.h +rm -rf relative_attn_bias_time/host/*.cpp +cp -rf op_kernel relative_attn_bias_time/ +cp -rf op_host relative_attn_bias_time/ + +cd relative_attn_bias_time + +# 判断当前目录下是否存在CMakePresets.json文件 +if [ ! -f "CMakePresets.json" ]; then + echo "ERROR, CMakePresets.json file not exist." + exit 1 +fi + +# 禁止生成CRC校验和 +sed -i 's/--nomd5/--nomd5 --nocrc/g' ./cmake/makeself.cmake + +# 修改cann安装路径 +sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json +# 修改vendor_name 防止覆盖之前vendor_name为customize的算子; +# vendor_name需要和aclnn中的CMakeLists.txt中的CUST_PKG_PATH值同步,不同步aclnn会调用失败; +# vendor_name字段值不能包含customize;包含会导致多算子部署场景CANN的vendors路径下config.ini文件内容截取错误 +sed -i 's:"customize":"relative_attn_bias_time":g' CMakePresets.json + +if [ "$ai_core" = "ai_core-Ascend310P3" ]; then + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_time_kernel.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_time_time.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_time_pos.h +fi + +line=`awk '/ENABLE_SOURCE_PACKAGE/{print NR}' CMakePresets.json` +line=`expr ${line} + 2` +sed -i "${line}s/True/False/g" CMakePresets.json + +# 增加LOG_CPP编译选项支持错误日志打印 +sed -i "1 i include(../../../cmake/func.cmake)" ./op_host/CMakeLists.txt + +line1=`awk '/tartet_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line1}s/OP_TILING_LIB/OP_TILING_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +line2=`awk '/tartet_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line2}s/OP_PROTO_LIB/OP_PROTO_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +bash build.sh + +# 安装编译成功的算子包 +bash ./build_out/custom_opp*.run diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py index 83bb6f01..a2edc02f 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py @@ -161,37 +161,23 @@ def rab_pos_golden(rel_pos_bias: torch.Tensor, identity: torch.Tensor, past_vali @torch.no_grad() -def rab(num_layers, train_len, candidate_len, bs, dtype): +def rab_time(num_layers, train_len, candidate_len, bs, dtype): torch_npu.npu.set_device(DEVICE) - layer_num = random.randint(0, num_layers - 1) - pos_w = create_pos_w(train_len, num_layers).to(dtype) past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) timestamps_weights = create_timestamps_weights(num_layers).to(dtype) - rel_pos_bias_list, identity_list = init_rel_pos_bias(pos_w=pos_w, - train_len=train_len, - candidate_len=candidate_len, - num_layers=num_layers) - rel_pos_bias_list, identity_list = rel_pos_bias_list.to(dtype), identity_list.to(dtype) - - rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) - identity_list = identity_list.to(DEVICE) + timestamps = timestamps.to(DEVICE) timestamps_weights = timestamps_weights.to(DEVICE) - past_valid_lens = past_valid_lens.to(DEVICE) - torch_npu.npu.synchronize() - - rab_pos_out, rab_time_out = rab_npu(rel_pos_bias=rel_pos_bias_list[layer_num, ...], - identity=identity_list[layer_num, ...], - timestamps=timestamps, - timestamps_weights=timestamps_weights, - past_valid_lens=past_valid_lens) torch_npu.npu.synchronize() + rab_time_out = torch.ops.mxrec.relative_attn_bias_time(timestamps_weights=timestamps_weights, + timestamps=timestamps) rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1), timestamps=timestamps) torch_npu.npu.synchronize() + assert torch.allclose(rab_time_out_golden, rab_time_out) @@ -201,7 +187,7 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("bs", [1, 2, 4]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): - rab(num_layers, train_len, candidate_len, bs, dtype) + rab_time(num_layers, train_len, candidate_len, bs, dtype) @pytest.mark.parametrize("num_layers", [1, 8]) @@ -209,4 +195,4 @@ def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("candidate_len", [0]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rab_train(num_layers, train_len, candidate_len, bs, dtype): - rab(num_layers, train_len, candidate_len, bs, dtype) + rab_time(num_layers, train_len, candidate_len, bs, dtype) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py index 8ce7fce4..db8e2af5 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py @@ -187,43 +187,23 @@ def rab_pos(num_layers, train_len, candidate_len, bs, dtype): @torch.no_grad() -def rab(num_layers, train_len, candidate_len, bs, dtype): +def rab_time(num_layers, train_len, candidate_len, bs, dtype): torch_npu.npu.set_device(DEVICE) - layer_num = random.randint(0, num_layers - 1) - pos_w = create_pos_w(train_len, num_layers).to(dtype) past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) timestamps_weights = create_timestamps_weights(num_layers).to(dtype) - rel_pos_bias_list, identity_list = init_rel_pos_bias(pos_w=pos_w, - train_len=train_len, - candidate_len=candidate_len, - num_layers=num_layers) - rel_pos_bias_list, identity_list = rel_pos_bias_list.to(dtype), identity_list.to(dtype) - rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) - identity_list = identity_list.to(DEVICE) timestamps = timestamps.to(DEVICE) timestamps_weights = timestamps_weights.to(DEVICE) - past_valid_lens = past_valid_lens.to(DEVICE) - torch_npu.npu.synchronize() - - rab_pos_out, rab_time_out = rab_npu(rel_pos_bias=rel_pos_bias_list[layer_num, ...], - identity=identity_list[layer_num, ...], - timestamps=timestamps, - timestamps_weights=timestamps_weights, - past_valid_lens=past_valid_lens) - rab_pos_out, rab_time_out = rab_pos_out.to("cpu"), rab_time_out.to("cpu") torch_npu.npu.synchronize() - rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias_list[layer_num, ...].to("cpu"), - identity=identity_list[layer_num, ...].to("cpu"), - past_valid_lens=past_valid_lens.to("cpu")) + rab_time_out = torch.ops.mxrec.relative_attn_bias_time(timestamps_weights=timestamps_weights, + timestamps=timestamps).to("cpu") rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1).to("cpu"), timestamps=timestamps.to("cpu")) torch_npu.npu.synchronize() - assert torch.allclose(rab_pos_out_golden, rab_pos_out) assert torch.allclose(rab_time_out_golden, rab_time_out) @@ -232,8 +212,8 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("candidate_len", [600]) @pytest.mark.parametrize("bs", [1, 2, 4]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): - rab(num_layers, train_len, candidate_len, bs, dtype) +def test_rab_time_eval(num_layers, train_len, candidate_len, bs, dtype): + rab_time(num_layers, train_len, candidate_len, bs, dtype) @pytest.mark.parametrize("num_layers", [8]) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp index e99d93d2..57e7bf64 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp @@ -18,29 +18,6 @@ using torch::autograd::Function; using namespace at; using namespace std; -std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, const Tensor& identity, - const Tensor& timestamps, const Tensor& timestampsWeights, - const at::IntArrayRef pastValidLens, const double bucketDivisor) -{ - auto relPosBiasConti = relPosBias.contiguous(); - auto identityConti = identity.contiguous(); - auto timestampsConti = timestamps.contiguous(); - auto timestampsWeightsConti = timestampsWeights.contiguous(); - - const int bs = pastValidLens.size(); - const int sx2 = relPosBias.size(0); // relPosBias(2s, 2s) - const int s = sx2 / 2; - const int numLayers = timestampsWeights.size(0); - - at::Tensor rabPosOut = at::zeros({bs, sx2, sx2}, relPosBiasConti.options()); - at::Tensor rabTimeOut = at::zeros({numLayers, bs, s, 1, s, 1}, timestampsWeightsConti.options()); - - EXEC_NPU_CMD(aclnnRelativeAttnBias, relPosBiasConti, identityConti, timestampsConti, timestampsWeightsConti, - pastValidLens, bucketDivisor, rabPosOut, rabTimeOut); - rabTimeOut = rabTimeOut.repeat({1, 1, 1, 2, 1, 2}).reshape({numLayers, bs, sx2, sx2}); - return {rabPosOut, rabTimeOut}; -} - Tensor relative_attn_bias_pos_impl_npu(const Tensor& relPosBias, const Tensor& identity, const at::IntArrayRef pastValidLens) { @@ -56,6 +33,22 @@ Tensor relative_attn_bias_pos_impl_npu(const Tensor& relPosBias, const Tensor& i return rabPosOut; } +Tensor relative_attn_bias_time_impl_npu(const Tensor& timestamps, const Tensor& timestampsWeights, + const double bucketDivisor) +{ + auto timestampsConti = timestamps.contiguous(); + auto timestampsWeightsConti = timestampsWeights.contiguous(); + const int numLayers = timestampsWeights.size(0); + const int bs = timestampsConti.size(0); + const int s = timestampsConti.size(1); + const int sx2 = s * 2; + + at::Tensor rabTimeOut = at::zeros({numLayers, bs, s, 1, s, 1}, timestampsWeightsConti.options()); + EXEC_NPU_CMD(aclnnRelativeAttnBiasTime, timestampsConti, timestampsWeightsConti, bucketDivisor, rabTimeOut); + rabTimeOut = rabTimeOut.repeat({1, 1, 1, 2, 1, 2}).reshape({numLayers, bs, sx2, sx2}); + return rabTimeOut; +} + Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, const Tensor& bucketTimestamps, const int64_t numBuckets) { @@ -66,9 +59,8 @@ Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, const Ten auto rabTimeGradConti = rabTimeGrad.contiguous(); auto bucketTimestampsConti = bucketTimestamps.contiguous(); // (n, b, s, s) - bucketTimestampsConti = bucketTimestampsConti.view({batchsize, s, 1, s, 1}) - .repeat({1, 1, 2, 1, 2}) - .reshape({batchsize, sx2, sx2}); + bucketTimestampsConti = + bucketTimestampsConti.view({batchsize, s, 1, s, 1}).repeat({1, 1, 2, 1, 2}).reshape({batchsize, sx2, sx2}); at::Tensor rabTimeGradOut = at::zeros({numLayers, numBuckets}, rabTimeGrad.options()); EXEC_NPU_CMD(aclnnRelativeAttnBiasBackward, rabTimeGradConti, bucketTimestampsConti, numBuckets, rabTimeGradOut); @@ -77,13 +69,10 @@ Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, const Ten TORCH_LIBRARY_FRAGMENT(mxrec, m) { - m.def("relative_attn_bias(Tensor rel_pos_bias, " - " Tensor identity, " - " Tensor timestamps, " - " Tensor timestamps_weights, " - " int[] past_valid_lens," - " float bucket_divisor" - " ) -> (Tensor, Tensor)"); + m.def("relative_attn_bias_time(Tensor timestamps, " + " Tensor timestamps_weights, " + " float bucket_divisor" + " ) -> Tensor"); m.def("relative_attn_bias_pos(Tensor rel_pos_bias, " " Tensor identity, " " int[] past_valid_lens" @@ -96,14 +85,14 @@ TORCH_LIBRARY_FRAGMENT(mxrec, m) TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) { - m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); m.impl("relative_attn_bias_pos", &relative_attn_bias_pos_impl_npu); + m.impl("relative_attn_bias_time", &relative_attn_bias_time_impl_npu); m.impl("relative_attn_bias_backward", &relative_attn_bias_backward_impl_npu); } TORCH_LIBRARY_IMPL(fbgemm, PrivateUse1, m) { - m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); m.impl("relative_attn_bias_pos", &relative_attn_bias_pos_impl_npu); + m.impl("relative_attn_bias_time", &relative_attn_bias_time_impl_npu); m.impl("relative_attn_bias_backward", &relative_attn_bias_backward_impl_npu); } -- Gitee From 081802b49536ed67b7484a3a03b9fc1a3ec2881f Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 3 Jun 2025 10:13:28 +0800 Subject: [PATCH 20/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=E3=80=82?= =?UTF-8?q?=E6=8B=86=E5=88=86time&pos=E5=90=8E=E7=BC=96=E8=AF=91debug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...tn_bias.cpp => relative_attn_bias_pos.cpp} | 10 ++++---- ...ling.h => relative_attn_bias_pos_tiling.h} | 10 ++++---- .../op_kernel/rab_common.h | 2 -- ...tn_bias.cpp => relative_attn_bias_pos.cpp} | 12 +++++----- .../relative_attn_bias_pos.json | 2 +- ...n_bias.cpp => relative_attn_bias_time.cpp} | 23 ++++++++++--------- ...ing.h => relative_attn_bias_time_tiling.h} | 12 +++++----- .../op_kernel/rab_common.h | 6 ----- ...n_bias.cpp => relative_attn_bias_time.cpp} | 12 +++++----- .../op_kernel/relative_attn_bias_time.h | 2 +- ...bias.json => relative_attn_bias_time.json} | 0 .../test_relative_attn_bias.py | 10 ++++---- .../test_relative_attn_bias_v200.py | 3 ++- ...split_embedding_codegen_lookup_function.py | 2 +- 14 files changed, 51 insertions(+), 55 deletions(-) rename mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/{relative_attn_bias.cpp => relative_attn_bias_pos.cpp} (97%) rename mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/{relative_attn_bias_tiling.h => relative_attn_bias_pos_tiling.h} (64%) rename mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/{relative_attn_bias.cpp => relative_attn_bias_pos.cpp} (55%) rename mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/{relative_attn_bias.cpp => relative_attn_bias_time.cpp} (91%) rename mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/{relative_attn_bias_tiling.h => relative_attn_bias_time_tiling.h} (64%) rename mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/{relative_attn_bias.cpp => relative_attn_bias_time.cpp} (55%) rename mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/{relative_attn_bias.json => relative_attn_bias_time.json} (100%) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp similarity index 97% rename from mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp rename to mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp index 4d71e14f..5f7a7173 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp @@ -1,12 +1,12 @@ /** -* @file relative_attn_bias.cpp +* @file relative_attn_bias_pos.cpp * * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. * */ #include -#include "relative_attn_bias_tiling.h" +#include "relative_attn_bias_pos_tiling.h" #include "register/op_def_registry.h" #include "tiling/tiling_api.h" #include "tiling/platform/platform_ascendc.h" @@ -144,9 +144,9 @@ static ge::graphStatus InferShape(gert::InferShapeContext* context) } // namespace ge namespace ops { -class RelativeAttnBias : public OpDef { +class RelativeAttnBiasPos : public OpDef { public: - explicit RelativeAttnBias(const char* name) : OpDef(name) + explicit RelativeAttnBiasPos(const char* name) : OpDef(name) { this->Input("rel_pos_bias") .ParamType(REQUIRED) @@ -181,6 +181,6 @@ public: } }; -OP_ADD(RelativeAttnBias); +OP_ADD(RelativeAttnBiasPos); } // namespace ops \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h similarity index 64% rename from mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_tiling.h rename to mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h index 17d05ef5..d66dafa2 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_tiling.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h @@ -1,12 +1,12 @@ /** - * @file relative_attn_bias_tiling.h + * @file relative_attn_bias_pos_tiling.h * * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. * */ -#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H -#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_TILING_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_TILING_H #include "register/tilingdata_base.h" constexpr int MAX_BATCH_SIZE = 512; @@ -20,6 +20,6 @@ TILING_DATA_FIELD_DEF_ARR(uint32_t, MAX_BATCH_SIZE, pastValidLens); TILING_DATA_FIELD_DEF(int, dataType); END_TILING_DATA_DEF; -REGISTER_TILING_DATA_CLASS(RelativeAttnBias, TilingData) +REGISTER_TILING_DATA_CLASS(RelativeAttnBiasPos, TilingData) } // namespace optiling -#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_TILING_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h index 62b5ddcb..fb7148d4 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h @@ -12,8 +12,6 @@ constexpr int DATA_ALIGN_BYTES = 32; constexpr int MAX_BATCH_SIZE = 512; constexpr int NUM_BUFFER = 2; -constexpr int MAX_SEQ_CNT = 128; -constexpr int GATHER_PROCESS_WINDOW = 4096; constexpr int8_t TYPE_FP32 = 0; constexpr int8_t TYPE_FP16 = 1; diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp similarity index 55% rename from mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias.cpp rename to mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp index 249f4c2b..06798639 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp @@ -1,5 +1,5 @@ /** -* @file relative_attn_bias.cpp +* @file relative_attn_bias_pos.cpp * * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. * @@ -9,11 +9,11 @@ #include "relative_attn_bias_pos.h" #include "kernel_operator.h" -extern "C" __global__ __aicore__ void relative_attn_bias(GM_ADDR positionBias, - GM_ADDR identity, - GM_ADDR rabPosOut, - GM_ADDR workspace, - GM_ADDR tiling) +extern "C" __global__ __aicore__ void relative_attn_bias_pos(GM_ADDR positionBias, + GM_ADDR identity, + GM_ADDR rabPosOut, + GM_ADDR workspace, + GM_ADDR tiling) { GET_TILING_DATA(tilingData, tiling); Args args{ diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json index d5ed4f18..c986b36e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json @@ -1,6 +1,6 @@ [ { - "op": "RelativeAttnBias", + "op": "RelativeAttnBiasPos", "language": "cpp", "input_desc": [ { diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time.cpp similarity index 91% rename from mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias.cpp rename to mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time.cpp index 13046a0a..f9779cb3 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time.cpp @@ -1,12 +1,12 @@ /** -* @file relative_attn_bias.cpp +* @file relative_attn_bias_time.cpp * * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. * */ #include -#include "relative_attn_bias_tiling.h" +#include "relative_attn_bias_time_tiling.h" #include "register/op_def_registry.h" #include "tiling/tiling_api.h" #include "tiling/platform/platform_ascendc.h" @@ -37,14 +37,14 @@ constexpr int DIM5 = 5; constexpr int MAX_S = 4300; namespace optiling { -static ge::graphStatus TimeTilingFunc(TilingData& tilingData, gert::TilingContext* context) +static ge::graphStatus TimeTilingFunc(RelativeAttnBiasTimeTilingData& tilingData, gert::TilingContext* context) { auto tsShape = context->GetInputShape(TIMESTAMPS_INDEX)->GetStorageShape(); // (b, s) auto tswShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX)->GetStorageShape(); // (num_layer, num_buckets) int batchsize = tsShape.GetDim(DIM0); // (b, s) int s = tsShape.GetDim(DIM1); // (b, s) - int numLayers = tswShape.GetDim(DIM0); // (num_layer, num_buckets) + int numLayer = tswShape.GetDim(DIM0); // (num_layer, num_buckets) int numBuckets = tswShape.GetDim(DIM1); // (num_layer, num_buckets) float divs = *context->GetAttrs()->GetFloat(BUCKET_DIV_INDEX); float clampMax = exp((numBuckets - 1) * divs); @@ -58,11 +58,11 @@ static ge::graphStatus TimeTilingFunc(TilingData& tilingData, gert::TilingContex OPS_CHECK(s > MAX_S, OPS_LOG_E("Tiling Debug", "Len of timestamps sequence larger than limit."), return ge::GRAPH_FAILED); - OPS_CHECK(bs <= 0, + OPS_CHECK(batchsize <= 0, OPS_LOG_E("Tiling Debug", "Invalid batchsize of timestamps."), return ge::GRAPH_FAILED); - tilingData.set_bs(bs); + tilingData.set_bs(batchsize); tilingData.set_s(s); tilingData.set_numLayer(numLayer); tilingData.set_numBuckets(numBuckets); @@ -72,6 +72,7 @@ static ge::graphStatus TimeTilingFunc(TilingData& tilingData, gert::TilingContex // 计算stride、buff // 获取ub uint64_t ub; + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); ub = ub - RESERVER_UB_SIZE; // 获取数据类型 @@ -84,7 +85,7 @@ static ge::graphStatus TimeTilingFunc(TilingData& tilingData, gert::TilingContex // 计算不含buff的stride长度 ub -= numBuckets * numLayer * tswSize + numLayer * DATA_ALIGN_BYTES; // 减去tsw预留ub uint32_t alignSeqLen = (s * tswSize + DATA_ALIGN_BYTES - 1) / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES / tswSize; - stride = ub / (sizeof(float) + tsSize) / alignSeqLen; + uint32_t stride = ub / (sizeof(float) + tsSize) / alignSeqLen; // 计算clamp buff所需空间 std::vector shape_vec = {stride * alignSeqLen}; @@ -113,7 +114,7 @@ static ge::graphStatus TilingFunc(gert::TilingContext* context) OPS_LOG_E("Tiling Debug", "Core num is 0."), return ge::GRAPH_FAILED); - RelativeAttnBiasTilingData tilingData; + RelativeAttnBiasTimeTilingData tilingData; auto ret = TimeTilingFunc(tilingData, context); if (ret != ge::GRAPH_SUCCESS) { return ret; @@ -151,9 +152,9 @@ static ge::graphStatus InferShape(gert::InferShapeContext* context) } // namespace ge namespace ops { -class RelativeAttnBias : public OpDef { +class RelativeAttnBiasTime : public OpDef { public: - explicit RelativeAttnBias(const char* name) : OpDef(name) + explicit RelativeAttnBiasTime(const char* name) : OpDef(name) { this->Input("timestamps") .ParamType(REQUIRED) @@ -188,6 +189,6 @@ public: } }; -OP_ADD(RelativeAttnBias); +OP_ADD(RelativeAttnBiasTime); } // namespace ops \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time_tiling.h similarity index 64% rename from mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_tiling.h rename to mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time_tiling.h index 3bfcca17..5579bb1a 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_tiling.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time_tiling.h @@ -1,17 +1,17 @@ /** - * @file relative_attn_bias_tiling.h + * @file relative_attn_bias_time_tiling.h * * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. * */ -#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H -#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_TILING_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_TILING_H #include "register/tilingdata_base.h" constexpr int MAX_BATCH_SIZE = 512; namespace optiling { -BEGIN_TILING_DATA_DEF(TilingData) +BEGIN_TILING_DATA_DEF(RelativeAttnBiasTimeTilingData) TILING_DATA_FIELD_DEF(int64_t, s); TILING_DATA_FIELD_DEF(int64_t, bs); TILING_DATA_FIELD_DEF(int64_t, stride); @@ -26,6 +26,6 @@ TILING_DATA_FIELD_DEF(int, tsType); TILING_DATA_FIELD_DEF(int, buffSize); END_TILING_DATA_DEF; -REGISTER_TILING_DATA_CLASS(RelativeAttnBias, TilingData) +REGISTER_TILING_DATA_CLASS(RelativeAttnBiasTime, RelativeAttnBiasTimeTilingData) } // namespace optiling -#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TILING_H +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_TILING_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h index 78d0b3f2..b574b0e1 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h @@ -10,8 +10,6 @@ #include "kernel_operator.h" constexpr int DATA_ALIGN_BYTES = 32; -constexpr int MAX_BATCH_SIZE = 512; -constexpr int NUM_BUFFER = 2; constexpr int MAX_SEQ_CNT = 128; constexpr int GATHER_PROCESS_WINDOW = 4096; @@ -23,14 +21,10 @@ constexpr int8_t TYPE_INT64 = 9; using namespace AscendC; struct Args { - // pos_bias - GM_ADDR positionBias; - GM_ADDR identity; // ts_bias GM_ADDR timestamps; GM_ADDR timestampsWeights; // out - GM_ADDR rabPosOut; GM_ADDR rabTimeOut; GM_ADDR workspace; diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.cpp similarity index 55% rename from mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias.cpp rename to mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.cpp index b2f3ec35..54583eb0 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.cpp @@ -1,5 +1,5 @@ /** -* @file relative_attn_bias.cpp +* @file relative_attn_bias_time.cpp * * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. * @@ -9,11 +9,11 @@ #include "rab_common.h" #include "relative_attn_bias_time.h" -extern "C" __global__ __aicore__ void relative_attn_bias(GM_ADDR timestamps, - GM_ADDR timestampsWeights, - GM_ADDR rabTimeOut, - GM_ADDR workspace, - GM_ADDR tiling) +extern "C" __global__ __aicore__ void relative_attn_bias_time(GM_ADDR timestamps, + GM_ADDR timestampsWeights, + GM_ADDR rabTimeOut, + GM_ADDR workspace, + GM_ADDR tiling) { GET_TILING_DATA(tilingData, tiling); Args args{ diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h index 8f859551..4ee26a1f 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h @@ -28,7 +28,7 @@ public: GET_TILING_DATA(tilingData, args.tiling); s = tilingData.s; bs = tilingData.bs; - stride = tilingData.timeStride; + stride = tilingData.stride; alignSeqLen = Ceil(s * sizeof(FloatType)) / sizeof(FloatType); int totalLen = bs * s; diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias.json b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias_time.json similarity index 100% rename from mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias.json rename to mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias_time.json diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py index a2edc02f..cb1c9929 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py @@ -111,7 +111,7 @@ def rab_npu(rel_pos_bias: torch.Tensor, return rab_pos, rab_time -def rab_time_golden(ts_w: torch.Tensor, timestamps: torch.Tensor): +def rab_time_golden(ts_w: torch.Tensor, timestamps: torch.Tensor, bucketization_divisor: float): """ num_buckets = 128 num_layers = 1 - 20 @@ -131,7 +131,7 @@ def rab_time_golden(ts_w: torch.Tensor, timestamps: torch.Tensor): diff_timestamps = timestamps.reshape(bs, infer_len, 1) - timestamps.reshape(bs, 1, infer_len) clamp_max = torch.exp(torch.tensor(NUM_BUCKETS * BUCKET_DIVISOR)) - diff_timestamps = torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) / BUCKET_DIVISOR + diff_timestamps = torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) / bucketization_divisor bucket_timestamps = diff_timestamps.long().view(-1) rab_time = torch.index_select(ts_w, dim=0, index=bucket_timestamps) @@ -173,9 +173,11 @@ def rab_time(num_layers, train_len, candidate_len, bs, dtype): torch_npu.npu.synchronize() rab_time_out = torch.ops.mxrec.relative_attn_bias_time(timestamps_weights=timestamps_weights, - timestamps=timestamps) + timestamps=timestamps, + bucket_divisor=BUCKET_DIVISOR) rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1), - timestamps=timestamps) + timestamps=timestamps, + bucketization_divisor=BUCKET_DIVISOR) torch_npu.npu.synchronize() assert torch.allclose(rab_time_out_golden, rab_time_out) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py index db8e2af5..fcc9296c 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py @@ -199,7 +199,8 @@ def rab_time(num_layers, train_len, candidate_len, bs, dtype): torch_npu.npu.synchronize() rab_time_out = torch.ops.mxrec.relative_attn_bias_time(timestamps_weights=timestamps_weights, - timestamps=timestamps).to("cpu") + timestamps=timestamps, + bucket_divisor=BUCKET_DIVISOR).to("cpu") rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1).to("cpu"), timestamps=timestamps.to("cpu")) torch_npu.npu.synchronize() diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/split_embedding_codegen_lookup_adagrad_function/test_split_embedding_codegen_lookup_function.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/split_embedding_codegen_lookup_adagrad_function/test_split_embedding_codegen_lookup_function.py index c68eaee0..f1b8a2d3 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/split_embedding_codegen_lookup_adagrad_function/test_split_embedding_codegen_lookup_function.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/split_embedding_codegen_lookup_adagrad_function/test_split_embedding_codegen_lookup_function.py @@ -51,7 +51,7 @@ OPTIMIZER_PARAM = { @dataclass class LookupParams: - tables: list[int] + tables: list[list[int]] mutile_hots: list[int] batch_size: int pooling_mode: PoolingMode -- Gitee From 2f27ce4e21f80e2748489b8d9f80148e127501c8 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 3 Jun 2025 10:29:51 +0800 Subject: [PATCH 21/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=E3=80=82?= =?UTF-8?q?=E6=8B=86=E5=88=86pos=20debug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_pos.cpp | 3 +- .../op_host/relative_attn_bias_pos_tiling.h | 2 +- .../op_kernel/relative_attn_bias_pos.cpp | 4 +-- .../op_kernel/relative_attn_bias_pos.h | 2 +- .../test_relative_attn_bias.py | 28 +++++++++++++++++++ 5 files changed, 34 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp index 5f7a7173..0e0ed4d2 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp @@ -70,13 +70,14 @@ static ge::graphStatus PosTilingFunc(TilingData& tilingData, gert::TilingContext // 获取ub uint64_t ub; + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); ub = ub - RESERVER_UB_SIZE; // 获取数据类型 auto identityType = context->GetInputTensor(IDENTITY_INDEX)->GetDataType(); auto biasType = context->GetInputTensor(REL_POS_BIAS_INDEX)->GetDataType(); - int identitySize = ge::GetSizeByDataType(identityType); + int identitySize = ge::GetSizeByDataType(biasType); OPS_CHECK(identityType != biasType, OPS_LOG_E("Tiling Debug", "Mismatch data type of identity and rel_pos_bias."), return ge::GRAPH_FAILED); diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h index d66dafa2..313563a8 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h @@ -17,7 +17,7 @@ TILING_DATA_FIELD_DEF(int64_t, bs); TILING_DATA_FIELD_DEF(int64_t, stride); TILING_DATA_FIELD_DEF_ARR(uint32_t, MAX_BATCH_SIZE, pastValidLens); -TILING_DATA_FIELD_DEF(int, dataType); +TILING_DATA_FIELD_DEF(int, biasType); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(RelativeAttnBiasPos, TilingData) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp index 06798639..1c97532e 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp @@ -19,10 +19,10 @@ extern "C" __global__ __aicore__ void relative_attn_bias_pos(GM_ADDR positionBia Args args{ positionBias, identity, rabPosOut, workspace, tiling }; - if (tilingData.dataType == TYPE_FP32) { + if (tilingData.biasType == TYPE_FP32) { RelativeAttnBiasPos kernel; kernel.Compute(args); - } else if (tilingData.floatType == TYPE_FP16) { + } else if (tilingData.biasType == TYPE_FP16) { RelativeAttnBiasPos kernel; kernel.Compute(args); } diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h index bd3fccea..0dcb1c4c 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h @@ -23,7 +23,7 @@ public: GET_TILING_DATA(tilingData, args.tiling); s = SEQ_EXPAND * tilingData.s; bs = tilingData.bs; - stride = tilingData.positionStride; + stride = tilingData.stride; for (auto i = 0; i < bs; ++i) { pastValidLens[i] = tilingData.pastValidLens[i]; } diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py index cb1c9929..7f4c7781 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py @@ -160,6 +160,32 @@ def rab_pos_golden(rel_pos_bias: torch.Tensor, identity: torch.Tensor, past_vali return rel_pos_bias_list +@torch.no_grad() +def rab_pos(num_layers, train_len, candidate_len, bs, dtype): + torch_npu.npu.set_device(DEVICE) + pos_w = create_pos_w(train_len, num_layers).to(dtype) + past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) + rel_pos_bias_list, identity_list = init_rel_pos_bias(pos_w=pos_w, + train_len=train_len, + candidate_len=candidate_len, + num_layers=num_layers) + rel_pos_bias_list, identity_list = rel_pos_bias_list.to(dtype), identity_list.to(dtype) + + rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) + identity_list = identity_list.to(DEVICE) + past_valid_lens = past_valid_lens.to(DEVICE) + torch_npu.npu.synchronize() + + for rel_pos_bias, identity in zip(rel_pos_bias_list, identity_list): + rab_pos_out = torch.ops.mxrec.relative_attn_bias_pos(rel_pos_bias=rel_pos_bias, + identity=identity, + past_valid_lens=past_valid_lens.tolist()) + rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias, + identity=identity, + past_valid_lens=past_valid_lens) + assert torch.allclose(rab_pos_out_golden, rab_pos_out) + + @torch.no_grad() def rab_time(num_layers, train_len, candidate_len, bs, dtype): torch_npu.npu.set_device(DEVICE) @@ -190,6 +216,7 @@ def rab_time(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): rab_time(num_layers, train_len, candidate_len, bs, dtype) + rab_pos(num_layers, train_len, candidate_len, bs, dtype) @pytest.mark.parametrize("num_layers", [1, 8]) @@ -198,3 +225,4 @@ def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rab_train(num_layers, train_len, candidate_len, bs, dtype): rab_time(num_layers, train_len, candidate_len, bs, dtype) + rab_pos(num_layers, train_len, candidate_len, bs, dtype) -- Gitee From b00a0daeb913273e9573dafb2df44683947ce944 Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 3 Jun 2025 10:35:45 +0800 Subject: [PATCH 22/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=E3=80=82?= =?UTF-8?q?=E6=8B=86=E5=88=86pos=20debug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../op_host/relative_attn_bias_pos.cpp | 8 ++++---- .../op_host/relative_attn_bias_pos_tiling.h | 2 +- .../op_kernel/relative_attn_bias_pos.cpp | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp index 0e0ed4d2..a520ffcf 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp @@ -77,17 +77,17 @@ static ge::graphStatus PosTilingFunc(TilingData& tilingData, gert::TilingContext auto identityType = context->GetInputTensor(IDENTITY_INDEX)->GetDataType(); auto biasType = context->GetInputTensor(REL_POS_BIAS_INDEX)->GetDataType(); - int identitySize = ge::GetSizeByDataType(biasType); + int biasDataSize = ge::GetSizeByDataType(biasType); OPS_CHECK(identityType != biasType, OPS_LOG_E("Tiling Debug", "Mismatch data type of identity and rel_pos_bias."), return ge::GRAPH_FAILED); - OPS_CHECK(identitySize == 0, + OPS_CHECK(biasDataSize < 1, OPS_LOG_E("Tiling Debug", "Invalid data type."), return ge::GRAPH_FAILED); - tilingData.set_dataType(identityType); + tilingData.set_dataType(biasType); // 计算一次处理的窗口大小(stride) - int stride = ub / (NUM_BUFFER * 3 * identitySize); + int stride = ub / (NUM_BUFFER * 3 * biasDataSize); tilingData.set_stride(stride); return ge::GRAPH_SUCCESS; } diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h index 313563a8..d66dafa2 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h @@ -17,7 +17,7 @@ TILING_DATA_FIELD_DEF(int64_t, bs); TILING_DATA_FIELD_DEF(int64_t, stride); TILING_DATA_FIELD_DEF_ARR(uint32_t, MAX_BATCH_SIZE, pastValidLens); -TILING_DATA_FIELD_DEF(int, biasType); +TILING_DATA_FIELD_DEF(int, dataType); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(RelativeAttnBiasPos, TilingData) diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp index 1c97532e..211edd74 100644 --- a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp @@ -19,10 +19,10 @@ extern "C" __global__ __aicore__ void relative_attn_bias_pos(GM_ADDR positionBia Args args{ positionBias, identity, rabPosOut, workspace, tiling }; - if (tilingData.biasType == TYPE_FP32) { + if (tilingData.dataType == TYPE_FP32) { RelativeAttnBiasPos kernel; kernel.Compute(args); - } else if (tilingData.biasType == TYPE_FP16) { + } else if (tilingData.dataType == TYPE_FP16) { RelativeAttnBiasPos kernel; kernel.Compute(args); } -- Gitee From 7a5893eeedf3147b8ecfa614698b3889cb78c69b Mon Sep 17 00:00:00 2001 From: zhoucy Date: Tue, 3 Jun 2025 10:40:49 +0800 Subject: [PATCH 23/23] =?UTF-8?q?[fix]relative=5Fattn=5Fbias=E3=80=82?= =?UTF-8?q?=E6=8B=86=E5=88=86pos=20debug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../relative_attn_bias/test_relative_attn_bias.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py index 7f4c7781..a6144e77 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py @@ -15,7 +15,6 @@ # limitations under the License. # ============================================================================== -import random import sysconfig import pytest @@ -178,11 +177,11 @@ def rab_pos(num_layers, train_len, candidate_len, bs, dtype): for rel_pos_bias, identity in zip(rel_pos_bias_list, identity_list): rab_pos_out = torch.ops.mxrec.relative_attn_bias_pos(rel_pos_bias=rel_pos_bias, - identity=identity, - past_valid_lens=past_valid_lens.tolist()) + identity=identity, + past_valid_lens=past_valid_lens.tolist()) rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias, - identity=identity, - past_valid_lens=past_valid_lens) + identity=identity, + past_valid_lens=past_valid_lens) assert torch.allclose(rab_pos_out_golden, rab_pos_out) -- Gitee