From c60ecd9f9802e96aa0999890fabca312fbdc479f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Thu, 8 Aug 2024 17:22:54 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91DDR=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E4=BF=9D=E5=AD=98=E6=8A=A5=E9=94=99=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E4=BF=9D=E5=AD=98key=E5=92=8Cembedding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/emb_table/embedding_ddr.cpp | 3 ++- src/core/emb_table/embedding_ddr.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/core/emb_table/embedding_ddr.cpp b/src/core/emb_table/embedding_ddr.cpp index 97c10402..f4027fea 100644 --- a/src/core/emb_table/embedding_ddr.cpp +++ b/src/core/emb_table/embedding_ddr.cpp @@ -190,7 +190,8 @@ void EmbeddingDDR::LoadOptimizerSlot(const string &savePath, vector& keyInfo) { SyncLatestEmbedding(pythonBatchId); vector keys; diff --git a/src/core/emb_table/embedding_ddr.h b/src/core/emb_table/embedding_ddr.h index f7a0dc0b..3ccbfb4a 100644 --- a/src/core/emb_table/embedding_ddr.h +++ b/src/core/emb_table/embedding_ddr.h @@ -44,7 +44,7 @@ public: void LoadOptimizerSlot(const string& savePath, vector>& optimizerSlots); - void Save(const string& savePath, const int pythonBatchId); + void Save(const string& savePath, const int pythonBatchId, bool saveDelta, const map& keyInfo); void SyncLatestEmbedding(const int pythonBatchId); -- Gitee From 5357f677fc2575902305bd747694d81a6ad56bff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Thu, 8 Aug 2024 18:40:07 +0800 Subject: [PATCH 02/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91DDR=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E4=BF=9D=E5=AD=98=E6=8A=A5=E9=94=99=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E4=BF=9D=E5=AD=98key=E5=92=8Cembedding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/saver/patch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index f9c17dad..9861756b 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -654,7 +654,7 @@ def patch_for_second_or_step_timer(): def checkpoint_saver_hook_init(self, checkpoint_dir, save_secs=None, save_steps=None, saver=None, - checkpoint_basename="model.ckpt", scaffold=None, listeners=None): + checkpoint_basename="model.ckpt", scaffold=None, listeners=None, save_graph_def=True): logging.info("Create CheckpointSaverHook.") if saver is not None and scaffold is not None: raise ValueError("You cannot provide both saver and scaffold.") @@ -670,6 +670,7 @@ def checkpoint_saver_hook_init(self, checkpoint_dir, save_secs=None, save_steps= is_incremental_checkpoint=self._is_incremental_checkpoint) self._listeners = listeners or [] self._steps_per_run = 1 + self._save_graph_def = save_graph_def def after_run_checkpoint_saver_hook(self, run_context, run_values): -- Gitee From 8111f97878e465a50fc8593fd6ecb843fcbc0ced Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Mon, 12 Aug 2024 21:21:01 +0800 Subject: [PATCH 03/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E6=9B=B4=E6=94=B9?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=90=8D=E4=B8=BA=E5=A4=A7=E5=86=99=E5=AD=97?= =?UTF-8?q?=E6=AF=8D=E5=BC=80=E5=A4=B4=EF=BC=9B=E8=B0=83=E6=95=B4=E5=BC=80?= =?UTF-8?q?=E5=90=AF=E5=A2=9E=E9=87=8F=E6=97=B6=E4=BF=9D=E5=AD=98key?= =?UTF-8?q?=E3=80=81embed=E7=9A=84=E5=A4=9A=E4=BD=99=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/saver/patch.py | 29 ++++++++++++------------ src/core/emb_table/embedding_dynamic.cpp | 17 +++++++------- src/core/emb_table/embedding_static.cpp | 20 ++++++++-------- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 ++-- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- 5 files changed, 37 insertions(+), 35 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 9861756b..89eb391e 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -279,21 +279,22 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra if write_meta_graph: write_meta_graph_task(self, checkpoint_file=checkpoint_file, meta_graph_suffix=meta_graph_suffix, sess=sess, strip_default_attrs=strip_default_attrs, save_debug_info=save_debug_info) + if is_incremental_checkpoint: + save_cost_time = time.time() - start_save_time + save_dir, _ = os.path.split(save_path) + export_tag = "Seconds" if save_delta else "DueTime" + model_index_info = {"timestamp": str(int(start_save_time)), "export_tag": export_tag, + "type": saved_model_type, "global_step": int(global_step), + "cost_ms": int(save_cost_time * 1000)} + if save_delta: + delta_model_version = "delta_" + str(int(global_step)) + write_delta_export_time_ms(save_dir, {delta_model_version: int(save_cost_time * 1000)}) + update_model_index(save_dir, model_index_info) + + # 当保存base的时候清空delta目录 + if not save_delta: + clear_delta_models(save_dir) comm.Barrier() - if is_incremental_checkpoint: - save_cost_time = time.time() - start_save_time - save_dir, _ = os.path.split(save_path) - export_tag = "Seconds" if save_delta else "DueTime" - model_index_info = {"timestamp": str(int(start_save_time)), "export_tag": export_tag, "type": saved_model_type, - "global_step": int(global_step), "cost_ms": int(save_cost_time * 1000)} - if save_delta: - delta_model_version = "delta_" + str(int(global_step)) - write_delta_export_time_ms(save_dir, {delta_model_version: int(save_cost_time * 1000)}) - update_model_index(save_dir, model_index_info) - - # 当保存base的时候清空delta目录 - if not save_delta: - clear_delta_models(save_dir) return model_checkpoint_path diff --git a/src/core/emb_table/embedding_dynamic.cpp b/src/core/emb_table/embedding_dynamic.cpp index c53cef7b..8e6b28dc 100644 --- a/src/core/emb_table/embedding_dynamic.cpp +++ b/src/core/emb_table/embedding_dynamic.cpp @@ -143,15 +143,16 @@ void EmbeddingDynamic::SaveKey(const string& savePath, bool saveDelta, const map deviceKey.clear(); embAddress.clear(); - for (const auto &it: keyOffsetMap) { - if (saveDelta) { - auto result = keyInfo.find(it.first); - if (result == keyInfo.end() || !result->second.isChanged) { - continue; - } + if (saveDelta) { + for (const auto& it : keyInfo) { + deviceKey.push_back(it.first); + embAddress.push_back(keyOffsetMap[it.first]); + } + } else { + for (const auto &it: keyOffsetMap) { + deviceKey.push_back(it.first); + embAddress.push_back(it.second); } - deviceKey.push_back(it.first); - embAddress.push_back(it.second); } if (fileSystemPtr_ == nullptr) { diff --git a/src/core/emb_table/embedding_static.cpp b/src/core/emb_table/embedding_static.cpp index 0aef7b9e..904c1211 100644 --- a/src/core/emb_table/embedding_static.cpp +++ b/src/core/emb_table/embedding_static.cpp @@ -90,18 +90,18 @@ void EmbeddingStatic::SaveKey(const string& savePath, bool saveDelta, const map< deviceKey.clear(); deviceOffset.clear(); - for (const auto& it: keyOffsetMap) { - // When saving a delta model, you need to first extract the keys from deltaMap[name] where isChanged is true - // from the keyOffsetMap. - if (saveDelta) { - auto result = keyInfo.find(it.first); - if (result == keyInfo.end() || !result->second.isChanged) { - continue; - } + if (saveDelta) { + for (const auto& it : keyInfo) { + deviceKey.push_back(it.first); + deviceOffset.push_back(keyOffsetMap[it.first]); + } + } else { + for (const auto& it: keyOffsetMap) { + deviceKey.push_back(it.first); + deviceOffset.push_back(it.second); } - deviceKey.push_back(it.first); - deviceOffset.push_back(it.second); } + LOG_INFO("Device key size: {}, device offset size: {}.", deviceKey.size(), deviceOffset.size()); if (fileSystemPtr_ == nullptr) { diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 1da3e1a9..9ba00c29 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1144,7 +1144,7 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) // 更新delta表 std::lock_guard lock(keyCountUpdateMtx); - updateDeltaInfo(embInfo.name, keyCountVec, timeStamp, globalStep); + UpdateDeltaInfo(embInfo.name, keyCountVec, timeStamp, globalStep); keyBatchIdMap[embInfo.name]++; keyCountUpdateCv.notify_all(); } @@ -1152,7 +1152,7 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) }); } -void HybridMgmt::updateDeltaInfo(const string& embName, vector& keyCountVec, int64_t timeStamp, +void HybridMgmt::UpdateDeltaInfo(const string& embName, vector& keyCountVec, int64_t timeStamp, int64_t batchId) { auto keyCountSize = keyCountVec.size(); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 9cd5c9a9..3af4583e 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -213,7 +213,7 @@ public: void SendTensorForSwap(const EmbBaseInfo& info, const vector& swapInPosUint, const vector& swapOutPosUint); - void updateDeltaInfo(const string& embName, vector& keyCountVec, int64_t timeStamp, int64_t batchId); + void UpdateDeltaInfo(const string& embName, vector& keyCountVec, int64_t timeStamp, int64_t batchId); void ResetDeltaInfo(); -- Gitee From 6117aa82e366428604d2f928a9666ec039ec3a21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Mon, 12 Aug 2024 21:21:01 +0800 Subject: [PATCH 04/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E6=9B=B4=E6=94=B9?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=90=8D=E4=B8=BA=E5=A4=A7=E5=86=99=E5=AD=97?= =?UTF-8?q?=E6=AF=8D=E5=BC=80=E5=A4=B4=EF=BC=9B=E8=B0=83=E6=95=B4=E5=BC=80?= =?UTF-8?q?=E5=90=AF=E5=A2=9E=E9=87=8F=E6=97=B6=E4=BF=9D=E5=AD=98key?= =?UTF-8?q?=E3=80=81embed=E7=9A=84=E5=A4=9A=E4=BD=99=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/saver/patch.py | 29 ++++++++++++------------ src/core/emb_table/embedding_dynamic.cpp | 17 +++++++------- src/core/emb_table/embedding_static.cpp | 20 ++++++++-------- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 ++-- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- 5 files changed, 37 insertions(+), 35 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 9861756b..92a9ac6a 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -279,21 +279,22 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra if write_meta_graph: write_meta_graph_task(self, checkpoint_file=checkpoint_file, meta_graph_suffix=meta_graph_suffix, sess=sess, strip_default_attrs=strip_default_attrs, save_debug_info=save_debug_info) + if is_incremental_checkpoint: + save_cost_time = time.time() - start_save_time + save_dir, _ = os.path.split(save_path) + export_tag = "Seconds" if save_delta else "DueTime" + model_index_info = {"timestamp": str(int(start_save_time)), "export_tag": export_tag, + "type": saved_model_type, "global_step": int(global_step), + "cost_ms": int(save_cost_time * 1000)} + if save_delta: + delta_model_version = "delta_" + str(int(global_step)) + write_delta_export_time_ms(save_dir, {delta_model_version: int(save_cost_time * 1000)}) + update_model_index(save_dir, model_index_info) + + # When saving base model, clear delta model directories + if not save_delta: + clear_delta_models(save_dir) comm.Barrier() - if is_incremental_checkpoint: - save_cost_time = time.time() - start_save_time - save_dir, _ = os.path.split(save_path) - export_tag = "Seconds" if save_delta else "DueTime" - model_index_info = {"timestamp": str(int(start_save_time)), "export_tag": export_tag, "type": saved_model_type, - "global_step": int(global_step), "cost_ms": int(save_cost_time * 1000)} - if save_delta: - delta_model_version = "delta_" + str(int(global_step)) - write_delta_export_time_ms(save_dir, {delta_model_version: int(save_cost_time * 1000)}) - update_model_index(save_dir, model_index_info) - - # 当保存base的时候清空delta目录 - if not save_delta: - clear_delta_models(save_dir) return model_checkpoint_path diff --git a/src/core/emb_table/embedding_dynamic.cpp b/src/core/emb_table/embedding_dynamic.cpp index c53cef7b..8e6b28dc 100644 --- a/src/core/emb_table/embedding_dynamic.cpp +++ b/src/core/emb_table/embedding_dynamic.cpp @@ -143,15 +143,16 @@ void EmbeddingDynamic::SaveKey(const string& savePath, bool saveDelta, const map deviceKey.clear(); embAddress.clear(); - for (const auto &it: keyOffsetMap) { - if (saveDelta) { - auto result = keyInfo.find(it.first); - if (result == keyInfo.end() || !result->second.isChanged) { - continue; - } + if (saveDelta) { + for (const auto& it : keyInfo) { + deviceKey.push_back(it.first); + embAddress.push_back(keyOffsetMap[it.first]); + } + } else { + for (const auto &it: keyOffsetMap) { + deviceKey.push_back(it.first); + embAddress.push_back(it.second); } - deviceKey.push_back(it.first); - embAddress.push_back(it.second); } if (fileSystemPtr_ == nullptr) { diff --git a/src/core/emb_table/embedding_static.cpp b/src/core/emb_table/embedding_static.cpp index 0aef7b9e..904c1211 100644 --- a/src/core/emb_table/embedding_static.cpp +++ b/src/core/emb_table/embedding_static.cpp @@ -90,18 +90,18 @@ void EmbeddingStatic::SaveKey(const string& savePath, bool saveDelta, const map< deviceKey.clear(); deviceOffset.clear(); - for (const auto& it: keyOffsetMap) { - // When saving a delta model, you need to first extract the keys from deltaMap[name] where isChanged is true - // from the keyOffsetMap. - if (saveDelta) { - auto result = keyInfo.find(it.first); - if (result == keyInfo.end() || !result->second.isChanged) { - continue; - } + if (saveDelta) { + for (const auto& it : keyInfo) { + deviceKey.push_back(it.first); + deviceOffset.push_back(keyOffsetMap[it.first]); + } + } else { + for (const auto& it: keyOffsetMap) { + deviceKey.push_back(it.first); + deviceOffset.push_back(it.second); } - deviceKey.push_back(it.first); - deviceOffset.push_back(it.second); } + LOG_INFO("Device key size: {}, device offset size: {}.", deviceKey.size(), deviceOffset.size()); if (fileSystemPtr_ == nullptr) { diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 1da3e1a9..9ba00c29 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1144,7 +1144,7 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) // 更新delta表 std::lock_guard lock(keyCountUpdateMtx); - updateDeltaInfo(embInfo.name, keyCountVec, timeStamp, globalStep); + UpdateDeltaInfo(embInfo.name, keyCountVec, timeStamp, globalStep); keyBatchIdMap[embInfo.name]++; keyCountUpdateCv.notify_all(); } @@ -1152,7 +1152,7 @@ void HybridMgmt::ReceiveKeyThread(const EmbInfo& embInfo) }); } -void HybridMgmt::updateDeltaInfo(const string& embName, vector& keyCountVec, int64_t timeStamp, +void HybridMgmt::UpdateDeltaInfo(const string& embName, vector& keyCountVec, int64_t timeStamp, int64_t batchId) { auto keyCountSize = keyCountVec.size(); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 9cd5c9a9..3af4583e 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -213,7 +213,7 @@ public: void SendTensorForSwap(const EmbBaseInfo& info, const vector& swapInPosUint, const vector& swapOutPosUint); - void updateDeltaInfo(const string& embName, vector& keyCountVec, int64_t timeStamp, int64_t batchId); + void UpdateDeltaInfo(const string& embName, vector& keyCountVec, int64_t timeStamp, int64_t batchId); void ResetDeltaInfo(); -- Gitee From ee999feb70d4a055898c7c68742d93a4b6be815e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Mon, 12 Aug 2024 21:28:47 +0800 Subject: [PATCH 05/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E6=9B=B4=E6=94=B9?= =?UTF-8?q?=E6=B3=A8=E9=87=8A=E4=B8=BA=E8=8B=B1=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/saver/patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 89eb391e..a05bef06 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -291,7 +291,7 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra write_delta_export_time_ms(save_dir, {delta_model_version: int(save_cost_time * 1000)}) update_model_index(save_dir, model_index_info) - # 当保存base的时候清空delta目录 + # When saving base model, clear delta model directories. if not save_delta: clear_delta_models(save_dir) comm.Barrier() -- Gitee From 300b036cb0a078bc380aa18483b2bb4a4d1da90f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Wed, 14 Aug 2024 18:09:18 +0800 Subject: [PATCH 06/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E8=A7=A3=E5=86=B3?= =?UTF-8?q?=E5=BC=80=E5=90=AF=E5=A2=9E=E9=87=8F=E4=BF=9D=E5=AD=98=E6=97=B6?= =?UTF-8?q?=E5=A4=9A=E5=8D=A1=E4=BF=9D=E5=AD=98=E7=9A=84key=E4=B8=8E?= =?UTF-8?q?=E9=A2=84=E6=9C=9F=E4=B8=8D=E4=B8=80=E8=87=B4=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/emb_table/embedding_dynamic.cpp | 12 +++++- src/core/emb_table/embedding_static.cpp | 12 ++++-- src/core/key_process/key_process.cpp | 51 +++++++++++++++--------- src/core/key_process/key_process.h | 5 +-- 4 files changed, 53 insertions(+), 27 deletions(-) diff --git a/src/core/emb_table/embedding_dynamic.cpp b/src/core/emb_table/embedding_dynamic.cpp index 8e6b28dc..2f244eaf 100644 --- a/src/core/emb_table/embedding_dynamic.cpp +++ b/src/core/emb_table/embedding_dynamic.cpp @@ -145,8 +145,13 @@ void EmbeddingDynamic::SaveKey(const string& savePath, bool saveDelta, const map if (saveDelta) { for (const auto& it : keyInfo) { - deviceKey.push_back(it.first); - embAddress.push_back(keyOffsetMap[it.first]); + auto result = keyOffsetMap.find(it.first); + if (result == keyOffsetMap.end()) { + LOG_DEBUG("Key: {} not in keyOffsetMap."); + continue; + } + deviceKey.push_back(result->first); + embAddress.push_back(result->second); } } else { for (const auto &it: keyOffsetMap) { @@ -155,6 +160,9 @@ void EmbeddingDynamic::SaveKey(const string& savePath, bool saveDelta, const map } } + LOG_INFO("Get device keys and embAddress, table: {}, save path: {}, rank id: {}, device key size: {}, device " + "embAddress size: {}.", name, savePath, rankId_, deviceKey.size(), embAddress.size()); + if (fileSystemPtr_ == nullptr) { throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); } diff --git a/src/core/emb_table/embedding_static.cpp b/src/core/emb_table/embedding_static.cpp index 904c1211..93324e49 100644 --- a/src/core/emb_table/embedding_static.cpp +++ b/src/core/emb_table/embedding_static.cpp @@ -92,8 +92,13 @@ void EmbeddingStatic::SaveKey(const string& savePath, bool saveDelta, const map< if (saveDelta) { for (const auto& it : keyInfo) { - deviceKey.push_back(it.first); - deviceOffset.push_back(keyOffsetMap[it.first]); + auto result = keyOffsetMap.find(it.first); + if (result == keyOffsetMap.end()) { + LOG_DEBUG("Key: {} not in keyOffsetMap."); + continue; + } + deviceKey.push_back(result->first); + deviceOffset.push_back(result->second); } } else { for (const auto& it: keyOffsetMap) { @@ -102,7 +107,8 @@ void EmbeddingStatic::SaveKey(const string& savePath, bool saveDelta, const map< } } - LOG_INFO("Device key size: {}, device offset size: {}.", deviceKey.size(), deviceOffset.size()); + LOG_INFO("Get device keys and offsets, table: {}, save path: {}, rank id: {}, device key size: {}, device offset " + "size: {}.", name, savePath, rankId_, deviceKey.size(), deviceOffset.size()); if (fileSystemPtr_ == nullptr) { throw runtime_error("failed to obtain the file system pointer, the file system pointer is null."); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 3a79ac26..975bb732 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -321,14 +321,14 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, vector & restore, vector & hotPos, - vector >& keyCount, vector& keyCountVec) + vector >& keyCount) { TimeCost uniqueTc; if (m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { tie(splitKeys, restore, keyCount) = HashSplitWithFAAE(batch); // 按存储dev id切分并去重 } else { - tie(splitKeys, restore, hotPos, keyCountVec) = HotHashSplit(batch); // 按存储dev id切分并去重 + tie(splitKeys, restore, hotPos, keyCount) = HotHashSplit(batch); // 按存储dev id切分并去重 } LOG_DEBUG("uniqueTc(ms):{}", uniqueTc.ElapsedMS()); } @@ -405,7 +405,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, vector> keyCount; vector keyCountVec; TimeCost totalTimeCost = TimeCost(); - HashSplitHelper(batch, splitKeys, restore, hotPos, keyCount, keyCountVec); + HashSplitHelper(batch, splitKeys, restore, hotPos, keyCount); auto [lookupKeys, scAll, ss] = ProcessSplitKeys(batch, threadId, splitKeys); vector countRecv; @@ -413,6 +413,19 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { countRecv = GetCountRecv(batch, threadId, keyCount, scAll, ss); } + if (isIncrementalCheckpoint) { + countRecv = GetCountRecv(batch, threadId, keyCount, scAll, ss); + map tmpKeyCountMap; + auto keySize = lookupKeys.size(); + for (int i = 0; i < keySize; ++i) { + tmpKeyCountMap[lookupKeys[i]] += countRecv[i]; + } + auto tmpKeyCountMapSize = tmpKeyCountMap.size(); + for (const auto& it : tmpKeyCountMap) { + keyCountVec.push_back(it.first); + keyCountVec.push_back(it.second); + } + } std::lock_guard lock(loadSaveMut[channel][threadId]); RecordKeyCountMap(batch); BuildRestoreVec(batch, ss, restore, static_cast(hotPos.size())); @@ -920,7 +933,7 @@ tuple, vector, vector>> KeyProcess::Hash return {splitKeys, restore, keyCount}; } -tuple, vector, vector, vector> KeyProcess::HotHashSplit(const +tuple, vector, vector, vector>> KeyProcess::HotHashSplit(const unique_ptr& batch) { EASY_FUNCTION(profiler::colors::Gold) @@ -928,10 +941,9 @@ unique_ptr& batch) size_t miniBs = batch->Size(); vector splitKeys(rankInfo.rankSize); vector restore(batch->Size()); - absl::flat_hash_map uKey; // 用于去重查询 + absl::flat_hash_map> uKey; // 用于去重查询 absl::flat_hash_map keyCountMapByEmbName; - absl::flat_hash_map keyCountOneBatch; - vector keyCountVec; + vector> keyCount(rankInfo.rankSize); std::shared_lock lock(g_smut); auto hotMap = hotKey[batch->name]; lock.unlock(); @@ -944,13 +956,11 @@ unique_ptr& batch) if (batch->batchId % hotEmbUpdateStep == 0) { keyCountMapByEmbName[key]++; } - if (isIncrementalCheckpoint) { - keyCountOneBatch[key]++; - } emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); auto result = uKey.find(key); if (result != uKey.end()) { // // already in splitKeys - restore[i] = result->second; + restore[i] = result->second.first; + uKey[key].second++; continue; } // new key in current batch @@ -970,16 +980,19 @@ unique_ptr& batch) // restore记录去重后key在桶内偏移量(用于计算恢复向量) restore[i] = static_cast(splitKeys[devId].size() + (hotOffset - 1)); } - uKey[key] = restore[i]; + uKey[key].first = restore[i]; + uKey[key].second = 1; } - if (isIncrementalCheckpoint) { - for (auto& it : keyCountOneBatch) { - keyCountVec.emplace_back(it.first); - keyCountVec.emplace_back(it.second); + + // Process key count in splitKeys + for (int j = 0; j < rankInfo.rankSize; ++j) { + vector count; + for (size_t k = 0; k < splitKeys[j].size(); ++k) { + count.emplace_back(uKey[splitKeys[j][k]].second); } + keyCount[j] = count; } - LOG_INFO(KEY_PROCESS "Hot hash split, batch id: {}, batch name: {}, channel: {}, kc size: {}, data: {}", - batch->batchId, batch->name, batch->channel, keyCountVec.size(), VectorToString(keyCountVec)); + if (GlogConfig::gStatOn) { size_t uniqueKeyNum = 0; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { @@ -992,7 +1005,7 @@ unique_ptr& batch) UpdateHotMap(keyCountMapByEmbName, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, batch->name); AddCountStartToHotPos(splitKeys, hotPos, hotPosDev, batch); - return { splitKeys, restore, hotPos, keyCountVec }; + return { splitKeys, restore, hotPos, keyCount }; } void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index ac99b42d..4d0d1acc 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -243,7 +243,7 @@ public: auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector, - vector>; + vector>>; void PaddingAlltoallVC(vector& splitKeys) const; @@ -296,8 +296,7 @@ public: vector scAll, vector ss); void HashSplitHelper(const unique_ptr & batch, vector & splitKeys, - vector & restore, vector & hotPos, - vector >& keyCount, vector& keyCountVec); + vector & restore, vector & hotPos, vector >& keyCount); template inline vector Count2Start(const vector& count) const -- Gitee From 2517e612f036d0540bd796462c1e640702868528 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Thu, 15 Aug 2024 21:15:41 +0800 Subject: [PATCH 07/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E8=A7=A3=E5=86=B3?= =?UTF-8?q?=E5=BC=80=E5=90=AF=E5=A2=9E=E9=87=8F=E4=BF=9D=E5=AD=98=E6=97=B6?= =?UTF-8?q?=E5=A4=9A=E5=8D=A1=E4=BF=9D=E5=AD=98=E7=9A=84key=E4=B8=8E?= =?UTF-8?q?=E9=A2=84=E6=9C=9F=E4=B8=8D=E4=B8=80=E8=87=B4=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/core/asc/build_graph.py | 2 +- mx_rec/saver/patch.py | 16 ++++++++-------- src/core/key_process/key_process.cpp | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index cc8c9cf9..133e4b91 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -148,7 +148,7 @@ def get_preprocessed_tensor_for_asc(table, config): id_offsets, swap_info = get_id_offsets(max_lookup_vec_size, config) is_incremental_checkpoint = ConfigInitializer.get_instance().is_incremental_checkpoint - if is_incremental_checkpoint: + if is_incremental_checkpoint and config["channel_id"] == TRAIN_CHANNEL_ID: table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(table) channel_name = f"{table_instance.table_name}_key_d2h_{TRAIN_CHANNEL_ID}" # send timestamp and global step tensor to host diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index a05bef06..8dcde975 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -310,22 +310,24 @@ def restore(self, sess, save_path): restore_model_version = ConfigInitializer.get_instance().restore_model_version directory, base_name = os.path.split(save_path) + model_type = BASE_MODEL if is_incremental_checkpoint: + model_type = get_model_type_by_version(directory, base_name.split("-")[1]) if not restore_model_version: # open incremental checkpoint and restore_model_version is none, then restore the latest model - restore_model_version = base_name.split("-")[1] - model_type = get_model_type_by_version(directory, restore_model_version) if model_type == DELTA_MODEL: # get the newest base model and then restore delta models one by one - base_model, delta_models = get_base_and_delta_models(directory, restore_model_version) + base_model, delta_models = get_base_and_delta_models(directory, base_name.split("-")[1]) delta_models_str = " ".join(delta_models) logger.info(f"Restore %s model from base model: %s and delta models: %s.", model_type, base_model, delta_models_str) read_base_delta_and_write(directory, base_model, delta_models) - save_path = os.path.join(directory, base_name) else: base_name = base_name.split("-")[0] + "-" + restore_model_version model_type = get_model_type_by_version(directory, restore_model_version) + if not model_type: + logger.error("Get model type by version failed, %s step model not exists.", restore_model_version) + raise ValueError(f"Get model type by version failed, {restore_model_version} step model not exists.") if model_type == DELTA_MODEL: # get the newest base model and then restore delta models one by one base_model, delta_models = get_base_and_delta_models(directory, restore_model_version) @@ -333,10 +335,8 @@ def restore(self, sess, save_path): logger.info(f"Restore %s model from base model: %s and delta models: %s.", model_type, base_model, delta_models_str) read_base_delta_and_write(directory, base_model, delta_models) - save_path = os.path.join(directory, base_name) - else: - model_type = BASE_MODEL - save_path = os.path.join(directory, base_name) + + save_path = os.path.join(directory, base_name) if not check_characters_is_valid(save_path): raise ValueError("save_path contains invalid characters such as newline, " diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 975bb732..1de2c5dc 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -413,7 +413,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { countRecv = GetCountRecv(batch, threadId, keyCount, scAll, ss); } - if (isIncrementalCheckpoint) { + if (isIncrementalCheckpoint && !channel) { countRecv = GetCountRecv(batch, threadId, keyCount, scAll, ss); map tmpKeyCountMap; auto keySize = lookupKeys.size(); @@ -453,7 +453,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, // 将keyCountVec放进tensor里并推到一个队列里 auto keyCountTensors = make_unique>(); - if (isIncrementalCheckpoint) { + if (isIncrementalCheckpoint && !channel) { keyCountTensors->push_back(Vec2TensorI64(keyCountVec)); } @@ -464,7 +464,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, PushGlobalUniqueTensors(tensors, lookupKeys, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); PushResultHBM(batch, move(tensors)); - if (isIncrementalCheckpoint) { + if (isIncrementalCheckpoint && !channel) { PushKeyCountHBM(batch, move(keyCountTensors)); } } else { -- Gitee From 5de2597428817075496e138748dacb3c0105dd40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Fri, 16 Aug 2024 09:10:05 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E6=B8=85=E7=90=86cl?= =?UTF-8?q?eancode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/core/key_process/key_process.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 2d589f5e..a7bb18b7 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -556,7 +556,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { countRecv = GetCountRecv(batch, threadId, keyCount, scAll, ss); } - if (isIncrementalCheckpoint && !channel) { + if (isIncrementalCheckpoint && channel == 0) { countRecv = GetCountRecv(batch, threadId, keyCount, scAll, ss); map tmpKeyCountMap; auto keySize = lookupKeys.size(); @@ -596,7 +596,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, // 将keyCountVec放进tensor里并推到一个队列里 auto keyCountTensors = make_unique>(); - if (isIncrementalCheckpoint && !channel) { + if (isIncrementalCheckpoint && channel == 0) { keyCountTensors->push_back(Vec2TensorI64(keyCountVec)); } @@ -607,7 +607,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, PushGlobalUniqueTensors(tensors, lookupKeys, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); PushResultHBM(batch, move(tensors)); - if (isIncrementalCheckpoint && !channel) { + if (isIncrementalCheckpoint && channel == 0) { PushKeyCountHBM(batch, move(keyCountTensors)); } } else { -- Gitee From 2512a0fe86d94699ad270182a3c81487836af8f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Thu, 22 Aug 2024 21:13:43 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=A2=9E=E9=87=8F=E5=8A=A0=E8=BD=BD=E6=97=B6=E5=A4=9A=E5=8D=A1?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E5=A4=B1=E8=B4=A5=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/saver/patch.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 74c81387..515caa0a 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -314,23 +314,20 @@ def restore(self, sess, save_path): directory, base_name = os.path.split(save_path) model_type = BASE_MODEL + if is_incremental_checkpoint: - model_type = get_model_type_by_version(directory, base_name.split("-")[1]) - if not restore_model_version: - # open incremental checkpoint and restore_model_version is none, then restore the latest model - if model_type == DELTA_MODEL: - # get the newest base model and then restore delta models one by one - base_model, delta_models = get_base_and_delta_models(directory, base_name.split("-")[1]) - delta_models_str = " ".join(delta_models) - logger.info(f"Restore %s model from base model: %s and delta models: %s.", model_type, base_model, - delta_models_str) - read_base_delta_and_write(directory, base_model, delta_models) - else: + if restore_model_version is not None: base_name = base_name.split("-")[0] + "-" + str(restore_model_version) - model_type = get_model_type_by_version(directory, str(restore_model_version)) - if not model_type: - logger.error("Get model type by version failed, %s step model not exists.", restore_model_version) - raise ValueError(f"Get model type by version failed, {restore_model_version} step model not exists.") + restore_model_version = base_name.split("-")[1] + model_type = get_model_type_by_version(directory, restore_model_version) + if not model_type: + logger.error("Get model type by version failed, %s step model not exists.", restore_model_version) + raise ValueError(f"Get model type by version failed, {restore_model_version} step model not exists.") + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + comm.Barrier() + if should_write_data(rank, save_path): if model_type == DELTA_MODEL: # get the newest base model and then restore delta models one by one base_model, delta_models = get_base_and_delta_models(directory, str(restore_model_version)) @@ -338,6 +335,7 @@ def restore(self, sess, save_path): logger.info(f"Restore %s model from base model: %s and delta models: %s.", model_type, base_model, delta_models_str) read_base_delta_and_write(directory, base_model, delta_models) + comm.Barrier() save_path = os.path.join(directory, base_name) -- Gitee From 944feb980b847e56687cec92c4e51281b0bd713d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9C=96?= Date: Fri, 23 Aug 2024 14:29:35 +0800 Subject: [PATCH 10/10] =?UTF-8?q?=E3=80=90FIX=E3=80=91=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=BC=80=E5=90=AF=E5=A2=9E=E9=87=8F=E8=AE=AD=E7=BB=83=E6=97=B6?= =?UTF-8?q?=E4=BC=9A=E5=81=B6=E7=8E=B0=E5=8D=A1=E4=BD=8F=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mx_rec/saver/patch.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 515caa0a..947bf050 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -606,17 +606,21 @@ def should_trigger_for_step(self, step: int) -> bool: return True return False + should_trigger = False if self._save_checkpoint_due_time is not None: if time.time() >= self._last_triggered_base_time + self._save_checkpoint_due_time: self._is_delta = False - return True + should_trigger = True if self._save_delta_checkpoints_secs is not None: if time.time() >= self._last_triggered_delta_time + self._save_delta_checkpoints_secs: self._is_delta = True - return True + should_trigger = True + + comm = MPI.COMM_WORLD + result = comm.allreduce(should_trigger, op=MPI.LOR) - return False + return result def update_last_triggered_step(self, step: int) -> (Optional[float], Optional[int]): -- Gitee