diff --git a/mxrec_add_ons/build/build.sh b/mxrec_add_ons/build/build.sh index 38720b991443d5b5153ea79c696c8475aa97ddcb..cc8eba35610d340e8609bc23aaf8a2ba4b0c9208 100644 --- a/mxrec_add_ons/build/build.sh +++ b/mxrec_add_ons/build/build.sh @@ -32,7 +32,7 @@ permute2d_sparse_data split_embedding_codegen_forward_unweighted dense_to_jagged " -support_310p_list="gather_for_rank1 hstu_dense_forward_fuxi relative_attn_bias" +support_310p_list="gather_for_rank1 hstu_dense_forward_fuxi relative_attn_bias_pos relative_attn_bias_time" cd "${MxRec_DIR}" diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a520ffcf0571308d72ed5c99ad5909ce7b1d86ba --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos.cpp @@ -0,0 +1,187 @@ +/** +* @file relative_attn_bias_pos.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include +#include "relative_attn_bias_pos_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_api.h" +#include "tiling/platform/platform_ascendc.h" +#include "../../../common/ops_log.h" + +constexpr int32_t RESERVER_UB_SIZE = (20 * 1024); +constexpr uint8_t NUM_BUFFER = 2; + +// input index +constexpr int REL_POS_BIAS_INDEX = 0; +constexpr int IDENTITY_INDEX = 1; +// output index +constexpr int RAB_POSITION_INDEX = 0; +// attr index +constexpr int PAST_VALID_LENS_INDEX = 0; +// output dim +constexpr int REL_POS_BIAS_DIM = 2; +constexpr int IDENTITY_DIM = 2; +constexpr int RAB_POS_OUT_DIM = 3; +constexpr int DIM0 = 0; +constexpr int DIM1 = 1; +constexpr int DIM2 = 2; + +namespace optiling { +static ge::graphStatus PosTilingFunc(TilingData& tilingData, gert::TilingContext* context) +{ + // 设置past_valid_len + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + const auto pastValidLensPtr = attrs->GetAttrPointer(PAST_VALID_LENS_INDEX); + int batchsize = pastValidLensPtr->GetSize(); + OPS_CHECK(batchsize <= 0, + OPS_LOG_E("Tiling Debug", "mismatch batchsize of past_valid_len and timestamps."), + return ge::GRAPH_FAILED); + tilingData.set_bs(batchsize); + + auto *pastValidLensData = const_cast(reinterpret_cast(pastValidLensPtr->GetData())); + uint32_t pastValidLens[MAX_BATCH_SIZE]; + for (auto i = 0; i < batchsize; ++i) { + pastValidLens[i] = pastValidLensData[i]; + } + tilingData.set_pastValidLens(pastValidLens); + // 校验s + auto biasShape = context->GetInputShape(REL_POS_BIAS_INDEX)->GetStorageShape(); // (2s, 2s) + auto identityShape = context->GetInputShape(IDENTITY_INDEX)->GetStorageShape(); // (2s, 2s) + + int biasSeqLen = biasShape.GetDim(DIM0); + int biasSeqLen2 = biasShape.GetDim(DIM1); + int idSeqLen = identityShape.GetDim(DIM0); + int idSeqLen2 = identityShape.GetDim(DIM1); + + OPS_CHECK(biasShape.GetDimNum() != REL_POS_BIAS_DIM, + OPS_LOG_E("Tiling Debug", "Invalid rel_pos_bias shape."), + return ge::GRAPH_FAILED); + OPS_CHECK(identityShape.GetDimNum() != IDENTITY_DIM, + OPS_LOG_E("Tiling Debug", "Invalid identity shape."), + return ge::GRAPH_FAILED); + OPS_CHECK(biasSeqLen != biasSeqLen2 || biasSeqLen != idSeqLen || biasSeqLen != idSeqLen2, + OPS_LOG_E("Tiling Debug", "Mismatch sequence len of rel_pos_bias and identity."), + return ge::GRAPH_FAILED); + tilingData.set_s(biasSeqLen / 2); + + // 获取ub + uint64_t ub; + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); + ub = ub - RESERVER_UB_SIZE; + // 获取数据类型 + auto identityType = context->GetInputTensor(IDENTITY_INDEX)->GetDataType(); + auto biasType = context->GetInputTensor(REL_POS_BIAS_INDEX)->GetDataType(); + + int biasDataSize = ge::GetSizeByDataType(biasType); + OPS_CHECK(identityType != biasType, + OPS_LOG_E("Tiling Debug", "Mismatch data type of identity and rel_pos_bias."), + return ge::GRAPH_FAILED); + OPS_CHECK(biasDataSize < 1, + OPS_LOG_E("Tiling Debug", "Invalid data type."), + return ge::GRAPH_FAILED); + tilingData.set_dataType(biasType); + + // 计算一次处理的窗口大小(stride) + int stride = ub / (NUM_BUFFER * 3 * biasDataSize); + tilingData.set_stride(stride); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingFunc(gert::TilingContext* context) +{ + OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("relPosBiasShape", context->GetInputShape(REL_POS_BIAS_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("identityShape", context->GetInputShape(IDENTITY_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("attrs", context->GetAttrs(), return ge::GRAPH_FAILED); + + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + const auto pastValidLensPtr = attrs->GetAttrPointer(PAST_VALID_LENS_INDEX); + OPS_LOG_E_IF_NULL("past_valid_len", pastValidLensPtr, return ge::GRAPH_FAILED); + + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + OPS_CHECK(coreNum == 0, + OPS_LOG_E("Tiling Debug", "Core num is 0."), + return ge::GRAPH_FAILED); + context->SetBlockDim(coreNum); + + TilingData tilingData; + auto ret = PosTilingFunc(tilingData, context); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + auto rowTilingData = context->GetRawTilingData(); + OPS_LOG_E_IF_NULL("GetRawTilingData", rowTilingData, return ge::GRAPH_FAILED); + tilingData.SaveToBuffer(rowTilingData->GetData(), rowTilingData->GetCapacity()); + rowTilingData->SetDataSize(tilingData.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext* context) +{ + gert::Shape* rabPosOutShape = context->GetOutputShape(RAB_POSITION_INDEX); + + const gert::RuntimeAttrs* attrs = context->GetAttrs(); + const auto pastValidLensPtr = attrs->GetAttrPointer(PAST_VALID_LENS_INDEX); + int bs = pastValidLensPtr->GetSize(); + const gert::Shape* identityShape = context->GetInputShape(IDENTITY_INDEX); + int s = identityShape->GetDim(DIM0); // identityShape(2s, 2s) + + rabPosOutShape->SetDimNum(RAB_POS_OUT_DIM); + rabPosOutShape->SetDim(DIM0, bs); + rabPosOutShape->SetDim(DIM1, s); + rabPosOutShape->SetDim(DIM2, s); + + return GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class RelativeAttnBiasPos : public OpDef { +public: + explicit RelativeAttnBiasPos(const char* name) : OpDef(name) + { + this->Input("rel_pos_bias") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("identity") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("rab_pos") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("past_valid_lens").ListInt(); + + this->SetInferShape(ge::InferShape); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") + .ExtendCfgInfo("coreType.value", "AiCore") + .ExtendCfgInfo("prebuildPattern.value", "Opaque"); + + this->AICore().SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910", aicore_config); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + this->AICore().AddConfig("ascend310p", aicore_config); + } +}; + +OP_ADD(RelativeAttnBiasPos); + +} // namespace ops \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..d66dafa2db1ae4252d06797b35930db49f0a2faf --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_host/relative_attn_bias_pos_tiling.h @@ -0,0 +1,25 @@ +/** + * @file relative_attn_bias_pos_tiling.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_TILING_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_TILING_H +#include "register/tilingdata_base.h" +constexpr int MAX_BATCH_SIZE = 512; + +namespace optiling { +BEGIN_TILING_DATA_DEF(TilingData) +TILING_DATA_FIELD_DEF(int64_t, s); +TILING_DATA_FIELD_DEF(int64_t, bs); +TILING_DATA_FIELD_DEF(int64_t, stride); +TILING_DATA_FIELD_DEF_ARR(uint32_t, MAX_BATCH_SIZE, pastValidLens); + +TILING_DATA_FIELD_DEF(int, dataType); + +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(RelativeAttnBiasPos, TilingData) +} // namespace optiling +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_TILING_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h new file mode 100644 index 0000000000000000000000000000000000000000..fb7148d486164dfc9a8abab4d491b705a626a54a --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/rab_common.h @@ -0,0 +1,33 @@ +/** + * @file rab_common.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RAB_COMMON_H +#define MXREC_ADD_ONS_RAB_COMMON_H + +#include "kernel_operator.h" +constexpr int DATA_ALIGN_BYTES = 32; +constexpr int MAX_BATCH_SIZE = 512; +constexpr int NUM_BUFFER = 2; + +constexpr int8_t TYPE_FP32 = 0; +constexpr int8_t TYPE_FP16 = 1; +constexpr int8_t TYPE_INT32 = 3; +constexpr int8_t TYPE_INT64 = 9; + +using namespace AscendC; + +struct Args { + // pos_bias + GM_ADDR positionBias; + GM_ADDR identity; + // out + GM_ADDR rabPosOut; + + GM_ADDR workspace; + GM_ADDR tiling; +}; +#endif // MXREC_ADD_ONS_RAB_COMMON_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp new file mode 100644 index 0000000000000000000000000000000000000000..211edd74457b593f6aa589233ee529ecca32d5eb --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.cpp @@ -0,0 +1,29 @@ +/** +* @file relative_attn_bias_pos.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include "rab_common.h" +#include "relative_attn_bias_pos.h" +#include "kernel_operator.h" + +extern "C" __global__ __aicore__ void relative_attn_bias_pos(GM_ADDR positionBias, + GM_ADDR identity, + GM_ADDR rabPosOut, + GM_ADDR workspace, + GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + Args args{ + positionBias, identity, rabPosOut, workspace, tiling + }; + if (tilingData.dataType == TYPE_FP32) { + RelativeAttnBiasPos kernel; + kernel.Compute(args); + } else if (tilingData.dataType == TYPE_FP16) { + RelativeAttnBiasPos kernel; + kernel.Compute(args); + } +} diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h new file mode 100644 index 0000000000000000000000000000000000000000..0dcb1c4c9f45663242876d1e180dc07a26776845 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/op_kernel/relative_attn_bias_pos.h @@ -0,0 +1,174 @@ +/** + * @file relative_attn_bias_pos.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H +#include "rab_common.h" +#include "kernel_operator.h" +using namespace AscendC; + +constexpr int SEQ_EXPAND = 2; // rab_pos中序列长度为原本输入的两倍 + +template +class RelativeAttnBiasPos { +public: + __aicore__ inline RelativeAttnBiasPos() {} + + __aicore__ inline void Init(Args args) + { + GET_TILING_DATA(tilingData, args.tiling); + s = SEQ_EXPAND * tilingData.s; + bs = tilingData.bs; + stride = tilingData.stride; + for (auto i = 0; i < bs; ++i) { + pastValidLens[i] = tilingData.pastValidLens[i]; + } + + posBiasGT.SetGlobalBuffer((__gm__ FloatType*)args.positionBias, s * s); + identityGT.SetGlobalBuffer((__gm__ FloatType*)args.identity, s * s); + rabPosBiasOutGT.SetGlobalBuffer((__gm__ FloatType*)args.rabPosOut, bs * s * s); + + pipe.InitBuffer(queIdentityIn, NUM_BUFFER, Ceil(SEQ_EXPAND * stride * sizeof(FloatType))); + pipe.InitBuffer(quePosIn, NUM_BUFFER, Ceil(stride * sizeof(FloatType))); + + int64_t totalTableSizeSplit = s % GetBlockNum(); + int64_t baseLen = s / GetBlockNum(); + if (GetBlockIdx() >= totalTableSizeSplit) { + totalRow = baseLen; + rowOffset = totalTableSizeSplit * (baseLen + 1) + (GetBlockIdx() - totalTableSizeSplit) * baseLen; + } else { + totalRow = baseLen + 1; + rowOffset = GetBlockIdx() * (baseLen + 1); + } + REL_POS_BIAS_FIRST = posBiasGT.GetValue(0); + } + + __aicore__ inline void ComputeIdentity(int offset, int cnt) + { + // DataCopyIn identity + LocalTensor identityUb = queIdentityIn.AllocTensor(); + + DataCopy(identityUb, identityGT[offset], Ceil(cnt * sizeof(FloatType)) / sizeof(FloatType)); + queIdentityIn.EnQue(identityUb); + + // Compute identity * rel_pos_bias[0, 0], (1 - identity) + LocalTensor identityFilledUb = queIdentityIn.DeQue(); + + // 后半段 (1 - identity) + Muls(identityFilledUb[stride], identityFilledUb, (FloatType)-1, cnt); + Adds(identityFilledUb[stride], identityFilledUb[stride], (FloatType)1, cnt); + + // 前半段 identity * rel_pos_bias[0, 0] + Muls(identityFilledUb, identityFilledUb, REL_POS_BIAS_FIRST, cnt); + + queIdentityIn.EnQue(identityFilledUb); + } + + __aicore__ inline void DataCopyIn(int row, int offset, int cnt) + { + LocalTensor posBiasUb = quePosIn.AllocTensor(); + DataCopy(posBiasUb, posBiasGT[row * s + offset], Ceil(cnt * sizeof(FloatType)) / sizeof(FloatType)); + quePosIn.EnQue(posBiasUb); + } + + __aicore__ inline void ComputeRabBias(LocalTensor& identityCalcUb, int cnt) + { + LocalTensor posBiasUb = quePosIn.DeQue(); + Mul(posBiasUb, posBiasUb, identityCalcUb[stride], cnt); + Add(posBiasUb, posBiasUb, identityCalcUb, cnt); + pipe_barrier(PIPE_ALL); + quePosIn.EnQue(posBiasUb); + } + + __aicore__ inline int64_t Ceil(int64_t a, int64_t b = DATA_ALIGN_BYTES) + { + if (b == 0) { + return 0; + } + return (a + b - 1) / b * b; + } + + __aicore__ inline void DataCopyOut(int offset, int cnt) + { + uint32_t datasize = cnt * sizeof(FloatType); + uint32_t alignLen = datasize / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES; + uint32_t unAlignLen = datasize - alignLen; + uint32_t alignCnt = alignLen / sizeof(FloatType); + uint32_t unAlignCnt = unAlignLen / sizeof(FloatType); + + LocalTensor posBiasUb = quePosIn.DeQue(); + // 对齐部分拷出 + if (alignLen > 0) { + DataCopy(rabPosBiasOutGT[offset], posBiasUb, cnt); + } + // 非对齐部分拷出 + if (unAlignLen > 0) { +#ifdef SUPPORT_V200 + uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(FloatType))) - (1ul << unAlignCnt); + uint64_t mask[2] = {mask0, 0}; + Duplicate(posBiasUb[alignCnt], (FloatType)0, mask, 1, 1, 1); + quePosIn.EnQue(posBiasUb); + posBiasUb = quePosIn.DeQue(); + SetAtomicAdd(); + DataCopy(rabPosBiasOutGT[offset + alignCnt], posBiasUb[alignCnt], Ceil(unAlignLen) / sizeof(FloatType)); + SetAtomicNone(); +#else + const DataCopyExtParams dataCopyExtParams{1, unAlignLen, 0, 0, 0}; + DataCopyPad(rabPosBiasOutGT[offset + alignCnt], posBiasUb[alignCnt], dataCopyExtParams); +#endif + } + quePosIn.FreeTensor(posBiasUb); + } + + __aicore__ inline void Compute(Args args) + { + Init(args); + for (int row = rowOffset; row < rowOffset + totalRow; ++row) { + int offset = 0; + for (int j = 0; j < (s + stride - 1) / stride; ++j) { + int remain = s - offset; + int cnt = remain > stride ? stride : remain; + ComputeIdentity(offset + row * s, cnt); + LocalTensor identityCalcUb = queIdentityIn.DeQue(); + + for (int b = 0; b < bs; ++b) { + int valid_len = pastValidLens[b]; + int valid_row = row > valid_len ? valid_len : row; + DataCopyIn(valid_row, offset, cnt); + ComputeRabBias(identityCalcUb, cnt); + int padOutPtr = b * s * s + row * s + j * stride; + DataCopyOut(padOutPtr, cnt); + } + queIdentityIn.FreeTensor(identityCalcUb); + offset += cnt; + } + } + } + +private: + // shape + int s; + int bs; + int stride; + // tiling + int rowOffset; // identity、rel_pos_bias(s, s)的行偏移 + int totalRow; // 需要处理的总行数 + +private: + TPipe pipe; + TQue queIdentityIn; + TQue queIdentityCalcIn; + TQue quePosIn; + + GlobalTensor identityGT; + GlobalTensor posBiasGT; + GlobalTensor rabPosBiasOutGT; + uint32_t pastValidLens[MAX_BATCH_SIZE]; + FloatType REL_POS_BIAS_FIRST; // identity[0, 0] +}; + +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_POS_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json new file mode 100644 index 0000000000000000000000000000000000000000..c986b36e593af467ae49b1f3c90f0ac25bbba2c6 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/relative_attn_bias_pos.json @@ -0,0 +1,47 @@ +[ + { + "op": "RelativeAttnBiasPos", + "language": "cpp", + "input_desc": [ + { + "name": "rel_pos_bias", + "param_type": "required", + "format": [ + "ND", "ND" + ], + "type": [ + "fp16", "fp32" + ] + }, + { + "name": "identity", + "param_type": "required", + "format": [ + "ND", "ND" + ], + "type": [ + "fp16", "fp32" + ] + } + ], + "output_desc": [ + { + "name": "rab_pos", + "param_type": "required", + "format": [ + "ND", "ND" + ], + "type": [ + "fp16", "fp32" + ] + } + ], + "attr": [ + { + "name": "past_valid_lens", + "param_type": "required", + "type": "list_int" + } + ] + } +] \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/run.sh b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..453261620af39a8c7205b4e0cf9f57a88b9bf15d --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_pos/run.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# ============================================================================== + +set -e + +# 查找msopgen的路径,加入到环境变量PATH中 +msopgen_path=$(find /usr/local/Ascend/ -name msopgen | grep bin) +parent_dir=$(dirname "$msopgen_path") +export PATH=$parent_dir:$PATH + +ai_core="ai_core-Ascend910B1" +if [ "$#" -eq 1 ]; then + ai_core=$1 +fi + +# 利用msopgen生成可编译文件 +rm -rf ./relative_attn_bias_pos +python3 /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/bin/msopgen gen -i relative_attn_bias_pos.json -f tf -c ${ai_core} -lan cpp -out ./relative_attn_bias_pos -m 0 -op RelativeAttnBiasPos +rm -rf relative_attn_bias_pos/op_kernel/*.h +rm -rf relative_attn_bias_pos/op_kernel/*.cpp +rm -rf relative_attn_bias_pos/host/*.h +rm -rf relative_attn_bias_pos/host/*.cpp +cp -rf op_kernel relative_attn_bias_pos/ +cp -rf op_host relative_attn_bias_pos/ + +cd relative_attn_bias_pos + +# 判断当前目录下是否存在CMakePresets.json文件 +if [ ! -f "CMakePresets.json" ]; then + echo "ERROR, CMakePresets.json file not exist." + exit 1 +fi + +# 禁止生成CRC校验和 +sed -i 's/--nomd5/--nomd5 --nocrc/g' ./cmake/makeself.cmake + +# 修改cann安装路径 +sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json +# 修改vendor_name 防止覆盖之前vendor_name为customize的算子; +# vendor_name需要和aclnn中的CMakeLists.txt中的CUST_PKG_PATH值同步,不同步aclnn会调用失败; +# vendor_name字段值不能包含customize;包含会导致多算子部署场景CANN的vendors路径下config.ini文件内容截取错误 +sed -i 's:"customize":"relative_attn_bias_pos":g' CMakePresets.json + +if [ "$ai_core" = "ai_core-Ascend310P3" ]; then + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_pos_kernel.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_pos_time.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_pos_pos.h +fi + +line=`awk '/ENABLE_SOURCE_PACKAGE/{print NR}' CMakePresets.json` +line=`expr ${line} + 2` +sed -i "${line}s/True/False/g" CMakePresets.json + +# 增加LOG_CPP编译选项支持错误日志打印 +sed -i "1 i include(../../../cmake/func.cmake)" ./op_host/CMakeLists.txt + +line1=`awk '/tartet_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line1}s/OP_TILING_LIB/OP_TILING_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +line2=`awk '/tartet_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line2}s/OP_PROTO_LIB/OP_PROTO_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +bash build.sh + +# 安装编译成功的算子包 +bash ./build_out/custom_opp*.run diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f9779cb32cdf1650786a326b55a43d725514cb14 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time.cpp @@ -0,0 +1,194 @@ +/** +* @file relative_attn_bias_time.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include +#include "relative_attn_bias_time_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_api.h" +#include "tiling/platform/platform_ascendc.h" +#include "../../../common/ops_log.h" + +constexpr int32_t RESERVER_UB_SIZE = (20 * 1024); +constexpr int32_t DATA_ALIGN_BYTES = 32; + +// input index +constexpr int TIMESTAMPS_INDEX = 0; +constexpr int TIMESTAMPS_WEIGHTS_INDEX = 1; +// output index +constexpr int RAB_TIME_INDEX = 0; +// attr index +constexpr int BUCKET_DIV_INDEX = 0; +// input/output dim +constexpr int TIMESTAMPS_DIM = 2; +constexpr int TIMESTAMPS_WEIGHTS_DIM = 2; +constexpr int RAB_TIME_OUT_DIM = 6; +constexpr int DIM_PLACE_HOLDER = 1; +constexpr int DIM0 = 0; +constexpr int DIM1 = 1; +constexpr int DIM2 = 2; +constexpr int DIM3 = 3; +constexpr int DIM4 = 4; +constexpr int DIM5 = 5; +// constrain of params +constexpr int MAX_S = 4300; + +namespace optiling { +static ge::graphStatus TimeTilingFunc(RelativeAttnBiasTimeTilingData& tilingData, gert::TilingContext* context) +{ + auto tsShape = context->GetInputShape(TIMESTAMPS_INDEX)->GetStorageShape(); // (b, s) + auto tswShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX)->GetStorageShape(); // (num_layer, num_buckets) + + int batchsize = tsShape.GetDim(DIM0); // (b, s) + int s = tsShape.GetDim(DIM1); // (b, s) + int numLayer = tswShape.GetDim(DIM0); // (num_layer, num_buckets) + int numBuckets = tswShape.GetDim(DIM1); // (num_layer, num_buckets) + float divs = *context->GetAttrs()->GetFloat(BUCKET_DIV_INDEX); + float clampMax = exp((numBuckets - 1) * divs); + + OPS_CHECK(tsShape.GetDimNum() != TIMESTAMPS_DIM, + OPS_LOG_E("Tiling Debug", "Invalid timestamps shape."), + return ge::GRAPH_FAILED); + OPS_CHECK(tswShape.GetDimNum() != TIMESTAMPS_WEIGHTS_DIM, + OPS_LOG_E("Tiling Debug", "Invalid timestamps_weights shape."), + return ge::GRAPH_FAILED); + OPS_CHECK(s > MAX_S, + OPS_LOG_E("Tiling Debug", "Len of timestamps sequence larger than limit."), + return ge::GRAPH_FAILED); + OPS_CHECK(batchsize <= 0, + OPS_LOG_E("Tiling Debug", "Invalid batchsize of timestamps."), + return ge::GRAPH_FAILED); + + tilingData.set_bs(batchsize); + tilingData.set_s(s); + tilingData.set_numLayer(numLayer); + tilingData.set_numBuckets(numBuckets); + tilingData.set_bucketDivisor(divs); + tilingData.set_clampMax(clampMax); + + // 计算stride、buff + // 获取ub + uint64_t ub; + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + ascendPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub); + ub = ub - RESERVER_UB_SIZE; + // 获取数据类型 + auto tswType = context->GetInputTensor(TIMESTAMPS_WEIGHTS_INDEX)->GetDataType(); + auto tsType = context->GetInputTensor(TIMESTAMPS_INDEX)->GetDataType(); + int tswSize = ge::GetSizeByDataType(tswType); + int tsSize = ge::GetSizeByDataType(tsType); + tilingData.set_tswType(tswType); + tilingData.set_tsType(tsType); + // 计算不含buff的stride长度 + ub -= numBuckets * numLayer * tswSize + numLayer * DATA_ALIGN_BYTES; // 减去tsw预留ub + uint32_t alignSeqLen = (s * tswSize + DATA_ALIGN_BYTES - 1) / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES / tswSize; + uint32_t stride = ub / (sizeof(float) + tsSize) / alignSeqLen; + + // 计算clamp buff所需空间 + std::vector shape_vec = {stride * alignSeqLen}; + ge::Shape shape(shape_vec); + uint32_t maxBuff = 0; + uint32_t minBuff = 0; + AscendC::GetClampMaxMinTmpSize(shape, sizeof(float), false, maxBuff, minBuff); + tilingData.set_buffSize(maxBuff); + + // 重新计算stride长度 + stride = (ub - maxBuff) / (sizeof(float) + tsSize) / alignSeqLen; + tilingData.set_stride(stride); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus TilingFunc(gert::TilingContext* context) +{ + OPS_LOG_E_IF_NULL("context", context, return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("timestampShape", context->GetInputShape(TIMESTAMPS_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("tswShape", context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX), return ge::GRAPH_FAILED); + OPS_LOG_E_IF_NULL("attrs", context->GetAttrs(), return ge::GRAPH_FAILED); + + auto ascendPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + size_t coreNum = ascendPlatform.GetCoreNumAiv(); + OPS_CHECK(coreNum == 0, + OPS_LOG_E("Tiling Debug", "Core num is 0."), + return ge::GRAPH_FAILED); + + RelativeAttnBiasTimeTilingData tilingData; + auto ret = TimeTilingFunc(tilingData, context); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + context->SetBlockDim(coreNum); + auto rowTilingData = context->GetRawTilingData(); + OPS_LOG_E_IF_NULL("GetRawTilingData", rowTilingData, return ge::GRAPH_FAILED); + tilingData.SaveToBuffer(rowTilingData->GetData(), rowTilingData->GetCapacity()); + rowTilingData->SetDataSize(tilingData.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext* context) +{ + const gert::Shape* tsShape = context->GetInputShape(TIMESTAMPS_INDEX); + const gert::Shape* tswShape = context->GetInputShape(TIMESTAMPS_WEIGHTS_INDEX); + gert::Shape* rabTimeOutShape = context->GetOutputShape(RAB_TIME_INDEX); + int bs = tsShape->GetDim(DIM0); + int s = tsShape->GetDim(DIM1); + int numLayers = tswShape->GetDim(DIM1); + + rabTimeOutShape->SetDimNum(RAB_TIME_OUT_DIM); + rabTimeOutShape->SetDim(DIM0, numLayers); + rabTimeOutShape->SetDim(DIM1, bs); + rabTimeOutShape->SetDim(DIM2, s); + rabTimeOutShape->SetDim(DIM3, DIM_PLACE_HOLDER); + rabTimeOutShape->SetDim(DIM4, s); + rabTimeOutShape->SetDim(DIM5, DIM_PLACE_HOLDER); + return GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class RelativeAttnBiasTime : public OpDef { +public: + explicit RelativeAttnBiasTime(const char* name) : OpDef(name) + { + this->Input("timestamps") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("timestamps_weights") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("rab_time") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("bucket_divisor").Float(); + + this->SetInferShape(ge::InferShape); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .ExtendCfgInfo("jitCompile.flag", "static_false,dynamic_false") + .ExtendCfgInfo("coreType.value", "AiCore") + .ExtendCfgInfo("prebuildPattern.value", "Opaque"); + + this->AICore().SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910", aicore_config); + this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910_93", aicore_config); + this->AICore().AddConfig("ascend310p", aicore_config); + } +}; + +OP_ADD(RelativeAttnBiasTime); + +} // namespace ops \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time_tiling.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..5579bb1ae7744d4ecee01f4a0b1be4640c1d195c --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_host/relative_attn_bias_time_tiling.h @@ -0,0 +1,31 @@ +/** + * @file relative_attn_bias_time_tiling.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_TILING_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_TILING_H +#include "register/tilingdata_base.h" +constexpr int MAX_BATCH_SIZE = 512; + +namespace optiling { +BEGIN_TILING_DATA_DEF(RelativeAttnBiasTimeTilingData) +TILING_DATA_FIELD_DEF(int64_t, s); +TILING_DATA_FIELD_DEF(int64_t, bs); +TILING_DATA_FIELD_DEF(int64_t, stride); + +TILING_DATA_FIELD_DEF(float, bucketDivisor); +TILING_DATA_FIELD_DEF(int64_t, numBuckets); +TILING_DATA_FIELD_DEF(int64_t, numLayer); +TILING_DATA_FIELD_DEF(float, clampMax); + +TILING_DATA_FIELD_DEF(int, tswType); +TILING_DATA_FIELD_DEF(int, tsType); +TILING_DATA_FIELD_DEF(int, buffSize); + +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(RelativeAttnBiasTime, RelativeAttnBiasTimeTilingData) +} // namespace optiling +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_TILING_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h new file mode 100644 index 0000000000000000000000000000000000000000..b574b0e12c854093c3e365e33a84b1114937a2e3 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/rab_common.h @@ -0,0 +1,33 @@ +/** + * @file rab_common.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RAB_COMMON_H +#define MXREC_ADD_ONS_RAB_COMMON_H + +#include "kernel_operator.h" +constexpr int DATA_ALIGN_BYTES = 32; +constexpr int MAX_SEQ_CNT = 128; +constexpr int GATHER_PROCESS_WINDOW = 4096; + +constexpr int8_t TYPE_FP32 = 0; +constexpr int8_t TYPE_FP16 = 1; +constexpr int8_t TYPE_INT32 = 3; +constexpr int8_t TYPE_INT64 = 9; + +using namespace AscendC; + +struct Args { + // ts_bias + GM_ADDR timestamps; + GM_ADDR timestampsWeights; + // out + GM_ADDR rabTimeOut; + + GM_ADDR workspace; + GM_ADDR tiling; +}; +#endif // MXREC_ADD_ONS_RAB_COMMON_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.cpp b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.cpp new file mode 100644 index 0000000000000000000000000000000000000000..54583eb058d1bedd8e5461a9689642fd1395653d --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.cpp @@ -0,0 +1,29 @@ +/** +* @file relative_attn_bias_time.cpp +* +* Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. +* +*/ + +#include "kernel_operator.h" +#include "rab_common.h" +#include "relative_attn_bias_time.h" + +extern "C" __global__ __aicore__ void relative_attn_bias_time(GM_ADDR timestamps, + GM_ADDR timestampsWeights, + GM_ADDR rabTimeOut, + GM_ADDR workspace, + GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + Args args{ + timestamps, timestampsWeights, rabTimeOut, workspace, tiling + }; + if (tilingData.tswType == TYPE_FP32) { + RelativeAttnBiasTime kernel; + kernel.Compute(args); + } else if (tilingData.tswType == TYPE_FP16) { + RelativeAttnBiasTime kernel; + kernel.Compute(args); + } +} diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h new file mode 100644 index 0000000000000000000000000000000000000000..4ee26a1fb7f66de36c3b33a8965af444fd0ecf47 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/op_kernel/relative_attn_bias_time.h @@ -0,0 +1,255 @@ +/** + * @file relative_attn_bias_time.h + * + * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved. + * + */ + +#ifndef MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H +#define MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H +#include +#include "rab_common.h" +#include "kernel_operator.h" +using namespace AscendC; + +struct SequenceParams { + int startIndexGT; + int startIndexUb; + int subValue; +}; + +template +class RelativeAttnBiasTime { +public: + __aicore__ inline RelativeAttnBiasTime() {} + + __aicore__ inline void Init(Args args) + { + GET_TILING_DATA(tilingData, args.tiling); + s = tilingData.s; + bs = tilingData.bs; + stride = tilingData.stride; + alignSeqLen = Ceil(s * sizeof(FloatType)) / sizeof(FloatType); + + int totalLen = bs * s; + uint32_t seqDatasize = s * sizeof(FloatType); + alignLen = seqDatasize / DATA_ALIGN_BYTES * DATA_ALIGN_BYTES; + alignCnt = alignLen / sizeof(FloatType); + unalignLen = seqDatasize - alignLen; + unalignCnt = unalignLen / sizeof(FloatType); + + div = 1 / tilingData.bucketDivisor; + numBuckets = tilingData.numBuckets; + alignNumBuckets = Ceil(numBuckets * sizeof(FloatType)) / sizeof(FloatType); + numLayer = tilingData.numLayer; + + clampMin = 1; // 根据仿真代码,指定为1 + clampMax = tilingData.clampMax; + + timestampsGT.SetGlobalBuffer((__gm__ int32_t*)args.timestamps, bs * s); + timestampsWeightsGT.SetGlobalBuffer((__gm__ FloatType*)args.timestampsWeights, numBuckets * numLayer); + rabTimeBiasOutGT.SetGlobalBuffer((__gm__ FloatType*)args.rabTimeOut, numLayer * bs * s * s); + + pipe.InitBuffer(queTimestamps, 1, stride * alignSeqLen * sizeof(int32_t)); + pipe.InitBuffer(queTimestampsFloat, 1, stride * alignSeqLen * sizeof(float)); + pipe.InitBuffer(queTimestampsWeights, 1, alignNumBuckets * numLayer * sizeof(FloatType)); + pipe.InitBuffer(tmpQue, 1, Ceil(tilingData.buffSize)); + + int totalTableSizeSplit = totalLen % GetBlockNum(); + int baseLen = totalLen / GetBlockNum(); + if (GetBlockIdx() >= totalTableSizeSplit) { + processRowLen = baseLen; + startIndex = totalTableSizeSplit * (baseLen + 1) + (GetBlockIdx() - totalTableSizeSplit) * baseLen; + } else { + processRowLen = baseLen + 1; + startIndex = GetBlockIdx() * (baseLen + 1); + } + } + + __aicore__ inline void FillSeqParams(SequenceParams* params, int offset, int cnt) + { + LocalTensor ts = queTimestamps.AllocTensor(); + DataCopy(ts, timestampsGT[offset], Ceil(cnt)); + for (int i = 0; i < cnt; ++i) { + int seqSubValue = ts.GetValue(i); + int seqId = (offset + i) / s; + int seqOffsetUb = i * alignSeqLen; + int seqOffsetGT = seqId * s; + + params[i].startIndexGT = seqOffsetGT; + params[i].startIndexUb = seqOffsetUb; + params[i].subValue = seqSubValue; + } + queTimestamps.FreeTensor(ts); + } + + __aicore__ inline void DataCopyIn(SequenceParams* params, int cnt) + { + LocalTensor ts = queTimestamps.AllocTensor(); + for (int i = 0; i < cnt; ++i) { + SequenceParams param = params[i]; + int startIndexGT = param.startIndexGT; + int startIndexUb = param.startIndexUb; + + DataCopy(ts[startIndexUb], timestampsGT[startIndexGT], alignSeqLen); + } + queTimestamps.EnQue(ts); + } + + __aicore__ inline void ComputeBucketTimestamps(SequenceParams* params, int rowCnt) + { + LocalTensor tsInt = queTimestamps.DeQue(); + LocalTensor tsTmp = tsInt.template ReinterpretCast(); + LocalTensor ts = queTimestampsFloat.AllocTensor(); + LocalTensor buff = tmpQue.AllocTensor(); + + for (int i = 0; i < rowCnt; ++i) { + SequenceParams param = params[i]; + int startIndexUb = param.startIndexUb; + int value = param.subValue; + Adds(tsInt[startIndexUb], tsInt[startIndexUb], (int32_t)-value, s); + } + + uint32_t cnt = rowCnt * alignSeqLen; + Cast(ts, tsInt, RoundMode::CAST_NONE, cnt); + + Abs(ts, ts, cnt); + ClampMin(tsTmp, ts, buff, clampMin, cnt); + Log(ts, tsTmp, cnt); + Muls(ts, ts, div, cnt); + ClampMax(tsTmp, ts, buff, (float)numBuckets, cnt); + + Cast(tsInt, tsTmp, RoundMode::CAST_TRUNC, cnt); + Muls(tsInt, tsInt, (int32_t)sizeof(FloatType), cnt); // 计算gather时的偏移量单位为bytes + + tmpQue.FreeTensor(buff); + queTimestampsFloat.FreeTensor(ts); + queTimestamps.EnQue(tsInt); + } + + __aicore__ inline void IndexSelect(LocalTensor& tsw, LocalTensor& tsInt, int layer, int rowCnt) + { + uint32_t cnt = rowCnt * alignSeqLen; + LocalTensor rabTime = queTimestampsFloat.AllocTensor(); + uint32_t processLenMax = GATHER_PROCESS_WINDOW / sizeof(FloatType); + uint32_t tmpOffset = 0; + while (tmpOffset < cnt) { + uint32_t processLen = (cnt - tmpOffset) > processLenMax ? processLenMax : (cnt - tmpOffset); + Gather(rabTime[tmpOffset], tsw[layer * alignNumBuckets], tsInt[tmpOffset], (uint32_t)0, processLen); + tmpOffset += processLen; + } + queTimestampsFloat.EnQue(rabTime); + } + + __aicore__ inline void DataCopyOut(uint32_t ptr, int rowCnt) + { + LocalTensor rabTime = queTimestampsFloat.DeQue(); + + for (int i = 0; i < rowCnt; ++i) { + uint32_t ptrUb = i * alignSeqLen; + + // 对齐部分拷出 + if (alignLen > 0) { + DataCopy(rabTimeBiasOutGT[ptr + i * s], rabTime[ptrUb], s); + } + // 非对齐拷出 + if (unalignLen == 0) { + continue; + } +#ifdef SUPPORT_V200 + uint64_t mask0 = (1ul << (DATA_ALIGN_BYTES / sizeof(FloatType))) - (1ul << unalignCnt); + uint64_t mask[2] = {mask0, 0}; + Duplicate(rabTime[ptrUb + alignCnt], (FloatType)0, mask, 1, 1, 1); + queTimestampsFloat.EnQue(rabTime); + rabTime = queTimestampsFloat.DeQue(); + SetAtomicAdd(); + DataCopy(rabTimeBiasOutGT[ptr + i * s + alignCnt], rabTime[ptrUb + alignCnt], + Ceil(unalignLen) / sizeof(FloatType)); + SetAtomicNone(); +#else + const DataCopyExtParams dataCopyExtParams{1, unalignLen, 0, 0, 0}; + DataCopyPad(rabTimeBiasOutGT[ptr + i * s + alignCnt], rabTime[ptrUb + alignCnt], dataCopyExtParams); +#endif + } + queTimestampsFloat.FreeTensor(rabTime); + } + + __aicore__ inline void DataCopyInTsw() + { + LocalTensor tsw = queTimestampsWeights.AllocTensor(); + for (int n = 0; n < numLayer; ++n) { + DataCopy(tsw[n * alignNumBuckets], timestampsWeightsGT[n * numBuckets], alignNumBuckets); + } + queTimestampsWeights.EnQue(tsw); + } + + __aicore__ inline int64_t Ceil(int64_t a, int64_t b = DATA_ALIGN_BYTES) + { + if (b == 0) { + return 0; + } + return (a + b - 1) / b * b; + } + + __aicore__ inline void Compute(Args args) + { + Init(args); + DataCopyInTsw(); + LocalTensor tsw = queTimestampsWeights.DeQue(); + + for (int offset = 0; offset < processRowLen; offset += stride) { + int rowOffset = offset + startIndex; + int rowCnt = stride > (processRowLen - offset) ? (processRowLen - offset) : stride; + + SequenceParams params[MAX_SEQ_CNT]; + FillSeqParams(params, rowOffset, rowCnt); + DataCopyIn(params, rowCnt); + ComputeBucketTimestamps(params, rowCnt); + + LocalTensor tsInt = queTimestamps.DeQue(); + for (int n = 0; n < numLayer; ++n) { + IndexSelect(tsw, tsInt, n, rowCnt); + pipe_barrier(PIPE_ALL); + + uint32_t ptr = (n * bs * s + rowOffset) * s; + DataCopyOut(ptr, rowCnt); + } + queTimestamps.FreeTensor(tsInt); + } + queTimestampsWeights.FreeTensor(tsw); + } + +private: + // shape + uint32_t s; + uint32_t alignSeqLen; + uint32_t bs; + uint32_t stride; + // align + uint32_t alignLen; + uint32_t alignCnt; + uint32_t unalignLen; + uint32_t unalignCnt; + // tiling + uint32_t startIndex; + uint32_t processRowLen; + + float div; + int32_t numBuckets; + int32_t alignNumBuckets; + int32_t numLayer; + float clampMin; + float clampMax; + +private: + GlobalTensor timestampsGT; + GlobalTensor timestampsWeightsGT; + GlobalTensor rabTimeBiasOutGT; + + TPipe pipe; + TQue queTimestamps; + TQue queTimestampsFloat; + TQue queTimestampsWeights; + TQue tmpQue; +}; +#endif // MXREC_ADD_ONS_RELATIVE_ATTN_BIAS_TIME_H diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias_time.json b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias_time.json new file mode 100644 index 0000000000000000000000000000000000000000..bfb8a8c0f2eef3a7e387edce9e7cc776e37ab4f7 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/relative_attn_bias_time.json @@ -0,0 +1,53 @@ +[ + { + "op": "RelativeAttnBiasTime", + "language": "cpp", + "input_desc": [ + { + "name": "timestamps", + "param_type": "required", + "format": [ + "ND", + "ND" + ], + "type": [ + "int32", + "int32" + ] + }, + { + "name": "timestamps_weights", + "param_type": "required", + "format": [ + "ND", + "ND" + ], + "type": [ + "fp32", + "fp16" + ] + } + ], + "output_desc": [ + { + "name": "rab_time", + "param_type": "required", + "format": [ + "ND", + "ND" + ], + "type": [ + "fp32", + "fp16" + ] + } + ], + "attr": [ + { + "name": "bucket_divisor", + "param_type": "required", + "type": "float" + } + ] + } +] \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/run.sh b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..c004186114f88758b71b0b5b3af448a2236e7086 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/operators/relative_attn_bias_time/run.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# ============================================================================== + +set -e + +# 查找msopgen的路径,加入到环境变量PATH中 +msopgen_path=$(find /usr/local/Ascend/ -name msopgen | grep bin) +parent_dir=$(dirname "$msopgen_path") +export PATH=$parent_dir:$PATH + +ai_core="ai_core-Ascend910B1" +if [ "$#" -eq 1 ]; then + ai_core=$1 +fi + +# 利用msopgen生成可编译文件 +rm -rf ./relative_attn_bias_time +python3 /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/bin/msopgen gen -i relative_attn_bias_time.json -f tf -c ${ai_core} -lan cpp -out ./relative_attn_bias_time -m 0 -op RelativeAttnBiasTime +rm -rf relative_attn_bias_time/op_kernel/*.h +rm -rf relative_attn_bias_time/op_kernel/*.cpp +rm -rf relative_attn_bias_time/host/*.h +rm -rf relative_attn_bias_time/host/*.cpp +cp -rf op_kernel relative_attn_bias_time/ +cp -rf op_host relative_attn_bias_time/ + +cd relative_attn_bias_time + +# 判断当前目录下是否存在CMakePresets.json文件 +if [ ! -f "CMakePresets.json" ]; then + echo "ERROR, CMakePresets.json file not exist." + exit 1 +fi + +# 禁止生成CRC校验和 +sed -i 's/--nomd5/--nomd5 --nocrc/g' ./cmake/makeself.cmake + +# 修改cann安装路径 +sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json +# 修改vendor_name 防止覆盖之前vendor_name为customize的算子; +# vendor_name需要和aclnn中的CMakeLists.txt中的CUST_PKG_PATH值同步,不同步aclnn会调用失败; +# vendor_name字段值不能包含customize;包含会导致多算子部署场景CANN的vendors路径下config.ini文件内容截取错误 +sed -i 's:"customize":"relative_attn_bias_time":g' CMakePresets.json + +if [ "$ai_core" = "ai_core-Ascend310P3" ]; then + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_time_kernel.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_time_time.h + sed -i "1i #define SUPPORT_V200" ./op_kernel/relative_attn_bias_time_pos.h +fi + +line=`awk '/ENABLE_SOURCE_PACKAGE/{print NR}' CMakePresets.json` +line=`expr ${line} + 2` +sed -i "${line}s/True/False/g" CMakePresets.json + +# 增加LOG_CPP编译选项支持错误日志打印 +sed -i "1 i include(../../../cmake/func.cmake)" ./op_host/CMakeLists.txt + +line1=`awk '/tartet_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line1}s/OP_TILING_LIB/OP_TILING_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +line2=`awk '/tartet_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB)/{print NR}' ./op_host/CMakeLists.txt` +sed -i "${line2}s/OP_PROTO_LIB/OP_PROTO_LIB LOG_CPP/g" ./op_host/CMakeLists.txt + +bash build.sh + +# 安装编译成功的算子包 +bash ./build_out/custom_opp*.run diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py index 83bb6f0110533ef1397fce71af2879a4168e8f15..a6144e77ad14b0357cd93264a437465ed28d5478 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias.py @@ -15,7 +15,6 @@ # limitations under the License. # ============================================================================== -import random import sysconfig import pytest @@ -111,7 +110,7 @@ def rab_npu(rel_pos_bias: torch.Tensor, return rab_pos, rab_time -def rab_time_golden(ts_w: torch.Tensor, timestamps: torch.Tensor): +def rab_time_golden(ts_w: torch.Tensor, timestamps: torch.Tensor, bucketization_divisor: float): """ num_buckets = 128 num_layers = 1 - 20 @@ -131,7 +130,7 @@ def rab_time_golden(ts_w: torch.Tensor, timestamps: torch.Tensor): diff_timestamps = timestamps.reshape(bs, infer_len, 1) - timestamps.reshape(bs, 1, infer_len) clamp_max = torch.exp(torch.tensor(NUM_BUCKETS * BUCKET_DIVISOR)) - diff_timestamps = torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) / BUCKET_DIVISOR + diff_timestamps = torch.log(torch.abs(diff_timestamps).clamp(1, clamp_max)) / bucketization_divisor bucket_timestamps = diff_timestamps.long().view(-1) rab_time = torch.index_select(ts_w, dim=0, index=bucket_timestamps) @@ -161,14 +160,10 @@ def rab_pos_golden(rel_pos_bias: torch.Tensor, identity: torch.Tensor, past_vali @torch.no_grad() -def rab(num_layers, train_len, candidate_len, bs, dtype): +def rab_pos(num_layers, train_len, candidate_len, bs, dtype): torch_npu.npu.set_device(DEVICE) - - layer_num = random.randint(0, num_layers - 1) pos_w = create_pos_w(train_len, num_layers).to(dtype) past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) - timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) - timestamps_weights = create_timestamps_weights(num_layers).to(dtype) rel_pos_bias_list, identity_list = init_rel_pos_bias(pos_w=pos_w, train_len=train_len, candidate_len=candidate_len, @@ -177,21 +172,39 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) identity_list = identity_list.to(DEVICE) - timestamps = timestamps.to(DEVICE) - timestamps_weights = timestamps_weights.to(DEVICE) past_valid_lens = past_valid_lens.to(DEVICE) torch_npu.npu.synchronize() - rab_pos_out, rab_time_out = rab_npu(rel_pos_bias=rel_pos_bias_list[layer_num, ...], - identity=identity_list[layer_num, ...], - timestamps=timestamps, - timestamps_weights=timestamps_weights, - past_valid_lens=past_valid_lens) + for rel_pos_bias, identity in zip(rel_pos_bias_list, identity_list): + rab_pos_out = torch.ops.mxrec.relative_attn_bias_pos(rel_pos_bias=rel_pos_bias, + identity=identity, + past_valid_lens=past_valid_lens.tolist()) + rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias, + identity=identity, + past_valid_lens=past_valid_lens) + assert torch.allclose(rab_pos_out_golden, rab_pos_out) + + +@torch.no_grad() +def rab_time(num_layers, train_len, candidate_len, bs, dtype): + torch_npu.npu.set_device(DEVICE) + + past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) + timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) + timestamps_weights = create_timestamps_weights(num_layers).to(dtype) + + timestamps = timestamps.to(DEVICE) + timestamps_weights = timestamps_weights.to(DEVICE) torch_npu.npu.synchronize() + rab_time_out = torch.ops.mxrec.relative_attn_bias_time(timestamps_weights=timestamps_weights, + timestamps=timestamps, + bucket_divisor=BUCKET_DIVISOR) rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1), - timestamps=timestamps) + timestamps=timestamps, + bucketization_divisor=BUCKET_DIVISOR) torch_npu.npu.synchronize() + assert torch.allclose(rab_time_out_golden, rab_time_out) @@ -201,7 +214,8 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("bs", [1, 2, 4]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): - rab(num_layers, train_len, candidate_len, bs, dtype) + rab_time(num_layers, train_len, candidate_len, bs, dtype) + rab_pos(num_layers, train_len, candidate_len, bs, dtype) @pytest.mark.parametrize("num_layers", [1, 8]) @@ -209,4 +223,5 @@ def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("candidate_len", [0]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rab_train(num_layers, train_len, candidate_len, bs, dtype): - rab(num_layers, train_len, candidate_len, bs, dtype) + rab_time(num_layers, train_len, candidate_len, bs, dtype) + rab_pos(num_layers, train_len, candidate_len, bs, dtype) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py index 3a24438b139262e2f6632b6efc237644fccf26c4..fcc9296c052341c13b55ae36882a552f493e7c92 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/relative_attn_bias/test_relative_attn_bias_v200.py @@ -161,14 +161,10 @@ def rab_pos_golden(rel_pos_bias: torch.Tensor, identity: torch.Tensor, past_vali @torch.no_grad() -def rab(num_layers, train_len, candidate_len, bs, dtype): +def rab_pos(num_layers, train_len, candidate_len, bs, dtype): torch_npu.npu.set_device(DEVICE) - - layer_num = random.randint(0, num_layers - 1) pos_w = create_pos_w(train_len, num_layers).to(dtype) past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) - timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) - timestamps_weights = create_timestamps_weights(num_layers).to(dtype) rel_pos_bias_list, identity_list = init_rel_pos_bias(pos_w=pos_w, train_len=train_len, candidate_len=candidate_len, @@ -177,27 +173,38 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): rel_pos_bias_list = rel_pos_bias_list.to(DEVICE) identity_list = identity_list.to(DEVICE) - timestamps = timestamps.to(DEVICE) - timestamps_weights = timestamps_weights.to(DEVICE) past_valid_lens = past_valid_lens.to(DEVICE) torch_npu.npu.synchronize() - rab_pos_out, rab_time_out = rab_npu(rel_pos_bias=rel_pos_bias_list[layer_num, ...], - identity=identity_list[layer_num, ...], - timestamps=timestamps, - timestamps_weights=timestamps_weights, - past_valid_lens=past_valid_lens) - rab_pos_out, rab_time_out = rab_pos_out.to("cpu"), rab_time_out.to("cpu") + for rel_pos_bias, identity in zip(rel_pos_bias_list, identity_list): + op_result = torch.ops.mxrec.relative_attn_bias_pos(rel_pos_bias=rel_pos_bias, + identity=identity, + past_valid_lens=past_valid_lens.tolist()).to('cpu') + golden_result = rab_pos_golden(rel_pos_bias=rel_pos_bias.to('cpu'), + identity=identity.to('cpu'), + past_valid_lens=past_valid_lens.to('cpu')) + assert torch.allclose(op_result, golden_result) + + +@torch.no_grad() +def rab_time(num_layers, train_len, candidate_len, bs, dtype): + torch_npu.npu.set_device(DEVICE) + + past_valid_lens = create_past_valid_lens(bs, train_len).to(torch.int32) + timestamps = create_timestamps(train_len, candidate_len, past_valid_lens).to(torch.int32) + timestamps_weights = create_timestamps_weights(num_layers).to(dtype) + + timestamps = timestamps.to(DEVICE) + timestamps_weights = timestamps_weights.to(DEVICE) torch_npu.npu.synchronize() - rab_pos_out_golden = rab_pos_golden(rel_pos_bias=rel_pos_bias_list[layer_num, ...].to("cpu"), - identity=identity_list[layer_num, ...].to("cpu"), - past_valid_lens=past_valid_lens.to("cpu")) + rab_time_out = torch.ops.mxrec.relative_attn_bias_time(timestamps_weights=timestamps_weights, + timestamps=timestamps, + bucket_divisor=BUCKET_DIVISOR).to("cpu") rab_time_out_golden = rab_time_golden(ts_w=timestamps_weights.transpose(0, 1).to("cpu"), timestamps=timestamps.to("cpu")) torch_npu.npu.synchronize() - assert torch.allclose(rab_pos_out_golden, rab_pos_out) assert torch.allclose(rab_time_out_golden, rab_time_out) @@ -206,5 +213,14 @@ def rab(num_layers, train_len, candidate_len, bs, dtype): @pytest.mark.parametrize("candidate_len", [600]) @pytest.mark.parametrize("bs", [1, 2, 4]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -def test_rab_eval(num_layers, train_len, candidate_len, bs, dtype): - rab(num_layers, train_len, candidate_len, bs, dtype) +def test_rab_time_eval(num_layers, train_len, candidate_len, bs, dtype): + rab_time(num_layers, train_len, candidate_len, bs, dtype) + + +@pytest.mark.parametrize("num_layers", [8]) +@pytest.mark.parametrize("train_len", [500, 1000, 2000, 4000]) +@pytest.mark.parametrize("candidate_len", [600]) +@pytest.mark.parametrize("bs", [1, 2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_rab_pos_eval(num_layers, train_len, candidate_len, bs, dtype): + rab_pos(num_layers, train_len, candidate_len, bs, dtype) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/split_embedding_codegen_lookup_adagrad_function/test_split_embedding_codegen_lookup_function.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/split_embedding_codegen_lookup_adagrad_function/test_split_embedding_codegen_lookup_function.py index c68eaee0bc3faa877bed7b76ec7268791b1b778a..f1b8a2d3c6ff5f0e55cf736ad653131f6b7759f2 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/split_embedding_codegen_lookup_adagrad_function/test_split_embedding_codegen_lookup_function.py +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/split_embedding_codegen_lookup_adagrad_function/test_split_embedding_codegen_lookup_function.py @@ -51,7 +51,7 @@ OPTIMIZER_PARAM = { @dataclass class LookupParams: - tables: list[int] + tables: list[list[int]] mutile_hots: list[int] batch_size: int pooling_mode: PoolingMode diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp index 8df45253f3e61df83b7bad9647702d47faf382df..57e7bf6423a102dae98537ff917afc1b714ce8fb 100644 --- a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_library/2.6.0/relative_attn_bias/relative_attn_bias.cpp @@ -18,27 +18,35 @@ using torch::autograd::Function; using namespace at; using namespace std; -std::tuple relative_attn_bias_impl_npu(const Tensor& relPosBias, const Tensor& identity, - const Tensor& timestamps, const Tensor& timestampsWeights, - const at::IntArrayRef pastValidLens, const double bucketDivisor) +Tensor relative_attn_bias_pos_impl_npu(const Tensor& relPosBias, const Tensor& identity, + const at::IntArrayRef pastValidLens) { auto relPosBiasConti = relPosBias.contiguous(); auto identityConti = identity.contiguous(); - auto timestampsConti = timestamps.contiguous(); - auto timestampsWeightsConti = timestampsWeights.contiguous(); const int bs = pastValidLens.size(); const int sx2 = relPosBias.size(0); // relPosBias(2s, 2s) - const int s = sx2 / 2; - const int numLayers = timestampsWeights.size(0); at::Tensor rabPosOut = at::zeros({bs, sx2, sx2}, relPosBiasConti.options()); - at::Tensor rabTimeOut = at::zeros({numLayers, bs, s, 1, s, 1}, timestampsWeightsConti.options()); - EXEC_NPU_CMD(aclnnRelativeAttnBias, relPosBiasConti, identityConti, timestampsConti, timestampsWeightsConti, - pastValidLens, bucketDivisor, rabPosOut, rabTimeOut); + EXEC_NPU_CMD(aclnnRelativeAttnBiasPos, relPosBiasConti, identityConti, pastValidLens, rabPosOut); + return rabPosOut; +} + +Tensor relative_attn_bias_time_impl_npu(const Tensor& timestamps, const Tensor& timestampsWeights, + const double bucketDivisor) +{ + auto timestampsConti = timestamps.contiguous(); + auto timestampsWeightsConti = timestampsWeights.contiguous(); + const int numLayers = timestampsWeights.size(0); + const int bs = timestampsConti.size(0); + const int s = timestampsConti.size(1); + const int sx2 = s * 2; + + at::Tensor rabTimeOut = at::zeros({numLayers, bs, s, 1, s, 1}, timestampsWeightsConti.options()); + EXEC_NPU_CMD(aclnnRelativeAttnBiasTime, timestampsConti, timestampsWeightsConti, bucketDivisor, rabTimeOut); rabTimeOut = rabTimeOut.repeat({1, 1, 1, 2, 1, 2}).reshape({numLayers, bs, sx2, sx2}); - return {rabPosOut, rabTimeOut}; + return rabTimeOut; } Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, const Tensor& bucketTimestamps, @@ -51,9 +59,8 @@ Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, const Ten auto rabTimeGradConti = rabTimeGrad.contiguous(); auto bucketTimestampsConti = bucketTimestamps.contiguous(); // (n, b, s, s) - bucketTimestampsConti = bucketTimestampsConti.view({batchsize, s, 1, s, 1}) - .repeat({1, 1, 2, 1, 2}) - .reshape({batchsize, sx2, sx2}); + bucketTimestampsConti = + bucketTimestampsConti.view({batchsize, s, 1, s, 1}).repeat({1, 1, 2, 1, 2}).reshape({batchsize, sx2, sx2}); at::Tensor rabTimeGradOut = at::zeros({numLayers, numBuckets}, rabTimeGrad.options()); EXEC_NPU_CMD(aclnnRelativeAttnBiasBackward, rabTimeGradConti, bucketTimestampsConti, numBuckets, rabTimeGradOut); @@ -62,13 +69,14 @@ Tensor relative_attn_bias_backward_impl_npu(const Tensor& rabTimeGrad, const Ten TORCH_LIBRARY_FRAGMENT(mxrec, m) { - m.def("relative_attn_bias(Tensor rel_pos_bias, " - " Tensor identity, " - " Tensor timestamps, " - " Tensor timestamps_weights, " - " int[] past_valid_lens," - " float bucket_divisor" - " ) -> (Tensor, Tensor)"); + m.def("relative_attn_bias_time(Tensor timestamps, " + " Tensor timestamps_weights, " + " float bucket_divisor" + " ) -> Tensor"); + m.def("relative_attn_bias_pos(Tensor rel_pos_bias, " + " Tensor identity, " + " int[] past_valid_lens" + " ) -> Tensor"); m.def("relative_attn_bias_backward(Tensor rab_time_grad, " " Tensor bucket_timestamps, " " int num_buckets" @@ -77,12 +85,14 @@ TORCH_LIBRARY_FRAGMENT(mxrec, m) TORCH_LIBRARY_IMPL(mxrec, PrivateUse1, m) { - m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); + m.impl("relative_attn_bias_pos", &relative_attn_bias_pos_impl_npu); + m.impl("relative_attn_bias_time", &relative_attn_bias_time_impl_npu); m.impl("relative_attn_bias_backward", &relative_attn_bias_backward_impl_npu); } TORCH_LIBRARY_IMPL(fbgemm, PrivateUse1, m) { - m.impl("relative_attn_bias", &relative_attn_bias_impl_npu); + m.impl("relative_attn_bias_pos", &relative_attn_bias_pos_impl_npu); + m.impl("relative_attn_bias_time", &relative_attn_bias_time_impl_npu); m.impl("relative_attn_bias_backward", &relative_attn_bias_backward_impl_npu); }