diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index c06fe82d7d2a9bfe5991da01c181731dd6bef56c..2580a07c3ba3a8b6e1e936b68e0c4ad71c63b333 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -116,8 +116,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, hybridMgmtBlock->SetRankInfo(rankInfo); // 启动数据处理线程 - KEY_PROCESS_INSTANCE->Initialize(rankInfo, embInfos, thresholdValues, seed, - isIncrementalCheckpoint); + KEY_PROCESS_INSTANCE->Initialize(rankInfo, embInfos, thresholdValues, seed, isIncrementalCheckpoint); isRunning = true; isL3StorageEnabled = rankInfo.isSSDEnabled; @@ -169,7 +168,7 @@ void HybridMgmt::Save(const string& savePath, bool saveDelta) throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } string saveModelType = - saveDelta ? TransferModelType2Str(SaveModelType::DELTA) : TransferModelType2Str(SaveModelType::BASE); + saveDelta ? TransferModelType2Str(SaveModelType::DELTA) : TransferModelType2Str(SaveModelType::BASE); LOG_INFO(MGMT + "Start to save {} model to {}.", saveModelType, savePath); // 数据处理线程上锁 @@ -212,8 +211,6 @@ void HybridMgmt::Save(const string& savePath, bool saveDelta) LOG_INFO(MGMT + "End to save {} model.", saveModelType); // 数据处理线程释放锁 KEY_PROCESS_INSTANCE->LoadSaveUnlock(); - hybridMgmtBlock->FinishSave(); - cvCheckSave.notify_all(); #endif } @@ -443,23 +440,7 @@ void HybridMgmt::Destroy() } // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 isRunning = false; - mutexDestroy = true; - for (const auto& embInfo : mgmtEmbInfo) { - for (int index = 0; index < EMBEDDING_THREAD_NUM; index++) { - string trainKey = MakeKeyName(index, embInfo.name, TRAIN_CHANNEL_ID); - lastUpdateFinishCV[trainKey].notify_all(); - lastLookUpFinishCV[trainKey].notify_all(); - lastSendFinishCV[trainKey].notify_all(); - lastRecvFinishCV[trainKey].notify_all(); - string evalKey = MakeKeyName(index, embInfo.name, EVAL_CHANNEL_ID); - lastUpdateFinishCV[evalKey].notify_all(); - lastLookUpFinishCV[evalKey].notify_all(); - lastSendFinishCV[evalKey].notify_all(); - lastRecvFinishCV[evalKey].notify_all(); - } - } - cvCheckSave.notify_all(); // 防止save异常退出场景阻塞在EvalTask { // 获取锁 避免KeyProcess中手动发送结束信息时通道关闭 @@ -471,15 +452,20 @@ void HybridMgmt::Destroy() LOG_DEBUG(MGMT + "destroy hdTransfer end."); } + JoinEmbeddingCacheThread(); + LOG_DEBUG(MGMT + "destroy EmbeddingCacheThread end."); + hybridMgmtBlock->Destroy(); for (auto& t : procThreads) { t->join(); } + procThreads.clear(); + LOG_DEBUG(MGMT + "destroy parseKeyThread end."); + if (cacheManager != nullptr) { cacheManager = nullptr; } - JoinEmbeddingCacheThread(); - procThreads.clear(); + // 等待并销毁接收key的线程 for (auto& t : receiveKeyThreads) { t.join(); @@ -499,21 +485,19 @@ void HybridMgmt::Destroy() void HybridMgmt::TrainTask(TaskType type) { #ifndef GTEST - int channelId = TRAIN_CHANNEL_ID; - int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[channelId]; + int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[TRAIN_CHANNEL_ID]; do { - hybridMgmtBlock->CheckAndSetBlock(channelId); - if (hybridMgmtBlock->GetBlockStatus(channelId)) { - hybridMgmtBlock->DoBlock(channelId); + hybridMgmtBlock->CheckAndSetBlock(TRAIN_CHANNEL_ID); + if (hybridMgmtBlock->GetBlockStatus(TRAIN_CHANNEL_ID)) { + hybridMgmtBlock->DoBlock(TRAIN_CHANNEL_ID); } if (!isRunning) { return; } - LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, theTrainBatchId); + LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", TRAIN_CHANNEL_ID, theTrainBatchId); if (isBackUpTrainStatus) { RecoverTrainStatus(); } - ParseKeys(TRAIN_CHANNEL_ID, theTrainBatchId, type); } while (true); #endif @@ -525,23 +509,16 @@ void HybridMgmt::TrainTask(TaskType type) void HybridMgmt::EvalTask(TaskType type) { #ifndef GTEST - int channelId = EVAL_CHANNEL_ID; - int& evalBatchId = hybridMgmtBlock->hybridBatchId[channelId]; + int& evalBatchId = hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]; do { - hybridMgmtBlock->CheckAndSetBlock(channelId); - if (hybridMgmtBlock->GetBlockStatus(channelId)) { - LOG_DEBUG("eval channel block at batchId:{}, needWaitSave:{}", evalBatchId, - hybridMgmtBlock->IsNeedWaitSave()); - std::unique_lock checkSaveLocker(saveMutex); - cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); - - LOG_DEBUG("wake TrainTask"); - hybridMgmtBlock->DoBlock(channelId); + hybridMgmtBlock->CheckAndSetBlock(EVAL_CHANNEL_ID); + if (hybridMgmtBlock->GetBlockStatus(EVAL_CHANNEL_ID)) { + hybridMgmtBlock->DoBlock(EVAL_CHANNEL_ID); } if (!isRunning) { return; } - LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, evalBatchId); + LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", EVAL_CHANNEL_ID, evalBatchId); ParseKeys(EVAL_CHANNEL_ID, evalBatchId, type); } while (true); @@ -595,8 +572,8 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId, TaskType type) threadPool->enqueueWithFuture([this, info]() { return ProcessEmbInfoDDR(info); }); remainResult.push_back(std::move(remainBatch)); } else { - std::future remainBatch = threadPool->enqueueWithFuture( - [this, info]() { return ProcessEmbInfoL3Storage(info); }); + std::future remainBatch = + threadPool->enqueueWithFuture([this, info]() { return ProcessEmbInfoL3Storage(info); }); remainResult.push_back(std::move(remainBatch)); } break; @@ -641,14 +618,13 @@ bool HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo& info, bool isGrad) bool isEos = false; auto infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec(info, ProcessedInfo::RESTORE, isEos); if (isEos) { - HandleEosCase(info, remainBatchOut); - return remainBatchOut; + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); + return false; } if (infoVecs == nullptr) { - LOG_INFO(MGMT + "table:{}, channelId:{} batchId:{}, ParseKeys infoVecs empty !", info.name, info.channelId, + LOG_WARN(MGMT + "table:{}, channelId:{} batchId:{}, ParseKeys infoVecs empty !", info.name, info.channelId, info.batchId); - remainBatchOut = false; - return remainBatchOut; + return false; } LOG_DEBUG("table:{}, channelId:{} batchId:{}, ParseKeysHBM GetInfoVec end", info.name, info.channelId, info.batchId); @@ -656,7 +632,7 @@ bool HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo& info, bool isGrad) // 动态shape场景下,获取all2all向量(通信量矩阵) SendAll2AllVec(info, remainBatchOut); if (!remainBatchOut) { - return remainBatchOut; + return false; } // 发送查询向量 @@ -671,7 +647,7 @@ bool HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo& info, bool isGrad) SendUniqKeysAndRestoreVecHBM(info, infoVecs, isGrad); } - // 发送恢复向量 + // 发送恢复向量和hotPos TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, info.channelId, info.name, info.batchId); LOG_DEBUG("table:{}, sendRestoreSyncTC(ms):{}, parseKeysTc HBM mode (ms):{}", info.name, @@ -699,7 +675,13 @@ bool HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info) // 只有在每次GetUniqueKeys的时候才知道上游是否已经EOS // 注意GetUniqueKeys与EOS关联,需要在ProcessEmbInfoDDR最先调用,如需调整位置,请参考并适配其他函数 // 获取GlobalUnique向量 - auto uniqueKeys = GetUniqueKeys(info, remainBatchOut); + bool isEos = false; + auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); + if (isEos) { + EosL1Que[info.name][info.channelId].Pushv(true); + LOG_DEBUG("Enqueue on EosL1Que, eos status! table:{}, batchId:{}, channelId:{}, EosL1Que size: {}", info.name, + info.batchId, info.channelId, EosL1Que[info.name][info.channelId].Size()); + } if (uniqueKeys.empty()) { return remainBatchOut; } @@ -979,11 +961,23 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName, int channelId) std::string swapOutName = embName + SWAP_OUT_STR; std::vector addrs; while (isRunning && lookupAddrSuccess) { + bool isEos = EosL1Que[embName][channelId].WaitAndPop(); if (!isRunning) { return; } + EosL2Que[embName][channelId].Pushv(isEos); + if (isEos) { + LOG_DEBUG("Enqueue on EosL2Que, eos status! table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, " + "EosL2Que.size: {}", + embName, id, channelId, EosL1Que[embName][channelId].Size(), EosL2Que[embName][channelId].Size()); + continue; + } + // swap in std::vector keys = HBMSwapKeyQue[swapInName][channelId].WaitAndPop(); + if (!isRunning) { + return; + } TimeCost lookupAddrsInTC; int rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); if (rc != H_OK) { @@ -1000,6 +994,9 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName, int channelId) keys = HBMSwapKeyQue[swapOutName][channelId].WaitAndPop(); TimeCost lookupAddrsOutTC; rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); + if (!isRunning) { + return; + } if (rc != H_OK) { lookupAddrSuccess = false; throw runtime_error("EmbeddingLookupAddrs failed! error code: " + std::to_string(rc)); @@ -1101,14 +1098,12 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) receiveKeyThreads.emplace_back([embInfo, this]() { while (isRunning) { TransferChannel transferName = TransferChannel::KEY_D2H; - size_t ret = hdTransfer->RecvOffsetsAcl(transferName, TRAIN_CHANNEL_ID, - embInfo.name); + size_t ret = hdTransfer->RecvOffsetsAcl(transferName, TRAIN_CHANNEL_ID, embInfo.name); if (ret == 0) { LOG_WARN("Receive empty data."); } else { LOG_INFO("Receive data success, get {} data size: {}.", embInfo.name, ret); - auto aclData = - acltdtGetDataItem(hdTransfer->aclDatasetsForIncrementalCkpt[embInfo.name], 0); + auto aclData = acltdtGetDataItem(hdTransfer->aclDatasetsForIncrementalCkpt[embInfo.name], 0); if (aclData == nullptr) { throw runtime_error("Acl get tensor data failed."); } @@ -1116,12 +1111,12 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) int64_t timeStamp = *ptr; int64_t globalStep = *(ptr + 1); LOG_INFO("Receive {} timeStamp: {}, global step: {}.", embInfo.name, timeStamp, globalStep); - // tensorflow获取的global step是从1开始的,但是在key process中batch id则是从0开始,因此,下面的info中的batchId需要用 - // globalStep - 1 - EmbBaseInfo info = {.batchId=static_cast(globalStep-1), .channelId=TRAIN_CHANNEL_ID, - .name=embInfo.name}; - unique_ptr> keyCountVecInfo = - KEY_PROCESS_INSTANCE->GetKCInfoVec(info); + // tensorflow获取的global step是从1开始的,但是在key process中batch + // id则是从0开始,因此,下面的info中的batchId需要用 globalStep - 1 + EmbBaseInfo info = {.batchId = static_cast(globalStep - 1), + .channelId = TRAIN_CHANNEL_ID, + .name = embInfo.name}; + unique_ptr> keyCountVecInfo = KEY_PROCESS_INSTANCE->GetKCInfoVec(info); if (keyCountVecInfo == nullptr) { LOG_ERROR("Get key count info vector is empty."); throw runtime_error("Get key count info vector is empty."); @@ -1133,8 +1128,8 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) for (int64 i = 0; i < keyCountSize; ++i) { keyCountVec.push_back(static_cast(keyCountVecTmp(i))); } - LOG_INFO("Emb table: {}, channel: {}, size is: {}, data: {}", - embInfo.name, TRAIN_CHANNEL_ID, keyCountSize, VectorToString(keyCountVec)); + LOG_INFO("Emb table: {}, channel: {}, size is: {}, data: {}", embInfo.name, TRAIN_CHANNEL_ID, + keyCountSize, VectorToString(keyCountVec)); // 更新delta表 std::lock_guard lock(keyCountUpdateMtx); @@ -1153,7 +1148,7 @@ void HybridMgmt::updateDeltaInfo(const string& embName, vector& keyCoun auto& embMap = deltaMap[embName]; for (int i = 0; i < keyCountSize; i += KEY_COUNT_ELEMENT_NUM) { emb_key_t key = keyCountVec[i]; - int64_t recentCount = keyCountVec[i+1]; + int64_t recentCount = keyCountVec[i + 1]; KeyInfo& keyInfo = embMap[key]; keyInfo.totalCount += recentCount; keyInfo.recentCount += recentCount; @@ -1189,14 +1184,14 @@ void HybridMgmt::EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo .extEmbeddingSize = embInfo.extEmbeddingSize, .channelId = channelId, .name = embInfo.name}; - vector h2dEmb; + vector h2dEmb; auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running"); + LOG_DEBUG("HybridMgmt is not running when [LookUpAndSendDDR], table:{}, batchId:{}, channel:{}", embInfo.name, + batchId, channelId); return; } - EmbeddingSendDDR(info, h2dEmb); } @@ -1218,10 +1213,11 @@ void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbI vector swapOutAddrs; auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running"); + LOG_DEBUG("HybridMgmt is not running or receive empty data when [EmbeddingReceiveDDR], table:{}, batchId:{}, " + "channel:{}", + embInfo.name, batchId, channelId); return; } - EmbeddingUpdateDDR(info, ptr, swapOutAddrs); } @@ -1242,7 +1238,8 @@ void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const E auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running"); + LOG_DEBUG("HybridMgmt is not running when [LookUpAndSendL3Storage], table:{}, batchId:{}, channel:{}", + embInfo.name, batchId, channelId); return; } @@ -1266,8 +1263,15 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons float* ptr = nullptr; vector swapOutAddrs; int64_t dims0 = 0; - EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0); + auto isSuccess = EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0); + if (!isSuccess) { + LOG_DEBUG( + "HybridMgmt is not running or receive empty data when [EmbeddingReceiveL3Storage], table:{}, batchId:{}, " + "channel:{}", + embInfo.name, batchId, channelId); + return; + } EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); } @@ -1284,7 +1288,14 @@ bool HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info) // 只有在每次GetUniqueKeys的时候才知道上游是否已经EOS // 注意GetUniqueKeys与EOS关联,需要在ProcessEmbInfoL3Storage最先调用,如需调整位置,请参考并适配其他函数 // 获取GlobalUnique向量 - auto uniqueKeys = GetUniqueKeys(info, remainBatchOut); + bool isEos = false; + auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); + if (isEos) { + EosL1Que[info.name][info.channelId].Pushv(true); + LOG_DEBUG("Enqueue on EosL1Que L3Storage, eos status! table:{}, batchId:{}, channelId:{}", info.name, + info.batchId, info.channelId); + } + if (uniqueKeys.empty()) { return remainBatchOut; } @@ -1359,6 +1370,8 @@ void HybridMgmt::InitDataPipelineForDDR(const string& embName) HBMSwapAddrsQue[embName + SWAP_IN_STR]; HBMSwapAddrsQue[embName + SWAP_OUT_STR]; + EosL1Que[embName]; + EosL2Que[embName]; // 初始化lookup线程 hybridMgmtBlock->lookUpSwapAddrsPushId[embName][TRAIN_CHANNEL_ID] = 0; // 此处初始化,避免多线程竞争导致计数错误 hybridMgmtBlock->lookUpSwapAddrsPushId[embName][EVAL_CHANNEL_ID] = 0; @@ -1380,6 +1393,9 @@ void HybridMgmt::InitDataPipelineForL3Storage(const string& embName, int extEmbe HBMSwapAddrsQue[embName + SWAP_IN_STR]; HBMSwapAddrsQue[embName + SWAP_OUT_STR]; + EosL1Que[embName]; + EosL2Que[embName]; + HBMSwapKeyQue[embName + ADDR_STR]; HBMSwapKeyForL3StorageQue[embName + SWAP_IN_STR]; HBMSwapKeyForL3StorageQue[embName + ADDR_STR]; @@ -1426,30 +1442,44 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) void HybridMgmt::JoinEmbeddingCacheThread() { - for (auto& p : HBMSwapAddrsQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : HBMSwapKeyQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : HBMSwapKeyForL3StorageQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : DDRSwapKeyQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : DDRSwapKeyForL3StorageQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : DDRSwapAddrsQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); + for (int channelId = 0; channelId < MAX_CHANNEL_NUM; channelId++) { + // Let ReceiveAndUpdate & LookupAndSend thread stop. + for (const auto& embInfo : mgmtEmbInfo) { + for (int index = 0; index < EMBEDDING_THREAD_NUM; index++) { + string key = MakeSwapCVName(index, embInfo.name, channelId); + lastUpdateFinishCV[key].notify_all(); + lastLookUpFinishCV[key].notify_all(); + lastSendFinishCV[key].notify_all(); + lastRecvFinishCV[key].notify_all(); + } + } + + for (auto& p : EosL1Que) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : EosL2Que) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : HBMSwapAddrsQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : HBMSwapKeyQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : HBMSwapKeyForL3StorageQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : DDRSwapKeyQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : DDRSwapKeyForL3StorageQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : DDRSwapAddrsQue) { + p.second[channelId].DestroyQueue(); + } } + for (auto& t : EmbeddingLookUpAndSendThreadPool) { t.join(); } @@ -1461,21 +1491,9 @@ void HybridMgmt::JoinEmbeddingCacheThread() } } -void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut) -{ - // Predict do not need to be blocked. - if (info.channelId == EVAL_CHANNEL_ID && alreadyTrainOnce) { - // Eval after train. - hybridMgmtBlock->SetBlockStatus(EVAL_CHANNEL_ID, true); - LOG_INFO("GetUniqueKeys get eos from eval channel, SetBlockStatus=true"); - } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); - remainBatchOut = false; -} - bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastRecvFinishLocker(lastRecvFinishMutex[currentKey]); lastRecvFinishCV[currentKey].wait(lastRecvFinishLocker, [info, this] { return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1483,6 +1501,17 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto if (!isRunning) { return false; } + bool isEos = EosL2Que[info.name][info.channelId].WaitAndPop(); + if (!isRunning) { + return false; + } + if (isEos) { + LOG_DEBUG("EmbeddingReceiveDDR get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, + info.channelId); + // It cannot return here after send eos, otherwise it will block the next round of switching. + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); + } + TimeCost EmbeddingRecvTC = TimeCost(); swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR][info.channelId].WaitAndPop(); @@ -1509,17 +1538,17 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto int64_t dims[dimNum]; acltdtGetDimsFromItem(aclData, dims, dimNum); - LOG_DEBUG("table:{}, accumulate batchId:{}, dims[0]:{}, swapOutAddrs size:{}", info.name, info.batchId, dims[0], - swapOutAddrs.size()); + LOG_DEBUG(MGMT + "In swap thread, finish receive d2h embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, dims[0]:{}, swapOutAddrs size:{}, EmbeddingRecvTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, dims[0], swapOutAddrs.size(), + EmbeddingRecvTC.ElapsedMS()); if (dims[0] != static_cast(swapOutAddrs.size())) { throw runtime_error("data dims[0] != swapOutKeys.size()"); } - - LOG_DEBUG("table:{}, accumulate batchId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name, info.batchId, - info.threadIdx, EmbeddingRecvTC.ElapsedMS()); hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastRecvFinishCV[nextKey].notify_all(); return true; @@ -1527,7 +1556,7 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutex[currentKey]); lastUpdateFinishCV[currentKey].wait(lastUpdateFinishLocker, [info, this] { return (hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1549,21 +1578,20 @@ void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr if (!swapOutAddrs.empty()) { sample = FloatPtrToLimitStr(swapOutAddrs.front(), info.extEmbeddingSize); // print first element } - LOG_DEBUG( - "table:{}, accumulate batchId:{}, thread:{}, receive d2hEmb, ext emb:{}, emb size:{}, emb samples:{}, " - "EmbeddingUpdateTC(ms):{}", - info.name.c_str(), info.batchId, info.threadIdx, info.extEmbeddingSize, swapOutAddrs.size(), sample, - EmbeddingUpdateTC.ElapsedMS()); + LOG_DEBUG(MGMT + "In swap thread, finish update d2h embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, ext emb:{}, emb size:{}, emb samples:{}, EmbeddingUpdateTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, info.extEmbeddingSize, swapOutAddrs.size(), + sample, EmbeddingUpdateTC.ElapsedMS()); } hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastUpdateFinishCV[nextKey].notify_all(); } bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutex[currentKey]); lastUpdateFinishCV[currentKey].wait(lastUpdateFinishLocker, [info, this] { return (hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId] >= info.batchId) || mutexDestroy; @@ -1586,15 +1614,18 @@ bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2d } hybridMgmtBlock->lastLookUpFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastLookUpFinishCV[nextKey].notify_all(); + LOG_DEBUG(MGMT + "In swap thread, finish embedding lookup, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}", + info.name, info.channelId, info.batchId, info.threadIdx); return true; } void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastSendFinishLocker(lastSendFinishMutex[currentKey]); lastSendFinishCV[currentKey].wait(lastSendFinishLocker, [info, this] { return (hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1603,21 +1634,20 @@ void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEm // 区分通道发送 hdTransfer->Send(TransferChannel::H2D, h2dEmb, info.channelId, info.name, info.batchId); hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastSendFinishCV[nextKey].notify_all(); - LOG_DEBUG("table:{}, accumulate batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name, info.batchId, - info.threadIdx, SendTC.ElapsedMS()); - // 对于end of sequence场景,key - // process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]++; - LOG_DEBUG("h2dNextBatchId, table:{}, channelId:{}, next batchId:{}", info.name, info.channelId, - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]); + LOG_DEBUG(MGMT + "In swap thread, finish send h2d embedding, table:{}, channelId:{}, batchId:{}, accumulate " + "batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", + info.name, info.channelId, hybridMgmtBlock->h2dSendBatchId[info.name][info.channelId], info.batchId, + info.threadIdx, SendTC.ElapsedMS()); + hybridMgmtBlock->h2dSendBatchId[info.name][info.channelId]++; } void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& embInfo, int channelId) { auto fn = [index, embInfo, channelId, this]() { + LOG_DEBUG(MGMT + "Create LookUpAndSendThread, table:{}, index:{}, channel:{}", embInfo.name, index, channelId); while (true) { lookUpAndSendBatchIdMtx[channelId].lock(); if (hybridMgmtBlock->lookUpAndSendTableBatchId[embInfo.name][channelId] % EMBEDDING_THREAD_NUM == index) { @@ -1633,6 +1663,9 @@ void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& em lookUpAndSendBatchIdMtx[channelId].unlock(); } if (!isRunning) { + LOG_DEBUG(MGMT + "Destroy LookUpAndSendThread, table:{}, index:{}, channel:{}, batchId:{}", + embInfo.name, index, channelId, + hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId]); return; } } @@ -1643,6 +1676,8 @@ void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& em void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& embInfo, int channelId) { auto fn = [index, embInfo, channelId, this]() { + LOG_DEBUG(MGMT + "Create ReceiveAndUpdateThread, table:{}, index:{}, channel:{}", embInfo.name, index, + channelId); while (true) { receiveAndUpdateBatchIdMtx[channelId].lock(); if (hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId] % EMBEDDING_THREAD_NUM == @@ -1659,6 +1694,9 @@ void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& receiveAndUpdateBatchIdMtx[channelId].unlock(); } if (!isRunning) { + LOG_DEBUG(MGMT + "Destroy ReceiveAndUpdateThread, table:{}, index:{}, channel:{}, batchId:{}", + embInfo.name, index, channelId, + hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId]); return; } } @@ -1669,7 +1707,7 @@ void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastRecvFinishLocker(lastRecvFinishMutex[currentKey]); lastRecvFinishCV[currentKey].wait(lastRecvFinishLocker, [info, this] { return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1677,6 +1715,17 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, if (!isRunning) { return false; } + bool isEos = EosL1Que[info.name][info.channelId].WaitAndPop(); + if (!isRunning) { + return false; + } + if (isEos) { + LOG_DEBUG("EmbeddingReceiveL3Storage get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, + info.channelId); + // It cannot return here after send eos, otherwise it will block the next round of switching. + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); + } + // DDR swap out key need to be removed LookUpAndRemoveAddrs(info); @@ -1704,15 +1753,15 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, size_t dimNum = acltdtGetDimNumFromItem(aclData); int64_t dims[dimNum]; acltdtGetDimsFromItem(aclData, dims, dimNum); - - LOG_DEBUG("table:{}, accumulate batchId:{}, channelId:{}, recv d2h, dims[0]:{}, swapOutAddrs.size:{}", info.name, - info.batchId, info.channelId, dims[0], swapOutAddrs.size()); dims0 = dims[0]; - LOG_DEBUG("table:{}, accumulate batchId:{}, channelId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name.c_str(), - info.batchId, info.channelId, info.threadIdx, EmbeddingRecvTC.ElapsedMS()); + LOG_DEBUG(MGMT + "In swap thread, finish receive d2h embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, dims[0]:{}, swapOutAddrs size:{}, EmbeddingRecvTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, dims[0], swapOutAddrs.size(), + EmbeddingRecvTC.ElapsedMS()); hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastRecvFinishCV[nextKey].notify_all(); return true; } @@ -1720,7 +1769,7 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, int64_t& dims0) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutex[currentKey]); lastUpdateFinishCV[currentKey].wait(lastUpdateFinishLocker, [info, this] { return (hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1742,9 +1791,10 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr throw runtime_error("memcpy_s failed, error code:" + to_string(rc)); } } - LOG_DEBUG("table:{}, accumulate batchId:{}, channelId:{}, thread:{}, EmbeddingUpdateTC(ms):{}", info.name.c_str(), - info.batchId, info.channelId, info.threadIdx, EmbeddingUpdateTC.ElapsedMS()); + LOG_DEBUG(MGMT + "In swap thread, finish update d2h DDR embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, EmbeddingUpdateTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, EmbeddingUpdateTC.ElapsedMS()); // L3Storage更新 TimeCost L3StorageUpdateTC = TimeCost(); std::vector swapOutL3StorageAddrOffs = @@ -1760,17 +1810,20 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr } cacheManager->UpdateL3StorageEmb(info.name, embPtr, extEmbeddingSize, swapOutL3StorageKeys, swapOutL3StorageAddrOffs); - LOG_DEBUG("table:{}, accumulate batchId:{}, channelId:{}, thread:{}, L3StorageUpdateTC(ms):{}", info.name.c_str(), - info.batchId, info.channelId, info.threadIdx, L3StorageUpdateTC.ElapsedMS()); + + LOG_DEBUG( + MGMT + "In swap thread, finish update d2h L3Storage embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, L3StorageUpdateTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, L3StorageUpdateTC.ElapsedMS()); hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastUpdateFinishCV[nextKey].notify_all(); } bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutex[currentKey]); lastUpdateFinishCV[currentKey].wait(lastUpdateFinishLocker, [info, this] { return (hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId] >= info.batchId) || mutexDestroy; @@ -1815,9 +1868,11 @@ bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vectorlastLookUpFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastLookUpFinishCV[nextKey].notify_all(); return true; @@ -1825,7 +1880,7 @@ bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastSendFinishLocker(lastSendFinishMutex[currentKey]); lastSendFinishCV[currentKey].wait(lastSendFinishLocker, [info, this] { return (hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1834,16 +1889,14 @@ void HybridMgmt::EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& // 区分通道发送 hdTransfer->Send(TransferChannel::H2D, h2dEmb, info.channelId, info.name, info.batchId); hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastSendFinishCV[nextKey].notify_all(); - LOG_DEBUG("table:{}, channelId:{}, accumulate batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name.c_str(), - info.channelId, info.batchId, info.threadIdx, SendTC.ElapsedMS()); + LOG_DEBUG(MGMT + "In swap thread, finish send h2d embedding, table:{}, channelId:{}, batchId:{}, accumulate " + "batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", + info.name, info.channelId, hybridMgmtBlock->h2dSendBatchId[info.name][info.channelId], info.batchId, + info.threadIdx, SendTC.ElapsedMS()); - // 对于end of sequence场景,key - // process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]++; - LOG_DEBUG("h2dNextBatchId, table:{}, channelId:{}, next batchId:{}", info.name, info.channelId, - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]); + hybridMgmtBlock->h2dSendBatchId[info.name][info.channelId]++; } void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& swapInKeys, @@ -1887,6 +1940,9 @@ void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, vectorL3Storage HBMSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR][info.channelId].Pushv(hbmSwapInfo.swapOutL3StorageKeys); HBMSwapKeyForL3StorageQue[info.name + ADDR_STR][info.channelId].Pushv(hbmSwapInfo.swapOutL3StorageAddrOffs); + + // normal status + EosL1Que[info.name][info.channelId].Pushv(false); } bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb) @@ -1909,20 +1965,20 @@ bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dE throw runtime_error("memcpy_s failed, error code:" + to_string(rc)); } } - LOG_DEBUG("table:{}, channel:{}, thread:{}, accumulate batchId:{}, send h2dEmb, emb size:{}, emb samples:{}, " - "embeddingLookupTC(ms):{}", - info.name.c_str(), info.channelId, info.threadIdx, info.batchId, swapInAddrs.size(), - FloatPtrToLimitStr(h2dEmbAddr, swapInAddrs.size() * info.extEmbeddingSize), - embeddingLookupTC.ElapsedMS()); + LOG_DEBUG( + "[BuildH2DEmbedding] table:{}, channel:{}, thread:{}, accumulate batchId:{}, emb size:{}, emb samples:{}, " + "embeddingLookupTC(ms):{}", + info.name.c_str(), info.channelId, info.threadIdx, info.batchId, swapInAddrs.size(), + FloatPtrToLimitStr(h2dEmbAddr, swapInAddrs.size() * info.extEmbeddingSize), embeddingLookupTC.ElapsedMS()); return true; } -vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut) +vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut, bool& isEos) { - bool isEos = false; auto uniqueKeys = KEY_PROCESS_INSTANCE->GetUniqueKeys(info, isEos); + // DDR eos send in swap pipeline. if (isEos) { - HandleEosCase(info, remainBatchOut); + remainBatchOut = false; return uniqueKeys; } if (uniqueKeys.empty()) { @@ -2079,13 +2135,16 @@ void HybridMgmt::EnqueueSwapInfo(const EmbBaseInfo& info, pair, { auto& swapInKeys = swapInKoPair.first; auto& swapOutKeys = swapOutKoPair.first; - - LOG_DEBUG("enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, - info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size()); HBMSwapKeyQue[info.name + SWAP_OUT_STR][info.channelId].Pushv(swapOutKeys); HBMSwapKeyQue[info.name + SWAP_IN_STR][info.channelId].Pushv(swapInKeys); CheckLookupAddrSuccessDDR(); + + EosL1Que[info.name][info.channelId].Pushv(false); + LOG_DEBUG("Enqueue on HBMSwapKeyQue and EosL1Que, table:{}, batchId:{}, channelId:{}, swapInSize:{}, " + "swapOutSize:{}, EosL1Que.size: {}", + info.name, info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size(), + EosL1Que[info.name][info.channelId].Size()); } void HybridMgmt::BackUpTrainStatus() diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 9cd5c9a92df33abf102c9190274b7e559806594c..0f7f099984822cc14bf782e5c0bc5aa57705f59b 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -17,13 +17,13 @@ See the License for the specific language governing permissions and #define MX_REC_EMB_MGMT_H #include -#include -#include -#include -#include #include #include +#include +#include #include +#include +#include #include "absl/container/flat_hash_map.h" #include "emb_table/embedding_table.h" @@ -37,8 +37,8 @@ See the License for the specific language governing permissions and #include "utils/config.h" #include "utils/singleton.h" #include "utils/task_queue.h" -#include "utils/time_cost.h" #include "utils/thread_pool.h" +#include "utils/time_cost.h" namespace MxRec { using namespace std; @@ -142,7 +142,8 @@ public: void ReceiveKeyThread(const EmbInfo& embInfo); - GTEST_PRIVATE : bool mutexDestroy{false}; +GTEST_PRIVATE : + bool mutexDestroy{false}; // LookupAndSend & ReceiveAndUpdate Condition_Variable_Wait stop. std::mutex lookUpAndSendBatchIdMtx[MAX_CHANNEL_NUM]; // train and eval std::mutex receiveAndUpdateBatchIdMtx[MAX_CHANNEL_NUM]; @@ -169,6 +170,9 @@ public: std::map>[MAX_CHANNEL_NUM]> HBMSwapAddrsQue; std::map>[MAX_CHANNEL_NUM]> DDRSwapAddrsQue; + std::map[MAX_CHANNEL_NUM]> EosL1Que; + std::map[MAX_CHANNEL_NUM]> EosL2Que; + std::mutex evictMut; std::map> trainKeysSet; @@ -181,9 +185,6 @@ public: std::map>> trainTestSwitchInfoStore{}; std::atomic lookupAddrSuccess{true}; - std::mutex saveMutex; - std::condition_variable cvCheckSave; - unique_ptr threadPool; void SetFeatureTypeForLoad(vector& loadFeatures); @@ -234,8 +235,8 @@ private: bool isRunning; bool isLoad{false}; bool isInitialized{false}; - bool alreadyTrainOnce = false; // 用于判断是否为predict模式 - bool isBackUpTrainStatus = false; // whether the train state has been backed up + bool alreadyTrainOnce = false; // 用于判断是否为predict模式 + bool isBackUpTrainStatus = false; // whether the train state has been backed up bool isIncrementalCkpt; map> deltaMap; absl::flat_hash_map keyBatchIdMap; @@ -259,8 +260,6 @@ private: void JoinEmbeddingCacheThread(); - void HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut); - bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs); void EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs); @@ -286,7 +285,7 @@ private: bool BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb); - vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut); + vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut, bool& isEos); vector GetRestoreVecSec(const EmbBaseInfo& info, bool& remainBatchOut); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index 8943044a158d9ddacd93468b5a4f89738ba97b2d..3069e073f9db8e638788a0d415c16404b5e71eb3 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -41,7 +41,6 @@ void HybridMgmtBlock::CheckAndSetBlock(int channelId) LOG_DEBUG(HYBRID_BLOCKING + "blocking by save saveInterval {} pythonBatchId {} hybridBatchId {}", saveInterval, pythonBatchId[channelId], hybridBatchId[channelId]); isBlock[TRAIN_CHANNEL_ID] = true; - finishSave = false; } if (stepsInterval[channelId] == -1) { return; @@ -175,7 +174,7 @@ void HybridMgmtBlock::ResetAll(int channelId) b.second[channelId] = 0; } // L3 data pipeline, swap - for (auto& b : h2dNextBatchId) { + for (auto& b : h2dSendBatchId) { b.second[channelId] = 0; } @@ -187,36 +186,6 @@ void HybridMgmtBlock::ResetAll(int channelId) LOG_DEBUG(HYBRID_BLOCKING + "after reset block status," " channelId:{}, pythonBatchId:{}, readEmbedBatchId:{}, hybridBatchId:{}", channelId, pythonBatchId[channelId], readEmbedBatchId[channelId], hybridBatchId[channelId]); - - LOG_DEBUG("Start to reset isNeedSendEos"); - Singleton::GetInstance()->SetEos(0, channelId); -} - -/// 检查当前的步数是否可以进行save -/// \return 0 is legal, 1 需要回退一步, -1 表示错误 -int HybridMgmtBlock::CheckSaveEmbMapValid() -{ - // 检查数据通道此时的HashMap是否被提前处理了 - if (pythonBatchId[lastRunChannelId] >= hybridBatchId[lastRunChannelId]) { - LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt is checking the step and checking that the parameters are normal. " - "The number of steps in the previous round is " - "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); - return 0; - } else if (pythonBatchId[lastRunChannelId] + 1 == hybridBatchId[lastRunChannelId]) { - // 在通道切换时,上一个通道处理的数据超出了python侧的调用 - LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt is checking the step, and the parameters have been processed one step " - "in advance. The number of steps in the previous round was " - "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); - - return 1; - } else { - // 在通道切换时,hybrid处理的数据还没有赶上python侧,此时需要等待hybrid处理完成 - LOG_DEBUG(HYBRID_BLOCKING + "ERROR FLAG lastRunChannelId {} hybridBatchId {}", lastRunChannelId, - hybridBatchId[lastRunChannelId]); - return -1; - } } bool HybridMgmtBlock::GetBlockStatus(int channelId) @@ -256,22 +225,4 @@ void HybridMgmtBlock::SetStepInterval(int trainStep, int evalStep) HybridMgmtBlock::~HybridMgmtBlock() { Destroy(); -} - -void HybridMgmtBlock::Wake(int channelId) -{ - isBlock[channelId] = false; -} - -bool HybridMgmtBlock::IsNeedWaitSave() -{ - if (saveInterval != 0 && saveInterval != -1 && hybridBatchId[TRAIN_CHANNEL_ID] % saveInterval == 0 && !finishSave) { - return true; - } - return false; -} - -void HybridMgmtBlock::FinishSave() -{ - finishSave = true; } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h index b80a284872abc12eab399b3ffa4ade08914bc154..65fb29f766f321651d302043a4bdb4639fc7824c 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -45,8 +45,8 @@ namespace MxRec { int maxTrainStep = 0; int stepsInterval[MAX_CHANNEL_NUM] = {0, 0}; // 通道i运行多少步后切换为通道j - std::map lookUpSwapAddrsPushId; // L2 pipeline, key->addr - std::map h2dNextBatchId; // L3 pipeline, use for eos + std::map lookUpSwapAddrsPushId; // L2 pipeline, key->addr + std::map h2dSendBatchId; // L3 pipeline, swap thread std::map lookUpAndSendTableBatchId; std::map receiveAndUpdateTableBatchId; std::map lastUpdateFinishStep; @@ -76,8 +76,6 @@ namespace MxRec { void ResetAll(int channelId); - int CheckSaveEmbMapValid(); - bool GetBlockStatus(int channelId); void SetBlockStatus(int channelId, bool block); @@ -90,19 +88,12 @@ namespace MxRec { void Destroy(); - void Wake(int channelId); - - bool IsNeedWaitSave(); - - void FinishSave(); - private: // 控制通道阻塞的变量 bool isBlock[MAX_CHANNEL_NUM] = {true, true}; // 控制训练了多少步进行保存的步数 int saveInterval = 0; RankInfo rankInfo; - bool finishSave = true; }; class HybridMgmtBlockingException : public std::exception { diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 6c82c84620d91889ab3f317ff69104e9ef8b399e..e1ef1018af7f1759879d25dadfdf9ab38ad25fb7 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -96,8 +96,7 @@ namespace MxRec { static std::vector m_cfgThresholds; // 用于判断阈值配置的有效性 static absl::flat_hash_map m_embStatus; // 用于“准入&淘汰”功能解耦 - - GTEST_PRIVATE : + GTEST_PRIVATE : // 解析m_table2Threshold bool ParseThresholdCfg(const std::vector& thresholdValues); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 888e35d92e2ab6a859b11ddd9519067aa2aa798f..a391d42c536c90483ce610230af2018ebcae4e19 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -12,24 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "key_process.h" #include #include -#include #include +#include +#include "emb_table/embedding_mgmt.h" +#include "hd_transfer/hd_transfer.h" +#include "ock_ctr_common/include/error_code.h" #include "utils/common.h" +#include "utils/config.h" #include "utils/logger.h" #include "utils/safe_queue.h" #include "utils/singleton.h" #include "utils/time_cost.h" -#include "utils/config.h" -#include "emb_table/embedding_mgmt.h" -#include "hd_transfer/hd_transfer.h" -#include "ock_ctr_common/include/error_code.h" using namespace std; using namespace chrono; @@ -43,8 +42,7 @@ void KeyProcess::SetupHotEmbUpdateStep() } bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, - const vector& thresholdValues, - int seed, bool isIncrementalCkpt) + const vector& thresholdValues, int seed, bool isIncrementalCkpt) { readySendEosCnt[TRAIN_CHANNEL_ID].store(0); readySendEosCnt[EVAL_CHANNEL_ID].store(0); @@ -261,7 +259,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) while (true) { TimeCost getAndProcessTC; TimeCost getBatchDataTC; - batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue + batch = GetBatchData(channel, threadId); // Get batch data from SingletonQueue. LOG_DEBUG("getBatchDataTC(ms):{}", getBatchDataTC.ElapsedMS()); if (batch == nullptr) { break; @@ -315,10 +313,11 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) if (!isKeyProcessTaskSuccess) { break; } - LOG_INFO(KEY_PROCESS "getAndProcessTC(ms):{}, key process cost:{}," - " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}", + LOG_INFO(KEY_PROCESS + "getAndProcessTC(ms):{}, key process cost:{}," + " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}, isEos:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, batch->name, - batch->channel, threadId, batch->batchId); + batch->channel, threadId, batch->batchId, batch->isEos); int queueIndex = threadId + (MAX_KEY_PROCESS_THREAD * batch->channel); auto batchQueue = SingletonQueue::GetInstances(queueIndex); batchQueue->PutDirty(move(batch)); @@ -329,9 +328,9 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) LOG_INFO(KEY_PROCESS "KeyProcessTask exit. rank:{} channelId:{}, threadId:{}", rankInfo.rankId, channel, threadId); } -void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, - vector & restore, vector & hotPos, - vector >& keyCount, vector& keyCountVec) +void KeyProcess::HashSplitHelper(const unique_ptr& batch, vector& splitKeys, vector& restore, + vector& hotPos, vector>& keyCount, + vector& keyCountVec) { TimeCost uniqueTc; // Deduplicate the Key, and model parallel requires bucketing, data parallel does not. @@ -464,6 +463,10 @@ KeysT KeyProcess::BroadcastGlobalDpIdUnique(const unique_ptr& batch, bool KeyProcess::KeyProcessTaskHelperForDp(unique_ptr& batch, int channel, int threadId) { + if (batch->isEos) { + HandleEos(batch, channel, threadId); + return true; + } vector splitKeys; vector restore; vector hotPos; @@ -542,6 +545,10 @@ bool KeyProcess::KeyProcessTaskHelperForDp(unique_ptr& batch, int cha bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) { + if (batch->isEos) { + HandleEos(batch, channel, threadId); + return true; + } vector splitKeys; vector restore; vector hotPos; @@ -572,10 +579,14 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) - if (!rankInfo.isDDR) { EmbeddingMgmt::Instance()->Key2Offset(batch->name, lookupKeys, channel); } + if (!rankInfo.isDDR) { + EmbeddingMgmt::Instance()->Key2Offset(batch->name, lookupKeys, channel); + } // Static all2all,need send count - if (!rankInfo.useStatic) { SendA2A(scAll, batch->name, batch->channel, batch->batchId); } + if (!rankInfo.useStatic) { + SendA2A(scAll, batch->name, batch->channel, batch->batchId); + } TimeCost pushResultTC; auto tensors = make_unique>(); @@ -590,6 +601,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, hotPos.resize(hotEmbTotCount[batch->name], 0); tensors->push_back(Vec2TensorI32(hotPos)); + // Tensors contains restore、hotPos、restoreSec&unique、idOffset in order when HBM mode, and is pushed in infolist. if (!rankInfo.isDDR) { PushGlobalUniqueTensors(tensors, lookupKeys, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); @@ -613,6 +625,27 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, return true; } +void KeyProcess::HandleEos(unique_ptr& batch, int channel, int threadId) +{ + if (!rankInfo.isDDR) { // HBM + std::unique_lock lockGuard(mut); + infoList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); + lockGuard.unlock(); + LOG_INFO("KeyProcessTaskHelper hbm eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", batch->name, + batch->batchId, batch->channel, threadId); + return; + } + // DDR + vector uniqueKeys; + std::unique_lock lockGuard(mut); + uniqueKeysList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); + lockGuard.unlock(); + LOG_INFO("KeyProcessTaskHelper ddr eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", batch->name, + batch->batchId, batch->channel, threadId); +} + KeysT KeyProcess::FeatureAdmitForDp(KeysT& lookupKeys, KeysT& globalDpIdVec) { KeysT globalDpIdUniqueVec; @@ -732,7 +765,7 @@ void KeyProcess::PushResultHBM(unique_ptr& batch, unique_ptr lockGuard(mut); storage.push_front(move(tensors)); - infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); lockGuard.unlock(); } @@ -741,8 +774,9 @@ void KeyProcess::PushResultDDR(unique_ptr& batch, unique_ptr lockGuard(mut); storage.push_front(move(tensors)); - infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); - uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(uniqueKeys))); + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); + uniqueKeysList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); restoreVecSecList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(restoreVecSec))); lockGuard.unlock(); } @@ -751,8 +785,8 @@ void KeyProcess::PushKeyCountHBM(unique_ptr& batch, unique_ptr lockGuard(mut); keyCountStorage.push_front(move(tensors)); - keyCountInfoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, - keyCountStorage.begin())); + keyCountInfoList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, keyCountStorage.begin())); lockGuard.unlock(); LOG_INFO("Push key count to list success."); } @@ -775,6 +809,10 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) const while (true) { batch = batchQueue->TryPop(); if (batch != nullptr) { + if (batch->isEos) { + LOG_INFO("GetBatchData eos, table name:{}, batchId:{}, channelId:{} threadId:{}", batch->name, + batch->batchId, channel, commId); + } break; } this_thread::sleep_for(100us); @@ -987,7 +1025,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, ve batch->batchId); // 使用静态all2all通信:发送或接受量为预置固定值 scInfo[batch->name] = 65536 / rankSize 经验值 - if (rankInfo.useStatic) { // maybe move after all2all + if (rankInfo.useStatic) { // maybe move after all2all ProcessKeysWithStatic(batch, splitKeys); } @@ -997,7 +1035,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, ve LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, batchName:{}, MPI_Allgatherv finish." " processSplitKeysTC(ms):{}", batch->channel, id, batch->batchId, batch->name, processSplitKeysTC.ElapsedMS()); - return { keyRecv, scAll, ss }; + return {keyRecv, scAll, ss}; } KeysT keySend; @@ -1177,8 +1215,8 @@ tuple, vector, vector>> KeyProcess::Hash return {splitKeys, restore, keyCount}; } -tuple, vector, vector, vector> KeyProcess::HotHashSplit(const -unique_ptr& batch) +tuple, vector, vector, vector> KeyProcess::HotHashSplit( + const unique_ptr& batch) { EASY_FUNCTION(profiler::colors::Gold) emb_key_t* batchData = batch->sample.data(); @@ -1345,8 +1383,7 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, cons LOG_DEBUG("channelId:{} threadId:{} batchId:{}, GetScAll start.", batch->channel, commId, batch->batchId); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - auto retCode = MPI_Allgather(keyScLocal.data(), sendAndRecvCount, MPI_INT, - scAll.data(), sendAndRecvCount, MPI_INT, + auto retCode = MPI_Allgather(keyScLocal.data(), sendAndRecvCount, MPI_INT, scAll.data(), sendAndRecvCount, MPI_INT, comm[batch->channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {} commId {}, MPI_Allgather failed:{}", rankInfo.rankId, commId, retCode); @@ -1421,7 +1458,7 @@ T KeyProcess::GetInfo(info_list_t& list, const EmbBaseInfo& info) return move(t); } -template +template T KeyProcess::GetKeyCountVec(info_list_t& list, const EmbBaseInfo& info) { std::lock_guard lockGuard(mut); @@ -1462,14 +1499,16 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos) } try { auto infoVec = GetInfo(uniqueKeysList, info); - ret = get>(infoVec); - break; - } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); - isEos = IsGetUniqueKeysEos(info, startTime); + isEos = get(infoVec); if (isEos) { + LOG_INFO(KEY_PROCESS "GetUniqueKeys eos! {}[{}]:{}", info.name, info.channelId, info.batchId); break; } + ret = get>(infoVec); + break; + } catch (EmptyList&) { + LOG_TRACE("getting unique info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", + info.name, info.channelId, info.batchId, hybridMgmtBlock->readEmbedBatchId[info.channelId]); this_thread::sleep_for(1ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed table:{}, channel:{}, mgmt batchId:{}, wrong top", info.name, info.channelId, @@ -1480,42 +1519,6 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos) return ret; } -bool KeyProcess::IsGetUniqueKeysEos(const EmbBaseInfo& info, std::chrono::_V2::system_clock::time_point& startTime) -{ - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - auto endTime = std::chrono::system_clock::now(); - - // readEmbKey start with 0 - int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1; - // 避免eos在keyProcess还未处理完数据时插队到通道前面 - std::chrono::duration elapsedTime = endTime - startTime; - if (info.batchId != 0 && elapsedTime.count() >= timeoutGetUniqueKeysEmpty) { - LOG_DEBUG("table:{}, channelId:{}, isNeedSendEos:{}, current batchId:{}, L1 pipeline readEmbKeyBatchId:{}, " - "L2 pipeline lookUpSwapAddrsPushId:{}, L3 pipeline h2dNextBatchId:{}", - info.name, info.channelId, isNeedSendEos[info.channelId], info.batchId, readEmbKeyBatchId, - hybridMgmtBlock->lookUpSwapAddrsPushId[info.name][info.channelId], - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]); - startTime = std::chrono::system_clock::now(); - } - // Check '>= readEmbedBatchIdAll' condition to avoid send eos before handle all batch data from readEmbKey Op. - if (isNeedSendEos[info.channelId] && readEmbKeyBatchId < info.batchId && - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId] == - hybridMgmtBlock->lookUpSwapAddrsPushId[info.name][info.channelId] && - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId] >= - hybridMgmtBlock->readEmbedBatchId[info.channelId]) { - LOG_INFO("table:{}, channelId:{} current batchId:{}, GetUniqueKeys eos, L1 pipeline readEmbKeyBatchId:{}, " - "L2 pipeline hybridBatchId:{}, L3 pipeline lookUpSwapAddrsPushId:{}, L4 pipeline h2dNextBatchId:{}", - info.name, info.channelId, info.batchId, readEmbKeyBatchId, - hybridMgmtBlock->hybridBatchId[info.channelId], - hybridMgmtBlock->lookUpSwapAddrsPushId[info.name][info.channelId], - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]); - return true; - } - LOG_TRACE("getting uniqueKeys failed, table:{}, channel:{}, mgmt batchId:{}, readEmbKey batchId:{}, list is empty", - info.name, info.channelId, info.batchId, readEmbKeyBatchId); - return false; -} - std::vector KeyProcess::GetRestoreVecSec(const EmbBaseInfo& info) { TimeCost tc = TimeCost(); @@ -1539,18 +1542,8 @@ std::vector KeyProcess::GetRestoreVecSec(const EmbBaseInfo& info) auto ret = GetInfo(restoreVecSecList, info); return get>(ret); } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); - // readEmbKey真实的次数是readEmbedBatchId减1 - int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1; - // 避免eos在keyProcess还未处理完数据时插队到通道前面 - if (isNeedSendEos[info.channelId] && readEmbKeyBatchId < info.batchId && - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId] == info.batchId) { - LOG_ERROR("channelId:{} batchId:{}, GetRestoreVecSec eos, code should not reach here", info.channelId, - info.batchId); - throw runtime_error("GetRestoreVecSec eos, code should not reach here"); - } LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", - info.name, info.channelId, info.batchId, readEmbKeyBatchId); + info.name, info.channelId, info.batchId, hybridMgmtBlock->readEmbedBatchId[info.channelId]); this_thread::sleep_for(1ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed {}[{}]:{} wrong top", info.name, info.channelId, info.batchId); @@ -1597,8 +1590,8 @@ void KeyProcess::SendEos(const std::string& embName, int batchId, int channel) this_thread::sleep_for(1000ms); } readySendEosCnt[channel].store(0); - isNeedSendEos[channel] = false; - LOG_DEBUG("isNeedSendEos set to false, table:{}, channelId:{} batchId:{}", embName, channel, batchId); + + LOG_DEBUG("sendEos finish all, table:{}, channelId:{} batchId:{}", embName, channel, batchId); #endif } @@ -1645,19 +1638,19 @@ unique_ptr> KeyProcess::GetInfoVec(const EmbBaseInfo& info, Proce try { auto infoVec = GetInfo(*list, info); + isEos = get(infoVec); + if (isEos) { + LOG_INFO(KEY_PROCESS "GetInfoVec eos! {}[{}]:{}", info.name, info.channelId, info.batchId); + break; + } auto it = get>>::iterator>(infoVec); ret = std::move(*it); std::unique_lock lockGuard(mut); storage.erase(it); break; } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); - isEos = IsGetInfoVecEos(info.batchId, info.name, info.channelId); - if (isEos) { - break; - } LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", - info.name, info.channelId, info.batchId, (hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1)); + info.name, info.channelId, info.batchId, hybridMgmtBlock->readEmbedBatchId[info.channelId]); this_thread::sleep_for(1ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed {}[{}]:{} wrong top", info.name, info.channelId, info.batchId); @@ -1686,7 +1679,6 @@ unique_ptr> KeyProcess::GetKCInfoVec(const EmbBaseInfo& info) keyCountStorage.erase(it); break; } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); LOG_TRACE("getting info failed, list is empty."); this_thread::sleep_for(1ms); } catch (WrongListTop&) { @@ -1712,7 +1704,7 @@ void KeyProcess::SendA2A(const vector& a2aInfo, const string& embName, int std::unique_lock lockGuard(mut); storage.push_front(move(tensors)); - all2AllList[embName][channel].push(make_tuple(batch, embName, storage.begin())); + all2AllList[embName][channel].push(make_tuple(batch, embName, false, storage.begin())); lockGuard.unlock(); } @@ -1829,35 +1821,21 @@ void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) } } -void KeyProcess::SetEos(int status, int channelId) -{ - unique_lock lockGuard(eosMutex); - LOG_INFO("isNeedSendEos status is changed, channel:{}, before status:{}, input status:{}", channelId, - isNeedSendEos[channelId], status); - isNeedSendEos[channelId] = (status == 1); -} - -bool KeyProcess::IsGetInfoVecEos(int batch, const string& embName, int channel) +void KeyProcess::EnqueueEosBatch(int64_t batchNum, int channelId) { - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - - // 避免eos在keyProcess还未处理完数据时插队到通道前面, readEmbKey真实的次数是readEmbedBatchId减1 - int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[channel] - 1; - if (rankInfo.isDDR) { - if (isNeedSendEos[channel] && readEmbKeyBatchId < batch && - hybridMgmtBlock->h2dNextBatchId[embName][channel] == batch) { - LOG_ERROR("channelId:{} batchId:{}, GetInfoVec eos, code should not reach here", channel, batch); - throw runtime_error("GetInfoVec eos, code should not reach here"); - } - } else { - LOG_TRACE("table:{}, channelId:{}, readEmbKeyBatchId:{}, batchId:{}, isNeedSendEos:{}", embName, channel, - readEmbKeyBatchId, batch, isNeedSendEos[channel]); - if (isNeedSendEos[channel] && readEmbKeyBatchId < batch) { - LOG_INFO("table:{}, channelId:{} batchId:{}, GetInfoVec eos", embName, channel, batch); - return true; - } + LOG_INFO("Enqueue dataSet eos on batch queue, channel:{}, eos number:{}", channelId, batchNum); + int threadNum = GetThreadNumEnv(); + int batchQueueId = int(batchNum % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId); + auto queue = SingletonQueue::GetInstances(batchQueueId); + for (auto& emb : embInfos) { + auto batchData = queue->GetOne(); // get dirty or empty data block + batchData->name = emb.first; + batchData->channel = channelId; + batchData->batchId = batchNum; + batchData->sample = {0, 0, 0, 0, 0, 0, 0, 0}; // fake data + batchData->isEos = true; + queue->Pushv(move(batchData)); } - return false; } void KeyProcess::SendEosTensor(const std::string& embName, int channel) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 6afa73aa41b67b573938675b14fad5d7904438f3..a10b5a9c14e7ea6eeec1d27753721c6620630112 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -68,9 +68,6 @@ private: const char* errorMessage; }; -constexpr int MPI_ABNORMAL_SEND_VALUE = 0; // MPI异常通信时发送0 -constexpr int MPI_NORMAL_SEND_VALUE = 1; // MPI正常通信时发送1 - class EmptyList : public std::exception {}; class WrongListTop : public std::exception {}; @@ -78,7 +75,7 @@ class WrongListTop : public std::exception {}; class KeyProcess { public: bool Initialize(const RankInfo& rInfo, const vector& eInfos, - const vector& thresholdValues = {}, int seed = 0, bool isIncrementalCkpt = false); + const vector& thresholdValues = {}, int seed = 0, bool isIncrementalCkpt = false); unique_ptr> GetInfoVec(const EmbBaseInfo& info, ProcessedInfo type, bool& isEos); @@ -192,25 +189,25 @@ public: uniqueKeys.resize(lookupKeys.size(), -1); } - void SetEos(int status, int channelId); + void EnqueueEosBatch(int64_t batchNum, int channelId); void SendEos(const string& embName, int batchId, int channel); + void HandleEos(unique_ptr& batch, int channel, int threadId); + bool isRunning{false}; - bool isIncrementalCheckpoint {false}; + bool isIncrementalCheckpoint{false}; std::mutex destroyMutex; - std::mutex eosMutex; + inline bool HasEmbName(const string& embName) { return embInfos.find(embName) != embInfos.end(); }; - GTEST_PRIVATE : - - int - Start(); +GTEST_PRIVATE : + int Start(); template T GetInfo(info_list_t& list, const EmbBaseInfo& info); @@ -242,12 +239,8 @@ public: int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; - // for end-of-sequence case - bool isNeedSendEos[2] = {false, false}; // 表示各表通道0、1的eos状态 atomic readySendEosCnt[2]; atomic finishSendEosCnt[2]; - const double timeoutGetUniqueKeys = 30.0; // 如果超时仍未获取到数据将触发EOS - const double timeoutGetUniqueKeysEmpty = 1.0; // 如果超时仍未获取到数据将打印信息 void InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo); @@ -275,7 +268,7 @@ public: void GetUniqueConfig(ock::ctr::UniqueConf& uniqueConf); void InitializeUnique(ock::ctr::UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, - const unique_ptr & batch, ock::ctr::UniquePtr& unique); + const unique_ptr& batch, ock::ctr::UniquePtr& unique); void ProcessBatchWithFastUnique(const unique_ptr& batch, ock::ctr::UniquePtr& unique, int id, UniqueInfo& uniqueInfoOut); @@ -287,8 +280,8 @@ public: auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; - auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector, - vector>; + auto HotHashSplit(const unique_ptr& batch) + -> tuple, vector, vector, vector>; void PaddingAlltoallVC(vector& splitKeys) const; @@ -341,12 +334,11 @@ public: vector GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss); - void HashSplitHelper(const unique_ptr & batch, vector & splitKeys, - vector & restore, vector & hotPos, - vector >& keyCount, vector& keyCountVec); + void HashSplitHelper(const unique_ptr& batch, vector& splitKeys, vector& restore, + vector& hotPos, vector>& keyCount, vector& keyCountVec); - vector GetCountRecvForDp(const unique_ptr& batch, const int id, - vector& keyCount, vector scAll); + vector GetCountRecvForDp(const unique_ptr& batch, const int id, vector& keyCount, + vector scAll); KeysT FeatureAdmitForDp(KeysT& lookupKeys, KeysT& globalDpIdVec); @@ -362,10 +354,6 @@ public: string DumpSplitKeys(vector>& splitKeys) const; - bool IsGetInfoVecEos(int batch, const string& embName, int channel); - - bool IsGetUniqueKeysEos(const EmbBaseInfo& info, std::chrono::_V2::system_clock::time_point& startTime); - void SendEosTensor(const std::string& embName, int channel); }; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index f267652ebcb7eef942905dbce8bedaa6e3043814..3cadc5f58c063170d65a2e30931b30f44af53724 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -188,7 +188,7 @@ namespace MxRec { } // Make key for mutex and cv in swap pipeline, id: threadId, channelId: train/eval. - string MakeKeyName(int id, const string& tableName, int channelId) + string MakeSwapCVName(int id, const string& tableName, int channelId) { return to_string(id) + tableName + to_string(channelId); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 744944880079a4c0e3be0198653d88e8aa07dffb..609e3460c8d7528e4d48d9c541c5f244e285c2d6 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -122,9 +122,9 @@ using freq_num_t = int64_t; using EmbNameT = std::string; using KeysT = std::vector; using LookupKeyT = std::tuple; // batch_id quarry_lable keys_vector -using UinqueKeyT = std::tuple>; +using UinqueKeyT = std::tuple>; using RestoreVecSecT = std::tuple>; -using TensorInfoT = std::tuple>>::iterator>; +using TensorInfoT = std::tuple>>::iterator>; namespace HybridOption { const unsigned int USE_STATIC = 0x001; @@ -190,6 +190,7 @@ struct Batch { size_t batchSize; int batchId; int channel = 0; + bool isEos = false; time_t timestamp{-1}; }; @@ -520,14 +521,13 @@ struct KeySendInfo { }; struct KeyInfo { - int64_t lastUseTime; // 最后使用时间 - int64_t recentCount; // 最近使用次数 - bool isChanged; // 是否有变更 - int64_t batchID; // batch id - int64_t totalCount; // key总使用次数 - - KeyInfo(): lastUseTime(0), recentCount(0), isChanged(false), - batchID(0), totalCount(0) {} + int64_t lastUseTime; // 最后使用时间 + int64_t recentCount; // 最近使用次数 + bool isChanged; // 是否有变更 + int64_t batchID; // batch id + int64_t totalCount; // key总使用次数 + + KeyInfo() : lastUseTime(0), recentCount(0), isChanged(false), batchID(0), totalCount(0) {} }; using EmbMemT = absl::flat_hash_map; @@ -596,7 +596,6 @@ enum class CkptDataType { std::string CkptDataTypeName(CkptDataType type); - enum CTRLogLevel { // can't use enum class due to compatibility for AccCTR DEBUG = 0, INFO, @@ -629,7 +628,7 @@ bool CheckFilePermission(const string& filePath); int GetStepFromPath(const string& loadPath); -string MakeKeyName(int id, const string& tableName, int channelId); +string MakeSwapCVName(int id, const string& tableName, int channelId); } // end namespace MxRec #define KEY_PROCESS "\033[45m[KeyProcess]\033[0m " diff --git a/src/core/utils/task_queue.h b/src/core/utils/task_queue.h index a42e514776789dc747868483a03fa6790228310a..6042e54c420391f5247f00b7f2653d8950056f6f 100644 --- a/src/core/utils/task_queue.h +++ b/src/core/utils/task_queue.h @@ -82,7 +82,7 @@ namespace MxRec { void DestroyQueue() { finished = true; - dataCond.notify_one(); + dataCond.notify_all(); } bool Empty() const diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index afc3fe3ad8696c9acea89c8e70ed3d270d4710f9..dc46aa20704721cb01e0d65845b11be78edb5e39 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -20,8 +20,11 @@ See the License for the specific language governing permissions and #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/mutex.h" + #if defined(TF_VERSION_TF2) #include "tensorflow/core/data/name_utils.h" #endif @@ -76,16 +79,18 @@ class EosDatasetOp::Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext *ctx, const DatasetBase *input, int32_t channelId, int32_t maxTrainSteps, - int32_t maxEvalSteps) + int32_t maxEvalSteps, + const DataTypeVector& outputTypes, + const std::vector& outputShapes) : DatasetBase(DatasetContext(ctx)), input_(input), channelId_(channelId), maxTrainSteps_(maxTrainSteps), maxEvalSteps_(maxEvalSteps), + outputTypes_(outputTypes), + outputShapes_(outputShapes), id_(g_datasetId[channelId]) { input_->Ref(); - auto os_input = input->output_shapes(); - output_shapes_ = os_input; MPI_Comm_group(MPI_COMM_WORLD, &g_group); MPI_Comm_create(MPI_COMM_WORLD, g_group, &g_comm[channelId]); @@ -118,14 +123,14 @@ public: this, prefix_para}); } - const DataTypeVector &output_dtypes() const override + const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); + return outputTypes_; } - const std::vector &output_shapes() const override + const std::vector& output_shapes() const override { - return output_shapes_; + return outputShapes_; } string DebugString() const override @@ -186,7 +191,6 @@ private: } #endif - Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) override @@ -198,9 +202,10 @@ private: } TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - auto keyProcess = Singleton::GetInstance(); - auto datasetId = dataset()->id_; auto channelId = dataset()->channelId_; + PrintOutput(out_tensors, channelId); + + auto keyProcess = Singleton::GetInstance(); if (channelId == 0 && iter_times_ == dataset()->maxTrainSteps_) { *end_of_sequence = true; } @@ -217,7 +222,8 @@ private: &req); CheckCommFinished(req, channelId); - keyProcess->SetEos(1, dataset()->channelId_); + keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); + LOG_DEBUG("[ACTIVE] GetNext eos was triggered actively, channel: {}, iter: {}", dataset()->channelId_, iter_times_); @@ -232,7 +238,9 @@ private: if (getNextStatus < g_rankSize) { *end_of_sequence = true; - keyProcess->SetEos(1, dataset()->channelId_); + + keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); + LOG_DEBUG( "[PASSIVE] GetNext eos was triggered passively, channel: {}, iter: {}, sum: {}", dataset()->channelId_, iter_times_, getNextStatus); @@ -276,6 +284,29 @@ private: return Status::OK(); } + void PrintOutput(std::vector *out_tensors, int channelId) + { + // Out size equals to zero when batch eos. + int outSize = out_tensors->size(); + if (MxRec::Logger::GetLevel() <= MxRec::Logger::DEBUG) { + for (const auto& t : *out_tensors) { + DataType tensor_type = t.dtype(); + TensorShape tensor_shape = t.shape(); + LOG_DEBUG("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, " + "tensor_type: {}, tensor_shape: {}", + channelId, + iter_times_, + outSize, + tensor_type, + tensor_shape.DebugString()); + } + } + if (outSize <= 0) { + LOG_DEBUG("Iterator getNext eos, channel: {}, iter: {}, outTensor size: {}", channelId, + iter_times_, outSize); + } + } + private: static constexpr int GET_NEXT_CONTINUE = 1; static constexpr int GET_NEXT_TERMINATE = 0; @@ -293,11 +324,15 @@ private: int32_t channelId_; int32_t maxTrainSteps_; int32_t maxEvalSteps_; - std::vector output_shapes_; + const DataTypeVector outputTypes_; + std::vector outputShapes_; int id_; }; -EosDatasetOp::EosDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) {} +EosDatasetOp::EosDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &outputTypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &outputShapes_)); +} void EosDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, DatasetBase **output) { @@ -307,7 +342,8 @@ void EosDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, Dataset OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxTrainSteps, &maxTrainSteps)); int32_t maxEvalSteps; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxEvalSteps, &maxEvalSteps)); - *output = new Dataset(ctx, input, channel, maxTrainSteps, maxEvalSteps); + *output = new (std::nothrow) Dataset(ctx, input, channel, maxTrainSteps, maxEvalSteps, outputTypes_, outputShapes_); + OP_REQUIRES(ctx, *output != nullptr, errors::InvalidArgument("EosDatasetOp: new dataset failed")); } REGISTER_OP("EosDataset") diff --git a/src/dataset_tf/eos_dataset_op.h b/src/dataset_tf/eos_dataset_op.h index bf30c6b9b79e3aa8deaddaadb758e17b2b181685..117f2794743a6badd48dec9c102d8ba255686cba 100644 --- a/src/dataset_tf/eos_dataset_op.h +++ b/src/dataset_tf/eos_dataset_op.h @@ -44,6 +44,8 @@ namespace data { private: class Dataset; + DataTypeVector outputTypes_; + std::vector outputShapes_; }; // class EosDatasetOp } // namespace data } // namespace tensorflow diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 98fca9615339c4ed3dad1cad1ab37fb9c7153d49..123c7e1eae4f3dce16717a221a068243fbc35baa 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -189,6 +189,12 @@ namespace MxRec { Singleton::GetInstance()->ClearTransChannel(channelId); threadNum = GetThreadNumEnv(); + if (threadNum <= 0) { + context->SetStatus( + errors::Aborted(__FILE__, ":", __LINE__, " ", "ThreadNum invalid. It should be bigger than 0 ...")); + return; + } + auto keyProcess = Singleton::GetInstance(); if (!keyProcess->isRunning) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); @@ -278,6 +284,7 @@ namespace MxRec { batchData->name = embNames.at(i); size_t len = splits(i); batchData->channel = channelId; + batchData->isEos = false; batchData->batchId = ids[0]; batchData->sample.resize(len); if (isTimestamp) { @@ -381,6 +388,11 @@ namespace MxRec { Singleton::GetInstance()->ClearTransChannel(channelId); threadNum = GetThreadNumEnv(); + if (threadNum <= 0) { + context->SetStatus( + errors::Aborted(__FILE__, ":", __LINE__, " ", "ThreadNum invalid. It should be bigger than 0 ...")); + return; + } auto keyProcess = Singleton::GetInstance(); if (!keyProcess->isRunning) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); @@ -466,6 +478,7 @@ namespace MxRec { batchData->name = embNames.at(i); size_t len = splits.at(i); batchData->channel = channelId; + batchData->isEos = false; batchData->batchId = batchId; batchData->sample.resize(len); if (isTimestamp) { diff --git a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp index c51d4be238bf19566df5954368d9ee43e761ac51..9fd2db80f8fe19c3f6c09d68e6d9545c2f51df66 100644 --- a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp +++ b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp @@ -94,31 +94,6 @@ TEST_F(HybridMgmtBlockTest, ResetAll) ASSERT_EQ(hybridMgmtBlock->hybridBatchId[0], 0); } -TEST_F(HybridMgmtBlockTest, CheckSaveEmbMapValid) -{ - hybridMgmtBlock = std::make_unique(); - hybridMgmtBlock->SetStepInterval(1, 1); - hybridMgmtBlock->lastRunChannelId = 0; - - hybridMgmtBlock->pythonBatchId[0] = 0; - hybridMgmtBlock->hybridBatchId[0] = 0; - hybridMgmtBlock->CheckSaveEmbMapValid(); - int status0 = hybridMgmtBlock->CheckSaveEmbMapValid(); - - hybridMgmtBlock->pythonBatchId[0] = 0; - hybridMgmtBlock->hybridBatchId[0] = 1; - hybridMgmtBlock->CheckSaveEmbMapValid(); - int status1 = hybridMgmtBlock->CheckSaveEmbMapValid(); - - int step2 = 2; - hybridMgmtBlock->pythonBatchId[0] = 0; - hybridMgmtBlock->hybridBatchId[0] = step2; - int status2 = hybridMgmtBlock->CheckSaveEmbMapValid(); - ASSERT_EQ(status0, 0); - ASSERT_EQ(status1, 1); - ASSERT_EQ(status2, -1); -} - TEST_F(HybridMgmtBlockTest, CountPythonStep) { hybridMgmtBlock = std::make_unique();