diff --git a/examples/mmoe/config.py b/examples/mmoe/config.py new file mode 100644 index 0000000000000000000000000000000000000000..08cfe9e70460d0425e63c14c5a4d495f7d8a369a --- /dev/null +++ b/examples/mmoe/config.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +import os + +import tensorflow as tf +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig +from npu_bridge.estimator.npu.npu_config import NPURunConfig + +from mx_rec.constants.constants import CacheModeEnum + +SSD_DATA_PATH = ["ssd_data"] + + +class LearningRateScheduler: + """ + LR Scheduler combining Polynomial Decay with Warmup at the beginning. + TF-based cond operations necessary for performance in graph mode. + """ + + def __init__(self, base_lr_dense, base_lr_sparse): + self.base_lr_dense = base_lr_dense + self.base_lr_sparse = base_lr_sparse + + def calc(self): + # used for the constant stage + lr_factor_constant = tf.cast(1.0, tf.float32) + + lr_sparse = self.base_lr_sparse * lr_factor_constant + lr_dense = self.base_lr_dense * lr_factor_constant + return lr_dense, lr_sparse + + +class Config: + def __init__(self, ) -> None: + self.rank_id = int(os.getenv("OMPI_COMM_WORLD_RANK")) if os.getenv("OMPI_COMM_WORLD_RANK") else None + tmp = os.getenv("TRAIN_RANK_SIZE") + if tmp is None: + raise ValueError("please export TRAIN_RANK_SIZE") + self.rank_size = int(tmp) + + self.data_path = os.getenv("DLRM_CRITEO_DATA_PATH") + self.train_file_pattern = "train" + self.test_file_pattern = "test" + + self.batch_size = 32 + self.line_per_sample = 1 + self.train_epoch = 100 + self.test_epoch = 100 + self.expert_num = 8 + self.gate_num = 2 + self.expert_size = 16 + self.tower_size = 8 + + self.perform_shuffle = False + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.int64 + + self.feat_cnt = 26 + self.__set_emb_table_size() + + self.field_num = 26 + self.send_count = self.get_send_count(self.rank_size) + + self.emb_dim = self.expert_num * self.expert_size + self.gate_num * self.expert_num + self.hashtable_threshold = 1 + + self.USE_PIPELINE_TEST = False + + self.global_step = tf.Variable(0, trainable=False) + _lr_scheduler = LearningRateScheduler( + 0.001, + 0.001 + ) + self.learning_rate = _lr_scheduler.calc() + + + @staticmethod + def get_send_count(rank_size): + try: + return 46000 // rank_size + except ZeroDivisionError as exp: + raise ZeroDivisionError('Rank size can not be zero.') from exp + + + def __set_emb_table_size(self) -> None: + self.cache_mode = os.getenv("CACHE_MODE") + if self.cache_mode is None: + raise ValueError("please export CACHE_MODE environment variable, support:[HBM, DDR, SSD]") + + if self.cache_mode == CacheModeEnum.HBM.value: + self.dev_vocab_size = 1000 * self.rank_size + self.host_vocab_size = 0 + elif self.cache_mode == CacheModeEnum.DDR.value: + self.dev_vocab_size = 1000 * self.rank_size + self.host_vocab_size = 1000 * self.rank_size + elif self.cache_mode == CacheModeEnum.SSD.value: + self.dev_vocab_size = 1000 * self.rank_size + self.host_vocab_size = 1000 * self.rank_size + self.ssd_vocab_size = 1000 * self.rank_size + else: + raise ValueError(f"get CACHE_MODE:{self.cache_mode}, expect in [HBM, DDR, SSD]") + + def get_emb_table_cfg(self) -> None: + if self.cache_mode == CacheModeEnum.HBM.value: + return {"device_vocabulary_size": self.dev_vocab_size} + elif self.cache_mode == CacheModeEnum.DDR.value: + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size} + elif self.cache_mode == CacheModeEnum.SSD.value: + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size, + "ssd_vocabulary_size": self.ssd_vocab_size, + "ssd_data_path": SSD_DATA_PATH} + else: + raise RuntimeError(f"get CACHE_MODE:{self.cache_mode}, check Config.__set_emb_table_size implementation") + + +def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): + session_config = tf.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["mix_compile_mode"].b = False + custom_op.parameter_map["use_off_line"].b = True + custom_op.parameter_map["min_group_size"].b = 1 + # 可选配置level0:pairwise;level1:pairwise + custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:fullmesh;level1:fullmesh") + custom_op.parameter_map["enable_data_pre_proc"].b = True + custom_op.parameter_map["iterations_per_loop"].i = 10 + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + custom_op.parameter_map["hcom_parallel"].b = False + custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_impl_mode.ini") + custom_op.parameter_map["op_execute_timeout"].i = 2000 + custom_op.parameter_map["variable_memory_max_size"].s = tf.compat.as_bytes( + str(13 * 1024 * 1024 * 1024)) # total 31 need 13; + custom_op.parameter_map["graph_memory_max_size"].s = tf.compat.as_bytes(str(18 * 1024 * 1024 * 1024)) # need 25 + custom_op.parameter_map["stream_max_parallel_num"].s = tf.compat.as_bytes("DNN_VM_AICPU:3,AIcoreEngine:3") + + if dump_data: + custom_op.parameter_map["enable_dump"].b = True + custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(dump_path) + custom_op.parameter_map["dump_step"].s = tf.compat.as_bytes(dump_steps) + custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all") + + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + return session_config + + +def get_npu_run_config(): + session_config = tf.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + run_config = NPURunConfig( + save_summary_steps=1000, + save_checkpoints_steps=100, + keep_checkpoint_max=5, + session_config=session_config, + log_step_count_steps=20, + precision_mode='allow_mix_precision', + enable_data_pre_proc=True, + iterations_per_loop=1, + jit_compile=False, + op_compiler_cache_mode="enable", + HCCL_algorithm="level0:fullmesh;level1:fullmesh" # 可选配置:level0:pairwise;level1:pairwise + ) + return run_config diff --git a/examples/mmoe/main_mxrec.py b/examples/mmoe/main_mxrec.py new file mode 100644 index 0000000000000000000000000000000000000000..d02566aa9339016181cc42ab0672bbddbfb8e9e4 --- /dev/null +++ b/examples/mmoe/main_mxrec.py @@ -0,0 +1,469 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import shutil +import time +import warnings +import random +from glob import glob + +import tensorflow as tf +from sklearn.metrics import roc_auc_score +import numpy as np +from npu_bridge.npu_init import * +from config import sess_config, Config, SSD_DATA_PATH, CacheModeEnum +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET +from mx_rec.core.asc.helper import FeatureSpec, get_asc_insert_func +from mx_rec.core.asc.manager import start_asc_pipeline +from mx_rec.core.embedding import create_table, sparse_lookup +from mx_rec.core.feature_process import EvictHook +from mx_rec.graph.modifier import modify_graph_and_start_emb_cache, GraphModifierHook +from mx_rec.constants.constants import ASCEND_TIMESTAMP +from mx_rec.util.initialize import ConfigInitializer, init, terminate_config_initializer +from mx_rec.util.ops import import_host_pipeline_ops +import mx_rec.util as mxrec_util +from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.util.log import logger +from optimizer import get_dense_and_sparse_optimizer + +from model import MyModel + +npu_plugin.set_device_sat_mode(0) + +dense_hashtable_seed = 128 +sparse_hashtable_seed = 128 +shuffle_seed = 128 +random.seed(shuffle_seed) + + +def add_timestamp_func(batch): + timestamp = import_host_pipeline_ops().return_timestamp(tf.cast(batch['label'], dtype=tf.int64)) + batch["timestamp"] = timestamp + return batch + + +def make_batch_and_iterator(config, feature_spec_list, is_training, dump_graph, is_use_faae=False): + if config.USE_PIPELINE_TEST: + num_parallel = 1 + else: + num_parallel = 8 + + def extract_fn(data_record): + features = { + # Extract features using the keys set during creation + 'label': tf.compat.v1.FixedLenFeature(shape=(2 * config.line_per_sample,), dtype=tf.int64), + 'sparse_feature': tf.compat.v1.FixedLenFeature(shape=(29 * config.line_per_sample,), dtype=tf.int64), + 'dense_feature': tf.compat.v1.FixedLenFeature(shape=(11 * config.line_per_sample,), dtype=tf.float32), + } + sample = tf.compat.v1.parse_single_example(data_record, features) + return sample + + def reshape_fn(batch): + batch['label'] = tf.reshape(batch['label'], [-1, 2]) + batch['dense_feature'] = tf.reshape(batch['dense_feature'], [-1, 11]) + batch['sparse_feature'] = tf.reshape(batch['sparse_feature'], [-1, 29]) + return batch + + if is_training: + files_list = glob(os.path.join(config.data_path, config.train_file_pattern) + '/*.tfrecord') + else: + files_list = glob(os.path.join(config.data_path, config.test_file_pattern) + '/*.tfrecord') + dataset = tf.data.TFRecordDataset(files_list, num_parallel_reads=num_parallel) + batch_size = config.batch_size // config.line_per_sample + + dataset = dataset.shard(config.rank_size, config.rank_id) + if is_training: + dataset = dataset.shuffle(batch_size * 1000, seed=shuffle_seed) + if is_training: + dataset = dataset.repeat(config.train_epoch) + else: + dataset = dataset.repeat(config.test_epoch) + dataset = dataset.map(extract_fn, num_parallel_calls=num_parallel).batch(batch_size, + drop_remainder=True) + dataset = dataset.map(reshape_fn, num_parallel_calls=num_parallel) + if is_use_faae: + dataset = dataset.map(add_timestamp_func) + + if not MODIFY_GRAPH_FLAG: + insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph) + dataset = dataset.map(insert_fn) + + dataset = dataset.prefetch(100) + + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + return batch, iterator + + +def model_forward(feature_list, hash_table_list, batch, is_train, modify_graph): + embedding_list = [] + logger.debug(f"In model_forward function, is_train: {is_train}, feature_list: {len(feature_list)}, " + f"hash_table_list: {len(hash_table_list)}") + for feature, hash_table in zip(feature_list, hash_table_list): + if MODIFY_GRAPH_FLAG: + feature = batch["sparse_feature"] + embedding = sparse_lookup(hash_table, feature, cfg.send_count, dim=None, is_train=is_train, + name="user_embedding_lookup", modify_graph=modify_graph, batch=batch, + access_and_evict_config=None) + embedding_list.append(embedding) + + if len(embedding_list) == 1: + emb = embedding_list[0] + elif len(embedding_list) > 1: + emb = tf.reduce_sum(embedding_list, axis=0, keepdims=False) + else: + raise ValueError("the length of embedding_list must be greater than or equal to 1.") + emb = tf.reduce_sum(emb, axis=1) + my_model = MyModel() + model_output = my_model.build_model(embedding=emb, + dense_feature=batch["dense_feature"], + label=batch["label"], + is_training=is_train, + seed=dense_hashtable_seed) + return model_output + + +def evaluate(): + print("read_test dataset") + if not MODIFY_GRAPH_FLAG: + eval_label = eval_model.get("label") + sess.run([eval_iterator.initializer]) + else: + # In sess run mode, if the label from the original batch is still used for sess run, + # a getnext timeout error will occur, and a new batch from the new dataset needs to be used + eval_label = ConfigInitializer.get_instance().train_params_config.get_target_batch(False).get("label") + sess.run([ConfigInitializer.get_instance().train_params_config.get_initializer(False)]) + log_loss_list = [] + pred_income_list = [] + pred_mat_list = [] + label_income_list = [] + label_mat_list = [] + eval_current_steps = 0 + finished = False + print("eval begin") + + while not finished: + + eval_current_steps += 1 + eval_start = time.time() + try: + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_label]) + except tf.errors.OutOfRangeError: + break + eval_cost = time.time() - eval_start + qps_eval = (1 / eval_cost) * rank_size * cfg.batch_size + log_loss_list += list(eval_loss.reshape(-1)) + pred_income = pred[0] + pred_mat = pred[1] + pred_income_list += list(pred_income.reshape(-1)) + pred_mat_list += list(pred_mat.reshape(-1)) + label_income_list += list(label[:, 0].reshape(-1)) + label_mat_list += list(label[:, 1].reshape(-1)) + print(f"eval current_steps: {eval_current_steps}, qps: {qps_eval}") + if eval_current_steps == eval_steps: + finished = True + + auc_income = roc_auc_score(label_income_list, pred_income_list) + auc_mat = roc_auc_score(label_mat_list, pred_mat_list) + mean_log_loss = np.mean(log_loss_list) + return auc_income, auc_mat, mean_log_loss + + +def evaluate_fix(step): + print("read_test dataset evaluate_fix") + if not MODIFY_GRAPH_FLAG: + sess.run([eval_iterator.initializer]) + else: + sess.run([ConfigInitializer.get_instance().train_params_config.get_initializer(False)]) + log_loss_list = [] + pred_list = [] + label_list = [] + eval_current_steps = 0 + finished = False + print("eval begin") + while not finished: + try: + eval_current_steps += 1 + eval_loss, pred, label = sess.run([eval_model.get("loss"), eval_model.get("pred"), eval_model.get("label")]) + log_loss_list += list(eval_loss.reshape(-1)) + pred_list += list(pred.reshape(-1)) + label_list += list(label.reshape(-1)) + print(f"eval current_steps: {eval_current_steps}") + + if eval_current_steps == eval_steps: + finished = True + except tf.errors.OutOfRangeError: + finished = True + + label_numpy = np.array(label_list) + pred_numpy = np.array(pred_list) + if not os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}"): + os.makedirs(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}") + + if os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy"): + os.remove(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy") + if os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy"): + os.remove(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy") + if os.path.exists(f"flag_{rank_id}.txt"): + os.remove(f"flag_{rank_id}.txt") + np.save(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy", label_numpy) + np.save(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy", pred_numpy) + os.mknod(f"flag_{rank_id}.txt") + while True: + file_exists_list = [os.path.exists(f"flag_{i}.txt") for i in range(rank_size)] + if sum(file_exists_list) == rank_size: + print("All saved!!!!!!!!!!") + break + else: + print("Waitting for saving numpy!!!!!!!!") + time.sleep(1) + continue + + auc = roc_auc_score(label_list, pred_list) + mean_log_loss = np.mean(log_loss_list) + return auc, mean_log_loss + + +def create_feature_spec_list(use_timestamp=False): + access_threshold = None + eviction_threshold = None + if use_timestamp: + access_threshold = 1000 + eviction_threshold = 180 + + feature_spec_list = [FeatureSpec("sparse_feature", table_name="sparse_embeddings", batch_size=cfg.batch_size, + access_threshold=access_threshold, eviction_threshold=eviction_threshold)] + if use_multi_lookup: + feature_spec_list.append(FeatureSpec("sparse_feature", table_name="sparse_embeddings", + batch_size=cfg.batch_size, + access_threshold=access_threshold, + eviction_threshold=eviction_threshold)) + if use_timestamp: + feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) + return feature_spec_list + + +def _del_related_dir(del_path: str) -> None: + if not os.path.isabs(del_path): + del_path = os.path.join(os.getcwd(), del_path) + dirs = glob(del_path) + for sub_dir in dirs: + shutil.rmtree(sub_dir, ignore_errors=True) + logger.info(f"Delete dir:{sub_dir}") + + +def _clear_saved_model() -> None: + _del_related_dir("/root/ascend/log/*") + _del_related_dir("kernel*") + _del_related_dir("model_dir_rank*") + _del_related_dir("op_cache") + + if os.getenv("CACHE_MODE", "") != CacheModeEnum.SSD.value: + return + logger.info("Current cache mode is SSD, and file overwrite is not allowed in SSD mode, deleting exist directory" + " then create empty directory for this use case.") + for sub_path in SSD_DATA_PATH: + _del_related_dir(sub_path) + os.makedirs(sub_path, mode=0o550, exist_ok=True) + logger.info(f"Create dir:{sub_path}") + + +if __name__ == "__main__": + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + warnings.filterwarnings("ignore") + _clear_saved_model() + + rank_size = int(os.getenv("TRAIN_RANK_SIZE")) if os.getenv("TRAIN_RANK_SIZE") else None + interval = int(os.getenv("INTERVAL")) if os.getenv("INTERVAL") else None + train_steps = 1000 + eval_steps = 1000 + + try: + use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) + use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 0))) + MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) + use_faae = bool(int(os.getenv("USE_FAAE", 0))) + except ValueError as err: + raise ValueError("please correctly config USE_DYNAMIC_EXPANSION or USE_MULTI_LOOKUP or USE_FAAE " + "or USE_MODIFY_GRAPH only 0 or 1 is supported.") from err + + use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) + logger.info(f"USE_DYNAMIC:{use_dynamic}") + init(train_steps=train_steps, eval_steps=eval_steps, + use_dynamic=use_dynamic, use_dynamic_expansion=use_dynamic_expansion) + + rank_id = mxrec_util.communication.hccl_ops.get_rank_id() + cfg = Config() + feature_spec_list_train = None + feature_spec_list_eval = None + if use_faae: + feature_spec_list_train = create_feature_spec_list(use_timestamp=True) + feature_spec_list_eval = create_feature_spec_list(use_timestamp=True) + else: + feature_spec_list_train = create_feature_spec_list(use_timestamp=False) + feature_spec_list_eval = create_feature_spec_list(use_timestamp=False) + + train_batch, train_iterator = make_batch_and_iterator(cfg, feature_spec_list_train, is_training=True, + dump_graph=True, is_use_faae=use_faae) + eval_batch, eval_iterator = make_batch_and_iterator(cfg, feature_spec_list_eval, is_training=False, + dump_graph=False, is_use_faae=use_faae) + logger.info(f"train_batch: {train_batch}") + + if use_faae: + cfg.dev_vocab_size = cfg.dev_vocab_size // 2 + + optimizer_list = [get_dense_and_sparse_optimizer(cfg)] + + # note: variance_scaling_initializer only support HBM mode + emb_initializer = tf.constant_initializer(value=0.1) + sparse_hashtable = create_table( + key_dtype=cfg.key_type, + dim=tf.TensorShape([cfg.emb_dim]), + name="sparse_embeddings", + emb_initializer=emb_initializer, + **cfg.get_emb_table_cfg() + ) + if use_faae: + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, train_batch["timestamp"]) + + sparse_hashtable_list = [sparse_hashtable, sparse_hashtable] if use_multi_lookup else [sparse_hashtable] + train_model = model_forward(feature_spec_list_train, sparse_hashtable_list, train_batch, + is_train=True, modify_graph=MODIFY_GRAPH_FLAG) + eval_model = model_forward(feature_spec_list_eval, sparse_hashtable_list, eval_batch, + is_train=False, modify_graph=MODIFY_GRAPH_FLAG) + + dense_variables, sparse_variables = get_dense_and_sparse_variable() + trainable_varibles = [] + trainable_varibles.extend(dense_variables) + if use_dynamic_expansion: + trainable_varibles.append(tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB)[0]) + else: + trainable_varibles.extend(sparse_variables) + rank_size = mxrec_util.communication.hccl_ops.get_rank_size() + train_ops = [] + # multi task training + for loss, (dense_optimizer, sparse_optimizer) in zip([train_model.get("loss")], optimizer_list): + # do dense optimization + grads = dense_optimizer.compute_gradients(loss, var_list=trainable_varibles) + avg_grads = [] + for grad, var in grads[:-1]: + if rank_size > 1: + grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None + if grad is not None: + avg_grads.append((grad / 8.0, var)) + # apply gradients: update variables + train_ops.append(dense_optimizer.apply_gradients(avg_grads)) + + if use_dynamic_expansion: + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + # do sparse optimization by addr + sparse_grads = list(grads[-1]) # local_embedding + grads_and_vars = [(grad, address) for grad, address in zip(sparse_grads, train_address_list)] + train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + else: + # do sparse optimization + sparse_grads = list(grads[-1]) + print("sparse_grads_tensor:", sparse_grads) + grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] + train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + + + with tf.control_dependencies(train_ops): + train_ops = tf.no_op() + cfg.learning_rate = [cfg.learning_rate[0], cfg.learning_rate[1]] + + if MODIFY_GRAPH_FLAG: + modify_graph_and_start_emb_cache(dump_graph=True) + else: + start_asc_pipeline() + + hook_list = [] + if use_faae: + hook_evict = EvictHook(evict_enable=True, evict_time_interval=120) + hook_list.append(hook_evict) + if MODIFY_GRAPH_FLAG: # 该场景添加hook处理校验问题 + hook_list.append(GraphModifierHook(modify_graph=False)) + + if use_faae: + sess = tf.compat.v1.train.MonitoredTrainingSession( + hooks=hook_list, + config=sess_config(dump_data=False) + ) + sess.graph._unsafe_unfinalize() + if not MODIFY_GRAPH_FLAG: + sess.run(train_iterator.initializer) + else: + sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True)) + else: + sess = tf.compat.v1.Session(config=sess_config(dump_data=False)) + sess.run(tf.compat.v1.global_variables_initializer()) + if not MODIFY_GRAPH_FLAG: + sess.run(train_iterator.initializer) + else: + sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True)) + + epoch = 0 + cost_sum = 0 + qps_sum = 0 + best_auc_income = 0 + best_auc_mat = 0 + iteration_per_loop = 10 + + train_ops = util.set_iteration_per_loop(sess, train_ops, 10) + + i = 0 + while True: + i += 1 + logger.info(f"################ training at step {i * iteration_per_loop} ################") + start_time = time.time() + + try: + grad, loss, lr, global_step = sess.run([train_ops, train_model.get("loss"), + cfg.learning_rate, cfg.global_step]) + except tf.errors.OutOfRangeError: + logger.info(f"Encounter the end of Sequence for training.") + break + + end_time = time.time() + cost_time = end_time - start_time + qps = (1 / cost_time) * rank_size * cfg.batch_size * iteration_per_loop + cost_sum += cost_time + logger.info(f"step: {i * iteration_per_loop}; training loss: {loss}") + logger.info(f"step: {i * iteration_per_loop}; grad: {grad}") + logger.info(f"step: {i * iteration_per_loop}; lr: {lr}") + logger.info(f"global step: {global_step}") + logger.info(f"step: {i * iteration_per_loop}; current sess cost time: {cost_time:.10f}; current QPS: {qps}") + logger.info(f"training at step:{i * iteration_per_loop}, table[{sparse_hashtable.table_name}], " + f"table size:{sparse_hashtable.size()}, table capacity:{sparse_hashtable.capacity()}") + + if i % (train_steps // iteration_per_loop) == 0: + if interval is not None: + test_auc_income, test_auc_mat, test_mean_log_loss = evaluate_fix(i * iteration_per_loop) + else: + test_auc_income, test_auc_mat, test_mean_log_loss = evaluate() + print("Test auc income: {};Test auc mat: {} ;log_loss: {} ".format(test_auc_income, + test_auc_mat, test_mean_log_loss)) + best_auc_income = max(best_auc_income, test_auc_income) + best_auc_mat = max(best_auc_mat, test_auc_mat) + logger.info(f"training step: {i * iteration_per_loop}, best auc income: " + f"{best_auc_income} , best auc mat: {best_auc_mat}") + + + sess.close() + + terminate_config_initializer() + logger.info("Demo done!") diff --git a/examples/mmoe/model.py b/examples/mmoe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8cbb7ba8f2f823b6333a1c35cc0e150773f32298 --- /dev/null +++ b/examples/mmoe/model.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time +from easydict import EasyDict as edict + +import tensorflow as tf + + +model_cfg = edict() +model_cfg.loss_mode = "batch" +LOSS_OP_NAME = "loss" +LABEL_OP_NAME = "label" +VAR_LIST = "variable" +PRED_OP_NAME = "pred" + + +class MyModel: + def __init__(self, expert_num=8, expert_size=16, tower_size=8, gate_num=2): + + self.expert_num = expert_num + self.expert_size = expert_size + self.tower_size = tower_size + self.gate_num = gate_num + + + def expert_layer(self, _input): + param_expert = [] + for i in range(0, self.expert_num): + expert_linear = tf.layers.dense(_input, units=self.expert_size, activation=None, name=f'expert_layer_{i}', + kernel_initializer=tf.constant_initializer(value=0.1), + bias_initializer=tf.constant_initializer(value=0.1)) + + param_expert.append(expert_linear) + return param_expert + + + def gate_layer(self, _input): + param_gate = [] + for i in range(0, self.gate_num): + gate_linear = tf.layers.dense(_input, units=self.expert_num, activation=None, name=f'gate_layer_{i}', + kernel_initializer=tf.constant_initializer(value=0.1), + bias_initializer=tf.constant_initializer(value=0.1)) + + param_gate.append(gate_linear) + return param_gate + + + def tower_layer(self, _input, layer_name): + tower_linear = tf.layers.dense(_input, units=self.tower_size, activation='relu', + name=f'tower_layer_{layer_name}', + kernel_initializer=tf.constant_initializer(value=0.1), + bias_initializer=tf.constant_initializer(value=0.1)) + + tower_linear_out = tf.layers.dense(tower_linear, units=2, activation=None, + name=f'tower_payer_out_{layer_name}', + kernel_initializer=tf.constant_initializer(value=0.1), + bias_initializer=tf.constant_initializer(value=0.1)) + + return tower_linear_out + + + + + def build_model(self, + embedding=None, + dense_feature=None, + label=None, + is_training=True, + seed=None): + + with tf.variable_scope("mmoe", reuse=tf.AUTO_REUSE): + + dense_expert = self.expert_layer(dense_feature) + dense_gate = self.gate_layer(dense_feature) + + all_expert = [] + _slice_num = 0 + for i in range(0, self.expert_num): + slice_num_end = _slice_num + self.expert_size + cur_expert = tf.add(dense_expert[i], embedding[:, _slice_num:slice_num_end]) + cur_expert = tf.nn.relu(cur_expert) + all_expert.append(cur_expert) + _slice_num = slice_num_end + + expert_concat = tf.concat(all_expert, axis=1) + expert_concat = tf.reshape(expert_concat, [-1, self.expert_num, self.expert_size]) + + output_layers = [] + out_pred = [] + for i in range(0, self.gate_num): + slice_gate_end = _slice_num + self.expert_num + cur_gate = tf.add(dense_gate[i], embedding[:, _slice_num:slice_gate_end]) + cur_gate = tf.nn.softmax(cur_gate) + + cur_gate = tf.reshape(cur_gate, [-1, self.expert_num, 1]) + + cur_gate_expert = tf.multiply(x=expert_concat, y=cur_gate) + cur_gate_expert = tf.reduce_sum(cur_gate_expert, axis=1) + + out = self.tower_layer(cur_gate_expert, i) + out = tf.nn.softmax(out) + out = tf.clip_by_value(out, clip_value_min=1e-15, clip_value_max=1.0 - 1e-15) + output_layers.append(out) + out_pred.append(tf.nn.softmax(out[:, 1])) + _slice_num = slice_gate_end + trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='mmoe') + + label_income = label[:, 0:1] + label_mat = label[:, 1:] + + pred_income_1 = tf.slice(output_layers[0], [0, 1], [-1, 1]) + pred_marital_1 = tf.slice(output_layers[1], [0, 1], [-1, 1]) + + cost_income = tf.losses.log_loss(labels=tf.cast(label_income, tf.float32), predictions=pred_income_1, + epsilon=1e-4) + cost_marital = tf.losses.log_loss(labels=tf.cast(label_mat, tf.float32), predictions=pred_marital_1, + epsilon=1e-4) + + avg_cost_income = tf.reduce_mean(cost_income) + avg_cost_marital = tf.reduce_mean(cost_marital) + + loss = 0.5 * (avg_cost_income + avg_cost_marital) + + return {LOSS_OP_NAME: loss, + PRED_OP_NAME: out_pred, + LABEL_OP_NAME: label, + VAR_LIST: trainable_variables} diff --git a/examples/mmoe/op_impl_mode.ini b/examples/mmoe/op_impl_mode.ini new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/mmoe/optimizer.py b/examples/mmoe/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5469c705c7476f007aa56ffc6f8af85ee328fc05 --- /dev/null +++ b/examples/mmoe/optimizer.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf + +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.optimizers.lazy_adam import create_hash_optimizer +from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address + + + +def get_dense_and_sparse_optimizer(cfg): + dense_optimizer = tf.train.AdamOptimizer(learning_rate=cfg.learning_rate[0]) + use_dynamic_expansion = ConfigInitializer.get_instance().use_dynamic_expansion + if use_dynamic_expansion: + sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate[1]) + else: + sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate[1]) + + return dense_optimizer, sparse_optimizer diff --git a/examples/mmoe/run.sh b/examples/mmoe/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c1424435bac57a727cfc1980e25e08c6c5ba74c --- /dev/null +++ b/examples/mmoe/run.sh @@ -0,0 +1,99 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +cur_path=$(dirname "$(readlink -f "$0")") + +so_path=$1 +mx_rec_package_path=$2 +hccl_cfg_json=$3 +dlrm_criteo_data_path=$4 +ip=$5 # no ranktable时传入该参数 + +interface="lo" +num_server=1 +local_rank_size=8 +num_process=$((num_server * local_rank_size)) +export TRAIN_RANK_SIZE=$num_process + +################# 参数配置 ###################### +export USE_DYNAMIC=0 # 0:静态shape;1:动态shape +export CACHE_MODE="HBM" # HBM;DDR;SSD +export USE_FAAE=0 # 0:关闭准入淘汰;1:开启准入淘汰 +export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 +export USE_MULTI_LOOKUP=0 # 0:一表一查;1:一表多查 +export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 +################################################ +echo "CACHE_MODE:${CACHE_MODE}" + +export HCCL_CONNECT_TIMEOUT=1200 +export DLRM_CRITEO_DATA_PATH=${dlrm_criteo_data_path} +export PYTHONPATH=${mx_rec_package_path}:${so_path}:$PYTHONPATH +export LD_PRELOAD=/usr/lib64/libgomp.so.1 +export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH +export ASCEND_DEVICE_ID=0 +export RANK_ID_START=0 +export JOB_ID=10086 +export CUSTOMIZED_OPS_LIB_PATH=${so_path}/libcust_ops.so # Todo: please config +export MXREC_LOG_LEVEL="INFO" +export TF_CPP_MIN_LOG_LEVEL=3 +export ASCEND_GLOBAL_LOG_LEVEL=3 +#export USE_FAAE=1 +export ENABLE_FORCE_V2_CONTROL=1 + +export PROFILING_OPTIONS='{"output":"/home/yz/profiling", + "training_trace":"on", + "task_trace":"on", + "aicpu":"on", + "fp_point":"", + "bp_point":"", + "aic_metrics":"PipeUtilization"}' + +RANK_ID_START=0 + +export MXREC_MODE="ASC" +echo "MXREC_MODE is $MXREC_MODE" +export py=main_mxrec.py +echo "py is $py" + +# 区分ranktable和no ranktable +if [ -n "$ip" ]; then + # no ranktable分支 + echo "Current is no ranktable solution." + echo "Input node ip: $ip, please make sure this ip is available." + export CM_CHIEF_IP=$ip # 主节点ip + export CM_CHIEF_PORT=60001 # 主节点监听端口 + export CM_CHIEF_DEVICE=0 # 主节点device id + export CM_WORKER_IP=$ip # 当前节点ip + export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 + echo "CM_CHIEF_IP=$CM_CHIEF_IP" + echo "CM_CHIEF_PORT=$CM_CHIEF_PORT" + echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" + echo "CM_WORKER_IP=$CM_WORKER_IP" + echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" +else + # ranktable分支 + echo "Current is ranktable solution, hccl json file:${hccl_cfg_json}" + export RANK_SIZE=$num_process + echo "RANK_SIZE=${RANK_SIZE}, please make sure hccl configuration json file match this parameter" + export RANK_TABLE_FILE=${hccl_cfg_json} +fi + +echo "use horovod to start tasks" +# GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=2 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' + +horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ +python3.7 ${py} 2>&1 | tee temp_${CACHE_MODE}_${num_process}p.log diff --git a/src/AccCTR/src/common/util/error_code.h b/src/AccCTR/src/common/util/error_code.h index b30bfd830016466ef4b3fea8ec1cd9955d983335..87c8ffe61ca2dbc6f2f87692ac299d501ef2d9e0 100644 --- a/src/AccCTR/src/common/util/error_code.h +++ b/src/AccCTR/src/common/util/error_code.h @@ -43,6 +43,7 @@ using CTRCode = enum : int { H_TABLE_NAME_EMPTY = 22, H_PREFILL_BUFFER_SIZE_INVALID = 23, H_TABLE_NAME_TOO_LONG = 24, + H_EMB_CACHE_INFO_LOST = 25 }; } } diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp index 76e90abc4c1760f24e9cf5f6a2f6d35fbbaca287..5257882036a2ccbd3e52096bfee4b0aa3b1720b3 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.cpp @@ -253,8 +253,7 @@ int EmbCacheManagerImpl::ExportDeviceKeyOffsetPairs(const std::string& tableName if (checkTableNameRet != H_OK) { return checkTableNameRet; } - OffsetMapper& om = offsetMappers[tableName]; - koVec = om.ExportSortedKVPairs(); + koVec = offsetMappers[tableName].ExportSortedKVPairs(); return H_OK; } @@ -318,6 +317,61 @@ int EmbCacheManagerImpl::LoadEmbTableInfos(std::string tableName, const std::vec return H_OK; } +int EmbCacheManagerImpl::BackUpTrainStatus(const std::string& tableName) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + // Back up the key-offset correspondence on the device + kvVecsBackUp[tableName] = offsetMappers[tableName].ExportVec(); + + auto embInfo = embCacheInfos.find(tableName); + if (embInfo == embCacheInfos.end()) { + return H_EMB_CACHE_INFO_LOST; + } + uint32_t reserve = embInfo->second.maxCacheSize / VOCAB_CACHE_RATIO; + uint32_t maxCacheSize = embInfo->second.maxCacheSize; + + auto om = offsetMappersBackUp.find(tableName); + if (om != offsetMappersBackUp.end()) { + offsetMappersBackUp[tableName].UnInitialize(); + } + offsetMappersBackUp[tableName].Initialize(reserve, maxCacheSize); + offsetMappersBackUp[tableName] = offsetMappers[tableName]; + + return H_OK; +} + +int EmbCacheManagerImpl::RecoverTrainStatus(const std::string& tableName) +{ + int checkTableNameRet = CheckValidTableName(tableName); + if (checkTableNameRet != H_OK) { + return checkTableNameRet; + } + + auto embInfo = embCacheInfos.find(tableName); + if (embInfo == embCacheInfos.end()) { + return H_EMB_CACHE_INFO_LOST; + } + uint32_t reserve = embInfo->second.maxCacheSize / VOCAB_CACHE_RATIO; + uint32_t maxCacheSize = embInfo->second.maxCacheSize; + + offsetMappers[tableName].UnInitialize(); + offsetMappers[tableName].Initialize(reserve, maxCacheSize); + offsetMappers[tableName] = offsetMappersBackUp[tableName]; + + // Recover the key-offset correspondence on the device + auto kvVecBackUp = kvVecsBackUp[tableName]; + for (const auto& kvPair: kvVecBackUp) { + offsetMappers[tableName].Put(kvPair.first, kvPair.second); + } + + kvVecBackUp.clear(); + return H_OK; +} + void EmbCacheManagerImpl::Destroy() { for (auto it = offsetMappers.begin(); it != offsetMappers.end(); it++) { @@ -422,3 +476,17 @@ uint32_t EmbCacheManagerImpl::GetUsage(const std::string& tableName) { return embTables[tableName].GetUsage(); } + +int EmbCacheManagerImpl::ResetOffsetMappers() +{ + for (auto it = offsetMappers.begin(); it != offsetMappers.end(); it++) { + auto embInfo = embCacheInfos.find(it->first); + if (embInfo == embCacheInfos.end()) { + return H_EMB_CACHE_INFO_LOST; + } + it->second.UnInitialize(); + uint32_t reserve = embInfo->second.maxCacheSize / VOCAB_CACHE_RATIO; + it->second.Initialize(reserve, embInfo->second.maxCacheSize); + } + return H_OK; +} diff --git a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h index 80fbcd46c90e91f6ea5eb80537e316f76a701ed2..e4a240ae4945556dac3ec724c39d0485b569b84e 100644 --- a/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h +++ b/src/AccCTR/src/embedding_cache/cache_manager/cache_manager.h @@ -73,12 +73,20 @@ public: const std::vector>& embeddings, const std::vector>& optimizerSlots) override; + int BackUpTrainStatus(const std::string& tableName) override; + + int RecoverTrainStatus(const std::string& tableName) override; + + int ResetOffsetMappers() override; + uint32_t GetUsage(const std::string& tableName) override; private: std::map embCacheInfos; std::map offsetMappers; + std::map offsetMappersBackUp; std::map embTables; + std::map>> kvVecsBackUp; int CheckValidTableName(const std::string& tableName); diff --git a/src/AccCTR/src/embedding_cache/limited_set.h b/src/AccCTR/src/embedding_cache/limited_set.h index 036a64775b226909dd76745083b7f8e8f2f75038..f7bc2e1e6fac570becffce4cd772da036afc285c 100644 --- a/src/AccCTR/src/embedding_cache/limited_set.h +++ b/src/AccCTR/src/embedding_cache/limited_set.h @@ -20,19 +20,21 @@ limitations under the License. namespace EmbCache { +static constexpr int64_t NODE_DEFAULT_VALUE = -1; + class LimitedSet { public: struct Node { uint64_t value; Node *prev, *next; - Node(uint64_t val = -1) : value(val), prev(nullptr), next(nullptr) {} + Node(uint64_t val = NODE_DEFAULT_VALUE) : value(val), prev(nullptr), next(nullptr) {} }; - LimitedSet(uint64_t maxRange) : head(new Node(-1)), tail(new Node(-1)) + LimitedSet(uint64_t maxRange) : head(new Node(NODE_DEFAULT_VALUE)), tail(new Node(NODE_DEFAULT_VALUE)) { nodes.resize(maxRange); for (auto &node : nodes) { - node = new Node(-1); + node = new Node(NODE_DEFAULT_VALUE); } head->next = tail; tail->prev = head; @@ -47,6 +49,21 @@ public: delete tail; } + LimitedSet(const LimitedSet& other): head(new Node(NODE_DEFAULT_VALUE)), tail(new Node(NODE_DEFAULT_VALUE)) + { + nodes.resize(other.nodes.size()); + for (auto& node: nodes) { + node = new Node(NODE_DEFAULT_VALUE); + } + + head->next = tail; + tail->prev = head; + + for (Node* node = other.head->next; node != other.tail; node = node->next) { + insert(node->value); + } + } + void insert(uint64_t value) { if (nodes[value]->value == value) { @@ -69,7 +86,7 @@ public: Node *node = nodes[value]; node->prev->next = node->next; node->next->prev = node->prev; - node->value = -1; + node->value = NODE_DEFAULT_VALUE; } bool find(uint64_t value) diff --git a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h index f42a0d3fc0c0e4c4ea14e88c0183d050e7fdba74..1ad470c5bf4fb9bf7dea441f7874466849b13b39 100644 --- a/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h +++ b/src/AccCTR/src/embedding_cache/offset_mapper/offset_mapper.h @@ -35,6 +35,38 @@ public: ~OffsetMapper() = default; + OffsetMapper(const OffsetMapper& other): maxCacheSize(other.maxCacheSize), useLength(other.useLength), + validPos(new LimitedSet(*other.validPos)), + evictPos(new LimitedSet(*other.evictPos)), + pos2Key(other.pos2Key), lastBatchPos(other.lastBatchPos), + evictSize(other.evictSize) + { + } + + OffsetMapper& operator=(const OffsetMapper& other) + { + if (this != &other) { + delete validPos; + validPos = nullptr; + delete evictPos; + evictPos = nullptr; + + if (other.validPos != nullptr) { + validPos = new LimitedSet(*other.validPos); + } + if (other.evictPos != nullptr) { + evictPos = new LimitedSet(*other.evictPos); + } + + maxCacheSize = other.maxCacheSize; + useLength = other.useLength; + pos2Key = other.pos2Key; + lastBatchPos = other.lastBatchPos; + evictSize = other.evictSize; + } + return *this; + } + bool Initialize(uint32_t reserve, uint32_t maxSize = 0) { maxCacheSize = maxSize; diff --git a/src/AccCTR/src/include/embedding_cache.h b/src/AccCTR/src/include/embedding_cache.h index 4adf1fbf57dfcf35d2e157a32e5fd778aa2982a5..c0468549312a8b0a4834b0774eaa9e10eaadcf3b 100644 --- a/src/AccCTR/src/include/embedding_cache.h +++ b/src/AccCTR/src/include/embedding_cache.h @@ -315,6 +315,26 @@ public: virtual int LoadEmbTableInfos(std::string tableName, const std::vector& keys, const std::vector>& embeddings, const std::vector>& optimizerSlots) = 0; + + /* * + * When switch the channel to eval, backup the current table's offsetMapper object. + * @Param tableName: embedding table name + * @Return errorCode + */ + virtual int BackUpTrainStatus(const std::string& tableName) = 0; + + /* * + * When switch the eval channel back to train, Recover the current table's offsetMapper object to the backup state. + * @Param tableName: embedding table name + * @Return errorCode + */ + virtual int RecoverTrainStatus(const std::string& tableName) = 0; + + /* * + * Reset the offsetMapper object to revert to its initialized state after loading. + * @Return errorCode + */ + virtual int ResetOffsetMappers() = 0; }; } // namespace EmbCache diff --git a/src/core/emb_table/embedding_ddr.cpp b/src/core/emb_table/embedding_ddr.cpp index d163aa6d3e12389761ee74e1b13c841f8fe0a583..d05b35019f9b9af0e019616dba27f1d0eccaa8de 100644 --- a/src/core/emb_table/embedding_ddr.cpp +++ b/src/core/emb_table/embedding_ddr.cpp @@ -78,6 +78,11 @@ void EmbeddingDDR::Load(const string& savePath, mapResetOffsetMappers(); + if (rs != 0) { + throw runtime_error("embCache->ResetOffsetMappers failed, err code: " + to_string(rc)); + } } void EmbeddingDDR::LoadKey(const string &savePath, vector &keys) @@ -187,15 +192,13 @@ void EmbeddingDDR::LoadOptimizerSlot(const string &savePath, vector keys; vector> embeddings; vector> optimizerSlots; auto step = GetStepFromPath(savePath); - if (step > 0) { - SyncLatestEmbedding(); - embCache->GetEmbTableInfos(name, keys, embeddings, optimizerSlots); - } + embCache->GetEmbTableInfos(name, keys, embeddings, optimizerSlots); SaveKey(savePath, keys); SaveEmbedding(savePath, embeddings); @@ -376,3 +379,13 @@ void EmbeddingDDR::SetEmbCache(ock::ctr::EmbCacheManagerPtr embCache) { this->embCache = embCache; } + +void EmbeddingDDR::BackUpTrainStatus() +{ + embCache->BackUpTrainStatus(name); +} + +void EmbeddingDDR::RecoverTrainStatus() +{ + embCache->RecoverTrainStatus(name); +} diff --git a/src/core/emb_table/embedding_ddr.h b/src/core/emb_table/embedding_ddr.h index ac5c5878baf99c853ea925348cf22adb58f62f5b..26d85e606b37b414ba62fc41b9782cc4b30fceb9 100644 --- a/src/core/emb_table/embedding_ddr.h +++ b/src/core/emb_table/embedding_ddr.h @@ -73,6 +73,9 @@ public: void SaveEmbAndOptim(const string& savePath); void SetEmbCache(ock::ctr::EmbCacheManagerPtr embCache); + void BackUpTrainStatus(); + void RecoverTrainStatus(); + GTEST_PRIVATE: void EvictDeleteEmb(const vector& keys); diff --git a/src/core/emb_table/embedding_mgmt.cpp b/src/core/emb_table/embedding_mgmt.cpp index 9e7dcbb09c34afd19d8d965d8766069623cffdfa..d889cdba58ea51f95c743448ea29f19c77c56cc2 100644 --- a/src/core/emb_table/embedding_mgmt.cpp +++ b/src/core/emb_table/embedding_mgmt.cpp @@ -196,3 +196,17 @@ void EmbeddingMgmt::SetEmbCacheForEmbTable(const ock::ctr::EmbCacheManagerPtr& e table.second->SetEmbCache(embCache); } } + +void EmbeddingMgmt::BackUpTrainStatusBeforeLoad() +{ + for (auto& table: embeddings) { + table.second->BackUpTrainStatus(); + } +} + +void EmbeddingMgmt::RecoverTrainStatus() +{ + for (auto& table: embeddings) { + table.second->RecoverTrainStatus(); + } +} \ No newline at end of file diff --git a/src/core/emb_table/embedding_mgmt.h b/src/core/emb_table/embedding_mgmt.h index ef10678667255e38c0718bc1ce5857c179304635..9dd0e363292f1dae84d618333e4d216c891e4494 100644 --- a/src/core/emb_table/embedding_mgmt.h +++ b/src/core/emb_table/embedding_mgmt.h @@ -89,6 +89,16 @@ public: */ void Save(const string& filePath); + /** + * In estimator mode, when switching from train to eval, backup the training state of all tables. + */ + void BackUpTrainStatusBeforeLoad(); + + /** + * In estimator mode, when switching from eval to train, recover the training state of all tables. + */ + void RecoverTrainStatus(); + /** * 获取所有表对应的DeviceOffsets,该偏移用于python侧保存embedding时抽取key对应的embedding */ diff --git a/src/core/emb_table/embedding_static.cpp b/src/core/emb_table/embedding_static.cpp index 61874b1fcea290f7d7421cd17be432921f416cb2..0db152ed8ce0ece19dc1f073ff915abac3fedc1e 100644 --- a/src/core/emb_table/embedding_static.cpp +++ b/src/core/emb_table/embedding_static.cpp @@ -160,11 +160,23 @@ void EmbeddingStatic::LoadKey(const string& savePath) } maxOffset = keyOffsetMap.size(); - free(static_cast(buf)); } vector EmbeddingStatic::GetDeviceOffset() { return deviceOffset; -} \ No newline at end of file +} + +void EmbeddingStatic::BackUpTrainStatus() +{ + keyOffsetMapBackUp = keyOffsetMap; +} + +void EmbeddingStatic::RecoverTrainStatus() +{ + if (keyOffsetMapBackUp.size()!=0) { + keyOffsetMap = keyOffsetMapBackUp; + keyOffsetMapBackUp.clear(); + } +} diff --git a/src/core/emb_table/embedding_static.h b/src/core/emb_table/embedding_static.h index 6515f586736a769f08cc1b868a6c8add1a63504e..6f772e0891a09dd24e05983a125d2e046f01095e 100644 --- a/src/core/emb_table/embedding_static.h +++ b/src/core/emb_table/embedding_static.h @@ -39,6 +39,10 @@ public: void Save(const string& savePath); + void BackUpTrainStatus(); + + void RecoverTrainStatus(); + vector GetDeviceOffset(); GTEST_PRIVATE: diff --git a/src/core/emb_table/embedding_table.cpp b/src/core/emb_table/embedding_table.cpp index b4eb23795ec7b6bd1ea44724cb29d11704d54f48..12b0137a984b7e76925b783ac27b14e10a166b43 100644 --- a/src/core/emb_table/embedding_table.cpp +++ b/src/core/emb_table/embedding_table.cpp @@ -143,6 +143,14 @@ void EmbeddingTable::Save(const string& filePath) { } +void EmbeddingTable::BackUpTrainStatus() +{ +} + +void EmbeddingTable::RecoverTrainStatus() +{ +} + void EmbeddingTable::MakeDir(const string& dirName) { if (fileSystemPtr_ == nullptr) { diff --git a/src/core/emb_table/embedding_table.h b/src/core/emb_table/embedding_table.h index ef7418870229bd954f55e54f8581de1f539cd532..da6a42bee1e9f6eb719ed139e3ff4d7af0d773b1 100644 --- a/src/core/emb_table/embedding_table.h +++ b/src/core/emb_table/embedding_table.h @@ -77,6 +77,10 @@ public: void MakeDir(const string& dirName); + virtual void BackUpTrainStatus(); + + virtual void RecoverTrainStatus(); + virtual vector GetDeviceOffset(); vector GetLoadOffset(); @@ -97,6 +101,7 @@ public: size_t ssdVocabSize; size_t maxOffset; absl::flat_hash_map keyOffsetMap; + absl::flat_hash_map keyOffsetMapBackUp; std::vector evictDevPos; // 记录HBM内被淘汰的key std::vector evictHostPos; // 记录Host内淘汰列表 diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index f8ad92161310f9e60287443588ae3b726f36f6c6..4801f95b14df5ef31add10146e87df7bec3e7ce8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -206,12 +206,6 @@ bool HybridMgmt::Load(const string& loadPath, vector warmStartTables) throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } - if (mgmtRankInfo.isDDR && IsTrainAndEvalCase()) { - LOG_INFO("estimator train and eval case, skip loading, " - "host will reuse data in memory while evaluating since is's same as saved data"); - return true; - } - // 数据处理线程上锁 KEY_PROCESS_INSTANCE->LoadSaveLock(); @@ -221,6 +215,7 @@ bool HybridMgmt::Load(const string& loadPath, vector warmStartTables) Checkpoint loadCkpt; vector loadFeatures; SetFeatureTypeForLoad(loadFeatures); + BackUpTrainStatus(); if (warmStartTables.size() == 0) { EmbeddingMgmt::Instance()->Load(loadPath, trainKeysSet); @@ -256,10 +251,15 @@ bool HybridMgmt::Load(const string& loadPath, vector warmStartTables) featAdmitNEvict.LoadHistoryRecords(loadData.histRec); } + int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[TRAIN_CHANNEL_ID]; if (isL3StorageEnabled) { LOG_DEBUG(MGMT + "Start host side load: L3Storage key freq map"); auto step = GetStepFromPath(loadPath); - cacheManager->Load(mgmtEmbInfo, step, trainKeysSet); + // When in load and train mode or predict mode, SSD needs to actually execute loading + // When in the train and eval modes, loading before eval should be directly skipped + if (theTrainBatchId == 0) { + cacheManager->Load(mgmtEmbInfo, step, trainKeysSet); + } } LOG_DEBUG(MGMT + "Finish host side load process"); @@ -501,6 +501,8 @@ void HybridMgmt::EvalTask(TaskType type) cvCheckSave.wait(checkSaveLocker, [this] { return !hybridMgmtBlock->IsNeedWaitSave() || mutexDestroy; }); if (hybridMgmtBlock->pythonBatchId[EVAL_CHANNEL_ID] >= hybridMgmtBlock->hybridBatchId[EVAL_CHANNEL_ID]) { + // Before waking the data process for training, Recover the backed-up training state + RecoverTrainStatus(); hybridMgmtBlock->Wake(TRAIN_CHANNEL_ID); } else { std::this_thread::sleep_for(SLEEP_MS); @@ -2213,3 +2215,34 @@ bool HybridMgmt::IsTrainAndEvalCase() } return alreadyTrainOnce && isChannelSwitchCase; } + +void HybridMgmt::BackUpTrainStatus() +{ + int channelID = TRAIN_CHANNEL_ID; + int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[channelID]; + if (theTrainBatchId == 0) { + return; + } + + LOG_INFO("On Estimator train and eval mode, start to backup train status, " + "current train batchId: {} .", theTrainBatchId); + // When in the train and eval mode of estimator, backup training states before loading. + EmbeddingMgmt::Instance()->BackUpTrainStatusBeforeLoad(); + + if (isL3StorageEnabled) { + cacheManager->BackUpTrainStatus(); + } + isBackUpTrainStatus = true; +} + +void HybridMgmt::RecoverTrainStatus() +{ + if (isBackUpTrainStatus) { + EmbeddingMgmt::Instance()->RecoverTrainStatus(); + } + + if (isL3StorageEnabled) { + cacheManager->RecoverTrainStatus(); + } + isBackUpTrainStatus = false; +} \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 57a7ddd16145e1a130edf45d14d0b4847d7ddcba..5f94c96dafd99f411aad51e981dda2a0045d014a 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -131,6 +131,10 @@ public: void ProcessEmbInfoL3Storage(const EmbBaseInfo& info, bool& remainBatchOut); + void BackUpTrainStatus(); + + void RecoverTrainStatus(); + GTEST_PRIVATE : bool mutexDestroy{false}; std::mutex lookUpAndSendBatchIdMtx; std::mutex receiveAndUpdateBatchIdMtx; @@ -219,6 +223,7 @@ private: bool isLoad{false}; bool isInitialized{false}; bool alreadyTrainOnce = false; // 用于判断是否为predict模式 + bool isBackUpTrainStatus = false; // whether the train state has been backed up map lookUpSwapInAddrsPushId; // 用于处理eos场景,当消费者追上生产者且长时间无上游数据,会触发eos map specialProcessStatus; diff --git a/src/core/l3_storage/cache_manager.cpp b/src/core/l3_storage/cache_manager.cpp index ee3d7bc54890353f0cdf2aa2280bb357a32f728c..7ea68e1439103cbbbce32e5d082570087b155a42 100644 --- a/src/core/l3_storage/cache_manager.cpp +++ b/src/core/l3_storage/cache_manager.cpp @@ -32,10 +32,10 @@ void CacheManager::Init(ock::ctr::EmbCacheManagerPtr embCachePtr, vectorembCache = std::move(embCachePtr); for (auto& emb : mgmtEmbInfo) { - EmbBaseInfo baseInfo {emb.ssdVocabSize, emb.ssdDataPath, false}; + EmbBaseInfo baseInfo {emb.ssdVocabSize, emb.ssdDataPath, false, emb.extEmbeddingSize}; embBaseInfos.emplace(emb.name, baseInfo); preProcessMapper[emb.name].Initialize(emb.name, emb.hostVocabSize, emb.ssdVocabSize); } @@ -293,3 +293,73 @@ void CacheManager::FetchL3StorageEmb2DDR(string tableName, uint32_t extEmbedding embeddingTaskStep++; evictWaitCond.notify_all(); } + +void CacheManager::BackUpTrainStatus() +{ + ddrKeyFreqMapBackUp = ddrKeyFreqMap; + excludeDDRKeyCountMapBackUp = excludeDDRKeyCountMap; +} + +void CacheManager::RecoverTrainStatus() +{ + for (const auto& pair: excludeDDRKeyCountMapBackUp) { + auto tableName = pair.first; + + std::vector ssdKeysBeforeEval; + std::vector ssdKeysAfterEval; + std::vector swapInKeys; + std::vector swapOutKeys; + + for (const auto& keyMap : pair.second) { + ssdKeysBeforeEval.push_back(keyMap.first); + } + for (const auto& keyMap : excludeDDRKeyCountMap[tableName]) { + ssdKeysAfterEval.push_back(keyMap.first); + } + + GetSwapInAndSwapOutKeys(ssdKeysBeforeEval, ssdKeysAfterEval, swapInKeys, swapOutKeys); + + // ddr <-> ssd + // ddr-> lookup address, ssd->insert embedding , ddr->remove embedding + vector swapInKeysAddr; + int rc = embCache->EmbeddingLookupAddrs(tableName, swapInKeys, swapInKeysAddr); + if (rc != 0) { + throw runtime_error("EmbeddingLookUpAddrs failed! error code: " + std::to_string(rc)); + } + auto extEmbeddingSize = embBaseInfos[tableName].extEmbeddingSize; + l3Storage->InsertEmbeddingsByAddr(tableName, swapInKeys, swapInKeysAddr, extEmbeddingSize); + rc = embCache->EmbeddingRemove(tableName, swapInKeys); + if (rc != 0) { + throw runtime_error("EmbeddingRemove failed! error code: " + std::to_string(rc)); + } + + // ssd->fetch embedding, ddr->EmbeddingUpdate, ssd->delete embedding + auto swapOutEmbeddings = l3Storage->FetchEmbeddings(tableName, swapOutKeys); + vector swapOutFlattenEmbeddings; + for (auto& emb : swapOutEmbeddings) { + swapOutFlattenEmbeddings.insert(swapOutFlattenEmbeddings.cend(), emb.cbegin(), emb.cend()); + } + rc = embCache->EmbeddingUpdate(tableName, swapOutKeys, swapOutFlattenEmbeddings.data()); + l3Storage->DeleteEmbeddings(tableName, swapOutKeys); + } + + ddrKeyFreqMap = ddrKeyFreqMapBackUp; + excludeDDRKeyCountMap = excludeDDRKeyCountMapBackUp; +} + +void CacheManager::GetSwapInAndSwapOutKeys(vector& ssdKeysBeforeEval, + vector& ssdKeysAfterEval, + vector& swapInKeys, vector& swapOutKeys) +{ + std::sort(ssdKeysBeforeEval.begin(), ssdKeysBeforeEval.end()); + std::sort(ssdKeysAfterEval.begin(), ssdKeysAfterEval.end()); + vector intersectionKeys; + std::set_intersection(ssdKeysBeforeEval.begin(), ssdKeysBeforeEval.end(), ssdKeysAfterEval.begin(), + ssdKeysAfterEval.end(), std::back_inserter(intersectionKeys)); + + std::set_difference(ssdKeysBeforeEval.begin(), ssdKeysBeforeEval.end(), intersectionKeys.begin(), + intersectionKeys.end(), std::back_inserter(swapInKeys)); + std::set_difference(ssdKeysAfterEval.begin(), ssdKeysAfterEval.end(), intersectionKeys.begin(), + intersectionKeys.end(), std::back_inserter(swapOutKeys)); +} + diff --git a/src/core/l3_storage/cache_manager.h b/src/core/l3_storage/cache_manager.h index 79335788b615138d68ab497567d8f57a61c10896..34e7f0c24a758856fb4c5e4ed16a36fc6bb3bdb0 100644 --- a/src/core/l3_storage/cache_manager.h +++ b/src/core/l3_storage/cache_manager.h @@ -107,10 +107,20 @@ namespace MxRec { int64_t GetTableUsage(const string& tableName); + void BackUpTrainStatus(); + + void RecoverTrainStatus(); + + void GetSwapInAndSwapOutKeys(vector& ssdKeysBeforeEval, + vector& ssdKeysAfterEval, + vector& swapInKeys, vector& swapOutKeys); + // DDR内每个表中emb数据频次缓存;map unordered_map ddrKeyFreqMap; + unordered_map ddrKeyFreqMapBackUp; // 每张表中非DDR内key的出现次数 unordered_map> excludeDDRKeyCountMap; + unordered_map> excludeDDRKeyCountMapBackUp; // 每一个table对应一个PreProcessMapper,预先推演HBM->DDR的情况 std::unordered_map preProcessMapper; @@ -125,6 +135,7 @@ namespace MxRec { uint64_t maxTableSize; vector savePath; bool isExist; + int extEmbeddingSize; }; void CreateL3StorageTableIfNotExist(const std::string& embTableName); diff --git a/src/core/ock_ctr_common/include/embedding_cache.h b/src/core/ock_ctr_common/include/embedding_cache.h index f3bc9e23adb0fec9ea682bbac0f412bec31dae83..ce807f160e17b2ef5917e7386c3f88ae2ce3f95d 100644 --- a/src/core/ock_ctr_common/include/embedding_cache.h +++ b/src/core/ock_ctr_common/include/embedding_cache.h @@ -315,6 +315,26 @@ public: virtual int LoadEmbTableInfos(std::string tableName, const std::vector& keys, const std::vector>& embeddings, const std::vector>& optimizerSlots) = 0; + + /* * + * When switch the channel to eval, backup the current table's offsetMapper object. + * @Param tableName: embedding table name + * @Return errorCode + */ + virtual int BackUpTrainStatus(const std::string& tableName) = 0; + + /* * + * When switch the eval channel back to train, Recover the current table's offsetMapper object to the backup state. + * @Param tableName: embedding table name + * @Return errorCode + */ + virtual int RecoverTrainStatus(const std::string& tableName) = 0; + + /* * + * Reset the offsetMapper object to revert to its initialized state after loading. + * @Return errorCode + */ + virtual int ResetOffsetMappers() = 0; }; } // namespace EmbCache