From 5e46cf154de509f6a3ff59d32b99445bab055006 Mon Sep 17 00:00:00 2001 From: y00806855 Date: Wed, 18 Jun 2025 17:20:25 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0all2all=5Fvc=5Fmesh?= =?UTF-8?q?=E9=80=9A=E4=BF=A1=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cust_op/lccl/emb_custom.json | 75 ++++ .../lccl/op_host/lccl_all_to_all_vc_mesh.cpp | 122 ++++++ .../op_host/lccl_all_to_all_vc_mesh_tiling.h | 31 ++ cust_op/lccl/op_kernel/all2all_vc_mesh.h | 354 ++++++++++++++++++ .../lccl/op_kernel/lccl_all2all_vc_mesh.cpp | 30 ++ cust_op/lccl/run.sh | 1 + src/ops_tf/hybrid_dataset_ops.cpp | 22 ++ 7 files changed, 635 insertions(+) create mode 100644 cust_op/lccl/op_host/lccl_all_to_all_vc_mesh.cpp create mode 100644 cust_op/lccl/op_host/lccl_all_to_all_vc_mesh_tiling.h create mode 100644 cust_op/lccl/op_kernel/all2all_vc_mesh.h create mode 100644 cust_op/lccl/op_kernel/lccl_all2all_vc_mesh.cpp diff --git a/cust_op/lccl/emb_custom.json b/cust_op/lccl/emb_custom.json index 3751f4bc..6ffb2818 100644 --- a/cust_op/lccl/emb_custom.json +++ b/cust_op/lccl/emb_custom.json @@ -243,5 +243,80 @@ "type": "int" } ] + }, + { + "op": "LcclAllToAllVCMesh", + "language": "cpp", + "input_desc": [ + { + "name": "send_data", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "fp32","fp16" + ] + }, + { + "name": "send_count_matrix", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "int64","int64" + ] + }, + { + "name": "shape_vec", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "int32","int32" + ] + }, + { + "name": "peer_mem", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "int64","int64" + ] + } + ], + "output_desc": [ + { + "name": "recv_data", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "fp32","fp32" + ] + } + ], + "attr": [ + { + "name": "rank", + "param_type": "required", + "type": "int" + }, + { + "name": "rank_size", + "param_type": "required", + "type": "int" + }, + { + "name": "dim", + "param_type": "required", + "type": "int" + } + ] } ] \ No newline at end of file diff --git a/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh.cpp b/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh.cpp new file mode 100644 index 00000000..0c131d26 --- /dev/null +++ b/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh.cpp @@ -0,0 +1,122 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lccl_all_to_all_vc_mesh_tiling.h" + +#include "register/op_def_registry.h" + +static int g_magic = 9; +static int g_blockDim = 64; +namespace optiling { + static ge::graphStatus TilingFunc(gert::TilingContext* context) + { + LcclAllToAllVCMeshTilingData tiling; + + auto sendBuff = context->GetInputTensor(0); + auto* attrs = context->GetAttrs(); + const auto* rank_ = attrs->GetAttrPointer(0); + const auto* rankSize_ = attrs->GetAttrPointer(1); + int rank = static_cast(*rank_); + int rankSize = static_cast(*rankSize_); + + tiling.set_rank(rank); + tiling.set_rankSize(rankSize); + + tiling.set_magic(g_magic); + + context->SetBlockDim(g_blockDim); + + // 参考官网默认值 设置workSpace大小 + uint32_t sysWorkspaceSize = 16 * 1024 * 1024; + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + currentWorkspace[0] = sysWorkspaceSize; + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; + } +} + + +namespace ge { + static ge::graphStatus InferShape(gert::InferShapeContext* context) + { + const gert::Shape* x1_shape = context->GetInputShape(0); + const gert::Shape* x3_shape = context->GetInputShape(2); + gert::Shape* y_shape = context->GetOutputShape(0); + + y_shape->SetDim(0, x3_shape->GetDim(0)); + y_shape->SetDim(1, x1_shape->GetDim(1)); + y_shape->SetDim(2, 1); + + return ge::GRAPH_SUCCESS; + } + + static ge::graphStatus InferDataType(gert::InferDataTypeContext* context) + { + const auto inputDataType = context->GetInputDataType(0); + context->SetOutputDataType(0, inputDataType); + return ge::GRAPH_SUCCESS; + } +} + + +namespace ops { + class LcclAllToAllVCMesh : public OpDef { + public: + explicit LcclAllToAllVCMesh(const char* name) : OpDef(name) + { + this->Input("send_data") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("send_count_matrix") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("shape_vec") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("peer_mem") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("recv_data") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Attr("rank").Int(); + this->Attr("rank_size").Int(); + this->Attr("dim").Int(); + + this->SetInferShape(ge::InferShape); + this->SetInferDataType(ge::InferDataType); + + this->AICore() + .SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910_95"); + } + }; + + OP_ADD(LcclAllToAllVCMesh); +} diff --git a/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh_tiling.h b/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh_tiling.h new file mode 100644 index 00000000..b5403378 --- /dev/null +++ b/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh_tiling.h @@ -0,0 +1,31 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LCCL_ALL_TO_ALL_V_C_MESH_TILING_H +#define LCCL_ALL_TO_ALL_V_C_MESH_TILING_H + +#include "register/tilingdata_base.h" + +namespace optiling { + BEGIN_TILING_DATA_DEF(LcclAllToAllVCMeshTilingData) + TILING_DATA_FIELD_DEF(int64_t, rank); + TILING_DATA_FIELD_DEF(int64_t, rankSize); + TILING_DATA_FIELD_DEF(int64_t, magic); + END_TILING_DATA_DEF; + + REGISTER_TILING_DATA_CLASS(LcclAllToAllVCMesh, LcclAllToAllVCMeshTilingData) +} + +#endif // LCCL_ALL_TO_ALL_V_C_MESH_TILING_H \ No newline at end of file diff --git a/cust_op/lccl/op_kernel/all2all_vc_mesh.h b/cust_op/lccl/op_kernel/all2all_vc_mesh.h new file mode 100644 index 00000000..5f389061 --- /dev/null +++ b/cust_op/lccl/op_kernel/all2all_vc_mesh.h @@ -0,0 +1,354 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LCCL_ALL2ALLVC_MESH_H +#define LCCL_ALL2ALLVC_MESH_H + +#include "collectives.h" +#include "ipc_queue.h" + +using namespace AscendC; + +template +class All2AllVCMesh : public Collectives { + constexpr static int INVALID_RANK_NUM = 0xFFFFFFFF; // 非法rank + constexpr static int64_t CORE_NUMS_PER_STAGE = 16; // 每个阶段提供的最大核数 + constexpr static int64_t SHARE_QUE_DEPTH = 16; // 单个共享队列深度 + + constexpr static int64_t SINGLE_RANK_MAX_NUM = CORE_NUMS_PER_STAGE; + constexpr static int64_t MULTI_RANK_SIZE = (LCAL_MAX_RANK_SIZE + SINGLE_RANK_MAX_NUM - 1) / SINGLE_RANK_MAX_NUM; + + constexpr static int64_t IDLER_CORE = 0; // 闲置的核 + constexpr static int64_t PRODUCER_CORE = 1; // 生产组,负责向共享内存写入数据,input->share,或者share->share + constexpr static int64_t CONSUMER_CORE = 2; // 消费组,负责从共享内存读出数据,share->output + +public: + __aicore__ inline All2AllVCMesh(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) + { + } + + __aicore__ inline void Init(GM_ADDR input, GM_ADDR send_count_matrix, GM_ADDR shape_vec, GM_ADDR peer_mem, + GM_ADDR output, int64_t rank, int64_t rankSize, int64_t magic) + { + this->root = 0; + this->len = 0; + this->magic = magic; + this->rank = rank; + this->rankSize = rankSize; + + blockIdx = GetBlockIdx(); + blockNum = GetBlockNum(); + + GlobalTensor peerMemsAddrGm; + int64_t peer_mem_addr = reinterpret_cast(peer_mem); + peerMemsAddrGm.SetGlobalBuffer((__gm__ int64_t*)peer_mem_addr, rankSize * sizeof (int64_t)); + for (int i = 0; i < rankSize; ++i) { + shareAddrs[i] = (GM_ADDR)(peerMemsAddrGm.GetValue(i))+ + (this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET); + } + + sendCountMatrixGm.SetGlobalBuffer((__gm__ int64_t*)send_count_matrix, rankSize * rankSize * sizeof (int64_t)); + // 初始化共享内存信息 + InitShare(); + // 初始化核分组 + InitCoreGroup(); + // 初始化数据切片 + InitDataSlice(); + + // 初始化输入输出 + for (int j = 0; j < rankSize; j++) { + sendLen += sendCountMatrixGm.GetValue(rank * rankSize + j); + } + inputGt.SetGlobalBuffer((__gm__ T*)input); + for (int j = 0; j < rankSize; j++) { + revLen += sendCountMatrixGm.GetValue(j * rankSize + rank); + } + outputGt.SetGlobalBuffer((__gm__ T*)output); + } + + __aicore__ inline void Process() + { + if (coreGroup == PRODUCER_CORE) { + ProducerStage(); + } + if (coreGroup == CONSUMER_CORE) { + ConsumerStage(); + } + } + +private: + // 计算rank数量较大时的queNum 以及 每个队列里单块可放入的元素数量queElemLen + __aicore__ inline void InitShare() + { + int64_t queNum = CORE_NUMS_PER_STAGE; // 共享内存最小切分数量为单阶段核数 + if (rankSize > CORE_NUMS_PER_STAGE) { + queNum = rankSize; + } + queElemLen = IPC_BUFF_MAX_SIZE / sizeof(T) / queNum / SHARE_QUE_DEPTH; // 计算共享队列元素大小 + } + + __aicore__ inline void InitCoreGroup() + { + // 每个rank在每个stage分到的core数量, 多卡下为1 + coreNumPerRank = CORE_NUMS_PER_STAGE / rankSize > 1 ? + CORE_NUMS_PER_STAGE / rankSize : 1; + // 多卡下为CORE_NUMS_PER_STAGE + coreNumPerStage = coreNumPerRank * rankSize < CORE_NUMS_PER_STAGE ? + coreNumPerRank * rankSize : CORE_NUMS_PER_STAGE; + // 一个core处理多少rank + rankNumPerCore = CeilDiv(rankSize, coreNumPerStage); + + // 单核负责多rank时,计算虚拟核索引并储存(索引为多个,由一个物理核执行其它虚拟核索引的操作) + // 多卡场景下flagNumPerStage 为 ranksize + flagNumPerStage = coreNumPerStage * rankNumPerCore; + // 将core 分类到不同的stage, 并且找到本core对应要处理的rank + if (blockIdx < coreNumPerStage) { + coreGroup = PRODUCER_CORE; + for (auto i = 0; i < rankNumPerCore; ++i) { + groupCoreIdx[i] = blockIdx * rankNumPerCore + i; + } + } else if (blockIdx < coreNumPerStage + coreNumPerStage) { + coreGroup = CONSUMER_CORE; + for (auto i = 0; i < rankNumPerCore; ++i) { + groupCoreIdx[i] = blockIdx * rankNumPerCore + i - flagNumPerStage; + } + } else { + coreGroup = IDLER_CORE; + } + } + + __aicore__ inline void InitDataSlice() + { + queLen = queElemLen * SHARE_QUE_DEPTH; // 一个que的可放入的元素数量 + queSize = queLen * sizeof(T); + + // 生产者负责搬运本rank的输入数据至共享内存,input-->share + if (coreGroup == PRODUCER_CORE) { + ProducerDataSlice(); + } else if (coreGroup == CONSUMER_CORE) { + ConsumerDataSlice(); + } + } + + __aicore__ inline void ProducerDataSlice() + { + maxSliceNum = 0; + for (auto i = 0; i < rankNumPerCore; ++i) { + // 当前核负责的rank, 因为是基于trackRankSize计算的groupCoreIdx,所以要乘RANK_SIZE_TWO + targetRank[i] = groupCoreIdx[i] / coreNumPerRank; + if (targetRank[i] >= rankSize) { + targetRank[i] = INVALID_RANK_NUM; + continue; + } + // 当前核负责的ipcQue + writeQue[i].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + groupCoreIdx[i] * queSize, queLen, queElemLen); + + // 当前核负责的数据长度和偏移 + sendOffset[i] = 0; + for (int j = 0; j < targetRank[i]; j++) { + sendOffset[i] += sendCountMatrixGm.GetValue(rank * rankSize + j); + } + inputDataLen[i] = sendCountMatrixGm.GetValue(rank * rankSize + targetRank[i]); + SplitData(inputDataLen[i], coreNumPerRank, groupCoreIdx[i] % coreNumPerRank, inputOffset[i], + inputLen[i], sendOffset[i]); + // 当前核负责的数据切片数,能分成一个que中的多少小块 + sliceNum[i] = CeilDiv(inputLen[i], queElemLen); + if (sliceNum[i] > maxSliceNum){ + maxSliceNum = sliceNum[i]; + } + } + } + + __aicore__ inline void ConsumerDataSlice() + { + maxSliceNum = 0; + for (auto i = 0; i < rankNumPerCore; ++i) { + // 当前核负责的rank + targetRank[i] = groupCoreIdx[i] / coreNumPerRank; + if (targetRank[i] >= rankSize) { + targetRank[i] = INVALID_RANK_NUM; + continue; + } + // 当前核负责的ipcQue + readQue[i].Init(&sync, magic, shareAddrs[targetRank[i]] + IPC_DATA_OFFSET + (rank * coreNumPerRank + groupCoreIdx[i] % coreNumPerRank) * queSize, + queLen, queElemLen); + // 当前核负责的数据长度和偏移 + revOffset[i] = 0; + for (int j = 0; j < targetRank[i]; j++) { + revOffset[i] += sendCountMatrixGm.GetValue(j * rankSize + rank); + } + outputDataLen[i] = sendCountMatrixGm.GetValue(targetRank[i] * rankSize + rank); + + SplitData(outputDataLen[i], coreNumPerRank, groupCoreIdx[i] % coreNumPerRank, outputOffset[i], + outputLen[i], revOffset[i]); + // 当前核负责的数据切片数 + sliceNum[i] = CeilDiv(outputLen[i], queElemLen); + if (sliceNum[i] > maxSliceNum){ + maxSliceNum = sliceNum[i]; + } + } + } + + __aicore__ inline void SplitData(const int64_t totalLen, const int64_t useCoreNum, const int64_t useCoreIdx, + int64_t& dataOffset, int64_t& dataLen, int startOffset) + { + // 向上整除获取每个core切分的数据个数 + dataLen = CeilDiv(totalLen, useCoreNum); + // 数据量极小或略微超过核数的情况,后面若干个core数据量为0 + dataOffset = useCoreIdx * dataLen + startOffset; // 使用当前block在useBlock里的相对索引来计算偏移 + if (useCoreIdx * dataLen >= totalLen) { + dataOffset = totalLen + startOffset; + dataLen = 0; + return; + } + // 非整除情况,最后一个core数据量为剩余数据量 + if (dataOffset + dataLen - startOffset > totalLen) { + dataLen = totalLen - useCoreIdx * dataLen; + } + } + + __aicore__ inline void ProducerStage() + { + for (auto i = 0; i < rankNumPerCore; ++i) { + if (targetRank[i] == INVALID_RANK_NUM) { + continue; + } + + // 写共享内存队列时,需要等待当前rank + waitRankListForWrite[i][0] = targetRank[i]; + waitNumForWrite[i] = 1; + waitBlockForWrite[i] = rank * coreNumPerRank + groupCoreIdx[i] % coreNumPerRank + flagNumPerStage; + } + InputToSharePipeline(); + } + + __aicore__ inline void InputToSharePipeline() + { + int64_t flagValue[MULTI_RANK_SIZE]; // 要等待标志位的储存值 + for (auto i = 0; i < rankNumPerCore; ++i) { + flagValue[i] = -1; // 统一赋值为-1,便于后续小于判断 + } + // 以最多切片sliceNum[0]为切片数进行循环,切片数不足的不拷贝 + for (auto sliceIdx = 0; sliceIdx < maxSliceNum; ++sliceIdx) { + for (auto i = 0; i < rankNumPerCore; ++i) { + if (targetRank[i] == INVALID_RANK_NUM) { + continue; + } + InputToShareSlice(i, sliceIdx, flagValue[i]); + } + } + } + + __aicore__ inline void InputToShareSlice(int64_t idx, int64_t sliceIdx, int64_t& flagValue) + { + readGt = inputGt[sliceIdx * queElemLen + inputOffset[idx]]; + // 计算当前切片拷贝数据量,数据量为0时不拷贝 + copyLen = inputLen[idx] - queElemLen * sliceIdx; + if (copyLen > queElemLen) { + copyLen = queElemLen; + } else if (copyLen < 0) { + copyLen = 0; + } + writeQue[idx].DeQue(waitRankListForWrite[idx], waitNumForWrite[idx], waitBlockForWrite[idx]); + writeGt = writeQue[idx].EnQue(); + if (copyLen > 0) { + CpGM2GMPingPong(copyLen * sizeof(T), readGt, writeGt, COPYONLY); + } + sync.SetInnerFlag(magic, sliceIdx, rank, groupCoreIdx[idx]); + } + + __aicore__ inline void ConsumerStage() + { + int64_t flagValue[MULTI_RANK_SIZE]; // 要等待标志位的储存值 + for (auto i = 0; i < rankNumPerCore; ++i) { + flagValue[i] = -1; + } + // 以最多切片sliceNum[0]为切片数进行循环,切片数不足的不拷贝 + for (auto sliceIdx = 0; sliceIdx < maxSliceNum; ++sliceIdx) { + for (auto i = 0; i < rankNumPerCore; ++i) { + if (targetRank[i] == INVALID_RANK_NUM) { + continue; + } + ShareToOutputSlice(i, sliceIdx, flagValue[i]); + } + } + } + + __aicore__ inline void ShareToOutputSlice(int64_t idx, int64_t sliceIdx, int64_t& flagValue) + { + // 计算当前切片拷贝数据量,数据量为0时不拷贝 + copyLen = outputLen[idx] - queElemLen * sliceIdx; + if (copyLen > queElemLen) { + copyLen = queElemLen; + } else if (copyLen < 0) { + copyLen = 0; + } + + // 拉取本rank数据 + if (flagValue < sliceIdx) { + sync.WaitInnerFlag(magic, sliceIdx, targetRank[idx], rank * coreNumPerRank + groupCoreIdx[idx]%coreNumPerRank); + flagValue = sync.GetInnerFlag(targetRank[idx], rank * coreNumPerRank + groupCoreIdx[idx]%coreNumPerRank) & EVENT_ID_MASK; + } + readGt = readQue[idx].ReadFront(); + if (copyLen > 0) { + writeGt = outputGt[sliceIdx * queElemLen + outputOffset[idx]]; + CpGM2GMPingPong(copyLen * sizeof(T), readGt, writeGt, COPYONLY); + } + sync.SetInnerFlag(magic, sliceIdx, rank, groupCoreIdx[idx] + flagNumPerStage); + } + + GlobalTensor inputGt; + GlobalTensor outputGt; + GlobalTensor readGt; + GlobalTensor writeGt; + GlobalTensor sendCountMatrixGm; + int64_t maxSliceNum; + int64_t revLen = 0; + int64_t sendLen = 0; + int64_t sendOffset[MULTI_RANK_SIZE]; + int64_t revOffset[MULTI_RANK_SIZE]; + int64_t inputDataLen[MULTI_RANK_SIZE]; + int64_t outputDataLen[MULTI_RANK_SIZE]; + + int waitRankListForWrite[MULTI_RANK_SIZE][1]; // 写共享内存时,需要等待的rank列表 + int waitNumForWrite[MULTI_RANK_SIZE]; // 写共享内存时,需要等待的数量 + int waitBlockForWrite[MULTI_RANK_SIZE]; // 写共享内存时,需要等待的标志位 + + int64_t queLen; + int64_t queSize; + int64_t coreNumPerStage; // 每个阶段使用的核数 + int64_t flagNumPerStage; // 每个阶段使用的同步标志位数 + int64_t coreNumPerRank; // 每个rank数据分配的核数 + int64_t rankNumPerCore; // 每个核负责的rank数 + int64_t coreGroup; // 当前核的功能分组 + int64_t groupCoreIdx[MULTI_RANK_SIZE]; // 当前核在组内的索引,可以为等效核索引 + int64_t targetRank[MULTI_RANK_SIZE]; // 当前核负责的rank + + IpcQueue readQue[MULTI_RANK_SIZE]; // 读端共享内存队列 + IpcQueue writeQue[MULTI_RANK_SIZE]; // 写端共享内存队列 + int64_t queElemLen; // 共享内存队列里每个元素大小(以T计) + + int64_t sliceNum[MULTI_RANK_SIZE]; // 当前核负责的数据切片总数 + int64_t copyLen; // 当前拷贝数据片的长度(以T计) + int64_t inputOffset[MULTI_RANK_SIZE]; // 当前核负责的input偏移(以T计) + int64_t inputLen[MULTI_RANK_SIZE]; // 当前核负责的input长度(以T计) + int64_t outputOffset[MULTI_RANK_SIZE]; // 当前核负责的output偏移(以T计) + int64_t outputLen[MULTI_RANK_SIZE]; // 当前核负责的output长度(以T计) +}; + +#endif // LCCL_ALL2ALLVC_MESH_H diff --git a/cust_op/lccl/op_kernel/lccl_all2all_vc_mesh.cpp b/cust_op/lccl/op_kernel/lccl_all2all_vc_mesh.cpp new file mode 100644 index 00000000..3b222678 --- /dev/null +++ b/cust_op/lccl/op_kernel/lccl_all2all_vc_mesh.cpp @@ -0,0 +1,30 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel_operator.h" +#include "all2all_vc_mesh.h" + +extern "C" __global__ __aicore__ void lccl_all_to_all_v_c_mesh(GM_ADDR send_data, GM_ADDR send_count_matrix, + GM_ADDR shape_vec, GM_ADDR peer_mem, GM_ADDR rev_data, + GM_ADDR workspace, GM_ADDR tiling) { + GET_TILING_DATA(tiling_data, tiling); + + All2AllVCMesh opKernel(tiling_data.rank, tiling_data.rankSize, (1 << 2)); + opKernel.Init(send_data, send_count_matrix, shape_vec, peer_mem, rev_data, + tiling_data.rank, tiling_data.rankSize, tiling_data.magic); + + opKernel.Process(); +} \ No newline at end of file diff --git a/cust_op/lccl/run.sh b/cust_op/lccl/run.sh index c1820505..51e03992 100644 --- a/cust_op/lccl/run.sh +++ b/cust_op/lccl/run.sh @@ -23,6 +23,7 @@ rm -rf ./custom_op msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b1 -lan cpp -out ./custom_op -m 0 -op LcclAllToAll msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b1 -lan cpp -out ./custom_op -m 1 -op LcclAllUss msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b1 -lan cpp -out ./custom_op -m 1 -op LcclGatherAll +msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910_95 -lan cpp -out ./custom_op -m 1 -op LcclAllToAllVCMesh cp -rf op_kernel custom_op/ cp -rf op_host custom_op/ diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 2cc7e813..14167838 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -716,6 +716,28 @@ namespace tensorflow { }); REGISTER_KERNEL_BUILDER(Name("LcclAllToAll").Device(DEVICE_CPU), MxRec::CustOps); + REGISTER_OP("LcclAllToAllVCMesh") + .Input("send_data: float") + .Input("send_count_matrix: int64") + .Input("shape_vec: int32") + .Input("peer_mem: int64") + .Attr("rank: int") + .Attr("rank_size: int") + .Attr("dim: int") + .Output("rev_data: float") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ShapeHandle dataShape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &dataShape)); + tensorflow::shape_inference::DimensionHandle rows = c->Dim(dataShape, 0); + int64_t shape1 = c->Value(rows); + int dim = 0; + c->GetAttr("dim", &dim); + c->set_output(0, c->MakeShape({shape1, dim, 1})); + return Status::OK(); + }); + REGISTER_KERNEL_BUILDER(Name("LcclAllToAllVCMesh").Device(DEVICE_CPU), MxRec::CustOps); + REGISTER_OP("LcclGatherAll") .Input("emb_table: float") .Input("lookup: int32") -- Gitee From c450a8c29f19cfeda1060a67a5e4bf986049ca0e Mon Sep 17 00:00:00 2001 From: ybl Date: Wed, 18 Jun 2025 17:23:38 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0all2all=5Fvc=5Fmesh?= =?UTF-8?q?=E9=80=9A=E4=BF=A1=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cust_op/lccl/emb_custom.json | 75 ++++ .../lccl/op_host/lccl_all_to_all_vc_mesh.cpp | 122 ++++++ .../op_host/lccl_all_to_all_vc_mesh_tiling.h | 31 ++ cust_op/lccl/op_kernel/all2all_vc_mesh.h | 354 ++++++++++++++++++ .../lccl/op_kernel/lccl_all2all_vc_mesh.cpp | 30 ++ cust_op/lccl/run.sh | 1 + src/ops_tf/hybrid_dataset_ops.cpp | 22 ++ 7 files changed, 635 insertions(+) create mode 100644 cust_op/lccl/op_host/lccl_all_to_all_vc_mesh.cpp create mode 100644 cust_op/lccl/op_host/lccl_all_to_all_vc_mesh_tiling.h create mode 100644 cust_op/lccl/op_kernel/all2all_vc_mesh.h create mode 100644 cust_op/lccl/op_kernel/lccl_all2all_vc_mesh.cpp diff --git a/cust_op/lccl/emb_custom.json b/cust_op/lccl/emb_custom.json index 3751f4bc..6ffb2818 100644 --- a/cust_op/lccl/emb_custom.json +++ b/cust_op/lccl/emb_custom.json @@ -243,5 +243,80 @@ "type": "int" } ] + }, + { + "op": "LcclAllToAllVCMesh", + "language": "cpp", + "input_desc": [ + { + "name": "send_data", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "fp32","fp16" + ] + }, + { + "name": "send_count_matrix", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "int64","int64" + ] + }, + { + "name": "shape_vec", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "int32","int32" + ] + }, + { + "name": "peer_mem", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "int64","int64" + ] + } + ], + "output_desc": [ + { + "name": "recv_data", + "param_type": "required", + "format": [ + "ND","ND" + ], + "type": [ + "fp32","fp32" + ] + } + ], + "attr": [ + { + "name": "rank", + "param_type": "required", + "type": "int" + }, + { + "name": "rank_size", + "param_type": "required", + "type": "int" + }, + { + "name": "dim", + "param_type": "required", + "type": "int" + } + ] } ] \ No newline at end of file diff --git a/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh.cpp b/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh.cpp new file mode 100644 index 00000000..0c131d26 --- /dev/null +++ b/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh.cpp @@ -0,0 +1,122 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lccl_all_to_all_vc_mesh_tiling.h" + +#include "register/op_def_registry.h" + +static int g_magic = 9; +static int g_blockDim = 64; +namespace optiling { + static ge::graphStatus TilingFunc(gert::TilingContext* context) + { + LcclAllToAllVCMeshTilingData tiling; + + auto sendBuff = context->GetInputTensor(0); + auto* attrs = context->GetAttrs(); + const auto* rank_ = attrs->GetAttrPointer(0); + const auto* rankSize_ = attrs->GetAttrPointer(1); + int rank = static_cast(*rank_); + int rankSize = static_cast(*rankSize_); + + tiling.set_rank(rank); + tiling.set_rankSize(rankSize); + + tiling.set_magic(g_magic); + + context->SetBlockDim(g_blockDim); + + // 参考官网默认值 设置workSpace大小 + uint32_t sysWorkspaceSize = 16 * 1024 * 1024; + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + currentWorkspace[0] = sysWorkspaceSize; + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; + } +} + + +namespace ge { + static ge::graphStatus InferShape(gert::InferShapeContext* context) + { + const gert::Shape* x1_shape = context->GetInputShape(0); + const gert::Shape* x3_shape = context->GetInputShape(2); + gert::Shape* y_shape = context->GetOutputShape(0); + + y_shape->SetDim(0, x3_shape->GetDim(0)); + y_shape->SetDim(1, x1_shape->GetDim(1)); + y_shape->SetDim(2, 1); + + return ge::GRAPH_SUCCESS; + } + + static ge::graphStatus InferDataType(gert::InferDataTypeContext* context) + { + const auto inputDataType = context->GetInputDataType(0); + context->SetOutputDataType(0, inputDataType); + return ge::GRAPH_SUCCESS; + } +} + + +namespace ops { + class LcclAllToAllVCMesh : public OpDef { + public: + explicit LcclAllToAllVCMesh(const char* name) : OpDef(name) + { + this->Input("send_data") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("send_count_matrix") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("shape_vec") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("peer_mem") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("recv_data") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Attr("rank").Int(); + this->Attr("rank_size").Int(); + this->Attr("dim").Int(); + + this->SetInferShape(ge::InferShape); + this->SetInferDataType(ge::InferDataType); + + this->AICore() + .SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend910_95"); + } + }; + + OP_ADD(LcclAllToAllVCMesh); +} diff --git a/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh_tiling.h b/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh_tiling.h new file mode 100644 index 00000000..b5403378 --- /dev/null +++ b/cust_op/lccl/op_host/lccl_all_to_all_vc_mesh_tiling.h @@ -0,0 +1,31 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef LCCL_ALL_TO_ALL_V_C_MESH_TILING_H +#define LCCL_ALL_TO_ALL_V_C_MESH_TILING_H + +#include "register/tilingdata_base.h" + +namespace optiling { + BEGIN_TILING_DATA_DEF(LcclAllToAllVCMeshTilingData) + TILING_DATA_FIELD_DEF(int64_t, rank); + TILING_DATA_FIELD_DEF(int64_t, rankSize); + TILING_DATA_FIELD_DEF(int64_t, magic); + END_TILING_DATA_DEF; + + REGISTER_TILING_DATA_CLASS(LcclAllToAllVCMesh, LcclAllToAllVCMeshTilingData) +} + +#endif // LCCL_ALL_TO_ALL_V_C_MESH_TILING_H \ No newline at end of file diff --git a/cust_op/lccl/op_kernel/all2all_vc_mesh.h b/cust_op/lccl/op_kernel/all2all_vc_mesh.h new file mode 100644 index 00000000..5f389061 --- /dev/null +++ b/cust_op/lccl/op_kernel/all2all_vc_mesh.h @@ -0,0 +1,354 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LCCL_ALL2ALLVC_MESH_H +#define LCCL_ALL2ALLVC_MESH_H + +#include "collectives.h" +#include "ipc_queue.h" + +using namespace AscendC; + +template +class All2AllVCMesh : public Collectives { + constexpr static int INVALID_RANK_NUM = 0xFFFFFFFF; // 非法rank + constexpr static int64_t CORE_NUMS_PER_STAGE = 16; // 每个阶段提供的最大核数 + constexpr static int64_t SHARE_QUE_DEPTH = 16; // 单个共享队列深度 + + constexpr static int64_t SINGLE_RANK_MAX_NUM = CORE_NUMS_PER_STAGE; + constexpr static int64_t MULTI_RANK_SIZE = (LCAL_MAX_RANK_SIZE + SINGLE_RANK_MAX_NUM - 1) / SINGLE_RANK_MAX_NUM; + + constexpr static int64_t IDLER_CORE = 0; // 闲置的核 + constexpr static int64_t PRODUCER_CORE = 1; // 生产组,负责向共享内存写入数据,input->share,或者share->share + constexpr static int64_t CONSUMER_CORE = 2; // 消费组,负责从共享内存读出数据,share->output + +public: + __aicore__ inline All2AllVCMesh(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) + { + } + + __aicore__ inline void Init(GM_ADDR input, GM_ADDR send_count_matrix, GM_ADDR shape_vec, GM_ADDR peer_mem, + GM_ADDR output, int64_t rank, int64_t rankSize, int64_t magic) + { + this->root = 0; + this->len = 0; + this->magic = magic; + this->rank = rank; + this->rankSize = rankSize; + + blockIdx = GetBlockIdx(); + blockNum = GetBlockNum(); + + GlobalTensor peerMemsAddrGm; + int64_t peer_mem_addr = reinterpret_cast(peer_mem); + peerMemsAddrGm.SetGlobalBuffer((__gm__ int64_t*)peer_mem_addr, rankSize * sizeof (int64_t)); + for (int i = 0; i < rankSize; ++i) { + shareAddrs[i] = (GM_ADDR)(peerMemsAddrGm.GetValue(i))+ + (this->magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET); + } + + sendCountMatrixGm.SetGlobalBuffer((__gm__ int64_t*)send_count_matrix, rankSize * rankSize * sizeof (int64_t)); + // 初始化共享内存信息 + InitShare(); + // 初始化核分组 + InitCoreGroup(); + // 初始化数据切片 + InitDataSlice(); + + // 初始化输入输出 + for (int j = 0; j < rankSize; j++) { + sendLen += sendCountMatrixGm.GetValue(rank * rankSize + j); + } + inputGt.SetGlobalBuffer((__gm__ T*)input); + for (int j = 0; j < rankSize; j++) { + revLen += sendCountMatrixGm.GetValue(j * rankSize + rank); + } + outputGt.SetGlobalBuffer((__gm__ T*)output); + } + + __aicore__ inline void Process() + { + if (coreGroup == PRODUCER_CORE) { + ProducerStage(); + } + if (coreGroup == CONSUMER_CORE) { + ConsumerStage(); + } + } + +private: + // 计算rank数量较大时的queNum 以及 每个队列里单块可放入的元素数量queElemLen + __aicore__ inline void InitShare() + { + int64_t queNum = CORE_NUMS_PER_STAGE; // 共享内存最小切分数量为单阶段核数 + if (rankSize > CORE_NUMS_PER_STAGE) { + queNum = rankSize; + } + queElemLen = IPC_BUFF_MAX_SIZE / sizeof(T) / queNum / SHARE_QUE_DEPTH; // 计算共享队列元素大小 + } + + __aicore__ inline void InitCoreGroup() + { + // 每个rank在每个stage分到的core数量, 多卡下为1 + coreNumPerRank = CORE_NUMS_PER_STAGE / rankSize > 1 ? + CORE_NUMS_PER_STAGE / rankSize : 1; + // 多卡下为CORE_NUMS_PER_STAGE + coreNumPerStage = coreNumPerRank * rankSize < CORE_NUMS_PER_STAGE ? + coreNumPerRank * rankSize : CORE_NUMS_PER_STAGE; + // 一个core处理多少rank + rankNumPerCore = CeilDiv(rankSize, coreNumPerStage); + + // 单核负责多rank时,计算虚拟核索引并储存(索引为多个,由一个物理核执行其它虚拟核索引的操作) + // 多卡场景下flagNumPerStage 为 ranksize + flagNumPerStage = coreNumPerStage * rankNumPerCore; + // 将core 分类到不同的stage, 并且找到本core对应要处理的rank + if (blockIdx < coreNumPerStage) { + coreGroup = PRODUCER_CORE; + for (auto i = 0; i < rankNumPerCore; ++i) { + groupCoreIdx[i] = blockIdx * rankNumPerCore + i; + } + } else if (blockIdx < coreNumPerStage + coreNumPerStage) { + coreGroup = CONSUMER_CORE; + for (auto i = 0; i < rankNumPerCore; ++i) { + groupCoreIdx[i] = blockIdx * rankNumPerCore + i - flagNumPerStage; + } + } else { + coreGroup = IDLER_CORE; + } + } + + __aicore__ inline void InitDataSlice() + { + queLen = queElemLen * SHARE_QUE_DEPTH; // 一个que的可放入的元素数量 + queSize = queLen * sizeof(T); + + // 生产者负责搬运本rank的输入数据至共享内存,input-->share + if (coreGroup == PRODUCER_CORE) { + ProducerDataSlice(); + } else if (coreGroup == CONSUMER_CORE) { + ConsumerDataSlice(); + } + } + + __aicore__ inline void ProducerDataSlice() + { + maxSliceNum = 0; + for (auto i = 0; i < rankNumPerCore; ++i) { + // 当前核负责的rank, 因为是基于trackRankSize计算的groupCoreIdx,所以要乘RANK_SIZE_TWO + targetRank[i] = groupCoreIdx[i] / coreNumPerRank; + if (targetRank[i] >= rankSize) { + targetRank[i] = INVALID_RANK_NUM; + continue; + } + // 当前核负责的ipcQue + writeQue[i].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + groupCoreIdx[i] * queSize, queLen, queElemLen); + + // 当前核负责的数据长度和偏移 + sendOffset[i] = 0; + for (int j = 0; j < targetRank[i]; j++) { + sendOffset[i] += sendCountMatrixGm.GetValue(rank * rankSize + j); + } + inputDataLen[i] = sendCountMatrixGm.GetValue(rank * rankSize + targetRank[i]); + SplitData(inputDataLen[i], coreNumPerRank, groupCoreIdx[i] % coreNumPerRank, inputOffset[i], + inputLen[i], sendOffset[i]); + // 当前核负责的数据切片数,能分成一个que中的多少小块 + sliceNum[i] = CeilDiv(inputLen[i], queElemLen); + if (sliceNum[i] > maxSliceNum){ + maxSliceNum = sliceNum[i]; + } + } + } + + __aicore__ inline void ConsumerDataSlice() + { + maxSliceNum = 0; + for (auto i = 0; i < rankNumPerCore; ++i) { + // 当前核负责的rank + targetRank[i] = groupCoreIdx[i] / coreNumPerRank; + if (targetRank[i] >= rankSize) { + targetRank[i] = INVALID_RANK_NUM; + continue; + } + // 当前核负责的ipcQue + readQue[i].Init(&sync, magic, shareAddrs[targetRank[i]] + IPC_DATA_OFFSET + (rank * coreNumPerRank + groupCoreIdx[i] % coreNumPerRank) * queSize, + queLen, queElemLen); + // 当前核负责的数据长度和偏移 + revOffset[i] = 0; + for (int j = 0; j < targetRank[i]; j++) { + revOffset[i] += sendCountMatrixGm.GetValue(j * rankSize + rank); + } + outputDataLen[i] = sendCountMatrixGm.GetValue(targetRank[i] * rankSize + rank); + + SplitData(outputDataLen[i], coreNumPerRank, groupCoreIdx[i] % coreNumPerRank, outputOffset[i], + outputLen[i], revOffset[i]); + // 当前核负责的数据切片数 + sliceNum[i] = CeilDiv(outputLen[i], queElemLen); + if (sliceNum[i] > maxSliceNum){ + maxSliceNum = sliceNum[i]; + } + } + } + + __aicore__ inline void SplitData(const int64_t totalLen, const int64_t useCoreNum, const int64_t useCoreIdx, + int64_t& dataOffset, int64_t& dataLen, int startOffset) + { + // 向上整除获取每个core切分的数据个数 + dataLen = CeilDiv(totalLen, useCoreNum); + // 数据量极小或略微超过核数的情况,后面若干个core数据量为0 + dataOffset = useCoreIdx * dataLen + startOffset; // 使用当前block在useBlock里的相对索引来计算偏移 + if (useCoreIdx * dataLen >= totalLen) { + dataOffset = totalLen + startOffset; + dataLen = 0; + return; + } + // 非整除情况,最后一个core数据量为剩余数据量 + if (dataOffset + dataLen - startOffset > totalLen) { + dataLen = totalLen - useCoreIdx * dataLen; + } + } + + __aicore__ inline void ProducerStage() + { + for (auto i = 0; i < rankNumPerCore; ++i) { + if (targetRank[i] == INVALID_RANK_NUM) { + continue; + } + + // 写共享内存队列时,需要等待当前rank + waitRankListForWrite[i][0] = targetRank[i]; + waitNumForWrite[i] = 1; + waitBlockForWrite[i] = rank * coreNumPerRank + groupCoreIdx[i] % coreNumPerRank + flagNumPerStage; + } + InputToSharePipeline(); + } + + __aicore__ inline void InputToSharePipeline() + { + int64_t flagValue[MULTI_RANK_SIZE]; // 要等待标志位的储存值 + for (auto i = 0; i < rankNumPerCore; ++i) { + flagValue[i] = -1; // 统一赋值为-1,便于后续小于判断 + } + // 以最多切片sliceNum[0]为切片数进行循环,切片数不足的不拷贝 + for (auto sliceIdx = 0; sliceIdx < maxSliceNum; ++sliceIdx) { + for (auto i = 0; i < rankNumPerCore; ++i) { + if (targetRank[i] == INVALID_RANK_NUM) { + continue; + } + InputToShareSlice(i, sliceIdx, flagValue[i]); + } + } + } + + __aicore__ inline void InputToShareSlice(int64_t idx, int64_t sliceIdx, int64_t& flagValue) + { + readGt = inputGt[sliceIdx * queElemLen + inputOffset[idx]]; + // 计算当前切片拷贝数据量,数据量为0时不拷贝 + copyLen = inputLen[idx] - queElemLen * sliceIdx; + if (copyLen > queElemLen) { + copyLen = queElemLen; + } else if (copyLen < 0) { + copyLen = 0; + } + writeQue[idx].DeQue(waitRankListForWrite[idx], waitNumForWrite[idx], waitBlockForWrite[idx]); + writeGt = writeQue[idx].EnQue(); + if (copyLen > 0) { + CpGM2GMPingPong(copyLen * sizeof(T), readGt, writeGt, COPYONLY); + } + sync.SetInnerFlag(magic, sliceIdx, rank, groupCoreIdx[idx]); + } + + __aicore__ inline void ConsumerStage() + { + int64_t flagValue[MULTI_RANK_SIZE]; // 要等待标志位的储存值 + for (auto i = 0; i < rankNumPerCore; ++i) { + flagValue[i] = -1; + } + // 以最多切片sliceNum[0]为切片数进行循环,切片数不足的不拷贝 + for (auto sliceIdx = 0; sliceIdx < maxSliceNum; ++sliceIdx) { + for (auto i = 0; i < rankNumPerCore; ++i) { + if (targetRank[i] == INVALID_RANK_NUM) { + continue; + } + ShareToOutputSlice(i, sliceIdx, flagValue[i]); + } + } + } + + __aicore__ inline void ShareToOutputSlice(int64_t idx, int64_t sliceIdx, int64_t& flagValue) + { + // 计算当前切片拷贝数据量,数据量为0时不拷贝 + copyLen = outputLen[idx] - queElemLen * sliceIdx; + if (copyLen > queElemLen) { + copyLen = queElemLen; + } else if (copyLen < 0) { + copyLen = 0; + } + + // 拉取本rank数据 + if (flagValue < sliceIdx) { + sync.WaitInnerFlag(magic, sliceIdx, targetRank[idx], rank * coreNumPerRank + groupCoreIdx[idx]%coreNumPerRank); + flagValue = sync.GetInnerFlag(targetRank[idx], rank * coreNumPerRank + groupCoreIdx[idx]%coreNumPerRank) & EVENT_ID_MASK; + } + readGt = readQue[idx].ReadFront(); + if (copyLen > 0) { + writeGt = outputGt[sliceIdx * queElemLen + outputOffset[idx]]; + CpGM2GMPingPong(copyLen * sizeof(T), readGt, writeGt, COPYONLY); + } + sync.SetInnerFlag(magic, sliceIdx, rank, groupCoreIdx[idx] + flagNumPerStage); + } + + GlobalTensor inputGt; + GlobalTensor outputGt; + GlobalTensor readGt; + GlobalTensor writeGt; + GlobalTensor sendCountMatrixGm; + int64_t maxSliceNum; + int64_t revLen = 0; + int64_t sendLen = 0; + int64_t sendOffset[MULTI_RANK_SIZE]; + int64_t revOffset[MULTI_RANK_SIZE]; + int64_t inputDataLen[MULTI_RANK_SIZE]; + int64_t outputDataLen[MULTI_RANK_SIZE]; + + int waitRankListForWrite[MULTI_RANK_SIZE][1]; // 写共享内存时,需要等待的rank列表 + int waitNumForWrite[MULTI_RANK_SIZE]; // 写共享内存时,需要等待的数量 + int waitBlockForWrite[MULTI_RANK_SIZE]; // 写共享内存时,需要等待的标志位 + + int64_t queLen; + int64_t queSize; + int64_t coreNumPerStage; // 每个阶段使用的核数 + int64_t flagNumPerStage; // 每个阶段使用的同步标志位数 + int64_t coreNumPerRank; // 每个rank数据分配的核数 + int64_t rankNumPerCore; // 每个核负责的rank数 + int64_t coreGroup; // 当前核的功能分组 + int64_t groupCoreIdx[MULTI_RANK_SIZE]; // 当前核在组内的索引,可以为等效核索引 + int64_t targetRank[MULTI_RANK_SIZE]; // 当前核负责的rank + + IpcQueue readQue[MULTI_RANK_SIZE]; // 读端共享内存队列 + IpcQueue writeQue[MULTI_RANK_SIZE]; // 写端共享内存队列 + int64_t queElemLen; // 共享内存队列里每个元素大小(以T计) + + int64_t sliceNum[MULTI_RANK_SIZE]; // 当前核负责的数据切片总数 + int64_t copyLen; // 当前拷贝数据片的长度(以T计) + int64_t inputOffset[MULTI_RANK_SIZE]; // 当前核负责的input偏移(以T计) + int64_t inputLen[MULTI_RANK_SIZE]; // 当前核负责的input长度(以T计) + int64_t outputOffset[MULTI_RANK_SIZE]; // 当前核负责的output偏移(以T计) + int64_t outputLen[MULTI_RANK_SIZE]; // 当前核负责的output长度(以T计) +}; + +#endif // LCCL_ALL2ALLVC_MESH_H diff --git a/cust_op/lccl/op_kernel/lccl_all2all_vc_mesh.cpp b/cust_op/lccl/op_kernel/lccl_all2all_vc_mesh.cpp new file mode 100644 index 00000000..3b222678 --- /dev/null +++ b/cust_op/lccl/op_kernel/lccl_all2all_vc_mesh.cpp @@ -0,0 +1,30 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel_operator.h" +#include "all2all_vc_mesh.h" + +extern "C" __global__ __aicore__ void lccl_all_to_all_v_c_mesh(GM_ADDR send_data, GM_ADDR send_count_matrix, + GM_ADDR shape_vec, GM_ADDR peer_mem, GM_ADDR rev_data, + GM_ADDR workspace, GM_ADDR tiling) { + GET_TILING_DATA(tiling_data, tiling); + + All2AllVCMesh opKernel(tiling_data.rank, tiling_data.rankSize, (1 << 2)); + opKernel.Init(send_data, send_count_matrix, shape_vec, peer_mem, rev_data, + tiling_data.rank, tiling_data.rankSize, tiling_data.magic); + + opKernel.Process(); +} \ No newline at end of file diff --git a/cust_op/lccl/run.sh b/cust_op/lccl/run.sh index c1820505..51e03992 100644 --- a/cust_op/lccl/run.sh +++ b/cust_op/lccl/run.sh @@ -23,6 +23,7 @@ rm -rf ./custom_op msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b1 -lan cpp -out ./custom_op -m 0 -op LcclAllToAll msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b1 -lan cpp -out ./custom_op -m 1 -op LcclAllUss msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b1 -lan cpp -out ./custom_op -m 1 -op LcclGatherAll +msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910_95 -lan cpp -out ./custom_op -m 1 -op LcclAllToAllVCMesh cp -rf op_kernel custom_op/ cp -rf op_host custom_op/ diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 2cc7e813..14167838 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -716,6 +716,28 @@ namespace tensorflow { }); REGISTER_KERNEL_BUILDER(Name("LcclAllToAll").Device(DEVICE_CPU), MxRec::CustOps); + REGISTER_OP("LcclAllToAllVCMesh") + .Input("send_data: float") + .Input("send_count_matrix: int64") + .Input("shape_vec: int32") + .Input("peer_mem: int64") + .Attr("rank: int") + .Attr("rank_size: int") + .Attr("dim: int") + .Output("rev_data: float") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ShapeHandle dataShape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &dataShape)); + tensorflow::shape_inference::DimensionHandle rows = c->Dim(dataShape, 0); + int64_t shape1 = c->Value(rows); + int dim = 0; + c->GetAttr("dim", &dim); + c->set_output(0, c->MakeShape({shape1, dim, 1})); + return Status::OK(); + }); + REGISTER_KERNEL_BUILDER(Name("LcclAllToAllVCMesh").Device(DEVICE_CPU), MxRec::CustOps); + REGISTER_OP("LcclGatherAll") .Input("emb_table: float") .Input("lookup: int32") -- Gitee From 501478d1db0de9b12f7faa6a76c92871ae0bda0d Mon Sep 17 00:00:00 2001 From: y00806855 Date: Sat, 28 Jun 2025 11:58:36 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=80=9A=E4=BF=A1?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cust_op/lccl/op_kernel/sync_collectives.h | 288 ++++++++++++++-------- 1 file changed, 186 insertions(+), 102 deletions(-) diff --git a/cust_op/lccl/op_kernel/sync_collectives.h b/cust_op/lccl/op_kernel/sync_collectives.h index 0a821fa9..cb8144ec 100644 --- a/cust_op/lccl/op_kernel/sync_collectives.h +++ b/cust_op/lccl/op_kernel/sync_collectives.h @@ -1,25 +1,19 @@ -/** - * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. */ - #ifndef LCCL_SYNC_H #define LCCL_SYNC_H #include "comm_args.h" using namespace AscendC; +using namespace Lcal; // 同步标志位占用长度 constexpr int64_t FLAG_UNIT_INT_NUM = 4; @@ -28,88 +22,169 @@ constexpr int64_t SYNC_UNIT_SIZE = FLAG_UNIT_INT_NUM * sizeof(int64_t); // magic作为比较值时高位偏移量 constexpr int64_t MAGIC_OFFSET = 32; +// 多轮循环复用ipcBuff时,magic初始左移位数 +constexpr int64_t MAGIC_ORIGIN_OFFSET = 10; +constexpr int64_t MAGIC_MASK = ~((1LL << MAGIC_OFFSET) - 1); + class SyncCollectives { public: - __aicore__ inline SyncCollectives() {} + __aicore__ inline SyncCollectives() + {} - __aicore__ inline void Init(int rank, int rankSize, GM_ADDR *shareAddrs, int blockIdx, int blockNum) + __aicore__ inline void Init(int rank, int rankSize, GM_ADDR *shareAddrs, TBuf &tBuf) { this->rank = rank; this->rankSize = rankSize; this->shareAddrs = shareAddrs; - this->blockIdx = blockIdx; - this->blockNum = blockNum; + this->blockIdx = block_idx; + this->blockNum = block_num; // 单个标志段长度 - segmentCount = blockNum * FLAG_UNIT_INT_NUM; + segmentCount = block_num * FLAG_UNIT_INT_NUM; // 初始化当前核对应的卡内/卡间同步地址 - blockInnerSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + blockIdx * FLAG_UNIT_INT_NUM; - blockOuterSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + segmentCount + blockIdx * FLAG_UNIT_INT_NUM; - // 初始化标志位数据搬运队列,一次最多可搬运blockNum个标志 - pipe.InitBuffer(syncSetQue, PING_PONG_SIZE, blockNum * SYNC_UNIT_SIZE); - pipe.InitBuffer(syncWaitQue, PING_PONG_SIZE, blockNum * SYNC_UNIT_SIZE); + localSyncAddr = (__gm__ int64_t *)(shareAddrs[rank]); + basicSyncAddr = (__gm__ int64_t *)(shareAddrs[rank]) + block_idx * FLAG_UNIT_INT_NUM; + blockOuterSyncAddr = (__gm__ int64_t *)(shareAddrs[rank]) + segmentCount + block_idx * FLAG_UNIT_INT_NUM; + this->tBuf = tBuf; + } + + __aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID) + { + int64_t v = MergeMagicWithValue(magic, value); + SetFlag(localSyncAddr + eventID * FLAG_UNIT_INT_NUM, v); + } + + /** + * @brief 设置指定卡的指定eventID的flag,设置的值为 magic 和 value 组合而成的值。 + * @param magic 算子批次,最终会组合到要set的flag的数值中高32位去 + * @param value 具体的最终要set的flag的数值中低32位的值 + * @param eventID 实际上从物理地址来看,是以共享内存首地址起往后的偏移量(要进行缩放,不是偏移量绝对值)。 + * @param rank 这个rank是在CommArgs结构体内peerMems数组内对应的rankId,并非global或local的id。 + * (91093场景local不适用,910B多机场景global不适用。) + */ + __aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank) + { + int64_t v = MergeMagicWithValue(magic, value); + SetFlag((__gm__ int64_t *)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, v); + } + + __aicore__ inline int32_t CalEventIdByMulBlockNum(int32_t blockMultiplier, int32_t targetCoreId) + { + return (blockMultiplier * blockNum) + targetCoreId; + } + + /** + * @brief 等待指定卡的指定eventID的flag变为 magic 和 value 组合而成的值。 + * @param magic 算子批次,最终会组合到要wait的flag的数值中高32位去 + * @param value 具体的最终要wait的flag的数值中低32位的值 + * @param eventID 实际上从物理地址来看,是以共享内存首地址起往后的偏移量。 + * @param rank 这个rank是在CommArgs结构体内peerMems数组内对应的rankId,并非global或local的id。 + * (91093场景local不适用,910B多机场景global不适用。) + */ + __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v); + } + + /** + * @brief 相比起WaitSyncFlag函数,额外允许 远端Flag > 期望要check的FlagValue的值 通过校验。 + */ + __aicore__ inline void WaitSyncGreaterFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v, false); + } + + __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[this->rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v); + } + + __aicore__ inline void WaitSyncGreaterFlag(int32_t magic, int32_t value, int32_t eventID) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[this->rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v, false); + } + + /** + * @brief 等待指定卡的指定eventID往后的flagNum个flag变为 magic 和 value 组合而成的值。
+ * 注:[eventID, eventID + flagNum) + */ + __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank, int64_t flagNum) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, flagNum, v); + } + + __aicore__ inline void WaitSyncGreaterFlag( + int32_t magic, int32_t value, int32_t eventID, int32_t rank, int64_t flagNum) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, flagNum, v, false); } // 设置单个卡内同步标志(内存A) __aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID) { - int64_t value = GetFlagValue(magic, eventID); - SetFlag(blockInnerSyncAddr, value); + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag(basicSyncAddr, value); } __aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock) { - int64_t value = GetFlagValue(magic, eventID); - SetFlag((__gm__ int64_t*)(shareAddrs[setRank]) + setBlock * FLAG_UNIT_INT_NUM, value); + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag((__gm__ int64_t *)(shareAddrs[setRank]) + setBlock * FLAG_UNIT_INT_NUM, value); } // 等待单个卡内同步标志(内存A) __aicore__ inline void WaitInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock) { - int64_t value = GetFlagValue(magic, eventID); - WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM, 1, value); + int64_t value = MergeMagicWithValue(magic, eventID); + WaitOneRankPartFlag((__gm__ int64_t *)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM, 1, value); } // 等待整个rank内所有卡内同步标志(内存A) __aicore__ inline void WaitRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank) { - int64_t value = GetFlagValue(magic, eventID); - WaitOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value); + int64_t value = MergeMagicWithValue(magic, eventID); + WaitOneRankAllFlag((__gm__ int64_t *)(shareAddrs[waitRank]), value); } // 检验整个rank内所有卡内同步标志(内存A) __aicore__ inline bool CheckRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank) { - int64_t value = GetFlagValue(magic, eventID); - return CheckOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value); + int64_t value = MergeMagicWithValue(magic, eventID); + return CheckOneRankAllFlag((__gm__ int64_t *)(shareAddrs[waitRank]), value); } // 设置单个卡间同步标志(内存B) __aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID) { - int64_t value = GetFlagValue(magic, eventID); + int64_t value = MergeMagicWithValue(magic, eventID); SetFlag(blockOuterSyncAddr, value); } __aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock) { - __gm__ int64_t* flagAddr = GetOuterFlagAddr(setRank, setBlock); - int64_t value = GetFlagValue(magic, eventID); + __gm__ int64_t *flagAddr = GetOuterFlagAddr(setRank, setBlock); + int64_t value = MergeMagicWithValue(magic, eventID); SetFlag(flagAddr, value); } // 等待单个卡间同步标志(内存B) __aicore__ inline void WaitOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock) { - int64_t value = GetFlagValue(magic, eventID); - __gm__ int64_t* flagAddr = GetOuterFlagAddr(waitRank, waitBlock); + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr = GetOuterFlagAddr(waitRank, waitBlock); WaitOneRankPartFlag(flagAddr, 1, value); } // 等待整个rank内所有卡间同步标志(内存B) __aicore__ inline void WaitOneRankOuterFlag(int32_t magic, int32_t eventID, int64_t rank) { - int64_t value = GetFlagValue(magic, eventID); - __gm__ int64_t* flagAddr; + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; flagAddr = GetOuterFlagAddr(rank, 0); WaitOneRankPartFlag(flagAddr, blockNum, value); } @@ -117,8 +192,8 @@ public: // 等待所有rank从startBlock开始的flagNum个卡间同步标志(内存B) __aicore__ inline void WaitAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock, int64_t flagNum) { - int64_t value = GetFlagValue(magic, eventID); - __gm__ int64_t* flagAddr; + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; int waitRank; for (auto r = 0; r < rankSize; ++r) { waitRank = (rank + r) % rankSize; // 错峰读取rank标志,防止多核并发拷贝影响性能 @@ -128,11 +203,11 @@ public: } // 检验所有rank从startBlock开始的flagNum个卡间同步标志(内存B) - __aicore__ inline bool CheckAllRankPartOuterFlag(int32_t magic, int32_t eventID, - int64_t startBlock, int64_t flagNum) + __aicore__ inline bool CheckAllRankPartOuterFlag( + int32_t magic, int32_t eventID, int64_t startBlock, int64_t flagNum) { - int64_t value = GetFlagValue(magic, eventID); - __gm__ int64_t* flagAddr; + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; int waitRank; for (auto r = 0; r < rankSize; ++r) { waitRank = (rank + r) % rankSize; // 错峰读取rank标志,防止多核并发拷贝影响性能 @@ -157,55 +232,52 @@ public: } // 低级接口,设置同步标志 - __aicore__ inline void SetFlag(__gm__ int64_t* setAddr, int64_t setValue) + __aicore__ inline void SetFlag(__gm__ int64_t *setAddr, int64_t setValue) { - set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); - set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); GlobalTensor globalSet; globalSet.SetGlobalBuffer(setAddr, FLAG_UNIT_INT_NUM); - LocalTensor localSet = syncSetQue.AllocTensor(); + LocalTensor localSet = tBuf.GetWithOffset(1, 0); localSet.SetValue(0, setValue); // 将global同步标识拷贝至local - set_flag(PIPE_S, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID0); // 等待SetValue完成 + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // 等待SetValue完成 DataCopy(globalSet, localSet, FLAG_UNIT_INT_NUM); - set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); // 等待UB->GM完成 - - syncSetQue.FreeTensor(localSet); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // 等待UB->GM完成 } // 低级接口,等待同步标志 - __aicore__ inline void WaitFlag(__gm__ int64_t* waitAddr, int64_t waitValue) + __aicore__ inline void WaitFlag(__gm__ int64_t *waitAddr, int64_t waitValue) { WaitOneRankPartFlag(waitAddr, 1, waitValue); } // 读取一个标志位,返回立即数 - __aicore__ inline int64_t GetFlag(__gm__ int64_t* waitAddr) + __aicore__ inline int64_t GetFlag(__gm__ int64_t *waitAddr) { GlobalTensor globalWait; globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM); - LocalTensor localWait = syncWaitQue.AllocTensor(); + LocalTensor localWait = tBuf.GetWithOffset(1, 0); // 将global拷贝至local DataCopy(localWait, globalWait, FLAG_UNIT_INT_NUM); - set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); // 等待GM->UB + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // 等待GM->UB int64_t res = localWait.GetValue(0); - syncWaitQue.FreeTensor(localWait); return res; } // 获取单个卡内多个连续的同步标志 - __aicore__ inline void WaitOneRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, - int64_t startBlock, int64_t flagNum) + __aicore__ inline void WaitOneRankPartOuterFlag( + int32_t magic, int32_t eventID, int64_t waitRank, int64_t startBlock, int64_t flagNum) { - int64_t value = GetFlagValue(magic, eventID); - __gm__ int64_t* flagAddr; + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; flagAddr = GetOuterFlagAddr(waitRank, startBlock); WaitOneRankPartFlag(flagAddr, flagNum, value); } @@ -213,88 +285,101 @@ public: // 获取单个卡内同步标志(内存A) __aicore__ inline int64_t GetInnerFlag(int64_t waitRank, int64_t waitBlock) { - return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM); + return GetFlag((__gm__ int64_t *)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM); } __aicore__ inline int64_t GetOuterFlag(int64_t waitRank, int64_t waitBlock) { - return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + segmentCount + waitBlock * FLAG_UNIT_INT_NUM); + return GetFlag((__gm__ int64_t *)(shareAddrs[waitRank]) + segmentCount + waitBlock * FLAG_UNIT_INT_NUM); } -public: - __aicore__ inline int64_t GetFlagValue(int32_t magic, int32_t eventID) +private: + __aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value) { // magic作为高位,eventID作为低位,组成一个value值用于比较 - return (static_cast(magic) << MAGIC_OFFSET) + static_cast(eventID); + return (static_cast(magic) << MAGIC_OFFSET) | static_cast(value); } - __aicore__ inline __gm__ int64_t* GetInnerFlagAddr(int64_t flagRank, int64_t flagBlock) + __aicore__ inline __gm__ int64_t *GetInnerFlagAddr(int64_t flagRank, int64_t flagBlock) { - return (__gm__ int64_t*)(shareAddrs[flagRank]) + flagBlock * FLAG_UNIT_INT_NUM; + return (__gm__ int64_t *)(shareAddrs[flagRank]) + flagBlock * FLAG_UNIT_INT_NUM; } - __aicore__ inline __gm__ int64_t* GetOuterFlagAddr(int64_t flagRank, int64_t flagBlock) + __aicore__ inline __gm__ int64_t *GetOuterFlagAddr(int64_t flagRank, int64_t flagBlock) { - return (__gm__ int64_t*)(shareAddrs[flagRank]) + segmentCount + flagBlock * FLAG_UNIT_INT_NUM; + return (__gm__ int64_t *)(shareAddrs[flagRank]) + segmentCount + flagBlock * FLAG_UNIT_INT_NUM; } - // 等待一个rank内部分同步标志 - __aicore__ inline void WaitOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue) + /** + * @brief 等待一个rank内部分同步标志 + * @param int64_t waitAddr 等待的首个标志位的地址(含) + * @param int64_t flagNum 等待的标志位个数 + * @param int64_t checkValue checkValue + * @param bool mustEqual 用于当远端flagValue大于等于当前checkValue时,控制进一步判断逻辑。
+ * true表示相等,即MAGIC_MASK掩码部分必须严格相等;false表示可以接受远端的掩码部分大于等于checkValue的掩码部分。 + * @return + */ + __aicore__ inline void WaitOneRankPartFlag( + __gm__ int64_t *waitAddr, int64_t flagNum, int64_t checkValue, bool mustEqual = true) { GlobalTensor globalWait; globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM); - LocalTensor localWait = syncWaitQue.AllocTensor(); + LocalTensor localWait = tBuf.GetWithOffset(flagNum * FLAG_UNIT_INT_NUM, 0); bool isSync = true; + int64_t checkedFlagNum = 0; do { + int64_t remainToCheck = flagNum - checkedFlagNum; // 将global同步标识拷贝至local - DataCopy(localWait, globalWait, flagNum * FLAG_UNIT_INT_NUM); - set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); // 等待GM->UB + DataCopy(localWait, globalWait[checkedFlagNum * FLAG_UNIT_INT_NUM], remainToCheck * FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // 等待GM->UB // 检验同步标识是否为checkValue isSync = true; - for (auto i = 0; i < flagNum; ++i) { + for (auto i = 0; i < remainToCheck; ++i) { // 当有core未达到checkValue的阶段时,继续等待 - if (localWait.GetValue(i * FLAG_UNIT_INT_NUM) < checkValue) { + int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM); + if ((mustEqual && (v < checkValue || ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK)))) || + ((!mustEqual) && (v < checkValue || ((v & MAGIC_MASK) < (checkValue & MAGIC_MASK))))) { isSync = false; break; } + checkedFlagNum++; } } while (!isSync); - syncWaitQue.FreeTensor(localWait); } // 等待一个rank内所有同步标志 - __aicore__ inline void WaitOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue) + __aicore__ inline void WaitOneRankAllFlag(__gm__ int64_t *waitAddr, int64_t checkValue) { WaitOneRankPartFlag(waitAddr, blockNum, checkValue); } // 检验一个rank内部分同步标志,仅拷贝一次 - __aicore__ inline bool CheckOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue) + __aicore__ inline bool CheckOneRankPartFlag(__gm__ int64_t *waitAddr, int64_t flagNum, int64_t checkValue) { GlobalTensor globalWait; globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM); - LocalTensor localWait = syncWaitQue.AllocTensor(); + LocalTensor localWait = tBuf.GetWithOffset(flagNum * FLAG_UNIT_INT_NUM, 0); // 将global同步标识拷贝至local DataCopy(localWait, globalWait, flagNum * FLAG_UNIT_INT_NUM); - set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); // 等待GM->UB + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); // 等待GM->UB // 检验同步标识是否为checkValue bool isSync = true; for (auto i = 0; i < flagNum; ++i) { // 当有core未达到checkValue的阶段时,继续等待 - if (localWait.GetValue(i * FLAG_UNIT_INT_NUM) < checkValue) { + int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM); + if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) { isSync = false; break; } } - syncWaitQue.FreeTensor(localWait); return isSync; } // 检验一个rank内所有同步标志,仅拷贝一次 - __aicore__ inline bool CheckOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue) + __aicore__ inline bool CheckOneRankAllFlag(__gm__ int64_t *waitAddr, int64_t checkValue) { return CheckOneRankPartFlag(waitAddr, blockNum, checkValue); } @@ -304,11 +389,10 @@ public: int blockNum; GM_ADDR *shareAddrs; int64_t segmentCount; // 一组同步标志段的长度(int64_t类型计数) - __gm__ int64_t* blockInnerSyncAddr; // 当前block卡内同步标志地址 - __gm__ int64_t* blockOuterSyncAddr; // 当前block卡间同步标志地址 - TQue syncSetQue; // 从local拷贝同步标志至global的队列 - TQue syncWaitQue; // 从global拷贝同步标志至local的队列 - TPipe pipe; + __gm__ int64_t *localSyncAddr; + __gm__ int64_t *basicSyncAddr; // 当前block卡内同步标志地址 + __gm__ int64_t *blockOuterSyncAddr; // 当前block卡间同步标志地址 + TBuf tBuf; }; -#endif // LCCL_SYNC_H +#endif // LCCL_SYNC_H \ No newline at end of file -- Gitee