diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 82e99a3f98119fdf8a7083f275fdef68c2617cee..ef5b0d159d7b2a06cad03bb939cc9845510398c9 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -250,7 +250,10 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra tf_logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") save_check(latest_filename, sess) - + could_invoke_set_save_info = ((not context.executing_eagerly()) and self.sparse_saver and + self.sparse_saver.config_instance) + if could_invoke_set_save_info: + self.sparse_saver.config_instance.hybrid_manager_config.set_save_op_info(False) if global_step is not None: checkpoint_file = get_checkpoint_file(self, global_step, sess, save_path) else: @@ -298,6 +301,9 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra if not save_delta: clear_delta_models(save_dir) comm.Barrier() + + if could_invoke_set_save_info: + self.sparse_saver.config_instance.hybrid_manager_config.set_save_op_info(True) return model_checkpoint_path diff --git a/mx_rec/util/config_utils/hybrid_mgmt_utils.py b/mx_rec/util/config_utils/hybrid_mgmt_utils.py index 38d29b272eeec24613754570ac076918350b6d0b..293087c2a9b0d72c8d162203a1e02ca99427af03 100644 --- a/mx_rec/util/config_utils/hybrid_mgmt_utils.py +++ b/mx_rec/util/config_utils/hybrid_mgmt_utils.py @@ -98,3 +98,9 @@ class HybridManagerConfig: raise RuntimeError("ASC manager not exist.") self.asc_manager.fetch_device_emb() logger.debug("request of fetching embedding from device to host for saving has been send") + + def set_save_op_info(self, is_save_end: bool) -> None: + if self.asc_manager is None: + raise RuntimeError("ASC manager not exist.") + self.asc_manager.set_save_op_info(is_save_end) + logger.debug("Request to set save op info end.") diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index b0bac0e8dd309898cfa5303eb9ef69dc26224829..e96e890e6d895275f34d99ea04f74481cb2f5abc 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -2214,4 +2214,14 @@ void HybridMgmt::GetDeltaModelKeys(const string& savePath, bool saveDelta, } } } -} \ No newline at end of file +} + +void HybridMgmt::SetSaveOpInfo(bool isSaveEnd) +{ + if (isSaveEnd) { + KEY_PROCESS_INSTANCE->SetPythonSaveEndInfo(); + } else { + KEY_PROCESS_INSTANCE->SetPythonSaveStartInfo(); + } +} + diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index f7b41d2b547329501238ba75319b4bb2e9f07718..da87ebedfbcb7d5e802c397279dc6d7ed7a925ee 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -142,6 +142,8 @@ public: void ReceiveKeyThread(const EmbInfo& embInfo); + void SetSaveOpInfo(bool isSaveEnd); + GTEST_PRIVATE : bool mutexDestroy{false}; // LookupAndSend & ReceiveAndUpdate Condition_Variable_Wait stop. std::mutex lookUpAndSendBatchIdMtx[MAX_CHANNEL_NUM]; // train and eval diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 3b7e032cb5927248fea522fa8ffa7d334ceecb77..7016dc031e522eb91bbfd91736d883f08656d62c 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1590,6 +1590,9 @@ void KeyProcess::SendEos(const std::string& embName, int batchId, int channel) destroyMutex.unlock(); return; } + + WaitSaveEnd(embName, batchId, channel); + SendEosTensor(embName, channel); destroyMutex.unlock(); LOG_INFO("channelId:{} batchId:{}, the embName:{} SendEos end, release destroyMutex", channel, batchId, embName); @@ -1878,3 +1881,46 @@ void KeyProcess::SendEosTensor(const std::string& embName, int channel) } #endif } + +void KeyProcess::SetPythonSaveStartInfo() +{ + // Add an element assign false value, indicates a save op start. + LOG_INFO("Python save operation start."); + this->saveOpRecords_.emplace_back(false); +} + +void KeyProcess::SetPythonSaveEndInfo() +{ + if (this->saveOpRecords_.empty()) { + throw runtime_error("failed to set save op end because save records is empty."); + } + // Set the last element as true, indicates a save op end. + if (this->saveOpRecords_.size() > SAVE_RECORD_LENGTH) { + this->saveOpRecords_.clear(); + this->saveOpRecords_.emplace_back(true); + } else { + this->saveOpRecords_[saveOpRecords_.size() - 1] = true; + } + LOG_INFO("Python save operation end."); +} + +void KeyProcess::WaitSaveEnd(const std::string& embName, int batchId, int channel) +{ + // Before sending eos, wait for the save operation to complete. + // Sleep for 3 seconds and wait for a possible save operation. + this_thread::sleep_for(3000ms); + int loop_cnt = 1; + while (!saveOpRecords_.empty() && !saveOpRecords_[saveOpRecords_.size() - 1] + && loop_cnt <= SAVE_RECORD_CHECK_TIMES) { + this_thread::sleep_for(1000ms); + loop_cnt++; + } + + if (loop_cnt <= SAVE_RECORD_CHECK_TIMES) { + LOG_DEBUG("[EOS] table:{}, channelId:{} batchId:{}, before send eos, check save records loop times:{}.", + embName, channel, batchId, loop_cnt); + } else { + LOG_WARN("[EOS] table:{}, channelId:{} batchId:{}, before send eos, check save records loop times:{}.", + embName, channel, batchId, loop_cnt); + } +} diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 0296d2d7b75c0545b07e9600f863a80b396436f9..12a9fd2b8bcd80075154b23efc8fa7257ad06953 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -35,6 +35,9 @@ See the License for the specific language governing permissions and namespace MxRec { using namespace std; +constexpr int SAVE_RECORD_CHECK_TIMES = 25; +constexpr size_t SAVE_RECORD_LENGTH = 10; + template struct Cmp { bool operator()(const T& a, const T& b) const @@ -199,6 +202,10 @@ public: bool isIncrementalCheckpoint{false}; + void SetPythonSaveStartInfo(); + + void SetPythonSaveEndInfo(); + std::mutex destroyMutex; inline bool HasEmbName(const string& embName) @@ -355,6 +362,10 @@ GTEST_PRIVATE : string DumpSplitKeys(vector>& splitKeys) const; void SendEosTensor(const std::string& embName, int channel); + + void WaitSaveEnd(const std::string& embName, int batchId, int channel); + + std::vector saveOpRecords_ = {}; }; #define KEY_PROCESS_INSTANCE Singleton::GetInstance() diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 07a2163b22ca492294810b0990030e5dd0058d05..73bcc99d0a47ba387e39e869a3ac29e149fabd3f 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -236,7 +236,8 @@ namespace { .def("get_table_size", &MxRec::HybridMgmt::GetTableSize, py::arg("table_name")) .def("get_table_capacity", &MxRec::HybridMgmt::GetTableCapacity, py::arg("table_name")) .def("set_optim_info", &MxRec::HybridMgmt::SetOptimizerInfo, py::arg("table_name"), - py::arg("optimizer_info")); + py::arg("optimizer_info")) + .def("set_save_op_info", &MxRec::HybridMgmt::SetSaveOpInfo, py::arg("is_save_end")); } void GetThresholdValue(pybind11::module_& m)