From 905fb9d2a37f1f017bfc180a25bd072ff57bd502 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 23 Jul 2024 11:57:38 +0800 Subject: [PATCH 01/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91DDR=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E5=81=B6=E5=8F=91=E6=8A=A5=E9=94=99=E7=A9=BA=E9=97=B4?= =?UTF-8?q?=E4=B8=8D=E8=B6=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/dataset_tf/eos_dataset_op.cc | 135 ++++++++++++++++++++++++++++--- src/dataset_tf/eos_dataset_op.h | 2 + 2 files changed, 127 insertions(+), 10 deletions(-) diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index afc3fe3a..eb28bbe2 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -20,8 +20,11 @@ See the License for the specific language governing permissions and #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/mutex.h" + #if defined(TF_VERSION_TF2) #include "tensorflow/core/data/name_utils.h" #endif @@ -76,16 +79,20 @@ class EosDatasetOp::Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext *ctx, const DatasetBase *input, int32_t channelId, int32_t maxTrainSteps, - int32_t maxEvalSteps) + int32_t maxEvalSteps, + const DataTypeVector& outputTypes, + const std::vector& outputShapes) : DatasetBase(DatasetContext(ctx)), input_(input), channelId_(channelId), maxTrainSteps_(maxTrainSteps), maxEvalSteps_(maxEvalSteps), + outputTypes_(outputTypes), + outputShapes_(outputShapes), id_(g_datasetId[channelId]) { input_->Ref(); - auto os_input = input->output_shapes(); - output_shapes_ = os_input; +// auto os_input = input->output_shapes(); +// output_shapes_ = os_input; MPI_Comm_group(MPI_COMM_WORLD, &g_group); MPI_Comm_create(MPI_COMM_WORLD, g_group, &g_comm[channelId]); @@ -118,14 +125,14 @@ public: this, prefix_para}); } - const DataTypeVector &output_dtypes() const override + const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); + return outputTypes_; } - const std::vector &output_shapes() const override + const std::vector& output_shapes() const override { - return output_shapes_; + return outputShapes_; } string DebugString() const override @@ -186,6 +193,86 @@ private: } #endif + int64_t GetTensorElementNum(size_t index) { + PartialTensorShape tensor_shape = dataset()->output_shapes()[index]; + int64_t element_number = 1LL; + for (int32_t i = 0; i < tensor_shape.dims(); i++) { + element_number *= tensor_shape.dim_size(i); + } + return element_number; + } + + bool IsUnknowShape(const PartialTensorShape& output_shapes) const { + if (output_shapes.unknown_rank()) { + return true; + } + for (int32_t i = 0; i < output_shapes.dims(); i++) { + if (output_shapes.dim_size(i) == -1) { + return true; + } + } + return false; + } + Tensor CreateTensorByShape(const PartialTensorShape& output_shapes, const DataType& tensor_data_type) { + TensorShape tf_shape; + for (int32_t i = 0; i < output_shapes.dims(); i++) { + tf_shape.AddDim(output_shapes.dim_size(i)); + } + LOG_INFO("[LQK] CreateTensorByShape, tensor shape: {}", tf_shape.DebugString()); + return Tensor(tensor_data_type, tf_shape); + } + + std::vector CreateOutputVecTensor() + { + size_t output_shape_size = dataset()->output_shapes().size(); + size_t output_type_size = dataset()->output_dtypes().size(); + LOG_INFO("[LQK] output_shape_size: {}, output_type_size: {}", output_shape_size, output_type_size); + if (output_shape_size != output_type_size) { + LOG_ERROR("[LQK] output_shape_size: {} is not equal to output_type_size: {}", output_shape_size, + output_type_size); + return {}; + } + std::vector result; + for (size_t i = 0UL; i < output_shape_size; i++) { + DataType tensor_data_type = dataset()->output_dtypes().at(i); + if (tensor_data_type == DT_STRING) { + LOG_ERROR("[LQK] current tensor type is DT_STRING"); + return{}; + } + LOG_INFO("[LQK] current tensor type is: {}", tensor_data_type); + LOG_INFO("[LQK] current tensor dim is: {}, dim[0].dim_Size is {}", dataset()->output_shapes()[i].dims(), + dataset()->output_shapes()[i].dim_size(0)); + if (dataset()->output_shapes()[i].dims() == 2) { + LOG_INFO("[LQK] current tensor dim[1].dim_Size is {}", dataset()->output_shapes()[i].dim_size(1)); + } + if (IsUnknowShape(dataset()->output_shapes()[i])) { + LOG_INFO("[LQK] output shape is unknown shape"); + Tensor tensor(tensor_data_type, TensorShape({3, 1})); + if (dataset()->output_shapes()[i].dims() == -1) { + tensor = Tensor(tensor_data_type, TensorShape({1})); + } + + // 获取指针 + auto tensor_data = const_cast(tensor.tensor_data().data()); + auto tensor_size = tensor.tensor_data().size(); + LOG_INFO("[LQK] IsUnknowShape, create tensor: {}, tensor size: {}, tensor.NumElements:{}", + tensor.DebugString(), tensor_size, tensor.NumElements()); + + memset_s(tensor_data, tensor_size, 0, tensor_size); + + LOG_INFO("[LQK] IsUnknowShape, after memset tensor: {}", tensor.DebugString()); + + result.push_back(tensor); + continue; + } + Tensor a = CreateTensorByShape(dataset()->output_shapes()[i], tensor_data_type); + LOG_INFO("[LQK] success create know shape tensor: {}", a.DebugString()); + + result.push_back(a); + } + return result; + } + Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, @@ -198,6 +285,28 @@ private: } TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + int outSize = out_tensors->size(); + if (outSize > 0) { + for (const auto& t : *out_tensors) { + DataType tensor_type = t.dtype(); + TensorShape tensor_shape = t.shape(); + LOG_INFO("[LQK] GetNext eos, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " + "tensor_shape: {}", + dataset()->channelId_, + iter_times_, + outSize, + tensor_type, + tensor_shape.DebugString()); + } + } + if (!is_second_eos && *end_of_sequence) { + is_second_eos = true; + *end_of_sequence = false; + *out_tensors = CreateOutputVecTensor(); + } else if (is_second_eos) { + *end_of_sequence = true; + } + auto keyProcess = Singleton::GetInstance(); auto datasetId = dataset()->id_; auto channelId = dataset()->channelId_; @@ -287,17 +396,22 @@ private: GUARDED_BY(mu_); std::unique_ptr input_impl_ GUARDED_BY(mu_); + bool is_second_eos = false; }; const DatasetBase *input_; int32_t channelId_; int32_t maxTrainSteps_; int32_t maxEvalSteps_; - std::vector output_shapes_; + const DataTypeVector outputTypes_; + std::vector outputShapes_; int id_; }; -EosDatasetOp::EosDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) {} +EosDatasetOp::EosDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &outputTypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &outputShapes_)); +} void EosDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, DatasetBase **output) { @@ -307,7 +421,8 @@ void EosDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, Dataset OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxTrainSteps, &maxTrainSteps)); int32_t maxEvalSteps; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxEvalSteps, &maxEvalSteps)); - *output = new Dataset(ctx, input, channel, maxTrainSteps, maxEvalSteps); + *output = new (std::nothrow) Dataset(ctx, input, channel, maxTrainSteps, maxEvalSteps, outputTypes_, outputShapes_); + OP_REQUIRES(ctx, *output != nullptr, errors::InvalidArgument("EosDatasetOp: new dataset failed")); } REGISTER_OP("EosDataset") diff --git a/src/dataset_tf/eos_dataset_op.h b/src/dataset_tf/eos_dataset_op.h index bf30c6b9..117f2794 100644 --- a/src/dataset_tf/eos_dataset_op.h +++ b/src/dataset_tf/eos_dataset_op.h @@ -44,6 +44,8 @@ namespace data { private: class Dataset; + DataTypeVector outputTypes_; + std::vector outputShapes_; }; // class EosDatasetOp } // namespace data } // namespace tensorflow -- Gitee From 6f19ec3b2917168c5874bc03903f82803de89fc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 23 Jul 2024 12:10:36 +0800 Subject: [PATCH 02/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 +- src/core/key_process/key_process.cpp | 3 +++ src/core/utils/common.h | 18 ++++++++++++++++++ src/dataset_tf/eos_dataset_op.cc | 1 + 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 4801f95b..2d89b117 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -658,7 +658,7 @@ void HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo& info, bool& remainBatchOut SendUniqKeysAndRestoreVecHBM(info, infoVecs, isGrad); } - // 发送恢复向量 + // 发送恢复向量和hotPos TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, info.channelId, info.name); LOG_DEBUG("table:{}, sendRestoreSyncTC(ms):{}, parseKeysTc HBM mode (ms):{}", info.name, diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 1cb9f992..3a1f5aec 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -546,6 +546,9 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) const while (true) { batch = batchQueue->TryPop(); if (batch != nullptr) { + if (batch->CheckAndSetEos()) { + LOG_INFO("GetBatchData eos, channelId:{} threadId:{}", channel, commId); + } break; } this_thread::sleep_for(100us); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 8c7528f4..cf5f46e7 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -183,11 +183,29 @@ struct Batch { return s; } + bool CheckAndSetEos() + { + int num = sample.size(); + if (num < 3 || num % 3 != 0) { + return false; + } + + for (int i = 0; i < num; i++) { + if (sample[i] != 0) + { + return false; + } + } + isEos = true; + return true; + } + std::vector sample; std::string name; size_t batchSize; int batchId; int channel = 0; + bool isEos = false; time_t timestamp{-1}; }; diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index eb28bbe2..dae71e57 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -213,6 +213,7 @@ private: } return false; } + Tensor CreateTensorByShape(const PartialTensorShape& output_shapes, const DataType& tensor_data_type) { TensorShape tf_shape; for (int32_t i = 0; i < output_shapes.dims(); i++) { -- Gitee From 019d36426e3339faced7b7faec4836d3c0d9681f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 23 Jul 2024 15:17:36 +0800 Subject: [PATCH 03/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/key_process/key_process.cpp | 93 +++++++++++++++++++--------- src/core/utils/common.h | 4 +- src/dataset_tf/eos_dataset_op.cc | 4 +- 3 files changed, 69 insertions(+), 32 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 3a1f5aec..6e8e0418 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -305,9 +305,9 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) break; } LOG_INFO(KEY_PROCESS "getAndProcessTC(ms):{}, key process cost:{}," - " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}", + " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}, isEos:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, - batch->name, batch->channel, threadId, batch->batchId); + batch->name, batch->channel, threadId, batch->batchId, batch->isEos); int queueIndex = threadId + (MAX_KEY_PROCESS_THREAD * batch->channel); auto batchQueue = SingletonQueue::GetInstances(queueIndex); batchQueue->PutDirty(move(batch)); @@ -397,6 +397,26 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) { + if (batch->isEos) { + if (!rankInfo.isDDR) { // HBM +// auto tensors = make_unique>(); + std::unique_lock lockGuard(mut); +// storage.push_front(move(tensors)); + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); + lockGuard.unlock(); + LOG_INFO("KeyProcessTaskHelper hbm eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", + batch->name, batch->batchId, batch->channel, threadId); + return true; + } + // DDR + vector uniqueKeys; + std::unique_lock lockGuard(mut); + uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); + lockGuard.unlock(); + LOG_INFO("KeyProcessTaskHelper ddr eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", + batch->name, batch->batchId, batch->channel, threadId); + return true; + } vector splitKeys; vector restore; vector hotPos; @@ -440,11 +460,12 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, hotPos.resize(hotEmbTotCount[batch->name], 0); tensors->push_back(Vec2TensorI32(hotPos)); + // HBM把restore、unique、idoffset做成了Tensor,放到infolist里面了(hbm第一个get的是tensors) if (!rankInfo.isDDR) { PushGlobalUniqueTensors(tensors, lookupKeys, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); PushResultHBM(batch, move(tensors)); - } else { + } else { // DDR 则保留原有的数据结构,idoffset在mgmt侧组装(ddr第一个get的是unique) std::vector lookupKeysUint(lookupKeys.begin(), lookupKeys.end()); vector uniqueKeys; vector restoreVecSec; @@ -513,7 +534,7 @@ void KeyProcess::PushResultHBM(unique_ptr& batch, unique_ptr lockGuard(mut); storage.push_front(move(tensors)); - infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); lockGuard.unlock(); } @@ -522,8 +543,8 @@ void KeyProcess::PushResultDDR(unique_ptr& batch, unique_ptr lockGuard(mut); storage.push_front(move(tensors)); - infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); - uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(uniqueKeys))); + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); + uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); restoreVecSecList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(restoreVecSec))); lockGuard.unlock(); } @@ -1195,15 +1216,24 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos, } try { auto infoVec = GetInfo(uniqueKeysList, info); - ret = get>(infoVec); - break; - } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); - isEos = IsGetUniqueKeysEos(info, startTime, lookUpSwapInAddrsPushId); + isEos = get(infoVec); if (isEos) { + LOG_WARN(KEY_PROCESS "GetUniqueKeys eos! {}[{}]:{}", + info.name, info.channelId, info.batchId); break; } - this_thread::sleep_for(1ms); + ret = get>(infoVec); + break; + } catch (EmptyList&) { +// unique_lock lockEosGuard(eosMutex); +// isEos = IsGetUniqueKeysEos(info, startTime, lookUpSwapInAddrsPushId); +// if (isEos) { +// break; +// } + LOG_DEBUG(KEY_PROCESS "GetUniqueKeys EmptyList! {}[{}]:{}", + info.name, info.channelId, info.batchId); + + this_thread::sleep_for(10ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed table:{}, channel:{}, mgmt batchId:{}, wrong top", info.name, info.channelId, info.channelId); @@ -1279,16 +1309,16 @@ std::vector KeyProcess::GetRestoreVecSec(const EmbBaseInfo& info) auto ret = GetInfo(restoreVecSecList, info); return get>(ret); } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); - // readEmbKey真实的次数是readEmbedBatchId减1 +// unique_lock lockEosGuard(eosMutex); +// // readEmbKey真实的次数是readEmbedBatchId减1 int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1; - // 避免eos在keyProcess还未处理完数据时插队到通道前面 - if (isNeedSendEos[info.channelId] && readEmbKeyBatchId < info.batchId && - hybridMgmtBlock->h2dNextBatchId[info.name] == info.batchId) { - LOG_ERROR("channelId:{} batchId:{}, GetRestoreVecSec eos, code should not reach here", - info.channelId, info.batchId); - throw runtime_error("GetRestoreVecSec eos, code should not reach here"); - } +// // 避免eos在keyProcess还未处理完数据时插队到通道前面 +// if (isNeedSendEos[info.channelId] && readEmbKeyBatchId < info.batchId && +// hybridMgmtBlock->h2dNextBatchId[info.name] == info.batchId) { +// LOG_ERROR("channelId:{} batchId:{}, GetRestoreVecSec eos, code should not reach here", +// info.channelId, info.batchId); +// throw runtime_error("GetRestoreVecSec eos, code should not reach here"); +// } LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", info.name, info.channelId, info.batchId, readEmbKeyBatchId); this_thread::sleep_for(1ms); @@ -1386,20 +1416,27 @@ unique_ptr> KeyProcess::GetInfoVec(const EmbBaseInfo &info, Proce try { auto infoVec = GetInfo(*list, info); + isEos = get(infoVec); + if (isEos) { + LOG_WARN(KEY_PROCESS "GetInfoVec eos! {}[{}]:{}", info.name, info.channelId, info.batchId); + break; + } auto it = get>>::iterator>(infoVec); ret = std::move(*it); std::unique_lock lockGuard(mut); storage.erase(it); break; } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); - isEos = IsGetInfoVecEos(info.batchId, info.name, info.channelId); - if (isEos) { - break; - } +// unique_lock lockEosGuard(eosMutex); +// isEos = IsGetInfoVecEos(info.batchId, info.name, info.channelId); +// if (isEos) { +// break; +// } + LOG_DEBUG(KEY_PROCESS "GetInfoVec EmptyList! {}[{}]:{}", info.name, info.channelId, info.batchId); + LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", info.name, info.channelId, info.batchId, (hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1)); - this_thread::sleep_for(1ms); + this_thread::sleep_for(10ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed {}[{}]:{} wrong top", info.name, info.channelId, info.batchId); this_thread::sleep_for(1ms); @@ -1423,7 +1460,7 @@ void KeyProcess::SendA2A(const vector& a2aInfo, const string& embName, int std::unique_lock lockGuard(mut); storage.push_front(move(tensors)); - all2AllList[embName][channel].push(make_tuple(batch, embName, storage.begin())); + all2AllList[embName][channel].push(make_tuple(batch, embName, false, storage.begin())); lockGuard.unlock(); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index cf5f46e7..e9598b64 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -120,9 +120,9 @@ using freq_num_t = int64_t; using EmbNameT = std::string; using KeysT = std::vector; using LookupKeyT = std::tuple; // batch_id quarry_lable keys_vector -using UinqueKeyT = std::tuple>; +using UinqueKeyT = std::tuple>; using RestoreVecSecT = std::tuple>; -using TensorInfoT = std::tuple>>::iterator>; +using TensorInfoT = std::tuple>>::iterator>; namespace HybridOption { const unsigned int USE_STATIC = 0x001; diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index dae71e57..e34f93ea 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -327,7 +327,7 @@ private: &req); CheckCommFinished(req, channelId); - keyProcess->SetEos(1, dataset()->channelId_); +// keyProcess->SetEos(1, dataset()->channelId_); LOG_DEBUG("[ACTIVE] GetNext eos was triggered actively, channel: {}, iter: {}", dataset()->channelId_, iter_times_); @@ -342,7 +342,7 @@ private: if (getNextStatus < g_rankSize) { *end_of_sequence = true; - keyProcess->SetEos(1, dataset()->channelId_); +// keyProcess->SetEos(1, dataset()->channelId_); LOG_DEBUG( "[PASSIVE] GetNext eos was triggered passively, channel: {}, iter: {}, sum: {}", dataset()->channelId_, iter_times_, getNextStatus); -- Gitee From cfcf4213bc8b91f13912544eaab864fbf1326bd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 23 Jul 2024 15:50:45 +0800 Subject: [PATCH 04/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/key_process/key_process.cpp | 17 +++++++++-------- src/core/utils/common.h | 7 +------ src/dataset_tf/eos_dataset_op.cc | 15 +++++++++++++-- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 6e8e0418..7d0c551f 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -568,7 +568,8 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) const batch = batchQueue->TryPop(); if (batch != nullptr) { if (batch->CheckAndSetEos()) { - LOG_INFO("GetBatchData eos, channelId:{} threadId:{}", channel, commId); + LOG_INFO("GetBatchData eos, table name:{}, batchId:{}, channelId:{} threadId:{}", batch->name, + batch->batchId, channel, commId); } break; } @@ -1230,10 +1231,10 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos, // if (isEos) { // break; // } - LOG_DEBUG(KEY_PROCESS "GetUniqueKeys EmptyList! {}[{}]:{}", - info.name, info.channelId, info.batchId); +// LOG_DEBUG(KEY_PROCESS "GetUniqueKeys EmptyList! {}[{}]:{}", +// info.name, info.channelId, info.batchId); - this_thread::sleep_for(10ms); + this_thread::sleep_for(1ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed table:{}, channel:{}, mgmt batchId:{}, wrong top", info.name, info.channelId, info.channelId); @@ -1368,8 +1369,8 @@ void KeyProcess::SendEos(const std::string& embName, int batchId, int channel, b this_thread::sleep_for(1000ms); } readySendEosCnt[channel].store(0); - isNeedSendEos[channel] = false; - LOG_DEBUG("isNeedSendEos set to false, table:{}, channelId:{} batchId:{}", embName, channel, batchId); +// isNeedSendEos[channel] = false; + LOG_DEBUG("sendEos finish all, table:{}, channelId:{} batchId:{}", embName, channel, batchId); #endif } @@ -1432,11 +1433,11 @@ unique_ptr> KeyProcess::GetInfoVec(const EmbBaseInfo &info, Proce // if (isEos) { // break; // } - LOG_DEBUG(KEY_PROCESS "GetInfoVec EmptyList! {}[{}]:{}", info.name, info.channelId, info.batchId); +// LOG_DEBUG(KEY_PROCESS "GetInfoVec EmptyList! {}[{}]:{}", info.name, info.channelId, info.batchId); LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", info.name, info.channelId, info.batchId, (hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1)); - this_thread::sleep_for(10ms); + this_thread::sleep_for(1ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed {}[{}]:{} wrong top", info.name, info.channelId, info.batchId); this_thread::sleep_for(1ms); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index e9598b64..e68168ae 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -185,12 +185,7 @@ struct Batch { bool CheckAndSetEos() { - int num = sample.size(); - if (num < 3 || num % 3 != 0) { - return false; - } - - for (int i = 0; i < num; i++) { + for (int i = 0; i < 8; i++) { if (sample[i] != 0) { return false; diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index e34f93ea..72be6696 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -220,7 +220,18 @@ private: tf_shape.AddDim(output_shapes.dim_size(i)); } LOG_INFO("[LQK] CreateTensorByShape, tensor shape: {}", tf_shape.DebugString()); - return Tensor(tensor_data_type, tf_shape); + + Tensor tmp(tensor_data_type, tf_shape); + auto tensor_data = const_cast(tmp.tensor_data().data()); + auto tensor_size = tmp.tensor_data().size(); + LOG_INFO("[LQK] KnownShape, create tensor: {}, tensor size: {}, tensor.NumElements:{}", + tmp.DebugString(), tensor_size, tmp.NumElements()); + + memset_s(tensor_data, tensor_size, 0, tensor_size); + + LOG_INFO("[LQK] KnownShape, after memset tensor: {}", tmp.DebugString()); + + return tmp; } std::vector CreateOutputVecTensor() @@ -248,7 +259,7 @@ private: } if (IsUnknowShape(dataset()->output_shapes()[i])) { LOG_INFO("[LQK] output shape is unknown shape"); - Tensor tensor(tensor_data_type, TensorShape({3, 1})); + Tensor tensor(tensor_data_type, TensorShape({8, 1})); if (dataset()->output_shapes()[i].dims() == -1) { tensor = Tensor(tensor_data_type, TensorShape({1})); } -- Gitee From 208b3abc4364ca37815dac7a2ff31d24b13087c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 23 Jul 2024 16:08:45 +0800 Subject: [PATCH 05/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/key_process/key_process.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 7d0c551f..fd9d054b 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1189,11 +1189,11 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos, TimeCost tc = TimeCost(); HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - bool cancelMonitor = false; - thread timeoutMonitor; - if (info.batchId != 0) { - timeoutMonitor = StartEosMonitorThread(info, cancelMonitor); - } +// bool cancelMonitor = false; +// thread timeoutMonitor; +// if (info.batchId != 0) { +// timeoutMonitor = StartEosMonitorThread(info, cancelMonitor); +// } // 循环尝试获取list中的数据;如果key process线程退出或者处理数据超时,返回空vector @@ -1241,10 +1241,10 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos, this_thread::sleep_for(1ms); } } - cancelMonitor = true; - if (timeoutMonitor.joinable()) { - timeoutMonitor.join(); - } +// cancelMonitor = true; +// if (timeoutMonitor.joinable()) { +// timeoutMonitor.join(); +// } return ret; } -- Gitee From dd18a01666c3116504d7ca0194c6959887f75592 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 23 Jul 2024 16:44:32 +0800 Subject: [PATCH 06/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset,=20deb?= =?UTF-8?q?ug=20ddr=20error=20when=20eval=20eos?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 2d89b117..e0932026 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -490,16 +490,18 @@ void HybridMgmt::TrainTask(TaskType type) void HybridMgmt::EvalTask(TaskType type) { #ifndef GTEST - int channelId = EVAL_CHANNEL_ID; - int& evalBatchId = hybridMgmtBlock->hybridBatchId[channelId]; + int& evalBatchId = hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]; do { - hybridMgmtBlock->CheckAndSetBlock(channelId); - if (hybridMgmtBlock->GetBlockStatus(channelId)) { + hybridMgmtBlock->CheckAndSetBlock(EVAL_CHANNEL_ID); + if (hybridMgmtBlock->GetBlockStatus(EVAL_CHANNEL_ID)) { LOG_DEBUG("eval channel block at batchId:{}, needWaitSave:{}", evalBatchId, hybridMgmtBlock->IsNeedWaitSave()); std::unique_lock checkSaveLocker(saveMutex); cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); + LOG_DEBUG("eval channel block, python batch id:{}, hybridBatchId:{}", + hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID], evalBatchId); + if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) { // Before waking the data process for training, Recover the backed-up training state RecoverTrainStatus(); @@ -510,12 +512,12 @@ void HybridMgmt::EvalTask(TaskType type) } LOG_DEBUG("wake TrainTask"); - hybridMgmtBlock->DoBlock(channelId); + hybridMgmtBlock->DoBlock(EVAL_CHANNEL_ID); } if (!isRunning) { return; } - LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, evalBatchId); + LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", EVAL_CHANNEL_ID, evalBatchId); ParseKeys(EVAL_CHANNEL_ID, evalBatchId, type); } while (true); -- Gitee From 6da208ed3595261a557a5f447646758fc7a71bc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 23 Jul 2024 20:00:40 +0800 Subject: [PATCH 07/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset,=20fix?= =?UTF-8?q?=20ddr=20error=20when=20eval=20eos?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 39 ++++++---------------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index e0932026..54ce943d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1381,49 +1381,25 @@ void HybridMgmt::HandleReachMaxStepCase(const EmbBaseInfo& info, bool& remainBat hybridMgmtBlock->SetBlockStatus(TRAIN_CHANNEL_ID, true); } +// DDR void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut) { LOG_INFO("GetUniqueKeys get eos, handle final batch for current epoch, table:{}, channel:{}, batchId:{}", info.name, info.channelId, info.batchId); bool sendAllChannel = false; - if (info.channelId == TRAIN_CHANNEL_ID) { - vector emptySwapOutPos; - SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); - LOG_INFO("GetUniqueKeys get eos, send pos for train channel, table:{}, batchId:{}", info.name, info.batchId); - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); - remainBatchOut = false; - return; - } + vector emptySwapOutPos; + SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); + LOG_INFO("GetUniqueKeys get eos, send pos for channel, table:{}, batchId:{}, channel:{}", info.name, info.batchId, + info.channelId); if (!alreadyTrainOnce) { // predict场景 LOG_INFO("ProcessEmbInfoDDR first run in eval channel, assume as predict mode, start handle eos"); - std::vector emptySwapOutPos; - SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); sendAllChannel = true; } else { hybridMgmtBlock->SetBlockStatus(EVAL_CHANNEL_ID, true); LOG_INFO("GetUniqueKeys get eos from eval channel, SetBlockStatus=true"); - if (hybridMgmtBlock->IsNeedWaitSave()) { - // train+eval+save场景 - // 当前step n之后需要save,涉及save到train的状态切换。需要: - // 1. 补发pos以启动eval step n-1并完成。 - // 2. eval step n遇到eos结束 - // 3. 开始save,完成后唤醒train的ProcessEmbInfoDDR,所以需要在此之前改变specialProcessStatus - LOG_DEBUG("eval encounter eos and need save after this step" - "send pos change specialProcessStatus, current status:{}, modify to status:{}", - ProcessStatus2Str(specialProcessStatus[info.name]), - ProcessStatus2Str(ProcessStatus::AFTER_SWITCH_FIRST_BATCH)); - vector emptySwapOutPos; - SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); - specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_FIRST_BATCH; - } else { - // train+eval+train场景 - // 交给train的ProcessEmbInfoDDR启动最后n-1步eval - // train发送pos让eval step n-1跑完,到eval step n时各channel遇到eos后结束(train、eval共享的channel除外) - LOG_INFO("GetUniqueKeys get eos, skip send pos for eval channel, table:{}, batchId:{}", info.name, - info.batchId); - } + specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_FIRST_BATCH; } KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); remainBatchOut = false; @@ -1774,6 +1750,7 @@ void HybridMgmt::EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& LOG_DEBUG("h2dNextBatchId, table:{}, next batchId:{}", info.name, hybridMgmtBlock->h2dNextBatchId[info.name]); } +// HBM void HybridMgmt::HandleEosCaseHBM(const string& embName, int batchId, int channelId, bool& remainBatchOut) { bool sendAllChannel = false; @@ -1784,7 +1761,7 @@ void HybridMgmt::HandleEosCaseHBM(const string& embName, int batchId, int channe } else { // train+eval场景 hybridMgmtBlock->SetBlockStatus(EVAL_CHANNEL_ID, true); - LOG_INFO("GetUniqueKeys get eos from eval channel, SetBlockStatus=true"); + LOG_INFO("GetInfoVec[RESTORE]: {}, get eos from eval channel, SetBlockStatus=true", embName); } } KEY_PROCESS_INSTANCE->SendEos(embName, batchId, channelId, sendAllChannel); -- Gitee From a956f45ea8039d5505e40f5a93a2e60677f967e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 24 Jul 2024 11:30:32 +0800 Subject: [PATCH 08/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset,=20fix?= =?UTF-8?q?=20ddr=20finish=20too=20early?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 113 +++++++++++++++++++++++---- src/core/hybrid_mgmt/hybrid_mgmt.h | 8 +- 2 files changed, 101 insertions(+), 20 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 54ce943d..ea66e83a 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -693,7 +693,18 @@ void HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut // 只有在每次GetUniqueKeys的时候才知道上游是否已经EOS // 注意GetUniqueKeys与EOS关联,需要在ProcessEmbInfoDDR最先调用,如需调整位置,请参考并适配其他函数 // 获取GlobalUnique向量 - auto uniqueKeys = GetUniqueKeys(info, remainBatchOut); + bool isEos = false; + auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); + if (isEos) { + std::vector addrs{nullptr, nullptr}; + if (info.channelId == EVAL_CHANNEL_ID) { + addrs.push_back(nullptr); + } + HBMSwapAddrsQue[info.name + SWAP_IN_STR].Pushv(addrs); + HBMSwapAddrsQue[info.name + SWAP_OUT_STR].Pushv(addrs); + LOG_DEBUG("[LQK] enqueue HBMSwapAddrsQue eos, table:{}, batchId:{}, channelId:{}, addrs:{}", + info.name, info.batchId, info.channelId, addrs.size()); + } if (uniqueKeys.empty()) { return; } @@ -1080,13 +1091,16 @@ void HybridMgmt::EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo .name = embInfo.name}; vector h2dEmb; - auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb); + bool isEos = false; + auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb, isEos); if (!isSuccess) { LOG_INFO("HybridMgmt is not running"); return; } + if (!isEos) { + EmbeddingSendDDR(info, h2dEmb); + } - EmbeddingSendDDR(info, h2dEmb); } void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbInfo& embInfo) @@ -1127,13 +1141,16 @@ void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const E .name = embInfo.name}; vector h2dEmb; - auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb); + bool isEos = false; + auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb, isEos); if (!isSuccess) { LOG_INFO("HybridMgmt is not running"); return; } - EmbeddingSendL3Storage(info, h2dEmb); + if (!isEos) { + EmbeddingSendL3Storage(info, h2dEmb); + } } void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, const EmbInfo& embInfo) @@ -1177,7 +1194,8 @@ void HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBa // 只有在每次GetUniqueKeys的时候才知道上游是否已经EOS // 注意GetUniqueKeys与EOS关联,需要在ProcessEmbInfoL3Storage最先调用,如需调整位置,请参考并适配其他函数 // 获取GlobalUnique向量 - auto uniqueKeys = GetUniqueKeys(info, remainBatchOut); + bool isEos = false; + auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); if (uniqueKeys.empty()) { return; } @@ -1401,7 +1419,7 @@ void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut) LOG_INFO("GetUniqueKeys get eos from eval channel, SetBlockStatus=true"); specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_FIRST_BATCH; } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); +// KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); remainBatchOut = false; } @@ -1417,6 +1435,26 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto TimeCost EmbeddingRecvTC = TimeCost(); swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); + if (swapOutAddrs.size() == 2 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr) { // eos + bool sendAllChannel = false; + if (!alreadyTrainOnce) { + // predict场景 + sendAllChannel = true; + } + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, TRAIN_CHANNEL_ID, sendAllChannel); + cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + return true; + } + if (swapOutAddrs.size() == 3 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr && swapOutAddrs[2] == nullptr) { // eos + bool sendAllChannel = false; + if (!alreadyTrainOnce) { + // predict场景 + sendAllChannel = true; + } + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, EVAL_CHANNEL_ID, sendAllChannel); + cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + return true; + } if (!isRunning) { return false; } @@ -1488,7 +1526,7 @@ void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr cvLastUpdateFinishMap[info.name][info.cvNotifyIndex].notify_all(); } -bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb) +bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos) { std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { @@ -1506,12 +1544,14 @@ bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2d return false; } - bool isSuccess = BuildH2DEmbedding(info, h2dEmb); + bool isSuccess = BuildH2DEmbedding(info, h2dEmb, isEos); if (!isSuccess) { return false; } - lastLookUpFinishStepMap[info.name]++; + if (!isEos) { + lastLookUpFinishStepMap[info.name]++; + } cvLastLookUpFinishMap[info.name][info.cvNotifyIndex].notify_all(); return true; @@ -1600,6 +1640,26 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, TimeCost EmbeddingRecvTC = TimeCost(); // finish时会pop空vector,因此需要额外判定isRunning swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); + if (swapOutAddrs.size() == 2 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr) { // eos + bool sendAllChannel = false; + if (!alreadyTrainOnce) { + // predict场景 + sendAllChannel = true; + } + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, TRAIN_CHANNEL_ID, sendAllChannel); + cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + return true; + } + if (swapOutAddrs.size() == 3 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr && swapOutAddrs[2] == nullptr) { // eos + bool sendAllChannel = false; + if (!alreadyTrainOnce) { + // predict场景 + sendAllChannel = true; + } + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, EVAL_CHANNEL_ID, sendAllChannel); + cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + return true; + } if (!isRunning) { return false; } @@ -1681,7 +1741,7 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr cvLastUpdateFinishMap[info.name][info.cvNotifyIndex].notify_all(); } -bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb) +bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos) { std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { @@ -1721,12 +1781,14 @@ bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb) +bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos) { std::vector swapInAddrs = HBMSwapAddrsQue[info.name + SWAP_IN_STR].WaitAndPop(); + if (swapInAddrs.size() == 2 && swapInAddrs[0] == nullptr && swapInAddrs[1] == nullptr) { // eos + isEos = true; + bool sendAllChannel = false; + if (!alreadyTrainOnce) { + // predict场景 + sendAllChannel = true; + } + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, TRAIN_CHANNEL_ID, sendAllChannel); + return true; + } + if (swapInAddrs.size() == 3 && swapInAddrs[0] == nullptr && swapInAddrs[1] == nullptr && swapInAddrs[2] == nullptr) { // eos + isEos = true; + bool sendAllChannel = false; + if (!alreadyTrainOnce) { + // predict场景 + sendAllChannel = true; + } + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, EVAL_CHANNEL_ID, sendAllChannel); + return true; + } if (!isRunning) { return false; } @@ -1932,9 +2014,8 @@ bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dE return true; } -vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut) +vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut, bool& isEos) { - bool isEos = false; auto uniqueKeys = KEY_PROCESS_INSTANCE->GetUniqueKeys(info, isEos, lookUpSwapInAddrsPushId); if (isEos) { HandleEosCase(info, remainBatchOut); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 5f94c96d..558e0690 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -258,7 +258,7 @@ private: void EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs); - bool EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb); + bool EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos); void EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb); @@ -266,7 +266,7 @@ private: void EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, int64_t& dims0); - bool EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb); + bool EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos); void EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& h2dEmb); @@ -284,9 +284,9 @@ private: void HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& swapInKeys, vector& swapOutKeys); - bool BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb); + bool BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos); - vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut); + vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut, bool& isEos); vector GetRestoreVecSec(const EmbBaseInfo& info, bool& remainBatchOut); -- Gitee From d288e5308073f375369bd4d8ab48648a41923858 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 24 Jul 2024 11:47:58 +0800 Subject: [PATCH 09/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset,=20fix?= =?UTF-8?q?=20ddr=20finish=20too=20early,=20but=20ssd=20wait=20to=20be=20m?= =?UTF-8?q?atched?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 32 ++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index ea66e83a..5555120e 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -696,14 +696,15 @@ void HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut bool isEos = false; auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); if (isEos) { - std::vector addrs{nullptr, nullptr}; + vector swapInKeys{0, 0}; if (info.channelId == EVAL_CHANNEL_ID) { - addrs.push_back(nullptr); + swapInKeys = vector(2,1); } - HBMSwapAddrsQue[info.name + SWAP_IN_STR].Pushv(addrs); - HBMSwapAddrsQue[info.name + SWAP_OUT_STR].Pushv(addrs); - LOG_DEBUG("[LQK] enqueue HBMSwapAddrsQue eos, table:{}, batchId:{}, channelId:{}, addrs:{}", - info.name, info.batchId, info.channelId, addrs.size()); + + LOG_DEBUG("[LQK] enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}", info.name, + info.batchId, info.channelId, swapInKeys.size()); + HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); + CheckLookupAddrSuccessDDR(); } if (uniqueKeys.empty()) { return; @@ -993,6 +994,25 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName) } // swap in std::vector keys = HBMSwapKeyQue[swapInName].WaitAndPop(); + if (keys.size() == 2 && keys[0] == keys[1] == 0) //train eos + { + std::vector addrs{nullptr, nullptr}; + HBMSwapAddrsQue[swapInName].Pushv(addrs); + HBMSwapAddrsQue[swapOutName].Pushv(addrs); + LOG_DEBUG("[LQK] enqueue HBMSwapAddrsQue eos, table:{}, batchId:{}, channelId:{}", + embName, id, TRAIN_CHANNEL_ID); + continue; + } + if (keys.size() == 2 && keys[0] == keys[1] == 1) //eval eos + { + std::vector addrs{nullptr, nullptr, nullptr}; + HBMSwapAddrsQue[swapInName].Pushv(addrs); + HBMSwapAddrsQue[swapOutName].Pushv(addrs); + LOG_DEBUG("[LQK] enqueue HBMSwapAddrsQue eos, table:{}, batchId:{}, channelId:{}", + embName, id, EVAL_CHANNEL_ID); + continue; + } + TimeCost lookupAddrsInTC; int rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); if (rc != H_OK) { -- Gitee From a8071c81cc5747218998e7388c77b414c6b6874d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 24 Jul 2024 11:57:50 +0800 Subject: [PATCH 10/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset,=20fix?= =?UTF-8?q?=20ddr=20finish=20too=20early,=20but=20ssd=20wait=20to=20be=20m?= =?UTF-8?q?atched?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 5555120e..29862aa0 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1456,6 +1456,7 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); if (swapOutAddrs.size() == 2 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr) { // eos + LOG_INFO("EmbeddingReceiveDDR get eos from train channel"); bool sendAllChannel = false; if (!alreadyTrainOnce) { // predict场景 @@ -1466,6 +1467,7 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto return true; } if (swapOutAddrs.size() == 3 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr && swapOutAddrs[2] == nullptr) { // eos + LOG_INFO("EmbeddingReceiveDDR get eos from eval channel"); bool sendAllChannel = false; if (!alreadyTrainOnce) { // predict场景 @@ -1991,6 +1993,7 @@ bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dE { std::vector swapInAddrs = HBMSwapAddrsQue[info.name + SWAP_IN_STR].WaitAndPop(); if (swapInAddrs.size() == 2 && swapInAddrs[0] == nullptr && swapInAddrs[1] == nullptr) { // eos + LOG_INFO("BuildH2DEmbedding get eos from train channel"); isEos = true; bool sendAllChannel = false; if (!alreadyTrainOnce) { @@ -2001,6 +2004,7 @@ bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dE return true; } if (swapInAddrs.size() == 3 && swapInAddrs[0] == nullptr && swapInAddrs[1] == nullptr && swapInAddrs[2] == nullptr) { // eos + LOG_INFO("BuildH2DEmbedding get eos from eval channel"); isEos = true; bool sendAllChannel = false; if (!alreadyTrainOnce) { -- Gitee From 4479879e91bc957aa6d40b27743b2c1b63793e72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 24 Jul 2024 16:02:18 +0800 Subject: [PATCH 11/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset,=20fix?= =?UTF-8?q?=20ddr=20finish=20too=20early?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 189 +++++++++++---------------- src/core/hybrid_mgmt/hybrid_mgmt.h | 14 +- 2 files changed, 85 insertions(+), 118 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 29862aa0..579f6599 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -696,15 +696,9 @@ void HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut bool isEos = false; auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); if (isEos) { - vector swapInKeys{0, 0}; - if (info.channelId == EVAL_CHANNEL_ID) { - swapInKeys = vector(2,1); - } - - LOG_DEBUG("[LQK] enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}", info.name, - info.batchId, info.channelId, swapInKeys.size()); - HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); - CheckLookupAddrSuccessDDR(); + EosL1Que[info.name].Pushv(make_pair(true, info.channelId)); + LOG_DEBUG("[LQK] enqueue EosL1Que DDR, table:{}, batchId:{}, channelId:{}", info.name, + info.batchId, info.channelId); } if (uniqueKeys.empty()) { return; @@ -992,27 +986,18 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName) if (!isRunning) { return; } - // swap in - std::vector keys = HBMSwapKeyQue[swapInName].WaitAndPop(); - if (keys.size() == 2 && keys[0] == keys[1] == 0) //train eos - { - std::vector addrs{nullptr, nullptr}; - HBMSwapAddrsQue[swapInName].Pushv(addrs); - HBMSwapAddrsQue[swapOutName].Pushv(addrs); - LOG_DEBUG("[LQK] enqueue HBMSwapAddrsQue eos, table:{}, batchId:{}, channelId:{}", - embName, id, TRAIN_CHANNEL_ID); - continue; - } - if (keys.size() == 2 && keys[0] == keys[1] == 1) //eval eos - { - std::vector addrs{nullptr, nullptr, nullptr}; - HBMSwapAddrsQue[swapInName].Pushv(addrs); - HBMSwapAddrsQue[swapOutName].Pushv(addrs); - LOG_DEBUG("[LQK] enqueue HBMSwapAddrsQue eos, table:{}, batchId:{}, channelId:{}", - embName, id, EVAL_CHANNEL_ID); + pair keyChannel = EosL1Que[embName].WaitAndPop(); + if (keyChannel.first) { + EosL2Que[embName].Pushv(make_pair(true, keyChannel.second)); + LOG_DEBUG("[LQK] enqueue EosL2Que eos, table:{}, batchId:{}, channelId:{}", + embName, id, keyChannel.second); continue; + } else { + EosL2Que[embName].Pushv(make_pair(false, keyChannel.second)); } + // swap in + std::vector keys = HBMSwapKeyQue[swapInName].WaitAndPop(); TimeCost lookupAddrsInTC; int rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); if (rc != H_OK) { @@ -1111,15 +1096,12 @@ void HybridMgmt::EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo .name = embInfo.name}; vector h2dEmb; - bool isEos = false; - auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb, isEos); + auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb); if (!isSuccess) { LOG_INFO("HybridMgmt is not running"); return; } - if (!isEos) { - EmbeddingSendDDR(info, h2dEmb); - } + EmbeddingSendDDR(info, h2dEmb); } @@ -1138,13 +1120,15 @@ void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbI float* ptr = nullptr; vector swapOutAddrs; - auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs); + bool isEos = false; + auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs, isEos); if (!isSuccess) { LOG_INFO("HybridMgmt is not running"); return; } - - EmbeddingUpdateDDR(info, ptr, swapOutAddrs); + if (!isEos) { + EmbeddingUpdateDDR(info, ptr, swapOutAddrs); + } } void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const EmbInfo& embInfo) @@ -1161,16 +1145,13 @@ void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const E .name = embInfo.name}; vector h2dEmb; - bool isEos = false; - auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb, isEos); + auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb); if (!isSuccess) { LOG_INFO("HybridMgmt is not running"); return; } - if (!isEos) { - EmbeddingSendL3Storage(info, h2dEmb); - } + EmbeddingSendL3Storage(info, h2dEmb); } void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, const EmbInfo& embInfo) @@ -1189,9 +1170,11 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons float* ptr = nullptr; vector swapOutAddrs; int64_t dims0 = 0; - EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0); - - EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); + bool isEos = false; + EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0, isEos); + if (!isEos) { + EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); + } } /// 构造训练所需的各种向量数据 @@ -1216,6 +1199,12 @@ void HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBa // 获取GlobalUnique向量 bool isEos = false; auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); + if (isEos) { + EosL1Que[info.name].Pushv(make_pair(true, info.channelId)); + LOG_DEBUG("[LQK] enqueue EosL1Que SSD, table:{}, batchId:{}, channelId:{}", info.name, + info.batchId, info.channelId); + } + if (uniqueKeys.empty()) { return; } @@ -1301,6 +1290,8 @@ void HybridMgmt::InitDataPipelineForDDR(const string& embName) HBMSwapAddrsQue[embName + SWAP_IN_STR]; HBMSwapAddrsQue[embName + SWAP_OUT_STR]; + EosL1Que[embName]; + EosL2Que[embName]; // 初始化lookup线程 lookUpSwapInAddrsPushId[embName]; // 此处初始化,避免多线程竞争导致计数错误 lookUpSwapInAddrsThreads.emplace_back( @@ -1317,6 +1308,9 @@ void HybridMgmt::InitDataPipelineForL3Storage(const string& embName, int extEmbe HBMSwapAddrsQue[embName + SWAP_IN_STR]; HBMSwapAddrsQue[embName + SWAP_OUT_STR]; + EosL1Que[embName]; + EosL2Que[embName]; + HBMSwapKeyQue[embName + ADDR_STR]; HBMSwapKeyForL3StorageQue[embName + SWAP_IN_STR]; HBMSwapKeyForL3StorageQue[embName + ADDR_STR]; @@ -1365,6 +1359,12 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) void HybridMgmt::JoinEmbeddingCacheThread() { + for (auto& p : EosL1Que) { + p.second.DestroyQueue(); + } + for (auto& p : EosL2Que) { + p.second.DestroyQueue(); + } for (auto& p : HBMSwapAddrsQue) { p.second.DestroyQueue(); } @@ -1443,7 +1443,7 @@ void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut) remainBatchOut = false; } -bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs) +bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, bool& isEos) { std::unique_lock lastRecvFinishLocker(lastRecvFinishMutexMap[info.name][info.threadIdx]); cvLastRecvFinishMap[info.name][info.threadIdx].wait(lastRecvFinishLocker, [info, this] { @@ -1452,31 +1452,22 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto if (!isRunning) { return false; } - TimeCost EmbeddingRecvTC = TimeCost(); - - swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); - if (swapOutAddrs.size() == 2 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr) { // eos - LOG_INFO("EmbeddingReceiveDDR get eos from train channel"); - bool sendAllChannel = false; - if (!alreadyTrainOnce) { - // predict场景 - sendAllChannel = true; - } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, TRAIN_CHANNEL_ID, sendAllChannel); - cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); - return true; - } - if (swapOutAddrs.size() == 3 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr && swapOutAddrs[2] == nullptr) { // eos - LOG_INFO("EmbeddingReceiveDDR get eos from eval channel"); + pair keyChannel = EosL2Que[info.name].WaitAndPop(); + if (keyChannel.first) { + isEos = true; + LOG_INFO("EmbeddingReceiveDDR get eos from channel: {}", keyChannel.second); bool sendAllChannel = false; if (!alreadyTrainOnce) { // predict场景 sendAllChannel = true; } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, EVAL_CHANNEL_ID, sendAllChannel); + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, keyChannel.second, sendAllChannel); cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); return true; } + + TimeCost EmbeddingRecvTC = TimeCost(); + swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); if (!isRunning) { return false; } @@ -1548,7 +1539,7 @@ void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr cvLastUpdateFinishMap[info.name][info.cvNotifyIndex].notify_all(); } -bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos) +bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb) { std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { @@ -1566,14 +1557,12 @@ bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2d return false; } - bool isSuccess = BuildH2DEmbedding(info, h2dEmb, isEos); + bool isSuccess = BuildH2DEmbedding(info, h2dEmb); if (!isSuccess) { return false; } - if (!isEos) { - lastLookUpFinishStepMap[info.name]++; - } + lastLookUpFinishStepMap[info.name]++; cvLastLookUpFinishMap[info.name][info.cvNotifyIndex].notify_all(); return true; @@ -1647,7 +1636,7 @@ void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& } bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, - int64_t& dims0) + int64_t& dims0, bool& isEos) { std::unique_lock lastRecvFinishLocker(lastRecvFinishMutexMap[info.name][info.threadIdx]); cvLastRecvFinishMap[info.name][info.threadIdx].wait(lastRecvFinishLocker, [info, this] { @@ -1656,32 +1645,27 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, if (!isRunning) { return false; } - // DDR swap out key need to be removed - LookUpAndRemoveAddrs(info); - TimeCost EmbeddingRecvTC = TimeCost(); - // finish时会pop空vector,因此需要额外判定isRunning - swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); - if (swapOutAddrs.size() == 2 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr) { // eos - bool sendAllChannel = false; - if (!alreadyTrainOnce) { - // predict场景 - sendAllChannel = true; - } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, TRAIN_CHANNEL_ID, sendAllChannel); - cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); - return true; - } - if (swapOutAddrs.size() == 3 && swapOutAddrs[0] == nullptr && swapOutAddrs[1] == nullptr && swapOutAddrs[2] == nullptr) { // eos + pair keyChannel = EosL1Que[info.name].WaitAndPop(); + if (keyChannel.first) { + isEos = true; + LOG_INFO("EmbeddingReceiveL3Storage get eos from channel: {}", keyChannel.second); bool sendAllChannel = false; if (!alreadyTrainOnce) { // predict场景 sendAllChannel = true; } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, EVAL_CHANNEL_ID, sendAllChannel); + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, keyChannel.second, sendAllChannel); cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); return true; } + + // DDR swap out key need to be removed + LookUpAndRemoveAddrs(info); + + TimeCost EmbeddingRecvTC = TimeCost(); + // finish时会pop空vector,因此需要额外判定isRunning + swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); if (!isRunning) { return false; } @@ -1763,7 +1747,7 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr cvLastUpdateFinishMap[info.name][info.cvNotifyIndex].notify_all(); } -bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos) +bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb) { std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutexMap[info.name][info.threadIdx]); cvLastUpdateFinishMap[info.name][info.threadIdx].wait(lastUpdateFinishLocker, [info, this] { @@ -1803,14 +1787,12 @@ bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vectorL3Storage HBMSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].Pushv(hbmSwapInfo.swapOutL3StorageKeys); HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].Pushv(hbmSwapInfo.swapOutL3StorageAddrOffs); + + EosL1Que[info.name].Pushv(make_pair(false, info.channelId)); } -bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos) +bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb) { std::vector swapInAddrs = HBMSwapAddrsQue[info.name + SWAP_IN_STR].WaitAndPop(); - if (swapInAddrs.size() == 2 && swapInAddrs[0] == nullptr && swapInAddrs[1] == nullptr) { // eos - LOG_INFO("BuildH2DEmbedding get eos from train channel"); - isEos = true; - bool sendAllChannel = false; - if (!alreadyTrainOnce) { - // predict场景 - sendAllChannel = true; - } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, TRAIN_CHANNEL_ID, sendAllChannel); - return true; - } - if (swapInAddrs.size() == 3 && swapInAddrs[0] == nullptr && swapInAddrs[1] == nullptr && swapInAddrs[2] == nullptr) { // eos - LOG_INFO("BuildH2DEmbedding get eos from eval channel"); - isEos = true; - bool sendAllChannel = false; - if (!alreadyTrainOnce) { - // predict场景 - sendAllChannel = true; - } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, EVAL_CHANNEL_ID, sendAllChannel); - return true; - } if (!isRunning) { return false; } @@ -2284,8 +2246,9 @@ void HybridMgmt::EnqueueSwapInfo(const EmbBaseInfo& info, pair, info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size()); HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(swapOutKeys); HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); - CheckLookupAddrSuccessDDR(); + + EosL1Que[info.name].Pushv(make_pair(false, info.channelId)); } bool HybridMgmt::IsTrainAndEvalCase() diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 558e0690..c75811ec 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -166,6 +166,9 @@ public: std::map>> HBMSwapAddrsQue; std::map>> DDRSwapAddrsQue; + std::map>> EosL1Que; // pair + std::map>> EosL2Que; + std::mutex evictMut; std::map> trainKeysSet; @@ -254,19 +257,20 @@ private: void HandleEosCaseHBM(const string& embName, int batchId, int channelId, bool& remainBatchOut); - bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs); + bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, bool& isEos); void EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs); - bool EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos); + bool EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb); void EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb); - bool EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0); + bool EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0, + bool& isEos); void EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, int64_t& dims0); - bool EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos); + bool EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb); void EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& h2dEmb); @@ -284,7 +288,7 @@ private: void HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& swapInKeys, vector& swapOutKeys); - bool BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb, bool& isEos); + bool BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb); vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut, bool& isEos); -- Gitee From 6babb68c493d2f33eb08b4b8328174b128a24e1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 24 Jul 2024 16:59:49 +0800 Subject: [PATCH 12/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset,=20fix?= =?UTF-8?q?=20ddr=20finish=20too=20early?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 579f6599..fe78e244 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -697,8 +697,8 @@ void HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); if (isEos) { EosL1Que[info.name].Pushv(make_pair(true, info.channelId)); - LOG_DEBUG("[LQK] enqueue EosL1Que DDR, table:{}, batchId:{}, channelId:{}", info.name, - info.batchId, info.channelId); + LOG_DEBUG("[LQK] enqueue EosL1Que DDR, table:{}, batchId:{}, channelId:{}, EosL1Que size: {}", info.name, + info.batchId, info.channelId, EosL1Que.size()); } if (uniqueKeys.empty()) { return; @@ -989,11 +989,13 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName) pair keyChannel = EosL1Que[embName].WaitAndPop(); if (keyChannel.first) { EosL2Que[embName].Pushv(make_pair(true, keyChannel.second)); - LOG_DEBUG("[LQK] enqueue EosL2Que eos, table:{}, batchId:{}, channelId:{}", - embName, id, keyChannel.second); + LOG_DEBUG("[LQK] enqueue EosL2Que eos, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, EosL2Que.size: {}", + embName, id, keyChannel.second, EosL1Que.size(), EosL2Que.size()); continue; } else { EosL2Que[embName].Pushv(make_pair(false, keyChannel.second)); + LOG_DEBUG("[LQK] enqueue EosL2Que normal, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, EosL2Que.size: {}", + embName, id, keyChannel.second, EosL1Que.size(), EosL2Que.size()); } // swap in @@ -2249,6 +2251,9 @@ void HybridMgmt::EnqueueSwapInfo(const EmbBaseInfo& info, pair, CheckLookupAddrSuccessDDR(); EosL1Que[info.name].Pushv(make_pair(false, info.channelId)); + LOG_DEBUG("enqueue EosL1Que, normal status, table:{}, batchId:{}, channelId:{}, EosL1Que.size: {}", info.name, + info.batchId, info.channelId, EosL1Que.size()); + } bool HybridMgmt::IsTrainAndEvalCase() -- Gitee From 4ffca33e660cb8cde1562c26a348f287a54012cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 23 Jul 2024 11:57:38 +0800 Subject: [PATCH 13/31] =?UTF-8?q?=E3=80=90FIX=E3=80=91eos=20dataset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 164 ++++++++++++++++++--------- src/core/hybrid_mgmt/hybrid_mgmt.h | 10 +- src/core/key_process/key_process.cpp | 115 +++++++++++++------ src/core/utils/common.h | 17 ++- src/dataset_tf/eos_dataset_op.cc | 151 ++++++++++++++++++++++-- src/dataset_tf/eos_dataset_op.h | 2 + 6 files changed, 351 insertions(+), 108 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 4801f95b..4be1b99d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -490,16 +490,18 @@ void HybridMgmt::TrainTask(TaskType type) void HybridMgmt::EvalTask(TaskType type) { #ifndef GTEST - int channelId = EVAL_CHANNEL_ID; - int& evalBatchId = hybridMgmtBlock->hybridBatchId[channelId]; + int& evalBatchId = hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]; do { - hybridMgmtBlock->CheckAndSetBlock(channelId); - if (hybridMgmtBlock->GetBlockStatus(channelId)) { + hybridMgmtBlock->CheckAndSetBlock(EVAL_CHANNEL_ID); + if (hybridMgmtBlock->GetBlockStatus(EVAL_CHANNEL_ID)) { LOG_DEBUG("eval channel block at batchId:{}, needWaitSave:{}", evalBatchId, hybridMgmtBlock->IsNeedWaitSave()); std::unique_lock checkSaveLocker(saveMutex); cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); + LOG_DEBUG("eval channel block, python batch id:{}, hybridBatchId:{}", + hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID], evalBatchId); + if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) { // Before waking the data process for training, Recover the backed-up training state RecoverTrainStatus(); @@ -510,12 +512,12 @@ void HybridMgmt::EvalTask(TaskType type) } LOG_DEBUG("wake TrainTask"); - hybridMgmtBlock->DoBlock(channelId); + hybridMgmtBlock->DoBlock(EVAL_CHANNEL_ID); } if (!isRunning) { return; } - LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, evalBatchId); + LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", EVAL_CHANNEL_ID, evalBatchId); ParseKeys(EVAL_CHANNEL_ID, evalBatchId, type); } while (true); @@ -658,7 +660,7 @@ void HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo& info, bool& remainBatchOut SendUniqKeysAndRestoreVecHBM(info, infoVecs, isGrad); } - // 发送恢复向量 + // 发送恢复向量和hotPos TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, info.channelId, info.name); LOG_DEBUG("table:{}, sendRestoreSyncTC(ms):{}, parseKeysTc HBM mode (ms):{}", info.name, @@ -691,7 +693,13 @@ void HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info, bool& remainBatchOut // 只有在每次GetUniqueKeys的时候才知道上游是否已经EOS // 注意GetUniqueKeys与EOS关联,需要在ProcessEmbInfoDDR最先调用,如需调整位置,请参考并适配其他函数 // 获取GlobalUnique向量 - auto uniqueKeys = GetUniqueKeys(info, remainBatchOut); + bool isEos = false; + auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); + if (isEos) { + EosL1Que[info.name].Pushv(make_pair(true, info.channelId)); + LOG_DEBUG("[LQK] enqueue EosL1Que DDR, table:{}, batchId:{}, channelId:{}, EosL1Que size: {}", info.name, + info.batchId, info.channelId, EosL1Que.size()); + } if (uniqueKeys.empty()) { return; } @@ -978,6 +986,18 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName) if (!isRunning) { return; } + pair keyChannel = EosL1Que[embName].WaitAndPop(); + if (keyChannel.first) { + EosL2Que[embName].Pushv(make_pair(true, keyChannel.second)); + LOG_DEBUG("[LQK] enqueue EosL2Que eos, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, EosL2Que.size: {}", + embName, id, keyChannel.second, EosL1Que.size(), EosL2Que.size()); + continue; + } else { + EosL2Que[embName].Pushv(make_pair(false, keyChannel.second)); + LOG_DEBUG("[LQK] enqueue EosL2Que normal, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, EosL2Que.size: {}", + embName, id, keyChannel.second, EosL1Que.size(), EosL2Que.size()); + } + // swap in std::vector keys = HBMSwapKeyQue[swapInName].WaitAndPop(); TimeCost lookupAddrsInTC; @@ -1083,8 +1103,8 @@ void HybridMgmt::EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo LOG_INFO("HybridMgmt is not running"); return; } - EmbeddingSendDDR(info, h2dEmb); + } void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbInfo& embInfo) @@ -1102,13 +1122,15 @@ void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbI float* ptr = nullptr; vector swapOutAddrs; - auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs); + bool isEos = false; + auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs, isEos); if (!isSuccess) { LOG_INFO("HybridMgmt is not running"); return; } - - EmbeddingUpdateDDR(info, ptr, swapOutAddrs); + if (!isEos) { + EmbeddingUpdateDDR(info, ptr, swapOutAddrs); + } } void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const EmbInfo& embInfo) @@ -1150,9 +1172,11 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons float* ptr = nullptr; vector swapOutAddrs; int64_t dims0 = 0; - EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0); - - EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); + bool isEos = false; + EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0, isEos); + if (!isEos) { + EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); + } } /// 构造训练所需的各种向量数据 @@ -1175,7 +1199,14 @@ void HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBa // 只有在每次GetUniqueKeys的时候才知道上游是否已经EOS // 注意GetUniqueKeys与EOS关联,需要在ProcessEmbInfoL3Storage最先调用,如需调整位置,请参考并适配其他函数 // 获取GlobalUnique向量 - auto uniqueKeys = GetUniqueKeys(info, remainBatchOut); + bool isEos = false; + auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); + if (isEos) { + EosL1Que[info.name].Pushv(make_pair(true, info.channelId)); + LOG_DEBUG("[LQK] enqueue EosL1Que SSD, table:{}, batchId:{}, channelId:{}", info.name, + info.batchId, info.channelId); + } + if (uniqueKeys.empty()) { return; } @@ -1261,6 +1292,8 @@ void HybridMgmt::InitDataPipelineForDDR(const string& embName) HBMSwapAddrsQue[embName + SWAP_IN_STR]; HBMSwapAddrsQue[embName + SWAP_OUT_STR]; + EosL1Que[embName]; + EosL2Que[embName]; // 初始化lookup线程 lookUpSwapInAddrsPushId[embName]; // 此处初始化,避免多线程竞争导致计数错误 lookUpSwapInAddrsThreads.emplace_back( @@ -1277,6 +1310,9 @@ void HybridMgmt::InitDataPipelineForL3Storage(const string& embName, int extEmbe HBMSwapAddrsQue[embName + SWAP_IN_STR]; HBMSwapAddrsQue[embName + SWAP_OUT_STR]; + EosL1Que[embName]; + EosL2Que[embName]; + HBMSwapKeyQue[embName + ADDR_STR]; HBMSwapKeyForL3StorageQue[embName + SWAP_IN_STR]; HBMSwapKeyForL3StorageQue[embName + ADDR_STR]; @@ -1325,6 +1361,12 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) void HybridMgmt::JoinEmbeddingCacheThread() { + for (auto& p : EosL1Que) { + p.second.DestroyQueue(); + } + for (auto& p : EosL2Que) { + p.second.DestroyQueue(); + } for (auto& p : HBMSwapAddrsQue) { p.second.DestroyQueue(); } @@ -1379,55 +1421,31 @@ void HybridMgmt::HandleReachMaxStepCase(const EmbBaseInfo& info, bool& remainBat hybridMgmtBlock->SetBlockStatus(TRAIN_CHANNEL_ID, true); } +// DDR void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut) { LOG_INFO("GetUniqueKeys get eos, handle final batch for current epoch, table:{}, channel:{}, batchId:{}", info.name, info.channelId, info.batchId); bool sendAllChannel = false; - if (info.channelId == TRAIN_CHANNEL_ID) { - vector emptySwapOutPos; - SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); - LOG_INFO("GetUniqueKeys get eos, send pos for train channel, table:{}, batchId:{}", info.name, info.batchId); - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); - remainBatchOut = false; - return; - } + vector emptySwapOutPos; + SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); + LOG_INFO("GetUniqueKeys get eos, send pos for channel, table:{}, batchId:{}, channel:{}", info.name, info.batchId, + info.channelId); if (!alreadyTrainOnce) { // predict场景 LOG_INFO("ProcessEmbInfoDDR first run in eval channel, assume as predict mode, start handle eos"); - std::vector emptySwapOutPos; - SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); sendAllChannel = true; } else { hybridMgmtBlock->SetBlockStatus(EVAL_CHANNEL_ID, true); LOG_INFO("GetUniqueKeys get eos from eval channel, SetBlockStatus=true"); - if (hybridMgmtBlock->IsNeedWaitSave()) { - // train+eval+save场景 - // 当前step n之后需要save,涉及save到train的状态切换。需要: - // 1. 补发pos以启动eval step n-1并完成。 - // 2. eval step n遇到eos结束 - // 3. 开始save,完成后唤醒train的ProcessEmbInfoDDR,所以需要在此之前改变specialProcessStatus - LOG_DEBUG("eval encounter eos and need save after this step" - "send pos change specialProcessStatus, current status:{}, modify to status:{}", - ProcessStatus2Str(specialProcessStatus[info.name]), - ProcessStatus2Str(ProcessStatus::AFTER_SWITCH_FIRST_BATCH)); - vector emptySwapOutPos; - SendTensorForSwap(info, lastSwapInPosMap[info.name], emptySwapOutPos); - specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_FIRST_BATCH; - } else { - // train+eval+train场景 - // 交给train的ProcessEmbInfoDDR启动最后n-1步eval - // train发送pos让eval step n-1跑完,到eval step n时各channel遇到eos后结束(train、eval共享的channel除外) - LOG_INFO("GetUniqueKeys get eos, skip send pos for eval channel, table:{}, batchId:{}", info.name, - info.batchId); - } + specialProcessStatus[info.name] = ProcessStatus::AFTER_SWITCH_FIRST_BATCH; } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); +// KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId, sendAllChannel); remainBatchOut = false; } -bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs) +bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, bool& isEos) { std::unique_lock lastRecvFinishLocker(lastRecvFinishMutexMap[info.name][info.threadIdx]); cvLastRecvFinishMap[info.name][info.threadIdx].wait(lastRecvFinishLocker, [info, this] { @@ -1436,8 +1454,21 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto if (!isRunning) { return false; } - TimeCost EmbeddingRecvTC = TimeCost(); + pair keyChannel = EosL2Que[info.name].WaitAndPop(); + if (keyChannel.first) { + isEos = true; + LOG_INFO("EmbeddingReceiveDDR get eos from channel: {}", keyChannel.second); + bool sendAllChannel = false; + if (!alreadyTrainOnce) { + // predict场景 + sendAllChannel = true; + } + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, keyChannel.second, sendAllChannel); + cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + return true; + } + TimeCost EmbeddingRecvTC = TimeCost(); swapOutAddrs = HBMSwapAddrsQue[info.name + SWAP_OUT_STR].WaitAndPop(); if (!isRunning) { return false; @@ -1607,7 +1638,7 @@ void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& } bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, - int64_t& dims0) + int64_t& dims0, bool& isEos) { std::unique_lock lastRecvFinishLocker(lastRecvFinishMutexMap[info.name][info.threadIdx]); cvLastRecvFinishMap[info.name][info.threadIdx].wait(lastRecvFinishLocker, [info, this] { @@ -1616,6 +1647,21 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, if (!isRunning) { return false; } + + pair keyChannel = EosL1Que[info.name].WaitAndPop(); + if (keyChannel.first) { + isEos = true; + LOG_INFO("EmbeddingReceiveL3Storage get eos from channel: {}", keyChannel.second); + bool sendAllChannel = false; + if (!alreadyTrainOnce) { + // predict场景 + sendAllChannel = true; + } + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, keyChannel.second, sendAllChannel); + cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + return true; + } + // DDR swap out key need to be removed LookUpAndRemoveAddrs(info); @@ -1772,6 +1818,7 @@ void HybridMgmt::EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& LOG_DEBUG("h2dNextBatchId, table:{}, next batchId:{}", info.name, hybridMgmtBlock->h2dNextBatchId[info.name]); } +// HBM void HybridMgmt::HandleEosCaseHBM(const string& embName, int batchId, int channelId, bool& remainBatchOut) { bool sendAllChannel = false; @@ -1782,7 +1829,7 @@ void HybridMgmt::HandleEosCaseHBM(const string& embName, int batchId, int channe } else { // train+eval场景 hybridMgmtBlock->SetBlockStatus(EVAL_CHANNEL_ID, true); - LOG_INFO("GetUniqueKeys get eos from eval channel, SetBlockStatus=true"); + LOG_INFO("GetInfoVec[RESTORE]: {}, get eos from eval channel, SetBlockStatus=true", embName); } } KEY_PROCESS_INSTANCE->SendEos(embName, batchId, channelId, sendAllChannel); @@ -1832,6 +1879,10 @@ void HybridMgmt::HandleFirstBatchCaseDDR(const EmbBaseInfo& info, info.batchId, info.channelId, swapInKeys.size(), emptySwapOutKeys.size()); HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(emptySwapOutKeys); HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); + + EosL1Que[info.name].Pushv(make_pair(false, info.channelId)); + LOG_DEBUG("enqueue EosL1Que, normal status, table:{}, batchId:{}, channelId:{}, EosL1Que.size: {}", info.name, + info.batchId, info.channelId, EosL1Que.size()); } void HybridMgmt::HandleFirstBatchCaseL3Storage(const EmbBaseInfo& info, @@ -1924,6 +1975,8 @@ void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, vectorL3Storage HBMSwapKeyForL3StorageQue[info.name + SWAP_OUT_STR].Pushv(hbmSwapInfo.swapOutL3StorageKeys); HBMSwapKeyForL3StorageQue[info.name + ADDR_STR].Pushv(hbmSwapInfo.swapOutL3StorageAddrOffs); + + EosL1Que[info.name].Pushv(make_pair(false, info.channelId)); } bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb) @@ -1953,9 +2006,8 @@ bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dE return true; } -vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut) +vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut, bool& isEos) { - bool isEos = false; auto uniqueKeys = KEY_PROCESS_INSTANCE->GetUniqueKeys(info, isEos, lookUpSwapInAddrsPushId); if (isEos) { HandleEosCase(info, remainBatchOut); @@ -2200,8 +2252,12 @@ void HybridMgmt::EnqueueSwapInfo(const EmbBaseInfo& info, pair, info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size()); HBMSwapKeyQue[info.name + SWAP_OUT_STR].Pushv(swapOutKeys); HBMSwapKeyQue[info.name + SWAP_IN_STR].Pushv(swapInKeys); - CheckLookupAddrSuccessDDR(); + + EosL1Que[info.name].Pushv(make_pair(false, info.channelId)); + LOG_DEBUG("enqueue EosL1Que, normal status, table:{}, batchId:{}, channelId:{}, EosL1Que.size: {}", info.name, + info.batchId, info.channelId, EosL1Que.size()); + } bool HybridMgmt::IsTrainAndEvalCase() diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 5f94c96d..c75811ec 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -166,6 +166,9 @@ public: std::map>> HBMSwapAddrsQue; std::map>> DDRSwapAddrsQue; + std::map>> EosL1Que; // pair + std::map>> EosL2Que; + std::mutex evictMut; std::map> trainKeysSet; @@ -254,7 +257,7 @@ private: void HandleEosCaseHBM(const string& embName, int batchId, int channelId, bool& remainBatchOut); - bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs); + bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, bool& isEos); void EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs); @@ -262,7 +265,8 @@ private: void EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb); - bool EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0); + bool EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0, + bool& isEos); void EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, int64_t& dims0); @@ -286,7 +290,7 @@ private: bool BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dEmb); - vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut); + vector GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut, bool& isEos); vector GetRestoreVecSec(const EmbBaseInfo& info, bool& remainBatchOut); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 1cb9f992..fd9d054b 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -305,9 +305,9 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) break; } LOG_INFO(KEY_PROCESS "getAndProcessTC(ms):{}, key process cost:{}," - " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}", + " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}, isEos:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, - batch->name, batch->channel, threadId, batch->batchId); + batch->name, batch->channel, threadId, batch->batchId, batch->isEos); int queueIndex = threadId + (MAX_KEY_PROCESS_THREAD * batch->channel); auto batchQueue = SingletonQueue::GetInstances(queueIndex); batchQueue->PutDirty(move(batch)); @@ -397,6 +397,26 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) { + if (batch->isEos) { + if (!rankInfo.isDDR) { // HBM +// auto tensors = make_unique>(); + std::unique_lock lockGuard(mut); +// storage.push_front(move(tensors)); + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); + lockGuard.unlock(); + LOG_INFO("KeyProcessTaskHelper hbm eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", + batch->name, batch->batchId, batch->channel, threadId); + return true; + } + // DDR + vector uniqueKeys; + std::unique_lock lockGuard(mut); + uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); + lockGuard.unlock(); + LOG_INFO("KeyProcessTaskHelper ddr eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", + batch->name, batch->batchId, batch->channel, threadId); + return true; + } vector splitKeys; vector restore; vector hotPos; @@ -440,11 +460,12 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, hotPos.resize(hotEmbTotCount[batch->name], 0); tensors->push_back(Vec2TensorI32(hotPos)); + // HBM把restore、unique、idoffset做成了Tensor,放到infolist里面了(hbm第一个get的是tensors) if (!rankInfo.isDDR) { PushGlobalUniqueTensors(tensors, lookupKeys, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); PushResultHBM(batch, move(tensors)); - } else { + } else { // DDR 则保留原有的数据结构,idoffset在mgmt侧组装(ddr第一个get的是unique) std::vector lookupKeysUint(lookupKeys.begin(), lookupKeys.end()); vector uniqueKeys; vector restoreVecSec; @@ -513,7 +534,7 @@ void KeyProcess::PushResultHBM(unique_ptr& batch, unique_ptr lockGuard(mut); storage.push_front(move(tensors)); - infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); lockGuard.unlock(); } @@ -522,8 +543,8 @@ void KeyProcess::PushResultDDR(unique_ptr& batch, unique_ptr lockGuard(mut); storage.push_front(move(tensors)); - infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); - uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(uniqueKeys))); + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); + uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); restoreVecSecList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(restoreVecSec))); lockGuard.unlock(); } @@ -546,6 +567,10 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) const while (true) { batch = batchQueue->TryPop(); if (batch != nullptr) { + if (batch->CheckAndSetEos()) { + LOG_INFO("GetBatchData eos, table name:{}, batchId:{}, channelId:{} threadId:{}", batch->name, + batch->batchId, channel, commId); + } break; } this_thread::sleep_for(100us); @@ -1164,11 +1189,11 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos, TimeCost tc = TimeCost(); HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - bool cancelMonitor = false; - thread timeoutMonitor; - if (info.batchId != 0) { - timeoutMonitor = StartEosMonitorThread(info, cancelMonitor); - } +// bool cancelMonitor = false; +// thread timeoutMonitor; +// if (info.batchId != 0) { +// timeoutMonitor = StartEosMonitorThread(info, cancelMonitor); +// } // 循环尝试获取list中的数据;如果key process线程退出或者处理数据超时,返回空vector @@ -1192,14 +1217,23 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos, } try { auto infoVec = GetInfo(uniqueKeysList, info); - ret = get>(infoVec); - break; - } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); - isEos = IsGetUniqueKeysEos(info, startTime, lookUpSwapInAddrsPushId); + isEos = get(infoVec); if (isEos) { + LOG_WARN(KEY_PROCESS "GetUniqueKeys eos! {}[{}]:{}", + info.name, info.channelId, info.batchId); break; } + ret = get>(infoVec); + break; + } catch (EmptyList&) { +// unique_lock lockEosGuard(eosMutex); +// isEos = IsGetUniqueKeysEos(info, startTime, lookUpSwapInAddrsPushId); +// if (isEos) { +// break; +// } +// LOG_DEBUG(KEY_PROCESS "GetUniqueKeys EmptyList! {}[{}]:{}", +// info.name, info.channelId, info.batchId); + this_thread::sleep_for(1ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed table:{}, channel:{}, mgmt batchId:{}, wrong top", @@ -1207,10 +1241,10 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos, this_thread::sleep_for(1ms); } } - cancelMonitor = true; - if (timeoutMonitor.joinable()) { - timeoutMonitor.join(); - } +// cancelMonitor = true; +// if (timeoutMonitor.joinable()) { +// timeoutMonitor.join(); +// } return ret; } @@ -1276,16 +1310,16 @@ std::vector KeyProcess::GetRestoreVecSec(const EmbBaseInfo& info) auto ret = GetInfo(restoreVecSecList, info); return get>(ret); } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); - // readEmbKey真实的次数是readEmbedBatchId减1 +// unique_lock lockEosGuard(eosMutex); +// // readEmbKey真实的次数是readEmbedBatchId减1 int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1; - // 避免eos在keyProcess还未处理完数据时插队到通道前面 - if (isNeedSendEos[info.channelId] && readEmbKeyBatchId < info.batchId && - hybridMgmtBlock->h2dNextBatchId[info.name] == info.batchId) { - LOG_ERROR("channelId:{} batchId:{}, GetRestoreVecSec eos, code should not reach here", - info.channelId, info.batchId); - throw runtime_error("GetRestoreVecSec eos, code should not reach here"); - } +// // 避免eos在keyProcess还未处理完数据时插队到通道前面 +// if (isNeedSendEos[info.channelId] && readEmbKeyBatchId < info.batchId && +// hybridMgmtBlock->h2dNextBatchId[info.name] == info.batchId) { +// LOG_ERROR("channelId:{} batchId:{}, GetRestoreVecSec eos, code should not reach here", +// info.channelId, info.batchId); +// throw runtime_error("GetRestoreVecSec eos, code should not reach here"); +// } LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", info.name, info.channelId, info.batchId, readEmbKeyBatchId); this_thread::sleep_for(1ms); @@ -1335,8 +1369,8 @@ void KeyProcess::SendEos(const std::string& embName, int batchId, int channel, b this_thread::sleep_for(1000ms); } readySendEosCnt[channel].store(0); - isNeedSendEos[channel] = false; - LOG_DEBUG("isNeedSendEos set to false, table:{}, channelId:{} batchId:{}", embName, channel, batchId); +// isNeedSendEos[channel] = false; + LOG_DEBUG("sendEos finish all, table:{}, channelId:{} batchId:{}", embName, channel, batchId); #endif } @@ -1383,17 +1417,24 @@ unique_ptr> KeyProcess::GetInfoVec(const EmbBaseInfo &info, Proce try { auto infoVec = GetInfo(*list, info); + isEos = get(infoVec); + if (isEos) { + LOG_WARN(KEY_PROCESS "GetInfoVec eos! {}[{}]:{}", info.name, info.channelId, info.batchId); + break; + } auto it = get>>::iterator>(infoVec); ret = std::move(*it); std::unique_lock lockGuard(mut); storage.erase(it); break; } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); - isEos = IsGetInfoVecEos(info.batchId, info.name, info.channelId); - if (isEos) { - break; - } +// unique_lock lockEosGuard(eosMutex); +// isEos = IsGetInfoVecEos(info.batchId, info.name, info.channelId); +// if (isEos) { +// break; +// } +// LOG_DEBUG(KEY_PROCESS "GetInfoVec EmptyList! {}[{}]:{}", info.name, info.channelId, info.batchId); + LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", info.name, info.channelId, info.batchId, (hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1)); this_thread::sleep_for(1ms); @@ -1420,7 +1461,7 @@ void KeyProcess::SendA2A(const vector& a2aInfo, const string& embName, int std::unique_lock lockGuard(mut); storage.push_front(move(tensors)); - all2AllList[embName][channel].push(make_tuple(batch, embName, storage.begin())); + all2AllList[embName][channel].push(make_tuple(batch, embName, false, storage.begin())); lockGuard.unlock(); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 8c7528f4..e68168ae 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -120,9 +120,9 @@ using freq_num_t = int64_t; using EmbNameT = std::string; using KeysT = std::vector; using LookupKeyT = std::tuple; // batch_id quarry_lable keys_vector -using UinqueKeyT = std::tuple>; +using UinqueKeyT = std::tuple>; using RestoreVecSecT = std::tuple>; -using TensorInfoT = std::tuple>>::iterator>; +using TensorInfoT = std::tuple>>::iterator>; namespace HybridOption { const unsigned int USE_STATIC = 0x001; @@ -183,11 +183,24 @@ struct Batch { return s; } + bool CheckAndSetEos() + { + for (int i = 0; i < 8; i++) { + if (sample[i] != 0) + { + return false; + } + } + isEos = true; + return true; + } + std::vector sample; std::string name; size_t batchSize; int batchId; int channel = 0; + bool isEos = false; time_t timestamp{-1}; }; diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index afc3fe3a..72be6696 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -20,8 +20,11 @@ See the License for the specific language governing permissions and #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/mutex.h" + #if defined(TF_VERSION_TF2) #include "tensorflow/core/data/name_utils.h" #endif @@ -76,16 +79,20 @@ class EosDatasetOp::Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext *ctx, const DatasetBase *input, int32_t channelId, int32_t maxTrainSteps, - int32_t maxEvalSteps) + int32_t maxEvalSteps, + const DataTypeVector& outputTypes, + const std::vector& outputShapes) : DatasetBase(DatasetContext(ctx)), input_(input), channelId_(channelId), maxTrainSteps_(maxTrainSteps), maxEvalSteps_(maxEvalSteps), + outputTypes_(outputTypes), + outputShapes_(outputShapes), id_(g_datasetId[channelId]) { input_->Ref(); - auto os_input = input->output_shapes(); - output_shapes_ = os_input; +// auto os_input = input->output_shapes(); +// output_shapes_ = os_input; MPI_Comm_group(MPI_COMM_WORLD, &g_group); MPI_Comm_create(MPI_COMM_WORLD, g_group, &g_comm[channelId]); @@ -118,14 +125,14 @@ public: this, prefix_para}); } - const DataTypeVector &output_dtypes() const override + const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); + return outputTypes_; } - const std::vector &output_shapes() const override + const std::vector& output_shapes() const override { - return output_shapes_; + return outputShapes_; } string DebugString() const override @@ -186,6 +193,98 @@ private: } #endif + int64_t GetTensorElementNum(size_t index) { + PartialTensorShape tensor_shape = dataset()->output_shapes()[index]; + int64_t element_number = 1LL; + for (int32_t i = 0; i < tensor_shape.dims(); i++) { + element_number *= tensor_shape.dim_size(i); + } + return element_number; + } + + bool IsUnknowShape(const PartialTensorShape& output_shapes) const { + if (output_shapes.unknown_rank()) { + return true; + } + for (int32_t i = 0; i < output_shapes.dims(); i++) { + if (output_shapes.dim_size(i) == -1) { + return true; + } + } + return false; + } + + Tensor CreateTensorByShape(const PartialTensorShape& output_shapes, const DataType& tensor_data_type) { + TensorShape tf_shape; + for (int32_t i = 0; i < output_shapes.dims(); i++) { + tf_shape.AddDim(output_shapes.dim_size(i)); + } + LOG_INFO("[LQK] CreateTensorByShape, tensor shape: {}", tf_shape.DebugString()); + + Tensor tmp(tensor_data_type, tf_shape); + auto tensor_data = const_cast(tmp.tensor_data().data()); + auto tensor_size = tmp.tensor_data().size(); + LOG_INFO("[LQK] KnownShape, create tensor: {}, tensor size: {}, tensor.NumElements:{}", + tmp.DebugString(), tensor_size, tmp.NumElements()); + + memset_s(tensor_data, tensor_size, 0, tensor_size); + + LOG_INFO("[LQK] KnownShape, after memset tensor: {}", tmp.DebugString()); + + return tmp; + } + + std::vector CreateOutputVecTensor() + { + size_t output_shape_size = dataset()->output_shapes().size(); + size_t output_type_size = dataset()->output_dtypes().size(); + LOG_INFO("[LQK] output_shape_size: {}, output_type_size: {}", output_shape_size, output_type_size); + if (output_shape_size != output_type_size) { + LOG_ERROR("[LQK] output_shape_size: {} is not equal to output_type_size: {}", output_shape_size, + output_type_size); + return {}; + } + std::vector result; + for (size_t i = 0UL; i < output_shape_size; i++) { + DataType tensor_data_type = dataset()->output_dtypes().at(i); + if (tensor_data_type == DT_STRING) { + LOG_ERROR("[LQK] current tensor type is DT_STRING"); + return{}; + } + LOG_INFO("[LQK] current tensor type is: {}", tensor_data_type); + LOG_INFO("[LQK] current tensor dim is: {}, dim[0].dim_Size is {}", dataset()->output_shapes()[i].dims(), + dataset()->output_shapes()[i].dim_size(0)); + if (dataset()->output_shapes()[i].dims() == 2) { + LOG_INFO("[LQK] current tensor dim[1].dim_Size is {}", dataset()->output_shapes()[i].dim_size(1)); + } + if (IsUnknowShape(dataset()->output_shapes()[i])) { + LOG_INFO("[LQK] output shape is unknown shape"); + Tensor tensor(tensor_data_type, TensorShape({8, 1})); + if (dataset()->output_shapes()[i].dims() == -1) { + tensor = Tensor(tensor_data_type, TensorShape({1})); + } + + // 获取指针 + auto tensor_data = const_cast(tensor.tensor_data().data()); + auto tensor_size = tensor.tensor_data().size(); + LOG_INFO("[LQK] IsUnknowShape, create tensor: {}, tensor size: {}, tensor.NumElements:{}", + tensor.DebugString(), tensor_size, tensor.NumElements()); + + memset_s(tensor_data, tensor_size, 0, tensor_size); + + LOG_INFO("[LQK] IsUnknowShape, after memset tensor: {}", tensor.DebugString()); + + result.push_back(tensor); + continue; + } + Tensor a = CreateTensorByShape(dataset()->output_shapes()[i], tensor_data_type); + LOG_INFO("[LQK] success create know shape tensor: {}", a.DebugString()); + + result.push_back(a); + } + return result; + } + Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, @@ -198,6 +297,28 @@ private: } TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + int outSize = out_tensors->size(); + if (outSize > 0) { + for (const auto& t : *out_tensors) { + DataType tensor_type = t.dtype(); + TensorShape tensor_shape = t.shape(); + LOG_INFO("[LQK] GetNext eos, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " + "tensor_shape: {}", + dataset()->channelId_, + iter_times_, + outSize, + tensor_type, + tensor_shape.DebugString()); + } + } + if (!is_second_eos && *end_of_sequence) { + is_second_eos = true; + *end_of_sequence = false; + *out_tensors = CreateOutputVecTensor(); + } else if (is_second_eos) { + *end_of_sequence = true; + } + auto keyProcess = Singleton::GetInstance(); auto datasetId = dataset()->id_; auto channelId = dataset()->channelId_; @@ -217,7 +338,7 @@ private: &req); CheckCommFinished(req, channelId); - keyProcess->SetEos(1, dataset()->channelId_); +// keyProcess->SetEos(1, dataset()->channelId_); LOG_DEBUG("[ACTIVE] GetNext eos was triggered actively, channel: {}, iter: {}", dataset()->channelId_, iter_times_); @@ -232,7 +353,7 @@ private: if (getNextStatus < g_rankSize) { *end_of_sequence = true; - keyProcess->SetEos(1, dataset()->channelId_); +// keyProcess->SetEos(1, dataset()->channelId_); LOG_DEBUG( "[PASSIVE] GetNext eos was triggered passively, channel: {}, iter: {}, sum: {}", dataset()->channelId_, iter_times_, getNextStatus); @@ -287,17 +408,22 @@ private: GUARDED_BY(mu_); std::unique_ptr input_impl_ GUARDED_BY(mu_); + bool is_second_eos = false; }; const DatasetBase *input_; int32_t channelId_; int32_t maxTrainSteps_; int32_t maxEvalSteps_; - std::vector output_shapes_; + const DataTypeVector outputTypes_; + std::vector outputShapes_; int id_; }; -EosDatasetOp::EosDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) {} +EosDatasetOp::EosDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &outputTypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &outputShapes_)); +} void EosDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, DatasetBase **output) { @@ -307,7 +433,8 @@ void EosDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, Dataset OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxTrainSteps, &maxTrainSteps)); int32_t maxEvalSteps; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxEvalSteps, &maxEvalSteps)); - *output = new Dataset(ctx, input, channel, maxTrainSteps, maxEvalSteps); + *output = new (std::nothrow) Dataset(ctx, input, channel, maxTrainSteps, maxEvalSteps, outputTypes_, outputShapes_); + OP_REQUIRES(ctx, *output != nullptr, errors::InvalidArgument("EosDatasetOp: new dataset failed")); } REGISTER_OP("EosDataset") diff --git a/src/dataset_tf/eos_dataset_op.h b/src/dataset_tf/eos_dataset_op.h index bf30c6b9..117f2794 100644 --- a/src/dataset_tf/eos_dataset_op.h +++ b/src/dataset_tf/eos_dataset_op.h @@ -44,6 +44,8 @@ namespace data { private: class Dataset; + DataTypeVector outputTypes_; + std::vector outputShapes_; }; // class EosDatasetOp } // namespace data } // namespace tensorflow -- Gitee From ef52d462095a1e70f0d5042aa769b7caa645326f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Fri, 9 Aug 2024 17:11:06 +0800 Subject: [PATCH 14/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 117 ++++++++++++------------ src/core/hybrid_mgmt/hybrid_mgmt.h | 18 ++-- src/core/key_process/key_process.cpp | 127 ++++++++------------------- src/core/key_process/key_process.h | 4 - src/core/utils/common.h | 19 ++-- src/dataset_tf/eos_dataset_op.cc | 2 - 6 files changed, 112 insertions(+), 175 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 7fb2feeb..18b6e32e 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -116,8 +116,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, hybridMgmtBlock->SetRankInfo(rankInfo); // 启动数据处理线程 - KEY_PROCESS_INSTANCE->Initialize(rankInfo, embInfos, thresholdValues, seed, - isIncrementalCheckpoint); + KEY_PROCESS_INSTANCE->Initialize(rankInfo, embInfos, thresholdValues, seed, isIncrementalCheckpoint); isRunning = true; isL3StorageEnabled = rankInfo.isSSDEnabled; @@ -169,7 +168,7 @@ void HybridMgmt::Save(const string& savePath, bool saveDelta) throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } string saveModelType = - saveDelta ? TransferModelType2Str(SaveModelType::DELTA) : TransferModelType2Str(SaveModelType::BASE); + saveDelta ? TransferModelType2Str(SaveModelType::DELTA) : TransferModelType2Str(SaveModelType::BASE); LOG_INFO(MGMT + "Start to save {} model to {}.", saveModelType, savePath); // 数据处理线程上锁 @@ -510,7 +509,9 @@ void HybridMgmt::TrainTask(TaskType type) return; } LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, theTrainBatchId); - + if (isBackUpTrainStatus) { + RecoverTrainStatus(); + } ParseKeys(TRAIN_CHANNEL_ID, theTrainBatchId, type); } while (true); #endif @@ -535,10 +536,11 @@ void HybridMgmt::EvalTask(TaskType type) hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID], evalBatchId); if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) { - // Before waking the data process for training, Recover the backed-up training state - RecoverTrainStatus(); hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); - } else { + } else if (!isRunning) { + return; + } + else { std::this_thread::sleep_for(SLEEP_MS); continue; } @@ -603,8 +605,8 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId, TaskType type) threadPool->enqueueWithFuture([this, info]() { return ProcessEmbInfoDDR(info); }); remainResult.push_back(std::move(remainBatch)); } else { - std::future remainBatch = threadPool->enqueueWithFuture( - [this, info]() { return ProcessEmbInfoL3Storage(info); }); + std::future remainBatch = + threadPool->enqueueWithFuture([this, info]() { return ProcessEmbInfoL3Storage(info); }); remainResult.push_back(std::move(remainBatch)); } break; @@ -711,7 +713,7 @@ bool HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info) auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); if (isEos) { EosL1Que[info.name].Pushv(make_pair(true, info.channelId)); - LOG_DEBUG("[LQK] enqueue EosL1Que DDR, table:{}, batchId:{}, channelId:{}, EosL1Que size: {}", info.name, + LOG_DEBUG("Enqueue on EosL1Que, table:{}, batchId:{}, channelId:{}, EosL1Que size: {}", info.name, info.batchId, info.channelId, EosL1Que.size()); } if (uniqueKeys.empty()) { @@ -993,23 +995,28 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName, int channelId) std::string swapOutName = embName + SWAP_OUT_STR; std::vector addrs; while (isRunning && lookupAddrSuccess) { + pair eosChannel = EosL1Que[embName].WaitAndPop(); if (!isRunning) { return; } - pair keyChannel = EosL1Que[embName].WaitAndPop(); - if (keyChannel.first) { - EosL2Que[embName].Pushv(make_pair(true, keyChannel.second)); - LOG_DEBUG("[LQK] enqueue EosL2Que eos, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, EosL2Que.size: {}", - embName, id, keyChannel.second, EosL1Que.size(), EosL2Que.size()); + if (eosChannel.first) { + EosL2Que[embName].Pushv(make_pair(true, eosChannel.second)); + LOG_DEBUG( + "Enqueue on EosL2Que, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, EosL2Que.size: {}", + embName, id, eosChannel.second, EosL1Que.size(), EosL2Que.size()); continue; } else { - EosL2Que[embName].Pushv(make_pair(false, keyChannel.second)); - LOG_DEBUG("[LQK] enqueue EosL2Que normal, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, EosL2Que.size: {}", - embName, id, keyChannel.second, EosL1Que.size(), EosL2Que.size()); + EosL2Que[embName].Pushv(make_pair(false, eosChannel.second)); + LOG_DEBUG("Enqueue on EosL2Que, normal status, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, " + "EosL2Que.size: {}", + embName, id, eosChannel.second, EosL1Que.size(), EosL2Que.size()); } // swap in std::vector keys = HBMSwapKeyQue[swapInName][channelId].WaitAndPop(); + if (!isRunning) { + return; + } TimeCost lookupAddrsInTC; int rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); if (rc != H_OK) { @@ -1026,6 +1033,9 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName, int channelId) keys = HBMSwapKeyQue[swapOutName][channelId].WaitAndPop(); TimeCost lookupAddrsOutTC; rc = embCache->EmbeddingLookupAddrs(embName, keys, addrs); + if (!isRunning) { + return; + } if (rc != H_OK) { lookupAddrSuccess = false; throw runtime_error("EmbeddingLookupAddrs failed! error code: " + std::to_string(rc)); @@ -1127,14 +1137,12 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) receiveKeyThreads.emplace_back([embInfo, this]() { while (isRunning) { TransferChannel transferName = TransferChannel::KEY_D2H; - size_t ret = hdTransfer->RecvOffsetsAcl(transferName, TRAIN_CHANNEL_ID, - embInfo.name); + size_t ret = hdTransfer->RecvOffsetsAcl(transferName, TRAIN_CHANNEL_ID, embInfo.name); if (ret == 0) { LOG_WARN("Receive empty data."); } else { LOG_INFO("Receive data success, get {} data size: {}.", embInfo.name, ret); - auto aclData = - acltdtGetDataItem(hdTransfer->aclDatasetsForIncrementalCkpt[embInfo.name], 0); + auto aclData = acltdtGetDataItem(hdTransfer->aclDatasetsForIncrementalCkpt[embInfo.name], 0); if (aclData == nullptr) { throw runtime_error("Acl get tensor data failed."); } @@ -1142,12 +1150,12 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) int64_t timeStamp = *ptr; int64_t globalStep = *(ptr + 1); LOG_INFO("Receive {} timeStamp: {}, global step: {}.", embInfo.name, timeStamp, globalStep); - // tensorflow获取的global step是从1开始的,但是在key process中batch id则是从0开始,因此,下面的info中的batchId需要用 - // globalStep - 1 - EmbBaseInfo info = {.batchId=static_cast(globalStep-1), .channelId=TRAIN_CHANNEL_ID, - .name=embInfo.name}; - unique_ptr> keyCountVecInfo = - KEY_PROCESS_INSTANCE->GetKCInfoVec(info); + // tensorflow获取的global step是从1开始的,但是在key process中batch + // id则是从0开始,因此,下面的info中的batchId需要用 globalStep - 1 + EmbBaseInfo info = {.batchId = static_cast(globalStep - 1), + .channelId = TRAIN_CHANNEL_ID, + .name = embInfo.name}; + unique_ptr> keyCountVecInfo = KEY_PROCESS_INSTANCE->GetKCInfoVec(info); if (keyCountVecInfo == nullptr) { LOG_ERROR("Get key count info vector is empty."); throw runtime_error("Get key count info vector is empty."); @@ -1159,8 +1167,8 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) for (int64 i = 0; i < keyCountSize; ++i) { keyCountVec.push_back(static_cast(keyCountVecTmp(i))); } - LOG_INFO("Emb table: {}, channel: {}, size is: {}, data: {}", - embInfo.name, TRAIN_CHANNEL_ID, keyCountSize, VectorToString(keyCountVec)); + LOG_INFO("Emb table: {}, channel: {}, size is: {}, data: {}", embInfo.name, TRAIN_CHANNEL_ID, + keyCountSize, VectorToString(keyCountVec)); // 更新delta表 std::lock_guard lock(keyCountUpdateMtx); @@ -1179,7 +1187,7 @@ void HybridMgmt::updateDeltaInfo(const string& embName, vector& keyCoun auto& embMap = deltaMap[embName]; for (int i = 0; i < keyCountSize; i += KEY_COUNT_ELEMENT_NUM) { emb_key_t key = keyCountVec[i]; - int64_t recentCount = keyCountVec[i+1]; + int64_t recentCount = keyCountVec[i + 1]; KeyInfo& keyInfo = embMap[key]; keyInfo.totalCount += recentCount; keyInfo.recentCount += recentCount; @@ -1317,8 +1325,8 @@ bool HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info) auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); if (isEos) { EosL1Que[info.name].Pushv(make_pair(true, info.channelId)); - LOG_DEBUG("[LQK] enqueue EosL1Que SSD, table:{}, batchId:{}, channelId:{}", info.name, - info.batchId, info.channelId); + LOG_DEBUG("Enqueue on EosL1Que L3Storage, table:{}, batchId:{}, channelId:{}", info.name, info.batchId, + info.channelId); } if (uniqueKeys.empty()) { @@ -1527,20 +1535,17 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto lastRecvFinishCV[currentKey].wait(lastRecvFinishLocker, [info, this] { return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; }); + + pair eosChannel = EosL2Que[info.name].WaitAndPop(); if (!isRunning) { return false; } - pair keyChannel = EosL2Que[info.name].WaitAndPop(); - if (keyChannel.first) { + string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + if (eosChannel.first) { isEos = true; - LOG_INFO("EmbeddingReceiveDDR get eos from channel: {}", keyChannel.second); - bool sendAllChannel = false; - if (!alreadyTrainOnce) { - // predict场景 - sendAllChannel = true; - } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, keyChannel.second, sendAllChannel); - cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + LOG_INFO("EmbeddingReceiveDDR get eos from channel: {}", eosChannel.second); + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, eosChannel.second); + lastRecvFinishCV[nextKey].notify_all(); return true; } @@ -1580,7 +1585,6 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto LOG_DEBUG("table:{}, accumulate batchId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name, info.batchId, info.threadIdx, EmbeddingRecvTC.ElapsedMS()); hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); lastRecvFinishCV[nextKey].notify_all(); return true; @@ -1735,21 +1739,17 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, lastRecvFinishCV[currentKey].wait(lastRecvFinishLocker, [info, this] { return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; }); + + pair eosChannel = EosL1Que[info.name].WaitAndPop(); if (!isRunning) { return false; } - - pair keyChannel = EosL1Que[info.name].WaitAndPop(); - if (keyChannel.first) { + string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + if (eosChannel.first) { isEos = true; - LOG_INFO("EmbeddingReceiveL3Storage get eos from channel: {}", keyChannel.second); - bool sendAllChannel = false; - if (!alreadyTrainOnce) { - // predict场景 - sendAllChannel = true; - } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, keyChannel.second, sendAllChannel); - cvLastRecvFinishMap[info.name][info.cvNotifyIndex].notify_all(); + LOG_INFO("EmbeddingReceiveL3Storage get eos from channel: {}", eosChannel.second); + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, eosChannel.second); + lastRecvFinishCV[nextKey].notify_all(); return true; } @@ -1788,7 +1788,6 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, LOG_DEBUG("table:{}, accumulate batchId:{}, channelId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name.c_str(), info.batchId, info.channelId, info.threadIdx, EmbeddingRecvTC.ElapsedMS()); hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); lastRecvFinishCV[nextKey].notify_all(); return true; } @@ -1964,6 +1963,7 @@ void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& h2dE vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut, bool& isEos) { - auto uniqueKeys = KEY_PROCESS_INSTANCE->GetUniqueKeys(info, isEos, lookUpSwapInAddrsPushId); + auto uniqueKeys = KEY_PROCESS_INSTANCE->GetUniqueKeys(info, isEos); if (isEos) { HandleEosCase(info, remainBatchOut); return uniqueKeys; @@ -2164,9 +2164,8 @@ void HybridMgmt::EnqueueSwapInfo(const EmbBaseInfo& info, pair, CheckLookupAddrSuccessDDR(); EosL1Que[info.name].Pushv(make_pair(false, info.channelId)); - LOG_DEBUG("enqueue EosL1Que, normal status, table:{}, batchId:{}, channelId:{}, EosL1Que.size: {}", info.name, + LOG_DEBUG("Enqueue on EosL1Que, normal status, table:{}, batchId:{}, channelId:{}, EosL1Que.size: {}", info.name, info.batchId, info.channelId, EosL1Que.size()); - } void HybridMgmt::BackUpTrainStatus() diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 11091fd2..988af792 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -17,13 +17,13 @@ See the License for the specific language governing permissions and #define MX_REC_EMB_MGMT_H #include -#include -#include -#include -#include #include #include +#include +#include #include +#include +#include #include "absl/container/flat_hash_map.h" #include "emb_table/embedding_table.h" @@ -37,8 +37,8 @@ See the License for the specific language governing permissions and #include "utils/config.h" #include "utils/singleton.h" #include "utils/task_queue.h" -#include "utils/time_cost.h" #include "utils/thread_pool.h" +#include "utils/time_cost.h" namespace MxRec { using namespace std; @@ -169,8 +169,8 @@ public: std::map>[MAX_CHANNEL_NUM]> HBMSwapAddrsQue; std::map>[MAX_CHANNEL_NUM]> DDRSwapAddrsQue; - std::map>> EosL1Que; // pair - std::map>> EosL2Que; + std::map>> EosL1Que; // pair + std::map>> EosL2Que; std::mutex evictMut; @@ -237,8 +237,8 @@ private: bool isRunning; bool isLoad{false}; bool isInitialized{false}; - bool alreadyTrainOnce = false; // 用于判断是否为predict模式 - bool isBackUpTrainStatus = false; // whether the train state has been backed up + bool alreadyTrainOnce = false; // 用于判断是否为predict模式 + bool isBackUpTrainStatus = false; // whether the train state has been backed up bool isIncrementalCkpt; map> deltaMap; absl::flat_hash_map keyBatchIdMap; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a5d4c5ff..a3a73ac1 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and ==============================================================================*/ #include "key_process.h" +#include + #include #include -#include - #include "emb_table/embedding_mgmt.h" #include "hd_transfer/hd_transfer.h" #include "ock_ctr_common/include/error_code.h" @@ -41,8 +41,7 @@ void KeyProcess::SetupHotEmbUpdateStep() } bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, - const vector& thresholdValues, - int seed, bool isIncrementalCkpt) + const vector& thresholdValues, int seed, bool isIncrementalCkpt) { readySendEosCnt[TRAIN_CHANNEL_ID].store(0); readySendEosCnt[EVAL_CHANNEL_ID].store(0); @@ -259,7 +258,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) while (true) { TimeCost getAndProcessTC; TimeCost getBatchDataTC; - batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue + batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue LOG_DEBUG("getBatchDataTC(ms):{}", getBatchDataTC.ElapsedMS()); if (batch == nullptr) { break; @@ -305,8 +304,9 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) if (!KeyProcessTaskHelper(batch, channel, threadId)) { break; } - LOG_INFO(KEY_PROCESS "getAndProcessTC(ms):{}, key process cost:{}," - " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}, isEos:{}", + LOG_INFO(KEY_PROCESS + "getAndProcessTC(ms):{}, key process cost:{}," + " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}, isEos:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, batch->name, batch->channel, threadId, batch->batchId, batch->isEos); int queueIndex = threadId + (MAX_KEY_PROCESS_THREAD * batch->channel); @@ -319,16 +319,16 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) LOG_INFO(KEY_PROCESS "KeyProcessTask exit. rank:{} channelId:{}, threadId:{}", rankInfo.rankId, channel, threadId); } -void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, - vector & restore, vector & hotPos, - vector >& keyCount, vector& keyCountVec) +void KeyProcess::HashSplitHelper(const unique_ptr& batch, vector& splitKeys, vector& restore, + vector& hotPos, vector>& keyCount, + vector& keyCountVec) { TimeCost uniqueTc; if (m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { tie(splitKeys, restore, keyCount) = HashSplitWithFAAE(batch); // 按存储dev id切分并去重 } else { - tie(splitKeys, restore, hotPos, keyCountVec) = HotHashSplit(batch); // 按存储dev id切分并去重 + tie(splitKeys, restore, hotPos, keyCountVec) = HotHashSplit(batch); // 按存储dev id切分并去重 } LOG_DEBUG("uniqueTc(ms):{}", uniqueTc.ElapsedMS()); } @@ -400,23 +400,25 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) { if (batch->isEos) { - if (!rankInfo.isDDR) { // HBM -// auto tensors = make_unique>(); + if (!rankInfo.isDDR) { // HBM + // auto tensors = make_unique>(); std::unique_lock lockGuard(mut); -// storage.push_front(move(tensors)); - infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); + // storage.push_front(move(tensors)); + infoList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); lockGuard.unlock(); - LOG_INFO("KeyProcessTaskHelper hbm eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", - batch->name, batch->batchId, batch->channel, threadId); + LOG_INFO("KeyProcessTaskHelper hbm eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", batch->name, + batch->batchId, batch->channel, threadId); return true; } // DDR vector uniqueKeys; std::unique_lock lockGuard(mut); - uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); + uniqueKeysList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); lockGuard.unlock(); - LOG_INFO("KeyProcessTaskHelper ddr eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", - batch->name, batch->batchId, batch->channel, threadId); + LOG_INFO("KeyProcessTaskHelper ddr eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", batch->name, + batch->batchId, batch->channel, threadId); return true; } vector splitKeys; @@ -449,10 +451,14 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) - if (!rankInfo.isDDR) { EmbeddingMgmt::Instance()->Key2Offset(batch->name, lookupKeys, channel); } + if (!rankInfo.isDDR) { + EmbeddingMgmt::Instance()->Key2Offset(batch->name, lookupKeys, channel); + } // Static all2all,need send count - if (!rankInfo.useStatic) { SendA2A(scAll, batch->name, batch->channel, batch->batchId); } + if (!rankInfo.useStatic) { + SendA2A(scAll, batch->name, batch->channel, batch->batchId); + } TimeCost pushResultTC; auto tensors = make_unique>(); @@ -554,7 +560,8 @@ void KeyProcess::PushResultDDR(unique_ptr& batch, unique_ptr lockGuard(mut); storage.push_front(move(tensors)); infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); - uniqueKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); + uniqueKeysList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); restoreVecSecList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(restoreVecSec))); lockGuard.unlock(); } @@ -563,8 +570,8 @@ void KeyProcess::PushKeyCountHBM(unique_ptr& batch, unique_ptr lockGuard(mut); keyCountStorage.push_front(move(tensors)); - keyCountInfoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, - keyCountStorage.begin())); + keyCountInfoList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, keyCountStorage.begin())); lockGuard.unlock(); LOG_INFO("Push key count to list success."); } @@ -945,8 +952,8 @@ tuple, vector, vector>> KeyProcess::Hash return {splitKeys, restore, keyCount}; } -tuple, vector, vector, vector> KeyProcess::HotHashSplit(const -unique_ptr& batch) +tuple, vector, vector, vector> KeyProcess::HotHashSplit( + const unique_ptr& batch) { EASY_FUNCTION(profiler::colors::Gold) emb_key_t* batchData = batch->sample.data(); @@ -1017,7 +1024,7 @@ unique_ptr& batch) UpdateHotMap(keyCountMapByEmbName, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, batch->name); AddCountStartToHotPos(splitKeys, hotPos, hotPosDev, batch); - return { splitKeys, restore, hotPos, keyCountVec }; + return {splitKeys, restore, hotPos, keyCountVec}; } void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, @@ -1212,7 +1219,7 @@ T KeyProcess::GetInfo(info_list_t& list, const EmbBaseInfo& info) return move(t); } -template +template T KeyProcess::GetKeyCountVec(info_list_t& list, const EmbBaseInfo& info) { std::lock_guard lockGuard(mut); @@ -1255,8 +1262,7 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos) auto infoVec = GetInfo(uniqueKeysList, info); isEos = get(infoVec); if (isEos) { - LOG_WARN(KEY_PROCESS "GetUniqueKeys eos! {}[{}]:{}", - info.name, info.channelId, info.batchId); + LOG_WARN(KEY_PROCESS "GetUniqueKeys eos! {}[{}]:{}", info.name, info.channelId, info.batchId); break; } ret = get>(infoVec); @@ -1274,42 +1280,6 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos) return ret; } -bool KeyProcess::IsGetUniqueKeysEos(const EmbBaseInfo& info, std::chrono::_V2::system_clock::time_point& startTime) -{ - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - auto endTime = std::chrono::system_clock::now(); - - // readEmbKey start with 0 - int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[info.channelId] - 1; - // 避免eos在keyProcess还未处理完数据时插队到通道前面 - std::chrono::duration elapsedTime = endTime - startTime; - if (info.batchId != 0 && elapsedTime.count() >= timeoutGetUniqueKeysEmpty) { - LOG_DEBUG("table:{}, channelId:{}, isNeedSendEos:{}, current batchId:{}, L1 pipeline readEmbKeyBatchId:{}, " - "L2 pipeline lookUpSwapAddrsPushId:{}, L3 pipeline h2dNextBatchId:{}", - info.name, info.channelId, isNeedSendEos[info.channelId], info.batchId, readEmbKeyBatchId, - hybridMgmtBlock->lookUpSwapAddrsPushId[info.name][info.channelId], - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]); - startTime = std::chrono::system_clock::now(); - } - // Check '>= readEmbedBatchIdAll' condition to avoid send eos before handle all batch data from readEmbKey Op. - if (isNeedSendEos[info.channelId] && readEmbKeyBatchId < info.batchId && - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId] == - hybridMgmtBlock->lookUpSwapAddrsPushId[info.name][info.channelId] && - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId] >= - hybridMgmtBlock->readEmbedBatchId[info.channelId]) { - LOG_INFO("table:{}, channelId:{} current batchId:{}, GetUniqueKeys eos, L1 pipeline readEmbKeyBatchId:{}, " - "L2 pipeline hybridBatchId:{}, L3 pipeline lookUpSwapAddrsPushId:{}, L4 pipeline h2dNextBatchId:{}", - info.name, info.channelId, info.batchId, readEmbKeyBatchId, - hybridMgmtBlock->hybridBatchId[info.channelId], - hybridMgmtBlock->lookUpSwapAddrsPushId[info.name][info.channelId], - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]); - return true; - } - LOG_TRACE("getting uniqueKeys failed, table:{}, channel:{}, mgmt batchId:{}, readEmbKey batchId:{}, list is empty", - info.name, info.channelId, info.batchId, readEmbKeyBatchId); - return false; -} - std::vector KeyProcess::GetRestoreVecSec(const EmbBaseInfo& info) { TimeCost tc = TimeCost(); @@ -1621,29 +1591,6 @@ void KeyProcess::SetEos(int status, int channelId) isNeedSendEos[channelId] = (status == 1); } -bool KeyProcess::IsGetInfoVecEos(int batch, const string& embName, int channel) -{ - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - - // 避免eos在keyProcess还未处理完数据时插队到通道前面, readEmbKey真实的次数是readEmbedBatchId减1 - int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[channel] - 1; - if (rankInfo.isDDR) { - if (isNeedSendEos[channel] && readEmbKeyBatchId < batch && - hybridMgmtBlock->h2dNextBatchId[embName][channel] == batch) { - LOG_ERROR("channelId:{} batchId:{}, GetInfoVec eos, code should not reach here", channel, batch); - throw runtime_error("GetInfoVec eos, code should not reach here"); - } - } else { - LOG_TRACE("table:{}, channelId:{}, readEmbKeyBatchId:{}, batchId:{}, isNeedSendEos:{}", embName, channel, - readEmbKeyBatchId, batch, isNeedSendEos[channel]); - if (isNeedSendEos[channel] && readEmbKeyBatchId < batch) { - LOG_INFO("table:{}, channelId:{} batchId:{}, GetInfoVec eos", embName, channel, batch); - return true; - } - } - return false; -} - void KeyProcess::SendEosTensor(const std::string& embName, int channel) { #ifndef GTEST diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index ac99b42d..cf2f1114 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -311,10 +311,6 @@ public: string DumpSplitKeys(vector>& splitKeys) const; - bool IsGetInfoVecEos(int batch, const string& embName, int channel); - - bool IsGetUniqueKeysEos(const EmbBaseInfo& info, std::chrono::_V2::system_clock::time_point& startTime); - void SendEosTensor(const std::string& embName, int channel); }; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 594d0862..98449b58 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -187,8 +187,7 @@ struct Batch { bool CheckAndSetEos() { for (int i = 0; i < 8; i++) { - if (sample[i] != 0) - { + if (sample[i] != 0) { return false; } } @@ -527,14 +526,13 @@ struct KeySendInfo { }; struct KeyInfo { - int64_t lastUseTime; // 最后使用时间 - int64_t recentCount; // 最近使用次数 - bool isChanged; // 是否有变更 - int64_t batchID; // batch id - int64_t totalCount; // key总使用次数 - - KeyInfo(): lastUseTime(0), recentCount(0), isChanged(false), - batchID(0), totalCount(0) {} + int64_t lastUseTime; // 最后使用时间 + int64_t recentCount; // 最近使用次数 + bool isChanged; // 是否有变更 + int64_t batchID; // batch id + int64_t totalCount; // key总使用次数 + + KeyInfo() : lastUseTime(0), recentCount(0), isChanged(false), batchID(0), totalCount(0) {} }; using EmbMemT = absl::flat_hash_map; @@ -603,7 +601,6 @@ enum class CkptDataType { std::string CkptDataTypeName(CkptDataType type); - enum CTRLogLevel { // can't use enum class due to compatibility for AccCTR DEBUG = 0, INFO, diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 72be6696..1cd8c8bc 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -91,8 +91,6 @@ public: outputShapes_(outputShapes), id_(g_datasetId[channelId]) { input_->Ref(); -// auto os_input = input->output_shapes(); -// output_shapes_ = os_input; MPI_Comm_group(MPI_COMM_WORLD, &g_group); MPI_Comm_create(MPI_COMM_WORLD, g_group, &g_comm[channelId]); -- Gitee From 68410cec83c085ca5e2f97d1c66a9715e2dd8227 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 12 Aug 2024 16:10:08 +0800 Subject: [PATCH 15/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 3 -- src/core/key_process/key_process.cpp | 63 +++++++++++++--------- src/core/key_process/key_process.h | 8 ++- src/core/utils/common.h | 11 ---- src/dataset_tf/eos_dataset_op.cc | 21 ++++---- 5 files changed, 52 insertions(+), 54 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index 8943044a..8643f924 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -187,9 +187,6 @@ void HybridMgmtBlock::ResetAll(int channelId) LOG_DEBUG(HYBRID_BLOCKING + "after reset block status," " channelId:{}, pythonBatchId:{}, readEmbedBatchId:{}, hybridBatchId:{}", channelId, pythonBatchId[channelId], readEmbedBatchId[channelId], hybridBatchId[channelId]); - - LOG_DEBUG("Start to reset isNeedSendEos"); - Singleton::GetInstance()->SetEos(0, channelId); } /// 检查当前的步数是否可以进行save diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a3a73ac1..c65b4b66 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -400,25 +400,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) { if (batch->isEos) { - if (!rankInfo.isDDR) { // HBM - // auto tensors = make_unique>(); - std::unique_lock lockGuard(mut); - // storage.push_front(move(tensors)); - infoList[batch->name][batch->channel].push( - make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); - lockGuard.unlock(); - LOG_INFO("KeyProcessTaskHelper hbm eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", batch->name, - batch->batchId, batch->channel, threadId); - return true; - } - // DDR - vector uniqueKeys; - std::unique_lock lockGuard(mut); - uniqueKeysList[batch->name][batch->channel].push( - make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); - lockGuard.unlock(); - LOG_INFO("KeyProcessTaskHelper ddr eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", batch->name, - batch->batchId, batch->channel, threadId); + HandleEos(batch, channel, threadId); return true; } vector splitKeys; @@ -497,6 +479,29 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, return true; } +void KeyProcess::HandleEos(unique_ptr& batch, int channel, int threadId) +{ + if (!rankInfo.isDDR) { // HBM + // auto tensors = make_unique>(); + std::unique_lock lockGuard(mut); + // storage.push_front(move(tensors)); + infoList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); + lockGuard.unlock(); + LOG_INFO("KeyProcessTaskHelper hbm eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", batch->name, + batch->batchId, batch->channel, threadId); + return; + } + // DDR + vector uniqueKeys; + std::unique_lock lockGuard(mut); + uniqueKeysList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, batch->isEos, move(uniqueKeys))); + lockGuard.unlock(); + LOG_INFO("KeyProcessTaskHelper ddr eos, batch name:{}, batch id: {}, channelId:{} threadId:{}", batch->name, + batch->batchId, batch->channel, threadId); +} + void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tensors, KeysT& lookupKeys, int channel) { LOG_INFO(KEY_PROCESS "rank:{}, channel:{}, useSumSameIdGradients:{} ...", rankInfo.rankId, channel, @@ -594,7 +599,7 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) const while (true) { batch = batchQueue->TryPop(); if (batch != nullptr) { - if (batch->CheckAndSetEos()) { + if (batch->isEos) { LOG_INFO("GetBatchData eos, table name:{}, batchId:{}, channelId:{} threadId:{}", batch->name, batch->batchId, channel, commId); } @@ -1583,12 +1588,20 @@ void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) } } -void KeyProcess::SetEos(int status, int channelId) +void KeyProcess::EnqueEosBatch(int64_t batchNum, int channelId) { - unique_lock lockGuard(eosMutex); - LOG_INFO("isNeedSendEos status is changed, channel:{}, before status:{}, input status:{}", channelId, - isNeedSendEos[channelId], status); - isNeedSendEos[channelId] = (status == 1); + LOG_INFO("DataSet eos, channel:{}, eos number:{}", channelId, batchNum); + int threadNum = GetThreadNumEnv(); + int batchQueueId = int(batchNum % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId); + auto queue = SingletonQueue::GetInstances(batchQueueId); + for (auto& emb : embInfos) { + auto batchData = queue->GetOne(); // get dirty or empty data block + batchData->name = emb.first; + batchData->channel = channelId; + batchData->batchId = batchNum; + batchData->isEos = true; + queue->Pushv(move(batchData)); + } } void KeyProcess::SendEosTensor(const std::string& embName, int channel) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index cf2f1114..2f4eab22 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -156,10 +156,12 @@ public: } } - void SetEos(int status, int channelId); + void EnqueEosBatch(int64_t batchNum, int channelId); void SendEos(const string& embName, int batchId, int channel); + void HandleEos(unique_ptr& batch, int channel, int threadId); + bool isRunning{false}; bool isIncrementalCheckpoint {false}; @@ -206,12 +208,8 @@ public: int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; - // for end-of-sequence case - bool isNeedSendEos[2] = {false, false}; // 表示各表通道0、1的eos状态 atomic readySendEosCnt[2]; atomic finishSendEosCnt[2]; - const double timeoutGetUniqueKeys = 30.0; // 如果超时仍未获取到数据将触发EOS - const double timeoutGetUniqueKeysEmpty = 1.0; // 如果超时仍未获取到数据将打印信息 void InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 98449b58..b5008df3 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -184,17 +184,6 @@ struct Batch { return s; } - bool CheckAndSetEos() - { - for (int i = 0; i < 8; i++) { - if (sample[i] != 0) { - return false; - } - } - isEos = true; - return true; - } - std::vector sample; std::string name; size_t batchSize; diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 1cd8c8bc..29de1f34 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -295,6 +295,7 @@ private: } TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + // todo debug print int outSize = out_tensors->size(); if (outSize > 0) { for (const auto& t : *out_tensors) { @@ -309,13 +310,13 @@ private: tensor_shape.DebugString()); } } - if (!is_second_eos && *end_of_sequence) { - is_second_eos = true; - *end_of_sequence = false; - *out_tensors = CreateOutputVecTensor(); - } else if (is_second_eos) { - *end_of_sequence = true; - } +// if (!is_second_eos && *end_of_sequence) { +// is_second_eos = true; +// *end_of_sequence = false; +// *out_tensors = CreateOutputVecTensor(); +// } else if (is_second_eos) { +// *end_of_sequence = true; +// } auto keyProcess = Singleton::GetInstance(); auto datasetId = dataset()->id_; @@ -336,7 +337,7 @@ private: &req); CheckCommFinished(req, channelId); -// keyProcess->SetEos(1, dataset()->channelId_); + keyProcess->EnqueEosBatch(iter_times_, dataset()->channelId_); LOG_DEBUG("[ACTIVE] GetNext eos was triggered actively, channel: {}, iter: {}", dataset()->channelId_, iter_times_); @@ -351,7 +352,7 @@ private: if (getNextStatus < g_rankSize) { *end_of_sequence = true; -// keyProcess->SetEos(1, dataset()->channelId_); + keyProcess->EnqueEosBatch(iter_times_, dataset()->channelId_); LOG_DEBUG( "[PASSIVE] GetNext eos was triggered passively, channel: {}, iter: {}, sum: {}", dataset()->channelId_, iter_times_, getNextStatus); @@ -406,7 +407,7 @@ private: GUARDED_BY(mu_); std::unique_ptr input_impl_ GUARDED_BY(mu_); - bool is_second_eos = false; +// bool is_second_eos = false; }; const DatasetBase *input_; -- Gitee From 8880b8e339be5669308868a4ec3ea803971a4f95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 12 Aug 2024 19:56:21 +0800 Subject: [PATCH 16/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/key_process/key_process.cpp | 7 +++---- src/core/key_process/key_process.h | 2 +- src/dataset_tf/eos_dataset_op.cc | 8 ++++---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index c65b4b66..43724a39 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -482,9 +482,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, void KeyProcess::HandleEos(unique_ptr& batch, int channel, int threadId) { if (!rankInfo.isDDR) { // HBM - // auto tensors = make_unique>(); std::unique_lock lockGuard(mut); - // storage.push_front(move(tensors)); infoList[batch->name][batch->channel].push( make_tuple(batch->batchId, batch->name, batch->isEos, storage.begin())); lockGuard.unlock(); @@ -1588,9 +1586,9 @@ void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) } } -void KeyProcess::EnqueEosBatch(int64_t batchNum, int channelId) +void KeyProcess::EnqueueEosBatch(int64_t batchNum, int channelId) { - LOG_INFO("DataSet eos, channel:{}, eos number:{}", channelId, batchNum); + LOG_INFO("Enqueue data set eos on batch queue, channel:{}, eos number:{}", channelId, batchNum); int threadNum = GetThreadNumEnv(); int batchQueueId = int(batchNum % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId); auto queue = SingletonQueue::GetInstances(batchQueueId); @@ -1599,6 +1597,7 @@ void KeyProcess::EnqueEosBatch(int64_t batchNum, int channelId) batchData->name = emb.first; batchData->channel = channelId; batchData->batchId = batchNum; + batchData->sample = {0, 0, 0, 0, 0, 0, 0, 0}; // fake data batchData->isEos = true; queue->Pushv(move(batchData)); } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 2f4eab22..c8ce024c 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -156,7 +156,7 @@ public: } } - void EnqueEosBatch(int64_t batchNum, int channelId); + void EnqueueEosBatch(int64_t batchNum, int channelId); void SendEos(const string& embName, int batchId, int channel); diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 29de1f34..5db20221 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -295,13 +295,13 @@ private: } TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - // todo debug print + // Out size equals to zero when batch eos. int outSize = out_tensors->size(); if (outSize > 0) { for (const auto& t : *out_tensors) { DataType tensor_type = t.dtype(); TensorShape tensor_shape = t.shape(); - LOG_INFO("[LQK] GetNext eos, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " + LOG_DEBUG("Iterator getNext eos, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " "tensor_shape: {}", dataset()->channelId_, iter_times_, @@ -337,7 +337,7 @@ private: &req); CheckCommFinished(req, channelId); - keyProcess->EnqueEosBatch(iter_times_, dataset()->channelId_); + keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); LOG_DEBUG("[ACTIVE] GetNext eos was triggered actively, channel: {}, iter: {}", dataset()->channelId_, iter_times_); @@ -352,7 +352,7 @@ private: if (getNextStatus < g_rankSize) { *end_of_sequence = true; - keyProcess->EnqueEosBatch(iter_times_, dataset()->channelId_); + keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); LOG_DEBUG( "[PASSIVE] GetNext eos was triggered passively, channel: {}, iter: {}, sum: {}", dataset()->channelId_, iter_times_, getNextStatus); -- Gitee From 89731fa096bca55f7164b4df373aaee83e74d72a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 12 Aug 2024 20:22:40 +0800 Subject: [PATCH 17/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/dataset_tf/eos_dataset_op.cc | 19 ++++++++----------- src/ops_tf/hybrid_dataset_ops.cpp | 2 ++ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 5db20221..2f9b5f93 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -310,13 +310,6 @@ private: tensor_shape.DebugString()); } } -// if (!is_second_eos && *end_of_sequence) { -// is_second_eos = true; -// *end_of_sequence = false; -// *out_tensors = CreateOutputVecTensor(); -// } else if (is_second_eos) { -// *end_of_sequence = true; -// } auto keyProcess = Singleton::GetInstance(); auto datasetId = dataset()->id_; @@ -336,8 +329,10 @@ private: MPI_Iallreduce(MPI_IN_PLACE, &getNextStatus, 1, MPI_INT, MPI_SUM, g_comm[channelId], &req); CheckCommFinished(req, channelId); - - keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); + // Max step is achieved, no need to send eos. + if (outSize == 0) { + keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); + } LOG_DEBUG("[ACTIVE] GetNext eos was triggered actively, channel: {}, iter: {}", dataset()->channelId_, iter_times_); @@ -352,7 +347,10 @@ private: if (getNextStatus < g_rankSize) { *end_of_sequence = true; - keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); + // Max step is achieved, no need to send eos. + if (outSize == 0) { + keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); + } LOG_DEBUG( "[PASSIVE] GetNext eos was triggered passively, channel: {}, iter: {}, sum: {}", dataset()->channelId_, iter_times_, getNextStatus); @@ -407,7 +405,6 @@ private: GUARDED_BY(mu_); std::unique_ptr input_impl_ GUARDED_BY(mu_); -// bool is_second_eos = false; }; const DatasetBase *input_; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 98fca961..5b358884 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -278,6 +278,7 @@ namespace MxRec { batchData->name = embNames.at(i); size_t len = splits(i); batchData->channel = channelId; + batchData->isEos = false; batchData->batchId = ids[0]; batchData->sample.resize(len); if (isTimestamp) { @@ -466,6 +467,7 @@ namespace MxRec { batchData->name = embNames.at(i); size_t len = splits.at(i); batchData->channel = channelId; + batchData->isEos = false; batchData->batchId = batchId; batchData->sample.resize(len); if (isTimestamp) { -- Gitee From b2ac4ff07eeca5bc9d50002637c51d558b231a7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 13 Aug 2024 15:03:44 +0800 Subject: [PATCH 18/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 78 ++++++++++++---------------- src/core/hybrid_mgmt/hybrid_mgmt.h | 4 +- src/dataset_tf/eos_dataset_op.cc | 5 +- 3 files changed, 38 insertions(+), 49 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 18b6e32e..c43ec426 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -535,16 +535,6 @@ void HybridMgmt::EvalTask(TaskType type) LOG_DEBUG("eval channel block, python batch id:{}, hybridBatchId:{}", hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID], evalBatchId); - if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) { - hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); - } else if (!isRunning) { - return; - } - else { - std::this_thread::sleep_for(SLEEP_MS); - continue; - } - LOG_DEBUG("wake TrainTask"); hybridMgmtBlock->DoBlock(EVAL_CHANNEL_ID); } @@ -712,9 +702,9 @@ bool HybridMgmt::ProcessEmbInfoDDR(const EmbBaseInfo& info) bool isEos = false; auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); if (isEos) { - EosL1Que[info.name].Pushv(make_pair(true, info.channelId)); - LOG_DEBUG("Enqueue on EosL1Que, table:{}, batchId:{}, channelId:{}, EosL1Que size: {}", info.name, - info.batchId, info.channelId, EosL1Que.size()); + EosL1Que[info.name][info.channelId].Pushv(true); + LOG_DEBUG("Enqueue on EosL1Que, eos status! table:{}, batchId:{}, channelId:{}, EosL1Que size: {}", info.name, + info.batchId, info.channelId, EosL1Que[info.name][info.channelId].Size()); } if (uniqueKeys.empty()) { return remainBatchOut; @@ -995,21 +985,16 @@ void HybridMgmt::LookUpSwapAddrs(const string& embName, int channelId) std::string swapOutName = embName + SWAP_OUT_STR; std::vector addrs; while (isRunning && lookupAddrSuccess) { - pair eosChannel = EosL1Que[embName].WaitAndPop(); + bool isEos = EosL1Que[embName][channelId].WaitAndPop(); if (!isRunning) { return; } - if (eosChannel.first) { - EosL2Que[embName].Pushv(make_pair(true, eosChannel.second)); - LOG_DEBUG( - "Enqueue on EosL2Que, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, EosL2Que.size: {}", - embName, id, eosChannel.second, EosL1Que.size(), EosL2Que.size()); - continue; - } else { - EosL2Que[embName].Pushv(make_pair(false, eosChannel.second)); - LOG_DEBUG("Enqueue on EosL2Que, normal status, table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, " + EosL2Que[embName][channelId].Pushv(isEos); + if (isEos) { + LOG_DEBUG("Enqueue on EosL2Que, eos status! table:{}, batchId:{}, channelId:{}, EosL1Que size:{}, " "EosL2Que.size: {}", - embName, id, eosChannel.second, EosL1Que.size(), EosL2Que.size()); + embName, id, channelId, EosL1Que[embName][channelId].Size(), EosL2Que[embName][channelId].Size()); + continue; } // swap in @@ -1324,9 +1309,9 @@ bool HybridMgmt::ProcessEmbInfoL3Storage(const EmbBaseInfo& info) bool isEos = false; auto uniqueKeys = GetUniqueKeys(info, remainBatchOut, isEos); if (isEos) { - EosL1Que[info.name].Pushv(make_pair(true, info.channelId)); - LOG_DEBUG("Enqueue on EosL1Que L3Storage, table:{}, batchId:{}, channelId:{}", info.name, info.batchId, - info.channelId); + EosL1Que[info.name][info.channelId].Pushv(true); + LOG_DEBUG("Enqueue on EosL1Que L3Storage, eos status! table:{}, batchId:{}, channelId:{}", info.name, + info.batchId, info.channelId); } if (uniqueKeys.empty()) { @@ -1476,10 +1461,12 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) void HybridMgmt::JoinEmbeddingCacheThread() { for (auto& p : EosL1Que) { - p.second.DestroyQueue(); + p.second[TRAIN_CHANNEL_ID].DestroyQueue(); + p.second[EVAL_CHANNEL_ID].DestroyQueue(); } for (auto& p : EosL2Que) { - p.second.DestroyQueue(); + p.second[TRAIN_CHANNEL_ID].DestroyQueue(); + p.second[EVAL_CHANNEL_ID].DestroyQueue(); } for (auto& p : HBMSwapAddrsQue) { p.second[TRAIN_CHANNEL_ID].DestroyQueue(); @@ -1536,15 +1523,15 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; }); - pair eosChannel = EosL2Que[info.name].WaitAndPop(); + isEos = EosL2Que[info.name][info.channelId].WaitAndPop(); if (!isRunning) { return false; } string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); - if (eosChannel.first) { - isEos = true; - LOG_INFO("EmbeddingReceiveDDR get eos from channel: {}", eosChannel.second); - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, eosChannel.second); + if (isEos) { + LOG_DEBUG("EmbeddingReceiveDDR get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, + info.channelId); + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); lastRecvFinishCV[nextKey].notify_all(); return true; } @@ -1740,15 +1727,15 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; }); - pair eosChannel = EosL1Que[info.name].WaitAndPop(); + isEos = EosL1Que[info.name][info.channelId].WaitAndPop(); if (!isRunning) { return false; } string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); - if (eosChannel.first) { - isEos = true; - LOG_INFO("EmbeddingReceiveL3Storage get eos from channel: {}", eosChannel.second); - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, eosChannel.second); + if (isEos) { + LOG_DEBUG("EmbeddingReceiveL3Storage get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, + info.channelId); + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); lastRecvFinishCV[nextKey].notify_all(); return true; } @@ -1964,7 +1951,7 @@ void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& h2dEmb) @@ -2155,17 +2142,16 @@ void HybridMgmt::EnqueueSwapInfo(const EmbBaseInfo& info, pair, { auto& swapInKeys = swapInKoPair.first; auto& swapOutKeys = swapOutKoPair.first; - - LOG_DEBUG("enqueue HBMSwapKeyQue table:{}, batchId:{}, channelId:{}, swapInSize:{}, swapOutSize:{}", info.name, - info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size()); HBMSwapKeyQue[info.name + SWAP_OUT_STR][info.channelId].Pushv(swapOutKeys); HBMSwapKeyQue[info.name + SWAP_IN_STR][info.channelId].Pushv(swapInKeys); CheckLookupAddrSuccessDDR(); - EosL1Que[info.name].Pushv(make_pair(false, info.channelId)); - LOG_DEBUG("Enqueue on EosL1Que, normal status, table:{}, batchId:{}, channelId:{}, EosL1Que.size: {}", info.name, - info.batchId, info.channelId, EosL1Que.size()); + EosL1Que[info.name][info.channelId].Pushv(false); + LOG_DEBUG("Enqueue on HBMSwapKeyQue and EosL1Que, table:{}, batchId:{}, channelId:{}, swapInSize:{}, " + "swapOutSize:{}, EosL1Que.size: {}", + info.name, info.batchId, info.channelId, swapInKeys.size(), swapOutKeys.size(), + EosL1Que[info.name][info.channelId].Size()); } void HybridMgmt::BackUpTrainStatus() diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 988af792..473d95b7 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -169,8 +169,8 @@ public: std::map>[MAX_CHANNEL_NUM]> HBMSwapAddrsQue; std::map>[MAX_CHANNEL_NUM]> DDRSwapAddrsQue; - std::map>> EosL1Que; // pair - std::map>> EosL2Que; + std::map[MAX_CHANNEL_NUM]> EosL1Que; + std::map[MAX_CHANNEL_NUM]> EosL2Que; std::mutex evictMut; diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 2f9b5f93..4c8819d6 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -301,7 +301,7 @@ private: for (const auto& t : *out_tensors) { DataType tensor_type = t.dtype(); TensorShape tensor_shape = t.shape(); - LOG_DEBUG("Iterator getNext eos, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " + LOG_DEBUG("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " "tensor_shape: {}", dataset()->channelId_, iter_times_, @@ -309,6 +309,9 @@ private: tensor_type, tensor_shape.DebugString()); } + } else { + LOG_DEBUG("Iterator getNext eos, channel: {}, iter: {}, outTensor size: {}", dataset()->channelId_, + iter_times_, outSize); } auto keyProcess = Singleton::GetInstance(); -- Gitee From b17e2ebd54d0f4983f70ead5ba5cce994cf95786 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 13 Aug 2024 15:31:12 +0800 Subject: [PATCH 19/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index c43ec426..8bcae3d1 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1212,7 +1212,8 @@ void HybridMgmt::EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running"); + LOG_INFO("HybridMgmt is not running when [LookUpAndSendDDR], table:{}, batchId:{}, channel:{}", embInfo.name, + batchId, channelId); return; } EmbeddingSendDDR(info, h2dEmb); @@ -1237,7 +1238,9 @@ void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbI bool isEos = false; auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs, isEos); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running"); + LOG_INFO("HybridMgmt is not running or receive empty data when [ReceiveAndUpdateDDR], table:{}, batchId:{}, " + "channel:{}", + embInfo.name, batchId, channelId); return; } if (!isEos) { @@ -1262,7 +1265,8 @@ void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const E auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running"); + LOG_INFO("HybridMgmt is not running when [LookUpAndSendL3Storage], table:{}, batchId:{}, channel:{}", + embInfo.name, batchId, channelId); return; } @@ -1287,7 +1291,13 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons vector swapOutAddrs; int64_t dims0 = 0; bool isEos = false; - EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0, isEos); + auto isSuccess = EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0, isEos); + if (!isSuccess) { + LOG_INFO("HybridMgmt is not running or receive empty data when [LookUpAndSendL3Storage], table:{}, batchId:{}, " + "channel:{}", + embInfo.name, batchId, channelId); + return; + } if (!isEos) { EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); } -- Gitee From 4d8183d638d72e9b5973edb8a3c5a8cb8f5a4c5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Tue, 13 Aug 2024 19:35:32 +0800 Subject: [PATCH 20/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 126 +++++++++++++-------------- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- src/core/utils/common.cpp | 2 +- src/core/utils/common.h | 2 +- src/core/utils/task_queue.h | 2 +- 5 files changed, 67 insertions(+), 67 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 8bcae3d1..95cd5075 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -442,22 +442,7 @@ void HybridMgmt::Destroy() } // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 isRunning = false; - mutexDestroy = true; - for (const auto& embInfo : mgmtEmbInfo) { - for (int index = 0; index < EMBEDDING_THREAD_NUM; index++) { - string trainKey = MakeKeyName(index, embInfo.name, TRAIN_CHANNEL_ID); - lastUpdateFinishCV[trainKey].notify_all(); - lastLookUpFinishCV[trainKey].notify_all(); - lastSendFinishCV[trainKey].notify_all(); - lastRecvFinishCV[trainKey].notify_all(); - string evalKey = MakeKeyName(index, embInfo.name, EVAL_CHANNEL_ID); - lastUpdateFinishCV[evalKey].notify_all(); - lastLookUpFinishCV[evalKey].notify_all(); - lastSendFinishCV[evalKey].notify_all(); - lastRecvFinishCV[evalKey].notify_all(); - } - } cvCheckSave.notify_all(); // 防止save异常退出场景阻塞在EvalTask { @@ -474,11 +459,14 @@ void HybridMgmt::Destroy() for (auto& t : procThreads) { t->join(); } + procThreads.clear(); + if (cacheManager != nullptr) { cacheManager = nullptr; } JoinEmbeddingCacheThread(); - procThreads.clear(); + LOG_DEBUG(MGMT + "destroy EmbeddingCacheThread end."); + // 等待并销毁接收key的线程 for (auto& t : receiveKeyThreads) { t.join(); @@ -1470,38 +1458,44 @@ void HybridMgmt::InitEmbeddingCache(const vector& embInfos) void HybridMgmt::JoinEmbeddingCacheThread() { - for (auto& p : EosL1Que) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : EosL2Que) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : HBMSwapAddrsQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : HBMSwapKeyQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : HBMSwapKeyForL3StorageQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : DDRSwapKeyQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : DDRSwapKeyForL3StorageQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); - } - for (auto& p : DDRSwapAddrsQue) { - p.second[TRAIN_CHANNEL_ID].DestroyQueue(); - p.second[EVAL_CHANNEL_ID].DestroyQueue(); + for (int channelId = 0; channelId < MAX_CHANNEL_NUM; channelId++) { + // Let ReceiveAndUpdate & LookupAndSend thread stop. + for (const auto& embInfo : mgmtEmbInfo) { + for (int index = 0; index < EMBEDDING_THREAD_NUM; index++) { + string key = MakeSwapCVName(index, embInfo.name, channelId); + lastUpdateFinishCV[key].notify_all(); + lastLookUpFinishCV[key].notify_all(); + lastSendFinishCV[key].notify_all(); + lastRecvFinishCV[key].notify_all(); + } + } + + for (auto& p : EosL1Que) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : EosL2Que) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : HBMSwapAddrsQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : HBMSwapKeyQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : HBMSwapKeyForL3StorageQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : DDRSwapKeyQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : DDRSwapKeyForL3StorageQue) { + p.second[channelId].DestroyQueue(); + } + for (auto& p : DDRSwapAddrsQue) { + p.second[channelId].DestroyQueue(); + } } + for (auto& t : EmbeddingLookUpAndSendThreadPool) { t.join(); } @@ -1527,7 +1521,7 @@ void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut) bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, bool& isEos) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastRecvFinishLocker(lastRecvFinishMutex[currentKey]); lastRecvFinishCV[currentKey].wait(lastRecvFinishLocker, [info, this] { return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1537,7 +1531,7 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto if (!isRunning) { return false; } - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); if (isEos) { LOG_DEBUG("EmbeddingReceiveDDR get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, info.channelId); @@ -1589,7 +1583,7 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutex[currentKey]); lastUpdateFinishCV[currentKey].wait(lastUpdateFinishLocker, [info, this] { return (hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1619,13 +1613,13 @@ void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr } hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastUpdateFinishCV[nextKey].notify_all(); } bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2dEmb) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutex[currentKey]); lastUpdateFinishCV[currentKey].wait(lastUpdateFinishLocker, [info, this] { return (hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId] >= info.batchId) || mutexDestroy; @@ -1648,7 +1642,7 @@ bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2d } hybridMgmtBlock->lastLookUpFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastLookUpFinishCV[nextKey].notify_all(); return true; @@ -1656,7 +1650,7 @@ bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2d void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastSendFinishLocker(lastSendFinishMutex[currentKey]); lastSendFinishCV[currentKey].wait(lastSendFinishLocker, [info, this] { return (hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1665,7 +1659,7 @@ void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEm // 区分通道发送 hdTransfer->Send(TransferChannel::H2D, h2dEmb, info.channelId, info.name, info.batchId); hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastSendFinishCV[nextKey].notify_all(); LOG_DEBUG("table:{}, accumulate batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name, info.batchId, info.threadIdx, SendTC.ElapsedMS()); @@ -1680,6 +1674,7 @@ void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEm void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& embInfo, int channelId) { auto fn = [index, embInfo, channelId, this]() { + LOG_DEBUG(MGMT + "Create LookUpAndSendThread, table:{}, index:{}, channel:{}", embInfo.name, index, channelId); while (true) { lookUpAndSendBatchIdMtx[channelId].lock(); if (hybridMgmtBlock->lookUpAndSendTableBatchId[embInfo.name][channelId] % EMBEDDING_THREAD_NUM == index) { @@ -1695,6 +1690,8 @@ void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& em lookUpAndSendBatchIdMtx[channelId].unlock(); } if (!isRunning) { + LOG_DEBUG(MGMT + "Destroy LookUpAndSendThread, table:{}, index:{}, channel:{}, batchId:{}", + embInfo.name, index, channelId, hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId]); return; } } @@ -1705,6 +1702,7 @@ void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& em void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& embInfo, int channelId) { auto fn = [index, embInfo, channelId, this]() { + LOG_DEBUG(MGMT + "Create ReceiveAndUpdateThread, table:{}, index:{}, channel:{}", embInfo.name, index, channelId); while (true) { receiveAndUpdateBatchIdMtx[channelId].lock(); if (hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId] % EMBEDDING_THREAD_NUM == @@ -1721,6 +1719,8 @@ void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& receiveAndUpdateBatchIdMtx[channelId].unlock(); } if (!isRunning) { + LOG_DEBUG(MGMT + "Destroy ReceiveAndUpdateThread, table:{}, index:{}, channel:{}, batchId:{}", + embInfo.name, index, channelId, hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId]); return; } } @@ -1731,7 +1731,7 @@ void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0, bool& isEos) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastRecvFinishLocker(lastRecvFinishMutex[currentKey]); lastRecvFinishCV[currentKey].wait(lastRecvFinishLocker, [info, this] { return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1741,7 +1741,7 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, if (!isRunning) { return false; } - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); if (isEos) { LOG_DEBUG("EmbeddingReceiveL3Storage get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, info.channelId); @@ -1792,7 +1792,7 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, int64_t& dims0) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutex[currentKey]); lastUpdateFinishCV[currentKey].wait(lastUpdateFinishLocker, [info, this] { return (hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1836,13 +1836,13 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr info.batchId, info.channelId, info.threadIdx, L3StorageUpdateTC.ElapsedMS()); hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastUpdateFinishCV[nextKey].notify_all(); } bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastUpdateFinishLocker(lastUpdateFinishMutex[currentKey]); lastUpdateFinishCV[currentKey].wait(lastUpdateFinishLocker, [info, this] { return (hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId] >= info.batchId) || mutexDestroy; @@ -1889,7 +1889,7 @@ bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vectorlastLookUpFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastLookUpFinishCV[nextKey].notify_all(); return true; @@ -1897,7 +1897,7 @@ bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vector& h2dEmb) { - string currentKey = MakeKeyName(info.threadIdx, info.name, info.channelId); + string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastSendFinishLocker(lastSendFinishMutex[currentKey]); lastSendFinishCV[currentKey].wait(lastSendFinishLocker, [info, this] { return (hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; @@ -1906,7 +1906,7 @@ void HybridMgmt::EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& // 区分通道发送 hdTransfer->Send(TransferChannel::H2D, h2dEmb, info.channelId, info.name, info.batchId); hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId]++; - string nextKey = MakeKeyName(info.cvNotifyIndex, info.name, info.channelId); + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastSendFinishCV[nextKey].notify_all(); LOG_DEBUG("table:{}, channelId:{}, accumulate batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name.c_str(), info.channelId, info.batchId, info.threadIdx, SendTC.ElapsedMS()); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 473d95b7..0c9c35ff 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -142,7 +142,7 @@ public: void ReceiveKeyThread(const EmbInfo& embInfo); - GTEST_PRIVATE : bool mutexDestroy{false}; + GTEST_PRIVATE : bool mutexDestroy{false}; // LookupAndSend & ReceiveAndUpdate Condition_Variable_Wait stop. std::mutex lookUpAndSendBatchIdMtx[MAX_CHANNEL_NUM]; // train and eval std::mutex receiveAndUpdateBatchIdMtx[MAX_CHANNEL_NUM]; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index f267652e..3cadc5f5 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -188,7 +188,7 @@ namespace MxRec { } // Make key for mutex and cv in swap pipeline, id: threadId, channelId: train/eval. - string MakeKeyName(int id, const string& tableName, int channelId) + string MakeSwapCVName(int id, const string& tableName, int channelId) { return to_string(id) + tableName + to_string(channelId); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index b5008df3..357ce890 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -622,7 +622,7 @@ bool CheckFilePermission(const string& filePath); int GetStepFromPath(const string& loadPath); -string MakeKeyName(int id, const string& tableName, int channelId); +string MakeSwapCVName(int id, const string& tableName, int channelId); } // end namespace MxRec #define KEY_PROCESS "\033[45m[KeyProcess]\033[0m " diff --git a/src/core/utils/task_queue.h b/src/core/utils/task_queue.h index a42e5147..6042e54c 100644 --- a/src/core/utils/task_queue.h +++ b/src/core/utils/task_queue.h @@ -82,7 +82,7 @@ namespace MxRec { void DestroyQueue() { finished = true; - dataCond.notify_one(); + dataCond.notify_all(); } bool Empty() const -- Gitee From efe4a52ac0ac068bd9831a7616b25c562843e15f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 14 Aug 2024 10:28:17 +0800 Subject: [PATCH 21/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 49 +++++---------------- src/core/hybrid_mgmt/hybrid_mgmt.h | 5 --- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 51 ---------------------- src/core/hybrid_mgmt/hybrid_mgmt_block.h | 11 ----- src/core/key_process/key_process.cpp | 4 +- 5 files changed, 14 insertions(+), 106 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 95cd5075..21644507 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -211,8 +211,6 @@ void HybridMgmt::Save(const string& savePath, bool saveDelta) LOG_INFO(MGMT + "End to save {} model.", saveModelType); // 数据处理线程释放锁 KEY_PROCESS_INSTANCE->LoadSaveUnlock(); - hybridMgmtBlock->FinishSave(); - cvCheckSave.notify_all(); #endif } @@ -443,7 +441,6 @@ void HybridMgmt::Destroy() // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 isRunning = false; mutexDestroy = true; - cvCheckSave.notify_all(); // 防止save异常退出场景阻塞在EvalTask { // 获取锁 避免KeyProcess中手动发送结束信息时通道关闭 @@ -486,17 +483,16 @@ void HybridMgmt::Destroy() void HybridMgmt::TrainTask(TaskType type) { #ifndef GTEST - int channelId = TRAIN_CHANNEL_ID; - int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[channelId]; + int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[TRAIN_CHANNEL_ID]; do { - hybridMgmtBlock->CheckAndSetBlock(channelId); - if (hybridMgmtBlock->GetBlockStatus(channelId)) { - hybridMgmtBlock->DoBlock(channelId); + hybridMgmtBlock->CheckAndSetBlock(TRAIN_CHANNEL_ID); + if (hybridMgmtBlock->GetBlockStatus(TRAIN_CHANNEL_ID)) { + hybridMgmtBlock->DoBlock(TRAIN_CHANNEL_ID); } if (!isRunning) { return; } - LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, theTrainBatchId); + LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", TRAIN_CHANNEL_ID, theTrainBatchId); if (isBackUpTrainStatus) { RecoverTrainStatus(); } @@ -515,15 +511,6 @@ void HybridMgmt::EvalTask(TaskType type) do { hybridMgmtBlock->CheckAndSetBlock(EVAL_CHANNEL_ID); if (hybridMgmtBlock->GetBlockStatus(EVAL_CHANNEL_ID)) { - LOG_DEBUG("eval channel block at batchId:{}, needWaitSave:{}", evalBatchId, - hybridMgmtBlock->IsNeedWaitSave()); - std::unique_lock checkSaveLocker(saveMutex); - cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); - - LOG_DEBUG("eval channel block, python batch id:{}, hybridBatchId:{}", - hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID], evalBatchId); - - LOG_DEBUG("wake TrainTask"); hybridMgmtBlock->DoBlock(EVAL_CHANNEL_ID); } if (!isRunning) { @@ -629,14 +616,13 @@ bool HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo& info, bool isGrad) bool isEos = false; auto infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec(info, ProcessedInfo::RESTORE, isEos); if (isEos) { - HandleEosCase(info, remainBatchOut); - return remainBatchOut; + KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); + return false; } if (infoVecs == nullptr) { - LOG_INFO(MGMT + "table:{}, channelId:{} batchId:{}, ParseKeys infoVecs empty !", info.name, info.channelId, + LOG_WARN(MGMT + "table:{}, channelId:{} batchId:{}, ParseKeys infoVecs empty !", info.name, info.channelId, info.batchId); - remainBatchOut = false; - return remainBatchOut; + return false; } LOG_DEBUG("table:{}, channelId:{} batchId:{}, ParseKeysHBM GetInfoVec end", info.name, info.channelId, info.batchId); @@ -644,7 +630,7 @@ bool HybridMgmt::ProcessEmbInfoHBM(const EmbBaseInfo& info, bool isGrad) // 动态shape场景下,获取all2all向量(通信量矩阵) SendAll2AllVec(info, remainBatchOut); if (!remainBatchOut) { - return remainBatchOut; + return false; } // 发送查询向量 @@ -1507,18 +1493,6 @@ void HybridMgmt::JoinEmbeddingCacheThread() } } -void HybridMgmt::HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut) -{ - // Predict do not need to be blocked. - if (info.channelId == EVAL_CHANNEL_ID && alreadyTrainOnce) { - // Eval after train. - hybridMgmtBlock->SetBlockStatus(EVAL_CHANNEL_ID, true); - LOG_INFO("GetUniqueKeys get eos from eval channel, SetBlockStatus=true"); - } - KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); - remainBatchOut = false; -} - bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, bool& isEos) { string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); @@ -1995,8 +1969,9 @@ bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dE vector HybridMgmt::GetUniqueKeys(const EmbBaseInfo& info, bool& remainBatchOut, bool& isEos) { auto uniqueKeys = KEY_PROCESS_INSTANCE->GetUniqueKeys(info, isEos); + // DDR eos send in swap pipeline. if (isEos) { - HandleEosCase(info, remainBatchOut); + remainBatchOut = false; return uniqueKeys; } if (uniqueKeys.empty()) { diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 0c9c35ff..10eecf4a 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -184,9 +184,6 @@ public: std::map>> trainTestSwitchInfoStore{}; std::atomic lookupAddrSuccess{true}; - std::mutex saveMutex; - std::condition_variable cvCheckSave; - unique_ptr threadPool; void SetFeatureTypeForLoad(vector& loadFeatures); @@ -262,8 +259,6 @@ private: void JoinEmbeddingCacheThread(); - void HandleEosCase(const EmbBaseInfo& info, bool& remainBatchOut); - bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, bool& isEos); void EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index 8643f924..f6b36ef6 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -41,7 +41,6 @@ void HybridMgmtBlock::CheckAndSetBlock(int channelId) LOG_DEBUG(HYBRID_BLOCKING + "blocking by save saveInterval {} pythonBatchId {} hybridBatchId {}", saveInterval, pythonBatchId[channelId], hybridBatchId[channelId]); isBlock[TRAIN_CHANNEL_ID] = true; - finishSave = false; } if (stepsInterval[channelId] == -1) { return; @@ -189,43 +188,11 @@ void HybridMgmtBlock::ResetAll(int channelId) channelId, pythonBatchId[channelId], readEmbedBatchId[channelId], hybridBatchId[channelId]); } -/// 检查当前的步数是否可以进行save -/// \return 0 is legal, 1 需要回退一步, -1 表示错误 -int HybridMgmtBlock::CheckSaveEmbMapValid() -{ - // 检查数据通道此时的HashMap是否被提前处理了 - if (pythonBatchId[lastRunChannelId] >= hybridBatchId[lastRunChannelId]) { - LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt is checking the step and checking that the parameters are normal. " - "The number of steps in the previous round is " - "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); - return 0; - } else if (pythonBatchId[lastRunChannelId] + 1 == hybridBatchId[lastRunChannelId]) { - // 在通道切换时,上一个通道处理的数据超出了python侧的调用 - LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt is checking the step, and the parameters have been processed one step " - "in advance. The number of steps in the previous round was " - "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); - - return 1; - } else { - // 在通道切换时,hybrid处理的数据还没有赶上python侧,此时需要等待hybrid处理完成 - LOG_DEBUG(HYBRID_BLOCKING + "ERROR FLAG lastRunChannelId {} hybridBatchId {}", lastRunChannelId, - hybridBatchId[lastRunChannelId]); - return -1; - } -} - bool HybridMgmtBlock::GetBlockStatus(int channelId) { return isBlock[channelId]; } -void HybridMgmtBlock::SetBlockStatus(int channelId, bool block) -{ - isBlock[channelId] = block; -} - void HybridMgmtBlock::Destroy() { if (!isRunning) { @@ -253,22 +220,4 @@ void HybridMgmtBlock::SetStepInterval(int trainStep, int evalStep) HybridMgmtBlock::~HybridMgmtBlock() { Destroy(); -} - -void HybridMgmtBlock::Wake(int channelId) -{ - isBlock[channelId] = false; -} - -bool HybridMgmtBlock::IsNeedWaitSave() -{ - if (saveInterval != 0 && saveInterval != -1 && hybridBatchId[TRAIN_CHANNEL_ID] % saveInterval == 0 && !finishSave) { - return true; - } - return false; -} - -void HybridMgmtBlock::FinishSave() -{ - finishSave = true; } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h index b80a2848..ab28a267 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -76,12 +76,8 @@ namespace MxRec { void ResetAll(int channelId); - int CheckSaveEmbMapValid(); - bool GetBlockStatus(int channelId); - void SetBlockStatus(int channelId, bool block); - void SetRankInfo(RankInfo ri); void SetStepInterval(int trainStep, int evalStep); @@ -90,19 +86,12 @@ namespace MxRec { void Destroy(); - void Wake(int channelId); - - bool IsNeedWaitSave(); - - void FinishSave(); - private: // 控制通道阻塞的变量 bool isBlock[MAX_CHANNEL_NUM] = {true, true}; // 控制训练了多少步进行保存的步数 int saveInterval = 0; RankInfo rankInfo; - bool finishSave = true; }; class HybridMgmtBlockingException : public std::exception { diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 43724a39..606e1233 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1265,7 +1265,7 @@ vector KeyProcess::GetUniqueKeys(const EmbBaseInfo& info, bool& isEos) auto infoVec = GetInfo(uniqueKeysList, info); isEos = get(infoVec); if (isEos) { - LOG_WARN(KEY_PROCESS "GetUniqueKeys eos! {}[{}]:{}", info.name, info.channelId, info.batchId); + LOG_INFO(KEY_PROCESS "GetUniqueKeys eos! {}[{}]:{}", info.name, info.channelId, info.batchId); break; } ret = get>(infoVec); @@ -1404,7 +1404,7 @@ unique_ptr> KeyProcess::GetInfoVec(const EmbBaseInfo& info, Proce auto infoVec = GetInfo(*list, info); isEos = get(infoVec); if (isEos) { - LOG_WARN(KEY_PROCESS "GetInfoVec eos! {}[{}]:{}", info.name, info.channelId, info.batchId); + LOG_INFO(KEY_PROCESS "GetInfoVec eos! {}[{}]:{}", info.name, info.channelId, info.batchId); break; } auto it = get>>::iterator>(infoVec); -- Gitee From 1b2fe697e4c141f55d309d6fe655e991775837cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 14 Aug 2024 11:15:52 +0800 Subject: [PATCH 22/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.h | 3 +- .../key_process/feature_admit_and_evict.h | 3 +- src/core/key_process/key_process.cpp | 29 ++++++++++--------- src/core/key_process/key_process.h | 28 +++++++----------- 4 files changed, 30 insertions(+), 33 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 10eecf4a..b84d2d83 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -142,7 +142,8 @@ public: void ReceiveKeyThread(const EmbInfo& embInfo); - GTEST_PRIVATE : bool mutexDestroy{false}; // LookupAndSend & ReceiveAndUpdate Condition_Variable_Wait stop. +GTEST_PRIVATE : + bool mutexDestroy{false}; // LookupAndSend & ReceiveAndUpdate Condition_Variable_Wait stop. std::mutex lookUpAndSendBatchIdMtx[MAX_CHANNEL_NUM]; // train and eval std::mutex receiveAndUpdateBatchIdMtx[MAX_CHANNEL_NUM]; diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 6c82c846..e1ef1018 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -96,8 +96,7 @@ namespace MxRec { static std::vector m_cfgThresholds; // 用于判断阈值配置的有效性 static absl::flat_hash_map m_embStatus; // 用于“准入&淘汰”功能解耦 - - GTEST_PRIVATE : + GTEST_PRIVATE : // 解析m_table2Threshold bool ParseThresholdCfg(const std::vector& thresholdValues); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 8a6d033e..3f937eaf 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -14,21 +14,21 @@ See the License for the specific language governing permissions and ==============================================================================*/ #include "key_process.h" +#include +#include + #include #include -#include -#include - +#include "emb_table/embedding_mgmt.h" +#include "hd_transfer/hd_transfer.h" +#include "ock_ctr_common/include/error_code.h" #include "utils/common.h" +#include "utils/config.h" #include "utils/logger.h" #include "utils/safe_queue.h" #include "utils/singleton.h" #include "utils/time_cost.h" -#include "utils/config.h" -#include "emb_table/embedding_mgmt.h" -#include "hd_transfer/hd_transfer.h" -#include "ock_ctr_common/include/error_code.h" using namespace std; using namespace chrono; @@ -463,6 +463,10 @@ KeysT KeyProcess::BroadcastGlobalDpIdUnique(const unique_ptr& batch, bool KeyProcess::KeyProcessTaskHelperForDp(unique_ptr& batch, int channel, int threadId) { + if (batch->isEos) { + HandleEos(batch, channel, threadId); + return true; + } vector splitKeys; vector restore; vector hotPos; @@ -1021,7 +1025,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, ve batch->batchId); // 使用静态all2all通信:发送或接受量为预置固定值 scInfo[batch->name] = 65536 / rankSize 经验值 - if (rankInfo.useStatic) { // maybe move after all2all + if (rankInfo.useStatic) { // maybe move after all2all ProcessKeysWithStatic(batch, splitKeys); } @@ -1031,7 +1035,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, ve LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, batchName:{}, MPI_Allgatherv finish." " processSplitKeysTC(ms):{}", batch->channel, id, batch->batchId, batch->name, processSplitKeysTC.ElapsedMS()); - return { keyRecv, scAll, ss }; + return {keyRecv, scAll, ss}; } KeysT keySend; @@ -1379,8 +1383,7 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, cons LOG_DEBUG("channelId:{} threadId:{} batchId:{}, GetScAll start.", batch->channel, commId, batch->batchId); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - auto retCode = MPI_Allgather(keyScLocal.data(), sendAndRecvCount, MPI_INT, - scAll.data(), sendAndRecvCount, MPI_INT, + auto retCode = MPI_Allgather(keyScLocal.data(), sendAndRecvCount, MPI_INT, scAll.data(), sendAndRecvCount, MPI_INT, comm[batch->channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {} commId {}, MPI_Allgather failed:{}", rankInfo.rankId, commId, retCode); @@ -1826,11 +1829,11 @@ void KeyProcess::EnqueueEosBatch(int64_t batchNum, int channelId) int batchQueueId = int(batchNum % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId); auto queue = SingletonQueue::GetInstances(batchQueueId); for (auto& emb : embInfos) { - auto batchData = queue->GetOne(); // get dirty or empty data block + auto batchData = queue->GetOne(); // get dirty or empty data block batchData->name = emb.first; batchData->channel = channelId; batchData->batchId = batchNum; - batchData->sample = {0, 0, 0, 0, 0, 0, 0, 0}; // fake data + batchData->sample = {0, 0, 0, 0, 0, 0, 0, 0}; // fake data batchData->isEos = true; queue->Pushv(move(batchData)); } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 61c0f375..20bd95b7 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -68,9 +68,6 @@ private: const char* errorMessage; }; -constexpr int MPI_ABNORMAL_SEND_VALUE = 0; // MPI异常通信时发送0 -constexpr int MPI_NORMAL_SEND_VALUE = 1; // MPI正常通信时发送1 - class EmptyList : public std::exception {}; class WrongListTop : public std::exception {}; @@ -78,7 +75,7 @@ class WrongListTop : public std::exception {}; class KeyProcess { public: bool Initialize(const RankInfo& rInfo, const vector& eInfos, - const vector& thresholdValues = {}, int seed = 0, bool isIncrementalCkpt = false); + const vector& thresholdValues = {}, int seed = 0, bool isIncrementalCkpt = false); unique_ptr> GetInfoVec(const EmbBaseInfo& info, ProcessedInfo type, bool& isEos); @@ -200,7 +197,7 @@ public: bool isRunning{false}; - bool isIncrementalCheckpoint {false}; + bool isIncrementalCheckpoint{false}; std::mutex destroyMutex; std::mutex eosMutex; @@ -209,10 +206,8 @@ public: return embInfos.find(embName) != embInfos.end(); }; - GTEST_PRIVATE : - - int - Start(); +GTEST_PRIVATE : + int Start(); template T GetInfo(info_list_t& list, const EmbBaseInfo& info); @@ -273,7 +268,7 @@ public: void GetUniqueConfig(ock::ctr::UniqueConf& uniqueConf); void InitializeUnique(ock::ctr::UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, - const unique_ptr & batch, ock::ctr::UniquePtr& unique); + const unique_ptr& batch, ock::ctr::UniquePtr& unique); void ProcessBatchWithFastUnique(const unique_ptr& batch, ock::ctr::UniquePtr& unique, int id, UniqueInfo& uniqueInfoOut); @@ -285,8 +280,8 @@ public: auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; - auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector, - vector>; + auto HotHashSplit(const unique_ptr& batch) + -> tuple, vector, vector, vector>; void PaddingAlltoallVC(vector& splitKeys) const; @@ -339,12 +334,11 @@ public: vector GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss); - void HashSplitHelper(const unique_ptr & batch, vector & splitKeys, - vector & restore, vector & hotPos, - vector >& keyCount, vector& keyCountVec); + void HashSplitHelper(const unique_ptr& batch, vector& splitKeys, vector& restore, + vector& hotPos, vector>& keyCount, vector& keyCountVec); - vector GetCountRecvForDp(const unique_ptr& batch, const int id, - vector& keyCount, vector scAll); + vector GetCountRecvForDp(const unique_ptr& batch, const int id, vector& keyCount, + vector scAll); KeysT FeatureAdmitForDp(KeysT& lookupKeys, KeysT& globalDpIdVec); -- Gitee From 2dc1c73854b73a7c1ea90d3a4f8f13a122ac93af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 14 Aug 2024 11:17:30 +0800 Subject: [PATCH 23/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/key_process/key_process.cpp | 1 - src/core/key_process/key_process.h | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 3f937eaf..33a1e197 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1679,7 +1679,6 @@ unique_ptr> KeyProcess::GetKCInfoVec(const EmbBaseInfo& info) keyCountStorage.erase(it); break; } catch (EmptyList&) { - unique_lock lockEosGuard(eosMutex); LOG_TRACE("getting info failed, list is empty."); this_thread::sleep_for(1ms); } catch (WrongListTop&) { diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 20bd95b7..a10b5a9c 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -200,7 +200,7 @@ public: bool isIncrementalCheckpoint{false}; std::mutex destroyMutex; - std::mutex eosMutex; + inline bool HasEmbName(const string& embName) { return embInfos.find(embName) != embInfos.end(); -- Gitee From d9fead8db70fe85361e52b378c6d1c81c0fac84c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 14 Aug 2024 15:31:43 +0800 Subject: [PATCH 24/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 22 ++++++++++++++-------- src/dataset_tf/eos_dataset_op.cc | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 1a161e7b..1ebbeebc 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -452,17 +452,19 @@ void HybridMgmt::Destroy() LOG_DEBUG(MGMT + "destroy hdTransfer end."); } + JoinEmbeddingCacheThread(); + LOG_DEBUG(MGMT + "destroy EmbeddingCacheThread end."); + hybridMgmtBlock->Destroy(); for (auto& t : procThreads) { t->join(); } procThreads.clear(); + LOG_DEBUG(MGMT + "destroy parseKeyThread end."); if (cacheManager != nullptr) { cacheManager = nullptr; } - JoinEmbeddingCacheThread(); - LOG_DEBUG(MGMT + "destroy EmbeddingCacheThread end."); // 等待并销毁接收key的线程 for (auto& t : receiveKeyThreads) { @@ -1186,7 +1188,7 @@ void HybridMgmt::EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running when [LookUpAndSendDDR], table:{}, batchId:{}, channel:{}", embInfo.name, + LOG_DEBUG("HybridMgmt is not running when [LookUpAndSendDDR], table:{}, batchId:{}, channel:{}", embInfo.name, batchId, channelId); return; } @@ -1212,7 +1214,7 @@ void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbI bool isEos = false; auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs, isEos); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running or receive empty data when [ReceiveAndUpdateDDR], table:{}, batchId:{}, " + LOG_DEBUG("HybridMgmt is not running or receive empty data when [ReceiveAndUpdateDDR], table:{}, batchId:{}, " "channel:{}", embInfo.name, batchId, channelId); return; @@ -1239,7 +1241,7 @@ void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const E auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running when [LookUpAndSendL3Storage], table:{}, batchId:{}, channel:{}", + LOG_DEBUG("HybridMgmt is not running when [LookUpAndSendL3Storage], table:{}, batchId:{}, channel:{}", embInfo.name, batchId, channelId); return; } @@ -1267,7 +1269,7 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons bool isEos = false; auto isSuccess = EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0, isEos); if (!isSuccess) { - LOG_INFO("HybridMgmt is not running or receive empty data when [LookUpAndSendL3Storage], table:{}, batchId:{}, " + LOG_DEBUG("HybridMgmt is not running or receive empty data when [LookUpAndSendL3Storage], table:{}, batchId:{}, " "channel:{}", embInfo.name, batchId, channelId); return; @@ -1500,7 +1502,9 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto lastRecvFinishCV[currentKey].wait(lastRecvFinishLocker, [info, this] { return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; }); - + if (!isRunning) { + return false; + } isEos = EosL2Que[info.name][info.channelId].WaitAndPop(); if (!isRunning) { return false; @@ -1710,7 +1714,9 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, lastRecvFinishCV[currentKey].wait(lastRecvFinishLocker, [info, this] { return (hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId] == info.batchId) || mutexDestroy; }); - + if (!isRunning) { + return false; + } isEos = EosL1Que[info.name][info.channelId].WaitAndPop(); if (!isRunning) { return false; diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 4c8819d6..35e8f617 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -301,7 +301,7 @@ private: for (const auto& t : *out_tensors) { DataType tensor_type = t.dtype(); TensorShape tensor_shape = t.shape(); - LOG_DEBUG("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " + LOG_TRACE("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " "tensor_shape: {}", dataset()->channelId_, iter_times_, -- Gitee From 2a470106d5cc54396500a8eff2d6bf3336608f6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Wed, 14 Aug 2024 15:59:42 +0800 Subject: [PATCH 25/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/dataset_tf/eos_dataset_op.cc | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 35e8f617..3fcb7bc8 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -301,7 +301,7 @@ private: for (const auto& t : *out_tensors) { DataType tensor_type = t.dtype(); TensorShape tensor_shape = t.shape(); - LOG_TRACE("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " + LOG_DEBUG("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " "tensor_shape: {}", dataset()->channelId_, iter_times_, @@ -332,10 +332,9 @@ private: MPI_Iallreduce(MPI_IN_PLACE, &getNextStatus, 1, MPI_INT, MPI_SUM, g_comm[channelId], &req); CheckCommFinished(req, channelId); - // Max step is achieved, no need to send eos. - if (outSize == 0) { - keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); - } + + keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); + LOG_DEBUG("[ACTIVE] GetNext eos was triggered actively, channel: {}, iter: {}", dataset()->channelId_, iter_times_); @@ -350,10 +349,9 @@ private: if (getNextStatus < g_rankSize) { *end_of_sequence = true; - // Max step is achieved, no need to send eos. - if (outSize == 0) { - keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); - } + + keyProcess->EnqueueEosBatch(iter_times_, dataset()->channelId_); + LOG_DEBUG( "[PASSIVE] GetNext eos was triggered passively, channel: {}, iter: {}, sum: {}", dataset()->channelId_, iter_times_, getNextStatus); -- Gitee From 36a0c37a77e00d0439faa2b10521b1fbf687edb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Thu, 15 Aug 2024 11:08:27 +0800 Subject: [PATCH 26/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 106 ++++++++++-------- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 7 +- src/core/hybrid_mgmt/hybrid_mgmt_block.h | 6 +- .../hybrid_mgmt/hybrid_mgmt_block_test.cpp | 25 ----- 4 files changed, 67 insertions(+), 77 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 1ebbeebc..ac359e6d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1184,12 +1184,12 @@ void HybridMgmt::EmbeddingLookUpAndSendDDR(int batchId, int index, const EmbInfo .extEmbeddingSize = embInfo.extEmbeddingSize, .channelId = channelId, .name = embInfo.name}; - vector h2dEmb; + vector h2dEmb; auto isSuccess = EmbeddingLookUpDDR(info, h2dEmb); if (!isSuccess) { LOG_DEBUG("HybridMgmt is not running when [LookUpAndSendDDR], table:{}, batchId:{}, channel:{}", embInfo.name, - batchId, channelId); + batchId, channelId); return; } EmbeddingSendDDR(info, h2dEmb); @@ -1215,8 +1215,8 @@ void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbI auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs, isEos); if (!isSuccess) { LOG_DEBUG("HybridMgmt is not running or receive empty data when [ReceiveAndUpdateDDR], table:{}, batchId:{}, " - "channel:{}", - embInfo.name, batchId, channelId); + "channel:{}", + embInfo.name, batchId, channelId); return; } if (!isEos) { @@ -1242,7 +1242,7 @@ void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const E auto isSuccess = EmbeddingLookUpL3Storage(info, h2dEmb); if (!isSuccess) { LOG_DEBUG("HybridMgmt is not running when [LookUpAndSendL3Storage], table:{}, batchId:{}, channel:{}", - embInfo.name, batchId, channelId); + embInfo.name, batchId, channelId); return; } @@ -1269,9 +1269,10 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons bool isEos = false; auto isSuccess = EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0, isEos); if (!isSuccess) { - LOG_DEBUG("HybridMgmt is not running or receive empty data when [LookUpAndSendL3Storage], table:{}, batchId:{}, " - "channel:{}", - embInfo.name, batchId, channelId); + LOG_DEBUG( + "HybridMgmt is not running or receive empty data when [LookUpAndSendL3Storage], table:{}, batchId:{}, " + "channel:{}", + embInfo.name, batchId, channelId); return; } if (!isEos) { @@ -1544,15 +1545,15 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto int64_t dims[dimNum]; acltdtGetDimsFromItem(aclData, dims, dimNum); - LOG_DEBUG("table:{}, accumulate batchId:{}, dims[0]:{}, swapOutAddrs size:{}", info.name, info.batchId, dims[0], - swapOutAddrs.size()); + LOG_DEBUG(MGMT + "In swap thread, finish receive d2h embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, dims[0]:{}, swapOutAddrs size:{}, EmbeddingRecvTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, dims[0], swapOutAddrs.size(), + EmbeddingRecvTC.ElapsedMS()); if (dims[0] != static_cast(swapOutAddrs.size())) { throw runtime_error("data dims[0] != swapOutKeys.size()"); } - LOG_DEBUG("table:{}, accumulate batchId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name, info.batchId, - info.threadIdx, EmbeddingRecvTC.ElapsedMS()); hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId]++; lastRecvFinishCV[nextKey].notify_all(); @@ -1583,11 +1584,10 @@ void HybridMgmt::EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr if (!swapOutAddrs.empty()) { sample = FloatPtrToLimitStr(swapOutAddrs.front(), info.extEmbeddingSize); // print first element } - LOG_DEBUG( - "table:{}, accumulate batchId:{}, thread:{}, receive d2hEmb, ext emb:{}, emb size:{}, emb samples:{}, " - "EmbeddingUpdateTC(ms):{}", - info.name.c_str(), info.batchId, info.threadIdx, info.extEmbeddingSize, swapOutAddrs.size(), sample, - EmbeddingUpdateTC.ElapsedMS()); + LOG_DEBUG(MGMT + "In swap thread, finish update d2h embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, ext emb:{}, emb size:{}, emb samples:{}, EmbeddingUpdateTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, info.extEmbeddingSize, swapOutAddrs.size(), + sample, EmbeddingUpdateTC.ElapsedMS()); } hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId]++; @@ -1623,6 +1623,9 @@ bool HybridMgmt::EmbeddingLookUpDDR(const EmbTaskInfo& info, vector& h2d string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastLookUpFinishCV[nextKey].notify_all(); + LOG_DEBUG(MGMT + "In swap thread, finish embedding lookup, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}", + info.name, info.channelId, info.batchId, info.threadIdx); return true; } @@ -1639,14 +1642,12 @@ void HybridMgmt::EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEm hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId]++; string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastSendFinishCV[nextKey].notify_all(); - LOG_DEBUG("table:{}, accumulate batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name, info.batchId, - info.threadIdx, SendTC.ElapsedMS()); - // 对于end of sequence场景,key - // process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]++; - LOG_DEBUG("h2dNextBatchId, table:{}, channelId:{}, next batchId:{}", info.name, info.channelId, - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]); + LOG_DEBUG(MGMT + "In swap thread, finish send h2d embedding, table:{}, channelId:{}, batchId:{}, accumulate " + "batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", + info.name, info.channelId, hybridMgmtBlock->h2dSendBatchId[info.name][info.channelId], info.batchId, + info.threadIdx, SendTC.ElapsedMS()); + hybridMgmtBlock->h2dSendBatchId[info.name][info.channelId]++; } void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& embInfo, int channelId) @@ -1669,7 +1670,8 @@ void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& em } if (!isRunning) { LOG_DEBUG(MGMT + "Destroy LookUpAndSendThread, table:{}, index:{}, channel:{}, batchId:{}", - embInfo.name, index, channelId, hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId]); + embInfo.name, index, channelId, + hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId]); return; } } @@ -1680,7 +1682,8 @@ void HybridMgmt::CreateEmbeddingLookUpAndSendThread(int index, const EmbInfo& em void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& embInfo, int channelId) { auto fn = [index, embInfo, channelId, this]() { - LOG_DEBUG(MGMT + "Create ReceiveAndUpdateThread, table:{}, index:{}, channel:{}", embInfo.name, index, channelId); + LOG_DEBUG(MGMT + "Create ReceiveAndUpdateThread, table:{}, index:{}, channel:{}", embInfo.name, index, + channelId); while (true) { receiveAndUpdateBatchIdMtx[channelId].lock(); if (hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId] % EMBEDDING_THREAD_NUM == @@ -1698,7 +1701,8 @@ void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& } if (!isRunning) { LOG_DEBUG(MGMT + "Destroy ReceiveAndUpdateThread, table:{}, index:{}, channel:{}, batchId:{}", - embInfo.name, index, channelId, hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId]); + embInfo.name, index, channelId, + hybridMgmtBlock->receiveAndUpdateTableBatchId[embInfo.name][channelId]); return; } } @@ -1757,13 +1761,13 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, size_t dimNum = acltdtGetDimNumFromItem(aclData); int64_t dims[dimNum]; acltdtGetDimsFromItem(aclData, dims, dimNum); - - LOG_DEBUG("table:{}, accumulate batchId:{}, channelId:{}, recv d2h, dims[0]:{}, swapOutAddrs.size:{}", info.name, - info.batchId, info.channelId, dims[0], swapOutAddrs.size()); dims0 = dims[0]; - LOG_DEBUG("table:{}, accumulate batchId:{}, channelId:{}, thread:{}, EmbeddingRecvTC(ms):{}", info.name.c_str(), - info.batchId, info.channelId, info.threadIdx, EmbeddingRecvTC.ElapsedMS()); + LOG_DEBUG(MGMT + "In swap thread, finish receive d2h embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, dims[0]:{}, swapOutAddrs size:{}, EmbeddingRecvTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, dims[0], swapOutAddrs.size(), + EmbeddingRecvTC.ElapsedMS()); + hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId]++; lastRecvFinishCV[nextKey].notify_all(); return true; @@ -1794,9 +1798,10 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr throw runtime_error("memcpy_s failed, error code:" + to_string(rc)); } } - LOG_DEBUG("table:{}, accumulate batchId:{}, channelId:{}, thread:{}, EmbeddingUpdateTC(ms):{}", info.name.c_str(), - info.batchId, info.channelId, info.threadIdx, EmbeddingUpdateTC.ElapsedMS()); + LOG_DEBUG(MGMT + "In swap thread, finish update d2h DDR embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, EmbeddingUpdateTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, EmbeddingUpdateTC.ElapsedMS()); // L3Storage更新 TimeCost L3StorageUpdateTC = TimeCost(); std::vector swapOutL3StorageAddrOffs = @@ -1812,8 +1817,11 @@ void HybridMgmt::EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr } cacheManager->UpdateL3StorageEmb(info.name, embPtr, extEmbeddingSize, swapOutL3StorageKeys, swapOutL3StorageAddrOffs); - LOG_DEBUG("table:{}, accumulate batchId:{}, channelId:{}, thread:{}, L3StorageUpdateTC(ms):{}", info.name.c_str(), - info.batchId, info.channelId, info.threadIdx, L3StorageUpdateTC.ElapsedMS()); + + LOG_DEBUG( + MGMT + "In swap thread, finish update d2h L3Storage embedding, table:{}, channelId:{}, accumulate batchId:{}, " + "thread:{}, L3StorageUpdateTC(ms):{}", + info.name, info.channelId, info.batchId, info.threadIdx, L3StorageUpdateTC.ElapsedMS()); hybridMgmtBlock->lastUpdateFinishStep[info.name][info.channelId]++; string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); @@ -1867,7 +1875,9 @@ bool HybridMgmt::EmbeddingLookUpL3Storage(const EmbTaskInfo& info, vectorlastLookUpFinishStep[info.name][info.channelId]++; string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastLookUpFinishCV[nextKey].notify_all(); @@ -1888,14 +1898,12 @@ void HybridMgmt::EmbeddingSendL3Storage(const EmbTaskInfo& info, vector& hybridMgmtBlock->lastSendFinishStep[info.name][info.channelId]++; string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastSendFinishCV[nextKey].notify_all(); - LOG_DEBUG("table:{}, channelId:{}, accumulate batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", info.name.c_str(), - info.channelId, info.batchId, info.threadIdx, SendTC.ElapsedMS()); + LOG_DEBUG(MGMT + "In swap thread, finish send h2d embedding, table:{}, channelId:{}, batchId:{}, accumulate " + "batchId:{}, thread:{}, SendH2DEmbTC(ms):{}", + info.name, info.channelId, hybridMgmtBlock->h2dSendBatchId[info.name][info.channelId], info.batchId, + info.threadIdx, SendTC.ElapsedMS()); - // 对于end of sequence场景,key - // process需要基于h2dNextBatchId等待每个table都完成了最后1个step发送,才能发EOS至各channel - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]++; - LOG_DEBUG("h2dNextBatchId, table:{}, channelId:{}, next batchId:{}", info.name, info.channelId, - hybridMgmtBlock->h2dNextBatchId[info.name][info.channelId]); + hybridMgmtBlock->h2dSendBatchId[info.name][info.channelId]++; } void HybridMgmt::HandleDataSwapForL3Storage(const EmbBaseInfo& info, vector& swapInKeys, @@ -1964,11 +1972,11 @@ bool HybridMgmt::BuildH2DEmbedding(const EmbTaskInfo& info, vector& h2dE throw runtime_error("memcpy_s failed, error code:" + to_string(rc)); } } - LOG_DEBUG("table:{}, channel:{}, thread:{}, accumulate batchId:{}, send h2dEmb, emb size:{}, emb samples:{}, " - "embeddingLookupTC(ms):{}", - info.name.c_str(), info.channelId, info.threadIdx, info.batchId, swapInAddrs.size(), - FloatPtrToLimitStr(h2dEmbAddr, swapInAddrs.size() * info.extEmbeddingSize), - embeddingLookupTC.ElapsedMS()); + LOG_DEBUG( + "[BuildH2DEmbedding] table:{}, channel:{}, thread:{}, accumulate batchId:{}, emb size:{}, emb samples:{}, " + "embeddingLookupTC(ms):{}", + info.name.c_str(), info.channelId, info.threadIdx, info.batchId, swapInAddrs.size(), + FloatPtrToLimitStr(h2dEmbAddr, swapInAddrs.size() * info.extEmbeddingSize), embeddingLookupTC.ElapsedMS()); return true; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index f6b36ef6..3069e073 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -174,7 +174,7 @@ void HybridMgmtBlock::ResetAll(int channelId) b.second[channelId] = 0; } // L3 data pipeline, swap - for (auto& b : h2dNextBatchId) { + for (auto& b : h2dSendBatchId) { b.second[channelId] = 0; } @@ -193,6 +193,11 @@ bool HybridMgmtBlock::GetBlockStatus(int channelId) return isBlock[channelId]; } +void HybridMgmtBlock::SetBlockStatus(int channelId, bool block) +{ + isBlock[channelId] = block; +} + void HybridMgmtBlock::Destroy() { if (!isRunning) { diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h index ab28a267..65fb29f7 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -45,8 +45,8 @@ namespace MxRec { int maxTrainStep = 0; int stepsInterval[MAX_CHANNEL_NUM] = {0, 0}; // 通道i运行多少步后切换为通道j - std::map lookUpSwapAddrsPushId; // L2 pipeline, key->addr - std::map h2dNextBatchId; // L3 pipeline, use for eos + std::map lookUpSwapAddrsPushId; // L2 pipeline, key->addr + std::map h2dSendBatchId; // L3 pipeline, swap thread std::map lookUpAndSendTableBatchId; std::map receiveAndUpdateTableBatchId; std::map lastUpdateFinishStep; @@ -78,6 +78,8 @@ namespace MxRec { bool GetBlockStatus(int channelId); + void SetBlockStatus(int channelId, bool block); + void SetRankInfo(RankInfo ri); void SetStepInterval(int trainStep, int evalStep); diff --git a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp index c51d4be2..9fd2db80 100644 --- a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp +++ b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp @@ -94,31 +94,6 @@ TEST_F(HybridMgmtBlockTest, ResetAll) ASSERT_EQ(hybridMgmtBlock->hybridBatchId[0], 0); } -TEST_F(HybridMgmtBlockTest, CheckSaveEmbMapValid) -{ - hybridMgmtBlock = std::make_unique(); - hybridMgmtBlock->SetStepInterval(1, 1); - hybridMgmtBlock->lastRunChannelId = 0; - - hybridMgmtBlock->pythonBatchId[0] = 0; - hybridMgmtBlock->hybridBatchId[0] = 0; - hybridMgmtBlock->CheckSaveEmbMapValid(); - int status0 = hybridMgmtBlock->CheckSaveEmbMapValid(); - - hybridMgmtBlock->pythonBatchId[0] = 0; - hybridMgmtBlock->hybridBatchId[0] = 1; - hybridMgmtBlock->CheckSaveEmbMapValid(); - int status1 = hybridMgmtBlock->CheckSaveEmbMapValid(); - - int step2 = 2; - hybridMgmtBlock->pythonBatchId[0] = 0; - hybridMgmtBlock->hybridBatchId[0] = step2; - int status2 = hybridMgmtBlock->CheckSaveEmbMapValid(); - ASSERT_EQ(status0, 0); - ASSERT_EQ(status1, 1); - ASSERT_EQ(status2, -1); -} - TEST_F(HybridMgmtBlockTest, CountPythonStep) { hybridMgmtBlock = std::make_unique(); -- Gitee From 981c4ec7b565716f24df1013b09442a4d0d784ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Thu, 15 Aug 2024 14:48:33 +0800 Subject: [PATCH 27/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 35 +++++++++++----------------- src/core/hybrid_mgmt/hybrid_mgmt.h | 5 ++-- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index ac359e6d..cec4d3b9 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1211,17 +1211,14 @@ void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbI float* ptr = nullptr; vector swapOutAddrs; - bool isEos = false; - auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs, isEos); + auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs); if (!isSuccess) { LOG_DEBUG("HybridMgmt is not running or receive empty data when [ReceiveAndUpdateDDR], table:{}, batchId:{}, " "channel:{}", embInfo.name, batchId, channelId); return; } - if (!isEos) { - EmbeddingUpdateDDR(info, ptr, swapOutAddrs); - } + EmbeddingUpdateDDR(info, ptr, swapOutAddrs); } void HybridMgmt::EmbeddingLookUpAndSendL3Storage(int batchId, int index, const EmbInfo& embInfo, int channelId) @@ -1266,8 +1263,8 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons float* ptr = nullptr; vector swapOutAddrs; int64_t dims0 = 0; - bool isEos = false; - auto isSuccess = EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0, isEos); + + auto isSuccess = EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0); if (!isSuccess) { LOG_DEBUG( "HybridMgmt is not running or receive empty data when [LookUpAndSendL3Storage], table:{}, batchId:{}, " @@ -1275,9 +1272,7 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons embInfo.name, batchId, channelId); return; } - if (!isEos) { - EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); - } + EmbeddingUpdateL3Storage(info, ptr, swapOutAddrs, dims0); } /// 构造训练所需的各种向量数据 @@ -1496,7 +1491,7 @@ void HybridMgmt::JoinEmbeddingCacheThread() } } -bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, bool& isEos) +bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs) { string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastRecvFinishLocker(lastRecvFinishMutex[currentKey]); @@ -1506,17 +1501,14 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto if (!isRunning) { return false; } - isEos = EosL2Que[info.name][info.channelId].WaitAndPop(); + bool isEos = EosL2Que[info.name][info.channelId].WaitAndPop(); if (!isRunning) { return false; } - string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); if (isEos) { LOG_DEBUG("EmbeddingReceiveDDR get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, info.channelId); KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); - lastRecvFinishCV[nextKey].notify_all(); - return true; } TimeCost EmbeddingRecvTC = TimeCost(); @@ -1553,8 +1545,9 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto if (dims[0] != static_cast(swapOutAddrs.size())) { throw runtime_error("data dims[0] != swapOutKeys.size()"); } - hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId]++; + + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastRecvFinishCV[nextKey].notify_all(); return true; @@ -1711,7 +1704,7 @@ void HybridMgmt::CreateEmbeddingReceiveAndUpdateThread(int index, const EmbInfo& } bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, - int64_t& dims0, bool& isEos) + int64_t& dims0) { string currentKey = MakeSwapCVName(info.threadIdx, info.name, info.channelId); std::unique_lock lastRecvFinishLocker(lastRecvFinishMutex[currentKey]); @@ -1721,17 +1714,14 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, if (!isRunning) { return false; } - isEos = EosL1Que[info.name][info.channelId].WaitAndPop(); + bool isEos = EosL1Que[info.name][info.channelId].WaitAndPop(); if (!isRunning) { return false; } - string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); if (isEos) { LOG_DEBUG("EmbeddingReceiveL3Storage get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, info.channelId); KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); - lastRecvFinishCV[nextKey].notify_all(); - return true; } // DDR swap out key need to be removed @@ -1767,8 +1757,9 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, "thread:{}, dims[0]:{}, swapOutAddrs size:{}, EmbeddingRecvTC(ms):{}", info.name, info.channelId, info.batchId, info.threadIdx, dims[0], swapOutAddrs.size(), EmbeddingRecvTC.ElapsedMS()); - hybridMgmtBlock->lastRecvFinishStep[info.name][info.channelId]++; + + string nextKey = MakeSwapCVName(info.cvNotifyIndex, info.name, info.channelId); lastRecvFinishCV[nextKey].notify_all(); return true; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index b84d2d83..0f7f0999 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -260,7 +260,7 @@ private: void JoinEmbeddingCacheThread(); - bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, bool& isEos); + bool EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs); void EmbeddingUpdateDDR(const EmbTaskInfo& info, const float* embPtr, vector& swapOutAddrs); @@ -268,8 +268,7 @@ private: void EmbeddingSendDDR(const EmbTaskInfo& info, vector& h2dEmb); - bool EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0, - bool& isEos); + bool EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, vector& swapOutAddrs, int64_t& dims0); void EmbeddingUpdateL3Storage(const EmbTaskInfo& info, float* embPtr, vector& swapOutAddrs, int64_t& dims0); -- Gitee From 5842628a0626bbf5250d071e6731d4d5177a127e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Fri, 16 Aug 2024 14:27:15 +0800 Subject: [PATCH 28/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/key_process/key_process.cpp | 2 +- src/dataset_tf/eos_dataset_op.cc | 94 ---------------------------- src/ops_tf/hybrid_dataset_ops.cpp | 11 ++++ 3 files changed, 12 insertions(+), 95 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 33a1e197..dbd87303 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1823,7 +1823,7 @@ void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) void KeyProcess::EnqueueEosBatch(int64_t batchNum, int channelId) { - LOG_INFO("Enqueue data set eos on batch queue, channel:{}, eos number:{}", channelId, batchNum); + LOG_INFO("Enqueue dataSet eos on batch queue, channel:{}, eos number:{}", channelId, batchNum); int threadNum = GetThreadNumEnv(); int batchQueueId = int(batchNum % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId); auto queue = SingletonQueue::GetInstances(batchQueueId); diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 3fcb7bc8..b8fa866e 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -191,99 +191,6 @@ private: } #endif - int64_t GetTensorElementNum(size_t index) { - PartialTensorShape tensor_shape = dataset()->output_shapes()[index]; - int64_t element_number = 1LL; - for (int32_t i = 0; i < tensor_shape.dims(); i++) { - element_number *= tensor_shape.dim_size(i); - } - return element_number; - } - - bool IsUnknowShape(const PartialTensorShape& output_shapes) const { - if (output_shapes.unknown_rank()) { - return true; - } - for (int32_t i = 0; i < output_shapes.dims(); i++) { - if (output_shapes.dim_size(i) == -1) { - return true; - } - } - return false; - } - - Tensor CreateTensorByShape(const PartialTensorShape& output_shapes, const DataType& tensor_data_type) { - TensorShape tf_shape; - for (int32_t i = 0; i < output_shapes.dims(); i++) { - tf_shape.AddDim(output_shapes.dim_size(i)); - } - LOG_INFO("[LQK] CreateTensorByShape, tensor shape: {}", tf_shape.DebugString()); - - Tensor tmp(tensor_data_type, tf_shape); - auto tensor_data = const_cast(tmp.tensor_data().data()); - auto tensor_size = tmp.tensor_data().size(); - LOG_INFO("[LQK] KnownShape, create tensor: {}, tensor size: {}, tensor.NumElements:{}", - tmp.DebugString(), tensor_size, tmp.NumElements()); - - memset_s(tensor_data, tensor_size, 0, tensor_size); - - LOG_INFO("[LQK] KnownShape, after memset tensor: {}", tmp.DebugString()); - - return tmp; - } - - std::vector CreateOutputVecTensor() - { - size_t output_shape_size = dataset()->output_shapes().size(); - size_t output_type_size = dataset()->output_dtypes().size(); - LOG_INFO("[LQK] output_shape_size: {}, output_type_size: {}", output_shape_size, output_type_size); - if (output_shape_size != output_type_size) { - LOG_ERROR("[LQK] output_shape_size: {} is not equal to output_type_size: {}", output_shape_size, - output_type_size); - return {}; - } - std::vector result; - for (size_t i = 0UL; i < output_shape_size; i++) { - DataType tensor_data_type = dataset()->output_dtypes().at(i); - if (tensor_data_type == DT_STRING) { - LOG_ERROR("[LQK] current tensor type is DT_STRING"); - return{}; - } - LOG_INFO("[LQK] current tensor type is: {}", tensor_data_type); - LOG_INFO("[LQK] current tensor dim is: {}, dim[0].dim_Size is {}", dataset()->output_shapes()[i].dims(), - dataset()->output_shapes()[i].dim_size(0)); - if (dataset()->output_shapes()[i].dims() == 2) { - LOG_INFO("[LQK] current tensor dim[1].dim_Size is {}", dataset()->output_shapes()[i].dim_size(1)); - } - if (IsUnknowShape(dataset()->output_shapes()[i])) { - LOG_INFO("[LQK] output shape is unknown shape"); - Tensor tensor(tensor_data_type, TensorShape({8, 1})); - if (dataset()->output_shapes()[i].dims() == -1) { - tensor = Tensor(tensor_data_type, TensorShape({1})); - } - - // 获取指针 - auto tensor_data = const_cast(tensor.tensor_data().data()); - auto tensor_size = tensor.tensor_data().size(); - LOG_INFO("[LQK] IsUnknowShape, create tensor: {}, tensor size: {}, tensor.NumElements:{}", - tensor.DebugString(), tensor_size, tensor.NumElements()); - - memset_s(tensor_data, tensor_size, 0, tensor_size); - - LOG_INFO("[LQK] IsUnknowShape, after memset tensor: {}", tensor.DebugString()); - - result.push_back(tensor); - continue; - } - Tensor a = CreateTensorByShape(dataset()->output_shapes()[i], tensor_data_type); - LOG_INFO("[LQK] success create know shape tensor: {}", a.DebugString()); - - result.push_back(a); - } - return result; - } - - Status GetNextInternal(IteratorContext *ctx, std::vector *out_tensors, bool *end_of_sequence) override @@ -315,7 +222,6 @@ private: } auto keyProcess = Singleton::GetInstance(); - auto datasetId = dataset()->id_; auto channelId = dataset()->channelId_; if (channelId == 0 && iter_times_ == dataset()->maxTrainSteps_) { *end_of_sequence = true; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 5b358884..123c7e1e 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -189,6 +189,12 @@ namespace MxRec { Singleton::GetInstance()->ClearTransChannel(channelId); threadNum = GetThreadNumEnv(); + if (threadNum <= 0) { + context->SetStatus( + errors::Aborted(__FILE__, ":", __LINE__, " ", "ThreadNum invalid. It should be bigger than 0 ...")); + return; + } + auto keyProcess = Singleton::GetInstance(); if (!keyProcess->isRunning) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); @@ -382,6 +388,11 @@ namespace MxRec { Singleton::GetInstance()->ClearTransChannel(channelId); threadNum = GetThreadNumEnv(); + if (threadNum <= 0) { + context->SetStatus( + errors::Aborted(__FILE__, ":", __LINE__, " ", "ThreadNum invalid. It should be bigger than 0 ...")); + return; + } auto keyProcess = Singleton::GetInstance(); if (!keyProcess->isRunning) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); -- Gitee From bec0a56182375ee904d741f3be07ca1a76b18ccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 19 Aug 2024 09:25:01 +0800 Subject: [PATCH 29/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 ++-- src/core/key_process/key_process.cpp | 10 +++++----- src/dataset_tf/eos_dataset_op.cc | 27 +++++++++++++++------------ 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index cec4d3b9..72712074 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1213,7 +1213,7 @@ void HybridMgmt::EmbeddingReceiveAndUpdateDDR(int batchId, int index, const EmbI vector swapOutAddrs; auto isSuccess = EmbeddingReceiveDDR(info, ptr, swapOutAddrs); if (!isSuccess) { - LOG_DEBUG("HybridMgmt is not running or receive empty data when [ReceiveAndUpdateDDR], table:{}, batchId:{}, " + LOG_DEBUG("HybridMgmt is not running or receive empty data when [EmbeddingReceiveDDR], table:{}, batchId:{}, " "channel:{}", embInfo.name, batchId, channelId); return; @@ -1267,7 +1267,7 @@ void HybridMgmt::EmbeddingReceiveAndUpdateL3Storage(int batchId, int index, cons auto isSuccess = EmbeddingReceiveL3Storage(info, ptr, swapOutAddrs, dims0); if (!isSuccess) { LOG_DEBUG( - "HybridMgmt is not running or receive empty data when [LookUpAndSendL3Storage], table:{}, batchId:{}, " + "HybridMgmt is not running or receive empty data when [EmbeddingReceiveL3Storage], table:{}, batchId:{}, " "channel:{}", embInfo.name, batchId, channelId); return; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index dbd87303..a391d42c 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -14,12 +14,12 @@ See the License for the specific language governing permissions and ==============================================================================*/ #include "key_process.h" -#include -#include - #include #include +#include +#include + #include "emb_table/embedding_mgmt.h" #include "hd_transfer/hd_transfer.h" #include "ock_ctr_common/include/error_code.h" @@ -601,7 +601,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, hotPos.resize(hotEmbTotCount[batch->name], 0); tensors->push_back(Vec2TensorI32(hotPos)); - // HBM把restore、unique、idoffset做成了Tensor,放到infolist里面了(hbm第一个get的是tensors) + // Tensors contains restore、hotPos、restoreSec&unique、idOffset in order when HBM mode, and is pushed in infolist. if (!rankInfo.isDDR) { PushGlobalUniqueTensors(tensors, lookupKeys, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); @@ -609,7 +609,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, if (isIncrementalCheckpoint) { PushKeyCountHBM(batch, move(keyCountTensors)); } - } else { // DDR 保留原有的数据结构,idoffset在上层mgmt组装(ddr第一个get的是unique) + } else { std::vector lookupKeysUint(lookupKeys.begin(), lookupKeys.end()); vector uniqueKeys; vector restoreVecSec; diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index b8fa866e..1098aede 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -204,19 +204,22 @@ private: // Out size equals to zero when batch eos. int outSize = out_tensors->size(); - if (outSize > 0) { - for (const auto& t : *out_tensors) { - DataType tensor_type = t.dtype(); - TensorShape tensor_shape = t.shape(); - LOG_DEBUG("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " - "tensor_shape: {}", - dataset()->channelId_, - iter_times_, - outSize, - tensor_type, - tensor_shape.DebugString()); + if (MxRec::Logger::GetLevel() <= MxRec::Logger::DEBUG) { + if (outSize > 0) { + for (const auto& t : *out_tensors) { + DataType tensor_type = t.dtype(); + TensorShape tensor_shape = t.shape(); + LOG_DEBUG("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " + "tensor_shape: {}", + dataset()->channelId_, + iter_times_, + outSize, + tensor_type, + tensor_shape.DebugString()); + } } - } else { + } + if (outSize <= 0) { LOG_DEBUG("Iterator getNext eos, channel: {}, iter: {}, outTensor size: {}", dataset()->channelId_, iter_times_, outSize); } -- Gitee From 6b1f3dc1a8a1e70b6a10666c72254f6dd7123e06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 19 Aug 2024 09:38:19 +0800 Subject: [PATCH 30/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/dataset_tf/eos_dataset_op.cc | 47 +++++++++++++++++--------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 1098aede..dc46aa20 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -202,30 +202,10 @@ private: } TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - // Out size equals to zero when batch eos. - int outSize = out_tensors->size(); - if (MxRec::Logger::GetLevel() <= MxRec::Logger::DEBUG) { - if (outSize > 0) { - for (const auto& t : *out_tensors) { - DataType tensor_type = t.dtype(); - TensorShape tensor_shape = t.shape(); - LOG_DEBUG("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, tensor_type: {}, " - "tensor_shape: {}", - dataset()->channelId_, - iter_times_, - outSize, - tensor_type, - tensor_shape.DebugString()); - } - } - } - if (outSize <= 0) { - LOG_DEBUG("Iterator getNext eos, channel: {}, iter: {}, outTensor size: {}", dataset()->channelId_, - iter_times_, outSize); - } + auto channelId = dataset()->channelId_; + PrintOutput(out_tensors, channelId); auto keyProcess = Singleton::GetInstance(); - auto channelId = dataset()->channelId_; if (channelId == 0 && iter_times_ == dataset()->maxTrainSteps_) { *end_of_sequence = true; } @@ -304,6 +284,29 @@ private: return Status::OK(); } + void PrintOutput(std::vector *out_tensors, int channelId) + { + // Out size equals to zero when batch eos. + int outSize = out_tensors->size(); + if (MxRec::Logger::GetLevel() <= MxRec::Logger::DEBUG) { + for (const auto& t : *out_tensors) { + DataType tensor_type = t.dtype(); + TensorShape tensor_shape = t.shape(); + LOG_DEBUG("Iterator getNext normal, channel: {}, iter: {}, outTensor size: {}, " + "tensor_type: {}, tensor_shape: {}", + channelId, + iter_times_, + outSize, + tensor_type, + tensor_shape.DebugString()); + } + } + if (outSize <= 0) { + LOG_DEBUG("Iterator getNext eos, channel: {}, iter: {}, outTensor size: {}", channelId, + iter_times_, outSize); + } + } + private: static constexpr int GET_NEXT_CONTINUE = 1; static constexpr int GET_NEXT_TERMINATE = 0; -- Gitee From fb6184d5c4c2b7931261c565b915b734774c96c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BD=97=E5=B9=B8=E8=BF=90?= Date: Mon, 19 Aug 2024 09:42:54 +0800 Subject: [PATCH 31/31] =?UTF-8?q?=E3=80=90FEAT=E3=80=91eos=20data=20driver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 72712074..2580a07c 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1508,6 +1508,7 @@ bool HybridMgmt::EmbeddingReceiveDDR(const EmbTaskInfo& info, float*& ptr, vecto if (isEos) { LOG_DEBUG("EmbeddingReceiveDDR get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, info.channelId); + // It cannot return here after send eos, otherwise it will block the next round of switching. KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); } @@ -1721,6 +1722,7 @@ bool HybridMgmt::EmbeddingReceiveL3Storage(const EmbTaskInfo& info, float*& ptr, if (isEos) { LOG_DEBUG("EmbeddingReceiveL3Storage get eos, table:{}, batchId:{}, channel: {}", info.name, info.batchId, info.channelId); + // It cannot return here after send eos, otherwise it will block the next round of switching. KEY_PROCESS_INSTANCE->SendEos(info.name, info.batchId, info.channelId); } -- Gitee