From 22279840eed7150fd443d023451006739a57accc Mon Sep 17 00:00:00 2001 From: public ostmssso Date: Tue, 30 Jan 2024 14:58:39 +0800 Subject: [PATCH 001/551] Add README.md --- README.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 00000000..02cb5989 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# mxRec + +null \ No newline at end of file -- Gitee From b2f605f5c9aafd0a88e3c59d2e973302351c5a02 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 1 Apr 2022 17:48:05 +0800 Subject: [PATCH 002/551] Match-id-a31264ee4e866f7c3c3d4547d1331f97f5bcf0d8 --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 02cb5989..086c739d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,2 @@ # mxRec -null \ No newline at end of file -- Gitee From 53eb11c802717fc1900eb507356d2aa0ff006e4b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 15 May 2023 14:21:30 +0800 Subject: [PATCH 003/551] Match-id-a9b1e3bdda4ae93624a9266885724793763063aa --- .gitmodules | 9 + build/build.sh | 190 ++++ example/__init__.py | 0 example/little_demo/config.py | 112 ++ example/little_demo/dataset.py | 60 ++ example/little_demo/main.py | 248 +++++ example/little_demo/model.py | 93 ++ example/little_demo/op_impl_mode.ini | 3 + example/little_demo/optimizer.py | 22 + example/little_demo/random_data_generator.py | 52 + example/little_demo/run.sh | 46 + mx_rec/__init__.py | 9 + mx_rec/core/__init__.py | 0 mx_rec/core/asc/__init__.py | 1 + mx_rec/core/asc/build_graph.py | 142 +++ mx_rec/core/asc/feature_spec.py | 173 ++++ mx_rec/core/asc/helper.py | 453 ++++++++ mx_rec/core/asc/manager.py | 195 ++++ mx_rec/core/embedding.py | 826 +++++++++++++++ mx_rec/graph/__init__.py | 1 + mx_rec/graph/modifier.py | 425 ++++++++ mx_rec/graph/patch.py | 29 + mx_rec/graph/utils.py | 91 ++ mx_rec/optimizers/__init__.py | 0 mx_rec/optimizers/adagrad.py | 119 +++ mx_rec/optimizers/base.py | 38 + mx_rec/optimizers/ftrl.py | 244 +++++ mx_rec/optimizers/ftrl_t.py | 255 +++++ mx_rec/optimizers/ftrl_t_dense.py | 191 ++++ mx_rec/optimizers/gradient_descent.py | 55 + mx_rec/optimizers/gradient_descent_by_addr.py | 184 ++++ mx_rec/optimizers/lazy_adam.py | 184 ++++ mx_rec/optimizers/lazy_adam_by_addr.py | 307 ++++++ mx_rec/optimizers/momentum.py | 130 +++ mx_rec/saver/__init__.py | 0 mx_rec/saver/patch.py | 352 +++++++ mx_rec/saver/saver.py | 295 ++++++ mx_rec/util/__init__.py | 5 + mx_rec/util/atomic.py | 27 + mx_rec/util/constants.py | 91 ++ mx_rec/util/initialize.py | 617 +++++++++++ mx_rec/util/log.py | 22 + mx_rec/util/ops.py | 18 + mx_rec/util/perf.py | 15 + mx_rec/util/synchronizer.py | 76 ++ mx_rec/util/tf_version_adapter.py | 12 + mx_rec/util/variable.py | 47 + setup.py | 35 + src/CMakeLists.txt | 111 ++ src/build.sh | 29 + src/core/CMakeLists.txt | 51 + src/core/checkpoint/checkpoint.cpp | 459 ++++++++ src/core/checkpoint/checkpoint.h | 98 ++ .../ckpt_data_handler/ckpt_data_handler.cpp | 32 + .../ckpt_data_handler/ckpt_data_handler.h | 61 ++ .../emb_hash_ckpt/emb_hash_ckpt.cpp | 164 +++ .../emb_hash_ckpt/emb_hash_ckpt.h | 59 ++ .../feat_admit_n_evict_ckpt.cpp | 176 ++++ .../feat_admit_n_evict_ckpt.h | 67 ++ .../host_emb_ckpt/host_emb_ckpt.cpp | 143 +++ .../host_emb_ckpt/host_emb_ckpt.h | 59 ++ .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp | 89 ++ .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.h | 43 + .../nddr_offset_ckpt/nddr_offset_ckpt.cpp | 63 ++ .../nddr_offset_ckpt/nddr_offset_ckpt.h | 41 + src/core/emb_hashmap/emb_hashmap.cpp | 275 +++++ src/core/emb_hashmap/emb_hashmap.h | 66 ++ src/core/emb_mgmt/emb_mgmt.cpp | 574 ++++++++++ src/core/emb_mgmt/emb_mgmt.h | 129 +++ src/core/emb_table/emb_table.cpp | 241 +++++ src/core/emb_table/emb_table.h | 95 ++ src/core/hd_transfer/acl_channel.h | 34 + src/core/hd_transfer/hd_transfer.cpp | 197 ++++ src/core/hd_transfer/hd_transfer.h | 90 ++ src/core/host_emb/host_emb.cpp | 259 +++++ src/core/host_emb/host_emb.h | 64 ++ .../constant_initializer.cpp | 28 + .../constant_initializer.h | 32 + src/core/initializer/initializer.cpp | 6 + src/core/initializer/initializer.h | 26 + .../random_normal_initializer.cpp | 33 + .../random_normal_initializer.h | 39 + .../truncated_normal_initializer.cpp | 42 + .../truncated_normal_initializer.h | 43 + .../key_process/feature_admit_and_evict.cpp | 317 ++++++ .../key_process/feature_admit_and_evict.h | 102 ++ src/core/key_process/key_process.cpp | 980 ++++++++++++++++++ src/core/key_process/key_process.h | 167 +++ src/core/utils/common.cpp | 107 ++ src/core/utils/common.h | 433 ++++++++ src/core/utils/safe_queue.h | 152 +++ src/core/utils/singleton.h | 53 + src/core/utils/spinlock.h | 175 ++++ src/core/utils/task_queue.h | 103 ++ src/core/utils/time_cost.h | 39 + src/core/utils/unique.h | 774 ++++++++++++++ src/ops_tf/CMakeLists.txt | 7 + src/ops_tf/hybrid_dataset_ops.cpp | 766 ++++++++++++++ src/ops_tf/tf_ops.h | 12 + src/pybind/CMakeLists.txt | 6 + src/pybind/module_main.cpp | 153 +++ src/pybind/module_main.h | 11 + src/test_ut.sh | 71 ++ src/tests/CMakeLists.txt | 62 ++ src/tests/checkpoint/checkpoint_test.cpp | 449 ++++++++ .../ckpt_data_handler_test.cpp | 225 ++++ src/tests/emb_mgmt/emb_mgmt_test.cpp | 226 ++++ src/tests/emb_table/emb_table_test.cpp | 129 +++ src/tests/gtest_main.cpp | 22 + src/tests/host_emb/host_emb_test.cpp | 53 + src/tests/initializer/initializer_test.cpp | 99 ++ .../feature_admit_and_evict_test.cpp | 477 +++++++++ src/tests/key_process/key_process_test.cpp | 373 +++++++ tools/python/key_2_emb_formatter.py | 202 ++++ 114 files changed, 17262 insertions(+) create mode 100644 .gitmodules create mode 100644 build/build.sh create mode 100644 example/__init__.py create mode 100644 example/little_demo/config.py create mode 100644 example/little_demo/dataset.py create mode 100644 example/little_demo/main.py create mode 100644 example/little_demo/model.py create mode 100644 example/little_demo/op_impl_mode.ini create mode 100644 example/little_demo/optimizer.py create mode 100644 example/little_demo/random_data_generator.py create mode 100644 example/little_demo/run.sh create mode 100644 mx_rec/__init__.py create mode 100644 mx_rec/core/__init__.py create mode 100644 mx_rec/core/asc/__init__.py create mode 100644 mx_rec/core/asc/build_graph.py create mode 100644 mx_rec/core/asc/feature_spec.py create mode 100644 mx_rec/core/asc/helper.py create mode 100644 mx_rec/core/asc/manager.py create mode 100644 mx_rec/core/embedding.py create mode 100644 mx_rec/graph/__init__.py create mode 100644 mx_rec/graph/modifier.py create mode 100644 mx_rec/graph/patch.py create mode 100644 mx_rec/graph/utils.py create mode 100644 mx_rec/optimizers/__init__.py create mode 100644 mx_rec/optimizers/adagrad.py create mode 100644 mx_rec/optimizers/base.py create mode 100644 mx_rec/optimizers/ftrl.py create mode 100644 mx_rec/optimizers/ftrl_t.py create mode 100644 mx_rec/optimizers/ftrl_t_dense.py create mode 100644 mx_rec/optimizers/gradient_descent.py create mode 100644 mx_rec/optimizers/gradient_descent_by_addr.py create mode 100644 mx_rec/optimizers/lazy_adam.py create mode 100644 mx_rec/optimizers/lazy_adam_by_addr.py create mode 100644 mx_rec/optimizers/momentum.py create mode 100644 mx_rec/saver/__init__.py create mode 100644 mx_rec/saver/patch.py create mode 100644 mx_rec/saver/saver.py create mode 100644 mx_rec/util/__init__.py create mode 100644 mx_rec/util/atomic.py create mode 100644 mx_rec/util/constants.py create mode 100644 mx_rec/util/initialize.py create mode 100644 mx_rec/util/log.py create mode 100644 mx_rec/util/ops.py create mode 100644 mx_rec/util/perf.py create mode 100644 mx_rec/util/synchronizer.py create mode 100644 mx_rec/util/tf_version_adapter.py create mode 100644 mx_rec/util/variable.py create mode 100644 setup.py create mode 100644 src/CMakeLists.txt create mode 100644 src/build.sh create mode 100644 src/core/CMakeLists.txt create mode 100644 src/core/checkpoint/checkpoint.cpp create mode 100644 src/core/checkpoint/checkpoint.h create mode 100644 src/core/ckpt_data_handler/ckpt_data_handler.cpp create mode 100644 src/core/ckpt_data_handler/ckpt_data_handler.h create mode 100644 src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp create mode 100644 src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h create mode 100644 src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp create mode 100644 src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h create mode 100644 src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp create mode 100644 src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h create mode 100644 src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp create mode 100644 src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h create mode 100644 src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp create mode 100644 src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h create mode 100644 src/core/emb_hashmap/emb_hashmap.cpp create mode 100644 src/core/emb_hashmap/emb_hashmap.h create mode 100644 src/core/emb_mgmt/emb_mgmt.cpp create mode 100644 src/core/emb_mgmt/emb_mgmt.h create mode 100644 src/core/emb_table/emb_table.cpp create mode 100644 src/core/emb_table/emb_table.h create mode 100644 src/core/hd_transfer/acl_channel.h create mode 100644 src/core/hd_transfer/hd_transfer.cpp create mode 100644 src/core/hd_transfer/hd_transfer.h create mode 100644 src/core/host_emb/host_emb.cpp create mode 100644 src/core/host_emb/host_emb.h create mode 100644 src/core/initializer/constant_initializer/constant_initializer.cpp create mode 100644 src/core/initializer/constant_initializer/constant_initializer.h create mode 100644 src/core/initializer/initializer.cpp create mode 100644 src/core/initializer/initializer.h create mode 100644 src/core/initializer/random_normal_initializer/random_normal_initializer.cpp create mode 100644 src/core/initializer/random_normal_initializer/random_normal_initializer.h create mode 100644 src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp create mode 100644 src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h create mode 100644 src/core/key_process/feature_admit_and_evict.cpp create mode 100644 src/core/key_process/feature_admit_and_evict.h create mode 100644 src/core/key_process/key_process.cpp create mode 100644 src/core/key_process/key_process.h create mode 100644 src/core/utils/common.cpp create mode 100644 src/core/utils/common.h create mode 100644 src/core/utils/safe_queue.h create mode 100644 src/core/utils/singleton.h create mode 100644 src/core/utils/spinlock.h create mode 100644 src/core/utils/task_queue.h create mode 100644 src/core/utils/time_cost.h create mode 100644 src/core/utils/unique.h create mode 100644 src/ops_tf/CMakeLists.txt create mode 100644 src/ops_tf/hybrid_dataset_ops.cpp create mode 100644 src/ops_tf/tf_ops.h create mode 100644 src/pybind/CMakeLists.txt create mode 100644 src/pybind/module_main.cpp create mode 100644 src/pybind/module_main.h create mode 100644 src/test_ut.sh create mode 100644 src/tests/CMakeLists.txt create mode 100644 src/tests/checkpoint/checkpoint_test.cpp create mode 100644 src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp create mode 100644 src/tests/emb_mgmt/emb_mgmt_test.cpp create mode 100644 src/tests/emb_table/emb_table_test.cpp create mode 100644 src/tests/gtest_main.cpp create mode 100644 src/tests/host_emb/host_emb_test.cpp create mode 100644 src/tests/initializer/initializer_test.cpp create mode 100644 src/tests/key_process/feature_admit_and_evict_test.cpp create mode 100644 src/tests/key_process/key_process_test.cpp create mode 100644 tools/python/key_2_emb_formatter.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..4b398c49 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "src/thirdparty/googletest"] + path = src/thirdparty/googletest + url = https://codehub-dg-y.huawei.com/OpenSourceCenter/googletest.git +[submodule "src/thirdparty/spdlog"] + path = src/thirdparty/spdlog + url = https://codehub-dg-y.huawei.com/OpenSourceCenter/spdlog.git +[submodule "src/thirdparty/pybind11"] + path = src/thirdparty/pybind11 + url = https://codehub-dg-y.huawei.com/OpenSourceCenter/pybind11.git diff --git a/build/build.sh b/build/build.sh new file mode 100644 index 00000000..4f3de7ec --- /dev/null +++ b/build/build.sh @@ -0,0 +1,190 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. +# Description: build script. +# Author: MindX SDK +# Create: 2022 +# History: NA + +set -e +warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } +ARCH="$(uname -m)" +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +cd "$SCRIPT_DIR" +if [ "$(uname -m)" = "aarch64" ] +then + pip3 install virtualenv --force-reinstall + virtualenv -p "$(which python3.7)" tf2_env + source tf2_env/bin/activate + [ ! -f tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl ] && wget --no-check-certificate https://cmc-szver-artifactory.cmc.tools.huawei.com/artifactory/cmc-software-release/MindX/mindx_img_tools/1.0.0/tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl + pip3 install tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl --no-deps + pip3 install setuptools==49.2.1 + tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow + deactivate tf2_env + virtualenv -p "$(which python3.7)" tf1_env + source tf1_env/bin/activate + [ ! -f tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl ] && wget --no-check-certificate https://cmc-szver-artifactory.cmc.tools.huawei.com/artifactory/cmc-software-release/MindX/mindx_img_tools/1.0.0/tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl + pip3 install tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl --no-deps + pip3 install setuptools==49.2.1 + tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core + deactivate tf1_env +fi + +if [ "$(uname -m)" = "x86_64" ] +then + pip3 install virtualenv --force-reinstall + virtualenv -p "$(which python3.7)" tf2_env + source tf2_env/bin/activate + [ ! -f tensorflow-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl ] && wget --no-check-certificate https://cmc-hgh-artifactory.cmc.tools.huawei.com/artifactory/opensource_general/Tensorflow/2.6.5/package/tensorflow-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl + pip3 install tensorflow-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl --no-deps + pip3 install setuptools==49.2.1 + tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow + deactivate tf2_env + virtualenv -p "$(which python3.7)" tf1_env + source tf1_env/bin/activate + [ ! -f tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl ] && wget --no-check-certificate https://cmc-szver-artifactory.cmc.tools.huawei.com/artifactory/cmc-software-release/MindX/mindx_img_tools/1.0.0/tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl + pip3 install tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl --no-deps + pip3 install setuptools==49.2.1 + tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core + deactivate tf1_env +fi + +VERSION_FILE=${SCRIPT_DIR}/../../mindxsdk/build/conf/config.yaml +get_version() { + if [ -f "$VERSION_FILE" ]; then + VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") + if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then + VERSION=${VERSION%.*} + fi + else + VERSION="5.0.T104" + fi +} + +project_root_folder=${SCRIPT_DIR}/.. +project_output_path=${project_root_folder}/output/ +rm -rf "${project_output_path}" +rm -rf "${SCRIPT_DIR}/lib" +get_version +export VERSION +echo "MindX SDK mxrec: ${VERSION}" >> ./version.info + +pkg_dir=mindxsdk-mxrec +[ -d ${pkg_dir} ] && rm -rf ${pkg_dir} +mkdir ${pkg_dir} +mv version.info ${pkg_dir} + +opensource_path=${project_root_folder}/../opensource/opensource +abseil_src_path=${opensource_path}/abseil +echo "${abseil_src_path}" +abseil_install_path=${project_root_folder}/install/abseil + +src_path=${project_root_folder}/src + +cd "${project_root_folder}" + +release_tar=Ascend-${pkg_dir}-${VERSION}-linux-${ARCH}.tar.gz + +install_abseil() +{ + rm -rf "${abseil_install_path}" + echo "${abseil_install_path}" + if [[ ! -d ${abseil_install_path} ]] + then mkdir -p "${abseil_install_path}" + fi + + cd "${abseil_src_path}" + echo "${abseil_src_path}" + cmake -DCMAKE_INSTALL_PREFIX="${abseil_install_path}" . && make -j8 && make install + + echo "${project_output_path}"/abseil + mkdir -p "${project_output_path}"/abseil + if [ -d "${abseil_install_path}"/lib64/ ]; then + cp -rf "${abseil_install_path}"/lib64/libabsl* "${project_output_path}"/abseil + elif [ -d "${abseil_install_path}"/lib/ ]; then + cp -rf "${abseil_install_path}"/lib/libabsl* "${project_output_path}"/abseil + else + echo "${abseil_install_path}"/lib64/ not exist + exit 1 + fi +} + +compile_securec() +{ + if [[ ! -d ${project_root_folder}/platform/securec ]]; then + echo "securec is not exist" + exit 1 + fi + + if [[ ! -f ${project_root_folder}/platform/securec/lib/libsecurec.so ]]; then + cd ${project_root_folder}/platform/securec/src + make -j + fi +} + +compile_so_file() +{ + cd "${src_path}" + chmod u+x build.sh + ./build.sh "$1" "${project_root_folder}" + cd .. +} + +collect_so_file() +{ + cd "${src_path}" + rm -rf "${src_path}"/libasc + mkdir -p "${src_path}"/libasc + chmod u+x libasc + + cp -df "${project_root_folder}"/output/*.so* libasc + cp ${project_root_folder}/platform/securec/lib/libsecurec.so libasc +} + +gen_wheel_file() +{ + cd "${project_root_folder}" + touch "${src_path}"/libasc/__init__.py + [ -d "${project_root_folder}"/mx_rec/libasc ] && rm -rf "${project_root_folder}"/mx_rec/libasc + mv "${src_path}"/libasc "${project_root_folder}"/mx_rec + python3 setup.py bdist_wheel + mkdir -p "$1" + mv dist/mx_rec*.whl "$1" + rm -rf "${project_root_folder}"/mx_rec/libasc +} + +gen_tar_file() +{ + cd "${src_path}" + mv "${project_root_folder}"/tf1_whl ../build/${pkg_dir} + mv "${project_root_folder}"/tf2_whl ../build/${pkg_dir} + cp -r "${src_path}"/../example ../build/${pkg_dir} + cd ../build + tar -zvcf "${release_tar}" "${pkg_dir}" || { + warn "compression failed, packages might be broken" + } + + mv "${release_tar}" "${SCRIPT_DIR}"/../output/ + +} + +install_abseil +compile_securec + +echo "-----Build Start tf1 -----" +source "${SCRIPT_DIR}"/tf1_env/bin/activate +compile_so_file "${tf1_path}" +collect_so_file +gen_wheel_file "${project_root_folder}"/tf1_whl +deactivate tf1_env + +echo "-----Build Start tf2 -----" +source "${SCRIPT_DIR}"/tf2_env/bin/activate +compile_so_file "${tf2_path}" +collect_so_file +gen_wheel_file "${project_root_folder}"/tf2_whl +deactivate tf2_env + +echo "-----Build gen tar -----" +gen_tar_file + +echo "-----Done-----" diff --git a/example/__init__.py b/example/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/example/little_demo/config.py b/example/little_demo/config.py new file mode 100644 index 00000000..122b168f --- /dev/null +++ b/example/little_demo/config.py @@ -0,0 +1,112 @@ +# coding: UTF-8 +import logging +import math +import tensorflow as tf + +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig + +from mx_rec.util.initialize import get_rank_size + + +class Config: + def __init__(self, mode="simple", task_name="default"): + self.task_name = task_name + if mode == "simple": + self.generate_simple_config() + else: + self.generate_large_scale_config() + + def generate_simple_config(self): + self.batch_number = 8192 + self.batch_size = 4096 + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.float32 + + self.item_range = 80000 + self.user_range = 200000 + self.category_range = 5000 + self.item_feat_cnt = 16 + self.user_feat_cnt = 8 + self.category_feat_cnt = 3 + self.access_threshold = 100 + self.eviction_threshold = 60 + + rank_size = get_rank_size() + coefficient = 1.1 + if rank_size != 0: + max_ui_send_cnt = max(self.item_feat_cnt, self.user_feat_cnt) + max_ui_range = max(self.item_range, self.user_range) + self.item_send_cnt = min(int(self.batch_size * self.item_feat_cnt * coefficient), + math.ceil(self.item_range / rank_size)) + self.item_vocab_size = max(self.item_send_cnt * rank_size * rank_size, self.item_range) + self.user_send_cnt = min(int(self.batch_size * max_ui_send_cnt * coefficient), + math.ceil(max_ui_range / rank_size)) + self.user_vocab_size = max(self.user_send_cnt * rank_size * rank_size, self.user_range) + self.category_send_cnt = min(int(self.batch_size * self.category_feat_cnt * coefficient), + math.ceil(self.category_range / rank_size)) + else: + raise ZeroDivisionError("rank size must be an integer which is greater value zero.") + + self.user_hashtable_dim = 32 + self.user_hashtable_threshold = 1 + self.item_hashtable_dim = 8 + self.item_hashtable_threshold = 1 + + self.learning_rate = 0.01 + + def generate_large_scale_config(self): + self.lookup_count = 40 + self.tensor_name_list = ["sparse_tensor_%d" % i for i in range(self.lookup_count)] + self.hashtable_name_list = ["hashtable_%d" % i for i in range(self.lookup_count)] + self.batch_size = 9600 + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.float32 + + self.vocabulary_size = 500000 + self.feat_cnt = 1 + + rank_size = get_rank_size() + coefficient = 1.1 + if rank_size != 0: + self.send_cnt = min(int(self.batch_size * self.feat_cnt * coefficient), + math.ceil(self.vocabulary_size / rank_size)) + else: + raise ZeroDivisionError("rank size must be an integer which is greater value zero.") + + self.hashtable_dim = 8 + self.learning_rate = 0.01 + + +def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): + session_config = tf.compat.v1.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["mix_compile_mode"].b = False + custom_op.parameter_map["use_off_line"].b = True + custom_op.parameter_map["min_group_size"].b = 1 + custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:pairwise;level1:pairwise") + custom_op.parameter_map["enable_data_pre_proc"].b = True + custom_op.parameter_map["iterations_per_loop"].i = 1 + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + custom_op.parameter_map["hcom_parallel"].b = False + custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_impl_mode.ini") + + if dump_data: + """ + To see the details, please refer to the descriptions at official web site + """ + custom_op.parameter_map["enable_dump"].b = True + custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(dump_path) + custom_op.parameter_map["dump_step"].s = tf.compat.as_bytes(dump_steps) + custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all") + + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + + return session_config diff --git a/example/little_demo/dataset.py b/example/little_demo/dataset.py new file mode 100644 index 00000000..7a9cf82c --- /dev/null +++ b/example/little_demo/dataset.py @@ -0,0 +1,60 @@ +# coding: UTF-8 +import tensorflow as tf + +from random_data_generator import get_data_generator, get_large_scale_data_generator +from mx_rec.util.initialize import get_rank_size, get_rank_id, get_host_pipeline_ops + + +def generate_dataset(cfg, use_timestamp=False, batch_number=100): + dataset = tf.compat.v1.data.Dataset.from_generator( + generator=get_data_generator(cfg, batch_number=batch_number), + output_types={"item_ids": cfg.key_type, + "user_ids": cfg.key_type, + "category_ids": cfg.key_type, + "label_0": cfg.label_type, + "label_1": cfg.label_type}, + output_shapes={"item_ids": tf.TensorShape([cfg.batch_size, cfg.item_feat_cnt]), + "user_ids": tf.TensorShape([cfg.batch_size, cfg.user_feat_cnt]), + "category_ids": tf.TensorShape([cfg.batch_size, cfg.category_feat_cnt]), + "label_0": tf.TensorShape([cfg.batch_size]), + "label_1": tf.TensorShape([cfg.batch_size])}) + if use_timestamp: + dataset = dataset.map(add_timestamp_func) + + rank_size = get_rank_size() + rank_id = get_rank_id() + if rank_size > 1: + dataset = dataset.shard(rank_size, rank_id) + + return dataset + + +def add_timestamp_func(batch): + host_pipeline_ops = get_host_pipeline_ops() + timestamp = host_pipeline_ops.return_timestamp(tf.cast(batch['label_0'], tf.int64)) + batch["timestamp"] = timestamp + return batch + + +def generate_large_scale_data(cfg): + key_type_list = [cfg.key_type for _ in range(cfg.lookup_count)] + output_type_dict = dict(zip(cfg.tensor_name_list, key_type_list)) + output_type_dict["label_0"] = cfg.label_type + output_type_dict["label_1"] = cfg.label_type + + tensor_shape_list = [tf.TensorShape([cfg.batch_size]) for _ in range(cfg.lookup_count)] + output_shape_dict = dict(zip(cfg.tensor_name_list, tensor_shape_list)) + output_shape_dict["label_0"] = tf.TensorShape([cfg.batch_size]) + output_shape_dict["label_1"] = tf.TensorShape([cfg.batch_size]) + + dataset = tf.data.Dataset.from_generator(generator=get_large_scale_data_generator(cfg), + output_types=output_type_dict, + output_shapes=output_shape_dict) + rank_size = get_rank_size() + rank_id = get_rank_id() + if rank_size > 1: + dataset = dataset.shard(rank_size, rank_id) + + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + return batch, iterator diff --git a/example/little_demo/main.py b/example/little_demo/main.py new file mode 100644 index 00000000..671515d4 --- /dev/null +++ b/example/little_demo/main.py @@ -0,0 +1,248 @@ +# coding: UTF-8 +import logging +import os +import warnings +from glob import glob +import tensorflow as tf + +from config import sess_config, Config +from dataset import generate_dataset +from optimizer import get_dense_and_sparse_optimizer +from model import MyModel +from mx_rec.util.tf_version_adapter import hccl_ops +from mx_rec.core.asc.helper import FeatureSpec, get_asc_insert_func +from mx_rec.core.asc.manager import start_asc_pipeline +from mx_rec.core.embedding import create_table, sparse_lookup +from mx_rec.graph.modifier import modify_graph_and_start_emb_cache +from mx_rec.util.constants import MxRecMode, ASCEND_TIMESTAMP +from mx_rec.util.initialize import get_rank_id, get_rank_size, init, clear_channel, terminate_config_initializer, \ + set_if_load, get_initializer +from mx_rec.util.variable import get_dense_and_sparse_variable + +tf.compat.v1.disable_eager_execution() + + +def make_batch_and_iterator(is_training, use_timestamp=False, dump_graph=False, batch_number=100): + dataset = generate_dataset(cfg, use_timestamp=use_timestamp, batch_number=batch_number) + if not MODIFY_GRAPH_FLAG: + insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph) + dataset = dataset.map(insert_fn) + dataset = dataset.prefetch(100) + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + return batch, iterator + + +def model_forward(input_list, batch, is_train, modify_graph, config_dict=None): + embedding_list = [] + feature_list, hash_table_list, send_count_list = input_list + for feature, hash_table, send_count in zip(feature_list, hash_table_list, send_count_list): + access_and_evict_config = None + if isinstance(config_dict, dict): + access_and_evict_config = config_dict.get(hash_table.table_name) + embedding = sparse_lookup(hash_table, feature, send_count, dim=None, is_train=is_train, + access_and_evict_config=access_and_evict_config, + name=hash_table.table_name + "_lookup", modify_graph=modify_graph) + + reduced_embedding = tf.reduce_sum(embedding, axis=1, keepdims=False) + embedding_list.append(reduced_embedding) + + my_model = MyModel() + my_model(embedding_list, batch["label_0"], batch["label_1"]) + return my_model + + +def build_graph(hash_table_list, is_train, use_timestamp=False, config_dict=None, batch_number=100): + batch, iterator = make_batch_and_iterator(is_training=is_train, use_timestamp=use_timestamp, + dump_graph=is_train, batch_number=batch_number) + if MODIFY_GRAPH_FLAG: + input_list = [ + [batch["user_ids"], batch["item_ids"]], + [hash_table_list[0], hash_table_list[1]], + [cfg.user_send_cnt, cfg.item_send_cnt], + ] + if USE_TIMESTAMP: + tf.add_to_collection(ASCEND_TIMESTAMP, batch["timestamp"]) + model = model_forward(input_list, batch, + is_train=is_train, modify_graph=True, config_dict=config_dict) + else: + input_list = [ + [feature_spec for feature_spec in feature_spec_list], + [hash_table_list[0], hash_table_list[0], hash_table_list[0], hash_table_list[1]], + [cfg.user_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt, cfg.item_send_cnt], + ] + model = model_forward(input_list, batch, + is_train=is_train, modify_graph=False, config_dict=config_dict) + + return iterator, model + + +def evaluate(): + if MODIFY_GRAPH_FLAG: + sess.run(get_initializer(False)) + else: + sess.run(eval_iterator.initializer) + clear_channel(is_train_channel=False) + for j in range(1, EVAL_STEPS + 1): + logging.info(f"################ eval at step {j} epoch {EPOCH} ################") + try: + sess.run(eval_model.loss_list) + except tf.errors.OutOfRangeError: + logging.info(f"Encounter the end of Sequence for eval.") + break + + +if __name__ == "__main__": + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + warnings.filterwarnings("ignore") + + mode = MxRecMode.mapping(os.getenv("MXREC_MODE")) + TRAIN_INTERVAL = 100 + EVAL_STEPS = 10 + SAVING_INTERVAL = 100 + USE_TIMESTAMP = False + + # add dynamic expansion support + use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) + + # nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0 + init(use_mpi = bool(int(os.getenv("USE_MPI"))), + train_interval=TRAIN_INTERVAL, + eval_steps=EVAL_STEPS, + prefetch_batch_number=5, + use_dynamic=int(os.getenv("LITTLE_DEMO_USE_DYNAMIC", 0)), + use_hot=bool(int(os.getenv("USE_HOT", 0))), + use_dynamic_expansion=use_dynamic_expansion) + IF_LOAD = False + rank_id = get_rank_id() + filelist = glob(f"./saved-model/sparse-model-{rank_id}-0") + if filelist: + IF_LOAD = True + set_if_load(IF_LOAD) + + MODIFY_GRAPH_FLAG = False # ASC + use_MPI + modify_graph + + cfg = Config() + # access_threshold unit counts; eviction_threshold unit seconds + ACCESS_AND_EVICT = None + if USE_TIMESTAMP: + feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", + access_threshold=cfg.access_threshold, + eviction_threshold=cfg.eviction_threshold), + FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", + access_threshold=cfg.access_threshold, + eviction_threshold=cfg.eviction_threshold), + FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table", + access_threshold=cfg.access_threshold, + eviction_threshold=cfg.eviction_threshold), + FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table", + access_threshold=cfg.access_threshold, + eviction_threshold=cfg.eviction_threshold), + FeatureSpec("timestamp", is_timestamp=True)] + + config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) + config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) + ACCESS_AND_EVICT = dict(user_table=config_for_user_table, item_table=config_for_item_table) + + else: + feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table"), + FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table"), + FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table"), + FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table")] + + optimizer_list = [get_dense_and_sparse_optimizer(cfg) for _ in range(2)] + sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] + + user_hashtable = create_table(key_dtype=tf.int64, + dim=tf.TensorShape([cfg.user_hashtable_dim]), + name='user_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + device_vocabulary_size=cfg.user_vocab_size * 10, + host_vocabulary_size=0, # cfg.user_vocab_size * 100, # for h2d test + optimizer_list=sparse_optimizer_list, + mode=mode) + + item_hashtable = create_table(key_dtype=tf.int64, + dim=tf.TensorShape([cfg.item_hashtable_dim]), + name='item_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + device_vocabulary_size=cfg.item_vocab_size * 10, + host_vocabulary_size=0, # cfg.user_vocab_size * 100, # for h2d test + optimizer_list=sparse_optimizer_list, + mode=mode) + + train_iterator, train_model = build_graph([user_hashtable, item_hashtable], is_train=True, + use_timestamp=USE_TIMESTAMP, config_dict=ACCESS_AND_EVICT, + batch_number=cfg.batch_number) + eval_iterator, eval_model = build_graph([user_hashtable, item_hashtable], is_train=False, + use_timestamp=USE_TIMESTAMP, config_dict=ACCESS_AND_EVICT, + batch_number=cfg.batch_number) + dense_variables, sparse_variables = get_dense_and_sparse_variable() + + rank_size = get_rank_size() + train_ops = [] + # multi task training + for loss, (dense_optimizer, sparse_optimizer) in zip(train_model.loss_list, optimizer_list): + # do dense optimization + grads = dense_optimizer.compute_gradients(loss, var_list=dense_variables) + avg_grads = [] + for grad, var in grads: + if rank_size > 1: + grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None + if grad is not None: + avg_grads.append((grad, var)) + # apply gradients: update variables + train_ops.append(dense_optimizer.apply_gradients(avg_grads)) + + if use_dynamic_expansion: + from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET + train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + # do sparse optimization by addr + local_grads = tf.gradients(loss, train_emb_list) # local_embedding + grads_and_vars = [(grad, address) for grad, address in zip(local_grads, train_address_list)] + train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + else: + # do sparse optimization + sparse_grads = tf.gradients(loss, sparse_variables) + grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] + train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + + saver = tf.compat.v1.train.Saver() + if MODIFY_GRAPH_FLAG: + logging.info("start to modifying graph") + modify_graph_and_start_emb_cache(dump_graph=True) + else: + start_asc_pipeline() + + with tf.compat.v1.Session(config=sess_config(dump_data=False)) as sess: + if MODIFY_GRAPH_FLAG: + sess.run(get_initializer(True)) + else: + sess.run(train_iterator.initializer) + sess.run(tf.compat.v1.global_variables_initializer()) + EPOCH = 0 + if os.path.exists(f"./saved-model/sparse-model-{rank_id}-%d" % 0): + saver.restore(sess, f"./saved-model/model-{rank_id}-%d" % 0) + else: + saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0) + + for i in range(1, 201): + logging.info(f"################ training at step {i} ################") + try: + sess.run([train_ops, train_model.loss_list]) + except tf.errors.OutOfRangeError: + logging.info(f"Encounter the end of Sequence for training.") + break + else: + if i % TRAIN_INTERVAL == 0: + EPOCH += 1 + evaluate() + + if i % SAVING_INTERVAL == 0: + saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) + + saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) + + terminate_config_initializer() + logging.info("Demo done!") diff --git a/example/little_demo/model.py b/example/little_demo/model.py new file mode 100644 index 00000000..18ab98ca --- /dev/null +++ b/example/little_demo/model.py @@ -0,0 +1,93 @@ +# coding: UTF-8 +from __future__ import print_function + +import tensorflow as tf + + +class MyModel: + def __init__(self): + self.layer_dims = [1024, 512, 256, 128] + self.act_func = 'relu' + self.keep_prob = 0.8 + self._lambda = 4.91e-7 + self.emb_dim = None + self.loss_list = [] + self.predict_list = [] + self.all_layer_dims = None + self.h_w, self.h_b = [], [] + self.h_w_head_0, self.h_w_head_1, self.h_b_head_0, self.h_b_head_1 = None, None, None, None + + def __call__(self, embedding_list, label_0, label_1, is_training=True): + with tf.compat.v1.variable_scope("mlp", reuse=tf.compat.v1.AUTO_REUSE): + embedding = tf.concat(embedding_list, axis=1) + self.emb_dim = embedding.shape.as_list()[-1] + self.all_layer_dims = [self.emb_dim] + self.layer_dims + [1] + + with tf.compat.v1.variable_scope("mlp", reuse=tf.compat.v1.AUTO_REUSE): + for i in range(len(self.all_layer_dims) - 2): + self.h_w.append(tf.compat.v1.get_variable('h%d_w' % (i + 1), shape=self.all_layer_dims[i: i + 2], + initializer=tf.random_uniform_initializer(-0.01, 0.01), + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"])) + self.h_b.append( + tf.compat.v1.get_variable('h%d_b' % (i + 1), shape=[self.all_layer_dims[i + 1]], + initializer=tf.zeros_initializer, + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"])) + i += 1 + self.h_w_head_0 = tf.compat.v1.get_variable('h_w_head_0', shape=self.all_layer_dims[i: i + 2], + initializer=tf.random_uniform_initializer(-0.01, 0.01), + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"]) + self.h_b_head_0 = tf.compat.v1.get_variable('h_b_head_0', shape=[self.all_layer_dims[i + 1]], + initializer=tf.zeros_initializer, + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"]) + self.h_w_head_1 = tf.compat.v1.get_variable('h_w_head_1', shape=self.all_layer_dims[i: i + 2], + initializer=tf.random_uniform_initializer(-0.01, 0.01), + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"]) + self.h_b_head_1 = tf.compat.v1.get_variable('h_b_head_1', shape=[self.all_layer_dims[i + 1]], + initializer=tf.zeros_initializer, + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"]) + + logit_list = self.forward(embedding, self.act_func, self.keep_prob, training=is_training) + + for logit, label in zip(logit_list, (label_0, label_1)): + train_preds = tf.sigmoid(logit) + + basic_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=label) + + deep_loss = tf.reduce_mean(basic_loss) # + _lambda * tf.nn.l2_loss(embedding) + self.predict_list.append(train_preds) + self.loss_list.append(deep_loss) + + + def forward(self, embedding, act_func, keep_prob, training): + hidden_output = tf.reshape(embedding, [-1, self.emb_dim]) # *512 + for i, h_w_var in enumerate(self.h_w): + hidden_output = tf.matmul(self.activate(act_func, hidden_output), h_w_var) + hidden_output = hidden_output + self.h_b[i] + + def output_head(hidden_output, h_w, h_b): + hidden_output_branch = tf.matmul(self.activate(act_func, hidden_output), h_w) + logit = hidden_output_branch + h_b + logit = tf.reshape(logit, [-1, ]) + + return logit + + logit_0 = output_head(hidden_output, self.h_w_head_0, self.h_b_head_0) + logit_1 = output_head(hidden_output, self.h_w_head_1, self.h_b_head_1) + logit_list = [logit_0, logit_1] + + return logit_list + + @staticmethod + def activate(act_func, input_x): + if act_func == 'tanh': + return tf.tanh(input_x) + elif act_func == 'relu': + return tf.nn.relu(input_x) + else: + return tf.sigmoid(input_x) diff --git a/example/little_demo/op_impl_mode.ini b/example/little_demo/op_impl_mode.ini new file mode 100644 index 00000000..4a744500 --- /dev/null +++ b/example/little_demo/op_impl_mode.ini @@ -0,0 +1,3 @@ +ScatterNdAdd=support_out_of_bound_index +GatherV2=high_performance +UnsortedSegmentSum=high_performance \ No newline at end of file diff --git a/example/little_demo/optimizer.py b/example/little_demo/optimizer.py new file mode 100644 index 00000000..7a2e2a4b --- /dev/null +++ b/example/little_demo/optimizer.py @@ -0,0 +1,22 @@ +# coding: UTF-8 + +import logging + +import tensorflow as tf + +from mx_rec.optimizers.lazy_adam import create_hash_optimizer +from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address +from mx_rec.util.initialize import get_use_dynamic_expansion + + +def get_dense_and_sparse_optimizer(cfg): + dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate) + use_dynamic_expansion = get_use_dynamic_expansion() + if use_dynamic_expansion: + sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate) + logging.info("optimizer lazy_adam_by_addr") + else: + sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate) + logging.info("optimizer lazy_adam") + + return dense_optimizer, sparse_optimizer diff --git a/example/little_demo/random_data_generator.py b/example/little_demo/random_data_generator.py new file mode 100644 index 00000000..473ac690 --- /dev/null +++ b/example/little_demo/random_data_generator.py @@ -0,0 +1,52 @@ +# coding: UTF-8 +import logging + +import numpy as np + +from mx_rec.util.initialize import get_rank_id + + +def get_data_generator(config, batch_number): + rank_id = get_rank_id() + + def data_generator(): + i = 0 + while i < batch_number: + item_ids = np.random.randint(0, config.item_range, (config.batch_size, config.item_feat_cnt)) + user_ids = np.random.randint(0, config.user_range, (config.batch_size, config.user_feat_cnt)) + category_ids = np.random.randint(0, config.category_range, (config.batch_size, config.category_feat_cnt)) + label_0 = np.random.randint(0, 2, (config.batch_size,)) + label_1 = np.random.randint(0, 2, (config.batch_size,)) + + yield {"item_ids": item_ids, + "user_ids": user_ids, + "category_ids": category_ids, + "label_0": label_0, + "label_1": label_1} + i += 1 + + logging.debug(f"================ end of data generator for {config.task_name} task | rank id {rank_id} " + f"================") + + return data_generator + + +def get_large_scale_data_generator(config): + def data_generator(): + i = 0 + while True: + id_list = [np.random.randint(0, config.vocabulary_size, (config.batch_size,)) + for _ in range(config.lookup_count)] + + data_block = dict(zip(config.tensor_name_list, id_list)) + + label_0 = np.random.randint(0, 2, (config.batch_size,)) + label_1 = np.random.randint(0, 2, (config.batch_size,)) + data_block["label_0"] = label_0 + data_block["label_1"] = label_1 + + logging.debug(f"================ generate NO.{i} step ================") + yield data_block + i += 1 + + return data_generator diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh new file mode 100644 index 00000000..abbe8cae --- /dev/null +++ b/example/little_demo/run.sh @@ -0,0 +1,46 @@ +kill -9 `ps -ef | grep python | grep -v grep | awk '{print $2}'` > /dev/null 2>&1 +rm -rf /root/ascend/log/* +rm -rf ./saved-model/* +rm -rf ./kernel* +rm -rf ./export_graph/* + +cur_path=`pwd` +mx_rec_package_path="/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec" # please config +so_path=${mx_rec_package_path}/libasc +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x SPDLOG_LEVEL=debug -bind-to none' +interface="lo" +local_rank_size=8 # 每个节点使用的NPU卡数 +num_server=1 # 训练节点数 +num_process=$((${num_server} * ${local_rank_size})) # 训练总的进程数,等于使用的NPU卡的总数 + +export HCCL_CONNECT_TIMEOUT=1200 # HCCL集合通信 建链超时时间,取值范围[120,7200] +export PYTHONPATH=${so_path}:$PYTHONPATH # 环境python安装路径 +export LD_PRELOAD=/usr/lib64/libgomp.so.1 # GNU OpenMP动态库路径 +export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH +# 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 +export RANK_TABLE_FILE="${cur_path}/hccl_json_8p.json" # 若使用去除ranktable方案,请注释掉这一行 +export JOB_ID=10086 +# 训练任务使用的NPU卡数总数 +export RANK_SIZE=$num_process # 若使用去除ranktable方案,请注释掉这一行 +export MXREC_LOG_LEVEL="DEBUG" # 框架日志等级 +export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL +# 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 +export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL +export MXREC_MODE="ASC" +export USE_MPI=1 +export USE_DYNAMIC=0 # 0: 静态;1:动态 + +#################使用去除ranktable方案时开启###################### +#export CM_CHIEF_IP="192.168.1.1" # 主节点ip +#export CM_CHIEF_PORT=6000 # 主节点监听端口 +#export CM_CHIEF_DEVICE=0 # 主节点device id +#export CM_WORKER_IP="192.168.1.1" # 当前节点ip +#export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 +######################################################### + +py=$1 +echo "py is $py" +echo "use horovod to start tasks" +horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ +python3.7 ${py} 2>&1 | tee temp.log + diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py new file mode 100644 index 00000000..3ffb468a --- /dev/null +++ b/mx_rec/__init__.py @@ -0,0 +1,9 @@ +# coding: UTF-8 +from .util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION +from .util.tf_version_adapter import npu_ops, hccl_ops +from .saver.patch import patch_for_saver +from .graph.patch import patch_for_dataset + + +patch_for_saver() +patch_for_dataset() diff --git a/mx_rec/core/__init__.py b/mx_rec/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mx_rec/core/asc/__init__.py b/mx_rec/core/asc/__init__.py new file mode 100644 index 00000000..e9754a0d --- /dev/null +++ b/mx_rec/core/asc/__init__.py @@ -0,0 +1 @@ +# coding: UTF-8 diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py new file mode 100644 index 00000000..e081c798 --- /dev/null +++ b/mx_rec/core/asc/build_graph.py @@ -0,0 +1,142 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Huawei Technologies Co., Ltd + +import logging + +import tensorflow as tf + +import mxrec_pybind +from mx_rec.util.constants import AVOID_TENSOR_POS +from mx_rec.util.initialize import get_use_static +from mx_rec.util.tf_version_adapter import npu_ops + + +def get_restore_vector(config): + logging.debug(f'Channel {config.get("table_name")}_restore_{config.get("channel_id")} was built for getnext') + emb_size = None + if config.get("skip_emb_transfer"): + if not isinstance(config.get("_emb_size"), int) or config.get("_emb_size") < 1: + raise TypeError(f"_emb_size must be a int") + if config.get("_emb_size") < 1: + raise ValueError(f"_emb_size is less than 1") + emb_size = config.get("_emb_size") + else: + if not isinstance(config.get("ext_emb_size"), int) or config.get("ext_emb_size") < 1: + raise TypeError(f"ext_emb_size must be a int") + if config.get("ext_emb_size") < 1: + raise ValueError(f"ext_emb_size is less than 1") + emb_size = config.get("ext_emb_size") + + use_hot = config.get("use_hot") + hot_pos = None + + if get_use_static(): + restore_size = config.get("batch_size") * config.get("feat_cnt") + else: + restore_size = None + + if use_hot: + device_id = int(config.get("device_id")) + hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) + restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[restore_size, [hot_size]], + channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}') + else: + restore_vector = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[restore_size], + channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}')[0] + + return restore_vector, hot_pos + + +def get_id_offsets(max_lookup_vec_size, config): + logging.debug(f'Channel {config.get("table_name")}_lookup_{config.get("channel_id")} was built for getnext') + # 自动扩容当前只支持HBM模式,默认没有换入换出 + if config.get("use_dynamic_expansion"): + [id_offsets] = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int64], + output_shapes=[[max_lookup_vec_size]], + channel_name=f'{config.get("table_name")}_lookup_{config.get("channel_id")}') + return id_offsets, [], 0 + + [id_offsets] = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[[max_lookup_vec_size]], + channel_name=f'{config.get("table_name")}_lookup_{config.get("channel_id")}') + if config.get("skip_emb_transfer"): + return id_offsets, [], 0 + swap_pos, swap_len = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[[max_lookup_vec_size], []], + channel_name=f'{config.get("table_name")}_swap_{config.get("channel_id")}') + return id_offsets, swap_pos, swap_len + + +def get_all2all_args(use_static: bool, config: dict) -> list: + """ + Get all2all parameters for dynamic condition + :param use_static: dynamic or static + :param config: embedding config + :return: all2all parametrs + """ + all2all_args = None + if not use_static: + with tf.compat.v1.variable_scope("all2all"): + logging.debug(f'Channel {config.get("table_name")}_a2a_{config.get("channel_id")} was built for getnext') + all2all_args = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int64], + output_shapes=[[config.get("rank_size"), config.get("rank_size")]], + channel_name=f'{config.get("table_name")}_all2all_{config.get("channel_id")}', + name="a2a_get_next")[0] * config.get("_emb_size") + + return all2all_args + + +def get_preprocessed_tensor_for_asc(table, config): + use_static = get_use_static() + max_lookup_vec_size = None + if use_static: + max_lookup_vec_size = config.get("send_count") * config.get("rank_size") + + with tf.compat.v1.variable_scope("restore_vector"): + restore_vector, hot_pos = get_restore_vector(config) + + with tf.compat.v1.variable_scope("id_offsets"): + id_offsets, swap_pos, swap_len = get_id_offsets(max_lookup_vec_size, config) + + all2all_args = get_all2all_args(use_static, config) + + if config.get("skip_emb_transfer"): + swap_in = [tf.no_op()] + else: + with tf.compat.v1.variable_scope("h2d_emb"): + logging.debug(f'Channel {config.get("table_name")}_h2d_{config.get("channel_id")} was built for getnext') + h2d_emb = npu_ops.gen_npu_ops.get_next( + output_types=[tf.float32], + output_shapes=[[max_lookup_vec_size, config.get("ext_emb_size")]], + channel_name=f'{config.get("table_name")}_h2d_{config.get("channel_id")}')[0] + logging.debug(f"h2d_emb shape: {h2d_emb}") + if not isinstance(table, list): + raise RuntimeError("When enable emb_transfer, optimizer should have slots") + if use_static: + swap_pos = swap_pos[0:swap_len] + h2d_emb = h2d_emb[0:swap_len, :] + swap_outs = [tf.gather(one_table, swap_pos) for one_table in table] + swap_out = tf.concat(swap_outs, axis=1) + logging.debug( + f'Channel {config.get("table_name")}_d2h_{config.get("channel_id")} was built for op outfeed.') + swap_out_op = npu_ops.outfeed_enqueue_op( + channel_name=f'{config.get("table_name")}_d2h_{config.get("channel_id")}', inputs=[swap_out]) + with tf.control_dependencies([swap_out_op]): + # fix empty nd update + swap_pos = tf.concat([swap_pos, tf.constant([AVOID_TENSOR_POS])], axis=0) + h2d_emb = tf.concat([h2d_emb, tf.constant([[0.1] * config.get("ext_emb_size")])], axis=0) + nd_swap_pos = tf.expand_dims(swap_pos, 1) + table_num = len(table) + h2d_emb_split = tf.split(h2d_emb, table_num, axis=1) + swap_in = [tf.compat.v1.scatter_nd_update(table[i], nd_swap_pos, h2d_emb_split[i]) + for i in range(len(table))] + return restore_vector, hot_pos, id_offsets, swap_in, all2all_args diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py new file mode 100644 index 00000000..0c680027 --- /dev/null +++ b/mx_rec/core/asc/feature_spec.py @@ -0,0 +1,173 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Huawei Technologies Co., Ltd +import logging +from functools import reduce + +import tensorflow as tf + +from mx_rec.util.atomic import AtomicInteger +from mx_rec.util.initialize import insert_feature_spec, insert_training_mode_channel_id, get_use_static + +feature_spec_global_id = AtomicInteger() + + +class FeatureSpec: + instance_count_train = 0 + instance_count_eval = 0 + use_timestamp_train = False + use_timestamp_eval = False + + def __init__(self, name, **kwargs): + feature_spec_global_id.increase() + spec_name = name + f"_{feature_spec_global_id}" + self.name = spec_name + self._index_key = kwargs.get("index_key") if kwargs.get("index_key") else name + self._table_name = kwargs.get("table_name") if kwargs.get("table_name") else name + self._feat_cnt = kwargs.get("feat_count") + self._access_threshold = kwargs.get("access_threshold") + self._eviction_threshold = kwargs.get("eviction_threshold") + self._is_timestamp = kwargs.get("is_timestamp") + self.feat_pos_train = None + self.feat_pos_eval = None + self.dims = None + self.rank = None + self.batch_size = kwargs.get("batch_size") + self.split = None # usually split == batch_size * feature_count + self.initialized = False + self._pipeline_mode = set() + self.check_params() + + @staticmethod + def include_timestamp(is_training): + if is_training: + if FeatureSpec.use_timestamp_train: + raise EnvironmentError(f"Timestamp was set twice for training mode.") + FeatureSpec.use_timestamp_train = True + else: + FeatureSpec.use_timestamp_eval = True + + @staticmethod + def use_timestamp(is_training): + return FeatureSpec.use_timestamp_train if is_training else FeatureSpec.use_timestamp_eval + + @property + def is_timestamp(self): + return self._is_timestamp + + @property + def access_threshold(self): + return self._access_threshold + + @property + def eviction_threshold(self): + return self._eviction_threshold + + @property + def index_key(self): + return self._index_key + + @property + def table_name(self): + return self._table_name + + @property + def feat_cnt(self): + return self._feat_cnt + + def check_params(self): + def check_str(arg, param_name): + if not isinstance(arg, str): + raise TypeError(f"{param_name} should be a string, whose value is {arg} with type '{type(arg)}' " + f"in fact.") + + def check_natural_number(arg, param_name): + if not isinstance(arg, int) or arg < 1: + raise TypeError(f"{param_name} should be a natural number, whose value is {arg} with type " + f"'{type(arg)}' in fact.") + + def check_bool(arg, param_name): + if not isinstance(arg, bool): + raise TypeError(f"{param_name} should be a bool, whose value is {arg} with type " + f"'{type(arg)}' in fact.") + + check_str(self.name, "name") + check_str(self._table_name, "table_name") + + if self._feat_cnt is not None: + check_natural_number(self._feat_cnt, "feat_count") + + if self._access_threshold is not None: + check_natural_number(self._access_threshold, "access_threshold") + elif self._eviction_threshold is not None: + raise ValueError(f"Access_threshold should be configured before eviction_threshold.") + + if self._eviction_threshold is not None: + check_natural_number(self._eviction_threshold, "eviction_threshold") + + if self._is_timestamp is not None: + check_bool(self._is_timestamp, "is_timestamp") + + def set_feat_pos(self, is_training): + if is_training: + self.feat_pos_train = FeatureSpec.instance_count_train + FeatureSpec.instance_count_train += 1 + else: + self.feat_pos_eval = FeatureSpec.instance_count_eval + FeatureSpec.instance_count_eval += 1 + + @property + def pipeline_mode(self): + return self._pipeline_mode + + def insert_pipeline_mode(self, mode): + if not isinstance(mode, bool): + raise TypeError("Is training mode must be a boolean.") + + if mode and mode in self._pipeline_mode: + logging.info(f"FeatureSpec{self.name}. Is training mode [{mode}] has been set.") + return + + insert_training_mode_channel_id(is_training=mode) + + self._pipeline_mode.add(mode) + + def set_feat_attribute(self, tensor, is_training): + self.insert_pipeline_mode(is_training) + self.set_feat_pos(is_training) + if not self.initialized: + self.initialized = True + + if get_use_static(): + self.dims = tensor.shape.as_list() + self.rank = tensor.shape.rank + if self.rank < 1: + raise ValueError(f"Given tensor rank cannot be smaller than 1, which is {self.rank} now.") + + inferred_feat_cnt = 1 if self.rank == 1 else reduce(lambda x, y: x * y, self.dims[1:]) + logging.debug(f"update feature_spec[{self.name}] feature_count " + f"from {self._feat_cnt} to {inferred_feat_cnt} via {self.dims}") + self.batch_size = self.dims[0] + self._feat_cnt = inferred_feat_cnt + self.split = self.batch_size * self._feat_cnt + else: + tensor = tf.reshape(tensor, [-1]) + self.dims = tf.shape(tensor) + self.rank = 1 + self.split = tf.math.reduce_prod(tf.shape(tensor)) + self.batch_size = self.split + self._feat_cnt = 1 + + else: + logging.debug(f"The initialized Feature Spec was set once again.") + if get_use_static(): + if self.dims != tensor.shape.as_list(): + raise ValueError(f"Given static Tensor shape mismatches with the last one, whose is_training mode " + f"is not {is_training}. ") + else: + if self.dims.shape.as_list() != tf.shape(tf.reshape(tensor, [-1])).shape.as_list(): + raise ValueError(f"Given dynamic Tensor shape mismatches with the last one, whose is_training mode " + f"is not {is_training}. ") + + insert_feature_spec(self, is_training) + return tensor, self.table_name, self.feat_cnt, self.split diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py new file mode 100644 index 00000000..756f5b0d --- /dev/null +++ b/mx_rec/core/asc/helper.py @@ -0,0 +1,453 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Huawei Technologies Co., Ltd +import logging +from functools import reduce +import os + +import tensorflow as tf + +from mx_rec.util.initialize import get_host_pipeline_ops, insert_feature_spec, insert_training_mode_channel_id, \ + get_training_mode_channel_id, get_use_static +from .feature_spec import FeatureSpec + + +def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, feature_numbers=None, + table_names=None, **kwargs): + ''' + desperated. + use create_asc_insert_func_with_specs or create_asc_insert_func_with_agc + ''' + if tgt_key_specs is not None: + return create_asc_insert_func_with_specs(tgt_key_specs=tgt_key_specs, **kwargs) + if args_index_list is not None: + return create_asc_insert_func_with_acg(args_index_list=args_index_list, + feature_counts=feature_numbers, + table_names=table_names, + **kwargs) + raise RuntimeError("call get_asc_insert_func in-correctly.") + + +def create_asc_insert_func_with_specs(tgt_key_specs, **kwargs): + ''' + feature spec模式 + ''' + return get_asc_insert_func_inner(tgt_key_specs=tgt_key_specs, **kwargs) + + +def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names, **kwargs): + ''' + 自动改图模式 auto change graph + ''' + return get_asc_insert_func_inner(args_index_list=args_index_list, + feature_counts=feature_counts, + table_names=table_names, + **kwargs) + + +def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_counts=None, + table_names=None, **kwargs): + both_none = tgt_key_specs is None and args_index_list is None + both_no_none = tgt_key_specs is not None and args_index_list is not None + if both_none or both_no_none: + raise ValueError("Args tgt_key_specs and args_index_list should and only can choice one to get insert tensors.") + + is_training = kwargs.get("is_training", True) + dump_graph = kwargs.get("dump_graph", False) + + if tgt_key_specs is not None: + if not isinstance(tgt_key_specs, (list, tuple)): + tgt_key_specs = [tgt_key_specs] + + def insert_fn_for_feature_specs(*args): + data_src = args + if len(args) == 1: + data_src = args[0] + + read_emb_key_inputs_dict = {"insert_tensors": [], "table_names": [], + "feature_spec_names": [], "splits": []} + get_target_tensors_with_feature_specs(tgt_key_specs, data_src, is_training, read_emb_key_inputs_dict) + logging.debug(f"do_insert with spec for {read_emb_key_inputs_dict['table_names']}") + return do_insert(args, + insert_tensors=read_emb_key_inputs_dict["insert_tensors"], + splits=read_emb_key_inputs_dict["splits"], + table_names=read_emb_key_inputs_dict["table_names"], + input_dict={"is_training": is_training, "dump_graph": dump_graph, + "timestamp": FeatureSpec.use_timestamp(is_training), + "feature_spec_names": read_emb_key_inputs_dict["feature_spec_names"], + "auto_change_graph": False}) + + insert_fn = insert_fn_for_feature_specs + + else: + if feature_counts is None or table_names is None: + raise ValueError("Please config 'args_index_list', 'feature_counts' and 'table_names' at the same time.") + + def insert_fn_for_arg_indexes(*args): + insert_tensors = get_target_tensors_with_args_indexes(args_index_list) + # config timestamp later + logging.debug(f"do_insert without spec for {table_names}") + splits = [] + for insert_tensor in insert_tensors: + split = reduce(lambda x, y: x * y, insert_tensor.shape.as_list()) + splits.append(split if split is not None else tf.math.reduce_prod(tf.shape(insert_tensor))) + return do_insert(args, + insert_tensors=insert_tensors, + splits=splits, + table_names=table_names, + input_dict={"is_training": is_training, "dump_graph": dump_graph, + "timestamp": FeatureSpec.use_timestamp(is_training), + "feature_spec_names": None, + "auto_change_graph": True}) + + insert_fn = insert_fn_for_arg_indexes + + return insert_fn + + +def merge_feature_id_request(feature_id_list, split_list, table_name_list, feature_spec_names): + if not (len(feature_id_list) == len(split_list) and len(split_list) == len(table_name_list)): + raise RuntimeError(f"shape not match. len(feature_id_list): {len(feature_id_list)}," + f"len(split_list): {len(split_list)}" + f"len(table_name_list): {len(table_name_list)}") + feature_id_requests = zip(feature_id_list, split_list, table_name_list) + feature_id_requests = sorted(feature_id_requests, key=lambda x: (x[2], x[0].name)) + logging.debug(f" features to merge: {feature_id_requests}") + last_table_name = None + last_split = 0 + last_tensorshape_split = 0 + output_feature_id_list = [x[0] for x in feature_id_requests] + output_split_list = [] + output_tensorshape_split_list = [] + output_table_name_list = [] + for feature_id, split, table_name in feature_id_requests: + if last_table_name is None or last_table_name == table_name: + last_table_name = table_name + last_split += split + last_tensorshape_split += tf.math.reduce_prod(tf.shape(feature_id)) + else: + output_table_name_list.append(last_table_name) + output_split_list.append(last_split) + output_tensorshape_split_list.append(last_tensorshape_split) + last_table_name = table_name + last_split = split + last_tensorshape_split = tf.math.reduce_prod(tf.shape(feature_id)) + if last_table_name is not None: + output_table_name_list.append(last_table_name) + output_split_list.append(last_split) + output_tensorshape_split_list.append(last_tensorshape_split) + logging.debug(f"merge request from {table_name_list} {split_list} " + f" to {output_table_name_list} {output_split_list}") + return output_feature_id_list, output_split_list, output_table_name_list, output_tensorshape_split_list + + +def send_feature_id_request_async(feature_id_list, split_list, + table_name_list, input_dict): + is_training = input_dict["is_training"] + timestamp = input_dict["timestamp"] + feature_spec_names = input_dict["feature_spec_names"] + auto_change_graph = input_dict["auto_change_graph"] + host_pipeline_ops = get_host_pipeline_ops() + use_static = get_use_static() + timestamp_feature_id = [] + + if timestamp: + timestamp_feature_id = feature_id_list[:1] + feature_id_list = feature_id_list[1:] + + if not auto_change_graph: # future support acg + feature_id_list, split_list, table_name_list, tensorshape_split_list = \ + merge_feature_id_request(feature_id_list, split_list, + table_name_list, feature_spec_names) + else: + tensorshape_split_list = split_list + + # check training mode order and ensure channel id + channel_id = get_training_mode_channel_id(is_training=is_training) + + if timestamp: + feature_id_list = timestamp_feature_id + feature_id_list + concat_tensor = tf.concat(feature_id_list, axis=0) + + if use_static: + logging.debug(f"read_emb_key_v2(static), table_name_list: {table_name_list}, split_list: {split_list}") + return host_pipeline_ops.read_emb_key_v2(concat_tensor, channel_id=channel_id, splits=split_list, + emb_name=table_name_list, timestamp=timestamp) + + logging.debug(f"read_emb_key_v2_dynamic, table_name_list: {table_name_list}, " + f"tensorshape_split_list: {tensorshape_split_list}") + pipeline_op = host_pipeline_ops.read_emb_key_v2_dynamic(concat_tensor, tensorshape_split_list, + channel_id=channel_id, + emb_name=table_name_list, timestamp=timestamp) + + return pipeline_op + + +def do_insert(args, insert_tensors, splits, table_names, input_dict): + is_training = input_dict["is_training"] + dump_graph = input_dict["dump_graph"] + timestamp = input_dict["timestamp"] + feature_spec_names = input_dict["feature_spec_names"] + auto_change_graph = input_dict["auto_change_graph"] + + # Only the tables that need to be used after table combination are retained in meituan situation. + # Current solution has error in same situations. For example, a sparse table has not been auto-merged. + from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + new_insert_tensors, new_splits, new_table_names = [], [], [] + logging.debug(f"In do_insert function, ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") + for idx, table_name in enumerate(table_names): + if ASCEND_TABLE_NAME_MUST_CONTAIN is not None and ASCEND_TABLE_NAME_MUST_CONTAIN not in table_name: + logging.info(f"After the tables are combined, the information about the" + f" {table_name} table does not need to be provided to the read_emb_key operator.") + continue + new_insert_tensors.append(insert_tensors[idx]) + new_splits.append(splits[idx]) + new_table_names.append(table_names[idx]) + + if timestamp: + new_insert_tensors = insert_tensors + + pipeline_op = \ + send_feature_id_request_async(feature_id_list=new_insert_tensors, + split_list=new_splits, + table_name_list=new_table_names, + input_dict={"is_training": is_training, + "timestamp": timestamp, + "feature_spec_names": feature_spec_names, + "auto_change_graph": auto_change_graph}) + + if dump_graph: + graph_def = tf.compat.v1.get_default_graph().as_graph_def() + tf.compat.v1.train.write_graph(graph_def, "./export_graph", "pipeline_graph.pb", False) + + # have to export read_emb_key_v2 op, other wise tensorflow will wipe out it by graph optimizing + output_batch = export_read_emb_key_v2_op(args, pipeline_op) + return output_batch + + +def export_read_emb_key_v2_op(args, pipeline_op): + origin_batch = list(args) + if isinstance(origin_batch[0], dict): + output_batch = origin_batch[0] + valid_key = get_valid_op_key(output_batch) + output_batch[valid_key] = pipeline_op + + elif len(origin_batch) == 1 and isinstance(origin_batch[0], tf.Tensor): + origin_batch.append(pipeline_op) + output_batch = tuple(origin_batch) + + elif len(origin_batch) == 2: + if isinstance(origin_batch[0], (list, tuple)): + origin_batch[0] = list(origin_batch[0]) + origin_batch[0].append(pipeline_op) + origin_batch[0] = tuple(origin_batch[0]) + output_batch = tuple(origin_batch) + + elif isinstance(origin_batch[0], tf.Tensor): + origin_batch[0] = [origin_batch[0]] + origin_batch[0].append(pipeline_op) + origin_batch[0] = tuple(origin_batch[0]) + output_batch = tuple(origin_batch) + + else: + raise EnvironmentError(f"An unexpected condition was encountered.") + + else: + origin_batch.append(tuple(pipeline_op)) + output_batch = tuple(origin_batch) + return output_batch + + +def get_valid_op_key(batch_dict: dict) -> str: + if not isinstance(batch_dict, dict): + raise TypeError(f"batch_dict must be a dict") + + sorted_keys = sorted(batch_dict) + valid_key = f"{sorted_keys[-1]}_read_emb_key" + + return valid_key + + +def get_target_tensors_with_args_indexes(args_index_list): + insert_tensors = [] + graph = tf.get_default_graph() + + for index in args_index_list: + tensor = graph.get_tensor_by_name("args_%d:0" % index) + if tensor.dtype != tf.int64: + logging.debug(f"Input tensor dtype is {tensor.dtype}, which will be transferred to tf.int64.") + tensor = tf.cast(tensor, tf.int64) + insert_tensors.append(tf.reshape(tensor, [-1, ])) + + return insert_tensors + + +def get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict): + def parse_feature_spec(feature_spec, batch, is_training, read_emb_key_inputs_dict): + if isinstance(batch, dict): + if feature_spec.index_key not in batch: + # feature_spec.is_timestamp is true when batch does not contain timestamp + if feature_spec.is_timestamp: + raise KeyError(f"Cannot find key or index {feature_spec.index_key} in batch.") + # feature_spec.is_timestamp is false when batch does not contain timestamp + return + + if not isinstance(batch.get(feature_spec.index_key), tf.Tensor): + raise TypeError(f"Target value is not a tensor, which is a {type(batch.get(feature_spec.index_key))}.") + + tensor = batch.get(feature_spec.index_key) + elif isinstance(batch, (list, tuple)): + if feature_spec.index_key >= len(batch): + raise ValueError(f"index out of range.") + + if not isinstance(batch[feature_spec.index_key], tf.Tensor): + raise TypeError(f"Target value is not a tensor, which is a {type(batch[feature_spec.index_key])}.") + + tensor = batch[feature_spec.index_key] + else: + raise ValueError(f"Encounter a invalid batch.") + + if feature_spec.is_timestamp is None: + tensor, table_name, feat_count, split = feature_spec.set_feat_attribute(tensor, is_training) + if tensor.dtype != tf.int64: + tensor = tf.cast(tensor, dtype=tf.int64) + + read_emb_key_inputs_dict["insert_tensors"].append(tf.reshape(tensor, [-1, ])) + read_emb_key_inputs_dict["table_names"].append(table_name) + read_emb_key_inputs_dict["splits"].append(split) + read_emb_key_inputs_dict["feature_spec_names"].append(feature_spec.name) + elif feature_spec.is_timestamp: + if len(tensor.shape.as_list()) != 0: + raise ValueError(f"Given TimeStamp Tensor must be a scalar.") + read_emb_key_inputs_dict["insert_tensors"] = [tf.reshape(tensor, [-1, ])] + \ + read_emb_key_inputs_dict["insert_tensors"] + feature_spec.include_timestamp(is_training) + elif tensor is not None: + raise ValueError(f"Spec timestamp should be true when batch contains timestamp.") + + if isinstance(tgt_key_specs, dict): + for key, item in tgt_key_specs.items(): + get_target_tensors_with_feature_specs(item, batch[key], is_training, read_emb_key_inputs_dict) + return + + elif isinstance(tgt_key_specs, (list, tuple)): + if is_feature_spec_list(tgt_key_specs): + for feature in tgt_key_specs: + get_target_tensors_with_feature_specs(feature, batch, is_training, read_emb_key_inputs_dict) + return + + elif isinstance(batch, (list, tuple)) and len(tgt_key_specs) == len(batch): + for spec, sub_batch in zip(tgt_key_specs, batch): + get_target_tensors_with_feature_specs(spec, sub_batch, is_training, read_emb_key_inputs_dict) + return + + elif isinstance(tgt_key_specs, FeatureSpec): + parse_feature_spec(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + return + + raise ValueError(f"Please keep tgt_key_specs was built with the same structure compare to given batch. \n\t\t" + f"In fact, tgt_key_specs type is {type(tgt_key_specs)} but batch type is {type(batch)}.") + + +def is_feature_spec_list(specs): + if not isinstance(specs, (list, tuple)): + return False + + for item in specs: + if not isinstance(item, FeatureSpec): + return False + + return True + + +def get_asc_read_raw_func(cfg_list): + batch = {} + int_name_order = [] + int_len_list = [] + float_name_order = [] + float_len_list = [] + line_per_sample_list = [] + host_pipeline_ops = get_host_pipeline_ops() + for cfg in cfg_list: + if cfg.data_type == "int64": + int_name_order.append(cfg.feature_name) + int_len_list.append(cfg.feature_len) + line_per_sample_list.append(cfg.line_per_sample) + + if cfg.data_type == "float": + float_name_order.append(cfg.feature_name) + float_len_list.append(cfg.feature_len) + line_per_sample_list.append(cfg.line_per_sample) + if len(set(line_per_sample_list)) != 1: + raise ValueError(f"Please check that each line_per_sample value should be equal.") + line_per_sample = line_per_sample_list[0] + + def read_raw_fn(data_src): + raw_int_sample, raw_float_sample = host_pipeline_ops.read_raw( + sample=data_src, + int_len=sum(int_len_list) * line_per_sample, + float_len=sum(float_len_list) * line_per_sample, + feat_order=int_name_order + float_name_order + ) + + int_split_res = tf.split(raw_int_sample, [i * line_per_sample_list[0] for i in int_len_list]) + + float_split_res = tf.split(raw_float_sample, [i * line_per_sample_list[0] for i in float_len_list]) + + logging.debug(f"############ Enter read_raw_fn ########") + + for name_id, name in enumerate(int_name_order): + batch[name] = int_split_res[name_id] + + for name_id, name in enumerate(float_name_order): + batch[name] = float_split_res[name_id] + return batch + + return read_raw_fn + + +class ParseConfig: + + def __init__(self, **kwargs): + self.input_keys = set(kwargs.keys()) + self._feature_name = kwargs.get("feature_name") + self._feature_len = int(kwargs.get("feature_len")) + self._data_type = kwargs.get("data_type") + self._line_per_sample = int(kwargs.get("line_per_sample")) + self.check_params() + + @property + def feature_name(self): + return self._feature_name + + @property + def feature_len(self): + return self._feature_len + + @property + def data_type(self): + return self._data_type + + @property + def line_per_sample(self): + return self._line_per_sample + + def check_params(self): + supported_keys = {"feature_name", "feature_len", "line_per_sample", "data_type"} + if self.input_keys != supported_keys: + raise KeyError("Please offer an expected keyword argument") + + if not isinstance(self._feature_name, str): + raise TypeError(f"Please offer a feature_name with string type.") + + if not isinstance(self._data_type, str): + raise TypeError(f"Please offer a data_type with string type.") + + if self._data_type not in ("int64", "float"): + raise TypeError(f"Please offer a data_type with int64 or float type") + + if self._feature_len <= 0: + raise ValueError(f"Please offer a feature_len greater than zero.") + + if self._line_per_sample <= 0: + raise ValueError(f"Please offer a line_per_sample greater than zero.") diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py new file mode 100644 index 00000000..6ea99612 --- /dev/null +++ b/mx_rec/core/asc/manager.py @@ -0,0 +1,195 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Huawei Technologies Co., Ltd +import logging +import os + +import tensorflow as tf + +from mx_rec.util.constants import MxRecMode +from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ + is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ + export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ + get_use_hot, get_use_dynamic_expansion, export_optimizer + + +def generate_table_info_list(): + from mxrec_pybind import EmbInfo + from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + # table_name is corresponding to channel_name which is in used in operator gen_npu_ops.get_next + table_info_list = [] + + # check whether DDR is enabled or disabled for all tables. + host_voc_sizes = [table_instance.host_vocabulary_size for table_instance in export_table_instances().values()] + total_host_voc_size = sum(host_voc_sizes) + if total_host_voc_size != 0 and 0 in host_voc_sizes: + raise ValueError(f"The host-side DDR function of all tables must be used or not used at the same time. " + f"However, host voc size of each table is {host_voc_sizes}.") + + optimizer = export_optimizer() + # generate table info + for _, table_instance in export_table_instances().items(): + # When dynamic expansion mode, ext_emb_size is set by optimizer + if optimizer is not None: + table_instance.ext_emb_size = table_instance.scalar_emb_size * (1 + optimizer.slot_num) + logging.debug(f"ext_emb_size is reset to be {table_instance.ext_emb_size} for EmbInfo") + # Only the tables that need to be used after table combination are retained in meituan situation. + # Current solution has error in same situations. For example, a sparse table has not been auto-merged. + logging.debug(f"In EmbInfo, ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") + if ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ + ASCEND_TABLE_NAME_MUST_CONTAIN not in table_instance.table_name: + logging.info(f"After the tables are combined, the information about the" + f" {table_instance.table_name} table does not need to be provided to the EmbInfo.") + continue + rec_mode_asc_flag = table_instance.mode == MxRecMode.ASC + static_shape_rec_flag = rec_mode_asc_flag and get_use_static() and table_instance.send_count > 0 + dynamic_shape_rec_flag = rec_mode_asc_flag and not get_use_static() + if static_shape_rec_flag or dynamic_shape_rec_flag: + logging.debug(f"table_instance.slice_device_vocabulary_size: {table_instance.slice_device_vocabulary_size}") + logging.debug(f"table_instance.slice_host_vocabulary_size: {table_instance.slice_host_vocabulary_size}") + table_info = EmbInfo(table_instance.table_name, table_instance.send_count, table_instance.ext_emb_size, + [table_instance.slice_device_vocabulary_size, + table_instance.slice_host_vocabulary_size], + [matched_emb_initializer(table_instance)] + + matched_opt_slot_initializers(table_instance)) + table_info_list.append(table_info) + + return table_info_list + + +def matched_emb_initializer(tabel_info): + from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo + initializer_case_map = {"tf1/tf2_constant_initializer": + isinstance(tabel_info.emb_initializer, tf.keras.initializers.Constant) or + isinstance(tabel_info.emb_initializer, tf.constant_initializer), + "tf1/tf2_random_normal_initializer": + isinstance(tabel_info.emb_initializer, tf.keras.initializers.RandomNormal) or + isinstance(tabel_info.emb_initializer, tf.random_normal_initializer), + "tf1_truncated_normal_initializer": + tf.__version__.startswith("1") and + (isinstance(tabel_info.emb_initializer, tf.truncated_normal_initializer) or + isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal)), + "tf2_truncated_normal_initializer": + tf.__version__.startswith("2") and + isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal), + } + if initializer_case_map["tf1/tf2_constant_initializer"]: + initializer = InitializeInfo(name="constant_initializer", start=0, len=tabel_info.scalar_emb_size, + constant_initializer_info=ConstantInitializerInfo( + constant_val=tabel_info.emb_initializer.value)) + elif initializer_case_map["tf1/tf2_random_normal_initializer"]: + random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed + initializer = InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.scalar_emb_size, + normal_initializer_info=NormalInitializerInfo( + mean=tabel_info.emb_initializer.mean, + stddev=tabel_info.emb_initializer.stddev, + seed=random_seed + )) + elif initializer_case_map["tf1_truncated_normal_initializer"] or \ + initializer_case_map["tf2_truncated_normal_initializer"]: + random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed + initializer = InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, + normal_initializer_info=NormalInitializerInfo( + mean=tabel_info.emb_initializer.mean, + stddev=tabel_info.emb_initializer.stddev, + seed=random_seed + )) + else: + initializer = InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, + normal_initializer_info=NormalInitializerInfo( + mean=0.0, + stddev=1.0, + seed=0 + )) + return initializer + + +def matched_opt_slot_initializers(table_instance): + from mxrec_pybind import InitializeInfo, ConstantInitializerInfo + + start_index = table_instance.scalar_emb_size + slot_initializers = [] + + for optimizer in table_instance.optimizer_instance_list: + for slot_init_value in optimizer.get_slot_init_values(): + slot_initializer = InitializeInfo(name="constant_initializer", + start=start_index, + len=table_instance.scalar_emb_size, + constant_initializer_info=ConstantInitializerInfo( + constant_val=slot_init_value + )) + slot_initializers.append(slot_initializer) + start_index += table_instance.scalar_emb_size + + return slot_initializers + + +def generate_threshold_list(): + from mxrec_pybind import ThresholdValue + threshold_list = [] + + for _, feature_spec in export_feature_spec().items(): + if feature_spec.eviction_threshold: + threshold = ThresholdValue(feature_spec.table_name, + feature_spec.access_threshold, + feature_spec.eviction_threshold) + threshold_list.append(threshold) + continue + if feature_spec.access_threshold: + threshold = ThresholdValue(feature_spec.table_name, + feature_spec.access_threshold, + -1) + threshold_list.append(threshold) + + return threshold_list + + +def initialize_emb_cache(table_info_list, threshold_list): + from mxrec_pybind import HybridMgmt, RankInfo, USE_STATIC, USE_HOT, USE_DYNAMIC_EXPANSION + + rank_id = get_rank_id() + device_id = get_device_id() + rank_size = get_rank_size() + evaluate_stride = get_train_interval() + eval_steps = get_eval_steps() + n_batch_to_prefetch = get_prefetch_batch_number() + if_load = get_if_load() + option = 0 + if get_use_static(): + option = option | USE_STATIC + if get_use_hot(): + option = option | USE_HOT + if get_use_dynamic_expansion(): + option = option | USE_DYNAMIC_EXPANSION + + if get_training_mode_channel_id(is_training=False) == 0: + rank_info = RankInfo(rank_id, device_id, rank_size, option, n_batch_to_prefetch, + [eval_steps, evaluate_stride]) + else: + rank_info = RankInfo(rank_id, device_id, rank_size, option, n_batch_to_prefetch, + [evaluate_stride, eval_steps]) + + emb_cache = HybridMgmt() + if threshold_list: + emb_cache.initialize(rank_info=rank_info, emb_info=table_info_list, if_load=if_load, + threshold_values=threshold_list) + else: + emb_cache.initialize(rank_info=rank_info, emb_info=table_info_list, if_load=if_load) + + set_asc_manager(emb_cache) + logging.info("Preprocessing has been sunk into the host pipeline.") + logging.debug(f"Flag if load is {if_load}.") + logging.debug(f"n_batch_to_prefetch is {n_batch_to_prefetch}.") + logging.debug(f"evaluate_stride is {evaluate_stride}.") + logging.debug(f"eval_steps is {eval_steps}.") + logging.debug(f"threshold_values are {threshold_list}.") + + +def start_asc_pipeline(): + table_info_list = generate_table_info_list() + threshold_list = generate_threshold_list() + if not table_info_list: + logging.warning(f"table_info_list is empty") + + if not is_asc_manager_initialized() and table_info_list: + initialize_emb_cache(table_info_list, threshold_list) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py new file mode 100644 index 00000000..52f7ce9c --- /dev/null +++ b/mx_rec/core/embedding.py @@ -0,0 +1,826 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Huawei Technologies Co., Ltd +import logging +import math +import time +from collections import defaultdict + +import numpy as np +import tensorflow as tf +from tensorflow.python.framework import ops + +from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc +from mx_rec.core.asc.helper import FeatureSpec +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ + ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ + ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ + DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID +from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ + insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ + clear_channel, trigger_evict, get_table_instance_by_name, get_use_hot, get_device_id, export_feature_spec, \ + ConfigInitializer, get_ascend_global_hashtable_collection, get_host_pipeline_ops, get_use_dynamic_expansion +from mx_rec.util.tf_version_adapter import npu_ops +from mx_rec.util.variable import remove_saving_var + + +def create_table(key_dtype, dim, name, emb_initializer, device_vocabulary_size=1, host_vocabulary_size=0, + optimizer_list=None, mode=MxRecMode.ASC, value_dtype=tf.float32, shard_num=1, + fusion_optimizer_var=True, hashtable_threshold=0): + """ + + Args: + key_dtype: data type for feature id + dim: embedding vector size + name: hash table name + emb_initializer: the initializer for embedding values + device_vocabulary_size: embedding vector numbers on device + host_vocabulary_size: embedding vector numbers on ddr + relation from feature to variable offset will be built + optimizer_list: specify the optimizers to use for current hash table + mode: specify which mode to run for current sparse table + value_dtype: the type of the value tensors. + shard_num: embedding partition number + fusion_optimizer_var: fusion optimizer variable with embedding + hashtable_threshold: choose to implement based on hash table or linear layer + + Returns: SparseEmbedding instance + + """ + config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, + device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, + optimizer_list=optimizer_list, mode=mode, value_dtype=value_dtype, shard_num=shard_num, + fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold) + embedding = SparseEmbedding(config) + return embedding + + +def sparse_lookup(hashtable, ids, send_count, **kwargs): + """ + + Args: + hashtable: SparseEmbedding instance to be looked up + ids: Tensor to lookup from hashtable + send_count: used to config all2all communication parameters + kwargs: + dim: not in use + is_train: not in use + name: will be used to build scope_name together with hashtable name + modify_graph: if True, the original graph will be modified before building a Session instance + + Returns: Tensor for lookup result + + """ + + def check_lookup_kwargs(): + kwargs["name"] = kwargs.get("name", hashtable.get_default_lookup_name()) + if not isinstance(kwargs.get("name"), str): + raise TypeError("Given name must be a string.") + + kwargs["modify_graph"] = kwargs.get("modify_graph", False) + if not isinstance(kwargs.get("modify_graph"), bool): + raise TypeError("Given name must be a boolean.") + + def check_table_legality_for_feature_spec(table, feature_spec): + # check whether the name of the table exists with FeatureSpec. + if table.table_name != feature_spec.table_name: + raise ValueError(f"The table name '{feature_spec.table_name}' specified by FeatureSpec is inconsistent with" + f" the SparseEmbedding table name '{table.table_name}'.") + + def check_modify_graph(): + if not kwargs.get("modify_graph"): + logging.warning(f"MxRecMode {MxRecMode.ASC} must config with a 'True' " + f"modify_graph.") + + check_lookup_kwargs() + scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) + with tf.compat.v1.variable_scope(scope_name): + if hashtable.mode == MxRecMode.ASC: + if isinstance(ids, FeatureSpec): + check_table_legality_for_feature_spec(hashtable, ids) + return hashtable.lookup_for_asc_with_feature_spec(ids, send_count, **kwargs) + else: + check_modify_graph() + return hashtable.lookup_for_asc(ids, send_count, **kwargs) + else: + raise EnvironmentError(f"Invalid MxRec Mode.") + + +class SparseEmbedding: + """ + each feat_name has its own sparse_embedding_layer. + """ + customized_ops = get_customized_ops() + anchor_tensor_specs = defaultdict(dict) + + def __init__(self, config): + self.embedding_size = config.get("embedding_size") + if isinstance(self.embedding_size, int): + self.embedding_size = tf.TensorShape([self.embedding_size]) + self.device_vocabulary_size = config.get("device_vocabulary_size") + self.host_vocabulary_size = config.get("host_vocabulary_size") + self.table_name = config.get("table_name") + self.key_dtype = config.get("key_dtype") + self._optimizer_instance_list = config.get("optimizer_list") + self.emb_initializer = config.get("emb_initializer") + self._mode = config.get("mode") + self.optimizer_slot_info_list = [] + self._slot_num = dict() + self._send_count = 0 + self._use_feature_mapping = False + self.skip_emb_transfer = True if self.host_vocabulary_size <= 0 else False + self._default_name_count = -1 + self._emb_size = None + self.ext_emb_size = None + self.ext_coefficient = 1 + self._optimizer = dict() + self.slice_device_vocabulary_size = None + self.slice_host_vocabulary_size = None + self.variable = None + self.lookup_info = set() + self.lookup_result = None + self.use_dynamic_expansion = get_use_dynamic_expansion() + + self.set_slice_vocab_size() + self.set_emb_size() + if self._mode == MxRecMode.ASC and is_asc_frozen() and self.table_name in get_name_to_var_dict(): + self.variable = tf.compat.v1.get_variable(self.table_name, + shape=(self.slice_device_vocabulary_size, self._emb_size)) + if not self.skip_emb_transfer: + self.set_ext_emb_size() + else: + self.check_and_format_init_params() + self._initialize_variables() + self.set_ext_emb_size() + tf.compat.v1.add_to_collection(get_ascend_global_hashtable_collection(), self.variable) + + def check_optimizer_instance(self): + for optimizer_instance in self._optimizer_instance_list: + if tf.__version__.startswith("1"): + from npu_bridge.estimator.npu.npu_loss_scale_optimizer import NPULossScaleOptimizer + if isinstance(optimizer_instance, NPULossScaleOptimizer): + optimizer_instance = getattr(optimizer_instance, '_opt') + else: + from npu_device.train.optimizer.npu_loss_scale_optimizer import NpuLossScaleOptimizer + if isinstance(optimizer_instance, NpuLossScaleOptimizer): + optimizer_instance = getattr(optimizer_instance, '_opt') + + if not isinstance(optimizer_instance, CustomizedOptimizer): + raise ValueError(f"args optimizer list must be a list or an instance of CustomizedOptimizer.") + + def check_and_format_init_params(self): + if not isinstance(self.embedding_size, tf.TensorShape): + raise TypeError("Parameter 'embedding_size' must be a tf.TensorShape instance.") + + if self.embedding_size.ndims != 1: + raise ValueError("Parameter 'embedding_size' can only be one dim shape.") + + if self.mode == MxRecMode.ASC and is_asc_frozen(): + raise EnvironmentError(f"Emb cache management has been established, you cannot build new ASC hash table.") + + if self.mode != MxRecMode.ASC and self.host_vocabulary_size > 0: + raise ValueError(f"Only ASC mode can use host_vocabulary_size > 0.") + + if self.mode == MxRecMode.ASC and not is_mpi_in_use(): + raise EnvironmentError(f"Hash table with ASC mode must use mpi to start task.") + + if not self.skip_emb_transfer and not self._optimizer_instance_list: + raise ValueError("ASC with DDR mode should config optimizers before instantiating sparse table, " + "but nothing was configured.") + if not self.skip_emb_transfer and self.use_dynamic_expansion: + raise ValueError("DDR mode do not support embedding dynamic_expansion for now.") + + self._optimizer_instance_list = [] if self._optimizer_instance_list is None else self._optimizer_instance_list + if isinstance(self._optimizer_instance_list, CustomizedOptimizer): + self._optimizer_instance_list = [self._optimizer_instance_list] + + if not isinstance(self._optimizer_instance_list, (tuple, list)): + raise ValueError(f"args optimizer list must be a list or an instance of CustomizedOptimizer.") + self._optimizer_instance_list = list(self._optimizer_instance_list) + + self.check_optimizer_instance() + + def get_default_lookup_name(self): + logging.debug(f"getting one default lookup name") + self._default_name_count += 1 + return "sparse_lookup_%d" % self._default_name_count + + def set_using_feature_mapping(self): + self._use_feature_mapping = True + + def set_emb_size(self): + self._emb_size = self.embedding_size.as_list()[0] + + def set_ext_emb_size(self): + self.ext_coefficient += len(self.optimizer_slot_info_list) + if self.use_dynamic_expansion and len(self._optimizer_instance_list) != 0: + self.ext_coefficient += self._slot_num[self.table_name] + self.ext_emb_size = self._emb_size * self.ext_coefficient + logging.debug(f"init table, ext_emb_size is set to be {self.ext_emb_size}") + + def set_slice_vocab_size(self): + rank_size = get_rank_size() + if rank_size == 0: + raise ZeroDivisionError("Rank size cannot be zero.") + if self.use_dynamic_expansion: + self.slice_device_vocabulary_size = 1 # 动态扩容模式下,保留device侧variable,大小设置为 1 + self.slice_host_vocabulary_size = 0 + else: + self.slice_device_vocabulary_size = math.ceil(self.device_vocabulary_size / rank_size) + self.slice_host_vocabulary_size = math.ceil(self.host_vocabulary_size / rank_size) + + @property + def use_feature_mapping(self): + return self._use_feature_mapping + + @property + def scalar_emb_size(self): + return self._emb_size + + @property + def mode(self): + return self._mode + + def _record(self): + insert_table_instance(self.table_name, self.variable, self) + logging.debug(f"Device vocabulary_size for table {self.table_name} is {self.device_vocabulary_size}.") + logging.debug(f"Slice_device_vocabulary_size for table {self.table_name} is" + f" {self.slice_device_vocabulary_size}.") + logging.debug(f"Host vocabulary size for table {self.table_name} is {self.host_vocabulary_size}.") + logging.debug(f"Slice host vocabulary_size for table {self.table_name} is" + f" {self.slice_host_vocabulary_size}.") + + @staticmethod + def get_anchor_attribute(anchor, attr): + if not isinstance(anchor, tf.Tensor): + raise ValueError("Anchor must be a Tensor.") + + if attr not in ASCAnchorAttr: + raise ValueError("Given attr must be limited in Enum 'ASCAnchorAttr'.") + + specs = SparseEmbedding.anchor_tensor_specs.get(anchor) + if specs is None: + raise ValueError(f"Given anchor '{anchor}' was not registered.") + + return specs.get(attr) + + def register_anchor_attribute(self, anchor_ids, feature_spec, kwargs): + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.TABLE_INSTANCE] = self + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = kwargs.get("is_train") + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.FEATURE_SPEC] = feature_spec + + def check_mode(self, method_mode): + if self.mode != method_mode: + raise RuntimeError(f"Current sparse table was config in {self.mode.value} mode, but sparse lookup method " + f"for {method_mode} was in use.") + + def check_and_format_lookup_params(self, feature, send_count, is_training): + logging.debug(f"sparse lookup for table {self.table_name} with is_training {is_training}") + + def check_params(): + if not isinstance(is_training, bool): + raise ValueError("Arg is_train should be a boolean.") + + if isinstance(feature, FeatureSpec): + if not feature.initialized: + raise ValueError(f"Feature Spec has not been initialized.") + key_info = "{}_{}".format(feature.name, feature.index_key) + if is_training not in feature.pipeline_mode: + raise ValueError(f"You have not config feature for is training mode '{is_training}', please config " + f"feature with func sparse_lookup at first.") + + elif isinstance(feature, tf.Tensor): + logging.debug("Input feature is a Tensor.") + + else: + raise TypeError(f"Given feature must be a FeatureSpec or tf.Tensor.") + + if is_training not in self.lookup_info: + self.lookup_info.add(is_training) + + if get_use_static(): + if isinstance(send_count, int) and send_count > 0: + if self._send_count and self._send_count != send_count: + logging.warning(f"A new send count {send_count} will be used to replace the old one" + f"({self._send_count}).") + + self._send_count = send_count + else: + raise ValueError("Send count must be a integer which is larger than 0.") + + check_params() + max_int32 = np.iinfo(np.int32).max + if self.slice_host_vocabulary_size + self.slice_device_vocabulary_size > max_int32: + raise ValueError(f"Given device_vocabulary_size and host_vocabulary_size was too big for table " + f"'{self.table_name}', in which slice_device_vocabulary_size was " + f"{self.slice_device_vocabulary_size} and slice_host_vocabulary_size was " + f"{self.slice_host_vocabulary_size} ") + + is_check_mode = self.mode == MxRecMode.ASC and not self.skip_emb_transfer and not self.use_dynamic_expansion + if is_check_mode and self.slice_device_vocabulary_size < self.send_count * get_rank_size(): + raise ValueError(f"Given device_vocabulary_size was too small for table '{self.table_name}', in which " + f"slice_device_vocabulary_size was {self.slice_device_vocabulary_size} and " + f"send_count({self.send_count}) * rank_size({get_rank_size()}) was " + f"{self.send_count * get_rank_size()}") + + if is_check_mode and self.slice_host_vocabulary_size < self.send_count * get_rank_size(): + raise ValueError(f"Given host_vocabulary_size was too small for table '{self.table_name}', in which " + f"slice_host_vocabulary_size was {self.slice_host_vocabulary_size} and " + f"send_count({self.send_count}) * rank_size({get_rank_size()}) was " + f"{self.send_count * get_rank_size()}") + + def set_optimizer(self, key, state_dict): + if key in self._optimizer: + raise ValueError(f"Optimizer {key} has been set for hash table {self.table_name}") + + self._optimizer[key] = state_dict + + @property + def send_count(self): + return self._send_count + + @property + def optimizer(self): + return self._optimizer + + @property + def optimizer_instance_list(self): + return self._optimizer_instance_list + + def lookup_for_asc(self, ids: tf.Tensor, send_count, **kwargs): + """ + + Args: + ids: Tensor to lookup from hashtable + send_count: int, used to config all2all communication parameters + kwargs: + dim: not in use + is_train: + name: not in use + modify_graph: if True, the original graph will be modified before building a Session instance + + Returns: Tensor for lookup result + + """ + logging.debug(f"Enter ASC Branch.") + self.check_mode(MxRecMode.ASC) + is_training = kwargs.get("is_train") + if is_asc_frozen() and is_training: + raise EnvironmentError(f"Cannot build new sparse forward graph after emb cache management was built.") + if not kwargs.get("modify_graph"): + raise RuntimeError(f"modify_graph must be turn-on when lookup by ids(Tensor, not FeatureSpec).") + + rank_size = get_rank_size() + access_threshold = None + eviction_threshold = None + if kwargs.get("access_and_evict_config"): + access_and_evict_config = kwargs.get("access_and_evict_config") + access_threshold = access_and_evict_config.get("access_threshold") + eviction_threshold = access_and_evict_config.get("eviction_threshold") + feature_spec = FeatureSpec(self.table_name, + access_threshold=access_threshold, + eviction_threshold=eviction_threshold) + feature_spec.set_feat_attribute(ids, is_training) + + if is_asc_frozen() and not is_training: + clear_channel(is_train_channel=False) + + self.check_and_format_lookup_params(ids, send_count, is_training) + anchor_ids = tf.identity(ids, name="ids") + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, anchor_ids) + self.register_anchor_attribute(anchor_ids, feature_spec, kwargs) + + use_dynamic_expansion = get_use_dynamic_expansion() + use_static = get_use_static() + use_hot = get_use_hot() + logging.debug(f"In lookup_for_asc function, table name: {self.table_name}, ids: {ids}, use_dynamic_expansion: " + f"{use_dynamic_expansion}, use_static: {use_static}, use_hot: {use_hot}") + + id_offsets = tf.ones(shape=[send_count * rank_size if use_static else 1 * rank_size, ], + dtype=tf.int64 if get_use_dynamic_expansion() else tf.int32, name="id_offsets") + id_offsets = tf.identity(id_offsets, name=ASCAnchorAttr.ID_OFFSETS.value) + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.ID_OFFSETS] = id_offsets + local_embeddings = None + if use_dynamic_expansion: + local_embeddings = get_host_pipeline_ops().embedding_lookup_by_address(id_offsets, + embedding_dim=self._emb_size, + embedding_type=1) + if is_training: + tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) + tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) + + @tf.custom_gradient + def sparse_forward(table, feat_ids): + logging.debug(f"fp rank size: {rank_size}") + restore_vector = tf.ones(shape=[np.prod(feat_ids.shape.as_list()), ], dtype=tf.int32, name="restore_vector") + restore_vector = tf.identity(restore_vector, name=ASCAnchorAttr.RESTORE_VECTOR.value) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, restore_vector) + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.RESTORE_VECTOR] = restore_vector + + all2all_matrix = None + if not use_static: + all2all_matrix = tf.ones(shape=[rank_size, rank_size], dtype=tf.int64, name="all2all_matrix") + all2all_matrix = tf.identity(all2all_matrix, name=ASCAnchorAttr.ALL2ALL_MATRIX.value) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, all2all_matrix) + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.ALL2ALL_MATRIX] = all2all_matrix + + hot_pos = None + if use_hot: + import mxrec_pybind + hot_size = int(mxrec_pybind.get_ub_hot_size(get_device_id()) / self._emb_size) + hot_pos = tf.ones(shape=[hot_size, ], dtype=tf.int32, name="hot_pos") + hot_pos = tf.identity(hot_pos, name=ASCAnchorAttr.HOT_POS.value) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_HOT_POS, hot_pos) + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.HOT_POS] = hot_pos + + if not use_dynamic_expansion: + id_offsets_abs = tf.abs(id_offsets) + local_emb = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") + else: + local_emb = tf.identity(table, name="identity_local_emb") + all2all_args = send_count if use_static else all2all_matrix + unique_embeddings = get_own_emb(local_emb, all2all_args, self.scalar_emb_size, use_static) + + if hot_pos is not None: + unique_embeddings = tf.concat([tf.gather(unique_embeddings, hot_pos, name="hot_pos"), + unique_embeddings], axis=0) + + embeddings = tf.gather(unique_embeddings, restore_vector, axis=0, name="gather_for_restore_vector") + lookup_result = tf.reshape(embeddings, feat_ids.shape.as_list() + [self.scalar_emb_size]) + + def grad(lookup_diff): + embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) + logging.debug(f"bp rank size: {rank_size}") + unique_embeddings_shape = unique_embeddings.shape.as_list() if use_static \ + else tf.shape(unique_embeddings) + unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, + restore_vector, + unique_embeddings_shape[0]) + bp_all2all_args = all2all_args if use_static else tf.transpose(all2all_args) + if hot_pos is not None: + hot, cold = tf.split(unique_grads, [tf.shape(hot_pos)[0], + tf.shape(unique_grads)[0] - tf.shape(hot_pos)[0]], axis=0) + unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) + local_grad = get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) + + if use_dynamic_expansion: + return local_grad, feat_ids + return ops.IndexedSlices(values=local_grad, indices=id_offsets, dense_shape=tf.shape(table)), feat_ids + + return lookup_result, grad + + if use_dynamic_expansion: + return sparse_forward(local_embeddings, ids) + return sparse_forward(self.variable, ids) + + def lookup_for_asc_with_feature_spec(self, feature_spec: FeatureSpec, send_count: int, **kwargs): + """ + Args: + feature_spec: an instance of FeatureSpec to lookup from hashtable + send_count: int, used to config all2all communication parameters + kwargs: + dim: not in use + is_train: + name: not in use + modify_graph: if True, the original graph will be modified before building a Session instance + + Returns: Tensor for lookup result + + """ + spec_name = feature_spec.name + is_training = kwargs.get("is_train") + if self.lookup_result is not None and spec_name in self.lookup_result \ + and is_training in self.lookup_result[spec_name]: + return self.lookup_result[spec_name][is_training] + + table_name = feature_spec.table_name + same_table_feature_spec = ConfigInitializer.get_instance().table_name_to_feature_spec[table_name][is_training] + if len(same_table_feature_spec) == 0: + raise RuntimeError(f"spec_name {spec_name} not in table {table_name}") + if len(same_table_feature_spec) == 1: + lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) + self.lookup_result = {spec_name: {is_training: lookup_result}} + return self.lookup_result[spec_name][is_training] + else: + same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) + same_table_spec_count = len(same_table_feature_spec) + feature_count = [x.feat_cnt * x.batch_size for x in same_table_feature_spec] + total_feature_count = sum(feature_count) + mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", + feat_count=total_feature_count, table_name=table_name) + mock_feature_spec.batch_size = 1 + mock_feature_spec.dims = [1, total_feature_count] + mock_feature_spec.initialized = True + mock_feature_spec.pipeline_mode.add(True) + mock_feature_spec.pipeline_mode.add(False) + lookup_result = self.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, + send_count * same_table_spec_count, **kwargs) + logging.debug(f"lookup table {table_name} via {feature_count}") + lookup_result = tf.reshape(lookup_result, [-1, self.scalar_emb_size]) + split_size = [x.feat_cnt * x.batch_size for x in same_table_feature_spec] + lookup_result_split = tf.split(lookup_result, split_size) + self.lookup_result = {k.name: {is_training: tf.reshape(v, k.dims + [self.scalar_emb_size])} + for k, v in zip(same_table_feature_spec, lookup_result_split)} + return self.lookup_result[spec_name][is_training] + + def lookup_for_asc_with_feature_spec_inner(self, feature_spec: FeatureSpec, send_count: int, **kwargs): + """ + Args: + feature_spec: an instance of FeatureSpec to lookup from hashtable + send_count: int, used to config all2all communication parameters + kwargs: + dim: not in use + is_train: + name: not in use + modify_graph: if True, the original graph will be modified before building a Session instance + + Returns: Tensor for lookup result + + """ + logging.debug(f"Enter ASC Branch, looking up with FeatureSpec.") + self.check_mode(MxRecMode.ASC) + is_training = kwargs.get("is_train") + self.check_and_format_lookup_params(feature_spec, send_count, is_training) + rank_size = get_rank_size() + device_id = get_device_id() + use_hot = get_use_hot() + use_dynamic_expansion = get_use_dynamic_expansion() + + # check training mode order and ensure channel id + channel_id = get_training_mode_channel_id(is_training=is_training) + logging.debug(f"get preprocessed tensor for asc for table {self.table_name} with skip emb transfer " + f"{self.skip_emb_transfer} is_training: {is_training}, channel_id: {channel_id} .") + + config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, + rank_size=rank_size, channel_id=channel_id, table_name=self.table_name, + skip_emb_transfer=self.skip_emb_transfer, ext_emb_size=self.ext_emb_size, + _emb_size=self._emb_size, use_hot=use_hot, device_id=device_id, + use_dynamic_expansion=use_dynamic_expansion) + + if self.skip_emb_transfer: + restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( + self.variable, config) + else: + variable_list = [self.variable] + [slot_info.get("slot") for slot_info in self.optimizer_slot_info_list] + restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( + variable_list, config) + control_ops = swap_in + + id_offsets = tf.identity(id_offsets, name="identity_addr") + restore_vector = tf.identity(restore_vector, name="identity_restore") + if is_training and use_dynamic_expansion: + tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) + + use_static = get_use_static() + host_pipeline_ops = get_host_pipeline_ops() + + @tf.custom_gradient + def sparse_forward(table): + logging.debug(f"fp rank size: {rank_size}") + if not use_dynamic_expansion: + id_offsets_abs = tf.abs(id_offsets) + local_embeddings = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") + else: + local_embeddings = tf.identity(table, name="identity_local_emb") + + all2all_args = send_count if use_static else all2all_matrix + unique_embeddings = get_own_emb(local_embeddings, all2all_args, self.scalar_emb_size, use_static) + + if hot_pos is not None: + unique_embeddings = tf.concat([tf.gather(unique_embeddings, hot_pos, name="hot_pos"), + unique_embeddings], axis=0) + if use_static: + unique_embeddings_shape = unique_embeddings.shape.as_list() + else: + unique_embeddings_shape = tf.shape(unique_embeddings) + embeddings = tf.gather(unique_embeddings, restore_vector, axis=0, name="gather_for_restore_vector") + + if use_static: + lookup_result = tf.reshape(embeddings, feature_spec.dims + [self.scalar_emb_size]) + else: + lookup_result = tf.reshape(embeddings, [-1, self.scalar_emb_size]) + + def grad(lookup_diff): + embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) + logging.debug(f"bp rank size: {rank_size}") + unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, + restore_vector, + unique_embeddings_shape[0]) + bp_all2all_args = all2all_args if use_static else tf.transpose(all2all_args) + if hot_pos is not None: + hot, cold = tf.split(unique_grads, [tf.shape(hot_pos)[0], + tf.shape(unique_grads)[0] - tf.shape(hot_pos)[0]], axis=0) + unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) + local_grad = get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) + if use_dynamic_expansion: + update_grad = local_grad + else: + update_grad = ops.IndexedSlices(values=local_grad, indices=id_offsets, + dense_shape=tf.shape(table)) + return update_grad + + return lookup_result, grad + + with tf.control_dependencies(control_ops): + if not use_dynamic_expansion: + return sparse_forward(self.variable) + + local_embeddings = \ + host_pipeline_ops.embedding_lookup_by_address(id_offsets, embedding_dim=self._emb_size, + embedding_type=1) + if is_training: + tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) + + return sparse_forward(local_embeddings) + + def _initialize_variables(self): + initialized_tensor = self.emb_initializer(self.slice_device_vocabulary_size + self.embedding_size) + self.variable = tf.compat.v1.get_variable(self.table_name, trainable=False, initializer=initialized_tensor) + # make sure sparse table variable will not be saved and restored within tf checkpoint. + remove_saving_var(self.variable) + self._record() + + if self.use_dynamic_expansion: + for sparse_optimizer_instance in self._optimizer_instance_list: + self._slot_num[self.table_name] = sparse_optimizer_instance.slot_num + logging.info(f"init emb, table name: {self.table_name}, slot_num: {sparse_optimizer_instance.slot_num}") + + if self.mode == MxRecMode.ASC and not self.skip_emb_transfer: + # build optimizer states + for sparse_optimizer_instance in self._optimizer_instance_list: + slot_info_list = sparse_optimizer_instance.initialize_slots(self.variable, self) + self.optimizer_slot_info_list.extend(slot_info_list) + + for slot_info in self.optimizer_slot_info_list: + self.set_optimizer_slot(slot_info) + + @staticmethod + def set_optimizer_slot(slot_info): + slot = slot_info.get("slot") + slot_name = slot_info.get("slot_name") + optimizer = slot_info.get("optimizer") + named_slot_key = slot_info.get("named_slot_key") + + optimizer.insert_slot(slot, named_slot_key, slot_name) + + +def get_own_ids(unique_ids, origin_id_lens, send_cnt, self): + from mx_rec.util.tf_version_adapter import hccl_ops + rank_size = get_rank_size() + if rank_size > 1: + ids_send_cnt = tf.constant([send_cnt] * rank_size, dtype=tf.int64) + ids_send_offset = tf.constant([send_cnt * i for i in range(rank_size)], dtype=tf.int64) + own_ids = hccl_ops.all_to_all_v(send_data=unique_ids, + send_counts=ids_send_cnt, + send_displacements=ids_send_offset, + recv_counts=ids_send_cnt, + recv_displacements=ids_send_offset) + + lens_sc = tf.constant([1] * rank_size, dtype=tf.int64) + lens_sd = tf.constant([i for i in range(rank_size)], dtype=tf.int64) + local_id_lens = hccl_ops.all_to_all_v(send_data=origin_id_lens, + send_counts=lens_sc, + send_displacements=lens_sd, + recv_counts=lens_sc, + recv_displacements=lens_sd) + + else: + own_ids = unique_ids + local_id_lens = origin_id_lens + + def feature_mapping(): + self.set_using_feature_mapping() + id_offsets = SparseEmbedding.customized_ops.feature_mapping(own_ids, table_name=self.table_name) + return id_offsets + + id_offsets = feature_mapping() + id_offsets.set_shape([send_cnt * rank_size]) + + return id_offsets, local_id_lens + + +def get_own_emb(emb, all2all_args, emb_size, use_static): + ''' + obtain embedding of source data + ''' + from mx_rec.util.tf_version_adapter import hccl_ops + rank_size = get_rank_size() + rank_id = get_rank_id() + + src_emb = emb + + reshape_info = [all2all_args * rank_size, emb_size] if use_static else [-1, emb_size] + + if rank_size == 1 and use_static: + return tf.reshape(src_emb, reshape_info) + + if use_static: + emb_send_cnt = tf.constant([all2all_args * emb_size] * rank_size, dtype=tf.int64) + emb_send_offset = tf.constant([all2all_args * emb_size * i for i in range(rank_size)], dtype=tf.int64) + src_emb = hccl_ops.all_to_all_v(send_data=emb, + send_counts=emb_send_cnt, + send_displacements=emb_send_offset, + recv_counts=emb_send_cnt, + recv_displacements=emb_send_offset) + else: + src_emb = hccl_ops.all_to_all_v_c(send_data=emb, + send_count_matrix=all2all_args, + rank=rank_id) + + return tf.reshape(src_emb, reshape_info) + + +class _EvictHook(tf.compat.v1.train.SessionRunHook): + """Sets evict based on global step or time.""" + + def __init__(self, + evict_enable=False, + evict_time_interval=DEFAULT_EVICT_TIME_INTERVAL, + evict_step_interval=None): + self._evict_enable = evict_enable + self._evict_time_interval = evict_time_interval + self._evict_step_interval = evict_step_interval + self._hash_table_instance = dict() + self._start_time = time.time() + self._global_step = 0 + self._evict_op = dict() + self._global_step_tensor = None + + self.check_evict_init_params() + logging.info(f"_EvictHook - > evict_time_interval: {self._evict_time_interval}, " + f"evict_step_interval: {self._evict_step_interval}") + + def begin(self): + self._global_step_tensor = tf.compat.v1.train.get_or_create_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step should be created to use _EvictHook.") + self.check_name_and_get_hashtable() + for name, instance in self._hash_table_instance.items(): + scope_name = "{0}//{1}".format(instance.table_name, "evict") + with tf.compat.v1.variable_scope(scope_name): + logging.debug(f'Channel {instance.table_name}_evict_{TRAIN_CHANNEL_ID} was built for op ' + f'getnext') + + use_static = get_use_static() + if use_static: + evict_pos = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[instance.slice_device_vocabulary_size], + channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}')[0] + initialized_tensor = instance.emb_initializer( + instance.slice_device_vocabulary_size + instance.embedding_size) + else: + evict_pos = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[None], + channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}')[0] + + initialized_tensor = instance.emb_initializer( + evict_pos.shape.as_list()[0] + instance.embedding_size) + + logging.debug(f'evict_pos output shape {evict_pos}, and slice_device_vocabulary_size ' + f'{instance.slice_device_vocabulary_size}, ' + f'initialized_tensor shape: {initialized_tensor}') + + nd_evict_pos = tf.expand_dims(evict_pos, 1) + self._evict_op[name] = tf.compat.v1.scatter_nd_update(instance.variable, nd_evict_pos, + initialized_tensor) + + def after_create_session(self, session, coord): + self._global_step = session.run(self._global_step_tensor) + logging.debug(f"_EvictHook - > after_create_session, step: {self._global_step}") + + def after_run(self, run_context, run_values): + if not self._evict_enable: + return + + self._global_step = run_context.session.run(self._global_step_tensor) + cur_time = time.time() + if cur_time - self._start_time > self._evict_time_interval or \ + (self._evict_step_interval is not None and self._global_step % self._evict_step_interval == 0): + logging.info(f"_EvictHook - > evict switch on!!! after_run step: {self._global_step}") + trigger_evict() + self._start_time = cur_time + for name in self._hash_table_instance.keys(): + run_context.session.run(self._evict_op[name]) + + def check_name_and_get_hashtable(self): + for _, feature_spec in export_feature_spec().items(): + if feature_spec.eviction_threshold: + logging.debug(f"_EvictHook - > check and get instance: table_names {feature_spec.table_name}") + self._hash_table_instance[feature_spec.table_name] = get_table_instance_by_name(feature_spec.table_name) + + def check_evict_init_params(self): + def check_type(arg, n_type, param_name): + if not isinstance(arg, n_type): + raise TypeError(f"{param_name} should be type '{n_type}', whose value is {arg} with type " + f"'{type(arg)}' in fact.") + if type(arg) == int and arg < 1: + raise ValueError(f"{param_name} should be bigger than 0, whose value is {arg} in fact") + + check_type(self._evict_enable, bool, "evict_enable") + if self._evict_time_interval is not None: + check_type(self._evict_time_interval, int, "evict_time_interval") + if self._evict_step_interval is not None: + check_type(self._evict_step_interval, int, "evict_time_interval") diff --git a/mx_rec/graph/__init__.py b/mx_rec/graph/__init__.py new file mode 100644 index 00000000..e9754a0d --- /dev/null +++ b/mx_rec/graph/__init__.py @@ -0,0 +1 @@ +# coding: UTF-8 diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py new file mode 100644 index 00000000..5132e26c --- /dev/null +++ b/mx_rec/graph/modifier.py @@ -0,0 +1,425 @@ +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. +# Description: build script. +# Author: MindX SDK +# pylint: disable=W0212 +import logging +from collections import defaultdict + +import tensorflow as tf +from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter + +from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc +from mx_rec.core.asc.helper import get_asc_insert_func, FeatureSpec +from mx_rec.core.asc.manager import start_asc_pipeline +from mx_rec.core.embedding import SparseEmbedding +from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_CUTTING_POINT, \ + ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCAnchorAttr, ASCEND_TIMESTAMP +from mx_rec.util.initialize import get_rank_size, destroy_asc_manager, get_training_mode_channel_id, \ + get_feature_spec, insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id +from mx_rec.util.perf import performance +from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, replace_anchor, \ + record_ops_to_replace, export_pb_graph, make_sorted_key_to_tensor_list + + +def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tensor_names=None, + pipeline_input_indexes=None): + input_names = check_input_list(input_names, str) + output_names = check_input_list(output_names, str) + batch_tensor_names = check_input_list(batch_tensor_names, str) + pipeline_input_indexes = check_input_list(pipeline_input_indexes, int) + both_is_none = batch_tensor_names is None and pipeline_input_indexes is None + both_not_none = batch_tensor_names is not None and pipeline_input_indexes is not None + if both_is_none or both_not_none: + raise ValueError("It is legal when and only when one of the parameters 'batch_tensor_names' and " + "'pipeline_input_indexes' was given.") + + def map_func(*args): + def print_tensors(batch_id, tracker=None): + if tracker is None: + tracker = [] + if isinstance(batch_id, dict): + for key, item in batch_id.items(): + print_tensors(item, tracker + [key]) + if isinstance(batch_id, tf.Tensor): + logging.debug(f"######## tracker: {tracker}, tensor: {batch_id} ########") + + for batch in args: + print_tensors(batch) + + batch = args[0] + + input_tensors = [] + if batch_tensor_names is not None: + for tensor_name in batch_tensor_names: + tensor = batch.get(tensor_name) + if tensor is None: + raise ValueError(f"Given input_tensor_name '{tensor_name}' is invalid.") + + input_tensors.append(tensor) + + else: + graph = tf.get_default_graph() + for index in pipeline_input_indexes: + tensor = graph.get_tensor_by_name("args_%d:0" % index) + input_tensors.append(tensor) + + output_list = tf.import_graph_def(graph_def, input_map=dict(zip(input_names, input_tensors)), + return_elements=output_names) + + output_batch = list(args) + output_batch.append(tuple(output_list)) + return tuple(output_batch) + + return map_func + + +def get_input_index_list(cutting_point_list, replacement_specs, mapping_name_list, base_count, timestamp_index=None): + input_index_list = [] + for cutting_point in cutting_point_list: + if cutting_point in replacement_specs: + index = int(cutting_point.name.split(":")[1]) + + elif cutting_point.name in mapping_name_list: + index = base_count + mapping_name_list.index(cutting_point.name) + + else: + raise ValueError(f"Cannot find a matching output for cutting point tensor named '{cutting_point.name}'.") + input_index_list.append(index) + if timestamp_index is not None: + input_index_list = [timestamp_index] + input_index_list + + return input_index_list + + +def find_make_iterator_op(batch_tensor): + graph = tf.get_default_graph() + operations = graph.get_operations() + for each_op in operations: + for input_tensor in batch_tensor.op.inputs: + if input_tensor.op.outputs and input_tensor.op.outputs[0] in list( + each_op.inputs) and each_op.type == "MakeIterator": + logging.debug(f"Op MakeIterator '{each_op.name}' was found.") + return each_op + + raise ValueError(f"Op MakeIterator was not found.") + + +@performance("find_target_dataset_op") +def find_target_dataset_op(base_ops, op_type): + base_ops = check_input_list(base_ops, tf.Operation) + parent_ops = base_ops + + while True: + for parent_op in parent_ops: + if parent_op.type == op_type: + return parent_op + + base_ops = parent_ops + parent_ops = [] + for base_op in base_ops: + parent_ops.extend(find_parent_op(base_op)) + + if not parent_ops: + raise ValueError(f"Op {op_type} was not found.") + + +def get_op_before_optimize_dataset(get_next_op): + if get_next_op.type != "IteratorGetNext": + raise TypeError("Op '{get_next_op}' must be one instance of IteratorGetNext.") + + # looking for the MakeIterator operator which corresponds to given batch_tensor + base_op = find_make_iterator_op(get_next_op.outputs[0]) + # looking for the op which is the one before OptimizeDataset operator + target_op = find_target_dataset_op(base_op, "OptimizeDataset") + if find_parent_op(target_op)[0].type == "PrefetchDataset": + target_op = find_parent_op(target_op)[0] + + return target_op + + +def get_passing_tensor_list(src_tensors, target_op): + def get_passing_tensors(src_tensor): + passing_tensors = [] + tensor_list = [src_tensor] + while tensor_list: + last_tensor = tensor_list.pop() + if last_tensor.op is target_op: + passing_tensors.append(last_tensor) + else: + tensor_list.extend(list(last_tensor.op.inputs)) + + return passing_tensors + + src_tensors = check_input_list(src_tensors, tf.Tensor) + passing_tensor_list = [] + sub_src_tensors = [] + for tensor in src_tensors: + passing_tensors = get_passing_tensors(tensor) + for passing_tensor in passing_tensors: + if passing_tensor not in passing_tensor_list: + passing_tensor_list.append(passing_tensor) + if len(passing_tensors) != 0: + logging.info(f"passing_tensors: {passing_tensors}") + sub_src_tensors.append(tensor) + else: + logging.info(f"Cannot find passing tensor for given tensor '{tensor}'.") + + output_index_list = [int(tensor.name.split(":")[1]) for tensor in passing_tensor_list] + + return passing_tensor_list, output_index_list, sub_src_tensors + + +def find_target_instance_dataset(variant_tensor): + dataset_instance_list = tf.compat.v1.get_collection("dataset_group") + for ins in dataset_instance_list: + if ins._variant_tensor == variant_tensor: + if not isinstance(ins, DatasetV1Adapter): + ins = ins._input_dataset + logging.debug(f"Find target instance '{ins}', whose variant_tensor is '{variant_tensor}'.") + if not isinstance(ins.element_spec, dict) and not ( + isinstance(ins.element_spec, (list, tuple)) and len(ins.element_spec) == 2 and isinstance( + ins.element_spec[0], dict)): + raise NotImplementedError("The found dataset does not return a valid layout.") + + return ins + + raise LookupError(f"Can not find target instance, whose variant_tensor is '{variant_tensor}' respectively.") + + +def get_sub_graph(input_tensors, output_tensors): + input_tensors = check_input_list(input_tensors, tf.Tensor) + output_tensors = check_input_list(output_tensors, tf.Tensor) + input_op_name_list = [tensor.op.name for tensor in input_tensors] + output_op_name_list = [tensor.op.name for tensor in output_tensors] + + graph_def = tf.compat.v1.get_default_graph().as_graph_def() + cut_graph_input = tf.compat.v1.graph_util.extract_sub_graph(graph_def, input_op_name_list) + cut_graph_output = tf.compat.v1.graph_util.extract_sub_graph(graph_def, output_op_name_list) + + node_list = [] + node_list_input = cut_graph_input.node + node_list_output = cut_graph_output.node + for node in node_list_output: + if node not in node_list_input: + node_list.append(node) + + sub_graph_def = tf.compat.v1.GraphDef() + sub_graph_def.node.extend(node_list) + + input_name_list = [tensor.name for tensor in input_tensors] + output_name_list = [tensor.name for tensor in output_tensors] + + return sub_graph_def, input_name_list, output_name_list + + +def update_input_tensor_with_new_batch(replacement_specs, new_get_next_op_name): + graph = tf.compat.v1.get_default_graph() + for old_tensor, item in replacement_specs.items(): + for idx, operator in item: + old_tensor_name = old_tensor.name + output_index = old_tensor_name.split(":")[-1] + new_tensor_name = "%s:%s" % (new_get_next_op_name, output_index) + new_tensor = graph.get_tensor_by_name(new_tensor_name) + operator._update_input(idx, new_tensor) + + +def make_src_to_tgt_mapping(src_element_spec, tgt_element_spec): + # adding '_0' to the prefix + if not isinstance(src_element_spec, (list, tuple)): + src_element_spec = [src_element_spec] + src_sorted_keys = make_sorted_key_to_tensor_list(src_element_spec, []) + tgt_sorted_keys = make_sorted_key_to_tensor_list(tgt_element_spec, []) + index_to_src_key_mapping = dict([(idx, key) for idx, key in enumerate(src_sorted_keys)]) + tgt_key_to_index_mapping = dict([(key, idx) for idx, key in enumerate(tgt_sorted_keys)]) + + original_tensor_count = len(src_sorted_keys) + + def mapping_func(src_idx): + key = index_to_src_key_mapping.get(src_idx) + if key is None: + raise ValueError("Given src_idx is out of range.") + + tgt_idx = tgt_key_to_index_mapping.get(key) + return tgt_idx + + return mapping_func, original_tensor_count + + +@performance("graph_modifier") +def modify_graph_and_start_emb_cache(dump_graph=False): + modify_graph_for_asc(dump_graph=dump_graph) + start_asc_pipeline() + + +def generate_get_next_op_specs(cutting_point_list, dump_graph): + get_next_op_map = defaultdict(dict) + for input_tensor in cutting_point_list: + get_next_op = find_target_dataset_op(input_tensor.op, "IteratorGetNext") + if get_next_op not in get_next_op_map: + logging.debug(f"find a new get_next_op named '{get_next_op.name}'") + replacement_specs = record_ops_to_replace(get_next_op) + get_next_op_map[get_next_op]["replacement_specs"] = replacement_specs + passing_tensor_list, batch_tensor_index_list, sub_cutting_point_list = \ + get_passing_tensor_list(cutting_point_list, get_next_op) + get_next_op_map[get_next_op]["passing_tensor_list"] = passing_tensor_list + get_next_op_map[get_next_op]["batch_tensor_index_list"] = batch_tensor_index_list + get_next_op_map[get_next_op]["sub_cutting_point_list"] = sub_cutting_point_list + + sub_graph_def, input_name_list, output_name_list = get_sub_graph(passing_tensor_list, + sub_cutting_point_list) + get_next_op_map[get_next_op]["sub_graph_def"] = sub_graph_def + get_next_op_map[get_next_op]["input_name_list"] = input_name_list + get_next_op_map[get_next_op]["output_name_list"] = output_name_list + get_next_op_map[get_next_op]["is_training"] = \ + SparseEmbedding.get_anchor_attribute(input_tensor, ASCAnchorAttr.IS_TRAINING) + + export_pb_graph("cut_graph_%s.pb" % get_next_op.name, dump_graph, graph_def=sub_graph_def) + + return get_next_op_map + + +def get_src_and_generate_tgt_dataset(get_next_op, records): + target_op = get_op_before_optimize_dataset(get_next_op) + src_dataset = find_target_instance_dataset(target_op.outputs[0]) + tgt_dataset = src_dataset.map(get_preprocessing_map_func(records.get("sub_graph_def"), + records.get("input_name_list"), + records.get("output_name_list"), + pipeline_input_indexes=records.get( + "batch_tensor_index_list"))) + + return src_dataset, tgt_dataset + + +def modify_graph_for_asc(dump_graph=False, prefetch=10): + cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) + check_cutting_points(cutting_point_list) + if not cutting_point_list: + logging.warning("Nothing to revise.") + return + + export_pb_graph("old_graph.pb", dump_graph) + get_next_op_map = generate_get_next_op_specs(cutting_point_list, dump_graph) + + for get_next_op, records in get_next_op_map.items(): + is_training = records.get("is_training") + timestamp_index = get_timestamp_index(get_next_op, is_training) + src_dataset, tgt_dataset = get_src_and_generate_tgt_dataset(get_next_op, records) + mapping_func, original_tensor_count = make_src_to_tgt_mapping(src_dataset.element_spec, + tgt_dataset.element_spec) + sub_cutting_point_list = records.get("sub_cutting_point_list") + input_index_list = get_input_index_list(sub_cutting_point_list, + records.get("replacement_specs"), + records.get("output_name_list"), + original_tensor_count, timestamp_index=timestamp_index) + feature_numbers = [SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).feat_cnt for + cutting_point in sub_cutting_point_list] + table_names = [SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).table_name for + cutting_point in sub_cutting_point_list] + tgt_dataset = tgt_dataset.map( + get_asc_insert_func(feature_numbers=feature_numbers, table_names=table_names, + args_index_list=input_index_list, is_training=is_training, dump_graph=dump_graph)) + + tgt_dataset = tgt_dataset.prefetch(prefetch) + new_iterator = tgt_dataset.make_initializable_iterator() + new_batch = new_iterator.get_next() + tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, new_iterator.initializer) + set_initializer(is_training, new_iterator.initializer) + + try: + one_tensor = [v for _, v in new_batch.items()][0] + except IndexError as err: + raise IndexError(f"Cannot find a tensor from given batch.") from err + new_get_next_op_name = find_target_dataset_op(one_tensor.op, "IteratorGetNext").name + update_input_tensor_with_new_batch(records.get("replacement_specs"), new_get_next_op_name) + + for _, cutting_point in enumerate(sub_cutting_point_list): + feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) + table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) + channel_id = get_training_mode_channel_id(is_training) + config = dict( + batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, + send_count=table_instance.send_count, channel_id=channel_id, rank_size=get_rank_size(), + table_name=table_instance.table_name, skip_emb_transfer=table_instance.skip_emb_transfer, + ext_emb_size=table_instance.ext_emb_size, _emb_size=table_instance._emb_size, use_hot=get_use_hot(), + device_id=get_device_id()) + build_asc_graph(table_instance, cutting_point, config) + + logging.info("Graph has been revised.") + export_pb_graph("new_graph.pb", dump_graph) + + +def get_timestamp_index(get_next_op, is_training): + timestamp_tensor_list = tf.compat.v1.get_collection(ASCEND_TIMESTAMP) + timestamp_index = None + for timestamp in timestamp_tensor_list: + if timestamp in get_next_op.outputs: + timestamp_index = int(timestamp.name.split(":")[1]) + timestamp_feature_spec = get_feature_spec("timestamp") + if timestamp_feature_spec is None: + timestamp_feature_spec = FeatureSpec("timestamp", index_key=timestamp_index, is_timestamp=True) + timestamp_feature_spec.include_timestamp(is_training) + insert_feature_spec(timestamp_feature_spec, is_training) + break + + if timestamp_feature_spec.index_key != timestamp_index: + raise ValueError(f"Given timestamp_index, which is {timestamp_index}, does not match index " + f"key. Please double check.") + timestamp_feature_spec.include_timestamp(is_training) + break + return timestamp_index + + +def build_asc_graph(table_instance, cutting_point, config): + # returned results swap_pos and swap_len were not in used, will be applied for DDR mode + logging.debug(f"try to replace anchors for table {config.get('table_name')} on channel {config.get('channel_id')}") + skip_emb_transfer = config.get("skip_emb_transfer") + logging.info(f"modifier build_asc_graph skip_emb_transfer: {skip_emb_transfer}") + if skip_emb_transfer: + restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( + table_instance.variable, config) + else: + variable_list = [table_instance.variable] \ + + [slot_info.get("slot") for slot_info in table_instance.optimizer_slot_info_list] + restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( + variable_list, config) + + with tf.control_dependencies(swap_in): + id_offsets = tf.identity(id_offsets) + + logging.info(f"build_asc_graph -> id_offsets: {id_offsets}") + replace_anchor_vec(cutting_point, ASCAnchorAttr.ID_OFFSETS, id_offsets) + logging.info(f"build_asc_graph -> restore_vector: {restore_vector}") + replace_anchor_vec(cutting_point, ASCAnchorAttr.RESTORE_VECTOR, restore_vector) + + logging.info(f"build_asc_graph -> all2all_matrix: {all2all_matrix}") + if not get_use_static(): + replace_anchor_vec(cutting_point, ASCAnchorAttr.ALL2ALL_MATRIX, all2all_matrix) + + logging.info(f"build_asc_graph -> hot_pos: {hot_pos}") + if get_use_hot(): + replace_anchor_vec(cutting_point, ASCAnchorAttr.HOT_POS, hot_pos) + + logging.debug(f"has replace anchors for table {config.get('table_name')} on channel {config.get('channel_id')}") + + +def replace_anchor_vec(cutting_point, attribute, anchor): + anchor_vec = SparseEmbedding.get_anchor_attribute(cutting_point, attribute) + replacement_specs_for_anchor_vec = record_ops_to_replace(anchor_vec.op) + replace_anchor(replacement_specs_for_anchor_vec, [anchor]) + + +class GraphModifierHook(tf.estimator.SessionRunHook): + def __init__(self, dump_graph=True, modify_graph=False): + self.dump_graph = dump_graph + self.modify_graph = modify_graph + + def begin(self): + if self.modify_graph: + modify_graph_and_start_emb_cache(dump_graph=self.dump_graph) + else: + start_asc_pipeline() + + def after_create_session(self, session, coord): + if self.modify_graph: + session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py new file mode 100644 index 00000000..4f7ab920 --- /dev/null +++ b/mx_rec/graph/patch.py @@ -0,0 +1,29 @@ +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. +# Description: build script. +# Author: MindX SDK + +import weakref + +import tensorflow as tf +from tensorflow.python.data.ops.dataset_ops import DatasetV2 +from tensorflow.python.data.ops.dataset_ops import _VariantTracker +from tensorflow.python.framework import ops + + +def init_dataset(self, input_data): + """ + input_data: A DT_VARIANT tensor that represents the dataset. + """ + # pylint: disable=W + tf.compat.v1.add_to_collection("dataset_group", self) + self._variant_tensor_attr = input_data + # get obj + dataset_obj = weakref.proxy(self) + self._variant_tracker = self._track_trackable( + _VariantTracker(self._variant_tensor, lambda: dataset_obj._trace_variant_creation()()), name="_variant_tracker") + self._graph_attr = ops.get_default_graph() + + +def patch_for_dataset(): + DatasetV2.__init__ = init_dataset diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py new file mode 100644 index 00000000..2e9f0134 --- /dev/null +++ b/mx_rec/graph/utils.py @@ -0,0 +1,91 @@ +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. +# Description: build script. +# Author: MindX SDK +from collections import defaultdict + +import tensorflow as tf + + +def check_input_list(objs, obj_type): + if isinstance(objs, obj_type): + objs = [objs] + + if isinstance(objs, list): + for tensor in objs: + if not isinstance(tensor, obj_type): + raise ValueError(f"Given input parameter must be a {obj_type} or a list of {obj_type}") + + return objs + + +def find_parent_op(operator): + parent_ops = [] + for input_tensor in operator.inputs: + parent_op = input_tensor.op + if isinstance(parent_op, tf.Operation): + parent_ops.append(parent_op) + return parent_ops + + +def check_cutting_points(cutting_point_list): + for tensor in cutting_point_list: + if not isinstance(tensor, tf.Tensor): + raise TypeError(f"Collection ASCEND_CUTTING_POINT can only contain Tensors, but '{tensor}' was found.") + + if tensor.op.type != "Identity": + raise ValueError(f"Cutting point can only be the output of an Operator 'Identity'.") + + +def record_ops_to_replace(src_op): + replacement_specs = defaultdict(list) + output_list = src_op.outputs + op_list = tf.get_default_graph().get_operations() + for tensor in output_list: + for operator in op_list: + if tensor in operator.inputs: + input_index = list(operator.inputs).index(tensor) + replacement_specs[tensor].append((input_index, operator)) + + return replacement_specs + + +def replace_anchor(replacement_specs: defaultdict, new_tensor_list: list): + # pylint: disable=W0212 + if len(replacement_specs) != len(new_tensor_list): + raise ValueError("Given replacement_specs and new_tensor_list must have the same length.") + + for tensor_idx, (_, items) in enumerate(replacement_specs.items()): + for input_idx, operator in items: + operator._update_input(input_idx, new_tensor_list[tensor_idx]) + + +def export_pb_graph(file_name, dump_graph, graph_def=None, export_path="./export_graph"): + if dump_graph: + graph_def = graph_def if graph_def else tf.get_default_graph().as_graph_def() + tf.train.write_graph(graph_def, export_path, file_name, False) + + +def make_sorted_key_to_tensor_list(element_spec, sorted_keys, prefix=""): + if isinstance(element_spec, tf.TensorSpec): + sorted_keys.append(prefix) + return sorted_keys + + elif isinstance(element_spec, dict): + for key, item in element_spec.items(): + if not isinstance(key, str): + raise TypeError(f"The key of element_spec must be a string.") + + prefix = "{0}_{1}".format(prefix, key) + sorted_keys = make_sorted_key_to_tensor_list(item, sorted_keys, prefix=prefix) + sorted_keys = sorted(sorted_keys) + return sorted_keys + + elif isinstance(element_spec, (list, tuple)): + for idx, item in enumerate(element_spec): + prefix = "{0}_{1}".format(prefix, str(idx)) + sorted_keys = make_sorted_key_to_tensor_list(item, sorted_keys, prefix=prefix) + sorted_keys = sorted(sorted_keys) + return sorted_keys + + raise TypeError(f"Given element_spec, whose type is {type(element_spec)}, is invalid.") diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py new file mode 100644 index 00000000..9b1b8deb --- /dev/null +++ b/mx_rec/optimizers/adagrad.py @@ -0,0 +1,119 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2023 Huawei Technologies Co., Ltd + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from collections import defaultdict + +import tensorflow as tf +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import adagrad, training_ops +from tensorflow.python.training import slot_creator + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.initialize import get_table_instance +from mx_rec.util.variable import remove_saving_var, check_param_type + + +def create_hash_optimizer(learning_rate=0.001, + initial_accumulator_value=0.9, + use_locking=False, + name="Adagrad"): + """ + Create an instance of adagrad hash optimizer + :param learning_rate: A `Tensor` or a floating point value. The learning rate. + :param initial_accumulator_value: A floating point value. Starting value for the accumulators, must be positive. + :param use_locking: If `True` use locks for update operations. + :param name: Optional name prefix for the operations created when applying gradients. Defaults to "Adagrad". + :return: adagrad hash optimizer instance + """ + return CustomizedAdagrad(learning_rate=learning_rate, + initial_accumulator_value=initial_accumulator_value, + use_locking=use_locking, + name=name) + + +class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, + learning_rate, + initial_accumulator_value, + use_locking=False, + name="Adagrad"): + self.optimizer_type = "Adagrad" + super(CustomizedAdagrad, self).__get_name__(name=name) + super(CustomizedAdagrad, self).__init__(learning_rate=learning_rate, + initial_accumulator_value=initial_accumulator_value, + use_locking=use_locking, + name=self.unique_name) + + self._check_input_param() + + def _check_input_param(self): + check_param_type("learning_rate", self._learning_rate, (tf.Tensor, float)) + check_param_type("initial_accumulator_value", self._initial_accumulator_value, (tf.Tensor, float)) + check_param_type("use_locking", self._use_locking, bool) + + def _create_slots(self, var_list): + logging.debug(" Start _create_slots") + for var in var_list: + dtype = var.dtype.base_dtype + if var.get_shape().is_fully_defined(): + init = init_ops.constant_initializer(self._initial_accumulator_value, + dtype=dtype) + else: + init = self._init_constant_op(var, dtype) + + acc_state_name = self._name + "/" + "accumulator" + self._get_or_make_slot_with_initializer(var, init, var.get_shape(), dtype, + "acc", acc_state_name) + + def _apply_sparse(self, grad, var): + acc = self.get_slot(var, "acc") + return training_ops.sparse_apply_adagrad( + var, acc, math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + grad.values, + grad.indices, + use_locking=self._use_locking) + + def _resource_apply_sparse(self, grad, var, indices): + acc = self.get_slot(var, "acc") + return training_ops.resource_sparse_apply_adagrad( + var.handle, acc.handle, math_ops.cast(self._learning_rate_tensor, grad.dtype), + grad, indices, use_locking=self._use_locking) + + def initialize_slots(self, var, table_instance): + # Create slots for the first and second moments. + def creat_one_single_slot(var, op_name): + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + return new_slot_variable + + accumulator = creat_one_single_slot(var, self._name + "/" + "accumulator") + remove_saving_var(accumulator) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"accumulator": accumulator}) + return [{"slot": accumulator, "named_slot_key": named_slot_key, "slot_name": "acc", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + # return state value list of adagrad that needs to initialize in ASC DDR. + initial_accumulator_value = 0.0 + return [initial_accumulator_value] diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py new file mode 100644 index 00000000..a54e829c --- /dev/null +++ b/mx_rec/optimizers/base.py @@ -0,0 +1,38 @@ +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. +# Description: build script. +# Author: MindX SDK + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + + +class CustomizedOptimizer: + + name_counter = defaultdict(int) + + def __init__(self): + self.unique_name = "" + self.base_name = "" + + def __get_name__(self, name="CustomizedOptimizer"): + if name in CustomizedOptimizer.name_counter: + CustomizedOptimizer.name_counter[name] += 1 + count = CustomizedOptimizer.name_counter.get(name) + + else: + count = CustomizedOptimizer.name_counter[name] + self.unique_name = name + "_" + str(count) + self.base_name = name + + def initialize_slots(self, var, table_instance): + raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") + + def insert_slot(self, slot, named_slots_key, slot_name): + raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") + + def get_slot_init_values(self): + raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py new file mode 100644 index 00000000..9ccb0099 --- /dev/null +++ b/mx_rec/optimizers/ftrl.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. +# Description: +# Author: MindX SDK + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from collections import defaultdict + +import tensorflow as tf + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import gen_state_ops +from tensorflow.python.training import ftrl +from tensorflow.python.training import slot_creator + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.initialize import get_table_instance +from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var, check_param_type, check_param_range + + +def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl", **kwargs): + + return CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) + + +class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, learning_rate, use_locking=False, name="Ftrl", **kwargs): + self.optimizer_type = "ftrl" + super(CustomizedFtrl, self).__get_name__(name=name) + super(CustomizedFtrl, self).__init__( + learning_rate=learning_rate, + learning_rate_power=kwargs.get("learning_rate_power", -0.5), + initial_accumulator_value=kwargs.get("initial_accumulator_value", 0.1), + l1_regularization_strength=kwargs.get("l1_regularization_strength", 0.0), + l2_regularization_strength=kwargs.get("l2_regularization_strength", 0.0), + use_locking=use_locking, + name=self.unique_name, + accum_name=kwargs.get("accum_name", None), + linear_name=kwargs.get("linear_name", None), + l2_shrinkage_regularization_strength=kwargs.get("l2_shrinkage_regularization_strength", 0.0) + ) + + param_name_list = ["initial_accumulator_value", "l1_regularization_strength", + "l2_regularization_strength", "l2_shrinkage_regularization_strength"] + + def _check_param_type_range(param_name_list): + for name in param_name_list: + if kwargs.get(name, None): + check_param_type(name, kwargs.get(name), (int, float)) + check_param_range(name, kwargs.get(name), 0, 1e4) + + if kwargs.get("accum_name", None): + check_param_type("accum_name", kwargs.get("accum_name"), str) + + if kwargs.get("linear_name", None): + check_param_type("linear_name", kwargs.get("linear_name"), str) + + check_param_type("use_locking", use_locking, bool) + + _check_param_type_range(param_name_list) + + def _apply_sparse_duplicate_indices(self, grad, var): + logging.debug(f"######### _apply_sparse_duplicate_indices {var}") + return self._apply_sparse(grad, var) + + def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): + logging.debug(f"######### _resource_apply_sparse_duplicate_indices {indices}") + return self._resource_apply_sparse(grad, handle, indices) + + def _resource_apply_sparse(self, grad, handle, indices): + logging.debug("Enter _resource_apply_sparse") + if self._l2_shrinkage_regularization_strength <= 0.0: + return self._apply_sparse_shared( + grad, + handle, + indices, + self._resource_scatter_nd_update) + else: + return self._apply_sparse_shared_v2( + grad, + handle, + indices, + self._resource_scatter_nd_update) + + def _apply_sparse(self, grad, var): + logging.debug("Enter _apply_sparse") + if self._l2_shrinkage_regularization_strength <= 0.0: + return self._apply_sparse_shared( + grad.values, + var, + grad.indices, + lambda x, i, v: tf.compat.v1.scatter_nd_update(x, i, v)) + else: + return self._apply_sparse_shared_v2( + grad.values, + var, + grad.indices, + lambda x, i, v: tf.compat.v1.scatter_nd_update(x, i, v)) + + def _apply_sparse_shared(self, grad, var, indices, scatter_nd_update): + logging.debug("Enter _apply_sparse_shared") + accum = self.get_slot(var, "accum") + linear = self.get_slot(var, "linear") + lr = math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) + l1 = math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype) + l2 = math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype) + lr_power = math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype) + + abs_indices = tf.math.maximum(indices, 0) + nd_indices = tf.expand_dims(indices, 1) + accum_old = tf.gather(accum, abs_indices) + linear_old = tf.gather(linear, abs_indices) + var_old = tf.gather(var, abs_indices) + + accum_update = accum_old + tf.multiply(grad, grad) + with tf.control_dependencies([accum_update]): + accum_update_op = scatter_nd_update(accum, nd_indices, accum_update) + + sigma = math_ops.pow(accum_update, -lr_power) - math_ops.pow(accum_old, -lr_power) + sigma = tf.divide(sigma, lr) + + linear_update = linear_old + grad + tf.multiply(sigma, var_old) + with tf.control_dependencies([linear_update]): + linear_update_op = scatter_nd_update(linear, nd_indices, linear_update) + + quadratic = tf.divide(1.0, math_ops.pow(accum_update, lr_power) * lr) + 2 * l2 + + var_new = tf.math.sign(linear_update) * l1 - linear_update + var_new = tf.divide(var_new, quadratic) + mask = math_ops.cast(tf.math.greater(tf.abs(linear_update), l1), var.dtype.base_dtype) + + var_update = tf.multiply(var_new, mask) + + var_update_op = scatter_nd_update(var, nd_indices, var_update) + + return control_flow_ops.group(accum_update_op, linear_update_op, var_update_op) + + def _apply_sparse_shared_v2(self, grad, var, indices, scatter_nd_update): + logging.debug("Enter _apply_sparse_shared_v2") + accum = self.get_slot(var, "accum") + linear = self.get_slot(var, "linear") + lr = math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) + l1 = math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype) + l2 = math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype) + lr_power = math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype) + l2_shrinkage = math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, var.dtype.base_dtype) + + abs_indices = tf.math.maximum(indices, 0) + nd_indices = tf.expand_dims(indices, 1) + accum_old = tf.gather(accum, abs_indices) + linear_old = tf.gather(linear, abs_indices) + var_old = tf.gather(var, abs_indices) + + grad_with_shrinkage = grad + 2 * l2_shrinkage * var_old + + accum_update = accum_old + tf.multiply(grad, grad) + with tf.control_dependencies([accum_update]): + accum_update_op = scatter_nd_update(accum, nd_indices, accum_update) + + sigma = math_ops.pow(accum_update, -lr_power) - math_ops.pow(accum_old, -lr_power) + sigma = tf.divide(sigma, lr) + + with tf.control_dependencies([grad_with_shrinkage]): + linear_update = linear_old + grad_with_shrinkage - tf.multiply(sigma, var_old) + with tf.control_dependencies([linear_update]): + linear_update_op = scatter_nd_update(linear, nd_indices, linear_update) + + quadratic = tf.divide(1.0, math_ops.pow(accum_update, lr_power) * lr) + 2 * l2 + + var_new = tf.math.sign(linear_update) * l1 - linear_update + var_new = tf.divide(var_new, quadratic) + mask = math_ops.cast(tf.math.greater(tf.abs(linear_update), l1), var.dtype.base_dtype) + + var_update = tf.multiply(var_new, mask) + + var_update_op = scatter_nd_update(var, nd_indices, var_update) + + return control_flow_ops.group(accum_update_op, linear_update_op, var_update_op) + + def _resource_scatter_nd_update(self, x, i, v): + with ops.control_dependencies([ + gen_state_ops.resource_scatter_nd_update(x.handle, i, v)]): + return x.value() + + def _create_slots(self, var_list): + logging.debug(" Enter _create_slots") + + # Create slots for the first and second moments. + accum_state_name = self._name + "/" + "accum" + linear_state_name = self._name + "/" + "linear" + for each_var in var_list: + with ops.colocate_with(each_var): + val = constant_op.constant( + self._initial_accumulator_value, dtype=each_var.dtype, shape=each_var.get_shape()) + + table_instance = check_and_get_config_via_var(each_var, self.optimizer_type) + + accum = self._get_or_make_slot(each_var, val, "accum", accum_state_name) + linear = self._zeros_slot(each_var, "linear", linear_state_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + remove_saving_var(accum) + remove_saving_var(linear) + + if self._name not in table_instance.optimizer: + table_instance.set_optimizer(self._name, {"accum": accum, "linear": linear}) + + def initialize_slots(self, var, table_instance): + val = constant_op.constant( + self._initial_accumulator_value, dtype=var.dtype, shape=var.get_shape()) + + accum = slot_creator.create_slot(var, val, self._name + "/" + "accum") + linear = slot_creator.create_zeros_slot(var, self._name + "/" + "linear") + remove_saving_var(accum) + remove_saving_var(linear) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"accum": accum, "linear": linear}) + return [{"slot": accum, "named_slot_key": named_slot_key, "slot_name": "accum", "optimizer": self}, + {"slot": linear, "named_slot_key": named_slot_key, "slot_name": "linear", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + # return state value list of ftrl that needs to initialize in ASC DDR. + initial_linear_value = 0.0 + return [self._initial_accumulator_value, initial_linear_value] diff --git a/mx_rec/optimizers/ftrl_t.py b/mx_rec/optimizers/ftrl_t.py new file mode 100644 index 00000000..a2795f6d --- /dev/null +++ b/mx_rec/optimizers/ftrl_t.py @@ -0,0 +1,255 @@ +# coding=utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: +# Author: MindX SDK + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from collections import defaultdict + +import tensorflow as tf + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import gen_state_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import slot_creator + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.initialize import get_table_instance +from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var + + +def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl_t", **kwargs): + + return CustomizedFtrlT(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) + + +class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, learning_rate, use_locking=False, name="Ftrl_t", **kwargs): + self.optimizer_type = "ftrl" + super(CustomizedFtrlT, self).__get_name__(name=name) + + self._learning_rate = learning_rate + self._alpha = kwargs.get("alpha", 0.06) + self._beta = kwargs.get("beta", 1.0) + self._lambda1 = kwargs.get("lambda1", 0.0) + self._lambda2 = kwargs.get("lambda2", 0.0) + self._epsilon = kwargs.get("epsilon", 0.0) + self._grad_factor = kwargs.get("grad_factor", 0.0) + self._z_name = kwargs.get("z_name", None) + self._n_name = kwargs.get("n_name", None) + self._g_name = kwargs.get("g_name", None) + self._learning_rate_tensor = None + self._alpha_tensor = None + self._beta_tensor = None + self._lambda1_tensor = None + self._lambda2_tensor = None + self._epsilon_tensor = None + self._grad_factor_tensor = None + super(CustomizedFtrlT, self).__init__(use_locking, self.unique_name) + + def _prepare(self): + self._learning_rate_tensor = ops.convert_to_tensor( + self._learning_rate, name="learning_rate") + self._alpha_tensor = ops.convert_to_tensor(self._alpha, name="alpha") + self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta") + self._lambda1_tensor = ops.convert_to_tensor(self._lambda1, name="lambda1") + self._lambda2_tensor = ops.convert_to_tensor(self._lambda2, name="lambda2") + self._epsilon_tensor = ops.convert_to_tensor(self._epsilon, name="epsilon") + self._grad_factor_tensor = ops.convert_to_tensor(self._grad_factor, name="grad_factor") + + def _apply_sparse_duplicate_indices(self, grad, var): + logging.debug(f"######### _apply_sparse_duplicate_indices {var}") + return self._apply_sparse(grad, var) + + def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): + logging.debug(f"######### _resource_apply_sparse_duplicate_indices {indices}") + return self._resource_apply_sparse(grad, handle, indices) + + def _resource_apply_sparse(self, grad, handle, indices): + logging.debug("Enter _resource_apply_sparse") + if self._lambda1 > 1e-10: + return self._apply_sparse_shared( + grad, + handle, + indices, + self._resource_scatter_nd_update) + else: + return self._apply_sparse_shared_v2( + grad, + handle, + indices, + self._resource_scatter_nd_update) + + def _apply_sparse(self, grad, var): + logging.debug("Enter _apply_sparse") + if self._lambda1 > 1e-10: + return self._apply_sparse_shared( + grad.values, + var, + grad.indices, + lambda x, i, v: tf.compat.v1.scatter_nd_update(x, i, v)) + else: + return self._apply_sparse_shared_v2( + grad.values, + var, + grad.indices, + lambda x, i, v: tf.compat.v1.scatter_nd_update(x, i, v)) + + def _apply_sparse_shared(self, grad, var, indices, scatter_nd_update): + logging.debug("Enter _apply_sparse_shared") + z = self.get_slot(var, "z") + n = self.get_slot(var, "n") + g = self.get_slot(var, "g") + w = self.get_slot(var, "w") + alpha = math_ops.cast(self._alpha_tensor, var.dtype.base_dtype) + beta = math_ops.cast(self._beta_tensor, var.dtype.base_dtype) + lambda1 = math_ops.cast(self._lambda1_tensor, var.dtype.base_dtype) + lambda2 = math_ops.cast(self._lambda2_tensor, var.dtype.base_dtype) + epsilon = math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype) + grad_factor = math_ops.cast(self._grad_factor_tensor, var.dtype.base_dtype) + + abs_indices = tf.math.maximum(indices, 0) + nd_indices = tf.expand_dims(indices, 1) + with tf.control_dependencies([grad]): + z_old = tf.gather(z, abs_indices) + n_old = tf.gather(n, abs_indices) + g_old = tf.gather(g, abs_indices) + var_old = tf.gather(w, abs_indices) + + g_new = grad_factor * g_old + (1.0 - grad_factor) * grad + with tf.control_dependencies([g_new]): + g_update = scatter_nd_update(g, nd_indices, g_new) + + rho = tf.divide(tf.sqrt(n_old + tf.square(g_new)) - tf.sqrt(n_old), alpha) + z_new = (1.0 - epsilon) * z_old + g_new - tf.multiply(rho, var_old) + with tf.control_dependencies([z_new]): + z_update = scatter_nd_update(z, nd_indices, z_new) + + n_new = (1.0 - epsilon) * n_old + tf.square(g_new) + with tf.control_dependencies([n_new]): + n_update = scatter_nd_update(n, nd_indices, n_new) + + denominator = tf.divide((beta + tf.sqrt(n_new)), alpha) + lambda2 + numerator = lambda1 * tf.sign(z_new) - z_new + mask = math_ops.cast(tf.math.greater(tf.abs(z_new), lambda1), var.dtype.base_dtype) + var_new = tf.multiply(mask, tf.divide(numerator, denominator)) + with tf.control_dependencies([var_new]): + w_update = scatter_nd_update(w, nd_indices, var_new) + var_update = scatter_nd_update(var, nd_indices, var_new) + + return control_flow_ops.group(g_update, z_update, n_update, w_update, var_update) + + def _apply_sparse_shared_v2(self, grad, var, indices, scatter_nd_update): + logging.debug("Enter _apply_sparse_shared_v2") + z = self.get_slot(var, "z") + n = self.get_slot(var, "n") + g = self.get_slot(var, "g") + w = self.get_slot(var, "w") + alpha = math_ops.cast(self._alpha_tensor, var.dtype.base_dtype) + beta = math_ops.cast(self._beta_tensor, var.dtype.base_dtype) + lambda2 = math_ops.cast(self._lambda2_tensor, var.dtype.base_dtype) + epsilon = math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype) + grad_factor = math_ops.cast(self._grad_factor_tensor, var.dtype.base_dtype) + + abs_indices = tf.math.maximum(indices, 0) + nd_indices = tf.expand_dims(indices, 1) + with tf.control_dependencies([grad]): + z_old = tf.gather(z, abs_indices) + n_old = tf.gather(n, abs_indices) + g_old = tf.gather(g, abs_indices) + var_old = tf.gather(w, abs_indices) + + g_new = grad_factor * g_old + (1.0 - grad_factor) * grad + with tf.control_dependencies([g_new]): + g_update = scatter_nd_update(g, nd_indices, g_new) + + rho = tf.divide(tf.sqrt(n_old + tf.square(g_new)) - tf.sqrt(n_old), alpha) + z_new = (1.0 - epsilon) * z_old + g_new - tf.multiply(rho, var_old) + with tf.control_dependencies([z_new]): + z_update = scatter_nd_update(z, nd_indices, z_new) + + n_new = (1.0 - epsilon) * n_old + tf.square(g_new) + with tf.control_dependencies([n_new]): + n_update = scatter_nd_update(n, nd_indices, n_new) + + denominator = tf.divide((beta + tf.sqrt(n_new)), alpha) + lambda2 + var_new = tf.divide(-1.0 * z_new, denominator) + with tf.control_dependencies([var_new]): + w_update = scatter_nd_update(w, nd_indices, var_new) + var_update = scatter_nd_update(var, nd_indices, var_new) + + return control_flow_ops.group(g_update, z_update, n_update, w_update, var_update) + + def _resource_scatter_nd_update(self, x, i, v): + with ops.control_dependencies([ + gen_state_ops.resource_scatter_nd_update(x.handle, i, v)]): + return x.value() + + def _create_slots(self, var_list): + logging.debug(" Enter _create_slots") + + # Create slots for the first and second moments. + z_state_name = self._name + "/" + "z" + n_state_name = self._name + "/" + "n" + g_state_name = self._name + "/" + "g" + w_state_name = self._name + "/" + "w" + for each_var in var_list: + with ops.colocate_with(each_var): + table_instance = check_and_get_config_via_var(each_var, self.optimizer_type) + + z = self._zeros_slot(each_var, "z", z_state_name) + n = self._zeros_slot(each_var, "n", n_state_name) + g = self._zeros_slot(each_var, "g", g_state_name) + w = self._zeros_slot(each_var, "w", w_state_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + remove_saving_var(z) + remove_saving_var(n) + remove_saving_var(g) + remove_saving_var(w) + + if self._name not in table_instance.optimizer: + table_instance.set_optimizer(self._name, {"z": z, "n": n, "g": g, "w": w}) + + def initialize_slots(self, var, table_instance): + z = slot_creator.create_zeros_slot(var, self._name + "/" + "z") + n = slot_creator.create_zeros_slot(var, self._name + "/" + "n") + g = slot_creator.create_zeros_slot(var, self._name + "/" + "g") + w = slot_creator.create_zeros_slot(var, self._name + "/" + "w") + remove_saving_var(z) + remove_saving_var(n) + remove_saving_var(g) + remove_saving_var(w) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"z": z, "n": n, "g": g, "w": w}) + return [{"slot": z, "named_slot_key": named_slot_key, "slot_name": "z", "optimizer": self}, + {"slot": n, "named_slot_key": named_slot_key, "slot_name": "n", "optimizer": self}, + {"slot": g, "named_slot_key": named_slot_key, "slot_name": "g", "optimizer": self}, + {"slot": w, "named_slot_key": named_slot_key, "slot_name": "w", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + initial_z_value = 0.0 + initial_n_value = 0.0 + initial_g_value = 0.0 + initial_w_value = 0.0 + return [initial_z_value, initial_n_value, initial_g_value, initial_w_value] diff --git a/mx_rec/optimizers/ftrl_t_dense.py b/mx_rec/optimizers/ftrl_t_dense.py new file mode 100644 index 00000000..ae305838 --- /dev/null +++ b/mx_rec/optimizers/ftrl_t_dense.py @@ -0,0 +1,191 @@ +# coding=utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: +# Author: MindX SDK + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from collections import defaultdict + +import tensorflow as tf + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import gen_state_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import slot_creator + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.initialize import get_table_instance +from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var + + +def create_ftrl_dense_optimizer(learning_rate, use_locking=False, name="Ftrl_t_dense", **kwargs): + + return CustomizedFtrlTZ(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) + + +class CustomizedFtrlTZ(optimizer.Optimizer): + name_counter = defaultdict(int) + + def __init__(self, learning_rate, use_locking=False, name="Ftrl_t_dense", **kwargs): + self.optimizer_type = "ftrl" + self._learning_rate = learning_rate + self._alpha = kwargs.get("alpha", 0.06) + self._beta = kwargs.get("beta", 1.0) + self._lambda1 = kwargs.get("lambda1", 0.0) + self._lambda2 = kwargs.get("lambda2", 0.0) + self._epsilon = kwargs.get("epsilon", 0.0) + self._grad_factor = kwargs.get("grad_factor", 0.0) + self._z_name = kwargs.get("z_name", None) + self._n_name = kwargs.get("n_name", None) + self._g_name = kwargs.get("g_name", None) + self._learning_rate_tensor = None + self._alpha_tensor = None + self._beta_tensor = None + self._lambda1_tensor = None + self._lambda2_tensor = None + self._epsilon_tensor = None + self._grad_factor_tensor = None + super(CustomizedFtrlTZ, self).__init__(use_locking, name) + logging.debug("CustomizedFtrlTZ __init__ ok") + + def _prepare(self): + self._learning_rate_tensor = ops.convert_to_tensor( + self._learning_rate, name="learning_rate") + self._alpha_tensor = ops.convert_to_tensor(self._alpha, name="alpha") + self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta") + self._lambda1_tensor = ops.convert_to_tensor(self._lambda1, name="lambda1") + self._lambda2_tensor = ops.convert_to_tensor(self._lambda2, name="lambda2") + self._epsilon_tensor = ops.convert_to_tensor(self._epsilon, name="epsilon") + self._grad_factor_tensor = ops.convert_to_tensor(self._grad_factor, name="grad_factor") + + def _resource_apply_dense(self, grad, handle): + if self._lambda1 > 1e-10: + return self._apply_dense_shared( + grad, + handle) + else: + return self._apply_dense_shared_v2( + grad, + handle) + + def _apply_dense(self, grad, var): + if self._lambda1 > 1e-10: + return self._apply_dense_shared( + grad.values, + var) + else: + return self._apply_dense_shared_v2( + grad.values, + var) + + def _apply_dense_shared(self, grad, var): + logging.debug("Enter _apply_dense_shared") + z_var = self.get_slot(var, "z") + n_var = self.get_slot(var, "n") + g_var = self.get_slot(var, "g") + w_var = self.get_slot(var, "w") + alpha = math_ops.cast(self._alpha_tensor, var.dtype.base_dtype) + beta = math_ops.cast(self._beta_tensor, var.dtype.base_dtype) + lambda1 = math_ops.cast(self._lambda1_tensor, var.dtype.base_dtype) + lambda2 = math_ops.cast(self._lambda2_tensor, var.dtype.base_dtype) + epsilon = math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype) + grad_factor = math_ops.cast(self._grad_factor_tensor, var.dtype.base_dtype) + + z_old = tf.identity(z_var) + n_old = tf.identity(n_var) + g_old = tf.identity(g_var) + var_old = tf.identity(w_var) + + g_new = grad_factor * g_old + (1.0 - grad_factor) * grad + with tf.control_dependencies([g_new]): + g_update = tf.compat.v1.assign(g_var, g_new) + + rho = tf.divide(tf.sqrt(n_old + tf.square(g_new)) - tf.sqrt(n_old), alpha) + z_new = (1.0 - epsilon) * z_old + g_new - tf.multiply(rho, var_old) + with tf.control_dependencies([z_new]): + z_update = tf.compat.v1.assign(z_var, z_new) + + n_new = (1.0 - epsilon) * n_old + tf.square(g_new) + with tf.control_dependencies([n_new]): + n_update = tf.compat.v1.assign(n_var, n_new) + + denominator = tf.divide((beta + tf.sqrt(n_new)), alpha) + lambda2 + numerator = lambda1 * tf.sign(z_new) - z_new + mask = math_ops.cast(tf.math.greater(tf.abs(z_new), lambda1), var.dtype.base_dtype) + var_new = tf.multiply(mask, tf.divide(numerator, denominator)) + with tf.control_dependencies([var_new]): + w_update = tf.compat.v1.assign(w_var, var_new) + var_updata = tf.compat.v1.assign(var, var_new) + + return control_flow_ops.group(g_update, z_update, n_update, w_update, var_updata) + + def _apply_dense_shared_v2(self, grad, var): + logging.debug("Enter _apply_dense_shared_v2") + z_var = self.get_slot(var, "z") + n_var = self.get_slot(var, "n") + g_var = self.get_slot(var, "g") + w_var = self.get_slot(var, "w") + alpha = math_ops.cast(self._alpha_tensor, var.dtype.base_dtype) + beta = math_ops.cast(self._beta_tensor, var.dtype.base_dtype) + lambda2 = math_ops.cast(self._lambda2_tensor, var.dtype.base_dtype) + epsilon = math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype) + grad_factor = math_ops.cast(self._grad_factor_tensor, var.dtype.base_dtype) + + z_old = tf.identity(z_var) + n_old = tf.identity(n_var) + g_old = tf.identity(g_var) + var_old = tf.identity(w_var) + + g_new = grad_factor * g_old + (1.0 - grad_factor) * grad + with tf.control_dependencies([g_new]): + g_updata = tf.compat.v1.assign(g_var, g_new) + + rho = tf.divide(tf.sqrt(n_old + tf.square(g_new)) - tf.sqrt(n_old), alpha) + z_new = (1.0 - epsilon) * z_old + g_new - tf.multiply(rho, var_old) + with tf.control_dependencies([z_new]): + z_updata = tf.compat.v1.assign(z_var, z_new) + + n_new = (1.0 - epsilon) * n_old + tf.square(g_new) + with tf.control_dependencies([n_new]): + n_updata = tf.compat.v1.assign(n_var, n_new) + + denominator = tf.divide((beta + tf.sqrt(n_new)), alpha) + lambda2 + var_new = tf.divide(-1.0 * z_new, denominator) + with tf.control_dependencies([var_new]): + w_updata = tf.compat.v1.assign(w_var, var_new) + var_updata = tf.compat.v1.assign(var, var_new) + + return control_flow_ops.group(g_updata, z_updata, n_updata, w_updata, var_updata) + + def _resource_scatter_nd_update(self, x_input, i_input, v_input): + with ops.control_dependencies([ + gen_state_ops.resource_scatter_nd_update(x_input.handle, i_input, v_input)]): + return x_input.value() + + def _create_slots(self, var_list): + logging.debug(" Enter _create_slots") + + # Create slots for the first and second moments. + z_state_name = self._name + "/" + "z" + n_state_name = self._name + "/" + "n" + g_state_name = self._name + "/" + "g" + w_state_name = self._name + "/" + "w" + for each_var in var_list: + with ops.colocate_with(each_var): + z_zero = self._zeros_slot(each_var, "z", z_state_name) + n_zero = self._zeros_slot(each_var, "n", n_state_name) + g_zero = self._zeros_slot(each_var, "g", g_state_name) + w_zero = self._zeros_slot(each_var, "w", w_state_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + remove_saving_var(z_zero) + remove_saving_var(n_zero) + remove_saving_var(g_zero) + remove_saving_var(w_zero) + diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py new file mode 100644 index 00000000..997d62fa --- /dev/null +++ b/mx_rec/optimizers/gradient_descent.py @@ -0,0 +1,55 @@ +# coding=utf-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. +# Description: +# Author: MindX SDK +# Create: 2022-12-01 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from collections import defaultdict + +import tensorflow as tf + +from tensorflow.python.ops import gen_state_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import gradient_descent + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.variable import check_param_type + + +def create_hash_optimizer(learning_rate, use_locking=False, name="GradientDescent"): + + return CustomizedGradientDescent(learning_rate=learning_rate, use_locking=use_locking, name=name) + + +class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, learning_rate, use_locking=False, name="GradientDescent"): + self.optimizer_type = "gradient_descent" + super(CustomizedGradientDescent, self).__get_name__(name=name) + super(CustomizedGradientDescent, self).__init__(learning_rate=learning_rate, use_locking=use_locking, + name=self.unique_name) + + check_param_type("use_locking", use_locking, bool) + + def _apply_sparse_duplicate_indices(self, grad, var): + logging.debug(" Enter _apply_sparse_duplicate_indices") + nd_indices = tf.expand_dims(grad.indices, 1) + nd_value = grad.values * math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) + var_update_op = tf.scatter_nd_add(var, nd_indices, -nd_value, use_locking=self._use_locking) + return var_update_op + + def _apply_dense(self, grad, var): + logging.debug(" Enter _apply_dense") + raise NotImplementedError("You are using a wrong type of variable.") + + def initialize_slots(self, var, table_instance): + return [] + + def get_slot_init_values(self): + return [] diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py new file mode 100644 index 00000000..3d22bf1b --- /dev/null +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -0,0 +1,184 @@ +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: CustomizedGradientDescentByAddr. +# Author: MindX SDK + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import logging +from collections import defaultdict + +from tensorflow.python.framework import ops, indexed_slices +from tensorflow.python.ops import math_ops +from tensorflow.python.training import optimizer +from tensorflow.python.eager import context +from tensorflow.python.ops import resource_variable_ops + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer + + +def create_hash_optimizer_by_addr(learning_rate, weight_decay=0.0001, use_locking=False, name="GradientDescentByAddr"): + optimizer_by_addr = CustomizedGradientDescentByAddr(learning_rate=learning_rate, + weight_decay=weight_decay, + use_locking=use_locking, + name=name) + insert_optimizer(optimizer_by_addr) + return optimizer_by_addr + + +class CustomizedGradientDescentByAddr(optimizer.Optimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, learning_rate, weight_decay, use_locking=False, name="GradientDescentByAddr"): + self.optimizer_type = "gradient_descent_by_addr" + self.weight_decay = weight_decay + super(CustomizedGradientDescentByAddr, self).__init__(use_locking, name) + + self._learning_rate = learning_rate + self._learning_rate_tensor = None + self._slot_num = 0 + + def _convert_grads_and_addrs(self, grads_and_vars): + converted_grads_and_addrs = [] + for grad, addr in grads_and_vars: + if grad is not None: + try: + # Convert the grad to Tensor or IndexedSlices if necessary. + grad = ops.convert_to_tensor_or_indexed_slices(grad) + except TypeError as error: + raise TypeError("Gradient must be convertible to a Tensor or IndexedSlices, or None") from error + if not isinstance(grad, (ops.Tensor, indexed_slices.IndexedSlices)): + raise TypeError("Gradient must be a Tensor, IndexedSlices, or None") + processor = _get_processor(addr) + converted_grads_and_addrs.append((grad, addr, processor)) + return converted_grads_and_addrs + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + # No DistributionStrategy case. + grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. + if not grads_and_vars: + raise ValueError("No variables provided.") + converted_grads_and_addrs = tuple(self._convert_grads_and_addrs(grads_and_vars)) + + addr_list = [a for g, a, _ in converted_grads_and_addrs if g is not None] + if not addr_list: + raise ValueError("No gradients provided for any address: %s." % + ([str(a) for _, a, _ in converted_grads_and_addrs],)) + with ops.init_scope(): + self._create_slots(addr_list) + + update_ops = [] + with ops.name_scope(name, self._name) as name: + self._prepare() + for grad, addr, processor in converted_grads_and_addrs: + if grad is None: + continue + if (context.executing_eagerly() or + resource_variable_ops.is_resource_variable(addr) + and not addr._in_graph_mode): # pylint: disable=protected-access + scope_name = "" + else: + scope_name = addr.op.name + with ops.name_scope( + "update_" + scope_name), ops.colocate_with(addr): + update_ops.append(processor.update_op(self, grad)) + + apply_updates = self._finish(update_ops, name) + + if not context.executing_eagerly(): + if isinstance(apply_updates, ops.Tensor): + logging.debug(">>>>Enter ops.Tensor") + apply_updates = apply_updates.op + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + if apply_updates not in train_op: + logging.debug(">>>>Enter apply_updates not in train_op") + train_op.append(apply_updates) + else: + raise RuntimeError("eager wrong.") + + return apply_updates + + @property + def slot_num(self): + return self._slot_num + + def _prepare(self): + learning_rate = self._call_if_callable(self._learning_rate) + self._learning_rate_tensor = ops.convert_to_tensor( + learning_rate, name="learning_rate") + + def get_slot_init_values(self): + return [] + + def _apply_sparse(self, grad, addr): + logging.debug(">>>> Enter _apply_sparse SGD by addr") + host_pipeline_ops = get_host_pipeline_ops() + dim = grad.shape.as_list()[-1] + if self.weight_decay is None: + nd_value = grad * math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype) + else: + lookup_tensor = \ + host_pipeline_ops.embedding_lookup_by_address(addr, embedding_dim=dim, embedding_type=1) + nd_value = (grad + math_ops.cast(self.weight_decay, grad.dtype.base_dtype) * lookup_tensor) * math_ops.cast( + self._learning_rate_tensor, grad.dtype.base_dtype) + var_update_op = host_pipeline_ops.embedding_update_by_address(addr, -nd_value, update_type=0) + + return var_update_op + + def _apply_dense(self, grad, var): + logging.debug(">>>> Enter _apply_dense") + raise NotImplementedError("You are using a wrong type of variable.") + + +def get_filtered_grad_fn(grad_fn): + def filtered_grad_fn(*args, **kwargs): + return [(g, a) for g, a in grad_fn(*args, **kwargs) if g is not None] + + return filtered_grad_fn + + +class _OptimizableAddr(metaclass=abc.ABCMeta): + """Interface for abstracting over addresses in the optimizers.""" + + @abc.abstractmethod + def target(self): + """Returns the optimization target for this address.""" + raise NotImplementedError("Calling an abstract method.") + + @abc.abstractmethod + def update_op(self, opt, grad): + """Returns the update ops for updating the address.""" + raise NotImplementedError("Calling an abstract method.") + + +def _get_processor(addr): + """The processor of v.""" + if isinstance(addr, ops.Tensor): + logging.debug(">>>>Enter _get_processor tensor") + return _TensorByAddressProcessor(addr) + raise NotImplementedError("Trying to optimize unsupported type ", addr) + + +class _TensorByAddressProcessor(_OptimizableAddr): + """Processor for Tensor filled with addresses.""" + + def __init__(self, addr): + self._a = addr + + def target(self): + return self._a + + def __str__(self): + return "<_TensorByAddressProcessor(%s)>" % self._a + + def update_op(self, opt, grad): + if isinstance(grad, ops.Tensor): + logging.debug(">>>>Enter update_op ops.Tensor") + update_op = opt._apply_sparse(grad, self._a) # pylint: disable=protected-access + return update_op + else: + raise RuntimeError("Only support g with type Tensor.") diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py new file mode 100644 index 00000000..8d3c0427 --- /dev/null +++ b/mx_rec/optimizers/lazy_adam.py @@ -0,0 +1,184 @@ +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. +# Description: build script. +# Author: MindX SDK + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from collections import defaultdict + +import tensorflow as tf + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_state_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import adam +from tensorflow.python.training import slot_creator + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.initialize import get_table_instance +from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var, check_param_type, check_param_range + + +def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdam"): + """ + Args: + learning_rate: learning rate + beta1: + beta2: + epsilon: + name: + + Returns: a customized optimizer instance + """ + return CustomizedLazyAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) + + +class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="LazyAdam"): + self.optimizer_type = "LazyAdam" + super(CustomizedLazyAdam, self).__get_name__(name=name) + super(CustomizedLazyAdam, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, + epsilon=epsilon, use_locking=use_locking, name=self.unique_name) + + check_param_type("beta1", beta1, (int, float)) + check_param_range("beta1", beta1, 0, 1) + + check_param_type("beta2", beta2, (int, float)) + check_param_range("beta2", beta2, 0, 1) + + check_param_type("epsilon", epsilon, (int, float)) + check_param_range("epsilon", epsilon, 0, 1) + + check_param_type("use_locking", use_locking, bool) + + + def _apply_sparse_duplicate_indices(self, grad, var): + # _apply_sparse_duplicate_indices method include tf.unique and unsorted_segment_sum operations which may + # introduce dynamic shape problem, if encounter that, please de-annotation the method below. + logging.debug(f"_apply_sparse_duplicate_indices {var}") + return self._apply_sparse(grad, var) + + def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): + logging.debug(f"_resource_apply_sparse_duplicate_indices {indices}") + return self._resource_apply_sparse(grad, handle, indices) + + def _apply_dense(self, grad, var): + logging.debug("Enter _apply_dense") + raise NotImplementedError("You are using a wrong type of variable.") + + def _cast_to_base_type(self, var): + var_type = var.dtype.base_dtype + temp_lr = math_ops.cast(self._lr_t, var_type) + temp_b1 = math_ops.cast(self._beta1_t, var_type) + temp_b2 = math_ops.cast(self._beta2_t, var_type) + temp_epsilon = math_ops.cast(self._epsilon_t, var_type) + return temp_lr, temp_b1, temp_b2, temp_epsilon + + def _resource_apply_sparse(self, grad, handle, indices): + logging.debug("Enter _resource_apply_sparse") + return self._apply_sparse_shared( + grad, + handle, + indices, + self._resource_scatter_nd_add) + + def _apply_sparse(self, grad, var): + logging.debug("Enter _apply_sparse") + return self._apply_sparse_shared( + grad.values, + var, + grad.indices, + lambda x, i, v: tf.compat.v1.scatter_nd_add(x, i, v)) + + def _apply_sparse_shared(self, grad, var, indices, scatter_nd_add): + power_b1, power_b2 = self._get_beta_accumulators() + power_b1 = math_ops.cast(power_b1, var.dtype.base_dtype) + power_b2 = math_ops.cast(power_b2, var.dtype.base_dtype) + temp_lr, temp_b1, temp_b2, temp_epsilon = self._cast_to_base_type(var) + learning_rate = tf.divide(temp_lr * math_ops.sqrt(1 - power_b2), (1 - power_b1)) + + abs_indices = tf.math.maximum(indices, 0) + nd_indices = tf.expand_dims(indices, 1) + + momentum = self.get_slot(var, "m") + old_m_slice = tf.gather(momentum, abs_indices) + m_t_slice = temp_b1 * old_m_slice + (1 - temp_b1) * grad + m_update_op = scatter_nd_add(momentum, nd_indices, m_t_slice - old_m_slice) + + velocity = self.get_slot(var, "v") + old_v_slice = tf.gather(velocity, abs_indices) + v_t_slice = temp_b2 * old_v_slice + (1 - temp_b2) * math_ops.square(grad) + v_update_op = scatter_nd_add(velocity, nd_indices, v_t_slice - old_v_slice) + + denominator_slice = math_ops.sqrt(v_t_slice) + temp_epsilon + var_update_op = scatter_nd_add(var, nd_indices, tf.divide(-learning_rate * m_t_slice, denominator_slice)) + return control_flow_ops.group(m_update_op, v_update_op, var_update_op) + + def _resource_scatter_nd_add(self, x, i, v): + with ops.control_dependencies([ + gen_state_ops.resource_scatter_nd_add(x.handle, i, v)]): + return x.value() + + def _create_slots(self, var_list): + logging.debug(" Enter _create_slots") + first_var = min(var_list, key=lambda x: x.name) + self._create_non_slot_variable( + initial_value=self._beta1, name="beta1_power", colocate_with=first_var) + self._create_non_slot_variable( + initial_value=self._beta2, name="beta2_power", colocate_with=first_var) + + # Create slots for the first and second moments. + m_state_name = self._name + "/" + "momentum" + v_state_name = self._name + "/" + "velocity" + for each_var in var_list: + table_instance = check_and_get_config_via_var(each_var, self.optimizer_type) + + momentum = self._zeros_slot(each_var, "m", m_state_name) + velocity = self._zeros_slot(each_var, "v", v_state_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + remove_saving_var(momentum) + remove_saving_var(velocity) + + if self._name not in table_instance.optimizer: + table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) + + def initialize_slots(self, var, table_instance): + # Create slots for the first and second moments. + def creat_one_single_slot(var, op_name): + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + return new_slot_variable + + momentum = creat_one_single_slot(var, self._name + "/" + "momentum") + velocity = creat_one_single_slot(var, self._name + "/" + "velocity") + remove_saving_var(momentum) + remove_saving_var(velocity) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) + return [{"slot": momentum, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}, + {"slot": velocity, "named_slot_key": named_slot_key, "slot_name": "v", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + # return state value list of adam that needs to initialize in ASC DDR. + initial_momentum_value = 0.0 + initial_velocity_value = 0.0 + return [initial_momentum_value, initial_velocity_value] diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py new file mode 100644 index 00000000..714634f8 --- /dev/null +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -0,0 +1,307 @@ +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: CustomizedLazyAdamByAddress. +# Author: MindX SDK + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import logging +from collections import defaultdict + +import tensorflow as tf +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import optimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import indexed_slices + +from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.variable import check_param_type, check_param_range + + +def create_hash_optimizer_by_address(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, + name="LazyAdamByAddress"): + """ + Args: + learning_rate: learning rate + beta1: + beta2: + epsilon: + name: + + Returns: a customized optimizer instance + """ + + optimizer_by_addr = CustomizedLazyAdamByAddress(learning_rate=learning_rate, beta1=beta1, beta2=beta2, + epsilon=epsilon, name=name) + insert_optimizer(optimizer_by_addr) + return optimizer_by_addr + + +class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, + name="LazyAdamByAddress"): + self.optimizer_type = "LazyAdamByAddress" + super(CustomizedLazyAdamByAddress, self).__init__(use_locking, name) + + self._lr = learning_rate + self._beta1 = beta1 + self._beta2 = beta2 + self._epsilon = epsilon + + self._non_slot_dict = {} + self._slot_num = 2 + + # Tensor versions of the constructor arguments, created in _prepare(). + self._lr_t = None + self._beta1_t = None + self._beta2_t = None + self._epsilon_t = None + + self._check_input_param() + + def _check_input_param(self): + check_param_type("beta1", self._beta1, (int, float)) + check_param_range("beta1", self._beta1, 0, 1) + + check_param_type("beta2", self._beta2, (int, float)) + check_param_range("beta2", self._beta2, 0, 1) + + check_param_type("epsilon", self._epsilon, (int, float)) + check_param_range("epsilon", self._epsilon, 0, 1) + + check_param_type("use_locking", self._use_locking, bool) + + @property + def slot_num(self): + return self._slot_num + + def _get_beta_accumulators(self): + with ops.init_scope(): + if context.executing_eagerly(): + graph = None + else: + graph = ops.get_default_graph() + return (self._get_non_slot_variable("beta1_power", graph=graph), + self._get_non_slot_variable("beta2_power", graph=graph)) + + def _create_slots(self, addr_list): + first_addr = addr_list[0] + self._create_non_slot_variable( + initial_value=self._beta1, name="beta1_power", colocate_with=first_addr) + self._create_non_slot_variable( + initial_value=self._beta2, name="beta2_power", colocate_with=first_addr) + + def _create_non_slot_variable(self, initial_value, name, colocate_with): + """Add an extra variable, not associated with a slot.""" + # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. + eager = context.executing_eagerly() + graph = None if eager else tf.get_default_graph() + + key = (name, graph) + var = self._non_slot_dict.get(key, None) + if var is None: + distribution_strategy = distribute_ctx.get_strategy() + with distribution_strategy.extended.colocate_vars_with(colocate_with): + var = variable_scope.variable( + initial_value, name=name, trainable=False, + use_resource=resource_variable_ops.is_resource_variable( + colocate_with)) + self._non_slot_dict[key] = var + return var + + def _prepare(self): + learn_rate = self._call_if_callable(self._lr) + beta1 = self._call_if_callable(self._beta1) + beta2 = self._call_if_callable(self._beta2) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(learn_rate, name="learning_rate") + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") + + def _finish(self, update_ops, name_scope): + # Update the power accumulators. + with ops.control_dependencies(update_ops): + beta1_power, beta2_power = self._get_beta_accumulators() + with ops.colocate_with(beta1_power): + update_beta1 = beta1_power.assign( + beta1_power * self._beta1_t, use_locking=self._use_locking) + update_beta2 = beta2_power.assign( + beta2_power * self._beta2_t, use_locking=self._use_locking) + return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], name=name_scope) + + def get_slot_init_values(self): + # return state value list of adam that needs to initialize in ASC DDR. + initial_momentum_value = 0.0 + initial_velocity_value = 0.0 + return [initial_momentum_value, initial_velocity_value] + + def _apply_dense(self, grad, var): + logging.debug(">>>>Enter _apply_dense") + raise NotImplementedError("You are using a wrong type of variable.") + + def _cast_to_base_type(self, var): + var_type = var.dtype.base_dtype + temp_lr = math_ops.cast(self._lr_t, var_type) + temp_b1 = math_ops.cast(self._beta1_t, var_type) + temp_b2 = math_ops.cast(self._beta2_t, var_type) + temp_epsilon = math_ops.cast(self._epsilon_t, var_type) + return temp_lr, temp_b1, temp_b2, temp_epsilon + + def _apply_sparse(self, grad, addr): + logging.debug(">>>> Enter _apply_sparse Lazy_adam by addr") + return self._apply_sparse_shared( + grad, + addr) + + def _apply_sparse_shared(self, grad, addr): + power_b1, power_b2 = self._get_beta_accumulators() + power_b1 = math_ops.cast(power_b1, grad.dtype.base_dtype) + power_b2 = math_ops.cast(power_b2, grad.dtype.base_dtype) + temp_lr, temp_b1, temp_b2, temp_epsilon = self._cast_to_base_type(grad) + learning_rate = tf.divide(temp_lr * math_ops.sqrt(1 - power_b2), (1 - power_b1)) + + host_pipeline_ops = get_host_pipeline_ops() + dim = grad.shape.as_list()[-1] + combined_tensor = \ + host_pipeline_ops.embedding_lookup_by_address(addr, embedding_dim=3 * dim, embedding_type=1) + + split_length = [dim] + [dim] + [dim] + split_tensors = tf.split(combined_tensor, split_length, axis=1) + + old_m_slice = split_tensors[1] + m_t_slice = temp_b1 * old_m_slice + (1 - temp_b1) * grad + + old_v_slice = split_tensors[2] + v_t_slice = temp_b2 * old_v_slice + (1 - temp_b2) * math_ops.square(grad) + + denominator_slice = math_ops.sqrt(v_t_slice) + temp_epsilon + update_list = [tf.divide(-learning_rate * m_t_slice, denominator_slice)] + [m_t_slice - old_m_slice] + \ + [v_t_slice - old_v_slice] + update_tensor = tf.concat(update_list, axis=1) + var_update_op = host_pipeline_ops.embedding_update_by_address(addr, update_tensor, update_type=0) + var_update_op = tf.identity(var_update_op, name="identity_var_update_op") + + return var_update_op + + def _convert_grads_and_addrs(self, grads_and_vars): + converted_grads_and_addrs = [] + for grad, addr in grads_and_vars: + if grad is not None: + try: + # Convert the grad to Tensor or IndexedSlices if necessary. + grad = ops.convert_to_tensor_or_indexed_slices(grad) + except TypeError as error: + raise TypeError("Gradient must be convertible to a Tensor or IndexedSlices, or None") from error + if not isinstance(grad, (ops.Tensor, indexed_slices.IndexedSlices)): + raise TypeError("Gradient must be a Tensor, IndexedSlices, or None") + processor = _get_processor(addr) + converted_grads_and_addrs.append((grad, addr, processor)) + return converted_grads_and_addrs + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + + # No DistributionStrategy case. + grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. + if not grads_and_vars: + raise ValueError("No variables provided.") + + converted_grads_and_addrs = tuple(self._convert_grads_and_addrs(grads_and_vars)) + addr_list = [a for g, a, _ in converted_grads_and_addrs if g is not None] + if not addr_list: + raise ValueError("No gradients provided for any address: %s." % + ([str(a) for _, a, _ in converted_grads_and_addrs],)) + with ops.init_scope(): + self._create_slots(addr_list) + update_ops = [] + with ops.name_scope(name, self._name) as name: + self._prepare() + for grad, addr, processor in converted_grads_and_addrs: + if grad is None: + continue + if (context.executing_eagerly() or + resource_variable_ops.is_resource_variable(addr) + and not addr._in_graph_mode): # pylint: disable=protected-access + scope_name = "" + else: + scope_name = addr.op.name + with ops.name_scope( + "update_" + scope_name), ops.colocate_with(addr): + update_ops.append(processor.update_op(self, grad)) + + apply_updates = self._finish(update_ops, name) + + if not context.executing_eagerly(): + if isinstance(apply_updates, ops.Tensor): + logging.debug(">>>>Enter ops.Tensor") + apply_updates = apply_updates.op + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + if apply_updates not in train_op: + logging.debug(">>>>Enter apply_updates not in train_op") + train_op.append(apply_updates) + else: + raise RuntimeError("eager wrong.") + + return apply_updates + + +def get_filtered_grad_fn(grad_fn): + def filtered_grad_fn(*args, **kwargs): + return [(g, a) for g, a in grad_fn(*args, **kwargs) if g is not None] + + return filtered_grad_fn + + +class _OptimizableAddr(metaclass=abc.ABCMeta): + """Interface for abstracting over addresses in the optimizers.""" + + @abc.abstractmethod + def target(self): + """Returns the optimization target for this address.""" + raise NotImplementedError("Calling an abstract method.") + + @abc.abstractmethod + def update_op(self, opt, grad): + """Returns the update ops for updating the address.""" + raise NotImplementedError("Calling an abstract method.") + + +class _TensorByAddressProcessor(_OptimizableAddr): + """Processor for Tensor filled with addresses.""" + + def __init__(self, addr): + self._a = addr + + def target(self): + return self._a + + def __str__(self): + return "<_TensorByAddressProcessor(%s)>" % self._a + + def update_op(self, opt, grad): + if isinstance(grad, ops.Tensor): + logging.debug(">>>>Enter update_op ops.Tensor") + update_op = opt._apply_sparse(grad, self._a) # pylint: disable=protected-access + return update_op + else: + raise RuntimeError("Only support g with type Tensor.") + + +def _get_processor(addr): + """The processor of v.""" + if isinstance(addr, ops.Tensor): + logging.debug(">>>>Enter _get_processor tensor") + return _TensorByAddressProcessor(addr) + raise NotImplementedError("Trying to optimize unsupported type ", addr) \ No newline at end of file diff --git a/mx_rec/optimizers/momentum.py b/mx_rec/optimizers/momentum.py new file mode 100644 index 00000000..f9d67508 --- /dev/null +++ b/mx_rec/optimizers/momentum.py @@ -0,0 +1,130 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2023 Huawei Technologies Co., Ltd + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from collections import defaultdict + +import tensorflow as tf +from tensorflow.python.ops import math_ops +from tensorflow.python.training import training_ops +from tensorflow.python.training import momentum +from tensorflow.python.training import slot_creator + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.initialize import get_table_instance +from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var, check_param_type, check_param_range + + +def create_hash_optimizer(learning_rate_input=0.001, mom=0.9, enable_locking=False, optimizer_name="momentum", + enable_nesterov=False): + """ + Create an instance of hash optimizer + :param learning_rate_input: A `Tensor` or a floating point value. The learning rate. + :param mom: A `Tensor` or a floating point value. The momentum. + :param enable_locking: If `True` use locks for update operations. + :param optimizer_name: Optional name prefix for the operations created when applying gradients. + Defaults to "Momentum". + :param enable_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al., 2013). This implementation always + computes gradients at the value of the variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. This implementation is an approximation of + the original formula, valid for high values of momentum. It will compute the "adjusted gradient" in NAG by + assuming that the new gradient will be estimated by the current average gradient plus the product of momentum and + the change in the average gradient. + :return: momentum hash optimizer instance + """ + return CustomizedMomentum(learning_rate=learning_rate_input, + momentum_var=mom, + use_locking=enable_locking, + name=optimizer_name, + use_nesterov=enable_nesterov) + + +class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, + learning_rate, + momentum_var, + use_locking=False, + name="Momentum", + use_nesterov=False): + self.optimizer_type = "Momentum" + super(CustomizedMomentum, self).__get_name__(name=name) + super(CustomizedMomentum, self).__init__(learning_rate=learning_rate, + momentum=momentum_var, + use_locking=use_locking, + name=self.unique_name, + use_nesterov=use_nesterov) + + self._check_input_param() + + def _check_input_param(self): + check_param_type("learning_rate", self._learning_rate, (tf.Tensor, float)) + check_param_type("momentum", self._momentum, (tf.Tensor, float)) + check_param_type("use_locking", self._use_locking, bool) + check_param_type("use_nesterov", self._use_nesterov, bool) + + check_param_range("momentum", self._momentum, 0.0, 1.0) + + def _create_slots(self, var_list): + logging.debug(" Start _create_slots") + m_state_name = self._name + "/" + "momentum" + for var in var_list: + table_instance = check_and_get_config_via_var(var, self.optimizer_type) + momentum_slot = self._zeros_slot(var, "m", m_state_name) + + remove_saving_var(momentum_slot) + if self._name not in table_instance.optimizer: + table_instance.set_optimizer(self._name, {"momentum": momentum_slot}) + logging.debug(" End _create_slots") + + def _apply_sparse(self, grad, var): + mom = self.get_slot(var, "m") + return training_ops.sparse_apply_momentum( + var, mom, math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + grad.values, grad.indices, math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), + use_locking=self._use_locking, + use_nesterov=self._use_nesterov).op + + def _resource_apply_sparse(self, grad, var, indices): + mom = self.get_slot(var, "m") + return training_ops.resource_sparse_apply_momentum( + var.handle, mom.handle, math_ops.cast(self._learning_rate_tensor, grad.dtype), + grad, indices, math_ops.cast(self._momentum_tensor, grad.dtype), + use_locking=self._use_locking, + use_nesterov=self._use_nesterov) + + def initialize_slots(self, var, table_instance): + # Create slots for the first and second moments. + def creat_one_single_slot(var, op_name): + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + return new_slot_variable + + momentum_slot = creat_one_single_slot(var, self._name + "/" + "momentum") + remove_saving_var(momentum_slot) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"momentum": momentum_slot}) + return [{"slot": momentum_slot, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + # return state value list of momentum that needs to initialize in ASC DDR. + initial_momentum_value = 0.0 + return [initial_momentum_value] diff --git a/mx_rec/saver/__init__.py b/mx_rec/saver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py new file mode 100644 index 00000000..56eba09d --- /dev/null +++ b/mx_rec/saver/patch.py @@ -0,0 +1,352 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Huawei Technologies Co., Ltd + +import os +import time +import logging + +import tensorflow as tf +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.core.protobuf import trackable_object_graph_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.ops import io_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import training_util +from tensorflow.python.util import compat +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.training.saving import saveable_object +from tensorflow.python.training.saving import saveable_object_util + +from mx_rec.saver.saver import Saver as SparseSaver +from mx_rec.util.initialize import get_ascend_global_hashtable_collection + + +def get_sparse_vars(var_list): + # build sparse saver + if var_list is not None: + if not isinstance(var_list, (list, tuple)): + raise TypeError("A non-None var_list must be a list or tuple.") + ascend_variables = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + sparse_var_list = [] + for var in var_list: + if var in ascend_variables: + sparse_var_list.append(var) + var_list = sparse_var_list + else: + var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + return var_list + + +def init_check(defer_build, var_list): + if defer_build and var_list: + raise ValueError( + "If `var_list` is provided then build cannot be deferred. Either set defer_build=False or var_list=None.") + if context.executing_eagerly(): + tf_logging.warning("When executing eagerly variables do not necessarily have unique names, " + "and so the variable.name-based lookups Saver performs are error-prone.") + if var_list is None: + raise RuntimeError("eager execution, `var_list` must specify a list or dict of variables to save") + + +def saver_init(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, + name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, + allow_empty=False, write_version=saver_pb2.SaverDef.V2, pad_step_number=False, save_relative_paths=False, + filename=None, fid_version=0): + self._last_checkpoints = [] + self._checkpoints_to_be_deleted = [] + self._var_list = var_list + self._is_built = False + self._is_empty = None + + init_check(defer_build, var_list) + self._write_version = write_version + self._reshape = reshape + self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours + self._save_relative_paths = save_relative_paths + self._sharded = sharded + self._restore_sequentially = restore_sequentially + self._max_to_keep = max_to_keep + self._builder = builder + self._name = name + self._filename = filename + self.saver_def = saver_def + self._allow_empty = allow_empty + self._pad_step_number = pad_step_number + # mt customed parameter + self._fid_version = fid_version + + if self.saver_def: + self._check_saver_def() + self._write_version = self.saver_def.version + + if context.executing_eagerly(): + keep_time = self._keep_checkpoint_every_n_hours * 3600 + self._next_checkpoint_time = (time.time() + keep_time) + elif not defer_build: + self.build() + self._object_restore_saver = None + # mxRec Patch + # create sparse saver only when var_list is not None + self.sparse_saver = None + if get_sparse_vars(var_list): + self.sparse_saver = SparseSaver(var_list=var_list, max_to_keep=max_to_keep, prefix_name=filename) + + +def save_check(latest_filename, sess): + if os.path.split(latest_filename)[0]: + raise ValueError("'latest_filename' must not contain path components") + if not context.executing_eagerly() and not isinstance(sess, session.SessionInterface): + raise TypeError("'sess' must be a Session; %s" % sess) + + +def get_model_checkpoint_path(self, checkpoint_file, sess): + if not context.executing_eagerly(): + model_checkpoint_path = sess.run(self.saver_def.save_tensor_name, + {self.saver_def.filename_tensor_name: checkpoint_file}) + # mxRec Patch + # save sparse model, only run when self.sparse_saver is not None + if self.sparse_saver: + self.sparse_saver.save(sess, save_path=checkpoint_file) + logging.info("Save model into dir %s", checkpoint_file) + else: + self._build_eager(checkpoint_file, build_save=True, build_restore=False) + model_checkpoint_path = self.saver_def.save_tensor_name + + return model_checkpoint_path + + +def update_checkpoint_state(self, model_checkpoint_path, parent_save_path, latest_file_name, suffix_meta_graph, + save_path): + self._RecordLastCheckpoint(model_checkpoint_path) + try: + checkpoint_management.update_checkpoint_state_internal(save_dir=parent_save_path, + model_checkpoint_path=model_checkpoint_path, + all_model_checkpoint_paths=self.last_checkpoints, + latest_filename=latest_file_name, + save_relative_paths=self._save_relative_paths) + except errors.NotFoundError as err: + if not gfile.IsDirectory(parent_save_path): + err = ValueError(f"Parent directory of {save_path} doesn't exist, can't save.") + raise err + self._MaybeDeleteOldCheckpoints(meta_graph_suffix=suffix_meta_graph) + + +def write_meta_graph_task(self, checkpoint_file, suffix_meta_graph, sess, strip_default_attrs, save_debug_info): + meta_graph_name = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix=suffix_meta_graph) + if not context.executing_eagerly(): + with sess.graph.as_default(): + self.export_meta_graph(meta_graph_name, strip_default_attrs=strip_default_attrs, + save_debug_info=save_debug_info) + + +def get_checkpoint_file(self, global_step, sess, save_path): + if not isinstance(global_step, compat.integral_types): + global_step = training_util.global_step(sess, global_step) + checkpoint_file = f"{save_path}-{global_step}" + if self._pad_step_number: + checkpoint_file = f"{save_path}-{global_step:08d}" + return checkpoint_file + + +def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, + write_state=True, strip_default_attrs=False, save_debug_info=False): + if not self._is_built and not context.executing_eagerly(): + raise RuntimeError("`build()` should be called before save if defer_build==True") + if latest_filename is None: + latest_filename = "checkpoint" + if self._write_version != saver_pb2.SaverDef.V2: + tf_logging.warning("TensorFlow's V1 checkpoint format has been deprecated.") + + save_check(latest_filename, sess) + + if global_step is not None: + checkpoint_file = get_checkpoint_file(self, global_step, sess, save_path) + else: + checkpoint_file = save_path + if os.path.basename(save_path) == latest_filename and not self._sharded: + # Guard against collision between data file and checkpoint state file. + raise ValueError(f"{latest_filename} collides with {save_path}") + + save_path_parent = os.path.dirname(save_path) + model_checkpoint_path = None + if self._is_empty: + return model_checkpoint_path + + model_checkpoint_path = compat.as_str(get_model_checkpoint_path(self, checkpoint_file, sess)) + if write_state: + update_checkpoint_state(self, model_checkpoint_path, save_path_parent, latest_filename, meta_graph_suffix, + save_path) + if write_meta_graph: + write_meta_graph_task(self, checkpoint_file, meta_graph_suffix, sess, strip_default_attrs, save_debug_info) + return model_checkpoint_path + + +def restore(self, sess, save_path): + if save_path is None: + raise ValueError("Can't load save_path when it is None.") + checkpoint_prefix = compat.as_text(save_path) + if self._is_empty: + return + if not checkpoint_management.checkpoint_exists_internal(checkpoint_prefix): + raise ValueError("The passed save_path is not a valid checkpoint: " + + checkpoint_prefix) + + tf_logging.info("Restoring parameters from %s", checkpoint_prefix) + try: + if not context.executing_eagerly(): + sess.run(self.saver_def.restore_op_name, + {self.saver_def.filename_tensor_name: save_path}) + # mxRec Patch + # restore sparse model, only run when self.sparse_saver is not None + if self.sparse_saver: + self.sparse_saver.restore(sess, save_path) + + logging.info("Restore from dir %s", save_path) + else: + self._build_eager(save_path, build_save=False, build_restore=True) + + except errors.NotFoundError as err: + try: + names_to_keys = object_graph_key_mapping(save_path) + except errors.NotFoundError: + raise _wrap_restore_error_with_msg( + err, "a Variable name or other graph key that is missing") from err + + # This is an object-based checkpoint. We'll print a warning and then do + # the restore. + tf_logging.warning( + "Restoring an object-based checkpoint using a name-based saver. This " + "may be somewhat fragile, and will re-build the Saver. Instead, " + "consider loading object-based checkpoints using tf.train.Checkpoint().") + self._object_restore_saver = saver_from_object_based_checkpoint(checkpoint_path=save_path, + var_list=self._var_list, builder=self._builder, + names_to_keys=names_to_keys, + cached_saver=self._object_restore_saver) + + except errors.InvalidArgumentError as err: + raise _wrap_restore_error_with_msg(err, "a mismatch between the current graph and the graph") from err + + +def object_graph_key_mapping(file_path): + reader = pywrap_tensorflow.NewCheckpointReader(file_path) + obj_graph_str = reader.get_tensor(trackable.OBJECT_GRAPH_PROTO_KEY) + obj_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) + obj_graph_proto.ParseFromString(obj_graph_str) + node_name_to_key = {} + for each_node in obj_graph_proto.nodes: + for attribute in each_node.attributes: + node_name_to_key[attribute.full_name] = attribute.checkpoint_key + return node_name_to_key + + +def _wrap_restore_error_with_msg(err, extra_verbiage): + err_msg = ("Restoring from checkpoint failed." + "This is most likely due to {} from the checkpoint." + "Please ensure that you have not altered the graph expected based on the checkpoint. " + "Original error: {}").format(extra_verbiage, err.message) + return err.__class__(err.node_def, err.op, err_msg) + + +def saver_from_object_based_checkpoint(checkpoint_path, var_list=None, builder=None, names_to_keys=None, + cached_saver=None): + if names_to_keys is None: + try: + names_to_keys = object_graph_key_mapping(checkpoint_path) + except errors.NotFoundError as err: + raise ValueError("Checkpoint in %s not an object-based checkpoint." % + checkpoint_path) from err + if var_list is None: + var_list = variables._all_saveable_objects() + if builder is None: + builder = BulkSaverBuilder() + + current_node_names = set() + obj_saveable_list = saveable_object_util.validate_and_slice_inputs(var_list) + + for obj_saveable in obj_saveable_list: + for spec in obj_saveable.specs: + current_node_names.add(spec.name) + previous_node_names = set(names_to_keys.keys()) + missing_names = current_node_names - previous_node_names + if missing_names: + extra_node_names = previous_node_names - current_node_names + intersecting_names = previous_node_names.intersection(current_node_names) + raise errors.NotFoundError( + None, None, + message=("Existing variables not in the checkpoint: %s\n" + "Variables names when this checkpoint was written which don't exist now: %s\n\n" + "(%d variable name(s) did match)\n\n" + "Could not find some variables in the checkpoint (see names above). " + "Saver was attempting to load an object-based checkpoint (saved using tf.train.Checkpoint " + "or tf.keras.Model.save_weights) using variable names. " + "If the checkpoint was written with eager execution enabled, " + "it's possible that variable names have changed (for example missing a '_1' suffix). " + "It's also possible that there are new variables which did not exist " + "when the checkpoint was written. " + "You can construct a Saver(var_list=...) with only the variables which previously existed, " + "and if variable names have changed you may need to make this a dictionary " + "with the old names as keys. If you're using an Estimator, " + "you'll need to return a tf.train.Saver inside a tf.train.Scaffold from your model_fn.") % ( + ", ".join(sorted(missing_names)), ", ".join(sorted(extra_node_names)), len(intersecting_names))) + for obj_saveable in obj_saveable_list: + for spec in obj_saveable.specs: + spec.name = names_to_keys.get(spec.name) + if cached_saver is None: + return tf.compat.v1.train.Saver(obj_saveable_list) + return cached_saver + + +class BaseSaverBuilder(object): + VariableSaveable = saveable_object_util.ReferenceVariableSaveable + SaveSpec = saveable_object.SaveSpec + ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable + SaveableObject = saveable_object.SaveableObject + + def __init__(self, write_version=saver_pb2.SaverDef.V2): + self._write_version = write_version + + def save_op(self, file_name, obj_saveable_list): + tensors, tensor_names, tensor_slices = [], [], [] + for obj_saveable in obj_saveable_list: + for spec in obj_saveable.specs: + tensors.append(spec.tensor) + tensor_names.append(spec.name) + tensor_slices.append(spec.slice_spec) + if self._write_version == saver_pb2.SaverDef.V2: + return io_ops.save_v2(file_name, tensor_names, tensor_slices, + tensors) + elif self._write_version == saver_pb2.SaverDef.V1: + return io_ops._save(filename=file_name, tensor_names=tensor_names, tensors=tensors, + tensor_slices=tensor_slices) + else: + raise RuntimeError("Unexpected write_version: " + self._write_version) + + +class BulkSaverBuilder(BaseSaverBuilder): + def bulk_restore(self, filename_tensor, saveables, preferred_shard, restore_sequentially): + restore_specs = [] + del restore_sequentially + for obj_saveable in saveables: + for spec in obj_saveable.specs: + restore_specs.append((spec.name, + spec.slice_spec, + spec.dtype)) + tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) + with ops.device("cpu:0"): + return io_ops.restore_v2(filename_tensor, tensor_names, tensor_slices, tensor_dtypes) + + +def patch_for_saver(): + dense_saver = tf.compat.v1.train.Saver + dense_saver.__init__ = saver_init + dense_saver.save = save + dense_saver.restore = restore + logging.debug("Class tf.train.Saver has been patched.") diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py new file mode 100644 index 00000000..d1593872 --- /dev/null +++ b/mx_rec/saver/saver.py @@ -0,0 +1,295 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Huawei Technologies Co., Ltd + +import json +import os +import shutil +import logging +import stat +from collections import defaultdict + +import numpy as np +import tensorflow as tf +from tensorflow.python.util import compat + +from mx_rec.util.constants import DataName, DataAttr +from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ + get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, \ + get_ascend_global_hashtable_collection +from mx_rec.util.perf import performance + + +class Saver(object): + customized_ops = get_customized_ops() + + def __init__(self, var_list=None, max_to_keep=3, prefix_name="checkpoint"): + self.max_to_keep = max_to_keep + self._prefix_name = prefix_name + self.var_list = var_list + self.rank_id = get_rank_id() + self.local_rank_id = self.rank_id % 8 + self.rank_size = get_rank_size() + self.local_rank_size = min(self.rank_size, 8) + self.save_op_dict = defaultdict(dict) + self.restore_fetch_list = [] + self.placeholder_dict = defaultdict(dict) + self.build() + + def build(self): + if self.var_list is None: + logging.debug(f"optimizer collection name: {get_ascend_global_hashtable_collection()}") + self.var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + + with tf.compat.v1.variable_scope("mx_rec_save"): + self._build_save() + with tf.compat.v1.variable_scope("mx_rec_restore"): + self._build_restore() + + logging.debug("Save & Restore graph was built.") + + def _build_save(self): + for var in self.var_list: + table_instance = get_table_instance(var) + table_name = table_instance.table_name + with tf.compat.v1.variable_scope(table_name): + sub_dict = self.save_op_dict[table_name] + sub_dict[DataName.EMBEDDING.value] = var + if table_instance.optimizer: + sub_dict["optimizer"] = table_instance.optimizer + + def _build_restore(self): + for var in self.var_list: + table_instance = get_table_instance(var) + sub_placeholder_dict = self.placeholder_dict[table_instance.table_name] + with tf.compat.v1.variable_scope(table_instance.table_name): + sub_placeholder_dict[DataName.EMBEDDING.value] = variable = \ + tf.compat.v1.placeholder(dtype=tf.float32, shape=[table_instance.slice_device_vocabulary_size, + table_instance.scalar_emb_size], + name=DataName.EMBEDDING.value) + assign_op = var.assign(variable) + self.restore_fetch_list.append(assign_op) + + if table_instance.optimizer: + self._build_optimizer_restore(sub_placeholder_dict, table_instance) + + def _build_optimizer_restore(self, sub_placeholder_dict, table_instance): + sub_placeholder_dict["optimizer"] = optimizer_placeholder_dict = dict() + optimizer_states = table_instance.optimizer + for optimizer_name, optimizer_state_dict in optimizer_states.items(): + optimizer_placeholder_dict[optimizer_name] = sub_optimizer_placeholder_dict = \ + dict([(state_key, tf.compat.v1.placeholder(dtype=tf.float32, + shape=[table_instance.slice_device_vocabulary_size, + table_instance.scalar_emb_size], + name=state_key)) + for state_key, state in optimizer_state_dict.items()]) + for key_state, state in optimizer_state_dict.items(): + assign_op = state.assign(sub_optimizer_placeholder_dict.get(key_state)) + self.restore_fetch_list.append(assign_op) + + @performance("Save") + def save(self, sess, save_path="model", global_step=None): + logging.debug(f"======== Start saving for rank id {self.rank_id} ========") + save_path = save_path if save_path else self._prefix_name + directory, base_name = os.path.split(save_path) + if global_step is not None: + if not isinstance(global_step, compat.integral_types): + global_step = int(sess.run(global_step)) + ckpt_name = "sparse-%s-%d" % (base_name, global_step) + else: + ckpt_name = "sparse-%s" % base_name + + integrated_path = os.path.join(directory, ckpt_name) + saving_path = os.path.abspath(integrated_path) + if os.path.exists(saving_path): + shutil.rmtree(saving_path, ignore_errors=True) + logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been deleted.") + os.makedirs(saving_path, exist_ok=True) + logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been made.") + + self._save(sess, saving_path) + logging.info(f"sparse model was saved in dir '{saving_path}' .") + logging.debug(f"======== Saving finished for rank id {self.rank_id} ========") + + def _save(self, sess, root_dir): + if is_asc_manager_initialized(): + save_host_data(root_dir) + logging.debug(f"host data was saved.") + + result = sess.run(self.save_op_dict) + for table_name, dump_data_dict in result.items(): + save_embedding_data(root_dir, table_name, dump_data_dict, self.rank_id) + table_instance = get_table_instance_by_name(table_name) + if table_instance.use_feature_mapping: + save_feature_mapping_data(root_dir, table_name, dump_data_dict, self.rank_id) + save_offset_data(root_dir, table_name, dump_data_dict, self.rank_id) + if "optimizer" in dump_data_dict: + dump_optimizer_data_dict = dump_data_dict.get("optimizer") + for optimizer_name, dump_optimizer_data in dump_optimizer_data_dict.items(): + save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, + self.rank_id) + + @performance("Restore") + def restore(self, sess, reading_path): + logging.debug("======== Start restoring ========") + directory, base_name = os.path.split(reading_path) + ckpt_name = "sparse-%s" % base_name + + integrated_path = os.path.join(directory, ckpt_name) + reading_path = os.path.abspath(integrated_path) + if not os.path.exists(reading_path): + raise FileExistsError(f"Given dir {reading_path} does not exist, please double check.") + + self._restore(sess, reading_path) + logging.info(f"sparse model was restored from dir '{reading_path}' .") + logging.debug("======== Restoring finished ========") + + def _restore(self, sess, reading_path): + if is_asc_manager_initialized(): + restore_host_data(reading_path) + logging.debug(f"host data was restored.") + + restore_feed_dict = defaultdict(dict) + for table_name, sub_placeholder_dict in self.placeholder_dict.items(): + fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, + NameDescriptor(table_name, DataName.EMBEDDING.value)) + table_instance = get_table_instance_by_name(table_name) + if table_instance.use_feature_mapping: + fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, + NameDescriptor(table_name, DataName.FEATURE_MAPPING.value)) + fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, + NameDescriptor(table_name, DataName.OFFSET.value)) + + if "optimizer" in sub_placeholder_dict: + optimizer_state_placeholder_dict_group = sub_placeholder_dict.get("optimizer") + for optimizer_name, optimizer_state_placeholder_dict in optimizer_state_placeholder_dict_group.items(): + for state_key in optimizer_state_placeholder_dict: + fill_placeholder(reading_path=reading_path, + placeholder_dict=optimizer_state_placeholder_dict, + feed_dict=restore_feed_dict, + suffix=self.rank_id, + name_descriptor=NameDescriptor(table_name, state_key, + optimizer_name=optimizer_name)) + + sess.run(self.restore_fetch_list, feed_dict=restore_feed_dict) + + +class NameDescriptor: + def __init__(self, table_name, data_name, optimizer_name=None): + self.table_name = table_name + self.data_name = data_name + self.optimizer_name = optimizer_name + + +def fill_placeholder(reading_path, placeholder_dict, feed_dict, suffix, name_descriptor): + if name_descriptor.optimizer_name: + target_path = generate_path(reading_path, "Optimizer", name_descriptor.optimizer_name, "HBM", + name_descriptor.table_name, name_descriptor.data_name) + else: + target_path = generate_path(reading_path, "HashTable", "HBM", name_descriptor.table_name, + name_descriptor.data_name) + restore_data_dict = read_binary_data(target_path, suffix, name_descriptor.data_name) + + for key, data in restore_data_dict.items(): + embedding_placeholder = placeholder_dict.get(key) + feed_dict[embedding_placeholder] = data + + +def save_embedding_data(root_dir, table_name, dump_data_dict, suffix): + target_path = generate_path(root_dir, "HashTable", "HBM", table_name, DataName.EMBEDDING.value) + data_to_write = dump_data_dict.get(DataName.EMBEDDING.value) + + attribute = dict() + attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name + attribute[DataAttr.SHAPE.value] = data_to_write.shape + write_binary_data(target_path, suffix, data_to_write, attributes=attribute) + + +def save_feature_mapping_data(root_dir, table_name, dump_data_dict, suffix): + target_path = generate_path(root_dir, "HashTable", "HBM", table_name, DataName.FEATURE_MAPPING.value) + data_to_write = dump_data_dict.get(DataName.FEATURE_MAPPING.value) + valid_len = dump_data_dict.get(DataName.VALID_LEN.value) + data_to_write = data_to_write[:valid_len * 3] + + attribute = dict() + attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name + attribute[DataName.THRESHOLD.value] = int(dump_data_dict.get(DataName.THRESHOLD.value)) + write_binary_data(target_path, suffix, data_to_write, attributes=attribute) + + +def save_offset_data(root_dir, table_name, dump_data_dict, suffix): + target_path = generate_path(root_dir, "HashTable", "HBM", table_name, DataName.OFFSET.value) + data_to_write = dump_data_dict.get(DataName.OFFSET.value) + valid_bucket_num = dump_data_dict.get(DataName.VALID_BUCKET_NUM.value) + data_to_write = data_to_write[:valid_bucket_num] + + attribute = dict() + attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name + write_binary_data(target_path, suffix, data_to_write, attributes=attribute) + + +def save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, suffix): + for state_key, state in dump_optimizer_data.items(): + target_path = generate_path(root_dir, "Optimizer", optimizer_name, "HBM", table_name, state_key) + data_to_write = state + + attribute = dict() + attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name + attribute[DataAttr.SHAPE.value] = data_to_write.shape + write_binary_data(target_path, suffix, data_to_write, attributes=attribute) + + +def generate_path(*args): + return os.path.join(*args) + + +def generate_file_name(suffix): + return "slice_%d.data" % suffix, "slice_%d.attribute" % suffix + + +def write_binary_data(writing_path, suffix, data, attributes=None): + os.makedirs(writing_path, exist_ok=True) + data_file, attribute_file = generate_file_name(suffix) + target_data_dir = os.path.join(writing_path, data_file) + target_attribute_dir = os.path.join(writing_path, attribute_file) + if os.path.exists(target_data_dir): + raise FileExistsError(f"Target_data_dir {target_data_dir} exists before writing.") + if os.path.exists(target_attribute_dir): + raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} exists before writing.") + data.tofile(target_data_dir) + + if attributes is not None: + if not isinstance(attributes, dict): + raise TypeError(f"Parameter 'attributes' must be one dict instance, instead of {type(attributes)}") + flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL + mode = stat.S_IRUSR | stat.S_IWUSR + with os.fdopen(os.open(target_attribute_dir, flags, mode), 'w') as file: + file.write(json.dumps(attributes)) + + +def read_binary_data(reading_path, suffix, data_name): + data_file, attribute_file = generate_file_name(suffix) + target_data_dir = os.path.join(reading_path, data_file) + target_attribute_dir = os.path.join(reading_path, attribute_file) + if not os.path.exists(target_data_dir): + raise FileExistsError(f"Target_data_dir {target_data_dir} does not exist when reading.") + if not os.path.exists(target_attribute_dir): + raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} does not exist when reading.") + + with open(target_attribute_dir, "r") as fin: + attributes = json.load(fin) + + if DataAttr.DATATYPE.value not in attributes: + raise AttributeError(f"Lack of attribute {DataAttr.DATATYPE.value}.") + + data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) + if DataAttr.SHAPE.value in attributes: + data_to_restore = data_to_restore.reshape(attributes.pop(DataAttr.SHAPE.value)) + + data_dict = {data_name: data_to_restore} + for key, item in attributes.items(): + data_dict[key] = item + logging.debug(f"Attribute: '{target_attribute_dir}' and data file: '{target_data_dir}' have been read.") + logging.debug(f"Reading shape is {data_to_restore.shape}.") + + return data_dict diff --git a/mx_rec/util/__init__.py b/mx_rec/util/__init__.py new file mode 100644 index 00000000..68b05dcd --- /dev/null +++ b/mx_rec/util/__init__.py @@ -0,0 +1,5 @@ +# coding: UTF-8 +from .log import get_log_level + + +get_log_level() diff --git a/mx_rec/util/atomic.py b/mx_rec/util/atomic.py new file mode 100644 index 00000000..57c41acb --- /dev/null +++ b/mx_rec/util/atomic.py @@ -0,0 +1,27 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Huawei Technologies Co., Ltd + +import threading + + +class AtomicInteger(): + def __init__(self, value=0): + self._value = int(value) + self._lock = threading.Lock() + + def increase(self, num=1): + with self._lock: + self._value += int(num) + return self._value + + def decrease(self, num=1): + return self.inc(-num) + + def value(self): + with self._lock: + return self._value + + def __str__(self): + return str(self.value()) + \ No newline at end of file diff --git a/mx_rec/util/constants.py b/mx_rec/util/constants.py new file mode 100644 index 00000000..fe0c84a3 --- /dev/null +++ b/mx_rec/util/constants.py @@ -0,0 +1,91 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +# Copyright 2021-2023 Huawei Technologies Co., Ltd + +from enum import Enum + +ASCEND_GLOBAL_HASHTABLE_COLLECTION = "ASCEND_GLOBAL_HASHTABLE_COLLECTION" +ASCEND_CUTTING_POINT_INITIALIZER = "ASCEND_CUTTING_POINT_INITIALIZER" +ASCEND_CUTTING_POINT = "ASCEND_CUTTING_POINT" +ASCEND_SPARSE_LOOKUP_ENTRANCE = "ASCEND_SPARSE_LOOKUP_ENTRANCE" +ASCEND_SPARSE_LOOKUP_ID_OFFSET = "ASCEND_SPARSE_LOOKUP_ID_OFFSET" +ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR = "ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR" +# dynamic shape identity +ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX = "ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX" +# hot embed function identity +ASCEND_SPARSE_LOOKUP_HOT_POS = "ASCEND_SPARSE_LOOKUP_HOT_POS" +ASCEND_TIMESTAMP = "ASCEND_TIMESTAMP" +CUSTOMIZED_OPS_LIB_PATH = "CUSTOMIZED_OPS_LIB_PATH" +HOST_PIPELINE_OPS_LIB_PATH = "HOST_PIPELINE_OPS_LIB_PATH" +ASCEND_SPARSE_LOOKUP_LOCAL_EMB = "ASCEND_SPARSE_LOOKUP_LOCAL_EMB" + +# the name of the embedding table merged by third party +ASCEND_TABLE_NAME_MUST_CONTAIN = None + +# this number is a temp plan to solve a problem +# to avoid op "scatter_nd_update" may get a None tensor for input +AVOID_TENSOR_POS = 439999 +LOCAL_RANK_SIZE = "LOCAL_RANK_SIZE" # 训练时,当前服务器使用的NPU卡数 +MAX_DEVICE_NUM_LOCAL_MACHINE = 16 # 单台服务器最大的卡数 +DEFAULT_DEVICE_NUM_LOCAL_MACHINE = 8 # 单台服务器默认的卡数 + +DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 +TRAIN_CHANNEL_ID = 0 +EVAL_CHANNEL_ID = 1 + + +class BaseEnum(Enum): + @classmethod + def mapping(cls, key): + for mode in cls: + if key == mode.value: + return mode + + raise KeyError(f"Cannot find a corresponding mode in current Enum " + f"class {cls}, given parameter '{key}' is illegal, " + f"please choose a valid one from " + f"'{list(map(lambda c: c.value, cls))}'.") + + +class DataName(Enum): + EMBEDDING = "embedding" + FEATURE_MAPPING = "feature_mapping" + OFFSET = "offset" + THRESHOLD = "threshold" + VALID_LEN = "valid_len" + VALID_BUCKET_NUM = "valid_bucket_num" + + +class DataAttr(Enum): + SHAPE = "shape" + DATATYPE = "data_type" + + +class ASCAnchorAttr(Enum): + TABLE_INSTANCE = "table_instance" + IS_TRAINING = "is_training" + RESTORE_VECTOR = "restore_vector" + ID_OFFSETS = "id_offsets" + FEATURE_SPEC = "feature_spec" + ALL2ALL_MATRIX = "all2all_matrix" + HOT_POS = "hot_pos" + + +class MxRecMode(BaseEnum): + ASC = "ASC" # Ascend Sparse with Cpu-hashtable + + +class OptimizerType(Enum): + LAZY_ADAM = "LazyAdam" + SGD = "SGD" + + @staticmethod + def get_optimizer_state_meta(mode): + if mode in OPTIMIZER_STATE_META: + return OPTIMIZER_STATE_META.get(mode) + + raise ValueError(f"Invalid mode value, please choose one from {list(map(lambda c: c.value, OptimizerType))}") + + +OPTIMIZER_STATE_META = {OptimizerType.LAZY_ADAM: ["momentum", "velocity"], + OptimizerType.SGD: []} diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py new file mode 100644 index 00000000..2cb6d235 --- /dev/null +++ b/mx_rec/util/initialize.py @@ -0,0 +1,617 @@ +# coding: UTF-8 + +import json +import logging +import os +import psutil + +import mx_rec.util.constants +from mx_rec.util.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, \ + ASCEND_GLOBAL_HASHTABLE_COLLECTION +from mx_rec.util.ops import import_host_pipeline_ops + + +class ConfigInitializer: + _single_instance = None + customized_ops = None + host_pipeline_ops = import_host_pipeline_ops() + + def __init__(self, use_mpi, **kwargs): + self._use_mpi = use_mpi + self._ascend_global_hashtable_collection = ASCEND_GLOBAL_HASHTABLE_COLLECTION + self._comm = None + self._asc_manager = None + self._mpi = None + self._is_frozen = False + self._train_interval = None + self._eval_steps = None + self._prefetch_batch_number = None + self._if_load = None + self._table_instance_dict = dict() + self._name_to_var_dict = dict() + self._table_name_set = set() + self._table_name_to_feature_spec = dict() + self._feature_spec_dict = dict() + self._training_mode_channel_dict = dict() + self._rank_to_device_dict = dict() + self._initializer_dict = {} + self._optimizer_instance = None + + if self._use_mpi: + logging.debug(f"Using mpi to launch task.") + from mpi4py import MPI + self._mpi = MPI + self._comm = MPI.COMM_WORLD + self._rank_id = self._comm.Get_rank() + self._rank_size = self._comm.Get_size() + else: + self._rank_id = kwargs.get("rank_id") + self._rank_size = kwargs.get("rank_size") + + if os.getenv("RANK_TABLE_FILE"): + self.parse_hccl_json() + else: + self.set_device_dict() + self.check_parameters() + self.train_interval = kwargs.get("train_interval", -1) + self.eval_steps = kwargs.get("eval_steps", -1) + self.prefetch_batch_number = kwargs.get("prefetch_batch_number", 1) + self.if_load = kwargs.get("if_load", False) + if_dynamic = kwargs.get("use_dynamic", 1) + + self.use_static = 0 if if_dynamic == 1 else 1 + self.use_hot = kwargs.get("use_hot", True) + self.use_dynamic_expansion = kwargs.get("use_dynamic_expansion", False) + if kwargs.get("bind_cpu", True): + bind_cpu(self._rank_id, self._rank_size) + + def __del__(self): + self.terminate() + + def terminate(self): + if self._asc_manager is not None: + self.del_asc_manager() + + if self._mpi: + self._mpi.Finalize() + logging.debug("MPI has been destroyed.") + + def insert_feature_spec(self, feature, is_training): + self._feature_spec_dict[feature.name] = feature + if feature.table_name not in self._table_name_to_feature_spec: + self._table_name_to_feature_spec[feature.table_name] = {True: [], False: []} + self._table_name_to_feature_spec[feature.table_name][is_training].append(feature) + + def get_feature_spec(self, key): + return self._feature_spec_dict.get(key) + + @property + def feature_spec_dict(self): + return self._feature_spec_dict + + @property + def table_name_set(self): + return self._table_name_set + + @property + def table_name_to_feature_spec(self): + return self._table_name_to_feature_spec + + def parse_hccl_json(self): + rank_table_path = os.getenv("RANK_TABLE_FILE") + if not os.path.exists(rank_table_path): + raise FileExistsError(f"Target_hccl_json_dir {rank_table_path} does not exist when reading.") + with open(rank_table_path, "r", encoding="utf-8") as file: + table_hccl = json.load(file) + if "server_list" not in table_hccl: + raise AttributeError(f"Lack of attribute server_list.") + if not table_hccl["server_list"]: + raise ValueError(f"Server_list is empty.") + if "device" not in table_hccl["server_list"][0]: + raise AttributeError(f"Lack of attribute device.") + + for server_list in table_hccl.get("server_list"): + for device in server_list.get("device"): + if "rank_id" not in device or not device["rank_id"].isdigit(): + raise ValueError(f"hccl_json rank_id wrong.") + if "device_id" not in device or not device["device_id"].isdigit(): + raise ValueError(f"hccl_json device_id wrong.") + self._rank_to_device_dict[int(device["rank_id"])] = int(device["device_id"]) + + def set_device_dict(self): + ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") + if not ascend_visible_devices: + raise ValueError("env variable ascend_visible_devices is null.") + if "-" in ascend_visible_devices: + rank_start = int(ascend_visible_devices.strip().split("-")[0]) + elif "," in ascend_visible_devices: + rank_start = int(ascend_visible_devices.strip().split(",")[0]) + elif ascend_visible_devices in ["0", "1", "2", "3", "4", "5", "6", "7"]: + rank_start = int(ascend_visible_devices.strip()) + else: + raise ValueError("invalid env variable ascend_visible_devices.") + rank_size = os.getenv("CM_WORKER_SIZE") + if rank_size: + rank_size = int(rank_size) + local_rank_size = rank_size if rank_size < 8 else 8 + for device_index in range(rank_size): + self._rank_to_device_dict[device_index] = int(device_index % local_rank_size) + rank_start + else: + raise ValueError("get CM_WORKER_SIZE failed.") + + def insert_training_mode_channel_id(self, is_training): + if is_training not in self._training_mode_channel_dict: + # mx_rec has 2 channel for data input. it would bind channel_id to training mode recorded in dict. + self._training_mode_channel_dict[is_training] = len(self._training_mode_channel_dict) + + def get_training_mode_channel_id(self, is_training): + return self._training_mode_channel_dict.get(is_training) + + def insert_table_instance(self, name, key, instance): + if key in self._table_instance_dict: + raise KeyError(f"Given key {key} has been used.") + + if name in self._table_name_set: + raise ValueError(f"Duplicated hashtable name '{name}' was used.") + + logging.debug(f"Record one hash table, with name: {name}, key: {key}.") + self._table_name_set.add(name) + if name not in self._table_name_to_feature_spec: + self._table_name_to_feature_spec[name] = {True: [], False: []} + self._name_to_var_dict[name] = key + self._table_instance_dict[key] = instance + + def get_table_instance(self, key): + if key not in self._table_instance_dict: + raise KeyError(f"Given key does not exist.") + + return self._table_instance_dict.get(key) + + def get_table_instance_by_name(self, table_name): + if table_name not in self._name_to_var_dict: + raise KeyError(f"Given table name does not exist.") + + key = self._name_to_var_dict.get(table_name) + return self._table_instance_dict.get(key) + + @property + def table_instance_dict(self): + return self._table_instance_dict + + def insert_optimizer(self, optimizer): + self._optimizer_instance = optimizer + + @property + def optimizer_instance(self): + return self._optimizer_instance + + def check_parameters(self): + if not isinstance(self._use_mpi, bool): + raise ValueError(f"Arg use_mpi must be a boolean.") + + if not isinstance(self.rank_id, int) or not isinstance(self.rank_size, int): + raise ValueError(f"Args rank_size and rank_id must be integers. {self.rank_id} {self.rank_size}") + + if self.rank_id < 0: + raise ValueError(f"Arg rank_id must be larger than 0, which is {self.rank_id} now.") + + if self.rank_size < 1: + raise ValueError(f"Arg rank_size must be larger than 1, which is {self.rank_size} now.") + + if self.rank_id >= self.rank_size: + raise ValueError(f"Rank_id must be within the range from 0 to rank_size.") + + def freeze(self): + self._is_frozen = True + + def unfreeze(self): + self._is_frozen = False + + @property + def is_frozen(self): + return self._is_frozen + + @property + def name_to_var_dict(self): + return self._name_to_var_dict + + def set_asc_manager(self, manager): + from mxrec_pybind import HybridMgmt + if not isinstance(manager, HybridMgmt): + raise ValueError(f"Given manager must be the instance of {HybridMgmt}, which is {type(manager)} " + f"type currently.") + self._asc_manager = manager + self.freeze() + + def get_asc_manager(self): + return self._asc_manager + + def del_asc_manager(self): + self.delete_initializers() + self._asc_manager.destroy() + self._asc_manager = None + self.unfreeze() + logging.debug("ASC manager has been destroyed.") + + @property + def use_mpi(self): + return self._use_mpi + + @property + def rank_size(self): + return self._rank_size + + @property + def rank_id(self): + return self._rank_id + + @property + def device_id(self): + if self._rank_id not in self._rank_to_device_dict: + raise KeyError(f"rank id not in rank_to_device_dict. {self._rank_id} {self._rank_to_device_dict}") + return self._rank_to_device_dict[self._rank_id] + + @staticmethod + def get_instance(): + if ConfigInitializer._single_instance is None: + raise EnvironmentError("Please init the environment for mx_rec at first.") + + return ConfigInitializer._single_instance + + @staticmethod + def set_instance(use_mpi, **kwargs): + if ConfigInitializer._single_instance is not None: + raise EnvironmentError("ConfigInitializer has been initialized once, twice initialization was forbidden.") + + ConfigInitializer._single_instance = ConfigInitializer(use_mpi, **kwargs) + + @property + def train_interval(self): + return self._train_interval + + @property + def eval_steps(self): + return self._eval_steps + + @train_interval.setter + def train_interval(self, interval): + check_step(interval) + self._train_interval = interval + + @eval_steps.setter + def eval_steps(self, steps): + check_step(steps) + self._eval_steps = steps + + @property + def prefetch_batch_number(self): + return self._prefetch_batch_number + + @prefetch_batch_number.setter + def prefetch_batch_number(self, number): + check_step(number, 1) + self._prefetch_batch_number = number + + @property + def if_load(self): + return self._if_load + + @if_load.setter + def if_load(self, flag): + if not isinstance(flag, bool): + raise TypeError(f"Flag if load should be a boolean.") + + self._if_load = flag + + @property + def ascend_global_hashtable_collection(self): + return self._ascend_global_hashtable_collection + + @ascend_global_hashtable_collection.setter + def ascend_global_hashtable_collection(self, name): + if not isinstance(name, str): + raise TypeError(f"collection name '{name}' must be a string.") + self._ascend_global_hashtable_collection = name + + def get_initializer(self, is_training): + return self._initializer_dict.get(is_training) + + def set_initializer(self, is_training, initializer): + if not isinstance(is_training, bool): + raise ValueError(f"Given key must be a boolean, but got {is_training}.") + + self._initializer_dict[is_training] = initializer + + def delete_initializers(self): + self._initializer_dict = {} + + +def set_ascend_global_hashtable_collection(name=ASCEND_GLOBAL_HASHTABLE_COLLECTION): + ConfigInitializer.get_instance().ascend_global_hashtable_collection = name + + +def get_ascend_global_hashtable_collection(): + return ConfigInitializer.get_instance().ascend_global_hashtable_collection + + +def check_step(param, min_value=-1): + if not isinstance(param, int): + raise TypeError("Given param must be an integer.") + + if param < min_value: + raise ValueError(f"Valid value range is larger than or equals to {min_value}.") + + if param == 0: + raise ValueError("Arg train_interval or eval_steps cannot equal to 0.") + + +def init(use_mpi, **kwargs): + ConfigInitializer.set_instance(use_mpi, **kwargs) + set_ascend_env() + + +def is_mpi_in_use(): + return ConfigInitializer.get_instance().use_mpi + + +def get_rank_size(): + return ConfigInitializer.get_instance().rank_size + + +def get_rank_id(): + return ConfigInitializer.get_instance().rank_id + + +def get_device_id(): + return ConfigInitializer.get_instance().device_id + + +def set_asc_manager(manager): + ConfigInitializer.get_instance().set_asc_manager(manager) + + +def trigger_evict(): + if not is_asc_manager_initialized(): + raise RuntimeError("ASC manager does not exist.") + + ConfigInitializer.get_instance().get_asc_manager().evict() + logging.debug("Feature evict is triggered by ops.") + + +def clear_channel(is_train_channel=False): + if not isinstance(is_train_channel, bool): + raise ValueError("Arg is_train_channel should be a boolean.") + channel_id = get_training_mode_channel_id(is_train_channel) + logging.info(f"clear channel: {channel_id}") + + return ConfigInitializer.get_instance().host_pipeline_ops.clear_channel(channel_id) + + +def is_asc_manager_initialized(): + return ConfigInitializer.get_instance().get_asc_manager() is not None + + +def save_host_data(root_dir): + if not is_asc_manager_initialized(): + raise RuntimeError("ASC manager does not exist.") + + ConfigInitializer.get_instance().get_asc_manager().save(root_dir) + logging.debug("Data from host pipeline has been saved.") + + +def restore_host_data(root_dir): + if not is_asc_manager_initialized(): + raise RuntimeError("ASC manager does not exist.") + + if not ConfigInitializer.get_instance().get_asc_manager().load(root_dir): + terminate_config_initializer() + raise TypeError("Asc load data does not match usr setups, \ + please re-consider if you want to restore from this dir") + logging.debug("Data from host pipeline has been restored.") + + +def destroy_asc_manager(): + initializer = ConfigInitializer.get_instance() + if initializer.get_asc_manager() is not None: + logging.debug("start destroy asc manager...") + initializer.del_asc_manager() + else: + logging.warning("ASC manager does not exist, please check your code.") + + +def is_asc_frozen(): + return ConfigInitializer.get_instance().is_frozen + + +def export_table_name_set(): + return ConfigInitializer.get_instance().table_name_set + + +def get_host_pipeline_ops(): + return ConfigInitializer.host_pipeline_ops + + +def get_customized_ops(): + return ConfigInitializer.customized_ops + + +def get_train_interval(): + return ConfigInitializer.get_instance().train_interval + + +def get_eval_steps(): + return ConfigInitializer.get_instance().eval_steps + + +def set_train_interval(interval): + ConfigInitializer.get_instance().train_interval = interval + + +def set_eval_steps(steps): + ConfigInitializer.get_instance().eval_steps = steps + + +def get_prefetch_batch_number(): + return ConfigInitializer.get_instance().prefetch_batch_number + + +def set_prefetch_batch_number(number): + ConfigInitializer.get_instance().prefetch_batch_number = number + + +def get_table_instance(key): + return ConfigInitializer.get_instance().get_table_instance(key) + + +def get_table_instance_by_name(table_name): + return ConfigInitializer.get_instance().get_table_instance_by_name(table_name) + + +def insert_table_instance(name, key, instance): + ConfigInitializer.get_instance().insert_table_instance(name, key, instance) + + +def export_table_instances(): + return ConfigInitializer.get_instance().table_instance_dict + + +def insert_optimizer(optimizer): + ConfigInitializer.get_instance().insert_optimizer(optimizer) + + +def export_optimizer(): + return ConfigInitializer.get_instance().optimizer_instance + + +def insert_feature_spec(feature, is_training): + ConfigInitializer.get_instance().insert_feature_spec(feature, is_training) + + +def get_feature_spec(key): + return ConfigInitializer.get_instance().get_feature_spec(key) + + +def insert_training_mode_channel_id(is_training): + ConfigInitializer.get_instance().insert_training_mode_channel_id(is_training) + + +def get_training_mode_channel_id(is_training): + return ConfigInitializer.get_instance().get_training_mode_channel_id(is_training) + + +def export_feature_spec(): + return ConfigInitializer.get_instance().feature_spec_dict + + +def set_if_load(if_load): + ConfigInitializer.get_instance().if_load = if_load + + +def get_if_load(): + return ConfigInitializer.get_instance().if_load + + +def get_use_static(): + return ConfigInitializer.get_instance().use_static + + +def get_use_hot(): + return ConfigInitializer.get_instance().use_hot + + +def get_use_dynamic_expansion(): + return ConfigInitializer.get_instance().use_dynamic_expansion + + +def terminate_config_initializer(): + ConfigInitializer.get_instance().terminate() + + +def get_name_to_var_dict(): + return ConfigInitializer.get_instance().name_to_var_dict + + +def get_initializer(is_training): + return ConfigInitializer.get_instance().get_initializer(is_training) + + +def set_initializer(is_training, initializer): + ConfigInitializer.get_instance().set_initializer(is_training, initializer) + + +def set_ascend_table_name_must_contain(name="merged"): + mx_rec.util.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name + + +def set_ascend_env(): + """ + 配置昇腾相关的参数和环境变量,生成hccl配置 + """ + rank = get_rank_id() + rank_size = get_rank_size() + local_rank_size = 8 + + os.environ["MOX_USE_NPU"] = "1" + os.environ["FUSION_TENSOR_SIZE"] = "2000000000" + os.environ["MOX_USE_TF_ESTIMATOR"] = "0" + os.environ["MOX_USE_TDT"] = "1" + os.environ["HEARTBEAT"] = "1" + os.environ["CONITNUE_TRAIN"] = "true" + + os.environ["RANK_ID"] = str(rank) + + device_id = str(get_device_id()) + os.environ["DEVICE_ID"] = device_id + os.environ["ASCEND_DEVICE_ID"] = device_id + os.environ["DEVICE_INDEX"] = device_id + + if os.getenv("RANK_TABLE_FILE"): + os.environ["RANK_SIZE"] = str(rank_size) + else: + import socket + host_name = socket.gethostname() + host_ip = socket.gethostbyname(host_name) + os.environ["CM_WORKER_IP"] = host_ip + os.environ["HCCL_CONNECT_TIMEOUT"] = "1200" + + os.environ["JOB_ID"] = "10086" + os.environ["SOC_VERSION"] = "Ascend910" + os.environ["GE_AICPU_FLAG"] = "1" + os.environ["NEW_GE_FE_ID"] = "1" + os.environ["EXPERIMENTAL_DYNAMIC_PARTITION"] = "1" + os.environ["ENABLE_FORCE_V2_CONTROL"] = "1" + + logging.debug(f"Ascend env has been set.") + + +def bind_cpu(rank_id: int, rank_size: int = None): + """ + 以均衡的方式为每个进程绑定CPU + :param rank_id:当前进程的rank_id + :return: + """ + from multiprocessing import cpu_count + import math + + try: + local_rank_size = int(os.getenv(LOCAL_RANK_SIZE)) if rank_size is None else rank_size + except (ValueError, TypeError): + logging.warning(f"no valid LOCAL_RANK_SIZE was set. {DEFAULT_DEVICE_NUM_LOCAL_MACHINE} is set as default value") + local_rank_size = DEFAULT_DEVICE_NUM_LOCAL_MACHINE + + if not (1 <= local_rank_size <= MAX_DEVICE_NUM_LOCAL_MACHINE): + logging.warning(f"LOCAL_RANK_SIZE should be between 1 and {MAX_DEVICE_NUM_LOCAL_MACHINE}. " + f"{DEFAULT_DEVICE_NUM_LOCAL_MACHINE} is set as default value") + local_rank_size = DEFAULT_DEVICE_NUM_LOCAL_MACHINE + + total_cpu = cpu_count() + avg_count = math.ceil(total_cpu / local_rank_size) + max_index = total_cpu - 1 + start = rank_id * avg_count + cpu_list = [start + i for i in range(avg_count) if start + i <= max_index] + + process = psutil.Process() + try: + process.cpu_affinity(cpu_list) + except IndexError: + logging.error(f"failed to bind cpu for rank {rank_id}: {cpu_list}") diff --git a/mx_rec/util/log.py b/mx_rec/util/log.py new file mode 100644 index 00000000..8370860e --- /dev/null +++ b/mx_rec/util/log.py @@ -0,0 +1,22 @@ +# coding: UTF-8 +import os +import logging + + +def get_log_level(): + env_log_level = os.getenv("MXREC_LOG_LEVEL") + if env_log_level is None: + env_log_level = "INFO" + + log_level = logging.getLevelName(env_log_level) + if not isinstance(log_level, int): + raise EnvironmentError("A wrong log level string was given.") + + log_format = "%(asctime)s\t%(levelname)s\t%(message)s" + date_format = "%m/%d/%Y %H:%M:%S %p" + + logging.basicConfig(level=log_level, format=log_format, datefmt=date_format) + + +if __name__ == "__main__": + logging.debug("haha") \ No newline at end of file diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py new file mode 100644 index 00000000..b396f5ea --- /dev/null +++ b/mx_rec/util/ops.py @@ -0,0 +1,18 @@ +# coding: UTF-8 +import os +import logging +import tensorflow as tf + +from mx_rec.util.constants import HOST_PIPELINE_OPS_LIB_PATH + + +def import_host_pipeline_ops(): + host_pipeline_ops_lib_path = os.getenv(HOST_PIPELINE_OPS_LIB_PATH) + if host_pipeline_ops_lib_path: + logging.debug(f"Using the HOST_PIPELINE_OPS_LIB_PATH '{host_pipeline_ops_lib_path}' to get ops lib.") + return tf.load_op_library(host_pipeline_ops_lib_path) + else: + mx_rec_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")) + so_path = os.path.join(mx_rec_dir, 'mx_rec/libasc/libasc_ops.so') + logging.debug(f"Using the DEFAULT PATH '{so_path}' to get ops lib.") + return tf.load_op_library(so_path) diff --git a/mx_rec/util/perf.py b/mx_rec/util/perf.py new file mode 100644 index 00000000..72065812 --- /dev/null +++ b/mx_rec/util/perf.py @@ -0,0 +1,15 @@ +# coding: UTF-8 +import time +import logging + + +def performance(method_name): + def decorator(func): + def wrapper(*args, **kwargs): + start = time.perf_counter() + result = func(*args, **kwargs) + span = time.perf_counter() - start + logging.debug(f"{method_name} method consume {span:.6f}s.") + return result + return wrapper + return decorator \ No newline at end of file diff --git a/mx_rec/util/synchronizer.py b/mx_rec/util/synchronizer.py new file mode 100644 index 00000000..82cd54fe --- /dev/null +++ b/mx_rec/util/synchronizer.py @@ -0,0 +1,76 @@ +# coding: UTF-8 +import logging +import socket +from time import sleep + +from mx_rec.util.initialize import get_rank_id + + +class Communicator: + def __init__(self): + self.socket = socket.socket() + self.host = socket.gethostname() + logging.debug(f"host: {self.host}") + self.port = 12345 + self.rank_id = get_rank_id() + self.local_rank_id = self.rank_id % 8 + self.build_connection() + + def build_connection(self): + if self.local_rank_id == 0: + self.socket.bind((self.host, self.port)) + self.socket.listen(8) + + else: + i = 0 + while True: + try: + self.socket.connect((self.host, self.port)) + break + except ConnectionRefusedError: + sleep(0.01) + + i += 1 + logging.debug(f"Connection failed at the NO.{i} time for local rank id {self.local_rank_id}, " + f"rank id {self.rank_id}") + if i > 200: + raise EnvironmentError(f"Socket connecting over time.") + + logging.debug(f"Connection was build for local rank id {self.local_rank_id}, rank id {self.rank_id}") + + + def server_reply(self): + conn, address = self.socket.accept() + client_data = conn.recv(1024).decode() + logging.debug(f"connecting address:{address}") + logging.debug(f"Receive client msg: {client_data}") + conn.send(b"Acknowledged!") + conn.close() + return client_data + + def client_connect(self): + info = str(self.local_rank_id).encode() + self.socket.send(info) + server_reply = self.socket.recv(1024).decode() + if server_reply != "Acknowledged!": + raise IOError("Got a unexpected string.") + + logging.debug(f"Got the reply from local rank 0 for local rank id {self.local_rank_id}, " + f"rank id {self.rank_id}.") + + self.socket.close() + + +if __name__ == "__main__": + communicator = Communicator() + if communicator.local_rank_id != 0: + communicator.client_connect() + + else: + synchronizer_check_list = [i for i in range(1, 8)] + while synchronizer_check_list: + idx = int(communicator.server_reply()) + synchronizer_check_list.remove(idx) + logging.info(f"Remove NO.{idx} element for synchronizer_check_list.") + + logging.info(f"Saver synchronized.") diff --git a/mx_rec/util/tf_version_adapter.py b/mx_rec/util/tf_version_adapter.py new file mode 100644 index 00000000..6f8dfa09 --- /dev/null +++ b/mx_rec/util/tf_version_adapter.py @@ -0,0 +1,12 @@ +# coding: UTF-8 +import tensorflow as tf + +if tf.__version__.startswith("1"): + from npu_bridge.hccl import hccl_ops +else: + from npu_device.compat.v1.hccl import hccl_ops + +if tf.__version__.startswith("1"): + from npu_bridge.estimator import npu_ops +else: + from npu_device.compat.v1.estimator import npu_ops \ No newline at end of file diff --git a/mx_rec/util/variable.py b/mx_rec/util/variable.py new file mode 100644 index 00000000..dc7b1c65 --- /dev/null +++ b/mx_rec/util/variable.py @@ -0,0 +1,47 @@ +# coding: UTF-8 + +import tensorflow as tf +from tensorflow.python.framework import ops + +from mx_rec.util.initialize import get_table_instance, get_ascend_global_hashtable_collection + + +def get_dense_and_sparse_variable(): + dense_variables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) + sparse_variables = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + return dense_variables, sparse_variables + + +def remove_saving_var(variable): + global_variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + savable_objects = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS) + if variable in global_variables: + global_variables.remove(variable) + + if variable in savable_objects: + savable_objects.remove(variable) + + +def check_and_get_config_via_var(variable, optimizer_type: str): + table_instance = get_table_instance(variable) + + if not table_instance.skip_emb_transfer and not table_instance.optimizer: + raise EnvironmentError(f"When ASC with DDR, you must pass the '{optimizer_type}' optimizer instances to the" + f" init method of SparseEmbedding.") + + return table_instance + + +def check_param_range(name, value, min_border, max_border): + if value > max_border or value < min_border: + raise ValueError(f"Please offer a {name} between [{min_border}, {max_border}].") + + return + + +def check_param_type(name, value, legal_type): + if not isinstance(value, legal_type): + raise TypeError(f"Please offer a {name} within types: {legal_type}.") + + return + diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..e660a184 --- /dev/null +++ b/setup.py @@ -0,0 +1,35 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. +# Description: build script. +# Author: MindX SDK +# Create: 2022 +# History: NA + +import os +from setuptools import setup, find_packages + +try: + LONG_DESCRIPTION = open("README.md").read() +except IOError: + LONG_DESCRIPTION = "" + +env_version = os.getenv("VERSION") +VERSION = env_version if env_version is not None else '5.0.T104' + +setup( + name='mx_rec', + version=VERSION, + author='HUAWEI Inc', + url='https://www.hiascend.com/zh/software/mindx-sdk', + description='MindX SDK Recommend', + long_description=LONG_DESCRIPTION, + # include mx_rec + packages=find_packages( + where='.', + include=["mx_rec*"] + ), + package_dir={}, + # other file + package_data={'': ['*.yml', '*.sh', '*.so*']}, + # dependency + python_requires='>=3.7.5' +) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 00000000..ca19df6c --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,111 @@ +cmake_minimum_required(VERSION 3.12) +project(MxRec LANGUAGES CXX) +set(CMAKE_CXX_STANDARD 14) + +if (DEFINED OMPI_PATH) + message("== ompi path: " ${OMPI_PATH}) +else () + message("ERROR no OMPI_PATH") +endif () +if (DEFINED CMAKE_INSTALL_PREFIX) + message("== install in " ${CMAKE_INSTALL_PREFIX}) +else () + message("ERROR no CMAKE_INSTALL_PREFIX") +endif () +if (DEFINED PYTHON_PATH) + set(PYTHON_INCLUDE_DIR ${PYTHON_PATH}/include/python3.7m) + set(PYTHON_LIBRARY ${PYTHON_PATH}/lib/libpython3.7m.so) +else () + message("ERROR no PYTHON_PATH") +endif () + +set(CMAKE_PREFIX_PATH ${OMPI_PATH} ${HDF5_PATH}) +find_package(OpenMP REQUIRED) +find_package(MPI REQUIRED) +find_package(PythonLibs 3.7 REQUIRED) +if (CMAKE_BUILD_TYPE MATCHES "Debug") + find_package(easy_profiler) +else () + find_package(easy_profiler PATHS ${EASY_PROFILER_PATH} NO_DEFAULT_PATH) +endif () +if (NOT easy_profiler_FOUND) + message("===EASY_PROFILER_NOT_FOUND===") +else () + message("==EASY_PROFILER_FOUND===") + ADD_DEFINITIONS(-DBUILD_WITH_EASY_PROFILER) +endif () +set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -ffunction-sections -O0 -Wall -g2 -ggdb") +set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -ffunction-sections -O3 -Wall -DNDEBUG") + +option(ENABLE_DEBUG "use debug mode" OFF) +if (ENABLE_DEBUG) + message(ENABLE_DEBUG) + add_definitions(-DENABLE_DEBUG) +endif () + +message("== ASCEND_PATH: ${ASCEND_PATH}") + + +include_directories(SYSTEM ${MPI_INCLUDE_PATH}) +include_directories(${HDF5_INCLUDE_DIRS}) +include_directories(${ASCEND_PATH}/fwkacllib/include) +include_directories(${ASCEND_PATH}/runtime/include) +include_directories(${ASCEND_PATH}/compiler/include) +include_directories(${PROJECT_SOURCE_DIR}/core) + +if (DEFINED TF_PATH) + message("== TF_PATH: ${TF_PATH}") + include_directories(${TF_PATH}/include) + link_directories(${TF_PATH}) +else () + message("ERROR no TF_PATH try `find / -name libtensorflow_framework.so.1`") +endif () + +if (DEFINED ASCEND_DRIVER_PYTH) + message(ASCEND_DRIVER_PATH ${ASCEND_DRIVER_PYTH}) +else () + set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) +endif () + +if (${TF_PATH} MATCHES "tensorflow_core") + message("== Enable tf1 ==") + link_directories(PUBLIC ${ASCEND_PATH}/../../tfplugin/latest/python/site-packages/npu_bridge/) + link_directories(PUBLIC ${PYTHON_PATH}/lib/python3.7/site-packages/npu_bridge/) + set(TF_LIB "-l:libtensorflow_framework.so.1") +else () + message("== Enable tf2 ==") + link_directories(PUBLIC ${ASCEND_PATH}/../../tfplugin/latest/python/site-packages/npu_device/compat/v1/) + link_directories(PUBLIC ${PYTHON_PATH}/lib/python3.7/site-packages/npu_device/compat/v1/) + set(TF_LIB "-l:libtensorflow_framework.so.2") +endif () + +if (OPENMP_FOUND) + message(STATUS "OPENMP FOUND") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") +endif () +add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) +set(SPDLOG_BUILD_SHARED ON) +if(NOT OPENSOURCE_DIR) + set(OPENSOURCE_DIR ${PROJECT_SOURCE_DIR}/../../opensource/opensource/) +endif() + +if(IS_DIRECTORY ${OPENSOURCE_DIR}) + add_subdirectory(${OPENSOURCE_DIR}/pybind11 pybind11.out) + add_subdirectory(${OPENSOURCE_DIR}/spdlog spdlog.out) +else() + message(FATAL_ERROR "INVALID FOLDER") +endif() + +include_directories(${PROJECT_SOURCE_DIR}/../../opensource/opensource/spdlog/include) +install(TARGETS spdlog LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) +add_subdirectory(core) +add_subdirectory(ops_tf) +add_subdirectory(pybind) +if (CMAKE_BUILD_TYPE MATCHES "Release") + message(STATUS "CMAKE_BUILD_TYPE is Release") +else() + message(STATUS "CMAKE_BUILD_TYPE is Debug. [default]") + add_subdirectory(tests) +endif () diff --git a/src/build.sh b/src/build.sh new file mode 100644 index 00000000..098693ef --- /dev/null +++ b/src/build.sh @@ -0,0 +1,29 @@ +rm -rf build; +mkdir build && cd build || exit 1 +# HDF5_PATH, EASY_PROFILER_PATH is optional +python_path="$(dirname $(dirname $(realpath $(which python3))))" +if [ -d /usr/local/Ascend/ascend-toolkit/latest ]; then + ascend_path=/usr/local/Ascend/ascend-toolkit/latest +elif [ -d /usr/local/Ascend/latest ]; then + ascend_path=/usr/local/Ascend/latest +else + echo "ERROR: can not find toolkit and tfplugin" + exit 1 +fi + +if [ ! -d "$2"/install/abseil/ ]; then + echo "ERROR: $2/install/abseil/ not exist" + exit 1 +fi + +cmake -DCMAKE_BUILD_TYPE=Release \ + -DTF_PATH="$1" \ + -DOMPI_PATH=/usr/local/openmpi/ \ + -DPYTHON_PATH="$python_path" \ + -DEASY_PROFILER_PATH=/ \ + -DASCEND_PATH="$ascend_path" \ + -DABSEIL_PATH="$2"/install/abseil/ \ + -DSECUREC_PATH="$2"/platform/securec \ + -DCMAKE_INSTALL_PREFIX="$2"/output .. +make -j +make install diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt new file mode 100644 index 00000000..3b59fb7a --- /dev/null +++ b/src/core/CMakeLists.txt @@ -0,0 +1,51 @@ +cmake_minimum_required(VERSION 3.12) +set(CMAKE_CXX_STANDARD 17) +if(NOT ABSEIL_PATH) + set(ABSEIL_PATH ${PROJECT_SOURCE_DIR}/../install/abseil/) +endif() +message("ABSEIL_PATH: " ${ABSEIL_PATH}) + +if(NOT SECUREC_PATH) + set(SECUREC_PATH ${PROJECT_SOURCE_DIR}/../platform/securec) +endif() +message("SECUREC_PATH: " ${SECUREC_PATH}) + +include_directories(${ABSEIL_PATH}/include) +link_directories(${ABSEIL_PATH}/lib) +include_directories(${PYTHON_PATH}/lib/python3.7/site-packages/tensorflow_core/include) + +file(GLOB_RECURSE MXREC_SRC ./*.cpp) +add_library(ASC SHARED ${MXREC_SRC}) + +target_include_directories(ASC + PUBLIC + ${SECUREC_PATH}/include + ${ASCEND_DRIVER_PATH}/include +) + +target_link_directories(ASC + PUBLIC + ${ASCEND_PATH}/fwkacllib/lib64 + ${ASCEND_PATH}/compiler/lib64 + ${ASCEND_PATH}/runtime/lib64 + ${HDF5_PATH}/lib + ${SECUREC_PATH}/lib + ${ASCEND_DRIVER_PATH}/lib64/driver +) + +target_link_libraries(ASC PUBLIC ascendcl msprofiler ge_executor gert runtime ge_common register graph ascend_protobuf + profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host) +target_link_libraries(ASC PUBLIC + -l:_tf_adapter.so + OpenMP::OpenMP_CXX ${MPI_CXX_LIBRARIES} + ${PYTHON_LIBRARY} + PRIVATE spdlog::spdlog + ) +find_package(easy_profiler PATHS ${EASY_PROFILER_PATH} NO_DEFAULT_PATH) +if (easy_profiler_FOUND) + message("==link with easy_profiler==") + target_link_directories(ASC PUBLIC ${EASY_PROFILER_PATH}) + target_link_libraries(ASC PUBLIC easy_profiler ) +endif () + +install(TARGETS ASC LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp new file mode 100644 index 00000000..549d8974 --- /dev/null +++ b/src/core/checkpoint/checkpoint.cpp @@ -0,0 +1,459 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-15 + */ + +#include +#include +#include +#include + +#include "ckpt_data_handler//emb_hash_ckpt/emb_hash_ckpt.h" +#include "ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h" +#include "ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h" +#include "ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h" +#include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" + +#include "checkpoint.h" + +using namespace std; +using namespace MxRec; + +void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo) +{ + // TODO: check savePath + + processPath = savePath; + rankId = mgmtRankInfo.rankId; + deviceId = mgmtRankInfo.deviceId; + useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; + mgmtEmbInfo = EmbInfo; + + spdlog::info("Start host side saving data."); + spdlog::debug("==Start to create save data handler."); + SetDataHandler(ckptData); + spdlog::debug("==Start save data process."); + SaveProcess(ckptData); + spdlog::info("Finish host side saving data."); +} + +void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo, + const vector& featureTypes) +{ + // TODO: check loadPath + + processPath = loadPath; + rankId = mgmtRankInfo.rankId; + deviceId = mgmtRankInfo.deviceId; + useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; + mgmtEmbInfo = EmbInfo; + + spdlog::info("Start host side loading data."); + spdlog::debug("==Start to create load data handler."); + SetDataHandler(featureTypes); + spdlog::debug("==Start load data process."); + LoadProcess(ckptData); + spdlog::info("Finish host side loading data."); +} + +void Checkpoint::SetDataHandler(CkptData& ckptData) +{ + dataHandlers.clear(); + if (!ckptData.hostEmbs.empty()) { + dataHandlers.push_back(make_unique()); + } + if (!ckptData.embHashMaps.empty()) { + dataHandlers.push_back(make_unique()); + } + if (!ckptData.maxOffset.empty()) { + dataHandlers.push_back(make_unique()); + } + if (!ckptData.keyOffsetMap.empty()) { + dataHandlers.push_back(make_unique()); + } + if (!ckptData.tens2Thresh.empty() && !ckptData.histRec.timestamps.empty() && + !ckptData.histRec.historyRecords.empty()) { + dataHandlers.push_back(make_unique()); + } +} + +void Checkpoint::SetDataHandler(const vector& featureTypes) +{ + map> setCkptMap { { CkptFeatureType::HOST_EMB, + [&] { dataHandlers.push_back(make_unique()); } }, + { CkptFeatureType::EMB_HASHMAP, [&] { dataHandlers.push_back(make_unique()); } }, + { CkptFeatureType::MAX_OFFSET, [&] { dataHandlers.push_back(make_unique()); } }, + { CkptFeatureType::KEY_OFFSET_MAP, [&] { dataHandlers.push_back(make_unique()); } }, + { CkptFeatureType::FEAT_ADMIT_N_EVICT, [&] { dataHandlers.push_back(make_unique()); } } }; + + for (const auto& featureType : featureTypes) { + setCkptMap.at(featureType)(); + } +} + +void Checkpoint::SaveProcess(CkptData& ckptData) +{ + for (const auto& dataHandler : dataHandlers) { + dataHandler->SetProcessData(ckptData); + + vector embNames { dataHandler->GetEmbNames() }; + vector dirNames { dataHandler->GetDirNames() }; + vector saveDataTypes { dataHandler->GetDataTypes() }; + + MakeUpperLayerSaveDir(dirNames); + MakeDataLayerSaveDir(embNames, saveDataTypes, dataHandler); + + SaveDataset(embNames, saveDataTypes, dataHandler); + } +} + +void Checkpoint::MakeUpperLayerSaveDir(const vector& dirNames) +{ + innerDirPath = processPath; + MakeSaveDir(innerDirPath); + + for (const auto& dirName : dirNames) { + innerDirPath = innerDirPath + dirSeparator + dirName; + MakeSaveDir(innerDirPath); + } +} + +void Checkpoint::MakeDataLayerSaveDir(const vector& embNames, + const vector& saveDataTypes, + const unique_ptr& dataHandler) +{ + for (const auto& embName : embNames) { + auto dataDir { innerDirPath + dirSeparator + embName }; + MakeSaveDir(dataDir); + + for (const auto& saveDataType : saveDataTypes) { + auto dataDirName { dataHandler->GetDataDirName(saveDataType) }; + auto datasetPath { dataDir + dirSeparator + dataDirName }; + MakeSaveDir(datasetPath); + } + } +} + +void Checkpoint::MakeSaveDir(const string& dirName) +{ + if (mkdir(dirName.c_str(), dirMode) == -1) { + spdlog::debug("Unable to create directory: {}", dirName); + } +} + +int Checkpoint::GetEmbeddingSize(const string& embName) +{ + for (const auto &embInfo: mgmtEmbInfo) { + if (embInfo.name == embName) { + return embInfo.embeddingSize; + } + } + return 0; +} + + +void Checkpoint::SaveDataset(const vector& embNames, + const vector& saveDataTypes, + const unique_ptr& dataHandler) +{ + for (const auto& embName: embNames) { + auto dataDir{innerDirPath + dirSeparator + embName}; + for (const auto& saveDataType: saveDataTypes) { + auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; + auto datasetDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; + auto attributeDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + attribFileType }; + + spdlog::debug("====Start getting data from handler to: {}", datasetDir); + auto transData { dataHandler->GetDataset(saveDataType, embName) }; + + // save embedding when dynamic expansion is open + if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { + auto embedPath { dataDir + dirSeparator + "key_embedding" }; + auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; + auto embeddingSize = GetEmbeddingSize(embName); + MakeSaveDir(embedPath); + spdlog::debug("====Start saving embedding data to: {}", datasetDir); + WriteEmbedding(transData, embedDatasetDir, embeddingSize); + } + + spdlog::debug("====Start saving data to: {}", datasetDir); + WriteStream(transData, datasetDir, transData.datasetSize, saveDataType); + spdlog::debug("====Start saving data to: {}", attributeDir); + WriteStream(transData, attributeDir, transData.attributeSize, CkptDataType::ATTRIBUTE); + } + } +} + +void Checkpoint::WriteEmbedding(CkptTransData& transData, const string& dataDir, int& embeddingSize) +{ + ofstream writeFile; + writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + spdlog::error("Set device failed, device_id:{}", deviceId); + } + + auto &transArr = transData.int64Arr; + for (size_t i{0}; i < transArr.size(); i += embHashNum) { + vector row(embeddingSize); + int64_t address = transArr.at(i + 1); + float *floatPtr = reinterpret_cast(address); + + aclError ret = aclrtMemcpy(row.data(), embeddingSize * sizeof(float), + floatPtr, embeddingSize * sizeof(float), + ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtMemcpy failed, ret={}", ret); + } + + writeFile.write((const char *) (row.data()), embeddingSize * sizeof(float)); + } +#endif + writeFile.close(); +} + +void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) +{ + std::ifstream readFile; + readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + spdlog::error("Set device failed, device_id:{}", deviceId); + } + + auto &AttributeArr = transData.attribute; + auto embHashMapSize = AttributeArr.at(0); + size_t datasetSize = readFile.tellg(); + readFile.seekg(0, std::ios::beg); + auto embeddingSize = static_cast(datasetSize / sizeof(float) / embHashMapSize); + + aclError ret; + void *newBlock = nullptr; + ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtMalloc failed, ret={}", ret); + } + + float *floatPtr = static_cast(newBlock); + auto &transArr = transData.int64Arr; + for (size_t i{0}; i < transArr.size(); i += embHashNum) { + vector row(embeddingSize); + readFile.read((char *) (row.data()), embeddingSize * sizeof(float)); + + aclError ret = aclrtMemcpy(floatPtr + i * embeddingSize, embeddingSize * sizeof(float), + row.data(), embeddingSize * sizeof(float), + ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtMemcpy failed, ret={}", ret); + } + + int64_t address = reinterpret_cast(floatPtr + i * embeddingSize); + transArr.at(i + 1) = address; + } +#endif + readFile.close(); +} + +void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType) +{ + ofstream writeFile; + writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); + + if (writeFile.is_open()) { + size_t idx = 0; + size_t writeSize = 0; + while (dataSize != 0) { + if (dataSize > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = dataSize; + } + + WriteDataset(transData, writeFile, writeSize, dataType, idx); + + dataSize -= writeSize; + idx += writeSize; + } + } else { + spdlog::debug("unable to open save file: {}", dataDir); + } + writeFile.close(); +} + +void Checkpoint::WriteDataset(CkptTransData& transData, + ofstream& writeFile, + size_t writeSize, + CkptDataType dataType, + size_t idx) +{ + if (int32TransSet.find(dataType) != int32TransSet.end()) { + writeFile.write((const char*)(transData.int32Arr.data()) + idx, writeSize); + } else if (int64TransSet.find(dataType) != int64TransSet.end()) { + writeFile.write((const char*)(transData.int64Arr.data()) + idx, writeSize); + } else if (floatTransSet.find(dataType) != floatTransSet.end()) { + writeFile.write((const char*)(transData.floatArr.data()) + idx, writeSize); + } else if (dataType == CkptDataType::ATTRIBUTE) { + writeFile.write((const char*)(transData.attribute.data()) + idx, writeSize); + } +} + +void Checkpoint::LoadProcess(CkptData& ckptData) +{ + for (const auto& dataHandler : dataHandlers) { + vector embNames {}; + vector dirNames { dataHandler->GetDirNames() }; + vector saveDataTypes { dataHandler->GetDataTypes() }; + + GetUpperLayerLoadDir(dirNames); + embNames = GetTableLayerLoadDir(); + + LoadDataset(embNames, saveDataTypes, dataHandler); + + dataHandler->GetProcessData(ckptData); + } +} + +void Checkpoint::GetUpperLayerLoadDir(const vector& dirNames) +{ + innerDirPath = processPath; + // TODO: check existence + + for (const auto& dirName : dirNames) { + innerDirPath = innerDirPath + dirSeparator + dirName; + // TODO: check existence + } +} + +vector Checkpoint::GetTableLayerLoadDir() +{ + vector loadTableDir; + auto dir { opendir(innerDirPath.c_str()) }; + struct dirent* en; + if (dir != nullptr) { + while ((en = readdir(dir)) != nullptr) { + if (strcmp(en->d_name, currDir.c_str()) != 0 && + strcmp(en->d_name, prevDir.c_str()) != 0) { + loadTableDir.emplace_back(en->d_name); + } + } + closedir(dir); + } + // TODO: may cause memory problem? need to check + + return loadTableDir; +} + +void Checkpoint::LoadDataset(const vector& embNames, + const vector& saveDataTypes, + const unique_ptr& dataHandler) +{ + for (const auto& embName : embNames) { + auto dataDir { innerDirPath + dirSeparator + embName }; + // TODO: check existence + for (const auto& saveDataType : saveDataTypes) { + auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; + // TODO: check existence + + auto datasetDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; + auto attributeDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + attribFileType }; + + CkptTransData transData; + auto dataElmtBytes { dataHandler->GetDataElmtBytes(saveDataType) }; + + spdlog::debug("====Start reading data from: {}", datasetDir); + ReadStream(transData, datasetDir, saveDataType, dataElmtBytes); + + spdlog::debug("====Start reading data from: {}", attributeDir); + dataElmtBytes = dataHandler->GetDataElmtBytes(CkptDataType::ATTRIBUTE); + ReadStream(transData, attributeDir, CkptDataType::ATTRIBUTE, dataElmtBytes); + + // load embedding when use dynamic expansion is open + if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { + auto embedPath { dataDir + dirSeparator + "key_embedding" }; + auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; + spdlog::debug("====Start loading embedding data from: {}", datasetDir); + ReadEmbedding(transData, embedDatasetDir); + } + + spdlog::debug("====Start loading data from: {} to data handler.", attributeDir); + dataHandler->SetDataset(saveDataType, embName, transData); + } + } +} + +void Checkpoint::ReadStream(CkptTransData& transData, + const string& dataDir, + CkptDataType dataType, + uint32_t dataElmtBytes) +{ + if (dataElmtBytes == 0) { + spdlog::error("dataElmtBytes is 0, don't handle [/ %] operation"); + return ; + } + std::ifstream readFile; + readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); + + size_t datasetSize = readFile.tellg(); + readFile.seekg(0, std::ios::beg); + + if (datasetSize % dataElmtBytes > 0) { + spdlog::debug("data is missing or incomplete in load file: {}", dataDir); + } + auto resizeSize { datasetSize / dataElmtBytes }; + SetTransDataSize(transData, resizeSize, dataType); + if (readFile.is_open()) { + size_t idx = 0; + size_t readSize = 0; + while (datasetSize != 0) { + if (datasetSize > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = datasetSize; + } + ReadDataset(transData, readFile, readSize, dataType, idx); + + datasetSize -= readSize; + idx += readSize; + } + } else { + spdlog::debug("unable to open load file: {}", dataDir); + } + + readFile.close(); +} + +void Checkpoint::SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType) +{ + if (int32TransSet.find(dataType) != int32TransSet.end()) { + transData.int32Arr.resize(datasetSize); + } else if (int64TransSet.find(dataType) != int64TransSet.end()) { + transData.int64Arr.resize(datasetSize); + } else if (floatTransSet.find(dataType) != floatTransSet.end()) { + transData.floatArr.resize(datasetSize); + } else if (dataType == CkptDataType::ATTRIBUTE) { + transData.attribute.resize(datasetSize); + } +} + +void Checkpoint::ReadDataset(CkptTransData& transData, + ifstream& readFile, + size_t readSize, + CkptDataType dataType, + size_t idx) +{ + if (int32TransSet.find(dataType) != int32TransSet.end()) { + readFile.read((char*)(transData.int32Arr.data()) + idx, readSize); + } else if (int64TransSet.find(dataType) != int64TransSet.end()) { + readFile.read((char*)(transData.int64Arr.data()) + idx, readSize); + } else if (floatTransSet.find(dataType) != floatTransSet.end()) { + readFile.read((char*)(transData.floatArr.data()) + idx, readSize); + } else if (dataType == CkptDataType::ATTRIBUTE) { + readFile.read((char*)(transData.attribute.data()) + idx, readSize); + } +} \ No newline at end of file diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h new file mode 100644 index 00000000..c1e7d9bc --- /dev/null +++ b/src/core/checkpoint/checkpoint.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: use to manage model saving and loading process + * Author: MindX SDK + * Create: 2022-11-15 + */ + +#ifndef MX_REC_CHECKPOINT_H +#define MX_REC_CHECKPOINT_H + +#include +#include +#include + +#include "ckpt_data_handler/ckpt_data_handler.h" + +namespace MxRec { + using namespace std; + + class Checkpoint { + public: + Checkpoint() = default; + ~Checkpoint() {}; + + void SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo); + void LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo, + const vector& featureTypes); + + private: + const string datasetName { "slice_" }; + const string dataFileType { ".data" }; + const string attribFileType { ".attribute" }; + const string dirSeparator { "/" }; + const mode_t dirMode { 0777 }; + + const string currDir { "." }; + const string prevDir { ".." }; + + const size_t oneTimeReadWriteLen { 32768 }; // 4096 * 8 + + const set int32TransSet { + CkptDataType::EMB_INFO, + CkptDataType::EMB_CURR_STAT, + CkptDataType::NDDR_OFFSET, + CkptDataType::TENSOR_2_THRESH + }; + const set int64TransSet{ + CkptDataType::EMB_HASHMAP, + CkptDataType::DEV_OFFSET, + CkptDataType::HIST_REC, + CkptDataType::NDDR_FEATMAP + }; + const set floatTransSet{ + CkptDataType::EMB_DATA + }; + + vector> dataHandlers; + string processPath; + string innerDirPath; + + int rankId; + int deviceId; + bool useDynamicExpansion {false}; + vector mgmtEmbInfo; + + const int embHashNum { 2 }; + + void SetDataHandler(CkptData& ckptData); + void SetDataHandler(const vector& featureTypes); + + void SaveProcess(CkptData& ckptData); + void MakeUpperLayerSaveDir(const vector& dirNames); + void MakeDataLayerSaveDir(const vector& embNames, const vector& saveDataTypes, + const unique_ptr& dataHandler); + void MakeSaveDir(const string& dirName); + void SaveDataset(const vector& embNames, const vector& saveDataTypes, + const unique_ptr& dataHandler); + void WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType); + void WriteDataset(CkptTransData& transData, ofstream& writeFile, size_t writeSize, CkptDataType dataType, + size_t idx); + void WriteEmbedding(CkptTransData& transData, const string& dataDir, int& embeddingSize); + void ReadEmbedding(CkptTransData& transData, const string& dataDir); + + int GetEmbeddingSize(const string& embName); + + void LoadProcess(CkptData& ckptData); + void GetUpperLayerLoadDir(const vector& dirNames); + vector GetTableLayerLoadDir(); + void LoadDataset(const vector& embNames, const vector& saveDataTypes, + const unique_ptr& dataHandler); + void ReadStream(CkptTransData& transData, const string& dataDir, CkptDataType dataType, uint32_t dataElmtBytes); + void SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType); + void ReadDataset(CkptTransData& transData, ifstream& readFile, size_t readSize, CkptDataType dataType, + size_t idx); + }; +} + +#endif // MX_REC_CHECKPOINT_H diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.cpp b/src/core/ckpt_data_handler/ckpt_data_handler.cpp new file mode 100644 index 00000000..abe985a7 --- /dev/null +++ b/src/core/ckpt_data_handler/ckpt_data_handler.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-12 + */ + +#include "ckpt_data_handler.h" + + +using namespace std; +using namespace MxRec; + +uint32_t CkptDataHandler::GetDataElmtBytes(CkptDataType dataType) +{ + return dataElmtBytes.at(static_cast(dataType)); +} + +string CkptDataHandler::GetDataDirName(CkptDataType dataType) +{ + return dataDirNames.at(static_cast(dataType)); +} + +void CkptDataHandler::CleanTransfer() +{ + transferData.int64Arr.clear(); + transferData.int32Arr.clear(); + transferData.floatArr.clear(); + transferData.attribute.clear(); + transferData.datasetSize = 0; + transferData.attributeSize = 0; +} \ No newline at end of file diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h new file mode 100644 index 00000000..87eb9b6a --- /dev/null +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-10 + */ + +#ifndef MX_REC_CKPT_DATA_HANDLER_H +#define MX_REC_CKPT_DATA_HANDLER_H + +#include + +#include "emb_hashmap/emb_hashmap.h" +#include "host_emb/host_emb.h" +#include "utils/common.h" + +namespace MxRec { + using namespace std; + + class CkptDataHandler { + public: + CkptDataHandler() = default; + virtual ~CkptDataHandler() {}; + + virtual void SetProcessData(CkptData& processData) = 0; + virtual void GetProcessData(CkptData& processData) = 0; + + virtual vector GetDataTypes() = 0; + uint32_t GetDataElmtBytes(CkptDataType dataType); + string GetDataDirName(CkptDataType dataType); + + virtual vector GetDirNames() = 0; + virtual vector GetEmbNames() = 0; + virtual CkptTransData GetDataset(CkptDataType dataType, string embName) = 0; + + virtual void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) = 0; + + protected: + const vector dataDirNames { + "embedding_info", + "embedding_data", + "embedding_hashmap", + "dev_offset_2_Batch_n_Key", + "embedding_current_status", + "max_offset", + "key_offset_map", + "tensor_2_threshold", + "history_record" + }; + const vector dataElmtBytes { 4, 4, 8, 8, 4, 4, 8, 4, 8, 8 }; + + const uint32_t eightBytes { 8 }; + const uint32_t fourBytes { 4 }; + + CkptTransData transferData; + + void CleanTransfer(); + }; +} + +#endif // MX_REC_CKPT_DATA_HANDLER_H diff --git a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp new file mode 100644 index 00000000..a8de3aed --- /dev/null +++ b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp @@ -0,0 +1,164 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-14 + */ + +#include "emb_hash_ckpt.h" +#include + +using namespace std; +using namespace MxRec; + +// remember to comment that the function will take over the control on the mem space + +void EmbHashCkpt::SetProcessData(CkptData& processData) +{ + saveEmbHashMaps.clear(); + loadEmbHashMaps.clear(); + saveEmbHashMaps = std::move(processData.embHashMaps); +} + +void EmbHashCkpt::GetProcessData(CkptData& processData) +{ + processData.embHashMaps = std::move(loadEmbHashMaps); + saveEmbHashMaps.clear(); + loadEmbHashMaps.clear(); +} + +vector EmbHashCkpt::GetDataTypes() +{ + return saveDataTypes; +} + +vector EmbHashCkpt::GetDirNames() +{ + return fileDirNames; +} + +vector EmbHashCkpt::GetEmbNames() +{ + vector embNames; + for (const auto& item : saveEmbHashMaps) { + embNames.push_back(item.first); + } + return embNames; +} + +CkptTransData EmbHashCkpt::GetDataset(CkptDataType dataType, string embName) +{ + map> dataTransMap { { CkptDataType::EMB_HASHMAP, + [=] { SetEmbHashMapTrans(embName); } }, + { CkptDataType::DEV_OFFSET, [=] { SetDevOffsetTrans(embName); } }, + { CkptDataType::EMB_CURR_STAT, [=] { SetEmbCurrStatTrans(embName); } } }; + + CleanTransfer(); + dataTransMap.at(dataType)(); + return move(transferData); +} + +void EmbHashCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) +{ + map> dataLoadMap { { CkptDataType::EMB_HASHMAP, [=] { SetEmbHashMap(embName); } }, + { CkptDataType::DEV_OFFSET, [=] { SetDevOffset(embName); } }, + { CkptDataType::EMB_CURR_STAT, [=] { SetEmbCurrStat(embName); } } }; + + CleanTransfer(); + transferData = move(loadedData); + dataLoadMap.at(dataType)(); +} + +void EmbHashCkpt::SetEmbHashMapTrans(string embName) +{ + auto& transArr = transferData.int64Arr; + auto& attribute = transferData.attribute; + auto embHashMapSize = saveEmbHashMaps.at(embName).hostHashMap.size(); + + attribute.push_back(embHashMapSize); + embHashMapSize = embHashMapSize * embHashElmtNum; + + attribute.push_back(embHashElmtNum); + attribute.push_back(eightBytes); + transferData.datasetSize = embHashMapSize * eightBytes; + transferData.attributeSize = attribute.size() * eightBytes; + transArr.reserve(embHashMapSize); + for (const auto& it : saveEmbHashMaps.at(embName).hostHashMap) { + transArr.push_back(it.first); + transArr.push_back(static_cast(it.second)); + } +} + +void EmbHashCkpt::SetDevOffsetTrans(string embName) +{ + const auto& devOffset2Batch = saveEmbHashMaps.at(embName).devOffset2Batch; + const auto& devOffset2Key = saveEmbHashMaps.at(embName).devOffset2Key; + auto& transArr = transferData.int64Arr; + auto& attribute = transferData.attribute; + auto embDevOffsetSize = devOffset2Batch.size(); + embDevOffsetSize += devOffset2Key.size(); + + attribute.push_back(devOffset2Batch.size()); + attribute.push_back(devOffset2Key.size()); + attribute.push_back(eightBytes); + transferData.datasetSize = embDevOffsetSize * eightBytes; + transferData.attributeSize = attribute.size() * eightBytes; + + transArr.reserve(embDevOffsetSize); + transArr.insert(transArr.end(), devOffset2Batch.begin(), devOffset2Batch.end()); + transArr.insert(transArr.end(), devOffset2Key.begin(), devOffset2Key.end()); +} + +void EmbHashCkpt::SetEmbCurrStatTrans(string embName) +{ + auto& transArr = transferData.int32Arr; + auto& attribute = transferData.attribute; + auto embDevOffsetSize = embCurrStatNum; + + attribute.push_back(embCurrStatNum); + attribute.push_back(fourBytes); + transferData.datasetSize = embCurrStatNum * fourBytes; + transferData.attributeSize = attribute.size() * eightBytes; + + transArr.reserve(embDevOffsetSize); + transArr.push_back(static_cast(saveEmbHashMaps.at(embName).currentUpdatePos)); + transArr.push_back(static_cast(saveEmbHashMaps.at(embName).hostVocabSize)); + transArr.push_back(static_cast(saveEmbHashMaps.at(embName).devVocabSize)); +} + +void EmbHashCkpt::SetEmbHashMap(string embName) +{ + auto& hostHashMap = loadEmbHashMaps[embName].hostHashMap; + const auto& transArr = transferData.int64Arr; + for (size_t i = 0; i < transArr.size(); i += embHashElmtNum) { + if (i + embHashElmtNum > transArr.size()) { + // this is an error, need to log this + } + + hostHashMap[transArr.at(i)] = static_cast(transArr.at(i + 1)); + } +} + +void EmbHashCkpt::SetDevOffset(string embName) +{ + const auto& transArr = transferData.int64Arr; + const auto& attribute = transferData.attribute; + auto& dev2Batch = loadEmbHashMaps[embName].devOffset2Batch; + auto& dev2Key = loadEmbHashMaps[embName].devOffset2Key; + + dev2Batch.resize(attribute.at(attrbDev2BatchIdx)); + dev2Key.reserve(attribute.at(attrbDev2KeyIdx)); + + fill(dev2Batch.begin(), dev2Batch.end(), -1); + dev2Key.insert(dev2Key.begin(), transArr.begin() + attribute.at(attrbDev2BatchIdx), transArr.end()); +} + +void EmbHashCkpt::SetEmbCurrStat(string embName) +{ + auto& embCurrStat = loadEmbHashMaps[embName]; + const auto& transArr = transferData.int32Arr; + + embCurrStat.currentUpdatePos = static_cast(transArr.at(currUpdataPosIdx)); + embCurrStat.hostVocabSize = static_cast(transArr.at(hostVocabIdx)); + embCurrStat.devVocabSize = static_cast(transArr.at(devVocabIdx)); +} \ No newline at end of file diff --git a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h new file mode 100644 index 00000000..8596f74a --- /dev/null +++ b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-14 + */ + +#ifndef MX_REC_EMB_HASH_CKPT_H +#define MX_REC_EMB_HASH_CKPT_H + +#include "ckpt_data_handler/ckpt_data_handler.h" + +namespace MxRec { + using namespace std; + + class EmbHashCkpt : public CkptDataHandler { + public: + EmbHashCkpt() = default; + ~EmbHashCkpt() override {} + + void SetProcessData(CkptData& processData) override; + void GetProcessData(CkptData& processData) override; + + vector GetDataTypes() override; + + vector GetDirNames() override; + vector GetEmbNames() override; + CkptTransData GetDataset(CkptDataType dataType, string embName) override; + + void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) override; + + private: + const vector fileDirNames { "HashTable", "DDR" }; + const vector saveDataTypes { CkptDataType::EMB_HASHMAP, CkptDataType::DEV_OFFSET, + CkptDataType::EMB_CURR_STAT }; + + const int currUpdataPosIdx { 0 }; + const int hostVocabIdx { 1 }; + const int devVocabIdx { 2 }; + + const int attrbDev2BatchIdx { 0 }; + const int attrbDev2KeyIdx { 1 }; + + const int embHashElmtNum { 2 }; + const int embCurrStatNum { 3 }; + emb_hash_mem_t saveEmbHashMaps; + emb_hash_mem_t loadEmbHashMaps; + + void SetEmbHashMapTrans(string embName); + void SetDevOffsetTrans(string embName); + void SetEmbCurrStatTrans(string embName); + + void SetEmbHashMap(string embName); + void SetDevOffset(string embName); + void SetEmbCurrStat(string embName); + }; +} + +#endif // MX_REC_EMB_HASH_CKPT_H \ No newline at end of file diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp new file mode 100644 index 00000000..6c342253 --- /dev/null +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -0,0 +1,176 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-22 + */ + +#include + +#include "feat_admit_n_evict_ckpt.h" + +using namespace std; +using namespace MxRec; + +void FeatAdmitNEvictCkpt::SetProcessData(CkptData& processData) +{ + ClearData(); + if (processData.tens2Thresh.empty() || processData.histRec.timestamps.empty() || + processData.histRec.historyRecords.empty()) { + spdlog::error("Missing Feature Admit and Evict data"); + } + saveTens2Thresh = std::move(processData.tens2Thresh); + saveHistRec = std::move(processData.histRec); +} + +void FeatAdmitNEvictCkpt::GetProcessData(CkptData& processData) +{ + processData.tens2Thresh = std::move(loadTens2Thresh); + processData.histRec = std::move(loadHistRec); + ClearData(); +} + +vector FeatAdmitNEvictCkpt::GetDataTypes() +{ + return saveDataTypes; +} + +vector FeatAdmitNEvictCkpt::GetDirNames() +{ + return fileDirNames; +} + +vector FeatAdmitNEvictCkpt::GetEmbNames() +{ + vector embNames; + for (const auto& item : saveTens2Thresh) { + embNames.push_back(item.first); + } + return embNames; +} + +CkptTransData FeatAdmitNEvictCkpt::GetDataset(CkptDataType dataType, string embName) +{ + map> dataTransMap { { CkptDataType::TENSOR_2_THRESH, + [=] { SetTens2ThreshTrans(embName); } }, + { CkptDataType::HIST_REC, [=] { SetHistRecTrans(embName); } } }; + + CleanTransfer(); + dataTransMap.at(dataType)(); + return move(transferData); +} + +void FeatAdmitNEvictCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) +{ + map> dataLoadMap { { CkptDataType::TENSOR_2_THRESH, + [=] { SetTens2Thresh(embName); } }, + { CkptDataType::HIST_REC, [=] { SetHistRec(embName); } } }; + + CleanTransfer(); + transferData = move(loadedData); + dataLoadMap.at(dataType)(); +} + +void FeatAdmitNEvictCkpt::ClearData() +{ + saveTens2Thresh.clear(); + loadTens2Thresh.clear(); + saveHistRec.timestamps.clear(); + saveHistRec.historyRecords.clear(); + loadHistRec.timestamps.clear(); + loadHistRec.historyRecords.clear(); +} + +void FeatAdmitNEvictCkpt::SetTens2ThreshTrans(string embName) +{ + auto tens2ThreshSize = GetTens2ThreshSize(); + auto& transArr = transferData.int32Arr; + const auto& tens2Thresh = saveTens2Thresh.at(embName); + + transArr.reserve(tens2ThreshSize); + transArr.push_back(tens2Thresh.countThreshold); + transArr.push_back(tens2Thresh.timeThreshold); +} + +void FeatAdmitNEvictCkpt::SetHistRecTrans(string embName) +{ + auto histRecSize = GetHistRecSize(embName); + auto& transArr = transferData.int64Arr; + const auto& timeStamp = saveHistRec.timestamps.at(embName); + const auto& histRecs = saveHistRec.historyRecords.at(embName); + + transArr.reserve(histRecSize); + + transArr.push_back(timeStamp); + for (const auto& histRec : histRecs) { + transArr.push_back(histRec.second.featureId); + transArr.push_back(static_cast(histRec.second.count)); + transArr.push_back(static_cast(histRec.second.lastTime)); + } +} + +void FeatAdmitNEvictCkpt::SetTens2Thresh(string embName) +{ + const auto& transArr = transferData.int32Arr; + auto& tens2Thresh = loadTens2Thresh[embName]; + + tens2Thresh.tensorName = embName; + tens2Thresh.countThreshold = transArr[countThresholdIdx]; + tens2Thresh.timeThreshold = transArr[timeThresholdIdx]; +} + +void FeatAdmitNEvictCkpt::SetHistRec(string embName) +{ + const auto& transArr = transferData.int64Arr; + const auto& attribute = transferData.attribute; + auto& timestamp = loadHistRec.timestamps[embName]; + auto& histRecs = loadHistRec.historyRecords[embName]; + + timestamp = transArr.front(); + + size_t featItemInfoTotalSize = attribute.front() * static_cast(featItemInfoSaveNum); + for (size_t i = featItemInfoOffset; i < featItemInfoTotalSize + featItemInfoOffset; i += featItemInfoSaveNum) { + const auto& featureId = transArr[i + featureIdIdxOffset]; + const auto& count = transArr[i + countIdxOffset]; + const auto& lastTime = transArr[i + lastTimeIdxOffset]; + + histRecs[featureId].featureId = featureId; + histRecs[featureId].count = count; + histRecs[featureId].lastTime = lastTime; + histRecs[featureId].tensorName = embName; + } +} + +int FeatAdmitNEvictCkpt::GetTens2ThreshSize() +{ + auto& attribute = transferData.attribute; + auto& attribSize = transferData.attributeSize; + auto& datasetSize = transferData.datasetSize; + + attribute.push_back(threshValSaveNum); + attribute.push_back(fourBytes); + attribSize = attribute.size() * eightBytes; + + datasetSize = threshValSaveNum * fourBytes; + + return threshValSaveNum; +} + +size_t FeatAdmitNEvictCkpt::GetHistRecSize(string embName) +{ + auto& attribute = transferData.attribute; + auto& attribSize = transferData.attributeSize; + auto& datasetSize = transferData.datasetSize; + + size_t timeStampNum = 1; // there will be only 1 timeStamp per embName + auto hashElmtNum = saveHistRec.historyRecords.at(embName).size(); + attribute.push_back(hashElmtNum); + auto elmtCount = timeStampNum + featItemInfoSaveNum * static_cast(hashElmtNum); + + attribute.push_back(eightBytes); + attribSize = attribute.size() * eightBytes; + + datasetSize = elmtCount * eightBytes; + + return elmtCount; +} diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h new file mode 100644 index 00000000..2b12d315 --- /dev/null +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-22 + */ + +#ifndef MXREC_FEAT_ADMIT_N_EVICT_CKPT_H +#define MXREC_FEAT_ADMIT_N_EVICT_CKPT_H + +#include "ckpt_data_handler/ckpt_data_handler.h" + +namespace MxRec { + using namespace std; + + class FeatAdmitNEvictCkpt : public CkptDataHandler { + public: + FeatAdmitNEvictCkpt() = default; + ~FeatAdmitNEvictCkpt() override {} + + void SetProcessData(CkptData& processData) override; + void GetProcessData(CkptData& processData) override; + + vector GetDataTypes() override; + + vector GetDirNames() override; + vector GetEmbNames() override; + CkptTransData GetDataset(CkptDataType dataType, string embName) override; + + void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) override; + + private: + const vector fileDirNames { "HashTable", "DDR" }; + const vector saveDataTypes { CkptDataType::TENSOR_2_THRESH, CkptDataType::HIST_REC }; + + const int featItemInfoSaveNum { 3 }; + const int threshValSaveNum { 2 }; + + const int countThresholdIdx { 0 }; + const int timeThresholdIdx { 1 }; + + const int featItemInfoOffset { 1 }; + + const int featureIdIdxOffset { 0 }; + const int countIdxOffset { 1 }; + const int lastTimeIdxOffset { 2 }; + + tensor_2_thresh_mem_t saveTens2Thresh; + tensor_2_thresh_mem_t loadTens2Thresh; + + AdmitAndEvictData saveHistRec; + AdmitAndEvictData loadHistRec; + + void ClearData(); + + void SetTens2ThreshTrans(string embName); + void SetHistRecTrans(string embName); + + void SetTens2Thresh(string embName); + void SetHistRec(string embName); + + int GetTens2ThreshSize(); + size_t GetHistRecSize(string embName); + }; +} + +#endif // MXREC_FEAT_ADMIT_N_EVICT_CKPT_H diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp new file mode 100644 index 00000000..c4dba3a5 --- /dev/null +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp @@ -0,0 +1,143 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-12 + */ + +#include "host_emb_ckpt.h" + + +using namespace std; +using namespace MxRec; + +// remember to comment that the function will take over the control on the mem space + +void HostEmbCkpt::SetProcessData(CkptData& processData) +{ + saveHostEmbs.clear(); + loadHostEmbs.clear(); + saveHostEmbs = std::move(processData.hostEmbs); +} + +void HostEmbCkpt::GetProcessData(CkptData& processData) +{ + processData.hostEmbs = std::move(loadHostEmbs); + saveHostEmbs.clear(); + loadHostEmbs.clear(); +} + +vector HostEmbCkpt::GetDataTypes() +{ + return saveDataTypes; +} + +vector HostEmbCkpt::GetDirNames() +{ + return fileDirNames; +} + +vector HostEmbCkpt::GetEmbNames() +{ + vector embNames; + for (const auto& item : saveHostEmbs) { + embNames.push_back(item.first); + } + return embNames; +} + +CkptTransData HostEmbCkpt::GetDataset(CkptDataType dataType, string embName) +{ + map> dataTransMap { { CkptDataType::EMB_INFO, [=] { SetEmbInfoTrans(embName); } }, + { CkptDataType::EMB_DATA, [=] { SetEmbDataTrans(embName); } } }; + + CleanTransfer(); + dataTransMap.at(dataType)(); + return move(transferData); +} + +void HostEmbCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) +{ + map> dataLoadMap { { CkptDataType::EMB_INFO, [=] { SetEmbInfo(embName); } }, + { CkptDataType::EMB_DATA, [=] { SetEmbData(embName); } } }; + + CleanTransfer(); + transferData = move(loadedData); + dataLoadMap.at(dataType)(); +} + +void HostEmbCkpt::SetEmbInfoTrans(string embName) +{ + auto embInfoSize = GetEmbInfoSize(); + auto& transArr = transferData.int32Arr; + const auto& hostEmbInfo = saveHostEmbs.at(embName).hostEmbInfo; + + transArr.reserve(embInfoSize); + transArr.push_back(hostEmbInfo.sendCount); + transArr.push_back(hostEmbInfo.embeddingSize); + transArr.push_back(static_cast(hostEmbInfo.devVocabSize)); + transArr.push_back(static_cast(hostEmbInfo.hostVocabSize)); +} + +void HostEmbCkpt::SetEmbDataTrans(string embName) +{ + auto embDataSize = GetEmbDataSize(embName); + transferData.floatArr.reserve(embDataSize); + for (const auto& item : saveHostEmbs.at(embName).embData) { + transferData.floatArr.insert(transferData.floatArr.end(), item.begin(), item.end()); + } +} + +void HostEmbCkpt::SetEmbInfo(string embName) +{ + auto& hostEmbInfo = loadHostEmbs[embName].hostEmbInfo; + const auto& transArr = transferData.int32Arr; + + hostEmbInfo.name = embName; + hostEmbInfo.sendCount = transArr.at(attribEmbInfoSendCntIdx); + hostEmbInfo.embeddingSize = transArr.at(attribEmbInfoEmbSizeIdx); + hostEmbInfo.devVocabSize = static_cast(transArr.at(attribEmbInfoDevVocabIdx)); + hostEmbInfo.hostVocabSize = static_cast(transArr.at(attribEmbInfoHostVocabIdx)); +} + +void HostEmbCkpt::SetEmbData(string embName) +{ + vector embValues; + auto embDataOuterSize = transferData.attribute.at(attribEmbDataOuterIdx); + auto embDataInnerSize = transferData.attribute.at(attribEmbDataInnerIdx); + auto rawBegin = transferData.floatArr.begin(); + loadHostEmbs[embName].embData.reserve(embDataOuterSize); + for (size_t i = 0; i < embDataOuterSize; ++i) { + size_t beginShift = i * embDataInnerSize; + size_t endShift = (i + 1) * embDataInnerSize; + embValues.reserve(embDataInnerSize); + embValues.insert(embValues.begin(), rawBegin + beginShift, rawBegin + endShift); + loadHostEmbs[embName].embData.push_back(move(embValues)); + } +} + +int HostEmbCkpt::GetEmbInfoSize() +{ + transferData.attribute.push_back(embSveElmtNum); + transferData.attribute.push_back(fourBytes); + transferData.datasetSize = embSveElmtNum * fourBytes; + transferData.attributeSize = transferData.attribute.size() * eightBytes; + + return embSveElmtNum; +} + +size_t HostEmbCkpt::GetEmbDataSize(string embName) +{ + auto embDataOuterSize = saveHostEmbs.at(embName).embData.size(); + transferData.attribute.push_back(embDataOuterSize); + + auto embDataInnerSize = saveHostEmbs.at(embName).embData.at(0).size(); + transferData.attribute.push_back(embDataInnerSize); + + transferData.attribute.push_back(fourBytes); + + transferData.datasetSize = embDataOuterSize * embDataInnerSize * fourBytes; + transferData.attributeSize = transferData.attribute.size() * eightBytes; + + return embDataOuterSize * embDataInnerSize; +} \ No newline at end of file diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h new file mode 100644 index 00000000..b48e3459 --- /dev/null +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-12 + */ + +#ifndef MX_REC_HOST_EMB_CKPT_H +#define MX_REC_HOST_EMB_CKPT_H + +#include "ckpt_data_handler/ckpt_data_handler.h" + +namespace MxRec { + using namespace std; + + class HostEmbCkpt : public CkptDataHandler { + public: + HostEmbCkpt() = default; + ~HostEmbCkpt() override {} + + void SetProcessData(CkptData& processData) override; + void GetProcessData(CkptData& processData) override; + + vector GetDataTypes() override; + + vector GetDirNames() override; + vector GetEmbNames() override; + CkptTransData GetDataset(CkptDataType dataType, string embName) override; + + void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) override; + + private: + const vector fileDirNames { "HashTable", "DDR" }; + const vector saveDataTypes { CkptDataType::EMB_INFO, CkptDataType::EMB_DATA }; + + const int attribEmbInfoSendCntIdx { 0 }; + const int attribEmbInfoEmbSizeIdx { 1 }; + const int attribEmbInfoDevVocabIdx { 2 }; + const int attribEmbInfoHostVocabIdx { 3 }; + + const int attribEmbDataOuterIdx { 0 }; + const int attribEmbDataInnerIdx { 1 }; + + const int embSveElmtNum { 4 }; + emb_mem_t saveHostEmbs; + emb_mem_t loadHostEmbs; + + void SetEmbInfoTrans(string embName); + void SetEmbDataTrans(string embName); + + void SetEmbInfo(string embName); + void SetEmbData(string embName); + + int GetEmbInfoSize(); + size_t GetEmbDataSize(string embName); + }; +} + +#endif // MX_REC_HOST_EMB_CKPT_H diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp new file mode 100644 index 00000000..271af3cc --- /dev/null +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-17 + */ +#include +#include + +#include "nddr_feat_map_ckpt.h" + + +using namespace std; +using namespace MxRec; + +void NddrFeatMapCkpt::SetProcessData(CkptData& processData) +{ + saveKeyOffsetMap.clear(); + loadKeyOffsetMap.clear(); + saveKeyOffsetMap = std::move(processData.keyOffsetMap); +} + +void NddrFeatMapCkpt::GetProcessData(CkptData& processData) +{ + processData.keyOffsetMap = std::move(loadKeyOffsetMap); + saveKeyOffsetMap.clear(); + loadKeyOffsetMap.clear(); +} + +vector NddrFeatMapCkpt::GetDataTypes() +{ + return saveDataTypes; +} + +vector NddrFeatMapCkpt::GetDirNames() +{ + return fileDirNames; +} + +vector NddrFeatMapCkpt::GetEmbNames() +{ + vector embNames; + for (const auto& item : saveKeyOffsetMap) { + embNames.push_back(item.first); + } + return embNames; +} + +CkptTransData NddrFeatMapCkpt::GetDataset(CkptDataType dataType, string embName) +{ + CleanTransfer(); + + auto& transArr = transferData.int64Arr; + auto& attribute = transferData.attribute; + auto embHashMapSize = saveKeyOffsetMap.at(embName).size(); + + attribute.push_back(embHashMapSize); + embHashMapSize = embHashMapSize * embHashElmtNum; + + attribute.push_back(embHashElmtNum); + attribute.push_back(eightBytes); + transferData.datasetSize = embHashMapSize * eightBytes; + transferData.attributeSize = attribute.size() * eightBytes; + + transArr.reserve(embHashMapSize); + for (const auto& it : saveKeyOffsetMap.at(embName)) { + transArr.push_back(it.first); + transArr.push_back(it.second); + } + + return move(transferData); +} + +void NddrFeatMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) +{ + CleanTransfer(); + transferData = move(loadedData); + + auto& hostHashMap = loadKeyOffsetMap[embName]; + const auto& transArr = transferData.int64Arr; + + for (size_t i { 0 }; i < transArr.size(); i += embHashElmtNum) { + if (i + embHashElmtNum > transArr.size()) { + // this is an error, need to log this + } + int64_t key { transArr.at(i) }; + hostHashMap[key] = transArr.at(i + 1); + } +} diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h new file mode 100644 index 00000000..f670c302 --- /dev/null +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-17 + */ + +#ifndef MX_REC_NDDR_FEAT_MAP_CKPT_H +#define MX_REC_NDDR_FEAT_MAP_CKPT_H + +#include "ckpt_data_handler/ckpt_data_handler.h" + +namespace MxRec { + using namespace std; + + class NddrFeatMapCkpt : public CkptDataHandler { + public: + NddrFeatMapCkpt() = default; + ~NddrFeatMapCkpt() override {} + + void SetProcessData(CkptData& processData) override; + void GetProcessData(CkptData& processData) override; + + vector GetDataTypes() override; + + vector GetDirNames() override; + vector GetEmbNames() override; + CkptTransData GetDataset(CkptDataType dataType, string embName) override; + + void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) override; + + private: + const vector fileDirNames { "HashTable", "HBM" }; + const vector saveDataTypes { CkptDataType::NDDR_FEATMAP }; + + const int embHashElmtNum { 2 }; + + key_offset_mem_t saveKeyOffsetMap; + key_offset_mem_t loadKeyOffsetMap; + }; +} + +#endif // MX_REC_NDDR_FEAT_MAP_CKPT_H diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp new file mode 100644 index 00000000..17f19b75 --- /dev/null +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-17 + */ + +#include "nddr_offset_ckpt.h" + + +using namespace std; +using namespace MxRec; + +void NddrOffsetCkpt::SetProcessData(CkptData& processData) +{ + saveMaxOffset.clear(); + loadMaxOffset.clear(); + saveMaxOffset = std::move(processData.maxOffset); +} + +void NddrOffsetCkpt::GetProcessData(CkptData& processData) +{ + processData.maxOffset = std::move(loadMaxOffset); + saveMaxOffset.clear(); + loadMaxOffset.clear(); +} + +vector NddrOffsetCkpt::GetDataTypes() +{ + return saveDataTypes; +} + +vector NddrOffsetCkpt::GetDirNames() +{ + return fileDirNames; +} + +vector NddrOffsetCkpt::GetEmbNames() +{ + vector embNames; + for (const auto& item : saveMaxOffset) { + embNames.push_back(item.first); + } + return embNames; +} + +CkptTransData NddrOffsetCkpt::GetDataset(CkptDataType dataType, string embName) +{ + CleanTransfer(); + transferData.int32Arr.push_back(saveMaxOffset.at(embName)); + transferData.datasetSize = fourBytes; + transferData.attribute.push_back(1); + transferData.attribute.push_back(fourBytes); + transferData.attributeSize = transferData.attribute.size() * eightBytes; + return move(transferData); +} + +void NddrOffsetCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) +{ + CleanTransfer(); + transferData = move(loadedData); + loadMaxOffset[embName] = transferData.int32Arr.front(); +} \ No newline at end of file diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h new file mode 100644 index 00000000..c8664e1c --- /dev/null +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-17 + */ + +#ifndef MX_REC_NDDR_OFFSET_CKPT_H +#define MX_REC_NDDR_OFFSET_CKPT_H + +#include "ckpt_data_handler/ckpt_data_handler.h" + +namespace MxRec { + using namespace std; + + class NddrOffsetCkpt : public CkptDataHandler { + public: + NddrOffsetCkpt() = default; + ~NddrOffsetCkpt() override {} + + void SetProcessData(CkptData& processData) override; + void GetProcessData(CkptData& processData) override; + + vector GetDataTypes() override; + + vector GetDirNames() override; + vector GetEmbNames() override; + CkptTransData GetDataset(CkptDataType dataType, string embName) override; + + void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) override; + + private: + const vector fileDirNames { "HashTable", "HBM" }; + const vector saveDataTypes { CkptDataType::NDDR_OFFSET }; + + offset_mem_t saveMaxOffset; + offset_mem_t loadMaxOffset; + }; +} + +#endif // MX_REC_NDDR_OFFSET_CKPT_H diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp new file mode 100644 index 00000000..4f739bf3 --- /dev/null +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -0,0 +1,275 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Date: 2022/11/15 + */ + +#include "emb_hashmap.h" +#include +#include +#include +#include +#include +#include "hd_transfer/hd_transfer.h" +#include "checkpoint/checkpoint.h" + +using namespace MxRec; + +void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad) +{ + this->rankInfo = rankInfo; + if (!ifLoad) { + EmbHashMapInfo embHashMap; + spdlog::info("init emb hash map from scratch"); + for (const auto& embInfo: embInfos) { + embHashMap.devOffset2Batch.resize(embInfo.devVocabSize); + embHashMap.devOffset2Key.resize(embInfo.devVocabSize); + embHashMap.hostVocabSize = embInfo.hostVocabSize; + embHashMap.devVocabSize = embInfo.devVocabSize; + embHashMap.currentUpdatePos = 0; + fill(embHashMap.devOffset2Batch.begin(), embHashMap.devOffset2Batch.end(), -1); + fill(embHashMap.devOffset2Key.begin(), embHashMap.devOffset2Key.end(), -1); + embHashMaps[embInfo.name] = embHashMap; + spdlog::trace("devOffset2Key, {}", embHashMaps.at(embInfo.name).devOffset2Key); + spdlog::trace("devOffset2Batch, {}", embHashMaps.at(embInfo.name).devOffset2Batch); + } + } +} + +vector EmbHashMap::Process(const string& embName, const vector& keys, size_t iBatch) +{ + EASY_FUNCTION(profiler::colors::Pink) + auto keepBatch = swapId - iBatch; + FindAndUpdateOffset(embName, keys, swapId, keepBatch); + swapId++; + EASY_BLOCK("hostHashMaps->tdt") + + auto& embHashMap = embHashMaps.at(embName); + vector tmpData; + auto lookUpVecSize = static_cast(embHashMap.lookUpVec.size()); + tmpData.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); + + auto lookupTensorData = tmpData.back().flat(); + for (int i = 0; i < lookUpVecSize; i++) { + lookupTensorData(i) = static_cast(embHashMap.lookUpVec[i]); + } + spdlog::trace("lookupTensor, {}", embHashMap.lookUpVec); + auto swapSize = static_cast(embHashMap.swapPos.size()); + tmpData.emplace_back(Tensor(tensorflow::DT_INT32, { swapSize })); + + auto swapTensorData = tmpData.back().flat(); + for (int i = 0; i < swapSize; i++) { + swapTensorData(i) = static_cast(embHashMap.swapPos[i]); + } + if (swapSize > 0) { + spdlog::debug("swap num: {}", swapSize); + } + spdlog::trace("swapTensor, {}", embHashMap.swapPos); + embHashMap.swapPos.clear(); + spdlog::info("current dev emb usage:{}-{}/[{}+{}]", embName, embHashMap.maxOffset, embHashMap.devVocabSize, + embHashMap.hostVocabSize); + tmpData.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); + auto swapLen = tmpData.back().flat(); + swapLen(0) = swapSize; + EASY_END_BLOCK + return tmpData; +} + +/* + * 从embHashMaps获取key对应的位置,并更新devOffset2Batch + */ +void EmbHashMap::FindAndUpdateOffset(const string& embName, const vector& keys, + size_t currentBatchId, size_t keepBatchId) +{ + EASY_FUNCTION() + size_t keySize = keys.size(); + auto& embHashMap = embHashMaps.at(embName); + embHashMap.lookUpVec.resize(keySize); + std::fill(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), INVALID_KEY_VALUE); + + FindAndUpdateBatchId(keys, currentBatchId, keySize, embHashMap); + EASY_BLOCK("FindNewOffset") + vector> KeysAndOffset; + + for (size_t i = 0; i < keySize; i++) { + auto key = keys[i]; + if (key == -1) { + continue; + } + auto& offset = embHashMap.lookUpVec[i]; + if (offset == INVALID_KEY_VALUE) { + offset = FindNewOffset(key, embHashMap); + } + if (offset >= static_cast(embHashMap.devVocabSize)) { + embHashMap.missingKeysHostPos.emplace_back(offset - embHashMap.devVocabSize); + KeysAndOffset.emplace_back(key, i); + } + } + EASY_END_BLOCK + EASY_BLOCK("FindPos") + size_t swapSize = KeysAndOffset.size(); + FindPos(embHashMap, swapSize, currentBatchId, keepBatchId); + EASY_END_BLOCK + EASY_BLOCK("ChangeInfo") +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ + shared(swapSize, KeysAndOffset, embHashMap, currentBatchId) + for (size_t i = 0; i < swapSize; i++) { + auto[key, j] = KeysAndOffset[i]; + int pos = static_cast(embHashMap.swapPos[i]); + ChangeSwapInfo(embHashMap, key, embHashMap.missingKeysHostPos[i] + embHashMap.devVocabSize, + currentBatchId, pos); + embHashMap.lookUpVec[j] = pos; + } + EASY_END_BLOCK +} + +void EmbHashMap::ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_t hostOffset, size_t currentBatchId, + int pos) +{ + embHashMap.devOffset2Batch[pos] = static_cast(currentBatchId); + embHashMap.hostHashMap[key] = pos; + auto& oldKey = embHashMap.devOffset2Key[pos]; + if (oldKey != -1) { + embHashMap.hostHashMap[oldKey] = hostOffset; + } + oldKey = key; +} + +int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap) +{ + int offset; + const auto& iter = embHashMap.hostHashMap.find(key); + if (iter != embHashMap.hostHashMap.end()) { // 由于未全局去重,需要再次查询确保是新key + offset = iter->second; + } else if (embHashMap.evictDevPos.size() != 0) { // 优先复用hbm表 + offset = embHashMap.evictDevPos.back(); + embHashMap.hostHashMap[key] = offset; + spdlog::trace("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", + key, offset, embHashMap.evictDevPos.size()); + embHashMap.evictDevPos.pop_back(); + } else if (embHashMap.evictPos.size() != 0) { // hbm不足,再复用ddr表 + offset = embHashMap.evictPos.back(); + embHashMap.hostHashMap[key] = offset; + spdlog::trace("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", + key, offset, embHashMap.evictPos.size()); + embHashMap.evictPos.pop_back(); + } else { + embHashMap.hostHashMap[key] = embHashMap.maxOffset; + offset = embHashMap.maxOffset; + embHashMap.maxOffset++; + if (embHashMap.maxOffset == embHashMap.devVocabSize) { + spdlog::info("start using host vocab!"); + } + if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { + spdlog::error("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, + embHashMap.hostVocabSize); + throw runtime_error("hostVocabSize too small"); + } + } + return offset; +} + +void EmbHashMap::FindAndUpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, + EmbHashMapInfo& embHashMap) const +{ + EASY_FUNCTION() + for (size_t i = 0; i < keySize; i++) { + int offset; + auto key = keys[i]; + if (key == -1) { + continue; + } + const auto& iter = embHashMap.hostHashMap.find(key); + if (iter != embHashMap.hostHashMap.end()) { // found + offset = static_cast(iter->second); + embHashMap.lookUpVec[i] = offset; // convert to offset(current) + spdlog::trace("key will be used, {} , offset , {}", key, offset); + if (offset < static_cast(embHashMap.devVocabSize)) { + embHashMap.devOffset2Batch[offset] = currentBatchId; + embHashMap.devOffset2Key[offset] = key; + } + } + } +} + +void EmbHashMap::FindPos(EmbHashMapInfo& embHashMap, int num, size_t currentBatchId, + size_t keepBatchId) +{ + while (num != 0) { + if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { + embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); + num -= 1; + } + embHashMap.currentUpdatePos++; + embHashMap.freeSize--; + if (embHashMap.currentUpdatePos == embHashMap.devVocabSize) { + embHashMap.currentUpdatePos = 0; + } + if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { + spdlog::error("devVocabSize is too small"); + throw runtime_error("devVocabSize is too small"); + } + } +} + +auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map +{ + return embHashMaps; +} + +void EmbHashMap::LoadHashMap(emb_hash_mem_t& loadData) +{ + embHashMaps = std::move(loadData); +} + +void EmbHashMapInfo::SetStartCount() +{ + currentUpdatePosStart = currentUpdatePos; + freeSize = devVocabSize; +} + +bool EmbHashMapInfo::HasFree(size_t i) +{ + return freeSize < i; +} + +/* +* 删除淘汰key的映射关系,并将其offset更新到evictPos,待后续复用 +*/ +void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& keys) +{ + EASY_FUNCTION() + size_t keySize = keys.size(); + auto& embHashMap = embHashMaps.at(embName); + + for (size_t i = 0; i < keySize; i++) { + size_t offset; + auto key = keys[i]; + if (key == -1) { + spdlog::error("evict key equal -1!"); + continue; + } + const auto& iter = embHashMap.hostHashMap.find(key); + if (iter != embHashMap.hostHashMap.end()) { + offset = iter->second; + embHashMap.hostHashMap.erase(iter); + spdlog::trace("evict embName {} , offset , {}", embName, offset); + } else { + // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 + continue; + } + + if (offset < embHashMap.devVocabSize) { + embHashMap.devOffset2Batch[offset] = -1; + embHashMap.devOffset2Key[offset] = -1; + embHashMap.evictDevPos.emplace_back(offset); + } else { + embHashMap.evictPos.emplace_back(offset - embHashMap.devVocabSize); + } + } + + spdlog::info("ddr EvictDeleteEmb, emb: [{}], hostEvictSize: {}, devEvictSize: {} ", + embName, embHashMap.evictPos.size(), embHashMap.evictDevPos.size()); + spdlog::trace("hostHashMap, {}", embHashMaps[embName].hostHashMap); +} \ No newline at end of file diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h new file mode 100644 index 00000000..424684e5 --- /dev/null +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Date: 2022/11/15 + */ + +#ifndef MX_REC_EMB_HASHMAP_H +#define MX_REC_EMB_HASHMAP_H + +#include +#include "absl/container/flat_hash_map.h" +#include +#include +#include "host_emb/host_emb.h" + +namespace MxRec { + using namespace std; + + class EmbHashMap { + public: + EmbHashMap() = default; + + void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); + + vector Process(const string& embName, const std::vector& keys, size_t iBatch); + + void FindAndUpdateOffset(const string& embName, const vector& keys, size_t currentBatchId, + size_t keepBatchId); + + void ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_t hostOffset, size_t currentBatchId, + int pos); + + void FindPos(EmbHashMapInfo& embHashMap, int num, size_t currentBatchId, + size_t keepBatchId); + + auto GetHashMaps() -> absl::flat_hash_map; + + void LoadHashMap(absl::flat_hash_map& loadData); + + const std::vector& GetMissingKeys(const string& embName) + { + return embHashMaps.at(embName).missingKeysHostPos; + } + + void ClearMissingKeys(const string& embName) + { + embHashMaps.at(embName).missingKeysHostPos.clear(); + } + + void EvictDeleteEmb(const string& embName, const vector& keys); + + absl::flat_hash_map embHashMaps; + + private: + RankInfo rankInfo; + int swapId { 0 }; + + void FindAndUpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, + EmbHashMapInfo& embHashMap) const; + + int32_t FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap); + }; +} + +#endif // MX_REC_EMB_HASHMAP_H \ No newline at end of file diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp new file mode 100644 index 00000000..c1293243 --- /dev/null +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -0,0 +1,574 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Date: 2022/11/15 + */ +#include "emb_mgmt.h" +#include +#include +#include "checkpoint/checkpoint.h" +#include "utils/time_cost.h" + +using namespace MxRec; +using namespace std; + +bool HybridMgmt::Initialize(RankInfo rankInfo, + const vector& embInfos, + int seed, + const vector& thresholdValues, + bool ifLoad) +{ + SetLog(rankInfo.rankId); + if (isRunning) { + return true; + } + MPI_Comm_size(MPI_COMM_WORLD, &rankInfo.rankSize); + int localRankId = rankInfo.deviceId; + spdlog::info(MGMT + "begin initialize, localRankSize:{}, localRankId {}, rank {}", rankInfo.localRankSize, + localRankId, rankInfo.rankId); + rankInfo.localRankId = localRankId; + size_t totHostVocabSize = 0; + for (const auto& emb : embInfos) { + totHostVocabSize += emb.hostVocabSize; + } + if (totHostVocabSize == 0) { + rankInfo.noDDR = true; + } + rankInfo.useDataset = getenv("DATASET") != nullptr; + mgmtRankInfo = rankInfo; + mgmtEmbInfo = embInfos; + skipUpdate = getenv("SKIP_UPDATE") != nullptr; + hdTransfer = Singleton::GetInstance(); + hdTransfer->Init(embInfos, rankInfo.deviceId); + preprocess = Singleton::GetInstance(); + preprocess->Initialize(rankInfo, embInfos, thresholdValues, ifLoad, seed); + preprocess->Start(); + lookUpKeysQueue = make_unique>>(); + restoreQueue = make_unique>>(); + isRunning = true; + if (!rankInfo.noDDR) { + hostEmbs = make_unique(); + hostHashMaps = make_unique(); + hostEmbs->Initialize(embInfos, seed, ifLoad); + hostHashMaps->Init(rankInfo, embInfos, ifLoad); + } + isLoad = ifLoad; + if (!rankInfo.useDataset && !isLoad) { + Start(); + } + for (const auto& info: embInfos) { + spdlog::info(MGMT + "emb[{}] vocab size {}+{} sc:{}", info.name, info.devVocabSize, info.hostVocabSize, + info.sendCount); + } + spdlog::info(MGMT + "end initialize, useDataset:{}, noDDR:{}, maxStep:{}, rank:{}", + rankInfo.useDataset, rankInfo.noDDR, rankInfo.maxStep, rankInfo.rankId); + return true; +} + +bool HybridMgmt::Save(string savePath) +{ + preprocess->LoadSaveLock(); + + CkptData saveData; + Checkpoint saveCkpt; + if (!mgmtRankInfo.noDDR) { + spdlog::debug(MGMT + "Start host side save: ddr mode hashmap"); + saveData.hostEmbs = hostEmbs->GetHostEmbs(); + saveData.embHashMaps = hostHashMaps->GetHashMaps(); + } else { + spdlog::debug(MGMT + "Start host side save: no ddr mode hashmap"); + saveData.maxOffset = preprocess->GetMaxOffset(); + saveData.keyOffsetMap = preprocess->GetKeyOffsetMap(); + } + + auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); + if (featAdmitNEvict.GetFunctionSwitch()) { + spdlog::debug(MGMT + "Start host side save: feature admit and evict"); + saveData.tens2Thresh = featAdmitNEvict.GetTensorThresholds(); + saveData.histRec.timestamps = featAdmitNEvict.GetHistoryRecords().timestamps; + saveData.histRec.historyRecords = featAdmitNEvict.GetHistoryRecords().historyRecords; + } + + saveCkpt.SaveModel(savePath, saveData, mgmtRankInfo, mgmtEmbInfo); + + preprocess->LoadSaveUnlock(); + + return true; +} + +bool HybridMgmt::Load(const string& loadPath) +{ + preprocess->LoadSaveLock(); + + spdlog::debug(MGMT + "Start host side load process"); + + CkptData loadData; + Checkpoint loadCkpt; + vector loadFeatures; + if (!mgmtRankInfo.noDDR) { + loadFeatures.push_back(CkptFeatureType::HOST_EMB); + loadFeatures.push_back(CkptFeatureType::EMB_HASHMAP); + } else { + loadFeatures.push_back(CkptFeatureType::MAX_OFFSET); + loadFeatures.push_back(CkptFeatureType::KEY_OFFSET_MAP); + } + + auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); + if (featAdmitNEvict.GetFunctionSwitch()) { + loadFeatures.push_back(CkptFeatureType::FEAT_ADMIT_N_EVICT); + } + + loadCkpt.LoadModel(loadPath, loadData, mgmtRankInfo, mgmtEmbInfo, loadFeatures); + if (!mgmtRankInfo.noDDR && !LoadMatchesDDRSetup(loadData)) { + return false; + } + + if (!mgmtRankInfo.noDDR) { + spdlog::debug(MGMT + "Start host side load: ddr mode hashmap"); + hostEmbs->LoadEmb(loadData.hostEmbs); + hostHashMaps->LoadHashMap(loadData.embHashMaps); + } else { + spdlog::debug(MGMT + "Start host side load: no ddr mode hashmap"); + preprocess->LoadMaxOffset(loadData.maxOffset); + preprocess->LoadKeyOffsetMap(loadData.keyOffsetMap); + } + if (featAdmitNEvict.GetFunctionSwitch()) { + spdlog::debug(MGMT + "Start host side load: feature admit and evict"); + featAdmitNEvict.LoadTensorThresholds(loadData.tens2Thresh); + featAdmitNEvict.LoadHistoryRecords(loadData.histRec); + } + + spdlog::debug(MGMT + "Finish host side load process"); + + preprocess->LoadSaveUnlock(); + + if (!mgmtRankInfo.useDataset && isLoad) { + Start(); + } + + return true; +} + +bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) +{ + bool loadDataMatches { true }; + size_t embTableCount { 0 }; + const auto& loadHostEmbs { loadData.hostEmbs }; + for (const auto& setupHostEmbs : mgmtEmbInfo) { + const auto& loadEmbTable { loadHostEmbs.find(setupHostEmbs.name) }; + if (loadEmbTable != loadHostEmbs.end()) { + embTableCount++; + + const auto& loadEmbInfo { loadEmbTable->second.hostEmbInfo }; + if (setupHostEmbs.sendCount != loadEmbInfo.sendCount) { + spdlog::error(MGMT + "Load data sendCount {} for table {} does not match setup sendCound {}", + setupHostEmbs.sendCount, setupHostEmbs.name, loadEmbInfo.sendCount); + loadDataMatches = false; + } + if (setupHostEmbs.embeddingSize != loadEmbInfo.embeddingSize) { + spdlog::error(MGMT + "Load data embeddingSize {} for table {} does not match setup embeddingSize {}", + setupHostEmbs.embeddingSize, setupHostEmbs.name, loadEmbInfo.embeddingSize); + loadDataMatches = false; + } + if (setupHostEmbs.devVocabSize != loadEmbInfo.devVocabSize) { + spdlog::error(MGMT + "Load data devVocabSize {} for table {} does not match setup devVocabSize {}", + setupHostEmbs.devVocabSize, setupHostEmbs.name, loadEmbInfo.devVocabSize); + loadDataMatches = false; + } + if (setupHostEmbs.hostVocabSize != loadEmbInfo.hostVocabSize) { + spdlog::error(MGMT + "Load data hostVocabSize {} for table {} does not match setup hostVocabSize {}", + setupHostEmbs.hostVocabSize, setupHostEmbs.name, loadEmbInfo.hostVocabSize); + loadDataMatches = false; + } + if (!loadDataMatches) { + return loadDataMatches; + } + } else { + spdlog::error(MGMT + "Load data does not contain table with table name: {}", setupHostEmbs.name); + return false; + } + } + + if (embTableCount < loadHostEmbs.size()) { + spdlog::error(MGMT + "Load data has {} tables more than setup table num {}", + loadHostEmbs.size(), embTableCount); + return false; + } + return true; +} + +void HybridMgmt::Start() +{ + if (mgmtRankInfo.noDDR) { + auto getInfoTask = [this]() { + auto ret = GetInfoTask(); + spdlog::info("getInfoTask done"); + return ret; + }; + procThreads.emplace_back(getInfoTask); + + auto sendInfoTask = [this]() { + auto ret = SendTask(); + spdlog::info("sendInfoTask done"); + return ret; + }; + procThreads.emplace_back(sendInfoTask); + } + + if (!mgmtRankInfo.noDDR) { + auto parseKeysTask = [this]() { + auto ret = ParseKeysTask(); + spdlog::info("parseKeysTask done"); + return ret; + }; + procThreads.emplace_back(parseKeysTask); + } +} + +bool HybridMgmt::TrainParseKeys() +{ + do { + if (!isRunning) { + return false; + } + ParseKeys(TRAIN_CHANNEL_ID, getInfoBatchId); + spdlog::info(MGMT + "parseKeysBatchId = {}", getInfoBatchId); + } while (getInfoBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1); + + return true; +} + +bool HybridMgmt::EvalParseKeys() +{ + int evalGetInfoBatchId = 0; // 0-99, 0-99 + do { + if (!isRunning) { + return false; + } + bool status = ParseKeys(EVAL_CHANNEL_ID, evalGetInfoBatchId); + if (!status) { + mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] = evalGetInfoBatchId; + break; + } + } while (evalGetInfoBatchId % mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1); + + return true; +} + +// 腾讯需要DDR开启特性 +bool HybridMgmt::ParseKeysTask() +{ + while (isRunning) { + spdlog::info(MGMT + "Start Mgmt ParseKeysTask"); + if (mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] > 0) { + if (!TrainParseKeys()) { + return false; + } + } + + if (mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] > 0) { + if (!EvalParseKeys()) { + return false; + } + } + } + + return false; +} + +bool HybridMgmt::GetInfoTask() +{ + while (isRunning) { + spdlog::info(MGMT + "Start Mgmt GetInfoTask"); + do { + if (!isRunning) { + return false; + } + GetLookupAndRestore(TRAIN_CHANNEL_ID, getInfoBatchId); + spdlog::info(MGMT + "getInfoBatchId = {}", getInfoBatchId); + } while (getInfoBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1); + + int evalGetInfoBatchId = 0; // 0-99, 0-99 + do { + if (!isRunning) { + return false; + } + bool status = GetLookupAndRestore(EVAL_CHANNEL_ID, evalGetInfoBatchId); + if (!status) { + mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] = evalGetInfoBatchId; + break; + } + } while (evalGetInfoBatchId % mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1); + } + return false; +} + +bool HybridMgmt::SendTask() +{ + while (isRunning) { + spdlog::info(MGMT + "Start Mgmt SendTask"); + do { + if (!isRunning) { + return false; + } + SendLookupAndRestore(TRAIN_CHANNEL_ID, sendBatchId); +#if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) + spdlog::info(MGMT + "sendBatchId = {}", sendBatchId); + if (trainBatchId == PROFILING_START_BATCH_ID) { + EASY_PROFILER_ENABLE + } else if (trainBatchId == PROFILING_END_BATCH_ID) { + EASY_PROFILER_DISABLE + ::profiler::dumpBlocksToFile( + fmt::format("/home/MX_REC-mgmt-profile-{}.prof", mgmtRankInfo.rankId).c_str()); + } +#endif + } while (sendBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1); + + int evalSendBatchId = 0; + do { + if (!isRunning) { + return false; + } + bool status = SendLookupAndRestore(EVAL_CHANNEL_ID, evalSendBatchId); + if (!status) { + mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] = evalSendBatchId; + break; + } + } while (evalSendBatchId % mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1); + } + return false; +} + +bool HybridMgmt::GetLookupAndRestore(int channelId, int &batchId) +{ + spdlog::info(MGMT + "start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); + for (const auto& embInfo: mgmtEmbInfo) { + TimeCost getAllTensorTC; + auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); + if (infoVecs == nullptr) { + spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); + return false; + } + lookUpKeysQueue->Pushv({ infoVecs->back() }); + infoVecs->pop_back(); + restoreQueue->Pushv(*infoVecs); + TIME_PRINT("getAllTensorTC TimeCost(ms):{}", getAllTensorTC.ElapsedMS()); + } + batchId++; + return true; +} + +bool HybridMgmt::SendLookupAndRestore(int channelId, int &batchId) +{ + for (const auto& embInfo: mgmtEmbInfo) { + if (!mgmtRankInfo.useStatic) { + auto all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); + hdTransfer->Send(ALL2ALL, { *all2all }, channelId, embInfo.name); + } + spdlog::info("SendLookupAndRestore batchId: {}, name: {}, channelId: {}", + batchId, embInfo.name, channelId); + + TimeCost sendTensorTC; + omp_set_num_threads(SEND_TENSOR_TYPE_NUM); +#pragma omp parallel sections + { +#pragma omp section + { + TimeCost sendLookupTC; + auto lookUpKeys = lookUpKeysQueue->WaitAndPop(); + hdTransfer->Send(LOOKUP, lookUpKeys, channelId, embInfo.name); + TIME_PRINT("LOOKUP Send TimeCost(ms):{}", sendLookupTC.ElapsedMS()); + } +#pragma omp section + { + TimeCost sendRestoreTC; + auto restore = restoreQueue->WaitAndPop(); + hdTransfer->Send(RESTORE, restore, channelId, embInfo.name); + TIME_PRINT("RESTORE Send TimeCost(ms):{}", sendRestoreTC.ElapsedMS()); + } + } + TIME_PRINT("sendTensorTC TimeCost(ms):{}", sendTensorTC.ElapsedMS()); + } + batchId++; + return true; +} + +bool HybridMgmt::EndBatch(int batchId, int channelId) +{ + return (batchId % mgmtRankInfo.maxStep[channelId] == 0 && mgmtRankInfo.maxStep[channelId] != -1); +} + +bool HybridMgmt::ParseKeys(int channelId, int& batchId) +{ + spdlog::info(MGMT + "DDR mode, start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); + TimeCost parseKeyTC; + int start = batchId, iBatch = 0; + bool ifHashmapFree = true, remainBatch = true; + while (true) { + spdlog::info(MGMT + "parse keys, [{}]:{}", channelId, batchId); + for (const auto& embInfo : mgmtEmbInfo) { + auto& embHashMap = hostHashMaps->embHashMaps.at(embInfo.name); + if (iBatch == 0) { + embHashMap.SetStartCount(); + } + auto lookupKeys = preprocess->GetLookupKeys(batchId, embInfo.name, channelId); + if (lookupKeys.empty()) { + remainBatch = false; + break; + } + auto restore = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); + hdTransfer->Send(RESTORE, *restore, channelId, embInfo.name); + + auto tmpData = hostHashMaps->Process(embInfo.name, lookupKeys, iBatch); + hdTransfer->Send(LOOKUP, { tmpData.front() }, channelId, embInfo.name); + tmpData.erase(tmpData.begin()); + hdTransfer->Send(SWAP, tmpData, channelId, embInfo.name); + + if (!mgmtRankInfo.useStatic) { + auto all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); + hdTransfer->Send(ALL2ALL, *all2all, channelId, embInfo.name); + } + + if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch + spdlog::warn(MGMT + "embName {}[{}]{},iBatch:{} freeSize not enough, {}", embInfo.name, channelId, + batchId, iBatch, lookupKeys.size()); + ifHashmapFree = false; + } + } + if (!remainBatch) { + EmbHDTransWrap(channelId, batchId, start, iBatch); + return false; + } + batchId++; + iBatch++; + if (EndBatch(batchId, channelId) || iBatch == mgmtRankInfo.nBatch || !ifHashmapFree || !isRunning) { + break; + } + } + if (!isRunning) { + return false; + } + EmbHDTransWrap(channelId, batchId - 1, start, iBatch); + TIME_PRINT("[{}]-{}, parseKeyTC TimeCost(ms):{}", channelId, batchId, parseKeyTC.ElapsedMS()); + return true; +} + +// send h2d & recv d2h emb +void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, int iBatch) +{ + if (iBatch == 0) { + return; + } + spdlog::info(MGMT + "trans emb, batchId:[{}-{}]", start, batchId); + hostEmbs->Join(); + EmbHDTrans(channelId, batchId); + + for (int i = 0; i < iBatch - 1; ++i) { + // need send empty + spdlog::info(MGMT + "trans emb dummy, batchId:{}, ", start + 1 + i); + EmbHDTrans(channelId, batchId); + } +} + +void HybridMgmt::EmbHDTrans(int channelId, int batchId) +{ + EASY_FUNCTION(profiler::colors::Blue) + EASY_VALUE("mgmtProcess", batchId) + spdlog::debug(MGMT + "trans emb, batchId:{}, channelId:{}", batchId, channelId); + TimeCost tr; + for (const auto& embInfo: mgmtEmbInfo) { + auto& missingKeys = hostHashMaps->embHashMaps.at(embInfo.name).missingKeysHostPos; + auto h2dEmb = hostEmbs->GetH2DEmb(missingKeys, embInfo.name); // order! + hdTransfer->Send(H2D, h2dEmb, channelId, embInfo.name, batchId); + } + for (const auto& embInfo: mgmtEmbInfo) { + const auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); + if (!(skipUpdate && missingKeys.empty())) { + hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! + } // skip when skip update and empty missing keys + hostHashMaps->ClearMissingKeys(embInfo.name); + } + TIME_PRINT("EmbHDTrans TimeCost(ms):{} batchId: {} ", tr.ElapsedMS(), batchId); +} + +void HybridMgmt::EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo) +{ + EASY_FUNCTION(profiler::colors::Blue) + EASY_VALUE("mgmtProcess", batchId) + spdlog::info(MGMT + "trans emb dummy, batchId:{}, channelId:{}", batchId, channelId); + auto transferName = TransferChannel::D2H; + auto d2hEmb = hdTransfer->Recv(transferName, channelId, embInfo.name)[0]; + hdTransfer->Send(H2D, {}, channelId, embInfo.name); +} + +/* +* hook通过时间或者step数触发淘汰 +*/ +void HybridMgmt::Evict() +{ + auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); + if (featAdmitNEvict.GetFunctionSwitch()) { + featAdmitNEvict.FeatureEvict(evictKeyMap); + } else { + spdlog::warn(MGMT + "Hook can not trigger evict, cause AdmitNEvict is not open"); + return; + } + spdlog::debug(MGMT + "evict triggered by hook, evict TableNum {} ", evictKeyMap.size()); + + if (mgmtRankInfo.noDDR) { + for (auto evict : evictKeyMap) { + preprocess->EvictKeys(evict.first, evict.second); + } + } else { + for (auto evict : evictKeyMap) { + EvictKeys(evict.first, evict.second); + } + } +} + +// ddr模式淘汰->删除映射表、初始化host表、发送dev淘汰位置 +void HybridMgmt::EvictKeys(const string& embName, const vector& keys) +{ + spdlog::debug(MGMT + "ddr mode, delete emb: [{}]! evict keySize:{}", embName, keys.size()); + // 删除映射关系 + if (keys.size() != 0) { + hostHashMaps->EvictDeleteEmb(embName, keys); + } + + // 初始化host侧的emb + auto& evictOffset = hostHashMaps->embHashMaps.at(embName).evictPos; + if (evictOffset.size() != 0) { + spdlog::debug(MGMT + "ddr mode, delete emb: [{}]! evict size on host:{}", embName, evictOffset.size()); + hostEmbs->EvictInitEmb(embName, evictOffset); + } else { + spdlog::info(MGMT + "ddr mode, evict size on host is empty"); + } + + // 发送dev侧的淘汰pos,以便dev侧初始化emb + auto evictDevOffset = hostHashMaps->embHashMaps.at(embName).evictDevPos; + spdlog::debug(MGMT + "ddr mode, init dev emb: [{}]! evict size on dev :{}", embName, evictDevOffset.size()); + + for (const auto& embInfo : mgmtEmbInfo) { + if (embInfo.name != embName) { + continue; + } + if (evictDevOffset.size() > embInfo.devVocabSize) { + spdlog::error(MGMT + "{} overflow! evict pos on dev {} bigger than dev vocabSize {}", + embName, evictDevOffset.size(), embInfo.devVocabSize); + } + if (mgmtRankInfo.useStatic) { + evictDevOffset.resize(embInfo.devVocabSize, -1); + } + break; + } + + auto tmpData = Vec2TensorI32(evictDevOffset); + hdTransfer->Send(EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); +} \ No newline at end of file diff --git a/src/core/emb_mgmt/emb_mgmt.h b/src/core/emb_mgmt/emb_mgmt.h new file mode 100644 index 00000000..31eef2f5 --- /dev/null +++ b/src/core/emb_mgmt/emb_mgmt.h @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Date: 2022/11/15 + */ + +#ifndef MX_REC_EMB_MGMT_H +#define MX_REC_EMB_MGMT_H + +#include +#include "absl/container/flat_hash_map.h" +#include +#include +#include +#include +#include "utils/common.h" +#include "utils/singleton.h" +#include "utils/task_queue.h" +#include "hd_transfer/hd_transfer.h" +#include "host_emb/host_emb.h" +#include "emb_hashmap/emb_hashmap.h" +#include "key_process/key_process.h" + +namespace MxRec { + using namespace std; + using namespace tensorflow; + + constexpr int SEND_TENSOR_TYPE_NUM = 2; + + class HybridMgmt { + public: + HybridMgmt() = default; + + ~HybridMgmt() + { + if (isRunning) { + Destroy(); + } + } + + HybridMgmt(const HybridMgmt&) = delete; + + HybridMgmt& operator=(const HybridMgmt&) = delete; + + bool Initialize(RankInfo rankInfo, const vector& embInfos, int seed, + const vector& thresholdValues, bool ifLoad); + + bool Save(string savePath); + + bool Load(const string& loadPath); + + void Start(); + + void Destroy() + { + if (!isRunning) { + return; + } + // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 + isRunning = false; + restoreQueue->DestroyQueue(); + lookUpKeysQueue->DestroyQueue(); + + // 先发送停止信号给preprocess,用于停止查询中lookup卡住状态 + preprocess->isRunning = false; + // 停止hdTransfer,用于停止mgmt的recv中卡住状态 + hdTransfer->Destroy(); + for (auto &i : procThreads) { + i.join(); + } + if (hostEmbs != nullptr) { + hostEmbs->Join(); + hostEmbs = nullptr; + } + procThreads.clear(); + // 停止预处理 + if (preprocess != nullptr) { + preprocess->Destroy(); + preprocess = nullptr; + } + }; + + bool ParseKeys(int channelId, int& batchId); + + void EmbHDTrans(int channelId, int batchId); + + void Evict(); + + void EvictKeys(const string& embName, const vector& keys); + + private: + int currentBatchId; + int trainBatchId = 0; // 0-199, 200- + int getInfoBatchId; + int sendBatchId; + vector mgmtEmbInfo; + RankInfo mgmtRankInfo; + unique_ptr hostEmbs {}; + unique_ptr hostHashMaps {}; + vector procThreads {}; + unique_ptr>> lookUpKeysQueue; + unique_ptr>> restoreQueue; + map> evictKeyMap {}; + KeyProcess *preprocess; + HDTransfer *hdTransfer; + bool isRunning; + bool skipUpdate; + bool isLoad { false }; + + bool ParseKeysTask(); + bool GetInfoTask(); + bool SendTask(); + bool TrainParseKeys(); + bool EvalParseKeys(); + + bool GetLookupAndRestore(int channelId, int &batchId); + bool SendLookupAndRestore(int channelId, int &batchId); + + void EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo); + + bool EndBatch(int batchId, int channelId); + + void EmbHDTransWrap(int channelId, const int& batchId, int start, int iBatch); + + bool LoadMatchesDDRSetup(const CkptData& loadData); + }; +} +#endif // MX_REC_EMB_MGMT_H diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp new file mode 100644 index 00000000..428dd23d --- /dev/null +++ b/src/core/emb_table/emb_table.cpp @@ -0,0 +1,241 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: emb table + * Author: MindX SDK + * Date: 2023/5/6 + */ + +#include +#include +#include +#include +#include +#include +#include "acl/acl_base.h" +#include "utils/common.h" +#include "initializer/initializer.h" +#include "emb_table/emb_table.h" + +using namespace std; +using namespace MxRec; +using namespace tensorflow; + +void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) +{ +#ifndef GTEST + this->rankInfo = rInfo; + this->seed = seed; + spdlog::info("EmbTable init, deviceID {}, embSize {} running", rInfo.deviceId, embInfo.embeddingSize); + // 计算embedding table需要分配的内存块数 + auto ret = aclrtSetDevice(static_cast(rInfo.deviceId)); + if (ret != ACL_ERROR_NONE) { + spdlog::error("Set device failed, device_id:{}, ret={}", rInfo.deviceId, ret); + throw AclError(); + } + embSize = embInfo.embeddingSize; + blockSize = BLOCK_EMB_COUNT * embSize; + for (int i = 0; i < INIT_BLOCK_COUNT; ++i) { + // 申请新的内存块 + void *newBlock = nullptr; + aclError ret = aclrtMalloc(&newBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtMalloc failed, ret={}", ret); + throw AclError(); + } + if (newBlock == nullptr) { + // 内存不足,抛出异常 + throw OutOfMemoryError(); + } else { + // 申请内存初始化 + RandomInit(newBlock, embInfo.initializeInfos, seed); + // 将新的内存块加入内存链表 + memoryList.push_back(newBlock); + SplitMemoryBlock(newBlock); + } + } + totalCapacity = memoryList.size(); + spdlog::info("aclrtMalloc success, emb name:{}", embInfo.name); +#endif +} + +EmbTable::~EmbTable() +{ +#ifndef GTEST + for (void *block : memoryList) { + // 释放内存块 + aclError ret = aclrtFree(block); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtFree failed, ret={}", ret); + } + } +#endif +} + +// 从embeddingList获取一个可用的emb地址 +int64_t EmbTable::GetEmbAddress() +{ +#ifndef GTEST + if (embeddingList.empty()) { + PrintStatus(); + spdlog::debug("GetEmbAddress, embedding_list size: empty! Add block!"); + void *addBlock = nullptr; + aclError ret = aclrtMalloc(&addBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtMalloc failed, ret={}", ret); + throw AclError(); + } + if (addBlock == nullptr) { + // 内存不足,抛出异常 + throw OutOfMemoryError(); + } else { + RandomInit(addBlock, embInfo.initializeInfos, seed); + // 将新的内存块加入内存list + memoryList.push_back(addBlock); + SplitMemoryBlock(addBlock); + totalCapacity++; + } + } + float *embAddr = embeddingList.front(); + embeddingList.pop_front(); + usedCapacity++; + return reinterpret_cast(embAddr); +#endif +} + +// 将一个emb地址放入embeddingList中 +void EmbTable::PutEmbAddress(int64_t curAddress) +{ + embeddingList.push_back(reinterpret_cast(curAddress)); + usedCapacity--; +} + +void EmbTable::RandomInit(void* newBlock, const vector& initializeInfos, int seed) +{ +#ifndef GTEST + spdlog::info("Device GenerateEmbData Start, seed:{}", seed); + vector devEmb(blockSize); + for (auto initializeInfo: initializeInfos) { + Initializer* initializer; + switch (initializeInfo.initializerType) { + case InitializerType::CONSTANT: { + spdlog::info("Device GenerateEmbData ing using Constant Initializer by value {}.", + initializeInfo.constantInitializerInfo.constantValue); + initializer = &initializeInfo.constantInitializer; + break; + } + case InitializerType::TRUNCATED_NORMAL: { + spdlog::info("Device GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}.", + initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + initializer = &initializeInfo.truncatedNormalInitializer; + break; + } + case InitializerType::RANDOM_NORMAL: { + spdlog::info("Device GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}.", + initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + initializer = &initializeInfo.randomNormalInitializer; + break; + } + default: { + spdlog::error("Device Invalid Initializer Type. Using default Constant Initializer with value 0."); + ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0); + initializer = &defaultInitializer; + } + } + for (int i = 0; i < BLOCK_EMB_COUNT; i++) { + initializer->GenerateData(&devEmb[i * embSize], embSize); + } + } + spdlog::info("Device GenerateEmbData End, seed:{}", seed); + aclError ret = aclrtMemcpy(newBlock, blockSize * sizeof(float), + devEmb.data(), blockSize * sizeof(float), + ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtMemcpy failed, ret={}", ret); + throw AclError(); + } +#endif +} + + +void EmbTable::SplitMemoryBlock(void *newBlock) +{ + if (embSize == 0) { + throw std::runtime_error("SplitMemoryBlock by embSize=0!"); + } + for (int i = 0; i < BLOCK_EMB_COUNT; i++) { + float *embPtr = static_cast(newBlock) + i * embSize; + embeddingList.push_back(embPtr); + } +} + +void EmbTable::PrintStatus() +{ + // 输出embedding table的总容量 + spdlog::info("Total capacity:{}", totalCapacity * blockSize); + // 输出embedding table的未使用的使用容量 + spdlog::info("Unused capacity:{}", totalCapacity * blockSize - usedCapacity * embSize); +} + +// 用于保存 +map> EmbTable::SaveEmb() +{ +#ifndef GTEST + if (embSize == 0) { + throw std::runtime_error("SaveEmb Divided by Zero!"); + } + map> savedEmb; + for (auto ptr : memoryList) { + float* floatPtr = static_cast(ptr); + for (int i = 0; i < BLOCK_EMB_COUNT; ++i) { + // 访问 aclmemcpy + vector row(embSize); + aclError ret = aclrtMemcpy(row.data(), embSize * sizeof(float), + floatPtr + i * embSize, embSize * sizeof(float), + ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtMemcpy failed, ret={}", ret); + throw AclError(); + } + savedEmb[reinterpret_cast(floatPtr + i * embSize)] = move(row); + } + } + return savedEmb; +#endif +} + +// 用于加载 输入一个vector,申请内存,存储输入信息 , list返回全部地址 +list EmbTable::LoadEmb(const vector> &savedEmb) +{ +#ifndef GTEST + list addressList; + int embCapacity = savedEmb.size(); + if (savedEmb.size() == 0 || savedEmb[0].size() == 0) { + spdlog::error("Load invalid savedEmb"); + return addressList; + } + embSize = savedEmb[0].size(); + void *newBlock = nullptr; + aclError ret = aclrtMalloc(&newBlock, embCapacity * embSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtMalloc failed, ret={}", ret); + throw AclError(); + } + if (newBlock == nullptr) { + // 内存不足,抛出异常 + throw OutOfMemoryError(); + } + float *floatPtr = static_cast(newBlock); + for (int i = 0; i < embCapacity; i++) { + aclError ret = aclrtMemcpy(floatPtr + i * embSize, embSize * sizeof(float), + savedEmb[i].data(), embSize * sizeof(float), + ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_SUCCESS) { + spdlog::error("aclrtMemcpy failed, ret={}", ret); + throw AclError(); + } + addressList.push_back(floatPtr + i * embSize); + } + memoryList.push_back(newBlock); + return addressList; +#endif +} \ No newline at end of file diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h new file mode 100644 index 00000000..5a1c0927 --- /dev/null +++ b/src/core/emb_table/emb_table.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: emb table + * Author: MindX SDK + * Date: 2023/5/6 + */ + +#ifndef MX_REC_EMB_TABLE_H +#define MX_REC_EMB_TABLE_H + +#include +#include +#include +#include +#include "utils/common.h" +#include +#include + +namespace MxRec { + + using namespace std; + + class EmbTable { + public: + EmbTable() = default; + + void Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed = 0); + + ~EmbTable(); + + // 从embeddingList获取获取一个可用的emb地址 + int64_t GetEmbAddress(); + + // 将一个emb地址放入embeddingList中 + void PutEmbAddress(int64_t curAddress); + + // 打印emb表使用情况 + void PrintStatus(); + + int GetTotalCap(); + + int GetUsedCap(); + + EmbTable(const EmbTable&) = delete; + + EmbTable(EmbTable&&) = delete; + + EmbTable& operator=(const EmbTable&) = delete; + + EmbTable& operator=(EmbTable&&) = delete; + + // 用于保存 + map> SaveEmb(); + + // 用于加载 输入一个vector,创建一个embeddingtable类,申请内存,存储输入信息 , list返回全部地址 + list LoadEmb(const vector> &savedEmb); + + GTEST_PRIVATE: + constexpr static int BLOCK_EMB_COUNT = 1000; + constexpr static int INIT_BLOCK_COUNT = 5; + constexpr static int TEST_EMB_SIZE = 12; + EmbInfo embInfo; + RankInfo rankInfo; + int blockSize = 1; + int embSize = 1; + int totalCapacity = 1; + int usedCapacity = 0; + int seed = 0; + float mean = 0; + float stddev = 1; + // embedding地址的列表 + list embeddingList; + // 内存块列表 + vector memoryList; + + void RandomInit(void* newBlock, const vector &initializeInfos, int seed); + + // embSize由embInfo得出 + void SplitMemoryBlock(void* newBlock); + + // 内部类,抛出内存不足异常 + class OutOfMemoryError : public runtime_error { + public: + OutOfMemoryError() : runtime_error("Out of memory!") {} + }; + + // 内部类,抛出acl异常 + class AclError : public runtime_error { + public: + AclError() : runtime_error("Acl failed!") {} + }; + }; +} + +#endif // MX_REC_EMB_TABLE_MANAGER_H \ No newline at end of file diff --git a/src/core/hd_transfer/acl_channel.h b/src/core/hd_transfer/acl_channel.h new file mode 100644 index 00000000..ce8da921 --- /dev/null +++ b/src/core/hd_transfer/acl_channel.h @@ -0,0 +1,34 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. +* Description: acl channel api +* Author: MindX SDK +* Date: 2022/11/15 +*/ + +#ifndef ACL_CHANNEL_H_ +#define ACL_CHANNEL_H_ + +#include +#include +#include "acl/acl_tdt.h" +#include "tensorflow/core/framework/tensor.h" + + +namespace tensorflow { +#ifdef CANN5_x + Status RecvTensorByAcl(acltdtChannelHandle *acl_handle, std::vector &tensors); +#else + + Status RecvTensorByAcl(const acltdtChannelHandle* acl_handle, std::vector& tensors); + + Status StopRecvTensorByAcl(acltdtChannelHandle **handle, const std::string &channel_name); + +#endif + + Status SendTensorsByAcl(const acltdtChannelHandle* acl_handle, acltdtTensorType acl_type, + const std::vector& tensors, bool& is_need_resend); + +} // namespace tensorflow + +#endif // ACL_CHANNEL_H_ + diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp new file mode 100644 index 00000000..10ae60ed --- /dev/null +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -0,0 +1,197 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. +* Description: common module +* Author: MindX SDK +* Date: 2022/11/15 +*/ +#include "hd_transfer.h" +#include +#include +#include +#include "utils/common.h" + +using namespace MxRec; +using namespace std; + +int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) +{ +#ifndef GTEST + spdlog::info(MGMT + "begin hd_transfer initialize, rank:{}", localRankId); + aclError retOk = aclInit(nullptr); + spdlog::info(MGMT + "end aclInit, rank:{}", localRankId); + if (retOk != ACL_SUCCESS) { + spdlog::error(MGMT + "aclInit fail, rank:{}, errno:{}", localRankId, retOk); + return false; + } + spdlog::info(MGMT + "start Set device, rank:{}", localRankId); + auto ret = aclrtSetDevice(static_cast(localRankId)); + if (ret != ACL_ERROR_NONE) { + spdlog::error("Set device failed, device_id:{}", localRankId); + return false; + } + spdlog::info(MGMT + "end Set device, rank:{}", localRankId); + for (const auto& embInfo: embInfos) { + auto embName = embInfo.name; + for (int i = 0; i < MAX_CHANNEL_NUM; ++i) { + CreateChannel(localRankId, embName, i); + } + } + running = true; + spdlog::info("hd_transfer init"); +#endif + return true; +} + +void HDTransfer::Destroy() +{ +#ifndef GTEST + running = false; + spdlog::info(HD + "destroy channel start"); + for (auto& c: transferChannels) { + tensorflow::StopRecvTensorByAcl(&c.second, c.first); + spdlog::info(HD + "destroy channel:{}", c.first); + } + aclFinalize(); +#endif +} + +void HDTransfer::CreateChannel(uint32_t localRankId, const string& embName, int channelNum) +{ +#ifndef GTEST + int channelSize; + const char* env = getenv("HD_CHANNEL_SIZE"); + if (env == nullptr) { + channelSize = LARGE_CHANNEL_SIZE; + } else { + try { + channelSize = stoi(env); + } catch (const std::invalid_argument& e) { + spdlog::warn("wrong HD_CHANNEL_SIZE env {}", e.what()); + channelSize = LARGE_CHANNEL_SIZE; + } catch (const std::out_of_range& e) { + spdlog::warn("wrong HD_CHANNEL_SIZE env {}", e.what()); + channelSize = LARGE_CHANNEL_SIZE; + } + if (channelSize <= 0) { + channelSize = LARGE_CHANNEL_SIZE; + } + } + spdlog::info("user config all2all restore lookup channel size:{}", channelSize); + for (int c = D2H; c != INVALID; c++) { + auto channel = static_cast(c); + string sendName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelNum); + if (TransferChannel2Str(channel) == "all2all" || + TransferChannel2Str(channel) == "restore" || + TransferChannel2Str(channel) == "lookup" || + TransferChannel2Str(channel) == "evict" /* for noDDR */ + ) { + transferChannels[sendName] = tdtCreateChannel(localRankId, sendName.c_str(), channelSize); + } else { + transferChannels[sendName] = tdtCreateChannel(localRankId, sendName.c_str(), PING_PONG_SIZE); + } + spdlog::info("create channel:{} {}", sendName, static_cast(transferChannels[sendName])); + } +#endif +} + +void HDTransfer::Send(TransferChannel channel, const vector &tensors, int channelId, const string &embName, + int batchId) +{ + EASY_FUNCTION() + if (!running) { + return; + } +#ifndef GTEST + vector sizes; + for (auto& t: tensors) { + sizes.push_back(t.NumElements()); + } + string sendName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelId); + + spdlog::info(HD + "hd transfer send {}, send count is {}, size list:{}", sendName, sizes.size(), + sizes); + + if (sizes.size() == 0) { + spdlog::warn("tensors num can not be zero"); + return; + } + bool isNeedResend = false; + int resendTime = 0; + tensorflow::Status status = tensorflow::Status::OK(); + do { + status = tensorflow::SendTensorsByAcl(transferChannels[sendName], ACL_TENSOR_DATA_TENSOR, tensors, + isNeedResend); + if (!running) { + return; + } + if (status != tensorflow::Status::OK()) { + spdlog::error(MGMT + "hd send {} error '{}'", sendName, status.error_message()); + throw runtime_error("hd send error"); + } + if (batchId != -1 && resendTime != 0) { + spdlog::warn(MGMT + "hd send {} batch: {} failed, retry: {} ", sendName, batchId, resendTime); + } + resendTime++; + } while (isNeedResend); +#endif +} + +vector HDTransfer::Recv(TransferChannel channel, int channelId, const string& embName) +{ + EASY_FUNCTION() +#ifndef GTEST + std::vector tensors; + string recvName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelId); + spdlog::debug("hd transfer try recv:{}", recvName); + + tensorflow::Status status = tensorflow::RecvTensorByAcl(transferChannels[recvName], tensors); + if (!running) { + return {}; + } + if (status != tensorflow::Status::OK()) { + spdlog::error(MGMT + "{} hd recv error '{}'", recvName, status.error_message()); + throw runtime_error("hd recv error"); + } + + vector sizes; + for (auto& t: tensors) { + sizes.push_back(t.NumElements()); + } + spdlog::info("hd transfer recv:{}, size:{}", recvName, sizes); + return tensors; +#endif + return {}; +} + +tuple HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& embName) +{ + EASY_FUNCTION() +#ifndef GTEST + std::vector tensors; + string recvName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelId); + spdlog::debug("hd transfer try recv:{}", recvName); + acltdtDataset* aclDataset = acltdtCreateDataset(); + if (aclDataset == nullptr) { + throw runtime_error(fmt::format("Failed recv:{}.", recvName).c_str()); + } + auto aclStatus = acltdtReceiveTensor(transferChannels[recvName], aclDataset, -1 /* no timeout */); + if (!running) { + return {nullptr, 0}; + } + if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { + throw runtime_error(fmt::format("Failed receive data from acl channel, acl status:{}", aclStatus).c_str()); + } + spdlog::info("hd transfer recv:{}", recvName); + return {aclDataset, acltdtGetDatasetSize(aclDataset)}; +#endif + return {nullptr, 0}; +} + +size_t HDTransfer::QueryChannelSize(const string& channelName) +{ + size_t size = -1; +#ifndef GTEST + acltdtQueryChannelSize(transferChannels[channelName], &size); +#endif + return size; +} diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h new file mode 100644 index 00000000..2e9ff303 --- /dev/null +++ b/src/core/hd_transfer/hd_transfer.h @@ -0,0 +1,90 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. +* Description: common module +* Author: MindX SDK +* Date: 2022/11/15 +*/ + +#ifndef MX_REC_HD_TRANSFER_H +#define MX_REC_HD_TRANSFER_H + +#include "acl/acl_base.h" +#include "acl/acl.h" +#include "acl/acl_tdt.h" +#include "acl/acl_tdt_queue.h" +#include "acl_channel.h" +#include "utils/common.h" + +#ifndef tdtCreateChannel +#define tdtCreateChannel acltdtCreateChannelWithCapacity +#endif + +namespace MxRec { + using namespace std; + const std::string MGMT = "\033[32m[Mgmt]\033[0m "; + const std::string HD = "\033[32m[HD]\033[0m "; + const std::string HOSTEMB = "\033[32m[HostEmb]\033[0m "; + const int PING_PONG_SIZE = 12; + const int LARGE_CHANNEL_SIZE = 100; + + enum TransferChannel { + D2H, + RESTORE, + ALL2ALL, + LOOKUP, + EVICT, + H2D, + SWAP, + INVALID + }; + + inline string TransferChannel2Str(TransferChannel e) + { + switch (e) { + case D2H: + return "d2h"; + case RESTORE: + return "restore"; + case ALL2ALL: + return "all2all"; + case LOOKUP: + return "lookup"; + case EVICT: + return "evict"; + case H2D: + return "h2d"; + case SWAP: + return "swap"; + default: + throw std::invalid_argument("Invalid TransferChannel"); + } + }; + + class HDTransfer { + public: + HDTransfer() = default; + + int Init(const vector& embInfos, uint32_t localRankId); + + void Send(TransferChannel channel, const vector& tensors, + int channelId, const string& embName, int batchId = -1); + + vector Recv(TransferChannel channel, int channelId, const string& embName); + + tuple RecvAcl(TransferChannel channel, int channelId, const string& embName); + + size_t QueryChannelSize(const string& channelName); + + auto Vec2Tensor(const vector& tmpVec) const -> vector; + + void Destroy(); + + private: +#ifndef GTEST + std::unordered_map transferChannels; +#endif + bool running; + void CreateChannel(uint32_t localRankId, const string& embName, int channelNum); + }; +} +#endif // MX_REC_HD_TRANSFER_H diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp new file mode 100644 index 00000000..dfed1b24 --- /dev/null +++ b/src/core/host_emb/host_emb.cpp @@ -0,0 +1,259 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Date: 2022/11/15 + */ + +#include "host_emb.h" +#include +#include +#include +#include +#include +#include "hd_transfer/hd_transfer.h" +#include "checkpoint/checkpoint.h" +#include "initializer/initializer.h" + +using namespace MxRec; +using namespace std; +using namespace chrono; + +bool HostEmb::Initialize(const vector& embInfos, int seed, bool ifLoad) +{ + if (!ifLoad) { + for (const auto& embInfo: embInfos) { + HostEmbTable hostEmb; + hostEmb.hostEmbInfo = embInfo; + EmbDataGenerator(embInfo.initializeInfos, seed, embInfo.hostVocabSize, embInfo.embeddingSize, + hostEmb.embData); + hostEmbs[embInfo.name] = move(hostEmb); + spdlog::info(HOSTEMB + "HostEmb Initialize End"); + } + } + return true; +} + +void HostEmb::EmbDataGenerator(const vector &initializeInfos, int seed, int vocabSize, + int embeddingSize, vector> &embData) +{ + spdlog::info(HOSTEMB + "GenerateEmbData Start, seed:{}", seed); + embData.clear(); + embData.resize(vocabSize, vector(embeddingSize)); + + for (auto initializeInfo: initializeInfos) { + Initializer* initializer; + + switch (initializeInfo.initializerType) { + case InitializerType::CONSTANT: { + spdlog::info(HOSTEMB + "GenerateEmbData ing using Constant Initializer by value {}.", + initializeInfo.constantInitializerInfo.constantValue); + initializer = &initializeInfo.constantInitializer; + break; + } + case InitializerType::TRUNCATED_NORMAL: { + spdlog::info(HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}.", + initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + initializer = &initializeInfo.truncatedNormalInitializer; + break; + } + case InitializerType::RANDOM_NORMAL: { + spdlog::info(HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}.", + initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + initializer = &initializeInfo.randomNormalInitializer; + break; + } + default: { + spdlog::error(HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); + ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0); + initializer = &defaultInitializer; + } + } + + for (int i = 0; i < vocabSize; i++) { + initializer->GenerateData(embData.at(i).data(), embeddingSize); + } + } + + spdlog::info(HOSTEMB + "GenerateEmbData End, seed:{}", seed); +} + +void HostEmb::LoadEmb(emb_mem_t& loadData) +{ + hostEmbs = std::move(loadData); +} + +void HostEmb::Join() +{ + spdlog::stopwatch sw; + spdlog::debug(HOSTEMB + "hostemb start join {}", procThread.size()); + for (auto& t: procThread) { + t->join(); + } + procThread.clear(); + spdlog::info(HOSTEMB + "hostemb end join, cost:{}", TO_MS(sw)); +} + +/* + * 从hdTransfer获取device侧返回的emb信息,并在host侧表的对应位置插入。 + * missingKeysHostPos为host侧需要发送的emb的位置,也就是淘汰的emb的插入位置 + */ +void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName) +{ + EASY_FUNCTION(profiler::colors::Purple) + auto hdTransfer = Singleton::GetInstance(); + TransferChannel transferName = TransferChannel::D2H; + spdlog::info(HOSTEMB + "wait D2H embs, channelId:{}", channelId); + const auto tensors = hdTransfer->Recv(transferName, channelId, embName); + if (tensors.empty()) { + spdlog::warn(HOSTEMB + "recv empty data"); + return; + } + const Tensor& d2hEmb = tensors[0]; + spdlog::info(HOSTEMB + "UpdateEmb End missingkeys len = {}", missingKeysHostPos.size()); + EASY_BLOCK("Update") + const float* tensorPtr = d2hEmb.flat().data(); + auto embeddingSize = hostEmbs[embName].hostEmbInfo.embeddingSize; + auto& embData = hostEmbs[embName].embData; + +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ + shared(missingKeysHostPos, tensorPtr, embData, embeddingSize) + for (size_t i = 0; i < missingKeysHostPos.size(); i++) { + auto& dst = embData[missingKeysHostPos[i]]; +#pragma omp simd + for (int j = 0; j < embeddingSize; j++) { + dst[j] = tensorPtr[j]; + } + tensorPtr = tensorPtr + embeddingSize; + } + spdlog::info(HOSTEMB + "update emb end"); + EASY_END_BLOCK +} + +void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName) +{ +#ifndef GTEST + EASY_FUNCTION(profiler::colors::Purple) + procThread.emplace_back(make_unique( + [&, missingKeysHostPos, channelId, embName] { + auto hdTransfer = Singleton::GetInstance(); + TransferChannel transferName = TransferChannel::D2H; + spdlog::info(HOSTEMB + "wait D2H embs, channelId:{}", channelId); + auto [aclDataset, size] = hdTransfer->RecvAcl(transferName, channelId, embName); + if (size == 0) { + spdlog::warn(HOSTEMB + "recv empty data"); + return; + } + spdlog::info(HOSTEMB + "UpdateEmb End missingkeys len = {}", missingKeysHostPos.size()); + EASY_BLOCK("Update") + auto& embData = hostEmbs[embName].embData; + auto embeddingSize = hostEmbs[embName].hostEmbInfo.embeddingSize; + auto aclData = acltdtGetDataItem(aclDataset, 0); + if (aclData == nullptr) { + throw runtime_error("Acl get tensor data from dataset failed."); + } + float* ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ptr, embData, embeddingSize) + for (size_t j = 0; j < missingKeysHostPos.size(); j++) { + auto& dst = embData[missingKeysHostPos[j]]; +#pragma omp simd + for (int k = 0; k < embeddingSize; k++) { + dst[k] = ptr[k]; + } + } + if (acltdtDestroyDataset(aclDataset) != ACL_ERROR_NONE) { + throw runtime_error("Acl destroy tensor dataset failed."); + } + spdlog::info(HOSTEMB + "update emb end"); + })); +#endif +} + +/* + * 找到host侧需要发送的emb,通过hdTransfer发送给device。 + * missingKeysHostPos为host侧需要发送的emb的位置 + */ +vector HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& embName) +{ + EASY_FUNCTION() + vector h2d_emb; + const auto& emb = hostEmbs[embName]; + const int embeddingSize = emb.hostEmbInfo.embeddingSize; + h2d_emb.emplace_back(Tensor(tensorflow::DT_FLOAT, { + int(missingKeysHostPos.size()), embeddingSize + })); + auto& tmpTensor = h2d_emb.back(); + auto tmpData = tmpTensor.flat(); +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(missingKeysHostPos, emb, tmpData) + for (size_t i = 0; i < missingKeysHostPos.size(); i++) { + const auto src = emb.embData[missingKeysHostPos[i]]; +#pragma omp simd + for (int j = 0; j < embeddingSize; j++) { + tmpData(j + i * embeddingSize) = src[j]; + } + } + spdlog::info("GetH2DEmb end, missingKeys count:{}", missingKeysHostPos.size()); + return h2d_emb; +} + +auto HostEmb::GetHostEmbs() -> absl::flat_hash_map +{ + return hostEmbs; +} + +EmbInfo::EmbInfo(const string &name, int sendCount, int embeddingSize, vector vocabsize, + vector initializeInfos) + : name(name), sendCount(sendCount), embeddingSize(embeddingSize), initializeInfos(initializeInfos) +{ + devVocabSize = vocabsize[0]; + hostVocabSize = vocabsize[1]; +} + +void HostEmb::EmbPartGenerator(const vector &initializeInfos, vector> &embData, + const vector& offset) +{ + for (auto initializeInfo: initializeInfos) { + Initializer* initializer; + + switch (initializeInfo.initializerType) { + case InitializerType::CONSTANT: { + spdlog::info(HOSTEMB + "GenerateEmbData ing using Constant Initializer by value {}.", + initializeInfo.constantInitializerInfo.constantValue); + initializer = &initializeInfo.constantInitializer; + break; + } + case InitializerType::TRUNCATED_NORMAL: { + spdlog::info(HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}.", + initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + initializer = &initializeInfo.truncatedNormalInitializer; + break; + } + case InitializerType::RANDOM_NORMAL: { + spdlog::info(HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}.", + initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + initializer = &initializeInfo.randomNormalInitializer; + break; + } + default: { + spdlog::error(HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); + ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0); + initializer = &defaultInitializer; + } + } + + for (size_t i = 0; i < offset.size(); i++) { + initializer->GenerateData(embData.at(offset.at(i)).data(), embData[0].size()); + } + } +} + +/* + * 利用initializer初始化emb淘汰的位置 + */ +void HostEmb::EvictInitEmb(const string& embName, const vector& offset) +{ + auto& hostEmb = GetEmb(embName); + EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); + + spdlog::info(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); +} \ No newline at end of file diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h new file mode 100644 index 00000000..b23175ba --- /dev/null +++ b/src/core/host_emb/host_emb.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Date: 2022/11/15 + */ + +#ifndef MX_REC_HOSTEMB_H +#define MX_REC_HOSTEMB_H + +#include +#include +#include +#include +#include "absl/container/flat_hash_map.h" +#include "utils/common.h" +#include "utils/singleton.h" +#include "tensorflow/core/framework/tensor.h" + +namespace MxRec { + using namespace std; + using namespace tensorflow; + + class HostEmb { + public: + HostEmb() = default; + + ~HostEmb() + {}; + + bool Initialize(const vector& embInfos, int seed, bool ifLoad = false); + + void LoadEmb(absl::flat_hash_map& loadData); + + void Join(); + + void UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName); + + void UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName); + + vector GetH2DEmb(const vector& missingKeysHostPos, const string& embName); + + auto GetHostEmbs() -> absl::flat_hash_map; + + void EvictInitEmb(const string& embName, const vector& offset); + + HostEmbTable& GetEmb(const string& embName) + { + return hostEmbs.at(embName); + } + + GTEST_PRIVATE: + absl::flat_hash_map hostEmbs; + + std::vector> procThread; + + void EmbDataGenerator(const vector& initializeInfos, int seed, int vocabSize, int embeddingSize, + vector>& embData); + void EmbPartGenerator(const vector &initializeInfos, vector> &embData, + const vector& offset); + }; +} + +#endif // MX_REC_HOSTEMB_H \ No newline at end of file diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp new file mode 100644 index 00000000..4bced738 --- /dev/null +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: constant initializer module + * Author: MindX SDK + * Date: 2022/12/22 + */ + +#include "constant_initializer.h" +#include + +using namespace std; +using namespace MxRec; + +ConstantInitializer::ConstantInitializer(int start, int len, float value) : start(start), len(len), value(value) {} + +void ConstantInitializer::GenerateData(float* emb, const int embSize) +{ + if (len == 0) { + return; + } + if (embSize < (start + len)) { + spdlog::warn( + "InitializeInfo start {} + len {} is larger than embedding size {}.", + start, len, embSize); + return; + } + std::fill_n(emb + start, len, value); +} \ No newline at end of file diff --git a/src/core/initializer/constant_initializer/constant_initializer.h b/src/core/initializer/constant_initializer/constant_initializer.h new file mode 100644 index 00000000..b763087b --- /dev/null +++ b/src/core/initializer/constant_initializer/constant_initializer.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: constant initializer module + * Author: MindX SDK + * Date: 2022/12/22 + */ + +#ifndef MX_REC_CONSTANT_INITIALIZER_H +#define MX_REC_CONSTANT_INITIALIZER_H + +#include +#include "initializer/initializer.h" + +namespace MxRec { + using std::vector; + + class ConstantInitializer : public Initializer { + public: + ConstantInitializer() = default; + ConstantInitializer(int start, int len, float value); + + ~ConstantInitializer() override {}; + + void GenerateData(float* emb, const int embSize) override; + + int start; + int len; + float value; + }; +} + +#endif // MX_REC_CONSTANT_INITIALIZER_H diff --git a/src/core/initializer/initializer.cpp b/src/core/initializer/initializer.cpp new file mode 100644 index 00000000..afb44d20 --- /dev/null +++ b/src/core/initializer/initializer.cpp @@ -0,0 +1,6 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. +* Description: initializer module +* Author: MindX SDK +* Date: 2022/12/22 +*/ \ No newline at end of file diff --git a/src/core/initializer/initializer.h b/src/core/initializer/initializer.h new file mode 100644 index 00000000..0fce164c --- /dev/null +++ b/src/core/initializer/initializer.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: initializer module + * Author: MindX SDK + * Date: 2022/12/22 + */ + +#ifndef MX_REC_INITIALIZER_H +#define MX_REC_INITIALIZER_H + +#include +namespace MxRec { + using std::vector; + + class Initializer { + public: + Initializer() = default; + virtual ~Initializer() {}; + + virtual void GenerateData(float* emb, int embSize)= 0; + int start; + int len; + }; +} + +#endif // MX_REC_INITIALIZER_H \ No newline at end of file diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp new file mode 100644 index 00000000..a5a31381 --- /dev/null +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: random normal initializer module + * Author: MindX SDK + * Date: 2022/12/23 + */ + +#include "random_normal_initializer.h" +#include +#include + +using namespace MxRec; + +RandomNormalInitializer::RandomNormalInitializer(int start, int len, float mean, float stddev, int seed) + : start(start), len(len), mean(mean), stddev(stddev), seed(seed) +{ + generator = std::default_random_engine(seed); + distribution = std::normal_distribution(mean, stddev); +} + +void RandomNormalInitializer::GenerateData(float* emb, const int embSize) +{ + if (len == 0) { + return; + } + if (embSize < (start + len)) { + spdlog::warn( + "InitializeInfo start {} + len {} is larger than embedding size {}.", + start, len, embSize); + return; + } + std::generate_n(emb + start, len, [&]() { return distribution(generator); }); +} \ No newline at end of file diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h new file mode 100644 index 00000000..e0127ca2 --- /dev/null +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: random normal initializer module + * Author: MindX SDK + * Date: 2022/12/23 + */ + +#ifndef MX_REC_RANDOM_NORMAL_INITIALIZER_H +#define MX_REC_RANDOM_NORMAL_INITIALIZER_H + +#include +#include + +#include "initializer/initializer.h" + +namespace MxRec { + using namespace std; + + class RandomNormalInitializer : public Initializer { + public: + RandomNormalInitializer() = default; + RandomNormalInitializer(int start, int len, float mean, float stddev, int seed); + + ~RandomNormalInitializer() override {}; + + void GenerateData(float* emb, const int embSize) override; + + int start; + int len; + float mean; + float stddev; + int seed; + + std::default_random_engine generator; + std::normal_distribution distribution; + }; +} + +#endif // MX_REC_RANDOM_NORMAL_INITIALIZER_H \ No newline at end of file diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp new file mode 100644 index 00000000..a1f49c08 --- /dev/null +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: truncated normal initializer module + * Author: MindX SDK + * Date: 2022/12/22 + */ + +#include "truncated_normal_initializer.h" +#include +#include + +using namespace MxRec; + +TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, float mean, float stddev, int seed) + : start(start), len(len), mean(mean), stddev(stddev), seed(seed) +{ + generator = std::default_random_engine(seed); + distribution = std::normal_distribution(mean, stddev); + minBound = mean - boundNum * stddev; + maxBound = mean + boundNum * stddev; +} + + +void TruncatedNormalInitializer::GenerateData(float* emb, const int embSize) +{ + if (len == 0) { + return; + } + if (embSize < (start + len)) { + spdlog::warn( + "InitializeInfo start {} + len {} is larger than embedding size {}.", + start, len, embSize); + return; + } + std::generate_n(emb + start, len, [&]() { + float tmp = distribution(generator); + while (tmp < minBound || tmp > maxBound) { + tmp = distribution(generator); + } + return tmp; + }); +} \ No newline at end of file diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h new file mode 100644 index 00000000..3c6bb980 --- /dev/null +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: truncated normal initializer module + * Author: MindX SDK + * Date: 2022/12/22 + */ + +#ifndef MX_REC_TRUNCATED_NORMAL_INITIALIZER_H +#define MX_REC_TRUNCATED_NORMAL_INITIALIZER_H + +#include +#include + +#include "initializer/initializer.h" + +namespace MxRec { + using namespace std; + + class TruncatedNormalInitializer : public Initializer { + public: + TruncatedNormalInitializer() = default; + TruncatedNormalInitializer(int start, int len, float mean, float stddev, int seed); + + ~TruncatedNormalInitializer() override {}; + + void GenerateData(float* emb, const int embSize) override; + + int boundNum = 2; + + int start; + int len; + float mean; + float stddev; + int seed; + + std::default_random_engine generator; + std::normal_distribution distribution; + float minBound; + float maxBound; + }; +} + +#endif // MX_REC_TRUNCATED_NORMAL_INITIALIZER_H \ No newline at end of file diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp new file mode 100644 index 00000000..bcbb02e7 --- /dev/null +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -0,0 +1,317 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: operator module + * Author: MindX SDK + * Date: 2022/11/23 + */ + +#include "feature_admit_and_evict.h" +#include +#include +#include +#include +#include +#include "checkpoint/checkpoint.h" + +using namespace MxRec; + +std::vector FeatureAdmitAndEvict::m_cfgThresholds {}; +absl::flat_hash_map FeatureAdmitAndEvict::m_embStatus {}; + +FeatureAdmitAndEvict::FeatureAdmitAndEvict(int recordsInitSize) : m_recordsInitSize(recordsInitSize) {} + +FeatureAdmitAndEvict::~FeatureAdmitAndEvict() +{ + m_isEnableFunction = false; + m_isExit = true; + if (m_evictThread.joinable()) { + m_evictThread.join(); + } +} + +bool FeatureAdmitAndEvict::Init(const std::vector& thresholdValues) +{ + if (!ParseThresholdCfg(thresholdValues)) { + m_isEnableFunction = false; + spdlog::error("Config is error, feature admin-and-evict function is not available ...\n"); + return false; + } + + return true; +} + +// 以下为类的公共接口 +FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, + const std::unique_ptr& batch, keys_t& splitKey, std::vector& keyCount) +{ + if (splitKey.size() != keyCount.size()) { + spdlog::error("splitKey.size {} != keyCount.size {}", splitKey.size(), keyCount.size()); + return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR; + } + + // 如果当前 tensorName 不在准入范围之内,则不进行“特征准入”逻辑 + std::string tensorName = batch->name; + absl::flat_hash_map mergeKeys; + mergeKeys.reserve(splitKey.size()); + PreProcessKeys(splitKey, keyCount, mergeKeys); + + std::lock_guard lock(m_syncMutexs); + auto iter = m_recordsData.historyRecords.find(tensorName); + if (iter == m_recordsData.historyRecords.end()) { // 之前tensorName没出现过时,数据初始化 + absl::flat_hash_map records(m_recordsInitSize); + m_recordsData.historyRecords[tensorName] = records; + } + spdlog::debug("FeatureAdmitAndEvict PrintSize, name:[{}], history key:[{}] ...", tensorName, + m_recordsData.historyRecords[tensorName].size()); + + m_recordsData.timestamps[tensorName] = batch->timestamp; + absl::flat_hash_map visitedRecords; + for (auto& key : splitKey) { + if (key == -1) { + continue; + } + + // 特征准入&特征淘汰 + auto it = visitedRecords.find(key); + if (it == visitedRecords.end()) { + visitedRecords[key] = true; + if (FeatureAdmitHelper(channel, tensorName, key, mergeKeys[key]) == + FeatureAdmitType::FEATURE_ADMIT_FAILED) { + visitedRecords[key] = false; + key = -1; // 被淘汰的Feature ID + } + continue; + } + + if (visitedRecords[key] == false) { + key = -1; + } + } + + return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; +} + +FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(int channel, const std::string& tensorName, + int64_t featureId, uint32_t featureCnt) +{ + // “特征准入”逻辑 + uint32_t currKeyCount = 0; + absl::flat_hash_map& historyRecordInfos = m_recordsData.historyRecords[tensorName]; + auto innerIt = historyRecordInfos.find(featureId); + + if (channel == TRAIN_CHANNEL_ID) { + if (innerIt == historyRecordInfos.end()) { + // 维护 m_historyRecords + FeatureItemInfo info(featureId, featureCnt, tensorName, m_recordsData.timestamps[tensorName]); + historyRecordInfos[featureId] = info; + currKeyCount = featureCnt; + } else { + // 维护 m_historyRecords + FeatureItemInfo &info = historyRecordInfos[featureId]; + info.count += featureCnt; + info.lastTime = m_recordsData.timestamps[tensorName]; + currKeyCount = info.count; + } + } else if (channel == EVAL_CHANNEL_ID) { // eval + if (innerIt != historyRecordInfos.end()) { + currKeyCount = historyRecordInfos[featureId].count; + } + } + + // 准入条件判断 + if (currKeyCount >= static_cast(m_tensor2Threshold[tensorName].countThreshold)) { + return FeatureAdmitType::FEATURE_ADMIT_OK; + } + + return FeatureAdmitType::FEATURE_ADMIT_FAILED; +} + +// 特征淘汰接口 +void FeatureAdmitAndEvict::FeatureEvict(map>& evictKeyMap) +{ + std::vector tensorNames = GetAllNeedEvictTensorNames(); + if (tensorNames.empty()) { + spdlog::info("EmbNames is empty, no evict function ..."); + return ; + } + if (!m_isEnableFunction) { + spdlog::warn("m_isEnableFunction switch is false, no evict function ..."); + return ; + } + std::lock_guard lock(m_syncMutexs); + // 从 m_historyRecords 中淘汰删除 + size_t tensorCnt = tensorNames.size(); + for (size_t i = 0; i < tensorCnt; ++i) { + FeatureEvictHelper(tensorNames[i], evictKeyMap[tensorNames[i]]); + } +} + +void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::vector& evictKey) +{ + // 从 m_historyRecords 中淘汰删除 + time_t currTime = m_recordsData.timestamps[embName]; + // 从 m_tensor2SortedLastTime 获取当前要淘汰的featureId + SortedRecords lastTimePriority; + for (auto& item : m_recordsData.historyRecords[embName]) { + lastTimePriority.push(item.second); + } + while (!lastTimePriority.empty()) { + if (currTime - lastTimePriority.top().lastTime < m_tensor2Threshold[embName].timeThreshold) { + break; + } + evictKey.emplace_back(lastTimePriority.top().featureId); + lastTimePriority.pop(); + } + + if (evictKey.size() == 0) { + spdlog::info("tensor-name[{}]'s lastTime[{}], had no key to delete ...", embName, currTime); + return; + } + spdlog::info("tensor-name[{}]'s lastTime[{}], had size[{}] keys to delete ...", embName, currTime, + evictKey.size()); + + // 真正从 m_historyRecords 中淘汰 + absl::flat_hash_map& historyRecords = m_recordsData.historyRecords[embName]; + for (size_t k = 0; k < evictKey.size(); ++k) { + historyRecords.erase(evictKey[k]); + } + if (historyRecords.empty()) { + m_recordsData.historyRecords.erase(embName); + } +} + +// 特征淘汰的使能接口 +void FeatureAdmitAndEvict::SetFunctionSwitch(bool isEnableEvict) +{ + if (isEnableEvict) { + spdlog::info("feature admit-and-evict switch is opened ..."); + } else { + spdlog::info("feature admit-and-evict switch is closed ..."); + } + m_isEnableFunction = isEnableEvict; +} +bool FeatureAdmitAndEvict::GetFunctionSwitch() const +{ + return m_isEnableFunction; +} + +void FeatureAdmitAndEvict::PreProcessKeys(const std::vector& splitKey, std::vector& keyCount, + absl::flat_hash_map& mergeKeys) +{ + for (size_t i = 0; i < splitKey.size(); ++i) { + if (splitKey[i] == -1) { + continue; + } + + auto it = mergeKeys.find(splitKey[i]); + if (it == mergeKeys.end()) { + mergeKeys[splitKey[i]] = keyCount[i]; + } else { + mergeKeys[splitKey[i]] += keyCount[i]; + } + } +} + +bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& thresholds, + const std::vector& embNames, bool isTimestamp) +{ + for (size_t i = 0; i < thresholds.size(); ++i) { + auto it = std::find(embNames.begin(), embNames.end(), thresholds[i].tensorName); + if (it == embNames.end()) { // 配置不存在于当前跑的模型,也要报错 + spdlog::error("embName[{}] is not exist at current model ...", thresholds[i].tensorName); + return false; + } else { + // 同时支持“准入&淘汰”,却没有传时间戳 + if (m_embStatus[*it] == SingleEmbTableStatus::SETS_ERROR) { + spdlog::error("embName[{}] config error, please check ...", embNames[i]); + return false; + } else if (m_embStatus[*it] == SingleEmbTableStatus::SETS_BOTH && !isTimestamp) { + spdlog::error("embName[{}] admit and evict, but no timestamp", embNames[i]); + return false; + } + } + } + + return true; +} + +auto FeatureAdmitAndEvict::GetTensorThresholds() -> tensor_2_thresh_mem_t +{ + std::lock_guard lock(m_syncMutexs); + return m_tensor2Threshold; +} + +auto FeatureAdmitAndEvict::GetHistoryRecords() -> AdmitAndEvictData& +{ + std::lock_guard lock(m_syncMutexs); + return m_recordsData; +} + +void FeatureAdmitAndEvict::LoadTensorThresholds(tensor_2_thresh_mem_t& loadData) +{ + std::lock_guard lock(m_syncMutexs); + m_tensor2Threshold = std::move(loadData); +} + +void FeatureAdmitAndEvict::LoadHistoryRecords(AdmitAndEvictData& loadData) +{ + std::lock_guard lock(m_syncMutexs); + m_recordsData = std::move(loadData); +} + +// 解析m_tensor2Threshold +bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& thresholdValues) +{ + if (thresholdValues.empty()) { + spdlog::error("thresholdValues is empty ..."); + return false; + } + + m_cfgThresholds = thresholdValues; + for (const auto& value : thresholdValues) { + spdlog::info("embName[{}], count[{}], time[{}] ...", + value.tensorName, value.countThreshold, value.timeThreshold); + auto it = m_tensor2Threshold.find(value.tensorName); + if (it != m_tensor2Threshold.end()) { + // train和eval同时开启,会出现表重复配置 + spdlog::info("[{}] is repeated configuration ...", value.tensorName); + return true; + } + m_tensor2Threshold[value.tensorName] = value; + + if (value.countThreshold != -1 && value.timeThreshold != -1) { + m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_BOTH; + } else if (value.countThreshold != -1 && value.timeThreshold == -1) { + m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ONLY_ADMIT; + } else { + spdlog::error("[{}] config error, have evict but no admit ...", value.tensorName); + m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ERROR; + return false; + } + } + + return true; +} + +std::vector FeatureAdmitAndEvict::GetAllNeedEvictTensorNames() +{ + std::vector names; + std::lock_guard lock(m_syncMutexs); + for (const auto& record : m_recordsData.historyRecords) { + // 只获取支持特征准入的embName + if (m_embStatus[record.first] == SingleEmbTableStatus::SETS_BOTH) { + names.emplace_back(record.first); + } + } + return names; +} + +void FeatureAdmitAndEvict::ResetAllRecords() +{ + std::lock_guard lock(m_syncMutexs); + for (auto& record : m_recordsData.historyRecords) { + record.second.clear(); + } + m_recordsData.historyRecords.clear(); + m_recordsData.timestamps.clear(); +} \ No newline at end of file diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h new file mode 100644 index 00000000..a355711d --- /dev/null +++ b/src/core/key_process/feature_admit_and_evict.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: operator module + * Author: MindX SDK + * Date: 2022/11/23 + */ + +#ifndef FEATURE_ADMIT_AND_EVICT_H +#define FEATURE_ADMIT_AND_EVICT_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "host_emb/host_emb.h" +#include "utils/common.h" +#include "utils/safe_queue.h" +#include "utils/singleton.h" + +namespace MxRec { + enum class FeatureAdmitType { + FEATURE_ADMIT_OK = 0, + FEATURE_ADMIT_FAILED + }; + enum class FeatureAdmitReturnType { + FEATURE_ADMIT_RETURN_OK = 0, + FEATURE_ADMIT_RETURN_ERROR + }; + + enum class SingleEmbTableStatus { + SETS_BOTH = 0, // 准入&淘汰功能,正常(threshold配置正常 && 传batch时间戳) + SETS_NONE, // 没有配置(不支持) + SETS_ONLY_ADMIT, // 只支持准入功能 + SETS_ERROR // 只有淘汰、没有准入的错误,或其它 + }; + + const int DEFAULT_RECORDS_INIT_SIZE = 10000; + const int FEATURE_EVICT_TIME_INTERVAL = 3600 * 24; + const int FEATURE_MIN_TIME_INTERVAL = 10; + + class FeatureAdmitAndEvict { + public: + explicit FeatureAdmitAndEvict(int recordsInitSize = DEFAULT_RECORDS_INIT_SIZE); + ~FeatureAdmitAndEvict(); + + bool Init(const std::vector& thresholdValues); + + // 以下为类的公共接口 + // 特征准入接口 + FeatureAdmitReturnType FeatureAdmit(int channel, const std::unique_ptr& batch, + keys_t& splitKey, std::vector& keyCount); + + // 特征淘汰接口 + void FeatureEvict(map>& evictKeyMap); + + // 特征淘汰的使能接口 + void SetFunctionSwitch(bool isEnableEvict); + bool GetFunctionSwitch() const; + void PreProcessKeys(const std::vector& splitKey, std::vector& keyCount, + absl::flat_hash_map& mergeKeys); + + // 判断配置是否正确的接口 + static bool IsThresholdCfgOK(const std::vector& thresholds, + const std::vector& embNames, bool isTimestamp); + + // 与模型保存加载交互的接口 + auto GetTensorThresholds() -> tensor_2_thresh_mem_t; + auto GetHistoryRecords() -> AdmitAndEvictData&; + + void LoadTensorThresholds(tensor_2_thresh_mem_t& loadData); + void LoadHistoryRecords(AdmitAndEvictData& loadData); + + static std::vector m_cfgThresholds; // 用于判断阈值配置的有效性 + static absl::flat_hash_map m_embStatus; // 用于“准入&淘汰”功能解耦 + + GTEST_PRIVATE : + + // 解析m_tensor2Threshold + bool ParseThresholdCfg(const std::vector& thresholdValues); + std::vector GetAllNeedEvictTensorNames(); + FeatureAdmitType FeatureAdmitHelper(int channel, const std::string& tensorName, + int64_t featureId, uint32_t featureCnt); + void FeatureEvictHelper(const std::string& embName, std::vector& evictKey); + void ResetAllRecords(); + + bool m_isEnableFunction { true }; // “特征淘汰”的使能开关 + bool m_isExit { false }; // 淘汰线程退出的标识 + absl::flat_hash_map m_tensor2Threshold; // tensor-X ---> ThresholdValue 映射 + AdmitAndEvictData m_recordsData; + std::mutex m_syncMutexs; // 特征准入与特征淘汰竞争的同步锁 + int m_recordsInitSize { DEFAULT_RECORDS_INIT_SIZE }; // m_historyRecords表初始容量 + std::thread m_evictThread; // 特征淘汰功能,以“线程 + 定时任务”方式实现 + }; +} + +#endif // FEATURE_ADMIT_AND_EVICT_H diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp new file mode 100644 index 00000000..2e24fa52 --- /dev/null +++ b/src/core/key_process/key_process.cpp @@ -0,0 +1,980 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Date: 2022/11/15 + */ + +#include "key_process.h" +#include +#include +#include +#include +#include +#include "checkpoint/checkpoint.h" +#include "hd_transfer/hd_transfer.h" + +using namespace std; +using namespace chrono; +using namespace MxRec; + +static shared_mutex g_smut; + +template +inline vector Count2Start(const vector& count) +{ + vector start = { 0 }; + for (size_t i = 0; i < count.size() - 1; ++i) { + start.push_back(count[i] + start.back()); + } + return start; +} + +int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, + const vector& thresholdValues, + bool ifLoad, int seed) +{ + this->rankInfo = rInfo; + if (rankInfo.useHot) { + const char* env = getenv("HOT_EMB_UPDATE_STEP"); + if (env == nullptr) { + hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; + } else { + hotEmbUpdateStep = stoi(env); + if (hotEmbUpdateStep == 0) { + hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; + } + } + } + + map scInfo; + for (const auto& info: eInfos) { + embInfos[info.name] = info; + scInfo[info.name] = info.sendCount; + if (rankInfo.useHot) { + hotEmbTotCount[info.name] = static_cast(GetUBSize(rInfo.deviceId) / sizeof(float) * HOT_EMB_CACHE_PCT / + info.embeddingSize); + } + if (rankInfo.useDynamicExpansion) { + // 动态扩容 + embeddingTableMap[info.name].Init(info, rInfo, seed); + spdlog::info(KEY_PROCESS "EmbeddingTableMap:{} init success", info.name); + } + + if (rankInfo.rankId == 0 && !ifLoad) { + Key2OffsetInit(info.name); + } + } + spdlog::info(KEY_PROCESS "hot emb count info:{}", hotEmbTotCount); + MPI_Group world_group; + MPI_Comm_group(MPI_COMM_WORLD, &world_group); + for (auto& i: comm) { + for (auto& j: i) { + MPI_Comm_create(MPI_COMM_WORLD, world_group, &j); + } + } + isRunning = true; + + // 特征准入与特征淘汰 + if (!thresholdValues.empty()) { + m_featureAdmitAndEvict.SetFunctionSwitch(true); + m_featureAdmitAndEvict.Init(thresholdValues); + } else { + m_featureAdmitAndEvict.SetFunctionSwitch(false); + spdlog::warn(KEY_PROCESS "Feature admit-and-evict function is unavailable ..."); + } + + spdlog::info(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", scInfo, + rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot); + return 0; +} + +// bind and start main process +int KeyProcess::Start() +{ + // bind like: + // 0 1 2 3 4 5 0 1 2 3 4 5 + // | rank0 | | rank1 | + // each rank creates KEY_PROCESS_THREAD threads, each thread process one batchdata + spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + auto fn = [this](int channel, int id) { +#ifndef GTEST + auto ret = aclrtSetDevice(static_cast(rankInfo.deviceId)); + if (ret != ACL_ERROR_NONE) { + spdlog::error("Set device failed, device_id:{}", rankInfo.deviceId); + return; + } +#endif + KeyProcessTask(channel, id); + }; // for clean code + for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { + for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { + procThread.emplace_back(fn, channel, id); // use lambda expression initialize thread + } + } + return 0; +} + +auto KeyProcess::GetMaxOffset() -> offset_mem_t +{ + return maxOffset; +} + +auto KeyProcess::GetKeyOffsetMap() -> key_offset_mem_t +{ + return keyOffsetMap; +} + +auto KeyProcess::GetFeatAdmitAndEvict() -> FeatureAdmitAndEvict& +{ + return m_featureAdmitAndEvict; +} + +void KeyProcess::LoadMaxOffset(offset_mem_t& loadData) +{ + maxOffset = std::move(loadData); +} + +void KeyProcess::LoadKeyOffsetMap(key_offset_mem_t& loadData) +{ + keyOffsetMap = std::move(loadData); +} + +// 只在python侧当训练结束时调用,如果出现死锁直接结束程序即可,测试时让进程等待足够长的时间再调用 +void KeyProcess::Destroy() +{ + isRunning = false; + spdlog::info(KEY_PROCESS "rank {} begin destroy.", rankInfo.rankId); + for (auto& i: procThread) { + i.join(); + } + procThread.clear(); + spdlog::info(KEY_PROCESS "rank {} destroy success.", rankInfo.rankId); +} + +void KeyProcess::LoadSaveLock() +{ + for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { + for (int threadId { 0 }; threadId < KEY_PROCESS_THREAD; ++threadId) { + loadSaveMut[channelId][threadId].lock(); + } + } +} + +void KeyProcess::LoadSaveUnlock() +{ + for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { + for (int threadId { 0 }; threadId < KEY_PROCESS_THREAD; ++threadId) { + loadSaveMut[channelId][threadId].unlock(); + } + } +} + +void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCESS_THREAD-1] +{ + unique_ptr batch; + ShardedDedup *unique = nullptr; + + spdlog::stopwatch sw; + try { + while (true) { + TimeCost getAndProcesTC; + TimeCost getBatchTC; + batch = GetBatchData(channel, id); // get batch data from SingletonQueue + TIME_PRINT("GetBatchData TimeCost(ms):{}", getBatchTC.ElapsedMS()); + + if (batch == nullptr) { + break; + } + auto getBatchTime = TO_MS(sw); + sw.reset(); + + if (unique == nullptr) { + GroupMethod groupMethod; + groupMethod.SetGroupCount(rankInfo.rankSize); + unique = new ShardedDedup(groupMethod, batch->batchSize, + embInfos[batch->name].sendCount); + } else { + unique->StartNewRound(); + } + auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); + if (!KeyProcessTaskHelper(batch, unique, channel, id, sw)) { + free(batch->tensorAddr); + batchQueue->PutDirty(move(batch)); + break; + } + TIME_PRINT("getAndProcesTC TimeCost(ms):{}", getAndProcesTC.ElapsedMS()); + spdlog::info(KEY_PROCESS "key process cost:{}, get data time:{} batch {}[{}]:{} ", + TO_MS(sw), getBatchTime, batch->name, batch->channel, batch->batchId); + free(batch->tensorAddr); + batchQueue->PutDirty(move(batch)); + } + delete unique; + } catch (const EndRunError &e) { + spdlog::debug(KEY_PROCESS "abort run: {}", e.what()); + } + spdlog::info(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, id, channel); +} + +bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_dedup unique, + int channel, int id, spdlog::stopwatch &sw) +{ + // tuple for keyRec restore hotPos scAll countRecv + std::tuple, vector, vector, vector, vector> rets; + isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && + FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; + TimeCost tc; + auto [lookupKeys, restore, hotPos, scAll, countRecv] = + ProcessBatchWithUniqueCompute(batch, unique, id); + TIME_PRINT("ProcessBatch TimeCost(ms):{}", tc.ElapsedMS()); + + // 特征准入&淘汰 + if (isWithFAAE && + (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, lookupKeys, + countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { + spdlog::error(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", + rankInfo.rankId, id, channel); + return false; + } + int batchListId = batch->batchId % KEY_PROCESS_THREAD; + + // without host, just device, all embedding vectors were stored in device + // map key to offset directly by lookup keyOffsetMap (hashmap) + if (rankInfo.noDDR) { + TimeCost key2OffsetTc; + Key2Offset(batch->name, lookupKeys); + TIME_PRINT("Key2Offset TimeCost(ms):{}", key2OffsetTc.ElapsedMS()); + } + if (!rankInfo.useStatic) { // Static all2all,need send count + SendA2A(scAll, batch->name, batch->channel, batch->batchId); + } + + auto tensors = make_unique>(); + tensors->push_back(Vec2TensorI32(restore)); + if (rankInfo.useHot) { + hotPos.resize(hotEmbTotCount[batch->name], -1); + tensors->push_back(Vec2TensorI32(hotPos)); + } + if (rankInfo.noDDR) { + if (rankInfo.useDynamicExpansion) { + tensors->push_back(Vec2TensorI64(lookupKeys)); + } else { + tensors->push_back(Vec2TensorI32(lookupKeys)); + } + } + TimeCost pushTensorTc; + PushResult(batch, move(tensors), lookupKeys, batchListId); + TIME_PRINT("pushTensorToListTC TimeCost(ms):{}", pushTensorTc.ElapsedMS()); + return true; +} + +vector KeyProcess::GetCountRecv(const unique_ptr& batch, int id, + vector>& keyCount, vector scAll, vector ss) +{ + if (rankInfo.useStatic) { + for (auto& cnt: keyCount) { + cnt.resize(embInfos[batch->name].sendCount, 0); + } + } + vector countSend; + for (auto& cnt: keyCount) { + countSend.insert(countSend.end(), cnt.begin(), cnt.end()); + } + vector sc; + for (int i = 0; i < rankInfo.rankSize; ++i) { + sc.push_back(scAll.at(rankInfo.rankSize * rankInfo.rankId + i)); + } + vector rc; // receive count + for (int i = 0; i < rankInfo.rankSize; ++i) { + rc.push_back(scAll.at(i * rankInfo.rankSize + rankInfo.rankId)); + } + auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 + vector countRecv; + countRecv.resize(rs.back() + rc.back()); + MPI_Alltoallv(countSend.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), + rc.data(), rs.data(), MPI_UINT32_T, comm[batch->channel][id]); + return countRecv; +} + +void KeyProcess::PushResult(unique_ptr& batch, unique_ptr> tensors, + keys_t& lookupKeys, int id) +{ + std::unique_lock lockGuard(getInfoMut[id]); + storage[id].push_front(move(tensors)); + infoList[id][batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, storage[id].begin())); + if (!rankInfo.noDDR) { + lookupKeysList[id][batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, move(lookupKeys))); + } + lockGuard.unlock(); +} + +/* + * 从共享队列SingletonQueue中读取batch数据并返回。batch数据由 ReadEmbKeyV2 写入。 + * commID为线程标识[0, KEY_PROCESS_THREAD-1],不同线程、训练或推理数据用不同的共享队列通信 + */ +unique_ptr KeyProcess::GetBatchData(int channel, int commId) +{ + EASY_FUNCTION() + unique_ptr batch = nullptr; + // train data, queue id = thread id [0, KEY_PROCESS_THREAD-1] + auto batchQueue = SingletonQueue::getInstances(commId + KEY_PROCESS_THREAD * channel); + EASY_BLOCK("get samples") + EASY_VALUE("run on CPU", sched_getcpu()) + spdlog::stopwatch sw; + while (true) { + batch = batchQueue->TryPop(); + if (batch != nullptr) { + break; + } else { + this_thread::sleep_for(100us); + } + if (duration_cast(sw.elapsed()).count() > GET_BATCH_TIMEOUT) { + if (commId == 0) { + spdlog::warn(KEY_PROCESS "getting batch timeout! 1. check last 'read batch cost' print. " + "channel[{}] commId[{}]", channel, commId); + } + this_thread::sleep_for(seconds(1)); + sw.reset(); + } + if (!isRunning) { + // 通信终止信号,同步退出,防止线程卡住 + int exitFlag = isRunning; + MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + throw EndRunError("GetBatchData end run."); + } + } + EASY_END_BLOCK + spdlog::info(KEY_PROCESS "GetBatchData get batchId:{}, batchSize:{}, batch.channel:{}, name:{}, " + "channel:{}, commId:{}, ", + batch->batchId, batch->batchSize, batch->channel, batch->name, channel, commId); +#if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) + if (batch->batchId == PROFILING_START_BATCH_ID) { + EASY_PROFILER_ENABLE + } else if (batch->batchId == PROFILING_END_BATCH_ID) { + EASY_PROFILER_ENABLE + ::profiler::dumpBlocksToFile(fmt::format("/home/MX_REC-profile-{}.prof", rankInfo.rankId).c_str()); + } +#endif + return batch; +} + +auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id) + -> tuple, vector, vector, vector> +{ + EASY_FUNCTION(profiler::colors::Purple) + EASY_VALUE("batchId", batch->batchId) + + EASY_BLOCK("ock-unique") + + TimeCost unique_tc; + + SimpleThreadPool pool_; + keys_t keySend; + size_t size = rankInfo.rankSize * embInfos[batch->name].sendCount; + if (!rankInfo.useStatic) { + size = batch->batchSize; + } + keySend.resize(size); + vector splitSize(rankInfo.rankSize); + vector uniqueVector(batch->batchSize); + vector restore(batch->batchSize); + vector idCount(batch->batchSize); + vector keyCount(size); + std::shared_lock lock(g_smut); + auto hotMap = hotKey[batch->name]; + lock.unlock(); + vector hotPos; + int hotOffset = 0; + + if (rankInfo.useHot) { + hotPos.resize(hotEmbTotCount[batch->name]); + hotOffset = hotEmbTotCount[batch->name]; + } + absl::flat_hash_map keyCountMap; + + UniqueData uniqueData = {batch->tensorAddr, batch->batchSize, restore.data(), uniqueVector.data(), splitSize.data(), + keySend.data(), idCount.data(), keyCount.data()}; + UniqueFlag uniqueFlag = {batch->isInt64, rankInfo.useStatic, rankInfo.useHot}; + UniqueForHot uniqueForHot = {hotOffset, hotPos.data(), hotMap, keyCountMap}; + UniqueThreadNum uniqueThreadNum = {MIN_UNIQUE_THREAD_NUM, MAX_UNIQUE_THREAD_NUM}; + + unique->Compute(&pool_, uniqueData, uniqueFlag, uniqueForHot, uniqueThreadNum); + + EASY_END_BLOCK + TIME_PRINT("UniqueCompute TimeCost(ms):{}", unique_tc.ElapsedMS()); + + if (rankInfo.useHot) { + UpdateHotMap(keyCountMap, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, batch->name); + } + + vector sc; // send count + if (rankInfo.useStatic) { + sc.resize(rankInfo.rankSize, embInfos[batch->name].sendCount); + } else { + sc.resize(rankInfo.rankSize); + for (int i = 0;i < rankInfo.rankSize; i++) { + sc[i] = splitSize[i]; + } + } + auto [keyRecv, scAll, countRecv] = All2All(sc, id, batch->channel, keySend, keyCount); + + return { keyRecv, restore, hotPos, scAll, countRecv}; +} + +auto KeyProcess::All2All(vector& sc, int id, int channel, keys_t& keySend, vector& keyCount) + -> tuple, vector> +{ + keys_t keyRecv; + TimeCost get_sc_all; + auto scAll = GetScAll(sc, id, channel); // Allgather通信获取所有(不同rank相同thread id的) + TIME_PRINT("GetScAll TimeCost(ms):{}", get_sc_all.ElapsedMS()); + + TimeCost all2allTC; + auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 + vector rc(rankInfo.rankSize); // receive count + for (int i = 0; i < rankInfo.rankSize; ++i) { + // 通信量矩阵某一列的和即为本地要从其他设备接受的key数据量 + rc[i] = scAll.at(i * rankInfo.rankSize + rankInfo.rankId); + } + auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 + keyRecv.resize(rs.back() + rc.back()); + EASY_BLOCK("all2all") + MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, + comm[channel][id]); + + vector countRecv(rs.back() + rc.back()); + if (isWithFAAE) { + MPI_Alltoallv(keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), rc.data(), rs.data(), + MPI_UINT32_T, comm[channel][id]); + } + TIME_PRINT("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); + EASY_END_BLOCK + return {keyRecv, scAll, countRecv}; +} + +auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, + vector& splitKeys) -> tuple, vector> +{ + EASY_FUNCTION(profiler::colors::Purple) + EASY_VALUE("batchId", batch->batchId) + spdlog::info(KEY_PROCESS "ProcessSplitKeys start batchId:{}, channel:{}", batch->batchId, batch->channel); + + // 使用静态all2all通信:发送或接受量为预置固定值 scInfo[batch->name] = 65536 / rankSize 经验值 + if (rankInfo.useStatic) { // maybe move after all2all + for (auto& i: splitKeys) { + if (static_cast(i.size()) > embInfos[batch->name].sendCount) { + spdlog::error("{}[{}]:{} overflow! set send count bigger than {}", + batch->name, batch->channel, batch->batchId, i.size()); + } + i.resize(embInfos[batch->name].sendCount, -1); + } + } + keys_t keySend; + vector sc; // send count + for (const auto& i: splitKeys) { + sc.push_back(static_cast(i.size())); + keySend.insert(keySend.end(), i.begin(), i.end()); + } + keys_t keyRecv; + auto scAll = GetScAll(sc, id, batch->channel); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 + auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 + vector rc; // receive count + for (int i = 0; i < rankInfo.rankSize; ++i) { + // 通信量矩阵某一列的和即为本地要从其他设备接受的key数据量 + rc.push_back(scAll.at(i * rankInfo.rankSize + rankInfo.rankId)); + } + auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 + keyRecv.resize(rs.back() + rc.back()); + spdlog::trace(KEY_PROCESS "MPI_Alltoallv begin. rank {} thread {} batch {} {}", rankInfo.rankId, id, batch->batchId, + batch->name); + EASY_BLOCK("all2all") + MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, + keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, + comm[batch->channel][id]); + EASY_END_BLOCK + spdlog::trace(KEY_PROCESS "MPI_Alltoallv finish. rank {} thread {} batch {} {}", + rankInfo.rankId, id, batch->batchId, batch->name); + + return { keyRecv, scAll, ss }; +} + +/* + * 将batch内的key按照所存储的dev id哈希切分并去重,哈希函数为模运算 + * splitKeys返回:将数据的key切分到其所在dev id对应的桶中,并去重。 + * restore返回:去重后key在桶内偏移量(用于计算恢复向量) + */ +auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple, vector> +{ + EASY_FUNCTION(profiler::colors::Gold) + auto* batchData = batch->sample.data(); + size_t miniBs = batch->Size(); + ASSERT(batchData != nullptr); + vector splitKeys(rankInfo.rankSize); + vector restore(batch->Size()); + vector hashSplitLens(rankInfo.rankSize); // 初始化全0,记录每个桶的长度 + absl::flat_hash_map uKey; // 用于去重查询 + EASY_BLOCK("split push back") + for (size_t i = 0; i < miniBs; i++) { + const emb_key_t& key = batchData[i]; + int devId = static_cast(key) & (rankInfo.rankSize - 1); // 数据所在的设备devID = key % dev总数 support -1 + auto result = uKey.find(key); + if (result == uKey.end()) { + splitKeys[devId].push_back(key); + restore[i] = hashSplitLens[devId]++; // restore记录去重后key在桶内偏移量(用于计算恢复向量) + uKey[key] = restore[i]; + } else { // 去重 + restore[i] = result->second; + } + } + EASY_END_BLOCK + if (spdlog::get_level() == spdlog::level::trace) { + stringstream ssTrace; + for (int devId = 0; devId < rankInfo.rankSize; ++devId) { + ssTrace << '|' << devId << ":"; + for (auto x: splitKeys[devId]) { + ssTrace << x << ','; + } + ssTrace << '|'; + } + spdlog::trace("dump splitKeys\n{}", ssTrace.str()); + } + return { splitKeys, restore }; +} + +auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const + -> tuple, vector, vector>> +{ + EASY_FUNCTION(profiler::colors::Gold) + auto* batchData = batch->sample.data(); + size_t miniBs = batch->Size(); + ASSERT(batchData != nullptr); + vector splitKeys(rankInfo.rankSize); + vector> keyCount(rankInfo.rankSize); // splitKeys在原始batch中对应的频次 + vector restore(batch->Size()); + vector hashSplitLens(rankInfo.rankSize); // 初始化全0,记录每个桶的长度 + absl::flat_hash_map> uKey; // 用于去重查询 + EASY_BLOCK("split push back") + for (size_t i = 0; i < miniBs; i++) { + const emb_key_t& key = batchData[i]; + int devId = static_cast(key) & (rankInfo.rankSize - 1); // 数据所在的设备devID = key % dev总数 support -1 + auto result = uKey.find(key); + if (result == uKey.end()) { + splitKeys[devId].push_back(key); + restore[i] = hashSplitLens[devId]++; // restore记录去重后key在桶内偏移量(用于计算恢复向量) + uKey[key].first = restore[i]; + uKey[key].second = 1; + } else { // 去重 + restore[i] = result->second.first; + uKey[key].second++; + } + } + + // 处理splitKeys对应的count + for (int j = 0; j < rankInfo.rankSize; ++j) { + vector count; + for (size_t k = 0; k < splitKeys[j].size(); ++k) { + count.emplace_back(uKey[splitKeys[j][k]].second); + } + keyCount[j] = count; + } + + EASY_END_BLOCK + if (spdlog::get_level() == spdlog::level::trace) { + stringstream ssTrace; + for (int devId = 0; devId < rankInfo.rankSize; ++devId) { + ssTrace << '|' << devId << ":"; + for (auto x : splitKeys[devId]) { + ssTrace << x << ','; + } + ssTrace << '|'; + } + spdlog::trace("dump splitKeys\n{}", ssTrace.str()); + } + + return { splitKeys, restore, keyCount }; +} + +auto KeyProcess::HotHashSplit(const unique_ptr& batch) -> +tuple, vector, vector> +{ + EASY_FUNCTION(profiler::colors::Gold) + auto* batchData = batch->sample.data(); + size_t miniBs = batch->Size(); + vector splitKeys(rankInfo.rankSize); + vector restore(batch->Size()); + absl::flat_hash_map uKey; // 用于去重查询 + absl::flat_hash_map keyCountMap; + std::shared_lock lock(g_smut); + auto hotMap = hotKey[batch->name]; + lock.unlock(); + vector hotPos(hotEmbTotCount[batch->name]); + vector hotPosDev(hotEmbTotCount[batch->name]); + + int hotCount = 0; + int hotOffset = hotEmbTotCount[batch->name]; + for (size_t i = 0; i < miniBs; i++) { // for mini batch + const emb_key_t& key = batchData[i]; + if (batch->batchId % hotEmbUpdateStep == 0) { + keyCountMap[key]++; + } + int devId = static_cast(key) & (rankInfo.rankSize - 1); // 数据所在的设备devID = key % dev总数 support -1 + auto result = uKey.find(key); + if (result != uKey.end()) { // // already in splitKeys + restore[i] = result->second; + continue; + } + // new key in current batch + splitKeys[devId].push_back(key); // push to bucket + auto hot = hotMap.find(key); + if (hot != hotMap.end()) { // is hot key + if (hot->second == -1) { // is new hot key in this batch + hotPos[hotCount] = splitKeys[devId].size() - 1; // pos in lookup vec (need add ss) for hot-gather + hotPosDev[hotCount] = devId; // which dev, for get ss + hot->second = hotCount; + restore[i] = hotCount++; // get pos of hot emb + } else { + restore[i] = hot->second; + } + } else { // is not hot key + restore[i] = splitKeys[devId].size() + hotOffset - 1; // restore记录去重后key在桶内偏移量(用于计算恢复向量) + } + uKey[key] = restore[i]; + } + + UpdateHotMap(keyCountMap, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, batch->name); + AddCountStartToHotPos(splitKeys, hotPos, hotPosDev, batch->name); + + return { splitKeys, restore, hotPos }; +} + +void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, + const string& embName) const +{ + vector splitKeysSize {}; + if (rankInfo.useStatic) { + for (size_t i = 0; i < splitKeys.size(); i++) { + splitKeysSize.push_back(embInfos.at(embName).sendCount); + } + } else { + for (auto& splitKey: splitKeys) { + splitKeysSize.push_back(splitKey.size()); + } + } + auto cs = Count2Start(splitKeysSize); + for (size_t i = 0; i < hotPos.size(); ++i) { + hotPos[i] += cs[hotPosDev[i]]; + } +} + +void KeyProcess::UpdateHotMap(absl::flat_hash_map& keyCountMap, uint32_t count, bool refresh, + const string& embName) +{ + auto& hotMap = hotKey[embName]; + if (refresh) { + priority_queue> pq; // top k key + for (auto& p: keyCountMap) { + pq.push(pair(-p.second, p.first)); + if (pq.size() > count) { + pq.pop(); + } + } + // gen new hot map + std::unique_lock lock(g_smut); + hotMap.clear(); + while (!pq.empty()) { + hotMap.insert(make_pair(pq.top().second, -1)); + pq.pop(); + } + } +} + +/* + * 将本地(rank)batch要发送的key数据量进行Allgather通信,获取所有(不同rank相同thread id的)线程间的通信量矩阵 + * scAll返回:所有线程间的通信量矩阵(按行平铺的一维向量) + */ +vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel) const +{ + EASY_FUNCTION() + vector scAll; + scAll.resize(rankInfo.rankSize * rankInfo.rankSize); + EASY_BLOCK("barrier"); + // 通信终止信号,同步退出,防止线程卡住 + spdlog::stopwatch sw; + int exitFlag = isRunning; + MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + if (exitFlag < rankInfo.rankSize) { + throw EndRunError("GetScAll end run."); + } + EASY_END_BLOCK; + spdlog::debug(KEY_PROCESS "barrier time:{}", TO_MS(sw)); + // allgather keyScLocal(key all2all keyScLocal = device all2all rc) + MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, + scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + spdlog::debug("rank {} key scAll matrix:\n{}", rankInfo.rankId, scAll); + return scAll; +} + +void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) +{ + EASY_FUNCTION(profiler::colors::Blue600) + std::lock_guard lk(key2OffsetMut); // lock for PROCESS_THREAD + auto& key2Offset = keyOffsetMap[embName]; + auto& maxOffsetTmp = maxOffset[embName]; + auto& evictPos = evictPosMap[embName]; + auto& curEmbTable = embeddingTableMap[embName]; // empty when not use dynamic expansion + for (long& key : splitKey) { + if (key == -1) { + if (rankInfo.useDynamicExpansion) { + key = 0; + } + continue; + } + const auto& iter = key2Offset.find(key); + if (iter != key2Offset.end()) { + // 老值 + key = iter->second; + } else if (evictPos.size() != 0) { + size_t offset; + // 新值, emb有pos可复用 + offset = evictPos.back(); + spdlog::trace("HBM mode, evictPos is not null, name[{}] key [{}] reuse offset [{}], evictSize [{}]!!!", + embName, key, offset, evictPos.size()); + key2Offset[key] = offset; + key = offset; + evictPos.pop_back(); + } else { + // 新值 + if (rankInfo.useDynamicExpansion) { + auto addr = curEmbTable.GetEmbAddress(); + key2Offset[key] = addr; + key = addr; + maxOffsetTmp++; + } else { + key2Offset[key] = maxOffsetTmp; + key = maxOffsetTmp++; + } + } + } + if (!rankInfo.useDynamicExpansion && maxOffsetTmp > embInfos[embName].devVocabSize) { + spdlog::error("dev cache overflow {}>{}", maxOffsetTmp, embInfos[embName].devVocabSize); + } + spdlog::debug("current dev emb usage:{}/{}", maxOffsetTmp, embInfos[embName].devVocabSize); +} + +void KeyProcess::Key2OffsetInit(const emb_name_t& embName) +{ + auto& key2Offset = keyOffsetMap[embName]; + auto& offset = maxOffset[embName]; + key2Offset[rankInfo.rankId] = offset; // 0 rank init feature id 0 to offset 0 + offset++; +} + +/* + * 构建恢复向量,以便从去重后的emb向量/key恢复回batch对应的emb向量 + * 输入接收到emb块的偏移blockOffset,batch内每个key在块内的偏移restoreVec + * 输出恢复向量restoreVec,即batch到keySend(平铺的splitKeys)的映射 + * 实现方案2:用map记录keySend中key和表内index/offset的映射,在恢复emb时直接根据batch的key查询该map即可找到receive + * emb中的 位置,时间复杂度:O(map构建keySend.size + map查询),空间复杂度:O(map) + */ +void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, + vector& restoreVec, int hotPosSize) const +{ + EASY_FUNCTION() + int hotNum = 0; + bool spdDebug = (spdlog::get_level() == spdlog::level::debug); + for (size_t i = 0; i < batch->Size(); ++i) { + const emb_key_t d = batch->sample[i]; + int devId = static_cast(d) & (rankInfo.rankSize - 1); + if (restoreVec[i] >= hotPosSize) { + restoreVec[i] += blockOffset[devId]; + } else if (spdDebug) { + hotNum += 1; + } + } + spdlog::debug("hot num in all:{}/{}", hotNum, batch->Size()); +} + +class EmptyList : public std::exception { +}; + +class WrongListTop : public std::exception { +}; + +template +T KeyProcess::GetInfo(array, KEY_PROCESS_THREAD>& list, int batch, const string& embName, int channel) +{ + int batchListId = batch % KEY_PROCESS_THREAD; + std::lock_guard lockGuard(getInfoMut[batchListId]); + if (list[batchListId][embName][channel].empty()) { + spdlog::trace("get info list is empty."); + throw EmptyList(); + } + auto topBatch = get(list[batchListId][embName][channel].top()); + if (topBatch < batch) { + spdlog::error("wrong batch id, top:{} expect:{}, channel:{}, embName: {}, queue_size:{}, " + "may not clear channel", + topBatch, batch, channel, embName, list[batchListId][embName][channel].size()); + this_thread::sleep_for(1s); + } + if (topBatch != batch) { + spdlog::trace("topBatch({}) is not equal batch({}).", topBatch, batch); + throw WrongListTop(); + } + auto t = list[batchListId][embName][channel].top(); + list[batchListId][embName][channel].pop(); + return move(t); +} + +keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) +{ + spdlog::stopwatch sw; + while (true) { + if (!isRunning) { + return {}; + } + if (batch != 0 && channel != 0 && duration_cast(sw.elapsed()).count() > KEY_PROCESS_TIMEOUT) { + spdlog::warn(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", embName, channel, batch); + return {}; + } + try { + auto ret = GetInfo(lookupKeysList, batch, embName, channel); + return get(ret); + } catch (EmptyList&) { + spdlog::trace("GetLookupKeys GetInfo failed {}[{}]:{} no input, wait and retry", + embName, channel, batch); + this_thread::sleep_for(1ms); + } catch (WrongListTop&) { + spdlog::trace("GetLookupKeys GetInfo failed {}[{}]:{} wrong top", + embName, channel, batch); + this_thread::sleep_for(1ms); + } + } +} + +unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) +{ + spdlog::stopwatch sw; + array, KEY_PROCESS_THREAD>* list; + switch (type) { + case ProcessedInfo::ALL2ALL: + list = &all2AllList; + break; + case ProcessedInfo::RESTORE: + list = &infoList; + break; + default: + throw runtime_error("ERROR list type"); + } + while (true) { + if (!isRunning) { + return nullptr; + } + if (batch != 0 && channel != 0 && duration_cast(sw.elapsed()).count() > KEY_PROCESS_TIMEOUT) { + spdlog::warn(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", embName, channel, batch); + return nullptr; + } + try { + auto ret = GetInfo(*list, batch, embName, channel); + auto it = get>>::iterator>(ret); + auto uTensor = move(*it); + int batchListId = batch % KEY_PROCESS_THREAD; + unique_lock lockGuard(getInfoMut[batchListId]); + storage[batchListId].erase(it); + return uTensor; + } catch (EmptyList&) { + spdlog::trace("GetInfoVec GetInfo failed {}[{}]:{} type: {} no input and retry", + embName, channel, batch, type); + this_thread::sleep_for(1ms); + } catch (WrongListTop&) { + spdlog::trace("GetInfoVec GetInfo failed {}[{}]:{} type: {} wrong top", + embName, channel, batch, type); + this_thread::sleep_for(1ms); + } + } +} + +void KeyProcess::SendA2A(const vector& a2aInfo, const string& embName, int channel, int batchId) +{ + // 数据放到队列里,在mgmt里面发送(检查发送数据量) + auto tensors = make_unique>(); + Tensor tmpTensor(tensorflow::DT_INT64, { rankInfo.rankSize, rankInfo.rankSize }); + auto tmpData = tmpTensor.matrix(); + for (int i = 0; i < rankInfo.rankSize; ++i) { + for (int j = 0; j < rankInfo.rankSize; ++j) { + tmpData(i, j) = a2aInfo[j * rankInfo.rankSize + i]; + } + } + tensors->emplace_back(move(tmpTensor)); + + int batchListId = batchId % KEY_PROCESS_THREAD; + std::unique_lock lockGuard(getInfoMut[batchListId]); + storage[batchListId].push_front(move(tensors)); + all2AllList[batchListId][embName][channel].push(make_tuple(batchId, embName, storage[batchListId].begin())); + lockGuard.unlock(); +} + +int KeyProcess::GetMaxStep(int channelId) const +{ + return rankInfo.maxStep.at(channelId); +} + +void KeyProcess::EvictKeys(const string& embName, const vector& keys) // hbm +{ + spdlog::info(KEY_PROCESS "hbm funEvictCall: [{}]! keySize:{}", embName, keys.size()); + + // 删除映射关系 + if (keys.size() != 0) { + EvictDeleteDeviceEmb(embName, keys); + } + + // 初始化 dev + EvictInitDeviceEmb(embName, evictPosMap.at(embName)); +} + +void KeyProcess::EvictDeleteDeviceEmb(const string& embName, const vector& keys) +{ + EASY_FUNCTION(profiler::colors::Blue600) + std::lock_guard lk(key2OffsetMut); // lock for PROCESS_THREAD + + size_t keySize = keys.size(); + auto& devHashMap = keyOffsetMap.at(embName); + auto& evictPos = evictPosMap.at(embName); + + for (size_t i = 0; i < keySize; i++) { + size_t offset; + auto key = keys[i]; + if (key == -1) { + spdlog::error("evict key equal -1!"); + continue; + } + const auto& iter = devHashMap.find(key); + if (iter == devHashMap.end()) { // not found + continue; + } + offset = iter->second; + devHashMap.erase(iter); + evictPos.emplace_back(offset); + spdlog::trace("evict embName {} , offset , {}", embName, offset); + } + spdlog::info(KEY_PROCESS "hbm EvictDeleteDeviceEmb: [{}]! evict size on dev:{}", embName, evictPos.size()); +} + +void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset) +{ + if (offset.size() > embInfos[embName].devVocabSize) { + spdlog::error("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", + embName, offset.size(), embInfos[embName].devVocabSize); + } + if (rankInfo.useStatic) { + offset.resize(embInfos[embName].devVocabSize, -1); + } + + auto trans = Singleton::GetInstance(); + // evict key发送给dev侧,dev侧初始化emb + auto tmpData = Vec2TensorI32(offset); + trans->Send(EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); + + spdlog::info(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); +} \ No newline at end of file diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h new file mode 100644 index 00000000..d2730707 --- /dev/null +++ b/src/core/key_process/key_process.h @@ -0,0 +1,167 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Date: 2022/11/15 + */ + +#ifndef MX_REC_KEY_PROCESS_H +#define MX_REC_KEY_PROCESS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/common.h" +#include "utils/safe_queue.h" +#include "utils/unique.h" +#include "utils/spinlock.h" +#include "utils/task_queue.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "host_emb/host_emb.h" +#include "feature_admit_and_evict.h" +#include "emb_table/emb_table.h" + +namespace MxRec { + using namespace std; + + constexpr int UNIQUE_BUCKET = 6; + constexpr int MIN_UNIQUE_THREAD_NUM = 1; + constexpr int MAX_UNIQUE_THREAD_NUM = 8; + + using a2a_info_t = vector; + using sharded_dedup = ShardedDedup*; + + template struct Cmp { + bool operator () (const T &a, const T &b) + { + return get(a) > get(b); // batch id order + } + }; + + template + using heap_t = priority_queue, Cmp>; + template + using info_list_t = map, MAX_QUEUE_NUM>>; + enum class ProcessedInfo { + RESTORE, + ALL2ALL, + INVALID + }; + + class KeyProcess { + public: + int Initialize(const RankInfo& rInfo, const vector& eInfos, + const vector& thresholdValues = {}, bool ifLoad = false, int seed = 0); + + unique_ptr> GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type); + + keys_t GetLookupKeys(int batch, const string& embName, int channel); + + int GetMaxStep(int channelId) const; + + int Start(); + + auto GetMaxOffset() -> offset_mem_t; + + auto GetKeyOffsetMap() -> key_offset_mem_t; + + auto GetFeatAdmitAndEvict() -> FeatureAdmitAndEvict&; + + void LoadMaxOffset(offset_mem_t& loadData); + + void LoadKeyOffsetMap(key_offset_mem_t& loadData); + + void Destroy(); + + void LoadSaveLock(); + + void LoadSaveUnlock(); + + void EvictKeys(const string& embName, const vector& keys); + + bool isRunning { false }; + + GTEST_PRIVATE: + + template + T GetInfo(array, KEY_PROCESS_THREAD>& list, int batch, const string& embName, int channel); + + RankInfo rankInfo; + map embInfos; + MPI_Comm comm[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD]; + vector procThread {}; + std::mutex key2OffsetMut {}; + std::mutex loadSaveMut[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD] {}; + std::mutex getInfoMut[KEY_PROCESS_THREAD] {}; + array, KEY_PROCESS_THREAD> lookupKeysList; + list>> storage[KEY_PROCESS_THREAD]; + array, KEY_PROCESS_THREAD> infoList; + array, KEY_PROCESS_THREAD> all2AllList; + map maxOffset {}; + map> keyOffsetMap {}; + FeatureAdmitAndEvict m_featureAdmitAndEvict {}; + map> evictPosMap {}; + map> hotKey {}; + map hotEmbTotCount; + map embeddingTableMap {}; + int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; + bool isWithFAAE; + + void KeyProcessTask(int channel, int id); + + bool KeyProcessTaskHelper(unique_ptr& batch, sharded_dedup unique_, + int channel, int id, spdlog::stopwatch& sw); + auto ProcessSplitKeys(const unique_ptr& batch, int id, + vector& splitKeys) -> tuple, vector>; + auto ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique_, int id) + -> tuple, vector, vector, vector>; + + auto All2All(vector& sc, int id, int channel, keys_t& keySend, vector& keyCount) + -> tuple, vector>; + + auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; + + auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector>; + + auto HashSplit_withFAAE(const unique_ptr& batch) const + -> tuple, vector, vector>>; + [[nodiscard]] vector GetScAll(const vector& keyScLocal, int commId, int channel) const; + + void Key2Offset(const emb_name_t& embName, keys_t& splitKey); + + unique_ptr GetBatchData(int channel, int commId); + + void BuildRestoreVec(const unique_ptr& batch, const vector& rs, + vector& restoreVec, int hotPosSize = 0) const; + + void SendA2A(const vector& a2aInfo, const string& embName, int channel, int batch); + + void Key2OffsetInit(const emb_name_t& embName); + + void EvictDeleteDeviceEmb(const string& embName, const vector& keys); + + void EvictInitDeviceEmb(const string& embName, vector offset); + + void UpdateHotMap(absl::flat_hash_map& keyCountMap, uint32_t count, bool refresh, + const string& embName); + + void PushResult(unique_ptr& batch, unique_ptr> tensors, keys_t& lookupKeys, int id); + + void AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, + const string& embName) const; + + vector GetCountRecv(const unique_ptr& batch, int id, + vector>& keyCount, vector scAll, vector ss); + }; +} + + +#endif // MX_REC_KEY_PROCESS_H diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp new file mode 100644 index 00000000..9da6952c --- /dev/null +++ b/src/core/utils/common.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Create: 2021 + * History: NA + */ + +#include "common.h" +#include +#include +#include +#include + +using namespace std; +using std::chrono::system_clock; + +namespace MxRec { + RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, + const vector& maxStep) : rankId(rankId), deviceId(deviceId), localRankSize(localRankSize), option(option), + nBatch(nBatch), maxStep(maxStep) + { + MPI_Comm_size(MPI_COMM_WORLD, &rankSize); + if (localRankSize != 0) { + localRankId = rankId % localRankSize; + } + useStatic = option bitand HybridOption::USE_STATIC; + useHot = option bitand HybridOption::USE_HOT; + useDynamicExpansion = option bitand HybridOption::USE_DYNAMIC_EXPANSION; + } + + RankInfo::RankInfo(int localRankSize, int option, int nBatch, const vector& maxStep) + : localRankSize(localRankSize), option(option), nBatch(nBatch), maxStep(maxStep) + { + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + MPI_Comm_size(MPI_COMM_WORLD, &rankSize); + if (localRankSize != 0) { + localRankId = rankId % localRankSize; + } + useStatic = option bitand HybridOption::USE_STATIC; + useHot = option bitand HybridOption::USE_HOT; + } + + RandomInfo::RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax) + : start(start), len(len), constantVal(constantVal), randomMin(randomMin), randomMax(randomMax) + {} + + ConstantInitializerInfo::ConstantInitializerInfo(float constantValue) + : constantValue(constantValue) + {} + + NormalInitializerInfo::NormalInitializerInfo(float mean, float stddev, int seed) + : mean(mean), stddev(stddev), seed(seed) + {} + + InitializeInfo::InitializeInfo(std::string& name, int start, int len, + ConstantInitializerInfo constantInitializerInfo) + : name(name), start(start), len(len), constantInitializerInfo(constantInitializerInfo) + { + if (name == "constant_initializer") { + initializerType = InitializerType::CONSTANT; + constantInitializer = ConstantInitializer(start, len, constantInitializerInfo.constantValue); + } else { + throw std::invalid_argument("Invalid Initializer Type."); + } + } + + InitializeInfo::InitializeInfo(std::string& name, int start, int len, NormalInitializerInfo normalInitializerInfo) + : name(name), start(start), len(len), normalInitializerInfo(normalInitializerInfo) + { + if (name == "truncated_normal_initializer") { + initializerType = InitializerType::TRUNCATED_NORMAL; + truncatedNormalInitializer = TruncatedNormalInitializer(start, len, + normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed); + } else if (name == "random_normal_initializer") { + initializerType = InitializerType::RANDOM_NORMAL; + randomNormalInitializer = RandomNormalInitializer(start, len, + normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed); + } else { + throw std::invalid_argument("Invalid Initializer Type."); + } + } + + void SetLog(int rank) + { + std::string pattern = "[%H:%M:%S.%e] [" + std::to_string(rank) + "] [%^%l%$] %v"; + spdlog::default_logger()->set_pattern(pattern); + auto env_val = spdlog::details::os::getenv("SPDLOG_LEVEL"); + spdlog::cfg::load_env_levels(); + } + + string GetChipName(int devID) + { + int ret = 0; + struct dsmi_chip_info_stru info = {{ 0 }, + { 0 }, + { 0 }}; + ret = dsmi_get_chip_info(devID, &info); + if (ret == 0) { + spdlog::debug("dsmi_get_chip_info successful, ret = {}, chip_name = {}", ret, + reinterpret_cast(info.chip_name)); + return reinterpret_cast(info.chip_name); + } + + throw std::runtime_error("dsmi_get_chip_info failed, ret = " + to_string(ret)); + } +} diff --git a/src/core/utils/common.h b/src/core/utils/common.h new file mode 100644 index 00000000..d7290fe1 --- /dev/null +++ b/src/core/utils/common.h @@ -0,0 +1,433 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Create: 2021 + * History: NA + */ + +#ifndef COMMON_H +#define COMMON_H + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "tensorflow/core/framework/tensor.h" +#include "absl/container/flat_hash_map.h" + +#include "initializer/initializer.h" +#include "initializer/constant_initializer/constant_initializer.h" +#include "initializer/truncated_normal_initializer/truncated_normal_initializer.h" +#include "initializer/random_normal_initializer/random_normal_initializer.h" + +#if defined(BUILD_WITH_EASY_PROFILER) + #include + #include +#else + #define EASY_FUNCTION(...) + #define EASY_VALUE(...) + #define EASY_BLOCK(...) + #define EASY_END_BLOCK + #define EASY_PROFILER_ENABLE + #define EASY_PROFILER_DISABLE +#endif + +namespace MxRec { +#define ASSERT(arg) assert(arg) +#define TO_MS(arg) duration_cast((arg).elapsed()) +#define INFO_PTR shared_ptr +#define TIME_PRINT spdlog::info +#define MGMT_CPY_THREADS 4 +#define PROFILING + // read batch cost + // key process cost + using namespace tensorflow; + constexpr int TRAIN_CHANNEL_ID = 0; + constexpr int EVAL_CHANNEL_ID = 1; + constexpr int MAX_CHANNEL_NUM = 2; + constexpr int KEY_PROCESS_THREAD = 6; + constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * KEY_PROCESS_THREAD; + constexpr int KEY_PROCESS_TIMEOUT = 120; + constexpr int GET_BATCH_TIMEOUT = 300; + constexpr size_t DEFAULT_RANDOM_SEED = 10086; + constexpr int INVALID_KEY_VALUE = -1; + constexpr int PROFILING_START_BATCH_ID = 100; + constexpr int PROFILING_END_BATCH_ID = 200; + constexpr int MGMT_THREAD_BIND = 48; + constexpr int UNIQUE_MAX_BUCKET_WIDTH = 6; + constexpr int HOT_EMB_UPDATE_STEP_DEFAULT = 1000; + constexpr float HOT_EMB_CACHE_PCT = 1. / 3; // hot emb cache percent + + using emb_key_t = int64_t; + using emb_name_t = std::string; + using keys_t = std::vector; + using lookup_key_t = std::tuple; // batch_id quarry_lable keys_vector + using tensor_info_t = std::tuple>>::iterator>; + using EndRunError = std::runtime_error; + + namespace HybridOption { + const int USE_STATIC = 0x001; + const int USE_HOT = 0x001 << 1; + const int USE_DYNAMIC_EXPANSION = 0x001 << 2; + }; + + string GetChipName(int devID); + + namespace UBSize { + const int ASCEND910_PREMIUM_A = 262144; + const int ASCEND910_PRO_B = 262144; + const int ASCEND910_B2 = 196608; + const int ASCEND910_B1 = 196608; + const int ASCEND910_B3 = 196608; + const int ASCEND910_B4 = 196608; + const int ASCEND910_C1 = 196608; + const int ASCEND910_C2 = 196608; + const int ASCEND910_C3 = 196608; + const int ASCEND920_A = 196608; + const int ASCEND910_PRO_A = 262144; + const int ASCEND910_B = 262144; + const int ASCEND910_A = 262144; + }; + + inline int GetUBSize(int devID) + { + std::map ChipUbSizeList = {{"910A", UBSize::ASCEND910_A}, + {"910B", UBSize::ASCEND910_B}, + {"920A", UBSize::ASCEND920_A}, + {"910B1", UBSize::ASCEND910_B1}, + {"910B2", UBSize::ASCEND910_B2}, + {"910B3", UBSize::ASCEND910_B3}, + {"910B4", UBSize::ASCEND910_B4}}; + auto it = ChipUbSizeList.find(GetChipName(devID)); + if (it != ChipUbSizeList.end()) { + return it->second; + } + + throw std::runtime_error("unknown chip ub size" + GetChipName(devID)); + } + + template struct Batch { + size_t Size() + { + return sample.size(); + } + + std::string UnParse() + { + std::string s; + constexpr size_t MAX_DISP_LEN = 20; + int max_len = std::min(sample.size(), MAX_DISP_LEN); + for (int i = 0; i < max_len; i++) { + s += std::to_string(sample[i]) + " "; + } + return s; + } + + std::vector sample; + void *tensorAddr = nullptr; + std::string name; + size_t batchSize; + int batchId; + int channel = 0; + bool isInt64; // true int64 false int32 + time_t timestamp { -1 }; + }; + +struct BatchTask { + vector splits; + vector embNames; + size_t batchSize; + int batchQueueId; + int batchId; + int channelId; + time_t timestamp { -1 }; + bool flag; // true int64 false int32 + const void *tensor; +}; + + using emb_batch_t = Batch; + using batch_task_t = BatchTask; + + + struct RankInfo { + RankInfo() = default; + + RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, + const std::vector& maxStep); + RankInfo(int localRankSize, int option, int nBatch, const std::vector& maxStep); + + int rankId {}; + int deviceId {}; + int rankSize {}; + int localRankId {}; + int localRankSize {}; + bool useStatic { false }; + bool useHot {}; + uint32_t option {}; + int nBatch {}; + bool useDataset { false }; // deprecated + bool noDDR { false }; + bool useDynamicExpansion {false}; + std::vector maxStep; + }; + + enum TensorIndex : uint32_t { + TENSOR_INDEX_0, + TENSOR_INDEX_1, + TENSOR_INDEX_2, + TENSOR_INDEX_3, + TENSOR_INDEX_4, + TENSOR_INDEX_5, + TENSOR_INDEX_6, + TENSOR_INDEX_7, + TENSOR_INDEX_8 + }; + + enum TupleIndex : uint32_t { + TUPLE_INDEX_0 = 0, + TUPLE_INDEX_1, + TUPLE_INDEX_2, + TUPLE_INDEX_3, + TUPLE_INDEX_4, + TUPLE_INDEX_5, + TUPLE_INDEX_6, + TUPLE_INDEX_7 + }; + + struct RandomInfo { + RandomInfo() = default; + + RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax); + + int start; + int len; + float constantVal; + float randomMin; + float randomMax; + }; + + struct ThresholdValue { + ThresholdValue() = default; + ThresholdValue(emb_name_t name, int countThre, int timeThre) + { + tensorName = name; + countThreshold = countThre; + timeThreshold = timeThre; + } + + emb_name_t tensorName { "" }; // embName + int countThreshold { -1 }; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 + int timeThreshold { -1 }; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 + }; + + struct FeatureItemInfo { + FeatureItemInfo() = default; + FeatureItemInfo(int64_t id, uint32_t cnt, std::string name, time_t lastT) + : featureId(id), count(cnt), tensorName(name), lastTime(lastT) + {} + + bool operator > (const FeatureItemInfo& item) const + { + return lastTime > item.lastTime; + } + + int64_t featureId { -1 }; + uint32_t count { 0 }; + std::string tensorName { "" }; + time_t lastTime { 0 }; + }; + + using SortedRecords = + std::priority_queue, std::greater>; + using HistoryRecords = absl::flat_hash_map>; + struct AdmitAndEvictData { + HistoryRecords historyRecords; // embName ---> {id, FeatureItemInfo} 映射 + absl::flat_hash_map timestamps; // 用于特征准入&淘汰的时间戳 + }; + + void SetLog(int rank); + + inline void GenerateRandomValue(std::vector& vecData, + std::default_random_engine& generator, + RandomInfo& randomInfo) + { + float min = ((randomInfo.randomMin == 0) ? -0.1f : randomInfo.randomMin); + float max = ((randomInfo.randomMax == 0) ? 0.1f : randomInfo.randomMax); + if (randomInfo.len == 0) { + return; + } + ASSERT(static_cast(vecData.size()) >= randomInfo.len + randomInfo.start); + std::uniform_real_distribution distribution(min, max); + std::generate_n(vecData.begin() + randomInfo.start, randomInfo.len, [&]() { return distribution(generator); }); + } + + enum class InitializerType { + CONSTANT, + TRUNCATED_NORMAL, + RANDOM_NORMAL + }; + + struct ConstantInitializerInfo { + ConstantInitializerInfo() = default; + explicit ConstantInitializerInfo(float constantValue); + + float constantValue; + }; + + struct NormalInitializerInfo { + NormalInitializerInfo() = default; + NormalInitializerInfo(float mean, float stddev, int seed); + + float mean; + float stddev; + int seed; + }; + + struct InitializeInfo { + InitializeInfo() = default; + + InitializeInfo(std::string& name, int start, int len, ConstantInitializerInfo constantInitializerInfo); + InitializeInfo(std::string& name, int start, int len, NormalInitializerInfo normalInitializerInfo); + + std::string name; + int start; + int len; + InitializerType initializerType; + + ConstantInitializerInfo constantInitializerInfo; + NormalInitializerInfo normalInitializerInfo; + + ConstantInitializer constantInitializer; + TruncatedNormalInitializer truncatedNormalInitializer; + RandomNormalInitializer randomNormalInitializer; + }; + + template + inline Tensor Vec2TensorI32(const std::vector& data) + { + Tensor tmpTensor(tensorflow::DT_INT32, { static_cast(data.size()) }); + auto tmpData = tmpTensor.flat(); + for (int j = 0; j < static_cast(data.size()); j++) { + tmpData(j) = static_cast(data[j]); + } + return tmpTensor; + } + + template + inline Tensor Vec2TensorI64(const std::vector& data) + { + Tensor tmpTensor(tensorflow::DT_INT64, { static_cast(data.size()) }); + auto tmpData = tmpTensor.flat(); + for (int j = 0; j < static_cast(data.size()); j++) { + tmpData(j) = static_cast(data[j]); + } + return tmpTensor; + } + + struct EmbInfo { + EmbInfo() = default; + + EmbInfo(const std::string& name, + int sendCount, + int embeddingSize, + std::vector vocabsize, + std::vector initializeInfos); + + std::string name; + int sendCount; + int embeddingSize; + size_t devVocabSize; + size_t hostVocabSize; + std::vector initializeInfos; + }; + + struct HostEmbTable { + EmbInfo hostEmbInfo; + std::vector> embData; + }; + + struct EmbHashMapInfo { + absl::flat_hash_map hostHashMap; + std::vector devOffset2Batch; // has -1 + std::vector devOffset2Key; + size_t currentUpdatePos; + size_t currentUpdatePosStart; + size_t hostVocabSize; + size_t devVocabSize; + size_t freeSize; + std::vector lookUpVec; + std::vector missingKeysHostPos; + std::vector swapPos; + size_t maxOffset { 0 }; + std::vector evictPos; + std::vector evictDevPos; + + void SetStartCount(); + + bool HasFree(size_t i); + }; + + using emb_mem_t = absl::flat_hash_map; + using emb_hash_mem_t = absl::flat_hash_map; + using offset_mem_t = std::map; + using key_offset_mem_t = std::map>; + using tensor_2_thresh_mem_t = absl::flat_hash_map; + using trans_serialize_t = uint8_t; + + enum class CkptFeatureType { + HOST_EMB = 0, + EMB_HASHMAP = 1, + MAX_OFFSET = 2, + KEY_OFFSET_MAP = 3, + FEAT_ADMIT_N_EVICT = 4 + }; + + struct CkptData { + emb_mem_t hostEmbs; + emb_hash_mem_t embHashMaps; + offset_mem_t maxOffset; + key_offset_mem_t keyOffsetMap; + tensor_2_thresh_mem_t tens2Thresh; + AdmitAndEvictData histRec; + }; + + struct CkptTransData { + std::vector int64Arr; + std::vector floatArr; + std::vector int32Arr; + std::vector transDataset; // may all use this to transfer data + std::vector attribute; // may need to use other form for attributes + size_t datasetSize; + size_t attributeSize; + }; + + enum class CkptDataType { + EMB_INFO = 0, + EMB_DATA = 1, + EMB_HASHMAP = 2, + DEV_OFFSET = 3, + EMB_CURR_STAT = 4, + NDDR_OFFSET = 5, + NDDR_FEATMAP = 6, + TENSOR_2_THRESH = 7, + HIST_REC = 8, + ATTRIBUTE = 9 + }; +} +#define KEY_PROCESS "\033[45m[KeyProcess]\033[0m " +#ifdef GTEST + #define GTEST_PRIVATE public +#else + #define GTEST_PRIVATE private +#endif +#endif diff --git a/src/core/utils/safe_queue.h b/src/core/utils/safe_queue.h new file mode 100644 index 00000000..3c2e2be5 --- /dev/null +++ b/src/core/utils/safe_queue.h @@ -0,0 +1,152 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: safe queue class + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#ifndef SAFE_QUEUE_H +#define SAFE_QUEUE_H + +#include +#include +#include +#include +#include +#include +#include "common.h" + +template +class SafeQueue { + static constexpr uint64_t DEFAULT_CAP = 10; + +public: + SafeQueue() = default; + + ~SafeQueue() = default; + + SafeQueue(SafeQueue const& other) + { + std::lock_guard lk(other.mut); + dataQueue = other.dataQueue; + } + + SafeQueue& operator=(SafeQueue const& other) + { + if (this == &other) { + return *this; + } + std::lock_guard lk(other.mut); + dataQueue = other.dataQueue; + return *this; + } + + std::unique_ptr GetOne() + { + std::lock_guard lk(mut); + if (emptyQueue.empty()) { + return std::make_unique(); + } else { + auto t = move(emptyQueue.back()); + emptyQueue.pop_back(); + return move(t); + } + } + + std::unique_ptr WaitAndGetOne() + { + { + std::lock_guard lk(mut); + if (creatNum < capacity) { + creatNum++; + return std::make_unique(); + } + } + std::unique_lock locker(mut); + dirtyCond.wait(locker, [this] { return !emptyQueue.empty(); }); + auto t = move(emptyQueue.back()); + emptyQueue.pop_back(); + return move(t); + } + + void PutDirty(std::unique_ptr&& t) + { + std::lock_guard lk(mut); + emptyQueue.push_back(move(t)); + dirtyCond.notify_one(); + } + + void Pushv(std::unique_ptr&& t) // 入队操作 + { + std::lock_guard lk(mut); + dataQueue.push_back(move(t)); + dataCond.notify_one(); + } + + std::unique_ptr WaitAndPop() + { + std::unique_lock lk(mut); + dataCond.wait(lk, [this] { return !dataQueue.empty(); }); + std::unique_ptr res = std::move(dataQueue.front()); + dataQueue.pop_front(); + return move(res); + } + + std::unique_ptr TryPop() + { + std::lock_guard lk(mut); + if (dataQueue.empty()) { + return nullptr; + } + std::unique_ptr res = std::move(dataQueue.front()); + dataQueue.pop_front(); + return move(res); + } + + bool Empty() const + { + std::lock_guard lk(mut); + return dataQueue.empty(); + } + + size_t Size() const + { + std::lock_guard lk(mut); + return dataQueue.size(); + } + +private: + mutable std::mutex mut; + uint64_t capacity = DEFAULT_CAP; + std::atomic creatNum {}; + std::list> dataQueue; + std::list> emptyQueue; + std::condition_variable dataCond; + std::condition_variable dirtyCond; +}; + +template +class SingletonQueue { +public: + static SafeQueue* getInstances(int i) + { + static SafeQueue instance[MxRec::MAX_QUEUE_NUM]; + if (i >= MxRec::MAX_QUEUE_NUM || i < 0) { + return nullptr; + } + return &instance[i]; + }; + + SingletonQueue() = delete; + + ~SingletonQueue() = delete; + + SingletonQueue(T&&) = delete; + + SingletonQueue(const T&) = delete; + + void operator=(const T&) = delete; +}; + +#endif \ No newline at end of file diff --git a/src/core/utils/singleton.h b/src/core/utils/singleton.h new file mode 100644 index 00000000..2ec50940 --- /dev/null +++ b/src/core/utils/singleton.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: singleton module. + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#ifndef RC_UTILS_SINGLETON_H +#define RC_UTILS_SINGLETON_H + + +#include +#include + +/** + * T must be destructed + * @tparam T + */ +template +class Singleton { +public: + Singleton() = delete; + + Singleton(const Singleton& singleton) = delete; + + Singleton& operator=(const Singleton& singleton) = delete; + + static T* GetInstance() + { + try { + static T instance; + return &instance; + } catch (std::exception& e) { + std::cerr << " create singleton error" << std::endl; + return nullptr; + } + } + + template + static T* GetInstance(P&& ... args) + { + try { + static T instance(std::forward

(args)...); + return &instance; + } catch (std::exception& e) { + std::cerr << " create singleton error" << std::endl; + return nullptr; + } + } +}; + +#endif diff --git a/src/core/utils/spinlock.h b/src/core/utils/spinlock.h new file mode 100644 index 00000000..95c0a35c --- /dev/null +++ b/src/core/utils/spinlock.h @@ -0,0 +1,175 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: spinlock module + * Author: MindX SDK + * Create: 2022 + * History: NA + */ +#ifndef SRC_UTILS_SPINLOCK_H +#define SRC_UTILS_SPINLOCK_H + +#include +#include +#include // NOLINT + +#define DISALLOW_COPY_MOVE_AND_ASSIGN_(type) \ + type(type const &) = delete; \ + type(type &&) noexcept = delete; \ + type &operator=(type const &) = delete + +static __inline void cpu_pause() { +#ifdef __GNUC__ + #ifdef __aarch64__ + __asm volatile("yield" ::: "memory"); +#elif defined(__i386__) || defined(__x86_64__) + __asm__ __volatile__("rep;nop;nop" ::: "memory"); +#else +#error "unknown architecture" +#endif +#else +#error "unknown architecture" +#endif +} + +static constexpr uint16_t g_kMaxSpinCountBeforeThreadYield = 64; + +#ifdef LOCK_NOTHING + +class SpinLock final { + public: + void lock() noexcept {} + bool try_lock() noexcept { return true; } + void unlock() noexcept {} +}; + +#elif defined(USE_MUTEX) + +class SpinLock final { +public: + void lock() noexcept { mt_.lock(); } + bool try_lock() noexcept { return mt_.try_lock(); } + void unlock() noexcept { mt_.unlock(); } + +private: + std::mutex mt_; +}; + +#else + +class SpinLock final { +public: + SpinLock() = default; + + DISALLOW_COPY_MOVE_AND_ASSIGN_(SpinLock); + + inline void lock() noexcept { + while (true) { + if (!lock_.exchange(true, std::memory_order_acquire)) { + break; + } + + uint16_t counter = 0; + while (lock_.load(std::memory_order_relaxed)) { + cpu_pause(); + if (++counter > g_kMaxSpinCountBeforeThreadYield) { + std::this_thread::yield(); + // reset counter + counter = 0; + } + } + } + } + + inline bool try_lock() noexcept { + if (lock_.load(std::memory_order_relaxed)) { + return false; + } + return !lock_.exchange(true, std::memory_order_acquire); + } + + inline void unlock() noexcept { lock_.store(false, std::memory_order_release); } + +private: + std::atomic lock_{false}; +}; + +class RWSpinLock final { + union LockData { + uint64_t raw; + struct { + uint32_t readers; + uint32_t writer; + } lock; + }; + +public: + RWSpinLock() = default; + + DISALLOW_COPY_MOVE_AND_ASSIGN_(RWSpinLock); + + inline void r_lock() noexcept { + LockData oldData, newData; + while (true) { + uint16_t counter = 0; + for (;;) { + oldData.raw = lock_.load(std::memory_order_relaxed); + if (oldData.lock.writer > 0) { + cpu_pause(); + if (++counter > g_kMaxSpinCountBeforeThreadYield) { + std::this_thread::yield(); + // reset counter + counter = 0; + } + } else { + break; + } + } + + newData.lock.readers = oldData.lock.readers + 1; + newData.lock.writer = 0; + if (lock_.compare_exchange_weak(oldData.raw, newData.raw, + std::memory_order_acquire, + std::memory_order_relaxed)) { + break; + } + } + } + + inline void w_lock() noexcept { + LockData oldData, newData; + while (true) { + uint16_t counter = 0; + for (;;) { + oldData.raw = lock_.load(std::memory_order_relaxed); + if (oldData.raw != 0) { + cpu_pause(); + if (++counter > g_kMaxSpinCountBeforeThreadYield) { + std::this_thread::yield(); + // reset counter + counter = 0; + } + } else { + break; + } + } + + newData.lock.readers = 0; + newData.lock.writer = 1; + if (lock_.compare_exchange_weak(oldData.raw, newData.raw, + std::memory_order_acquire, + std::memory_order_relaxed)) { + break; + } + } + } + + inline void r_unlock() noexcept { --lock_; } + + inline void w_unlock() noexcept { lock_.store(0, std::memory_order_release); } + +private: + std::atomic lock_{0}; +}; + +#endif +#endif diff --git a/src/core/utils/task_queue.h b/src/core/utils/task_queue.h new file mode 100644 index 00000000..14ed1202 --- /dev/null +++ b/src/core/utils/task_queue.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: task queue module + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#ifndef TASK_QUEUE_H +#define TASK_QUEUE_H + +#include +#include +#include +#include + +namespace MxRec { +namespace Common { +template class TaskQueue { +public: + TaskQueue() = default; + + ~TaskQueue() = default; + + TaskQueue(TaskQueue const & other) + { + std::lock_guard lk(other.mut); + dataQueue = other.dataQueue; + } + + TaskQueue &operator = (TaskQueue const & other) + { + if (this == &other) { + return *this; + } + std::lock_guard lk(other.mut); + dataQueue = other.dataQueue; + return *this; + } + + + void Pushv(T &t) + { + std::lock_guard lk(mut); + dataQueue.push_back(std::move(t)); + dataCond.notify_one(); + } + + void Pushv(T &&t) + { + std::lock_guard lk(mut); + dataQueue.emplace_back(t); + dataCond.notify_one(); + } + + T WaitAndPop() + { + std::unique_lock lk(mut); + dataCond.wait(lk, [this] { + if (!finished){ + return !dataQueue.empty(); + } else{ + return true; + } + }); + T res; + if (finished){ + return res; + } + res = dataQueue.front(); + dataQueue.pop_front(); + return res; + } + + void DestroyQueue(){ + finished = true; + dataCond.notify_one(); + } + + + bool Empty() const + { + std::lock_guard lk(mut); + return dataQueue.empty(); + } + + size_t Size() const + { + std::lock_guard lk(mut); + return dataQueue.size(); + } + +private: + mutable std::mutex mut; + std::list dataQueue; + std::condition_variable dataCond; + bool finished = false; +}; +} +} + + +#endif diff --git a/src/core/utils/time_cost.h b/src/core/utils/time_cost.h new file mode 100644 index 00000000..9852d30f --- /dev/null +++ b/src/core/utils/time_cost.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: time cost profile module. + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#ifndef TIMECOST_H +#define TIMECOST_H + +#include + +class TimeCost { +public: + TimeCost() + { + start_ = std::chrono::high_resolution_clock::now(); + } + + double ElapsedSec() + { + std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); + std::chrono::duration d = std::chrono::duration_cast>(end - start_); + return d.count(); + } + + size_t ElapsedMS() + { + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::milliseconds d = std::chrono::duration_cast(end - start_); + return d.count(); + } + +private: + std::chrono::high_resolution_clock::time_point start_; +}; + +#endif \ No newline at end of file diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h new file mode 100644 index 00000000..47d26dbe --- /dev/null +++ b/src/core/utils/unique.h @@ -0,0 +1,774 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: unique keys module + * Author: MindX SDK + * Create: 2022 + * History: NA + */ +#ifndef SRC_UTILS_UNIQUE_H +#define SRC_UTILS_UNIQUE_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "securec.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" + +#include "common.h" +#include "spinlock.h" +#include "time_cost.h" + +using namespace MxRec; + +struct UniqueData { + void *inputData; + size_t dataSize; + int32_t *restore; + int64_t *uniqueVector; + int32_t *splitSize; + int64_t *keySend; + int32_t *idCount; + int32_t *idCountFill; +}; + +struct UniqueFlag { + bool isInt64; + bool useStatic; + bool useHot; +}; + +struct UniqueForHot { + int hotOffset; + int *hotPos; + map &hotMap; + absl::flat_hash_map &keyCountMap; +}; + +struct UniqueThreadNum { + int minThread; + int maxThread; +}; + +class SendCntTooSmallError : public std::exception { +}; + +class GroupMethod { +public: + inline int GroupCount() + { + return groupCount_; + } + inline int GroupId(uint64_t val) + { + return val & (groupCount_ - 1); + } + void SetGroupCount(int count) + { + groupCount_ = count; + } + +private: + int groupCount_; +}; + +class SimpleThreadPool { +public: + void SyncRun(const std::vector> &tasks) + { + std::vector> futs; + for (auto &task : tasks) { + futs.push_back(std::async(task)); + } + for (auto &fut : futs) { + fut.wait(); + } + } +}; +template class Dedup { + static constexpr uint32_t kMinimalWorkloadPerWorker = 1 << 12; + static const int kDefaultBucketCount = 1 << 24; + static const int kDefaultBucketCountMask = kDefaultBucketCount - 1; + + template struct Meta { + static_assert(M <= UNIQUE_MAX_BUCKET_WIDTH, "should be no larger than max bucket width"); + SpinLock lock; + volatile int8_t count; + int8_t pad[3]; + int32_t replace_base; + volatile uint64_t data[M]; + volatile uint64_t idCount[M]; + } __attribute__((__aligned__(64))); + + struct Statistics { + uint64_t totalUniques = 0; + uint64_t totalOverflowUniques = 0; + }; + +public: + Dedup(int bucketCountPower2 = kDefaultBucketCount, int groups = 1) + : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) + { + void *area = aligned_alloc(64, sizeof(Meta) * bucketCount_); + table_ = reinterpret_cast *>(area); + Clear(bucketCount_); + } + + ~Dedup() + { + free(table_); + } + + static size_t BucketSize() + { + return sizeof(Meta); + } + + void Insert(uint64_t val) + { + int32_t h = static_cast(hash(val) & bucketCountMask_); + Meta *bucket = &table_[h]; + + int8_t count = bucket->count; + + int totalCount = 0; + + for (int i = 0; i < count; ++i) { + if (bucket->data[totalCount] == val) { + bucket->idCount[totalCount]++; + // found one + return; + } + totalCount++; + } + // try again, this time with lock acquired + if (count < N) { + std::lock_guard lg(bucket->lock); + for (int i = totalCount; i < bucket->count; ++i) { + if (bucket->data[totalCount] == val) { + bucket->idCount[totalCount]++; + // found one + return; + } + totalCount++; + } + if (totalCount < N) { + bucket->data[totalCount] = val; + bucket->count++; + bucket->idCount[totalCount]++; + return; + } + } + // shift to the overflow reservior + insertOverflow(val); + } + + int32_t GetReplaceOffset(uint64_t val) + { + int32_t h = static_cast(hash(val) & bucketCountMask_); + Meta *bucket = &table_[h]; + + int8_t count = bucket->count; + int totalCount = 0; + for (int i = 0; i < count; ++i) { + if (bucket->data[totalCount] == val) { + // found one + return bucket->replace_base + totalCount; + } + totalCount++; + } + // try again, this time with lock acquired + if (count < N) { + std::lock_guard lg(bucket->lock); + for (int i = totalCount; i < bucket->count; ++i) { + if (bucket->data[totalCount] == val) { + return bucket->replace_base + totalCount; + } + totalCount++; + } + if (totalCount < N) { + return -1; + } + } + return getReplaceOffsetFromOverflow(val); + } + + int32_t GetReplaceOffsetUnsafe(uint64_t val) + { + int32_t h = static_cast(hash(val) & bucketCountMask_); + Meta *bucket = &table_[h]; + + int totalCount = 0; + for (int i = 0; i < bucket->count; ++i) { + if (bucket->data[totalCount] == val) { + // found one + return bucket->replace_base + totalCount; + } + totalCount++; + } + if (totalCount < N) { + return -1; + } + return getReplaceOffsetFromOverflowUnsafe(val); + } + + bool Contains(uint64_t val) + { + int32_t h = static_cast(hash(val) & bucketCountMask_); + Meta *bucket = &table_[h]; + { + std::lock_guard lg(bucket->lock); + int totalCount = 0; + for (int i = 0; i < bucket->count; ++i) { + if (bucket->data[totalCount] == val) { + return true; + } + totalCount++; + } + if (totalCount < N) { + // bucket isn't filled, no hit for sure + return false; + } + } + return checkOverflow(val); + } + + void Clear(uint64_t newBucketCountPowerOf2 = 0) + { + std::lock_guard lg(overflowMutex_); + if (newBucketCountPowerOf2 > 0 && newBucketCountPowerOf2 != (uint64_t)bucketCount_) { + free(table_); + bucketCount_ = newBucketCountPowerOf2; + bucketCountMask_ = bucketCount_ - 1; + table_ = reinterpret_cast *>(aligned_alloc(64, sizeof(Meta) * bucketCount_)); + } + bzero(table_, sizeof(Meta) * bucketCount_); + overflow_.clear(); + idCountOverflow_.clear(); + } + + void NewParameter() + { + int32_t newBucketCountPowerOf2 = bucketCount_; + + if (stats_.totalUniques > 0 && stats_.totalOverflowUniques > kMinimalWorkloadPerWorker) { + // Time to check the proper size of sharded tables for performance + // sake. + uint64_t shardedTableSize = newBucketCountPowerOf2 * N * groupCount_; + int largeCount = 0; + while (shardedTableSize > stats_.totalUniques * 4 && largeCount_ != 1) { + // too large + newBucketCountPowerOf2 >>= 1; + shardedTableSize >>= 1; + largeCount++; + } + + int count = ((largeCount == 1) && (largeCount != largeCount_)) ? 2 : 1; + for (int i = 0; i < count; i++) { + if (stats_.totalOverflowUniques > kMinimalWorkloadPerWorker) { + newBucketCountPowerOf2 <<= 1; + shardedTableSize <<= 1; + } + } + + while (shardedTableSize < stats_.totalUniques + (stats_.totalUniques >> 2)) { + newBucketCountPowerOf2 <<= 1; + shardedTableSize <<= 1; + } + + if (largeCount_ != 1) { + largeCount_ = largeCount; + } + } + + Clear(newBucketCountPowerOf2); + bucketCount_ = newBucketCountPowerOf2; + stats_.totalUniques = 0; + stats_.totalOverflowUniques = 0; + } + + // Warning: functions below are not thread safe! + // + // Return the unique values + // Also update the hash-order base of each bucket + std::vector Unique() + { + int32_t replace_offset = 0; + std::vector output; + + for (int i = 0; i < bucketCount_; ++i) { + Meta *bucket = &table_[i]; + if (bucket->count == 0) { // 如果桶为0,则跳过 + continue; + } + bucket->replace_base = replace_offset; // 取桶的偏移量 + for (int j = 0; j < bucket->count; ++j) { + auto data = bucket->data[j]; + output.push_back(data); // 挨个桶取数据,然后填到output中去 + } + replace_offset += bucket->count; + } + auto it = overflow_.begin(); // 取overflow里面的,也添加到output中去 + while (it != overflow_.end()) { + output.push_back(it->first); + it->second = replace_offset++; // 记录偏移量++ + ++it; + } + return output; + } + + // Used by ShardedDedup Only! + uint32_t UniqueRaw(int64_t *output, uint32_t priorTotal, int32_t *idCount) + { + uint32_t total = priorTotal; + int32_t replace_offset = priorTotal; + + for (int i = 0; i < bucketCount_; ++i) { + Meta *bucket = &table_[i]; + if (bucket->count == 0) { + continue; + } + bucket->replace_base = replace_offset; + for (int j = 0; j < bucket->count; ++j) { + idCount[total] = bucket->idCount[j]; + output[total++] = bucket->data[j]; + } + replace_offset += bucket->count; + } + auto it = overflow_.begin(); + int32_t totalOverflow = 0; + while (it != overflow_.end()) { + idCount[total] = idCountOverflow_[it->first]; + output[total++] = it->first; + it->second = replace_offset++; + ++it; + ++totalOverflow; + } + + // set total overflow count + stats_.totalUniques = total - priorTotal; + stats_.totalOverflowUniques = totalOverflow; + return total - priorTotal; + } + + void handleHotKey(int key, map &hotMap, map &hotPosMap, int &hotCount) { + auto hot = hotMap.find(key); + if (hot != hotMap.end()) { + if (hot->second == -1) { + int pos = hotCount; + hotMap[key] = pos; + hotPosMap[key] = pos; + hotCount++; + } else { + hotPosMap[key] = -1; + } + } + } + + uint32_t UniqueRawForHot(int64_t *output, uint32_t priorTotal, int32_t* idCount, + map &hotMap, map &hotPosMap, int &hotCount, + absl::flat_hash_map &keyCountMap) + { + uint32_t total = priorTotal; + int32_t replace_offset = priorTotal; + + for (int i = 0; i < bucketCount_; ++i) { + Meta *bucket = &table_[i]; + if (bucket->count == 0) { + continue; + } + bucket->replace_base = replace_offset; + for (int j = 0; j < bucket->count; ++j) { + idCount[total] = bucket->idCount[j]; + output[total++] = bucket->data[j]; + handleHotKey(bucket->data[j], hotMap, hotPosMap, hotCount); + keyCountMap[bucket->data[j]] = bucket->idCount[j]; + } + replace_offset += bucket->count; + } + auto it = overflow_.begin(); + int32_t totalOverflow = 0; + while (it != overflow_.end()) { + idCount[total] = idCountOverflow_[it->first]; + keyCountMap[it->first] = idCountOverflow_[it->first]; + output[total++] = it->first; + handleHotKey(it->first, hotMap, hotPosMap, hotCount); + it->second = replace_offset++; + ++it; + ++totalOverflow; + } + + // set total overflow count + stats_.totalUniques = total - priorTotal; + stats_.totalOverflowUniques = totalOverflow; + return total - priorTotal; + } + + std::vector Replacement(const std::vector &input, std::vector *unique = nullptr, + int32_t base = 0) + { + std::vector output; + if (unique) { + *unique = std::move(Unique()); + } + for (auto &val : input) { + output.push_back(GetReplaceOffsetUnsafe(val) + base); + } + return output; + } + +private: + int bucketCount_; + int bucketCountMask_; + int upperRangeIndex_; + int groupCount_; + int largeCount_ { 0 }; + Meta *table_; + std::unordered_map overflow_; + std::unordered_map idCountOverflow_; + SpinLock overflowMutex_; + Statistics stats_; + + static inline uint64_t hash(uint64_t val) + { + return val ^ (val >> 16) ^ (val >> 32) ^ (val >> 48); + } + + void insertOverflow(uint64_t val) + { + std::lock_guard lg(overflowMutex_); + auto it = overflow_.find(val); + if (it == overflow_.end()) { + overflow_[val] = 0; + } + idCountOverflow_[val]++; + } + + bool checkOverflow(uint64_t val) + { + std::lock_guard lg(overflowMutex_); + return overflow_.find(val) != overflow_.end(); + } + + int32_t getReplaceOffsetFromOverflow(uint64_t val) + { + std::lock_guard lg(overflowMutex_); + auto it = overflow_.find(val); + return (it != overflow_.end()) ? it->second : -1; + } + + int32_t getReplaceOffsetFromOverflowUnsafe(uint64_t val) + { + auto it = overflow_.find(val); + return (it != overflow_.end()) ? it->second : -1; + } +}; // Dedup + +#define CACHE_LINE_ALIGN(size) (((size) + 63ul) & ~63ul) + +class OneSimpleGroupMethod { +public: + inline int GroupCount() + { + return 1; + } + inline int GroupId(uint64_t val) + { + return 0; + } +}; + +template class ShardedDedup { + static constexpr uint32_t kMinimalWorkloadPerWorker = 1 << 13; + static constexpr int kDefaultDuplicateRatio = 4; + static constexpr int kMinimalWorkerCount = 2; + static constexpr int kMaximalWorkerCount = 32; + +public: + using DedupT = Dedup; + + ShardedDedup(const GroupMethod &groupMethod, int desiredSize, int send_cnt, + int estimatedDuplicateRatio = kDefaultDuplicateRatio) + : groupMethod_(groupMethod), bucketCountPower2_(256), send_cnt_(send_cnt) + { + const int numOfGroupsInShard = groupMethod_.GroupCount(); + + desiredSize += (desiredSize >> 1); + while (bucketCountPower2_ * BucketWidth * numOfGroupsInShard * estimatedDuplicateRatio < desiredSize) { + bucketCountPower2_ <<= 1; + } + for (int32_t i = 0; i < numOfGroupsInShard; ++i) { + dedupShards_.emplace_back(new DedupT(bucketCountPower2_, numOfGroupsInShard)); + } + } + + ~ShardedDedup() {} + + const int NumOfGroupsInEachShard() + { + return groupMethod_.GroupCount(); + } + + /* * + * @brief given the input vector, compute unique values and partition + * them into regions delimited by the partition boundaries passed + * as ctor input (see above) + * + * + * @param pool thread pool which is used by unique task + * @param input the data input + * @param size the size of the data input + * @param uniqueVector unique values + * @param uniqueSize unique of sizes + * @param output the output vector of index values + * @param uniqueIds unique ids final + * @param idCount key count + * @param idCountFill key count and filled zero by send count + * @param isStatic output and idCount Fill isFilled + * @param isInt64 input data is int64 or int32 + * @param useHot hot embedding + * @param offset add hot map size + * @param hotMap hot key map + * @param keyCountMap record key count + * @param minThreadCount min thread number + * @param maxThreadCount max thread number + */ + template + int Compute(ThreadPool *pool, UniqueData &uniqueData, UniqueFlag &uniqueFlag, + UniqueForHot &uniqueForHot, UniqueThreadNum &uniqueThreadNum) + { + // Now kick off the computation + + void *input = uniqueData.inputData; + const size_t size = uniqueData.dataSize; + int64_t *uniqueVector = uniqueData.uniqueVector; + int32_t *uniqueSize = uniqueData.splitSize; + int32_t *output = uniqueData.restore; + int64_t *uniqueIds = uniqueData.keySend; + int32_t *idCount = uniqueData.idCount; + int32_t *idCountFill = uniqueData.idCountFill; + + map &hotMap = uniqueForHot.hotMap; + absl::flat_hash_map &keyCountMap = uniqueForHot.keyCountMap; + int offset = uniqueForHot.hotOffset; + int *hotPos = uniqueForHot.hotPos; + + bool useStatic = uniqueFlag.useStatic; + bool useHot = uniqueFlag.useHot; + bool isInt64 = uniqueFlag.isInt64; + + uint32_t minThreadCount = uniqueThreadNum.minThread; + uint32_t maxThreadCount = uniqueThreadNum.maxThread; + + std::vector uniqueSizeVector; + uniqueSizeVector.resize(groupMethod_.GroupCount()); + + size_t inputSize = size; + + uint32_t threadNum = (inputSize + kMinimalWorkloadPerWorker - 1) / kMinimalWorkloadPerWorker; + threadNum = std::min(maxThreadCount, std::max(threadNum, minThreadCount)); + + size_t partSize = (inputSize + threadNum - 1) / threadNum; + + std::vector> tasks; + + for (uint32_t i = 0; i < threadNum; ++i) { + const int numOfGroupsInShard = groupMethod_.GroupCount(); + tasks.push_back([this, i, input, inputSize, partSize, numOfGroupsInShard, isInt64]() -> TaskReturnType { + for (uint64_t j = i * partSize; j < std::min(inputSize, (i + 1) * partSize); ++j) { + auto val = isInt64 ? ((int64_t *)input)[j] : ((int32_t *)input)[j]; + auto group = groupMethod_.GroupId(val); + dedupShards_[group]->Insert(val); + } + return TaskReturnType {}; + }); + } + spdlog::debug("unique finish insert"); + + if (!tasks.empty()) { + pool->SyncRun(tasks); + } + + std::vector baseVector; + // Collect Unique and base vectors + uint64_t base = 0; + uint64_t total = 0; + + int hotCount = 0; + map hotPosMap; + + for (int j = 0; j < groupMethod_.GroupCount(); ++j) { + uint64_t inGroupTotal = 0; + if (useHot) { + inGroupTotal = dedupShards_[j]->UniqueRawForHot(uniqueVector, total, idCount, + hotMap, hotPosMap, hotCount, + keyCountMap); + } else { + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueVector, total, idCount); + } + uniqueSizeVector[j] = inGroupTotal; + total += inGroupTotal; + } + + spdlog::debug("unique finish uniqueRaw"); + + baseVector.push_back(base); + base += total; + + partSize = CACHE_LINE_ALIGN(partSize); + + int32_t *beginPtr = output; + int32_t *finishPtr = beginPtr + inputSize; + + int32_t *partBeginPtr = beginPtr; + int32_t *partEndPtr = + reinterpret_cast(CACHE_LINE_ALIGN(reinterpret_cast(partBeginPtr + partSize))); + + if(uniqueFlag.useStatic){ + for (int i = 0; i < groupMethod_.GroupCount(); i++) { + if (send_cnt_ < uniqueSizeVector[i]){ + spdlog::error("sendCnt should not be smaller than uniqueSize, sendCnt {}, uniqueSize {}", send_cnt_, uniqueSizeVector[i]); + throw SendCntTooSmallError(); + } + } + } + + std::vector totalUniqueSize; + totalUniqueSize.resize(groupMethod_.GroupCount()); + + size_t totalNumber = 0; + for (int i = 0; i < groupMethod_.GroupCount(); i++) { + totalUniqueSize[i] = totalNumber; + totalNumber += uniqueSizeVector[i]; + } + spdlog::debug("uniqueSize: {}", totalNumber); + + tasks.clear(); + while (partBeginPtr < finishPtr) { + if (partEndPtr > finishPtr) { + partEndPtr = finishPtr; + } + if (partBeginPtr < partEndPtr) { + // Due to cacheline alignment computation, the actual number of + // threads created here may not match threadNum exactly but + // should be +/-1 off. + const int numOfGroupsInShard = groupMethod_.GroupCount(); + tasks.push_back([this, input, &baseVector, beginPtr, partBeginPtr, partEndPtr, numOfGroupsInShard, + totalUniqueSize, useStatic, isInt64, useHot, offset, hotMap, hotPos, hotPosMap]() -> TaskReturnType { + for (int32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { + auto val = isInt64 ? ((int64_t *)input)[ptr - beginPtr] : ((int32_t *)input)[ptr - beginPtr]; + auto group = groupMethod_.GroupId(val); + uint32_t fillOffset = GetFillOffset(useStatic, baseVector, totalUniqueSize, val, group); + ComputeRestore(useHot, offset, hotMap, hotPos, hotPosMap, ptr, val, fillOffset); + } + return TaskReturnType {}; + }); + } + partBeginPtr = partEndPtr; + partEndPtr += partSize; + } + + if (!tasks.empty()) { + pool->SyncRun(tasks); + } + + + TileAndFill(groupMethod_.GroupCount(), uniqueVector, uniqueSize, uniqueIds, idCount, idCountFill, useStatic, uniqueSizeVector); + + return 0; + } + + void ComputeRestore(bool useHot, int offset,const map &hotMap, int *hotPos,const map &hotPosMap, + int32_t *ptr, int64_t val, uint32_t fillOffset) const { + auto hot = hotPosMap.find(val); + if (!useHot) { + *ptr = fillOffset; + } else if (hot == hotPosMap.end()) { + *ptr = offset + fillOffset; + } else if (hot->second == -1) { + *ptr = hotMap.find(val)->second; + } else { + hotPos[hot->second] = fillOffset; + *ptr = hot->second; + } + } + + uint32_t GetFillOffset(bool useStatic, const vector &baseVector, const vector &totalUniqueSize, + int64_t val, int32_t group) { + if (!useStatic) { + return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0]; + } else { + return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0] + send_cnt_ * group - totalUniqueSize[group]; + } + } + + void TileAndFill(int groupCount, const int64_t *uniqueVector, int32_t *uniqueSize, int64_t *uniqueIds, const int32_t *idCount, + int32_t *idCountFill, bool useStatic, const std::vector &uniqueSizeVector) const { + int start = 0; + int index = 0; + + for (int i = 0; i < groupCount; i++) { + if (i > 0) { + index += uniqueSizeVector[i - 1]; + } + + if (useStatic) { + start = i * send_cnt_; + } else { + start = index; + } + + size_t mem_size = uniqueSizeVector[i] * sizeof(int64_t); + auto rc = memcpy_s(uniqueIds + start, mem_size, uniqueVector + index, mem_size); + if (rc != 0) { + spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}",mem_size); + return; + } + + mem_size = uniqueSizeVector[i] * sizeof(int32_t); + rc = memcpy_s(idCountFill + start, mem_size, idCount + index, mem_size); + if (rc != 0) { + spdlog::error("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", mem_size); + return; + } + + int fillLen = send_cnt_ - uniqueSizeVector[i]; + if (useStatic) { + for (int j = 0; j < fillLen; j++) { + uniqueIds[start + uniqueSizeVector[i] + j] = -1; + idCountFill[start + uniqueSizeVector[i] + j] = 0; + } + } + + uniqueSize[i] = uniqueSizeVector[i]; + } + } + + void StartNewRound() + { + for (auto &s : dedupShards_) { + s->NewParameter(); + } + } + +private: + GroupMethod groupMethod_; + int32_t bucketCountPower2_; + std::vector> dedupShards_; + int32_t send_cnt_; +}; +#endif \ No newline at end of file diff --git a/src/ops_tf/CMakeLists.txt b/src/ops_tf/CMakeLists.txt new file mode 100644 index 00000000..055cb842 --- /dev/null +++ b/src/ops_tf/CMakeLists.txt @@ -0,0 +1,7 @@ +cmake_minimum_required(VERSION 3.12) +set(CMAKE_CXX_STANDARD 14) + +file(GLOB_RECURSE MXREC_OP_SRC ./*.cpp) +add_library(asc_ops SHARED ${MXREC_OP_SRC}) +target_link_libraries(asc_ops PUBLIC ASC) +install(TARGETS asc_ops LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) \ No newline at end of file diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp new file mode 100644 index 00000000..48aaeee2 --- /dev/null +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -0,0 +1,766 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: dataset ops. + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/example/example.pb.h" + +#include "securec.h" + +#include "key_process/key_process.h" +#include "key_process/feature_admit_and_evict.h" +#include "utils/common.h" +#include "utils/safe_queue.h" +#include "utils/singleton.h" + +#include "utils/time_cost.h" + +using namespace tensorflow; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; +using namespace std; +using namespace chrono; +using namespace MxRec; + +using OpKernelConstructionPtr = OpKernelConstruction*; +using OpKernelContextPtr = OpKernelContext*; +using InferenceContextPtr = ::tensorflow::shape_inference::InferenceContext*; + +spdlog::stopwatch staticSw {}; +spdlog::stopwatch staticReadRaw {}; +array, MAX_CHANNEL_NUM> batchIdsInfo {}; + +REGISTER_OP("ClearChannel").Attr("channel_id : int"); + +class ClearChannel : public OpKernel { +public: + explicit ClearChannel(OpKernelConstructionPtr context) : OpKernel(context) + { + OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); + + if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", + fmt::format("ClearChannel channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:{})", + MAX_CHANNEL_NUM))); + return; + } + } + + ~ClearChannel() = default; + + void Compute(OpKernelContextPtr context) override + { + spdlog::info("clear channel {}", channelId); + batchIdsInfo.at(channelId) = 0; + } + +private: + int channelId {}; +}; + +REGISTER_KERNEL_BUILDER(Name("ClearChannel").Device(DEVICE_CPU), ClearChannel); + + +// ##################### ReturnTimestamp ####################### +REGISTER_OP("ReturnTimestamp") + .Input("input: int64") + .Output("output: int64") + .SetShapeFn([](InferenceContextPtr c) { + c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); + return Status::OK(); + }); +class ReturnTimestamp : public OpKernel { +public: + explicit ReturnTimestamp(OpKernelConstructionPtr context) : OpKernel(context) + {} + + ~ReturnTimestamp() = default; + + void Compute(OpKernelContextPtr context) override + { + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = time(nullptr); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ReturnTimestamp").Device(DEVICE_CPU), ReturnTimestamp); + +// ##################### ReadEmbKeyV2Dynamic ####################### +REGISTER_OP("ReadEmbKeyV2Dynamic") + .Input("sample: T") + .Input("splits: int32") + .Output("output: int32") + .Attr("T: {int64, int32}") + .Attr("channel_id: int") + .Attr("emb_name: list(string)") // for which table to lookup + .Attr("timestamp: bool") // use for feature evict, (unix timestamp) + .SetShapeFn([](InferenceContextPtr c) { + c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); + return Status::OK(); + }); + +class ReadEmbKeyV2Dynamic : public OpKernel { +public: + explicit ReadEmbKeyV2Dynamic(OpKernelConstructionPtr context) : OpKernel(context) + { + spdlog::cfg::load_env_levels(); + spdlog::default_logger()->set_pattern("[%H:%M:%S.%e] [%^%l%$] %v"); + spdlog::debug("ReadEmbKeyV2Dynamic init"); + OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference + OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); + OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); + + // 特征准入&淘汰功能 相关校验 + + // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 + if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() && + !FeatureAdmitAndEvict::IsThresholdCfgOK(FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp)) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", + fmt::format("threshold config, or timestamp error ..."))); + return; + } + + if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", + fmt::format("ReadEmbKeyV2Dynamic channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:{})", + MAX_CHANNEL_NUM))); + return; + } + batchIdsInfo.at(channelId) = 0; + + auto keyProcess = Singleton::GetInstance(); + if (!keyProcess->isRunning) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); + return; + } + maxStep = keyProcess->GetMaxStep(channelId); + } + ~ReadEmbKeyV2Dynamic() = default; + + void Compute(OpKernelContextPtr context) override + { + EASY_FUNCTION(); + spdlog::debug("enter ReadEmbKeyV2Dynamic"); + spdlog::stopwatch sw; + int batchId = batchIdsInfo.at(channelId).fetch_add(1); + if (channelId == 1) { + if (maxStep != -1 && batchId >= maxStep) { + spdlog::warn("skip excess batch after {}/{}", batchId, maxStep); + return; + } + } + const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); + const auto& splits = context->input(TENSOR_INDEX_1).flat(); + int fieldNum = 0; + for (int i = 0; i < splits.size(); ++i) { + fieldNum += splits(i); + } + size_t dataSize = inputTensor.NumElements(); + + time_t timestamp = -1; + // 如果传递了时间戳,解析和校验 + if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", + fmt::format("timestamp[{}] error, skip excess batch after {}/{}", timestamp, batchId, maxStep))); + return; + } + // 保证所有embNames在m_embStatus中有状态记录 + SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); + + // [batchId % KEY_PROCESS_THREAD] which thread process this batch + // [KEY_PROCESS_THREAD * 0 or 1] train or inference + int batchQueueId = batchId % KEY_PROCESS_THREAD + KEY_PROCESS_THREAD * channelId; + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = batchId; + EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); + TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}, " + "splits: {}, dataSize: {}, filedNum: {}", + TO_MS(sw), TO_MS(staticSw), + channelId, batchId, splits.size(), dataSize, fieldNum); + staticSw.reset(); + } + + void EnqueueBatchData(std::vector ids, time_t timestamp, + const Tensor& inputTensor, const TTypes::ConstFlat& splits) + { + auto queue = SingletonQueue::getInstances(ids[1]); + size_t offset = 0; + if (isTimestamp) { + offset += 1; // 前面8个字节是unix时间戳 + } + for (int i = 0; i < splits.size(); ++i) { + auto batchData = queue->WaitAndGetOne(); // get dirty or empty data block + batchData->name = embNames.at(i); + size_t len = splits(i); + batchData->channel = channelId; + batchData->batchId = ids[0]; + batchData->batchSize = len; + if (isTimestamp) { + batchData->timestamp = timestamp; + } + spdlog::debug("batch[{}/{}] flatten bs: {}", ids[0], i+1, len); + std::unique_ptr batch = TensorCopy(inputTensor, move(batchData), len, offset); + if (batch == nullptr) { + spdlog::error("batch can not be null"); + return; + } + queue->Pushv(move(batch)); + } + TIME_PRINT(KEY_PROCESS "EnqueueBatchData, batchId:{}, channelId:{}", ids[0], channelId); + } + + std::unique_ptr TensorCopy(const Tensor& inputTensor, std::unique_ptr batchData, + const size_t& len, size_t& offset) + { + if (len == 0) { + spdlog::error("len can not be zero"); + return nullptr; + } + TimeCost ct; + void* src = nullptr; + size_t memSize; + if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { + batchData->isInt64 = false; + memSize = len * sizeof(int32_t); + src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data(). + data())) + offset); + } else { + batchData->isInt64 = true; + memSize = len * sizeof(int64_t); + src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data(). + data())) + offset); + } + batchData->tensorAddr = malloc(memSize); + if (batchData->tensorAddr == nullptr) { + spdlog::error("mmemory allocation failded..."); + } + void* dst = reinterpret_cast(batchData->tensorAddr); + auto rc = memcpy_s(dst, memSize, src, memSize); + if (rc != 0) { + spdlog::error("[ReadEmbKeyV2Dynamic]memcpy_s failded... memSize: {}", memSize); + } + TIME_PRINT("copy TimeCost(ms):{}", ct.ElapsedMS()); + offset += len; + return move(batchData); + } + + bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp, + size_t& dataSize) + { + if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 + spdlog::error("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); + return false; + } + + // 前面8个字节、即占一个featureId位,是unix时间戳 + auto src = (const time_t*)inputTensor.tensor_data().data(); + std::copy(src, src + 1, ×tamp); + spdlog::info("current batchId[{}] timestamp[{}]", batchId, timestamp); + dataSize -= 1; + + if (timestamp <= 0) { + spdlog::error("timestamp[{}] <= 0 ", timestamp); + return false; + } + + return true; + } + + void SetCurrEmbNamesStatus(const vector& embeddingNames, + absl::flat_hash_map& embStatus) + { + for (size_t i = 0; i < embeddingNames.size(); ++i) { + auto it = embStatus.find(embeddingNames[i]); + // 对配置了的,进行校验 + if (it == embStatus.end()) { + // 没有配置的,则不需要“准入&淘汰”功能 + embStatus.insert(std::pair(embeddingNames[i], SingleEmbTableStatus::SETS_NONE)); + } + } + } + + int channelId {}; + vector embNames {}; + int maxStep = 0; + bool isTimestamp { false }; +}; + +REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2Dynamic").Device(DEVICE_CPU), ReadEmbKeyV2Dynamic); + +// ##################### ReadEmbKeyV2 ####################### +REGISTER_OP("ReadEmbKeyV2") + .Input("sample: T") + .Output("output: int32") + .Attr("T: {int64, int32}") + .Attr("channel_id: int") + .Attr("splits: list(int)") + .Attr("emb_name: list(string)") // for which table to lookup + .Attr("timestamp: bool") // use for feature evict, (unix timestamp) + .SetShapeFn([](InferenceContextPtr c) { + c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); + return Status::OK(); + }); + +class ReadEmbKeyV2 : public OpKernel { +public: + explicit ReadEmbKeyV2(OpKernelConstructionPtr context) : OpKernel(context) + { + spdlog::cfg::load_env_levels(); + spdlog::default_logger()->set_pattern("[%H:%M:%S.%e] [%^%l%$] %v"); + spdlog::debug("ReadEmbKeyV2 init"); + OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference + OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); + OP_REQUIRES_OK(context, context->GetAttr("splits", &splits)); // 每个表的field Number + OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); + fieldNum = accumulate(splits.begin(), splits.end(), 0); + + // 特征准入&淘汰功能 相关校验 + + // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 + if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() && + !FeatureAdmitAndEvict::IsThresholdCfgOK(FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp)) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", + fmt::format("threshold config, or timestamp error ..."))); + return; + } + + if (splits.size() != embNames.size()) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", + fmt::format("splits & embNames size error.{} {}", splits.size(), embNames.size()))); + return; + } + if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", + fmt::format("ReadEmbKeyV2 channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:{})", + MAX_CHANNEL_NUM))); + return; + } + batchIdsInfo.at(channelId) = 0; + + auto keyProcess = Singleton::GetInstance(); + if (!keyProcess->isRunning) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); + return; + } + maxStep = keyProcess->GetMaxStep(channelId); + } + + ~ReadEmbKeyV2() = default; + + void Compute(OpKernelContextPtr context) override + { + EASY_FUNCTION(); + spdlog::debug("enter ReadEmbKeyV2"); + spdlog::stopwatch sw; + int batchId = batchIdsInfo.at(channelId)++; + Tensor* output = nullptr; + if (channelId == 1) { + if (maxStep != -1 && batchId >= maxStep) { + spdlog::warn("skip excess batch after {}/{}", batchId, maxStep); + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = batchId; + return; + } + } + const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); + size_t dataSize = inputTensor.NumElements(); + + time_t timestamp = -1; + // 如果传递了时间戳,解析和校验 + if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", + fmt::format("timestamp[{}] error, skip excess batch after {}/{}", timestamp, batchId, maxStep))); + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = batchId; + return; + } + // 保证所有embNames在m_embStatus中有状态记录 + SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); + + // [batchId % KEY_PROCESS_THREAD] which thread process this batch + // [KEY_PROCESS_THREAD * 0 or 1] train or inference + int batchQueueId = batchId % KEY_PROCESS_THREAD + KEY_PROCESS_THREAD * channelId; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = batchId; + + TimeCost tc; + EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); + TIME_PRINT("EnqueueBatchData TimeCost(ms):{}", tc.ElapsedMS()); + + TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", TO_MS(sw), + TO_MS(staticSw), channelId, batchId); + staticSw.reset(); + } + + int EnqueueBatchData(int batchId, int batchQueueId, time_t timestamp, const Tensor& inputTensor) + { + auto queue = SingletonQueue::getInstances(batchQueueId); + size_t offset = 0; + if (isTimestamp) { + offset += 1; // 前面8个字节是unix时间戳 + } + TimeCost ctAll; + for (size_t i = 0; i < splits.size(); ++i) { + TimeCost tp; + auto batchData = queue->WaitAndGetOne(); // get dirty or empty data block + TIME_PRINT("TryPopTimeCost(ms):{}", tp.ElapsedMS()); + + batchData->name = embNames.at(i); + size_t len = splits.at(i); + batchData->channel = channelId; + batchData->batchId = batchId; + batchData->batchSize = len; + TimeCost fz; + if (isTimestamp) { + batchData->timestamp = timestamp; + } + TIME_PRINT("fz TimeCost(ms):{}", fz.ElapsedMS()); + + std::unique_ptr batch = TensorCopy(inputTensor, move(batchData), len, offset); + if (batch == nullptr) { + spdlog::error("batch can not be null"); + return -1; + } + queue->Pushv(move(batch)); + } + TIME_PRINT("all copy TimeCost(ms):{}", ctAll.ElapsedMS()); + return 0; + } + + std::unique_ptr TensorCopy(const Tensor& inputTensor, std::unique_ptr batchData, + const size_t& len, size_t& offset) + { + if (len == 0) { + spdlog::error("len can not be zero"); + return nullptr; + } + TimeCost ct; + void* src = nullptr; + size_t memSize; + if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { + batchData->isInt64 = false; + memSize = len * sizeof(int32_t); + src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data() + .data())) + offset); + } else { + batchData->isInt64 = true; + memSize = len * sizeof(int64_t); + src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data() + .data())) + offset); + } + batchData->tensorAddr = malloc(memSize); + if (batchData->tensorAddr == nullptr) { + spdlog::error("mmemory allocation failded..."); + } + void* dst = reinterpret_cast(batchData->tensorAddr); + auto rc = memcpy_s(dst, memSize, src, memSize); + if (rc != 0) { + spdlog::error("[ReadEmbKeyV2Static]memcpy_s failded... memSize: {}", memSize); + } + TIME_PRINT("copy TimeCost(ms):{}", ct.ElapsedMS()); + offset += len; + return move(batchData); + } + + bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp, + size_t& dataSize) + { + if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 + spdlog::error("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); + return false; + } + + // 前面8个字节、即占一个featureId位,是unix时间戳 + auto src = (const time_t*)inputTensor.tensor_data().data(); + std::copy(src, src + 1, ×tamp); + spdlog::info("current batchId[{}] timestamp[{}]", batchId, timestamp); + dataSize -= 1; + + if (timestamp <= 0) { + spdlog::error("timestamp[{}] <= 0 ", timestamp); + return false; + } + + return true; + } + void SetCurrEmbNamesStatus(const vector& embeddingNames, + absl::flat_hash_map& embStatus) + { + for (size_t i = 0; i < embeddingNames.size(); ++i) { + auto it = embStatus.find(embeddingNames[i]); + // 对配置了的,进行校验 + if (it == embStatus.end()) { + // 没有配置的,则不需要“准入&淘汰”功能 + embStatus.insert(std::pair(embeddingNames[i], SingleEmbTableStatus::SETS_NONE)); + } + } + } + + int channelId {}; + vector splits {}; + int fieldNum {}; + vector embNames {}; + int maxStep = 0; + bool isTimestamp { false }; +}; + +REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), ReadEmbKeyV2); + +// ##################### ReadEmbKeyDatasetDummy ####################### +REGISTER_OP("ReadEmbKeyDatasetDummy") + .Input("sample: T") + .Output("lookup_vec: int32") + .Output("restore_vec: int32") + .Attr("T: {int64}") + .Attr("max_lookup_len: int") + .SetShapeFn([](InferenceContextPtr c) { + int temp; + TF_RETURN_IF_ERROR(c->GetAttr("max_lookup_len", &temp)); + c->set_output(TensorIndex::TENSOR_INDEX_0, c->Vector(temp)); + c->set_output(TensorIndex::TENSOR_INDEX_1, c->input(TensorIndex::TENSOR_INDEX_0)); + return Status::OK(); + }); + +class ReadEmbKeyDatasetDummy : public OpKernel { +public: + explicit ReadEmbKeyDatasetDummy(OpKernelConstructionPtr context) : OpKernel(context) + { + OP_REQUIRES_OK(context, context->GetAttr("max_lookup_len", &lookupLen)); + } + + ~ReadEmbKeyDatasetDummy() override = default; + + void Compute(OpKernelContextPtr context) override + { + EASY_FUNCTION(); + spdlog::stopwatch sw; + const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); + auto input = inputTensor.flat(); + const int restoreLen = static_cast(input.size()); + + // write lookup & restore vec + Tensor* lookupVec = nullptr; + Tensor* restoreVecTensor = nullptr; + + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape { lookupLen }, &lookupVec)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape { restoreLen }, &restoreVecTensor)); + auto l = lookupVec->flat(); + auto r = restoreVecTensor->flat(); + + // dummy data + for (int i { 0 }; i < lookupLen; ++i) { + l(i) = i; + } + for (int i { 0 }; i < restoreLen; ++i) { + r(i) = i % lookupLen; + } + spdlog::warn("dummy read batch cost: {},elapsed from last {}", TO_MS(sw), TO_MS(staticSw)); + staticSw.reset(); + } + + int lookupLen {}; +}; + +REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyDatasetDummy").Device(DEVICE_CPU), ReadEmbKeyDatasetDummy); + + +// ##################### ReadRaw ####################### +REGISTER_OP("ReadRaw") + .Input("sample: string") + .Output("int_output: int64") + .Output("float_output: float") + .Attr("int_len: int") + .Attr("float_len: int") + .Attr("feat_order: list(string)") + .SetShapeFn([](InferenceContextPtr c) { + int temp; + TF_RETURN_IF_ERROR(c->GetAttr("int_len", &temp)); + c->set_output(TENSOR_INDEX_0, c->Vector(temp)); + TF_RETURN_IF_ERROR(c->GetAttr("float_len", &temp)); + c->set_output(TENSOR_INDEX_1, c->Vector(temp)); + return Status::OK(); + }); + +class ReadRaw : public OpKernel { +public: + explicit ReadRaw(OpKernelConstructionPtr context) : OpKernel(context) + { + OP_REQUIRES_OK(context, context->GetAttr("int_len", &intLen)); + OP_REQUIRES_OK(context, context->GetAttr("float_len", &floatLen)); + OP_REQUIRES_OK(context, context->GetAttr("feat_order", &featOrder)); + sampleId = 0; + } + + ~ReadRaw() override = default; + + void Compute(OpKernelContextPtr context) override + { + spdlog::stopwatch sw; + Tensor* intTensor = nullptr; + Tensor* floatTensor = nullptr; + int intDataIndex = 0; + int floatDataIndex = 0; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape { intLen }, &intTensor)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape { floatLen }, &floatTensor)); + const Tensor& inputTensor = context->input(TENSOR_INDEX_0); + auto input = inputTensor.flat()(0); + tensorflow::Example example; + if (!example.ParseFromString(input)) { + cerr << "Failed to parse file." << endl; + } + spdlog::stopwatch sw_copy; + auto all_feature_map = example.features().feature(); + for (const auto& featName: featOrder) { + auto& cur_feature_value = all_feature_map.at(featName); + if (cur_feature_value.has_int64_list()) { + auto int64List = cur_feature_value.int64_list(); + int64* flat = intTensor->flat().data() + intDataIndex; + std::copy(int64List.value().begin(), int64List.value().end(), flat); + intDataIndex += int64List.value_size(); + } + if (cur_feature_value.has_float_list()) { + auto floatList = cur_feature_value.float_list(); + float* flat = floatTensor->flat().data() + floatDataIndex; + std::copy(floatList.value().begin(), floatList.value().end(), flat); + floatDataIndex += floatList.value_size(); + } + } + spdlog::info("ReadRaw sampleId:{} cost:{} copy:{} , elapsed from last:{}", sampleId++, TO_MS(sw), + TO_MS(sw_copy), TO_MS(staticReadRaw)); + staticReadRaw.reset(); + } + + int intLen; + int floatLen; + vector featOrder; + atomic sampleId; +}; + +REGISTER_KERNEL_BUILDER(Name("ReadRaw").Device(DEVICE_CPU), ReadRaw); + + +// ##################### ReadRawDummy ####################### +REGISTER_OP("ReadRawDummy") + .Input("sample: int64") + .Output("int_output: int64") + .Output("float_output: float") + .Attr("int_len: int") + .Attr("float_len: int") + .SetShapeFn([](InferenceContextPtr c) { + int temp; + TF_RETURN_IF_ERROR(c->GetAttr("int_len", &temp)); + c->set_output(TENSOR_INDEX_0, c->Vector(temp)); + TF_RETURN_IF_ERROR(c->GetAttr("float_len", &temp)); + c->set_output(TENSOR_INDEX_1, c->Vector(temp)); + return Status::OK(); + }); + +class ReadRawDummy : public OpKernel { +public: + explicit ReadRawDummy(OpKernelConstructionPtr context) : OpKernel(context) + { + OP_REQUIRES_OK(context, context->GetAttr("int_len", &intLen)); + OP_REQUIRES_OK(context, context->GetAttr("float_len", &floatLen)); + } + + ~ReadRawDummy() override = default; + + void Compute(OpKernelContextPtr context) override + { + spdlog::stopwatch sw; + Tensor* intTensor = nullptr; + Tensor* floatTensor = nullptr; + + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape { intLen }, &intTensor)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape { floatLen }, &floatTensor)); + + const Tensor& inputTensor = context->input(TENSOR_INDEX_0); + auto input = inputTensor.flat(); + int32_t batchId = input(0); + + spdlog::info("ReadRawDummy cost:{}, elapsed from last:{} , batchId = {}", TO_MS(sw), TO_MS(staticReadRaw), + batchId); + staticReadRaw.reset(); + } + + int intLen; + int floatLen; +}; + +REGISTER_KERNEL_BUILDER(Name("ReadRawDummy").Device(DEVICE_CPU), ReadRawDummy); + +class CustOps : public OpKernel { +public: + explicit CustOps(OpKernelConstructionPtr context) : OpKernel(context) + { + } + + void Compute(OpKernelContextPtr context) override + { + std::cout << " Cust opp not installed!!" << std::endl; + } + + ~CustOps() = default; +}; + +REGISTER_OP("EmbeddingLookupByAddress") + .Input("address: int64") + .Attr("embedding_dim: int") + .Attr("embedding_type: int") + .Output("y: float") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ShapeHandle addrShape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); + int embSize; + TF_RETURN_IF_ERROR(c->GetAttr("embedding_dim", &embSize)); + tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); + c->set_output(TENSOR_INDEX_0, c->Matrix(rows, embSize)); + return Status::OK(); + }); + +REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupByAddress").Device(DEVICE_CPU), CustOps); + +REGISTER_OP("EmbeddingUpdateByAddress") + .Input("address: int64") + .Input("embedding: float") + .Attr("update_type: int") + .Output("y: float") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ShapeHandle addrShape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); + ShapeHandle embeddingShape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &embeddingShape)); + tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); + tensorflow::shape_inference::DimensionHandle cols = c->Dim(embeddingShape, 1); + c->set_output(TENSOR_INDEX_0, c->Matrix(rows, cols)); + return Status::OK(); + }); + +REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), CustOps); \ No newline at end of file diff --git a/src/ops_tf/tf_ops.h b/src/ops_tf/tf_ops.h new file mode 100644 index 00000000..71b011bc --- /dev/null +++ b/src/ops_tf/tf_ops.h @@ -0,0 +1,12 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: tf ops. + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#ifndef MX_REC_TF_OPS_H +#define MX_REC_TF_OPS_H + +#endif // MX_REC_TF_OPS_H diff --git a/src/pybind/CMakeLists.txt b/src/pybind/CMakeLists.txt new file mode 100644 index 00000000..ed1bb626 --- /dev/null +++ b/src/pybind/CMakeLists.txt @@ -0,0 +1,6 @@ +cmake_minimum_required(VERSION 3.12) + +pybind11_add_module(mxrec_pybind module_main.cpp) +set_target_properties(mxrec_pybind PROPERTIES LINK_FLAGS "-Wl,-rpath,/") +target_link_libraries(mxrec_pybind PUBLIC ASC) +install(TARGETS mxrec_pybind LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) \ No newline at end of file diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp new file mode 100644 index 00000000..44f02ff0 --- /dev/null +++ b/src/pybind/module_main.cpp @@ -0,0 +1,153 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: pybind module + * Author: MindX SDK + * Date: 2022/11/15 + */ +#include "module_main.h" +#include +#include +#include "emb_mgmt/emb_mgmt.h" + +namespace py = pybind11; +using namespace MxRec; + +void GetRankInfo(py::module_& m); + +void GetEmbInfo(py::module_& m); + +void GetRandomInfo(py::module_& m); + +void GetHybridMgmt(py::module_& m); + +void GetThresholdValue(pybind11::module_& m); + +void GetInitializeInfo(pybind11::module_& m); + +void GetConstantInitializerInfo(pybind11::module_& m); + +void GetNormalInitializerInfo(pybind11::module_& m); + +int GetUBHotSize(int devID) +{ + return static_cast(MxRec::GetUBSize(devID)/ sizeof(float) * HOT_EMB_CACHE_PCT) ; +} + +PYBIND11_MODULE(mxrec_pybind, m) +{ + m.def("get_ub_hot_size", &GetUBHotSize, py::arg("device_id")); + + m.attr("USE_STATIC") = py::int_(HybridOption::USE_STATIC); + + m.attr("USE_HOT") = py::int_(HybridOption::USE_HOT); + + m.attr("USE_DYNAMIC_EXPANSION") = py::int_(HybridOption::USE_DYNAMIC_EXPANSION); + + GetRankInfo(m); + + GetEmbInfo(m); + + GetRandomInfo(m); + + GetHybridMgmt(m); + + GetThresholdValue(m); + + GetInitializeInfo(m); + + GetConstantInitializerInfo(m); + + GetNormalInitializerInfo(m); +} + +void GetRankInfo(pybind11::module_& m) +{ + pybind11::class_(m, "RankInfo") + .def(py::init>(), py::arg("rank_id"), py::arg("device_id"), + py::arg("local_rank_size"), py::arg("option"), py::arg("num_batch") = 1, + py::arg("max_step") = vector { -1, -1 }) + .def_readwrite("rank_id", &RankInfo::rankId) + .def_readwrite("device_id", &RankInfo::deviceId) + .def_readwrite("rank_size", &RankInfo::rankSize) + .def_readwrite("local_rank_size", &RankInfo::localRankSize) + .def_readwrite("option", &RankInfo::option) + .def_readwrite("num_batch", &RankInfo::nBatch) + .def_readwrite("max_step", &RankInfo::maxStep); +} + +void GetEmbInfo(pybind11::module_& m) +{ + pybind11::class_(m, "EmbInfo") + .def(pybind11::init, std::vector&>(), + py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), py::arg("vocab_size"), + py::arg("initialize_infos")) + .def_readwrite("name", &EmbInfo::name) + .def_readwrite("send_count", &EmbInfo::sendCount) + .def_readwrite("embedding_size", &EmbInfo::embeddingSize) + .def_readwrite("dev_vocab_size", &EmbInfo::devVocabSize) + .def_readwrite("host_vocab_size", &EmbInfo::hostVocabSize) + .def_readwrite("initialize_infos", &EmbInfo::initializeInfos); +} + +void GetRandomInfo(pybind11::module_& m) +{ + pybind11::class_(m, "RandomInfo") + .def(pybind11::init()) + .def_readwrite("start", &RandomInfo::start) + .def_readwrite("len", &RandomInfo::len) + .def_readwrite("constant_val", &RandomInfo::constantVal) + .def_readwrite("random_min", &RandomInfo::randomMin) + .def_readwrite("random_max", &RandomInfo::randomMax); +} + +void GetInitializeInfo(pybind11::module_ &m) +{ + pybind11::class_(m, "InitializeInfo") + .def(py::init(), py::arg("name"), py::arg("start"), + py::arg("len"), py::arg("constant_initializer_info")) + .def(py::init(), py::arg("name"), py::arg("start"), + py::arg("len"), py::arg("normal_initializer_info")) + .def_readwrite("name", &InitializeInfo::name) + .def_readwrite("start", &InitializeInfo::start) + .def_readwrite("len", &InitializeInfo::len) + .def_readwrite("ConstantInitializerInfo", &InitializeInfo::constantInitializerInfo) + .def_readwrite("NormalInitializerInfo", &InitializeInfo::normalInitializerInfo); +} + +void GetConstantInitializerInfo(pybind11::module_ &m) +{ + pybind11::class_(m, "ConstantInitializerInfo") + .def(py::init(), py::arg("constant_val") = 0) + .def_readwrite("constant_val", &ConstantInitializerInfo::constantValue); +} + +void GetNormalInitializerInfo(pybind11::module_ &m) +{ + pybind11::class_(m, "NormalInitializerInfo") + .def(py::init(), py::arg("mean") = 0.0, py::arg("stddev") = 1.0, py::arg("seed") = 0) + .def_readwrite("mean", &NormalInitializerInfo::mean) + .def_readwrite("stddev", &NormalInitializerInfo::stddev) + .def_readwrite("seed", &NormalInitializerInfo::seed); +} + +void GetHybridMgmt(pybind11::module_& m) +{ + pybind11::class_(m, "HybridMgmt") + .def(py::init()) + .def("initialize", &MxRec::HybridMgmt::Initialize, py::arg("rank_info"), py::arg("emb_info"), + py::arg("seed") = DEFAULT_RANDOM_SEED, py::arg("threshold_values") = vector {}, + py::arg("if_load") = false) + .def("save", &MxRec::HybridMgmt::Save, py::arg("save_path") = "") + .def("load", &MxRec::HybridMgmt::Load, py::arg("load_path") = "") + .def("destroy", &MxRec::HybridMgmt::Destroy) + .def("evict", &MxRec::HybridMgmt::Evict); +} + +void GetThresholdValue(pybind11::module_& m) +{ + pybind11::class_(m, "ThresholdValue") + .def(pybind11::init()) + .def_readwrite("tensor_name", &ThresholdValue::tensorName) + .def_readwrite("count_threshold", &ThresholdValue::countThreshold) + .def_readwrite("time_threshold", &ThresholdValue::timeThreshold); +} \ No newline at end of file diff --git a/src/pybind/module_main.h b/src/pybind/module_main.h new file mode 100644 index 00000000..e0ffa3bd --- /dev/null +++ b/src/pybind/module_main.h @@ -0,0 +1,11 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. +* Description: main head file +* Author: MindX SDK +* Date: 2022/11/15 +*/ + +#ifndef SPARSE_SSD_DEMO_MODULE_MAIN_H +#define SPARSE_SSD_DEMO_MODULE_MAIN_H + +#endif // SPARSE_SSD_DEMO_MODULE_MAIN_H diff --git a/src/test_ut.sh b/src/test_ut.sh new file mode 100644 index 00000000..175954bb --- /dev/null +++ b/src/test_ut.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. +# Description: NA +# Author: MindX SDK +# Create: 2022 +# History: NA +set -e + +# add mpirun env +export OMPI_ALLOW_RUN_AS_ROOT=1 +export OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 + +source /etc/profile +source /opt/rh/devtoolset-7/enable + +CUR_DIR=$(dirname "$(readlink -f "$0")") + +compile_securec() +{ + if [[ ! -d ${CUR_DIR}/../platform/securec ]]; then + echo "securec is not exist" + exit 1 + fi + + if [[ ! -f ${CUR_DIR}/../platform/securec/lib/libsecurec.so ]]; then + cd ${CUR_DIR}/../platform/securec/src + make -j + fi +} +compile_securec + +cd ${CUR_DIR} +cd ../src/ + +find ./ -name "*.sh" -exec dos2unix {} \; +find ./ -name "*.sh" -exec chmod +x {} \; + +rm -rf build + +mkdir build +cd build + +cmake -DCMAKE_BUILD_TYPE=Debug \ + -DTF_PATH=/opt/buildtools/python-3.7.5/lib/python3.7/site-packages/tensorflow_core \ + -DOMPI_PATH=/usr/local/openmpi/ \ + -DPYTHON_PATH=/opt/buildtools/python-3.7.5/ \ + -DEASY_PROFILER_PATH=/opt/buildtools/ \ + -DASCEND_PATH=/usr/local/Ascend/ascend-toolkit/latest \ + -DABSEIL_PATH=${PWD}/../../install/abseil/ \ + -DSECUREC_PATH=${CUR_DIR}/../platform/securec \ + -DBUILD_TESTS=on -DCOVERAGE=on ../ + +make -j +make install + +# Run Test +mpirun -np 4 ./tests/test_main + +cd ../ + +COVERAGE_FILE=coverage.info +REPORT_FOLDER=coverage_report +lcov --rc lcov_branch_coverage=1 -c -d build -o ${COVERAGE_FILE}_tmp +lcov --rc lcov_branch_coverage=1 -e ${COVERAGE_FILE}_tmp "*src*" -o ${COVERAGE_FILE} +genhtml --rc genhtml_branch_coverage=1 ${COVERAGE_FILE} -o ${REPORT_FOLDER} +rm -rf ${COVERAGE_FILE}_tmp +rm -rf ${COVERAGE_FILE} + +if [[ "$OSTYPE" == "darwin"* ]]; then + open ./${REPORT_FOLDER}/index.html +fi \ No newline at end of file diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt new file mode 100644 index 00000000..0ef60b39 --- /dev/null +++ b/src/tests/CMakeLists.txt @@ -0,0 +1,62 @@ +# 开启测试 +enable_testing() +set(CMAKE_CXX_STANDARD 17) +find_package(GTest REQUIRED) +include_directories(${GTEST_INCLUDE_DIRS}) +add_definitions(-DGTEST) + +# src +file(GLOB_RECURSE MXREC_SRC ${PROJECT_SOURCE_DIR}/core/*.cpp) +message("MXREC_SRC: " ${MXREC_SRC}) +file(GLOB_RECURSE MXREC_TEST_SRC ./*.cpp) +message("MXREC_TEST_SRC: " ${MXREC_TEST_SRC}) + +set(CMAKE_CXX_FLAGS "--coverage") + +add_executable(test_main ${MXREC_SRC} ${MXREC_TEST_SRC}) + +if(NOT SECUREC_PATH) + set(SECUREC_PATH ${PROJECT_SOURCE_DIR}/../platform/securec) +endif() + +if (easy_profiler_FOUND) + message("==link with easy_profiler==") + target_link_directories(test_main PUBLIC ${EASY_PROFILER_PATH}) + target_link_libraries(test_main PUBLIC easy_profiler) +endif () + +include_directories(PRIVATE .) +target_include_directories(test_main + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../core + ${ASCEND_DRIVER_PATH}/include + ${SECUREC_PATH}/include +) + +target_link_directories(test_main + PRIVATE + ${ASCEND_DRIVER_PATH}/lib64/driver + ${SECUREC_PATH}/lib) + +target_link_directories(test_main + PUBLIC + ${ASCEND_PATH}/fwkacllib/lib64 + ${ASCEND_PATH}/compiler/lib64 + ${ASCEND_PATH}/runtime/lib64 + ${HDF5_PATH}/lib + ${SECUREC_PATH}/lib + ${ASCEND_DRIVER_PATH}/lib64/driver + ) + +target_link_libraries(test_main PUBLIC ${TF_LIB} + securec + OpenMP::OpenMP_CXX ${HDF5_CXX_LIBRARIES} ${MPI_CXX_LIBRARIES} + ${PYTHON_LIBRARY} drvdsmi_host + ) + +target_link_libraries(test_main PUBLIC + ${GTEST_BOTH_LIBRARIES} + MPI::MPI_CXX) + +target_link_libraries(test_main PUBLIC ascendcl msprofiler ge_executor gert runtime ge_common register graph ascend_protobuf + profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host) \ No newline at end of file diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp new file mode 100644 index 00000000..3cff5e1f --- /dev/null +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -0,0 +1,449 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-11-15 + */ + +#include +#include +#include +#include + +#include "checkpoint/checkpoint.h" +#include "ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h" +#include "ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h" +#include "ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h" +#include "ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h" +#include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" + + +using namespace std; +using namespace MxRec; + +class CheckpointTest : public testing::Test { +protected: + string testPath { "./ckpt_mgmt_test" }; + int rankId; + + int floatBytes { 4 }; + int int32Bytes { 4 }; + int int64Bytes { 8 }; + + int64_t int64Min { static_cast(UINT32_MAX) }; + + int maxChannelNum = MAX_CHANNEL_NUM; + int keyProcessThread = KEY_PROCESS_THREAD; + + int embInfoNum { 10 }; + + float floatMem { 0.5 }; + int64_t featMem { static_cast(UINT32_MAX) }; + int32_t offsetMem { 0 }; + + string name { "table" }; + int sendCount { 8 }; + int embeddingSize { 100 }; + int devVocabSize { 8 }; + int hostVocabSize { 16 }; + + vector testEmbInfos; + RankInfo rankInfo; + + void SetUp() + { + spdlog::set_level(spdlog::level::trace); + int claimed; + + MPI_Query_thread(&claimed); + ASSERT_EQ(claimed, MPI_THREAD_MULTIPLE); + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + rankInfo.rankId = rankId; + rankInfo.useDynamicExpansion = false; + } + + void SetEmbInfo() + { + int idx { 0 }; + testEmbInfos.resize(embInfoNum); + for (auto& testEmbInfo : testEmbInfos) { + testEmbInfo.name = name + to_string(idx); + testEmbInfo.sendCount = sendCount; + testEmbInfo.embeddingSize = embeddingSize; + testEmbInfo.devVocabSize = devVocabSize; + testEmbInfo.hostVocabSize = hostVocabSize; + ++idx; + } + } + + void SetEmbData(vector>& testEmbData) + { + testEmbData.resize(hostVocabSize); + for (auto& testData : testEmbData) { + testData.resize(embeddingSize); + for (auto& testValue : testData) { + testValue = floatMem; + floatMem++; + } + } + } + + void SetHostEmbs(emb_mem_t& testHostEmbs) + { + vector> testEmbData; + for (const auto& testEmbInfo : testEmbInfos) { + SetEmbData(testEmbData); + HostEmbTable embTable { testEmbInfo, move(testEmbData) }; + testHostEmbs[testEmbInfo.name] = move(embTable); // set test input data + } + } + + void SetHashMapInfo(absl::flat_hash_map& testHash, + vector& testDev2B, + vector& testDev2K) + { + testDev2B.resize(devVocabSize); + testDev2K.resize(devVocabSize); + for (int i { 0 }; i < devVocabSize; ++i) { + testDev2K.at(i) = offsetMem; + testHash[featMem] = offsetMem; + + featMem++; + offsetMem++; + } + fill(testDev2B.begin(), testDev2B.end(), -1); + } + + void SetEmbHashMaps(emb_hash_mem_t& testEmbHashMaps) + { + EmbHashMapInfo embHashInfo; + absl::flat_hash_map testHash; + vector testDev2B; + vector testDev2K; + for (const auto& testEmbInfo : testEmbInfos) { + SetHashMapInfo(testHash, testDev2B, testDev2K); + + embHashInfo.hostHashMap = std::move(testHash); + + embHashInfo.devOffset2Batch = move(testDev2B); + embHashInfo.devOffset2Key = move(testDev2K); + + embHashInfo.currentUpdatePos = 0; + embHashInfo.hostVocabSize = hostVocabSize; + embHashInfo.devVocabSize = devVocabSize; + + testEmbHashMaps[testEmbInfo.name] = move(embHashInfo); + } + } + + void SetMaxOffset(offset_mem_t& testMaxOffset) + { + for (const auto& testEmbInfo : testEmbInfos) { + testMaxOffset[testEmbInfo.name] = offsetMem; + } + } + + void SetKeyOffsetMap(absl::flat_hash_map& testKeyOffsetMap) + { + for (int64_t i { 0 }; i < hostVocabSize; ++i) { + testKeyOffsetMap[featMem] = i; + + featMem++; + } + } + + void SetKeyOffsetMaps(key_offset_mem_t& testKeyOffsetMaps) + { + absl::flat_hash_map testKeyOffsetMap; + for (const auto& testEmbInfo : testEmbInfos) { + SetKeyOffsetMap(testKeyOffsetMap); + testKeyOffsetMaps[testEmbInfo.name] = std::move(testKeyOffsetMap); + } + } + + void SetTens2Threshold(tensor_2_thresh_mem_t& testTens2Threshold) + { + for (const auto& testEmbInfo : testEmbInfos) { + ThresholdValue val; + val.tensorName = testEmbInfo.name; + val.countThreshold = offsetMem; + val.timeThreshold = offsetMem; + + offsetMem++; + + testTens2Threshold[testEmbInfo.name] = move(val); + } + } + + void SetHistRec(AdmitAndEvictData& histRec) + { + int64_t featureId { int64Min }; + int count { 1 }; + time_t lastTime { 1000 }; + time_t timeStamp { 10000 }; + + for (const auto& testEmbInfo : testEmbInfos) { + auto& historyRecords { histRec.historyRecords[testEmbInfo.name] }; + auto& timestamps { histRec.timestamps[testEmbInfo.name] }; + + timestamps = timeStamp; + + for (int i = 0; i < count; ++i) { + historyRecords[featureId].featureId = featureId; + historyRecords[featureId].count = count; + historyRecords[featureId].tensorName = testEmbInfo.name; + historyRecords[featureId].lastTime = lastTime; + + featureId++; + } + + count++; + lastTime++; + timeStamp++; + } + } +}; + +TEST_F(CheckpointTest, HostEmbs) +{ + emb_mem_t testHostEmbs; + emb_mem_t validHostEmbs; + + SetEmbInfo(); + SetHostEmbs(testHostEmbs); + validHostEmbs = testHostEmbs; + + CkptData testSaveData; + CkptData validLoadData; + CkptData testLoadData; + + testSaveData.hostEmbs = testHostEmbs; + validLoadData.hostEmbs = validHostEmbs; + + Checkpoint testCkpt; + testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); + testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, + { CkptFeatureType::HOST_EMB }); + + for (const auto& it : validLoadData.hostEmbs) { + const auto& embInfo = testLoadData.hostEmbs.at(it.first).hostEmbInfo; + const auto& embData = testLoadData.hostEmbs.at(it.first).embData; + + EXPECT_EQ(it.second.hostEmbInfo.name, embInfo.name); + EXPECT_EQ(it.second.hostEmbInfo.sendCount, embInfo.sendCount); + EXPECT_EQ(it.second.hostEmbInfo.embeddingSize, embInfo.embeddingSize); + EXPECT_EQ(it.second.hostEmbInfo.devVocabSize, embInfo.devVocabSize); + EXPECT_EQ(it.second.hostEmbInfo.hostVocabSize, embInfo.hostVocabSize); + + EXPECT_EQ(it.second.embData, embData); + } +} + +TEST_F(CheckpointTest, EmbHashMaps) +{ + emb_hash_mem_t testEmbHashMaps; + emb_hash_mem_t validEmbHashMaps; + + SetEmbInfo(); + SetEmbHashMaps(testEmbHashMaps); + validEmbHashMaps = testEmbHashMaps; + + CkptData testSaveData; + CkptData validLoadData; + CkptData testLoadData; + + testSaveData.embHashMaps = std::move(testEmbHashMaps); + validLoadData.embHashMaps = std::move(validEmbHashMaps); + + Checkpoint testCkpt; + testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); + testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::EMB_HASHMAP }); + + EXPECT_EQ(validLoadData.embHashMaps.size(), testLoadData.embHashMaps.size()); + for (const auto& it : validLoadData.embHashMaps) { + EXPECT_EQ(1, testLoadData.embHashMaps.count(it.first)); + + const auto& hostHashMap = testLoadData.embHashMaps.at(it.first).hostHashMap; + const auto& devOffset2Batch = testLoadData.embHashMaps.at(it.first).devOffset2Batch; + const auto& devOffset2Key = testLoadData.embHashMaps.at(it.first).devOffset2Key; + const auto& currentUpdatePos = testLoadData.embHashMaps.at(it.first).currentUpdatePos; + const auto& hostVocabSize = testLoadData.embHashMaps.at(it.first).hostVocabSize; + const auto& devVocabSize = testLoadData.embHashMaps.at(it.first).devVocabSize; + + EXPECT_EQ(it.second.hostHashMap, hostHashMap); + + EXPECT_EQ(it.second.devOffset2Batch, devOffset2Batch); + EXPECT_EQ(it.second.devOffset2Key, devOffset2Key); + + EXPECT_EQ(it.second.currentUpdatePos, currentUpdatePos); + EXPECT_EQ(it.second.hostVocabSize, hostVocabSize); + EXPECT_EQ(it.second.devVocabSize, devVocabSize); + } +} + +TEST_F(CheckpointTest, MaxOffset) +{ + offset_mem_t testMaxOffset; + offset_mem_t validMaxOffset; + + SetEmbInfo(); + SetMaxOffset(testMaxOffset); + validMaxOffset = testMaxOffset; + + CkptData testSaveData; + CkptData validLoadData; + CkptData testLoadData; + + testSaveData.maxOffset = std::move(testMaxOffset); + validLoadData.maxOffset = std::move(validMaxOffset); + + Checkpoint testCkpt; + testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); + testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::MAX_OFFSET }); + + EXPECT_EQ(validLoadData.maxOffset.size(), testLoadData.maxOffset.size()); + for (const auto& it : validLoadData.maxOffset) { + EXPECT_EQ(1, testLoadData.maxOffset.count(it.first)); + + const auto& maxOffset = testLoadData.maxOffset.at(it.first); + + EXPECT_EQ(it.second, maxOffset); + } +} + +TEST_F(CheckpointTest, KeyOffsetMaps) +{ + key_offset_mem_t testKeyOffsetMaps; + key_offset_mem_t validKeyOffsetMaps; + + SetEmbInfo(); + SetKeyOffsetMaps(testKeyOffsetMaps); + validKeyOffsetMaps = testKeyOffsetMaps; + + CkptData testSaveData; + CkptData validLoadData; + CkptData testLoadData; + + testSaveData.keyOffsetMap = std::move(testKeyOffsetMaps); + validLoadData.keyOffsetMap = std::move(validKeyOffsetMaps); + + Checkpoint testCkpt; + testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); + testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::KEY_OFFSET_MAP }); + + EXPECT_EQ(validLoadData.keyOffsetMap.size(), testLoadData.keyOffsetMap.size()); + for (const auto& it : validLoadData.keyOffsetMap) { + EXPECT_EQ(1, testLoadData.keyOffsetMap.count(it.first)); + const auto& maxOffset = testLoadData.keyOffsetMap.at(it.first); + EXPECT_EQ(it.second, maxOffset); + } +} + +TEST_F(CheckpointTest, AllMgmt) +{ + offset_mem_t testMaxOffset; + offset_mem_t validMaxOffset; + key_offset_mem_t testKeyOffsetMaps; + key_offset_mem_t validKeyOffsetMaps; + + SetEmbInfo(); + SetMaxOffset(testMaxOffset); + validMaxOffset = testMaxOffset; + SetKeyOffsetMaps(testKeyOffsetMaps); + validKeyOffsetMaps = testKeyOffsetMaps; + + CkptData testSaveData; + CkptData validLoadData; + CkptData testLoadData; + + testSaveData.maxOffset = std::move(testMaxOffset); + validLoadData.maxOffset = std::move(validMaxOffset); + testSaveData.keyOffsetMap = std::move(testKeyOffsetMaps); + validLoadData.keyOffsetMap = std::move(validKeyOffsetMaps); + + Checkpoint testCkpt; + testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); + testCkpt.LoadModel(testPath, + testLoadData, + rankInfo, + testEmbInfos, + { CkptFeatureType::MAX_OFFSET, CkptFeatureType::KEY_OFFSET_MAP }); + + EXPECT_EQ(validLoadData.maxOffset.size(), testLoadData.maxOffset.size()); + for (const auto& it : validLoadData.maxOffset) { + EXPECT_EQ(1, testLoadData.maxOffset.count(it.first)); + + const auto& maxOffset = testLoadData.maxOffset.at(it.first); + + EXPECT_EQ(it.second, maxOffset); + } + + EXPECT_EQ(validLoadData.keyOffsetMap.size(), testLoadData.keyOffsetMap.size()); + for (const auto& it : validLoadData.keyOffsetMap) { + EXPECT_EQ(1, testLoadData.keyOffsetMap.count(it.first)); + + const auto& maxOffset = testLoadData.keyOffsetMap.at(it.first); + EXPECT_EQ(it.second, maxOffset); + } +} + +TEST_F(CheckpointTest, FeatAdmitNEvict) +{ + tensor_2_thresh_mem_t testTrens2Thresh; + tensor_2_thresh_mem_t validTrens2Thresh; + AdmitAndEvictData testHistRec; + AdmitAndEvictData validHistRec; + + SetEmbInfo(); + SetTens2Threshold(testTrens2Thresh); + validTrens2Thresh = testTrens2Thresh; + SetHistRec(testHistRec); + validHistRec = testHistRec; + + CkptData testSaveData; + CkptData validLoadData; + CkptData testLoadData; + + testSaveData.tens2Thresh = testTrens2Thresh; + testSaveData.histRec.timestamps = testHistRec.timestamps; + testSaveData.histRec.historyRecords = testHistRec.historyRecords; + validLoadData.tens2Thresh = validTrens2Thresh; + validLoadData.histRec = validHistRec; + validLoadData.histRec.timestamps = validHistRec.timestamps; + validLoadData.histRec.historyRecords = validHistRec.historyRecords; + + Checkpoint testCkpt; + testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); + testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::FEAT_ADMIT_N_EVICT }); + + EXPECT_EQ(validLoadData.tens2Thresh.size(), testLoadData.tens2Thresh.size()); + EXPECT_EQ(validLoadData.histRec.historyRecords.size(), testLoadData.histRec.historyRecords.size()); + for (const auto& it : validLoadData.tens2Thresh) { + EXPECT_EQ(1, testLoadData.tens2Thresh.count(it.first)); + + const auto& tens2Thresh = testLoadData.tens2Thresh.at(it.first); + + EXPECT_EQ(it.second.tensorName, tens2Thresh.tensorName); + EXPECT_EQ(it.second.countThreshold, tens2Thresh.countThreshold); + EXPECT_EQ(it.second.timeThreshold, tens2Thresh.timeThreshold); + } + + for (const auto& it : validLoadData.histRec.timestamps) { + EXPECT_EQ(1, testLoadData.histRec.timestamps.count(it.first)); + EXPECT_EQ(1, testLoadData.histRec.historyRecords.count(it.first)); + + const auto& timestamps = testLoadData.histRec.timestamps.at(it.first); + const auto& historyRecords = testLoadData.histRec.historyRecords.at(it.first); + const auto& validHistRec = validLoadData.histRec.historyRecords.at(it.first); + + EXPECT_EQ(it.second, timestamps); + for (const auto& validHR : validHistRec) { + const auto& testHR = historyRecords.at(validHR.first); + + EXPECT_EQ(validHR.second.featureId, testHR.featureId); + EXPECT_EQ(validHR.second.count, testHR.count); + EXPECT_EQ(validHR.second.tensorName, testHR.tensorName); + EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); + } + } +} \ No newline at end of file diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp new file mode 100644 index 00000000..2c13c0c4 --- /dev/null +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2022-12-03 + */ + +#include +#include +#include + +#include "ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h" +#include "ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h" +#include "ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h" +#include "ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h" +#include "ckpt_data_handler//feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" + +using namespace std; +using namespace MxRec; + +using valid_float_t = absl::flat_hash_map>; +using valid_int_t = absl::flat_hash_map>; +using valid_int64_t = absl::flat_hash_map>; +using valie_dataset_t = absl::flat_hash_map>; +using valid_attrib_t = absl::flat_hash_map>; + +class CkptDataHandlerTest : public testing::Test { +protected: + int floatBytes { 4 }; + int int32Bytes { 4 }; + int int64Bytes { 8 }; + + int64_t int64Min { static_cast(UINT32_MAX) }; + + int maxChannelNum { MAX_CHANNEL_NUM }; + int keyProcessThread { KEY_PROCESS_THREAD }; + + vector testEmbInfos; + valid_int_t validEmbInfo; + valid_attrib_t validEmbInfoAttrib; + + void SetEmbInfo() + { + int embInfoNum { 10 }; + + string name { "table" }; + int sendCount { 8 }; + int embeddingSize { 100 }; + size_t devVocabSize { 8 }; + size_t hostVocabSize { 16 }; + + int idx { 0 }; + testEmbInfos.resize(embInfoNum); + for (auto& testEmbInfo : testEmbInfos) { + testEmbInfo.name = name + to_string(idx); + testEmbInfo.sendCount = sendCount; + testEmbInfo.embeddingSize = embeddingSize; + testEmbInfo.devVocabSize = devVocabSize; + testEmbInfo.hostVocabSize = hostVocabSize; + + vector validInt { sendCount, + embeddingSize, + static_cast(devVocabSize), + static_cast(hostVocabSize) }; + vector validAttrib { static_cast(int32Bytes), static_cast(int32Bytes) }; + + validEmbInfo[name + to_string(idx)] = move(validInt); + validEmbInfoAttrib[name + to_string(idx)] = move(validAttrib); + ++idx; + } + } + + void SetTens2Threshold(tensor_2_thresh_mem_t& testTens2Threshold, + valid_int_t& validArr, + valid_attrib_t& validAttrib) + { + int countThreshold { 20 }; + int timeThreshold { 100 }; + + for (const auto& testEmbInfo : testEmbInfos) { + ThresholdValue val; + val.tensorName = testEmbInfo.name; + val.countThreshold = countThreshold; + val.timeThreshold = timeThreshold; + + vector valid { countThreshold, timeThreshold }; + + countThreshold++; + timeThreshold++; + + testTens2Threshold[testEmbInfo.name] = move(val); + validArr[testEmbInfo.name] = move(valid); + validAttrib[testEmbInfo.name].push_back(2); // 2 is element num in one vector + validAttrib[testEmbInfo.name].push_back(int32Bytes); + } + } + + void SetHistRec(AdmitAndEvictData& histRec, valid_int64_t& validArr, valid_attrib_t& validAttrib) + { + int64_t featureId { int64Min }; + int count { 1 }; + time_t lastTime { 1000 }; + time_t timeStamp { 10000 }; + + for (const auto& testEmbInfo : testEmbInfos) { + auto& validA { validArr[testEmbInfo.name] }; + auto& historyRecords { histRec.historyRecords[testEmbInfo.name] }; + auto& timestamps { histRec.timestamps[testEmbInfo.name] }; + + timestamps = timeStamp; + validA.push_back(timeStamp); + + for (int i = 0; i < count; ++i) { + historyRecords[featureId].featureId = featureId; + historyRecords[featureId].count = count; + historyRecords[featureId].tensorName = testEmbInfo.name; + historyRecords[featureId].lastTime = lastTime; + + validA.push_back(featureId); + validA.push_back(count); + validA.push_back(lastTime); + + featureId++; + } + + auto& attribute = validAttrib[testEmbInfo.name]; + attribute.push_back(count); + attribute.push_back(int64Bytes); + + count++; + lastTime++; + timeStamp++; + } + } +}; + +TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) +{ + tensor_2_thresh_mem_t testTrens2Thresh; + tensor_2_thresh_mem_t validTrens2Thresh; + AdmitAndEvictData testHistRec; + AdmitAndEvictData validHistRec; + + valid_int_t validTrens2ThreshArr; + valid_int64_t validHistRecArr; + valid_attrib_t validTrens2ThreshAttrib; + valid_attrib_t validHistRecAttrib; + + SetEmbInfo(); + SetTens2Threshold(testTrens2Thresh, validTrens2ThreshArr, validTrens2ThreshAttrib); + validTrens2Thresh = testTrens2Thresh; + SetHistRec(testHistRec, validHistRecArr, validHistRecAttrib); + validHistRec = testHistRec; + + CkptData testData; + CkptData validData; + FeatAdmitNEvictCkpt testCkpt; + + testData.tens2Thresh = testTrens2Thresh; + testData.histRec.timestamps = testHistRec.timestamps; + testData.histRec.historyRecords = testHistRec.historyRecords; + validData.tens2Thresh = validTrens2Thresh; + validData.histRec.timestamps = validHistRec.timestamps; + validData.histRec.historyRecords = validHistRec.historyRecords; + + testCkpt.SetProcessData(testData); + + vector embNames { testCkpt.GetEmbNames() }; + CkptTransData testSaveData; + EXPECT_EQ(validData.tens2Thresh.size(), embNames.size()); + + for (const auto& embName : embNames) { + EXPECT_EQ(1, validData.tens2Thresh.count(embName)); + + EXPECT_EQ(1, validData.histRec.timestamps.count(embName)); + EXPECT_EQ(1, validData.histRec.historyRecords.count(embName)); + + testSaveData = testCkpt.GetDataset(CkptDataType::TENSOR_2_THRESH, embName); + EXPECT_EQ(validTrens2ThreshArr.at(embName), testSaveData.int32Arr); // need other test method + EXPECT_EQ(validTrens2ThreshAttrib.at(embName), testSaveData.attribute); + testSaveData = testCkpt.GetDataset(CkptDataType::HIST_REC, embName); + EXPECT_EQ(validHistRecAttrib.at(embName), testSaveData.attribute); + } + + CkptTransData testLoadData; + for (const auto& embName : embNames) { + testLoadData.int32Arr = validTrens2ThreshArr.at(embName); + testLoadData.attribute = validTrens2ThreshAttrib.at(embName); + testCkpt.SetDataset(CkptDataType::TENSOR_2_THRESH, embName, testLoadData); + + testLoadData.int64Arr = validHistRecArr.at(embName); + testLoadData.attribute = validHistRecAttrib.at(embName); + testCkpt.SetDataset(CkptDataType::HIST_REC, embName, testLoadData); + } + testCkpt.GetProcessData(testData); + + EXPECT_EQ(validData.tens2Thresh.size(), testData.tens2Thresh.size()); + EXPECT_EQ(validData.histRec.historyRecords.size(), testData.histRec.historyRecords.size()); + for (const auto& it : validData.tens2Thresh) { + EXPECT_EQ(1, testData.tens2Thresh.count(it.first)); + + const auto& tens2Thresh = testData.tens2Thresh.at(it.first); + + EXPECT_EQ(it.second.tensorName, tens2Thresh.tensorName); + EXPECT_EQ(it.second.countThreshold, tens2Thresh.countThreshold); + EXPECT_EQ(it.second.timeThreshold, tens2Thresh.timeThreshold); + } + + for (const auto& it : validData.histRec.timestamps) { + EXPECT_EQ(1, testData.histRec.timestamps.count(it.first)); + EXPECT_EQ(1, testData.histRec.historyRecords.count(it.first)); + + const auto& historyRecords = testData.histRec.historyRecords.at(it.first); + const auto& validHistRec = validData.histRec.historyRecords.at(it.first); + + for (const auto& validHR : validHistRec) { + const auto& testHR = historyRecords.at(validHR.first); + + EXPECT_EQ(validHR.second.featureId, testHR.featureId); + EXPECT_EQ(validHR.second.count, testHR.count); + EXPECT_EQ(validHR.second.tensorName, testHR.tensorName); + EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); + } + } +} \ No newline at end of file diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp new file mode 100644 index 00000000..ff30385e --- /dev/null +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -0,0 +1,226 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: emb mgmt test + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#include +#include +#include +#include "emb_mgmt/emb_mgmt.h" +#include "host_emb/host_emb.h" +#include "utils/common.h" + +using namespace std; +using namespace MxRec; + +constexpr int DDR_DEVICE_SIZE = 5; +constexpr int DDR_HOST_SIZE = 15; +constexpr int HBM_DEVICE_SIZE = 20; +constexpr int HBM_HOST_SIZE = 0; + +// test class key_process +class EmbMgmtTest : public testing::Test { +protected: + // create RankInfo rankInfo and const vector &embInfos + RankInfo allRank; + int rankId = 0; + int deviceId = 0; + int rankSize = 1; + int localRankSize = 1; + bool useStatic = false; + std::vector maxStep = { 10, 5 }; + vector embInfos; + EmbInfo embInfo; + string name = "model"; + int sendCount = 5; + int embeddingSize = 8; + size_t devVocabSize = 5; + size_t hostVocabSize = 15; + vector randomInfos; + RandomInfo randomInfo; + int start = 0; + int len = hostVocabSize * embeddingSize; + float constantVal = 0; + float randomMin = -0.1f; + float randomMax = 0.1f; + int seed = 10086; + string loadPath; + HybridMgmt* hybridMgmt; + vector initializeInfos; + InitializeInfo initializeInfo; + ConstantInitializerInfo constantInitializerInfo; + string constantInitializerName = "constant_initializer"; + int nBatch = 10; + + void UpdateEmb(vector &missingKeysHostPos, int channelId, const string &embName, + std::unique_ptr &hostEmb, vector &d2h_emb) + { + spdlog::info(HD + "update emb start"); + if (d2h_emb.size() == 0) { + spdlog::info(HD + "emb is none", channelId); + return; + } + + auto tensorPtr = d2h_emb[0].flat().data(); + for (size_t i = 0; i < missingKeysHostPos.size(); i++) { + (hostEmb->GetEmb(embName).embData[missingKeysHostPos[i]]).assign( + tensorPtr, + tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.embeddingSize); + tensorPtr = tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.embeddingSize; + } + for (size_t i = 0; i < hostEmb->GetEmb(embName).embData.size(); ++i) { + spdlog::info("hostEmb: embName {}, {} is: {}", embName, i, hostEmb->GetEmb(embName).embData[i]); + } + spdlog::info(HD + "update emb end"); + d2h_emb.clear(); + } + + bool Float2TensorVec(const vector>& Datas, vector& tensors) + { + tensors.clear(); + for (auto transferData: Datas) { + Tensor tmpTensor(tensorflow::DT_FLOAT, { (int)transferData.size() }); + auto tmpData = tmpTensor.flat(); + for (uint j = 0; j < transferData.size(); ++j) { + tmpData(j) = transferData[j]; + } + tensors.emplace_back(move(tmpTensor)); + } + return true; + } + + void SetUp() + { + // init key_process (RankInfo rankInfo, const vector &embInfos) + constantInitializerInfo = ConstantInitializerInfo(constantVal); + initializeInfo = InitializeInfo(constantInitializerName, start, embeddingSize, constantInitializerInfo); + initializeInfos.push_back(initializeInfo); + + randomInfo = RandomInfo(start, len, constantVal, randomMin, randomMax); + randomInfos.emplace_back(randomInfo); + } + + void TearDown() + { + // delete + } +}; + +TEST_F(EmbMgmtTest, Initialize) +{ + vector vocabsize = { devVocabSize, hostVocabSize }; + embInfo = EmbInfo(name, sendCount, embeddingSize, vocabsize, initializeInfos); + embInfos.emplace_back(embInfo); + vector thresholdValues = {}; + + auto hybridMgmt = Singleton::GetInstance(); + cout << "setup..." << endl; + + allRank = RankInfo(rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); + hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); + auto hostEmbs = make_unique(); + hostEmbs->Initialize(embInfos, seed, false); + auto hostHashMaps = make_unique(); + hostHashMaps->Init(allRank, embInfos, false); + + int currentBatchId = 0; + vector lookupKeys = { 1, 3, 5, 7 }; + vector tmpData; + vector d2h_emb; + vector> tmpDatas; + tmpData = hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId); + auto missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; + spdlog::info("missingKeys {}", missingKeys); + hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); + auto status = Float2TensorVec(tmpDatas, d2h_emb); + ASSERT_EQ(status, true); + UpdateEmb(missingKeys, 0, embInfo.name, hostEmbs, d2h_emb); + hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos.clear(); + + lookupKeys = { 2, 3, 5, 6 }; + tmpData = hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId); + missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; + spdlog::info("missingKeys {}", missingKeys); + hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); + status = Float2TensorVec(tmpDatas, d2h_emb); + ASSERT_EQ(status, true); + UpdateEmb(missingKeys, 0, embInfo.name, hostEmbs, d2h_emb); + hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos.clear(); + + lookupKeys = { 1, 7, 9, 10 }; + tmpData = hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId); + missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; + spdlog::info("missingKeys {}", missingKeys); + hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); + Float2TensorVec(tmpDatas, d2h_emb); + status = Float2TensorVec(tmpDatas, d2h_emb); + ASSERT_EQ(status, true); + UpdateEmb(missingKeys, 0, embInfo.name, hostEmbs, d2h_emb); + hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos.clear(); + + hybridMgmt->Destroy(); +} + +TEST_F(EmbMgmtTest, Initialize_HBM) +{ + devVocabSize = HBM_DEVICE_SIZE; + hostVocabSize = HBM_HOST_SIZE; + vector vocabsize = { devVocabSize, hostVocabSize }; + embInfo = EmbInfo(name, sendCount, embeddingSize, vocabsize, initializeInfos); + embInfos.emplace_back(embInfo); + vector thresholdValues; + thresholdValues.emplace_back(name, 1, 1); + + auto hybridMgmt = Singleton::GetInstance(); + cout << "setup..." << endl; + allRank = RankInfo(rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); + hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); + + hybridMgmt->Destroy(); +} + +TEST_F(EmbMgmtTest, Evict) +{ + size_t devVocabSize = DDR_DEVICE_SIZE; + size_t hostVocabSize = DDR_HOST_SIZE; + vector vocabsize = { devVocabSize, hostVocabSize }; + embInfo = EmbInfo(name, sendCount, embeddingSize, vocabsize, initializeInfos); + embInfos.emplace_back(embInfo); + vector thresholdValues; + thresholdValues.emplace_back(name, 1, 1); + + auto hybridMgmt = Singleton::GetInstance(); + cout << "setup..." << endl; + allRank = RankInfo(rankId, deviceId, localRankSize, true, nBatch, maxStep); + hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); + + // evict test, ddr + hybridMgmt->Evict(); + + hybridMgmt->Destroy(); +} + +TEST_F(EmbMgmtTest, Evict_HBM) +{ + devVocabSize = HBM_DEVICE_SIZE; + hostVocabSize = HBM_HOST_SIZE; + vector vocabsize = { devVocabSize, hostVocabSize }; + embInfo = EmbInfo(name, sendCount, embeddingSize, vocabsize, initializeInfos); + embInfos.emplace_back(embInfo); + vector thresholdValues; + thresholdValues.emplace_back(name, 1, 1); + + auto hybridMgmt = Singleton::GetInstance(); + cout << "setup..." << endl; + allRank = RankInfo(rankId, deviceId, localRankSize, true, nBatch, maxStep); + hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); + + // evict test, hbm + vector keys = { 1, 3, 5, 7 }; + hybridMgmt->EvictKeys(name, keys); + + hybridMgmt->Destroy(); +} diff --git a/src/tests/emb_table/emb_table_test.cpp b/src/tests/emb_table/emb_table_test.cpp new file mode 100644 index 00000000..0d6716ae --- /dev/null +++ b/src/tests/emb_table/emb_table_test.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: emb table test + * Author: MindX SDK + * Date: 2023/5/6 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/common.h" +#include "emb_table/emb_table.h" + +using namespace std; +using namespace MxRec; +using namespace testing; +using namespace tensorflow; + +class EmbTableTest : public testing::Test { +protected: + void SetUp() + { + spdlog::set_level(spdlog::level::debug); + // 设置测试用的EmbInfo + embInfo.embeddingSize = embTable.TEST_EMB_SIZE; + spdlog::info("EmbTable BLOCK_EMB_COUNT {} INIT_BLOCK_COUNT {}", + embTable.BLOCK_EMB_COUNT, embTable.INIT_BLOCK_COUNT); + rankInfo.rankId = 0; + rankInfo.rankSize = 1; + rankInfo.localRankSize = 1; + rankInfo.useStatic = true; + rankInfo.localRankId = 0; + rankInfo.noDDR = false; + rankInfo.maxStep = { 1, -1 }; + rankInfo.deviceId = 0; + // 初始化EmbeddingTable +#ifndef GTEST + spdlog::info("rank {} running", rankInfo.deviceId); + aclInit(nullptr); +#endif + } + + EmbTable embTable; + EmbInfo embInfo; + RankInfo rankInfo; + aclrtContext context; + + void TearDown() { + } +}; + +// 测试初始化是否正常 +TEST_F(EmbTableTest, Init) +{ +#ifndef GTEST + // 测试初始化是否出现异常 + EXPECT_NO_THROW(embTable.Init(embInfo, rankInfo, 0)); + spdlog::info("embTable Init succeed!"); + ASSERT_EQ(embTable.rankInfo.rankId, rankInfo.rankId); + ASSERT_EQ(embTable.rankInfo.rankSize, rankInfo.rankSize); + ASSERT_EQ(embTable.rankInfo.localRankSize, rankInfo.localRankSize); + ASSERT_EQ(embTable.rankInfo.useStatic, rankInfo.useStatic); + ASSERT_EQ(embTable.rankInfo.localRankId, rankInfo.localRankId); + // 测试容量是否正常 + spdlog::info("totalCapacity {}, INIT_BLOCK_COUNT {}", embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); + EXPECT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); +#endif +} + +// 测试embedding list为空时的情况 +TEST_F(EmbTableTest, GetEmbAddressEmptyList) +{ +#ifndef GTEST + embTable.Init(embInfo, rankInfo, 0); + while (!embTable.embeddingList.empty()) { + float *embAddr = reinterpret_cast(embTable.GetEmbAddress()); + EXPECT_NE(embAddr, nullptr); + } + ASSERT_EQ(embTable.embeddingList.size(), 0); + + float *curAddr = nullptr; + int usedCapacityBefore = embTable.usedCapacity; + ASSERT_NO_THROW({ + curAddr= reinterpret_cast(embTable.GetEmbAddress()); + }); + EXPECT_NE(curAddr, nullptr); + EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore + 1); +#endif +} + +// 测试正常情况 +TEST_F(EmbTableTest, GetEmbAddressNormal) +{ +#ifndef GTEST + embTable.Init(embInfo, rankInfo, 0); + ASSERT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); + float *curAddr = nullptr; + int totalCapacityBefore = embTable.totalCapacity; + int usedCapacityBefore = embTable.usedCapacity; + ASSERT_NO_THROW({ + curAddr = reinterpret_cast(embTable.GetEmbAddress()); + }); + EXPECT_NE(curAddr, nullptr); + EXPECT_EQ(embTable.totalCapacity, totalCapacityBefore); + EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore + 1); +#endif +} + +// 测试将一个emb地址放入embeddingList中,是否成功 +TEST_F(EmbTableTest, PutEmbAddress) +{ +#ifndef GTEST + embTable.Init(embInfo, rankInfo, 0); + int64_t curAddr; + int usedCapacityBefore = embTable.usedCapacity; + ASSERT_NO_THROW({ + curAddr = embTable.GetEmbAddress(); + }); + EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore + 1); + embTable.PutEmbAddress(curAddr); + EXPECT_EQ(embTable.usedCapacity, usedCapacityBefore); + EXPECT_EQ(curAddr, reinterpret_cast(embTable.embeddingList.back())); +#endif +} diff --git a/src/tests/gtest_main.cpp b/src/tests/gtest_main.cpp new file mode 100644 index 00000000..50747a2f --- /dev/null +++ b/src/tests/gtest_main.cpp @@ -0,0 +1,22 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: gtest main + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#include +#include + +int main(int argc, char *argv[]) +{ + int result = 0; + ::testing::InitGoogleTest(&argc, argv); + int provided; + MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); + result = RUN_ALL_TESTS(); + MPI_Finalize(); + + return result; +} \ No newline at end of file diff --git a/src/tests/host_emb/host_emb_test.cpp b/src/tests/host_emb/host_emb_test.cpp new file mode 100644 index 00000000..de74f400 --- /dev/null +++ b/src/tests/host_emb/host_emb_test.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: host emb test + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#include +#include + +#include "host_emb/host_emb.h" +#include "tensorflow/core/framework/tensor.h" + +using namespace std; +using namespace tensorflow; +using namespace MxRec; + +TEST(HostEmb, Tensor2Float) +{ + shared_ptr>> lookups; + vector host_emb; + host_emb.resize(15); + vector> p(5, vector(3)); + host_emb[0] = 1; + host_emb[1] = 3; + std::cout << host_emb[0] << std::endl; + for (int i = 0; i < 5; i++) { + p[i].assign(host_emb.begin() + i * 3, host_emb.begin() + (i + 1) * 3); + } + std::cout << p[0][0] << std::endl; + std::cout << '5' << std::endl; + vector q; + std::cout << '0' << std::endl; + for (int i = 0; i < 2; i++) { + Tensor tmpTensor(tensorflow::DT_INT32, { 3 }); + std::cout << '1' << std::endl; + auto tmpData = tmpTensor.flat(); + std::cout << '2' << std::endl; + for (int j = 0; j < 3; j++) { + tmpData(j) = p[i][j]; + std::cout << '3' << std::endl; + } + + q.emplace_back(tmpTensor); + std::cout << '4' << std::endl; + } + std::cout << '1' << std::endl; + std::cout << q[0].flat()(0) << std::endl; + std::cout << q[0].flat()(1) << std::endl; + std::cout << q[1].flat()(0) << std::endl; + ASSERT_EQ(1, 1); +} diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp new file mode 100644 index 00000000..8f48ec66 --- /dev/null +++ b/src/tests/initializer/initializer_test.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: initializer test + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#include +#include + +#include "initializer/initializer.h" +#include "initializer/constant_initializer/constant_initializer.h" +#include "initializer/truncated_normal_initializer/truncated_normal_initializer.h" +#include "initializer/random_normal_initializer/random_normal_initializer.h" + +using namespace std; +using namespace MxRec; + +TEST(InitializerTest, ConstantInitializerTest) +{ + ConstantInitializer constant_initializer; // start; end; constant_val; + + constant_initializer = ConstantInitializer(1, 5, 7); + + vector> embData; + int vocabSize = 5; + int embeddingSize = 10; + embData.resize(vocabSize, vector(embeddingSize)); + for (int i = 0; i < vocabSize; i++) { + constant_initializer.GenerateData(embData.at(i).data(), embeddingSize); + } + + std::cout << "ConstantInitializerExample:" << std::endl; + for (int i = 0; i < vocabSize; i++) { + for (int j = 0; j < embeddingSize; j++) { + std::cout << embData[i][j] << ' '; + } + std::cout << std::endl; + } + + ASSERT_EQ(embData.at(2).at(2), 7); + ASSERT_EQ(embData.at(2).at(0), 0); +} + +TEST(InitializerTest, TruncatedNormalInitializerTest) +{ + TruncatedNormalInitializer truncatedNormalInitializer; + + truncatedNormalInitializer = TruncatedNormalInitializer(1, 10, 1.0, 0.3, 1); + + vector> embData; + int vocabSize = 5; + int embeddingSize = 11; + embData.resize(vocabSize, vector(embeddingSize)); + for (int i = 0; i < vocabSize; i++) { + truncatedNormalInitializer.GenerateData(embData.at(i).data(), embeddingSize); + } + + std::cout << "mean: " << truncatedNormalInitializer.mean << std::endl; + std::cout << "stddev: " << truncatedNormalInitializer.stddev << std::endl; + std::cout << "minBound: " << truncatedNormalInitializer.minBound << std::endl; + std::cout << "maxBound: " << truncatedNormalInitializer.maxBound << std::endl; + + std::cout << "TruncatedNormalInitializerExample:" << std::endl; + for (int i = 0; i < vocabSize; i++) { + for (int j = 0; j < embeddingSize; j++) { + std::cout << embData[i][j] << ' '; + } + std::cout << std::endl; + } + + ASSERT_EQ(1, 1); +} + +TEST(InitializerTest, RandomNormalInitializerTest) +{ + RandomNormalInitializer randomNormalInitializer(1, 10, 2.0, 0.5, 1); + + vector> embData; + int vocabSize = 5; + int embeddingSize = 11; + embData.resize(vocabSize, vector(embeddingSize)); + for (int i = 0; i < vocabSize; i++) { + randomNormalInitializer.GenerateData(embData.at(i).data(), embeddingSize); + } + + std::cout << "mean: " << randomNormalInitializer.mean << std::endl; + std::cout << "stddev: " << randomNormalInitializer.stddev << std::endl; + + std::cout << "RandomNormalInitializerExample:" << std::endl; + for (int i = 0; i < vocabSize; i++) { + for (int j = 0; j < embeddingSize; j++) { + std::cout << embData[i][j] << ' '; + } + std::cout << std::endl; + } + ASSERT_EQ(1, 1); +} \ No newline at end of file diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp new file mode 100644 index 00000000..b136d984 --- /dev/null +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -0,0 +1,477 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * Description: operator module + * Author: MindX SDK + * Date: 2022/12/08 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils/common.h" +#include "key_process/feature_admit_and_evict.h" +#include "absl/container/flat_hash_map.h" + +using namespace std; +using namespace testing; +using namespace MxRec; + +enum SleepTime : uint32_t { + SLEEP_SECOND_1 = 1, + SLEEP_SECOND_2 = 2, + SLEEP_SECOND_3 = 3, + SLEEP_SECOND_4 = 4, + SLEEP_SECOND_5 = 5, + SLEEP_SECOND_6 = 6, + SLEEP_SECOND_7 = 7, + SLEEP_SECOND_8 = 8, + SLEEP_SECOND_9 = 9, + SLEEP_SECOND_10 = 10 +}; + +using HashMapInfo = absl::flat_hash_map; +struct InputArgs { + keys_t keys; + vector cnt; + keys_t expectKeys; + HashMapInfo lastHistory; + HashMapInfo expectHistory; +}; + +class FeatureAdmitAndEvictTest : public testing::Test { +protected: + HashMapInfo GetHistoryRecords(keys_t& keys, vector& cnt, time_t ts, std::string embName, + HashMapInfo& oldInfos) + { + HashMapInfo newInfos; + unordered_map mergeKeys; + for (size_t i = 0; i < keys.size(); ++i) { + auto it = mergeKeys.find(keys[i]); + if (it == mergeKeys.end()) { + mergeKeys[keys[i]] = cnt[i]; + } else { + mergeKeys[keys[i]] += cnt[i]; + } + } + + for (auto& ele : mergeKeys) { + if (ele.first == -1) { + continue ; + } + + uint32_t oldCnt = 0; + auto it = oldInfos.find(ele.first); + if (it != oldInfos.end()) { // 把原有历史记录累加 + oldCnt = it->second.count; + oldInfos.erase(ele.first); + } + + FeatureItemInfo info = {ele.first, ele.second + oldCnt, embName, ts}; + newInfos.insert(std::pair(ele.first, info)); + } + + for (auto& ele : oldInfos) { // 把原有存在、但当前不存在的,汇总 + newInfos.insert(std::pair(ele.first, ele.second)); + } + + printf("now, expect history info: \n"); + for (auto& ele : newInfos) { + printf("\t{featureId[%ld], count[%d], embName[%s], lastTime[%ld]}\n", ele.second.featureId, + ele.second.count, ele.second.tensorName.c_str(), ele.second.lastTime); + } + printf("\n"); + + return newInfos; + } + bool IsAllTheSameMap(HashMapInfo& records1, HashMapInfo& records2) + { + if (records1.empty() || records1.size() != records2.size()) { + printf("IsAllTheSameMap() 111111\n"); + return false; + } + + for (auto& ele1 : records1) { + FeatureItemInfo& info1 = ele1.second; + auto it = records2.find(ele1.first); + if (it == records2.end()) { + printf("IsAllTheSameMap() 222222\n"); + return false; + } + + FeatureItemInfo& info2 = records2[ele1.first]; + if (info1.featureId != info2.featureId || + info1.count != info2.count || + info1.tensorName != info2.tensorName || + info1.lastTime != info2.lastTime) { + printf("IsAllTheSameMap() 333333\n"); + return false; + } + } + + return true; + } + bool IsAllTheSameVector(keys_t& keys1, keys_t& keys2) + { + printf("\nrun ret: keys1 ===> \n\t"); + for (auto &k1 : keys1) { + printf("%ld ", k1); + } + printf("\nexpect ret: keys2 ===> \n\t"); + for (auto &k2 : keys2) { + printf("%ld ", k2); + } + printf("\n\n"); + + if (keys1.empty() || keys1.size() != keys2.size()) { + printf("IsAllTheSameVector() AAAAAA\n"); + return false; + } + + size_t loopTimes = keys1.size(); + for (size_t i = 0; i < loopTimes; ++i) { + if (keys1[i] != keys2[i]) { + printf("IsAllTheSameVector() BBBBBB\n"); + return false; + } + } + return true; + } + void FeatureAdmitCommon(FeatureAdmitAndEvict& faae, int channel, string embName, InputArgs& args) + { + time_t ts = time(nullptr); + keys_t tmpKeys = args.keys; + std::unique_ptr batch = make_unique(); + batch->name = embName; + batch->timestamp = ts; + printf("\n"); + spdlog::info("current admit embName[{}] at time[{}] ...", embName.c_str(), ts); + + // 校验调接口不出错 + ASSERT_EQ(faae.FeatureAdmit(channel, batch, args.keys, args.cnt) != + FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR, true); + + // 校验特征准入的结果 + ASSERT_EQ(IsAllTheSameVector(args.keys, args.expectKeys), true); + + // 校验历史记录表信息 + args.expectHistory = GetHistoryRecords(tmpKeys, args.cnt, ts, embName, args.lastHistory); + std::lock_guard lock(faae.m_syncMutexs); // 与 evict-thread 竞争资源 + ASSERT_EQ(IsAllTheSameMap(faae.m_recordsData.historyRecords[embName], args.expectHistory), true); + } + + void FeatureAdmitCommonMultiThr(FeatureAdmitAndEvict& faae, int channel, string embName, InputArgs& args) + { + time_t ts = time(nullptr); + keys_t tmpKeys = args.keys; + std::unique_ptr batch = make_unique(); + batch->name = embName; + batch->timestamp = ts; + printf("\n"); + spdlog::info("current admit embName[{}] at time[{}] ...", embName.c_str(), ts); + + // 校验调接口不出错 + faae.FeatureAdmit(channel, batch, args.keys, args.cnt); + } + + void TestCaseHelpMultiThr(std::string thrName) + { + printf("\t############# [%s] tid[%lu] ############# begin ...\n", + thrName.c_str(), std::hash{}(std::this_thread::get_id())); + /* + {"tensorAAA", 2, 5} + keys1 = {11, 11, 33, 44, 11, 55, 88, 55} + cnt1 = 1 2 1 3 1 1 4 1 + */ + InputArgs args1 = {keys1, cnt1, {}, initHistory, {}}; // 每个表的第一次记录,要用initHistory追加 + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args1); + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); + + /* + {"tensorAAA", 2, 5} + keys2 = {11, 12, 33, 21, 11, 12} + cnt2 = 1 2 1 1 2 3 + */ + InputArgs args2 = {keys2, cnt2, {}, args1.expectHistory, {}}; + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args2); + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); + + /* + {"tensorBBB", 3, 7} + keys3 = {123, 121, 121, 212, 211} + cnt3 = 1 2 1 1 2 + */ + InputArgs args3 = {keys3, cnt3, {}, initHistory, {}}; + FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tensorName, args3); + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_6)); + + /* + {"tensorAAA", 2, 5} + keys4 = {11, 11, 33, 44, 55, 88, 55} + cnt4 = 1 2 3 2 1 2 1 + */ + InputArgs args4 = {keys4, cnt4, {}, args2.expectHistory, {}}; + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args4); + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); + + /* + {"tensorBBB", 3, 7} + keys5 = {125, 121, 122, 212, 211} + cnt5 = 1 2 1 3 1 + */ + InputArgs args5 = {keys5, cnt5, {}, args3.expectHistory, {}}; + FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tensorName, args5); + + printf("\t############# [%s] tid[%lu] ############# end ...\n", thrName.c_str(), + std::hash{}(std::this_thread::get_id())); + } + + void StartEvictThread() + { + evictThr = std::thread([&]() { + spdlog::info("Evict-thread start ..."); + + time_t currTime = 0; + time_t lastTime = 0; + while (!isExitFlag) { + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); + currTime = time(nullptr); + if (currTime - lastTime >= SleepTime::SLEEP_SECOND_4) { + spdlog::info("Evict-thread doing at currTime[{}] ...", currTime); + map> evictPosMap {}; + faae.FeatureEvict(evictPosMap); + lastTime = currTime; + } + } + spdlog::info("Evict-thread exit ..."); + }); + } + void WaitEvictThread() + { + map> evictPosMap {}; + faae.FeatureEvict(evictPosMap); // 退出前保证执行了一次“淘汰” + isExitFlag = true; + if (evictThr.joinable()) { + evictThr.join(); + } + } + + // 同时配置“准入、淘汰”阈值 && batch带时间戳,淘汰功能正常 + void TestCase1() + { + faae.ResetAllRecords(); + faae.ParseThresholdCfg(thresholds); + StartEvictThread(); + + printf("Current test single-thread is [%lu]\n", + std::hash{}(std::this_thread::get_id())); + /* + {"tensorAAA", 2, 5} + keys1 = {11, 11, 33, 44, 11, 55, 88, 55} + cnt1 = 1 2 1 3 1 1 4 1 + */ + keys_t expectRet1 = {11, 11, -1, 44, 11, 55, 88, 55}; + InputArgs args1 = {keys1, cnt1, expectRet1, initHistory, {}}; // 每个表的第一次记录,要用initHistory追加 + FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args1); + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); + + /* + {"tensorAAA", 2, 5} + keys2 = {11, 12, 33, 21, 11, 12} + cnt2 = 1 2 1 1 2 3 + */ + keys_t expectRet2 = {11, 12, 33, -1, 11, 12}; + InputArgs args2 = {keys2, cnt2, expectRet2, args1.expectHistory, {}}; + FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args2); + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); + + /* + {"tensorBBB", 3, 7} + keys3 = {123, 121, 121, 212, 211} + cnt3 = 1 2 1 1 2 + */ + keys_t expectRet3 = {-1, 121, 121, -1, -1}; + InputArgs args3 = {keys3, cnt3, expectRet3, initHistory, {}}; + FeatureAdmitCommon(faae, 0, thresholds[1].tensorName, args3); + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_6)); + + /* + {"tensorAAA", 2, 5} + keys4 = {11, 11, 33, 44, 55, 88, 55} + cnt4 = 1 2 3 2 1 2 1 + */ + keys_t expectRet4 = {11, 11, 33, 44, 55, 88, 55}; + InputArgs args4 = {keys4, cnt4, expectRet4, args2.expectHistory, {}}; + FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args4); + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); + + /* + {"tensorBBB", 3, 7} + keys5 = {125, 121, 122, 212, 211} + cnt5 = 1 2 1 3 1 + */ + keys_t expectRet5 = {-1, 121, -1, 212, 211}; + InputArgs args5 = {keys5, cnt5, expectRet5, args3.expectHistory, {}}; + FeatureAdmitCommon(faae, 0, thresholds[1].tensorName, args5); + + WaitEvictThread(); + spdlog::info("TestCase1(): single thread test over ..."); + } + + // 进行“准入”逻辑时,若(splitKey.size() != keyCount.size()),则业务报错退出;(说明是前面all2all通信数据错误) + void TestCase2() + { + faae.ResetAllRecords(); + faae.ParseThresholdCfg(thresholds); + + // 测试点:tmpCnt.size() != tmpKeys.size() + keys_t tmpKeys = {11, 11, 33, 44, 11, 55, 88, 55}; + vector tmpCnt = {1, 2, 1, 3, 1, 1, 4}; + + std::unique_ptr batch = make_unique(); + batch->name = thresholds[0].tensorName; + batch->timestamp = time(nullptr); + + // 校验调接口,出错 + ASSERT_EQ(faae.FeatureAdmit(0, batch, tmpKeys, tmpCnt) == + FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR, true); + spdlog::info("TestCase2() over ..."); + } + + // 准入、淘汰阈值可单独配置;只配置“准入”阈值、却不配置“淘汰”阈值,功能正常; + // 但是,配置了“淘汰”阈值,就必须得有“准入”阈值,才功能正常; + void TestCase3() + {} + + // 传递时间戳,与不传递时,保证batch数据相同 + void TestCase4() + {} + + // 校验多线程跑的结果 + void CheckMultiThreadRet(keys_t& expectKeys, std::vector& expectCnt, const std::string& embName, + int threadCnt) + { + // 校验历史记录表信息 + unordered_map mergeKeys; + for (size_t i = 0; i < expectKeys.size(); ++i) { + mergeKeys.insert(std::pair(expectKeys[i], expectCnt[i] * threadCnt)); + } + + auto &history = faae.m_recordsData.historyRecords[embName]; + for (auto& ele : mergeKeys) { + auto it = history.find(ele.first); + ASSERT_EQ(it != history.end(), true); + ASSERT_EQ((history[ele.first].featureId == ele.first && + history[ele.first].count == ele.second), true); + } + } + static void TestMultiThread(FeatureAdmitAndEvictTest* testObj, std::string& thrName) + { + testObj->TestCaseHelpMultiThr(thrName); + } + // 多线程跑,特征准入&淘汰功能正常; + void TestCase5() + { + faae.ResetAllRecords(); + faae.ParseThresholdCfg(thresholds); + StartEvictThread(); + + std::thread thrs[KEY_PROCESS_THREAD]; + // 测试多线程的 + for (int i = 0; i < KEY_PROCESS_THREAD; ++i) { + std::string name("thread-"); + name += std::to_string(i); + thrs[i] = std::thread(TestMultiThread, this, std::ref(name)); + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); + } + + for (int i = 0; i < KEY_PROCESS_THREAD; ++i) { + if (thrs[i].joinable()) { + thrs[i].join(); + } + } + + std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_8)); + { + /* + 如果没有淘汰功能 + tensorAAA数据将会是 {11, 12, 21, 33, 44, 55, 88} + 10 5 1 5 5 4 6 + tensorBBB数据将会是 {121, 122, 123, 125, 211, 212}; + 5 1 1 1 3 4 + */ + keys_t expectKeys1 = {11, 33, 44, 55, 88}; // 12,21被淘汰掉了 + vector expectCnt1 = {10, 5, 5, 4, 6}; + keys_t expectKeys2 = {121, 122, 125, 211, 212}; // 123被淘汰掉了 + vector expectCnt2 = {5, 1, 1, 3, 4}; + std::lock_guard lock(faae.m_syncMutexs); // 与 evict-thread 竞争资源 + CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tensorName, KEY_PROCESS_THREAD); + CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tensorName, KEY_PROCESS_THREAD); + } + + WaitEvictThread(); + spdlog::info("TestCase5(): multi thread test over ..."); + } + + // 同时不配置“准入、淘汰”阈值,特征准入&淘汰功能“不支持”; + void TestCase6() + { + faae.ResetAllRecords(); + faae.ParseThresholdCfg(thresholds); + + std::unique_ptr batch = make_unique(); + // 测试点:tensorDDD表没有配置阈值,则不支持 + batch->name = std::string("tensorDDD"); + batch->timestamp = time(nullptr); + + // 校验调接口,不支持 + spdlog::info("TestCase6() over ..."); + } + + bool isExitFlag { false }; + HashMapInfo initHistory; + FeatureAdmitAndEvict faae; + std::thread evictThr; + keys_t keys1 = {11, 11, 33, 44, 11, 55, 88, 55}; + vector cnt1 = {1, 2, 1, 3, 1, 1, 4, 1}; + keys_t keys2 = {11, 12, 33, 21, 11, 12}; + vector cnt2 = {1, 2, 1, 1, 2, 3}; + keys_t keys3 = {123, 121, 121, 212, 211}; + vector cnt3 = {1, 2, 1, 1, 2}; + keys_t keys4 = {11, 11, 33, 44, 55, 88, 55}; + vector cnt4 = {1, 2, 3, 2, 1, 2, 1}; + keys_t keys5 = {125, 121, 122, 212, 211}; + vector cnt5 = {1, 2, 1, 3, 1}; + std::vector thresholds = {{"tensorAAA", 2, 5}, {"tensorBBB", 3, 7}, {"tensorCCC", 5, 9}}; +}; + +TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict1) +{ + TestCase1(); +} +TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict2) +{ + TestCase2(); +} +TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict3) +{ + TestCase3(); +} +TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict4) +{ + TestCase4(); +} +TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict5) +{ + TestCase5(); +} +TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict6) +{ + TestCase6(); +} \ No newline at end of file diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp new file mode 100644 index 00000000..c001bb15 --- /dev/null +++ b/src/tests/key_process/key_process_test.cpp @@ -0,0 +1,373 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: key process test + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "utils/common.h" +#include "host_emb/host_emb.h" +#include "key_process/key_process.h" +#include "emb_mgmt/emb_mgmt.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +static constexpr size_t BATCH_NUM_EACH_THREAD = 5; + +class KeyProcessTest : public testing::Test { +protected: + void SetUp() + { + spdlog::set_level(spdlog::level::debug); + int claimed; + MPI_Query_thread(&claimed); + ASSERT_EQ(claimed, MPI_THREAD_MULTIPLE); + MPI_Comm_rank(MPI_COMM_WORLD, &worldRank); + MPI_Comm_size(MPI_COMM_WORLD, &worldSize); + spdlog::info(KEY_PROCESS "wordRank: {}, worldSize: {}", worldRank, worldSize); + // 初始化rank信息 + rankInfo.rankId = worldRank; + rankInfo.rankSize = worldSize; + rankInfo.localRankSize = worldSize; + rankInfo.useStatic = useStatic; + rankInfo.localRankId = rankInfo.rankId % rankInfo.localRankSize; + rankInfo.noDDR = false; + rankInfo.maxStep = { 1, -1 }; + // 初始化emb信息 + GenEmbInfos(embNum, embInfos, fieldNums); + splits = fieldNums; + } + + vector> PrepareBatch() + { + vector> result(KEY_PROCESS_THREAD * MAX_CHANNEL_NUM); + // 向共享队列中写入本进程所有线程要处理的 KEY_PROCESS_THREAD * BATCH_NUM_EACH_THREAD 个batch数据 + for (size_t threadId = 0; threadId < KEY_PROCESS_THREAD; ++threadId) { + int batchQueueId = threadId + KEY_PROCESS_THREAD * channel; + unsigned int seed = batchQueueId * 10; + auto queue = SingletonQueue::getInstances(batchQueueId); + + for (size_t batchNum = 0; batchNum < BATCH_NUM_EACH_THREAD; ++batchNum) { + size_t batchId = + batchNum * KEY_PROCESS_THREAD + threadId; + + for (size_t i = 0; i < embInfos.size(); i++) { // key按照不同emb表的存储切分开 + auto batch = queue->GetOne(); + batch->sample.resize(batchSize * fieldNums[i]); + GenData(batch->sample, 0, seed++); + batch->name = embInfos[i].name; + batch->batchId = batchId; + batch->channel = channel; + spdlog::debug("[{}/{}]" + KEY_PROCESS "PrepareBatch: batchQueueId: {}, {}[{}]{}, sampleSize:{}", worldRank, worldSize, + batchQueueId, batch->name, batch->channel, batch->batchId, batch->sample.size()); + emb_batch_t temp; + temp.sample = batch->sample; + temp.name = batch->name; + temp.batchId = batch->batchId; + temp.channel = batch->channel; + result[batchQueueId].push_back(temp); + queue->Pushv(std::move(batch)); + } + } + } + return result; + } + + // 生成随机数 + template + void GenData(vector& totBatchData, int start, unsigned int seed = 0) + { + default_random_engine generator { seed }; + uniform_int_distribution distribution(start, randMax); + for (size_t i = 0; i < totBatchData.size(); ++i) { + totBatchData[i] = distribution(generator); + } + } + + template + inline vector Count2Start(const vector& count) + { + vector start = { 0 }; + for (size_t i = 0; i < count.size() - 1; ++i) { + start.push_back(count[i] + start.back()); + } + return start; + } + + // 生成emb表信息 + bool GenEmbInfos(size_t embNums, vector& allEmbInfos, vector& geFieldNums) + { + default_random_engine generator; + uniform_int_distribution distribution(randMin, randMax); + int embSizeMin = 5, embSizeMax = 8, base = 2; + uniform_int_distribution embSizeDistribution(embSizeMin, embSizeMax); + stringstream ss; + for (unsigned int i = 0; i < embNums; ++i) { + EmbInfo temp; + ss << i; + temp.name = "emb" + ss.str(); + ss.str(""); + ss.clear(); + temp.sendCount = distribution(generator); + temp.embeddingSize = pow(base, embSizeDistribution(generator)); + geFieldNums.push_back(sampleSize); + allEmbInfos.push_back(move(temp)); + } + return true; + } + + auto GetSplitAndRestore(keys_t& sample) -> tuple, vector> + { + vector expectSplitKeys(worldSize); + vector expectRestore(sample.size()); + absl::flat_hash_map uKey; + for (unsigned int i = 0; i < sample.size(); ++i) { + int devId = sample[i] % worldSize; + auto result = uKey.find(sample[i]); + if (result == uKey.end()) { + expectSplitKeys[devId].push_back(sample[i]); + uKey.insert(make_pair(sample[i], expectSplitKeys[devId].size() - 1)); + expectRestore[i] = expectSplitKeys[devId].size() - 1; + } else { + expectRestore[i] = result->second; + } + } + return { expectSplitKeys, expectRestore }; + } + + void PrintHotHashSplit(const vector& splitKeys, + const vector& restore, + const vector& hotPos, int rankSize) + { + for (int i = 0; i < rankSize; ++i) { + std::cout << "splitKeys dev" << i << std::endl; + spdlog::info("{}", splitKeys[i]); + } + std::cout << "restore" << std::endl; + spdlog::info("{}", restore); + std::cout << "hotPos" << std::endl; + spdlog::info("{}", hotPos); + } + + void GetExpectRestore(keys_t& sample, vector& blockOffset, vector& restoreVec) + { + for (unsigned int i = 0; i < sample.size(); ++i) { + int devId = sample[i] % worldSize; + restoreVec[i] += blockOffset[devId]; + } + } + + enum class A2A { + SC, SS, RC, RS, INVALID + }; + + RankInfo rankInfo; + int worldRank {}; + int worldSize {}; + vector splits; + int sampleSize = 20; + int channel = 0; + int randMin = 10; + int randMax = 25; // 最大随机数范围 + // RankInfo rankInfo + int batchSize = 5; + int localRankSize = 2; + bool useStatic = true; + int staticSendCount = 65536; + + int maxRankSize = 8; + + // vector embInfos + int embNum = 1; + vector fieldNums; + + vector src; + vector allRankInfo; + vector embInfos; + unique_ptr batchData; + vector splitKeys; + vector restore; + KeyProcess process; + + void TearDown() + { + // delete + } +}; + +TEST_F(KeyProcessTest, Initialize) +{ + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.isRunning, true); + ASSERT_EQ(process.rankInfo.rankId, rankInfo.rankId); + ASSERT_EQ(process.rankInfo.rankSize, rankInfo.rankSize); + ASSERT_EQ(process.rankInfo.localRankSize, rankInfo.localRankSize); + ASSERT_EQ(process.rankInfo.useStatic, rankInfo.useStatic); + ASSERT_EQ(process.rankInfo.localRankId, rankInfo.localRankId); + ASSERT_EQ(process.embInfos.size(), embInfos.size()); + for (const EmbInfo& info: embInfos) { + ASSERT_NE(process.embInfos.find(info.name), process.embInfos.end()); + } +} + +TEST_F(KeyProcessTest, Start) +{ + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.isRunning, true); + ASSERT_EQ(process.Start(), 0); + process.Destroy(); +} + +TEST_F(KeyProcessTest, HashSplit) +{ + int rankSize = 4; + auto queue = SingletonQueue::getInstances(0); + auto batch = queue->GetOne(); + keys_t batchKeys = { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }; + vector expectRestore = { 0, 0, 0, 0, 1, 1, 1, 1, 1, 2 }; + vector> expectSplitKeys = { { 4, 16 }, { 1, 21, 29 }, { 14, 2 }, { 23, 7 } }; + batch->sample = std::move(batchKeys); + spdlog::debug(KEY_PROCESS "batch sample: {}", batch->sample); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.isRunning, true); + process.rankInfo.rankSize = rankSize; + auto [splitKeys, restore] = process.HashSplit(batch); + for (unsigned int i = 0; i < splitKeys.size(); ++i) { + ASSERT_THAT(splitKeys[i], ElementsAreArray(expectSplitKeys[i])); + } + ASSERT_THAT(restore, ElementsAreArray(expectRestore)); +} + +TEST_F(KeyProcessTest, GetScAll) +{ + vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 + spdlog::debug(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, keyScLocal); + vector expectScAll(worldSize * worldSize); + for (unsigned int i = 0; i < expectScAll.size(); ++i) { + expectScAll[i] = floor(i / worldSize) + 1; + } + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.isRunning, true); + auto scAll = process.GetScAll(keyScLocal, 0, 0); + ASSERT_THAT(scAll, ElementsAreArray(expectScAll)); +} + +TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) +{ + auto queue = SingletonQueue::getInstances(0); + auto batch = queue->GetOne(); + vector allBatchKeys = { { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }, + { 5, 17, 26, 9, 27, 22, 27, 28, 15, 3 }, + { 10, 4, 22, 17, 24, 13, 24, 26, 29, 11 }, + { 14, 21, 18, 25, 21, 4, 20, 24, 13, 19 } }; + vector> allExpectSs = { { 0, 2, 5, 7, 9 }, { 0, 1, 4, 6 }, { 0, 2, 5, 8 }, { 0, 3, 6, 8 } }; + vector> allExpectRestore = { { 2, 0, 7, 5, 1, 8, 6, 3, 3, 4 }, + { 1, 2, 4, 3, 6, 5, 6, 0, 7, 8 }, + { 5, 0, 6, 2, 1, 3, 1, 7, 4, 8 }, + { 6, 3, 7, 4, 3, 0, 1, 2, 5, 8 } }; + batch->sample = std::move(allBatchKeys[worldRank]); + spdlog::info(KEY_PROCESS "test BuildRestoreVec: rank {}, batchKeys {}", worldRank, batch->sample); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.isRunning, true); + auto [splitKeys, restore] = process.HashSplit(batch); + spdlog::debug("rank: {} splitKeys: {}", worldRank, splitKeys); + process.BuildRestoreVec(batch, allExpectSs[worldRank], restore); + ASSERT_THAT(restore, ElementsAreArray(allExpectRestore[worldRank])); +} + +TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) +{ + PrepareBatch(); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + + auto fn = [this](int channel, int id) { + auto embName = embInfos[0].name; + process.hotEmbTotCount[embName] = 10; + vector splitKeys; + vector restore; + vector hotPos; + unique_ptr batch; + batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue + spdlog::info("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); + tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); + spdlog::info("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, + hotPos); + }; // for clean code + for (int channel = 0; channel < 1; ++channel) { + for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { + process.procThread.emplace_back(fn, channel, id); // use lambda expression initialize thread + } + } + this_thread::sleep_for(20s); + process.Destroy(); +} + +TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) +{ + PrepareBatch(); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + + auto fn = [this](int channel, int id) { + auto embName = embInfos[0].name; + vector splitKeys; + vector restore; + vector hotPos; + unique_ptr batch; + batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue + spdlog::info("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); + tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); + auto[lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); + process.BuildRestoreVec(batch, ss, restore, hotPos.size()); + spdlog::info("rankid :{},batchid: {}, lookupKeys: {}, scAll: {}, restore after build {}", rankInfo.rankId, + batch->batchId, lookupKeys, scAll, restore); + }; // for clean code + for (int channel = 0; channel < 1; ++channel) { + for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { + process.procThread.emplace_back(fn, channel, id); // use lambda expression initialize thread + } + } + this_thread::sleep_for(20s); + process.Destroy(); +} + +TEST_F(KeyProcessTest, Key2Offset) +{ + keys_t lookupKeys = { 4, 16, 28, 4, 24, 4, 20, 24 }; + keys_t expectOffset = { 0, 1, 2, 0, 3, 0, 4, 3 }; + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.isRunning, true); + process.Key2Offset(emb_name_t(), lookupKeys); + spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys, process.keyOffsetMap); + ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); +} +// 自动化测试用例 +// 边界值、重复度测试 +TEST_F(KeyProcessTest, ProcessPrefetchTask) +{ + PrepareBatch(); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + process.rankInfo.rankSize = worldSize; + process.rankInfo.localRankId = process.rankInfo.rankId % process.rankInfo.localRankSize; + ASSERT_EQ(process.isRunning, true); + ASSERT_EQ(process.Start(), 0); + // 所有线程处理完(训练结束)后调用 + this_thread::sleep_for(5s); + spdlog::info("wait 20s for thread running"); + this_thread::sleep_for(20s); + process.Destroy(); +} diff --git a/tools/python/key_2_emb_formatter.py b/tools/python/key_2_emb_formatter.py new file mode 100644 index 00000000..6b2045ca --- /dev/null +++ b/tools/python/key_2_emb_formatter.py @@ -0,0 +1,202 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: +# Author: MindX SDK +# Create: 2023-01-29 + +import argparse +import json +import logging +import os +import re +import numpy as np + + +parser = argparse.ArgumentParser() +parser.add_argument('--path', type=str, required=True, help='path of the root dir of saved file') +parser.add_argument('--name', type=str, default="key_2_embedding", help='name of output file') +parser.add_argument('--ddr', type=bool, default=False, help='if saved data was from ddr mode, default False') +parser.add_argument('--step', type=int, default=0, help='the step when the data was saved, default 0') + + +def get_verified_path(path): + real_path = os.path.realpath(path) + if os.path.exists(real_path): + return real_path + else: + raise NotADirectoryError(f"{path} is not a valid directory") + + +def get_valid_file_name(name): + invalid_symbols = r"[\/\\\:\*\?\"\<\>\|]" + valid_name = re.sub(invalid_symbols, "_", name) + return valid_name + + +class Formatter: + def __init__(self, saved_file_path, out_file_name, is_ddr_mode, step): + self._device_dir_list = ["HashTable", "HBM"] + self._host_dir_list = ["HashTable", "DDR"] + self._device_emb_dir = "embedding" + self._host_emb_dir = "embedding_data" + self._device_hashmap_dir = "key_offset_map" + self._host_hashmap_dir = "embedding_hashmap" + self._attrib_suffix = ".attribute" + self._data_suffix = ".data" + self._out_file_suffix = ".npy" + + self._saved_file_path = get_verified_path(saved_file_path) + self._out_file_name = get_valid_file_name(out_file_name) + self._sub_dirs = self._get_sub_dirs(step) + self._table_names = None + + self._json_attrib_dtype = "data_type" + self._json_attrib_shape = "shape" + self._host_attrib_dtype = np.uint64 + self._hashmap_dtype = np.uint32 + self._raw_key_dtype = np.uint64 + self._key_dtype = np.int64 + self._raw_key_offset = np.iinfo(np.uint32).max + self._data_dtype = None + + self._is_ddr_mode = is_ddr_mode + + def process(self): + dev_dir = self._set_upper_dir(self._sub_dirs[0], self._device_dir_list) + self._table_names = self._get_table_names(dev_dir) + + transformed_data = [] + for table_name in self._table_names: + combined_key = None + combined_emb = None + for sub_dir in self._sub_dirs: + dev_dir = self._set_upper_dir(sub_dir, self._device_dir_list) + host_dir = self._set_upper_dir(sub_dir, self._host_dir_list) + emb_data = self._data_process(dev_dir, host_dir, table_name) + key, offset = self._hashmap_process(dev_dir, host_dir, table_name) + emb_data = emb_data[offset] + + if combined_key is not None: + combined_key = np.append(combined_key, key, axis=0) + else: + combined_key = key + if combined_emb is not None: + combined_emb = np.append(combined_emb, emb_data, axis=0) + else: + combined_emb = emb_data + + logging.debug(f"{table_name} has combined key {combined_key.shape}" + f" and combined emb {combined_emb.shape}") + + transformed_data.append(table_name) + transformed_data.append(combined_key) + transformed_data.append(combined_emb) + + np.save("./" + self._out_file_name + self._out_file_suffix, transformed_data) + + def _data_process(self, dev_dir, host_dir, table_name): + dev_emb_dir = os.path.join(dev_dir, table_name, self._device_emb_dir) + host_emb_dir = os.path.join(host_dir, table_name, self._host_emb_dir) + + data_file, attribute_file = self._get_file_names(dev_emb_dir) + dev_attribute = self._get_attribute(dev_emb_dir, attribute_file, is_json=True) + if not self._data_dtype: + self._data_dtype = dev_attribute.pop(self._json_attrib_dtype) + + dev_data_shape = dev_attribute.pop(self._json_attrib_shape) + emb_data = self._get_data(dev_emb_dir, data_file, self._data_dtype, dev_data_shape) + + if self._is_ddr_mode: + data_file, attribute_file = self._get_file_names(host_emb_dir) + host_attribute = self._get_attribute(host_emb_dir, attribute_file, is_json=False) + host_data_shape = [host_attribute[0], host_attribute[1]] + host_data = self._get_data(host_emb_dir, data_file, self._data_dtype, host_data_shape) + host_data = host_data[:, :dev_data_shape[1]] + emb_data = np.append(emb_data, host_data, axis=0) + + return emb_data + + def _hashmap_process(self, dev_dir, host_dir, table_name): + dev_hashmap_dir = os.path.join(dev_dir, table_name, self._device_hashmap_dir) + host_hashmap_dir = os.path.join(host_dir, table_name, self._host_hashmap_dir) + + if self._is_ddr_mode: + data_file, attribute_file = self._get_file_names(host_hashmap_dir) + else: + data_file, attribute_file = self._get_file_names(dev_hashmap_dir) + + attribute = self._get_attribute(host_hashmap_dir, attribute_file, is_json=False) + data_shape = attribute[:2] + raw_hashmap = self._get_data(host_hashmap_dir, data_file, self._hashmap_dtype, data_shape) + offset = raw_hashmap[:, -1] + raw_key = raw_hashmap[:, :2].astype(self._raw_key_dtype) + key = raw_key[:, 0] * self._raw_key_offset + raw_key[:, 1] + key = key.astype(self._key_dtype) + + return key, offset + + def _get_sub_dirs(self, step): + sub_dirs = [] + for _, sub_dir, _ in os.walk(self._saved_file_path): + sub_dirs.append(sub_dir) + + if not sub_dirs or not sub_dirs[0]: + raise FileNotFoundError(f"There is no sparse checkpoint for given root directory.") + + picked_sub_dirs = [] + for sub_dir in sub_dirs[0]: + if int(sub_dir.split("-")[-1]) == step: + picked_sub_dirs.append(sub_dir) + + if not picked_sub_dirs: + raise FileNotFoundError(f"There is no sparse checkpoint for given training step {step}.") + return picked_sub_dirs + + def _set_upper_dir(self, sub_dir, dir_list): + temp_dir = os.path.join(self._saved_file_path, sub_dir) + for directory in dir_list: + temp_dir = os.path.join(temp_dir, directory) + return temp_dir + + def _get_table_names(self, directory): + if os.path.exists(directory): + table_names = [] + for _, table_name, _ in os.walk(directory): + table_names.append(table_name) + return table_names[0] + else: + raise ValueError("given directory does not contain required subdirectories, cannot search for table names") + + def _get_file_names(self, directory): + files = [] + data_file = None + attribute_file = None + for _, _, file in os.walk(directory): + files.append(file) + for file in files[0]: + if file.find(self._data_suffix) != -1: + data_file = file + elif file.find(self._attrib_suffix) != -1: + attribute_file = file + return data_file, attribute_file + + def _get_attribute(self, directory, file_name, is_json): + file_dir = os.path.join(directory, file_name) + if is_json: + with open(file_dir, "r") as fin: + attributes = json.load(fin) + return attributes + else: + attributes = np.fromfile(file_dir, self._host_attrib_dtype) + return attributes + + def _get_data(self, directory, file_name, dtype, shape): + file_dir = os.path.join(directory, file_name) + data = np.fromfile(file_dir, dtype=dtype) + data = data.reshape(shape) + return data + + +if __name__ == "__main__": + args = parser.parse_args() + formatter = Formatter(saved_file_path=args.path, out_file_name=args.name, is_ddr_mode=args.ddr, step=args.step) + formatter.process() \ No newline at end of file -- Gitee From 84ff1133708fac22471f8fc035b2213aafd8f01e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 16 May 2023 17:23:58 +0800 Subject: [PATCH 004/551] Match-id-c3cf22baece96a7eae44e344fa1aec3904e65b94 --- src/core/checkpoint/checkpoint.cpp | 110 ++++++++++++++---- src/core/checkpoint/checkpoint.h | 6 +- .../ckpt_data_handler/ckpt_data_handler.cpp | 6 + .../ckpt_data_handler/ckpt_data_handler.h | 3 + .../host_emb_ckpt/host_emb_ckpt.cpp | 80 ++++++++----- .../host_emb_ckpt/host_emb_ckpt.h | 11 +- src/core/emb_hashmap/emb_hashmap.cpp | 2 +- src/core/emb_mgmt/emb_mgmt.cpp | 12 +- src/core/host_emb/host_emb.cpp | 20 ++-- src/core/host_emb/host_emb.h | 2 +- src/core/utils/common.h | 4 +- src/tests/checkpoint/checkpoint_test.cpp | 46 ++++++-- 12 files changed, 213 insertions(+), 89 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 549d8974..7b64bb9b 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -61,7 +61,7 @@ void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRa void Checkpoint::SetDataHandler(CkptData& ckptData) { dataHandlers.clear(); - if (!ckptData.hostEmbs.empty()) { + if (ckptData.hostEmbs != nullptr) { dataHandlers.push_back(make_unique()); } if (!ckptData.embHashMaps.empty()) { @@ -265,24 +265,35 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si ofstream writeFile; writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); - if (writeFile.is_open()) { + if (!writeFile.is_open()) { + spdlog::debug("unable to open save file: {}", dataDir); + } + + int loops = 1; + if (dataType == CkptDataType::EMB_DATA) { + loops = transData.floatArr.size(); + } + for (int i = 0; i < loops; i++) { size_t idx = 0; size_t writeSize = 0; - while (dataSize != 0) { - if (dataSize > oneTimeReadWriteLen) { + size_t dataCol = dataSize; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { writeSize = oneTimeReadWriteLen; } else { - writeSize = dataSize; + writeSize = dataCol; + } + if (floatTransSet.find(dataType) != floatTransSet.end()) { + writeFile.write((const char*)(transData.floatArr[i]) + idx, writeSize); + } else { + WriteDataset(transData, writeFile, writeSize, dataType, idx); } - WriteDataset(transData, writeFile, writeSize, dataType, idx); - - dataSize -= writeSize; + dataCol -= writeSize; idx += writeSize; } - } else { - spdlog::debug("unable to open save file: {}", dataDir); } + writeFile.close(); } @@ -296,8 +307,6 @@ void Checkpoint::WriteDataset(CkptTransData& transData, writeFile.write((const char*)(transData.int32Arr.data()) + idx, writeSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { writeFile.write((const char*)(transData.int64Arr.data()) + idx, writeSize); - } else if (floatTransSet.find(dataType) != floatTransSet.end()) { - writeFile.write((const char*)(transData.floatArr.data()) + idx, writeSize); } else if (dataType == CkptDataType::ATTRIBUTE) { writeFile.write((const char*)(transData.attribute.data()) + idx, writeSize); } @@ -313,7 +322,7 @@ void Checkpoint::LoadProcess(CkptData& ckptData) GetUpperLayerLoadDir(dirNames); embNames = GetTableLayerLoadDir(); - LoadDataset(embNames, saveDataTypes, dataHandler); + LoadDataset(embNames, saveDataTypes, dataHandler, ckptData); dataHandler->GetProcessData(ckptData); } @@ -351,7 +360,8 @@ vector Checkpoint::GetTableLayerLoadDir() void Checkpoint::LoadDataset(const vector& embNames, const vector& saveDataTypes, - const unique_ptr& dataHandler) + const unique_ptr& dataHandler, + CkptData& ckptData) { for (const auto& embName : embNames) { auto dataDir { innerDirPath + dirSeparator + embName }; @@ -364,15 +374,20 @@ void Checkpoint::LoadDataset(const vector& embNames, auto attributeDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + attribFileType }; CkptTransData transData; - auto dataElmtBytes { dataHandler->GetDataElmtBytes(saveDataType) }; - - spdlog::debug("====Start reading data from: {}", datasetDir); - ReadStream(transData, datasetDir, saveDataType, dataElmtBytes); spdlog::debug("====Start reading data from: {}", attributeDir); - dataElmtBytes = dataHandler->GetDataElmtBytes(CkptDataType::ATTRIBUTE); + auto dataElmtBytes { dataHandler->GetDataElmtBytes(CkptDataType::ATTRIBUTE) }; ReadStream(transData, attributeDir, CkptDataType::ATTRIBUTE, dataElmtBytes); + dataElmtBytes = dataHandler->GetDataElmtBytes(saveDataType); + if (saveDataType == CkptDataType::EMB_DATA) { + ReadStreamForEmbData(transData, datasetDir, dataElmtBytes, ckptData, embName); + continue; + } else { + spdlog::debug("====Start reading data from: {}", datasetDir); + ReadStream(transData, datasetDir, saveDataType, dataElmtBytes); + } + // load embedding when use dynamic expansion is open if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { auto embedPath { dataDir + dirSeparator + "key_embedding" }; @@ -382,7 +397,11 @@ void Checkpoint::LoadDataset(const vector& embNames, } spdlog::debug("====Start loading data from: {} to data handler.", attributeDir); - dataHandler->SetDataset(saveDataType, embName, transData); + if ((saveDataType == CkptDataType::EMB_INFO)) { + dataHandler->SetDatasetForLoadEmb(saveDataType, embName, transData, ckptData); + } else { + dataHandler->SetDataset(saveDataType, embName, transData); + } } } } @@ -428,6 +447,57 @@ void Checkpoint::ReadStream(CkptTransData& transData, readFile.close(); } +void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, + const string& dataDir, + uint32_t dataElmtBytes, + CkptData& ckptData, + string embName) +{ + if (dataElmtBytes == 0) { + spdlog::error("dataElmtBytes is 0, don't handle [/ %] operation"); + return ; + } + + auto embDataOuterSize = transData.attribute.at(attribEmbDataOuterIdx); + + auto loadHostEmbs = ckptData.hostEmbs; + auto& dst = (*loadHostEmbs)[embName].embData; + dst.reserve(embDataOuterSize); + + std::ifstream readFile; + readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); + + size_t datasetSize = readFile.tellg(); + readFile.seekg(0, std::ios::beg); + + if (datasetSize % embDataOuterSize > 0 || datasetSize % dataElmtBytes > 0) { + spdlog::debug("data is missing or incomplete in load file: {}", dataDir); + } + auto onceReadByteSize { datasetSize / embDataOuterSize }; + + if (!readFile.is_open()) { + spdlog::debug("unable to open load file: {}", dataDir); + } + for (size_t i = 0; i < embDataOuterSize; ++i) { + size_t idx = 0; + size_t readSize = 0; + size_t dataCol = onceReadByteSize; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = dataCol; + } + + readFile.read((char*)(dst[i].data()) + idx, readSize); + + dataCol -= readSize; + idx += readSize; + } + } + readFile.close(); +} + void Checkpoint::SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType) { if (int32TransSet.find(dataType) != int32TransSet.end()) { diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index c1e7d9bc..548517fb 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -64,6 +64,8 @@ namespace MxRec { vector mgmtEmbInfo; const int embHashNum { 2 }; + const int attribEmbDataOuterIdx { 0 }; + const int attribEmbDataInnerIdx { 1 }; void SetDataHandler(CkptData& ckptData); void SetDataHandler(const vector& featureTypes); @@ -87,8 +89,10 @@ namespace MxRec { void GetUpperLayerLoadDir(const vector& dirNames); vector GetTableLayerLoadDir(); void LoadDataset(const vector& embNames, const vector& saveDataTypes, - const unique_ptr& dataHandler); + const unique_ptr& dataHandler, CkptData& ckptData); void ReadStream(CkptTransData& transData, const string& dataDir, CkptDataType dataType, uint32_t dataElmtBytes); + void ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, uint32_t dataElmtBytes, + CkptData& ckptData, string embName); void SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType); void ReadDataset(CkptTransData& transData, ifstream& readFile, size_t readSize, CkptDataType dataType, size_t idx); diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.cpp b/src/core/ckpt_data_handler/ckpt_data_handler.cpp index abe985a7..6273fa66 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.cpp +++ b/src/core/ckpt_data_handler/ckpt_data_handler.cpp @@ -29,4 +29,10 @@ void CkptDataHandler::CleanTransfer() transferData.attribute.clear(); transferData.datasetSize = 0; transferData.attributeSize = 0; +} + +void CkptDataHandler::SetDatasetForLoadEmb(CkptDataType dataType, string embName, CkptTransData& loadedData, + CkptData& ckptData) +{ + throw std::runtime_error("Wrong CkptDataType, only EMB_INFO and EMB_DATA supported for load host emb"); } \ No newline at end of file diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index 87eb9b6a..ecc7907c 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -35,6 +35,9 @@ namespace MxRec { virtual void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) = 0; + virtual void SetDatasetForLoadEmb( + CkptDataType dataType, string embName, CkptTransData& loadedData, CkptData& ckptData); + protected: const vector dataDirNames { "embedding_info", diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp index c4dba3a5..791e330e 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp @@ -15,16 +15,15 @@ using namespace MxRec; void HostEmbCkpt::SetProcessData(CkptData& processData) { - saveHostEmbs.clear(); - loadHostEmbs.clear(); - saveHostEmbs = std::move(processData.hostEmbs); + saveHostEmbs = nullptr; + loadHostEmbs = nullptr; + saveHostEmbs = processData.hostEmbs; } void HostEmbCkpt::GetProcessData(CkptData& processData) { - processData.hostEmbs = std::move(loadHostEmbs); - saveHostEmbs.clear(); - loadHostEmbs.clear(); + saveHostEmbs = nullptr; + loadHostEmbs = nullptr; } vector HostEmbCkpt::GetDataTypes() @@ -40,12 +39,13 @@ vector HostEmbCkpt::GetDirNames() vector HostEmbCkpt::GetEmbNames() { vector embNames; - for (const auto& item : saveHostEmbs) { + for (const auto& item : *saveHostEmbs) { embNames.push_back(item.first); } return embNames; } +// save info and data CkptTransData HostEmbCkpt::GetDataset(CkptDataType dataType, string embName) { map> dataTransMap { { CkptDataType::EMB_INFO, [=] { SetEmbInfoTrans(embName); } }, @@ -58,19 +58,28 @@ CkptTransData HostEmbCkpt::GetDataset(CkptDataType dataType, string embName) void HostEmbCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { - map> dataLoadMap { { CkptDataType::EMB_INFO, [=] { SetEmbInfo(embName); } }, - { CkptDataType::EMB_DATA, [=] { SetEmbData(embName); } } }; + return; +} + +// load info and data +void HostEmbCkpt::SetDatasetForLoadEmb(CkptDataType dataType, string embName, CkptTransData& loadedData, + CkptData& ckptData) +{ + map> dataLoadMap { + { CkptDataType::EMB_INFO, [&] { SetEmbInfo(embName, ckptData); } }, + { CkptDataType::EMB_DATA, [&] { SetEmbData(embName, ckptData); } } }; CleanTransfer(); transferData = move(loadedData); dataLoadMap.at(dataType)(); } +// save Emb info void HostEmbCkpt::SetEmbInfoTrans(string embName) { auto embInfoSize = GetEmbInfoSize(); auto& transArr = transferData.int32Arr; - const auto& hostEmbInfo = saveHostEmbs.at(embName).hostEmbInfo; + const auto& hostEmbInfo = saveHostEmbs->at(embName).hostEmbInfo; transArr.reserve(embInfoSize); transArr.push_back(hostEmbInfo.sendCount); @@ -79,18 +88,21 @@ void HostEmbCkpt::SetEmbInfoTrans(string embName) transArr.push_back(static_cast(hostEmbInfo.hostVocabSize)); } +// save Emb data void HostEmbCkpt::SetEmbDataTrans(string embName) { - auto embDataSize = GetEmbDataSize(embName); - transferData.floatArr.reserve(embDataSize); - for (const auto& item : saveHostEmbs.at(embName).embData) { - transferData.floatArr.insert(transferData.floatArr.end(), item.begin(), item.end()); + auto embDataRows = GetEmbDataRows(embName); + transferData.floatArr.reserve(embDataRows); + for (auto& item : saveHostEmbs->at(embName).embData) { + transferData.floatArr.push_back(&item[0]); } } -void HostEmbCkpt::SetEmbInfo(string embName) +// load Emb info +void HostEmbCkpt::SetEmbInfo(string embName, CkptData& ckptData) { - auto& hostEmbInfo = loadHostEmbs[embName].hostEmbInfo; + loadHostEmbs = ckptData.hostEmbs; + auto& hostEmbInfo = (*loadHostEmbs)[embName].hostEmbInfo; const auto& transArr = transferData.int32Arr; hostEmbInfo.name = embName; @@ -100,20 +112,10 @@ void HostEmbCkpt::SetEmbInfo(string embName) hostEmbInfo.hostVocabSize = static_cast(transArr.at(attribEmbInfoHostVocabIdx)); } -void HostEmbCkpt::SetEmbData(string embName) +// load Emb data +void HostEmbCkpt::SetEmbData(string embName, CkptData& ckptData) { - vector embValues; - auto embDataOuterSize = transferData.attribute.at(attribEmbDataOuterIdx); - auto embDataInnerSize = transferData.attribute.at(attribEmbDataInnerIdx); - auto rawBegin = transferData.floatArr.begin(); - loadHostEmbs[embName].embData.reserve(embDataOuterSize); - for (size_t i = 0; i < embDataOuterSize; ++i) { - size_t beginShift = i * embDataInnerSize; - size_t endShift = (i + 1) * embDataInnerSize; - embValues.reserve(embDataInnerSize); - embValues.insert(embValues.begin(), rawBegin + beginShift, rawBegin + endShift); - loadHostEmbs[embName].embData.push_back(move(embValues)); - } + return; } int HostEmbCkpt::GetEmbInfoSize() @@ -128,10 +130,10 @@ int HostEmbCkpt::GetEmbInfoSize() size_t HostEmbCkpt::GetEmbDataSize(string embName) { - auto embDataOuterSize = saveHostEmbs.at(embName).embData.size(); + auto embDataOuterSize = saveHostEmbs->at(embName).embData.size(); transferData.attribute.push_back(embDataOuterSize); - auto embDataInnerSize = saveHostEmbs.at(embName).embData.at(0).size(); + auto embDataInnerSize = saveHostEmbs->at(embName).embData.at(0).size(); transferData.attribute.push_back(embDataInnerSize); transferData.attribute.push_back(fourBytes); @@ -140,4 +142,20 @@ size_t HostEmbCkpt::GetEmbDataSize(string embName) transferData.attributeSize = transferData.attribute.size() * eightBytes; return embDataOuterSize * embDataInnerSize; +} + +size_t HostEmbCkpt::GetEmbDataRows(string embName) +{ + auto embDataOuterSize = saveHostEmbs->at(embName).embData.size(); + transferData.attribute.push_back(embDataOuterSize); + + auto embDataInnerSize = saveHostEmbs->at(embName).embData.at(0).size(); + transferData.attribute.push_back(embDataInnerSize); + + transferData.attribute.push_back(fourBytes); + + transferData.datasetSize = embDataInnerSize * fourBytes; + transferData.attributeSize = transferData.attribute.size() * eightBytes; + + return embDataOuterSize; } \ No newline at end of file diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h index b48e3459..c1b95ff8 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h @@ -28,6 +28,8 @@ namespace MxRec { CkptTransData GetDataset(CkptDataType dataType, string embName) override; void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) override; + void SetDatasetForLoadEmb( + CkptDataType dataType, string embName, CkptTransData& loadedData, CkptData& ckptData) override; private: const vector fileDirNames { "HashTable", "DDR" }; @@ -42,17 +44,18 @@ namespace MxRec { const int attribEmbDataInnerIdx { 1 }; const int embSveElmtNum { 4 }; - emb_mem_t saveHostEmbs; - emb_mem_t loadHostEmbs; + emb_mem_t* saveHostEmbs; + emb_mem_t* loadHostEmbs; void SetEmbInfoTrans(string embName); void SetEmbDataTrans(string embName); - void SetEmbInfo(string embName); - void SetEmbData(string embName); + void SetEmbInfo(string embName, CkptData& ckptData); + void SetEmbData(string embName, CkptData& ckptData); int GetEmbInfoSize(); size_t GetEmbDataSize(string embName); + size_t GetEmbDataRows(string embName); }; } diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 4f739bf3..8f446a65 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. * Description: common module * Author: MindX SDK * Date: 2022/11/15 diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index c1293243..03f696cb 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -119,6 +119,7 @@ bool HybridMgmt::Load(const string& loadPath) loadFeatures.push_back(CkptFeatureType::FEAT_ADMIT_N_EVICT); } + loadData.hostEmbs = hostEmbs->GetHostEmbs(); loadCkpt.LoadModel(loadPath, loadData, mgmtRankInfo, mgmtEmbInfo, loadFeatures); if (!mgmtRankInfo.noDDR && !LoadMatchesDDRSetup(loadData)) { return false; @@ -126,7 +127,6 @@ bool HybridMgmt::Load(const string& loadPath) if (!mgmtRankInfo.noDDR) { spdlog::debug(MGMT + "Start host side load: ddr mode hashmap"); - hostEmbs->LoadEmb(loadData.hostEmbs); hostHashMaps->LoadHashMap(loadData.embHashMaps); } else { spdlog::debug(MGMT + "Start host side load: no ddr mode hashmap"); @@ -154,10 +154,10 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) { bool loadDataMatches { true }; size_t embTableCount { 0 }; - const auto& loadHostEmbs { loadData.hostEmbs }; + auto loadHostEmbs { loadData.hostEmbs }; for (const auto& setupHostEmbs : mgmtEmbInfo) { - const auto& loadEmbTable { loadHostEmbs.find(setupHostEmbs.name) }; - if (loadEmbTable != loadHostEmbs.end()) { + const auto& loadEmbTable { loadHostEmbs->find(setupHostEmbs.name) }; + if (loadEmbTable != loadHostEmbs->end()) { embTableCount++; const auto& loadEmbInfo { loadEmbTable->second.hostEmbInfo }; @@ -190,9 +190,9 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) } } - if (embTableCount < loadHostEmbs.size()) { + if (embTableCount < loadHostEmbs->size()) { spdlog::error(MGMT + "Load data has {} tables more than setup table num {}", - loadHostEmbs.size(), embTableCount); + loadHostEmbs->size(), embTableCount); return false; } return true; diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index dfed1b24..203c9bbc 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -21,15 +21,13 @@ using namespace chrono; bool HostEmb::Initialize(const vector& embInfos, int seed, bool ifLoad) { - if (!ifLoad) { - for (const auto& embInfo: embInfos) { - HostEmbTable hostEmb; - hostEmb.hostEmbInfo = embInfo; - EmbDataGenerator(embInfo.initializeInfos, seed, embInfo.hostVocabSize, embInfo.embeddingSize, - hostEmb.embData); - hostEmbs[embInfo.name] = move(hostEmb); - spdlog::info(HOSTEMB + "HostEmb Initialize End"); - } + for (const auto& embInfo: embInfos) { + HostEmbTable hostEmb; + hostEmb.hostEmbInfo = embInfo; + EmbDataGenerator(embInfo.initializeInfos, seed, embInfo.hostVocabSize, embInfo.embeddingSize, + hostEmb.embData); + hostEmbs[embInfo.name] = move(hostEmb); + spdlog::info(HOSTEMB + "HostEmb Initialize End"); } return true; } @@ -196,9 +194,9 @@ vector HostEmb::GetH2DEmb(const vector& missingKeysHostPos, cons return h2d_emb; } -auto HostEmb::GetHostEmbs() -> absl::flat_hash_map +auto HostEmb::GetHostEmbs() -> absl::flat_hash_map* { - return hostEmbs; + return &hostEmbs; } EmbInfo::EmbInfo(const string &name, int sendCount, int embeddingSize, vector vocabsize, diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index b23175ba..31928a27 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -40,7 +40,7 @@ namespace MxRec { vector GetH2DEmb(const vector& missingKeysHostPos, const string& embName); - auto GetHostEmbs() -> absl::flat_hash_map; + auto GetHostEmbs() -> absl::flat_hash_map*; void EvictInitEmb(const string& embName, const vector& offset); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index d7290fe1..51831b2c 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -393,7 +393,7 @@ struct BatchTask { }; struct CkptData { - emb_mem_t hostEmbs; + emb_mem_t* hostEmbs = nullptr; emb_hash_mem_t embHashMaps; offset_mem_t maxOffset; key_offset_mem_t keyOffsetMap; @@ -403,7 +403,7 @@ struct BatchTask { struct CkptTransData { std::vector int64Arr; - std::vector floatArr; + std::vector floatArr; std::vector int32Arr; std::vector transDataset; // may all use this to transfer data std::vector attribute; // may need to use other form for attributes diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 3cff5e1f..c5cbda50 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -21,6 +21,8 @@ using namespace std; using namespace MxRec; +const float MEM_INIT_VALUE = 0.5; + class CheckpointTest : public testing::Test { protected: string testPath { "./ckpt_mgmt_test" }; @@ -37,7 +39,7 @@ protected: int embInfoNum { 10 }; - float floatMem { 0.5 }; + float floatMem { MEM_INIT_VALUE }; int64_t featMem { static_cast(UINT32_MAX) }; int32_t offsetMem { 0 }; @@ -79,6 +81,7 @@ protected: void SetEmbData(vector>& testEmbData) { testEmbData.resize(hostVocabSize); + floatMem = MEM_INIT_VALUE; for (auto& testData : testEmbData) { testData.resize(embeddingSize); for (auto& testValue : testData) { @@ -88,13 +91,30 @@ protected: } } - void SetHostEmbs(emb_mem_t& testHostEmbs) + void SetHostEmbs(std::shared_ptr testHostEmbs) { vector> testEmbData; for (const auto& testEmbInfo : testEmbInfos) { SetEmbData(testEmbData); HostEmbTable embTable { testEmbInfo, move(testEmbData) }; - testHostEmbs[testEmbInfo.name] = move(embTable); // set test input data + testHostEmbs->insert({testEmbInfo.name, move(embTable)}); // set test input data + } + } + + void SetHostEmptyEmbs(std::shared_ptr loadHostEmbs) + { + vector> testEmbData; + for (const auto& testEmbInfo : testEmbInfos) { + // SetEmbData + testEmbData.resize(hostVocabSize); + for (auto& testData : testEmbData) { + testData.resize(embeddingSize); + for (auto& testValue : testData) { + testValue = 0; + } + } + HostEmbTable embTable { testEmbInfo, move(testEmbData) }; + loadHostEmbs->insert({testEmbInfo.name, move(embTable)}); // set test input data } } @@ -206,28 +226,30 @@ protected: TEST_F(CheckpointTest, HostEmbs) { - emb_mem_t testHostEmbs; - emb_mem_t validHostEmbs; - + std::shared_ptr testHostEmbs = std::make_shared(); SetEmbInfo(); SetHostEmbs(testHostEmbs); - validHostEmbs = testHostEmbs; + shared_ptr validHostEmbs = std::make_shared(); + SetHostEmbs(validHostEmbs); + shared_ptr loadHostEmbs = std::make_shared(); + SetHostEmptyEmbs(loadHostEmbs); CkptData testSaveData; CkptData validLoadData; CkptData testLoadData; - testSaveData.hostEmbs = testHostEmbs; - validLoadData.hostEmbs = validHostEmbs; + testSaveData.hostEmbs = testHostEmbs.get(); + validLoadData.hostEmbs = validHostEmbs.get(); + testLoadData.hostEmbs = loadHostEmbs.get(); Checkpoint testCkpt; testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::HOST_EMB }); - for (const auto& it : validLoadData.hostEmbs) { - const auto& embInfo = testLoadData.hostEmbs.at(it.first).hostEmbInfo; - const auto& embData = testLoadData.hostEmbs.at(it.first).embData; + for (const auto& it : *validLoadData.hostEmbs) { + const auto& embInfo = testLoadData.hostEmbs->at(it.first).hostEmbInfo; + const auto& embData = testLoadData.hostEmbs->at(it.first).embData; EXPECT_EQ(it.second.hostEmbInfo.name, embInfo.name); EXPECT_EQ(it.second.hostEmbInfo.sendCount, embInfo.sendCount); -- Gitee From b4d49e0e69b173c375185d79c14b84dd9b955576 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 16 May 2023 17:28:27 +0800 Subject: [PATCH 005/551] Match-id-5042cc0db0cdd1c5b924cd712cb5b0e725d9e5ef --- src/core/checkpoint/checkpoint.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 7b64bb9b..498386f6 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -267,6 +267,8 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si if (!writeFile.is_open()) { spdlog::debug("unable to open save file: {}", dataDir); + writeFile.close(); + return; } int loops = 1; @@ -471,12 +473,15 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, readFile.seekg(0, std::ios::beg); if (datasetSize % embDataOuterSize > 0 || datasetSize % dataElmtBytes > 0) { - spdlog::debug("data is missing or incomplete in load file: {}", dataDir); + spdlog::error("data is missing or incomplete in load file: {}", dataDir); + throw runtime_error("unable to load EMB_DATA cause wrong-format saved emb data"); } auto onceReadByteSize { datasetSize / embDataOuterSize }; if (!readFile.is_open()) { spdlog::debug("unable to open load file: {}", dataDir); + readFile.close(); + return; } for (size_t i = 0; i < embDataOuterSize; ++i) { size_t idx = 0; -- Gitee From dfedb5725f270a018cc871f4d56de7734fa20393 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 17 May 2023 14:32:18 +0800 Subject: [PATCH 006/551] Match-id-1a799e57346b4ba1e308890e7ecdb88d9f997816 --- build/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/build.sh b/build/build.sh index 4f3de7ec..7f68d243 100644 --- a/build/build.sh +++ b/build/build.sh @@ -82,7 +82,7 @@ src_path=${project_root_folder}/src cd "${project_root_folder}" -release_tar=Ascend-${pkg_dir}-${VERSION}-linux-${ARCH}.tar.gz +release_tar=Ascend-${pkg_dir}_${VERSION}_linux-${ARCH}.tar.gz install_abseil() { -- Gitee From 9e92b95bbf6fe6909e55220442b4a26295e99fcc Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 18 May 2023 15:37:55 +0800 Subject: [PATCH 007/551] Match-id-f8dc3e21fee5102689d76dba8ae6ba5ebc0e5312 --- mx_rec/core/asc/helper.py | 55 ++++++++++++++++++++++++++++++++++++-- mx_rec/core/asc/manager.py | 6 +++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 756f5b0d..b2f0fe5e 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -43,7 +43,52 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names feature_counts=feature_counts, table_names=table_names, **kwargs) - +def check_tensor(table_name,table_reachable_tensor): + if "/update_"+table_name+"/" in table_reachable_tensor.name or table_reachable_tensor.op.type == 'ApplyAdam': + return True + if 'gradients/' in table_reachable_tensor.name: + return True + return False +def find_dangling_table(table_names): + op_list = tf.get_default_graph().get_operations() + table_lookup_op = {} + table_reachable_tensor = {} + for op in op_list: + for tn in table_names: + if tn in op.name and op.type == "IdentityN": + if tn not in table_lookup_op: + table_lookup_op[tn] = [op] + table_reachable_tensor[tn] = op.outputs + else: + table_lookup_op[tn].append(op) + table_reachable_tensor[tn].extend(op.outputs) + logging.info(f"*********** find tables: {table_lookup_op}***********") + logging.info(f"looking for dangling table") + dangling_table = [] + for k in table_reachable_tensor: + tmp_tensor = table_reachable_tensor[k] + tensors_to_keep = set() + found = False + while tmp_tensor: + out_tensor = [] + + for tensor in tmp_tensor: + if tensor in tensors_to_keep: + continue + if check_tensor(k,tensor): + found = True + break + tensors_to_keep.add(tensor) + for op in op_list: + if tensor in op.inputs: + out_tensor.extend(op.outputs) + if found: + break + tmp_tensor = out_tensor + if not found: + dangling_table.append(k) + + return dangling_table def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_counts=None, table_names=None, **kwargs): @@ -82,10 +127,16 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ else: if feature_counts is None or table_names is None: raise ValueError("Please config 'args_index_list', 'feature_counts' and 'table_names' at the same time.") - + logging.info(f"all table_names: {table_names}") + dangling_tables = find_dangling_table(table_names) + for table_name in dangling_tables: + logging.info(f"In insert found dangling table: {table_name} which does not need to be provided to the EmbInfo.") + table_names.remove(table_name) + logging.info(f"used table_names: {table_names}") def insert_fn_for_arg_indexes(*args): insert_tensors = get_target_tensors_with_args_indexes(args_index_list) # config timestamp later + logging.debug(f"do_insert without spec for {table_names}") splits = [] for insert_tensor in insert_tensors: diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 6ea99612..378a2763 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -11,6 +11,7 @@ from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, se is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ get_use_hot, get_use_dynamic_expansion, export_optimizer +from mx_rec.core.asc.helper import find_dangling_table def generate_table_info_list(): @@ -28,11 +29,16 @@ def generate_table_info_list(): optimizer = export_optimizer() # generate table info + dangling_table = find_dangling_table([table_instance.table_name for _, table_instance in export_table_instances().items()]) for _, table_instance in export_table_instances().items(): # When dynamic expansion mode, ext_emb_size is set by optimizer if optimizer is not None: table_instance.ext_emb_size = table_instance.scalar_emb_size * (1 + optimizer.slot_num) logging.debug(f"ext_emb_size is reset to be {table_instance.ext_emb_size} for EmbInfo") + + if table_instance.table_name in dangling_table: + logging.info(f"Found dangling table: {table_instance.table_name} which does not need to be provided to the EmbInfo.") + continue # Only the tables that need to be used after table combination are retained in meituan situation. # Current solution has error in same situations. For example, a sparse table has not been auto-merged. logging.debug(f"In EmbInfo, ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") -- Gitee From 0580d38341b979f017af57715f36e213a08af75d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 18 May 2023 16:41:42 +0800 Subject: [PATCH 008/551] Match-id-342689e6d34c7d3f895d78403c68db9f07c485d1 --- mx_rec/util/initialize.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 2cb6d235..bf0bb3f2 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -111,12 +111,17 @@ class ConfigInitializer: raise AttributeError(f"Lack of attribute device.") for server_list in table_hccl.get("server_list"): - for device in server_list.get("device"): + devices = server_list.get("device") + if devices is None: + raise ValueError("device is empty") + local_rank_size = len(devices) + for device in devices: if "rank_id" not in device or not device["rank_id"].isdigit(): raise ValueError(f"hccl_json rank_id wrong.") + rank_id = int(device["rank_id"]) if "device_id" not in device or not device["device_id"].isdigit(): raise ValueError(f"hccl_json device_id wrong.") - self._rank_to_device_dict[int(device["rank_id"])] = int(device["device_id"]) + self._rank_to_device_dict[rank_id] = rank_id % local_rank_size def set_device_dict(self): ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") @@ -135,7 +140,7 @@ class ConfigInitializer: rank_size = int(rank_size) local_rank_size = rank_size if rank_size < 8 else 8 for device_index in range(rank_size): - self._rank_to_device_dict[device_index] = int(device_index % local_rank_size) + rank_start + self._rank_to_device_dict[device_index] = int(device_index % local_rank_size) else: raise ValueError("get CM_WORKER_SIZE failed.") -- Gitee From eb8ca4be1075c6a109739d93abf13f390fc8719e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 18 May 2023 20:07:35 +0800 Subject: [PATCH 009/551] Match-id-9a6d6288196258d90f2df9cc605abb3977a5a5c2 --- mx_rec/core/asc/helper.py | 81 ++++++++++++++++++++++---------------- mx_rec/core/asc/manager.py | 6 ++- 2 files changed, 52 insertions(+), 35 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index b2f0fe5e..cbf935ab 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -43,53 +43,64 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names feature_counts=feature_counts, table_names=table_names, **kwargs) -def check_tensor(table_name,table_reachable_tensor): - if "/update_"+table_name+"/" in table_reachable_tensor.name or table_reachable_tensor.op.type == 'ApplyAdam': + + +def check_tensor(table_name, table_reachable_tensor): + if "/update_" + table_name + "/" in table_reachable_tensor.name \ + or table_reachable_tensor.op.type == 'ApplyAdam': return True if 'gradients/' in table_reachable_tensor.name: return True return False + + def find_dangling_table(table_names): + def find_table_op(table_name, op, table_lookup_op, table_reachable_tensor): + if table_name in op.name and op.type == "IdentityN": + if table_name not in table_lookup_op: + table_lookup_op[table_name] = [op] + table_reachable_tensor[table_name] = op.outputs + else: + table_lookup_op[table_name].append(op) + table_reachable_tensor[table_name].extend(op.outputs) + op_list = tf.get_default_graph().get_operations() table_lookup_op = {} table_reachable_tensor = {} for op in op_list: - for tn in table_names: - if tn in op.name and op.type == "IdentityN": - if tn not in table_lookup_op: - table_lookup_op[tn] = [op] - table_reachable_tensor[tn] = op.outputs - else: - table_lookup_op[tn].append(op) - table_reachable_tensor[tn].extend(op.outputs) + for table_name in table_names: + find_table_op(table_name, op, table_lookup_op, table_reachable_tensor) + logging.info(f"*********** find tables: {table_lookup_op}***********") logging.info(f"looking for dangling table") dangling_table = [] - for k in table_reachable_tensor: - tmp_tensor = table_reachable_tensor[k] - tensors_to_keep = set() - found = False - while tmp_tensor: - out_tensor = [] - - for tensor in tmp_tensor: - if tensor in tensors_to_keep: + + def extend(op_list, tensor, spread_tensors): + for op in op_list: + if tensor in op.inputs: + spread_tensors.extend(op.outputs) + + def bfs_lookup(next_to_visit): + tensors_visited = set() + while next_to_visit: + spread_tensors = [] + for tensor in next_to_visit: + if tensor in tensors_visited: continue - if check_tensor(k,tensor): - found = True - break - tensors_to_keep.add(tensor) - for op in op_list: - if tensor in op.inputs: - out_tensor.extend(op.outputs) - if found: - break - tmp_tensor = out_tensor + if check_tensor(k, tensor): + return True + tensors_visited.add(tensor) + extend(op_list, tensor, spread_tensors) + next_to_visit = spread_tensors + return False + + for k in table_reachable_tensor: + found = bfs_lookup(table_reachable_tensor[k]) if not found: dangling_table.append(k) - return dangling_table + def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_counts=None, table_names=None, **kwargs): both_none = tgt_key_specs is None and args_index_list is None @@ -109,8 +120,10 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ if len(args) == 1: data_src = args[0] - read_emb_key_inputs_dict = {"insert_tensors": [], "table_names": [], - "feature_spec_names": [], "splits": []} + read_emb_key_inputs_dict = { + "insert_tensors": [], "table_names": [], + "feature_spec_names": [], "splits": [] + } get_target_tensors_with_feature_specs(tgt_key_specs, data_src, is_training, read_emb_key_inputs_dict) logging.debug(f"do_insert with spec for {read_emb_key_inputs_dict['table_names']}") return do_insert(args, @@ -130,9 +143,11 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ logging.info(f"all table_names: {table_names}") dangling_tables = find_dangling_table(table_names) for table_name in dangling_tables: - logging.info(f"In insert found dangling table: {table_name} which does not need to be provided to the EmbInfo.") + logging.info(f"In insert found dangling table: {table_name} " + f"which does not need to be provided to the EmbInfo.") table_names.remove(table_name) logging.info(f"used table_names: {table_names}") + def insert_fn_for_arg_indexes(*args): insert_tensors = get_target_tensors_with_args_indexes(args_index_list) # config timestamp later diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 378a2763..0d8807f2 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -29,7 +29,8 @@ def generate_table_info_list(): optimizer = export_optimizer() # generate table info - dangling_table = find_dangling_table([table_instance.table_name for _, table_instance in export_table_instances().items()]) + dangling_table = find_dangling_table([table_instance.table_name + for _, table_instance in export_table_instances().items()]) for _, table_instance in export_table_instances().items(): # When dynamic expansion mode, ext_emb_size is set by optimizer if optimizer is not None: @@ -37,7 +38,8 @@ def generate_table_info_list(): logging.debug(f"ext_emb_size is reset to be {table_instance.ext_emb_size} for EmbInfo") if table_instance.table_name in dangling_table: - logging.info(f"Found dangling table: {table_instance.table_name} which does not need to be provided to the EmbInfo.") + logging.info(f"Found dangling table: {table_instance.table_name} " + f"which does not need to be provided to the EmbInfo.") continue # Only the tables that need to be used after table combination are retained in meituan situation. # Current solution has error in same situations. For example, a sparse table has not been auto-merged. -- Gitee From 8a7dc0de74fc34eb110dad6a7172c7508da49b52 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 18 May 2023 20:51:03 +0800 Subject: [PATCH 010/551] Match-id-24fe35ee1ccb5a0c0b790b82fd81b7f1435d0305 --- mx_rec/core/asc/helper.py | 49 +++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index cbf935ab..f213608c 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -45,59 +45,58 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names **kwargs) -def check_tensor(table_name, table_reachable_tensor): - if "/update_" + table_name + "/" in table_reachable_tensor.name \ - or table_reachable_tensor.op.type == 'ApplyAdam': - return True - if 'gradients/' in table_reachable_tensor.name: - return True - return False - - def find_dangling_table(table_names): - def find_table_op(table_name, op, table_lookup_op, table_reachable_tensor): - if table_name in op.name and op.type == "IdentityN": + def check_tensor(table_name, table_reachable_tensor): + if ''.join(["/update_", table_name]) in table_reachable_tensor.name \ + or table_reachable_tensor.op.type == 'ApplyAdam': + return True + if 'gradients/' in table_reachable_tensor.name: + return True + return False + + def find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor): + if table_name in the_op.name and the_op.type == "IdentityN": if table_name not in table_lookup_op: - table_lookup_op[table_name] = [op] - table_reachable_tensor[table_name] = op.outputs + table_lookup_op[table_name] = [the_op] + table_reachable_tensor[table_name] = the_op.outputs else: - table_lookup_op[table_name].append(op) - table_reachable_tensor[table_name].extend(op.outputs) + table_lookup_op[table_name].append(the_op) + table_reachable_tensor[table_name].extend(the_op.outputs) op_list = tf.get_default_graph().get_operations() table_lookup_op = {} table_reachable_tensor = {} - for op in op_list: + for the_op in op_list: for table_name in table_names: - find_table_op(table_name, op, table_lookup_op, table_reachable_tensor) + find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) logging.info(f"*********** find tables: {table_lookup_op}***********") logging.info(f"looking for dangling table") dangling_table = [] def extend(op_list, tensor, spread_tensors): - for op in op_list: - if tensor in op.inputs: - spread_tensors.extend(op.outputs) + for the_op in op_list: + if tensor in the_op.inputs: + spread_tensors.extend(the_op.outputs) - def bfs_lookup(next_to_visit): + def bfs_lookup(table_name,next_to_visit): tensors_visited = set() while next_to_visit: spread_tensors = [] for tensor in next_to_visit: if tensor in tensors_visited: continue - if check_tensor(k, tensor): + if check_tensor(table_name, tensor): return True tensors_visited.add(tensor) extend(op_list, tensor, spread_tensors) next_to_visit = spread_tensors return False - for k in table_reachable_tensor: - found = bfs_lookup(table_reachable_tensor[k]) + for table_name, table_op in table_reachable_tensor.items(): + found = bfs_lookup(table_name,table_op) if not found: - dangling_table.append(k) + dangling_table.append(table_name) return dangling_table -- Gitee From 4341fd7e83d0e787f2e8065938f9d2922baa558f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 18 May 2023 20:54:32 +0800 Subject: [PATCH 011/551] Match-id-4e6e67fc64d73e67f635db847799bc5374fbdf3a --- mx_rec/core/asc/helper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index f213608c..e14128f8 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -79,7 +79,7 @@ def find_dangling_table(table_names): if tensor in the_op.inputs: spread_tensors.extend(the_op.outputs) - def bfs_lookup(table_name,next_to_visit): + def bfs_lookup(table_name, next_to_visit): tensors_visited = set() while next_to_visit: spread_tensors = [] @@ -94,7 +94,7 @@ def find_dangling_table(table_names): return False for table_name, table_op in table_reachable_tensor.items(): - found = bfs_lookup(table_name,table_op) + found = bfs_lookup(table_name, table_op) if not found: dangling_table.append(table_name) return dangling_table -- Gitee From c220a4bcbafceebf6faabe6b57b480937ec7d29e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 19 May 2023 16:22:02 +0800 Subject: [PATCH 012/551] Match-id-f57185292c977c4f8153daa27a5c0e6dcf921f06 --- example/little_demo/main.py | 20 ++++++++-------- mx_rec/graph/modifier.py | 5 ++-- mx_rec/optimizers/lazy_adam_by_addr.py | 33 ++++++++++++++++---------- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 671515d4..e29d36a9 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -208,7 +208,7 @@ if __name__ == "__main__": grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) - saver = tf.compat.v1.train.Saver() + # saver = tf.compat.v1.train.Saver() if MODIFY_GRAPH_FLAG: logging.info("start to modifying graph") modify_graph_and_start_emb_cache(dump_graph=True) @@ -222,10 +222,10 @@ if __name__ == "__main__": sess.run(train_iterator.initializer) sess.run(tf.compat.v1.global_variables_initializer()) EPOCH = 0 - if os.path.exists(f"./saved-model/sparse-model-{rank_id}-%d" % 0): - saver.restore(sess, f"./saved-model/model-{rank_id}-%d" % 0) - else: - saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0) + # if os.path.exists(f"./saved-model/sparse-model-{rank_id}-%d" % 0): + # saver.restore(sess, f"./saved-model/model-{rank_id}-%d" % 0) + # else: + # saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0) for i in range(1, 201): logging.info(f"################ training at step {i} ################") @@ -237,12 +237,12 @@ if __name__ == "__main__": else: if i % TRAIN_INTERVAL == 0: EPOCH += 1 - evaluate() - - if i % SAVING_INTERVAL == 0: - saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) + # evaluate() - saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) + # if i % SAVING_INTERVAL == 0: + # saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) + # + # saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) terminate_config_initializer() logging.info("Demo done!") diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 5132e26c..44c9de96 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -16,7 +16,8 @@ from mx_rec.core.embedding import SparseEmbedding from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_CUTTING_POINT, \ ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCAnchorAttr, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_size, destroy_asc_manager, get_training_mode_channel_id, \ - get_feature_spec, insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id + get_feature_spec, insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, \ + get_use_dynamic_expansion from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, replace_anchor, \ record_ops_to_replace, export_pb_graph, make_sorted_key_to_tensor_list @@ -342,7 +343,7 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): send_count=table_instance.send_count, channel_id=channel_id, rank_size=get_rank_size(), table_name=table_instance.table_name, skip_emb_transfer=table_instance.skip_emb_transfer, ext_emb_size=table_instance.ext_emb_size, _emb_size=table_instance._emb_size, use_hot=get_use_hot(), - device_id=get_device_id()) + device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion()) build_asc_graph(table_instance, cutting_point, config) logging.info("Graph has been revised.") diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 714634f8..9bf7d2f3 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -175,26 +175,35 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): host_pipeline_ops = get_host_pipeline_ops() dim = grad.shape.as_list()[-1] - combined_tensor = \ - host_pipeline_ops.embedding_lookup_by_address(addr, embedding_dim=3 * dim, embedding_type=1) - split_length = [dim] + [dim] + [dim] - split_tensors = tf.split(combined_tensor, split_length, axis=1) + addr_m = tf.add(addr, 4*dim) + addr_v = tf.add(addr, 8*dim) + old_m_slice = \ + host_pipeline_ops.embedding_lookup_by_address(addr_m, embedding_dim=dim, embedding_type=1) + old_v_slice = \ + host_pipeline_ops.embedding_lookup_by_address(addr_v, embedding_dim=dim, embedding_type=1) - old_m_slice = split_tensors[1] + # combined_tensor = \ + # host_pipeline_ops.embedding_lookup_by_address(addr, embedding_dim=3 * dim, embedding_type=1) + # split_length = [dim] + [dim] + [dim] + # split_tensors = tf.split(combined_tensor, split_length, axis=1) + + # old_m_slice = split_tensors[1] m_t_slice = temp_b1 * old_m_slice + (1 - temp_b1) * grad + m_update_op = host_pipeline_ops.embedding_update_by_address(addr_m, m_t_slice - old_m_slice, update_type=0) - old_v_slice = split_tensors[2] + # old_v_slice = split_tensors[2] v_t_slice = temp_b2 * old_v_slice + (1 - temp_b2) * math_ops.square(grad) + v_update_op = host_pipeline_ops.embedding_update_by_address(addr_v, v_t_slice - old_v_slice, update_type=0) denominator_slice = math_ops.sqrt(v_t_slice) + temp_epsilon - update_list = [tf.divide(-learning_rate * m_t_slice, denominator_slice)] + [m_t_slice - old_m_slice] + \ - [v_t_slice - old_v_slice] - update_tensor = tf.concat(update_list, axis=1) - var_update_op = host_pipeline_ops.embedding_update_by_address(addr, update_tensor, update_type=0) - var_update_op = tf.identity(var_update_op, name="identity_var_update_op") + # update_list = [tf.divide(-learning_rate * m_t_slice, denominator_slice)] + [m_t_slice - old_m_slice] + \ + # [v_t_slice - old_v_slice] + # update_tensor = tf.concat(update_list, axis=1) + var_update_op = host_pipeline_ops.embedding_update_by_address(addr, tf.divide(-learning_rate * m_t_slice, + denominator_slice), update_type=0) - return var_update_op + return control_flow_ops.group(m_update_op, v_update_op, var_update_op) def _convert_grads_and_addrs(self, grads_and_vars): converted_grads_and_addrs = [] -- Gitee From 8cc3acee3f844d837f8b10b38acc7a0360e4ec61 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 20 May 2023 16:32:02 +0800 Subject: [PATCH 013/551] Match-id-dc45273dd7e26b139b41accc5eed77afce85c9dc --- example/little_demo/main.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 671515d4..01ee8bbc 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -53,25 +53,16 @@ def model_forward(input_list, batch, is_train, modify_graph, config_dict=None): def build_graph(hash_table_list, is_train, use_timestamp=False, config_dict=None, batch_number=100): - batch, iterator = make_batch_and_iterator(is_training=is_train, use_timestamp=use_timestamp, - dump_graph=is_train, batch_number=batch_number) + batch, iterator = make_batch_and_iterator(is_training=is_train, use_timestamp=use_timestamp, dump_graph=is_train, + batch_number=batch_number) if MODIFY_GRAPH_FLAG: - input_list = [ - [batch["user_ids"], batch["item_ids"]], - [hash_table_list[0], hash_table_list[1]], - [cfg.user_send_cnt, cfg.item_send_cnt], - ] + feature_list = [batch["user_ids"], batch["item_ids"]] if USE_TIMESTAMP: tf.add_to_collection(ASCEND_TIMESTAMP, batch["timestamp"]) - model = model_forward(input_list, batch, + model = model_forward([feature_list, hash_table_list, [cfg.user_send_cnt, cfg.item_send_cnt]], batch, is_train=is_train, modify_graph=True, config_dict=config_dict) else: - input_list = [ - [feature_spec for feature_spec in feature_spec_list], - [hash_table_list[0], hash_table_list[0], hash_table_list[0], hash_table_list[1]], - [cfg.user_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt, cfg.item_send_cnt], - ] - model = model_forward(input_list, batch, + model = model_forward([feature_spec_list, hash_table_list, [cfg.user_send_cnt, cfg.item_send_cnt]], batch, is_train=is_train, modify_graph=False, config_dict=config_dict) return iterator, model @@ -106,11 +97,11 @@ if __name__ == "__main__": use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) # nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0 - init(use_mpi = bool(int(os.getenv("USE_MPI"))), + init(use_mpi=bool(int(os.getenv("USE_MPI"))), train_interval=TRAIN_INTERVAL, eval_steps=EVAL_STEPS, prefetch_batch_number=5, - use_dynamic=int(os.getenv("LITTLE_DEMO_USE_DYNAMIC", 0)), + use_dynamic=int(os.getenv("USE_DYNAMIC", 0)), use_hot=bool(int(os.getenv("USE_HOT", 0))), use_dynamic_expansion=use_dynamic_expansion) IF_LOAD = False @@ -129,12 +120,6 @@ if __name__ == "__main__": feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold), - FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", - access_threshold=cfg.access_threshold, - eviction_threshold=cfg.eviction_threshold), - FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table", - access_threshold=cfg.access_threshold, - eviction_threshold=cfg.eviction_threshold), FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table", access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold), @@ -146,8 +131,6 @@ if __name__ == "__main__": else: feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table"), - FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table"), - FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table"), FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table")] optimizer_list = [get_dense_and_sparse_optimizer(cfg) for _ in range(2)] -- Gitee From 0ad3b5f4b01771008c52fe2c6057d91c97003866 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 22 May 2023 14:46:36 +0800 Subject: [PATCH 014/551] Match-id-8d9ad893683f9df540dcaa9df2db5ca272517810 --- mx_rec/graph/modifier.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 5132e26c..44c9de96 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -16,7 +16,8 @@ from mx_rec.core.embedding import SparseEmbedding from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_CUTTING_POINT, \ ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCAnchorAttr, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_size, destroy_asc_manager, get_training_mode_channel_id, \ - get_feature_spec, insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id + get_feature_spec, insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, \ + get_use_dynamic_expansion from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, replace_anchor, \ record_ops_to_replace, export_pb_graph, make_sorted_key_to_tensor_list @@ -342,7 +343,7 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): send_count=table_instance.send_count, channel_id=channel_id, rank_size=get_rank_size(), table_name=table_instance.table_name, skip_emb_transfer=table_instance.skip_emb_transfer, ext_emb_size=table_instance.ext_emb_size, _emb_size=table_instance._emb_size, use_hot=get_use_hot(), - device_id=get_device_id()) + device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion()) build_asc_graph(table_instance, cutting_point, config) logging.info("Graph has been revised.") -- Gitee From f7f806673cc8145f0a9902dbec61b227a310a858 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 19 May 2023 22:16:49 +0800 Subject: [PATCH 015/551] Match-id-07cb005de2a9d908cffd20068a7a1a6048d5afe1 --- mx_rec/util/initialize.py | 25 ++++++++++++++++--------- src/pybind/CMakeLists.txt | 4 +++- src/pybind/module_main.cpp | 17 ++++++++++++++++- src/tests/CMakeLists.txt | 2 +- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index bf0bb3f2..4bc5f28f 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -1,10 +1,10 @@ # coding: UTF-8 - import json import logging import os import psutil +import mxrec_pybind import mx_rec.util.constants from mx_rec.util.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, \ ASCEND_GLOBAL_HASHTABLE_COLLECTION @@ -114,14 +114,16 @@ class ConfigInitializer: devices = server_list.get("device") if devices is None: raise ValueError("device is empty") - local_rank_size = len(devices) for device in devices: if "rank_id" not in device or not device["rank_id"].isdigit(): raise ValueError(f"hccl_json rank_id wrong.") rank_id = int(device["rank_id"]) if "device_id" not in device or not device["device_id"].isdigit(): raise ValueError(f"hccl_json device_id wrong.") - self._rank_to_device_dict[rank_id] = rank_id % local_rank_size + device_id = mxrec_pybind.get_logic_id(int(device["device_id"])) + if device_id > 16: + raise ValueError(f"get logic id from physic id fail.") + self._rank_to_device_dict[rank_id] = device_id def set_device_dict(self): ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") @@ -129,18 +131,23 @@ class ConfigInitializer: raise ValueError("env variable ascend_visible_devices is null.") if "-" in ascend_visible_devices: rank_start = int(ascend_visible_devices.strip().split("-")[0]) + device_list = [i for i in range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]))] elif "," in ascend_visible_devices: - rank_start = int(ascend_visible_devices.strip().split(",")[0]) + device_list = list(map(int, ascend_visible_devices.strip().split(","))) elif ascend_visible_devices in ["0", "1", "2", "3", "4", "5", "6", "7"]: - rank_start = int(ascend_visible_devices.strip()) + device_list = [int(ascend_visible_devices.strip())] else: raise ValueError("invalid env variable ascend_visible_devices.") - rank_size = os.getenv("CM_WORKER_SIZE") + rank_size = int(os.getenv("CM_WORKER_SIZE")) + self._rank_to_device_dict[0] = int(os.getenv("CM_CHIEF_DEVICE")) + device_list.pop(int(os.getenv("CM_CHIEF_DEVICE"))) if rank_size: - rank_size = int(rank_size) local_rank_size = rank_size if rank_size < 8 else 8 - for device_index in range(rank_size): - self._rank_to_device_dict[device_index] = int(device_index % local_rank_size) + for device_index in range(local_rank_size - 1): + device_id = mxrec_pybind.get_logic_id(int(device_list[device_index])) + if device_id > 16: + raise ValueError(f"get logic id from physic id fail.") + self._rank_to_device_dict[device_index + 1] = device_id else: raise ValueError("get CM_WORKER_SIZE failed.") diff --git a/src/pybind/CMakeLists.txt b/src/pybind/CMakeLists.txt index ed1bb626..5bf95c14 100644 --- a/src/pybind/CMakeLists.txt +++ b/src/pybind/CMakeLists.txt @@ -2,5 +2,7 @@ cmake_minimum_required(VERSION 3.12) pybind11_add_module(mxrec_pybind module_main.cpp) set_target_properties(mxrec_pybind PROPERTIES LINK_FLAGS "-Wl,-rpath,/") -target_link_libraries(mxrec_pybind PUBLIC ASC) +target_include_directories(mxrec_pybind PUBLIC ${ASCEND_DRIVER_PATH}/include) +target_link_directories(mxrec_pybind PUBLIC ${ASCEND_DRIVER_PATH}/lib64/driver) +target_link_libraries(mxrec_pybind PUBLIC ASC dcmi) install(TARGETS mxrec_pybind LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) \ No newline at end of file diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 44f02ff0..9426026f 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -4,10 +4,12 @@ * Author: MindX SDK * Date: 2022/11/15 */ -#include "module_main.h" #include #include +#include + #include "emb_mgmt/emb_mgmt.h" +#include "module_main.h" namespace py = pybind11; using namespace MxRec; @@ -33,10 +35,23 @@ int GetUBHotSize(int devID) return static_cast(MxRec::GetUBSize(devID)/ sizeof(float) * HOT_EMB_CACHE_PCT) ; } +uint32_t GetLogicID(uint32_t phyid) +{ + int32_t ret = 0; + uint32_t logicId; + ret = dcmi_get_device_logicid_from_phyid(phyid, &logicId); + if (ret != 0) { + return ret; + } + return logicId; +} + PYBIND11_MODULE(mxrec_pybind, m) { m.def("get_ub_hot_size", &GetUBHotSize, py::arg("device_id")); + m.def("get_logic_id", &GetLogicID, py::arg("physic_id")); + m.attr("USE_STATIC") = py::int_(HybridOption::USE_STATIC); m.attr("USE_HOT") = py::int_(HybridOption::USE_HOT); diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 0ef60b39..817348a2 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -51,7 +51,7 @@ target_link_directories(test_main target_link_libraries(test_main PUBLIC ${TF_LIB} securec OpenMP::OpenMP_CXX ${HDF5_CXX_LIBRARIES} ${MPI_CXX_LIBRARIES} - ${PYTHON_LIBRARY} drvdsmi_host + ${PYTHON_LIBRARY} drvdsmi_host dcmi ) target_link_libraries(test_main PUBLIC -- Gitee From db5c50a1792cd98b7a79781c9516cdae81265ea5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 22 May 2023 19:46:10 +0800 Subject: [PATCH 016/551] Match-id-d7b784c88a8ab69070a95526aabc387c69ea399a --- mx_rec/core/asc/helper.py | 22 ++++++++++++---------- mx_rec/core/asc/manager.py | 16 +++++----------- mx_rec/util/initialize.py | 18 ++++++++++++++++++ 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index e14128f8..3831cafa 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -8,7 +8,7 @@ import os import tensorflow as tf from mx_rec.util.initialize import get_host_pipeline_ops, insert_feature_spec, insert_training_mode_channel_id, \ - get_training_mode_channel_id, get_use_static + get_training_mode_channel_id, get_use_static,insert_dangling_table from .feature_spec import FeatureSpec @@ -64,9 +64,11 @@ def find_dangling_table(table_names): table_reachable_tensor[table_name].extend(the_op.outputs) op_list = tf.get_default_graph().get_operations() + table_lookup_op = {} table_reachable_tensor = {} for the_op in op_list: + logging.info(f"** the_op: {the_op}**") for table_name in table_names: find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) @@ -97,6 +99,7 @@ def find_dangling_table(table_names): found = bfs_lookup(table_name, table_op) if not found: dangling_table.append(table_name) + insert_dangling_table(table_name) return dangling_table @@ -125,10 +128,17 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ } get_target_tensors_with_feature_specs(tgt_key_specs, data_src, is_training, read_emb_key_inputs_dict) logging.debug(f"do_insert with spec for {read_emb_key_inputs_dict['table_names']}") + table_names = read_emb_key_inputs_dict["table_names"] + logging.info(f"all table_names: {table_names}") + dangling_tables = find_dangling_table(table_names) + for table_name in dangling_tables: + logging.info(f"In insert found dangling table: {table_name} " + f"which does not need to be provided to the EmbInfo.") + table_names.remove(table_name) return do_insert(args, insert_tensors=read_emb_key_inputs_dict["insert_tensors"], splits=read_emb_key_inputs_dict["splits"], - table_names=read_emb_key_inputs_dict["table_names"], + table_names=table_names, input_dict={"is_training": is_training, "dump_graph": dump_graph, "timestamp": FeatureSpec.use_timestamp(is_training), "feature_spec_names": read_emb_key_inputs_dict["feature_spec_names"], @@ -255,16 +265,8 @@ def do_insert(args, insert_tensors, splits, table_names, input_dict): feature_spec_names = input_dict["feature_spec_names"] auto_change_graph = input_dict["auto_change_graph"] - # Only the tables that need to be used after table combination are retained in meituan situation. - # Current solution has error in same situations. For example, a sparse table has not been auto-merged. - from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN new_insert_tensors, new_splits, new_table_names = [], [], [] - logging.debug(f"In do_insert function, ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") for idx, table_name in enumerate(table_names): - if ASCEND_TABLE_NAME_MUST_CONTAIN is not None and ASCEND_TABLE_NAME_MUST_CONTAIN not in table_name: - logging.info(f"After the tables are combined, the information about the" - f" {table_name} table does not need to be provided to the read_emb_key operator.") - continue new_insert_tensors.append(insert_tensors[idx]) new_splits.append(splits[idx]) new_table_names.append(table_names[idx]) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 0d8807f2..1e2e987c 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -10,7 +10,7 @@ from mx_rec.util.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ - get_use_hot, get_use_dynamic_expansion, export_optimizer + get_use_hot, get_use_dynamic_expansion, export_optimizer,export_dangling_table from mx_rec.core.asc.helper import find_dangling_table @@ -29,8 +29,9 @@ def generate_table_info_list(): optimizer = export_optimizer() # generate table info - dangling_table = find_dangling_table([table_instance.table_name - for _, table_instance in export_table_instances().items()]) + dangling_table = export_dangling_table() + # dangling_table = find_dangling_table([table_instance.table_name + # for _, table_instance in export_table_instances().items()]) for _, table_instance in export_table_instances().items(): # When dynamic expansion mode, ext_emb_size is set by optimizer if optimizer is not None: @@ -41,14 +42,7 @@ def generate_table_info_list(): logging.info(f"Found dangling table: {table_instance.table_name} " f"which does not need to be provided to the EmbInfo.") continue - # Only the tables that need to be used after table combination are retained in meituan situation. - # Current solution has error in same situations. For example, a sparse table has not been auto-merged. - logging.debug(f"In EmbInfo, ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") - if ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ - ASCEND_TABLE_NAME_MUST_CONTAIN not in table_instance.table_name: - logging.info(f"After the tables are combined, the information about the" - f" {table_instance.table_name} table does not need to be provided to the EmbInfo.") - continue + rec_mode_asc_flag = table_instance.mode == MxRecMode.ASC static_shape_rec_flag = rec_mode_asc_flag and get_use_static() and table_instance.send_count > 0 dynamic_shape_rec_flag = rec_mode_asc_flag and not get_use_static() diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 2cb6d235..eaa40e6d 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -28,6 +28,7 @@ class ConfigInitializer: self._prefetch_batch_number = None self._if_load = None self._table_instance_dict = dict() + self._dangling_table = [] self._name_to_var_dict = dict() self._table_name_set = set() self._table_name_to_feature_spec = dict() @@ -147,6 +148,15 @@ class ConfigInitializer: def get_training_mode_channel_id(self, is_training): return self._training_mode_channel_dict.get(is_training) + def insert_dangling_table(self, name): + if name in self._dangling_table: + return + self._dangling_table.append(name) + + @property + def dangling_table(self): + return self._dangling_table + def insert_table_instance(self, name, key, instance): if key in self._table_instance_dict: raise KeyError(f"Given key {key} has been used.") @@ -467,6 +477,10 @@ def get_table_instance_by_name(table_name): return ConfigInitializer.get_instance().get_table_instance_by_name(table_name) +def insert_dangling_table(table_name): + ConfigInitializer.get_instance().insert_dangling_table(table_name) + + def insert_table_instance(name, key, instance): ConfigInitializer.get_instance().insert_table_instance(name, key, instance) @@ -475,6 +489,10 @@ def export_table_instances(): return ConfigInitializer.get_instance().table_instance_dict +def export_dangling_table(): + return ConfigInitializer.get_instance().dangling_table + + def insert_optimizer(optimizer): ConfigInitializer.get_instance().insert_optimizer(optimizer) -- Gitee From 087ef297d55ebb3e832a47756378c90f13df33d1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 23 May 2023 15:55:18 +0800 Subject: [PATCH 017/551] Match-id-a51c936e7ed2bec8b821a065dc738496201eeeae --- CMakeLists.txt | 2 + build.sh | 6 ++ build/build.sh | 88 ++++++++++++------- build/package.sh | 60 +++++++++++++ config.ini | 0 dependency.xml | 0 mx_rec/__init__.py | 5 +- mx_rec/core/__init__.py | 3 + mx_rec/core/asc/__init__.py | 4 +- mx_rec/core/asc/build_graph.py | 16 ++-- mx_rec/core/asc/feature_spec.py | 5 +- mx_rec/core/asc/helper.py | 5 +- mx_rec/core/asc/manager.py | 5 +- mx_rec/core/embedding.py | 23 ++--- mx_rec/graph/__init__.py | 4 +- mx_rec/graph/modifier.py | 11 ++- mx_rec/graph/patch.py | 7 +- mx_rec/graph/utils.py | 8 +- mx_rec/optimizers/__init__.py | 3 + mx_rec/optimizers/adagrad.py | 4 +- mx_rec/optimizers/base.py | 7 +- mx_rec/optimizers/ftrl.py | 7 +- mx_rec/optimizers/ftrl_t.py | 7 +- mx_rec/optimizers/ftrl_t_dense.py | 9 +- mx_rec/optimizers/gradient_descent.py | 8 +- mx_rec/optimizers/gradient_descent_by_addr.py | 7 +- mx_rec/optimizers/lazy_adam.py | 7 +- mx_rec/optimizers/lazy_adam_by_addr.py | 11 ++- mx_rec/optimizers/momentum.py | 4 +- mx_rec/saver/__init__.py | 3 + mx_rec/saver/patch.py | 4 +- mx_rec/saver/saver.py | 4 +- mx_rec/util/__init__.py | 5 +- mx_rec/util/atomic.py | 9 +- mx_rec/util/constants.py | 4 +- mx_rec/util/initialize.py | 5 +- mx_rec/util/log.py | 9 +- mx_rec/util/ops.py | 5 +- mx_rec/util/perf.py | 7 +- mx_rec/util/synchronizer.py | 5 +- mx_rec/util/tf_version_adapter.py | 11 ++- mx_rec/util/variable.py | 4 +- setup.py | 9 +- src/build.sh | 9 +- src/test_ut.sh | 40 ++++----- 45 files changed, 293 insertions(+), 166 deletions(-) create mode 100644 CMakeLists.txt create mode 100644 build.sh create mode 100644 build/package.sh create mode 100644 config.ini create mode 100644 dependency.xml diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..deb19638 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,2 @@ +cmake_minimum_required(VERSION 3.12) +project(MxRec LANGUAGES CXX) diff --git a/build.sh b/build.sh new file mode 100644 index 00000000..86e1c715 --- /dev/null +++ b/build.sh @@ -0,0 +1,6 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: build entrance script. + +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +bash "${SCRIPT_DIR}"/build/build.sh diff --git a/build/build.sh b/build/build.sh index 7f68d243..18e9dbc8 100644 --- a/build/build.sh +++ b/build/build.sh @@ -1,14 +1,15 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. # Description: build script. # Author: MindX SDK -# Create: 2022 +# Create: 2021 # History: NA set -e warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } ARCH="$(uname -m)" SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +ROOT_DIR=$(dirname "${SCRIPT_DIR}") cd "$SCRIPT_DIR" if [ "$(uname -m)" = "aarch64" ] then @@ -48,7 +49,7 @@ then deactivate tf1_env fi -VERSION_FILE=${SCRIPT_DIR}/../../mindxsdk/build/conf/config.yaml +VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml get_version() { if [ -f "$VERSION_FILE" ]; then VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") @@ -60,35 +61,43 @@ get_version() { fi } -project_root_folder=${SCRIPT_DIR}/.. -project_output_path=${project_root_folder}/output/ -rm -rf "${project_output_path}" -rm -rf "${SCRIPT_DIR}/lib" +remove() +{ + if [ -d "$1" ]; then + rm -rf "$1" + elif [ -f "$1" ]; then + rm -f "$1" + fi +} + +project_output_path="${ROOT_DIR}"/output/ +remove "${project_output_path}" +remove "${SCRIPT_DIR}/lib" get_version export VERSION echo "MindX SDK mxrec: ${VERSION}" >> ./version.info pkg_dir=mindxsdk-mxrec -[ -d ${pkg_dir} ] && rm -rf ${pkg_dir} -mkdir ${pkg_dir} -mv version.info ${pkg_dir} +remove "${pkg_dir}" +mkdir "${pkg_dir}" +mv version.info "${pkg_dir}" -opensource_path=${project_root_folder}/../opensource/opensource +opensource_path="${ROOT_DIR}"/../opensource/opensource abseil_src_path=${opensource_path}/abseil echo "${abseil_src_path}" -abseil_install_path=${project_root_folder}/install/abseil +abseil_install_path="${ROOT_DIR}"/install/abseil -src_path=${project_root_folder}/src +src_path="${ROOT_DIR}"/src -cd "${project_root_folder}" +cd "${ROOT_DIR}" -release_tar=Ascend-${pkg_dir}_${VERSION}_linux-${ARCH}.tar.gz +release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz install_abseil() { - rm -rf "${abseil_install_path}" + remove "${abseil_install_path}" echo "${abseil_install_path}" - if [[ ! -d ${abseil_install_path} ]] + if [[ ! -d "${abseil_install_path}" ]] then mkdir -p "${abseil_install_path}" fi @@ -110,13 +119,13 @@ install_abseil() compile_securec() { - if [[ ! -d ${project_root_folder}/platform/securec ]]; then + if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then echo "securec is not exist" exit 1 fi - if [[ ! -f ${project_root_folder}/platform/securec/lib/libsecurec.so ]]; then - cd ${project_root_folder}/platform/securec/src + if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then + cd "${ROOT_DIR}"/platform/securec/src make -j fi } @@ -125,39 +134,39 @@ compile_so_file() { cd "${src_path}" chmod u+x build.sh - ./build.sh "$1" "${project_root_folder}" + ./build.sh "$1" "${ROOT_DIR}" cd .. } collect_so_file() { cd "${src_path}" - rm -rf "${src_path}"/libasc + remove "${src_path}"/libasc mkdir -p "${src_path}"/libasc chmod u+x libasc - cp -df "${project_root_folder}"/output/*.so* libasc - cp ${project_root_folder}/platform/securec/lib/libsecurec.so libasc + cp -df "${ROOT_DIR}"/output/*.so* libasc + cp "${ROOT_DIR}"/platform/securec/lib/libsecurec.so libasc } gen_wheel_file() { - cd "${project_root_folder}" + cd "${ROOT_DIR}" touch "${src_path}"/libasc/__init__.py - [ -d "${project_root_folder}"/mx_rec/libasc ] && rm -rf "${project_root_folder}"/mx_rec/libasc - mv "${src_path}"/libasc "${project_root_folder}"/mx_rec + remove "${ROOT_DIR}"/mx_rec/libasc + mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec python3 setup.py bdist_wheel mkdir -p "$1" mv dist/mx_rec*.whl "$1" - rm -rf "${project_root_folder}"/mx_rec/libasc + remove "${ROOT_DIR}"/mx_rec/libasc } gen_tar_file() { cd "${src_path}" - mv "${project_root_folder}"/tf1_whl ../build/${pkg_dir} - mv "${project_root_folder}"/tf2_whl ../build/${pkg_dir} - cp -r "${src_path}"/../example ../build/${pkg_dir} + mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" + mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" + cp -r "${src_path}"/../example ../build/"${pkg_dir}" cd ../build tar -zvcf "${release_tar}" "${pkg_dir}" || { warn "compression failed, packages might be broken" @@ -167,6 +176,18 @@ gen_tar_file() } +clean() +{ + remove "${ROOT_DIR}"/install + remove "${ROOT_DIR}"/mx_rec.egg-info + remove "${ROOT_DIR}"/src/build + remove "${ROOT_DIR}"/build/bdist.linux-"$(arch)" + remove "${ROOT_DIR}"/build/tf1_env + remove "${ROOT_DIR}"/build/tf2_env + remove "${ROOT_DIR}"/build/lib + remove "${ROOT_DIR}"/build/mindxsdk-mxrec +} + install_abseil compile_securec @@ -174,17 +195,18 @@ echo "-----Build Start tf1 -----" source "${SCRIPT_DIR}"/tf1_env/bin/activate compile_so_file "${tf1_path}" collect_so_file -gen_wheel_file "${project_root_folder}"/tf1_whl +gen_wheel_file "${ROOT_DIR}"/tf1_whl deactivate tf1_env echo "-----Build Start tf2 -----" source "${SCRIPT_DIR}"/tf2_env/bin/activate compile_so_file "${tf2_path}" collect_so_file -gen_wheel_file "${project_root_folder}"/tf2_whl +gen_wheel_file "${ROOT_DIR}"/tf2_whl deactivate tf2_env echo "-----Build gen tar -----" gen_tar_file +clean echo "-----Done-----" diff --git a/build/package.sh b/build/package.sh new file mode 100644 index 00000000..3e4a1e07 --- /dev/null +++ b/build/package.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Package script +# Copyright © Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +set -e + +CURDIR=$(dirname "$(readlink -f "$0")") +SCRIPT_NAME=$(basename "$0") +ROOT_PATH=$(readlink -f "$CURDIR"/../) +OUTPUT_PATH="$ROOT_PATH/output" +VERSION_FILE=${SCRIPT_DIR}/../../mindxsdk/build/conf/config.yaml +get_version() { + if [ -f "$VERSION_FILE" ]; then + VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") + if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then + VERSION=${VERSION%.*} + fi + else + VERSION="5.0.T104" + fi +} + +get_version +export VERSION + + +function make_zip_package() +{ + cd "${OUTPUT_PATH}" + pkg_file=$(ls "$OUTPUT_PATH"/*"${1}"*."${2}") + pkg_file="${pkg_file##*/}" + pkg_release="${pkg_file%."${2}"}" + + package_file="${OUTPUT_PATH}"/package + [ -d "$package_file" ] && rm -rf "$package_file" + mkdir "$package_file" + cp -f "${OUTPUT_PATH}"/crldata.crl "$OUTPUT_PATH/${pkg_release}.${2}.crl" + cp "$pkg_release".* "$package_file" + + cd "$package_file" + chmod 600 "$pkg_release.${2}" + chmod 600 "$pkg_release.${2}".cms + chmod 600 "$pkg_release.${2}".crl + zip_file="${3}$pkg_release.zip" + zip -r "$zip_file" "$pkg_release.${2}" "$pkg_release.${2}".cms "$pkg_release.${2}".crl + + mv "$package_file/$zip_file" "${OUTPUT_PATH}/$zip_file" + echo "zip $zip_file success !" + [ -d "$package_file" ] && rm -rf "$package_file" + return 0 +} + +function main() +{ + make_zip_package Ascend-mindxsdk-mxrec tar.gz + return 0 +} + +echo "begin to execute $SCRIPT_NAME" +main;ret="$?" +echo "finish exuecte $SCRIPT_NAME, result is $ret" diff --git a/config.ini b/config.ini new file mode 100644 index 00000000..e69de29b diff --git a/dependency.xml b/dependency.xml new file mode 100644 index 00000000..e69de29b diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 3ffb468a..6ad3f558 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -1,4 +1,7 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + from .util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from .util.tf_version_adapter import npu_ops, hccl_ops from .saver.patch import patch_for_saver diff --git a/mx_rec/core/__init__.py b/mx_rec/core/__init__.py index e69de29b..6924f767 100644 --- a/mx_rec/core/__init__.py +++ b/mx_rec/core/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/mx_rec/core/asc/__init__.py b/mx_rec/core/asc/__init__.py index e9754a0d..6924f767 100644 --- a/mx_rec/core/asc/__init__.py +++ b/mx_rec/core/asc/__init__.py @@ -1 +1,3 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index e081c798..153b6f7f 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -1,6 +1,6 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2021-2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import logging @@ -16,11 +16,11 @@ def get_restore_vector(config): logging.debug(f'Channel {config.get("table_name")}_restore_{config.get("channel_id")} was built for getnext') emb_size = None if config.get("skip_emb_transfer"): - if not isinstance(config.get("_emb_size"), int) or config.get("_emb_size") < 1: - raise TypeError(f"_emb_size must be a int") - if config.get("_emb_size") < 1: - raise ValueError(f"_emb_size is less than 1") - emb_size = config.get("_emb_size") + if not isinstance(config.get("emb_size"), int) or config.get("emb_size") < 1: + raise TypeError(f"emb_size must be a int") + if config.get("emb_size") < 1: + raise ValueError(f"emb_size is less than 1") + emb_size = config.get("emb_size") else: if not isinstance(config.get("ext_emb_size"), int) or config.get("ext_emb_size") < 1: raise TypeError(f"ext_emb_size must be a int") @@ -90,7 +90,7 @@ def get_all2all_args(use_static: bool, config: dict) -> list: output_types=[tf.int64], output_shapes=[[config.get("rank_size"), config.get("rank_size")]], channel_name=f'{config.get("table_name")}_all2all_{config.get("channel_id")}', - name="a2a_get_next")[0] * config.get("_emb_size") + name="a2a_get_next")[0] * config.get("emb_size") return all2all_args diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 0c680027..75b5b5c7 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -1,6 +1,7 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2021-2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import logging from functools import reduce diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 756f5b0d..61f585c4 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -1,6 +1,7 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2021-2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import logging from functools import reduce import os diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 6ea99612..8befc89a 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -1,6 +1,7 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2021-2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import logging import os diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 52f7ce9c..1b75ffa5 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -1,6 +1,7 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2021-2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import logging import math import time @@ -131,7 +132,7 @@ class SparseEmbedding: self._use_feature_mapping = False self.skip_emb_transfer = True if self.host_vocabulary_size <= 0 else False self._default_name_count = -1 - self._emb_size = None + self.emb_size = None self.ext_emb_size = None self.ext_coefficient = 1 self._optimizer = dict() @@ -146,7 +147,7 @@ class SparseEmbedding: self.set_emb_size() if self._mode == MxRecMode.ASC and is_asc_frozen() and self.table_name in get_name_to_var_dict(): self.variable = tf.compat.v1.get_variable(self.table_name, - shape=(self.slice_device_vocabulary_size, self._emb_size)) + shape=(self.slice_device_vocabulary_size, self.emb_size)) if not self.skip_emb_transfer: self.set_ext_emb_size() else: @@ -210,13 +211,13 @@ class SparseEmbedding: self._use_feature_mapping = True def set_emb_size(self): - self._emb_size = self.embedding_size.as_list()[0] + self.emb_size = self.embedding_size.as_list()[0] def set_ext_emb_size(self): self.ext_coefficient += len(self.optimizer_slot_info_list) if self.use_dynamic_expansion and len(self._optimizer_instance_list) != 0: self.ext_coefficient += self._slot_num[self.table_name] - self.ext_emb_size = self._emb_size * self.ext_coefficient + self.ext_emb_size = self.emb_size * self.ext_coefficient logging.debug(f"init table, ext_emb_size is set to be {self.ext_emb_size}") def set_slice_vocab_size(self): @@ -236,7 +237,7 @@ class SparseEmbedding: @property def scalar_emb_size(self): - return self._emb_size + return self.emb_size @property def mode(self): @@ -404,7 +405,7 @@ class SparseEmbedding: local_embeddings = None if use_dynamic_expansion: local_embeddings = get_host_pipeline_ops().embedding_lookup_by_address(id_offsets, - embedding_dim=self._emb_size, + embedding_dim=self.emb_size, embedding_type=1) if is_training: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) @@ -428,7 +429,7 @@ class SparseEmbedding: hot_pos = None if use_hot: import mxrec_pybind - hot_size = int(mxrec_pybind.get_ub_hot_size(get_device_id()) / self._emb_size) + hot_size = int(mxrec_pybind.get_ub_hot_size(get_device_id()) / self.emb_size) hot_pos = tf.ones(shape=[hot_size, ], dtype=tf.int32, name="hot_pos") hot_pos = tf.identity(hot_pos, name=ASCAnchorAttr.HOT_POS.value) tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_HOT_POS, hot_pos) @@ -555,7 +556,7 @@ class SparseEmbedding: config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, rank_size=rank_size, channel_id=channel_id, table_name=self.table_name, skip_emb_transfer=self.skip_emb_transfer, ext_emb_size=self.ext_emb_size, - _emb_size=self._emb_size, use_hot=use_hot, device_id=device_id, + emb_size=self.emb_size, use_hot=use_hot, device_id=device_id, use_dynamic_expansion=use_dynamic_expansion) if self.skip_emb_transfer: @@ -627,7 +628,7 @@ class SparseEmbedding: return sparse_forward(self.variable) local_embeddings = \ - host_pipeline_ops.embedding_lookup_by_address(id_offsets, embedding_dim=self._emb_size, + host_pipeline_ops.embedding_lookup_by_address(id_offsets, embedding_dim=self.emb_size, embedding_type=1) if is_training: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) diff --git a/mx_rec/graph/__init__.py b/mx_rec/graph/__init__.py index e9754a0d..6924f767 100644 --- a/mx_rec/graph/__init__.py +++ b/mx_rec/graph/__init__.py @@ -1 +1,3 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 44c9de96..ae075efe 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -1,8 +1,7 @@ -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. -# Description: build script. -# Author: MindX SDK -# pylint: disable=W0212 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import logging from collections import defaultdict @@ -342,7 +341,7 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=table_instance.send_count, channel_id=channel_id, rank_size=get_rank_size(), table_name=table_instance.table_name, skip_emb_transfer=table_instance.skip_emb_transfer, - ext_emb_size=table_instance.ext_emb_size, _emb_size=table_instance._emb_size, use_hot=get_use_hot(), + ext_emb_size=table_instance.ext_emb_size, emb_size=table_instance.emb_size, use_hot=get_use_hot(), device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion()) build_asc_graph(table_instance, cutting_point, config) diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 4f7ab920..8f2add0f 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -1,7 +1,6 @@ -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. -# Description: build script. -# Author: MindX SDK +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import weakref diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index 2e9f0134..cd26b9e1 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -1,7 +1,7 @@ -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. -# Description: build script. -# Author: MindX SDK +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + from collections import defaultdict import tensorflow as tf diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index e69de29b..6924f767 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index 9b1b8deb..9e048ad4 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -1,6 +1,6 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index a54e829c..56d8ff41 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -1,7 +1,6 @@ -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. -# Description: build script. -# Author: MindX SDK +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 9ccb0099..bf2b5c0d 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -1,7 +1,6 @@ -# coding=utf-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. -# Description: -# Author: MindX SDK +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/ftrl_t.py b/mx_rec/optimizers/ftrl_t.py index a2795f6d..6b1f04fc 100644 --- a/mx_rec/optimizers/ftrl_t.py +++ b/mx_rec/optimizers/ftrl_t.py @@ -1,7 +1,6 @@ -# coding=utf-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: -# Author: MindX SDK +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/ftrl_t_dense.py b/mx_rec/optimizers/ftrl_t_dense.py index ae305838..412a7617 100644 --- a/mx_rec/optimizers/ftrl_t_dense.py +++ b/mx_rec/optimizers/ftrl_t_dense.py @@ -1,7 +1,6 @@ -# coding=utf-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: -# Author: MindX SDK +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division @@ -74,7 +73,7 @@ class CustomizedFtrlTZ(optimizer.Optimizer): return self._apply_dense_shared_v2( grad, handle) - + def _apply_dense(self, grad, var): if self._lambda1 > 1e-10: return self._apply_dense_shared( diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 997d62fa..8313fca6 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -1,8 +1,6 @@ -# coding=utf-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. -# Description: -# Author: MindX SDK -# Create: 2022-12-01 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index 3d22bf1b..fbc34e0b 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -1,7 +1,6 @@ -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: CustomizedGradientDescentByAddr. -# Author: MindX SDK +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 8d3c0427..649302ff 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -1,7 +1,6 @@ -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. -# Description: build script. -# Author: MindX SDK +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 714634f8..6a5a598d 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -1,7 +1,6 @@ -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: CustomizedLazyAdamByAddress. -# Author: MindX SDK +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division @@ -238,7 +237,7 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): else: scope_name = addr.op.name with ops.name_scope( - "update_" + scope_name), ops.colocate_with(addr): + f"update_{scope_name}"), ops.colocate_with(addr): update_ops.append(processor.update_op(self, grad)) apply_updates = self._finish(update_ops, name) @@ -304,4 +303,4 @@ def _get_processor(addr): if isinstance(addr, ops.Tensor): logging.debug(">>>>Enter _get_processor tensor") return _TensorByAddressProcessor(addr) - raise NotImplementedError("Trying to optimize unsupported type ", addr) \ No newline at end of file + raise NotImplementedError("Trying to optimize unsupported type ", addr) diff --git a/mx_rec/optimizers/momentum.py b/mx_rec/optimizers/momentum.py index f9d67508..66c96552 100644 --- a/mx_rec/optimizers/momentum.py +++ b/mx_rec/optimizers/momentum.py @@ -1,6 +1,6 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/saver/__init__.py b/mx_rec/saver/__init__.py index e69de29b..6924f767 100644 --- a/mx_rec/saver/__init__.py +++ b/mx_rec/saver/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 56eba09d..706f45b0 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -1,6 +1,6 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2021-2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import os import time diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index d1593872..337447e3 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -1,6 +1,6 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2021-2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import json import os diff --git a/mx_rec/util/__init__.py b/mx_rec/util/__init__.py index 68b05dcd..e2c29877 100644 --- a/mx_rec/util/__init__.py +++ b/mx_rec/util/__init__.py @@ -1,4 +1,7 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + from .log import get_log_level diff --git a/mx_rec/util/atomic.py b/mx_rec/util/atomic.py index 57c41acb..f03ea569 100644 --- a/mx_rec/util/atomic.py +++ b/mx_rec/util/atomic.py @@ -1,6 +1,6 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2021-2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import threading @@ -9,14 +9,14 @@ class AtomicInteger(): def __init__(self, value=0): self._value = int(value) self._lock = threading.Lock() - + def increase(self, num=1): with self._lock: self._value += int(num) return self._value def decrease(self, num=1): - return self.inc(-num) + return self.inc(-num) def value(self): with self._lock: @@ -24,4 +24,3 @@ class AtomicInteger(): def __str__(self): return str(self.value()) - \ No newline at end of file diff --git a/mx_rec/util/constants.py b/mx_rec/util/constants.py index fe0c84a3..3a9be2ec 100644 --- a/mx_rec/util/constants.py +++ b/mx_rec/util/constants.py @@ -1,6 +1,6 @@ -#!/usr/bin/python3 +#!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2021-2023 Huawei Technologies Co., Ltd +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from enum import Enum diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 4bc5f28f..73dcd354 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -1,4 +1,7 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import json import logging import os diff --git a/mx_rec/util/log.py b/mx_rec/util/log.py index 8370860e..0093704b 100644 --- a/mx_rec/util/log.py +++ b/mx_rec/util/log.py @@ -1,4 +1,7 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import os import logging @@ -16,7 +19,3 @@ def get_log_level(): date_format = "%m/%d/%Y %H:%M:%S %p" logging.basicConfig(level=log_level, format=log_format, datefmt=date_format) - - -if __name__ == "__main__": - logging.debug("haha") \ No newline at end of file diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py index b396f5ea..40c99f37 100644 --- a/mx_rec/util/ops.py +++ b/mx_rec/util/ops.py @@ -1,4 +1,7 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import os import logging import tensorflow as tf diff --git a/mx_rec/util/perf.py b/mx_rec/util/perf.py index 72065812..66501b5f 100644 --- a/mx_rec/util/perf.py +++ b/mx_rec/util/perf.py @@ -1,4 +1,7 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import time import logging @@ -12,4 +15,4 @@ def performance(method_name): logging.debug(f"{method_name} method consume {span:.6f}s.") return result return wrapper - return decorator \ No newline at end of file + return decorator diff --git a/mx_rec/util/synchronizer.py b/mx_rec/util/synchronizer.py index 82cd54fe..9fba25dd 100644 --- a/mx_rec/util/synchronizer.py +++ b/mx_rec/util/synchronizer.py @@ -1,4 +1,7 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import logging import socket from time import sleep diff --git a/mx_rec/util/tf_version_adapter.py b/mx_rec/util/tf_version_adapter.py index 6f8dfa09..d071c5c4 100644 --- a/mx_rec/util/tf_version_adapter.py +++ b/mx_rec/util/tf_version_adapter.py @@ -1,12 +1,15 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import tensorflow as tf if tf.__version__.startswith("1"): - from npu_bridge.hccl import hccl_ops + from npu_bridge.hccl import hccl_ops else: from npu_device.compat.v1.hccl import hccl_ops if tf.__version__.startswith("1"): - from npu_bridge.estimator import npu_ops + from npu_bridge.estimator import npu_ops else: - from npu_device.compat.v1.estimator import npu_ops \ No newline at end of file + from npu_device.compat.v1.estimator import npu_ops diff --git a/mx_rec/util/variable.py b/mx_rec/util/variable.py index dc7b1c65..9101616a 100644 --- a/mx_rec/util/variable.py +++ b/mx_rec/util/variable.py @@ -1,4 +1,6 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import tensorflow as tf from tensorflow.python.framework import ops diff --git a/setup.py b/setup.py index e660a184..f0933917 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2025. All rights reserved. -# Description: build script. +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Description: setup script. # Author: MindX SDK # Create: 2022 # History: NA @@ -8,7 +10,8 @@ import os from setuptools import setup, find_packages try: - LONG_DESCRIPTION = open("README.md").read() + with open("README.md") as file: + LONG_DESCRIPTION = file.read() except IOError: LONG_DESCRIPTION = "" diff --git a/src/build.sh b/src/build.sh index 098693ef..c5728012 100644 --- a/src/build.sh +++ b/src/build.sh @@ -1,4 +1,11 @@ -rm -rf build; +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Description: build script. +# Author: MindX SDK +# Create: 2022 +# History: NA +set -e +[ -d build ] && rm -rf build; mkdir build && cd build || exit 1 # HDF5_PATH, EASY_PROFILER_PATH is optional python_path="$(dirname $(dirname $(realpath $(which python3))))" diff --git a/src/test_ut.sh b/src/test_ut.sh index 175954bb..f2bbe2ff 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2020-2022. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. # Description: NA # Author: MindX SDK # Create: 2022 @@ -14,41 +14,41 @@ source /etc/profile source /opt/rh/devtoolset-7/enable CUR_DIR=$(dirname "$(readlink -f "$0")") +ROOT_DIR=$(dirname "${CUR_DIR}") compile_securec() { - if [[ ! -d ${CUR_DIR}/../platform/securec ]]; then + if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then echo "securec is not exist" exit 1 fi - if [[ ! -f ${CUR_DIR}/../platform/securec/lib/libsecurec.so ]]; then - cd ${CUR_DIR}/../platform/securec/src + if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then + cd "${ROOT_DIR}"/platform/securec/src make -j fi } compile_securec -cd ${CUR_DIR} -cd ../src/ +cd "${ROOT_DIR}"/src find ./ -name "*.sh" -exec dos2unix {} \; find ./ -name "*.sh" -exec chmod +x {} \; -rm -rf build +[ -d build ] && rm -rf build mkdir build cd build cmake -DCMAKE_BUILD_TYPE=Debug \ - -DTF_PATH=/opt/buildtools/python-3.7.5/lib/python3.7/site-packages/tensorflow_core \ + -DTF_PATH="$(dirname "$(dirname "$(which python3.7)")")"/lib/python3.7/site-packages/tensorflow_core \ -DOMPI_PATH=/usr/local/openmpi/ \ - -DPYTHON_PATH=/opt/buildtools/python-3.7.5/ \ + -DPYTHON_PATH="$(dirname "$(dirname "$(which python3.7)")")" \ -DEASY_PROFILER_PATH=/opt/buildtools/ \ -DASCEND_PATH=/usr/local/Ascend/ascend-toolkit/latest \ - -DABSEIL_PATH=${PWD}/../../install/abseil/ \ - -DSECUREC_PATH=${CUR_DIR}/../platform/securec \ - -DBUILD_TESTS=on -DCOVERAGE=on ../ + -DABSEIL_PATH="$(dirname "$(dirname "${PWD}")")"/install/abseil/ \ + -DSECUREC_PATH="${ROOT_DIR}"/platform/securec \ + -DBUILD_TESTS=on -DCOVERAGE=on "$(dirname "${PWD}")" make -j make install @@ -56,16 +56,16 @@ make install # Run Test mpirun -np 4 ./tests/test_main -cd ../ +cd "$(dirname "${PWD}")" COVERAGE_FILE=coverage.info REPORT_FOLDER=coverage_report -lcov --rc lcov_branch_coverage=1 -c -d build -o ${COVERAGE_FILE}_tmp -lcov --rc lcov_branch_coverage=1 -e ${COVERAGE_FILE}_tmp "*src*" -o ${COVERAGE_FILE} -genhtml --rc genhtml_branch_coverage=1 ${COVERAGE_FILE} -o ${REPORT_FOLDER} -rm -rf ${COVERAGE_FILE}_tmp -rm -rf ${COVERAGE_FILE} +lcov --rc lcov_branch_coverage=1 -c -d build -o "${COVERAGE_FILE}"_tmp +lcov --rc lcov_branch_coverage=1 -e "${COVERAGE_FILE}"_tmp "*src*" -o "${COVERAGE_FILE}" +genhtml --rc genhtml_branch_coverage=1 "${COVERAGE_FILE}" -o "${REPORT_FOLDER}" +[ -d "${COVERAGE_FILE}"_tmp ] && rm -rf "${COVERAGE_FILE}"_tmp +[ -d "${COVERAGE_FILE}" ] && rm -rf "${COVERAGE_FILE}" if [[ "$OSTYPE" == "darwin"* ]]; then - open ./${REPORT_FOLDER}/index.html -fi \ No newline at end of file + open ./"${REPORT_FOLDER}"/index.html +fi -- Gitee From 212d4e107fe9157dd84350f7dac5ad3faa4a7bbb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 23 May 2023 17:14:30 +0800 Subject: [PATCH 018/551] Match-id-f9d69f7248d0e0c3da1527e86a35b79a5f0c53b1 --- mx_rec/util/initialize.py | 75 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 73dcd354..7d0c435b 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -5,6 +5,8 @@ import json import logging import os +from collections import defaultdict + import psutil import mxrec_pybind @@ -599,13 +601,58 @@ def set_ascend_env(): logging.debug(f"Ascend env has been set.") +def get_available_cpu_num_and_range(): + """ + 获取当前环境可用的cpu数量和numa范围 + Returns: + + """ + cpu_available = os.sched_getaffinity(os.getpid()) # 获取可被绑定的核心 + + is_ok = True + cpu_pkg_id_file = "/sys/devices/system/cpu/cpu{}/topology/physical_package_id" + pkg_id2cpu_list = defaultdict(list) + for cpu in cpu_available: + f_path = cpu_pkg_id_file.format(cpu) + if not os.path.exists(f_path): + logging.warning(f"failed to get numa node of cpu: {cpu}") + is_ok = False + break + with open(f_path, "r", encoding="utf-8") as f_in: + pkg_id = f_in.readline().strip() + pkg_id2cpu_list[pkg_id].append(cpu) + + def parse_range(cpu_list, cpu_range): + sorted_cpu_list = sorted(cpu_list) + pre_cpu = sorted_cpu_list[0] + cpu_range.append([pre_cpu]) + + for sorted_cpu in sorted_cpu_list[1:]: + if sorted_cpu - pre_cpu != 1: + cpu_range[-1].append(pre_cpu) + cpu_range.append([sorted_cpu]) + pre_cpu = sorted_cpu + + if len(cpu_range[-1]) == 1: + cpu_range[-1].append(pre_cpu) + + valid_cpu_range_list = [] + if is_ok: + logging.info(f"available numa node num: {len(pkg_id2cpu_list)}") + for k, part_cpu_list in pkg_id2cpu_list.items(): + parse_range(part_cpu_list, valid_cpu_range_list) + else: + parse_range(list(cpu_available), valid_cpu_range_list) + return len(cpu_available), valid_cpu_range_list + + def bind_cpu(rank_id: int, rank_size: int = None): """ 以均衡的方式为每个进程绑定CPU :param rank_id:当前进程的rank_id + :param rank_size: 进程数 :return: """ - from multiprocessing import cpu_count import math try: @@ -619,14 +666,32 @@ def bind_cpu(rank_id: int, rank_size: int = None): f"{DEFAULT_DEVICE_NUM_LOCAL_MACHINE} is set as default value") local_rank_size = DEFAULT_DEVICE_NUM_LOCAL_MACHINE - total_cpu = cpu_count() + total_cpu, cpu_range_list = get_available_cpu_num_and_range() avg_count = math.ceil(total_cpu / local_rank_size) - max_index = total_cpu - 1 - start = rank_id * avg_count - cpu_list = [start + i for i in range(avg_count) if start + i <= max_index] + while True: + if avg_count == 0: + logging.warning(f"not enough cpu to bind. cpu num: {total_cpu}, range: {cpu_range_list}") + return + + max_split = 0 + for cpu_range in cpu_range_list: + max_split += (cpu_range[1] - cpu_range[0] + 1) // avg_count + if max_split >= local_rank_size: + break + avg_count -= 1 + + candidate_list = [] + for cpu_range in cpu_range_list: + start = cpu_range[0] + splits = (cpu_range[1] - cpu_range[0] + 1) // avg_count + candidate_range = [list(range(start + i * avg_count, start + ((i + 1) * avg_count))) for i in range(splits)] + candidate_list.extend(candidate_range) + + cpu_list = candidate_list[rank_id] process = psutil.Process() try: process.cpu_affinity(cpu_list) + logging.info(f"bind cpu for rank {rank_id}: {cpu_list}") except IndexError: logging.error(f"failed to bind cpu for rank {rank_id}: {cpu_list}") -- Gitee From 28c06239009bf317620b71251318f8dfc08a6b84 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 23 May 2023 11:05:40 +0800 Subject: [PATCH 019/551] Match-id-aff4f959a1755e6d80f18ac80a7629b3ed62d239 --- mx_rec/__init__.py | 5 +++++ setup.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 6ad3f558..8e92175a 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -10,3 +10,8 @@ from .graph.patch import patch_for_dataset patch_for_saver() patch_for_dataset() +__version__ = "5.0.RC2" + + +def version(): + return __version__ diff --git a/setup.py b/setup.py index f0933917..36b406fe 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,25 @@ # History: NA import os +import stat from setuptools import setup, find_packages +import pkg_resources +from setuptools.extern.packaging import version as packaging_version + + +# Patch Version class to preserve original version string +class NoNormalizeVersion(packaging_version.Version): + def __init__(self, version): + self._orig_version = version + super().__init__(version) + + def __str__(self): + return self._orig_version + + +packaging_version.Version = NoNormalizeVersion +# Patch safe_version() to prevent version normalization +pkg_resources.safe_version = lambda v: v try: with open("README.md") as file: @@ -18,6 +36,21 @@ except IOError: env_version = os.getenv("VERSION") VERSION = env_version if env_version is not None else '5.0.T104' +INIT_FILE = "mx_rec/__init__.py" +with open(INIT_FILE, 'r') as file: + lines = file.readlines() + +for idx, line in enumerate(lines): + if "__version__ = " not in line: + continue + lines[idx] = f"__version__ = '{VERSION}'\n" + break + +FLAG = os.O_WRONLY | os.O_TRUNC +MODE = stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH +with os.fdopen(os.open(INIT_FILE, FLAG, MODE), 'w') as out: + out.writelines(lines) + setup( name='mx_rec', version=VERSION, -- Gitee From 388e376558fca4d596a6a2ef1943c1d433c29012 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 23 May 2023 17:38:54 +0800 Subject: [PATCH 020/551] Match-id-5420c36186806655b991dd4f293c6bf5a0a9579d --- mx_rec/core/embedding.py | 29 ++++++++++++++++++++--------- mx_rec/graph/utils.py | 16 +++++++++++++--- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 1b75ffa5..e157acc4 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -10,6 +10,7 @@ from collections import defaultdict import numpy as np import tensorflow as tf from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.helper import FeatureSpec @@ -365,14 +366,16 @@ class SparseEmbedding: """ logging.debug(f"Enter ASC Branch.") + if not kwargs.get("modify_graph"): + raise ValueError(f"modify_graph must be turn-on when lookup by ids(Tensor, not FeatureSpec).") + self.check_mode(MxRecMode.ASC) is_training = kwargs.get("is_train") if is_asc_frozen() and is_training: - raise EnvironmentError(f"Cannot build new sparse forward graph after emb cache management was built.") - if not kwargs.get("modify_graph"): - raise RuntimeError(f"modify_graph must be turn-on when lookup by ids(Tensor, not FeatureSpec).") + raise RuntimeError(f"Cannot build new sparse forward graph after emb cache management was built.") + if is_asc_frozen() and not is_training: + clear_channel(is_train_channel=False) - rank_size = get_rank_size() access_threshold = None eviction_threshold = None if kwargs.get("access_and_evict_config"): @@ -384,9 +387,6 @@ class SparseEmbedding: eviction_threshold=eviction_threshold) feature_spec.set_feat_attribute(ids, is_training) - if is_asc_frozen() and not is_training: - clear_channel(is_train_channel=False) - self.check_and_format_lookup_params(ids, send_count, is_training) anchor_ids = tf.identity(ids, name="ids") tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, anchor_ids) @@ -398,6 +398,7 @@ class SparseEmbedding: logging.debug(f"In lookup_for_asc function, table name: {self.table_name}, ids: {ids}, use_dynamic_expansion: " f"{use_dynamic_expansion}, use_static: {use_static}, use_hot: {use_hot}") + rank_size = get_rank_size() id_offsets = tf.ones(shape=[send_count * rank_size if use_static else 1 * rank_size, ], dtype=tf.int64 if get_use_dynamic_expansion() else tf.int32, name="id_offsets") id_offsets = tf.identity(id_offsets, name=ASCAnchorAttr.ID_OFFSETS.value) @@ -414,7 +415,13 @@ class SparseEmbedding: @tf.custom_gradient def sparse_forward(table, feat_ids): logging.debug(f"fp rank size: {rank_size}") - restore_vector = tf.ones(shape=[np.prod(feat_ids.shape.as_list()), ], dtype=tf.int32, name="restore_vector") + if use_static: + restore_vector = tf.ones(shape=[np.prod(feat_ids.shape.as_list()), ], dtype=tf.int32, + name="restore_vector") + else: + restore_vector = tf.ones(shape=[tf.math.reduce_prod(array_ops.shape(feat_ids)[0]), ], dtype=tf.int32, + name="restore_vector") + restore_vector = tf.identity(restore_vector, name=ASCAnchorAttr.RESTORE_VECTOR.value) tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, restore_vector) SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.RESTORE_VECTOR] = restore_vector @@ -448,7 +455,11 @@ class SparseEmbedding: unique_embeddings], axis=0) embeddings = tf.gather(unique_embeddings, restore_vector, axis=0, name="gather_for_restore_vector") - lookup_result = tf.reshape(embeddings, feat_ids.shape.as_list() + [self.scalar_emb_size]) + if use_static: + lookup_result = tf.reshape(embeddings, feat_ids.shape.as_list() + [self.scalar_emb_size]) + else: + dest_shape = array_ops.concat([array_ops.shape(feat_ids), [self.scalar_emb_size]], 0) + lookup_result = array_ops.reshape(embeddings, dest_shape) def grad(lookup_diff): embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index cd26b9e1..f2851b7a 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -60,10 +60,20 @@ def replace_anchor(replacement_specs: defaultdict, new_tensor_list: list): operator._update_input(input_idx, new_tensor_list[tensor_idx]) -def export_pb_graph(file_name, dump_graph, graph_def=None, export_path="./export_graph"): +def export_pb_graph(file_name, dump_graph, graph_def=None, export_path="./export_graph", as_text=False): + """ + Save tensorflow graph before and after modifier graph + :param file_name: FileName of the graph + :param dump_graph: Is serialize graph or not + :param graph_def: A Graph or a GraphDef protocol buffer. + :param export_path: Directory where to write the graph. + This can refer to remote filesystems, such as Google Cloud Storage (GCS). + :param as_text: If True, writes the graph as an ASCII proto + :return: None + """ if dump_graph: - graph_def = graph_def if graph_def else tf.get_default_graph().as_graph_def() - tf.train.write_graph(graph_def, export_path, file_name, False) + graph_def = graph_def if graph_def else tf.compat.v1.get_default_graph().as_graph_def() + tf.io.write_graph(graph_def, export_path, file_name, as_text) def make_sorted_key_to_tensor_list(element_spec, sorted_keys, prefix=""): -- Gitee From c363a6a94f48760143a372e2907d50ed323d53ca Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 24 May 2023 14:18:38 +0800 Subject: [PATCH 021/551] Match-id-9459432ed1aec8c410e0e89ef4e342dc5123a6ee --- mx_rec/graph/modifier.py | 5 ++++- mx_rec/saver/saver.py | 23 +++++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index ae075efe..9b92b1f5 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -16,7 +16,7 @@ from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_CUTTI ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCAnchorAttr, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_size, destroy_asc_manager, get_training_mode_channel_id, \ get_feature_spec, insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, \ - get_use_dynamic_expansion + get_use_dynamic_expansion, terminate_config_initializer from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, replace_anchor, \ record_ops_to_replace, export_pb_graph, make_sorted_key_to_tensor_list @@ -423,3 +423,6 @@ class GraphModifierHook(tf.estimator.SessionRunHook): def after_create_session(self, session, coord): if self.modify_graph: session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) + + def end(self, session): + terminate_config_initializer() diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 337447e3..a862e073 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -89,10 +89,18 @@ class Saver(object): @performance("Save") def save(self, sess, save_path="model", global_step=None): + """ + Save sparse tables + :param sess: A Session to use to save the sparse table variables + :param save_path: Only absolute path supported + :param global_step: If provided the global step number is appended to save_path to create + the checkpoint filenames. The optional argument can be a Tensor, a Tensor name or an integer. + :return: None + """ logging.debug(f"======== Start saving for rank id {self.rank_id} ========") save_path = save_path if save_path else self._prefix_name directory, base_name = os.path.split(save_path) - if global_step is not None: + if global_step: if not isinstance(global_step, compat.integral_types): global_step = int(sess.run(global_step)) ckpt_name = "sparse-%s-%d" % (base_name, global_step) @@ -100,7 +108,10 @@ class Saver(object): ckpt_name = "sparse-%s" % base_name integrated_path = os.path.join(directory, ckpt_name) - saving_path = os.path.abspath(integrated_path) + saving_path = integrated_path + if integrated_path.startswith("/"): + saving_path = os.path.abspath(integrated_path) + if os.path.exists(saving_path): shutil.rmtree(saving_path, ignore_errors=True) logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been deleted.") @@ -112,14 +123,14 @@ class Saver(object): logging.debug(f"======== Saving finished for rank id {self.rank_id} ========") def _save(self, sess, root_dir): - if is_asc_manager_initialized(): - save_host_data(root_dir) - logging.debug(f"host data was saved.") - result = sess.run(self.save_op_dict) for table_name, dump_data_dict in result.items(): save_embedding_data(root_dir, table_name, dump_data_dict, self.rank_id) table_instance = get_table_instance_by_name(table_name) + if is_asc_manager_initialized(): + save_host_data(root_dir) + logging.debug(f"host data was saved.") + if table_instance.use_feature_mapping: save_feature_mapping_data(root_dir, table_name, dump_data_dict, self.rank_id) save_offset_data(root_dir, table_name, dump_data_dict, self.rank_id) -- Gitee From 7600f5d2ea9600d0bcf187f1e2cf7498de43b8b5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 23 May 2023 16:45:44 +0800 Subject: [PATCH 022/551] Match-id-a32d67a1b3613ddcf5d3367b43215cfc21efb6b9 --- CMakeLists.txt | 2 +- build.sh | 28 ++++++++++++++++++++++++++-- build/build.sh | 2 ++ src/CMakeLists.txt | 4 ++-- tools/python/key_2_emb_formatter.py | 4 +++- 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index deb19638..d56a5b63 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,2 +1,2 @@ -cmake_minimum_required(VERSION 3.12) +cmake_minimum_required(VERSION 3.20) project(MxRec LANGUAGES CXX) diff --git a/build.sh b/build.sh index 86e1c715..87183edc 100644 --- a/build.sh +++ b/build.sh @@ -2,5 +2,29 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. # Description: build entrance script. -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -bash "${SCRIPT_DIR}"/build/build.sh +set -e +ROOT_DIR=$(dirname "$(readlink -f "$0")") + +remove() +{ + if [ -d "$1" ]; then + rm -rf "$1" + elif [ -f "$1" ]; then + rm -f "$1" + fi +} + +clean() +{ + remove "${ROOT_DIR}"/dist + remove "${ROOT_DIR}"/install + remove "${ROOT_DIR}"/mx_rec.egg-info + remove "${ROOT_DIR}"/src/build + remove "${ROOT_DIR}"/build/bdist.linux-"$(arch)" + remove "${ROOT_DIR}"/build/tf1_env + remove "${ROOT_DIR}"/build/tf2_env + remove "${ROOT_DIR}"/build/lib + remove "${ROOT_DIR}"/build/mindxsdk-mxrec +} + +clean diff --git a/build/build.sh b/build/build.sh index 18e9dbc8..b7fc65f3 100644 --- a/build/build.sh +++ b/build/build.sh @@ -103,6 +103,7 @@ install_abseil() cd "${abseil_src_path}" echo "${abseil_src_path}" + remove CMakeCache.txt cmake -DCMAKE_INSTALL_PREFIX="${abseil_install_path}" . && make -j8 && make install echo "${project_output_path}"/abseil @@ -178,6 +179,7 @@ gen_tar_file() clean() { + remove "${ROOT_DIR}"/dist remove "${ROOT_DIR}"/install remove "${ROOT_DIR}"/mx_rec.egg-info remove "${ROOT_DIR}"/src/build diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ca19df6c..42d2f2df 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.12) +cmake_minimum_required(VERSION 3.20) project(MxRec LANGUAGES CXX) set(CMAKE_CXX_STANDARD 14) @@ -31,7 +31,7 @@ endif () if (NOT easy_profiler_FOUND) message("===EASY_PROFILER_NOT_FOUND===") else () - message("==EASY_PROFILER_FOUND===") + message("==EASY_PROFILER_FOUND===") ADD_DEFINITIONS(-DBUILD_WITH_EASY_PROFILER) endif () set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -ffunction-sections -O0 -Wall -g2 -ggdb") diff --git a/tools/python/key_2_emb_formatter.py b/tools/python/key_2_emb_formatter.py index 6b2045ca..467bd5c3 100644 --- a/tools/python/key_2_emb_formatter.py +++ b/tools/python/key_2_emb_formatter.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. # Description: # Author: MindX SDK @@ -199,4 +201,4 @@ class Formatter: if __name__ == "__main__": args = parser.parse_args() formatter = Formatter(saved_file_path=args.path, out_file_name=args.name, is_ddr_mode=args.ddr, step=args.step) - formatter.process() \ No newline at end of file + formatter.process() -- Gitee From 304807c6eefe6d919393e85bd0303b7ead181676 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 25 May 2023 17:57:02 +0800 Subject: [PATCH 023/551] Match-id-a0ad58f40311fd89dc042e87e68cca05ede37eaa --- mx_rec/__init__.py | 8 +- mx_rec/core/asc/feature_spec.py | 34 ++-- mx_rec/core/asc/helper.py | 12 +- mx_rec/core/asc/manager.py | 8 +- mx_rec/core/embedding.py | 124 +++++++------- mx_rec/optimizers/adagrad.py | 62 +++---- mx_rec/optimizers/base.py | 2 +- mx_rec/optimizers/ftrl.py | 62 +++---- mx_rec/optimizers/ftrl_t.py | 72 ++++---- mx_rec/optimizers/gradient_descent.py | 14 +- mx_rec/optimizers/gradient_descent_by_addr.py | 46 ++--- mx_rec/optimizers/lazy_adam.py | 70 ++++---- mx_rec/optimizers/lazy_adam_by_addr.py | 116 ++++++------- mx_rec/optimizers/momentum.py | 62 +++---- mx_rec/saver/saver.py | 100 +++++------ mx_rec/util/__init__.py | 2 +- mx_rec/util/atomic.py | 6 +- mx_rec/util/initialize.py | 162 +++++++++--------- 18 files changed, 481 insertions(+), 481 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 8e92175a..bbee9e8f 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -2,10 +2,10 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -from .util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION -from .util.tf_version_adapter import npu_ops, hccl_ops -from .saver.patch import patch_for_saver -from .graph.patch import patch_for_dataset +from mx_rec.util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION +from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops +from mx_rec.saver.patch import patch_for_saver +from mx_rec.graph.patch import patch_for_dataset patch_for_saver() diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 75b5b5c7..7c148892 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -39,19 +39,6 @@ class FeatureSpec: self._pipeline_mode = set() self.check_params() - @staticmethod - def include_timestamp(is_training): - if is_training: - if FeatureSpec.use_timestamp_train: - raise EnvironmentError(f"Timestamp was set twice for training mode.") - FeatureSpec.use_timestamp_train = True - else: - FeatureSpec.use_timestamp_eval = True - - @staticmethod - def use_timestamp(is_training): - return FeatureSpec.use_timestamp_train if is_training else FeatureSpec.use_timestamp_eval - @property def is_timestamp(self): return self._is_timestamp @@ -76,6 +63,23 @@ class FeatureSpec: def feat_cnt(self): return self._feat_cnt + @property + def pipeline_mode(self): + return self._pipeline_mode + + @staticmethod + def include_timestamp(is_training): + if is_training: + if FeatureSpec.use_timestamp_train: + raise EnvironmentError(f"Timestamp was set twice for training mode.") + FeatureSpec.use_timestamp_train = True + else: + FeatureSpec.use_timestamp_eval = True + + @staticmethod + def use_timestamp(is_training): + return FeatureSpec.use_timestamp_train if is_training else FeatureSpec.use_timestamp_eval + def check_params(self): def check_str(arg, param_name): if not isinstance(arg, str): @@ -117,10 +121,6 @@ class FeatureSpec: self.feat_pos_eval = FeatureSpec.instance_count_eval FeatureSpec.instance_count_eval += 1 - @property - def pipeline_mode(self): - return self._pipeline_mode - def insert_pipeline_mode(self, mode): if not isinstance(mode, bool): raise TypeError("Is training mode must be a boolean.") diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 61f585c4..92a034e5 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -10,7 +10,7 @@ import tensorflow as tf from mx_rec.util.initialize import get_host_pipeline_ops, insert_feature_spec, insert_training_mode_channel_id, \ get_training_mode_channel_id, get_use_static -from .feature_spec import FeatureSpec +from mx_rec.core.asc.feature_spec import FeatureSpec def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, feature_numbers=None, @@ -68,14 +68,14 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ read_emb_key_inputs_dict = {"insert_tensors": [], "table_names": [], "feature_spec_names": [], "splits": []} get_target_tensors_with_feature_specs(tgt_key_specs, data_src, is_training, read_emb_key_inputs_dict) - logging.debug(f"do_insert with spec for {read_emb_key_inputs_dict['table_names']}") + logging.debug(f"do_insert with spec for {read_emb_key_inputs_dict.get('table_names')}") return do_insert(args, - insert_tensors=read_emb_key_inputs_dict["insert_tensors"], - splits=read_emb_key_inputs_dict["splits"], - table_names=read_emb_key_inputs_dict["table_names"], + insert_tensors=read_emb_key_inputs_dict.get("insert_tensors"), + splits=read_emb_key_inputs_dict.get("splits"), + table_names=read_emb_key_inputs_dict.get("table_names"), input_dict={"is_training": is_training, "dump_graph": dump_graph, "timestamp": FeatureSpec.use_timestamp(is_training), - "feature_spec_names": read_emb_key_inputs_dict["feature_spec_names"], + "feature_spec_names": read_emb_key_inputs_dict.get("feature_spec_names"), "auto_change_graph": False}) insert_fn = insert_fn_for_feature_specs diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 8befc89a..9554f74f 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -74,11 +74,11 @@ def matched_emb_initializer(tabel_info): tf.__version__.startswith("2") and isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal), } - if initializer_case_map["tf1/tf2_constant_initializer"]: + if initializer_case_map.get("tf1/tf2_constant_initializer"): initializer = InitializeInfo(name="constant_initializer", start=0, len=tabel_info.scalar_emb_size, constant_initializer_info=ConstantInitializerInfo( constant_val=tabel_info.emb_initializer.value)) - elif initializer_case_map["tf1/tf2_random_normal_initializer"]: + elif initializer_case_map.get("tf1/tf2_random_normal_initializer"): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed initializer = InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.scalar_emb_size, normal_initializer_info=NormalInitializerInfo( @@ -86,8 +86,8 @@ def matched_emb_initializer(tabel_info): stddev=tabel_info.emb_initializer.stddev, seed=random_seed )) - elif initializer_case_map["tf1_truncated_normal_initializer"] or \ - initializer_case_map["tf2_truncated_normal_initializer"]: + elif initializer_case_map.get("tf1_truncated_normal_initializer") or \ + initializer_case_map.get("tf2_truncated_normal_initializer"): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed initializer = InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, normal_initializer_info=NormalInitializerInfo( diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index e157acc4..b1106532 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -157,6 +157,53 @@ class SparseEmbedding: self.set_ext_emb_size() tf.compat.v1.add_to_collection(get_ascend_global_hashtable_collection(), self.variable) + @property + def use_feature_mapping(self): + return self._use_feature_mapping + + @property + def scalar_emb_size(self): + return self.emb_size + + @property + def mode(self): + return self._mode + + @property + def send_count(self): + return self._send_count + + @property + def optimizer(self): + return self._optimizer + + @property + def optimizer_instance_list(self): + return self._optimizer_instance_list + + @staticmethod + def get_anchor_attribute(anchor, attr): + if not isinstance(anchor, tf.Tensor): + raise ValueError("Anchor must be a Tensor.") + + if attr not in ASCAnchorAttr: + raise ValueError("Given attr must be limited in Enum 'ASCAnchorAttr'.") + + specs = SparseEmbedding.anchor_tensor_specs.get(anchor) + if specs is None: + raise ValueError(f"Given anchor '{anchor}' was not registered.") + + return specs.get(attr) + + @staticmethod + def set_optimizer_slot(slot_info): + slot = slot_info.get("slot") + slot_name = slot_info.get("slot_name") + optimizer = slot_info.get("optimizer") + named_slot_key = slot_info.get("named_slot_key") + + optimizer.insert_slot(slot, named_slot_key, slot_name) + def check_optimizer_instance(self): for optimizer_instance in self._optimizer_instance_list: if tf.__version__.startswith("1"): @@ -217,7 +264,7 @@ class SparseEmbedding: def set_ext_emb_size(self): self.ext_coefficient += len(self.optimizer_slot_info_list) if self.use_dynamic_expansion and len(self._optimizer_instance_list) != 0: - self.ext_coefficient += self._slot_num[self.table_name] + self.ext_coefficient += self._slot_num.get(self.table_name) self.ext_emb_size = self.emb_size * self.ext_coefficient logging.debug(f"init table, ext_emb_size is set to be {self.ext_emb_size}") @@ -232,41 +279,6 @@ class SparseEmbedding: self.slice_device_vocabulary_size = math.ceil(self.device_vocabulary_size / rank_size) self.slice_host_vocabulary_size = math.ceil(self.host_vocabulary_size / rank_size) - @property - def use_feature_mapping(self): - return self._use_feature_mapping - - @property - def scalar_emb_size(self): - return self.emb_size - - @property - def mode(self): - return self._mode - - def _record(self): - insert_table_instance(self.table_name, self.variable, self) - logging.debug(f"Device vocabulary_size for table {self.table_name} is {self.device_vocabulary_size}.") - logging.debug(f"Slice_device_vocabulary_size for table {self.table_name} is" - f" {self.slice_device_vocabulary_size}.") - logging.debug(f"Host vocabulary size for table {self.table_name} is {self.host_vocabulary_size}.") - logging.debug(f"Slice host vocabulary_size for table {self.table_name} is" - f" {self.slice_host_vocabulary_size}.") - - @staticmethod - def get_anchor_attribute(anchor, attr): - if not isinstance(anchor, tf.Tensor): - raise ValueError("Anchor must be a Tensor.") - - if attr not in ASCAnchorAttr: - raise ValueError("Given attr must be limited in Enum 'ASCAnchorAttr'.") - - specs = SparseEmbedding.anchor_tensor_specs.get(anchor) - if specs is None: - raise ValueError(f"Given anchor '{anchor}' was not registered.") - - return specs.get(attr) - def register_anchor_attribute(self, anchor_ids, feature_spec, kwargs): SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.TABLE_INSTANCE] = self SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = kwargs.get("is_train") @@ -338,18 +350,6 @@ class SparseEmbedding: self._optimizer[key] = state_dict - @property - def send_count(self): - return self._send_count - - @property - def optimizer(self): - return self._optimizer - - @property - def optimizer_instance_list(self): - return self._optimizer_instance_list - def lookup_for_asc(self, ids: tf.Tensor, send_count, **kwargs): """ @@ -503,8 +503,8 @@ class SparseEmbedding: spec_name = feature_spec.name is_training = kwargs.get("is_train") if self.lookup_result is not None and spec_name in self.lookup_result \ - and is_training in self.lookup_result[spec_name]: - return self.lookup_result[spec_name][is_training] + and is_training in self.lookup_result.get(spec_name): + return self.lookup_result.get(spec_name).get(is_training) table_name = feature_spec.table_name same_table_feature_spec = ConfigInitializer.get_instance().table_name_to_feature_spec[table_name][is_training] @@ -513,7 +513,7 @@ class SparseEmbedding: if len(same_table_feature_spec) == 1: lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) self.lookup_result = {spec_name: {is_training: lookup_result}} - return self.lookup_result[spec_name][is_training] + return self.lookup_result.get(spec_name).get(is_training) else: same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) same_table_spec_count = len(same_table_feature_spec) @@ -534,7 +534,7 @@ class SparseEmbedding: lookup_result_split = tf.split(lookup_result, split_size) self.lookup_result = {k.name: {is_training: tf.reshape(v, k.dims + [self.scalar_emb_size])} for k, v in zip(same_table_feature_spec, lookup_result_split)} - return self.lookup_result[spec_name][is_training] + return self.lookup_result.get(spec_name).get(is_training) def lookup_for_asc_with_feature_spec_inner(self, feature_spec: FeatureSpec, send_count: int, **kwargs): """ @@ -646,6 +646,15 @@ class SparseEmbedding: return sparse_forward(local_embeddings) + def _record(self): + insert_table_instance(self.table_name, self.variable, self) + logging.debug(f"Device vocabulary_size for table {self.table_name} is {self.device_vocabulary_size}.") + logging.debug(f"Slice_device_vocabulary_size for table {self.table_name} is" + f" {self.slice_device_vocabulary_size}.") + logging.debug(f"Host vocabulary size for table {self.table_name} is {self.host_vocabulary_size}.") + logging.debug(f"Slice host vocabulary_size for table {self.table_name} is" + f" {self.slice_host_vocabulary_size}.") + def _initialize_variables(self): initialized_tensor = self.emb_initializer(self.slice_device_vocabulary_size + self.embedding_size) self.variable = tf.compat.v1.get_variable(self.table_name, trainable=False, initializer=initialized_tensor) @@ -667,15 +676,6 @@ class SparseEmbedding: for slot_info in self.optimizer_slot_info_list: self.set_optimizer_slot(slot_info) - @staticmethod - def set_optimizer_slot(slot_info): - slot = slot_info.get("slot") - slot_name = slot_info.get("slot_name") - optimizer = slot_info.get("optimizer") - named_slot_key = slot_info.get("named_slot_key") - - optimizer.insert_slot(slot, named_slot_key, slot_name) - def get_own_ids(unique_ids, origin_id_lens, send_cnt, self): from mx_rec.util.tf_version_adapter import hccl_ops @@ -815,7 +815,7 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): trigger_evict() self._start_time = cur_time for name in self._hash_table_instance.keys(): - run_context.session.run(self._evict_op[name]) + run_context.session.run(self._evict_op.get(name)) def check_name_and_get_hashtable(self): for _, feature_spec in export_feature_spec().items(): diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index 9e048ad4..f1499c25 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -47,7 +47,7 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): use_locking=False, name="Adagrad"): self.optimizer_type = "Adagrad" - super(CustomizedAdagrad, self).__get_name__(name=name) + super(CustomizedAdagrad, self)._get_name(name=name) super(CustomizedAdagrad, self).__init__(learning_rate=learning_rate, initial_accumulator_value=initial_accumulator_value, use_locking=use_locking, @@ -55,6 +55,36 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): self._check_input_param() + def initialize_slots(self, var, table_instance): + # Create slots for the first and second moments. + def creat_one_single_slot(var, op_name): + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + return new_slot_variable + + accumulator = creat_one_single_slot(var, self._name + "/" + "accumulator") + remove_saving_var(accumulator) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"accumulator": accumulator}) + return [{"slot": accumulator, "named_slot_key": named_slot_key, "slot_name": "acc", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + # return state value list of adagrad that needs to initialize in ASC DDR. + initial_accumulator_value = 0.0 + return [initial_accumulator_value] + def _check_input_param(self): check_param_type("learning_rate", self._learning_rate, (tf.Tensor, float)) check_param_type("initial_accumulator_value", self._initial_accumulator_value, (tf.Tensor, float)) @@ -87,33 +117,3 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): return training_ops.resource_sparse_apply_adagrad( var.handle, acc.handle, math_ops.cast(self._learning_rate_tensor, grad.dtype), grad, indices, use_locking=self._use_locking) - - def initialize_slots(self, var, table_instance): - # Create slots for the first and second moments. - def creat_one_single_slot(var, op_name): - new_slot_variable = slot_creator.create_zeros_slot(var, op_name) - # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - return new_slot_variable - - accumulator = creat_one_single_slot(var, self._name + "/" + "accumulator") - remove_saving_var(accumulator) - named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") - - table_instance.set_optimizer(self._name, {"accumulator": accumulator}) - return [{"slot": accumulator, "named_slot_key": named_slot_key, "slot_name": "acc", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot - - def get_slot_init_values(self): - # return state value list of adagrad that needs to initialize in ASC DDR. - initial_accumulator_value = 0.0 - return [initial_accumulator_value] diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index 56d8ff41..3f00978a 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -17,7 +17,7 @@ class CustomizedOptimizer: self.unique_name = "" self.base_name = "" - def __get_name__(self, name="CustomizedOptimizer"): + def _get_name(self, name="CustomizedOptimizer"): if name in CustomizedOptimizer.name_counter: CustomizedOptimizer.name_counter[name] += 1 count = CustomizedOptimizer.name_counter.get(name) diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index bf2b5c0d..5498a932 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -34,7 +34,7 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): def __init__(self, learning_rate, use_locking=False, name="Ftrl", **kwargs): self.optimizer_type = "ftrl" - super(CustomizedFtrl, self).__get_name__(name=name) + super(CustomizedFtrl, self)._get_name(name=name) super(CustomizedFtrl, self).__init__( learning_rate=learning_rate, learning_rate_power=kwargs.get("learning_rate_power", -0.5), @@ -67,6 +67,36 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): _check_param_type_range(param_name_list) + def initialize_slots(self, var, table_instance): + val = constant_op.constant( + self._initial_accumulator_value, dtype=var.dtype, shape=var.get_shape()) + + accum = slot_creator.create_slot(var, val, self._name + "/" + "accum") + linear = slot_creator.create_zeros_slot(var, self._name + "/" + "linear") + remove_saving_var(accum) + remove_saving_var(linear) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"accum": accum, "linear": linear}) + return [{"slot": accum, "named_slot_key": named_slot_key, "slot_name": "accum", "optimizer": self}, + {"slot": linear, "named_slot_key": named_slot_key, "slot_name": "linear", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + # return state value list of ftrl that needs to initialize in ASC DDR. + initial_linear_value = 0.0 + return [self._initial_accumulator_value, initial_linear_value] + def _apply_sparse_duplicate_indices(self, grad, var): logging.debug(f"######### _apply_sparse_duplicate_indices {var}") return self._apply_sparse(grad, var) @@ -211,33 +241,3 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"accum": accum, "linear": linear}) - - def initialize_slots(self, var, table_instance): - val = constant_op.constant( - self._initial_accumulator_value, dtype=var.dtype, shape=var.get_shape()) - - accum = slot_creator.create_slot(var, val, self._name + "/" + "accum") - linear = slot_creator.create_zeros_slot(var, self._name + "/" + "linear") - remove_saving_var(accum) - remove_saving_var(linear) - named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") - - table_instance.set_optimizer(self._name, {"accum": accum, "linear": linear}) - return [{"slot": accum, "named_slot_key": named_slot_key, "slot_name": "accum", "optimizer": self}, - {"slot": linear, "named_slot_key": named_slot_key, "slot_name": "linear", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot - - def get_slot_init_values(self): - # return state value list of ftrl that needs to initialize in ASC DDR. - initial_linear_value = 0.0 - return [self._initial_accumulator_value, initial_linear_value] diff --git a/mx_rec/optimizers/ftrl_t.py b/mx_rec/optimizers/ftrl_t.py index 6b1f04fc..0dedc009 100644 --- a/mx_rec/optimizers/ftrl_t.py +++ b/mx_rec/optimizers/ftrl_t.py @@ -34,7 +34,7 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): def __init__(self, learning_rate, use_locking=False, name="Ftrl_t", **kwargs): self.optimizer_type = "ftrl" - super(CustomizedFtrlT, self).__get_name__(name=name) + super(CustomizedFtrlT, self)._get_name(name=name) self._learning_rate = learning_rate self._alpha = kwargs.get("alpha", 0.06) @@ -55,6 +55,41 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): self._grad_factor_tensor = None super(CustomizedFtrlT, self).__init__(use_locking, self.unique_name) + def initialize_slots(self, var, table_instance): + z = slot_creator.create_zeros_slot(var, self._name + "/" + "z") + n = slot_creator.create_zeros_slot(var, self._name + "/" + "n") + g = slot_creator.create_zeros_slot(var, self._name + "/" + "g") + w = slot_creator.create_zeros_slot(var, self._name + "/" + "w") + remove_saving_var(z) + remove_saving_var(n) + remove_saving_var(g) + remove_saving_var(w) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"z": z, "n": n, "g": g, "w": w}) + return [{"slot": z, "named_slot_key": named_slot_key, "slot_name": "z", "optimizer": self}, + {"slot": n, "named_slot_key": named_slot_key, "slot_name": "n", "optimizer": self}, + {"slot": g, "named_slot_key": named_slot_key, "slot_name": "g", "optimizer": self}, + {"slot": w, "named_slot_key": named_slot_key, "slot_name": "w", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + initial_z_value = 0.0 + initial_n_value = 0.0 + initial_g_value = 0.0 + initial_w_value = 0.0 + return [initial_z_value, initial_n_value, initial_g_value, initial_w_value] + def _prepare(self): self._learning_rate_tensor = ops.convert_to_tensor( self._learning_rate, name="learning_rate") @@ -217,38 +252,3 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"z": z, "n": n, "g": g, "w": w}) - - def initialize_slots(self, var, table_instance): - z = slot_creator.create_zeros_slot(var, self._name + "/" + "z") - n = slot_creator.create_zeros_slot(var, self._name + "/" + "n") - g = slot_creator.create_zeros_slot(var, self._name + "/" + "g") - w = slot_creator.create_zeros_slot(var, self._name + "/" + "w") - remove_saving_var(z) - remove_saving_var(n) - remove_saving_var(g) - remove_saving_var(w) - named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") - - table_instance.set_optimizer(self._name, {"z": z, "n": n, "g": g, "w": w}) - return [{"slot": z, "named_slot_key": named_slot_key, "slot_name": "z", "optimizer": self}, - {"slot": n, "named_slot_key": named_slot_key, "slot_name": "n", "optimizer": self}, - {"slot": g, "named_slot_key": named_slot_key, "slot_name": "g", "optimizer": self}, - {"slot": w, "named_slot_key": named_slot_key, "slot_name": "w", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot - - def get_slot_init_values(self): - initial_z_value = 0.0 - initial_n_value = 0.0 - initial_g_value = 0.0 - initial_w_value = 0.0 - return [initial_z_value, initial_n_value, initial_g_value, initial_w_value] diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 8313fca6..ac2206a1 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -29,12 +29,18 @@ class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, Custo def __init__(self, learning_rate, use_locking=False, name="GradientDescent"): self.optimizer_type = "gradient_descent" - super(CustomizedGradientDescent, self).__get_name__(name=name) + super(CustomizedGradientDescent, self)._get_name(name=name) super(CustomizedGradientDescent, self).__init__(learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name) check_param_type("use_locking", use_locking, bool) + def initialize_slots(self, var, table_instance): + return [] + + def get_slot_init_values(self): + return [] + def _apply_sparse_duplicate_indices(self, grad, var): logging.debug(" Enter _apply_sparse_duplicate_indices") nd_indices = tf.expand_dims(grad.indices, 1) @@ -45,9 +51,3 @@ class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, Custo def _apply_dense(self, grad, var): logging.debug(" Enter _apply_dense") raise NotImplementedError("You are using a wrong type of variable.") - - def initialize_slots(self, var, table_instance): - return [] - - def get_slot_init_values(self): - return [] diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index fbc34e0b..596b0375 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -41,20 +41,9 @@ class CustomizedGradientDescentByAddr(optimizer.Optimizer, CustomizedOptimizer): self._learning_rate_tensor = None self._slot_num = 0 - def _convert_grads_and_addrs(self, grads_and_vars): - converted_grads_and_addrs = [] - for grad, addr in grads_and_vars: - if grad is not None: - try: - # Convert the grad to Tensor or IndexedSlices if necessary. - grad = ops.convert_to_tensor_or_indexed_slices(grad) - except TypeError as error: - raise TypeError("Gradient must be convertible to a Tensor or IndexedSlices, or None") from error - if not isinstance(grad, (ops.Tensor, indexed_slices.IndexedSlices)): - raise TypeError("Gradient must be a Tensor, IndexedSlices, or None") - processor = _get_processor(addr) - converted_grads_and_addrs.append((grad, addr, processor)) - return converted_grads_and_addrs + @property + def slot_num(self): + return self._slot_num def apply_gradients(self, grads_and_vars, global_step=None, name=None): # No DistributionStrategy case. @@ -101,18 +90,29 @@ class CustomizedGradientDescentByAddr(optimizer.Optimizer, CustomizedOptimizer): return apply_updates - @property - def slot_num(self): - return self._slot_num + def get_slot_init_values(self): + return [] + + def _convert_grads_and_addrs(self, grads_and_vars): + converted_grads_and_addrs = [] + for grad, addr in grads_and_vars: + if grad is not None: + try: + # Convert the grad to Tensor or IndexedSlices if necessary. + grad = ops.convert_to_tensor_or_indexed_slices(grad) + except TypeError as error: + raise TypeError("Gradient must be convertible to a Tensor or IndexedSlices, or None") from error + if not isinstance(grad, (ops.Tensor, indexed_slices.IndexedSlices)): + raise TypeError("Gradient must be a Tensor, IndexedSlices, or None") + processor = _get_processor(addr) + converted_grads_and_addrs.append((grad, addr, processor)) + return converted_grads_and_addrs def _prepare(self): learning_rate = self._call_if_callable(self._learning_rate) self._learning_rate_tensor = ops.convert_to_tensor( learning_rate, name="learning_rate") - def get_slot_init_values(self): - return [] - def _apply_sparse(self, grad, addr): logging.debug(">>>> Enter _apply_sparse SGD by addr") host_pipeline_ops = get_host_pipeline_ops() @@ -168,12 +168,12 @@ class _TensorByAddressProcessor(_OptimizableAddr): def __init__(self, addr): self._a = addr - def target(self): - return self._a - def __str__(self): return "<_TensorByAddressProcessor(%s)>" % self._a + def target(self): + return self._a + def update_op(self, opt, grad): if isinstance(grad, ops.Tensor): logging.debug(">>>>Enter update_op ops.Tensor") diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 649302ff..51cf6e6f 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -42,7 +42,7 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="LazyAdam"): self.optimizer_type = "LazyAdam" - super(CustomizedLazyAdam, self).__get_name__(name=name) + super(CustomizedLazyAdam, self)._get_name(name=name) super(CustomizedLazyAdam, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, use_locking=use_locking, name=self.unique_name) @@ -57,6 +57,40 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): check_param_type("use_locking", use_locking, bool) + def initialize_slots(self, var, table_instance): + # Create slots for the first and second moments. + def creat_one_single_slot(var, op_name): + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + return new_slot_variable + + momentum = creat_one_single_slot(var, self._name + "/" + "momentum") + velocity = creat_one_single_slot(var, self._name + "/" + "velocity") + remove_saving_var(momentum) + remove_saving_var(velocity) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) + return [{"slot": momentum, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}, + {"slot": velocity, "named_slot_key": named_slot_key, "slot_name": "v", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + # return state value list of adam that needs to initialize in ASC DDR. + initial_momentum_value = 0.0 + initial_velocity_value = 0.0 + return [initial_momentum_value, initial_velocity_value] + def _apply_sparse_duplicate_indices(self, grad, var): # _apply_sparse_duplicate_indices method include tf.unique and unsorted_segment_sum operations which may @@ -147,37 +181,3 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) - - def initialize_slots(self, var, table_instance): - # Create slots for the first and second moments. - def creat_one_single_slot(var, op_name): - new_slot_variable = slot_creator.create_zeros_slot(var, op_name) - # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - return new_slot_variable - - momentum = creat_one_single_slot(var, self._name + "/" + "momentum") - velocity = creat_one_single_slot(var, self._name + "/" + "velocity") - remove_saving_var(momentum) - remove_saving_var(velocity) - named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") - - table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) - return [{"slot": momentum, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}, - {"slot": velocity, "named_slot_key": named_slot_key, "slot_name": "v", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot - - def get_slot_init_values(self): - # return state value list of adam that needs to initialize in ASC DDR. - initial_momentum_value = 0.0 - initial_velocity_value = 0.0 - return [initial_momentum_value, initial_velocity_value] diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 6a5a598d..42d0e5ba 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -69,6 +69,61 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): self._check_input_param() + @property + def slot_num(self): + return self._slot_num + + def get_slot_init_values(self): + # return state value list of adam that needs to initialize in ASC DDR. + initial_momentum_value = 0.0 + initial_velocity_value = 0.0 + return [initial_momentum_value, initial_velocity_value] + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + + # No DistributionStrategy case. + grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. + if not grads_and_vars: + raise ValueError("No variables provided.") + + converted_grads_and_addrs = tuple(self._convert_grads_and_addrs(grads_and_vars)) + addr_list = [a for g, a, _ in converted_grads_and_addrs if g is not None] + if not addr_list: + raise ValueError("No gradients provided for any address: %s." % + ([str(a) for _, a, _ in converted_grads_and_addrs],)) + with ops.init_scope(): + self._create_slots(addr_list) + update_ops = [] + with ops.name_scope(name, self._name) as name: + self._prepare() + for grad, addr, processor in converted_grads_and_addrs: + if grad is None: + continue + if (context.executing_eagerly() or + resource_variable_ops.is_resource_variable(addr) + and not addr._in_graph_mode): # pylint: disable=protected-access + scope_name = "" + else: + scope_name = addr.op.name + with ops.name_scope( + f"update_{scope_name}"), ops.colocate_with(addr): + update_ops.append(processor.update_op(self, grad)) + + apply_updates = self._finish(update_ops, name) + + if not context.executing_eagerly(): + if isinstance(apply_updates, ops.Tensor): + logging.debug(">>>>Enter ops.Tensor") + apply_updates = apply_updates.op + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + if apply_updates not in train_op: + logging.debug(">>>>Enter apply_updates not in train_op") + train_op.append(apply_updates) + else: + raise RuntimeError("eager wrong.") + + return apply_updates + def _check_input_param(self): check_param_type("beta1", self._beta1, (int, float)) check_param_range("beta1", self._beta1, 0, 1) @@ -81,10 +136,6 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): check_param_type("use_locking", self._use_locking, bool) - @property - def slot_num(self): - return self._slot_num - def _get_beta_accumulators(self): with ops.init_scope(): if context.executing_eagerly(): @@ -141,12 +192,6 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): beta2_power * self._beta2_t, use_locking=self._use_locking) return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], name=name_scope) - def get_slot_init_values(self): - # return state value list of adam that needs to initialize in ASC DDR. - initial_momentum_value = 0.0 - initial_velocity_value = 0.0 - return [initial_momentum_value, initial_velocity_value] - def _apply_dense(self, grad, var): logging.debug(">>>>Enter _apply_dense") raise NotImplementedError("You are using a wrong type of variable.") @@ -210,51 +255,6 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): converted_grads_and_addrs.append((grad, addr, processor)) return converted_grads_and_addrs - def apply_gradients(self, grads_and_vars, global_step=None, name=None): - - # No DistributionStrategy case. - grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. - if not grads_and_vars: - raise ValueError("No variables provided.") - - converted_grads_and_addrs = tuple(self._convert_grads_and_addrs(grads_and_vars)) - addr_list = [a for g, a, _ in converted_grads_and_addrs if g is not None] - if not addr_list: - raise ValueError("No gradients provided for any address: %s." % - ([str(a) for _, a, _ in converted_grads_and_addrs],)) - with ops.init_scope(): - self._create_slots(addr_list) - update_ops = [] - with ops.name_scope(name, self._name) as name: - self._prepare() - for grad, addr, processor in converted_grads_and_addrs: - if grad is None: - continue - if (context.executing_eagerly() or - resource_variable_ops.is_resource_variable(addr) - and not addr._in_graph_mode): # pylint: disable=protected-access - scope_name = "" - else: - scope_name = addr.op.name - with ops.name_scope( - f"update_{scope_name}"), ops.colocate_with(addr): - update_ops.append(processor.update_op(self, grad)) - - apply_updates = self._finish(update_ops, name) - - if not context.executing_eagerly(): - if isinstance(apply_updates, ops.Tensor): - logging.debug(">>>>Enter ops.Tensor") - apply_updates = apply_updates.op - train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) - if apply_updates not in train_op: - logging.debug(">>>>Enter apply_updates not in train_op") - train_op.append(apply_updates) - else: - raise RuntimeError("eager wrong.") - - return apply_updates - def get_filtered_grad_fn(grad_fn): def filtered_grad_fn(*args, **kwargs): @@ -283,12 +283,12 @@ class _TensorByAddressProcessor(_OptimizableAddr): def __init__(self, addr): self._a = addr - def target(self): - return self._a - def __str__(self): return "<_TensorByAddressProcessor(%s)>" % self._a + def target(self): + return self._a + def update_op(self, opt, grad): if isinstance(grad, ops.Tensor): logging.debug(">>>>Enter update_op ops.Tensor") diff --git a/mx_rec/optimizers/momentum.py b/mx_rec/optimizers/momentum.py index 66c96552..2424b8f6 100644 --- a/mx_rec/optimizers/momentum.py +++ b/mx_rec/optimizers/momentum.py @@ -54,7 +54,7 @@ class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): name="Momentum", use_nesterov=False): self.optimizer_type = "Momentum" - super(CustomizedMomentum, self).__get_name__(name=name) + super(CustomizedMomentum, self)._get_name(name=name) super(CustomizedMomentum, self).__init__(learning_rate=learning_rate, momentum=momentum_var, use_locking=use_locking, @@ -63,6 +63,36 @@ class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): self._check_input_param() + def initialize_slots(self, var, table_instance): + # Create slots for the first and second moments. + def creat_one_single_slot(var, op_name): + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. + return new_slot_variable + + momentum_slot = creat_one_single_slot(var, self._name + "/" + "momentum") + remove_saving_var(momentum_slot) + named_slot_key = (var.op.graph, var.op.name) + table_instance = get_table_instance(var) + if self._name in table_instance.optimizer: + raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + + table_instance.set_optimizer(self._name, {"momentum": momentum_slot}) + return [{"slot": momentum_slot, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + named_slots = self._slot_dict(slot_name) + if named_slots_key in named_slots: + raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " + f"please double check.") + + named_slots[named_slots_key] = slot + + def get_slot_init_values(self): + # return state value list of momentum that needs to initialize in ASC DDR. + initial_momentum_value = 0.0 + return [initial_momentum_value] + def _check_input_param(self): check_param_type("learning_rate", self._learning_rate, (tf.Tensor, float)) check_param_type("momentum", self._momentum, (tf.Tensor, float)) @@ -98,33 +128,3 @@ class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): grad, indices, math_ops.cast(self._momentum_tensor, grad.dtype), use_locking=self._use_locking, use_nesterov=self._use_nesterov) - - def initialize_slots(self, var, table_instance): - # Create slots for the first and second moments. - def creat_one_single_slot(var, op_name): - new_slot_variable = slot_creator.create_zeros_slot(var, op_name) - # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - return new_slot_variable - - momentum_slot = creat_one_single_slot(var, self._name + "/" + "momentum") - remove_saving_var(momentum_slot) - named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") - - table_instance.set_optimizer(self._name, {"momentum": momentum_slot}) - return [{"slot": momentum_slot, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot - - def get_slot_init_values(self): - # return state value list of momentum that needs to initialize in ASC DDR. - initial_momentum_value = 0.0 - return [initial_momentum_value] diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index a862e073..93a00810 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -48,6 +48,56 @@ class Saver(object): logging.debug("Save & Restore graph was built.") + @performance("Save") + def save(self, sess, save_path="model", global_step=None): + """ + Save sparse tables + :param sess: A Session to use to save the sparse table variables + :param save_path: Only absolute path supported + :param global_step: If provided the global step number is appended to save_path to create + the checkpoint filenames. The optional argument can be a Tensor, a Tensor name or an integer. + :return: None + """ + logging.debug(f"======== Start saving for rank id {self.rank_id} ========") + save_path = save_path if save_path else self._prefix_name + directory, base_name = os.path.split(save_path) + if global_step: + if not isinstance(global_step, compat.integral_types): + global_step = int(sess.run(global_step)) + ckpt_name = "sparse-%s-%d" % (base_name, global_step) + else: + ckpt_name = "sparse-%s" % base_name + + integrated_path = os.path.join(directory, ckpt_name) + saving_path = integrated_path + if integrated_path.startswith("/"): + saving_path = os.path.abspath(integrated_path) + + if os.path.exists(saving_path): + shutil.rmtree(saving_path, ignore_errors=True) + logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been deleted.") + os.makedirs(saving_path, exist_ok=True) + logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been made.") + + self._save(sess, saving_path) + logging.info(f"sparse model was saved in dir '{saving_path}' .") + logging.debug(f"======== Saving finished for rank id {self.rank_id} ========") + + @performance("Restore") + def restore(self, sess, reading_path): + logging.debug("======== Start restoring ========") + directory, base_name = os.path.split(reading_path) + ckpt_name = "sparse-%s" % base_name + + integrated_path = os.path.join(directory, ckpt_name) + reading_path = os.path.abspath(integrated_path) + if not os.path.exists(reading_path): + raise FileExistsError(f"Given dir {reading_path} does not exist, please double check.") + + self._restore(sess, reading_path) + logging.info(f"sparse model was restored from dir '{reading_path}' .") + logging.debug("======== Restoring finished ========") + def _build_save(self): for var in self.var_list: table_instance = get_table_instance(var) @@ -87,41 +137,6 @@ class Saver(object): assign_op = state.assign(sub_optimizer_placeholder_dict.get(key_state)) self.restore_fetch_list.append(assign_op) - @performance("Save") - def save(self, sess, save_path="model", global_step=None): - """ - Save sparse tables - :param sess: A Session to use to save the sparse table variables - :param save_path: Only absolute path supported - :param global_step: If provided the global step number is appended to save_path to create - the checkpoint filenames. The optional argument can be a Tensor, a Tensor name or an integer. - :return: None - """ - logging.debug(f"======== Start saving for rank id {self.rank_id} ========") - save_path = save_path if save_path else self._prefix_name - directory, base_name = os.path.split(save_path) - if global_step: - if not isinstance(global_step, compat.integral_types): - global_step = int(sess.run(global_step)) - ckpt_name = "sparse-%s-%d" % (base_name, global_step) - else: - ckpt_name = "sparse-%s" % base_name - - integrated_path = os.path.join(directory, ckpt_name) - saving_path = integrated_path - if integrated_path.startswith("/"): - saving_path = os.path.abspath(integrated_path) - - if os.path.exists(saving_path): - shutil.rmtree(saving_path, ignore_errors=True) - logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been deleted.") - os.makedirs(saving_path, exist_ok=True) - logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been made.") - - self._save(sess, saving_path) - logging.info(f"sparse model was saved in dir '{saving_path}' .") - logging.debug(f"======== Saving finished for rank id {self.rank_id} ========") - def _save(self, sess, root_dir): result = sess.run(self.save_op_dict) for table_name, dump_data_dict in result.items(): @@ -140,21 +155,6 @@ class Saver(object): save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, self.rank_id) - @performance("Restore") - def restore(self, sess, reading_path): - logging.debug("======== Start restoring ========") - directory, base_name = os.path.split(reading_path) - ckpt_name = "sparse-%s" % base_name - - integrated_path = os.path.join(directory, ckpt_name) - reading_path = os.path.abspath(integrated_path) - if not os.path.exists(reading_path): - raise FileExistsError(f"Given dir {reading_path} does not exist, please double check.") - - self._restore(sess, reading_path) - logging.info(f"sparse model was restored from dir '{reading_path}' .") - logging.debug("======== Restoring finished ========") - def _restore(self, sess, reading_path): if is_asc_manager_initialized(): restore_host_data(reading_path) diff --git a/mx_rec/util/__init__.py b/mx_rec/util/__init__.py index e2c29877..6b6497b8 100644 --- a/mx_rec/util/__init__.py +++ b/mx_rec/util/__init__.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -from .log import get_log_level +from mx_rec.util.log import get_log_level get_log_level() diff --git a/mx_rec/util/atomic.py b/mx_rec/util/atomic.py index f03ea569..bca4c660 100644 --- a/mx_rec/util/atomic.py +++ b/mx_rec/util/atomic.py @@ -10,6 +10,9 @@ class AtomicInteger(): self._value = int(value) self._lock = threading.Lock() + def __str__(self): + return str(self.value()) + def increase(self, num=1): with self._lock: self._value += int(num) @@ -21,6 +24,3 @@ class AtomicInteger(): def value(self): with self._lock: return self._value - - def __str__(self): - return str(self.value()) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 7d0c435b..e9ae8c5b 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -73,6 +73,86 @@ class ConfigInitializer: def __del__(self): self.terminate() + @property + def feature_spec_dict(self): + return self._feature_spec_dict + + @property + def table_name_set(self): + return self._table_name_set + + @property + def table_name_to_feature_spec(self): + return self._table_name_to_feature_spec + + @property + def table_instance_dict(self): + return self._table_instance_dict + + @property + def optimizer_instance(self): + return self._optimizer_instance + + @property + def is_frozen(self): + return self._is_frozen + + @property + def name_to_var_dict(self): + return self._name_to_var_dict + + @property + def use_mpi(self): + return self._use_mpi + + @property + def rank_size(self): + return self._rank_size + + @property + def rank_id(self): + return self._rank_id + + @property + def device_id(self): + if self._rank_id not in self._rank_to_device_dict: + raise KeyError(f"rank id not in rank_to_device_dict. {self._rank_id} {self._rank_to_device_dict}") + return self._rank_to_device_dict[self._rank_id] + + @property + def train_interval(self): + return self._train_interval + + @property + def eval_steps(self): + return self._eval_steps + + @property + def prefetch_batch_number(self): + return self._prefetch_batch_number + + @property + def if_load(self): + return self._if_load + + @property + def ascend_global_hashtable_collection(self): + return self._ascend_global_hashtable_collection + + @staticmethod + def get_instance(): + if ConfigInitializer._single_instance is None: + raise EnvironmentError("Please init the environment for mx_rec at first.") + + return ConfigInitializer._single_instance + + @staticmethod + def set_instance(use_mpi, **kwargs): + if ConfigInitializer._single_instance is not None: + raise EnvironmentError("ConfigInitializer has been initialized once, twice initialization was forbidden.") + + ConfigInitializer._single_instance = ConfigInitializer(use_mpi, **kwargs) + def terminate(self): if self._asc_manager is not None: self.del_asc_manager() @@ -90,20 +170,8 @@ class ConfigInitializer: def get_feature_spec(self, key): return self._feature_spec_dict.get(key) - @property - def feature_spec_dict(self): - return self._feature_spec_dict - - @property - def table_name_set(self): - return self._table_name_set - - @property - def table_name_to_feature_spec(self): - return self._table_name_to_feature_spec - def parse_hccl_json(self): - rank_table_path = os.getenv("RANK_TABLE_FILE") + rank_table_path = os.path.realpath(os.getenv("RANK_TABLE_FILE")) if not os.path.exists(rank_table_path): raise FileExistsError(f"Target_hccl_json_dir {rank_table_path} does not exist when reading.") with open(rank_table_path, "r", encoding="utf-8") as file: @@ -191,17 +259,9 @@ class ConfigInitializer: key = self._name_to_var_dict.get(table_name) return self._table_instance_dict.get(key) - @property - def table_instance_dict(self): - return self._table_instance_dict - def insert_optimizer(self, optimizer): self._optimizer_instance = optimizer - @property - def optimizer_instance(self): - return self._optimizer_instance - def check_parameters(self): if not isinstance(self._use_mpi, bool): raise ValueError(f"Arg use_mpi must be a boolean.") @@ -224,14 +284,6 @@ class ConfigInitializer: def unfreeze(self): self._is_frozen = False - @property - def is_frozen(self): - return self._is_frozen - - @property - def name_to_var_dict(self): - return self._name_to_var_dict - def set_asc_manager(self, manager): from mxrec_pybind import HybridMgmt if not isinstance(manager, HybridMgmt): @@ -250,46 +302,6 @@ class ConfigInitializer: self.unfreeze() logging.debug("ASC manager has been destroyed.") - @property - def use_mpi(self): - return self._use_mpi - - @property - def rank_size(self): - return self._rank_size - - @property - def rank_id(self): - return self._rank_id - - @property - def device_id(self): - if self._rank_id not in self._rank_to_device_dict: - raise KeyError(f"rank id not in rank_to_device_dict. {self._rank_id} {self._rank_to_device_dict}") - return self._rank_to_device_dict[self._rank_id] - - @staticmethod - def get_instance(): - if ConfigInitializer._single_instance is None: - raise EnvironmentError("Please init the environment for mx_rec at first.") - - return ConfigInitializer._single_instance - - @staticmethod - def set_instance(use_mpi, **kwargs): - if ConfigInitializer._single_instance is not None: - raise EnvironmentError("ConfigInitializer has been initialized once, twice initialization was forbidden.") - - ConfigInitializer._single_instance = ConfigInitializer(use_mpi, **kwargs) - - @property - def train_interval(self): - return self._train_interval - - @property - def eval_steps(self): - return self._eval_steps - @train_interval.setter def train_interval(self, interval): check_step(interval) @@ -300,19 +312,11 @@ class ConfigInitializer: check_step(steps) self._eval_steps = steps - @property - def prefetch_batch_number(self): - return self._prefetch_batch_number - @prefetch_batch_number.setter def prefetch_batch_number(self, number): check_step(number, 1) self._prefetch_batch_number = number - @property - def if_load(self): - return self._if_load - @if_load.setter def if_load(self, flag): if not isinstance(flag, bool): @@ -320,10 +324,6 @@ class ConfigInitializer: self._if_load = flag - @property - def ascend_global_hashtable_collection(self): - return self._ascend_global_hashtable_collection - @ascend_global_hashtable_collection.setter def ascend_global_hashtable_collection(self, name): if not isinstance(name, str): -- Gitee From 747ab46dfee4eace1ad72d282e7f015f00c72bbf Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 25 May 2023 21:51:31 +0800 Subject: [PATCH 024/551] Match-id-85f7a0608cacbdac224fe97f68991fe141355a34 --- example/little_demo/main.py | 15 ++-- example/little_demo/run.sh | 5 +- mx_rec/core/asc/build_graph.py | 9 ++- mx_rec/core/asc/feature_spec.py | 9 +++ mx_rec/core/asc/helper.py | 43 +++++++---- mx_rec/core/asc/manager.py | 23 +++++- mx_rec/core/embedding.py | 46 +++++++----- mx_rec/graph/modifier.py | 41 +++++++---- src/core/checkpoint/checkpoint.cpp | 2 +- .../host_emb_ckpt/host_emb_ckpt.cpp | 4 +- src/core/emb_mgmt/emb_mgmt.cpp | 55 +++++++++----- src/core/emb_table/emb_table.cpp | 6 +- src/core/hd_transfer/hd_transfer.cpp | 11 ++- src/core/host_emb/host_emb.cpp | 16 ++--- src/core/key_process/key_process.cpp | 71 +++++++++++++++---- src/core/key_process/key_process.h | 6 ++ src/core/utils/common.h | 19 ++++- src/ops_tf/hybrid_dataset_ops.cpp | 32 +++++++-- src/pybind/module_main.cpp | 14 ++-- src/tests/checkpoint/checkpoint_test.cpp | 4 +- .../ckpt_data_handler_test.cpp | 2 +- src/tests/emb_mgmt/emb_mgmt_test.cpp | 20 ++++-- src/tests/emb_table/emb_table_test.cpp | 2 +- src/tests/key_process/key_process_test.cpp | 2 +- 24 files changed, 327 insertions(+), 130 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 01ee8bbc..f2d0175f 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -10,7 +10,8 @@ from dataset import generate_dataset from optimizer import get_dense_and_sparse_optimizer from model import MyModel from mx_rec.util.tf_version_adapter import hccl_ops -from mx_rec.core.asc.helper import FeatureSpec, get_asc_insert_func +from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import create_table, sparse_lookup from mx_rec.graph.modifier import modify_graph_and_start_emb_cache @@ -56,10 +57,14 @@ def build_graph(hash_table_list, is_train, use_timestamp=False, config_dict=None batch, iterator = make_batch_and_iterator(is_training=is_train, use_timestamp=use_timestamp, dump_graph=is_train, batch_number=batch_number) if MODIFY_GRAPH_FLAG: - feature_list = [batch["user_ids"], batch["item_ids"]] + input_list = [ + [batch["user_ids"], batch["item_ids"], batch["user_ids"], batch["item_ids"]], + [hash_table_list[0], hash_table_list[0], hash_table_list[0], hash_table_list[1]], + [cfg.user_send_cnt, cfg.item_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt], + ] if USE_TIMESTAMP: tf.add_to_collection(ASCEND_TIMESTAMP, batch["timestamp"]) - model = model_forward([feature_list, hash_table_list, [cfg.user_send_cnt, cfg.item_send_cnt]], batch, + model = model_forward(input_list, batch, is_train=is_train, modify_graph=True, config_dict=config_dict) else: model = model_forward([feature_spec_list, hash_table_list, [cfg.user_send_cnt, cfg.item_send_cnt]], batch, @@ -141,7 +146,7 @@ if __name__ == "__main__": name='user_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), device_vocabulary_size=cfg.user_vocab_size * 10, - host_vocabulary_size=0, # cfg.user_vocab_size * 100, # for h2d test + host_vocabulary_size=0, # cfg.user_vocab_size * 100, # for h2d test optimizer_list=sparse_optimizer_list, mode=mode) @@ -150,7 +155,7 @@ if __name__ == "__main__": name='item_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), device_vocabulary_size=cfg.item_vocab_size * 10, - host_vocabulary_size=0, # cfg.user_vocab_size * 100, # for h2d test + host_vocabulary_size=0, # cfg.user_vocab_size * 100, # for h2d test optimizer_list=sparse_optimizer_list, mode=mode) diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index abbe8cae..76e9d1dd 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -28,7 +28,10 @@ export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL export MXREC_MODE="ASC" export USE_MPI=1 -export USE_DYNAMIC=0 # 0: 静态;1:动态 + +export USE_DYNAMIC=0 # 0:静态shape;1:动态shape +export USE_HOT=0 # 0:关闭hot emb;1: 开启hot emb +export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 #################使用去除ranktable方案时开启###################### #export CM_CHIEF_IP="192.168.1.1" # 主节点ip diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 153b6f7f..31cf30a6 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -14,7 +14,6 @@ from mx_rec.util.tf_version_adapter import npu_ops def get_restore_vector(config): logging.debug(f'Channel {config.get("table_name")}_restore_{config.get("channel_id")} was built for getnext') - emb_size = None if config.get("skip_emb_transfer"): if not isinstance(config.get("emb_size"), int) or config.get("emb_size") < 1: raise TypeError(f"emb_size must be a int") @@ -95,12 +94,16 @@ def get_all2all_args(use_static: bool, config: dict) -> list: return all2all_args -def get_preprocessed_tensor_for_asc(table, config): +def get_preprocessed_tensor_for_asc(table, config, ids_channel_name=None, modify_graph=False): use_static = get_use_static() max_lookup_vec_size = None if use_static: max_lookup_vec_size = config.get("send_count") * config.get("rank_size") + if modify_graph: + config["table_name"] = ids_channel_name + logging.debug(f"GetNext, table_name: {config.get('table_name')}, modify_graph: {modify_graph}") + with tf.compat.v1.variable_scope("restore_vector"): restore_vector, hot_pos = get_restore_vector(config) @@ -138,5 +141,5 @@ def get_preprocessed_tensor_for_asc(table, config): table_num = len(table) h2d_emb_split = tf.split(h2d_emb, table_num, axis=1) swap_in = [tf.compat.v1.scatter_nd_update(table[i], nd_swap_pos, h2d_emb_split[i]) - for i in range(len(table))] + for i in range(len(table))] return restore_vector, hot_pos, id_offsets, swap_in, all2all_args diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 7c148892..8e599e59 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -172,3 +172,12 @@ class FeatureSpec: insert_feature_spec(self, is_training) return tensor, self.table_name, self.feat_cnt, self.split + + +def get_feature_spec(table_name, access_and_evict_config): + access_threshold = None + eviction_threshold = None + if access_and_evict_config: + access_threshold = access_and_evict_config.get("access_threshold") + eviction_threshold = access_and_evict_config.get("eviction_threshold") + return FeatureSpec(table_name, access_threshold=access_threshold, eviction_threshold=eviction_threshold) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 92a034e5..21f497f3 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -4,12 +4,11 @@ import logging from functools import reduce -import os import tensorflow as tf -from mx_rec.util.initialize import get_host_pipeline_ops, insert_feature_spec, insert_training_mode_channel_id, \ - get_training_mode_channel_id, get_use_static +from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static, \ + export_table_instances from mx_rec.core.asc.feature_spec import FeatureSpec @@ -142,8 +141,7 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list, featu return output_feature_id_list, output_split_list, output_table_name_list, output_tensorshape_split_list -def send_feature_id_request_async(feature_id_list, split_list, - table_name_list, input_dict): +def send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict): is_training = input_dict["is_training"] timestamp = input_dict["timestamp"] feature_spec_names = input_dict["feature_spec_names"] @@ -170,18 +168,37 @@ def send_feature_id_request_async(feature_id_list, split_list, feature_id_list = timestamp_feature_id + feature_id_list concat_tensor = tf.concat(feature_id_list, axis=0) + ids_channel_name_list = [] + if auto_change_graph: + for _, table_instance in export_table_instances().items(): + if table_instance.table_name not in table_name_list: + logging.info(f"table_name ('{table_instance.table_name}') not in table_name_list: {table_name_list}") + continue + if len(table_instance.channel_name_list) > 1: + ids_channel_name_list.extend(table_instance.channel_name_list) + else: + ids_channel_name_list.append(table_instance.table_name) + if len(ids_channel_name_list) != len(tensorshape_split_list): + raise RuntimeError(f"The length of ids_channel_name_list and tensorshape_split_list must be equal, " + f"ids_channel_name_list: {ids_channel_name_list}, " + f"tensorshape_split_list: {tensorshape_split_list}") + + if len(split_list) == 0 or len(tensorshape_split_list) == 0: + raise RuntimeError(f"The length of split list can not be 0.") + if use_static: - logging.debug(f"read_emb_key_v2(static), table_name_list: {table_name_list}, split_list: {split_list}") + logging.debug(f"read_emb_key_v2(static), table_name_list: {table_name_list}, split_list: {split_list}, " + f"ids_channel_name_list: {ids_channel_name_list}") return host_pipeline_ops.read_emb_key_v2(concat_tensor, channel_id=channel_id, splits=split_list, - emb_name=table_name_list, timestamp=timestamp) + emb_name=table_name_list, timestamp=timestamp, + channel_name=ids_channel_name_list, modify_graph=auto_change_graph) logging.debug(f"read_emb_key_v2_dynamic, table_name_list: {table_name_list}, " - f"tensorshape_split_list: {tensorshape_split_list}") - pipeline_op = host_pipeline_ops.read_emb_key_v2_dynamic(concat_tensor, tensorshape_split_list, - channel_id=channel_id, - emb_name=table_name_list, timestamp=timestamp) - - return pipeline_op + f"tensorshape_split_list: {tensorshape_split_list}, ids_channel_name_list: {ids_channel_name_list}") + return host_pipeline_ops.read_emb_key_v2_dynamic(concat_tensor, tensorshape_split_list, + channel_id=channel_id, emb_name=table_name_list, + timestamp=timestamp, channel_name=ids_channel_name_list, + modify_graph=auto_change_graph) def do_insert(args, insert_tensors, splits, table_names, input_dict): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 9554f74f..a8dcf6cb 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -3,7 +3,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import logging -import os import tensorflow as tf @@ -48,11 +47,29 @@ def generate_table_info_list(): if static_shape_rec_flag or dynamic_shape_rec_flag: logging.debug(f"table_instance.slice_device_vocabulary_size: {table_instance.slice_device_vocabulary_size}") logging.debug(f"table_instance.slice_host_vocabulary_size: {table_instance.slice_host_vocabulary_size}") - table_info = EmbInfo(table_instance.table_name, table_instance.send_count, table_instance.ext_emb_size, + if table_instance.modify_graph and len(table_instance.channel_name_list) > 1 \ + and table_instance.slice_host_vocabulary_size > 0: + raise RuntimeError(f"In the case of modify graph, multiple lookups of a table are currently " + f"only compatible with HBM mode.") + if len(table_instance.channel_name_list) == 1: + ids_channel_name = table_instance.channel_name_list[0] + table_instance.channel_name_list = [table_instance.table_name] + try: + table_instance.send_count_map.pop(ids_channel_name) + table_instance.send_count_map[table_instance.table_name] = table_instance.send_count + except KeyError as error: + raise KeyError(f"ids_channel_name '{ids_channel_name}' not in send_count_map " + f"'{table_instance.send_count_map}'") from error + logging.debug(f"table_instance, table_name: {table_instance.table_name}, channel_name_list: " + f"{table_instance.channel_name_list}, send_count_map: {table_instance.send_count_map}") + table_info = EmbInfo(table_instance.table_name, table_instance.send_count, table_instance.scalar_emb_size, + table_instance.ext_emb_size, table_instance.modify_graph, + table_instance.channel_name_list, [table_instance.slice_device_vocabulary_size, table_instance.slice_host_vocabulary_size], [matched_emb_initializer(table_instance)] + - matched_opt_slot_initializers(table_instance)) + matched_opt_slot_initializers(table_instance), + table_instance.send_count_map) table_info_list.append(table_info) return table_info_list diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index b1106532..4becc98a 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -13,7 +13,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc -from mx_rec.core.asc.helper import FeatureSpec +from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ @@ -137,12 +137,16 @@ class SparseEmbedding: self.ext_emb_size = None self.ext_coefficient = 1 self._optimizer = dict() - self.slice_device_vocabulary_size = None - self.slice_host_vocabulary_size = None + self.slice_device_vocabulary_size = 0 + self.slice_host_vocabulary_size = 0 self.variable = None self.lookup_info = set() self.lookup_result = None self.use_dynamic_expansion = get_use_dynamic_expansion() + self.channel_name_list = [] + self.send_count_map = dict() + self.channel_name_dict = {True: [], False: []} + self.modify_graph = False self.set_slice_vocab_size() self.set_emb_size() @@ -350,6 +354,12 @@ class SparseEmbedding: self._optimizer[key] = state_dict + def set_channel_name(self, ids_channel_name, eval_mode): + self.channel_name_list.append(ids_channel_name) + if not eval_mode: + self.channel_name_dict.get(True).insert(0, ids_channel_name) + self.channel_name_dict.get(False).insert(0, ids_channel_name) + def lookup_for_asc(self, ids: tf.Tensor, send_count, **kwargs): """ @@ -373,19 +383,12 @@ class SparseEmbedding: is_training = kwargs.get("is_train") if is_asc_frozen() and is_training: raise RuntimeError(f"Cannot build new sparse forward graph after emb cache management was built.") - if is_asc_frozen() and not is_training: - clear_channel(is_train_channel=False) - access_threshold = None - eviction_threshold = None - if kwargs.get("access_and_evict_config"): - access_and_evict_config = kwargs.get("access_and_evict_config") - access_threshold = access_and_evict_config.get("access_threshold") - eviction_threshold = access_and_evict_config.get("eviction_threshold") - feature_spec = FeatureSpec(self.table_name, - access_threshold=access_threshold, - eviction_threshold=eviction_threshold) + feature_spec = get_feature_spec(self.table_name, kwargs.get("access_and_evict_config")) feature_spec.set_feat_attribute(ids, is_training) + # 'clear_channel()' function needs to be executed after 'set_feat_attribute()' function + if is_asc_frozen() and not is_training: + clear_channel(is_train_channel=False) self.check_and_format_lookup_params(ids, send_count, is_training) anchor_ids = tf.identity(ids, name="ids") @@ -395,8 +398,19 @@ class SparseEmbedding: use_dynamic_expansion = get_use_dynamic_expansion() use_static = get_use_static() use_hot = get_use_hot() - logging.debug(f"In lookup_for_asc function, table name: {self.table_name}, ids: {ids}, use_dynamic_expansion: " - f"{use_dynamic_expansion}, use_static: {use_static}, use_hot: {use_hot}") + eval_mode = not is_training and len(self.channel_name_dict.get(not is_training)) == 0 + ids_channel_name = "" + # set in train mode, train and eval mode, eval mode + if is_training or eval_mode: + ids_channel_name = feature_spec.name + "_lookup_ids" + self.set_channel_name(ids_channel_name, eval_mode) + send_count = send_count if send_count is not None else 0 + self._send_count = send_count + self.send_count_map[ids_channel_name] = send_count + self.modify_graph = kwargs.get("modify_graph", True) + logging.debug(f"In lookup_for_asc function, table name: {self.table_name}, anchor_ids: {anchor_ids}, " + f"ids_channel_name: {ids_channel_name}, use_dynamic_expansion: {use_dynamic_expansion}, " + f"use_static: {use_static}, use_hot: {use_hot}") rank_size = get_rank_size() id_offsets = tf.ones(shape=[send_count * rank_size if use_static else 1 * rank_size, ], diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 9b92b1f5..7ceb42dd 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -9,14 +9,15 @@ import tensorflow as tf from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc -from mx_rec.core.asc.helper import get_asc_insert_func, FeatureSpec +from mx_rec.core.asc.helper import get_asc_insert_func +from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding -from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_CUTTING_POINT, \ - ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCAnchorAttr, ASCEND_TIMESTAMP -from mx_rec.util.initialize import get_rank_size, destroy_asc_manager, get_training_mode_channel_id, \ - get_feature_spec, insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, \ - get_use_dynamic_expansion, terminate_config_initializer +from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ + ASCAnchorAttr, ASCEND_TIMESTAMP +from mx_rec.util.initialize import get_rank_size, get_training_mode_channel_id, get_feature_spec, \ + insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, get_use_dynamic_expansion, \ + terminate_config_initializer from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, replace_anchor, \ record_ops_to_replace, export_pb_graph, make_sorted_key_to_tensor_list @@ -341,9 +342,9 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=table_instance.send_count, channel_id=channel_id, rank_size=get_rank_size(), table_name=table_instance.table_name, skip_emb_transfer=table_instance.skip_emb_transfer, - ext_emb_size=table_instance.ext_emb_size, emb_size=table_instance.emb_size, use_hot=get_use_hot(), - device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion()) - build_asc_graph(table_instance, cutting_point, config) + ext_emb_size=table_instance.ext_emb_size, emb_size=table_instance.scalar_emb_size, + use_hot=get_use_hot(), device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion()) + build_asc_graph(table_instance, cutting_point, config, is_training) logging.info("Graph has been revised.") export_pb_graph("new_graph.pb", dump_graph) @@ -370,19 +371,31 @@ def get_timestamp_index(get_next_op, is_training): return timestamp_index -def build_asc_graph(table_instance, cutting_point, config): +def build_asc_graph(table_instance, cutting_point, config, is_training): # returned results swap_pos and swap_len were not in used, will be applied for DDR mode logging.debug(f"try to replace anchors for table {config.get('table_name')} on channel {config.get('channel_id')}") skip_emb_transfer = config.get("skip_emb_transfer") logging.info(f"modifier build_asc_graph skip_emb_transfer: {skip_emb_transfer}") + + if len(table_instance.channel_name_list) > 1: + channel_name_queue = table_instance.channel_name_dict.get(is_training) + if len(channel_name_queue) < 1: + raise ValueError(f"The length of channel_name_queue must be greater than or equal to 1.") + ids_channel_name = channel_name_queue.pop() + config["send_count"] = table_instance.send_count_map.get(ids_channel_name) + elif len(table_instance.channel_name_list) == 1: + ids_channel_name = config.get('table_name') + else: + raise ValueError(f"The length of channel_name_list must be greater than or equal to 1.") + if skip_emb_transfer: restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( - table_instance.variable, config) + table_instance.variable, config, ids_channel_name, table_instance.modify_graph) else: variable_list = [table_instance.variable] \ - + [slot_info.get("slot") for slot_info in table_instance.optimizer_slot_info_list] - restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( - variable_list, config) + + [slot_info.get("slot") for slot_info in table_instance.optimizer_slot_info_list] + restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( + variable_list, config, ids_channel_name, table_instance.modify_graph) with tf.control_dependencies(swap_in): id_offsets = tf.identity(id_offsets) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 498386f6..b398da94 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -147,7 +147,7 @@ int Checkpoint::GetEmbeddingSize(const string& embName) { for (const auto &embInfo: mgmtEmbInfo) { if (embInfo.name == embName) { - return embInfo.embeddingSize; + return embInfo.extEmbeddingSize; } } return 0; diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp index 791e330e..dc573a63 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp @@ -83,7 +83,7 @@ void HostEmbCkpt::SetEmbInfoTrans(string embName) transArr.reserve(embInfoSize); transArr.push_back(hostEmbInfo.sendCount); - transArr.push_back(hostEmbInfo.embeddingSize); + transArr.push_back(hostEmbInfo.extEmbeddingSize); transArr.push_back(static_cast(hostEmbInfo.devVocabSize)); transArr.push_back(static_cast(hostEmbInfo.hostVocabSize)); } @@ -107,7 +107,7 @@ void HostEmbCkpt::SetEmbInfo(string embName, CkptData& ckptData) hostEmbInfo.name = embName; hostEmbInfo.sendCount = transArr.at(attribEmbInfoSendCntIdx); - hostEmbInfo.embeddingSize = transArr.at(attribEmbInfoEmbSizeIdx); + hostEmbInfo.extEmbeddingSize = transArr.at(attribEmbInfoEmbSizeIdx); hostEmbInfo.devVocabSize = static_cast(transArr.at(attribEmbInfoDevVocabIdx)); hostEmbInfo.hostVocabSize = static_cast(transArr.at(attribEmbInfoHostVocabIdx)); } diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index 03f696cb..8cee8a08 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -162,13 +162,14 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) const auto& loadEmbInfo { loadEmbTable->second.hostEmbInfo }; if (setupHostEmbs.sendCount != loadEmbInfo.sendCount) { - spdlog::error(MGMT + "Load data sendCount {} for table {} does not match setup sendCound {}", + spdlog::error(MGMT + "Load data sendCount {} for table {} does not match setup sendCount {}", setupHostEmbs.sendCount, setupHostEmbs.name, loadEmbInfo.sendCount); loadDataMatches = false; } - if (setupHostEmbs.embeddingSize != loadEmbInfo.embeddingSize) { - spdlog::error(MGMT + "Load data embeddingSize {} for table {} does not match setup embeddingSize {}", - setupHostEmbs.embeddingSize, setupHostEmbs.name, loadEmbInfo.embeddingSize); + if (setupHostEmbs.extEmbeddingSize != loadEmbInfo.extEmbeddingSize) { + spdlog::error(MGMT + "Load data extEmbeddingSize {} for table {} does not match " + "setup extEmbeddingSize {}", + setupHostEmbs.extEmbeddingSize, setupHostEmbs.name, loadEmbInfo.extEmbeddingSize); loadDataMatches = false; } if (setupHostEmbs.devVocabSize != loadEmbInfo.devVocabSize) { @@ -351,14 +352,22 @@ bool HybridMgmt::GetLookupAndRestore(int channelId, int &batchId) spdlog::info(MGMT + "start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); for (const auto& embInfo: mgmtEmbInfo) { TimeCost getAllTensorTC; - auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); - if (infoVecs == nullptr) { - spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); - return false; + vector names = {embInfo.name}; + if (embInfo.modifyGraph) { + names = embInfo.channelNames; + } + spdlog::debug(MGMT + "GetLookupAndRestore embInfoName:{}, modifyGraph:{}, names:{}", + embInfo.name, embInfo.modifyGraph, names); + for (const string& name: names) { + auto infoVecs = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::RESTORE); + if (infoVecs == nullptr) { + spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); + return false; + } + lookUpKeysQueue->Pushv({ infoVecs->back() }); + infoVecs->pop_back(); + restoreQueue->Pushv(*infoVecs); } - lookUpKeysQueue->Pushv({ infoVecs->back() }); - infoVecs->pop_back(); - restoreQueue->Pushv(*infoVecs); TIME_PRINT("getAllTensorTC TimeCost(ms):{}", getAllTensorTC.ElapsedMS()); } batchId++; @@ -368,9 +377,17 @@ bool HybridMgmt::GetLookupAndRestore(int channelId, int &batchId) bool HybridMgmt::SendLookupAndRestore(int channelId, int &batchId) { for (const auto& embInfo: mgmtEmbInfo) { + vector names = {embInfo.name}; + if (embInfo.modifyGraph) { + names = embInfo.channelNames; + } + spdlog::debug(MGMT + "SendLookupAndRestore embInfoName:{}, modifyGraph:{}, names:{}", + embInfo.name, embInfo.modifyGraph, names); if (!mgmtRankInfo.useStatic) { - auto all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(ALL2ALL, { *all2all }, channelId, embInfo.name); + for (const string& name: names) { + auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); + hdTransfer->Send(ALL2ALL, { *all2all }, channelId, name); + } } spdlog::info("SendLookupAndRestore batchId: {}, name: {}, channelId: {}", batchId, embInfo.name, channelId); @@ -382,15 +399,19 @@ bool HybridMgmt::SendLookupAndRestore(int channelId, int &batchId) #pragma omp section { TimeCost sendLookupTC; - auto lookUpKeys = lookUpKeysQueue->WaitAndPop(); - hdTransfer->Send(LOOKUP, lookUpKeys, channelId, embInfo.name); + for (const string& name: names) { + auto lookUpKeys = lookUpKeysQueue->WaitAndPop(); + hdTransfer->Send(LOOKUP, lookUpKeys, channelId, name); + } TIME_PRINT("LOOKUP Send TimeCost(ms):{}", sendLookupTC.ElapsedMS()); } #pragma omp section { TimeCost sendRestoreTC; - auto restore = restoreQueue->WaitAndPop(); - hdTransfer->Send(RESTORE, restore, channelId, embInfo.name); + for (const string& name: names) { + auto restore = restoreQueue->WaitAndPop(); + hdTransfer->Send(RESTORE, restore, channelId, name); + } TIME_PRINT("RESTORE Send TimeCost(ms):{}", sendRestoreTC.ElapsedMS()); } } diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 428dd23d..09cd6aec 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -25,14 +25,14 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) #ifndef GTEST this->rankInfo = rInfo; this->seed = seed; - spdlog::info("EmbTable init, deviceID {}, embSize {} running", rInfo.deviceId, embInfo.embeddingSize); + spdlog::info("EmbTable init, deviceID {}, embSize {} running", rInfo.deviceId, embInfo.extEmbeddingSize); // 计算embedding table需要分配的内存块数 auto ret = aclrtSetDevice(static_cast(rInfo.deviceId)); if (ret != ACL_ERROR_NONE) { spdlog::error("Set device failed, device_id:{}, ret={}", rInfo.deviceId, ret); throw AclError(); } - embSize = embInfo.embeddingSize; + embSize = embInfo.extEmbeddingSize; blockSize = BLOCK_EMB_COUNT * embSize; for (int i = 0; i < INIT_BLOCK_COUNT; ++i) { // 申请新的内存块 @@ -54,7 +54,7 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) } } totalCapacity = memoryList.size(); - spdlog::info("aclrtMalloc success, emb name:{}", embInfo.name); + spdlog::info("aclrtMalloc success, emb name:{}, total capacity:{}", embInfo.name, totalCapacity); #endif } diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 10ae60ed..f3e9618b 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -31,9 +31,14 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) } spdlog::info(MGMT + "end Set device, rank:{}", localRankId); for (const auto& embInfo: embInfos) { - auto embName = embInfo.name; - for (int i = 0; i < MAX_CHANNEL_NUM; ++i) { - CreateChannel(localRankId, embName, i); + vector names = {embInfo.name}; + if (embInfo.modifyGraph) { + names = embInfo.channelNames; + } + for (const string& name: names) { + for (int i = 0; i < MAX_CHANNEL_NUM; ++i) { + CreateChannel(localRankId, name, i); + } } } running = true; diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 203c9bbc..6fdb1f54 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -24,7 +24,7 @@ bool HostEmb::Initialize(const vector& embInfos, int seed, bool ifLoad) for (const auto& embInfo: embInfos) { HostEmbTable hostEmb; hostEmb.hostEmbInfo = embInfo; - EmbDataGenerator(embInfo.initializeInfos, seed, embInfo.hostVocabSize, embInfo.embeddingSize, + EmbDataGenerator(embInfo.initializeInfos, seed, embInfo.hostVocabSize, embInfo.extEmbeddingSize, hostEmb.embData); hostEmbs[embInfo.name] = move(hostEmb); spdlog::info(HOSTEMB + "HostEmb Initialize End"); @@ -111,7 +111,7 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, spdlog::info(HOSTEMB + "UpdateEmb End missingkeys len = {}", missingKeysHostPos.size()); EASY_BLOCK("Update") const float* tensorPtr = d2hEmb.flat().data(); - auto embeddingSize = hostEmbs[embName].hostEmbInfo.embeddingSize; + auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; auto& embData = hostEmbs[embName].embData; #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ @@ -145,7 +145,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI spdlog::info(HOSTEMB + "UpdateEmb End missingkeys len = {}", missingKeysHostPos.size()); EASY_BLOCK("Update") auto& embData = hostEmbs[embName].embData; - auto embeddingSize = hostEmbs[embName].hostEmbInfo.embeddingSize; + auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; auto aclData = acltdtGetDataItem(aclDataset, 0); if (aclData == nullptr) { throw runtime_error("Acl get tensor data from dataset failed."); @@ -176,7 +176,7 @@ vector HostEmb::GetH2DEmb(const vector& missingKeysHostPos, cons EASY_FUNCTION() vector h2d_emb; const auto& emb = hostEmbs[embName]; - const int embeddingSize = emb.hostEmbInfo.embeddingSize; + const int embeddingSize = emb.hostEmbInfo.extEmbeddingSize; h2d_emb.emplace_back(Tensor(tensorflow::DT_FLOAT, { int(missingKeysHostPos.size()), embeddingSize })); @@ -199,14 +199,6 @@ auto HostEmb::GetHostEmbs() -> absl::flat_hash_map* return &hostEmbs; } -EmbInfo::EmbInfo(const string &name, int sendCount, int embeddingSize, vector vocabsize, - vector initializeInfos) - : name(name), sendCount(sendCount), embeddingSize(embeddingSize), initializeInfos(initializeInfos) -{ - devVocabSize = vocabsize[0]; - hostVocabSize = vocabsize[1]; -} - void HostEmb::EmbPartGenerator(const vector &initializeInfos, vector> &embData, const vector& offset) { diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 2e24fa52..c9cee4d0 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -49,11 +49,11 @@ int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, map scInfo; for (const auto& info: eInfos) { + spdlog::debug(KEY_PROCESS "Init sendCountMap:{}, channelNames:{}", info.sendCountMap, info.channelNames); embInfos[info.name] = info; scInfo[info.name] = info.sendCount; if (rankInfo.useHot) { - hotEmbTotCount[info.name] = static_cast(GetUBSize(rInfo.deviceId) / sizeof(float) * HOT_EMB_CACHE_PCT / - info.embeddingSize); + InitHotEmbTotCount(info, rInfo); } if (rankInfo.useDynamicExpansion) { // 动态扩容 @@ -115,6 +115,25 @@ int KeyProcess::Start() return 0; } +void KeyProcess::InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo) +{ + auto embeddingSize = info.extEmbeddingSize; + if (rankInfo.useDynamicExpansion) { + embeddingSize = info.embeddingSize; + } + hotEmbTotCount[info.name] = static_cast(GetUBSize(rInfo.deviceId) / sizeof(float) * HOT_EMB_CACHE_PCT / + embeddingSize); +} + +auto KeyProcess::GetSendCount(const string& name, const string& channelName, bool modifyGraph) +{ + auto sendCountSize = embInfos[name].sendCount; + if (modifyGraph) { + sendCountSize = embInfos[name].sendCountMap[channelName]; + } + return sendCountSize; +} + auto KeyProcess::GetMaxOffset() -> offset_mem_t { return maxOffset; @@ -184,6 +203,7 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES TIME_PRINT("GetBatchData TimeCost(ms):{}", getBatchTC.ElapsedMS()); if (batch == nullptr) { + spdlog::info(KEY_PROCESS "batch is nullptr"); break; } auto getBatchTime = TO_MS(sw); @@ -192,8 +212,8 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES if (unique == nullptr) { GroupMethod groupMethod; groupMethod.SetGroupCount(rankInfo.rankSize); - unique = new ShardedDedup(groupMethod, batch->batchSize, - embInfos[batch->name].sendCount); + auto sendCountSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); + unique = new ShardedDedup(groupMethod, batch->batchSize, sendCountSize); } else { unique->StartNewRound(); } @@ -246,7 +266,11 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_de TIME_PRINT("Key2Offset TimeCost(ms):{}", key2OffsetTc.ElapsedMS()); } if (!rankInfo.useStatic) { // Static all2all,need send count - SendA2A(scAll, batch->name, batch->channel, batch->batchId); + auto embName = batch->name; + if (batch->modifyGraph) { + embName = batch->channelName; + } + SendA2A(scAll, embName, batch->channel, batch->batchId); } auto tensors = make_unique>(); @@ -301,8 +325,13 @@ void KeyProcess::PushResult(unique_ptr& batch, unique_ptr lockGuard(getInfoMut[id]); storage[id].push_front(move(tensors)); - infoList[id][batch->name][batch->channel].push( - make_tuple(batch->batchId, batch->name, storage[id].begin())); + if (batch->modifyGraph) { + infoList[id][batch->channelName][batch->channel].push( + make_tuple(batch->batchId, batch->channelName, storage[id].begin())); + } else { + infoList[id][batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, storage[id].begin())); + } if (!rankInfo.noDDR) { lookupKeysList[id][batch->name][batch->channel].push( make_tuple(batch->batchId, batch->name, move(lookupKeys))); @@ -346,9 +375,9 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) } } EASY_END_BLOCK - spdlog::info(KEY_PROCESS "GetBatchData get batchId:{}, batchSize:{}, batch.channel:{}, name:{}, " - "channel:{}, commId:{}, ", - batch->batchId, batch->batchSize, batch->channel, batch->name, channel, commId); + spdlog::info(KEY_PROCESS "GetBatchData get batchId:{}, batchSize:{}, batch.channel:{}, batch.channelName:{}, " + "name:{}, channel:{}, commId:{}, ", + batch->batchId, batch->batchSize, batch->channel, batch->channelName, batch->name, channel, commId); #if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) if (batch->batchId == PROFILING_START_BATCH_ID) { EASY_PROFILER_ENABLE @@ -360,6 +389,18 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) return batch; } +size_t KeyProcess::GetKeySize(const unique_ptr &batch) +{ + size_t size = rankInfo.rankSize * embInfos[batch->name].sendCount; + if (batch->modifyGraph) { + size = rankInfo.rankSize * embInfos[batch->name].sendCountMap[batch->channelName]; + } + if (!rankInfo.useStatic) { + size = batch->batchSize; + } + return size; +} + auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id) -> tuple, vector, vector, vector> { @@ -372,10 +413,7 @@ auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba SimpleThreadPool pool_; keys_t keySend; - size_t size = rankInfo.rankSize * embInfos[batch->name].sendCount; - if (!rankInfo.useStatic) { - size = batch->batchSize; - } + size_t size = GetKeySize(batch); keySend.resize(size); vector splitSize(rankInfo.rankSize); vector uniqueVector(batch->batchSize); @@ -411,7 +449,7 @@ auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba vector sc; // send count if (rankInfo.useStatic) { - sc.resize(rankInfo.rankSize, embInfos[batch->name].sendCount); + sc.resize(rankInfo.rankSize, GetSendCount(batch->name, batch->channelName, batch->modifyGraph)); } else { sc.resize(rankInfo.rankSize); for (int i = 0;i < rankInfo.rankSize; i++) { @@ -420,6 +458,9 @@ auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba } auto [keyRecv, scAll, countRecv] = All2All(sc, id, batch->channel, keySend, keyCount); + spdlog::debug(KEY_PROCESS "ProcessBatchWithUniqueCompute get batchId:{}, batchSize:{}, channel:{}, " + "channelName:{}, name:{}, restore:{}, keyCount:{}", batch->batchId, batch->batchSize, + batch->channel, batch->channelName, batch->name, restore.size(), keyCount.size()); return { keyRecv, restore, hotPos, scAll, countRecv}; } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index d2730707..3a98aa0a 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -115,6 +115,10 @@ namespace MxRec { int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; + void InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo); + + auto GetSendCount(const string& name, const string& channelName, bool modifyGraph); + void KeyProcessTask(int channel, int id); bool KeyProcessTaskHelper(unique_ptr& batch, sharded_dedup unique_, @@ -124,6 +128,8 @@ namespace MxRec { auto ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique_, int id) -> tuple, vector, vector, vector>; + size_t GetKeySize(const unique_ptr &batch); + auto All2All(vector& sc, int id, int channel, keys_t& keySend, vector& keyCount) -> tuple, vector>; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 51831b2c..a3555d0f 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -136,10 +136,12 @@ namespace MxRec { std::vector sample; void *tensorAddr = nullptr; std::string name; + std::string channelName; size_t batchSize; int batchId; int channel = 0; bool isInt64; // true int64 false int32 + bool modifyGraph; time_t timestamp { -1 }; }; @@ -340,15 +342,30 @@ struct BatchTask { EmbInfo(const std::string& name, int sendCount, int embeddingSize, + int extEmbeddingSize, + bool modifyGraph, + std::vector channelNames, std::vector vocabsize, - std::vector initializeInfos); + std::vector initializeInfos, + std::map sendCountMap) + : name(name), sendCount(sendCount), embeddingSize(embeddingSize), extEmbeddingSize(extEmbeddingSize), + modifyGraph(modifyGraph), channelNames(channelNames), initializeInfos(initializeInfos), + sendCountMap(sendCountMap) + { + devVocabSize = vocabsize[0]; + hostVocabSize = vocabsize[1]; + } std::string name; int sendCount; int embeddingSize; + int extEmbeddingSize; + bool modifyGraph; size_t devVocabSize; size_t hostVocabSize; + std::vector channelNames; std::vector initializeInfos; + std::map sendCountMap; }; struct HostEmbTable { diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 48aaeee2..85a7c261 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -110,8 +110,10 @@ REGISTER_OP("ReadEmbKeyV2Dynamic") .Output("output: int32") .Attr("T: {int64, int32}") .Attr("channel_id: int") - .Attr("emb_name: list(string)") // for which table to lookup - .Attr("timestamp: bool") // use for feature evict, (unix timestamp) + .Attr("emb_name: list(string)") // for which table to lookup + .Attr("timestamp: bool") // use for feature evict, (unix timestamp) + .Attr("channel_name: list(string)") // use for multi lookup + .Attr("modify_graph: bool") // auto modify graph enabled .SetShapeFn([](InferenceContextPtr c) { c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); return Status::OK(); @@ -127,6 +129,8 @@ public: OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); + OP_REQUIRES_OK(context, context->GetAttr("channel_name", &channelNames)); + OP_REQUIRES_OK(context, context->GetAttr("modify_graph", &modifyGraph)); // 特征准入&淘汰功能 相关校验 @@ -194,9 +198,9 @@ public: out(0) = batchId; EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}, " - "splits: {}, dataSize: {}, filedNum: {}", + "splits: {}, dataSize: {}, filedNum: {}, channelNames: {}, modifyGraph: {}", TO_MS(sw), TO_MS(staticSw), - channelId, batchId, splits.size(), dataSize, fieldNum); + channelId, batchId, splits.size(), dataSize, fieldNum, channelNames, modifyGraph); staticSw.reset(); } @@ -211,6 +215,10 @@ public: for (int i = 0; i < splits.size(); ++i) { auto batchData = queue->WaitAndGetOne(); // get dirty or empty data block batchData->name = embNames.at(i); + if (modifyGraph) { + batchData->modifyGraph = modifyGraph; + batchData->channelName = channelNames.at(i); + } size_t len = splits(i); batchData->channel = channelId; batchData->batchId = ids[0]; @@ -302,8 +310,10 @@ public: int channelId {}; vector embNames {}; + vector channelNames {}; int maxStep = 0; bool isTimestamp { false }; + bool modifyGraph { false }; }; REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2Dynamic").Device(DEVICE_CPU), ReadEmbKeyV2Dynamic); @@ -315,8 +325,10 @@ REGISTER_OP("ReadEmbKeyV2") .Attr("T: {int64, int32}") .Attr("channel_id: int") .Attr("splits: list(int)") - .Attr("emb_name: list(string)") // for which table to lookup - .Attr("timestamp: bool") // use for feature evict, (unix timestamp) + .Attr("emb_name: list(string)") // for which table to lookup + .Attr("timestamp: bool") // use for feature evict, (unix timestamp) + .Attr("channel_name: list(string)") // use for multi lookup + .Attr("modify_graph: bool") // auto modify graph enabled .SetShapeFn([](InferenceContextPtr c) { c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); return Status::OK(); @@ -333,6 +345,8 @@ public: OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); OP_REQUIRES_OK(context, context->GetAttr("splits", &splits)); // 每个表的field Number OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); + OP_REQUIRES_OK(context, context->GetAttr("channel_name", &channelNames)); + OP_REQUIRES_OK(context, context->GetAttr("modify_graph", &modifyGraph)); fieldNum = accumulate(splits.begin(), splits.end(), 0); // 特征准入&淘汰功能 相关校验 @@ -430,6 +444,10 @@ public: TIME_PRINT("TryPopTimeCost(ms):{}", tp.ElapsedMS()); batchData->name = embNames.at(i); + if (modifyGraph) { + batchData->modifyGraph = modifyGraph; + batchData->channelName = channelNames.at(i); + } size_t len = splits.at(i); batchData->channel = channelId; batchData->batchId = batchId; @@ -525,8 +543,10 @@ public: vector splits {}; int fieldNum {}; vector embNames {}; + vector channelNames {}; int maxStep = 0; bool isTimestamp { false }; + bool modifyGraph { false }; }; REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), ReadEmbKeyV2); diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 9426026f..6d297dc1 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -93,15 +93,21 @@ void GetRankInfo(pybind11::module_& m) void GetEmbInfo(pybind11::module_& m) { pybind11::class_(m, "EmbInfo") - .def(pybind11::init, std::vector&>(), - py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), py::arg("vocab_size"), - py::arg("initialize_infos")) + .def(pybind11::init, std::vector, + std::vector&, std::map>(), + py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), + py::arg("ext_embedding_size"), py::arg("modify_graph"), py::arg("channel_name_list"), + py::arg("vocab_size"), py::arg("initialize_infos"), py::arg("send_count_map")) .def_readwrite("name", &EmbInfo::name) .def_readwrite("send_count", &EmbInfo::sendCount) .def_readwrite("embedding_size", &EmbInfo::embeddingSize) + .def_readwrite("ext_embedding_size", &EmbInfo::extEmbeddingSize) + .def_readwrite("modify_graph", &EmbInfo::modifyGraph) + .def_readwrite("channel_name_list", &EmbInfo::channelNames) .def_readwrite("dev_vocab_size", &EmbInfo::devVocabSize) .def_readwrite("host_vocab_size", &EmbInfo::hostVocabSize) - .def_readwrite("initialize_infos", &EmbInfo::initializeInfos); + .def_readwrite("initialize_infos", &EmbInfo::initializeInfos) + .def_readwrite("send_count_map", &EmbInfo::sendCountMap); } void GetRandomInfo(pybind11::module_& m) diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index c5cbda50..9523debf 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -71,7 +71,7 @@ protected: for (auto& testEmbInfo : testEmbInfos) { testEmbInfo.name = name + to_string(idx); testEmbInfo.sendCount = sendCount; - testEmbInfo.embeddingSize = embeddingSize; + testEmbInfo.extEmbeddingSize = embeddingSize; testEmbInfo.devVocabSize = devVocabSize; testEmbInfo.hostVocabSize = hostVocabSize; ++idx; @@ -253,7 +253,7 @@ TEST_F(CheckpointTest, HostEmbs) EXPECT_EQ(it.second.hostEmbInfo.name, embInfo.name); EXPECT_EQ(it.second.hostEmbInfo.sendCount, embInfo.sendCount); - EXPECT_EQ(it.second.hostEmbInfo.embeddingSize, embInfo.embeddingSize); + EXPECT_EQ(it.second.hostEmbInfo.extEmbeddingSize, embInfo.extEmbeddingSize); EXPECT_EQ(it.second.hostEmbInfo.devVocabSize, embInfo.devVocabSize); EXPECT_EQ(it.second.hostEmbInfo.hostVocabSize, embInfo.hostVocabSize); diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 2c13c0c4..69b45ad3 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -54,7 +54,7 @@ protected: for (auto& testEmbInfo : testEmbInfos) { testEmbInfo.name = name + to_string(idx); testEmbInfo.sendCount = sendCount; - testEmbInfo.embeddingSize = embeddingSize; + testEmbInfo.extEmbeddingSize = embeddingSize; testEmbInfo.devVocabSize = devVocabSize; testEmbInfo.hostVocabSize = hostVocabSize; diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index ff30385e..be4d2fbc 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -37,9 +37,13 @@ protected: string name = "model"; int sendCount = 5; int embeddingSize = 8; + int extEmbeddingSize = 24; + bool modifyGraph = false; size_t devVocabSize = 5; size_t hostVocabSize = 15; vector randomInfos; + vector channelNames = {"model_1", "model_2"}; + map sendCountMap = {{"model_1", 500}, {"model_2", 500}}; RandomInfo randomInfo; int start = 0; int len = hostVocabSize * embeddingSize; @@ -68,8 +72,8 @@ protected: for (size_t i = 0; i < missingKeysHostPos.size(); i++) { (hostEmb->GetEmb(embName).embData[missingKeysHostPos[i]]).assign( tensorPtr, - tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.embeddingSize); - tensorPtr = tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.embeddingSize; + tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.extEmbeddingSize); + tensorPtr = tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.extEmbeddingSize; } for (size_t i = 0; i < hostEmb->GetEmb(embName).embData.size(); ++i) { spdlog::info("hostEmb: embName {}, {} is: {}", embName, i, hostEmb->GetEmb(embName).embData[i]); @@ -112,7 +116,8 @@ protected: TEST_F(EmbMgmtTest, Initialize) { vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, vocabsize, initializeInfos); + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, channelNames, vocabsize, + initializeInfos, sendCountMap); embInfos.emplace_back(embInfo); vector thresholdValues = {}; @@ -169,7 +174,8 @@ TEST_F(EmbMgmtTest, Initialize_HBM) devVocabSize = HBM_DEVICE_SIZE; hostVocabSize = HBM_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, vocabsize, initializeInfos); + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, channelNames, vocabsize, + initializeInfos, sendCountMap); embInfos.emplace_back(embInfo); vector thresholdValues; thresholdValues.emplace_back(name, 1, 1); @@ -187,7 +193,8 @@ TEST_F(EmbMgmtTest, Evict) size_t devVocabSize = DDR_DEVICE_SIZE; size_t hostVocabSize = DDR_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, vocabsize, initializeInfos); + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, channelNames, vocabsize, + initializeInfos, sendCountMap); embInfos.emplace_back(embInfo); vector thresholdValues; thresholdValues.emplace_back(name, 1, 1); @@ -208,7 +215,8 @@ TEST_F(EmbMgmtTest, Evict_HBM) devVocabSize = HBM_DEVICE_SIZE; hostVocabSize = HBM_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, vocabsize, initializeInfos); + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, channelNames, vocabsize, + initializeInfos, sendCountMap); embInfos.emplace_back(embInfo); vector thresholdValues; thresholdValues.emplace_back(name, 1, 1); diff --git a/src/tests/emb_table/emb_table_test.cpp b/src/tests/emb_table/emb_table_test.cpp index 0d6716ae..7c2febfd 100644 --- a/src/tests/emb_table/emb_table_test.cpp +++ b/src/tests/emb_table/emb_table_test.cpp @@ -27,7 +27,7 @@ protected: { spdlog::set_level(spdlog::level::debug); // 设置测试用的EmbInfo - embInfo.embeddingSize = embTable.TEST_EMB_SIZE; + embInfo.extEmbeddingSize = embTable.TEST_EMB_SIZE; spdlog::info("EmbTable BLOCK_EMB_COUNT {} INIT_BLOCK_COUNT {}", embTable.BLOCK_EMB_COUNT, embTable.INIT_BLOCK_COUNT); rankInfo.rankId = 0; diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index c001bb15..6d1e01b7 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -122,7 +122,7 @@ protected: ss.str(""); ss.clear(); temp.sendCount = distribution(generator); - temp.embeddingSize = pow(base, embSizeDistribution(generator)); + temp.extEmbeddingSize = pow(base, embSizeDistribution(generator)); geFieldNums.push_back(sampleSize); allEmbInfos.push_back(move(temp)); } -- Gitee From dd0618a9bb79b9285be3794e7b1a0687512357f5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 11:04:45 +0800 Subject: [PATCH 025/551] Match-id-7031f7a303e7633b0b63475a26e51a9c97d021be --- src/core/key_process/key_process.cpp | 5 +++-- src/core/utils/unique.h | 27 ++++++++++++++------------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index c9cee4d0..b6f6b90d 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -193,7 +193,7 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES { unique_ptr batch; ShardedDedup *unique = nullptr; - + int preSendCount = 0; spdlog::stopwatch sw; try { while (true) { @@ -209,7 +209,7 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES auto getBatchTime = TO_MS(sw); sw.reset(); - if (unique == nullptr) { + if (unique == nullptr || preSendCount != embInfos[batch->name].sendCount) { GroupMethod groupMethod; groupMethod.SetGroupCount(rankInfo.rankSize); auto sendCountSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); @@ -217,6 +217,7 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES } else { unique->StartNewRound(); } + preSendCount = embInfos[batch->name].sendCount; auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); if (!KeyProcessTaskHelper(batch, unique, channel, id, sw)) { free(batch->tensorAddr); diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index 47d26dbe..198a8f25 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -111,7 +111,7 @@ template class Dedup { int8_t pad[3]; int32_t replace_base; volatile uint64_t data[M]; - volatile uint64_t idCount[M]; + std::atomic idCount[M]; } __attribute__((__aligned__(64))); struct Statistics { @@ -732,18 +732,19 @@ public: start = index; } - size_t mem_size = uniqueSizeVector[i] * sizeof(int64_t); - auto rc = memcpy_s(uniqueIds + start, mem_size, uniqueVector + index, mem_size); - if (rc != 0) { - spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}",mem_size); - return; - } - - mem_size = uniqueSizeVector[i] * sizeof(int32_t); - rc = memcpy_s(idCountFill + start, mem_size, idCount + index, mem_size); - if (rc != 0) { - spdlog::error("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", mem_size); - return; + if (uniqueSizeVector[i] > 0) { + size_t mem_size = uniqueSizeVector[i] * sizeof(int64_t); + auto rc = memcpy_s(uniqueIds + start, mem_size, uniqueVector + index, mem_size); + if (rc != 0) { + spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}",mem_size); + return; + } + mem_size = uniqueSizeVector[i] * sizeof(int32_t); + rc = memcpy_s(idCountFill + start, mem_size, idCount + index, mem_size); + if (rc != 0) { + spdlog::error("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", mem_size); + return; + } } int fillLen = send_cnt_ - uniqueSizeVector[i]; -- Gitee From 63ea8c74b1f1ee5028ff2d12989337bbcbae74c4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 12:12:22 +0800 Subject: [PATCH 026/551] Match-id-dd885e292759bb479ff78aa742fc2023f97c3d96 --- src/core/checkpoint/checkpoint.cpp | 11 +--- src/core/emb_hashmap/emb_hashmap.cpp | 5 +- src/core/emb_hashmap/emb_hashmap.h | 3 +- src/core/emb_mgmt/emb_mgmt.cpp | 16 +++--- src/core/emb_mgmt/emb_mgmt.h | 8 +-- src/core/hd_transfer/hd_transfer.cpp | 2 +- src/core/hd_transfer/hd_transfer.h | 18 +++---- src/core/host_emb/host_emb.cpp | 2 +- src/core/host_emb/host_emb.h | 2 +- .../constant_initializer.cpp | 2 +- .../constant_initializer.h | 2 +- .../random_normal_initializer.cpp | 2 +- .../random_normal_initializer.h | 2 +- .../truncated_normal_initializer.cpp | 2 +- .../truncated_normal_initializer.h | 2 +- .../key_process/feature_admit_and_evict.cpp | 4 +- .../key_process/feature_admit_and_evict.h | 4 +- src/core/key_process/key_process.cpp | 6 +-- src/core/key_process/key_process.h | 2 +- src/core/utils/common.cpp | 4 +- src/core/utils/common.h | 4 +- src/core/utils/spinlock.h | 51 ++++++++++--------- src/core/utils/task_queue.h | 11 ++-- src/core/utils/unique.h | 23 ++++++--- 24 files changed, 97 insertions(+), 91 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index b398da94..675d00c2 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -23,8 +23,6 @@ using namespace MxRec; void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo) { - // TODO: check savePath - processPath = savePath; rankId = mgmtRankInfo.rankId; deviceId = mgmtRankInfo.deviceId; @@ -42,8 +40,6 @@ void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRa void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo, const vector& featureTypes) { - // TODO: check loadPath - processPath = loadPath; rankId = mgmtRankInfo.rankId; deviceId = mgmtRankInfo.deviceId; @@ -186,7 +182,7 @@ void Checkpoint::SaveDataset(const vector& embNames, } } -void Checkpoint::WriteEmbedding(CkptTransData& transData, const string& dataDir, int& embeddingSize) +void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize) { ofstream writeFile; writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); @@ -333,11 +329,9 @@ void Checkpoint::LoadProcess(CkptData& ckptData) void Checkpoint::GetUpperLayerLoadDir(const vector& dirNames) { innerDirPath = processPath; - // TODO: check existence for (const auto& dirName : dirNames) { innerDirPath = innerDirPath + dirSeparator + dirName; - // TODO: check existence } } @@ -355,7 +349,6 @@ vector Checkpoint::GetTableLayerLoadDir() } closedir(dir); } - // TODO: may cause memory problem? need to check return loadTableDir; } @@ -367,10 +360,8 @@ void Checkpoint::LoadDataset(const vector& embNames, { for (const auto& embName : embNames) { auto dataDir { innerDirPath + dirSeparator + embName }; - // TODO: check existence for (const auto& saveDataType : saveDataTypes) { auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; - // TODO: check existence auto datasetDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; auto attributeDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + attribFileType }; diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 8f446a65..15ae6ac8 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -109,7 +109,7 @@ void EmbHashMap::FindAndUpdateOffset(const string& embName, const vector& keys, size_t curr } } -void EmbHashMap::FindPos(EmbHashMapInfo& embHashMap, int num, size_t currentBatchId, - size_t keepBatchId) +void EmbHashMap::FindPos(EmbHashMapInfo& embHashMap, int num, size_t keepBatchId) { while (num != 0) { if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 424684e5..a2337d01 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -31,8 +31,7 @@ namespace MxRec { void ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_t hostOffset, size_t currentBatchId, int pos); - void FindPos(EmbHashMapInfo& embHashMap, int num, size_t currentBatchId, - size_t keepBatchId); + void FindPos(EmbHashMapInfo& embHashMap, int num, size_t keepBatchId); auto GetHashMaps() -> absl::flat_hash_map; diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index 8cee8a08..a7d43c54 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -50,7 +50,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, if (!rankInfo.noDDR) { hostEmbs = make_unique(); hostHashMaps = make_unique(); - hostEmbs->Initialize(embInfos, seed, ifLoad); + hostEmbs->Initialize(embInfos, seed); hostHashMaps->Init(rankInfo, embInfos, ifLoad); } isLoad = ifLoad; @@ -66,7 +66,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, return true; } -bool HybridMgmt::Save(string savePath) +bool HybridMgmt::Save(const string savePath) { preprocess->LoadSaveLock(); @@ -347,7 +347,7 @@ bool HybridMgmt::SendTask() return false; } -bool HybridMgmt::GetLookupAndRestore(int channelId, int &batchId) +bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) { spdlog::info(MGMT + "start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); for (const auto& embInfo: mgmtEmbInfo) { @@ -374,7 +374,7 @@ bool HybridMgmt::GetLookupAndRestore(int channelId, int &batchId) return true; } -bool HybridMgmt::SendLookupAndRestore(int channelId, int &batchId) +bool HybridMgmt::SendLookupAndRestore(const int channelId, const int &batchId) { for (const auto& embInfo: mgmtEmbInfo) { vector names = {embInfo.name}; @@ -430,8 +430,10 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) { spdlog::info(MGMT + "DDR mode, start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); TimeCost parseKeyTC; - int start = batchId, iBatch = 0; - bool ifHashmapFree = true, remainBatch = true; + int start = batchId; + int iBatch = 0; + bool ifHashmapFree = true; + bool remainBatch = true; while (true) { spdlog::info(MGMT + "parse keys, [{}]:{}", channelId, batchId); for (const auto& embInfo : mgmtEmbInfo) { @@ -498,7 +500,7 @@ void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, in } } -void HybridMgmt::EmbHDTrans(int channelId, int batchId) +void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) { EASY_FUNCTION(profiler::colors::Blue) EASY_VALUE("mgmtProcess", batchId) diff --git a/src/core/emb_mgmt/emb_mgmt.h b/src/core/emb_mgmt/emb_mgmt.h index 31eef2f5..1ca832fa 100644 --- a/src/core/emb_mgmt/emb_mgmt.h +++ b/src/core/emb_mgmt/emb_mgmt.h @@ -46,7 +46,7 @@ namespace MxRec { bool Initialize(RankInfo rankInfo, const vector& embInfos, int seed, const vector& thresholdValues, bool ifLoad); - bool Save(string savePath); + bool Save(const string savePath); bool Load(const string& loadPath); @@ -83,7 +83,7 @@ namespace MxRec { bool ParseKeys(int channelId, int& batchId); - void EmbHDTrans(int channelId, int batchId); + void EmbHDTrans(const int channelId, const int batchId); void Evict(); @@ -114,8 +114,8 @@ namespace MxRec { bool TrainParseKeys(); bool EvalParseKeys(); - bool GetLookupAndRestore(int channelId, int &batchId); - bool SendLookupAndRestore(int channelId, int &batchId); + bool GetLookupAndRestore(const int channelId, int &batchId); + bool SendLookupAndRestore(const int channelId, const int &batchId); void EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo); diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index f3e9618b..b5c99868 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -60,7 +60,7 @@ void HDTransfer::Destroy() #endif } -void HDTransfer::CreateChannel(uint32_t localRankId, const string& embName, int channelNum) +void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum) { #ifndef GTEST int channelSize; diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index 2e9ff303..b840649d 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -27,7 +27,7 @@ namespace MxRec { const int PING_PONG_SIZE = 12; const int LARGE_CHANNEL_SIZE = 100; - enum TransferChannel { + enum class TransferChannel { D2H, RESTORE, ALL2ALL, @@ -41,19 +41,19 @@ namespace MxRec { inline string TransferChannel2Str(TransferChannel e) { switch (e) { - case D2H: + case TransferChannel::D2H: return "d2h"; - case RESTORE: + case TransferChannel::RESTORE: return "restore"; - case ALL2ALL: + case TransferChannel::ALL2ALL: return "all2all"; - case LOOKUP: + case TransferChannel::LOOKUP: return "lookup"; - case EVICT: + case TransferChannel::EVICT: return "evict"; - case H2D: + case TransferChannel::H2D: return "h2d"; - case SWAP: + case TransferChannel::SWAP: return "swap"; default: throw std::invalid_argument("Invalid TransferChannel"); @@ -84,7 +84,7 @@ namespace MxRec { std::unordered_map transferChannels; #endif bool running; - void CreateChannel(uint32_t localRankId, const string& embName, int channelNum); + void CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum); }; } #endif // MX_REC_HD_TRANSFER_H diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 6fdb1f54..0c316b1f 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -19,7 +19,7 @@ using namespace MxRec; using namespace std; using namespace chrono; -bool HostEmb::Initialize(const vector& embInfos, int seed, bool ifLoad) +bool HostEmb::Initialize(const vector& embInfos, int seed) { for (const auto& embInfo: embInfos) { HostEmbTable hostEmb; diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index 31928a27..da72e70c 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -28,7 +28,7 @@ namespace MxRec { ~HostEmb() {}; - bool Initialize(const vector& embInfos, int seed, bool ifLoad = false); + bool Initialize(const vector& embInfos, int seed); void LoadEmb(absl::flat_hash_map& loadData); diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index 4bced738..09780b88 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -13,7 +13,7 @@ using namespace MxRec; ConstantInitializer::ConstantInitializer(int start, int len, float value) : start(start), len(len), value(value) {} -void ConstantInitializer::GenerateData(float* emb, const int embSize) +void ConstantInitializer::GenerateData(const float* emb, const int embSize) { if (len == 0) { return; diff --git a/src/core/initializer/constant_initializer/constant_initializer.h b/src/core/initializer/constant_initializer/constant_initializer.h index b763087b..56f95a9e 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.h +++ b/src/core/initializer/constant_initializer/constant_initializer.h @@ -21,7 +21,7 @@ namespace MxRec { ~ConstantInitializer() override {}; - void GenerateData(float* emb, const int embSize) override; + void GenerateData(const float* emb, const int embSize) override; int start; int len; diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index a5a31381..56765ecf 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -18,7 +18,7 @@ RandomNormalInitializer::RandomNormalInitializer(int start, int len, float mean, distribution = std::normal_distribution(mean, stddev); } -void RandomNormalInitializer::GenerateData(float* emb, const int embSize) +void RandomNormalInitializer::GenerateData(const float* emb, const int embSize) { if (len == 0) { return; diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index e0127ca2..eeab2dfa 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -23,7 +23,7 @@ namespace MxRec { ~RandomNormalInitializer() override {}; - void GenerateData(float* emb, const int embSize) override; + void GenerateData(const float* emb, const int embSize) override; int start; int len; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index a1f49c08..2d830490 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -21,7 +21,7 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, float } -void TruncatedNormalInitializer::GenerateData(float* emb, const int embSize) +void TruncatedNormalInitializer::GenerateData(const float* emb, const int embSize) { if (len == 0) { return; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index 3c6bb980..34058351 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -23,7 +23,7 @@ namespace MxRec { ~TruncatedNormalInitializer() override {}; - void GenerateData(float* emb, const int embSize) override; + void GenerateData(const float* emb, const int embSize) override; int boundNum = 2; diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index bcbb02e7..7cb2caa8 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -91,8 +91,8 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; } -FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(int channel, const std::string& tensorName, - int64_t featureId, uint32_t featureCnt) +FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, const std::string& tensorName, + const int64_t featureId, const uint32_t featureCnt) { // “特征准入”逻辑 uint32_t currKeyCount = 0; diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index a355711d..dfc2490c 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -84,8 +84,8 @@ namespace MxRec { // 解析m_tensor2Threshold bool ParseThresholdCfg(const std::vector& thresholdValues); std::vector GetAllNeedEvictTensorNames(); - FeatureAdmitType FeatureAdmitHelper(int channel, const std::string& tensorName, - int64_t featureId, uint32_t featureCnt); + FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tensorName, + const int64_t featureId, const uint32_t featureCnt); void FeatureEvictHelper(const std::string& embName, std::vector& evictKey); void ResetAllRecords(); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index b6f6b90d..795bfcd2 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -189,7 +189,7 @@ void KeyProcess::LoadSaveUnlock() } } -void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCESS_THREAD-1] +void KeyProcess::KeyProcessTask(const int channel, const int id) // thread id [0, KEY_PROCESS_THREAD-1] { unique_ptr batch; ShardedDedup *unique = nullptr; @@ -885,11 +885,11 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) return get(ret); } catch (EmptyList&) { spdlog::trace("GetLookupKeys GetInfo failed {}[{}]:{} no input, wait and retry", - embName, channel, batch); + embName, channel, batch); this_thread::sleep_for(1ms); } catch (WrongListTop&) { spdlog::trace("GetLookupKeys GetInfo failed {}[{}]:{} wrong top", - embName, channel, batch); + embName, channel, batch); this_thread::sleep_for(1ms); } } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 3a98aa0a..64669f96 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -119,7 +119,7 @@ namespace MxRec { auto GetSendCount(const string& name, const string& channelName, bool modifyGraph); - void KeyProcessTask(int channel, int id); + void KeyProcessTask(const int channel, const int id); bool KeyProcessTaskHelper(unique_ptr& batch, sharded_dedup unique_, int channel, int id, spdlog::stopwatch& sw); diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 9da6952c..0ffe6040 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -37,8 +37,8 @@ namespace MxRec { if (localRankSize != 0) { localRankId = rankId % localRankSize; } - useStatic = option bitand HybridOption::USE_STATIC; - useHot = option bitand HybridOption::USE_HOT; + useStatic = option & HybridOption::USE_STATIC; + useHot = option & HybridOption::USE_HOT; } RandomInfo::RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index a3555d0f..84a1ae50 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -263,8 +263,8 @@ struct BatchTask { std::default_random_engine& generator, RandomInfo& randomInfo) { - float min = ((randomInfo.randomMin == 0) ? -0.1f : randomInfo.randomMin); - float max = ((randomInfo.randomMax == 0) ? 0.1f : randomInfo.randomMax); + float min = ((!randomInfo.randomMin) ? -0.1f : randomInfo.randomMin); + float max = ((!randomInfo.randomMax) ? 0.1f : randomInfo.randomMax); if (randomInfo.len == 0) { return; } diff --git a/src/core/utils/spinlock.h b/src/core/utils/spinlock.h index 95c0a35c..d284ef05 100644 --- a/src/core/utils/spinlock.h +++ b/src/core/utils/spinlock.h @@ -17,7 +17,8 @@ type(type &&) noexcept = delete; \ type &operator=(type const &) = delete -static __inline void cpu_pause() { +static __inline void cpu_pause() +{ #ifdef __GNUC__ #ifdef __aarch64__ __asm volatile("yield" ::: "memory"); @@ -62,7 +63,8 @@ public: DISALLOW_COPY_MOVE_AND_ASSIGN_(SpinLock); - inline void lock() noexcept { + inline void lock() noexcept + { while (true) { if (!lock_.exchange(true, std::memory_order_acquire)) { break; @@ -80,7 +82,8 @@ public: } } - inline bool try_lock() noexcept { + inline bool try_lock() noexcept + { if (lock_.load(std::memory_order_relaxed)) { return false; } @@ -107,22 +110,23 @@ public: DISALLOW_COPY_MOVE_AND_ASSIGN_(RWSpinLock); - inline void r_lock() noexcept { - LockData oldData, newData; + inline void r_lock() noexcept + { + LockData oldData; + LockData newData; while (true) { uint16_t counter = 0; for (;;) { oldData.raw = lock_.load(std::memory_order_relaxed); - if (oldData.lock.writer > 0) { - cpu_pause(); - if (++counter > g_kMaxSpinCountBeforeThreadYield) { - std::this_thread::yield(); - // reset counter - counter = 0; - } - } else { + if (oldData.lock.writer <= 0) { break; } + cpu_pause(); + if (++counter > g_kMaxSpinCountBeforeThreadYield) { + std::this_thread::yield(); + // reset counter + counter = 0; + } } newData.lock.readers = oldData.lock.readers + 1; @@ -135,22 +139,23 @@ public: } } - inline void w_lock() noexcept { - LockData oldData, newData; + inline void w_lock() noexcept + { + LockData oldData; + LockData newData; while (true) { uint16_t counter = 0; for (;;) { oldData.raw = lock_.load(std::memory_order_relaxed); - if (oldData.raw != 0) { - cpu_pause(); - if (++counter > g_kMaxSpinCountBeforeThreadYield) { - std::this_thread::yield(); - // reset counter - counter = 0; - } - } else { + if (oldData.raw == 0) { break; } + cpu_pause(); + if (++counter > g_kMaxSpinCountBeforeThreadYield) { + std::this_thread::yield(); + // reset counter + counter = 0; + } } newData.lock.readers = 0; diff --git a/src/core/utils/task_queue.h b/src/core/utils/task_queue.h index 14ed1202..d44e2a11 100644 --- a/src/core/utils/task_queue.h +++ b/src/core/utils/task_queue.h @@ -38,7 +38,6 @@ public: return *this; } - void Pushv(T &t) { std::lock_guard lk(mut); @@ -57,14 +56,14 @@ public: { std::unique_lock lk(mut); dataCond.wait(lk, [this] { - if (!finished){ + if (!finished) { return !dataQueue.empty(); - } else{ + } else { return true; } }); T res; - if (finished){ + if (finished) { return res; } res = dataQueue.front(); @@ -72,12 +71,12 @@ public: return res; } - void DestroyQueue(){ + void DestroyQueue() + { finished = true; dataCond.notify_one(); } - bool Empty() const { std::lock_guard lk(mut); diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index 198a8f25..9be61498 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -64,6 +64,16 @@ struct UniqueThreadNum { int maxThread; }; +namespace { + const int LEVEL1_CACHE = 64; + const int DEFAULT_DEDUPLICATION_RATE = 4; + const int DEDUPLICATION_RATE = 2; + const int HASH_SPLIT_BUCKERT_1 = 16; + const int HASH_SPLIT_BUCKERT_2 = 32; + const int HASH_SPLIT_BUCKERT_3 = 48; + const int PRE_APPLY_MEMORY = 256; +} + class SendCntTooSmallError : public std::exception { }; @@ -123,7 +133,7 @@ public: Dedup(int bucketCountPower2 = kDefaultBucketCount, int groups = 1) : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) { - void *area = aligned_alloc(64, sizeof(Meta) * bucketCount_); + void *area = aligned_alloc(LEVEL1_CACHE, sizeof(Meta) * bucketCount_); table_ = reinterpret_cast *>(area); Clear(bucketCount_); } @@ -270,7 +280,7 @@ public: // sake. uint64_t shardedTableSize = newBucketCountPowerOf2 * N * groupCount_; int largeCount = 0; - while (shardedTableSize > stats_.totalUniques * 4 && largeCount_ != 1) { + while (shardedTableSize > stats_.totalUniques * DEFAULT_DEDUPLICATION_RATE && largeCount_ != 1) { // too large newBucketCountPowerOf2 >>= 1; shardedTableSize >>= 1; @@ -285,7 +295,7 @@ public: } } - while (shardedTableSize < stats_.totalUniques + (stats_.totalUniques >> 2)) { + while (shardedTableSize < stats_.totalUniques + (stats_.totalUniques >> DEDUPLICATION_RATE)) { newBucketCountPowerOf2 <<= 1; shardedTableSize <<= 1; } @@ -365,7 +375,8 @@ public: return total - priorTotal; } - void handleHotKey(int key, map &hotMap, map &hotPosMap, int &hotCount) { + void handleHotKey(int key, map &hotMap, map &hotPosMap, int &hotCount) + { auto hot = hotMap.find(key); if (hot != hotMap.end()) { if (hot->second == -1) { @@ -445,7 +456,7 @@ private: static inline uint64_t hash(uint64_t val) { - return val ^ (val >> 16) ^ (val >> 32) ^ (val >> 48); + return val ^ (val >> HASH_SPLIT_BUCKERT_1) ^ (val >> HASH_SPLIT_BUCKERT_2) ^ (val >> HASH_SPLIT_BUCKERT_3); } void insertOverflow(uint64_t val) @@ -503,7 +514,7 @@ public: ShardedDedup(const GroupMethod &groupMethod, int desiredSize, int send_cnt, int estimatedDuplicateRatio = kDefaultDuplicateRatio) - : groupMethod_(groupMethod), bucketCountPower2_(256), send_cnt_(send_cnt) + : groupMethod_(groupMethod), bucketCountPower2_(PRE_APPLY_MEMORY), send_cnt_(send_cnt) { const int numOfGroupsInShard = groupMethod_.GroupCount(); -- Gitee From e2d1a6448f6790887f152c3cb477035201e5ab5d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 14:10:47 +0800 Subject: [PATCH 027/551] Match-id-a891f2e13553723d0774fddd17e3bf2406071a25 --- example/little_demo/main.py | 2 +- mx_rec/__init__.py | 3 +- mx_rec/core/embedding.py | 13 +- mx_rec/optimizers/base.py | 19 ++ mx_rec/optimizers/lazy_adam_by_addr.py | 255 +++++-------------------- 5 files changed, 83 insertions(+), 209 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index e29d36a9..27af7e30 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -97,7 +97,7 @@ if __name__ == "__main__": warnings.filterwarnings("ignore") mode = MxRecMode.mapping(os.getenv("MXREC_MODE")) - TRAIN_INTERVAL = 100 + TRAIN_INTERVAL = 200 EVAL_STEPS = 10 SAVING_INTERVAL = 100 USE_TIMESTAMP = False diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 3ffb468a..d56b33c3 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -3,7 +3,8 @@ from .util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from .util.tf_version_adapter import npu_ops, hccl_ops from .saver.patch import patch_for_saver from .graph.patch import patch_for_dataset - +from .optimizers.base import patch_for_optimizer patch_for_saver() patch_for_dataset() +patch_for_optimizer() diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 52f7ce9c..866a4f0f 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -406,7 +406,10 @@ class SparseEmbedding: local_embeddings = get_host_pipeline_ops().embedding_lookup_by_address(id_offsets, embedding_dim=self._emb_size, embedding_type=1) - if is_training: + + from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + if is_training and use_dynamic_expansion and ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ + ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) @@ -569,8 +572,6 @@ class SparseEmbedding: id_offsets = tf.identity(id_offsets, name="identity_addr") restore_vector = tf.identity(restore_vector, name="identity_restore") - if is_training and use_dynamic_expansion: - tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) use_static = get_use_static() host_pipeline_ops = get_host_pipeline_ops() @@ -629,7 +630,11 @@ class SparseEmbedding: local_embeddings = \ host_pipeline_ops.embedding_lookup_by_address(id_offsets, embedding_dim=self._emb_size, embedding_type=1) - if is_training: + + from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + if is_training and use_dynamic_expansion and ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ + ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name: + tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) return sparse_forward(local_embeddings) diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index a54e829c..55c02e06 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -9,6 +9,10 @@ from __future__ import print_function from collections import defaultdict +import logging +from tensorflow.python.framework import ops +from tensorflow.python.training.optimizer import _TensorProcessor + class CustomizedOptimizer: @@ -36,3 +40,18 @@ class CustomizedOptimizer: def get_slot_init_values(self): raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") + + +def my_update_op(self, opt, grad): + logging.debug("tf.compat.v1.training.optimizer._TensorProcessor has been patched, update_op.") + if isinstance(grad, ops.Tensor): + logging.debug(">>>>Enter update_op ops.Tensor") + update_op = opt._apply_sparse(grad, self._v) # pylint: disable=protected-access + return update_op + else: + raise RuntimeError("Only support g with type Tensor.") + + +def patch_for_optimizer(): + _TensorProcessor.update_op = my_update_op + logging.debug("Class tf.compat.v1.training.optimizer._TensorProcessor has been patched.") \ No newline at end of file diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 9bf7d2f3..6081a8e6 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -7,20 +7,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import abc import logging from collections import defaultdict import tensorflow as tf -from tensorflow.python.framework import ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import optimizer -from tensorflow.python.eager import context -from tensorflow.python.framework import indexed_slices +from tensorflow.python.training import adam from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer from mx_rec.optimizers.base import CustomizedOptimizer @@ -46,7 +38,7 @@ def create_hash_optimizer_by_address(learning_rate=0.001, beta1=0.9, beta2=0.999 return optimizer_by_addr -class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): +class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): name_counter = defaultdict(int) def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, @@ -54,20 +46,11 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): self.optimizer_type = "LazyAdamByAddress" super(CustomizedLazyAdamByAddress, self).__init__(use_locking, name) - self._lr = learning_rate - self._beta1 = beta1 - self._beta2 = beta2 - self._epsilon = epsilon + super(CustomizedLazyAdamByAddress, self).__get_name__(name=name) + super(CustomizedLazyAdamByAddress, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, + epsilon=epsilon, use_locking=use_locking, name=self.unique_name) - self._non_slot_dict = {} self._slot_num = 2 - - # Tensor versions of the constructor arguments, created in _prepare(). - self._lr_t = None - self._beta1_t = None - self._beta2_t = None - self._epsilon_t = None - self._check_input_param() def _check_input_param(self): @@ -86,15 +69,6 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): def slot_num(self): return self._slot_num - def _get_beta_accumulators(self): - with ops.init_scope(): - if context.executing_eagerly(): - graph = None - else: - graph = ops.get_default_graph() - return (self._get_non_slot_variable("beta1_power", graph=graph), - self._get_non_slot_variable("beta2_power", graph=graph)) - def _create_slots(self, addr_list): first_addr = addr_list[0] self._create_non_slot_variable( @@ -102,46 +76,6 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): self._create_non_slot_variable( initial_value=self._beta2, name="beta2_power", colocate_with=first_addr) - def _create_non_slot_variable(self, initial_value, name, colocate_with): - """Add an extra variable, not associated with a slot.""" - # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. - eager = context.executing_eagerly() - graph = None if eager else tf.get_default_graph() - - key = (name, graph) - var = self._non_slot_dict.get(key, None) - if var is None: - distribution_strategy = distribute_ctx.get_strategy() - with distribution_strategy.extended.colocate_vars_with(colocate_with): - var = variable_scope.variable( - initial_value, name=name, trainable=False, - use_resource=resource_variable_ops.is_resource_variable( - colocate_with)) - self._non_slot_dict[key] = var - return var - - def _prepare(self): - learn_rate = self._call_if_callable(self._lr) - beta1 = self._call_if_callable(self._beta1) - beta2 = self._call_if_callable(self._beta2) - epsilon = self._call_if_callable(self._epsilon) - - self._lr_t = ops.convert_to_tensor(learn_rate, name="learning_rate") - self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") - self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") - self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") - - def _finish(self, update_ops, name_scope): - # Update the power accumulators. - with ops.control_dependencies(update_ops): - beta1_power, beta2_power = self._get_beta_accumulators() - with ops.colocate_with(beta1_power): - update_beta1 = beta1_power.assign( - beta1_power * self._beta1_t, use_locking=self._use_locking) - update_beta2 = beta2_power.assign( - beta2_power * self._beta2_t, use_locking=self._use_locking) - return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], name=name_scope) - def get_slot_init_values(self): # return state value list of adam that needs to initialize in ASC DDR. initial_momentum_value = 0.0 @@ -175,142 +109,57 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): host_pipeline_ops = get_host_pipeline_ops() dim = grad.shape.as_list()[-1] + combined_tensor = \ + host_pipeline_ops.embedding_lookup_by_address(addr, embedding_dim=3 * dim, embedding_type=1) - addr_m = tf.add(addr, 4*dim) - addr_v = tf.add(addr, 8*dim) - old_m_slice = \ - host_pipeline_ops.embedding_lookup_by_address(addr_m, embedding_dim=dim, embedding_type=1) - old_v_slice = \ - host_pipeline_ops.embedding_lookup_by_address(addr_v, embedding_dim=dim, embedding_type=1) - - # combined_tensor = \ - # host_pipeline_ops.embedding_lookup_by_address(addr, embedding_dim=3 * dim, embedding_type=1) - # split_length = [dim] + [dim] + [dim] - # split_tensors = tf.split(combined_tensor, split_length, axis=1) + split_length = [dim] + [dim] + [dim] + split_tensors = tf.split(combined_tensor, split_length, axis=1) - # old_m_slice = split_tensors[1] + old_m_slice = split_tensors[1] m_t_slice = temp_b1 * old_m_slice + (1 - temp_b1) * grad - m_update_op = host_pipeline_ops.embedding_update_by_address(addr_m, m_t_slice - old_m_slice, update_type=0) - # old_v_slice = split_tensors[2] + old_v_slice = split_tensors[2] v_t_slice = temp_b2 * old_v_slice + (1 - temp_b2) * math_ops.square(grad) - v_update_op = host_pipeline_ops.embedding_update_by_address(addr_v, v_t_slice - old_v_slice, update_type=0) denominator_slice = math_ops.sqrt(v_t_slice) + temp_epsilon - # update_list = [tf.divide(-learning_rate * m_t_slice, denominator_slice)] + [m_t_slice - old_m_slice] + \ - # [v_t_slice - old_v_slice] - # update_tensor = tf.concat(update_list, axis=1) - var_update_op = host_pipeline_ops.embedding_update_by_address(addr, tf.divide(-learning_rate * m_t_slice, - denominator_slice), update_type=0) - - return control_flow_ops.group(m_update_op, v_update_op, var_update_op) - - def _convert_grads_and_addrs(self, grads_and_vars): - converted_grads_and_addrs = [] - for grad, addr in grads_and_vars: - if grad is not None: - try: - # Convert the grad to Tensor or IndexedSlices if necessary. - grad = ops.convert_to_tensor_or_indexed_slices(grad) - except TypeError as error: - raise TypeError("Gradient must be convertible to a Tensor or IndexedSlices, or None") from error - if not isinstance(grad, (ops.Tensor, indexed_slices.IndexedSlices)): - raise TypeError("Gradient must be a Tensor, IndexedSlices, or None") - processor = _get_processor(addr) - converted_grads_and_addrs.append((grad, addr, processor)) - return converted_grads_and_addrs - - def apply_gradients(self, grads_and_vars, global_step=None, name=None): - - # No DistributionStrategy case. - grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. - if not grads_and_vars: - raise ValueError("No variables provided.") - - converted_grads_and_addrs = tuple(self._convert_grads_and_addrs(grads_and_vars)) - addr_list = [a for g, a, _ in converted_grads_and_addrs if g is not None] - if not addr_list: - raise ValueError("No gradients provided for any address: %s." % - ([str(a) for _, a, _ in converted_grads_and_addrs],)) - with ops.init_scope(): - self._create_slots(addr_list) - update_ops = [] - with ops.name_scope(name, self._name) as name: - self._prepare() - for grad, addr, processor in converted_grads_and_addrs: - if grad is None: - continue - if (context.executing_eagerly() or - resource_variable_ops.is_resource_variable(addr) - and not addr._in_graph_mode): # pylint: disable=protected-access - scope_name = "" - else: - scope_name = addr.op.name - with ops.name_scope( - "update_" + scope_name), ops.colocate_with(addr): - update_ops.append(processor.update_op(self, grad)) - - apply_updates = self._finish(update_ops, name) - - if not context.executing_eagerly(): - if isinstance(apply_updates, ops.Tensor): - logging.debug(">>>>Enter ops.Tensor") - apply_updates = apply_updates.op - train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) - if apply_updates not in train_op: - logging.debug(">>>>Enter apply_updates not in train_op") - train_op.append(apply_updates) - else: - raise RuntimeError("eager wrong.") - - return apply_updates - - -def get_filtered_grad_fn(grad_fn): - def filtered_grad_fn(*args, **kwargs): - return [(g, a) for g, a in grad_fn(*args, **kwargs) if g is not None] - - return filtered_grad_fn - - -class _OptimizableAddr(metaclass=abc.ABCMeta): - """Interface for abstracting over addresses in the optimizers.""" - - @abc.abstractmethod - def target(self): - """Returns the optimization target for this address.""" - raise NotImplementedError("Calling an abstract method.") - - @abc.abstractmethod - def update_op(self, opt, grad): - """Returns the update ops for updating the address.""" - raise NotImplementedError("Calling an abstract method.") - - -class _TensorByAddressProcessor(_OptimizableAddr): - """Processor for Tensor filled with addresses.""" - - def __init__(self, addr): - self._a = addr - - def target(self): - return self._a - - def __str__(self): - return "<_TensorByAddressProcessor(%s)>" % self._a - - def update_op(self, opt, grad): - if isinstance(grad, ops.Tensor): - logging.debug(">>>>Enter update_op ops.Tensor") - update_op = opt._apply_sparse(grad, self._a) # pylint: disable=protected-access - return update_op - else: - raise RuntimeError("Only support g with type Tensor.") - - -def _get_processor(addr): - """The processor of v.""" - if isinstance(addr, ops.Tensor): - logging.debug(">>>>Enter _get_processor tensor") - return _TensorByAddressProcessor(addr) - raise NotImplementedError("Trying to optimize unsupported type ", addr) \ No newline at end of file + update_list = [tf.divide(-learning_rate * m_t_slice, denominator_slice)] + [m_t_slice - old_m_slice] + \ + [v_t_slice - old_v_slice] + update_tensor = tf.concat(update_list, axis=1) + var_update_op = host_pipeline_ops.embedding_update_by_address(addr, update_tensor, update_type=0) + + return var_update_op + + + # def _apply_sparse_shared(self, grad, addr): + # power_b1, power_b2 = self._get_beta_accumulators() + # power_b1 = math_ops.cast(power_b1, grad.dtype.base_dtype) + # power_b2 = math_ops.cast(power_b2, grad.dtype.base_dtype) + # temp_lr, temp_b1, temp_b2, temp_epsilon = self._cast_to_base_type(grad) + # learning_rate = tf.divide(temp_lr * math_ops.sqrt(1 - power_b2), (1 - power_b1)) + # + # host_pipeline_ops = get_host_pipeline_ops() + # dim = grad.shape.as_list()[-1] + # + # # addr_m = tf.where(tf.math.greater(addr, 0), addr + 4*dim, addr) + # # addr_v = tf.where(tf.math.greater(addr, 0), addr + 8*dim, addr) + # addr_m = tf.add(addr, 4*dim) + # addr_v = tf.add(addr, 8*dim) + # + # logging.debug(f'lazy adam by addr, addr is {addr}, addr_m is {addr_m}, addr_v is {addr_v}') + # old_m_slice = \ + # host_pipeline_ops.embedding_lookup_by_address(addr_m, embedding_dim=dim, embedding_type=1) + # old_v_slice = \ + # host_pipeline_ops.embedding_lookup_by_address(addr_v, embedding_dim=dim, embedding_type=1) + # + # m_t_slice = temp_b1 * old_m_slice + (1 - temp_b1) * grad + # m_update_op = host_pipeline_ops.embedding_update_by_address(addr_m, m_t_slice - old_m_slice, update_type=0) + # + # v_t_slice = temp_b2 * old_v_slice + (1 - temp_b2) * math_ops.square(grad) + # v_update_op = host_pipeline_ops.embedding_update_by_address(addr_v, v_t_slice - old_v_slice, update_type=0) + # + # denominator_slice = math_ops.sqrt(v_t_slice) + temp_epsilon + # + # var_update_op = host_pipeline_ops.embedding_update_by_address(addr, tf.divide(-learning_rate * m_t_slice, + # denominator_slice), update_type=0) + # + # return control_flow_ops.group(m_update_op, v_update_op, var_update_op) \ No newline at end of file -- Gitee From 083d822b262c494da62bc29362a530551f8f3da1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 14:29:24 +0800 Subject: [PATCH 028/551] Match-id-59e87bfedb9731b6b534b00b5d14275632195021 --- cust_op/cust_op_by_addr/CMakeLists.txt | 46 + cust_op/cust_op_by_addr/CMakePresets.json | 47 + cust_op/cust_op_by_addr/build.sh | 37 + cust_op/cust_op_by_addr/cmake/config.cmake | 16 + cust_op/cust_op_by_addr/cmake/func.cmake | 217 +++++ cust_op/cust_op_by_addr/cmake/intf.cmake | 25 + cust_op/cust_op_by_addr/cmake/makeself.cmake | 14 + .../cust_op_by_addr/cmake/util/__init__.py | 8 + .../cmake/util/batch_replay_impl.temp | 117 +++ .../cmake/util/code_channel_infer.py | 142 +++ .../cust_op_by_addr/cmake/util/const_var.py | 32 + .../cmake/util/gen_impl_and_mrege_json.sh | 57 ++ .../cmake/util/gen_ops_filter.sh | 61 ++ .../cmake/util/insert_op_info.py | 36 + .../cmake/util/insert_simplified_keys.py | 248 ++++++ .../cmake/util/kernel_entry.py | 115 +++ .../cmake/util/kernel_impl.temp | 10 + .../cmake/util/makeself/COPYING | 339 ++++++++ .../cmake/util/makeself/README.md | 246 ++++++ .../cmake/util/makeself/VERSION | 1 + .../cmake/util/makeself/make-release.sh | 9 + .../cmake/util/makeself/makeself-header.sh | 660 ++++++++++++++ .../cmake/util/makeself/makeself.1 | 110 +++ .../cmake/util/makeself/makeself.lsm | 16 + .../cmake/util/makeself/makeself.sh | 822 ++++++++++++++++++ .../cmake/util/makeself/run-tests.sh | 8 + .../cmake/util/merge_aicpu_info_json.sh | 31 + .../cmake/util/opdesc_parser.py | 176 ++++ .../cmake/util/parse_ini_to_json.py | 339 ++++++++ .../cmake/util/preset_parse.py | 23 + .../cmake/util/replay_codegen.py | 105 +++ .../cmake/util/replay_impl.temp | 120 +++ .../cmake/util/tik2_bin_param_build.py | 121 +++ .../cmake/util/tik2_impl_build.py | 376 ++++++++ .../cmake/util/tik2_ops_config.py | 111 +++ .../cmake/util/tik2_replay_build.py | 65 ++ .../cmake/util/tiling_data_def_build.py | 75 ++ .../cust_op_by_addr/framework/CMakeLists.txt | 11 + .../framework/tf_plugin/CMakeLists.txt | 8 + ...flow_embedding_lookup_by_address_plugin.cc | 13 + .../cust_op_by_addr/op_host/CMakeLists.txt | 35 + .../op_host/embedding_lookup_by_address.cpp | 146 ++++ .../embedding_lookup_by_address_tiling.h | 14 + .../op_host/embedding_update_by_address.cpp | 137 +++ .../embedding_update_by_address_tiling.h | 15 + cust_op/cust_op_by_addr/op_host/readme.md | 218 +++++ .../cust_op_by_addr/op_kernel/CMakeLists.txt | 80 ++ .../op_kernel/embedding_lookup_by_address.cpp | 235 +++++ .../op_kernel/embedding_update_by_address.cpp | 246 ++++++ cust_op/cust_op_by_addr/readme.md | 197 +++++ cust_op/cust_op_by_addr/scripts/install.sh | 228 +++++ cust_op/cust_op_by_addr/scripts/upgrade.sh | 121 +++ 52 files changed, 6685 insertions(+) create mode 100644 cust_op/cust_op_by_addr/CMakeLists.txt create mode 100644 cust_op/cust_op_by_addr/CMakePresets.json create mode 100644 cust_op/cust_op_by_addr/build.sh create mode 100644 cust_op/cust_op_by_addr/cmake/config.cmake create mode 100644 cust_op/cust_op_by_addr/cmake/func.cmake create mode 100644 cust_op/cust_op_by_addr/cmake/intf.cmake create mode 100644 cust_op/cust_op_by_addr/cmake/makeself.cmake create mode 100644 cust_op/cust_op_by_addr/cmake/util/__init__.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/batch_replay_impl.temp create mode 100644 cust_op/cust_op_by_addr/cmake/util/code_channel_infer.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/const_var.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/gen_impl_and_mrege_json.sh create mode 100644 cust_op/cust_op_by_addr/cmake/util/gen_ops_filter.sh create mode 100644 cust_op/cust_op_by_addr/cmake/util/insert_op_info.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/insert_simplified_keys.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/kernel_entry.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/kernel_impl.temp create mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/COPYING create mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/README.md create mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/VERSION create mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/make-release.sh create mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/makeself-header.sh create mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/makeself.1 create mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/makeself.lsm create mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/makeself.sh create mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/run-tests.sh create mode 100644 cust_op/cust_op_by_addr/cmake/util/merge_aicpu_info_json.sh create mode 100644 cust_op/cust_op_by_addr/cmake/util/opdesc_parser.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/parse_ini_to_json.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/preset_parse.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/replay_codegen.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/replay_impl.temp create mode 100644 cust_op/cust_op_by_addr/cmake/util/tik2_bin_param_build.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/tik2_impl_build.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/tik2_ops_config.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/tik2_replay_build.py create mode 100644 cust_op/cust_op_by_addr/cmake/util/tiling_data_def_build.py create mode 100644 cust_op/cust_op_by_addr/framework/CMakeLists.txt create mode 100644 cust_op/cust_op_by_addr/framework/tf_plugin/CMakeLists.txt create mode 100644 cust_op/cust_op_by_addr/framework/tf_plugin/tensorflow_embedding_lookup_by_address_plugin.cc create mode 100644 cust_op/cust_op_by_addr/op_host/CMakeLists.txt create mode 100644 cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp create mode 100644 cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h create mode 100644 cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp create mode 100644 cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h create mode 100644 cust_op/cust_op_by_addr/op_host/readme.md create mode 100644 cust_op/cust_op_by_addr/op_kernel/CMakeLists.txt create mode 100644 cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp create mode 100644 cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp create mode 100644 cust_op/cust_op_by_addr/readme.md create mode 100644 cust_op/cust_op_by_addr/scripts/install.sh create mode 100644 cust_op/cust_op_by_addr/scripts/upgrade.sh diff --git a/cust_op/cust_op_by_addr/CMakeLists.txt b/cust_op/cust_op_by_addr/CMakeLists.txt new file mode 100644 index 00000000..2b50f0d9 --- /dev/null +++ b/cust_op/cust_op_by_addr/CMakeLists.txt @@ -0,0 +1,46 @@ +cmake_minimum_required(VERSION 3.14.0) +project(opp) + +include(cmake/config.cmake) +include(cmake/func.cmake) +include(cmake/intf.cmake) + +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/framework) + add_subdirectory(framework) +endif() +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/op_host) + add_subdirectory(op_host) +endif() +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/op_kernel) + add_subdirectory(op_kernel) +endif() +if(ENABLE_TEST AND EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/testcases) + add_subdirectory(testcases) +endif() + +# modify vendor_name in install.sh and upgrade.sh +add_custom_command(OUTPUT ${CMAKE_BINARY_DIR}/scripts/install.sh ${CMAKE_BINARY_DIR}/scripts/upgrade.sh + COMMAND mkdir -p ${CMAKE_BINARY_DIR}/scripts + COMMAND cp -r ${CMAKE_SOURCE_DIR}/scripts/* ${CMAKE_BINARY_DIR}/scripts/ + COMMAND sed -i "s/vendor_name=customize/vendor_name=${vendor_name}/g" ${CMAKE_BINARY_DIR}/scripts/* +) +add_custom_target(modify_vendor ALL DEPENDS ${CMAKE_BINARY_DIR}/scripts/install.sh ${CMAKE_BINARY_DIR}/scripts/upgrade.sh) +install(DIRECTORY ${CMAKE_BINARY_DIR}/scripts/ DESTINATION . FILE_PERMISSIONS OWNER_EXECUTE OWNER_READ GROUP_READ) + +install(FILES ${CMAKE_SOURCE_DIR}/custom.proto DESTINATION packages OPTIONAL) + +get_system_info(SYSTEM_INFO) + +# CPack config +set(CPACK_PACKAGE_NAME ${CMAKE_PROJECT_NAME}) +set(CPACK_PACKAGE_VERSION ${CMAKE_PROJECT_VERSION}) +set(CPACK_PACKAGE_DESCRIPTION "CPack opp project") +set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CPack opp project") +set(CPACK_PACKAGE_DIRECTORY ${CMAKE_INSTALL_PREFIX}) +set(CPACK_PACKAGE_FILE_NAME "custom_opp_${SYSTEM_INFO}.run") +set(CPACK_GENERATOR External) +set(CPACK_CMAKE_GENERATOR "Unix Makefiles") +set(CPACK_EXTERNAL_ENABLE_STAGING TRUE) +set(CPACK_EXTERNAL_PACKAGE_SCRIPT ${CMAKE_SOURCE_DIR}/cmake/makeself.cmake) +set(CPACK_EXTERNAL_BUILT_PACKAGES ${CPACK_PACKAGE_DIRECTORY}/_CPack_Packages/Linux/External/${CPACK_PACKAGE_FILE_NAME}/${CPACK_PACKAGE_FILE_NAME}) +include(CPack) diff --git a/cust_op/cust_op_by_addr/CMakePresets.json b/cust_op/cust_op_by_addr/CMakePresets.json new file mode 100644 index 00000000..bd4e93df --- /dev/null +++ b/cust_op/cust_op_by_addr/CMakePresets.json @@ -0,0 +1,47 @@ +{ + "version": 1, + "cmakeMinimumRequired": { + "major": 3, + "minor": 19, + "patch": 0 + }, + "configurePresets": [ + { + "name": "default", + "displayName": "Default Config", + "description": "Default build using Unix Makefiles generator", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build_out", + "cacheVariables": { + "CMAKE_BUILD_TYPE": { + "type": "STRING", + "value": "Release" + }, + "ENABLE_SOURCE_PACKAGE": { + "type": "BOOL", + "value": "True" + }, + "ENABLE_BINARY_PACKAGE": { + "type": "BOOL", + "value": "False" + }, + "ENABLE_TEST": { + "type": "BOOL", + "value": "False" + }, + "vendor_name": { + "type": "STRING", + "value": "customize" + }, + "ASCEND_CANN_PACKAGE_PATH": { + "type": "PATH", + "value": "/usr/local/Ascend/ascend-toolkit/latest" + }, + "CMAKE_INSTALL_PREFIX": { + "type": "PATH", + "value": "${sourceDir}/build_out" + } + } + } + ] +} diff --git a/cust_op/cust_op_by_addr/build.sh b/cust_op/cust_op_by_addr/build.sh new file mode 100644 index 00000000..4be96d7d --- /dev/null +++ b/cust_op/cust_op_by_addr/build.sh @@ -0,0 +1,37 @@ +#!/bin/bash +script_path=$(realpath $(dirname $0)) + + +mkdir -p build_out +rm -rf build_out/* +cd build_out + +cmake_version=$(cmake --version | grep "cmake version" | awk '{print $3}') +if [ "$cmake_version" \< "3.19.0" ] ; then + opts=$(python3 $script_path/cmake/util/preset_parse.py $script_path/CMakePresets.json) + echo $opts + cmake .. $opts +else + cmake .. --preset=default +fi +target=package +if [ "$1"x != ""x ]; then target=$1; fi + +cmake --build . --target $target -j16 +if [ $? -ne 0 ]; then exit 1; fi + +if [ $target = "package" ]; then + if test -d ./op_kernel/binary ; then + ./cust*.run + if [ $? -ne 0 ]; then exit 1; fi + cmake --build . --target binary -j16 + if [ $? -ne 0 ]; then exit 1; fi + cmake --build . --target $target -j16 + fi +fi + +# for debug +# cd build_out +# make +# cpack +# verbose append -v diff --git a/cust_op/cust_op_by_addr/cmake/config.cmake b/cust_op/cust_op_by_addr/cmake/config.cmake new file mode 100644 index 00000000..c6f09290 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/config.cmake @@ -0,0 +1,16 @@ + +set(CMAKE_CXX_FLAGS_DEBUG "") +set(CMAKE_CXX_FLAGS_RELEASE "") + +if (NOT DEFINED vendor_name) + set(vendor_name customize CACHE STRING "") +endif() +if (NOT DEFINED ASCEND_CANN_PACKAGE_PATH) + set(ASCEND_CANN_PACKAGE_PATH /usr/local/Ascend/latest CACHE PATH "") +endif() +set(ASCEND_TENSOR_COMPILER_PATH ${ASCEND_CANN_PACKAGE_PATH}/compiler) +set(ASCEND_CCEC_COMPILER_PATH ${ASCEND_TENSOR_COMPILER_PATH}/ccec_compiler/bin) +set(ASCEND_AUTOGEN_PATH ${CMAKE_BINARY_DIR}/autogen) +set(ASCEND_COMPUTE_UNIT ascend910 ascend910b) +set(ASCEND_FRAMEWORK_TYPE tensorflow) +file(MAKE_DIRECTORY ${ASCEND_AUTOGEN_PATH}) diff --git a/cust_op/cust_op_by_addr/cmake/func.cmake b/cust_op/cust_op_by_addr/cmake/func.cmake new file mode 100644 index 00000000..69f2208d --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/func.cmake @@ -0,0 +1,217 @@ + +function(get_system_info SYSTEM_INFO) + if (UNIX) + execute_process(COMMAND grep -i ^id= /etc/os-release OUTPUT_VARIABLE TEMP) + string(REGEX REPLACE "\n|id=|ID=|\"" "" SYSTEM_NAME ${TEMP}) + set(${SYSTEM_INFO} ${SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR} PARENT_SCOPE) + elseif (WIN32) + message(STATUS "System is Windows. Only for pre-build.") + else () + message(FATAL_ERROR "${CMAKE_SYSTEM_NAME} not support.") + endif () +endfunction() + +function(opbuild) + message(STATUS "Opbuild generating sources") + cmake_parse_arguments(OPBUILD "" "OUT_DIR;PROJECT_NAME;ACCESS_PREFIX" "OPS_SRC" ${ARGN}) + execute_process(COMMAND ${CMAKE_CXX_COMPILER} -g -fPIC -shared -std=c++11 ${OPBUILD_OPS_SRC} -D_GLIBCXX_USE_CXX11_ABI=0 + -I ${ASCEND_CANN_PACKAGE_PATH}/include -L ${ASCEND_CANN_PACKAGE_PATH}/lib64 -lexe_graph -lregister + -o ${OPBUILD_OUT_DIR}/libascend_all_ops.so + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE EXEC_INFO + ERROR_VARIABLE EXEC_ERROR + ) + if (${EXEC_RESULT}) + message("build ops lib info: ${EXEC_INFO}") + message("build ops lib error: ${EXEC_ERROR}") + message(FATAL_ERROR "opbuild run failed!") + endif() + set(proj_env "") + set(prefix_env "") + if (NOT "${OPBUILD_PROJECT_NAME}x" STREQUAL "x") + set(proj_env "OPS_PROJECT_NAME=${OPBUILD_PROJECT_NAME}") + endif() + if (NOT "${OPBUILD_ACCESS_PREFIX}x" STREQUAL "x") + set(prefix_env "OPS_DIRECT_ACCESS_PREFIX=${OPBUILD_ACCESS_PREFIX}") + endif() + execute_process(COMMAND ${proj_env} ${prefix_env} ${ASCEND_CANN_PACKAGE_PATH}/toolkit/tools/opbuild/op_build + ${OPBUILD_OUT_DIR}/libascend_all_ops.so ${OPBUILD_OUT_DIR} + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE EXEC_INFO + ERROR_VARIABLE EXEC_ERROR + ) + if (${EXEC_RESULT}) + message("opbuild ops info: ${EXEC_INFO}") + message("opbuild ops error: ${EXEC_ERROR}") + endif() + message(STATUS "Opbuild generating sources - done") +endfunction() + +function(add_ops_info_target) + cmake_parse_arguments(OPINFO "" "TARGET;OPS_INFO;OUTPUT;INSTALL_DIR" "" ${ARGN}) + get_filename_component(opinfo_file_path "${OPINFO_OUTPUT}" DIRECTORY) + add_custom_command(OUTPUT ${OPINFO_OUTPUT} + COMMAND mkdir -p ${opinfo_file_path} + COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/parse_ini_to_json.py + ${OPINFO_OPS_INFO} ${OPINFO_OUTPUT} + ) + add_custom_target(${OPINFO_TARGET} ALL + DEPENDS ${OPINFO_OUTPUT} + ) + install(FILES ${OPINFO_OUTPUT} + DESTINATION ${OPINFO_INSTALL_DIR} + ) +endfunction() + +function(add_ops_impl_target) + cmake_parse_arguments(OPIMPL "" "TARGET;OPS_INFO;IMPL_DIR;OUT_DIR;INSTALL_DIR" "OPS_BATCH;OPS_ITERATE" ${ARGN}) + add_custom_command(OUTPUT ${OPIMPL_OUT_DIR}/.impl_timestamp + COMMAND mkdir -p ${OPIMPL_OUT_DIR}/dynamic + COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/tik2_impl_build.py + ${OPIMPL_OPS_INFO} + \"${OPIMPL_OPS_BATCH}\" \"${OPIMPL_OPS_ITERATE}\" + ${OPIMPL_IMPL_DIR} + ${OPIMPL_OUT_DIR}/dynamic + COMMAND rm -rf ${OPIMPL_OUT_DIR}/.impl_timestamp + COMMAND touch ${OPIMPL_OUT_DIR}/.impl_timestamp + DEPENDS ${OPIMPL_OPS_INFO} + ${CMAKE_SOURCE_DIR}/cmake/util/tik2_impl_build.py + ) + add_custom_target(${OPIMPL_TARGET} ALL + DEPENDS ${OPIMPL_OUT_DIR}/.impl_timestamp) + if (${ENABLE_SOURCE_PACKAGE}) + install(DIRECTORY ${OPIMPL_OUT_DIR}/dynamic + DESTINATION ${OPIMPL_INSTALL_DIR} + ) + endif() +endfunction() + +function(add_ops_replay_targets) + cmake_parse_arguments(OPREPLAY "" "OPS_INFO;COMPUTE_UNIT;IMPL_DIR;OUT_DIR;INSTALL_DIR" "OPS_BATCH;OPS_ITERATE" ${ARGN}) + # ccec compile options + set(ccec_base_opts -c -O2 --cce-aicore-only -mllvm -cce-aicore-function-stack-size=16000 + -mllvm -cce-aicore-record-overflow=false -std=c++17) + set(ccec_extopts_ascend310p --cce-aicore-arch=dav-m200 -mllvm -cce-aicore-fp-ceiling=2) + set(ccec_extopts_ascend910 --cce-aicore-arch=dav-c100) + set(ccec_extopts_ascend910b --cce-aicore-arch=dav-c220-cube) + file(MAKE_DIRECTORY ${OPREPLAY_OUT_DIR}) + execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/tik2_replay_build.py + ${OPREPLAY_OPS_INFO} + "${OPREPLAY_OPS_BATCH}" "${OPREPLAY_OPS_ITERATE}" + ${OPREPLAY_IMPL_DIR} + ${OPREPLAY_OUT_DIR} + ${OPREPLAY_COMPUTE_UNIT} + ) + file(GLOB replay_kernel_entries ${OPREPLAY_OUT_DIR}/*.cce) + if (NOT "${replay_kernel_entries}x" STREQUAL "x") + foreach(replay_kernel_file ${replay_kernel_entries}) + get_filename_component(replay_kernel_file_name "${replay_kernel_file}" NAME) + string(REPLACE "_entry.cce" "" op_kerne_name ${replay_kernel_file_name}) + file(GLOB replay_lib_src ${OPREPLAY_OUT_DIR}/${op_kerne_name}*.cpp) + set(OP_TILING_DATA_H_PATH ${OPREPLAY_OUT_DIR}/${op_kerne_name}_tiling_data.h) + add_library(replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} SHARED ${replay_lib_src}) + if(EXISTS ${OP_TILING_DATA_H_PATH}) + target_compile_options(replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} PRIVATE + -include ${OP_TILING_DATA_H_PATH} + ) + endif() + target_compile_definitions(replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} PRIVATE + ${op_kerne_name}=${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} + ) + target_link_libraries(replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} PRIVATE intf_pub + tikreplaylib::${OPREPLAY_COMPUTE_UNIT} + register + ) + add_custom_command(OUTPUT ${OPREPLAY_OUT_DIR}/${op_kerne_name}_entry_${OPREPLAY_COMPUTE_UNIT}.o + COMMAND ccec ${ccec_base_opts} ${ccec_extopts_${OPREPLAY_COMPUTE_UNIT}} ${replay_kernel_file} + -o ${OPREPLAY_OUT_DIR}/${op_kerne_name}_entry_${OPREPLAY_COMPUTE_UNIT}.o + DEPENDS ${replay_kernel_file} + ) + add_custom_target(replay_kernel_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} ALL + DEPENDS ${OPREPLAY_OUT_DIR}/${op_kerne_name}_entry_${OPREPLAY_COMPUTE_UNIT}.o + ) + install(TARGETS replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} + LIBRARY DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_replay + ) + install(FILES ${OPREPLAY_OUT_DIR}/${op_kerne_name}_entry_${OPREPLAY_COMPUTE_UNIT}.o + DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_replay + ) + endforeach() + endif() +endfunction() + +function(add_npu_support_target) + cmake_parse_arguments(NPUSUP "" "TARGET;OPS_INFO_DIR;OUT_DIR;INSTALL_DIR" "" ${ARGN}) + get_filename_component(npu_sup_file_path "${NPUSUP_OUT_DIR}" DIRECTORY) + add_custom_command(OUTPUT ${NPUSUP_OUT_DIR}/npu_supported_ops.json + COMMAND mkdir -p ${NPUSUP_OUT_DIR} + COMMAND ${CMAKE_SOURCE_DIR}/cmake/util/gen_ops_filter.sh + ${NPUSUP_OPS_INFO_DIR} + ${NPUSUP_OUT_DIR} + ) + add_custom_target(npu_supported_ops ALL + DEPENDS ${NPUSUP_OUT_DIR}/npu_supported_ops.json + ) + install(FILES ${NPUSUP_OUT_DIR}/npu_supported_ops.json + DESTINATION ${NPUSUP_INSTALL_DIR} + ) +endfunction() + +function(add_bin_compile_target) + cmake_parse_arguments(BINCMP "" "TARGET;OPS_INFO;COMPUTE_UNIT;IMPL_DIR;ADP_DIR;OUT_DIR;INSTALL_DIR" "" ${ARGN}) + file(MAKE_DIRECTORY ${BINCMP_OUT_DIR}/src) + file(MAKE_DIRECTORY ${BINCMP_OUT_DIR}/bin) + file(MAKE_DIRECTORY ${BINCMP_OUT_DIR}/gen) + execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/tik2_bin_param_build.py + ${BINCMP_OPS_INFO} ${BINCMP_OUT_DIR}/gen ${BINCMP_COMPUTE_UNIT} + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE EXEC_INFO + ERROR_VARIABLE EXEC_ERROR + ) + if (${EXEC_RESULT}) + message("ops binary compile scripts gen info: ${EXEC_INFO}") + message("ops binary compile scripts gen error: ${EXEC_ERROR}") + message(FATAL_ERROR "ops binary compile scripts gen failed!") + endif() + if (NOT TARGET binary) + add_custom_target(binary) + endif() + add_custom_target(${BINCMP_TARGET} + COMMAND cp ${BINCMP_IMPL_DIR}/*.cpp ${BINCMP_OUT_DIR}/src + ) + add_custom_target(${BINCMP_TARGET}_gen_ops_config + COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/insert_simplified_keys.py -p ${BINCMP_OUT_DIR}/bin + COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/tik2_ops_config.py -p ${BINCMP_OUT_DIR}/bin + -s ${BINCMP_COMPUTE_UNIT} + ) + add_dependencies(binary ${BINCMP_TARGET}_gen_ops_config) + file(GLOB bin_scripts ${BINCMP_OUT_DIR}/gen/*.sh) + foreach(bin_script ${bin_scripts}) + get_filename_component(bin_file ${bin_script} NAME_WE) + string(REPLACE "-" ";" bin_sep ${bin_file}) + list(GET bin_sep 0 op_type) + list(GET bin_sep 1 op_file) + list(GET bin_sep 2 op_index) + if (NOT TARGET ${BINCMP_TARGET}_${op_file}_copy) + file(MAKE_DIRECTORY ${BINCMP_OUT_DIR}/bin/${op_file}) + add_custom_target(${BINCMP_TARGET}_${op_file}_copy + COMMAND cp ${BINCMP_ADP_DIR}/${op_file}.py ${BINCMP_OUT_DIR}/src/${op_type}.py + ) + install(DIRECTORY ${BINCMP_OUT_DIR}/bin/${op_file} + DESTINATION ${BINCMP_INSTALL_DIR}/${BINCMP_COMPUTE_UNIT} OPTIONAL + ) + install(FILES ${BINCMP_OUT_DIR}/bin/${op_file}.json + DESTINATION ${BINCMP_INSTALL_DIR}/config/${BINCMP_COMPUTE_UNIT}/ OPTIONAL + ) + endif() + add_custom_target(${BINCMP_TARGET}_${op_file}_${op_index} + COMMAND bash ${bin_script} ${BINCMP_OUT_DIR}/src/${op_type}.py ${BINCMP_OUT_DIR}/bin/${op_file} + WORKING_DIRECTORY ${BINCMP_OUT_DIR} + ) + add_dependencies(${BINCMP_TARGET}_${op_file}_${op_index} ${BINCMP_TARGET} ${BINCMP_TARGET}_${op_file}_copy) + add_dependencies(${BINCMP_TARGET}_gen_ops_config ${BINCMP_TARGET}_${op_file}_${op_index}) + endforeach() + install(FILES ${BINCMP_OUT_DIR}/bin/binary_info_config.json + DESTINATION ${BINCMP_INSTALL_DIR}/config/${BINCMP_COMPUTE_UNIT} OPTIONAL + ) +endfunction() diff --git a/cust_op/cust_op_by_addr/cmake/intf.cmake b/cust_op/cust_op_by_addr/cmake/intf.cmake new file mode 100644 index 00000000..1c54b6ea --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/intf.cmake @@ -0,0 +1,25 @@ + +add_library(intf_pub INTERFACE) +target_compile_options(intf_pub INTERFACE + -fPIC + -fvisibility=hidden + -fvisibility-inlines-hidden + $<$:-O2 -s> + $<$:-O0 -g> + $<$:-std=c++11> + $<$,$>:-ftrapv -fstack-check> + $<$:-pthread -Wfloat-equal -Wshadow -Wformat=2 -Wno-deprecated -Wextra> + $,-fstack-protector-strong,-fstack-protector-all> +) +target_compile_definitions(intf_pub INTERFACE + _GLIBCXX_USE_CXX11_ABI=0 + $<$:_FORTIFY_SOURCE=2> +) +target_include_directories(intf_pub INTERFACE ${ASCEND_CANN_PACKAGE_PATH}/include) +target_link_options(intf_pub INTERFACE + $<$,EXECUTABLE>:-pie> + -Wl,-z,relro + -Wl,-z,now + -Wl,-z,noexecstack +) +target_link_directories(intf_pub INTERFACE ${ASCEND_CANN_PACKAGE_PATH}/lib64) diff --git a/cust_op/cust_op_by_addr/cmake/makeself.cmake b/cust_op/cust_op_by_addr/cmake/makeself.cmake new file mode 100644 index 00000000..18bdc331 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/makeself.cmake @@ -0,0 +1,14 @@ +execute_process(COMMAND chmod +x ${CMAKE_CURRENT_LIST_DIR}/util/makeself/makeself.sh) +execute_process(COMMAND ${CMAKE_CURRENT_LIST_DIR}/util/makeself/makeself.sh --gzip --complevel 4 --nomd5 --sha256 + ./ ${CPACK_PACKAGE_FILE_NAME} "version:1.0" ./install.sh + WORKING_DIRECTORY ${CPACK_TEMPORARY_DIRECTORY} + RESULT_VARIABLE EXEC_RESULT + ERROR_VARIABLE EXEC_ERROR +) +if (NOT "${EXEC_RESULT}x" STREQUAL "0x") + message(FATAL_ERROR "CPack Command error: ${EXEC_RESULT}\n${EXEC_ERROR}") +endif() +execute_process(COMMAND cp ${CPACK_EXTERNAL_BUILT_PACKAGES} ${CPACK_PACKAGE_DIRECTORY}/ + COMMAND echo "Copy ${CPACK_EXTERNAL_BUILT_PACKAGES} to ${CPACK_PACKAGE_DIRECTORY}/" + WORKING_DIRECTORY ${CPACK_TEMPORARY_DIRECTORY} +) diff --git a/cust_op/cust_op_by_addr/cmake/util/__init__.py b/cust_op/cust_op_by_addr/cmake/util/__init__.py new file mode 100644 index 00000000..c4ddc893 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import sys +import os + +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(PYF_PATH) diff --git a/cust_op/cust_op_by_addr/cmake/util/batch_replay_impl.temp b/cust_op/cust_op_by_addr/cmake/util/batch_replay_impl.temp new file mode 100644 index 00000000..7b4f5edf --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/batch_replay_impl.temp @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#include +#include +#include +#include "replay_def.h" +#include "code_gen.h" +#include "replay_fun.h" +#include "register/op_check.h" +#define __TIK2_REPLAY_CODE__ +#include + +using namespace std; +using namespace optiling; +using namespace tik2_replay; + +extern "C" void __KERNEL_FUN__ (__ARGS_DEF__, const char *); +extern "C" int elf_batch_append(char *elf, uint32_t elfSize, char *jit, int kernum, char *atext[], int alen[], + int atlen, const char* kernelname[]); + +#define KERNEL_N 1 +#define ARG_N (__ARG_NUM__) +#define MAX_L (1024 * 1024 * 100) +#define MAX_E (1024 * 1024) + +int __KERNEL_FUN___replay___OPS_PRODUCT__(ReplayFuncParam& param, const int core_type) +{ + // gen type 1 : direct call codes 0: load .o file + if (param.gentype < 0 || param.gentype > 1) { + printf("Error: call replay gen type is %d, should only be 1 or 0\n", param.gentype); + return 0; + } else if (param.gentype == 1 && param.objptr == nullptr) { + printf("Error: call replay with direct call mode, but code obj addr is null\n"); + return 0; + } else if (param.gentype == 0 && param.output_kernel_file == nullptr) { + printf("Error: call replay with object file mode, but object file path is null\n"); + return 0; + } + // core_type 0:MIX 1:CUBE 2:VEC + if (core_type < 0 || core_type > 2) { + printf("Error: call replay core type is %d !\n", core_type); + return 0; + } + g_coreType = __CORE_TYPE__; + g_taskRation = param.task_ration; + g_tilingKey = param.tiling_key; + + unsigned char *buf, *jit; + char *kernel[KERNEL_N]; + int len[KERNEL_N]; + block_idx = 0; + block_num = param.block_dim; + g_ubBase = block_num; + uint8_t *code = (uint8_t *)malloc(MAX_L); + uint8_t *pos = code; + struct timespec tp1, tp2; + + clock_gettime(CLOCK_MONOTONIC, &tp1); + if (block_num > 32) { + printf("Error: block_num > 32\n"); + return 0; + } + //__OP_FOPEN__ + for (int i = 0; i < KERNEL_N; i++) { + //__OP_SET_KERNEL__ + for (int j = 0; j < ARG_N; j++) + AddArg(j, ARG_STEP * (j + 1)); +#ifdef FP_CEILING + SetCtrlFloatEnable(); +#else + SetCtrlFloatDisable(); +#endif + CodeInit(pos, true); + __KERNEL_FUN__(__KERNEL_ARGS__, param.tiling_data); + CodeEnd(); + kernel[i] = (char *)pos; + len[i] = CodeLen(); + pos += len[i]; + } + //__OP_FCLOSE__ + clock_gettime(CLOCK_MONOTONIC, &tp2); + buf = (unsigned char *)malloc(MAX_E); + int fd = open(param.entry_file, O_RDONLY); + if (fd < 0) { + printf("[error]: cannot find entry.o : %s\n", param.entry_file); + return 0; + } + uint32_t bufSize = read(fd, buf, MAX_E); + if (bufSize <= 0) { + printf("[error]: entry.o : %s is too small ! \n", param.entry_file); + } + close(fd); + jit = (unsigned char *)malloc(MAX_L); + printf("total code generated %ld\n", pos - code); + int sz = elf_batch_append((char *)buf, bufSize, (char *)jit, KERNEL_N, kernel, len, pos - code, ¶m.kernel_name); + if (tp1.tv_sec != tp2.tv_sec) { + printf("%ld NS\n", tp2.tv_nsec + 1000000000 - tp1.tv_nsec); + } else { + printf("%ld NS\n", tp2.tv_nsec - tp1.tv_nsec); + } + printf("new elf size %d\n", sz); + if (param.gentype == 0) { + fd = open(param.output_kernel_file, O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR); + (void)write(fd, jit, sz); + close(fd); + free(jit); + } else if (param.gentype == 1) { + *param.objptr = (char*)jit; + } + free(buf); + free(code); + return sz; +} + +REG_REPLAY_FUNC(__OPTYPE__, __OPS_PRODUCT__, __KERNEL_FUN___replay___OPS_PRODUCT__); diff --git a/cust_op/cust_op_by_addr/cmake/util/code_channel_infer.py b/cust_op/cust_op_by_addr/cmake/util/code_channel_infer.py new file mode 100644 index 00000000..49ce5e52 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/code_channel_infer.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" +import os +import stat +import ctypes +import collections +import shutil +import subprocess +import copy + +"""CODE_* is used to cube/vector api is called in operator code +CODE_MIX means both cube and vector api is called +CODE_CUBE means only cube api is called +CODE_VEC means only vector api is called +""" +CODE_MIX = 0 +CODE_CUBE = 1 +CODE_VEC = 2 + + +def _is_v220(op_product: str): + """return if current soc version is V220 + + Returns: + res: True means V220 + """ + if op_product in ["ascend910b", "ascend910c"]: + return True + return False + +CheckCoreTypeParams = collections.namedtuple('CheckCoreTypeParams',\ +['src_file', 'arch', 'kernel_name', 'compile_options', 'addrspace_list', 'outdir']) + + +def _check_core_type(check_core_type_params: CheckCoreTypeParams): + """1. call ccec -S -emit-llvm to generate llvm-ir file + 2. analysis addrspace to check if exists cube or vector buffer scope + + Args: + CheckCoreTypeParams: + src_file (str): TIK2 operator code file + arch (str): _description_ + kernel_name (str): kernel function name + compile_options (list): compile options for ccec cmd + addrspace_list (list): addrspace of cube or vector + outdir(str): temp file output + + Returns: + res (bool): True if exists target addrspapce of arch + """ + llvm_ir_file = os.path.join(check_core_type_params.outdir, check_core_type_params.kernel_name + "_" +\ + check_core_type_params.arch + ".ll") + compile_cmd = [shutil.which("ccec"), '-S', '-emit-llvm', '-std=c++17', '-x', 'cce',\ + check_core_type_params.src_file] + compile_cmd += check_core_type_params.compile_options + + compile_cmd += ["--cce-aicore-arch=%s" % check_core_type_params.arch, + "--cce-aicore-only", "-o", llvm_ir_file] + proc = subprocess.Popen( + compile_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + if proc.returncode != 0: + msg = "compile %s error :%s\n" % (check_core_type_params.src_file, out.decode()) + print("check core type ", msg) + return False + + def _check_exist_space(line_content, space_list): + if "addrspace" not in line_content: + return False + for space in space_list: + if space in line_content: + return True + return False + + access_space = False + with open(llvm_ir_file) as llvm_ir: + line_list = llvm_ir.readlines() + for line in line_list: + if access_space: + break + access_space = _check_exist_space(line, check_core_type_params.addrspace_list) + os.remove(llvm_ir_file) + return access_space + + +InfoCodeChanelParams = collections.namedtuple('InfoCodeChanelParams',\ +['src_file', 'tiling_header', 'kernel_name', 'outdir', 'op_product', 'compile_options']) + + +def infer_code_channel(params: InfoCodeChanelParams): + """get code channel for v220, return CODE_MIX if soc version is not V220 + + Args: + src_file (str): TIK2 operator code file + src_file (str): TIK2 operator tiling header file + kernel_name (str): kernel function name + optype (str): operator type + compile_options (list): compile options for ccec cmd + + Raises: + Exception: if not exist L1/L0/UB if code, it's not a aicore code + + Returns: + res (int): CODE_MIX/CODE_CUBE/CODE_VEC + """ + if not _is_v220(params.op_product): + return CODE_MIX + if params.compile_options is None: + compile_options = [] + else: + compile_options = params.compile_options + ccec = shutil.which("ccec") + if ccec is not None: + ccec_path = os.path.dirname(ccec) + tikcpp_path = os.path.realpath(os.path.join(ccec_path, "..", "..", "tikcpp")) + else: + tikcpp_path = os.path.realpath("/usr/local/Ascend/latest/compiler/tikcpp") + compile_options.append("-I" + tikcpp_path) + compile_options.append("-I" + os.path.join(tikcpp_path, "tikcfw")) + compile_options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "impl")) + compile_options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "interface")) + compile_options.append("-D__NPU_TILING__") + compile_options += ["-include", params.tiling_header] + cube_addrspace_list = ["addrspace(2)", "addrspace(3)", "addrspace(4)", "addrspace(5)"] + access_l1_l0a = _check_core_type(CheckCoreTypeParams(params.src_file, "dav-c220-cube", params.kernel_name,\ + compile_options, cube_addrspace_list, params.outdir)) + cube_addrspace_list = ["addrspace(6)"] + access_ub = _check_core_type(CheckCoreTypeParams(params.src_file, "dav-c220-vec", params.kernel_name,\ + compile_options, cube_addrspace_list, params.outdir)) + + if access_l1_l0a and access_ub: + return CODE_MIX + elif access_l1_l0a: + return CODE_CUBE + elif access_ub: + return CODE_VEC + else: + raise Exception(f"cannot find valid addrspace in (2,3,4,5,6) in {params.src_file}") diff --git a/cust_op/cust_op_by_addr/cmake/util/const_var.py b/cust_op/cust_op_by_addr/cmake/util/const_var.py new file mode 100644 index 00000000..f5dde656 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/const_var.py @@ -0,0 +1,32 @@ + +#!/usr/bin/env python +# coding=utf-8 +""" +Function: +The replay funtion entry +Copyright Information: +Huawei Technologies Co., Ltd. All Rights Reserved © 2020 +""" + +import os +import stat + + +REPLAY_BATCH = 'batch' +REPLAY_ITERATE = 'iterate' +CFG_IMPL_DIR = 'impl_dir' +CFG_OUT_DIR = 'out_dir' +WFLAGS = os.O_WRONLY | os.O_CREAT +WMODES = stat.S_IWUSR | stat.S_IRUSR +SOC_MAP_EXT = {'ascend310p': 'Ascend310P3', 'ascend310b': 'Ascend310B1', + 'ascend910': 'Ascend910A', 'ascend910b': 'Ascend910B1'} +BIN_CMD = 'opc $1 --main_func={fun} --input_param={param} --soc_version={soc} \ +--output=$2 --impl_mode={impl} --op_mode=dynamic\n' +CHK_CMD = ''' +if ! test -f $2/{res_file} ; then + echo "$2/{res_file} not generated!" + exit 1 +fi +''' +ATTR_DEF_VAL = {'str' : '', 'int': 0, 'float': 0.0, 'bool': False, 'list_bool': [], + 'list_int': [], 'list_float': [], 'list_list_int': [[]]} diff --git a/cust_op/cust_op_by_addr/cmake/util/gen_impl_and_mrege_json.sh b/cust_op/cust_op_by_addr/cmake/util/gen_impl_and_mrege_json.sh new file mode 100644 index 00000000..55e12e5e --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/gen_impl_and_mrege_json.sh @@ -0,0 +1,57 @@ +#!/usr/bin/bash + +project_path=$1 +build_path=$2 +vendor_name=customize +if [[ ! -d "$project_path" ]]; then + echo "[ERROR] No projcet path is provided" + exit 1 +fi + +if [[ ! -d "$build_path" ]]; then + echo "[ERROR] No build path is provided" + exit 1 +fi + +# copy ai_core operators implements +# tbe_impl_files_num=$(ls $project_path/tbe/impl/* 2> /dev/null | wc -l) +# if [[ "$tbe_impl_files_num" -gt 0 ]];then +# cp -r ${project_path}/tbe/impl/* ${build_path}/makepkg/packages/vendors/$vendor_name/op_impl/ai_core/tbe/customize_impl +# cp -r ${project_path}/tbe/impl/* ${build_path}/makepkg/packages/vendors/$vendor_name/op_impl/vector_core/tbe/customize_impl +# fi + +# copy aicpu kernel so operators +if [[ -d "${project_path}/cpukernel/aicpu_kernel_lib" ]]; then + cp -f ${project_path}/cpukernel/aicpu_kernel_lib/* ${build_path}/makepkg/packages/vendors/$vendor_name/op_impl/cpu/aicpu_kernel/impl + rm -rf ${project_path}/cpukernel/aicpu_kernel_lib +fi + +# merge aicpu.ini and aicore.ini to generate npu_supported_ops.json +# mkdir -p ${build_path}/framework/op_info_cfg +# mkdir -p ${build_path}/framework/op_info_cfg/aicpu_kernel +# mkdir -p ${build_path}/framework/op_info_cfg/ai_core + +# if [[ -d "${project_path}/tbe/op_info_cfg/ai_core" ]]; then +# bash ${project_path}/cmake/util/gen_ops_filter.sh ${project_path}/tbe/op_info_cfg/ai_core ${build_path}/framework/op_info_cfg/ai_core +# fi + +# if [[ -d "${project_path}/cpukernel/op_info_cfg/aicpu_kernel" ]]; then +# bash ${project_path}/cmake/util/gen_ops_filter.sh ${project_path}/cpukernel/op_info_cfg/aicpu_kernel ${build_path}/framework/op_info_cfg/aicpu_kernel +# fi + +# aicpu_filter_file=${build_path}/framework/op_info_cfg/aicpu_kernel/npu_supported_ops.json +# aicore_filter_file=${build_path}/framework/op_info_cfg/ai_core/npu_supported_ops.json +# if [[ -f "${aicpu_filter_file}" ]] && [[ ! -f "${aicore_filter_file}" ]]; then +# cp $aicpu_filter_file ${build_path}/makepkg/packages/vendors/$vendor_name/framework/tensorflow +# fi +# if [[ -f "${aicore_filter_file}" ]] && [[ ! -f "${aicpu_filter_file}" ]]; then +# cp $aicore_filter_file ${build_path}/makepkg/packages/vendors/$vendor_name/framework/tensorflow +# fi + +# if [[ -f "${aicore_filter_file}" ]] && [[ -f "${aicpu_filter_file}" ]]; then +# chmod u+w ${aicpu_filter_file} +# python3 ${project_path}/cmake/util/insert_op_info.py ${aicore_filter_file} ${aicpu_filter_file} +# chmod u-w ${aicpu_filter_file} +# cp $aicpu_filter_file ${build_path}/makepkg/packages/vendors/$vendor_name/framework/tensorflow +# fi + diff --git a/cust_op/cust_op_by_addr/cmake/util/gen_ops_filter.sh b/cust_op/cust_op_by_addr/cmake/util/gen_ops_filter.sh new file mode 100644 index 00000000..54c7c640 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/gen_ops_filter.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +# Description: Generate npu_supported_ops.json +# ============================================================================== + +if [[ -z "$1" ]]; then + echo -e "[ERROR] No source dir provided" + exit 1 +fi + +if [[ -z "$2" ]]; then + echo -e "[ERROR] No destination dir provided" + exit 1 +fi + +src=$1 +dest_file=$2/npu_supported_ops.json + +if [ -f "$dest_file" ];then + chmod u+w $dest_file +fi + +echo $* + +add_ops() { + name=$1 + isHeavy=$2 + file=$3 + grep -w "\"$name\"" ${file} >/dev/null + if [ $? == 0 ];then + return + fi + echo " \"${name}\": {" >> ${file} + echo " \"isGray\": false," >> ${file} + echo " \"isHeavy\": ${isHeavy}" >> ${file} + echo " }," >> ${file} +} + +echo "{" > ${dest_file} +ini_files=$(find ${src} -name "*.ini") +for file in ${ini_files} ; do + name=$(grep '^\[' ${file} | sed 's/\[//g' | sed 's/]//g' | sed 's/\r//g') + grep 'heavyOp.flag' ${file} >/dev/null + if [ $? == 0 ];then + isHeavy=$(grep 'heavyOp.flag' ${file} | awk -F= '{print $2}') + else + isHeavy="false" + fi + for op in ${name}; do + add_ops ${op} ${isHeavy} ${dest_file} + done +done +echo "}" >> ${dest_file} +file_count=$(cat ${dest_file} | wc -l) +line=$(($file_count-1)) +sed -i "${line}{s/,//g}" ${dest_file} + +chmod 640 "${dest_file}" +echo -e "[INFO] Succed generated ${dest_file}" + +exit 0 diff --git a/cust_op/cust_op_by_addr/cmake/util/insert_op_info.py b/cust_op/cust_op_by_addr/cmake/util/insert_op_info.py new file mode 100644 index 00000000..28ba0875 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/insert_op_info.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" +import json +import os +import sys +import stat +import const_var + + +if __name__ == '__main__': + if len(sys.argv) != 3: + print(sys.argv) + print('argv error, inert_op_info.py your_op_file lib_op_file') + sys.exit(2) + + with open(sys.argv[1], 'r') as load_f: + insert_operator = json.load(load_f) + + all_operators = {} + if os.path.exists(sys.argv[2]): + if os.path.getsize(sys.argv[2]) != 0: + with open(sys.argv[2], 'r') as load_f: + all_operators = json.load(load_f) + + for k in insert_operator.keys(): + if k in all_operators.keys(): + print('replace op:[', k, '] success') + else: + print('insert op:[', k, '] success') + all_operators[k] = insert_operator[k] + + with os.fdopen(os.open(sys.argv[2], const_var.WFLAGS, const_var.WMODES), 'w') as json_file: + json_file.write(json.dumps(all_operators, indent=4)) diff --git a/cust_op/cust_op_by_addr/cmake/util/insert_simplified_keys.py b/cust_op/cust_op_by_addr/cmake/util/insert_simplified_keys.py new file mode 100644 index 00000000..19c3820f --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/insert_simplified_keys.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" + +import sys +import os +import re +import glob +import json +import argparse +import const_var + + +DATA_TPYE_DICT = { + 'float32': 0, + 'float16': 1, + 'int8': 2, + 'int16': 6, + 'uint16': 7, + 'uint8': 4, + 'int32': 3, + 'int64': 9, + 'uint32': 8, + 'uint64': 10, + 'bool': 12, + 'double': 11, + 'string': 13, + 'dual': 14, + 'dual': 15, + 'complex64': 16, + 'complex128': 17, + 'qint8': 18, + 'qint16': 19, + 'qint32': 20, + 'quint8': 21, + 'quint16': 22, + 'resource': 23, + 'string': 24, + 'dual': 25, + 'variant': 26, + 'bf16': 27, + 'bfloat16': 27, + 'undefined': 28, + 'int4': 29, + 'uint1': 30, + 'int2': 31 +} + +FORMAT_DICT = { + 'NCHW': 0, + 'NHWC': 1, + 'ND': 2, + 'NC1HWC0': 3, + 'FRACTAL_Z': 4, + 'NC1C0HWPAD': 5, + 'NHWC1C0': 6, + 'FSR_NCHW': 7, + 'FRACTAL_DECONV': 8, + 'C1HWNC0': 9, + 'FRACTAL_DECONV_TRANSPOSE': 10, + 'FRACTAL_DECONV_SP_STRIDE_TRANS': 11, + 'NC1HWC0_C04': 12, + 'FRACTAL_Z_C04': 13, + 'CHWN': 14, + 'FRACTAL_DECONV_SP_STRIDE8_TRANS': 15, + 'HWCN': 16, + 'NC1KHKWHWC0': 17, + 'BN_WEIGHT': 18, + 'FILTER_HWCK': 19, + 'HASHTABLE_LOOKUP_LOOKUPS': 20, + 'HASHTABLE_LOOKUP_KEYS': 21, + 'HASHTABLE_LOOKUP_VALUE': 22, + 'HASHTABLE_LOOKUP_OUTPUT': 23, + 'HASHTABLE_LOOKUP_HITS': 24, + 'C1HWNCoC0': 25, + 'MD': 26, + 'NDHWC': 27, + 'FRACTAL_ZZ': 28, + 'FRACTAL_NZ': 29, + 'NCDHW': 30, + 'DHWCN': 31, + 'NDC1HWC0': 32, + 'FRACTAL_Z_3D': 33, + 'CN': 34, + 'NC': 35, + 'DHWNC': 36, + 'FRACTAL_Z_3D_TRANSPOSE': 37, + 'FRACTAL_ZN_LSTM': 38, + 'FRACTAL_Z_G': 39, + 'RESERVED': 40, + 'ALL': 41, + 'NULL': 42, + 'ND_RNN_BIAS': 43, + 'FRACTAL_ZN_RNN': 44, + 'NYUV': 45, + 'NYUV_A': 46 +} + + +def load_json(json_file: str): + with open(json_file, encoding='utf-8') as file: + json_content = json.load(file) + return json_content + + +def get_specified_suffix_file(root_dir, suffix): + specified_suffix = os.path.join(root_dir, '**/*.{}'.format(suffix)) + all_suffix_files = glob.glob(specified_suffix, recursive=True) + return all_suffix_files + + +def get_deterministic_value(support_info): + deterministic_key = 'deterministic' + if deterministic_key not in support_info: + return 0 + deterministic_value = support_info.get(deterministic_key) + if deterministic_value == 'true': + return 1 + else: + return 0 + + +def get_precision_value(support_info): + precision_key = 'implMode' + precision_value = support_info.get(precision_key) + if precision_value == 'high_performance': + _value = 1 + elif precision_value == 'high_precision': + _value = 2 + else: + _value = 0 + return _value + + +def get_overflow_value(support_info): + return 0 + + +def get_parameters(info): + if info: + if 'dtype' in info: + data_type = info['dtype'] + data_type_value = DATA_TPYE_DICT.get(data_type) + else: + data_type_value = 0 + if 'format' in info: + _format = info['format'] + _format_value = FORMAT_DICT.get(_format) + else: + _format_value = 0 + else: + data_type_value = 0 + _format_value = 0 + return str(data_type_value), str(_format_value) + + +def get_dynamic_parameters(info): + # 动态输入时只需获取第一个参数 + return get_parameters(info[0]) + + +def get_all_parameters(support_info, _type): + result_list = list() + info_lists = support_info.get(_type) + if info_lists: + for _info in info_lists: + # 输入为列表时是动态输入 + if isinstance(_info, list): + data_type_value, _format_value = get_dynamic_parameters(_info) + else: + data_type_value, _format_value = get_parameters(_info) + result_list.append("{},{}".format(data_type_value, _format_value)) + return result_list + + +def get_all_input_parameters(support_info): + result = get_all_parameters(support_info, 'inputs') + return '/'.join(result) + + +def insert_content_into_file(input_file, content): + with open(input_file, 'r+') as file: + lines = file.readlines() + for index, line in enumerate(lines): + match_result = re.search(r'"staticKey":', line) + if match_result: + count = len(line) - len(line.lstrip()) + new_content = "{}{}".format(' ' * count, content) + # 插入到前一行,防止插入最后时还需要考虑是否添加逗号 + lines.insert(index, new_content) + break + file.seek(0) + file.write(''.join(lines)) + + +def insert_simplified_keys(json_file): + contents = load_json(json_file) + # 不存在'binFileName'或者'supportInfo'字段时,非需要替换的解析json文件 + if ('binFileName' not in contents) or ('supportInfo' not in contents): + return + support_info = contents.get('supportInfo') + bin_file_name = contents.get('binFileName') + bin_suffix = contents.get('binFileSuffix') + # 'simplifiedKey'字段已经存在时,直接返回,不重复生成 + if 'simplifiedKey' in support_info: + return + op_type = bin_file_name.split('_')[0] + deterministic = str(get_deterministic_value(support_info)) + precision = str(get_precision_value(support_info)) + overflow = str(get_overflow_value(support_info)) + input_parameters = get_all_input_parameters(support_info) + key = '{}/d={},p={},o={}/{}/'.format( + op_type, + deterministic, + precision, + overflow, + input_parameters) + result = '"simplifiedKey": "' + key + '",\n' + insert_content_into_file(json_file, result) + + +def insert_all_simplified_keys(root_dir): + suffix = 'json' + all_json_files = get_specified_suffix_file(root_dir, suffix) + for _json in all_json_files: + insert_simplified_keys(_json) + + +def args_prase(): + parser = argparse.ArgumentParser() + parser.add_argument('-p', + '--path', + nargs='?', + required=True, + help='Parse the path of the json file.') + return parser.parse_args() + + +def main(): + args = args_prase() + insert_all_simplified_keys(args.path) + + +if __name__ == '__main__': + main() diff --git a/cust_op/cust_op_by_addr/cmake/util/kernel_entry.py b/cust_op/cust_op_by_addr/cmake/util/kernel_entry.py new file mode 100644 index 00000000..2b77c970 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/kernel_entry.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" + + +def gen_fun_def(title, kernel, argn, arg_type, arg_name): + entry = [] + entry.append(title) + entry.append(kernel) + entry.append('(') + args = [] + for i in range(0, argn): + args.append(arg_type + ' ' + arg_name + str(i)) + entry.append(', '.join(args)) + entry.append(')') + return ' '.join(entry) + + +def gen_batch_kernel_body(fname, argn, arg_name): + body = [] + body.append('{') + fun = [] + fun.append(fname) + fun.append('(') + args = [] + for i in range(0, argn): + args.append(arg_name + str(i)) + fun.append(', '.join(args)) + fun.append(');') + body.append(' '.join(fun)) + body.append('}') + return '\n'.join(body) + + +def gen_mc_kernel_body(kn, argn, arg_name, blknum): + body = [] + body.append('{') + body.append(' switch(block_idx) {') + for blk in range(0, blknum): + fun = [] + fun.append('{}_blk{:02d}'.format(kn, blk)) + fun.append('(') + args = [] + for i in range(0, argn): + args.append(arg_name + str(i)) + fun.append(', '.join(args)) + fun.append(')') + body.append(' case {}: {}; break;'.format(blk, ' '.join(fun))) + body.append(' default: break;') + body.append(' }') + body.append('}') + return '\n'.join(body) + + +def gen_proc_body(argn, arg_name): + body = [] + body.append('{') + args = [] + for i in range(0, argn): + args.append(arg_name + str(i)) + body.append('uint64_t __x = (uint64_t)' + ' + (uint64_t)'.join(args) + ';') + body.append('__asm__ ("NOP");') + body.append('__asm__ ("NOP");') + body.append('__asm__ ("NOP");') + body.append('}') + return '\n'.join(body) + + +def batch_code_gen(kn, argn, argt): + codes = [] + kernel_name = kn + proc_name = kernel_name + '_percore' + arg_num = int(argn) + data_type = argt + arg_type = '__gm__ ' + data_type + '* __restrict__' + arg_name = 'arg' + kernel_title = 'extern \"C\" __global__ __aicore__ void' + proc_title = 'extern \"C\" __attribute__((noinline)) __aicore__ void' + codes.append('#ifndef __aicore__') + codes.append('#define __aicore__ [aicore]') + codes.append('#endif') + codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name) + ';') + codes.append(gen_fun_def(kernel_title, kernel_name, arg_num, arg_type, arg_name)) + codes.append(gen_batch_kernel_body(proc_name, arg_num, arg_name)) + codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name)) + codes.append(gen_proc_body(arg_num, arg_name)) + return '\n'.join(codes) + '\n' + + +def mc_code_gen(kn, argn, argt, blknum): + codes = [] + kernel_name = kn + core_num = int(blknum) + arg_num = int(argn) + data_type = argt + arg_type = '__gm__ ' + data_type + '* __restrict__' + arg_name = 'arg' + kernel_title = 'extern \"C\" __global__ __aicore__ void' + proc_title = 'extern \"C\" __attribute__((noinline)) __aicore__ void' + codes.append('#ifndef __aicore__') + codes.append('#define __aicore__ [aicore]') + codes.append('#endif') + for i in range(0, core_num): + proc_name = '{}_blk{:02d}'.format(kernel_name, i) + codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name) + ';') + codes.append(gen_fun_def(kernel_title, kernel_name, arg_num, arg_type, arg_name)) + codes.append(gen_mc_kernel_body(kernel_name, arg_num, arg_name, core_num)) + for i in range(0, core_num): + proc_name = '{}_blk{:02d}'.format(kernel_name, i) + codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name)) + codes.append(gen_proc_body(arg_num, arg_name)) + return '\n'.join(codes) + '\n' diff --git a/cust_op/cust_op_by_addr/cmake/util/kernel_impl.temp b/cust_op/cust_op_by_addr/cmake/util/kernel_impl.temp new file mode 100644 index 00000000..7391befa --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/kernel_impl.temp @@ -0,0 +1,10 @@ +#include +#include +#include +#include +#include +#include "replay_def.h" +#include "code_gen.h" +#include "replay_fun.h" +#define __TIK2_REPLAY_CODE__ +#include "__CCE_FILE__" diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/COPYING b/cust_op/cust_op_by_addr/cmake/util/makeself/COPYING new file mode 100644 index 00000000..d159169d --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/makeself/COPYING @@ -0,0 +1,339 @@ + GNU GENERAL PUBLIC LICENSE + Version 2, June 1991 + + Copyright (C) 1989, 1991 Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The licenses for most software are designed to take away your +freedom to share and change it. By contrast, the GNU General Public +License is intended to guarantee your freedom to share and change free +software--to make sure the software is free for all its users. This +General Public License applies to most of the Free Software +Foundation's software and to any other program whose authors commit to +using it. (Some other Free Software Foundation software is covered by +the GNU Lesser General Public License instead.) You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +this service if you wish), that you receive source code or can get it +if you want it, that you can change the software or use pieces of it +in new free programs; and that you know you can do these things. + + To protect your rights, we need to make restrictions that forbid +anyone to deny you these rights or to ask you to surrender the rights. +These restrictions translate to certain responsibilities for you if you +distribute copies of the software, or if you modify it. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must give the recipients all the rights that +you have. You must make sure that they, too, receive or can get the +source code. And you must show them these terms so they know their +rights. + + We protect your rights with two steps: (1) copyright the software, and +(2) offer you this license which gives you legal permission to copy, +distribute and/or modify the software. + + Also, for each author's protection and ours, we want to make certain +that everyone understands that there is no warranty for this free +software. If the software is modified by someone else and passed on, we +want its recipients to know that what they have is not the original, so +that any problems introduced by others will not reflect on the original +authors' reputations. + + Finally, any free program is threatened constantly by software +patents. We wish to avoid the danger that redistributors of a free +program will individually obtain patent licenses, in effect making the +program proprietary. To prevent this, we have made it clear that any +patent must be licensed for everyone's free use or not licensed at all. + + The precise terms and conditions for copying, distribution and +modification follow. + + GNU GENERAL PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. This License applies to any program or other work which contains +a notice placed by the copyright holder saying it may be distributed +under the terms of this General Public License. The "Program", below, +refers to any such program or work, and a "work based on the Program" +means either the Program or any derivative work under copyright law: +that is to say, a work containing the Program or a portion of it, +either verbatim or with modifications and/or translated into another +language. (Hereinafter, translation is included without limitation in +the term "modification".) Each licensee is addressed as "you". + +Activities other than copying, distribution and modification are not +covered by this License; they are outside its scope. The act of +running the Program is not restricted, and the output from the Program +is covered only if its contents constitute a work based on the +Program (independent of having been made by running the Program). +Whether that is true depends on what the Program does. + + 1. You may copy and distribute verbatim copies of the Program's +source code as you receive it, in any medium, provided that you +conspicuously and appropriately publish on each copy an appropriate +copyright notice and disclaimer of warranty; keep intact all the +notices that refer to this License and to the absence of any warranty; +and give any other recipients of the Program a copy of this License +along with the Program. + +You may charge a fee for the physical act of transferring a copy, and +you may at your option offer warranty protection in exchange for a fee. + + 2. You may modify your copy or copies of the Program or any portion +of it, thus forming a work based on the Program, and copy and +distribute such modifications or work under the terms of Section 1 +above, provided that you also meet all of these conditions: + + a) You must cause the modified files to carry prominent notices + stating that you changed the files and the date of any change. + + b) You must cause any work that you distribute or publish, that in + whole or in part contains or is derived from the Program or any + part thereof, to be licensed as a whole at no charge to all third + parties under the terms of this License. + + c) If the modified program normally reads commands interactively + when run, you must cause it, when started running for such + interactive use in the most ordinary way, to print or display an + announcement including an appropriate copyright notice and a + notice that there is no warranty (or else, saying that you provide + a warranty) and that users may redistribute the program under + these conditions, and telling the user how to view a copy of this + License. (Exception: if the Program itself is interactive but + does not normally print such an announcement, your work based on + the Program is not required to print an announcement.) + +These requirements apply to the modified work as a whole. If +identifiable sections of that work are not derived from the Program, +and can be reasonably considered independent and separate works in +themselves, then this License, and its terms, do not apply to those +sections when you distribute them as separate works. But when you +distribute the same sections as part of a whole which is a work based +on the Program, the distribution of the whole must be on the terms of +this License, whose permissions for other licensees extend to the +entire whole, and thus to each and every part regardless of who wrote it. + +Thus, it is not the intent of this section to claim rights or contest +your rights to work written entirely by you; rather, the intent is to +exercise the right to control the distribution of derivative or +collective works based on the Program. + +In addition, mere aggregation of another work not based on the Program +with the Program (or with a work based on the Program) on a volume of +a storage or distribution medium does not bring the other work under +the scope of this License. + + 3. You may copy and distribute the Program (or a work based on it, +under Section 2) in object code or executable form under the terms of +Sections 1 and 2 above provided that you also do one of the following: + + a) Accompany it with the complete corresponding machine-readable + source code, which must be distributed under the terms of Sections + 1 and 2 above on a medium customarily used for software interchange; or, + + b) Accompany it with a written offer, valid for at least three + years, to give any third party, for a charge no more than your + cost of physically performing source distribution, a complete + machine-readable copy of the corresponding source code, to be + distributed under the terms of Sections 1 and 2 above on a medium + customarily used for software interchange; or, + + c) Accompany it with the information you received as to the offer + to distribute corresponding source code. (This alternative is + allowed only for noncommercial distribution and only if you + received the program in object code or executable form with such + an offer, in accord with Subsection b above.) + +The source code for a work means the preferred form of the work for +making modifications to it. For an executable work, complete source +code means all the source code for all modules it contains, plus any +associated interface definition files, plus the scripts used to +control compilation and installation of the executable. However, as a +special exception, the source code distributed need not include +anything that is normally distributed (in either source or binary +form) with the major components (compiler, kernel, and so on) of the +operating system on which the executable runs, unless that component +itself accompanies the executable. + +If distribution of executable or object code is made by offering +access to copy from a designated place, then offering equivalent +access to copy the source code from the same place counts as +distribution of the source code, even though third parties are not +compelled to copy the source along with the object code. + + 4. You may not copy, modify, sublicense, or distribute the Program +except as expressly provided under this License. Any attempt +otherwise to copy, modify, sublicense or distribute the Program is +void, and will automatically terminate your rights under this License. +However, parties who have received copies, or rights, from you under +this License will not have their licenses terminated so long as such +parties remain in full compliance. + + 5. You are not required to accept this License, since you have not +signed it. However, nothing else grants you permission to modify or +distribute the Program or its derivative works. These actions are +prohibited by law if you do not accept this License. Therefore, by +modifying or distributing the Program (or any work based on the +Program), you indicate your acceptance of this License to do so, and +all its terms and conditions for copying, distributing or modifying +the Program or works based on it. + + 6. Each time you redistribute the Program (or any work based on the +Program), the recipient automatically receives a license from the +original licensor to copy, distribute or modify the Program subject to +these terms and conditions. You may not impose any further +restrictions on the recipients' exercise of the rights granted herein. +You are not responsible for enforcing compliance by third parties to +this License. + + 7. If, as a consequence of a court judgment or allegation of patent +infringement or for any other reason (not limited to patent issues), +conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot +distribute so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you +may not distribute the Program at all. For example, if a patent +license would not permit royalty-free redistribution of the Program by +all those who receive copies directly or indirectly through you, then +the only way you could satisfy both it and this License would be to +refrain entirely from distribution of the Program. + +If any portion of this section is held invalid or unenforceable under +any particular circumstance, the balance of the section is intended to +apply and the section as a whole is intended to apply in other +circumstances. + +It is not the purpose of this section to induce you to infringe any +patents or other property right claims or to contest validity of any +such claims; this section has the sole purpose of protecting the +integrity of the free software distribution system, which is +implemented by public license practices. Many people have made +generous contributions to the wide range of software distributed +through that system in reliance on consistent application of that +system; it is up to the author/donor to decide if he or she is willing +to distribute software through any other system and a licensee cannot +impose that choice. + +This section is intended to make thoroughly clear what is believed to +be a consequence of the rest of this License. + + 8. If the distribution and/or use of the Program is restricted in +certain countries either by patents or by copyrighted interfaces, the +original copyright holder who places the Program under this License +may add an explicit geographical distribution limitation excluding +those countries, so that distribution is permitted only in or among +countries not thus excluded. In such case, this License incorporates +the limitation as if written in the body of this License. + + 9. The Free Software Foundation may publish revised and/or new versions +of the General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies a version number of this License which applies to it and "any +later version", you have the option of following the terms and conditions +either of that version or of any later version published by the Free +Software Foundation. If the Program does not specify a version number of +this License, you may choose any version ever published by the Free Software +Foundation. + + 10. If you wish to incorporate parts of the Program into other free +programs whose distribution conditions are different, write to the author +to ask for permission. For software which is copyrighted by the Free +Software Foundation, write to the Free Software Foundation; we sometimes +make exceptions for this. Our decision will be guided by the two goals +of preserving the free status of all derivatives of our free software and +of promoting the sharing and reuse of software generally. + + NO WARRANTY + + 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY +FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN +OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES +PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED +OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS +TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE +PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, +REPAIR OR CORRECTION. + + 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR +REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING +OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED +TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY +YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER +PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE +POSSIBILITY OF SUCH DAMAGES. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +convey the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +Also add information on how to contact you by electronic and paper mail. + +If the program is interactive, make it output a short notice like this +when it starts in an interactive mode: + + Gnomovision version 69, Copyright (C) year name of author + Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, the commands you use may +be called something other than `show w' and `show c'; they could even be +mouse-clicks or menu items--whatever suits your program. + +You should also get your employer (if you work as a programmer) or your +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. Here is a sample; alter the names: + + Yoyodyne, Inc., hereby disclaims all copyright interest in the program + `Gnomovision' (which makes passes at compilers) written by James Hacker. + + , 1 April 1989 + Ty Coon, President of Vice + +This General Public License does not permit incorporating your program into +proprietary programs. If your program is a subroutine library, you may +consider it more useful to permit linking proprietary applications with the +library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/README.md b/cust_op/cust_op_by_addr/cmake/util/makeself/README.md new file mode 100644 index 00000000..b41f0168 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/makeself/README.md @@ -0,0 +1,246 @@ +[![License: GPL v2](https://img.shields.io/badge/License-GPL%20v2-blue.svg)](https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html) +![Build Status](https://github.com/megastep/makeself/workflows/CI/badge.svg) + +# makeself - Make self-extractable archives on Unix + +[makeself.sh][1] is a small shell script that generates a self-extractable +compressed tar archive from a directory. The resulting file appears as a shell script +(many of those have a **.run** suffix), and can be launched as is. The archive +will then uncompress itself to a temporary directory and an optional arbitrary +command will be executed (for example an installation script). This is pretty +similar to archives generated with WinZip Self-Extractor in the Windows world. +Makeself archives also include checksums for integrity self-validation (CRC +and/or MD5/SHA256 checksums). + +The makeself.sh script itself is used only to create the archives from a +directory of files. The resultant archive is actually a compressed (using +gzip, bzip2, or compress) TAR archive, with a small shell script stub at the +beginning. This small stub performs all the steps of extracting the files, +running the embedded command, and removing the temporary files when done. +All the user has to do to install the software contained in such an +archive is to "run" the archive, i.e **sh nice-software.run**. I recommend +using the ".run" (which was introduced by some Makeself archives released by +Loki Software) or ".sh" suffix for such archives not to confuse the users, +so that they will know they are actually shell scripts (with quite a lot of binary data +attached to them though!). + +I am trying to keep the code of this script as portable as possible, i.e it is +not relying on any bash-specific features and only calls commands that are +installed on any functioning UNIX-compatible system. This script as well as +the archives it generates should run on any Unix flavor, with any compatible +Bourne shell, provided of course that the compression programs are available. + +As of version 2.1, Makeself has been rewritten and tested on the following +platforms : + + * Linux (all distributions) + * Sun Solaris (8 and above) + * HP-UX (tested on 11.0 and 11i on HPPA RISC) + * SCO OpenUnix and OpenServer + * IBM AIX 5.1L + * macOS (Darwin) + * SGI IRIX 6.5 + * FreeBSD + * UnicOS / Cray + * Cygwin (Windows) + +If you successfully run Makeself and/or archives created with it on another +system, then please [let me know][2]! + +Examples of publicly available archives made using makeself are : + + * Game patches and installers for [Id Software][3] games like Quake 3 for Linux or Return To Castle Wolfenstein ; + * All game patches released by [Loki Software][4] for the Linux version of popular games ; + * The [nVidia drivers][5] for Linux + * The installer for the Linux version of [Google Earth][6] + * The [VirtualBox][7] installers for Linux + * The [Makeself][1] distribution itself ;-) + * and countless others... + +**Important note for Apache users:** By default, most Web servers will think that Makeself archives are regular text files and thus they may show up as text in a Web browser. The correct way to prevent this is to add a MIME type for this file format, like so (in httpd.conf) : + +`AddType application/x-makeself .run` + +**Important note for certain GNU/Linux distributions:** Archives created with Makeself prior to v2.1.2 were using an old syntax for the _head_ and _tail_ Unix commands that is being progressively obsoleted in their GNU forms. Therefore you may have problems uncompressing some of these archives. A workaround for this is to set the environment variable $_POSIX2_VERSION to enable the old syntax, i.e. : + +`export _POSIX2_VERSION=199209` + +## Usage + +The syntax of makeself is the following: + +``` +makeself.sh [args] archive_dir file_name label startup_script [script_args] +``` + + * _args_ are optional options for Makeself. The available ones are : + + * **`--version`** : Prints the version number on stdout, then exits immediately + * **`--gzip`** : Use gzip for compression (the default on platforms on which gzip is commonly available, like Linux) + * **`--bzip2`** : Use bzip2 instead of gzip for better compression. The bzip2 command must be available in the command path. It is recommended that the archive prefix be set to something like '.bz2.run', so that potential users know that they'll need bzip2 to extract it. + * **`--pbzip2`** : Use pbzip2 instead of gzip for better and faster compression on machines having multiple CPUs. The pbzip2 command must be available in the command path. It is recommended that the archive prefix be set to something like '.bz2.run', so that potential users know that they'll need bzip2 to extract it. + * **`--xz`** : Use xz instead of gzip for better compression. The xz command must be available in the command path. It is recommended that the archive prefix be set to something like '.xz.run' for the archive, so that potential users know that they'll need xz to extract it. + * **`--lzo`** : Use lzop instead of gzip for better compression. The lzop command must be available in the command path. It is recommended that the archive prefix be set to something like `.lzo.run` for the archive, so that potential users know that they'll need lzop to extract it. + * **`--lz4`** : Use lz4 instead of gzip for better compression. The lz4 command must be available in the command path. It is recommended that the archive prefix be set to something like '.lz4.run' for the archive, so that potential users know that they'll need lz4 to extract it. + * **`--zstd`** : Use zstd instead of gzip for better compression. The zstd command must be available in the command path. It is recommended that the archive prefix be set to something like '.zstd.run' for the archive, so that potential users know that they'll need zstd to extract it. + * **`--pigz`** : Use pigz for compression. + * **`--base64`** : Encode the archive to ASCII in Base64 format instead of compressing (base64 command required). + * **`--gpg-encrypt`** : Encrypt the archive using `gpg -ac -z $COMPRESS_LEVEL`. This will prompt for a password to encrypt with. Assumes that potential users have `gpg` installed. + * **`--ssl-encrypt`** : Encrypt the archive using `openssl aes-256-cbc -a -salt`. This will prompt for a password to encrypt with. Assumes that the potential users have the OpenSSL tools installed. + * **`--compress`** : Use the UNIX `compress` command to compress the data. This should be the default on all platforms that don't have gzip available. + * **`--nocomp`** : Do not use any compression for the archive, which will then be an uncompressed TAR. + * **`--complevel`** : Specify the compression level for gzip, bzip2, pbzip2, zstd, xz, lzo or lz4. (defaults to 9) + * **`--threads`** : Specify the number of threads to be used by compressors that support parallelization. Omit to use compressor's default. Most useful (and required) for opting into xz's threading, usually with `--threads=0` for all available cores. pbzip2 and pigz are parallel by default, and setting this value allows limiting the number of threads they use. + * **`--notemp`** : The generated archive will not extract the files to a temporary directory, but in a new directory created in the current directory. This is better to distribute software packages that may extract and compile by themselves (i.e. launch the compilation through the embedded script). + * **`--current`** : Files will be extracted to the current directory, instead of in a subdirectory. This option implies `--notemp` above. + * **`--follow`** : Follow the symbolic links inside of the archive directory, i.e. store the files that are being pointed to instead of the links themselves. + * **`--append`** _(new in 2.1.x)_: Append data to an existing archive, instead of creating a new one. In this mode, the settings from the original archive are reused (compression type, label, embedded script), and thus don't need to be specified again on the command line. + * **`--header`** : Makeself uses a separate file to store the header stub, called `makeself-header.sh`. By default, it is assumed that it is stored in the same location as makeself.sh. This option can be used to specify its actual location if it is stored someplace else. + * **`--cleanup`** : Specify a script that is run when execution is interrupted or finishes successfully. The script is executed with the same environment and initial `script_args` as `startup_script`. + * **`--copy`** : Upon extraction, the archive will first extract itself to a temporary directory. The main application of this is to allow self-contained installers stored in a Makeself archive on a CD, when the installer program will later need to unmount the CD and allow a new one to be inserted. This prevents "Filesystem busy" errors for installers that span multiple CDs. + * **`--nox11`** : Disable the automatic spawning of a new terminal in X11. + * **`--nowait`** : When executed from a new X11 terminal, disable the user prompt at the end of the script execution. + * **`--nomd5`** and **`--nocrc`** : Disable the creation of a MD5 / CRC checksum for the archive. This speeds up the extraction process if integrity checking is not necessary. + * **`--sha256`** : Adds a SHA256 checksum for the archive. This is in addition to the MD5 / CRC checksums unless `--nomd5` is also used. + * **`--lsm` _file_** : Provide and LSM file to makeself, that will be embedded in the generated archive. LSM files are describing a software package in a way that is easily parseable. The LSM entry can then be later retrieved using the `--lsm` argument to the archive. An example of a LSM file is provided with Makeself. + * **`--tar-format opt`** : Specify the tar archive format (default is ustar); you may use any value accepted by your tar command (such as posix, v7, etc). + * **`--tar-extra opt`** : Append more options to the tar command line. + + For instance, in order to exclude the `.git` directory from the packaged archive directory using the GNU `tar`, one can use `makeself.sh --tar-extra "--exclude=.git" ...` + + * **`--keep-umask`** : Keep the umask set to shell default, rather than overriding when executing self-extracting archive. + * **`--packaging-date date`** : Use provided string as the packaging date instead of the current date. + * **`--license`** : Append a license file. + * **`--nooverwrite`** : Do not extract the archive if the specified target directory already exists. + * **`--help-header file`** : Add a header to the archive's `--help` output. + * `archive_dir` is the name of the directory that contains the files to be archived + * `file_name` is the name of the archive to be created + * `label` is an arbitrary text string describing the package. It will be displayed while extracting the files. + * `startup_script` is the command to be executed _from within_ the directory of extracted files. Thus, if you wish to execute a program contained in this directory, you must prefix your command with `./`. For example, `./program` will be fine. The `script_args` are additional arguments for this command. + +Here is an example, assuming the user has a package image stored in a **/home/joe/mysoft**, and he wants to generate a self-extracting package named +**mysoft.sh**, which will launch the "setup" script initially stored in /home/joe/mysoft : + +`makeself.sh /home/joe/mysoft mysoft.sh "Joe's Nice Software Package" ./setup +` + +Here is also how I created the [makeself.run][9] archive which contains the Makeself distribution : + +`makeself.sh --notemp makeself makeself.run "Makeself by Stephane Peter" echo "Makeself has extracted itself" ` + +Archives generated with Makeself can be passed the following arguments: + + * **`--keep`** : Prevent the files to be extracted in a temporary directory that will be removed after the embedded script's execution. The files will then be extracted in the current working directory and will stay here until you remove them. + * **`--verbose`** : Will prompt the user before executing the embedded command + * **`--target dir`** : Allows to extract the archive in an arbitrary place. + * **`--nox11`** : Do not spawn a X11 terminal. + * **`--confirm`** : Prompt the user for confirmation before running the embedded command. + * **`--info`** : Print out general information about the archive (does not extract). + * **`--lsm`** : Print out the LSM entry, if it is present. + * **`--list`** : List the files in the archive. + * **`--check`** : Check the archive for integrity using the embedded checksums. Does not extract the archive. + * **`--nochown`** : By default, a `chown -R` command is run on the target directory after extraction, so that all files belong to the current user. This is mostly needed if you are running as root, as tar will then try to recreate the initial user ownerships. You may disable this behavior with this flag. + * **`--tar`** : Run the tar command on the contents of the archive, using the following arguments as parameter for the command. + * **`--noexec`** : Do not run the embedded script after extraction. + * **`--noexec-cleanup`** : Do not run the embedded cleanup script. + * **`--nodiskspace`** : Do not check for available disk space before attempting to extract. + * **`--cleanup-args`** : Specify arguments to be passed to the cleanup script. Wrap value in quotes to specify multiple arguments. + +Any subsequent arguments to the archive will be passed as additional arguments to the embedded command. You must explicitly use the `--` special command-line construct before any such options to make sure that Makeself will not try to interpret them. + +## Startup Script + +The startup script must be a regular Shell script. + +Within the startup script, you can use the `$USER_PWD` variable to get the path of the folder from which the self-extracting script is executed. This is especially useful to access files that are located in the same folder as the script, as shown in the example below. + +`my-self-extracting-script.sh --fooBarFileParameter foo.bar` + +## Building and Testing + +Clone the git repo and execute `git submodule update --init --recursive` to obtain all submodules. + +* To make a release: `make` +* To run all tests: `make test` + +## Maven Usage + +Makeself is now supported by the following maven plugin [makeself-maven-plugin](https://github.com/hazendaz/makeself-maven-plugin). Please refer to project for usage and report any bugs in regards to maven plugin on that project. + +## License + +Makeself itself is covered by the [GNU General Public License][8] (GPL) version 2 and above. Archives generated by Makeself don't have to be placed under this license (although I encourage it ;-)), since the archive itself is merely data for Makeself. + +## Contributing + +I will gladly consider merging your pull requests on the [GitHub][10] repository. However, please keep the following in mind: + + * One of the main purposes of Makeself is portability. Do not submit patches that will break supported platforms. The more platform-agnostic, the better. + * Please explain clearly what the purpose of the patch is, and how you achieved it. + +## Download + +Get the latest official distribution [here][9] (version 2.4.2). + +The latest development version can be grabbed from [GitHub][10]. Feel free to submit any patches there through the fork and pull request process. + +## Version history + + * **v1.0:** Initial public release + * **v1.1:** The archive can be passed parameters that will be passed on to the embedded script, thanks to John C. Quillan + * **v1.2:** Cosmetic updates, support for bzip2 compression and non-temporary archives. Many ideas thanks to Francois Petitjean. + * **v1.3:** More patches from Bjarni R. Einarsson and Francois Petitjean: Support for no compression (`--nocomp`), script is no longer mandatory, automatic launch in an xterm, optional verbose output, and -target archive option to indicate where to extract the files. + * **v1.4:** Many patches from Francois Petitjean: improved UNIX compatibility, automatic integrity checking, support of LSM files to get info on the package at run time.. + * **v1.5.x:** A lot of bugfixes, and many other patches, including automatic verification through the usage of checksums. Version 1.5.5 was the stable release for a long time, even though the Web page didn't get updated ;-). Makeself was also officially made a part of the [Loki Setup installer][11], and its source is being maintained as part of this package. + * **v2.0:** Complete internal rewrite of Makeself. The command-line parsing was vastly improved, the overall maintenance of the package was greatly improved by separating the stub from makeself.sh. Also Makeself was ported and tested to a variety of Unix platforms. + * **v2.0.1:** First public release of the new 2.0 branch. Prior versions are officially obsoleted. This release introduced the `--copy` argument that was introduced in response to a need for the [UT2K3][12] Linux installer. + * **v2.1.0:** Big change : Makeself can now support multiple embedded tarballs, each stored separately with their own checksums. An existing archive can be updated with the `--append` flag. Checksums are also better managed, and the `--nochown` option for archives appeared. + * **v2.1.1:** Fixes related to the Unix compression (compress command). Some Linux distributions made the insane choice to make it unavailable, even though gzip is capable of uncompressing these files, plus some more bugfixes in the extraction and checksum code. + * **v2.1.2:** Some bug fixes. Use head -n to avoid problems with POSIX conformance. + * **v2.1.3:** Bug fixes with the command line when spawning terminals. Added `--tar`, `--noexec` for archives. Added `--nomd5` and `--nocrc` to avoid creating checksums in archives. The embedded script is now run through "eval". The `--info` output now includes the command used to create the archive. A man page was contributed by Bartosz Fenski. + * **v2.1.4:** Fixed `--info` output. Generate random directory name when extracting files to . to avoid problems. Better handling of errors with wrong permissions for the directory containing the files. Avoid some race conditions, Unset the $CDPATH variable to avoid problems if it is set. Better handling of dot files in the archive directory. + * **v2.1.5:** Made the md5sum detection consistent with the header code. Check for the presence of the archive directory. Added `--encrypt` for symmetric encryption through gpg (Eric Windisch). Added support for the digest command on Solaris 10 for MD5 checksums. Check for available disk space before extracting to the target directory (Andreas Schweitzer). Allow extraction to run asynchronously (patch by Peter Hatch). Use file descriptors internally to avoid error messages (patch by Kay Tiong Khoo). + * **v2.1.6:** Replaced one dot per file progress with a realtime progress percentage and a spinning cursor. Added `--noprogress` to prevent showing the progress during the decompression. Added `--target` dir to allow extracting directly to a target directory. (Guy Baconniere) + * **v2.2.0:** First major new release in years! Includes many bugfixes and user contributions. Please look at the [project page on Github][10] for all the details. + * **v2.3.0:** Support for archive encryption via GPG or OpenSSL. Added LZO and LZ4 compression support. Options to set the packaging date and stop the umask from being overriden. Optionally ignore check for available disk space when extracting. New option to check for root permissions before extracting. + * **v2.3.1:** Various compatibility updates. Added unit tests for Travis CI in the GitHub repo. New `--tar-extra`, `--untar-extra`, `--gpg-extra`, `--gpg-asymmetric-encrypt-sign` options. + * **v2.4.0:** Added optional support for SHA256 archive integrity checksums. + * **v2.4.2:** New --cleanup and --cleanup-args arguments for cleanup scripts. Added threading support for supported compressors. Now supports zstd compression. + * **v2.4.3:** Make explicit POSIX tar archives for increased compatibility. + * **v2.4.4:** Fixed various compatibility issues (no longer use POSIX tar archives), Github Actions to check on Solaris and FreeBSD. + * **v2.4.5:** Added `--tar-format` option to set the tar archive format (default is ustar) + +## Links + + * Check out the ["Loki Setup"][11] installer, used to install many Linux games and other applications, and of which I am the co-author. Since the demise of Loki, I am now the official maintainer of the project, and it is now being hosted here on GitHub. + * Bjarni R. Einarsson also wrote the **setup.sh** installer script, inspired by Makeself. [Check it out !][14] + +## Contact + +This script was written by [Stéphane Peter][15] (megastep at megastep.org). Any enhancements and suggestions are welcome. + +Contributions were included from John C. Quillan, Bjarni R. Einarsson, +Francois Petitjean, Ryan C. Gordon, and many contributors on GitHub. If you think I forgot +your name, don't hesitate to contact me. + +This project is now hosted on GitHub. Feel free to submit patches and bug reports on the [project page][10]. + +* * * + +[Stephane Peter][2] + + [1]: http://makeself.io/ + [2]: mailto:megastep@megastep.org + [3]: http://www.idsoftware.com/ + [4]: http://www.lokigames.com/products/myth2/updates.php3 + [5]: http://www.nvidia.com/ + [6]: http://earth.google.com/ + [7]: http://www.virtualbox.org/ + [8]: http://www.gnu.org/copyleft/gpl.html + [9]: https://github.com/megastep/makeself/releases/download/release-2.4.5/makeself-2.4.5.run + [10]: https://github.com/megastep/makeself + [11]: https://github.com/megastep/loki_setup/ + [12]: http://www.unrealtournament2003.com/ + [13]: http://www.icculus.org/ + [14]: http://bre.klaki.net/programs/setup.sh/ + [15]: https://stephanepeter.com/ diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/VERSION b/cust_op/cust_op_by_addr/cmake/util/makeself/VERSION new file mode 100644 index 00000000..59aa62c1 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/makeself/VERSION @@ -0,0 +1 @@ +2.4.5 diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/make-release.sh b/cust_op/cust_op_by_addr/cmake/util/makeself/make-release.sh new file mode 100644 index 00000000..b5692d49 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/makeself/make-release.sh @@ -0,0 +1,9 @@ +#!/bin/sh +# +# Create a distributable archive of the current version of Makeself + +VER=`cat VERSION` +mkdir -p /tmp/makeself-$VER release +cp -pPR makeself* test README.md COPYING VERSION .gitmodules /tmp/makeself-$VER/ +./makeself.sh --notemp /tmp/makeself-$VER release/makeself-$VER.run "Makeself v$VER" echo "Makeself has extracted itself" + diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself-header.sh b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself-header.sh new file mode 100644 index 00000000..94090314 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself-header.sh @@ -0,0 +1,660 @@ +cat << EOF > "$archname" +#!/bin/bash +# This script was generated using Makeself $MS_VERSION +# The license covering this archive and its contents, if any, is wholly independent of the Makeself license (GPL) +# 2022.3.19-Modified the MS_Help function and some options +# Huawei Technologies Co., Ltd. + +ORIG_UMASK=\`umask\` + +CRCsum="$CRCsum" +MD5="$MD5sum" +SHA="$SHAsum" +SIGNATURE="$Signature" +TMPROOT=\${TMPDIR:="\$HOME"} +if ! test -d "\$TMPROOT"; then + TMPROOT="\$PWD" +fi +export TMPDIR="\$TMPROOT" +USER_PWD="\$PWD" +if ! test -d "\$USER_PWD"; then + exit 1 +fi +export USER_PWD +ARCHIVE_DIR=\`dirname "\$0"\` +export ARCHIVE_DIR + +name_of_file="\$0 " +pwd_of_file="\$PWD" +label="$LABEL" +script="$SCRIPT" +scriptargs="$SCRIPTARGS" +cleanup_script="${CLEANUP_SCRIPT}" +licensetxt="$LICENSE" +helpheader='$HELPHEADER' +targetdir="$archdirname" +filesizes="$filesizes" +totalsize="$totalsize" +keep="$KEEP" +nooverwrite="$NOOVERWRITE" +quiet="n" +accept="n" +nodiskspace="n" +export_conf="$EXPORT_CONF" +decrypt_cmd="$DECRYPT_CMD" +skip="$SKIP" + +print_cmd_arg="" +if type printf > /dev/null; then + print_cmd="printf" +elif test -x /usr/ucb/echo; then + print_cmd="/usr/ucb/echo" +else + print_cmd="echo" +fi + +if test -d /usr/xpg4/bin; then + PATH=/usr/xpg4/bin:\$PATH + export PATH +fi + +if test -d /usr/sfw/bin; then + PATH=\$PATH:/usr/sfw/bin + export PATH +fi + +unset CDPATH + +MS_Printf() +{ + \$print_cmd \$print_cmd_arg "\$1" +} + +MS_PrintLicense() +{ + PAGER=\${PAGER:=more} + if test x"\$licensetxt" != x; then + PAGER_PATH=\`exec <&- 2>&-; which \$PAGER || command -v \$PAGER || type \$PAGER\` + if test -x "\$PAGER_PATH"; then + echo "\$licensetxt" | \$PAGER + else + echo "\$licensetxt" + fi + if test x"\$accept" != xy; then + while true + do + MS_Printf "Please type y to accept, n otherwise: " + read yn + if test x"\$yn" = xn; then + keep=n + eval \$finish; exit 1 + break; + elif test x"\$yn" = xy; then + break; + fi + done + fi + fi +} + +MS_diskspace() +{ + ( + df -kP "\$1" | tail -1 | awk '{ if (\$4 ~ /%/) {print \$3} else {print \$4} }' + ) +} + +MS_dd() +{ + blocks=\`expr \$3 / 1024\` + bytes=\`expr \$3 % 1024\` + # Test for ibs, obs and conv feature + if dd if=/dev/zero of=/dev/null count=1 ibs=512 obs=512 conv=sync 2> /dev/null; then + dd if="\$1" ibs=\$2 skip=1 obs=1024 conv=sync 2> /dev/null | \\ + { test \$blocks -gt 0 && dd ibs=1024 obs=1024 count=\$blocks ; \\ + test \$bytes -gt 0 && dd ibs=1 obs=1024 count=\$bytes ; } 2> /dev/null + else + dd if="\$1" bs=\$2 skip=1 2> /dev/null + fi +} + +MS_dd_Progress() +{ + if test x"\$noprogress" = xy; then + MS_dd "\$@" + return \$? + fi + file="\$1" + offset=\$2 + length=\$3 + pos=0 + bsize=4194304 + while test \$bsize -gt \$length; do + bsize=\`expr \$bsize / 4\` + done + blocks=\`expr \$length / \$bsize\` + bytes=\`expr \$length % \$bsize\` + ( + dd ibs=\$offset skip=1 2>/dev/null + pos=\`expr \$pos \+ \$bsize\` + MS_Printf " 0%% " 1>&2 + if test \$blocks -gt 0; then + while test \$pos -le \$length; do + dd bs=\$bsize count=1 2>/dev/null + pcent=\`expr \$length / 100\` + pcent=\`expr \$pos / \$pcent\` + if test \$pcent -lt 100; then + MS_Printf "\b\b\b\b\b\b\b" 1>&2 + if test \$pcent -lt 10; then + MS_Printf " \$pcent%% " 1>&2 + else + MS_Printf " \$pcent%% " 1>&2 + fi + fi + pos=\`expr \$pos \+ \$bsize\` + done + fi + if test \$bytes -gt 0; then + dd bs=\$bytes count=1 2>/dev/null + fi + MS_Printf "\b\b\b\b\b\b\b" 1>&2 + MS_Printf " 100%% " 1>&2 + ) < "\$file" +} + +MS_Help() +{ + cat << EOH >&2 +Usage: \$0 [options] +Options: + --help | -h Print this message + --info Print embedded info : title, default target directory, embedded script ... + --list Print the list of files in the archive + --check Checks integrity and version dependency of the archive + --quiet Quiet install mode, skip human-computer interactions + --nox11 Do not spawn an xterm + --noexec Do not run embedded script + --extract= Extract directly to a target directory (absolute or relative) + Usually used with --noexec to just extract files without running + --tar arg1 [arg2 ...] Access the contents of the archive through the tar command +\${helpheader} +EOH +} + +MS_Verify_Sig() +{ + GPG_PATH=\`exec <&- 2>&-; which gpg || command -v gpg || type gpg\` + MKTEMP_PATH=\`exec <&- 2>&-; which mktemp || command -v mktemp || type mktemp\` + test -x "\$GPG_PATH" || GPG_PATH=\`exec <&- 2>&-; which gpg || command -v gpg || type gpg\` + test -x "\$MKTEMP_PATH" || MKTEMP_PATH=\`exec <&- 2>&-; which mktemp || command -v mktemp || type mktemp\` + offset=\`head -n "\$skip" "\$1" | wc -c | tr -d " "\` + temp_sig=\`mktemp -t XXXXX\` + echo \$SIGNATURE | base64 --decode > "\$temp_sig" + gpg_output=\`MS_dd "\$1" \$offset \$totalsize | LC_ALL=C "\$GPG_PATH" --verify "\$temp_sig" - 2>&1\` + gpg_res=\$? + rm -f "\$temp_sig" + if test \$gpg_res -eq 0 && test \`echo \$gpg_output | grep -c Good\` -eq 1; then + if test \`echo \$gpg_output | grep -c \$sig_key\` -eq 1; then + test x"\$quiet" = xn && echo "GPG signature is good" >&2 + else + echo "GPG Signature key does not match" >&2 + exit 2 + fi + else + test x"\$quiet" = xn && echo "GPG signature failed to verify" >&2 + exit 2 + fi +} + +MS_Check() +{ + OLD_PATH="\$PATH" + PATH=\${GUESS_MD5_PATH:-"\$OLD_PATH:/bin:/usr/bin:/sbin:/usr/local/ssl/bin:/usr/local/bin:/opt/openssl/bin"} + MD5_ARG="" + MD5_PATH=\`exec <&- 2>&-; which md5sum || command -v md5sum || type md5sum\` + test -x "\$MD5_PATH" || MD5_PATH=\`exec <&- 2>&-; which md5 || command -v md5 || type md5\` + test -x "\$MD5_PATH" || MD5_PATH=\`exec <&- 2>&-; which digest || command -v digest || type digest\` + PATH="\$OLD_PATH" + + SHA_PATH=\`exec <&- 2>&-; which shasum || command -v shasum || type shasum\` + test -x "\$SHA_PATH" || SHA_PATH=\`exec <&- 2>&-; which sha256sum || command -v sha256sum || type sha256sum\` + + if test x"\$quiet" = xn; then + MS_Printf "Verifying archive integrity..." + fi + offset=\`head -n "\$skip" "\$1" | wc -c | tr -d " "\` + fsize=\`cat "\$1" | wc -c | tr -d " "\` + if test \$totalsize -ne \`expr \$fsize - \$offset\`; then + echo " Unexpected archive size." >&2 + exit 2 + fi + verb=\$2 + i=1 + for s in \$filesizes + do + crc=\`echo \$CRCsum | cut -d" " -f\$i\` + if test -x "\$SHA_PATH"; then + if test x"\`basename \$SHA_PATH\`" = xshasum; then + SHA_ARG="-a 256" + fi + sha=\`echo \$SHA | cut -d" " -f\$i\` + if test x"\$sha" = x0000000000000000000000000000000000000000000000000000000000000000; then + test x"\$verb" = xy && echo " \$1 does not contain an embedded SHA256 checksum." >&2 + else + shasum=\`MS_dd_Progress "\$1" \$offset \$s | eval "\$SHA_PATH \$SHA_ARG" | cut -b-64\`; + if test x"\$shasum" != x"\$sha"; then + echo "Error in SHA256 checksums: \$shasum is different from \$sha" >&2 + exit 2 + elif test x"\$quiet" = xn; then + MS_Printf " SHA256 checksums are OK." >&2 + fi + crc="0000000000"; + fi + fi + if test -x "\$MD5_PATH"; then + if test x"\`basename \$MD5_PATH\`" = xdigest; then + MD5_ARG="-a md5" + fi + md5=\`echo \$MD5 | cut -d" " -f\$i\` + if test x"\$md5" = x00000000000000000000000000000000; then + test x"\$verb" = xy && echo " \$1 does not contain an embedded MD5 checksum." >&2 + else + md5sum=\`MS_dd_Progress "\$1" \$offset \$s | eval "\$MD5_PATH \$MD5_ARG" | cut -b-32\`; + if test x"\$md5sum" != x"\$md5"; then + echo "Error in MD5 checksums: \$md5sum is different from \$md5" >&2 + exit 2 + elif test x"\$quiet" = xn; then + MS_Printf " MD5 checksums are OK." >&2 + fi + crc="0000000000"; verb=n + fi + fi + if test x"\$crc" = x0000000000; then + test x"\$verb" = xy && echo " \$1 does not contain a CRC checksum." >&2 + else + sum1=\`MS_dd_Progress "\$1" \$offset \$s | CMD_ENV=xpg4 cksum | awk '{print \$1}'\` + if test x"\$sum1" != x"\$crc"; then + echo "Error in checksums: \$sum1 is different from \$crc" >&2 + exit 2 + elif test x"\$quiet" = xn; then + MS_Printf " CRC checksums are OK." >&2 + fi + fi + i=\`expr \$i + 1\` + offset=\`expr \$offset + \$s\` + done + if test x"\$quiet" = xn; then + echo " All good." + fi +} + +MS_Decompress() +{ + if test x"\$decrypt_cmd" != x""; then + { eval "\$decrypt_cmd" || echo " ... Decryption failed." >&2; } | eval "$GUNZIP_CMD" + else + eval "$GUNZIP_CMD" + fi + + if test \$? -ne 0; then + echo " ... Decompression failed." >&2 + fi +} + +UnTAR() +{ + if test x"\$quiet" = xn; then + tar \$1vf - $UNTAR_EXTRA 2>&1 || { echo " ... Extraction failed." >&2; kill -15 \$$; } + else + tar \$1f - $UNTAR_EXTRA 2>&1 || { echo Extraction failed. >&2; kill -15 \$$; } + fi +} + +MS_exec_cleanup() { + if test x"\$cleanup" = xy && test x"\$cleanup_script" != x""; then + cleanup=n + cd "\$tmpdir" + eval "\"\$cleanup_script\" \$scriptargs \$cleanupargs" + fi +} + +MS_cleanup() +{ + echo 'Signal caught, cleaning up' >&2 + MS_exec_cleanup + cd "\$TMPROOT" + rm -rf "\$tmpdir" + eval \$finish; exit 15 +} + +Script_Args_Check() +{ + script_supported_args=\$(echo \${helpheader} | grep -o -E "\-\-[^ ]+" | awk -F"=" {'print \$1'}) + arg_to_test=\$(echo \$1|awk -F"=" {'print \$1'}) + + for arg in \${script_supported_args}; + do + if test x"\$arg_to_test" = x"\$arg" ;then + return + fi + done + + MS_Help + exit 1 +} + +finish=true +xterm_loop= +noprogress=$NOPROGRESS +nox11=$NOX11 +copy=$COPY +ownership=$OWNERSHIP +verbose=n +cleanup=y +cleanupargs= +sig_key= + +initargs="\$@" + +while [ -n "\$*" ] +do + case "\$1" in + -h | --help) + MS_Help + exit 0 + ;; + -q | --quiet) + quiet=y + noprogress=y + shift + ;; + --info) + echo Identification: "\$label" + echo Target directory: "\$targetdir" + echo Uncompressed size: $USIZE KB + echo Compression: $COMPRESS + if test x"$ENCRYPT" != x""; then + echo Encryption: $ENCRYPT + fi + echo Date of packaging: $DATE + echo Built with Makeself version $MS_VERSION + echo Build command was: "$MS_COMMAND" + if test x"\$script" != x; then + echo Script run after extraction: + echo " " \$script \$scriptargs + fi + if test x"$copy" = xcopy; then + echo "Archive will copy itself to a temporary location" + fi + if test x"$NEED_ROOT" = xy; then + echo "Root permissions required for extraction" + fi + if test x"$KEEP" = xy; then + echo "directory \$targetdir is permanent" + else + echo "\$targetdir will be removed after extraction" + fi + exit 0 + ;; + --list) + echo Target directory: \$targetdir + offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` + for s in \$filesizes + do + MS_dd "\$0" \$offset \$s | MS_Decompress | UnTAR t + offset=\`expr \$offset + \$s\` + done + exit 0 + ;; + --tar) + offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` + arg1="\$2" + shift 2 || { MS_Help; exit 1; } + for s in \$filesizes + do + MS_dd "\$0" \$offset \$s | MS_Decompress | tar "\$arg1" - "\$@" + offset=\`expr \$offset + \$s\` + done + exit 0 + ;; + --check) + MS_Check "\$0" y + scriptargs="\$scriptargs \$1" + shift + ;; + --noexec) + script="" + cleanup_script="" + shift + ;; + --extract=*) + keep=y + targetdir=\`echo \$1 | cut -d"=" -f2 \` + if ! shift; then MS_Help; exit 1; fi + ;; + --nox11) + nox11=y + shift + ;; + --xwin) + if test "$NOWAIT" = n; then + finish="echo Press Return to close this window...; read junk" + fi + xterm_loop=1 + shift + ;; + --phase2) + copy=phase2 + shift + ;; + --repack | --repack-path=*) + Script_Args_Check \$1 + scriptargs="\$scriptargs '\$1'" + shift + if [[ ! "\$1" =~ ^-.* ]]; then + scriptargs="\$scriptargs '\$1'" + shift + fi + ;; + *) + Script_Args_Check \$1 + scriptargs="\$scriptargs '\$1'" + shift + ;; + esac +done + +quiet_para="" +if test x"\$quiet" = xy; then + quiet_para="--quiet " +fi +scriptargs="--\$name_of_file""--\"\$pwd_of_file\""" \$quiet_para""\$scriptargs" + +if test x"\$quiet" = xy -a x"\$verbose" = xy; then + echo Cannot be verbose and quiet at the same time. >&2 + exit 1 +fi + +if test x"$NEED_ROOT" = xy -a \`id -u\` -ne 0; then + echo "Administrative privileges required for this archive (use su or sudo)" >&2 + exit 1 +fi + +if test x"\$copy" \!= xphase2; then + MS_PrintLicense +fi + +case "\$copy" in +copy) + tmpdir="\$TMPROOT"/makeself.\$RANDOM.\`date +"%y%m%d%H%M%S"\`.\$\$ + mkdir "\$tmpdir" || { + echo "Could not create temporary directory \$tmpdir" >&2 + exit 1 + } + SCRIPT_COPY="\$tmpdir/makeself" + echo "Copying to a temporary location..." >&2 + cp "\$0" "\$SCRIPT_COPY" + chmod +x "\$SCRIPT_COPY" + cd "\$TMPROOT" + exec "\$SCRIPT_COPY" --phase2 -- \$initargs + ;; +phase2) + finish="\$finish ; rm -rf \`dirname \$0\`" + ;; +esac + +if test x"\$nox11" = xn; then + if tty -s; then # Do we have a terminal? + : + else + if test x"\$DISPLAY" != x -a x"\$xterm_loop" = x; then # No, but do we have X? + if xset q > /dev/null 2>&1; then # Check for valid DISPLAY variable + GUESS_XTERMS="xterm gnome-terminal rxvt dtterm eterm Eterm xfce4-terminal lxterminal kvt konsole aterm terminology" + for a in \$GUESS_XTERMS; do + if type \$a >/dev/null 2>&1; then + XTERM=\$a + break + fi + done + chmod a+x \$0 || echo Please add execution rights on \$0 + if test \`echo "\$0" | cut -c1\` = "/"; then # Spawn a terminal! + exec \$XTERM -e "\$0 --xwin \$initargs" + else + exec \$XTERM -e "./\$0 --xwin \$initargs" + fi + fi + fi + fi +fi + +if test x"\$targetdir" = x.; then + tmpdir="." +else + if test x"\$keep" = xy; then + if test x"\$nooverwrite" = xy && test -d "\$targetdir"; then + echo "Target directory \$targetdir already exists, aborting." >&2 + exit 1 + fi + if test x"\$quiet" = xn; then + echo "Creating directory \$targetdir" >&2 + fi + tmpdir="\$targetdir" + dashp="-p" + else + tmpdir="\$TMPROOT/selfgz\$\$\$RANDOM" + dashp="" + fi + mkdir \$dashp "\$tmpdir" || { + echo 'Cannot create target directory' \$tmpdir >&2 + echo 'You should try option --extract=' >&2 + eval \$finish + exit 1 + } +fi + +location="\`pwd\`" +if test x"\$SETUP_NOCHECK" != x1; then + MS_Check "\$0" +fi +offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` + +if test x"\$verbose" = xy; then + MS_Printf "About to extract $USIZE KB in \$tmpdir ... Proceed ? [Y/n] " + read yn + if test x"\$yn" = xn; then + eval \$finish; exit 1 + fi +fi + +if test x"\$quiet" = xn; then + # Decrypting with openssl will ask for password, + # the prompt needs to start on new line + if test x"$ENCRYPT" = x"openssl"; then + echo "Decrypting and uncompressing \$label..." + else + MS_Printf "Uncompressing \$label" + fi +fi +res=3 +if test x"\$keep" = xn; then + trap MS_cleanup 1 2 3 15 +fi + +if test x"\$nodiskspace" = xn; then + leftspace=\`MS_diskspace "\$tmpdir"\` + if test -n "\$leftspace"; then + if test "\$leftspace" -lt $USIZE; then + echo + echo "Not enough space left in "\`dirname \$tmpdir\`" (\$leftspace KB) to decompress \$0 ($USIZE KB)" >&2 + if test x"\$keep" = xn; then + echo "Consider setting TMPDIR to a directory with more free space." + fi + eval \$finish; exit 1 + fi + fi +fi + +for s in \$filesizes +do + if MS_dd_Progress "\$0" \$offset \$s | MS_Decompress | ( cd "\$tmpdir"; umask \$ORIG_UMASK ; UnTAR xp ) 1>/dev/null; then + if test x"\$ownership" = xy; then + (cd "\$tmpdir"; chown -R \`id -u\` .; chgrp -R \`id -g\` .) + fi + else + echo >&2 + echo "Unable to decompress \$0" >&2 + eval \$finish; exit 1 + fi + offset=\`expr \$offset + \$s\` +done +if test x"\$quiet" = xn; then + echo +fi + +cd "\$tmpdir" +res=0 +if test x"\$script" != x; then + if test x"\$export_conf" = x"y"; then + MS_BUNDLE="\$0" + MS_LABEL="\$label" + MS_SCRIPT="\$script" + MS_SCRIPTARGS="\$scriptargs" + MS_ARCHDIRNAME="\$archdirname" + MS_KEEP="\$KEEP" + MS_NOOVERWRITE="\$NOOVERWRITE" + MS_COMPRESS="\$COMPRESS" + MS_CLEANUP="\$cleanup" + export MS_BUNDLE MS_LABEL MS_SCRIPT MS_SCRIPTARGS + export MS_ARCHDIRNAME MS_KEEP MS_NOOVERWRITE MS_COMPRESS + fi + + if test x"\$verbose" = x"y"; then + yn="x" + while test x"\$yn" != x -a x"\$yn" != xy -a x"\$yn" != xY -a x"\$yn" != xn -a x"\$yn" != xN + do + MS_Printf "OK to execute: \$script \$scriptargs \$* ? [Y/n] " + read yn + if test x"\$yn" = x -o x"\$yn" = xy -o x"\$yn" = xY; then + eval "\"\$script\" \$scriptargs \"\\\$@\""; res=\$?; + elif test x"\$yn" = xn -o x"\$yn" = xN; then + echo "Unable to decompress \$script ,because of aborting! ";res=\$? + else + echo "Input value is unacceptable,please try again." + fi + done + else + eval "\"\$script\" \$scriptargs \"\\\$@\""; res=\$? + fi + if test "\$res" -ne 0; then + test x"\$verbose" = xy && echo "The program '\$script' returned an error code (\$res)" >&2 + fi +fi + +MS_exec_cleanup + +if test x"\$keep" = xn; then + cd "\$TMPROOT" + rm -rf "\$tmpdir" +fi +eval \$finish; exit \$res +EOF diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.1 b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.1 new file mode 100644 index 00000000..81bf6e4f --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.1 @@ -0,0 +1,110 @@ +.TH "MAKESELF" "1" "2.4.5" +.SH "NAME" +makeself \- An utility to generate self-extractable archives. +.SH "SYNTAX" +.B makeself [\fIoptions\fP] archive_dir file_name label +.B [\fIstartup_script\fP] [\fIargs\fP] +.SH "DESCRIPTION" +This program is a free (GPL) utility designed to create self-extractable +archives from a directory. +.SH "OPTIONS" +The following options are supported. +.TP 15 +.B -v, --version +Prints out the makeself version number and exits. +.TP +.B -h, --help +Print out help information. +.TP +.B --tar-quietly +Suppress verbose output from the tar command +.TP +.B --quiet +Do not print any messages other than errors +.TP +.B --gzip +Compress using gzip (default if detected). +.TP +.B --bzip2 +Compress using bzip2. +.TP +.B --pbzip2 +Compress using pbzip2. +.TP +.B --xz +Compress using xz. +.TP +.B --lzo +Compress using lzop. +.TP +.B --lz4 +Compress using lz4. +.TP +.B --compress +Compress using the UNIX 'compress' command. +.TP +.B --nocomp +Do not compress the data. +.TP +.B --complevel lvl +Specify the compression level for gzip,bzip2,pbzui2,xz,lzo or lz4 +.TP +.B --notemp +The archive will create archive_dir in the current directory and +uncompress in ./archive_dir. +.TP +.B --copy +Upon extraction, the archive will first copy itself to a temporary directory. +.TP +.B --append +Append more files to an existing makeself archive. The label and startup scripts will then be ignored. +.TP +.B --current +Files will be extracted to the current directory. Both --current and --target dir imply --notemp. +.TP +.B --target dir +Extract directly to a target directory. Directory path can be either absolute or relative. +.TP +.B --header file +Specify location of the header script. +.TP +.B --cleanup file +Specify a cleanup script that executes on interrupt and when finished successfully. +.TP +.B --follow +Follow the symlinks in the archive. +.TP +.B --noprogress +Do not show the progress during the decompression. +.TP +.B --nox11 +Disable automatic spawn of an xterm if running in X11. +.TP +.B --nowait +Do not wait for user input after executing embedded program from an xterm. +.TP +.B --nomd5 +Do not create a MD5 checksum for the archive. +.TP +.B --nocrc +Do not create a CRC32 checksum for the archive. +.TP +.B --lsm file +LSM file describing the package. +.B --packaging-date date +Use provided string as the packaging date instead of the current date. +.SH "EXAMPLES" +Here is an example, assuming the user has a package image stored in a /home/joe/mysoft, +and he wants to generate a self-extracting package named mysoft.sh, which will launch +the "setup" script initially stored in /home/joe/mysoft: +.TP +makeself.sh /home/joe/mysoft mysoft.sh "Joe's Nice Software Package" ./setup +.TP +Here is also how I created the makeself.run archive which contains the Makeself distribution: +.TP +makeself.sh --notemp makeself makeself.run "Makeself by Stephane Peter" echo "Makeself has extracted itself" +.SH "AUTHORS" +Makeself has been written by Stéphane Peter . +.BR +This man page was originally written by Bartosz Fenski for the +Debian GNU/Linux distribution (but it may be used by others). diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.lsm b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.lsm new file mode 100644 index 00000000..3c4cea8c --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.lsm @@ -0,0 +1,16 @@ +Begin3 +Title: makeself.sh +Version: 2.4.5 +Description: makeself.sh is a shell script that generates a self-extractable + tar.gz archive from a directory. The resulting file appears as a shell + script, and can be launched as is. The archive will then uncompress + itself to a temporary directory and an arbitrary command will be + executed (for example an installation script). This is pretty similar + to archives generated with WinZip Self-Extractor in the Windows world. +Keywords: Installation archive tar winzip +Author: Stephane Peter (megastep@megastep.org) +Maintained-by: Stephane Peter (megastep@megastep.org) +Original-site: https://makeself.io/ +Platform: Unix +Copying-policy: GPL +End diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.sh b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.sh new file mode 100644 index 00000000..c8ea5659 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.sh @@ -0,0 +1,822 @@ +#!/bin/sh +# +# Makeself version 2.4.x +# by Stephane Peter +# +# Utility to create self-extracting tar.gz archives. +# The resulting archive is a file holding the tar.gz archive with +# a small Shell script stub that uncompresses the archive to a temporary +# directory and then executes a given script from withing that directory. +# +# Makeself home page: https://makeself.io/ +# +# Version 2.0 is a rewrite of version 1.0 to make the code easier to read and maintain. +# +# Version history : +# - 1.0 : Initial public release +# - 1.1 : The archive can be passed parameters that will be passed on to +# the embedded script, thanks to John C. Quillan +# - 1.2 : Package distribution, bzip2 compression, more command line options, +# support for non-temporary archives. Ideas thanks to Francois Petitjean +# - 1.3 : More patches from Bjarni R. Einarsson and Francois Petitjean: +# Support for no compression (--nocomp), script is no longer mandatory, +# automatic launch in an xterm, optional verbose output, and -target +# archive option to indicate where to extract the files. +# - 1.4 : Improved UNIX compatibility (Francois Petitjean) +# Automatic integrity checking, support of LSM files (Francois Petitjean) +# - 1.5 : Many bugfixes. Optionally disable xterm spawning. +# - 1.5.1 : More bugfixes, added archive options -list and -check. +# - 1.5.2 : Cosmetic changes to inform the user of what's going on with big +# archives (Quake III demo) +# - 1.5.3 : Check for validity of the DISPLAY variable before launching an xterm. +# More verbosity in xterms and check for embedded command's return value. +# Bugfix for Debian 2.0 systems that have a different "print" command. +# - 1.5.4 : Many bugfixes. Print out a message if the extraction failed. +# - 1.5.5 : More bugfixes. Added support for SETUP_NOCHECK environment variable to +# bypass checksum verification of archives. +# - 1.6.0 : Compute MD5 checksums with the md5sum command (patch from Ryan Gordon) +# - 2.0 : Brand new rewrite, cleaner architecture, separated header and UNIX ports. +# - 2.0.1 : Added --copy +# - 2.1.0 : Allow multiple tarballs to be stored in one archive, and incremental updates. +# Added --nochown for archives +# Stopped doing redundant checksums when not necesary +# - 2.1.1 : Work around insane behavior from certain Linux distros with no 'uncompress' command +# Cleaned up the code to handle error codes from compress. Simplified the extraction code. +# - 2.1.2 : Some bug fixes. Use head -n to avoid problems. +# - 2.1.3 : Bug fixes with command line when spawning terminals. +# Added --tar for archives, allowing to give arbitrary arguments to tar on the contents of the archive. +# Added --noexec to prevent execution of embedded scripts. +# Added --nomd5 and --nocrc to avoid creating checksums in archives. +# Added command used to create the archive in --info output. +# Run the embedded script through eval. +# - 2.1.4 : Fixed --info output. +# Generate random directory name when extracting files to . to avoid problems. (Jason Trent) +# Better handling of errors with wrong permissions for the directory containing the files. (Jason Trent) +# Avoid some race conditions (Ludwig Nussel) +# Unset the $CDPATH variable to avoid problems if it is set. (Debian) +# Better handling of dot files in the archive directory. +# - 2.1.5 : Made the md5sum detection consistent with the header code. +# Check for the presence of the archive directory +# Added --encrypt for symmetric encryption through gpg (Eric Windisch) +# Added support for the digest command on Solaris 10 for MD5 checksums +# Check for available disk space before extracting to the target directory (Andreas Schweitzer) +# Allow extraction to run asynchronously (patch by Peter Hatch) +# Use file descriptors internally to avoid error messages (patch by Kay Tiong Khoo) +# - 2.1.6 : Replaced one dot per file progress with a realtime progress percentage and a spining cursor (Guy Baconniere) +# Added --noprogress to prevent showing the progress during the decompression (Guy Baconniere) +# Added --target dir to allow extracting directly to a target directory (Guy Baconniere) +# - 2.2.0 : Many bugfixes, updates and contributions from users. Check out the project page on Github for the details. +# - 2.3.0 : Option to specify packaging date to enable byte-for-byte reproducibility. (Marc Pawlowsky) +# - 2.4.0 : Optional support for SHA256 checksums in archives. +# - 2.4.2 : Add support for threads for several compressors. (M. Limber) +# Added zstd support. +# - 2.4.3 : Make explicit POSIX tar archives for increased compatibility. +# - 2.4.5 : Added --tar-format to override ustar tar archive format +# +# (C) 1998-2021 by Stephane Peter +# +# This software is released under the terms of the GNU GPL version 2 and above +# Please read the license at http://www.gnu.org/copyleft/gpl.html +# Self-extracting archives created with this script are explictly NOT released under the term of the GPL +# + +MS_VERSION=2.4.5 +MS_COMMAND="$0" +unset CDPATH + +for f in ${1+"$@"}; do + MS_COMMAND="$MS_COMMAND \\\\ + \\\"$f\\\"" +done + +# For Solaris systems +if test -d /usr/xpg4/bin; then + PATH=/usr/xpg4/bin:$PATH + export PATH +fi + +# Procedures + +MS_Usage() +{ + echo "Usage: $0 [args] archive_dir file_name label startup_script [script_args]" + echo "args can be one or more of the following :" + echo " --version | -v : Print out Makeself version number and exit" + echo " --help | -h : Print out this help message" + echo " --tar-quietly : Suppress verbose output from the tar command" + echo " --quiet | -q : Do not print any messages other than errors." + echo " --gzip : Compress using gzip (default if detected)" + echo " --pigz : Compress with pigz" + echo " --zstd : Compress with zstd" + echo " --bzip2 : Compress using bzip2 instead of gzip" + echo " --pbzip2 : Compress using pbzip2 instead of gzip" + echo " --xz : Compress using xz instead of gzip" + echo " --lzo : Compress using lzop instead of gzip" + echo " --lz4 : Compress using lz4 instead of gzip" + echo " --compress : Compress using the UNIX 'compress' command" + echo " --complevel lvl : Compression level for gzip pigz zstd xz lzo lz4 bzip2 and pbzip2 (default 9)" + echo " --threads thds : Number of threads to be used by compressors that support parallelization." + echo " Omit to use compressor's default. Most useful (and required) for opting" + echo " into xz's threading, usually with '--threads=0' for all available cores." + echo " pbzip2 and pigz are parallel by default, and setting this value allows" + echo " limiting the number of threads they use." + echo " --base64 : Instead of compressing, encode the data using base64" + echo " --gpg-encrypt : Instead of compressing, encrypt the data using GPG" + echo " --gpg-asymmetric-encrypt-sign" + echo " : Instead of compressing, asymmetrically encrypt and sign the data using GPG" + echo " --gpg-extra opt : Append more options to the gpg command line" + echo " --ssl-encrypt : Instead of compressing, encrypt the data using OpenSSL" + echo " --ssl-passwd pass : Use the given password to encrypt the data using OpenSSL" + echo " --ssl-pass-src src : Use the given src as the source of password to encrypt the data" + echo " using OpenSSL. See \"PASS PHRASE ARGUMENTS\" in man openssl." + echo " If this option is not supplied, the user will be asked to enter" + echo " encryption password on the current terminal." + echo " --ssl-no-md : Do not use \"-md\" option not supported by older OpenSSL." + echo " --nochown : Do not give the target folder to the current user (default)" + echo " --chown : Give the target folder to the current user recursively" + echo " --nocomp : Do not compress the data" + echo " --notemp : The archive will create archive_dir in the" + echo " current directory and uncompress in ./archive_dir" + echo " --needroot : Check that the root user is extracting the archive before proceeding" + echo " --copy : Upon extraction, the archive will first copy itself to" + echo " a temporary directory" + echo " --append : Append more files to an existing Makeself archive" + echo " The label and startup scripts will then be ignored" + echo " --target dir : Extract directly to a target directory" + echo " directory path can be either absolute or relative" + echo " --nooverwrite : Do not extract the archive if the specified target directory exists" + echo " --current : Files will be extracted to the current directory" + echo " Both --current and --target imply --notemp" + echo " --tar-format opt : Specify a tar archive format (default is ustar)" + echo " --tar-extra opt : Append more options to the tar command line" + echo " --untar-extra opt : Append more options to the during the extraction of the tar archive" + echo " --nomd5 : Don't calculate an MD5 for archive" + echo " --nocrc : Don't calculate a CRC for archive" + echo " --sha256 : Compute a SHA256 checksum for the archive" + echo " --header file : Specify location of the header script" + echo " --cleanup file : Specify a cleanup script that executes on interrupt and when finished successfully." + echo " --follow : Follow the symlinks in the archive" + echo " --noprogress : Do not show the progress during the decompression" + echo " --nox11 : Disable automatic spawn of a xterm" + echo " --nowait : Do not wait for user input after executing embedded" + echo " program from an xterm" + echo " --sign passphrase : Signature private key to sign the package with" + echo " --lsm file : LSM file describing the package" + echo " --license file : Append a license file" + echo " --help-header file : Add a header to the archive's --help output" + echo " --packaging-date date" + echo " : Use provided string as the packaging date" + echo " instead of the current date." + echo + echo " --keep-umask : Keep the umask set to shell default, rather than overriding when executing self-extracting archive." + echo " --export-conf : Export configuration variables to startup_script" + echo + echo "Do not forget to give a fully qualified startup script name" + echo "(i.e. with a ./ prefix if inside the archive)." + exit 1 +} + +# Default settings +if type gzip >/dev/null 2>&1; then + COMPRESS=gzip +elif type compress >/dev/null 2>&1; then + COMPRESS=compress +else + echo "ERROR: missing commands: gzip, compress" >&2 + MS_Usage +fi +ENCRYPT=n +PASSWD="" +PASSWD_SRC="" +OPENSSL_NO_MD=n +COMPRESS_LEVEL=9 +DEFAULT_THREADS=123456 # Sentinel value +THREADS=$DEFAULT_THREADS +KEEP=n +CURRENT=n +NOX11=n +NOWAIT=n +APPEND=n +TAR_QUIETLY=n +KEEP_UMASK=n +QUIET=n +NOPROGRESS=n +COPY=none +NEED_ROOT=n +TAR_ARGS=rvf +TAR_FORMAT=ustar +TAR_EXTRA="" +GPG_EXTRA="" +DU_ARGS=-ks +HEADER=`dirname "$0"`/makeself-header.sh +SIGNATURE="" +TARGETDIR="" +NOOVERWRITE=n +DATE=`LC_ALL=C date` +EXPORT_CONF=n +SHA256=n +OWNERSHIP=n +SIGN=n +GPG_PASSPHRASE="" + +# LSM file stuff +LSM_CMD="echo No LSM. >> \"\$archname\"" + +while true +do + case "$1" in + --version | -v) + echo Makeself version $MS_VERSION + exit 0 + ;; + --pbzip2) + COMPRESS=pbzip2 + shift + ;; + --bzip2) + COMPRESS=bzip2 + shift + ;; + --gzip) + COMPRESS=gzip + shift + ;; + --pigz) + COMPRESS=pigz + shift + ;; + --zstd) + COMPRESS=zstd + shift + ;; + --xz) + COMPRESS=xz + shift + ;; + --lzo) + COMPRESS=lzo + shift + ;; + --lz4) + COMPRESS=lz4 + shift + ;; + --compress) + COMPRESS=compress + shift + ;; + --base64) + COMPRESS=base64 + shift + ;; + --gpg-encrypt) + COMPRESS=gpg + shift + ;; + --gpg-asymmetric-encrypt-sign) + COMPRESS=gpg-asymmetric + shift + ;; + --gpg-extra) + GPG_EXTRA="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --ssl-encrypt) + ENCRYPT=openssl + shift + ;; + --ssl-passwd) + PASSWD=$2 + shift 2 || { MS_Usage; exit 1; } + ;; + --ssl-pass-src) + PASSWD_SRC=$2 + shift 2 || { MS_Usage; exit 1; } + ;; + --ssl-no-md) + OPENSSL_NO_MD=y + shift + ;; + --nocomp) + COMPRESS=none + shift + ;; + --complevel) + COMPRESS_LEVEL="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --threads) + THREADS="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --nochown) + OWNERSHIP=n + shift + ;; + --chown) + OWNERSHIP=y + shift + ;; + --notemp) + KEEP=y + shift + ;; + --copy) + COPY=copy + shift + ;; + --current) + CURRENT=y + KEEP=y + shift + ;; + --tar-format) + TAR_FORMAT="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --tar-extra) + TAR_EXTRA="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --untar-extra) + UNTAR_EXTRA="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --target) + TARGETDIR="$2" + KEEP=y + shift 2 || { MS_Usage; exit 1; } + ;; + --sign) + SIGN=y + GPG_PASSPHRASE="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --nooverwrite) + NOOVERWRITE=y + shift + ;; + --needroot) + NEED_ROOT=y + shift + ;; + --header) + HEADER="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --cleanup) + CLEANUP_SCRIPT="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --license) + # We need to escape all characters having a special meaning in double quotes + LICENSE=$(sed 's/\\/\\\\/g; s/"/\\\"/g; s/`/\\\`/g; s/\$/\\\$/g' "$2") + shift 2 || { MS_Usage; exit 1; } + ;; + --follow) + TAR_ARGS=rvhf + DU_ARGS=-ksL + shift + ;; + --noprogress) + NOPROGRESS=y + shift + ;; + --nox11) + NOX11=y + shift + ;; + --nowait) + NOWAIT=y + shift + ;; + --nomd5) + NOMD5=y + shift + ;; + --sha256) + SHA256=y + shift + ;; + --nocrc) + NOCRC=y + shift + ;; + --append) + APPEND=y + shift + ;; + --lsm) + LSM_CMD="cat \"$2\" >> \"\$archname\"" + shift 2 || { MS_Usage; exit 1; } + ;; + --packaging-date) + DATE="$2" + shift 2 || { MS_Usage; exit 1; } + ;; + --help-header) + HELPHEADER=`sed -e "s/'/'\\\\\''/g" $2` + shift 2 || { MS_Usage; exit 1; } + [ -n "$HELPHEADER" ] && HELPHEADER="$HELPHEADER +" + ;; + --tar-quietly) + TAR_QUIETLY=y + shift + ;; + --keep-umask) + KEEP_UMASK=y + shift + ;; + --export-conf) + EXPORT_CONF=y + shift + ;; + -q | --quiet) + QUIET=y + shift + ;; + -h | --help) + MS_Usage + ;; + -*) + echo Unrecognized flag : "$1" + MS_Usage + ;; + *) + break + ;; + esac +done + +if test $# -lt 1; then + MS_Usage +else + if test -d "$1"; then + archdir="$1" + else + echo "Directory $1 does not exist." >&2 + exit 1 + fi +fi +archname="$2" + +if test "$QUIET" = "y" || test "$TAR_QUIETLY" = "y"; then + if test "$TAR_ARGS" = "rvf"; then + TAR_ARGS="rf" + elif test "$TAR_ARGS" = "rvhf"; then + TAR_ARGS="rhf" + fi +fi + +if test "$APPEND" = y; then + if test $# -lt 2; then + MS_Usage + fi + + # Gather the info from the original archive + OLDENV=`sh "$archname" --dumpconf` + if test $? -ne 0; then + echo "Unable to update archive: $archname" >&2 + exit 1 + else + eval "$OLDENV" + OLDSKIP=`expr $SKIP + 1` + fi +else + if test "$KEEP" = n -a $# = 3; then + echo "ERROR: Making a temporary archive with no embedded command does not make sense!" >&2 + echo >&2 + MS_Usage + fi + # We don't want to create an absolute directory unless a target directory is defined + if test "$CURRENT" = y; then + archdirname="." + elif test x"$TARGETDIR" != x; then + archdirname="$TARGETDIR" + else + archdirname=`basename "$1"` + fi + + if test $# -lt 3; then + MS_Usage + fi + + LABEL="$3" + SCRIPT="$4" + test "x$SCRIPT" = x || shift 1 + shift 3 + SCRIPTARGS="$*" +fi + +if test "$KEEP" = n -a "$CURRENT" = y; then + echo "ERROR: It is A VERY DANGEROUS IDEA to try to combine --notemp and --current." >&2 + exit 1 +fi + +case $COMPRESS in +gzip) + GZIP_CMD="gzip -c$COMPRESS_LEVEL" + GUNZIP_CMD="gzip -cd" + ;; +pigz) + GZIP_CMD="pigz -$COMPRESS_LEVEL" + if test $THREADS -ne $DEFAULT_THREADS; then # Leave as the default if threads not indicated + GZIP_CMD="$GZIP_CMD --processes $THREADS" + fi + GUNZIP_CMD="gzip -cd" + ;; +zstd) + GZIP_CMD="zstd -$COMPRESS_LEVEL" + if test $THREADS -ne $DEFAULT_THREADS; then # Leave as the default if threads not indicated + GZIP_CMD="$GZIP_CMD --threads=$THREADS" + fi + GUNZIP_CMD="zstd -cd" + ;; +pbzip2) + GZIP_CMD="pbzip2 -c$COMPRESS_LEVEL" + if test $THREADS -ne $DEFAULT_THREADS; then # Leave as the default if threads not indicated + GZIP_CMD="$GZIP_CMD -p$THREADS" + fi + GUNZIP_CMD="bzip2 -d" + ;; +bzip2) + GZIP_CMD="bzip2 -$COMPRESS_LEVEL" + GUNZIP_CMD="bzip2 -d" + ;; +xz) + GZIP_CMD="xz -c$COMPRESS_LEVEL" + # Must opt-in by specifying a value since not all versions of xz support threads + if test $THREADS -ne $DEFAULT_THREADS; then + GZIP_CMD="$GZIP_CMD --threads=$THREADS" + fi + GUNZIP_CMD="xz -d" + ;; +lzo) + GZIP_CMD="lzop -c$COMPRESS_LEVEL" + GUNZIP_CMD="lzop -d" + ;; +lz4) + GZIP_CMD="lz4 -c$COMPRESS_LEVEL" + GUNZIP_CMD="lz4 -d" + ;; +base64) + GZIP_CMD="base64" + GUNZIP_CMD="base64 --decode -i -" + ;; +gpg) + GZIP_CMD="gpg $GPG_EXTRA -ac -z$COMPRESS_LEVEL" + GUNZIP_CMD="gpg -d" + ENCRYPT="gpg" + ;; +gpg-asymmetric) + GZIP_CMD="gpg $GPG_EXTRA -z$COMPRESS_LEVEL -es" + GUNZIP_CMD="gpg --yes -d" + ENCRYPT="gpg" + ;; +compress) + GZIP_CMD="compress -fc" + GUNZIP_CMD="(type compress >/dev/null 2>&1 && compress -fcd || gzip -cd)" + ;; +none) + GZIP_CMD="cat" + GUNZIP_CMD="cat" + ;; +esac + +if test x"$ENCRYPT" = x"openssl"; then + if test x"$APPEND" = x"y"; then + echo "Appending to existing archive is not compatible with OpenSSL encryption." >&2 + fi + + ENCRYPT_CMD="openssl enc -aes-256-cbc -salt" + DECRYPT_CMD="openssl enc -aes-256-cbc -d" + + if test x"$OPENSSL_NO_MD" != x"y"; then + ENCRYPT_CMD="$ENCRYPT_CMD -md sha256" + DECRYPT_CMD="$DECRYPT_CMD -md sha256" + fi + + if test -n "$PASSWD_SRC"; then + ENCRYPT_CMD="$ENCRYPT_CMD -pass $PASSWD_SRC" + elif test -n "$PASSWD"; then + ENCRYPT_CMD="$ENCRYPT_CMD -pass pass:$PASSWD" + fi +fi + +tmpfile="${TMPDIR:-/tmp}/mkself$$" + +if test -f "$HEADER"; then + oldarchname="$archname" + archname="$tmpfile" + # Generate a fake header to count its lines + SKIP=0 + . "$HEADER" + SKIP=`cat "$tmpfile" |wc -l` + # Get rid of any spaces + SKIP=`expr $SKIP` + rm -f "$tmpfile" + if test "$QUIET" = "n"; then + echo "Header is $SKIP lines long" >&2 + fi + archname="$oldarchname" +else + echo "Unable to open header file: $HEADER" >&2 + exit 1 +fi + +if test "$QUIET" = "n"; then + echo +fi + +if test "$APPEND" = n; then + if test -f "$archname"; then + echo "WARNING: Overwriting existing file: $archname" >&2 + fi +fi + +USIZE=`du $DU_ARGS "$archdir" | awk '{print $1}'` + +if test "." = "$archdirname"; then + if test "$KEEP" = n; then + archdirname="makeself-$$-`date +%Y%m%d%H%M%S`" + fi +fi + +test -d "$archdir" || { echo "Error: $archdir does not exist."; rm -f "$tmpfile"; exit 1; } +if test "$QUIET" = "n"; then + echo "About to compress $USIZE KB of data..." + echo "Adding files to archive named \"$archname\"..." +fi + +# See if we have GNU tar +TAR=`exec <&- 2>&-; which gtar || command -v gtar || type gtar` +test -x "$TAR" || TAR=tar + +tmparch="${TMPDIR:-/tmp}/mkself$$.tar" +( + if test "$APPEND" = "y"; then + tail -n "+$OLDSKIP" "$archname" | eval "$GUNZIP_CMD" > "$tmparch" + fi + cd "$archdir" + # "Determining if a directory is empty" + # https://www.etalabs.net/sh_tricks.html + find . \ + \( \ + ! -type d \ + -o \ + \( -links 2 -exec sh -c ' + is_empty () ( + cd "$1" + set -- .[!.]* ; test -f "$1" && return 1 + set -- ..?* ; test -f "$1" && return 1 + set -- * ; test -f "$1" && return 1 + return 0 + ) + is_empty "$0"' {} \; \ + \) \ + \) -print \ + | LC_ALL=C sort \ + | sed 's/./\\&/g' \ + | xargs $TAR $TAR_EXTRA --format $TAR_FORMAT -$TAR_ARGS "$tmparch" +) || { + echo "ERROR: failed to create temporary archive: $tmparch" + rm -f "$tmparch" "$tmpfile" + exit 1 +} + +USIZE=`du $DU_ARGS "$tmparch" | awk '{print $1}'` + +eval "$GZIP_CMD" <"$tmparch" >"$tmpfile" || { + echo "ERROR: failed to create temporary file: $tmpfile" + rm -f "$tmparch" "$tmpfile" + exit 1 +} +rm -f "$tmparch" + +if test x"$ENCRYPT" = x"openssl"; then + echo "About to encrypt archive \"$archname\"..." + { eval "$ENCRYPT_CMD -in $tmpfile -out ${tmpfile}.enc" && mv -f ${tmpfile}.enc $tmpfile; } || \ + { echo Aborting: could not encrypt temporary file: "$tmpfile".; rm -f "$tmpfile"; exit 1; } +fi + +fsize=`cat "$tmpfile" | wc -c | tr -d " "` + +# Compute the checksums + +shasum=0000000000000000000000000000000000000000000000000000000000000000 +md5sum=00000000000000000000000000000000 +crcsum=0000000000 + +if test "$NOCRC" = y; then + if test "$QUIET" = "n"; then + echo "skipping crc at user request" + fi +else + crcsum=`CMD_ENV=xpg4 cksum < "$tmpfile" | sed -e 's/ /Z/' -e 's/ /Z/' | cut -dZ -f1` + if test "$QUIET" = "n"; then + echo "CRC: $crcsum" + fi +fi + +if test "$SHA256" = y; then + SHA_PATH=`exec <&- 2>&-; which shasum || command -v shasum || type shasum` + if test -x "$SHA_PATH"; then + shasum=`eval "$SHA_PATH -a 256" < "$tmpfile" | cut -b-64` + else + SHA_PATH=`exec <&- 2>&-; which sha256sum || command -v sha256sum || type sha256sum` + shasum=`eval "$SHA_PATH" < "$tmpfile" | cut -b-64` + fi + if test "$QUIET" = "n"; then + if test -x "$SHA_PATH"; then + echo "SHA256: $shasum" + else + echo "SHA256: none, SHA command not found" + fi + fi +fi +if test "$NOMD5" = y; then + if test "$QUIET" = "n"; then + echo "Skipping md5sum at user request" + fi +else + # Try to locate a MD5 binary + OLD_PATH=$PATH + PATH=${GUESS_MD5_PATH:-"$OLD_PATH:/bin:/usr/bin:/sbin:/usr/local/ssl/bin:/usr/local/bin:/opt/openssl/bin"} + MD5_ARG="" + MD5_PATH=`exec <&- 2>&-; which md5sum || command -v md5sum || type md5sum` + test -x "$MD5_PATH" || MD5_PATH=`exec <&- 2>&-; which md5 || command -v md5 || type md5` + test -x "$MD5_PATH" || MD5_PATH=`exec <&- 2>&-; which digest || command -v digest || type digest` + PATH=$OLD_PATH + if test -x "$MD5_PATH"; then + if test `basename ${MD5_PATH}`x = digestx; then + MD5_ARG="-a md5" + fi + md5sum=`eval "$MD5_PATH $MD5_ARG" < "$tmpfile" | cut -b-32` + if test "$QUIET" = "n"; then + echo "MD5: $md5sum" + fi + else + if test "$QUIET" = "n"; then + echo "MD5: none, MD5 command not found" + fi + fi +fi +if test "$SIGN" = y; then + GPG_PATH=`exec <&- 2>&-; which gpg || command -v gpg || type gpg` + if test -x "$GPG_PATH"; then + SIGNATURE=`$GPG_PATH --pinentry-mode=loopback --batch --yes --passphrase "$GPG_PASSPHRASE" --output - --detach-sig $tmpfile | base64 | tr -d \\\\n` + if test "$QUIET" = "n"; then + echo "Signature: $SIGNATURE" + fi + else + echo "Missing gpg command" >&2 + fi +fi + +totalsize=0 +for size in $fsize; +do + totalsize=`expr $totalsize + $size` +done + +if test "$APPEND" = y; then + mv "$archname" "$archname".bak || exit + + # Prepare entry for new archive + filesizes="$fsize" + CRCsum="$crcsum" + MD5sum="$md5sum" + SHAsum="$shasum" + Signature="$SIGNATURE" + # Generate the header + . "$HEADER" + # Append the new data + cat "$tmpfile" >> "$archname" + + chmod +x "$archname" + rm -f "$archname".bak + if test "$QUIET" = "n"; then + echo "Self-extractable archive \"$archname\" successfully updated." + fi +else + filesizes="$fsize" + CRCsum="$crcsum" + MD5sum="$md5sum" + SHAsum="$shasum" + Signature="$SIGNATURE" + + # Generate the header + . "$HEADER" + + # Append the compressed tar data after the stub + if test "$QUIET" = "n"; then + echo + fi + cat "$tmpfile" >> "$archname" + chmod +x "$archname" + if test "$QUIET" = "n"; then + echo Self-extractable archive \"$archname\" successfully created. + fi +fi +rm -f "$tmpfile" diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/run-tests.sh b/cust_op/cust_op_by_addr/cmake/util/makeself/run-tests.sh new file mode 100644 index 00000000..31ee1651 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/makeself/run-tests.sh @@ -0,0 +1,8 @@ +#!/bin/sh +# Run every available test - Bash needed +cd test +for test in *test; +do + echo "Running test $test ..." + bash $test || { echo "*** ERROR: Test '$test' failed!"; exit 1; } +done diff --git a/cust_op/cust_op_by_addr/cmake/util/merge_aicpu_info_json.sh b/cust_op/cust_op_by_addr/cmake/util/merge_aicpu_info_json.sh new file mode 100644 index 00000000..a977bd51 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/merge_aicpu_info_json.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +project_path=$1 +build_path=$2 +vendor_name=customize +echo $@ +if [[ ! -d "$project_path" ]]; then + echo "[ERROR] No projcet path is provided" + exit 1 +fi + +if [[ ! -d "$build_path" ]]; then + echo "[ERROR] No build path is provided" + exit 1 +fi + +if [[ ! -d "$ASCEND_OPP_PATH" ]]; then + echo "[ERROR] No opp install path is provided" + exit 1 +fi +custom_exist_info_json=$ASCEND_OPP_PATH/vendors/$vendor_name/op_impl/cpu/config/cust_aicpu_kernel.json +custom_new_info_json=$build_path/makepkg/packages/vendors/$vendor_name/op_impl/cpu/config/cust_aicpu_kernel.json +temp_info_json=$build_path/makepkg/packages/vendors/$vendor_name/op_impl/cpu/config/temp_cust_aicpu_kernel.json + +if [[ -f "$custom_exist_info_json" ]] && [[ -f "$custom_new_info_json" ]]; then + cp -f $custom_exist_info_json $temp_info_json + chmod +w $temp_info_json + python3 ${project_path}/cmake/util/insert_op_info.py ${custom_new_info_json} ${temp_info_json} + cp -f $temp_info_json $custom_new_info_json + rm -f $temp_info_json +fi diff --git a/cust_op/cust_op_by_addr/cmake/util/opdesc_parser.py b/cust_op/cust_op_by_addr/cmake/util/opdesc_parser.py new file mode 100644 index 00000000..c58729c3 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/opdesc_parser.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" + +import sys +import os + + +class OpDesc: + def __init__(self: any, op_type: str): + self.op_type = op_type + self.attr_list = [] + self.attr_val = {} + self.input_name = [] + self.input_type = [] + self.input_dtype = [] + self.input_fmt = [] + self.output_name = [] + self.output_type = [] + self.output_dtype = [] + self.output_fmt = [] + self.op_fmt_sel = False + self.op_chk_support = False + self.op_intf = '' + self.kern_name = '' + self.op_file = '' + self.op_replay_flag = False + self.op_replay_batch = False + self.input_idx = 0 + self.output_idx = 0 + self.max_block_dim = 32 + self.max_shape_size = 268435456 + self.dynamic_shape = False + self.op_range_limit = '' + + @staticmethod + def _parse_digit(conf: str) -> int: + return int(conf.split('=')[1]) + + @staticmethod + def _parse_flag(conf: str) -> bool: + if 'true' == conf.split('=')[1]: + return True + return False + + @staticmethod + def _parse_str(conf: str) -> str: + return conf.split('=')[1] + + @staticmethod + def _parse_list(conf: str) -> list: + return conf.split('=')[1].split(',') + + def parse_input(self: any, conf: str): + if conf.startswith('input{}.name'.format(int(self.input_idx / 4))): + self.input_name.append(self._parse_str(conf)) + self.input_idx += 1 + elif conf.startswith('input{}.paramType'.format(int(self.input_idx / 4))): + self.input_type.append(self._parse_str(conf)) + self.input_idx += 1 + elif conf.startswith('input{}.dtype'.format(int(self.input_idx / 4))): + self.input_dtype.append(self._parse_str(conf)) + self.input_idx += 1 + elif conf.startswith('input{}.format'.format(int(self.input_idx / 4))): + self.input_fmt.append(self._parse_str(conf)) + self.input_idx += 1 + else: + return + + def parse_output(self: any, conf: str): + if conf.startswith('output{}.name'.format(int(self.output_idx / 4))): + self.output_name.append(self._parse_str(conf)) + self.output_idx += 1 + elif conf.startswith('output{}.paramType'.format(int(self.output_idx / 4))): + self.output_type.append(self._parse_str(conf)) + self.output_idx += 1 + elif conf.startswith('output{}.dtype'.format(int(self.output_idx / 4))): + self.output_dtype.append(self._parse_str(conf)) + self.output_idx += 1 + elif conf.startswith('output{}.format'.format(int(self.output_idx / 4))): + self.output_fmt.append(self._parse_str(conf)) + self.output_idx += 1 + else: + return + + def parse_op_format(self: any, conf: str): + self.op_fmt_sel = self._parse_flag(conf) + + def parse_check_support(self: any, conf: str): + self.op_chk_support = self._parse_flag(conf) + + def parse_range_limit(self: any, conf: str): + self.op_range_limit = self._parse_str(conf) + + def parse_kern_name(self: any, conf: str): + self.kern_name = self._parse_str(conf) + + def parse_op_intf(self: any, conf: str): + self.op_intf = self._parse_str(conf) + + def parse_op_file(self: any, conf: str): + self.op_file = self._parse_str(conf) + + def parse_dynamic_shape(self: any, conf: str): + self.dynamic_shape = self._parse_flag(conf) + + def parse_attr_list(self: any, conf: str): + self.attr_list = self._parse_list(conf) + + def parse_attr_val(self: any, conf: str): + for attr in self.attr_list: + if self.attr_val.get(attr) is None: + self.attr_val[attr] = {} + if conf.startswith('attr_{}.type'.format(attr)): + self.attr_val.get(attr)['type'] = self._parse_str(conf) + elif conf.startswith('attr_{}.paramType'.format(attr)): + self.attr_val.get(attr)['paramType'] = self._parse_str(conf) + elif conf.startswith('attr_{}.defaultValue'.format(attr)): + self.attr_val.get(attr)['defaultValue'] = self._parse_str(conf) + + def parse_replay_val(self: any, batch_list: list, iterator_list: list): + if self.op_type in batch_list: + self.op_replay_flag = True + self.op_replay_batch = True + elif self.op_type in iterator_list: + self.op_replay_flag = True + self.op_replay_batch = False + + +def get_op_desc(file: str, batch_list: list, iterator_list: list, builder: any, op_type: list) -> list: + op_descs = [] + op_match = False + with open (file, 'r') as fd: + lines = fd.readlines() + for line in lines: + line = line.strip() + if line.startswith('['): + name = line[1:-1] + if op_type is None or name in op_type: + op_match = True + op_desc = builder(name) + op_descs.append(op_desc) + else: + op_match = False + if op_type is not None and len(op_descs) == len(op_type): + return op_descs + continue + if not op_match: + continue + if line.startswith('input'): + op_desc.parse_input(line) + elif line.startswith('output'): + op_desc.parse_output(line) + elif line.startswith('dynamicFormat.flag'): + op_desc.parse_op_format(line) + elif line.startswith('needCheckSupport.flag'): + op_desc.parse_check_support(line) + elif line.startswith('rangeLimit.value'): + op_desc.parse_range_limit(line) + elif line.startswith('opInterface.value'): + op_desc.parse_op_intf(line) + elif line.startswith('kernel.name'): + op_desc.parse_kern_name(line) + elif line.startswith('opFile.value'): + op_desc.parse_op_file(line) + elif line.startswith('dynamicShapeSupport.flag'): + op_desc.parse_dynamic_shape(line) + elif line.startswith('attr.list'): + op_desc.parse_attr_list(line) + elif line.startswith('attr_'): + op_desc.parse_attr_val(line) + op_desc.parse_replay_val(batch_list, iterator_list) + return op_descs diff --git a/cust_op/cust_op_by_addr/cmake/util/parse_ini_to_json.py b/cust_op/cust_op_by_addr/cmake/util/parse_ini_to_json.py new file mode 100644 index 00000000..f75c5260 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/parse_ini_to_json.py @@ -0,0 +1,339 @@ +# Copyright 2020-2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +parser ini to json +""" + +import json +import os +import stat +import sys + + +ATTR_TYPE_LIST = ["int", "float", "bool", "str", "listInt", "listFloat", "listBool", "listStr", "listListInt", + "type", "listType", "tensor", "listTensor"] +ATTR_PARAMTYPE_LIST = ["optional", "required"] +BOOL_FLAG_KEY = ["dynamicFormat", "dynamicShapeSupport", "dynamicRankSupport", "precision_reduce", "heavyOp", + "needCheckSupport"] +BOOL_LIST = ["true", "false"] +DTYPE_LIST = ["float16", "float", "float32", "int8", "int16", "int32", "uint8", "uint16", "uint32", "bool", + "int64", "uint64", "qint8", "qint16", "qint32", "quint8", "quint16", "double", "complex64", + "complex128", "string", "resource", "dual", "dual_sub_int8", "dual_sub_uint8", "string_ref", + "int4", "bfloat16", "uint1"] +FORMAT_LIST = ["NCHW", "NHWC", "ND", "NC1HWC0", "FRACTAL_Z", "NC1C0HWPAD", "NHWC1C0", "FSR_NCHW", "FRACTAL_DECONV", + "C1HWNC0", "FRACTAL_DECONV_TRANSPOSE", "FRACTAL_DECONV_SP_STRIDE_TRANS", "NC1HWC0_C04", + "FRACTAL_Z_C04", "CHWN", "FRACTAL_DECONV_SP_STRIDE8_TRANS", "HWCN", "NC1KHKWHWC0", "BN_WEIGHT", + "FILTER_HWCK", "HASHTABLE_LOOKUP_LOOKUPS", "HASHTABLE_LOOKUP_KEYS", "HASHTABLE_LOOKUP_VALUE", + "HASHTABLE_LOOKUP_OUTPUT", "HASHTABLE_LOOKUP_HITS", "C1HWNCoC0", "MD", "NDHWC", "FRACTAL_ZZ", + "FRACTAL_NZ", "NCDHW", "DHWCN", "NDC1HWC0", "FRACTAL_Z_3D", "CN", "NC", "DHWNC", + "FRACTAL_Z_3D_TRANSPOSE", "FRACTAL_ZN_LSTM", "FRACTAL_ZN_RNN", "FRACTAL_Z_G", "NULL"] + + +def parse_ini_files(ini_files): + """ + parse ini files to json + Parameters: + ---------------- + ini_files:input file list + return:ops_info + ---------------- + """ + tbe_ops_info = {} + for ini_file in ini_files: + check_file_size(ini_file) + parse_ini_to_obj(ini_file, tbe_ops_info) + return tbe_ops_info + + +def check_file_size(input_file): + try: + file_size = os.path.getsize(input_file) + except OSError as os_error: + print('[ERROR] Failed to open "%s". %s' % (input_file, str(os_error))) + raise OSError from os_error + if file_size > 10*1024*1024: + print('[WARN] The size of %s exceeds 10MB, it may take more time to run, please wait.' % input_file) + + +def parse_ini_to_obj(ini_file, tbe_ops_info): + """ + parse ini file to json obj + Parameters: + ---------------- + ini_file:ini file path + tbe_ops_info:ops_info + ---------------- + """ + with open(ini_file) as ini_file: + lines = ini_file.readlines() + op_dict = {} + op_name = "" + find_op_type = False + for line in lines: + line = line.rstrip() + if line == "": + continue + if line.startswith("["): + if line.endswith("]"): + op_name = line[1:-1] + op_dict = {} + tbe_ops_info[op_name] = op_dict + find_op_type = True + elif "=" in line: + key1 = line[:line.index("=")] + key2 = line[line.index("=")+1:] + key1_0, key1_1 = key1.split(".") + if key1_0 not in op_dict: + op_dict[key1_0] = {} + if key1_1 in op_dict.get(key1_0): + raise RuntimeError("Op:" + op_name + " " + key1_0 + " " + + key1_1 + " is repeated!") + dic_key = op_dict.get(key1_0) + dic_key[key1_1] = key2 + else: + continue + if not find_op_type: + raise RuntimeError("Not find OpType in .ini file.") + + +def check_output_exist(op_dict, is_valid): + """ + Function Description: + Check output is exist + Parameter: op_dict + Parameter: is_valid + """ + if "output0" in op_dict: + output0_dict = op_dict.get("output0") + if output0_dict.get("name", None) is None: + is_valid = False + print("output0.name is required in .ini file!") + else: + is_valid = False + print("output0 is required in .ini file!") + return is_valid + + +def check_attr_dict(attr_dict, is_valid, attr): + """ + Function Description: + Check attr_dict + Parameter: attr_dict + Parameter: is_valid + Parameter: attr + """ + attr_type = attr_dict.get("type") + value = attr_dict.get("value") + param_type = attr_dict.get("paramType") + if attr_type is None or value is None: + is_valid = False + print("If attr.list is exist, {0}.type and {0}.value is required".format(attr)) + if param_type and param_type not in ATTR_PARAMTYPE_LIST: + is_valid = False + print("{0}.paramType only support {1}.".format(attr, ATTR_PARAMTYPE_LIST)) + if attr_type and attr_type not in ATTR_TYPE_LIST: + is_valid = False + print("{0}.type only support {1}.".format(attr, ATTR_TYPE_LIST)) + return is_valid + + +def check_attr(op_dict, is_valid): + """ + Function Description: + Check attr + Parameter: op_dict + Parameter: is_valid + """ + if "attr" in op_dict: + attr_dict = op_dict.get("attr") + attr_list_str = attr_dict.get("list", None) + if attr_list_str is None: + is_valid = False + print("attr.list is required in .ini file!") + else: + attr_list = attr_list_str.split(",") + for attr_name in attr_list: + attr = "attr_" + attr_name.strip() + attr_dict = op_dict.get(attr) + if attr_dict: + is_valid = check_attr_dict(attr_dict, is_valid, attr) + else: + is_valid = False + print("%s is required in .ini file, when attr.list is %s!" % (attr, attr_list_str)) + return is_valid + + +def check_bool_flag(op_dict, is_valid): + """ + Function Description: + check_bool_flag + Parameter: op_dict + Parameter: is_valid + """ + for key in BOOL_FLAG_KEY: + if key in op_dict: + op_bool_key = op_dict.get(key) + if op_bool_key.get("flag").strip() not in BOOL_LIST: + is_valid = False + print("{0}.flag only support {1}.".format(key, BOOL_LIST)) + return is_valid + + +def check_type_format(op_info, is_valid, op_info_key): + """ + Function Description: + Check type and format + Parameter: op_info + Parameter: is_valid + Parameter: op_info_key + """ + op_info_dtype_str = op_info.get("dtype") + op_info_dtype_num = 0 + op_info_format_num = 0 + if op_info_dtype_str: + op_info_dtype = op_info_dtype_str.split(",") + op_info_dtype_num = len(op_info_dtype) + for dtype in op_info_dtype: + if dtype.strip() not in DTYPE_LIST: + is_valid = False + print("{0}.dtype not support {1}.".format(op_info_key, dtype)) + op_info_format_str = op_info.get("format") + if op_info_format_str: + op_info_format = op_info_format_str.split(",") + op_info_format_num = len(op_info_format) + for op_format in op_info_format: + if op_format.strip() not in FORMAT_LIST: + is_valid = False + print("{0}.format not support {1}.".format(op_info_key, op_format)) + if op_info_dtype_num > 0 and op_info_format_num > 0: + if op_info_dtype_num != op_info_format_num: + is_valid = False + print("The number of {0}.dtype not match the number of {0}.format.".format(op_info_key)) + return is_valid + + +def check_op_info(tbe_ops): + """ + Function Description: + Check info. + Parameter: tbe_ops + Return Value: is_valid + """ + print("\n\n==============check valid for ops info start==============") + required_op_input_info_keys = ["paramType", "name"] + required_op_output_info_keys = ["paramType", "name"] + param_type_valid_value = ["dynamic", "optional", "required"] + is_valid = True + for op_key in tbe_ops: + op_dict = tbe_ops[op_key] + is_valid = check_output_exist(op_dict, is_valid) + for op_info_key in op_dict: + if op_info_key.startswith("input"): + op_input_info = op_dict[op_info_key] + missing_keys = [] + for required_op_input_info_key in required_op_input_info_keys: + if required_op_input_info_key not in op_input_info: + missing_keys.append(required_op_input_info_key) + if len(missing_keys) > 0: + print("op: " + op_key + " " + op_info_key + " missing: " + + ",".join(missing_keys)) + is_valid = False + else: + if not op_input_info["paramType"] in param_type_valid_value: + print("op: " + op_key + " " + op_info_key + \ + " paramType not valid, valid key:[dynamic, " + "optional, required]") + is_valid = False + is_valid = check_type_format(op_input_info, is_valid, op_info_key) + if op_info_key.startswith("output"): + op_input_info = op_dict[op_info_key] + missing_keys = [] + for required_op_input_info_key in required_op_output_info_keys: + if required_op_input_info_key not in op_input_info: + missing_keys.append(required_op_input_info_key) + if len(missing_keys) > 0: + print("op: " + op_key + " " + op_info_key + " missing: " + + ",".join(missing_keys)) + is_valid = False + else: + if not op_input_info["paramType"] in param_type_valid_value: + print("op: " + op_key + " " + op_info_key + + " paramType not valid, valid key:[dynamic, " + "optional, required]") + is_valid = False + is_valid = check_type_format(op_input_info, is_valid, op_info_key) + is_valid = check_attr(op_dict, is_valid) + is_valid = check_bool_flag(op_dict, is_valid) + print("==============check valid for ops info end================\n\n") + return is_valid + + +def write_json_file(tbe_ops_info, json_file_path): + """ + Save info to json file + Parameters: + ---------------- + tbe_ops_info: ops_info + json_file_path: json file path + ---------------- + """ + json_file_real_path = os.path.realpath(json_file_path) + wr_flag = os.O_WRONLY | os.O_CREAT + wr_mode = stat.S_IWUSR | stat.S_IRUSR + with os.fdopen(os.open(json_file_real_path, wr_flag, wr_mode), 'w') as file_path: + # Only the owner and group have rights + os.chmod(json_file_real_path, stat.S_IWGRP + stat.S_IWUSR + stat.S_IRGRP + + stat.S_IRUSR) + json.dump(tbe_ops_info, file_path, sort_keys=True, indent=4, + separators=(',', ':')) + print("Compile op info cfg successfully.") + + +def parse_ini_to_json(ini_file_paths, outfile_path): + """ + parse ini files to json file + Parameters: + ---------------- + ini_file_paths: list of ini file path + outfile_path: output file path + ---------------- + """ + tbe_ops_info = parse_ini_files(ini_file_paths) + if not check_op_info(tbe_ops_info): + print("Compile op info cfg failed.") + return False + write_json_file(tbe_ops_info, outfile_path) + return True + + +if __name__ == '__main__': + args = sys.argv + + OUTPUT_FILE_PATH = "tbe_ops_info.json" + ini_file_path_list = [] + + for arg in args: + if arg.endswith("ini"): + ini_file_path_list.append(arg) + OUTPUT_FILE_PATH = arg.replace(".ini", ".json") + if arg.endswith("json"): + OUTPUT_FILE_PATH = arg + + if len(ini_file_path_list) == 0: + ini_file_path_list.append("tbe_ops_info.ini") + + if not parse_ini_to_json(ini_file_path_list, OUTPUT_FILE_PATH): + sys.exit(1) + sys.exit(0) diff --git a/cust_op/cust_op_by_addr/cmake/util/preset_parse.py b/cust_op/cust_op_by_addr/cmake/util/preset_parse.py new file mode 100644 index 00000000..8f1124b1 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/preset_parse.py @@ -0,0 +1,23 @@ +import json +import sys +import os + + +def get_config_opts(file): + src_dir = os.path.abspath(os.path.dirname(file)) + opts = '' + with open(file, 'r') as fd: + config = json.load(fd) + for conf in config: + if conf == 'configurePresets': + for node in config[conf]: + macros = node.get('cacheVariables') + if macros is not None: + for key in macros: + opts += '-D{}={} '.format(key, macros[key]['value']) + opts = opts.replace('${sourceDir}', src_dir) + print(opts) + + +if __name__ == "__main__": + get_config_opts(sys.argv[1]) diff --git a/cust_op/cust_op_by_addr/cmake/util/replay_codegen.py b/cust_op/cust_op_by_addr/cmake/util/replay_codegen.py new file mode 100644 index 00000000..1baa364e --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/replay_codegen.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" + +import os +import stat +import collections +import kernel_entry as keb +from tiling_data_def_build import gen_tiling +import code_channel_infer +import const_var + +PYF_PATH = os.path.dirname(__file__) + +ReplayCodeGenParams = collections.namedtuple('ReplayCodeGenParams',\ +['op_type', 'impl', 'tiling_file', 'kernel', 'entry', 'argn', 'op_replay_batch', 'max_block_dim', 'max_shape_size']) + + +class ReplayCodeGen: + def __init__(self, replayCodeGenParams): + self.op_type = replayCodeGenParams.op_type + self.impl = replayCodeGenParams.impl + self.tiling_file = replayCodeGenParams.tiling_file + self.tiling_data_file = '' + self.kernel = replayCodeGenParams.kernel + self.entry = replayCodeGenParams.entry + self.argn = replayCodeGenParams.argn + self.batch = False + self.outdir = '' + self.data_type = 'uint8_t' + self.blknum = 32 + self.op_replay_batch = replayCodeGenParams.op_replay_batch + self.max_block_dim = replayCodeGenParams.max_block_dim + self.max_shape_size = replayCodeGenParams.max_shape_size + + def set_batch(self, is_batch): + self.batch = is_batch + + def set_outdir(self, outdir): + self.outdir = outdir + + def gen_replay(self, ops_product: str): + kerentry = os.path.join(self.outdir, self.kernel + '_entry.cce') + kerimpl = os.path.join(self.outdir, self.kernel + '_impl.cpp') + replayimpl = os.path.join(self.outdir, self.kernel + '_replay.cpp') + if self.batch: + reptmp = os.path.join(PYF_PATH, 'batch_replay_impl.temp') + else: + reptmp = os.path.join(PYF_PATH, 'replay_impl.temp') + kertmp = os.path.join(PYF_PATH, 'kernel_impl.temp') + self._gen_kentry(kerentry) + self._gen_kimpl_code(kerimpl, kertmp) + self._gen_tiling_data_header() + self._gen_replay_code(replayimpl, reptmp, ops_product) + + def _gen_tiling_data_header(self): + self.tiling_data_file = os.path.join(self.outdir, self.kernel + '_tiling_data.h') + gen_tiling(self.tiling_file, self.tiling_data_file) + + def _gen_kimpl_code(self, src, tmpfile): + with open(tmpfile, 'r') as fd: + temp = fd.read() + temp = temp.replace('__CCE_FILE__', self.impl) + with os.fdopen(os.open(src, const_var.WFLAGS, const_var.WMODES), 'w') as ofd: + ofd.write(temp) + + def _gen_replay_code(self, src, tmpfile, ops_product: str): + with open(tmpfile, 'r') as fd: + temp = fd.read() + temp = temp.replace('__ARG_NUM__', str(self.argn)) + argdef = [] + kargs = [] + for i in range(0, self.argn): + argdef.append('{} *'.format(self.data_type)) + kargs.append('({} *)GetArg({})'.format(self.data_type, i)) + temp = temp.replace('__ARGS_DEF__', ', '.join(argdef)) + temp = temp.replace('__KERNEL_ARGS__', ', '.join(kargs)) + temp = temp.replace('__KERNEL_FUN__', self.entry) + core_type_infer = 'core_type' + code_channel = code_channel_infer.infer_code_channel(code_channel_infer.InfoCodeChanelParams(self.impl,\ + self.tiling_data_file, self.kernel, self.outdir, ops_product, None)) + if code_channel == code_channel_infer.CODE_VEC: + core_type_infer = '0' + elif code_channel == code_channel_infer.CODE_CUBE: + core_type_infer = '1' + temp = temp.replace('__CORE_TYPE__', core_type_infer) + # regist function + temp = temp.replace('__OPS_PRODUCT__', ops_product) + temp = temp.replace('__OPTYPE__', self.op_type) + with os.fdopen(os.open(src, const_var.WFLAGS, const_var.WMODES), 'w') as ofd: + ofd.write(temp) + + def _gen_kentry(self, src): + kf = '' + pre_alloc_str = 'A' * 256 + if self.batch: + kf += keb.batch_code_gen("K{:02d}_{}{}".format(0, self.entry, pre_alloc_str), self.argn, self.data_type) + else: + kf += keb.mc_code_gen("K{:02d}_{}{}".format(0, self.entry, pre_alloc_str),\ + self.argn, self.data_type, self.blknum) + with os.fdopen(os.open(src, const_var.WFLAGS, const_var.WMODES), 'w') as ofd: + ofd.write(kf) diff --git a/cust_op/cust_op_by_addr/cmake/util/replay_impl.temp b/cust_op/cust_op_by_addr/cmake/util/replay_impl.temp new file mode 100644 index 00000000..3e0b2f44 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/replay_impl.temp @@ -0,0 +1,120 @@ +#include +#include +#include +#include +#include +#include +#include "replay_def.h" +#include "code_gen.h" +#include "replay_fun.h" +#include "register/op_check.h" +#define __TIK2_REPLAY_CODE__ +using namespace std; +using namespace optiling; +using namespace tik2_replay; + +extern "C" void __KERNEL_FUN__ (__ARGS_DEF__, const char *); +extern "C" int elf_append(char *elf, uint32_t elfSize, char *jit, int kernum, int blknum[], char *atext[], + int alen[], int atlen, const char* kernelname[]); + +#define KERNEL_N 1 +#define ARG_N (__ARG_NUM__) +#define MAX_L (1024 * 1024 * 100) +#define MAX_E (1024 * 1024) + +int __KERNEL_FUN___replay___OPS_PRODUCT__(ReplayFuncParam& param, const int core_type) +{ + // gen type 1 : direct call codes 0: load .o file + if (param.gentype < 0 || param.gentype > 1) { + printf("Error: call replay gen type is %d, should only be 1 or 0\n", param.gentype); + return 0; + } else if (param.gentype == 1 && param.objptr == nullptr) { + printf("Error: call replay with direct call mode, but code obj addr is null\n"); + return 0; + } else if (param.gentype == 0 && param.output_kernel_file == nullptr) { + printf("Error: call replay with object file mode, but object file path is null\n"); + return 0; + } + // core_type 0:MIX 1:CUBE 2:VEC + if (core_type < 0 || core_type > 2) { + printf("Error: call replay core type is %d !\n", core_type); + return 0; + } + g_coreType = __CORE_TYPE__; + g_taskRation = param.task_ration; + g_tilingKey = param.tiling_key; + + unsigned char *buf, *jit; + char *kernel[KERNEL_N * 32]; + int len[KERNEL_N * 32]; + int blknum[KERNEL_N]; + int max; + block_num = param.block_dim; + g_ubBase = block_num; + uint8_t *code = (uint8_t *)malloc(MAX_L); + uint8_t *pos = code; + struct timespec tp1, tp2; + + clock_gettime(CLOCK_MONOTONIC, &tp1); + if (block_num > 32) { + printf("Error: block_num > 32\n"); + return 0; + } + //__OP_FOPEN__ + for (int i = 0; i < KERNEL_N; i++) { + for (int j = 0; j < ARG_N; j++) + AddArg(j, ARG_STEP * (j + 1)); + for (block_idx = 0; block_idx < block_num; block_idx++) { + //__OP_SET_KERNEL__ + int code_idx = i * block_num + block_idx; +#ifdef FP_CEILING + SetCtrlFloatEnable(); +#else + SetCtrlFloatDisable(); +#endif + CodeInit(pos, false); + __KERNEL_FUN__(__KERNEL_ARGS__, param.tiling_data); + CodeEnd(); + kernel[code_idx] = (char *)pos; + len[code_idx] = CodeLen(); + pos += len[code_idx]; + printf("kernel %d core %ld code generated len %d\n", i, block_idx, len[code_idx]); + } + blknum[i] = block_num; + } + //__OP_FCLOSE__ + clock_gettime(CLOCK_MONOTONIC, &tp2); + buf = (unsigned char *)malloc(MAX_E); + int fd = open(param.entry_file, O_RDONLY); + if (fd < 0) { + printf("[error]: cannot find entry.o : %s\n", param.entry_file); + return 0; + } + uint32_t bufSize = read(fd, buf, MAX_E); + if (bufSize <= 0) { + printf("[error]: entry.o : %s is too small ! \n", param.entry_file); + } + close(fd); + jit = (unsigned char *)malloc(MAX_L); + printf("total code generated %ld\n", pos - code); + int sz = elf_append((char *)buf, bufSize, (char *)jit, KERNEL_N, blknum, kernel, len, pos - code, ¶m.kernel_name); + if (tp1.tv_sec != tp2.tv_sec) { + printf("%ld NS\n", tp2.tv_nsec + 1000000000 - tp1.tv_nsec); + } else { + printf("%ld NS\n", tp2.tv_nsec - tp1.tv_nsec); + } + printf("new elf size %d\n", sz); + if (param.gentype == 0) { + fd = open(param.output_kernel_file, O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR); + (void)write(fd, jit, sz); + close(fd); + free(jit); + } else if (param.gentype == 1) { + *param.objptr = (char*)jit; + } + free(buf); + free(code); + return sz; +} + +REG_REPLAY_FUNC(__OPTYPE__, __OPS_PRODUCT__, __KERNEL_FUN___replay___OPS_PRODUCT__); diff --git a/cust_op/cust_op_by_addr/cmake/util/tik2_bin_param_build.py b/cust_op/cust_op_by_addr/cmake/util/tik2_bin_param_build.py new file mode 100644 index 00000000..98095ab8 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/tik2_bin_param_build.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" + +import sys +import os +import json +import hashlib +import const_var +import opdesc_parser + +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) + + +class BinParamBuilder(opdesc_parser.OpDesc): + def __init__(self: any, op_type: str): + super().__init__(op_type) + self.soc = '' + self.out_path = '' + + def set_soc_version(self: any, soc: str): + self.soc = soc + + def set_out_path(self: any, out_path: str): + self.out_path = out_path + + def gen_input_json(self: any): + key_map = {} + count = len(self.input_dtype[0].split(',')) + for i in range(0, count): + inputs = [] + outputs = [] + attrs = [] + op_node = {} + for idx in range(0, len(self.input_name)): + idtypes = self.input_dtype[idx].split(',') + ifmts = self.input_fmt[idx].split(',') + para = {} + para['name'] = self.input_name[idx] + para['index'] = idx + para['dtype'] = idtypes[i] + para['format'] = ifmts[i] + para['paramType'] = self.input_type[idx] + para['shape'] = [-2] + inputs.append(para) + for idx in range(0, len(self.output_name)): + odtypes = self.output_dtype[idx].split(',') + ofmts = self.output_fmt[idx].split(',') + para = {} + para['name'] = self.output_name[idx] + para['index'] = idx + para['dtype'] = odtypes[i] + para['format'] = ofmts[i] + para['paramType'] = self.output_type[idx] + para['shape'] = [-2] + outputs.append(para) + for attr in self.attr_list: + att = {} + att['name'] = attr + atype = self.attr_val.get(attr).get('type').lower() + atype = atype.replace('list', 'list_') + att['dtype'] = atype + att['value'] = const_var.ATTR_DEF_VAL.get(atype) + attrs.append(att) + op_node['bin_filename'] = '' + op_node['inputs'] = inputs + op_node['outputs'] = outputs + if len(attrs) > 0: + op_node['attrs'] = attrs + param = {} + param['op_type'] = self.op_type + param['op_list'] = [op_node] + objstr = json.dumps(param, indent=' ') + md5sum = hashlib.md5(objstr.encode('utf-8')).hexdigest() + while key_map.get(md5sum) is not None: + objstr += '1' + md5sum = hashlib.md5(objstr.encode('utf-8')).hexdigest() + key_map[md5sum] = md5sum + bin_file = self.op_type + '_' + md5sum + op_node['bin_filename'] = bin_file + param_file = os.path.join(self.out_path, bin_file + '_param.json') + param_file = os.path.realpath(param_file) + with os.fdopen(os.open(param_file, const_var.WFLAGS, const_var.WMODES), 'w') as fd: + json.dump(param, fd, indent=' ') + self._write_buld_cmd(param_file, bin_file, i) + + + def _write_buld_cmd(self: any, param_file: str, bin_file: str, index: int): + hard_soc = const_var.SOC_MAP_EXT.get(self.soc) + if not hard_soc: + hard_soc = soc.capitalize() + name_com = [self.op_type, self.op_file, str(index)] + compile_file = os.path.join(self.out_path, '-'.join(name_com) + '.sh') + compile_file = os.path.realpath(compile_file) + with os.fdopen(os.open(compile_file, const_var.WFLAGS, const_var.WMODES), 'w') as fd: + fd.write('#!/bin/bash\n') + fd.write('echo "[{}] Generating {} ..."\n'.format(hard_soc, bin_file)) + cmd = const_var.BIN_CMD.format(fun=self.op_intf, soc=hard_soc, param=param_file, impl='""') + fd.write(cmd) + chk = const_var.CHK_CMD.format(res_file=bin_file + '.json') + fd.write(chk) + chk = const_var.CHK_CMD.format(res_file=bin_file + '.o') + fd.write(chk) + fd.write('echo "[{}] Generating {} Done"\n'.format(hard_soc, bin_file)) + + +def gen_bin_param_file(cfgfile: str, out_dir: str, soc: str): + op_descs = opdesc_parser.get_op_desc(cfgfile, [], [], BinParamBuilder, None) + for op_desc in op_descs: + op_desc.set_soc_version(soc) + op_desc.set_out_path(out_dir) + op_desc.gen_input_json() + + +if __name__ == '__main__': + if len(sys.argv) <= 3: + raise RuntimeError('arguments must greater than 3') + gen_bin_param_file(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/cust_op/cust_op_by_addr/cmake/util/tik2_impl_build.py b/cust_op/cust_op_by_addr/cmake/util/tik2_impl_build.py new file mode 100644 index 00000000..70a21db5 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/tik2_impl_build.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" + +import sys +import os +import stat +import opdesc_parser +import const_var + +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) + +IMPL_HEAD = ''' +import os, sys +import ctypes +import json +import shutil +from tbe.common.platform import get_soc_spec +from tbe.common.utils import para_check +from tbe.tikcpp import compile_op, replay_op, check_op_cap, generalize_op_params, get_code_channel, OpInfo +from impl.util.platform_adapter import tbe_register +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) + +DTYPE_MAP = {"float32": ["DT_FLOAT", "float"], + "float16": ["DT_FLOAT16", "half"], + "int8": ["DT_INT8", "int8_t"], + "int16": ["DT_INT16", "int16_t"], + "int32": ["DT_INT32", "int32_t"], + "int64": ["DT_INT64", "int64_t"], + "uint1": ["DT_UINT1", "uint8_t"], + "uint8": ["DT_UINT8", "uint8_t"], + "uint16": ["DT_UINT16", "uint16_t"], + "uint32": ["DT_UINT32", "uint32_t"], + "uint64": ["DT_UINT64", "uint64_t"], + "bool": ["DT_BOOL", "bool"], + "double": ["DT_DOUBLE", "double"], + "dual": ["DT_DUAL", "unknown"], + "dual_sub_int8": ["DT_DUAL_SUB_INT8", "unknown"], + "dual_sub_uint8": ["DT_DUAL_SUB_UINT8", "unknown"], + "string": ["DT_STRING", "unknown"], + "complex64": ["DT_COMPLEX64", "unknown"], + "complex128": ["DT_COMPLEX128", "unknown"], + "qint8": ["DT_QINT8", "unknown"], + "qint16": ["DT_QINT16", "unknown"], + "qint32": ["DT_QINT32", "unknown"], + "quint8": ["DT_QUINT8", "unknown"], + "quint16": ["DT_QUINT16", "unknown"], + "resource": ["DT_RESOURCE", "unknown"], + "string_ref": ["DT_STRING_REF", "unknown"], + "int4": ["DT_INT4", "int8_t"], + "bfloat16": ["DT_BF16", "half"]} + +def get_dtype_fmt_options(inputs, outputs): + options = [] + for x in inputs + outputs: + x_n = x.get("param_name").upper() + x_fmt = x.get("format") + x_dtype = x.get("dtype") + options.append("-DDTYPE_{n}={t}".format(n=x_n, t=DTYPE_MAP.get(x_dtype)[1])) + options.append("-DORIG_DTYPE_{n}={ot}".format(n=x_n, ot=DTYPE_MAP.get(x_dtype)[0])) + options.append("-DFORMAT_{n}={f}".format(n=x_n, f=x_fmt)) + return options + +def load_dso(so_path): + try: + ctypes.CDLL(so_path) + except OSError as error : + print(error) + raise RuntimeError("cannot open %s" %(so_path)) + else: + print("load so succ ", so_path) + +''' + +IMPL_API = ''' +@para_check.check_op_params({}) +def {}({}, kernel_name="{}", impl_mode=""): + inputs, outputs, attrs = _build_args({}) + options = get_dtype_fmt_options(inputs, outputs) + options += ["-x", "cce"] + ccec = shutil.which("ccec") + if ccec != None: + ccec_path = os.path.dirname(ccec) + tikcpp_path = os.path.realpath(os.path.join(ccec_path, "..", "..", "tikcpp")) + else: + tikcpp_path = os.path.realpath("/usr/local/Ascend/latest/compiler/tikcpp") + options.append("-I" + tikcpp_path) + options.append("-I" + os.path.join(tikcpp_path, "tikcfw")) + options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "impl")) + options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "interface")) + origin_func_name = "{}" + src = os.path.join(PYF_PATH, "{}") +''' + +REPLAY_OP_API = ''' + print("call into replay op") + soc_version = get_soc_spec("SOC_VERSION") + soc_short = get_soc_spec("SHORT_SOC_VERSION").lower() + tikreplay_codegen_path = tikcpp_path + "/tikreplaylib/lib" + tikreplay_stub_path = tikcpp_path + "/tikreplaylib/lib/" + soc_version + print("start load libtikreplaylib_codegen.so and libtikreplaylib_stub.so") + codegen_so_path = tikreplay_codegen_path + "/libtikreplaylib_codegen.so" + replaystub_so_path = tikreplay_stub_path + "/libtikreplaylib_stub.so" + if PYF_PATH.endswith("dynamic"): + op_replay_path = os.path.join(PYF_PATH, "..", "..", "op_replay") + else: + op_replay_path = os.path.join(PYF_PATH, "..", "op_replay") + replayapi_so_path = os.path.join(op_replay_path, "libreplay_{}_" + soc_short + ".so") + load_dso(codegen_so_path) + load_dso(replaystub_so_path) + load_dso(replayapi_so_path) + op_type = "{}" + entry_obj = os.path.join(op_replay_path, "{}_entry_" + soc_short + ".o") + code_channel = get_code_channel(src, kernel_name, op_type, options) + op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = inputs, outputs = outputs, attrs = attrs,\\ + impl_mode = impl_mode) + res, msg = replay_op(op_info, entry_obj, code_channel, src, options) + if not res: + print("call replay op failed for %s and get into call compile op" %(msg)) + compile_op(src, origin_func_name, op_info, options, code_channel) +''' + +COMPILE_OP_API = ''' + print("call into compile op") + op_type = "{}" + code_channel = get_code_channel(src, kernel_name, op_type, options) + op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = inputs, outputs = outputs, attrs = attrs,\\ + impl_mode = impl_mode) + compile_op(src, origin_func_name, op_info, options, code_channel) +''' + +SUP_API = ''' +def {}({}, impl_mode=""): + inputs, outputs, attrs = _build_args({}) + ret_str = check_op_cap("{}", "{}", inputs, outputs, attrs) + ret_dict = json.loads(ret_str) + err_code = ret_dict.get("ret_code") + sup = "Unknown" + reason = "Unknown reason" + if err_code is not None: + if err_code is 0: + sup = "True" + reason = "" + elif err_code is 1: + sup = "False" + reason = ret_dict.get("reason") + else: + sup = "Unknown" + reason = ret_dict.get("reason") + return sup, reason +''' +CAP_API = ''' +def {}({}, impl_mode=""): + inputs, outputs, attrs = _build_args({}) + result = check_op_cap("{}", "{}", inputs, outputs, attrs) + return result.decode("utf-8") +''' +GLZ_API = ''' +@tbe_register.register_param_generalization("{}") +def {}_generalization({}, generalize_config=None): + inputs, outputs, attrs = _build_args({}) + ret_str = generalize_op_params("{}", inputs, outputs, attrs, generalize_config) + return [json.loads(ret_str)] +''' + +ATTR_DEFAULT = {'bool': 'False', 'int': '0'} + + +class AdpBuilder(opdesc_parser.OpDesc): + def __init__(self: any, op_type: str): + self.argsname = [] + self.argsdefv = [] + super().__init__(op_type) + + def write_adapt(self: any, impl_path, path: str): + self._build_paradefault() + if impl_path != "": + src_file = os.path.join(impl_path, self.op_file + '.cpp') + if not os.path.exists(src_file): + return + out_path = os.path.abspath(path) + if self.dynamic_shape and not out_path.endswith('dynamic'): + out_path = os.path.join(path, 'dynamic') + os.makedirs(out_path, exist_ok=True) + adpfile = os.path.join(out_path, self.op_file + '.py') + with os.fdopen(os.open(adpfile, const_var.WFLAGS, const_var.WMODES), 'w') as fd: + self._write_head(fd) + self._write_argparse(fd) + self._write_impl(fd) + if self.op_chk_support: + self._write_cap('check_supported', fd) + self._write_cap('get_op_support_info', fd) + if self.op_fmt_sel: + self._write_cap('op_select_format', fd) + self._write_cap('get_op_specific_info', fd) + if self.op_range_limit == 'limited' or self.op_range_limit == 'dynamic': + self._write_glz(fd) + + def _ip_argpack(self: any, default: bool = True) -> list: + args = [] + for i in range(len(self.input_name)): + arg = self.input_name[i] + if default and self.argsdefv[i] is not None: + arg += '=' + self.argsdefv[i] + args.append(arg) + return args + + def _op_argpack(self: any, default: bool = True) -> list: + args = [] + argidx = len(self.input_name) + for i in range(len(self.output_name)): + arg = self.output_name[i] + if default and self.argsdefv[i + argidx] is not None: + arg += '=' + self.argsdefv[i + argidx] + args.append(arg) + return args + + def _attr_argpack(self: any, default: bool = True) -> list: + args = [] + argidx = len(self.input_name) + len(self.output_name) + for i in range(len(self.attr_list)): + att = self.attr_list[i] + arg = att + if default and self.argsdefv[i + argidx] is not None: + if self.attr_val.get(att).get('type') == 'str': + arg += '="' + self.argsdefv[i + argidx] + '"' + elif self.attr_val.get(att).get('type') == 'bool': + arg += '="' + self.argsdefv[i + argidx].capitalize() + '"' + else: + arg += '=' + self.argsdefv[i + argidx] + args.append(arg) + return args + + def _build_paralist(self: any, default: bool = True) -> str: + args = [] + args.extend(self._ip_argpack(default)) + args.extend(self._op_argpack(default)) + args.extend(self._attr_argpack(default)) + return ', '.join(args) + + def _io_parachk(self: any, types: list, type_name: str) -> list: + chk = [] + for iot in types: + if iot == 'optional': + ptype = 'OPTION' + else: + ptype = iot.upper() + chk.append('para_check.{}_{}'.format(ptype, type_name)) + return chk + + def _attr_parachk(self: any) -> list: + chk = [] + for att in self.attr_list: + if self.attr_val.get(att).get('paramType') == 'optional': + pt = 'OPTION' + else: + pt = self.attr_val.get(att).get('paramType').upper() + att_type = self.attr_val.get(att).get('type').upper() + att_type = att_type.replace('LIST', 'LIST_') + chk.append('para_check.{}_ATTR_{}'.format(pt, att_type)) + return chk + + def _build_parachk(self: any) -> str: + chk = [] + chk.extend(self._io_parachk(self.input_type, 'INPUT')) + chk.extend(self._io_parachk(self.output_type, 'OUTPUT')) + chk.extend(self._attr_parachk()) + chk.append('para_check.KERNEL_NAME') + return ', '.join(chk) + + def _build_paradefault(self: any): + optional = False + argtypes = [] + argtypes.extend(self.input_type) + argtypes.extend(self.output_type) + for atype in argtypes: + if atype == 'optional': + optional = True + if optional: + self.argsdefv.append('None') + else: + self.argsdefv.append(None) + for attr in self.attr_list: + atype = self.attr_val.get(attr).get('paramType') + if atype == 'optional': + optional = True + attrval = self.attr_val.get(attr).get('defaultValue') + if attrval is not None: + optional = True + if type == "bool": + attrval = attrval.capitalize() + elif type == "str": + attrval = "\"" + attrval + "\"" + self.argsdefv.append(attrval) + continue + if optional: + self.argsdefv.append(ATTR_DEFAULT.get(self.attr_val.get(attr).get('type'))) + else: + self.argsdefv.append(None) + + def _write_head(self: any, fd: object): + fd.write(IMPL_HEAD) + + def _write_argparse(self: any, fd: object): + args = self._build_paralist(False) + fd.write('def _build_args({}):\n'.format(args)) + fd.write(' inputs = []\n') + fd.write(' for arg in [{}]:\n'.format(', '.join(self.input_name))) + fd.write(' if arg != None:\n') + fd.write(' inputs.append(arg)\n') + fd.write(' outputs = []\n') + fd.write(' for arg in [{}]:\n'.format(', '.join(self.output_name))) + fd.write(' if arg != None:\n') + fd.write(' outputs.append(arg)\n') + fd.write(' attrs = []\n') + for attr in self.attr_list: + fd.write(' if {} != None:\n'.format(attr)) + fd.write(' attr = {}\n') + fd.write(' attr["name"] = "{}"\n'.format(attr)) + fd.write(' attr["dtype"] = "{}"\n'.format(self.attr_val.get(attr).get('type'))) + fd.write(' attr["value"] = {}\n'.format(attr)) + fd.write(' attrs.append(attr)\n') + fd.write(' return inputs, outputs, attrs\n') + + def _write_impl(self: any, fd: object): + argsdef = self._build_paralist() + argsval = self._build_paralist(False) + pchk = self._build_parachk() + if len(self.kern_name) > 0: + kern_name = self.kern_name + else: + kern_name = self.op_intf + src = self.op_file + '.cpp' + fd.write(IMPL_API.format(pchk, self.op_intf, argsdef, kern_name, argsval, self.op_intf, src)) + if self.op_replay_flag: + fd.write(REPLAY_OP_API.format(self.op_file, self.op_type, self.op_file)) + else: + fd.write(COMPILE_OP_API.format(self.op_type)) + + def _write_cap(self: any, cap_name: str, fd: object): + argsdef = self._build_paralist() + argsval = self._build_paralist(False) + if cap_name == 'check_supported': + fd.write(SUP_API.format(cap_name, argsdef, argsval, cap_name, self.op_type)) + else: + fd.write(CAP_API.format(cap_name, argsdef, argsval, cap_name, self.op_type)) + + def _write_glz(self: any, fd: object): + argsdef = self._build_paralist() + argsval = self._build_paralist(False) + fd.write(GLZ_API.format(self.op_type, self.op_intf, argsdef, argsval, self.op_type)) + + +def write_scripts(cfgfile: str, cfgs: dict, dirs: dict, ops: list = None): + batch_lists = cfgs.get(const_var.REPLAY_BATCH).split(';') + iterator_lists = cfgs.get(const_var.REPLAY_ITERATE).split(';') + file_map = {} + op_descs = opdesc_parser.get_op_desc(cfgfile, batch_lists, iterator_lists, AdpBuilder, ops) + for op_desc in op_descs: + op_desc.write_adapt(dirs.get(const_var.CFG_IMPL_DIR), dirs.get(const_var.CFG_OUT_DIR)) + file_map[op_desc.op_type] = op_desc.op_file + return file_map + +if __name__ == '__main__': + if len(sys.argv) <= 5: + raise RuntimeError('arguments must greater equal than 5') + rep_cfg = {} + rep_cfg[const_var.REPLAY_BATCH] = sys.argv[2] + rep_cfg[const_var.REPLAY_ITERATE] = sys.argv[3] + cfg_dir = {} + cfg_dir[const_var.CFG_IMPL_DIR] = sys.argv[4] + cfg_dir[const_var.CFG_OUT_DIR] = sys.argv[5] + write_scripts(sys.argv[1], rep_cfg, cfg_dir) diff --git a/cust_op/cust_op_by_addr/cmake/util/tik2_ops_config.py b/cust_op/cust_op_by_addr/cmake/util/tik2_ops_config.py new file mode 100644 index 00000000..2c881b67 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/tik2_ops_config.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" + +import sys +import os +import glob +import json +import argparse +import const_var + + +def load_json(json_file: str): + with open(json_file, encoding='utf-8') as file: + json_content = json.load(file) + return json_content + + +def get_specified_suffix_file(root_dir, suffix): + specified_suffix = os.path.join(root_dir, '**/*.{}'.format(suffix)) + all_suffix_files = glob.glob(specified_suffix, recursive=True) + return all_suffix_files + + +def add_simplified_config(op_type, key, objfile, config): + simple_cfg = config.get('binary_info_config.json') + op_cfg = simple_cfg.get(op_type) + if not op_cfg: + op_cfg = {} + op_cfg['dynamicRankSupport'] = True + op_cfg['simplifiedKeyMode'] = 0 + op_cfg['binaryList'] = [] + simple_cfg[op_type] = op_cfg + bin_list = op_cfg.get('binaryList') + bin_list.append({'simplifiedKey': key, 'binPath': objfile}) + + +def add_op_config(op_file, bin_info, config): + op_cfg = config.get(op_file) + if not op_cfg: + op_cfg = {} + op_cfg['binList'] = [] + config[op_file] = op_cfg + op_cfg.get('binList').append(bin_info) + + +def gen_ops_config(json_file, soc, config): + contents = load_json(json_file) + if ('binFileName' not in contents) or ('supportInfo' not in contents): + return + json_base_name = os.path.basename(json_file) + op_dir = os.path.basename(os.path.dirname(json_file)) + support_info = contents.get('supportInfo') + bin_name = contents.get('binFileName') + bin_suffix = contents.get('binFileSuffix') + bin_file_name = bin_name + bin_suffix + op_type = bin_name.split('_')[0] + op_file = op_dir + '.json' + bin_info = {} + key = support_info.get('simplifiedKey') + if key: + bin_info['simplifiedKey'] = key + add_simplified_config(op_type, key, os.path.join(soc, op_dir, bin_file_name), config) + bin_info['staticKey'] = support_info.get('staticKey') + bin_info['int64Mode'] = support_info.get('int64Mode') + bin_info['inputs'] = support_info.get('inputs') + bin_info['outputs'] = support_info.get('outputs') + if support_info.get('attrs'): + bin_info['attrs'] = support_info.get('attrs') + bin_info['binInfo'] = {'jsonFilePath': os.path.join(soc, op_dir, json_base_name)} + add_op_config(op_file, bin_info, config) + + +def gen_all_config(root_dir, soc): + suffix = 'json' + config = {} + config['binary_info_config.json'] = {} + all_json_files = get_specified_suffix_file(root_dir, suffix) + for _json in all_json_files: + gen_ops_config(_json, soc, config) + for cfg_key in config.keys(): + cfg_file = os.path.join(root_dir, cfg_key) + with os.fdopen(os.open(cfg_file, const_var.WFLAGS, const_var.WMODES), 'w') as fd: + json.dump(config.get(cfg_key), fd, indent=' ') + + +def args_prase(): + parser = argparse.ArgumentParser() + parser.add_argument('-p', + '--path', + nargs='?', + required=True, + help='Parse the path of the json file.') + parser.add_argument('-s', + '--soc', + nargs='?', + required=True, + help='Parse the soc_version of ops.') + return parser.parse_args() + + +def main(): + args = args_prase() + gen_all_config(args.path, args.soc) + + +if __name__ == '__main__': + main() diff --git a/cust_op/cust_op_by_addr/cmake/util/tik2_replay_build.py b/cust_op/cust_op_by_addr/cmake/util/tik2_replay_build.py new file mode 100644 index 00000000..1cac7d91 --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/tik2_replay_build.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +Created on Feb 28 20:56:45 2020 +Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. +""" + +import sys +import os +import opdesc_parser +import replay_codegen +import const_var +from replay_codegen import ReplayCodeGenParams + +PYF_PATH = os.path.dirname(os.path.realpath(__file__)) + + +class ReplayBuilder(opdesc_parser.OpDesc): + def __init__(self: any, op_type: str): + super().__init__(op_type) + + def gen_replay_source(self: any, impl_path: str, out_path: str, ops_product: str): + if not self.op_replay_flag: + print('{} replay not enabled'.format(self.op_type)) + return + argn = len(self.input_name) + len(self.output_name) + 1 + if self.op_replay_batch: + print('{} replay in batch mode'.format(self.op_type)) + else: + print('{} replay in normal mode'.format(self.op_type)) + if impl_path.endswith('op_kernel'): + implf = os.path.join(impl_path, self.op_file + '.cpp') + tiling_file = os.path.join(impl_path, "../op_host", self.op_file + '_tiling.h') + else: + if self.dynamic_shape: + dyn_path = 'dynamic' + else: + dyn_path = '' + implf = os.path.join(impl_path, dyn_path, self.op_file + '.cpp') + tiling_file = os.path.join(impl_path, "../../op_tiling", self.op_file + '_tiling.h') + rep_conf = replay_codegen.ReplayCodeGen(ReplayCodeGenParams(self.op_type, implf, tiling_file, self.op_file, \ + self.op_intf, argn, self.op_replay_batch, self.max_block_dim, self.max_shape_size)) + rep_conf.set_batch(self.op_replay_batch) + rep_conf.set_outdir(out_path) + rep_conf.gen_replay(ops_product) + + +def gen_replay(cfgfile: str, cfgs: dict, dirs: dict, ops_product: str, ops: list = None): + batch_lists = cfgs.get(const_var.REPLAY_BATCH).split(';') + iterator_lists = cfgs.get(const_var.REPLAY_ITERATE).split(';') + op_descs = opdesc_parser.get_op_desc(cfgfile, batch_lists, iterator_lists, ReplayBuilder, ops) + for op_desc in op_descs: + op_desc.gen_replay_source(dirs.get(const_var.CFG_IMPL_DIR), dirs.get(const_var.CFG_OUT_DIR), ops_product) + + +if __name__ == '__main__': + if len(sys.argv) <= 6: + raise RuntimeError('arguments must greater than 6') + rep_cfg = {} + rep_cfg[const_var.REPLAY_BATCH] = sys.argv[2] + rep_cfg[const_var.REPLAY_ITERATE] = sys.argv[3] + rep_dir = {} + rep_dir[const_var.CFG_IMPL_DIR] = sys.argv[4] + rep_dir[const_var.CFG_OUT_DIR] = sys.argv[5] + gen_replay(sys.argv[1], rep_cfg, rep_dir, sys.argv[6]) diff --git a/cust_op/cust_op_by_addr/cmake/util/tiling_data_def_build.py b/cust_op/cust_op_by_addr/cmake/util/tiling_data_def_build.py new file mode 100644 index 00000000..678756cb --- /dev/null +++ b/cust_op/cust_op_by_addr/cmake/util/tiling_data_def_build.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# coding=utf-8 +""" +Function: +The replay funtion entry +Copyright Information: +Huawei Technologies Co., Ltd. All Rights Reserved © 2020 +""" + +import sys +import os +import stat +import re +import const_var + + +def gen_tiling(tiling_header_file: str, tiling_file_out: str): + if not os.path.exists(tiling_header_file): + print("warning: no userdef tiling header file: ", tiling_header_file) + return + print("generate tiling def header file: ", tiling_file_out) + tiling_source = '#include \n' + tiling_source += '#include \n\n' + end_source = "" + pattern = re.compile(r'[(](.*)[)]', re.S) + with open(tiling_header_file, 'r') as fd: + lines = fd.readlines() + for line in lines: + line = line.strip() + if (line.startswith('BEGIN_TILING_DATA_DEF')): + tiling_source += '#pragma pack(1)\n' + tiling_source += 'struct ' + struct_def = re.findall(pattern, line)[0] + tiling_source += struct_def + ' {\n' + elif (line.startswith('TILING_DATA_FIELD_DEF_ARR')): + field_params = re.findall(pattern, line)[0] + fds = field_params.split(',') + tiling_source += ' {} {}[{}] = {{}};\n'.format(fds[0].strip(), fds[2].strip(), fds[1].strip()) + elif (line.startswith('TILING_DATA_FIELD_DEF_STRUCT')): + field_params = re.findall(pattern, line)[0] + fds = field_params.split(',') + tiling_source += ' {} {};\n'.format(fds[0].strip(), fds[1].strip()) + elif (line.startswith('TILING_DATA_FIELD_DEF')): + field_params = re.findall(pattern, line)[0] + fds = field_params.split(',') + tiling_source += ' {} {} = 0;\n'.format(fds[0].strip(), fds[1].strip()) + elif (line.startswith('END_TILING_DATA_DEF')): + tiling_source += '};\n' + tiling_source += '#pragma pack()\n\n' + tiling_source += '#ifdef __NPU_TILING__\n' + tiling_source += \ + 'inline [aicore] void InitTilingData(const __gm__ uint8_t* tiling, {}* const_data)\n'\ + .format(struct_def) + tiling_source += '{\n' + tiling_source += '}\n' + tiling_source += '#else\n' + tiling_source += 'inline void InitTilingData(uint8_t* tiling, {}* const_data)\n'.format(struct_def) + tiling_source += '{\n' + tiling_source += ' memcpy(const_data, tiling, sizeof({}));\n'.format(struct_def) + tiling_source += '}\n' + tiling_source += '#endif\n\n' + end_source = ''' +#define GET_TILING_DATA(tiling_data, tiling_arg) \\ +{} tiling_data; \\ +InitTilingData(tiling_arg, &tiling_data)\n +'''.format(struct_def) + tiling_source += end_source + with os.fdopen(os.open(tiling_file_out, const_var.WFLAGS, const_var.WMODES), 'w') as ofd: + ofd.write(tiling_source) + + +if __name__ == '__main__': + if len(sys.argv) <= 2: + raise RuntimeError('arguments must greater than 2') + gen_tiling(sys.argv[1], sys.argv[2]) diff --git a/cust_op/cust_op_by_addr/framework/CMakeLists.txt b/cust_op/cust_op_by_addr/framework/CMakeLists.txt new file mode 100644 index 00000000..b6be9b49 --- /dev/null +++ b/cust_op/cust_op_by_addr/framework/CMakeLists.txt @@ -0,0 +1,11 @@ +if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/mindspore") + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/caffe_plugin") + add_subdirectory(caffe_plugin) + endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/tf_plugin") + add_subdirectory(tf_plugin) + endif() + if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/onnx_plugin") + add_subdirectory(onnx_plugin) + endif() +endif() diff --git a/cust_op/cust_op_by_addr/framework/tf_plugin/CMakeLists.txt b/cust_op/cust_op_by_addr/framework/tf_plugin/CMakeLists.txt new file mode 100644 index 00000000..18b3f140 --- /dev/null +++ b/cust_op/cust_op_by_addr/framework/tf_plugin/CMakeLists.txt @@ -0,0 +1,8 @@ + +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} plugin_srcs) +add_library(cust_tf_parsers SHARED ${plugin_srcs}) +target_compile_definitions(cust_tf_parsers PRIVATE google=ascend_private) +target_link_libraries(cust_tf_parsers PRIVATE intf_pub graph) +install(TARGETS cust_tf_parsers + LIBRARY DESTINATION packages/vendors/${vendor_name}/framework/tensorflow +) diff --git a/cust_op/cust_op_by_addr/framework/tf_plugin/tensorflow_embedding_lookup_by_address_plugin.cc b/cust_op/cust_op_by_addr/framework/tf_plugin/tensorflow_embedding_lookup_by_address_plugin.cc new file mode 100644 index 00000000..b8a49483 --- /dev/null +++ b/cust_op/cust_op_by_addr/framework/tf_plugin/tensorflow_embedding_lookup_by_address_plugin.cc @@ -0,0 +1,13 @@ +#include "register/register.h" + +namespace domi { +// register op info to GE +REGISTER_CUSTOM_OP("EmbeddingLookupByAddress") +.FrameworkType(TENSORFLOW) // type: CAFFE, TENSORFLOW +.OriginOpType("EmbeddingLookupByAddress") // name in tf module +.ParseParamsByOperatorFn(AutoMappingByOpFn); +REGISTER_CUSTOM_OP("EmbeddingUpdateByAddress") +.FrameworkType(TENSORFLOW) // type: CAFFE, TENSORFLOW +.OriginOpType("EmbeddingUpdateByAddress") // name in tf module +.ParseParamsByOperatorFn(AutoMappingByOpFn); +} // namespace domi \ No newline at end of file diff --git a/cust_op/cust_op_by_addr/op_host/CMakeLists.txt b/cust_op/cust_op_by_addr/op_host/CMakeLists.txt new file mode 100644 index 00000000..005b7d01 --- /dev/null +++ b/cust_op/cust_op_by_addr/op_host/CMakeLists.txt @@ -0,0 +1,35 @@ + +aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} ops_srcs) + +opbuild(OPS_SRC ${ops_srcs} + OUT_DIR ${ASCEND_AUTOGEN_PATH} +) + +add_library(cust_op_proto SHARED ${ops_srcs} ${ASCEND_AUTOGEN_PATH}/op_proto.cc) +target_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB) +target_link_libraries(cust_op_proto PRIVATE intf_pub exe_graph register) +set_target_properties(cust_op_proto PROPERTIES OUTPUT_NAME + cust_opsproto_rt2.0 +) + +add_library(cust_optiling SHARED ${ops_srcs}) +target_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB) +target_link_libraries(cust_optiling PRIVATE intf_pub exe_graph register) +set_target_properties(cust_optiling PROPERTIES OUTPUT_NAME + cust_opmaster_rt2.0 +) + +add_custom_target(optiling_compat ALL + COMMAND ln -sf lib/linux/${CMAKE_SYSTEM_PROCESSOR}/$ + ${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so +) + +install(TARGETS cust_op_proto + LIBRARY DESTINATION packages/vendors/${vendor_name}/op_proto/lib/linux/${CMAKE_SYSTEM_PROCESSOR}) +install(FILES ${ASCEND_AUTOGEN_PATH}/op_proto.h + DESTINATION packages/vendors/${vendor_name}/op_proto/inc) +install(TARGETS cust_optiling + LIBRARY DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling/lib/linux/${CMAKE_SYSTEM_PROCESSOR}) +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so + DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling) + diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp new file mode 100644 index 00000000..6e3ea142 --- /dev/null +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -0,0 +1,146 @@ + +#include "embedding_lookup_by_address_tiling.h" +#include "register/op_def_registry.h" + +namespace optiling +{ + struct TilingCompileInfo + { + int64_t ub_size; + }; + + static ge::graphStatus TilingFunc(gert::TilingContext *context) + { + TilingData1 tiling; + int32_t block_total_nums = 48; + int32_t ub_limit = 160 * 1024; + auto *attrs = context->GetAttrs(); + const auto *attr0_value = attrs->GetAttrPointer(0); + int32_t embbeding_dim = *attr0_value; + const auto *attr1_value = attrs->GetAttrPointer(1); + int32_t embbeding_type = *attr1_value; + int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); + + tiling.set_embbeding_type(embbeding_type); + tiling.set_update_dim(embbeding_dim); + tiling.set_addr_nums(input_shape); + tiling.set_ub_limit(ub_limit); + + context->SetBlockDim(block_total_nums); + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; + } + + static ge::graphStatus TilingPrepare(gert::TilingParseContext *context) + { + return ge::GRAPH_SUCCESS; + } + + static int check_op_support(const ge::Operator &op, ge::AscendString &result) + { + std::string res_json_str = "{\"ret_code\": \"0\",\"reason\": \"check_supported_stub\"}"; + result = ge::AscendString(res_json_str.c_str()); + return 1; + } +} + +namespace ge +{ + ge::graphStatus InferShape1(gert::InferShapeContext *context) + { + + gert::Shape *y_shape = context->GetOutputShape(0); + auto *attrs = context->GetAttrs(); + const auto *attr0_value = attrs->GetAttrPointer(0); + int64_t update_dim = *attr0_value; + int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); + y_shape->SetDimNum(2); + y_shape->SetDim(0, input_shape); + y_shape->SetDim(1, update_dim); + return GRAPH_SUCCESS; + } + ge::graphStatus InferShapeRange1(gert::InferShapeRangeContext *context) + { + return GRAPH_SUCCESS; + } + ge::graphStatus InferDataType1(gert::InferDataTypeContext *context) + { + + int64_t embbeding_type; + auto *attrs = context->GetAttrs(); + const auto *attr1_value = attrs->GetAttrPointer(1); + if (attr1_value == nullptr) + { + printf(" Lookup embbeding_type nullptr\n"); + } + else + { + embbeding_type = *attr1_value; + } + if (embbeding_type == 0) + { + context->SetOutputDataType(0, ge::DataType(DT_INT32)); + } + else if (embbeding_type == 1) + { + context->SetOutputDataType(0, ge::DataType(DT_FLOAT)); + } + else if (embbeding_type == 2) + { + + context->SetOutputDataType(0, ge::DataType(DT_FLOAT16)); + } + else + { + context->SetOutputDataType(0, ge::DataType(DT_FLOAT)); + } + + return GRAPH_SUCCESS; + } +} + +namespace ops +{ + class EmbeddingLookupByAddress : public OpDef + { + public: + EmbeddingLookupByAddress(const char *name) : OpDef(name) + { + this->Input("address") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("y") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("embedding_dim").AttrType(OPTIONAL).Int(32); + this->Attr("embedding_type").AttrType(OPTIONAL).Int(1); + this->SetInferShape(ge::InferShape1) + .SetInferDataType(ge::InferDataType1); + + this->AICore() + .SetTiling(optiling::TilingFunc) + .SetTilingParse(optiling::TilingPrepare) + .SetCheckSupport(optiling::check_op_support); + + OpAICoreConfig aicConfig; + aicConfig.AsyncFlag(true) + .DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(false) + .RangeLimitValue("limited"); + this->AICore().AddConfig("ascend910b", aicConfig); + this->AICore().AddConfig("ascend910", aicConfig); + } + }; + + OP_ADD(EmbeddingLookupByAddress, optiling::TilingCompileInfo); +} diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h new file mode 100644 index 00000000..12c45086 --- /dev/null +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h @@ -0,0 +1,14 @@ +#include "register/tilingdata_base.h" + +namespace optiling +{ + BEGIN_TILING_DATA_DEF(TilingData1) + TILING_DATA_FIELD_DEF(int32_t, update_dim); + TILING_DATA_FIELD_DEF(int32_t, addr_nums); + TILING_DATA_FIELD_DEF(int32_t, ub_limit); + TILING_DATA_FIELD_DEF(int32_t, embbeding_type); + TILING_DATA_FIELD_DEF(int32_t, update_type); + END_TILING_DATA_DEF; + + REGISTER_TILING_DATA_CLASS(EmbeddingLookupByAddress, TilingData1) +} \ No newline at end of file diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp new file mode 100644 index 00000000..d27b2ac2 --- /dev/null +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -0,0 +1,137 @@ + +#include "embedding_update_by_address_tiling.h" +#include "register/op_def_registry.h" + +namespace optiling +{ + struct TilingCompileInfo + { + int64_t ub_size; + }; + + static ge::graphStatus TilingFunc(gert::TilingContext *context) + { + TilingData2 tiling; + + int32_t block_total_nums = 48; + int32_t ub_limit = 160 * 1024; + + int32_t update_dim; + int32_t embbeding_type; + + int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); + int32_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; + int32_t update_type=*(context->GetAttrs()->GetAttrPointer(0)); + ge::DataType input_datatype = context->GetInputTensor(1)->GetDataType(); + + switch (input_datatype) + { + case ge::DT_FLOAT16: + embbeding_type = 2; + break; + case ge::DT_FLOAT: + embbeding_type = 1; + break; + case ge::DT_INT32: + embbeding_type = 0; + break; + default: + embbeding_type = 1; + break; + } + + update_dim = input_dim; + tiling.set_update_type(update_type); + tiling.set_embbeding_type(embbeding_type); + tiling.set_update_dim(update_dim); + tiling.set_addr_nums(input_shape); + tiling.set_ub_limit(ub_limit); + context->SetBlockDim(block_total_nums); + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; + } + + static ge::graphStatus TilingPrepare(gert::TilingParseContext *context) + { + return ge::GRAPH_SUCCESS; + } + + static int check_op_support(const ge::Operator &op, ge::AscendString &result) + { + std::string res_json_str = "{\"ret_code\": \"0\",\"reason\": \"check_supported_stub\"}"; + result = ge::AscendString(res_json_str.c_str()); + return 1; + } +} + +namespace ge +{ + ge::graphStatus InferShape(gert::InferShapeContext *context) + { + gert::Shape *y_shape = context->GetOutputShape(0); + int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); + int64_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; + y_shape->SetDimNum(2); + y_shape->SetDim(0, input_shape); + y_shape->SetDim(1, input_dim); + return GRAPH_SUCCESS; + } + ge::graphStatus InferShapeRange(gert::InferShapeRangeContext *context) + { + return GRAPH_SUCCESS; + } + ge::graphStatus InferDataType(gert::InferDataTypeContext *context) + { + context->SetOutputDataType(0, ge::DataType(DT_FLOAT)); + return GRAPH_SUCCESS; + } +} + +namespace ops +{ + class EmbeddingUpdateByAddress : public OpDef + { + public: + EmbeddingUpdateByAddress(const char *name) : OpDef(name) + { + this->Input("address") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("embedding") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("y") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("update_type").AttrType(OPTIONAL).Int(0); + this->SetInferShape(ge::InferShape) + .SetInferDataType(ge::InferDataType); + + this->AICore() + .SetTiling(optiling::TilingFunc) + .SetTilingParse(optiling::TilingPrepare) + .SetCheckSupport(optiling::check_op_support); + + OpAICoreConfig aicConfig; + aicConfig.AsyncFlag(true) + .DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(false) + .RangeLimitValue("limited"); + this->AICore().AddConfig("ascend910b", aicConfig); + this->AICore().AddConfig("ascend910", aicConfig); + } + }; + OP_ADD(EmbeddingUpdateByAddress, optiling::TilingCompileInfo); +} diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h new file mode 100644 index 00000000..2a28626c --- /dev/null +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h @@ -0,0 +1,15 @@ + +#include "register/tilingdata_base.h" + +namespace optiling +{ + BEGIN_TILING_DATA_DEF(TilingData2) + TILING_DATA_FIELD_DEF(int32_t, update_dim); + TILING_DATA_FIELD_DEF(int32_t, addr_nums); + TILING_DATA_FIELD_DEF(int32_t, ub_limit); + TILING_DATA_FIELD_DEF(int32_t, embbeding_type); + TILING_DATA_FIELD_DEF(int32_t, update_type); + END_TILING_DATA_DEF; + + REGISTER_TILING_DATA_CLASS(EmbeddingUpdateByAddress, TilingData2) +} \ No newline at end of file diff --git a/cust_op/cust_op_by_addr/op_host/readme.md b/cust_op/cust_op_by_addr/op_host/readme.md new file mode 100644 index 00000000..2f9b4fd9 --- /dev/null +++ b/cust_op/cust_op_by_addr/op_host/readme.md @@ -0,0 +1,218 @@ +# tik2 op_host + + +### Ŀ¼ṹ + +``` + op_host //ԭ&tiling + embedding_update_by_address_tiling.h + embedding_update_by_address.cpp + CMakeLists.txt +``` + + +### tiling.h ļд + +embedding_update_by_address_tiling.h + +˴һṹ + +Щıͨǽ shapeattrֵϢдkerenl +``` +#include "register/tilingdata_base.h" + +namespace optiling +{ + BEGIN_TILING_DATA_DEF(TilingData2) + TILING_DATA_FIELD_DEF(int32_t, update_dim); + TILING_DATA_FIELD_DEF(int32_t, addr_nums); + TILING_DATA_FIELD_DEF(int32_t, ub_limit); + TILING_DATA_FIELD_DEF(int32_t, embbeding_type); + TILING_DATA_FIELD_DEF(int32_t, update_type); + END_TILING_DATA_DEF; + + REGISTER_TILING_DATA_CLASS(EmbeddingUpdateByAddress, TilingData2) +} +``` + +### cpp ļд + +##### ṹ + +``` +static ge::graphStatus TilingFunc(gert::TilingContext *context) //tilingıֵ + +ge::graphStatus InferShape(gert::InferShapeContext *context) //shape + +ge::graphStatus InferShapeRange(gert::InferShapeRangeContext *context) //shaperange + +ge::graphStatus InferDataType(gert::InferDataTypeContext *context) //dtype + +class EmbeddingUpdateByAddress : public OpDef //˴ӵϢ +``` + +##### ͨô + +``` +//ȡ shape +gert::Shape *x_shape = context->GetInputShape(0) +gert::Shape *y_shape = context->GetInputShape(1) +gert::Shape *z_shape = context->GetOutputShape(0) + +std::vector x_dims = x_shape->GetDims(); // shapedims x_dims = {232,123,2} +size_t x_dims_len = x_shape->GetDimNum(); // shapeά x_dims_len = 3 +int64_t x_dim2 = x_shape->GetDim(1); // shapeĵڶά x_dim2 = 123 +int64_t x_shapesize = x_shape->GetDimNum(); // shapeܴС x_shapesize = 232*123*2 + +//shape +*z_shape=Shape(x_dims); //z_shape Ϊ {232,123,2} +*z_shape=*x_shape; //z_shape x_shapeһ + + +//ȡ +DataType input2_dtype = context->GetInputDataType(1); //ȡڶ DataType +context->SetOutputDataType(0, input2_dtype); +context->SetOutputDataType(1, ge::DataType(DT_FLOAT)); + + +//ȡ attr +auto *attrs = context->GetAttrs(); +const auto attr0_value = *(attrs->GetAttrPointer(0)); // attr0_value = ӵ1Եֵ +const auto attr1_value = *(attrs->GetAttrPointer(1)); // attr1_value = ӵ2Եֵ +``` + + + + +##### TilingFunc д + +``` +TilingData1 tiling; + +int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); //ȡС +tiling.set_addr_nums(input_shape); // tilingaddr_nums addr_nums=input_shape +tiling.set_xxx(input_shape); //xxx ӦΪtilingͷļõı +``` + +#### ԭ д + +``` +namespace ge +{ + ge::graphStatus InferShape(gert::InferShapeContext *context) + { + gert::Shape *y_shape = context->GetOutputShape(0); + int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); + int64_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; + y_shape->SetDimNum(2); + y_shape->SetDim(0, input_shape); + y_shape->SetDim(1, input_dim); + return GRAPH_SUCCESS; + } + ge::graphStatus InferShapeRange(gert::InferShapeRangeContext *context) + { + return GRAPH_SUCCESS; + } + ge::graphStatus InferDataType(gert::InferDataTypeContext *context) + { + context->SetOutputDataType(0, ge::DataType(DT_FLOAT)); + return GRAPH_SUCCESS; + } +} +``` + +#### Ϣд + +``` +namespace ops +{ + class EmbeddingUpdateByAddress : public OpDef + { + public: + EmbeddingUpdateByAddress(const char *name) : OpDef(name) + { + this->Input("address") //ӵ1 + .ParamType(REQUIRED) + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("embedding") //ӵ2 + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("y") //ӵ1 + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("update_type").AttrType(OPTIONAL).Int(0); //ӵ1ԣĬֵ 0 + + this->SetInferShape(ge::InferShape) // ԭƵ + .SetInferDataType(ge::InferDataType); + //.SetInferShapeRange(ge::InferShapeRange); //ûãõ + + this->AICore() + .SetTiling(optiling::TilingFunc) + .SetTilingParse(optiling::TilingPrepare) + .SetCheckSupport(optiling::check_op_support); + + OpAICoreConfig aicConfig; + aicConfig.AsyncFlag(true) + .DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true) + .NeedCheckSupportFlag(false) + .PrecisionReduceFlag(false) //ϲܻϾȣfp32->fp16 + .RangeLimitValue("limited"); + this->AICore().AddConfig("ascend910b", aicConfig); //֧910b + this->AICore().AddConfig("ascend910", aicConfig); //֧910 + } + }; + OP_ADD(EmbeddingUpdateByAddress, optiling::TilingCompileInfo); +} + +``` + + + +### tips + +#### Ͳο +``` + enum DataType { + DT_FLOAT = 0, // float type + DT_FLOAT16 = 1, // fp16 type + DT_INT8 = 2, // int8 type + DT_INT16 = 6, // int16 type + DT_UINT16 = 7, // uint16 type + DT_UINT8 = 4, // uint8 type + DT_INT32 = 3, // + DT_INT64 = 9, // int64 type + DT_UINT32 = 8, // unsigned int32 + DT_UINT64 = 10, // unsigned int64 + DT_BOOL = 12, // bool type + DT_DOUBLE = 11, // double type + DT_STRING = 13, // string type + DT_DUAL_SUB_INT8 = 14, /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15, /**< dual output uint8 type */ + DT_COMPLEX64 = 16, // complex64 type + DT_COMPLEX128 = 17, // complex128 type + DT_QINT8 = 18, // qint8 type + DT_QINT16 = 19, // qint16 type + DT_QINT32 = 20, // qint32 type + DT_QUINT8 = 21, // quint8 type + DT_QUINT16 = 22, // quint16 type + DT_RESOURCE = 23, // resource type + DT_STRING_REF = 24, // string ref type + DT_DUAL = 25, // dual output type + DT_VARIANT = 26, // dt_variant type + DT_BF16 = 27, // bf16 type + DT_UNDEFINED = 28, // Used to indicate a DataType field has not been set. + DT_INT4 = 29, // int4 type + DT_MAX // Mark the boundaries of data types + }; + +``` + diff --git a/cust_op/cust_op_by_addr/op_kernel/CMakeLists.txt b/cust_op/cust_op_by_addr/op_kernel/CMakeLists.txt new file mode 100644 index 00000000..6fe828d5 --- /dev/null +++ b/cust_op/cust_op_by_addr/op_kernel/CMakeLists.txt @@ -0,0 +1,80 @@ + +set(tikreplaylib_DIR ${ASCEND_TENSOR_COMPILER_PATH}/tikcpp/tikreplaylib/lib/cmake) +find_package(tikreplaylib REQUIRED) +message(STATUS "PACKAGE tikreplaylib FOUND") + +# replay config +set(BATCH_MODE_REPLAY_LIST ) +set(ITERATOR_MODE_REPLAY_LIST ) + +foreach(compute_unit ${ASCEND_COMPUTE_UNIT}) + + # generate aic-${compute_unit}-ops-info.json + add_ops_info_target(TARGET ops_info_gen_${compute_unit} + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/tbe/op_info_cfg/ai_core/${compute_unit}/aic-${compute_unit}-ops-info.json + OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini + INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/config/${compute_unit} + ) + + # generate tik2 impl py once + if (NOT TARGET tik2_impl_gen) + add_ops_impl_target(TARGET tik2_impl_gen + OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini + OPS_BATCH ${BATCH_MODE_REPLAY_LIST} + OPS_ITERATE ${ITERATOR_MODE_REPLAY_LIST} + IMPL_DIR ${CMAKE_CURRENT_SOURCE_DIR} + OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/tbe + INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/${vendor_name}_impl + ) + endif() + + # dynamic shape binary compile + if (${ENABLE_BINARY_PACKAGE}) + add_bin_compile_target(TARGET tik2_bin_${compute_unit} + OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini + IMPL_DIR ${CMAKE_CURRENT_SOURCE_DIR} + ADP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tbe/dynamic + OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/binary/${compute_unit} + INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/kernel + COMPUTE_UNIT ${compute_unit} + ) + add_dependencies(tik2_bin_${compute_unit} tik2_impl_gen) + endif() + + # generate replay _impl.cpp, _replay.cpp and _entry.cce + add_ops_replay_targets(OPS_BATCH ${BATCH_MODE_REPLAY_LIST} + OPS_ITERATE ${ITERATOR_MODE_REPLAY_LIST} + OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini + IMPL_DIR ${CMAKE_CURRENT_SOURCE_DIR} + OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/tbe/replay/${compute_unit} + INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_replay + COMPUTE_UNIT ${compute_unit} + ) +endforeach() + +# generate npu_supported_ops.json +add_npu_support_target(TARGET npu_supported_ops + OPS_INFO_DIR ${ASCEND_AUTOGEN_PATH} + OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/tbe/op_info_cfg/ai_core + INSTALL_DIR packages/vendors/${vendor_name}/framework/${ASCEND_FRAMEWORK_TYPE} +) + +if(ENABLE_TEST AND EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/testcases) + add_subdirectory(testcases) +endif() + +# install kernel file +if (${ENABLE_SOURCE_PACKAGE}) + file(GLOB KERNEL_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp + ) + install(FILES ${KERNEL_FILES} + DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/${vendor_name}_impl/dynamic + ) + file(GLOB KERNEL_FILES + ${CMAKE_CURRENT_SOURCE_DIR}/*.py + ) + install(FILES ${KERNEL_FILES} + DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/${vendor_name}_impl/dynamic + ) +endif() diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp new file mode 100644 index 00000000..58fed7bb --- /dev/null +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -0,0 +1,235 @@ + +#include "kernel_operator.h" +using namespace tik2; +template +class KernelEimtable +{ +public: + __aicore__ inline KernelEimtable() + { + } + __aicore__ inline void Init(GM_ADDR address, GM_ADDR y) + { + + NeedComputeAddrLen = SingleCoreAddrLen; + if (block_idx == block_num - 1) + { + NeedComputeAddrLen = addr_nums * sizeof(int64_t) - SingleCoreAddrLen * (block_num - 1); + } + round = NeedComputeAddrLen / (roundSize * sizeof(int64_t)); + // pipe alloc memory to queue, the unit is Bytes + pipe.InitBuffer(tbuf, roundSize * sizeof(int64_t)); + + pipe.InitBuffer(inQueue, PingpongNum, Veclen); + pipe.InitBuffer(outQueue, PingpongNum, Veclen); // + + // get start index for current core, core parallel block_indx block_dim + srcAddrGlobal.SetGlobalBuffer((__gm__ int64_t *)(address + block_idx * SingleCoreAddrLen)); + dstDataGm.SetGlobalBuffer((__gm__ T *)(y)); + } + + __aicore__ inline void Init_param(GM_ADDR tiling) + { + GET_TILING_DATA(constData, tiling); + // TODO: user kernel impl + int32_t update_dim = constData.update_dim; + int32_t embbeding_type = constData.embbeding_type; + int32_t block_total_nums = block_num; + int32_t ub_limit = constData.ub_limit; + addr_nums = constData.addr_nums; + if (embbeding_type == 2) + { + single_data_size = 2; + } + else + { + single_data_size = 4; + } + PingpongNum = 1; + int min_move_num = 32 / single_data_size; + once_move_nums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); + + int addr_max_num = ((int)((int)(ub_limit / (sizeof(int64_t) + single_data_size * (once_move_nums * ((int32_t)(update_dim - 1 + once_move_nums) / once_move_nums)) * PingpongNum * 2)) / 4)) * 4; + int singlenum = (int)(addr_nums / block_total_nums); + if (singlenum % 4) + { + singlenum -= singlenum % 4; + } + roundSize = addr_max_num; // addr_max_num; + Veclen = roundSize * single_data_size * once_move_nums; + SingleCoreAddrLen = singlenum * sizeof(int64_t); + cache = roundSize; + dim = update_dim; + } + + __aicore__ inline void Process() + { + + LocalTensor srcAddrLocal = tbuf.Get(roundSize); + + if (round > 0) + { + for (int32_t i = 0; i < round; i++) + { + DataCopy(srcAddrLocal, srcAddrGlobal[i * roundSize], roundSize); + MoveProcess(srcAddrLocal, i, roundSize); + } + } + + int unprocess = (NeedComputeAddrLen / sizeof(int64_t)) % roundSize; + if (unprocess) + { + // 处理 addresslist 不对齐32b + int unprocess_once_copyaddr = unprocess; + if (unprocess_once_copyaddr % 4 != 0) + { + unprocess_once_copyaddr += (4 - unprocess % 4); + } + + DataCopy(srcAddrLocal, srcAddrGlobal[round * roundSize], unprocess_once_copyaddr); + MoveProcess(srcAddrLocal, round, unprocess); + } + } + +private: + __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int sizes) + { + set_flag(PIPE_MTE2, PIPE_S, 0); + wait_flag(PIPE_MTE2, PIPE_S, 0); + LocalTensor dataLocal; + bool isFull = true; + int nums = 0; + int out_index = 0; + int times = once_move_nums / 8; + int tmp_cache = cache - 1; + + for (int i = 0; i < sizes; i++) + { + + dataLocal = isFull ? inQueue.AllocTensor() : dataLocal; + int64_t address = srcAddrLocal.GetValue(i); + + if (address != 0) + { + srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(address)); + DataCopy(dataLocal[once_move_nums * nums], srcDataBufferGm, once_move_nums); + } + else + { + + for (int j = 0; j < times; j++) + { + Duplicate(dataLocal[once_move_nums * nums + j * 8], (T)0, 8); + } + + } + + nums++; + isFull = ( i == tmp_cache || i == sizes - 1); + if (isFull) + { + inQueue.EnQue(dataLocal); + Compute(nums); + CopyOut(out_index, turns, nums); + nums = 0; + out_index = i + 1; + tmp_cache += cache; + } + } + } + + __aicore__ inline void Compute(const int nums) + { + // deque input tensors from VECIN queue + LocalTensor srcLocal = inQueue.DeQue(); + LocalTensor dstLocal = outQueue.AllocTensor(); + + DataCopyParams copyparams; + copyparams.blockCount = 1; + copyparams.blockLen = once_move_nums * sizeof(T) * nums / 32; + DataCopy(dstLocal, srcLocal, copyparams); + + outQueue.EnQue(dstLocal); + inQueue.FreeTensor(srcLocal); + } + + __aicore__ inline void CopyOut(const int index, const int turns, const int nums) + { + LocalTensor dstLocal = outQueue.DeQue(); + + int offset = block_idx * dim * SingleCoreAddrLen / sizeof(int64_t) + (turns * roundSize * dim) + dim * index; +#if defined(__DAV_C220_VEC__) + if (single_data_size == 4) + { + copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm[offset].GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, + nums, dim * sizeof(T), 0, 0, 0, 0); + } + else if (single_data_size == 2) + { + copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm[offset].GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, + nums, dim * sizeof(T), 0, 0, 0, 0); + } +#else + + DataCopy(dstDataGm[offset], dstLocal, once_move_nums * nums); +#endif + outQueue.FreeTensor(dstLocal); + } + +public: + int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, cache, Veclen, dim, PingpongNum; + int32_t addr_nums; + int32_t once_move_nums, single_data_size, update_type; + +private: + TPipe pipe; + TBuf tbuf; + TQue inQueue; + TQue outQueue; + GlobalTensor srcDataBufferGm, dstDataGm, outDataGm; + GlobalTensor srcAddrGlobal; +}; + +extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR address, GM_ADDR y, GM_ADDR tiling) +{ + GET_TILING_DATA(constData, tiling); + // // TODO: user kernel impl + + int32_t embbeding_type = constData.embbeding_type; + + switch (embbeding_type) + { + case 0: + { + KernelEimtable op; + op.Init_param(tiling); + op.Init(address, y); + op.Process(); + } + break; + case 1: + { + KernelEimtable op; + op.Init_param(tiling); + op.Init(address, y); + op.Process(); + } + break; + case 2: + { + KernelEimtable op; + op.Init_param(tiling); + op.Init(address, y); + op.Process(); + } + break; + default: + { + KernelEimtable op; + op.Init_param(tiling); + op.Init(address, y); + op.Process(); + } + break; + } +} diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp new file mode 100644 index 00000000..b5a1e976 --- /dev/null +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -0,0 +1,246 @@ +#include "kernel_operator.h" +using namespace tik2; +template +class KernelEimtable_update +{ +public: + __aicore__ inline KernelEimtable_update() + { + } + __aicore__ inline void Init(GM_ADDR address, GM_ADDR embedding, GM_ADDR y) + { + NeedComputeAddrLen = SingleCoreAddrLen; + if (block_idx == block_num - 1) + { + NeedComputeAddrLen = addr_nums * sizeof(int64_t) - SingleCoreAddrLen * (block_num - 1); + } + round = NeedComputeAddrLen / (roundSize * sizeof(int64_t)); + + pipe.InitBuffer(tbuf, roundSize * sizeof(int64_t)); + pipe.InitBuffer(inQueue, PingpongNum, Veclen); + pipe.InitBuffer(outQueue, PingpongNum, Veclen); + // get start index for current core, core parallel block_indx block_dim + srcAddrGlobal.SetGlobalBuffer((__gm__ int64_t *)(address + block_idx * SingleCoreAddrLen)); + srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(embedding + block_idx * SingleCoreAddrLen / sizeof(int64_t) * sizeof(T) * dim)); + outDataGm.SetGlobalBuffer((__gm__ T *)(y)); + } + + __aicore__ inline void Init_param(GM_ADDR tiling) + { + GET_TILING_DATA(constData, tiling); + // TODO: user kernel impl + int32_t update_dim = constData.update_dim; + int32_t embbeding_type = constData.embbeding_type; + int32_t block_total_nums = block_num; + int32_t ub_limit = constData.ub_limit; + update_type = constData.update_type; + addr_nums = constData.addr_nums; + if (embbeding_type == 2) + { + single_data_size = 2; + } + else + { + single_data_size = 4; + } + PingpongNum = 1; + int min_move_num = 32 / single_data_size; + once_move_nums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); + + int addr_max_num = ((int)((int)(ub_limit / (sizeof(int64_t) + single_data_size * (once_move_nums * ((int32_t)(update_dim - 1 + once_move_nums) / once_move_nums)) * PingpongNum * 2)) / 4)) * 4; + int singlenum = (int)(addr_nums / block_total_nums); + if (singlenum % 4) + { + singlenum -= singlenum % 4; + } + roundSize = addr_max_num; // addr_max_num; + Veclen = roundSize * single_data_size * once_move_nums; + SingleCoreAddrLen = singlenum * sizeof(int64_t); + cache = roundSize; + dim = update_dim; + } + + __aicore__ inline void Process() + { + + LocalTensor srcAddrLocal = tbuf.Get(roundSize); + + int unprocess = (NeedComputeAddrLen / sizeof(int64_t)) % roundSize; + + if (round > 0) + { + for (int32_t i = 0; i < round; i++) + { + DataCopy(srcAddrLocal, srcAddrGlobal[i * roundSize], roundSize); + MoveProcess(srcAddrLocal, i, roundSize); + } + } + + if (unprocess) + { + int unprocess_once_copyaddr = unprocess; + if (unprocess_once_copyaddr % 4 != 0) + { + unprocess_once_copyaddr += (4 - unprocess % 4); + } + + DataCopy(srcAddrLocal, srcAddrGlobal[round * roundSize], unprocess_once_copyaddr); + MoveProcess(srcAddrLocal, round, unprocess); + } + } + +private: + __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int sizes) + { + set_flag(PIPE_MTE2, PIPE_S, 0); + wait_flag(PIPE_MTE2, PIPE_S, 0); + LocalTensor dataLocal; + int out_index = 0; + int offset = 0; + int64_t address = 0; + if (dim == once_move_nums) + { + dataLocal = inQueue.AllocTensor(); + DataCopy(dataLocal, srcDataBufferGm[turns * roundSize], sizes * once_move_nums); + inQueue.EnQue(dataLocal); + Compute(sizes); + LocalTensor dstLocal = outQueue.DeQue(); + if (update_type == 0) + { + SetAtomicAdd(); + } + for (int i = 0; i < sizes; i++) + { + address = srcAddrLocal.GetValue(i); + if (address != 0) + { + dstDataGm.SetGlobalBuffer((__gm__ T*)(address)); + DataCopy(dstDataGm, dstLocal[i*once_move_nums], once_move_nums); + } + } + if (update_type == 0) + { + SetAtomicNone(); + } + outQueue.FreeTensor(dstLocal); + } + else + { + for (int i = 0; i < sizes; i++) + { + dataLocal = inQueue.AllocTensor(); + DataCopy(dataLocal, srcDataBufferGm[i * dim + turns * roundSize], once_move_nums); + inQueue.EnQue(dataLocal); + Compute(1); + address = srcAddrLocal.GetValue(i); + CopyOut(address, turns, i); + } + } + } + + __aicore__ inline void Compute(const int nums) + { + // deque input tensors from VECIN queue + LocalTensor srcLocal = inQueue.DeQue(); + LocalTensor dstLocal = outQueue.AllocTensor(); + DataCopyParams copyparams; + copyparams.blockCount = 1; + copyparams.blockLen = once_move_nums * sizeof(T) * nums / 32; + DataCopy(dstLocal, srcLocal, copyparams); + outQueue.EnQue(dstLocal); + inQueue.FreeTensor(srcLocal); + } + + __aicore__ inline void CopyOut(const int64_t address, const int64_t turns, const int64_t index) + { + LocalTensor dstLocal = outQueue.DeQue(); + + int offset = block_idx * dim * SingleCoreAddrLen / sizeof(int64_t) + (turns * roundSize * dim) + dim * index; + + if (address != 0) + { + dstDataGm.SetGlobalBuffer((__gm__ T *)(address)); + + if (update_type == 0) + { + SetAtomicAdd(); + } + +#if defined(__DAV_C220_VEC__) + if (single_data_size == 4) + { + + copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, + 1, dim * sizeof(T), 0, 0, 0, 0); + } + else if (single_data_size == 2) + { + copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, + 1, dim * sizeof(T), 0, 0, 0, 0); + } +#else + DataCopy(dstDataGm, dstLocal, once_move_nums); +#endif + } + if (update_type == 0) + { + SetAtomicNone(); + } + outQueue.FreeTensor(dstLocal); + } + +public: + int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, addr_nums, cache, Veclen, dim, PingpongNum; + int32_t once_move_nums, single_data_size, update_type; + +private: + TPipe pipe; + TBuf tbuf; + TQue inQueue; + TQue outQueue; + GlobalTensor srcDataBufferGm, dstDataGm, outDataGm; + GlobalTensor srcAddrGlobal; +}; + +extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR address, GM_ADDR embedding, GM_ADDR y, GM_ADDR tiling) +{ + GET_TILING_DATA(constData, tiling); + + int32_t embbeding_type = constData.embbeding_type; + + switch (embbeding_type) + { + case 0: + { + KernelEimtable_update op; + op.Init_param(tiling); + op.Init(address, embedding, y); + op.Process(); + } + break; + case 1: + { + KernelEimtable_update op; + op.Init_param(tiling); + op.Init(address, embedding, y); + op.Process(); + } + break; + case 2: + { + KernelEimtable_update op; + op.Init_param(tiling); + op.Init(address, embedding, y); + op.Process(); + } + break; + default: + { + KernelEimtable_update op; + op.Init_param(tiling); + op.Init(address, embedding, y); + op.Process(); + } + break; + } +} diff --git a/cust_op/cust_op_by_addr/readme.md b/cust_op/cust_op_by_addr/readme.md new file mode 100644 index 00000000..a1fb963b --- /dev/null +++ b/cust_op/cust_op_by_addr/readme.md @@ -0,0 +1,197 @@ +# tik2_custom_op + +### +tik2 Զ + + +### Ŀ¼ṹ + +``` + CMakeLists.txt + CMakePresets.json + README.md + build.sh + framework // + op_host //ԭ&tiling + xx.h + xx.cc + CMakeLists.txt + op_kernel //kernel + xx.cc + CMakeLists.txt + cmake + config.cmake + util + makeself + parse_ini_to_json.py + gen_ops_filter.sh +``` + + + +#### Լ +- ưմշ SampleTik2 +- hostkernelĿ¼µļǰ׺շתΪsnake,磺 SampleTik2 -> sample_tik2 +- tilingṹļ _tiling.h, 磺 sample_tik2_tiling.h +- ӵںΪsnakeʽ磺sample_tik2 + + +### ӹ̴/ + +#### ʹmsopgenߴ/ + +ҪȶӵIRjsonļDzο + +``` +[ +{ + "op":"AddDSL", + "input_desc":[ + { + "name":"x1", + "param_type":"required", + "format":[ + "NCHW" + ], + "type":[ + "fp16" + ] + }, + { + "name":"x2", + "param_type":"required", + "format":[ + "NCHW" + ], + "type":[ + "fp16" + ] + } + ], + "output_desc":[ + { + "name":"y", + "param_type":"required", + "format":[ + "NCHW" + ], + "type":[ + "fp16" + ] + } + ] + "attr":[ + { + "name":"n", + "param_type":"required", + "type":"int" + } + ] +} +] +``` + +2ʹmsopgen,jsonļתΪӹ +``` +#ӹ +msopgen gen -i xxx.json -f tf -c ai_core-ascend910 -lan cpp -m 0 -out new_tik2_custom +#ӹ׷ +msopgen gen -i xxx.json -f tf -c ai_core-ascend910 -lan cpp -m 1 -out ./ +``` +Ŀǰ׷СbugҪcmake/config.cmake ļ + +``` +set(ASCEND_COMPUTE_UNIT ascend910 ascend910b) #˴оƬǷ񺭸ӵӦ÷Χ +``` +3вӵ + +``` + op_host + xx_tiling.h //tilingṹ + xx.cc //ԭ/tiling/infershape/Ϣⶨ + op_kernel + xx.cc //kernelʵ +``` + + +#### ֹ + +1дõ3Ӵ뽻빤̶ӦĿ¼ +``` + op_host + new_op_tiling.h //tilingṹ + new_op.cc //ԭ/tiling/infershape/Ϣⶨ + op_kernel + new_op.cc //kernelʵ +``` +ɾӣ ֱɾ op_hostop_kernel Ŀ¼¶ӦӼɡ + +2cmake/config.cmakeļ +``` +set(ASCEND_COMPUTE_UNIT ascend910 ascend910b) #˴оƬǷ񺭸ӵӦ÷Χ +``` +3ڿ ļ + + + + + +### Ӱ뼰װ + +1. ޸CMakePresets.json ļ + +˴ASCEND_CANN_PACKAGE_PATH ޸ΪȷcannĿ¼ + +``` + "ASCEND_CANN_PACKAGE_PATH": { + "type": "PATH", + "value": "/usr/local/Ascend/ascend-toolkit/latest" + }, +``` + +2. û + + +``` +# װtoolkitʱ +source ${HOME}/Ascend/ascend-toolkit/set_env.sh +``` + +3. ԶӰ + +``` +./build.sh +``` + +4. Ӱװ + +ִԶӽԶװ ${ASCEND_OPP_PATH}/vendors Ŀ¼ +``` +./build_out/custom_opp__.run +``` + + + +### TIK2ɲο̳& +https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/63RC1alpha002/operatordevelopment/tik2opdevg/atlastik2_10_0001.html + +https://codehub-y.huawei.com/TIKC-Project/tikcpp_smoke/files?ref=master&filePath=external_tik2_demo + + + +### 빱 + +1. Fork ֿ +2. ½ Feat_xxx ֧ +3. ύ +4. ½ Pull Request + + +### ؼ + +1. ʹ Readme_XXX.md ֲ֧ͬԣ Readme_en.md, Readme_zh.md +2. Gitee ٷ [blog.gitee.com](httpsblog.gitee.com) +3. [httpsgitee.comexplore](httpsgitee.comexplore) ַ˽ Gitee ϵ㿪ԴĿ +4. [GVP](httpsgitee.comgvp) ȫ Gitee мֵԴĿۺ㿪ԴĿ +5. Gitee ٷṩʹֲ [httpsgitee.comhelp](httpsgitee.comhelp) +6. Gitee һչʾ Gitee ԱɵĿ [httpsgitee.comgitee-stars](httpsgitee.comgitee-stars) diff --git a/cust_op/cust_op_by_addr/scripts/install.sh b/cust_op/cust_op_by_addr/scripts/install.sh new file mode 100644 index 00000000..e3988cde --- /dev/null +++ b/cust_op/cust_op_by_addr/scripts/install.sh @@ -0,0 +1,228 @@ +#!/bin/bash +vendor_name=customize +targetdir=/usr/local/Ascend/opp +target_custom=0 + +sourcedir=$PWD/packages +vendordir=vendors/$vendor_name + +QUIET="y" + +for i in "$@" +do + echo $i + if test $i = "--quiet"; then + QUIET="y" + break + fi +done + +log() { + cur_date=`date +"%Y-%m-%d %H:%M:%S"` + echo "[runtime] [$cur_date] "$1 +} + +if [ -n "${ASCEND_CUSTOM_OPP_PATH}" ]; then + if [ ! -d ${ASCEND_CUSTOM_OPP_PATH} ]; then + mkdir -p ${ASCEND_CUSTOM_OPP_PATH} >> /dev/null 2>&1 + if [ $? -ne 0 ]; then + log "[ERROR] create ${ASCEND_CUSTOM_OPP_PATH} failed" + fi + fi + targetdir=${ASCEND_CUSTOM_OPP_PATH} +else + if [ "x${ASCEND_OPP_PATH}" == "x" ]; then + log "[ERROR] env ASCEND_OPP_PATH no exist" + exit 1 + fi + targetdir="${ASCEND_OPP_PATH}" +fi + +if [ ! -d $targetdir ];then + log "[ERROR] $targetdir no exist" + exit 1 +fi + +upgrade() +{ + if [ ! -d ${sourcedir}/$vendordir/$1 ]; then + log "[INFO] no need to upgrade ops $1 files" + return 0 + fi + + if [ ! -d ${targetdir}/$vendordir/$1 ];then + log "[INFO] create ${targetdir}/$vendordir/$1." + mkdir -p ${targetdir}/$vendordir/$1 + if [ $? -ne 0 ];then + log "[ERROR] create ${targetdir}/$vendordir/$1 failed" + return 1 + fi + else + has_same_file=-1 + for file_a in ${sourcedir}/$vendordir/$1/*; do + file_b=${file_a##*/}; + if [ "ls ${targetdir}/$vendordir/$1" = "" ]; then + log "[INFO] ${targetdir}/$vendordir/$1 is empty !!" + return 1 + fi + grep -q $file_b <<<`ls ${targetdir}/$vendordir/$1`; + if [[ $? -eq 0 ]]; then + echo -n "${file_b} " + has_same_file=0 + fi + done + if [ 0 -eq $has_same_file ]; then + if test $QUIET = "n"; then + echo "[INFO]: has old version in ${targetdir}/$vendordir/$1, \ + you want to Overlay Installation , please enter:[o]; \ + or replace directory installation , please enter: [r]; \ + or not install , please enter:[n]." + + while true + do + read orn + if [ "$orn" = n ]; then + return 0 + elif [ "$orn" = m ]; then + break; + elif [ "$0rn" = r ]; then + [ -n "${targetdir}/$vendordir/$1/" ] && rm -rf "${targetdir}/$vendordir/$1"/* + break; + else + echo "[ERROR] input error, please input again!" + fi + done + fi + fi + log "[INFO] replace or merge old ops $1 files .g....." + fi + + log "copy new ops $1 files ......" + if [ -d ${targetdir}/$vendordir/$1/ ]; then + chmod -R +w "$targetdir/$vendordir/$1/" >/dev/null 2>&1 + fi + cp -rf ${sourcedir}/$vendordir/$1/* $targetdir/$vendordir/$1/ + if [ $? -ne 0 ];then + log "[ERROR] copy new $1 files failed" + return 1 + fi + + return 0 +} +upgrade_proto() +{ + if [ ! -f ${sourcedir}/$vendordir/custom.proto ]; then + log "[INFO] no need to upgrade custom.proto files" + return 0 + fi + if [ ! -d ${targetdir}/$vendordir/framework/caffe ];then + log "[INFO] create ${targetdir}/$vendordir/framework/caffe." + mkdir -p ${targetdir}/$vendordir/framework/caffe + if [ $? -ne 0 ];then + log "[ERROR] create ${targetdir}/$vendordir/framework/caffe failed" + return 1 + fi + else + if [ -f ${targetdir}/$vendordir/framework/caffe/custom.proto ]; then + # 有老版本,判断是否要覆盖式安装 + if test $QUIET = "n"; then + echo "[INFO] ${targetdir}/$vendordir/framework/caffe has old version"\ + "custom.proto file. Do you want to replace? [y/n] " + + while true + do + read yn + if [ "$yn" = n ]; then + return 0 + elif [ "$yn" = y ]; then + break; + else + echo "[ERROR] input error, please input again!" + fi + done + fi + fi + log "[INFO] replace old caffe.proto files ......" + fi + chmod -R +w "$targetdir/$vendordir/framework/caffe/" >/dev/null 2>&1 + cp -rf ${sourcedir}/$vendordir/custom.proto ${targetdir}/$vendordir/framework/caffe/ + if [ $? -ne 0 ];then + log "[ERROR] copy new custom.proto failed" + return 1 + fi + log "[INFO] copy custom.proto success" + + return 0 +} + +log "[INFO] copy uninstall sh success" + +if [ ! -d ${targetdir}/vendors ];then + log "[INFO] create ${targetdir}/vendors." + mkdir -p ${targetdir}/vendors + if [ $? -ne 0 ];then + log "[ERROR] create ${targetdir}/vendors failed" + return 1 + fi +fi +chmod u+w ${targetdir}/vendors + +echo "[ops_custom]upgrade framework" +upgrade framework +if [ $? -ne 0 ];then + exit 1 +fi + +echo "[ops_custom]upgrade op proto" +upgrade op_proto +if [ $? -ne 0 ];then + exit 1 +fi + +echo "[ops_custom]upgrade op impl" +upgrade op_impl +if [ $? -ne 0 ];then + exit 1 +fi + +upgrade_proto +if [ $? -ne 0 ];then + exit 1 +fi + +config_file=${targetdir}/vendors/config.ini +if [ ! -f ${config_file} ]; then + touch ${config_file} + chmod 640 ${config_file} + echo "load_priority=$vendor_name" > ${config_file} + if [ $? -ne 0 ];then + echo "echo load_priority failed" + exit 1 + fi +else + found_vendors="$(grep -w "load_priority" "$config_file" | cut --only-delimited -d"=" -f2-)" + found_vendor=$(echo $found_vendors | sed "s/$vendor_name//g" | tr ',' ' ') + vendor=$(echo $found_vendor | tr -s ' ' ',') + if [ "$vendor" != "" ]; then + sed -i "/load_priority=$found_vendors/s@load_priority=$found_vendors@load_priority=$vendor_name,$vendor@g" "$config_file" + fi +fi + +chmod u-w ${targetdir}/vendors + +if [ -d ${targetdir}/$vendordir/op_impl/cpu/aicpu_kernel/impl/ ]; then + chmod -R 440 ${targetdir}/$vendordir/op_impl/cpu/aicpu_kernel/impl/* >/dev/null 2>&1 +fi +if [ -f ${targetdir}/ascend_install.info ]; then + chmod -R 440 ${targetdir}/ascend_install.info +fi +if [ -f ${targetdir}/scene.info ]; then + chmod -R 440 ${targetdir}/scene.info +fi +if [ -f ${targetdir}/version.info ]; then + chmod -R 440 ${targetdir}/version.info +fi + +echo "SUCCESS" +exit 0 + diff --git a/cust_op/cust_op_by_addr/scripts/upgrade.sh b/cust_op/cust_op_by_addr/scripts/upgrade.sh new file mode 100644 index 00000000..2fa595c9 --- /dev/null +++ b/cust_op/cust_op_by_addr/scripts/upgrade.sh @@ -0,0 +1,121 @@ +#!/bin/bash +vendor_name=customize +targetdir=/usr/local/Ascend/opp +target_custom=0 + +sourcedir=$PWD/packages +vendordir=vendors/$vendor_name + +log() { + cur_date=`date +"%Y-%m-%d %H:%M:%S"` + echo "[runtime] [$cur_date] "$1 +} + +if [[ "x${ASCEND_OPP_PATH}" == "x" ]];then + log "[ERROR] env ASCEND_OPP_PATH no exist" + exit 1 +fi + +targetdir=${ASCEND_OPP_PATH} + +if [ ! -d $targetdir ];then + log "[ERROR] $targetdir no exist" + exit 1 +fi + +upgrade() +{ + if [ ! -d ${sourcedir}/$vendordir/$1 ]; then + log "[INFO] no need to upgrade ops $1 files" + return 0 + fi + + iif [ ! -d ${targetdir}/$vendordir/$1 ];then + log "[INFO] create ${targetdir}/$vendordir/$1." + mkdir -p ${targetdir}/$vendordir/$1 + if [ $? -ne 0 ];then + log "[ERROR] create ${targetdir}/$vendordir/$1 failed" + return 1 + fi + else + vendor_installed_dir=$(ls "$targetdir/vendors" 2> /dev/null) + for i in $vendor_installed_dir;do + vendor_installed_file=$(ls "$vendor_installed_dir/$vendor_name/$i" 2> /dev/null) + if [ "$i" = "$vendor_name" ] && [ "$vendor_installed_file" != "" ]; then + echo "[INFO]: $vendor_name custom opp package has been installed on the path $vendor_installed_dir, \ + you want to Overlay Installation , please enter:[o]; \ + or replace directory installation , please enter: [r]; \ + or not install , please enter:[n]." + fi + while true + do + read mrn + if [ "$mrn" = m ]; then + break + elif [ "$mrn" = r ]; then + [ -n "$vendor_installed_file"] && rm -rf "$vendor_installed_file" + break + elif [ "$mrn" = n ]; then + return 0 + else + echo "[WARNING]: Input error, please input m or r or n to choose!" + fi + done + done + log "[INFO] replace old ops $1 files ......" + fi + + log "copy new ops $1 files ......" + cp -rf ${sourcedir}/$vendordir/$1/* $targetdir/$vendordir/$1/ + if [ $? -ne 0 ];then + log "[ERROR] copy new $1 files failed" + return 1 + fi + + return 0 +} +log "[INFO] copy uninstall sh success" + +echo "[ops_custom]upgrade framework" +upgrade framework +if [ $? -ne 0 ];then + exit 1 +fi + +echo "[ops_custom]upgrade op proto" +upgrade op_proto +if [ $? -ne 0 ];then + exit 1 +fi + +echo "[ops_custom]upgrade op impl" +upgrade op_impl +if [ $? -ne 0 ];then + exit 1 +fi + +config_file=${targetdir}/vendors/config.ini +found_vendors="$(grep -w "load_priority" "$config_file" | cut --only-delimited -d"=" -f2-)" +found_vendor=$(echo $found_vendors | sed "s/$vendor_name//g" | tr ',' ' ') +vendor=$(echo $found_vendor | tr -s ' ' ',') +if [ "$vendor" != "" ]; then + sed -i "/load_priority=$found_vendors/s@load_priority=$found_vendors@load_priority=$vendor_name,$vendor@g" "$config_file" +fi + +changemode() +{ + if [ -d ${targetdir} ];then + chmod -R 550 ${targetdir}>/dev/null 2>&1 + fi + + return 0 +} +echo "[ops_custom]changemode..." +#changemode +if [ $? -ne 0 ];then + exit 1 +fi + +echo "SUCCESS" +exit 0 + -- Gitee From 8c412d3679ebbb8d57ee8b02a9274c73968e0fcf Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 14:44:49 +0800 Subject: [PATCH 029/551] Match-id-728c312965e8ee28318cdf71cfd5300dbf741d17 --- src/core/emb_hashmap/emb_hashmap.h | 2 +- .../{emb_mgmt/emb_mgmt.cpp => hybrid_mgmt/hybrid_mgmt.cpp} | 2 +- src/core/{emb_mgmt/emb_mgmt.h => hybrid_mgmt/hybrid_mgmt.h} | 0 src/core/key_process/key_process.cpp | 6 +++--- src/core/utils/unique.h | 6 ++---- src/pybind/module_main.cpp | 2 +- src/tests/key_process/key_process_test.cpp | 2 +- 7 files changed, 9 insertions(+), 11 deletions(-) rename src/core/{emb_mgmt/emb_mgmt.cpp => hybrid_mgmt/hybrid_mgmt.cpp} (99%) rename src/core/{emb_mgmt/emb_mgmt.h => hybrid_mgmt/hybrid_mgmt.h} (100%) diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index a2337d01..3f58bc95 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -9,9 +9,9 @@ #define MX_REC_EMB_HASHMAP_H #include -#include "absl/container/flat_hash_map.h" #include #include +#include "absl/container/flat_hash_map.h" #include "host_emb/host_emb.h" namespace MxRec { diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp similarity index 99% rename from src/core/emb_mgmt/emb_mgmt.cpp rename to src/core/hybrid_mgmt/hybrid_mgmt.cpp index a7d43c54..8ef056be 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -4,7 +4,7 @@ * Author: MindX SDK * Date: 2022/11/15 */ -#include "emb_mgmt.h" +#include "hybrid_mgmt.h" #include #include #include "checkpoint/checkpoint.h" diff --git a/src/core/emb_mgmt/emb_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h similarity index 100% rename from src/core/emb_mgmt/emb_mgmt.h rename to src/core/hybrid_mgmt/hybrid_mgmt.h diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 795bfcd2..6fd38ec4 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -66,11 +66,11 @@ int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, } } spdlog::info(KEY_PROCESS "hot emb count info:{}", hotEmbTotCount); - MPI_Group world_group; - MPI_Comm_group(MPI_COMM_WORLD, &world_group); + MPI_Group worldGroup; + MPI_Comm_group(MPI_COMM_WORLD, &worldGroup); for (auto& i: comm) { for (auto& j: i) { - MPI_Comm_create(MPI_COMM_WORLD, world_group, &j); + MPI_Comm_create(MPI_COMM_WORLD, worldGroup, &j); } } isRunning = true; diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index 9be61498..4509df4c 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -33,8 +33,6 @@ #include "spinlock.h" #include "time_cost.h" -using namespace MxRec; - struct UniqueData { void *inputData; size_t dataSize; @@ -115,7 +113,7 @@ template class Dedup { static const int kDefaultBucketCountMask = kDefaultBucketCount - 1; template struct Meta { - static_assert(M <= UNIQUE_MAX_BUCKET_WIDTH, "should be no larger than max bucket width"); + static_assert(M <= MxRec::UNIQUE_MAX_BUCKET_WIDTH, "should be no larger than max bucket width"); SpinLock lock; volatile int8_t count; int8_t pad[3]; @@ -392,7 +390,7 @@ public: uint32_t UniqueRawForHot(int64_t *output, uint32_t priorTotal, int32_t* idCount, map &hotMap, map &hotPosMap, int &hotCount, - absl::flat_hash_map &keyCountMap) + absl::flat_hash_map &keyCountMap) { uint32_t total = priorTotal; int32_t replace_offset = priorTotal; diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 6d297dc1..3839fe37 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -8,7 +8,7 @@ #include #include -#include "emb_mgmt/emb_mgmt.h" +#include "hybrid_mgmt/hybrid_mgmt.h" #include "module_main.h" namespace py = pybind11; diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 6d1e01b7..fa8ea37d 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -18,7 +18,7 @@ #include "utils/common.h" #include "host_emb/host_emb.h" #include "key_process/key_process.h" -#include "emb_mgmt/emb_mgmt.h" +#include "hybrid_mgmt/hybrid_mgmt.h" using namespace std; using namespace MxRec; -- Gitee From 61909311c9918d930eaa7dafd7026796e0ee1a53 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 14:51:26 +0800 Subject: [PATCH 030/551] Match-id-bf38f5d992ad7731dba1b6531e630dd5493f83e5 --- cust_op/cust_op_by_addr/build.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cust_op/cust_op_by_addr/build.sh b/cust_op/cust_op_by_addr/build.sh index 4be96d7d..f38465f9 100644 --- a/cust_op/cust_op_by_addr/build.sh +++ b/cust_op/cust_op_by_addr/build.sh @@ -6,6 +6,8 @@ mkdir -p build_out rm -rf build_out/* cd build_out +chmod +x $script_path/cmake/util/gen_ops_filter.sh + cmake_version=$(cmake --version | grep "cmake version" | awk '{print $3}') if [ "$cmake_version" \< "3.19.0" ] ; then opts=$(python3 $script_path/cmake/util/preset_parse.py $script_path/CMakePresets.json) -- Gitee From 4c041d58e5c4a59aa70ee4bf50c1adecf671b52d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 15:09:55 +0800 Subject: [PATCH 031/551] Match-id-b51a93008f3291038d47809f59321d37ad54b182 --- src/core/emb_hashmap/emb_hashmap.cpp | 5 +- src/core/emb_hashmap/emb_hashmap.h | 2 +- src/core/emb_mgmt/emb_mgmt.cpp | 6 +- src/core/host_emb/host_emb.cpp | 5 +- src/core/host_emb/host_emb.h | 4 +- src/core/key_process/key_process.cpp | 85 +++++++++++++++------------- src/core/key_process/key_process.h | 11 ++-- src/core/utils/common.h | 23 ++++++++ 8 files changed, 85 insertions(+), 56 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 8f446a65..588a904f 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -37,7 +37,8 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, } } -vector EmbHashMap::Process(const string& embName, const vector& keys, size_t iBatch) +void EmbHashMap::Process(const string& embName, const vector& keys, size_t iBatch, + vector& tmpData) { EASY_FUNCTION(profiler::colors::Pink) auto keepBatch = swapId - iBatch; @@ -46,7 +47,6 @@ vector EmbHashMap::Process(const string& embName, const vectortdt") auto& embHashMap = embHashMaps.at(embName); - vector tmpData; auto lookUpVecSize = static_cast(embHashMap.lookUpVec.size()); tmpData.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); @@ -73,7 +73,6 @@ vector EmbHashMap::Process(const string& embName, const vector(); swapLen(0) = swapSize; EASY_END_BLOCK - return tmpData; } /* diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 424684e5..4cd215e2 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -23,7 +23,7 @@ namespace MxRec { void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); - vector Process(const string& embName, const std::vector& keys, size_t iBatch); + void Process(const string& embName, const std::vector& keys, size_t iBatch, vector& tmpData); void FindAndUpdateOffset(const string& embName, const vector& keys, size_t currentBatchId, size_t keepBatchId); diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index 8cee8a08..98359de5 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -447,7 +447,8 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) auto restore = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); hdTransfer->Send(RESTORE, *restore, channelId, embInfo.name); - auto tmpData = hostHashMaps->Process(embInfo.name, lookupKeys, iBatch); + vector tmpData; + hostHashMaps->Process(embInfo.name, lookupKeys, iBatch, tmpData); hdTransfer->Send(LOOKUP, { tmpData.front() }, channelId, embInfo.name); tmpData.erase(tmpData.begin()); hdTransfer->Send(SWAP, tmpData, channelId, embInfo.name); @@ -506,7 +507,8 @@ void HybridMgmt::EmbHDTrans(int channelId, int batchId) TimeCost tr; for (const auto& embInfo: mgmtEmbInfo) { auto& missingKeys = hostHashMaps->embHashMaps.at(embInfo.name).missingKeysHostPos; - auto h2dEmb = hostEmbs->GetH2DEmb(missingKeys, embInfo.name); // order! + vector h2dEmb; + hostEmbs->GetH2DEmb(missingKeys, embInfo.name, h2dEmb); // order! hdTransfer->Send(H2D, h2dEmb, channelId, embInfo.name, batchId); } for (const auto& embInfo: mgmtEmbInfo) { diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 6fdb1f54..15514529 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -171,10 +171,10 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI * 找到host侧需要发送的emb,通过hdTransfer发送给device。 * missingKeysHostPos为host侧需要发送的emb的位置 */ -vector HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& embName) +void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& embName, + vector& h2d_emb) { EASY_FUNCTION() - vector h2d_emb; const auto& emb = hostEmbs[embName]; const int embeddingSize = emb.hostEmbInfo.extEmbeddingSize; h2d_emb.emplace_back(Tensor(tensorflow::DT_FLOAT, { @@ -191,7 +191,6 @@ vector HostEmb::GetH2DEmb(const vector& missingKeysHostPos, cons } } spdlog::info("GetH2DEmb end, missingKeys count:{}", missingKeysHostPos.size()); - return h2d_emb; } auto HostEmb::GetHostEmbs() -> absl::flat_hash_map* diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index 31928a27..ca022e7c 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -38,8 +38,8 @@ namespace MxRec { void UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName); - vector GetH2DEmb(const vector& missingKeysHostPos, const string& embName); - + void GetH2DEmb(const vector& missingKeysHostPos, const string& embName, + vector& h2d_emb); auto GetHostEmbs() -> absl::flat_hash_map*; void EvictInitEmb(const string& embName, const vector& offset); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index c9cee4d0..d05601cf 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -244,14 +244,20 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_de isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; TimeCost tc; - auto [lookupKeys, restore, hotPos, scAll, countRecv] = - ProcessBatchWithUniqueCompute(batch, unique, id); - TIME_PRINT("ProcessBatch TimeCost(ms):{}", tc.ElapsedMS()); +// keys_t lookupKeys; +// vector restore; +// vector hotPos; +// vector scAll; +// vector countRecv; + UniqueInfo uniqueInfo; + ProcessBatchWithUniqueCompute(batch, unique, id, uniqueInfo); + TIME_PRINT("no copy ProcessBatchWithUniqueCompute TimeCost(ms):{}", tc.ElapsedMS()); + // 特征准入&淘汰 if (isWithFAAE && - (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, lookupKeys, - countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { + (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, uniqueInfo.all2AllInfo.keyRecv, + uniqueInfo.all2AllInfo.countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { spdlog::error(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", rankInfo.rankId, id, channel); return false; @@ -262,7 +268,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_de // map key to offset directly by lookup keyOffsetMap (hashmap) if (rankInfo.noDDR) { TimeCost key2OffsetTc; - Key2Offset(batch->name, lookupKeys); + Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv); TIME_PRINT("Key2Offset TimeCost(ms):{}", key2OffsetTc.ElapsedMS()); } if (!rankInfo.useStatic) { // Static all2all,need send count @@ -270,24 +276,24 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_de if (batch->modifyGraph) { embName = batch->channelName; } - SendA2A(scAll, embName, batch->channel, batch->batchId); + SendA2A(uniqueInfo.all2AllInfo.scAll, embName, batch->channel, batch->batchId); } auto tensors = make_unique>(); - tensors->push_back(Vec2TensorI32(restore)); + tensors->push_back(Vec2TensorI32(uniqueInfo.restore)); if (rankInfo.useHot) { - hotPos.resize(hotEmbTotCount[batch->name], -1); - tensors->push_back(Vec2TensorI32(hotPos)); + uniqueInfo.hotPos.resize(hotEmbTotCount[batch->name], -1); + tensors->push_back(Vec2TensorI32(uniqueInfo.hotPos)); } if (rankInfo.noDDR) { if (rankInfo.useDynamicExpansion) { - tensors->push_back(Vec2TensorI64(lookupKeys)); + tensors->push_back(Vec2TensorI64(uniqueInfo.all2AllInfo.keyRecv)); } else { - tensors->push_back(Vec2TensorI32(lookupKeys)); + tensors->push_back(Vec2TensorI32(uniqueInfo.all2AllInfo.keyRecv)); } } TimeCost pushTensorTc; - PushResult(batch, move(tensors), lookupKeys, batchListId); + PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv, batchListId); TIME_PRINT("pushTensorToListTC TimeCost(ms):{}", pushTensorTc.ElapsedMS()); return true; } @@ -401,8 +407,8 @@ size_t KeyProcess::GetKeySize(const unique_ptr &batch) return size; } -auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id) - -> tuple, vector, vector, vector> +void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id, + UniqueInfo& uniqueInfo) { EASY_FUNCTION(profiler::colors::Purple) EASY_VALUE("batchId", batch->batchId) @@ -417,25 +423,24 @@ auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba keySend.resize(size); vector splitSize(rankInfo.rankSize); vector uniqueVector(batch->batchSize); - vector restore(batch->batchSize); + uniqueInfo.restore.resize(batch->batchSize); vector idCount(batch->batchSize); vector keyCount(size); std::shared_lock lock(g_smut); auto hotMap = hotKey[batch->name]; lock.unlock(); - vector hotPos; int hotOffset = 0; if (rankInfo.useHot) { - hotPos.resize(hotEmbTotCount[batch->name]); + uniqueInfo.hotPos.resize(hotEmbTotCount[batch->name]); hotOffset = hotEmbTotCount[batch->name]; } absl::flat_hash_map keyCountMap; - UniqueData uniqueData = {batch->tensorAddr, batch->batchSize, restore.data(), uniqueVector.data(), splitSize.data(), - keySend.data(), idCount.data(), keyCount.data()}; + UniqueData uniqueData = {batch->tensorAddr, batch->batchSize, uniqueInfo.restore.data(), uniqueVector.data(), + splitSize.data(), keySend.data(), idCount.data(), keyCount.data()}; UniqueFlag uniqueFlag = {batch->isInt64, rankInfo.useStatic, rankInfo.useHot}; - UniqueForHot uniqueForHot = {hotOffset, hotPos.data(), hotMap, keyCountMap}; + UniqueForHot uniqueForHot = {hotOffset, uniqueInfo.hotPos.data(), hotMap, keyCountMap}; UniqueThreadNum uniqueThreadNum = {MIN_UNIQUE_THREAD_NUM, MAX_UNIQUE_THREAD_NUM}; unique->Compute(&pool_, uniqueData, uniqueFlag, uniqueForHot, uniqueThreadNum); @@ -456,20 +461,21 @@ auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba sc[i] = splitSize[i]; } } - auto [keyRecv, scAll, countRecv] = All2All(sc, id, batch->channel, keySend, keyCount); + All2All(sc, id, batch->channel, keySend, keyCount, uniqueInfo.all2AllInfo); spdlog::debug(KEY_PROCESS "ProcessBatchWithUniqueCompute get batchId:{}, batchSize:{}, channel:{}, " "channelName:{}, name:{}, restore:{}, keyCount:{}", batch->batchId, batch->batchSize, - batch->channel, batch->channelName, batch->name, restore.size(), keyCount.size()); - return { keyRecv, restore, hotPos, scAll, countRecv}; + batch->channel, batch->channelName, batch->name, uniqueInfo.restore.size(), + keyCount.size()); + } -auto KeyProcess::All2All(vector& sc, int id, int channel, keys_t& keySend, vector& keyCount) - -> tuple, vector> +void KeyProcess::All2All(vector& sc, int id, int channel, keys_t& keySend, vector& keyCount, + All2AllInfo& all2AllInfo) + { - keys_t keyRecv; TimeCost get_sc_all; - auto scAll = GetScAll(sc, id, channel); // Allgather通信获取所有(不同rank相同thread id的) + GetScAll(sc, id, channel, all2AllInfo.scAll); // Allgather通信获取所有(不同rank相同thread id的) TIME_PRINT("GetScAll TimeCost(ms):{}", get_sc_all.ElapsedMS()); TimeCost all2allTC; @@ -477,22 +483,22 @@ auto KeyProcess::All2All(vector& sc, int id, int channel, keys_t& keySend, vector rc(rankInfo.rankSize); // receive count for (int i = 0; i < rankInfo.rankSize; ++i) { // 通信量矩阵某一列的和即为本地要从其他设备接受的key数据量 - rc[i] = scAll.at(i * rankInfo.rankSize + rankInfo.rankId); + rc[i] = all2AllInfo.scAll.at(i * rankInfo.rankSize + rankInfo.rankId); } auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 - keyRecv.resize(rs.back() + rc.back()); + all2AllInfo.keyRecv.resize(rs.back() + rc.back()); EASY_BLOCK("all2all") - MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, - comm[channel][id]); + MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfo.keyRecv.data(), rc.data(), rs.data(), + MPI_INT64_T, comm[channel][id]); - vector countRecv(rs.back() + rc.back()); + all2AllInfo.countRecv.resize(rs.back() + rc.back()); if (isWithFAAE) { - MPI_Alltoallv(keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), rc.data(), rs.data(), - MPI_UINT32_T, comm[channel][id]); + MPI_Alltoallv(keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfo.countRecv.data(), rc.data(), + rs.data(), MPI_UINT32_T, comm[channel][id]); } TIME_PRINT("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); EASY_END_BLOCK - return {keyRecv, scAll, countRecv}; + } auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, @@ -519,7 +525,8 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, keySend.insert(keySend.end(), i.begin(), i.end()); } keys_t keyRecv; - auto scAll = GetScAll(sc, id, batch->channel); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 + vector scAll; + GetScAll(sc, id, batch->channel, scAll); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 vector rc; // receive count for (int i = 0; i < rankInfo.rankSize; ++i) { @@ -735,10 +742,9 @@ void KeyProcess::UpdateHotMap(absl::flat_hash_map& keyCountMap, * 将本地(rank)batch要发送的key数据量进行Allgather通信,获取所有(不同rank相同thread id的)线程间的通信量矩阵 * scAll返回:所有线程间的通信量矩阵(按行平铺的一维向量) */ -vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel) const +void KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel, vector &scAll) const { EASY_FUNCTION() - vector scAll; scAll.resize(rankInfo.rankSize * rankInfo.rankSize); EASY_BLOCK("barrier"); // 通信终止信号,同步退出,防止线程卡住 @@ -754,7 +760,6 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); spdlog::debug("rank {} key scAll matrix:\n{}", rankInfo.rankId, scAll); - return scAll; } void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 3a98aa0a..ca020bf2 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -125,13 +125,14 @@ namespace MxRec { int channel, int id, spdlog::stopwatch& sw); auto ProcessSplitKeys(const unique_ptr& batch, int id, vector& splitKeys) -> tuple, vector>; - auto ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique_, int id) - -> tuple, vector, vector, vector>; + + void ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id, + UniqueInfo& uniqueInfo); size_t GetKeySize(const unique_ptr &batch); - auto All2All(vector& sc, int id, int channel, keys_t& keySend, vector& keyCount) - -> tuple, vector>; + void All2All(vector& sc, int id, int channel, keys_t& keySend, vector& keyCount, + All2AllInfo& all2AllInfo); auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; @@ -139,7 +140,7 @@ namespace MxRec { auto HashSplit_withFAAE(const unique_ptr& batch) const -> tuple, vector, vector>>; - [[nodiscard]] vector GetScAll(const vector& keyScLocal, int commId, int channel) const; + void GetScAll(const vector& keyScLocal, int commId, int channel, vector &scAll) const; void Key2Offset(const emb_name_t& embName, keys_t& splitKey); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index a3555d0f..51e78e2e 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -394,6 +394,29 @@ struct BatchTask { bool HasFree(size_t i); }; + struct All2AllInfo { + + All2AllInfo() = default; + All2AllInfo(keys_t keyRecv, vector scAll, vector countRecv); + + keys_t keyRecv; + vector scAll; + vector countRecv; + }; + + struct UniqueInfo { + + UniqueInfo() = default; + UniqueInfo(vector restore, vector hotPos, All2AllInfo all2AllInfo); + + + vector restore; + vector hotPos; + All2AllInfo all2AllInfo; + }; + + + using emb_mem_t = absl::flat_hash_map; using emb_hash_mem_t = absl::flat_hash_map; using offset_mem_t = std::map; -- Gitee From 33601f388ab00bbc70fd7fadfa6a2c15f0996d83 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 17:34:29 +0800 Subject: [PATCH 032/551] Match-id-3d8e20a5505ab44f037a6b974e47ace9cf50bef7 --- src/core/key_process/key_process.cpp | 18 +++++++++--------- src/core/key_process/key_process.h | 2 +- src/core/utils/common.h | 8 ++++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index d05601cf..4bdec548 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -418,14 +418,14 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba TimeCost unique_tc; SimpleThreadPool pool_; - keys_t keySend; + KeySendInfo keySendInfo; size_t size = GetKeySize(batch); - keySend.resize(size); + keySendInfo.keySend.resize(size); vector splitSize(rankInfo.rankSize); vector uniqueVector(batch->batchSize); uniqueInfo.restore.resize(batch->batchSize); vector idCount(batch->batchSize); - vector keyCount(size); + keySendInfo.keyCount.resize(size); std::shared_lock lock(g_smut); auto hotMap = hotKey[batch->name]; lock.unlock(); @@ -438,7 +438,7 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba absl::flat_hash_map keyCountMap; UniqueData uniqueData = {batch->tensorAddr, batch->batchSize, uniqueInfo.restore.data(), uniqueVector.data(), - splitSize.data(), keySend.data(), idCount.data(), keyCount.data()}; + splitSize.data(), keySendInfo.keySend.data(), idCount.data(), keySendInfo.keyCount.data()}; UniqueFlag uniqueFlag = {batch->isInt64, rankInfo.useStatic, rankInfo.useHot}; UniqueForHot uniqueForHot = {hotOffset, uniqueInfo.hotPos.data(), hotMap, keyCountMap}; UniqueThreadNum uniqueThreadNum = {MIN_UNIQUE_THREAD_NUM, MAX_UNIQUE_THREAD_NUM}; @@ -461,16 +461,16 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba sc[i] = splitSize[i]; } } - All2All(sc, id, batch->channel, keySend, keyCount, uniqueInfo.all2AllInfo); + All2All(sc, id, batch->channel, keySendInfo, uniqueInfo.all2AllInfo); spdlog::debug(KEY_PROCESS "ProcessBatchWithUniqueCompute get batchId:{}, batchSize:{}, channel:{}, " "channelName:{}, name:{}, restore:{}, keyCount:{}", batch->batchId, batch->batchSize, batch->channel, batch->channelName, batch->name, uniqueInfo.restore.size(), - keyCount.size()); + keySendInfo.keyCount.size()); } -void KeyProcess::All2All(vector& sc, int id, int channel, keys_t& keySend, vector& keyCount, +void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, All2AllInfo& all2AllInfo) { @@ -488,12 +488,12 @@ void KeyProcess::All2All(vector& sc, int id, int channel, keys_t& keySend, auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 all2AllInfo.keyRecv.resize(rs.back() + rc.back()); EASY_BLOCK("all2all") - MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfo.keyRecv.data(), rc.data(), rs.data(), + MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfo.keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[channel][id]); all2AllInfo.countRecv.resize(rs.back() + rc.back()); if (isWithFAAE) { - MPI_Alltoallv(keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfo.countRecv.data(), rc.data(), + MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfo.countRecv.data(), rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); } TIME_PRINT("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index ca020bf2..35af99e6 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -131,7 +131,7 @@ namespace MxRec { size_t GetKeySize(const unique_ptr &batch); - void All2All(vector& sc, int id, int channel, keys_t& keySend, vector& keyCount, + void All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, All2AllInfo& all2AllInfo); auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 51e78e2e..90d5633c 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -415,6 +415,14 @@ struct BatchTask { All2AllInfo all2AllInfo; }; + struct KeySendInfo { + + KeySendInfo() = default; + KeySendInfo(keys_t keySend, vector keyCount); + + keys_t keySend; + vector keyCount; + }; using emb_mem_t = absl::flat_hash_map; -- Gitee From 88405b7ebe3cc09d6f1ed3ecf9714e4191f7f48b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 17:44:26 +0800 Subject: [PATCH 033/551] Match-id-0e98af6bec787ed2326b58251008af62bc708526 --- src/core/emb_hashmap/emb_hashmap.cpp | 2 +- src/core/key_process/key_process.cpp | 14 +++++--------- src/core/utils/common.h | 5 ----- 3 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 588a904f..18b02dc5 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -38,7 +38,7 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, } void EmbHashMap::Process(const string& embName, const vector& keys, size_t iBatch, - vector& tmpData) + vector& tmpData) { EASY_FUNCTION(profiler::colors::Pink) auto keepBatch = swapId - iBatch; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 4bdec548..93b3e738 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -253,11 +253,10 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_de ProcessBatchWithUniqueCompute(batch, unique, id, uniqueInfo); TIME_PRINT("no copy ProcessBatchWithUniqueCompute TimeCost(ms):{}", tc.ElapsedMS()); - // 特征准入&淘汰 if (isWithFAAE && (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, uniqueInfo.all2AllInfo.keyRecv, - uniqueInfo.all2AllInfo.countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { + uniqueInfo.all2AllInfo.countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { spdlog::error(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", rankInfo.rankId, id, channel); return false; @@ -444,7 +443,6 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba UniqueThreadNum uniqueThreadNum = {MIN_UNIQUE_THREAD_NUM, MAX_UNIQUE_THREAD_NUM}; unique->Compute(&pool_, uniqueData, uniqueFlag, uniqueForHot, uniqueThreadNum); - EASY_END_BLOCK TIME_PRINT("UniqueCompute TimeCost(ms):{}", unique_tc.ElapsedMS()); @@ -467,7 +465,6 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba "channelName:{}, name:{}, restore:{}, keyCount:{}", batch->batchId, batch->batchSize, batch->channel, batch->channelName, batch->name, uniqueInfo.restore.size(), keySendInfo.keyCount.size()); - } void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, @@ -488,17 +485,16 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 all2AllInfo.keyRecv.resize(rs.back() + rc.back()); EASY_BLOCK("all2all") - MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfo.keyRecv.data(), rc.data(), rs.data(), - MPI_INT64_T, comm[channel][id]); + MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfo.keyRecv.data(), rc.data(), + rs.data(), MPI_INT64_T, comm[channel][id]); all2AllInfo.countRecv.resize(rs.back() + rc.back()); if (isWithFAAE) { - MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfo.countRecv.data(), rc.data(), - rs.data(), MPI_UINT32_T, comm[channel][id]); + MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfo.countRecv.data(), + rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); } TIME_PRINT("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); EASY_END_BLOCK - } auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 90d5633c..14804760 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -395,7 +395,6 @@ struct BatchTask { }; struct All2AllInfo { - All2AllInfo() = default; All2AllInfo(keys_t keyRecv, vector scAll, vector countRecv); @@ -405,18 +404,15 @@ struct BatchTask { }; struct UniqueInfo { - UniqueInfo() = default; UniqueInfo(vector restore, vector hotPos, All2AllInfo all2AllInfo); - vector restore; vector hotPos; All2AllInfo all2AllInfo; }; struct KeySendInfo { - KeySendInfo() = default; KeySendInfo(keys_t keySend, vector keyCount); @@ -424,7 +420,6 @@ struct BatchTask { vector keyCount; }; - using emb_mem_t = absl::flat_hash_map; using emb_hash_mem_t = absl::flat_hash_map; using offset_mem_t = std::map; -- Gitee From e4a5d650fd029a41d418680c46818e791995550d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 17:49:08 +0800 Subject: [PATCH 034/551] Match-id-0c36bd207dc73cef06a316d5160f145e0d99d064 --- src/core/key_process/key_process.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 93b3e738..889b30fb 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -256,7 +256,8 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_de // 特征准入&淘汰 if (isWithFAAE && (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, uniqueInfo.all2AllInfo.keyRecv, - uniqueInfo.all2AllInfo.countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { + uniqueInfo.all2AllInfo.countRecv) + == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { spdlog::error(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", rankInfo.rankId, id, channel); return false; -- Gitee From 6d4d6289d922e09e7b5e5d71d397c83b3871412d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 18:08:35 +0800 Subject: [PATCH 035/551] Match-id-7821851d8bc9e11f54c8b9889045ed21a1c4d27f --- build/build.sh | 2 +- src/tests/emb_mgmt/emb_mgmt_test.cpp | 7 ++++--- src/tests/key_process/key_process_test.cpp | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/build/build.sh b/build/build.sh index b7fc65f3..5fd7b54c 100644 --- a/build/build.sh +++ b/build/build.sh @@ -210,5 +210,5 @@ deactivate tf2_env echo "-----Build gen tar -----" gen_tar_file -clean +#clean echo "-----Done-----" diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index be4d2fbc..00bc55d9 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -136,7 +136,8 @@ TEST_F(EmbMgmtTest, Initialize) vector tmpData; vector d2h_emb; vector> tmpDatas; - tmpData = hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId); + vector tmpData; + hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); auto missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; spdlog::info("missingKeys {}", missingKeys); hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); @@ -146,7 +147,7 @@ TEST_F(EmbMgmtTest, Initialize) hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos.clear(); lookupKeys = { 2, 3, 5, 6 }; - tmpData = hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId); + hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; spdlog::info("missingKeys {}", missingKeys); hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); @@ -156,7 +157,7 @@ TEST_F(EmbMgmtTest, Initialize) hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos.clear(); lookupKeys = { 1, 7, 9, 10 }; - tmpData = hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId); + hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; spdlog::info("missingKeys {}", missingKeys); hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 6d1e01b7..3ad44c64 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -261,7 +261,8 @@ TEST_F(KeyProcessTest, GetScAll) } ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); ASSERT_EQ(process.isRunning, true); - auto scAll = process.GetScAll(keyScLocal, 0, 0); + vector scAll; + process.GetScAll(keyScLocal, 0, 0, scAll); ASSERT_THAT(scAll, ElementsAreArray(expectScAll)); } -- Gitee From c7b51e65bd2f029ad0cb2ac4d4c550706da6f1f2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 18:09:37 +0800 Subject: [PATCH 036/551] Match-id-6df8f0f94c6c89ca47fb281fc07a5593702820b5 --- src/core/checkpoint/checkpoint.h | 2 +- src/core/hd_transfer/hd_transfer.cpp | 2 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 22 +++++++++---------- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- .../constant_initializer.cpp | 2 +- .../constant_initializer.h | 2 +- .../random_normal_initializer.cpp | 2 +- .../random_normal_initializer.h | 2 +- .../truncated_normal_initializer.cpp | 2 +- .../truncated_normal_initializer.h | 2 +- src/core/key_process/key_process.cpp | 2 +- src/core/utils/common.h | 4 ++-- src/core/utils/spinlock.h | 13 +++++------ src/core/utils/unique.h | 15 ++++++++----- 14 files changed, 38 insertions(+), 36 deletions(-) diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 548517fb..7539a9b6 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -80,7 +80,7 @@ namespace MxRec { void WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType); void WriteDataset(CkptTransData& transData, ofstream& writeFile, size_t writeSize, CkptDataType dataType, size_t idx); - void WriteEmbedding(CkptTransData& transData, const string& dataDir, int& embeddingSize); + void WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize); void ReadEmbedding(CkptTransData& transData, const string& dataDir); int GetEmbeddingSize(const string& embName); diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index b5c99868..39c8c76e 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -82,7 +82,7 @@ void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName } } spdlog::info("user config all2all restore lookup channel size:{}", channelSize); - for (int c = D2H; c != INVALID; c++) { + for (int c = static_cast(TransferChannel::D2H); c != static_cast(TransferChannel::INVALID); c++) { auto channel = static_cast(c); string sendName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelNum); if (TransferChannel2Str(channel) == "all2all" || diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 8ef056be..a53ff9d5 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -374,7 +374,7 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) return true; } -bool HybridMgmt::SendLookupAndRestore(const int channelId, const int &batchId) +bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) { for (const auto& embInfo: mgmtEmbInfo) { vector names = {embInfo.name}; @@ -386,7 +386,7 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, const int &batchId) if (!mgmtRankInfo.useStatic) { for (const string& name: names) { auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(ALL2ALL, { *all2all }, channelId, name); + hdTransfer->Send(TransferChannel::ALL2ALL, { *all2all }, channelId, name); } } spdlog::info("SendLookupAndRestore batchId: {}, name: {}, channelId: {}", @@ -401,7 +401,7 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, const int &batchId) TimeCost sendLookupTC; for (const string& name: names) { auto lookUpKeys = lookUpKeysQueue->WaitAndPop(); - hdTransfer->Send(LOOKUP, lookUpKeys, channelId, name); + hdTransfer->Send(TransferChannel::LOOKUP, lookUpKeys, channelId, name); } TIME_PRINT("LOOKUP Send TimeCost(ms):{}", sendLookupTC.ElapsedMS()); } @@ -410,7 +410,7 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, const int &batchId) TimeCost sendRestoreTC; for (const string& name: names) { auto restore = restoreQueue->WaitAndPop(); - hdTransfer->Send(RESTORE, restore, channelId, name); + hdTransfer->Send(TransferChannel::RESTORE, restore, channelId, name); } TIME_PRINT("RESTORE Send TimeCost(ms):{}", sendRestoreTC.ElapsedMS()); } @@ -447,16 +447,16 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) break; } auto restore = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); - hdTransfer->Send(RESTORE, *restore, channelId, embInfo.name); + hdTransfer->Send(TransferChannel::RESTORE, *restore, channelId, embInfo.name); auto tmpData = hostHashMaps->Process(embInfo.name, lookupKeys, iBatch); - hdTransfer->Send(LOOKUP, { tmpData.front() }, channelId, embInfo.name); + hdTransfer->Send(TransferChannel::LOOKUP, { tmpData.front() }, channelId, embInfo.name); tmpData.erase(tmpData.begin()); - hdTransfer->Send(SWAP, tmpData, channelId, embInfo.name); + hdTransfer->Send(TransferChannel::SWAP, tmpData, channelId, embInfo.name); if (!mgmtRankInfo.useStatic) { auto all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(ALL2ALL, *all2all, channelId, embInfo.name); + hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embInfo.name); } if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch @@ -509,7 +509,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) for (const auto& embInfo: mgmtEmbInfo) { auto& missingKeys = hostHashMaps->embHashMaps.at(embInfo.name).missingKeysHostPos; auto h2dEmb = hostEmbs->GetH2DEmb(missingKeys, embInfo.name); // order! - hdTransfer->Send(H2D, h2dEmb, channelId, embInfo.name, batchId); + hdTransfer->Send(TransferChannel::H2D, h2dEmb, channelId, embInfo.name, batchId); } for (const auto& embInfo: mgmtEmbInfo) { const auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); @@ -528,7 +528,7 @@ void HybridMgmt::EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embI spdlog::info(MGMT + "trans emb dummy, batchId:{}, channelId:{}", batchId, channelId); auto transferName = TransferChannel::D2H; auto d2hEmb = hdTransfer->Recv(transferName, channelId, embInfo.name)[0]; - hdTransfer->Send(H2D, {}, channelId, embInfo.name); + hdTransfer->Send(TransferChannel::H2D, {}, channelId, embInfo.name); } /* @@ -593,5 +593,5 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) } auto tmpData = Vec2TensorI32(evictDevOffset); - hdTransfer->Send(EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); + hdTransfer->Send(TransferChannel::EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 1ca832fa..16cbebed 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -115,7 +115,7 @@ namespace MxRec { bool EvalParseKeys(); bool GetLookupAndRestore(const int channelId, int &batchId); - bool SendLookupAndRestore(const int channelId, const int &batchId); + bool SendLookupAndRestore(const int channelId, int &batchId); void EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo); diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index 09780b88..954ca98f 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -13,7 +13,7 @@ using namespace MxRec; ConstantInitializer::ConstantInitializer(int start, int len, float value) : start(start), len(len), value(value) {} -void ConstantInitializer::GenerateData(const float* emb, const int embSize) +void ConstantInitializer::GenerateData(float* const emb, const int embSize) { if (len == 0) { return; diff --git a/src/core/initializer/constant_initializer/constant_initializer.h b/src/core/initializer/constant_initializer/constant_initializer.h index 56f95a9e..68aa0654 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.h +++ b/src/core/initializer/constant_initializer/constant_initializer.h @@ -21,7 +21,7 @@ namespace MxRec { ~ConstantInitializer() override {}; - void GenerateData(const float* emb, const int embSize) override; + void GenerateData(float* const emb, const int embSize) override; int start; int len; diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index 56765ecf..7933555f 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -18,7 +18,7 @@ RandomNormalInitializer::RandomNormalInitializer(int start, int len, float mean, distribution = std::normal_distribution(mean, stddev); } -void RandomNormalInitializer::GenerateData(const float* emb, const int embSize) +void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) { if (len == 0) { return; diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index eeab2dfa..7bbf3b7b 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -23,7 +23,7 @@ namespace MxRec { ~RandomNormalInitializer() override {}; - void GenerateData(const float* emb, const int embSize) override; + void GenerateData( float* const emb, const int embSize) override; int start; int len; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index 2d830490..d02ac998 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -21,7 +21,7 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, float } -void TruncatedNormalInitializer::GenerateData(const float* emb, const int embSize) +void TruncatedNormalInitializer::GenerateData(float* const emb, const int embSize) { if (len == 0) { return; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index 34058351..d2da1bef 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -23,7 +23,7 @@ namespace MxRec { ~TruncatedNormalInitializer() override {}; - void GenerateData(const float* emb, const int embSize) override; + void GenerateData(float* const emb, const int embSize) override; int boundNum = 2; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 6fd38ec4..3ea879fa 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1016,7 +1016,7 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset auto trans = Singleton::GetInstance(); // evict key发送给dev侧,dev侧初始化emb auto tmpData = Vec2TensorI32(offset); - trans->Send(EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); + trans->Send(TransferChannel::EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); spdlog::info(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); } \ No newline at end of file diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 84a1ae50..883aca4e 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -126,8 +126,8 @@ namespace MxRec { { std::string s; constexpr size_t MAX_DISP_LEN = 20; - int max_len = std::min(sample.size(), MAX_DISP_LEN); - for (int i = 0; i < max_len; i++) { + int maxLen = std::min(sample.size(), MAX_DISP_LEN); + for (int i = 0; i < maxLen; i++) { s += std::to_string(sample[i]) + " "; } return s; diff --git a/src/core/utils/spinlock.h b/src/core/utils/spinlock.h index d284ef05..abbb748c 100644 --- a/src/core/utils/spinlock.h +++ b/src/core/utils/spinlock.h @@ -12,11 +12,6 @@ #include #include // NOLINT -#define DISALLOW_COPY_MOVE_AND_ASSIGN_(type) \ - type(type const &) = delete; \ - type(type &&) noexcept = delete; \ - type &operator=(type const &) = delete - static __inline void cpu_pause() { #ifdef __GNUC__ @@ -61,7 +56,9 @@ class SpinLock final { public: SpinLock() = default; - DISALLOW_COPY_MOVE_AND_ASSIGN_(SpinLock); + SpinLock(SpinLock const &) = delete; + SpinLock(SpinLock &&) noexcept = delete; + SpinLock &operator=(SpinLock const &) = delete; inline void lock() noexcept { @@ -108,7 +105,9 @@ class RWSpinLock final { public: RWSpinLock() = default; - DISALLOW_COPY_MOVE_AND_ASSIGN_(RWSpinLock); + RWSpinLock(RWSpinLock const &) = delete; + RWSpinLock(RWSpinLock &&) noexcept = delete; + RWSpinLock &operator=(RWSpinLock const &) = delete; inline void r_lock() noexcept { diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index 4509df4c..d0ef985f 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -33,6 +33,8 @@ #include "spinlock.h" #include "time_cost.h" +using namespace MxRec; + struct UniqueData { void *inputData; size_t dataSize; @@ -62,7 +64,7 @@ struct UniqueThreadNum { int maxThread; }; -namespace { +namespace SysytemConst{ const int LEVEL1_CACHE = 64; const int DEFAULT_DEDUPLICATION_RATE = 4; const int DEDUPLICATION_RATE = 2; @@ -131,7 +133,7 @@ public: Dedup(int bucketCountPower2 = kDefaultBucketCount, int groups = 1) : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) { - void *area = aligned_alloc(LEVEL1_CACHE, sizeof(Meta) * bucketCount_); + void *area = aligned_alloc(SysytemConst::LEVEL1_CACHE, sizeof(Meta) * bucketCount_); table_ = reinterpret_cast *>(area); Clear(bucketCount_); } @@ -278,7 +280,7 @@ public: // sake. uint64_t shardedTableSize = newBucketCountPowerOf2 * N * groupCount_; int largeCount = 0; - while (shardedTableSize > stats_.totalUniques * DEFAULT_DEDUPLICATION_RATE && largeCount_ != 1) { + while (shardedTableSize > stats_.totalUniques * SysytemConst::DEFAULT_DEDUPLICATION_RATE && largeCount_ != 1) { // too large newBucketCountPowerOf2 >>= 1; shardedTableSize >>= 1; @@ -293,7 +295,7 @@ public: } } - while (shardedTableSize < stats_.totalUniques + (stats_.totalUniques >> DEDUPLICATION_RATE)) { + while (shardedTableSize < stats_.totalUniques + (stats_.totalUniques >> SysytemConst::DEDUPLICATION_RATE)) { newBucketCountPowerOf2 <<= 1; shardedTableSize <<= 1; } @@ -454,7 +456,8 @@ private: static inline uint64_t hash(uint64_t val) { - return val ^ (val >> HASH_SPLIT_BUCKERT_1) ^ (val >> HASH_SPLIT_BUCKERT_2) ^ (val >> HASH_SPLIT_BUCKERT_3); + return val ^ (val >> SysytemConst::HASH_SPLIT_BUCKERT_1) ^ (val >> SysytemConst::HASH_SPLIT_BUCKERT_2) ^ + (val >> SysytemConst::HASH_SPLIT_BUCKERT_3); } void insertOverflow(uint64_t val) @@ -512,7 +515,7 @@ public: ShardedDedup(const GroupMethod &groupMethod, int desiredSize, int send_cnt, int estimatedDuplicateRatio = kDefaultDuplicateRatio) - : groupMethod_(groupMethod), bucketCountPower2_(PRE_APPLY_MEMORY), send_cnt_(send_cnt) + : groupMethod_(groupMethod), bucketCountPower2_(SysytemConst::PRE_APPLY_MEMORY), send_cnt_(send_cnt) { const int numOfGroupsInShard = groupMethod_.GroupCount(); -- Gitee From a35fb269a22b863f4141f911e504093a13f1eafd Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 May 2023 18:12:23 +0800 Subject: [PATCH 037/551] Match-id-9bd608d2ec934caa3f521a726c487fca3fddb290 --- src/tests/emb_mgmt/emb_mgmt_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 00bc55d9..22b8655f 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -133,7 +133,6 @@ TEST_F(EmbMgmtTest, Initialize) int currentBatchId = 0; vector lookupKeys = { 1, 3, 5, 7 }; - vector tmpData; vector d2h_emb; vector> tmpDatas; vector tmpData; -- Gitee From 0ca0436947be901c270c2760a69c16bfdf1ef9cd Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 09:40:24 +0800 Subject: [PATCH 038/551] Match-id-b0e5e8a3f28df02c5817ed771d6ab141cbf88db7 --- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- .../random_normal_initializer/random_normal_initializer.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 16cbebed..3238edc3 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -9,11 +9,11 @@ #define MX_REC_EMB_MGMT_H #include -#include "absl/container/flat_hash_map.h" #include #include #include #include +#include "absl/container/flat_hash_map.h" #include "utils/common.h" #include "utils/singleton.h" #include "utils/task_queue.h" diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index 7bbf3b7b..d6c8b376 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -23,7 +23,7 @@ namespace MxRec { ~RandomNormalInitializer() override {}; - void GenerateData( float* const emb, const int embSize) override; + void GenerateData(float *const emb, const int embSize) override; int start; int len; -- Gitee From da53dd9d272d2e3665375555d1504d25e57a3a3c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 10:08:32 +0800 Subject: [PATCH 039/551] Match-id-3e752d7ea7b7cca853efb6700160632fd0ab5052 --- src/core/utils/common.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 14804760..7ebc7057 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -395,9 +395,6 @@ struct BatchTask { }; struct All2AllInfo { - All2AllInfo() = default; - All2AllInfo(keys_t keyRecv, vector scAll, vector countRecv); - keys_t keyRecv; vector scAll; vector countRecv; -- Gitee From c82a614517f97970b7bd837a69d1e3e4032d80ef Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 10:09:03 +0800 Subject: [PATCH 040/551] Match-id-c0f0b217eaab00846ec800331df71c74c9521481 --- mx_rec/core/asc/helper.py | 26 ++++++++--------- mx_rec/core/asc/manager.py | 6 ++-- src/core/key_process/key_process.h | 4 ++- src/ops_tf/hybrid_dataset_ops.cpp | 45 ++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 19 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 3831cafa..0c0ba694 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -47,11 +47,15 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names def find_dangling_table(table_names): def check_tensor(table_name, table_reachable_tensor): - if ''.join(["/update_", table_name]) in table_reachable_tensor.name \ - or table_reachable_tensor.op.type == 'ApplyAdam': + the_op = table_reachable_tensor.op + logging.info(f"** the_op:{the_op.outputs} {the_op.name} {the_op.type}**") + + if table_reachable_tensor.op.type == 'ApplyAdam': return True - if 'gradients/' in table_reachable_tensor.name: + + if 'gradients/' in table_reachable_tensor.name and table_reachable_tensor.op.type == 'Identity': return True + return False def find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor): @@ -67,13 +71,13 @@ def find_dangling_table(table_names): table_lookup_op = {} table_reachable_tensor = {} + for the_op in op_list: - logging.info(f"** the_op: {the_op}**") for table_name in table_names: find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) logging.info(f"*********** find tables: {table_lookup_op}***********") - logging.info(f"looking for dangling table") + logging.info(f"opTypes{[x for x in opTypes]}") dangling_table = [] def extend(op_list, tensor, spread_tensors): @@ -128,17 +132,10 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ } get_target_tensors_with_feature_specs(tgt_key_specs, data_src, is_training, read_emb_key_inputs_dict) logging.debug(f"do_insert with spec for {read_emb_key_inputs_dict['table_names']}") - table_names = read_emb_key_inputs_dict["table_names"] - logging.info(f"all table_names: {table_names}") - dangling_tables = find_dangling_table(table_names) - for table_name in dangling_tables: - logging.info(f"In insert found dangling table: {table_name} " - f"which does not need to be provided to the EmbInfo.") - table_names.remove(table_name) return do_insert(args, insert_tensors=read_emb_key_inputs_dict["insert_tensors"], splits=read_emb_key_inputs_dict["splits"], - table_names=table_names, + table_names=read_emb_key_inputs_dict["table_names"], input_dict={"is_training": is_training, "dump_graph": dump_graph, "timestamp": FeatureSpec.use_timestamp(is_training), "feature_spec_names": read_emb_key_inputs_dict["feature_spec_names"], @@ -149,13 +146,12 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ else: if feature_counts is None or table_names is None: raise ValueError("Please config 'args_index_list', 'feature_counts' and 'table_names' at the same time.") - logging.info(f"all table_names: {table_names}") + dangling_tables = find_dangling_table(table_names) for table_name in dangling_tables: logging.info(f"In insert found dangling table: {table_name} " f"which does not need to be provided to the EmbInfo.") table_names.remove(table_name) - logging.info(f"used table_names: {table_names}") def insert_fn_for_arg_indexes(*args): insert_tensors = get_target_tensors_with_args_indexes(args_index_list) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 1e2e987c..aeb9e5ad 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -30,8 +30,9 @@ def generate_table_info_list(): optimizer = export_optimizer() # generate table info dangling_table = export_dangling_table() - # dangling_table = find_dangling_table([table_instance.table_name - # for _, table_instance in export_table_instances().items()]) + if not dangling_table: + dangling_table = find_dangling_table([table_instance.table_name + for _, table_instance in export_table_instances().items()]) for _, table_instance in export_table_instances().items(): # When dynamic expansion mode, ext_emb_size is set by optimizer if optimizer is not None: @@ -192,6 +193,5 @@ def start_asc_pipeline(): threshold_list = generate_threshold_list() if not table_info_list: logging.warning(f"table_info_list is empty") - if not is_asc_manager_initialized() and table_info_list: initialize_emb_cache(table_info_list, threshold_list) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index d2730707..ee736bb9 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -88,7 +88,9 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); bool isRunning { false }; - + inline bool hasEmbName(const string& emb_name ){ + return embInfos.find(emb_name) != embInfos.end(); + }; GTEST_PRIVATE: template diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 48aaeee2..48d396a9 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -200,15 +200,34 @@ public: staticSw.reset(); } + void CheckEmbTables(){ + auto keyProcess = Singleton::GetInstance(); + for (size_t i = 0; i < embNames.size(); ++i) { + if (!keyProcess->hasEmbName(embNames.at(i))) { + spdlog::info("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i)); + tableUsed.push_back(false); + }else{ + tableUsed.push_back(true); + } + } + } + void EnqueueBatchData(std::vector ids, time_t timestamp, const Tensor& inputTensor, const TTypes::ConstFlat& splits) { + if(tableUsed.empty()){ + CheckEmbTables(); + } auto queue = SingletonQueue::getInstances(ids[1]); size_t offset = 0; if (isTimestamp) { offset += 1; // 前面8个字节是unix时间戳 } for (int i = 0; i < splits.size(); ++i) { + if(!tableUsed.at(i)){ + offset += splits(i); + continue; + } auto batchData = queue->WaitAndGetOne(); // get dirty or empty data block batchData->name = embNames.at(i); size_t len = splits(i); @@ -218,6 +237,9 @@ public: if (isTimestamp) { batchData->timestamp = timestamp; } + spdlog::info("split size:{} {}", i, splits(i)); + spdlog::info("emb_name:{} {}", i, embNames.at(i)); + spdlog::debug("batch[{}/{}] flatten bs: {}", ids[0], i+1, len); std::unique_ptr batch = TensorCopy(inputTensor, move(batchData), len, offset); if (batch == nullptr) { @@ -302,6 +324,7 @@ public: int channelId {}; vector embNames {}; + vector tableUsed{}; int maxStep = 0; bool isTimestamp { false }; }; @@ -416,15 +439,36 @@ public: staticSw.reset(); } + void CheckEmbTables(){ + auto keyProcess = Singleton::GetInstance(); + for (size_t i = 0; i < splits.size(); ++i) { + if (!keyProcess->hasEmbName(embNames.at(i))) { + spdlog::info("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i)); + tableUsed.push_back(false); + }else{ + tableUsed.push_back(true); + } + } + } + int EnqueueBatchData(int batchId, int batchQueueId, time_t timestamp, const Tensor& inputTensor) { + if(tableUsed.empty()){ + CheckEmbTables(); + } auto queue = SingletonQueue::getInstances(batchQueueId); + size_t offset = 0; if (isTimestamp) { offset += 1; // 前面8个字节是unix时间戳 } TimeCost ctAll; for (size_t i = 0; i < splits.size(); ++i) { + if(!tableUsed.at(i)){ + offset += splits.at(i); + continue; + } + TimeCost tp; auto batchData = queue->WaitAndGetOne(); // get dirty or empty data block TIME_PRINT("TryPopTimeCost(ms):{}", tp.ElapsedMS()); @@ -523,6 +567,7 @@ public: int channelId {}; vector splits {}; + vector tableUsed{}; int fieldNum {}; vector embNames {}; int maxStep = 0; -- Gitee From 0f7cd76504b807b63fca2b84e9692e487cdb3a3d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 11:00:11 +0800 Subject: [PATCH 041/551] Match-id-4bee77ff90eab27bb0023a6f608d4af80b02abb7 --- src/core/emb_hashmap/emb_hashmap.cpp | 14 +++++----- src/core/emb_hashmap/emb_hashmap.h | 2 +- src/core/host_emb/host_emb.cpp | 6 ++-- src/core/host_emb/host_emb.h | 2 +- src/core/key_process/key_process.cpp | 41 ++++++++++++---------------- src/core/key_process/key_process.h | 6 ++-- 6 files changed, 33 insertions(+), 38 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 18b02dc5..39400991 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -38,7 +38,7 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, } void EmbHashMap::Process(const string& embName, const vector& keys, size_t iBatch, - vector& tmpData) + vector& tmpDataOut) { EASY_FUNCTION(profiler::colors::Pink) auto keepBatch = swapId - iBatch; @@ -48,17 +48,17 @@ void EmbHashMap::Process(const string& embName, const vector& keys, s auto& embHashMap = embHashMaps.at(embName); auto lookUpVecSize = static_cast(embHashMap.lookUpVec.size()); - tmpData.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); + tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); - auto lookupTensorData = tmpData.back().flat(); + auto lookupTensorData = tmpDataOut.back().flat(); for (int i = 0; i < lookUpVecSize; i++) { lookupTensorData(i) = static_cast(embHashMap.lookUpVec[i]); } spdlog::trace("lookupTensor, {}", embHashMap.lookUpVec); auto swapSize = static_cast(embHashMap.swapPos.size()); - tmpData.emplace_back(Tensor(tensorflow::DT_INT32, { swapSize })); + tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { swapSize })); - auto swapTensorData = tmpData.back().flat(); + auto swapTensorData = tmpDataOut.back().flat(); for (int i = 0; i < swapSize; i++) { swapTensorData(i) = static_cast(embHashMap.swapPos[i]); } @@ -69,8 +69,8 @@ void EmbHashMap::Process(const string& embName, const vector& keys, s embHashMap.swapPos.clear(); spdlog::info("current dev emb usage:{}-{}/[{}+{}]", embName, embHashMap.maxOffset, embHashMap.devVocabSize, embHashMap.hostVocabSize); - tmpData.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); - auto swapLen = tmpData.back().flat(); + tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); + auto swapLen = tmpDataOut.back().flat(); swapLen(0) = swapSize; EASY_END_BLOCK } diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 4cd215e2..7581ad94 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -23,7 +23,7 @@ namespace MxRec { void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); - void Process(const string& embName, const std::vector& keys, size_t iBatch, vector& tmpData); + void Process(const string& embName, const std::vector& keys, size_t iBatch, vector& tmpDataOut); void FindAndUpdateOffset(const string& embName, const vector& keys, size_t currentBatchId, size_t keepBatchId); diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 15514529..75f23936 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -172,15 +172,15 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI * missingKeysHostPos为host侧需要发送的emb的位置 */ void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& embName, - vector& h2d_emb) + vector& h2dEmbOut) { EASY_FUNCTION() const auto& emb = hostEmbs[embName]; const int embeddingSize = emb.hostEmbInfo.extEmbeddingSize; - h2d_emb.emplace_back(Tensor(tensorflow::DT_FLOAT, { + h2dEmbOut.emplace_back(Tensor(tensorflow::DT_FLOAT, { int(missingKeysHostPos.size()), embeddingSize })); - auto& tmpTensor = h2d_emb.back(); + auto& tmpTensor = h2dEmbOut.back(); auto tmpData = tmpTensor.flat(); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(missingKeysHostPos, emb, tmpData) for (size_t i = 0; i < missingKeysHostPos.size(); i++) { diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index ca022e7c..69e667fa 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -39,7 +39,7 @@ namespace MxRec { void UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName); void GetH2DEmb(const vector& missingKeysHostPos, const string& embName, - vector& h2d_emb); + vector& h2dEmbOut); auto GetHostEmbs() -> absl::flat_hash_map*; void EvictInitEmb(const string& embName, const vector& offset); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 889b30fb..215b1329 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -244,11 +244,6 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_de isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; TimeCost tc; -// keys_t lookupKeys; -// vector restore; -// vector hotPos; -// vector scAll; -// vector countRecv; UniqueInfo uniqueInfo; ProcessBatchWithUniqueCompute(batch, unique, id, uniqueInfo); TIME_PRINT("no copy ProcessBatchWithUniqueCompute TimeCost(ms):{}", tc.ElapsedMS()); @@ -408,7 +403,7 @@ size_t KeyProcess::GetKeySize(const unique_ptr &batch) } void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id, - UniqueInfo& uniqueInfo) + UniqueInfo& uniqueInfoOut) { EASY_FUNCTION(profiler::colors::Purple) EASY_VALUE("batchId", batch->batchId) @@ -423,7 +418,7 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba keySendInfo.keySend.resize(size); vector splitSize(rankInfo.rankSize); vector uniqueVector(batch->batchSize); - uniqueInfo.restore.resize(batch->batchSize); + uniqueInfoOut.restore.resize(batch->batchSize); vector idCount(batch->batchSize); keySendInfo.keyCount.resize(size); std::shared_lock lock(g_smut); @@ -432,15 +427,15 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba int hotOffset = 0; if (rankInfo.useHot) { - uniqueInfo.hotPos.resize(hotEmbTotCount[batch->name]); + uniqueInfoOut.hotPos.resize(hotEmbTotCount[batch->name]); hotOffset = hotEmbTotCount[batch->name]; } absl::flat_hash_map keyCountMap; - UniqueData uniqueData = {batch->tensorAddr, batch->batchSize, uniqueInfo.restore.data(), uniqueVector.data(), + UniqueData uniqueData = {batch->tensorAddr, batch->batchSize, uniqueInfoOut.restore.data(), uniqueVector.data(), splitSize.data(), keySendInfo.keySend.data(), idCount.data(), keySendInfo.keyCount.data()}; UniqueFlag uniqueFlag = {batch->isInt64, rankInfo.useStatic, rankInfo.useHot}; - UniqueForHot uniqueForHot = {hotOffset, uniqueInfo.hotPos.data(), hotMap, keyCountMap}; + UniqueForHot uniqueForHot = {hotOffset, uniqueInfoOut.hotPos.data(), hotMap, keyCountMap}; UniqueThreadNum uniqueThreadNum = {MIN_UNIQUE_THREAD_NUM, MAX_UNIQUE_THREAD_NUM}; unique->Compute(&pool_, uniqueData, uniqueFlag, uniqueForHot, uniqueThreadNum); @@ -460,20 +455,20 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba sc[i] = splitSize[i]; } } - All2All(sc, id, batch->channel, keySendInfo, uniqueInfo.all2AllInfo); + All2All(sc, id, batch->channel, keySendInfo, uniqueInfoOut.all2AllInfo); spdlog::debug(KEY_PROCESS "ProcessBatchWithUniqueCompute get batchId:{}, batchSize:{}, channel:{}, " "channelName:{}, name:{}, restore:{}, keyCount:{}", batch->batchId, batch->batchSize, - batch->channel, batch->channelName, batch->name, uniqueInfo.restore.size(), + batch->channel, batch->channelName, batch->name, uniqueInfoOut.restore.size(), keySendInfo.keyCount.size()); } void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, - All2AllInfo& all2AllInfo) + All2AllInfo& all2AllInfoOut) { TimeCost get_sc_all; - GetScAll(sc, id, channel, all2AllInfo.scAll); // Allgather通信获取所有(不同rank相同thread id的) + GetScAll(sc, id, channel, all2AllInfoOut.scAll); // Allgather通信获取所有(不同rank相同thread id的) TIME_PRINT("GetScAll TimeCost(ms):{}", get_sc_all.ElapsedMS()); TimeCost all2allTC; @@ -481,17 +476,17 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS vector rc(rankInfo.rankSize); // receive count for (int i = 0; i < rankInfo.rankSize; ++i) { // 通信量矩阵某一列的和即为本地要从其他设备接受的key数据量 - rc[i] = all2AllInfo.scAll.at(i * rankInfo.rankSize + rankInfo.rankId); + rc[i] = all2AllInfoOut.scAll.at(i * rankInfo.rankSize + rankInfo.rankId); } auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 - all2AllInfo.keyRecv.resize(rs.back() + rc.back()); + all2AllInfoOut.keyRecv.resize(rs.back() + rc.back()); EASY_BLOCK("all2all") - MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfo.keyRecv.data(), rc.data(), + MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfoOut.keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[channel][id]); - all2AllInfo.countRecv.resize(rs.back() + rc.back()); + all2AllInfoOut.countRecv.resize(rs.back() + rc.back()); if (isWithFAAE) { - MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfo.countRecv.data(), + MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfoOut.countRecv.data(), rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); } TIME_PRINT("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); @@ -739,10 +734,10 @@ void KeyProcess::UpdateHotMap(absl::flat_hash_map& keyCountMap, * 将本地(rank)batch要发送的key数据量进行Allgather通信,获取所有(不同rank相同thread id的)线程间的通信量矩阵 * scAll返回:所有线程间的通信量矩阵(按行平铺的一维向量) */ -void KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel, vector &scAll) const +void KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const { EASY_FUNCTION() - scAll.resize(rankInfo.rankSize * rankInfo.rankSize); + scAllOut.resize(rankInfo.rankSize * rankInfo.rankSize); EASY_BLOCK("barrier"); // 通信终止信号,同步退出,防止线程卡住 spdlog::stopwatch sw; @@ -755,8 +750,8 @@ void KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel spdlog::debug(KEY_PROCESS "barrier time:{}", TO_MS(sw)); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, - scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); - spdlog::debug("rank {} key scAll matrix:\n{}", rankInfo.rankId, scAll); + scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + spdlog::debug("rank {} key scAllOut matrix:\n{}", rankInfo.rankId, scAllOut); } void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 35af99e6..498268a7 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -127,12 +127,12 @@ namespace MxRec { vector& splitKeys) -> tuple, vector>; void ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id, - UniqueInfo& uniqueInfo); + UniqueInfo& uniqueInfoOut); size_t GetKeySize(const unique_ptr &batch); void All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, - All2AllInfo& all2AllInfo); + All2AllInfo& all2AllInfoOut); auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; @@ -140,7 +140,7 @@ namespace MxRec { auto HashSplit_withFAAE(const unique_ptr& batch) const -> tuple, vector, vector>>; - void GetScAll(const vector& keyScLocal, int commId, int channel, vector &scAll) const; + void GetScAll(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const; void Key2Offset(const emb_name_t& embName, keys_t& splitKey); -- Gitee From 8f8f79d402d6b4210e82afead297455fe0c52e5e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 11:04:30 +0800 Subject: [PATCH 042/551] Match-id-0524e0c369ab8c810f6bb401066e9714c37708eb --- src/core/emb_hashmap/emb_hashmap.h | 3 ++- src/core/emb_mgmt/emb_mgmt.cpp | 1 - src/core/key_process/key_process.cpp | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 7581ad94..498b61fa 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -23,7 +23,8 @@ namespace MxRec { void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); - void Process(const string& embName, const std::vector& keys, size_t iBatch, vector& tmpDataOut); + void Process(const string& embName, const std::vector& keys, size_t iBatch, + vector& tmpDataOut); void FindAndUpdateOffset(const string& embName, const vector& keys, size_t currentBatchId, size_t keepBatchId); diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index 98359de5..645e0cc8 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -457,7 +457,6 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) auto all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); hdTransfer->Send(ALL2ALL, *all2all, channelId, embInfo.name); } - if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch spdlog::warn(MGMT + "embName {}[{}]{},iBatch:{} freeSize not enough, {}", embInfo.name, channelId, batchId, iBatch, lookupKeys.size()); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 215b1329..f19f2263 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -481,8 +481,8 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 all2AllInfoOut.keyRecv.resize(rs.back() + rc.back()); EASY_BLOCK("all2all") - MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfoOut.keyRecv.data(), rc.data(), - rs.data(), MPI_INT64_T, comm[channel][id]); + MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfoOut.keyRecv.data(), + rc.data(), rs.data(), MPI_INT64_T, comm[channel][id]); all2AllInfoOut.countRecv.resize(rs.back() + rc.back()); if (isWithFAAE) { -- Gitee From 98dc332f77548cdf512c90d66693e06ab8ee6538 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 11:39:02 +0800 Subject: [PATCH 043/551] Match-id-05f8506f26427523c8a66c05cbbdf77104e4fe64 --- src/ops_tf/hybrid_dataset_ops.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 85a7c261..57c7a3fd 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -250,13 +250,15 @@ public: if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { batchData->isInt64 = false; memSize = len * sizeof(int32_t); - src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data(). - data())) + offset); + src = reinterpret_cast( + reinterpret_cast(const_cast((string *)(inputTensor.tensor_data().data()))) + + offset); } else { batchData->isInt64 = true; memSize = len * sizeof(int64_t); - src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data(). - data())) + offset); + src = reinterpret_cast( + reinterpret_cast(const_cast((string *)(inputTensor.tensor_data(). + data()))) + offset); } batchData->tensorAddr = malloc(memSize); if (batchData->tensorAddr == nullptr) { @@ -482,13 +484,15 @@ public: if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { batchData->isInt64 = false; memSize = len * sizeof(int32_t); - src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data() - .data())) + offset); + src = reinterpret_cast( + reinterpret_cast(const_cast((string *)(inputTensor.tensor_data().data()))) + + offset); } else { batchData->isInt64 = true; memSize = len * sizeof(int64_t); - src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data() - .data())) + offset); + src = reinterpret_cast( + reinterpret_cast(const_cast((string *)(inputTensor.tensor_data().data()))) + + offset); } batchData->tensorAddr = malloc(memSize); if (batchData->tensorAddr == nullptr) { -- Gitee From ed28abc8a36868d61927ba71bc1f351987010e2b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 11:52:06 +0800 Subject: [PATCH 044/551] Match-id-61f562ee3b95117dfcc11f9e39da63fdb0a5439c --- src/core/utils/unique.h | 48 +++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index d0ef985f..95e8221c 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -64,7 +64,7 @@ struct UniqueThreadNum { int maxThread; }; -namespace SysytemConst{ +namespace SysytemConst { const int LEVEL1_CACHE = 64; const int DEFAULT_DEDUPLICATION_RATE = 4; const int DEDUPLICATION_RATE = 2; @@ -264,7 +264,7 @@ public: free(table_); bucketCount_ = newBucketCountPowerOf2; bucketCountMask_ = bucketCount_ - 1; - table_ = reinterpret_cast *>(aligned_alloc(64, sizeof(Meta) * bucketCount_)); + table_ = reinterpret_cast *>(aligned_alloc(LEVEL1_CACHE, sizeof(Meta) * bucketCount_)); } bzero(table_, sizeof(Meta) * bucketCount_); overflow_.clear(); @@ -280,7 +280,8 @@ public: // sake. uint64_t shardedTableSize = newBucketCountPowerOf2 * N * groupCount_; int largeCount = 0; - while (shardedTableSize > stats_.totalUniques * SysytemConst::DEFAULT_DEDUPLICATION_RATE && largeCount_ != 1) { + while (shardedTableSize > stats_.totalUniques * SysytemConst::DEFAULT_DEDUPLICATION_RATE && + largeCount_ != 1) { // too large newBucketCountPowerOf2 >>= 1; shardedTableSize >>= 1; @@ -312,7 +313,6 @@ public: } // Warning: functions below are not thread safe! - // // Return the unique values // Also update the hash-order base of each bucket std::vector Unique() @@ -490,8 +490,6 @@ private: } }; // Dedup -#define CACHE_LINE_ALIGN(size) (((size) + 63ul) & ~63ul) - class OneSimpleGroupMethod { public: inline int GroupCount() @@ -641,19 +639,20 @@ public: baseVector.push_back(base); base += total; - partSize = CACHE_LINE_ALIGN(partSize); + partSize = ((partSize) + 63ul) & ~63ul; int32_t *beginPtr = output; int32_t *finishPtr = beginPtr + inputSize; int32_t *partBeginPtr = beginPtr; int32_t *partEndPtr = - reinterpret_cast(CACHE_LINE_ALIGN(reinterpret_cast(partBeginPtr + partSize))); + reinterpret_cast(((reinterpret_cast(partBeginPtr + partSize)) + 63ul) & ~63ul); - if(uniqueFlag.useStatic){ + if (uniqueFlag.useStatic) { for (int i = 0; i < groupMethod_.GroupCount(); i++) { - if (send_cnt_ < uniqueSizeVector[i]){ - spdlog::error("sendCnt should not be smaller than uniqueSize, sendCnt {}, uniqueSize {}", send_cnt_, uniqueSizeVector[i]); + if (send_cnt_ < uniqueSizeVector[i]) { + spdlog::error("sendCnt should not be smaller than uniqueSize, sendCnt {}, uniqueSize {}", send_cnt_, + uniqueSizeVector[i]); throw SendCntTooSmallError(); } } @@ -680,7 +679,8 @@ public: // should be +/-1 off. const int numOfGroupsInShard = groupMethod_.GroupCount(); tasks.push_back([this, input, &baseVector, beginPtr, partBeginPtr, partEndPtr, numOfGroupsInShard, - totalUniqueSize, useStatic, isInt64, useHot, offset, hotMap, hotPos, hotPosMap]() -> TaskReturnType { + totalUniqueSize, useStatic, isInt64, useHot, offset, hotMap, hotPos, + hotPosMap]() -> TaskReturnType { for (int32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { auto val = isInt64 ? ((int64_t *)input)[ptr - beginPtr] : ((int32_t *)input)[ptr - beginPtr]; auto group = groupMethod_.GroupId(val); @@ -698,14 +698,16 @@ public: pool->SyncRun(tasks); } - - TileAndFill(groupMethod_.GroupCount(), uniqueVector, uniqueSize, uniqueIds, idCount, idCountFill, useStatic, uniqueSizeVector); + TileAndFill(groupMethod_.GroupCount(), uniqueVector, uniqueSize, uniqueIds, idCount, idCountFill, useStatic, + uniqueSizeVector); return 0; } - void ComputeRestore(bool useHot, int offset,const map &hotMap, int *hotPos,const map &hotPosMap, - int32_t *ptr, int64_t val, uint32_t fillOffset) const { + void ComputeRestore(bool useHot, int offset, const map &hotMap, int *hotPos, + const map &hotPosMap, + int32_t *ptr, int64_t val, uint32_t fillOffset) const + { auto hot = hotPosMap.find(val); if (!useHot) { *ptr = fillOffset; @@ -720,16 +722,20 @@ public: } uint32_t GetFillOffset(bool useStatic, const vector &baseVector, const vector &totalUniqueSize, - int64_t val, int32_t group) { + int64_t val, int32_t group) + { if (!useStatic) { return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0]; } else { - return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0] + send_cnt_ * group - totalUniqueSize[group]; + return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0] + send_cnt_ * group - + totalUniqueSize[group]; } } - void TileAndFill(int groupCount, const int64_t *uniqueVector, int32_t *uniqueSize, int64_t *uniqueIds, const int32_t *idCount, - int32_t *idCountFill, bool useStatic, const std::vector &uniqueSizeVector) const { + void TileAndFill(int groupCount, const int64_t *uniqueVector, int32_t *uniqueSize, int64_t *uniqueIds, + const int32_t *idCount, int32_t *idCountFill, bool useStatic, + const std::vector &uniqueSizeVector) const + { int start = 0; int index = 0; @@ -748,7 +754,7 @@ public: size_t mem_size = uniqueSizeVector[i] * sizeof(int64_t); auto rc = memcpy_s(uniqueIds + start, mem_size, uniqueVector + index, mem_size); if (rc != 0) { - spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}",mem_size); + spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}", mem_size); return; } mem_size = uniqueSizeVector[i] * sizeof(int32_t); -- Gitee From 031102efef55c177bd145c62d3deb1b24f763b07 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 11:57:58 +0800 Subject: [PATCH 045/551] Match-id-e3c5ce02a7beaef19ad8baf7c784b4f3152349c4 --- src/core/key_process/key_process.cpp | 4 ++-- src/core/utils/common.h | 3 +-- src/core/utils/unique.h | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 3ea879fa..57dcdc7b 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -552,7 +552,7 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< EASY_FUNCTION(profiler::colors::Gold) auto* batchData = batch->sample.data(); size_t miniBs = batch->Size(); - ASSERT(batchData != nullptr); + assert(batchData != nullptr); vector splitKeys(rankInfo.rankSize); vector restore(batch->Size()); vector hashSplitLens(rankInfo.rankSize); // 初始化全0,记录每个桶的长度 @@ -591,7 +591,7 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const EASY_FUNCTION(profiler::colors::Gold) auto* batchData = batch->sample.data(); size_t miniBs = batch->Size(); - ASSERT(batchData != nullptr); + assert(batchData != nullptr); vector splitKeys(rankInfo.rankSize); vector> keyCount(rankInfo.rankSize); // splitKeys在原始batch中对应的频次 vector restore(batch->Size()); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 883aca4e..af8b15dc 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -43,7 +43,6 @@ #endif namespace MxRec { -#define ASSERT(arg) assert(arg) #define TO_MS(arg) duration_cast((arg).elapsed()) #define INFO_PTR shared_ptr #define TIME_PRINT spdlog::info @@ -268,7 +267,7 @@ struct BatchTask { if (randomInfo.len == 0) { return; } - ASSERT(static_cast(vecData.size()) >= randomInfo.len + randomInfo.start); + assert(static_cast(vecData.size()) >= randomInfo.len + randomInfo.start); std::uniform_real_distribution distribution(min, max); std::generate_n(vecData.begin() + randomInfo.start, randomInfo.len, [&]() { return distribution(generator); }); } diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index 95e8221c..240c1cef 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -264,7 +264,7 @@ public: free(table_); bucketCount_ = newBucketCountPowerOf2; bucketCountMask_ = bucketCount_ - 1; - table_ = reinterpret_cast *>(aligned_alloc(LEVEL1_CACHE, sizeof(Meta) * bucketCount_)); + table_ = reinterpret_cast *>(aligned_alloc(SysytemConst::LEVEL1_CACHE, sizeof(Meta) * bucketCount_)); } bzero(table_, sizeof(Meta) * bucketCount_); overflow_.clear(); -- Gitee From e9380235bd290cd298d803691c1be7b48f2a8b97 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 12:40:08 +0800 Subject: [PATCH 046/551] Match-id-32e53b2a61b146502cb61082940984179c37aa84 --- mx_rec/core/asc/helper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 0c0ba694..af61fe01 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -77,7 +77,6 @@ def find_dangling_table(table_names): find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) logging.info(f"*********** find tables: {table_lookup_op}***********") - logging.info(f"opTypes{[x for x in opTypes]}") dangling_table = [] def extend(op_list, tensor, spread_tensors): -- Gitee From b29521be879f2899cecbf76813611b5cf0caa2a9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 14:14:05 +0800 Subject: [PATCH 047/551] Match-id-1d11b9e4a69813926463e66c93387a064e672976 --- src/core/utils/common.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 7ebc7057..b9db5c8c 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -401,18 +401,12 @@ struct BatchTask { }; struct UniqueInfo { - UniqueInfo() = default; - UniqueInfo(vector restore, vector hotPos, All2AllInfo all2AllInfo); - vector restore; vector hotPos; All2AllInfo all2AllInfo; }; struct KeySendInfo { - KeySendInfo() = default; - KeySendInfo(keys_t keySend, vector keyCount); - keys_t keySend; vector keyCount; }; -- Gitee From 2db4f7d7698fb6491c02a768de27625aa0996151 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 14:29:09 +0800 Subject: [PATCH 048/551] Match-id-265c19592091ecfdf91d90eb327f81beb1e7a127 --- src/core/emb_mgmt/emb_mgmt.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index 645e0cc8..d1fbdbf6 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -452,7 +452,6 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) hdTransfer->Send(LOOKUP, { tmpData.front() }, channelId, embInfo.name); tmpData.erase(tmpData.begin()); hdTransfer->Send(SWAP, tmpData, channelId, embInfo.name); - if (!mgmtRankInfo.useStatic) { auto all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); hdTransfer->Send(ALL2ALL, *all2all, channelId, embInfo.name); -- Gitee From 3eaf819e1babb47b55be409efbcb70684a96d9ee Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 14:46:38 +0800 Subject: [PATCH 049/551] Match-id-28c0424710fc0ce70a38bc7408c5b756ba70c4b9 --- src/core/host_emb/host_emb.cpp | 2 +- src/core/key_process/key_process.cpp | 7 ++++--- src/core/utils/common.h | 1 - src/ops_tf/hybrid_dataset_ops.cpp | 21 +++++++++++++-------- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 0c316b1f..52252b75 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -89,7 +89,7 @@ void HostEmb::Join() t->join(); } procThread.clear(); - spdlog::info(HOSTEMB + "hostemb end join, cost:{}", TO_MS(sw)); + spdlog::info(HOSTEMB + "hostemb end join, cost:{}", duration_cast((sw).elapsed())); } /* diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 57dcdc7b..f9cab9fd 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -206,7 +206,7 @@ void KeyProcess::KeyProcessTask(const int channel, const int id) // thread id [0 spdlog::info(KEY_PROCESS "batch is nullptr"); break; } - auto getBatchTime = TO_MS(sw); + auto getBatchTime = duration_cast((sw).elapsed()); sw.reset(); if (unique == nullptr || preSendCount != embInfos[batch->name].sendCount) { @@ -226,7 +226,8 @@ void KeyProcess::KeyProcessTask(const int channel, const int id) // thread id [0 } TIME_PRINT("getAndProcesTC TimeCost(ms):{}", getAndProcesTC.ElapsedMS()); spdlog::info(KEY_PROCESS "key process cost:{}, get data time:{} batch {}[{}]:{} ", - TO_MS(sw), getBatchTime, batch->name, batch->channel, batch->batchId); + duration_cast( + (sw).elapsed()), getBatchTime, batch->name, batch->channel, batch->batchId); free(batch->tensorAddr); batchQueue->PutDirty(move(batch)); } @@ -750,7 +751,7 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int throw EndRunError("GetScAll end run."); } EASY_END_BLOCK; - spdlog::debug(KEY_PROCESS "barrier time:{}", TO_MS(sw)); + spdlog::debug(KEY_PROCESS "barrier time:{}", duration_cast((sw).elapsed())); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index af8b15dc..7b2fbe68 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -43,7 +43,6 @@ #endif namespace MxRec { -#define TO_MS(arg) duration_cast((arg).elapsed()) #define INFO_PTR shared_ptr #define TIME_PRINT spdlog::info #define MGMT_CPY_THREADS 4 diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 57c7a3fd..17342e4d 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -199,7 +199,7 @@ public: EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}, " "splits: {}, dataSize: {}, filedNum: {}, channelNames: {}, modifyGraph: {}", - TO_MS(sw), TO_MS(staticSw), + duration_cast((sw).elapsed()), duration_cast((staticSw).elapsed()), channelId, batchId, splits.size(), dataSize, fieldNum, channelNames, modifyGraph); staticSw.reset(); } @@ -427,8 +427,9 @@ public: EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); TIME_PRINT("EnqueueBatchData TimeCost(ms):{}", tc.ElapsedMS()); - TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", TO_MS(sw), - TO_MS(staticSw), channelId, batchId); + TIME_PRINT(KEY_PROCESS + "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", duration_cast((sw).elapsed()), + duration_cast((staticSw).elapsed()), channelId, batchId); staticSw.reset(); } @@ -603,7 +604,8 @@ public: for (int i { 0 }; i < restoreLen; ++i) { r(i) = i % lookupLen; } - spdlog::warn("dummy read batch cost: {},elapsed from last {}", TO_MS(sw), TO_MS(staticSw)); + spdlog::warn("dummy read batch cost: {},elapsed from last {}", duration_cast((sw).elapsed()), + duration_cast((staticSw).elapsed())); staticSw.reset(); } @@ -674,8 +676,10 @@ public: floatDataIndex += floatList.value_size(); } } - spdlog::info("ReadRaw sampleId:{} cost:{} copy:{} , elapsed from last:{}", sampleId++, TO_MS(sw), - TO_MS(sw_copy), TO_MS(staticReadRaw)); + spdlog::info("ReadRaw sampleId:{} cost:{} copy:{} , elapsed from last:{}", sampleId++, + duration_cast((sw).elapsed()), + duration_cast((sw_copy).elapsed()), + duration_cast((staticReadRaw).elapsed())); staticReadRaw.reset(); } @@ -727,8 +731,9 @@ public: auto input = inputTensor.flat(); int32_t batchId = input(0); - spdlog::info("ReadRawDummy cost:{}, elapsed from last:{} , batchId = {}", TO_MS(sw), TO_MS(staticReadRaw), - batchId); + spdlog::info("ReadRawDummy cost:{}, elapsed from last:{} , batchId = {}", + duration_cast((sw).elapsed()), + duration_cast((staticReadRaw).elapsed()), batchId); staticReadRaw.reset(); } -- Gitee From bcaf4e2327923dcdabd7829e9e39990df433392d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 15:22:58 +0800 Subject: [PATCH 050/551] Match-id-f94f7a40e4e8c6e29ab9ffe9ff6c06f24fd5f20e --- src/core/key_process/key_process.h | 3 ++- src/ops_tf/hybrid_dataset_ops.cpp | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 73bcb4e9..f0d50955 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -88,7 +88,8 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); bool isRunning { false }; - inline bool hasEmbName(const string& emb_name ){ + inline bool hasEmbName(const string& emb_name ) + { return embInfos.find(emb_name) != embInfos.end(); }; GTEST_PRIVATE: diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 5bd6ecc1..c1dd23e5 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -204,13 +204,14 @@ public: staticSw.reset(); } - void CheckEmbTables(){ + void CheckEmbTables() + { auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < embNames.size(); ++i) { if (!keyProcess->hasEmbName(embNames.at(i))) { spdlog::info("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i)); tableUsed.push_back(false); - }else{ + } else { tableUsed.push_back(true); } } @@ -219,7 +220,7 @@ public: void EnqueueBatchData(std::vector ids, time_t timestamp, const Tensor& inputTensor, const TTypes::ConstFlat& splits) { - if(tableUsed.empty()){ + if (tableUsed.empty()) { CheckEmbTables(); } auto queue = SingletonQueue::getInstances(ids[1]); @@ -228,7 +229,7 @@ public: offset += 1; // 前面8个字节是unix时间戳 } for (int i = 0; i < splits.size(); ++i) { - if(!tableUsed.at(i)){ + if (!tableUsed.at(i)) { offset += splits(i); continue; } @@ -453,13 +454,14 @@ public: staticSw.reset(); } - void CheckEmbTables(){ + void CheckEmbTables() + { auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < splits.size(); ++i) { if (!keyProcess->hasEmbName(embNames.at(i))) { spdlog::info("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i)); tableUsed.push_back(false); - }else{ + } else { tableUsed.push_back(true); } } @@ -467,7 +469,7 @@ public: int EnqueueBatchData(int batchId, int batchQueueId, time_t timestamp, const Tensor& inputTensor) { - if(tableUsed.empty()){ + if (tableUsed.empty()) { CheckEmbTables(); } auto queue = SingletonQueue::getInstances(batchQueueId); @@ -478,7 +480,7 @@ public: } TimeCost ctAll; for (size_t i = 0; i < splits.size(); ++i) { - if(!tableUsed.at(i)){ + if (!tableUsed.at(i)) { offset += splits.at(i); continue; } -- Gitee From 825a6d1490ba89ce2d2385a52ca1d035fb4997b0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 15:29:43 +0800 Subject: [PATCH 051/551] Match-id-f497a8eb23cd167243491055957152b12bf143ff --- mx_rec/core/asc/helper.py | 2 +- mx_rec/core/asc/manager.py | 2 +- src/core/key_process/key_process.h | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index fc6fc66d..19b2fd4d 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -8,7 +8,7 @@ from functools import reduce import tensorflow as tf from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static, \ - export_table_instances,insert_dangling_table + export_table_instances, insert_dangling_table from mx_rec.core.asc.feature_spec import FeatureSpec diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index e32daf3d..fa8d6234 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -10,7 +10,7 @@ from mx_rec.util.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ - get_use_hot, get_use_dynamic_expansion, export_optimizer,export_dangling_table + get_use_hot, get_use_dynamic_expansion, export_optimizer, export_dangling_table from mx_rec.core.asc.helper import find_dangling_table diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index f0d50955..cf078588 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -88,7 +88,8 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); bool isRunning { false }; - inline bool hasEmbName(const string& emb_name ) + + inline bool hasEmbName(const string &emb_name) { return embInfos.find(emb_name) != embInfos.end(); }; -- Gitee From 1b19593011206ef82ba320c3bb9fa35973f774f7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 15:36:44 +0800 Subject: [PATCH 052/551] Match-id-6d798a0ffcf4f4f0416a6d38fde22d01dbcd4fc6 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 5 +++++ src/core/hybrid_mgmt/hybrid_mgmt.h | 2 ++ src/core/utils/unique.h | 3 ++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index a53ff9d5..6aba11b8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -475,6 +475,11 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) break; } } + return getResult(channelId, batchId, start, iBatch, parseKeyTC); +} + +bool HybridMgmt::getResult(int channelId, int& batchId, int start, int iBatch, TimeCost parseKeyTC) +{ if (!isRunning) { return false; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 3238edc3..56a4cba8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -83,6 +83,8 @@ namespace MxRec { bool ParseKeys(int channelId, int& batchId); + bool getResult(int channelId, int& batchId, int start, int iBatch, TimeCost parseKeyTC); + void EmbHDTrans(const int channelId, const int batchId); void Evict(); diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index 240c1cef..ff1e116d 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -264,7 +264,8 @@ public: free(table_); bucketCount_ = newBucketCountPowerOf2; bucketCountMask_ = bucketCount_ - 1; - table_ = reinterpret_cast *>(aligned_alloc(SysytemConst::LEVEL1_CACHE, sizeof(Meta) * bucketCount_)); + table_ = reinterpret_cast *>(aligned_alloc(SysytemConst::LEVEL1_CACHE, + sizeof(Meta) * bucketCount_)); } bzero(table_, sizeof(Meta) * bucketCount_); overflow_.clear(); -- Gitee From db4af66cde781735f806e9c2e055f35faa673fe9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 15:44:16 +0800 Subject: [PATCH 053/551] Match-id-e490c8d58c68f896cd65aff5fe3fcc408ad0bcd0 --- mx_rec/core/asc/manager.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index fa8d6234..d0da8e03 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -13,7 +13,12 @@ from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, se get_use_hot, get_use_dynamic_expansion, export_optimizer, export_dangling_table from mx_rec.core.asc.helper import find_dangling_table - +def check_dangling_table(): + dangling_table = export_dangling_table() + if not dangling_table: + dangling_table = find_dangling_table([table_instance.table_name + for _, table_instance in export_table_instances().items()]) + return dangling_table def generate_table_info_list(): from mxrec_pybind import EmbInfo from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN @@ -29,10 +34,8 @@ def generate_table_info_list(): optimizer = export_optimizer() # generate table info - dangling_table = export_dangling_table() - if not dangling_table: - dangling_table = find_dangling_table([table_instance.table_name - for _, table_instance in export_table_instances().items()]) + dangling_table = check_dangling_table() + for _, table_instance in export_table_instances().items(): # When dynamic expansion mode, ext_emb_size is set by optimizer if optimizer is not None: -- Gitee From 2232e017c292b4ca0f052228ed79a798fe000f5e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 15:50:09 +0800 Subject: [PATCH 054/551] Match-id-fa69dd693a7058e8f448b24a3cb51f97efa2431a --- mx_rec/core/asc/manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index d0da8e03..eb38a9ca 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -13,12 +13,15 @@ from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, se get_use_hot, get_use_dynamic_expansion, export_optimizer, export_dangling_table from mx_rec.core.asc.helper import find_dangling_table + def check_dangling_table(): dangling_table = export_dangling_table() if not dangling_table: dangling_table = find_dangling_table([table_instance.table_name for _, table_instance in export_table_instances().items()]) return dangling_table + + def generate_table_info_list(): from mxrec_pybind import EmbInfo from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN -- Gitee From 3653ab07e52b4ff14cdc9e8afe963f50d9a8e8e3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 16:01:17 +0800 Subject: [PATCH 055/551] Match-id-ea67aa2e211c7bac8e765d70c09f7f59e69a5355 --- src/core/utils/spinlock.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/utils/spinlock.h b/src/core/utils/spinlock.h index abbb748c..e3b9e07d 100644 --- a/src/core/utils/spinlock.h +++ b/src/core/utils/spinlock.h @@ -34,8 +34,8 @@ static constexpr uint16_t g_kMaxSpinCountBeforeThreadYield = 64; class SpinLock final { public: void lock() noexcept {} - bool try_lock() noexcept { return true; } void unlock() noexcept {} + bool try_lock() noexcept { return true; } }; #elif defined(USE_MUTEX) -- Gitee From 2a2279ae3bf631d6b0a6a92cc868b529f584ec68 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 16:30:55 +0800 Subject: [PATCH 056/551] Match-id-cce01cb1bd80ba52b6457fa0a58ddc29dd5c822c --- src/core/key_process/key_process.cpp | 30 +++++++++++++++++----------- src/core/key_process/key_process.h | 7 ++++--- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index b6f6b90d..a5f644d3 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -192,8 +192,11 @@ void KeyProcess::LoadSaveUnlock() void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCESS_THREAD-1] { unique_ptr batch; - ShardedDedup *unique = nullptr; - int preSendCount = 0; + GroupMethod groupMethod; + groupMethod.SetGroupCount(rankInfo.rankSize); + shared_ptr unique; + map> uniquePtrMap; + spdlog::stopwatch sw; try { while (true) { @@ -209,15 +212,18 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES auto getBatchTime = TO_MS(sw); sw.reset(); - if (unique == nullptr || preSendCount != embInfos[batch->name].sendCount) { - GroupMethod groupMethod; - groupMethod.SetGroupCount(rankInfo.rankSize); - auto sendCountSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); - unique = new ShardedDedup(groupMethod, batch->batchSize, sendCountSize); - } else { + auto sendCountSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); + shared_ptr uniquePtr; + if (uniquePtrMap.find(sendCountSize) == uniquePtrMap.end()) { + uniquePtr.reset(new sharded_dedup(groupMethod, batch->batchSize, sendCountSize)); + uniquePtrMap.insert(std::make_pair(sendCountSize, uniquePtr)); + } + unique = uniquePtrMap[sendCountSize]; + + if (unique != nullptr) { unique->StartNewRound(); } - preSendCount = embInfos[batch->name].sendCount; + auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); if (!KeyProcessTaskHelper(batch, unique, channel, id, sw)) { free(batch->tensorAddr); @@ -230,14 +236,13 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES free(batch->tensorAddr); batchQueue->PutDirty(move(batch)); } - delete unique; } catch (const EndRunError &e) { spdlog::debug(KEY_PROCESS "abort run: {}", e.what()); } spdlog::info(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, id, channel); } -bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_dedup unique, +bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, shared_ptr unique, int channel, int id, spdlog::stopwatch &sw) { // tuple for keyRec restore hotPos scAll countRecv @@ -402,7 +407,8 @@ size_t KeyProcess::GetKeySize(const unique_ptr &batch) return size; } -auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id) +auto KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, + shared_ptr unique, int id) -> tuple, vector, vector, vector> { EASY_FUNCTION(profiler::colors::Purple) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 3a98aa0a..0456cabc 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -37,7 +37,7 @@ namespace MxRec { constexpr int MAX_UNIQUE_THREAD_NUM = 8; using a2a_info_t = vector; - using sharded_dedup = ShardedDedup*; + using sharded_dedup = ShardedDedup; template struct Cmp { bool operator () (const T &a, const T &b) @@ -121,11 +121,12 @@ namespace MxRec { void KeyProcessTask(int channel, int id); - bool KeyProcessTaskHelper(unique_ptr& batch, sharded_dedup unique_, + bool KeyProcessTaskHelper(unique_ptr& batch, shared_ptr unique, int channel, int id, spdlog::stopwatch& sw); auto ProcessSplitKeys(const unique_ptr& batch, int id, vector& splitKeys) -> tuple, vector>; - auto ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique_, int id) + auto ProcessBatchWithUniqueCompute(const unique_ptr &batch, + shared_ptr unique, int id) -> tuple, vector, vector, vector>; size_t GetKeySize(const unique_ptr &batch); -- Gitee From 07f9b09356bd78aa47b2d6ac4dcf824f1c283c70 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 16:44:26 +0800 Subject: [PATCH 057/551] Match-id-6371bd584475e2e0b11347f0e21cdd3ff26ebd80 --- src/core/emb_mgmt/emb_mgmt.cpp | 62 ++++++++++++++++++---------------- src/core/emb_mgmt/emb_mgmt.h | 2 ++ 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index d1fbdbf6..e4b34dc6 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -435,37 +435,12 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) while (true) { spdlog::info(MGMT + "parse keys, [{}]:{}", channelId, batchId); for (const auto& embInfo : mgmtEmbInfo) { - auto& embHashMap = hostHashMaps->embHashMaps.at(embInfo.name); - if (iBatch == 0) { - embHashMap.SetStartCount(); - } - auto lookupKeys = preprocess->GetLookupKeys(batchId, embInfo.name, channelId); - if (lookupKeys.empty()) { - remainBatch = false; - break; - } - auto restore = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); - hdTransfer->Send(RESTORE, *restore, channelId, embInfo.name); - - vector tmpData; - hostHashMaps->Process(embInfo.name, lookupKeys, iBatch, tmpData); - hdTransfer->Send(LOOKUP, { tmpData.front() }, channelId, embInfo.name); - tmpData.erase(tmpData.begin()); - hdTransfer->Send(SWAP, tmpData, channelId, embInfo.name); - if (!mgmtRankInfo.useStatic) { - auto all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(ALL2ALL, *all2all, channelId, embInfo.name); - } - if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch - spdlog::warn(MGMT + "embName {}[{}]{},iBatch:{} freeSize not enough, {}", embInfo.name, channelId, - batchId, iBatch, lookupKeys.size()); - ifHashmapFree = false; + ifHashmapFree = ProcessEmbInfo(embInfo.name, batchId, channelId, iBatch, remainBatch); + if (!remainBatch) { + EmbHDTransWrap(channelId, batchId, start, iBatch); + return false; } } - if (!remainBatch) { - EmbHDTransWrap(channelId, batchId, start, iBatch); - return false; - } batchId++; iBatch++; if (EndBatch(batchId, channelId) || iBatch == mgmtRankInfo.nBatch || !ifHashmapFree || !isRunning) { @@ -480,6 +455,35 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) return true; } +bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int& batchId, int channelId, int iBatch, bool& remainBatch) +{ + auto& embHashMap = hostHashMaps->embHashMaps.at(embName); + if (iBatch == 0) { + embHashMap.SetStartCount(); + } + auto lookupKeys = preprocess->GetLookupKeys(batchId, embName, channelId); + if (lookupKeys.empty()) { + remainBatch = false; + } + auto restore = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); + hdTransfer->Send(RESTORE, *restore, channelId, embName); + vector tmpData; + hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData); + hdTransfer->Send(LOOKUP, { tmpData.front() }, channelId, embName); + tmpData.erase(tmpData.begin()); + hdTransfer->Send(SWAP, tmpData, channelId, embName); + if (!mgmtRankInfo.useStatic) { + auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); + hdTransfer->Send(ALL2ALL, *all2all, channelId, embName); + } + if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch + spdlog::warn(MGMT + "embName {}[{}]{},iBatch:{} freeSize not enough, {}", embName, channelId, + batchId, iBatch, lookupKeys.size()); + return false; + } + return true; +} + // send h2d & recv d2h emb void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, int iBatch) { diff --git a/src/core/emb_mgmt/emb_mgmt.h b/src/core/emb_mgmt/emb_mgmt.h index 31eef2f5..6815bb7d 100644 --- a/src/core/emb_mgmt/emb_mgmt.h +++ b/src/core/emb_mgmt/emb_mgmt.h @@ -83,6 +83,8 @@ namespace MxRec { bool ParseKeys(int channelId, int& batchId); + bool ProcessEmbInfo(const std::string& embName, int& batchId, int channelId, int iBatch, bool& remainBatch); + void EmbHDTrans(int channelId, int batchId); void Evict(); -- Gitee From 04a45d275bdbb127acbd56cb44d3fb87c9b1e065 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 17:00:38 +0800 Subject: [PATCH 058/551] Match-id-36ccc1454256722a1d32685cb835fef74575e99e --- src/core/checkpoint/checkpoint.cpp | 2 +- src/core/checkpoint/checkpoint.h | 2 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 +- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- src/core/utils/common.h | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 675d00c2..a329e427 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -139,7 +139,7 @@ void Checkpoint::MakeSaveDir(const string& dirName) } } -int Checkpoint::GetEmbeddingSize(const string& embName) +int Checkpoint::GetEmbeddingSize(const string& embName) const { for (const auto &embInfo: mgmtEmbInfo) { if (embInfo.name == embName) { diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 7539a9b6..27578053 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -83,7 +83,7 @@ namespace MxRec { void WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize); void ReadEmbedding(CkptTransData& transData, const string& dataDir); - int GetEmbeddingSize(const string& embName); + int GetEmbeddingSize(const string& embName) const; void LoadProcess(CkptData& ckptData); void GetUpperLayerLoadDir(const vector& dirNames); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 6aba11b8..f91a1726 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -421,7 +421,7 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) return true; } -bool HybridMgmt::EndBatch(int batchId, int channelId) +bool HybridMgmt::EndBatch(int batchId, int channelId) const { return (batchId % mgmtRankInfo.maxStep[channelId] == 0 && mgmtRankInfo.maxStep[channelId] != -1); } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 56a4cba8..4d8cb8c1 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -121,7 +121,7 @@ namespace MxRec { void EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo); - bool EndBatch(int batchId, int channelId); + bool EndBatch(int batchId, int channelId) const; void EmbHDTransWrap(int channelId, const int& batchId, int start, int iBatch); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 7b2fbe68..638f4cb9 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -115,12 +115,12 @@ namespace MxRec { } template struct Batch { - size_t Size() + size_t Size() const { return sample.size(); } - std::string UnParse() + std::string UnParse() const { std::string s; constexpr size_t MAX_DISP_LEN = 20; -- Gitee From a4ad62f7f6aa08ce91a5b76755406f3c9998d51b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 17:06:41 +0800 Subject: [PATCH 059/551] Match-id-66c61eee5873a053302f6840b38dda5d15cd95b1 --- src/core/key_process/key_process.cpp | 28 +++++++++++++++++----------- src/core/key_process/key_process.h | 8 ++++---- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index f19f2263..61cf960c 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -192,7 +192,10 @@ void KeyProcess::LoadSaveUnlock() void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCESS_THREAD-1] { unique_ptr batch; - ShardedDedup *unique = nullptr; + GroupMethod groupMethod; + groupMethod.SetGroupCount(rankInfo.rankSize); + shared_ptr unique; + map> uniquePtrMap; spdlog::stopwatch sw; try { @@ -209,14 +212,18 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES auto getBatchTime = TO_MS(sw); sw.reset(); - if (unique == nullptr) { - GroupMethod groupMethod; - groupMethod.SetGroupCount(rankInfo.rankSize); - auto sendCountSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); - unique = new ShardedDedup(groupMethod, batch->batchSize, sendCountSize); - } else { + auto sendCountSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); + shared_ptr uniquePtr; + if (uniquePtrMap.find(sendCountSize) == uniquePtrMap.end()) { + uniquePtr.reset(new sharded_dedup(groupMethod, batch->batchSize, sendCountSize)); + uniquePtrMap.insert(std::make_pair(sendCountSize, uniquePtr)); + } + unique = uniquePtrMap[sendCountSize]; + + if (unique != nullptr) { unique->StartNewRound(); } + auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); if (!KeyProcessTaskHelper(batch, unique, channel, id, sw)) { free(batch->tensorAddr); @@ -229,14 +236,13 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES free(batch->tensorAddr); batchQueue->PutDirty(move(batch)); } - delete unique; } catch (const EndRunError &e) { spdlog::debug(KEY_PROCESS "abort run: {}", e.what()); } spdlog::info(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, id, channel); } -bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, sharded_dedup unique, +bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, shared_ptr unique, int channel, int id, spdlog::stopwatch &sw) { // tuple for keyRec restore hotPos scAll countRecv @@ -402,8 +408,8 @@ size_t KeyProcess::GetKeySize(const unique_ptr &batch) return size; } -void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id, - UniqueInfo& uniqueInfoOut) +void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, shared_ptr unique, + int id, UniqueInfo& uniqueInfoOut) { EASY_FUNCTION(profiler::colors::Purple) EASY_VALUE("batchId", batch->batchId) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 498268a7..4698c2e5 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -37,7 +37,7 @@ namespace MxRec { constexpr int MAX_UNIQUE_THREAD_NUM = 8; using a2a_info_t = vector; - using sharded_dedup = ShardedDedup*; + using sharded_dedup = ShardedDedup; template struct Cmp { bool operator () (const T &a, const T &b) @@ -121,13 +121,13 @@ namespace MxRec { void KeyProcessTask(int channel, int id); - bool KeyProcessTaskHelper(unique_ptr& batch, sharded_dedup unique_, + bool KeyProcessTaskHelper(unique_ptr& batch, shared_ptr unique, int channel, int id, spdlog::stopwatch& sw); auto ProcessSplitKeys(const unique_ptr& batch, int id, vector& splitKeys) -> tuple, vector>; - void ProcessBatchWithUniqueCompute(const unique_ptr &batch, sharded_dedup unique, int id, - UniqueInfo& uniqueInfoOut); + void ProcessBatchWithUniqueCompute(const unique_ptr &batch, shared_ptr unique, + int id, UniqueInfo& uniqueInfoOut); size_t GetKeySize(const unique_ptr &batch); -- Gitee From 181e109a3167103306d44789999975734bb62b85 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 17:37:57 +0800 Subject: [PATCH 060/551] Match-id-d3e0dba7183eee7b5d1390d9936d39b248f39380 --- .gitignore | 2 + example/little_demo/run.sh | 6 +- src/core/emb_mgmt/emb_mgmt.cpp | 62 +++++++++++++---- src/core/emb_mgmt/emb_mgmt.h | 6 ++ src/core/key_process/key_process.cpp | 68 ++++++++++++++----- src/core/key_process/key_process.h | 43 +++++++----- src/core/utils/common.cpp | 6 +- src/core/utils/common.h | 24 +++++-- src/core/utils/unique.h | 12 ++-- src/ops_tf/hybrid_dataset_ops.cpp | 16 ++--- src/tests/checkpoint/checkpoint_test.cpp | 2 +- .../ckpt_data_handler_test.cpp | 2 +- .../feature_admit_and_evict_test.cpp | 10 +-- src/tests/key_process/key_process_test.cpp | 14 ++-- 14 files changed, 186 insertions(+), 87 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..990936c2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea/ +cmake-build-debug/ diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 76e9d1dd..3ad0a4eb 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -18,7 +18,7 @@ export PYTHONPATH=${so_path}:$PYTHONPATH # 环境python安装路径 export LD_PRELOAD=/usr/lib64/libgomp.so.1 # GNU OpenMP动态库路径 export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH # 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 -export RANK_TABLE_FILE="${cur_path}/hccl_json_8p.json" # 若使用去除ranktable方案,请注释掉这一行 +export RANK_TABLE_FILE="${cur_path}/hccl_json_${local_rank_size}p.json" # 若使用去除ranktable方案,请注释掉这一行 export JOB_ID=10086 # 训练任务使用的NPU卡数总数 export RANK_SIZE=$num_process # 若使用去除ranktable方案,请注释掉这一行 @@ -33,6 +33,7 @@ export USE_DYNAMIC=0 # 0:静态shape;1:动态shape export USE_HOT=0 # 0:关闭hot emb;1: 开启hot emb export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 +export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 #################使用去除ranktable方案时开启###################### #export CM_CHIEF_IP="192.168.1.1" # 主节点ip #export CM_CHIEF_PORT=6000 # 主节点监听端口 @@ -44,6 +45,7 @@ export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 py=$1 echo "py is $py" echo "use horovod to start tasks" +DATE=$(date +%Y-%m-%d-%H-%M-%S) horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ -python3.7 ${py} 2>&1 | tee temp.log +python3.7 ${py} 2>&1 | tee "temp_${local_rank_size}p_${KEY_PROCESS_THREAD_NUM}t_${DATE}.log" diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index 8cee8a08..52747832 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -5,29 +5,41 @@ * Date: 2022/11/15 */ #include "emb_mgmt.h" + #include #include + #include "checkpoint/checkpoint.h" #include "utils/time_cost.h" using namespace MxRec; using namespace std; -bool HybridMgmt::Initialize(RankInfo rankInfo, - const vector& embInfos, - int seed, - const vector& thresholdValues, - bool ifLoad) +bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, + const vector& thresholdValues, bool ifLoad, int seed) { - SetLog(rankInfo.rankId); - if (isRunning) { - return true; + if (getenv("KEY_PROCESS_THREAD_NUM") != nullptr) { + int num = std::atoi(getenv("KEY_PROCESS_THREAD_NUM")); + if (num > MAX_KEY_PROCESS_THREAD) { + spdlog::error("[HybridMgmt::InitKeyProcess] KEY_PROCESS_THREAD_NUM:{} should be less than {}", + num, MAX_KEY_PROCESS_THREAD); + return false; + } + PerfConfig::keyProcessThreadNum = num; + spdlog::info("config KEY_PROCESS_THREAD_NUM:{}", num); } + + preprocess = Singleton::GetInstance(); + preprocess->Initialize(rankInfo, embInfos, thresholdValues, ifLoad, seed); + preprocess->Start(); + return true; +} + +void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfos) +{ MPI_Comm_size(MPI_COMM_WORLD, &rankInfo.rankSize); - int localRankId = rankInfo.deviceId; - spdlog::info(MGMT + "begin initialize, localRankSize:{}, localRankId {}, rank {}", rankInfo.localRankSize, - localRankId, rankInfo.rankId); - rankInfo.localRankId = localRankId; + rankInfo.localRankId = rankInfo.deviceId; + size_t totHostVocabSize = 0; for (const auto& emb : embInfos) { totHostVocabSize += emb.hostVocabSize; @@ -36,17 +48,36 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, rankInfo.noDDR = true; } rankInfo.useDataset = getenv("DATASET") != nullptr; +} + +bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, int seed, + const vector& thresholdValues, bool ifLoad) +{ + if (isRunning) { + return true; + } + SetLog(rankInfo.rankId); + InitRankInfo(rankInfo, embInfos); + + spdlog::info(MGMT + "begin initialize, localRankSize:{}, localRankId {}, rank {}", + rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); + + bool rc = InitKeyProcess(rankInfo, embInfos, thresholdValues, ifLoad, seed); + if (!rc) { + return false; + } + mgmtRankInfo = rankInfo; mgmtEmbInfo = embInfos; skipUpdate = getenv("SKIP_UPDATE") != nullptr; + hdTransfer = Singleton::GetInstance(); hdTransfer->Init(embInfos, rankInfo.deviceId); - preprocess = Singleton::GetInstance(); - preprocess->Initialize(rankInfo, embInfos, thresholdValues, ifLoad, seed); - preprocess->Start(); + lookUpKeysQueue = make_unique>>(); restoreQueue = make_unique>>(); isRunning = true; + if (!rankInfo.noDDR) { hostEmbs = make_unique(); hostHashMaps = make_unique(); @@ -57,6 +88,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, if (!rankInfo.useDataset && !isLoad) { Start(); } + for (const auto& info: embInfos) { spdlog::info(MGMT + "emb[{}] vocab size {}+{} sc:{}", info.name, info.devVocabSize, info.hostVocabSize, info.sendCount); diff --git a/src/core/emb_mgmt/emb_mgmt.h b/src/core/emb_mgmt/emb_mgmt.h index 31eef2f5..8a843563 100644 --- a/src/core/emb_mgmt/emb_mgmt.h +++ b/src/core/emb_mgmt/emb_mgmt.h @@ -89,6 +89,12 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); + private: + bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, + const vector& thresholdValues, bool ifLoad, int seed); + + void InitRankInfo(RankInfo& rankInfo, const vector& embInfos); + private: int currentBatchId; int trainBatchId = 0; // 0-199, 200- diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a5f644d3..a689bf7b 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -6,13 +6,18 @@ */ #include "key_process.h" + #include #include + #include #include #include + #include "checkpoint/checkpoint.h" #include "hd_transfer/hd_transfer.h" +#include "utils/common.h" +#include "utils/time_cost.h" using namespace std; using namespace chrono; @@ -30,6 +35,29 @@ inline vector Count2Start(const vector& count) return start; } +KeyProcess::KeyProcess() +{ + // init class members with PerfConfig::keyProcessThreadNum + for (size_t i = 0; i < MAX_CHANNEL_NUM; ++i) { + comm[i].resize(PerfConfig::keyProcessThreadNum); + for (int j = 0; j < PerfConfig::keyProcessThreadNum; ++j) { + comm[i][j] = MPI_COMM_WORLD; + } + } + + for (size_t i = 0; i < MAX_CHANNEL_NUM; ++i) { + std::vector tmp(PerfConfig::keyProcessThreadNum); + loadSaveMut[i].swap(tmp); + } + std::vector tmp(PerfConfig::keyProcessThreadNum); + getInfoMut.swap(tmp); + + storage.resize(PerfConfig::keyProcessThreadNum); + lookupKeysList.resize(PerfConfig::keyProcessThreadNum); + infoList.resize(PerfConfig::keyProcessThreadNum); + all2AllList.resize(PerfConfig::keyProcessThreadNum); +} + int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, const vector& thresholdValues, bool ifLoad, int seed) @@ -95,7 +123,7 @@ int KeyProcess::Start() // bind like: // 0 1 2 3 4 5 0 1 2 3 4 5 // | rank0 | | rank1 | - // each rank creates KEY_PROCESS_THREAD threads, each thread process one batchdata + // each rank creates PerfConfig::keyProcessThreadNum threads, each thread process one batchdata spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { #ifndef GTEST @@ -108,7 +136,7 @@ int KeyProcess::Start() KeyProcessTask(channel, id); }; // for clean code for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { - for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { + for (int id = 0; id < PerfConfig::keyProcessThreadNum; ++id) { procThread.emplace_back(fn, channel, id); // use lambda expression initialize thread } } @@ -174,7 +202,7 @@ void KeyProcess::Destroy() void KeyProcess::LoadSaveLock() { for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { - for (int threadId { 0 }; threadId < KEY_PROCESS_THREAD; ++threadId) { + for (int threadId { 0 }; threadId < PerfConfig::keyProcessThreadNum; ++threadId) { loadSaveMut[channelId][threadId].lock(); } } @@ -183,17 +211,19 @@ void KeyProcess::LoadSaveLock() void KeyProcess::LoadSaveUnlock() { for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { - for (int threadId { 0 }; threadId < KEY_PROCESS_THREAD; ++threadId) { + for (int threadId { 0 }; threadId < PerfConfig::keyProcessThreadNum; ++threadId) { loadSaveMut[channelId][threadId].unlock(); } } } -void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCESS_THREAD-1] +void KeyProcess::KeyProcessTask(int channel, int id) { unique_ptr batch; + GroupMethod groupMethod; groupMethod.SetGroupCount(rankInfo.rankSize); + shared_ptr unique; map> uniquePtrMap; @@ -224,7 +254,9 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES unique->StartNewRound(); } - auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); + auto batchQueue = + SingletonQueue::getInstances(id + PerfConfig::keyProcessThreadNum * batch->channel); + if (!KeyProcessTaskHelper(batch, unique, channel, id, sw)) { free(batch->tensorAddr); batchQueue->PutDirty(move(batch)); @@ -262,7 +294,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, shared_ptr rankInfo.rankId, id, channel); return false; } - int batchListId = batch->batchId % KEY_PROCESS_THREAD; + int batchListId = batch->batchId % PerfConfig::keyProcessThreadNum; // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) @@ -347,14 +379,14 @@ void KeyProcess::PushResult(unique_ptr& batch, unique_ptr中读取batch数据并返回。batch数据由 ReadEmbKeyV2 写入。 - * commID为线程标识[0, KEY_PROCESS_THREAD-1],不同线程、训练或推理数据用不同的共享队列通信 + * commID为线程标识[0, PerfConfig::keyProcessThreadNum-1],不同线程、训练或推理数据用不同的共享队列通信 */ unique_ptr KeyProcess::GetBatchData(int channel, int commId) { EASY_FUNCTION() unique_ptr batch = nullptr; - // train data, queue id = thread id [0, KEY_PROCESS_THREAD-1] - auto batchQueue = SingletonQueue::getInstances(commId + KEY_PROCESS_THREAD * channel); + // train data, queue id = thread id [0, PerfConfig::keyProcessThreadNum-1] + auto batchQueue = SingletonQueue::getInstances(commId + PerfConfig::keyProcessThreadNum * channel); EASY_BLOCK("get samples") EASY_VALUE("run on CPU", sched_getcpu()) spdlog::stopwatch sw; @@ -535,8 +567,8 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, } auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 keyRecv.resize(rs.back() + rc.back()); - spdlog::trace(KEY_PROCESS "MPI_Alltoallv begin. rank {} thread {} batch {} {}", rankInfo.rankId, id, batch->batchId, - batch->name); + spdlog::trace(KEY_PROCESS "MPI_Alltoallv begin. rank {} thread {} batch {} {}", + rankInfo.rankId, id, batch->batchId, batch->name); EASY_BLOCK("all2all") MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, @@ -851,9 +883,9 @@ class WrongListTop : public std::exception { }; template -T KeyProcess::GetInfo(array, KEY_PROCESS_THREAD>& list, int batch, const string& embName, int channel) +T KeyProcess::GetInfo(std::vector>& list, int batch, const string& embName, int channel) { - int batchListId = batch % KEY_PROCESS_THREAD; + int batchListId = batch % PerfConfig::keyProcessThreadNum; std::lock_guard lockGuard(getInfoMut[batchListId]); if (list[batchListId][embName][channel].empty()) { spdlog::trace("get info list is empty."); @@ -904,7 +936,7 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) { spdlog::stopwatch sw; - array, KEY_PROCESS_THREAD>* list; + std::vector>* list; switch (type) { case ProcessedInfo::ALL2ALL: list = &all2AllList; @@ -927,7 +959,7 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa auto ret = GetInfo(*list, batch, embName, channel); auto it = get>>::iterator>(ret); auto uTensor = move(*it); - int batchListId = batch % KEY_PROCESS_THREAD; + int batchListId = batch % PerfConfig::keyProcessThreadNum; unique_lock lockGuard(getInfoMut[batchListId]); storage[batchListId].erase(it); return uTensor; @@ -956,7 +988,7 @@ void KeyProcess::SendA2A(const vector& a2aInfo, const string& embName, int } tensors->emplace_back(move(tmpTensor)); - int batchListId = batchId % KEY_PROCESS_THREAD; + int batchListId = batchId % PerfConfig::keyProcessThreadNum; std::unique_lock lockGuard(getInfoMut[batchListId]); storage[batchListId].push_front(move(tensors)); all2AllList[batchListId][embName][channel].push(make_tuple(batchId, embName, storage[batchListId].begin())); @@ -1025,4 +1057,4 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset trans->Send(EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); spdlog::info(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); -} \ No newline at end of file +} diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 0456cabc..b85d9a23 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -9,25 +9,28 @@ #define MX_REC_KEY_PROCESS_H #include -#include #include #include #include +#include #include +#include + +#include +#include #include -#include #include #include + #include "utils/common.h" #include "utils/safe_queue.h" #include "utils/unique.h" #include "utils/spinlock.h" #include "utils/task_queue.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" + #include "host_emb/host_emb.h" -#include "feature_admit_and_evict.h" #include "emb_table/emb_table.h" +#include "feature_admit_and_evict.h" namespace MxRec { using namespace std; @@ -48,8 +51,10 @@ namespace MxRec { template using heap_t = priority_queue, Cmp>; + template using info_list_t = map, MAX_QUEUE_NUM>>; + enum class ProcessedInfo { RESTORE, ALL2ALL, @@ -58,6 +63,8 @@ namespace MxRec { class KeyProcess { public: + KeyProcess(); + int Initialize(const RankInfo& rInfo, const vector& eInfos, const vector& thresholdValues = {}, bool ifLoad = false, int seed = 0); @@ -90,21 +97,25 @@ namespace MxRec { bool isRunning { false }; GTEST_PRIVATE: - template - T GetInfo(array, KEY_PROCESS_THREAD>& list, int batch, const string& embName, int channel); + T GetInfo(std::vector>& list, int batch, const string& embName, int channel); RankInfo rankInfo; map embInfos; - MPI_Comm comm[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD]; + + std::vector comm[MAX_CHANNEL_NUM] {}; + vector procThread {}; std::mutex key2OffsetMut {}; - std::mutex loadSaveMut[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD] {}; - std::mutex getInfoMut[KEY_PROCESS_THREAD] {}; - array, KEY_PROCESS_THREAD> lookupKeysList; - list>> storage[KEY_PROCESS_THREAD]; - array, KEY_PROCESS_THREAD> infoList; - array, KEY_PROCESS_THREAD> all2AllList; + + std::vector loadSaveMut[MAX_CHANNEL_NUM]; + std::vector getInfoMut; + + std::vector>>> storage; + std::vector> lookupKeysList; + std::vector> infoList; + std::vector> all2AllList; + map maxOffset {}; map> keyOffsetMap {}; FeatureAdmitAndEvict m_featureAdmitAndEvict {}; @@ -168,7 +179,5 @@ namespace MxRec { vector GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss); }; -} - - +} // end namespace MxRec #endif // MX_REC_KEY_PROCESS_H diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 9da6952c..ba77c2f3 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -7,15 +7,19 @@ */ #include "common.h" + #include #include #include + #include using namespace std; using std::chrono::system_clock; namespace MxRec { + int PerfConfig::keyProcessThreadNum = DEFAULT_KEY_PROCESS_THREAD; + RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, const vector& maxStep) : rankId(rankId), deviceId(deviceId), localRankSize(localRankSize), option(option), nBatch(nBatch), maxStep(maxStep) @@ -104,4 +108,4 @@ namespace MxRec { throw std::runtime_error("dsmi_get_chip_info failed, ret = " + to_string(ret)); } -} +} // end namespace MxRec diff --git a/src/core/utils/common.h b/src/core/utils/common.h index a3555d0f..7e75a426 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -9,6 +9,10 @@ #ifndef COMMON_H #define COMMON_H +#include + +#include +#include #include #include @@ -17,11 +21,9 @@ #include #include #include -#include #include #include -#include -#include + #include "tensorflow/core/framework/tensor.h" #include "absl/container/flat_hash_map.h" @@ -52,13 +54,22 @@ namespace MxRec { // read batch cost // key process cost using namespace tensorflow; + constexpr int TRAIN_CHANNEL_ID = 0; constexpr int EVAL_CHANNEL_ID = 1; + constexpr int MAX_CHANNEL_NUM = 2; - constexpr int KEY_PROCESS_THREAD = 6; - constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * KEY_PROCESS_THREAD; + constexpr int MAX_KEY_PROCESS_THREAD = 10; + constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * MAX_KEY_PROCESS_THREAD; + + constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; + struct PerfConfig { + static int keyProcessThreadNum; + }; + constexpr int KEY_PROCESS_TIMEOUT = 120; constexpr int GET_BATCH_TIMEOUT = 300; + constexpr size_t DEFAULT_RANDOM_SEED = 10086; constexpr int INVALID_KEY_VALUE = -1; constexpr int PROFILING_START_BATCH_ID = 100; @@ -440,7 +451,8 @@ struct BatchTask { HIST_REC = 8, ATTRIBUTE = 9 }; -} +} // end namespace MxRec + #define KEY_PROCESS "\033[45m[KeyProcess]\033[0m " #ifdef GTEST #define GTEST_PRIVATE public diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index 198a8f25..b40115fb 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -121,7 +121,7 @@ template class Dedup { public: Dedup(int bucketCountPower2 = kDefaultBucketCount, int groups = 1) - : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) + : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) { void *area = aligned_alloc(64, sizeof(Meta) * bucketCount_); table_ = reinterpret_cast *>(area); @@ -419,7 +419,7 @@ public: } std::vector Replacement(const std::vector &input, std::vector *unique = nullptr, - int32_t base = 0) + int32_t base = 0) { std::vector output; if (unique) { @@ -502,8 +502,8 @@ public: using DedupT = Dedup; ShardedDedup(const GroupMethod &groupMethod, int desiredSize, int send_cnt, - int estimatedDuplicateRatio = kDefaultDuplicateRatio) - : groupMethod_(groupMethod), bucketCountPower2_(256), send_cnt_(send_cnt) + int estimatedDuplicateRatio = kDefaultDuplicateRatio) + : groupMethod_(groupMethod), bucketCountPower2_(256), send_cnt_(send_cnt) { const int numOfGroupsInShard = groupMethod_.GroupCount(); @@ -636,7 +636,7 @@ public: int32_t *partBeginPtr = beginPtr; int32_t *partEndPtr = - reinterpret_cast(CACHE_LINE_ALIGN(reinterpret_cast(partBeginPtr + partSize))); + reinterpret_cast(CACHE_LINE_ALIGN(reinterpret_cast(partBeginPtr + partSize))); if(uniqueFlag.useStatic){ for (int i = 0; i < groupMethod_.GroupCount(); i++) { @@ -668,7 +668,7 @@ public: // should be +/-1 off. const int numOfGroupsInShard = groupMethod_.GroupCount(); tasks.push_back([this, input, &baseVector, beginPtr, partBeginPtr, partEndPtr, numOfGroupsInShard, - totalUniqueSize, useStatic, isInt64, useHot, offset, hotMap, hotPos, hotPosMap]() -> TaskReturnType { + totalUniqueSize, useStatic, isInt64, useHot, offset, hotMap, hotPos, hotPosMap]() -> TaskReturnType { for (int32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { auto val = isInt64 ? ((int64_t *)input)[ptr - beginPtr] : ((int32_t *)input)[ptr - beginPtr]; auto group = groupMethod_.GroupId(val); diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 85a7c261..b8f19779 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -6,9 +6,10 @@ * History: NA */ -#include + #include #include +#include #include #include @@ -28,7 +29,6 @@ #include "utils/common.h" #include "utils/safe_queue.h" #include "utils/singleton.h" - #include "utils/time_cost.h" using namespace tensorflow; @@ -189,9 +189,9 @@ public: // 保证所有embNames在m_embStatus中有状态记录 SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); - // [batchId % KEY_PROCESS_THREAD] which thread process this batch - // [KEY_PROCESS_THREAD * 0 or 1] train or inference - int batchQueueId = batchId % KEY_PROCESS_THREAD + KEY_PROCESS_THREAD * channelId; + // [batchId % PerfConfig::keyProcessThreadNum] which thread process this batch + // [PerfConfig::keyProcessThreadNum * 0 or 1] train or inference + int batchQueueId = batchId % PerfConfig::keyProcessThreadNum + PerfConfig::keyProcessThreadNum * channelId; Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); @@ -414,9 +414,9 @@ public: // 保证所有embNames在m_embStatus中有状态记录 SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); - // [batchId % KEY_PROCESS_THREAD] which thread process this batch - // [KEY_PROCESS_THREAD * 0 or 1] train or inference - int batchQueueId = batchId % KEY_PROCESS_THREAD + KEY_PROCESS_THREAD * channelId; + // [batchId % PerfConfig::keyProcessThreadNum] which thread process this batch + // [PerfConfig::keyProcessThreadNum * 0 or 1] train or inference + int batchQueueId = batchId % PerfConfig::keyProcessThreadNum + PerfConfig::keyProcessThreadNum * channelId; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); out(0) = batchId; diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 9523debf..cd348bfa 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -35,7 +35,7 @@ protected: int64_t int64Min { static_cast(UINT32_MAX) }; int maxChannelNum = MAX_CHANNEL_NUM; - int keyProcessThread = KEY_PROCESS_THREAD; + int keyProcessThread = PerfConfig::keyProcessThreadNum; int embInfoNum { 10 }; diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 69b45ad3..bda4803a 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -33,7 +33,7 @@ protected: int64_t int64Min { static_cast(UINT32_MAX) }; int maxChannelNum { MAX_CHANNEL_NUM }; - int keyProcessThread { KEY_PROCESS_THREAD }; + int keyProcessThread { PerfConfig::keyProcessThreadNum }; vector testEmbInfos; valid_int_t validEmbInfo; diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index b136d984..1715ba71 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -382,16 +382,16 @@ protected: faae.ParseThresholdCfg(thresholds); StartEvictThread(); - std::thread thrs[KEY_PROCESS_THREAD]; + std::thread thrs[PerfConfig::keyProcessThreadNum]; // 测试多线程的 - for (int i = 0; i < KEY_PROCESS_THREAD; ++i) { + for (int i = 0; i < PerfConfig::keyProcessThreadNum; ++i) { std::string name("thread-"); name += std::to_string(i); thrs[i] = std::thread(TestMultiThread, this, std::ref(name)); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); } - for (int i = 0; i < KEY_PROCESS_THREAD; ++i) { + for (int i = 0; i < PerfConfig::keyProcessThreadNum; ++i) { if (thrs[i].joinable()) { thrs[i].join(); } @@ -411,8 +411,8 @@ protected: keys_t expectKeys2 = {121, 122, 125, 211, 212}; // 123被淘汰掉了 vector expectCnt2 = {5, 1, 1, 3, 4}; std::lock_guard lock(faae.m_syncMutexs); // 与 evict-thread 竞争资源 - CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tensorName, KEY_PROCESS_THREAD); - CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tensorName, KEY_PROCESS_THREAD); + CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tensorName, PerfConfig::keyProcessThreadNum); + CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tensorName, PerfConfig::keyProcessThreadNum); } WaitEvictThread(); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 6d1e01b7..ec246e05 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -52,16 +52,16 @@ protected: vector> PrepareBatch() { - vector> result(KEY_PROCESS_THREAD * MAX_CHANNEL_NUM); - // 向共享队列中写入本进程所有线程要处理的 KEY_PROCESS_THREAD * BATCH_NUM_EACH_THREAD 个batch数据 - for (size_t threadId = 0; threadId < KEY_PROCESS_THREAD; ++threadId) { - int batchQueueId = threadId + KEY_PROCESS_THREAD * channel; + vector> result(PerfConfig::keyProcessThreadNum * MAX_CHANNEL_NUM); + // 向共享队列中写入本进程所有线程要处理的 PerfConfig::keyProcessThreadNum * BATCH_NUM_EACH_THREAD 个batch数据 + for (size_t threadId = 0; threadId < PerfConfig::keyProcessThreadNum; ++threadId) { + int batchQueueId = threadId + PerfConfig::keyProcessThreadNum * channel; unsigned int seed = batchQueueId * 10; auto queue = SingletonQueue::getInstances(batchQueueId); for (size_t batchNum = 0; batchNum < BATCH_NUM_EACH_THREAD; ++batchNum) { size_t batchId = - batchNum * KEY_PROCESS_THREAD + threadId; + batchNum * PerfConfig::keyProcessThreadNum + threadId; for (size_t i = 0; i < embInfos.size(); i++) { // key按照不同emb表的存储切分开 auto batch = queue->GetOne(); @@ -308,7 +308,7 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) hotPos); }; // for clean code for (int channel = 0; channel < 1; ++channel) { - for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { + for (int id = 0; id < PerfConfig::keyProcessThreadNum; ++id) { process.procThread.emplace_back(fn, channel, id); // use lambda expression initialize thread } } @@ -337,7 +337,7 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) batch->batchId, lookupKeys, scAll, restore); }; // for clean code for (int channel = 0; channel < 1; ++channel) { - for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { + for (int id = 0; id < PerfConfig::keyProcessThreadNum; ++id) { process.procThread.emplace_back(fn, channel, id); // use lambda expression initialize thread } } -- Gitee From dd05bca8f282d4c86e29927d23e4efbebc81dfe2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 18:26:57 +0800 Subject: [PATCH 061/551] Match-id-abab645e2541c1577c9e5454ea07d836e53d5a48 --- mx_rec/core/asc/helper.py | 14 +++++++++----- mx_rec/util/initialize.py | 5 ++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 19b2fd4d..42d8a9eb 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -48,7 +48,7 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names def find_dangling_table(table_names): def check_tensor(table_name, table_reachable_tensor): the_op = table_reachable_tensor.op - logging.info(f"** the_op:{the_op.outputs} {the_op.name} {the_op.type}**") + logging.info(f"** table_reachable_op:{the_op.outputs} {the_op.name} {the_op.type}**") if table_reachable_tensor.op.type == 'ApplyAdam': return True @@ -147,10 +147,9 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ raise ValueError("Please config 'args_index_list', 'feature_counts' and 'table_names' at the same time.") dangling_tables = find_dangling_table(table_names) - for table_name in dangling_tables: - logging.info(f"In insert found dangling table: {table_name} " + logging.info(f"In insert found dangling table(s): {dangling_tables} " f"which does not need to be provided to the EmbInfo.") - table_names.remove(table_name) + # table_names.remove(table_name) def insert_fn_for_arg_indexes(*args): insert_tensors = get_target_tensors_with_args_indexes(args_index_list) @@ -168,7 +167,8 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ input_dict={"is_training": is_training, "dump_graph": dump_graph, "timestamp": FeatureSpec.use_timestamp(is_training), "feature_spec_names": None, - "auto_change_graph": True}) + "auto_change_graph": True, + "dangling_tables":dangling_tables}) insert_fn = insert_fn_for_arg_indexes @@ -277,9 +277,13 @@ def do_insert(args, insert_tensors, splits, table_names, input_dict): timestamp = input_dict["timestamp"] feature_spec_names = input_dict["feature_spec_names"] auto_change_graph = input_dict["auto_change_graph"] + dangling_tables = input_dict["dangling_tables"] new_insert_tensors, new_splits, new_table_names = [], [], [] for idx, table_name in enumerate(table_names): + if table_name in dangling_tables: + logging.info(f"do_insert skip table: {table_name}") + continue new_insert_tensors.append(insert_tensors[idx]) new_splits.append(splits[idx]) new_table_names.append(table_names[idx]) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 08c5d184..e65695a8 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -234,9 +234,8 @@ class ConfigInitializer: return self._training_mode_channel_dict.get(is_training) def insert_dangling_table(self, name): - if name in self._dangling_table: - return - self._dangling_table.append(name) + if name not in self._dangling_table: + self._dangling_table.append(name) @property def dangling_table(self): -- Gitee From da31698efa73af6bc1b37281c30481c959090fc5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 18:31:45 +0800 Subject: [PATCH 062/551] Match-id-ad9caac05852d7757b99aa1cb59066c34e528a05 --- mx_rec/core/asc/helper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 42d8a9eb..bff69fe2 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -149,7 +149,6 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ dangling_tables = find_dangling_table(table_names) logging.info(f"In insert found dangling table(s): {dangling_tables} " f"which does not need to be provided to the EmbInfo.") - # table_names.remove(table_name) def insert_fn_for_arg_indexes(*args): insert_tensors = get_target_tensors_with_args_indexes(args_index_list) -- Gitee From d7a08000cf4d36c51ae141972ed50155a02b7106 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 18:39:08 +0800 Subject: [PATCH 063/551] Match-id-c99180c36578fd9636f9038cf1231a81ea8adc70 --- src/core/utils/unique.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index 77b622e6..bc8a5995 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -647,7 +647,7 @@ public: int32_t *partBeginPtr = beginPtr; int32_t *partEndPtr = - reinterpret_cast(CACHE_LINE_ALIGN(reinterpret_cast(partBeginPtr + partSize))); + reinterpret_cast(((reinterpret_cast(partBeginPtr + partSize)) + 63ul) & ~63ul); if (uniqueFlag.useStatic) { for (int i = 0; i < groupMethod_.GroupCount(); i++) { -- Gitee From fdef3fea0d534520368c6edd668133bd2b4e66ab Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 18:41:01 +0800 Subject: [PATCH 064/551] Match-id-c64292ff9a1fbff06bda27deabe47113712fb7e6 --- mx_rec/core/asc/helper.py | 41 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index bff69fe2..6415e523 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -159,15 +159,27 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ for insert_tensor in insert_tensors: split = reduce(lambda x, y: x * y, insert_tensor.shape.as_list()) splits.append(split if split is not None else tf.math.reduce_prod(tf.shape(insert_tensor))) + + new_insert_tensors, new_splits, new_table_names = [], [], [] + for idx, table_name in enumerate(table_names): + if table_name in dangling_tables: + logging.info(f"do_insert skip table: {table_name}") + continue + new_insert_tensors.append(insert_tensors[idx]) + new_splits.append(splits[idx]) + new_table_names.append(table_names[idx]) + + if FeatureSpec.use_timestamp(is_training): + new_insert_tensors = insert_tensors + return do_insert(args, - insert_tensors=insert_tensors, - splits=splits, - table_names=table_names, + insert_tensors=new_insert_tensors, + splits=new_splits, + table_names=new_table_names, input_dict={"is_training": is_training, "dump_graph": dump_graph, "timestamp": FeatureSpec.use_timestamp(is_training), "feature_spec_names": None, - "auto_change_graph": True, - "dangling_tables":dangling_tables}) + "auto_change_graph": True}) insert_fn = insert_fn_for_arg_indexes @@ -276,24 +288,11 @@ def do_insert(args, insert_tensors, splits, table_names, input_dict): timestamp = input_dict["timestamp"] feature_spec_names = input_dict["feature_spec_names"] auto_change_graph = input_dict["auto_change_graph"] - dangling_tables = input_dict["dangling_tables"] - - new_insert_tensors, new_splits, new_table_names = [], [], [] - for idx, table_name in enumerate(table_names): - if table_name in dangling_tables: - logging.info(f"do_insert skip table: {table_name}") - continue - new_insert_tensors.append(insert_tensors[idx]) - new_splits.append(splits[idx]) - new_table_names.append(table_names[idx]) - - if timestamp: - new_insert_tensors = insert_tensors pipeline_op = \ - send_feature_id_request_async(feature_id_list=new_insert_tensors, - split_list=new_splits, - table_name_list=new_table_names, + send_feature_id_request_async(feature_id_list=insert_tensors, + split_list=splits, + table_name_list=table_names, input_dict={"is_training": is_training, "timestamp": timestamp, "feature_spec_names": feature_spec_names, -- Gitee From efe62e3aa60cac99aecab160df403bab3a2a6aae Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 18:54:29 +0800 Subject: [PATCH 065/551] Match-id-1ec6a27dacd5cb01efe53dd9b91f3f9d7e1ace50 --- src/tests/emb_mgmt/emb_mgmt_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index bc5d0ba8..0cdcc600 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -127,7 +127,7 @@ TEST_F(EmbMgmtTest, Initialize) allRank = RankInfo(rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); auto hostEmbs = make_unique(); - hostEmbs->Initialize(embInfos, seed, false); + hostEmbs->Initialize(embInfos, seed); auto hostHashMaps = make_unique(); hostHashMaps->Init(allRank, embInfos, false); -- Gitee From ace174f03cc87932d6573520e51ecd8e1e956320 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 28 May 2023 19:51:48 +0800 Subject: [PATCH 066/551] Match-id-6c9b4d42bc160f31606926046328aa59c2cbff84 --- src/core/emb_mgmt/emb_mgmt.cpp | 6 +++--- src/core/emb_mgmt/emb_mgmt.h | 6 +++--- src/core/host_emb/host_emb.cpp | 8 ++++---- src/core/host_emb/host_emb.h | 2 +- src/core/key_process/key_process.cpp | 9 +++++---- src/core/key_process/key_process.h | 2 +- src/core/utils/common.h | 11 ++++++++++- src/test_ut.sh | 9 ++++++++- src/tests/key_process/key_process_test.cpp | 6 ++++-- 9 files changed, 39 insertions(+), 20 deletions(-) diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index 52747832..9eb4d908 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -239,14 +239,14 @@ void HybridMgmt::Start() spdlog::info("getInfoTask done"); return ret; }; - procThreads.emplace_back(getInfoTask); + procThreads.emplace_back(std::make_unique(getInfoTask)); auto sendInfoTask = [this]() { auto ret = SendTask(); spdlog::info("sendInfoTask done"); return ret; }; - procThreads.emplace_back(sendInfoTask); + procThreads.emplace_back(std::make_unique(sendInfoTask)); } if (!mgmtRankInfo.noDDR) { @@ -255,7 +255,7 @@ void HybridMgmt::Start() spdlog::info("parseKeysTask done"); return ret; }; - procThreads.emplace_back(parseKeysTask); + procThreads.emplace_back(std::make_unique(parseKeysTask)); } } diff --git a/src/core/emb_mgmt/emb_mgmt.h b/src/core/emb_mgmt/emb_mgmt.h index 8a843563..b621d9e7 100644 --- a/src/core/emb_mgmt/emb_mgmt.h +++ b/src/core/emb_mgmt/emb_mgmt.h @@ -66,8 +66,8 @@ namespace MxRec { preprocess->isRunning = false; // 停止hdTransfer,用于停止mgmt的recv中卡住状态 hdTransfer->Destroy(); - for (auto &i : procThreads) { - i.join(); + for (auto& t : procThreads) { + t->join(); } if (hostEmbs != nullptr) { hostEmbs->Join(); @@ -104,7 +104,7 @@ namespace MxRec { RankInfo mgmtRankInfo; unique_ptr hostEmbs {}; unique_ptr hostHashMaps {}; - vector procThreads {}; + vector> procThreads {}; unique_ptr>> lookUpKeysQueue; unique_ptr>> restoreQueue; map> evictKeyMap {}; diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 6fdb1f54..905d0d5f 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -84,11 +84,11 @@ void HostEmb::LoadEmb(emb_mem_t& loadData) void HostEmb::Join() { spdlog::stopwatch sw; - spdlog::debug(HOSTEMB + "hostemb start join {}", procThread.size()); - for (auto& t: procThread) { + spdlog::debug(HOSTEMB + "hostemb start join {}", procThreads.size()); + for (auto& t: procThreads) { t->join(); } - procThread.clear(); + procThreads.clear(); spdlog::info(HOSTEMB + "hostemb end join, cost:{}", TO_MS(sw)); } @@ -132,7 +132,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI { #ifndef GTEST EASY_FUNCTION(profiler::colors::Purple) - procThread.emplace_back(make_unique( + procThreads.emplace_back(make_unique( [&, missingKeysHostPos, channelId, embName] { auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index 31928a27..b4e1b62b 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -52,7 +52,7 @@ namespace MxRec { GTEST_PRIVATE: absl::flat_hash_map hostEmbs; - std::vector> procThread; + std::vector> procThreads; void EmbDataGenerator(const vector& initializeInfos, int seed, int vocabSize, int embeddingSize, vector>& embData); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a689bf7b..673d683b 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -137,7 +137,7 @@ int KeyProcess::Start() }; // for clean code for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { for (int id = 0; id < PerfConfig::keyProcessThreadNum; ++id) { - procThread.emplace_back(fn, channel, id); // use lambda expression initialize thread + procThreads.emplace_back(std::make_unique(fn, channel, id)); } } return 0; @@ -192,10 +192,10 @@ void KeyProcess::Destroy() { isRunning = false; spdlog::info(KEY_PROCESS "rank {} begin destroy.", rankInfo.rankId); - for (auto& i: procThread) { - i.join(); + for (auto& t: procThreads) { + t->join(); } - procThread.clear(); + procThreads.clear(); spdlog::info(KEY_PROCESS "rank {} destroy success.", rankInfo.rankId); } @@ -266,6 +266,7 @@ void KeyProcess::KeyProcessTask(int channel, int id) spdlog::info(KEY_PROCESS "key process cost:{}, get data time:{} batch {}[{}]:{} ", TO_MS(sw), getBatchTime, batch->name, batch->channel, batch->batchId); free(batch->tensorAddr); + batch->tensorAddr = nullptr; batchQueue->PutDirty(move(batch)); } } catch (const EndRunError &e) { diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index b85d9a23..7c6b5ae6 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -105,7 +105,7 @@ namespace MxRec { std::vector comm[MAX_CHANNEL_NUM] {}; - vector procThread {}; + vector> procThreads {}; std::mutex key2OffsetMut {}; std::vector loadSaveMut[MAX_CHANNEL_NUM]; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 7e75a426..ff3ffc40 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -127,7 +127,8 @@ namespace MxRec { throw std::runtime_error("unknown chip ub size" + GetChipName(devID)); } - template struct Batch { + template + struct Batch { size_t Size() { return sample.size(); @@ -144,6 +145,14 @@ namespace MxRec { return s; } + ~Batch() + { + if (tensorAddr) { + free(tensorAddr); + tensorAddr = nullptr; + } + } + std::vector sample; void *tensorAddr = nullptr; std::string name; diff --git a/src/test_ut.sh b/src/test_ut.sh index f2bbe2ff..26a4423d 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -54,7 +54,14 @@ make -j make install # Run Test -mpirun -np 4 ./tests/test_main +DATE=$(date +%Y-%m-%d-%H-%M-%S) +if [[ "$1" == "--with-memcheck" ]]; then + echo "we are going to run test_main with memcheck via valgrind" + valgrind --tool=memcheck --leak-check=full --show-leak-kinds=all --log-file=../"memcheck_${DATE}.log" \ + ./tests/test_main 2>&1 |tee ../"test_main_${DATE}.log" +else + mpirun -np 4 ./tests/test_main +fi cd "$(dirname "${PWD}")" diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index ec246e05..f1b40ec1 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -309,7 +309,8 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < PerfConfig::keyProcessThreadNum; ++id) { - process.procThread.emplace_back(fn, channel, id); // use lambda expression initialize thread + // use lambda expression initialize thread + process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } } this_thread::sleep_for(20s); @@ -338,7 +339,8 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < PerfConfig::keyProcessThreadNum; ++id) { - process.procThread.emplace_back(fn, channel, id); // use lambda expression initialize thread + // use lambda expression initialize thread + process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } } this_thread::sleep_for(20s); -- Gitee From b7e5f4ebf10fa2d353d00676d5fe2e3aa3de24c5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 09:10:59 +0800 Subject: [PATCH 067/551] Match-id-a3147e8914b2380d912d23f36dc5d94c9516b52d --- build/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/build.sh b/build/build.sh index 5fd7b54c..b7fc65f3 100644 --- a/build/build.sh +++ b/build/build.sh @@ -210,5 +210,5 @@ deactivate tf2_env echo "-----Build gen tar -----" gen_tar_file -#clean +clean echo "-----Done-----" -- Gitee From 300e11a1c268dd3113e186a3ce62d9eef53678e6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 10:33:01 +0800 Subject: [PATCH 068/551] Match-id-b0fdba71a78c1cd6aff4961c5608ce6c62e1d214 --- src/core/emb_mgmt/emb_mgmt.cpp | 5 +++-- src/core/emb_mgmt/emb_mgmt.h | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index e4b34dc6..46a58f9b 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -455,7 +455,8 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) return true; } -bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int& batchId, int channelId, int iBatch, bool& remainBatch) +bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, + int channelId, int iBatch, bool& remainBatchOut) { auto& embHashMap = hostHashMaps->embHashMaps.at(embName); if (iBatch == 0) { @@ -463,7 +464,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int& batchId, int ch } auto lookupKeys = preprocess->GetLookupKeys(batchId, embName, channelId); if (lookupKeys.empty()) { - remainBatch = false; + remainBatchOut = false; } auto restore = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); hdTransfer->Send(RESTORE, *restore, channelId, embName); diff --git a/src/core/emb_mgmt/emb_mgmt.h b/src/core/emb_mgmt/emb_mgmt.h index 6815bb7d..019f5b9b 100644 --- a/src/core/emb_mgmt/emb_mgmt.h +++ b/src/core/emb_mgmt/emb_mgmt.h @@ -83,7 +83,7 @@ namespace MxRec { bool ParseKeys(int channelId, int& batchId); - bool ProcessEmbInfo(const std::string& embName, int& batchId, int channelId, int iBatch, bool& remainBatch); + bool ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut); void EmbHDTrans(int channelId, int batchId); -- Gitee From 7377226cab9842e2018dbc8b022c0f9177c691f2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 27 May 2023 18:26:55 +0800 Subject: [PATCH 069/551] Match-id-df2ccd37c9ae012257967f108b4e3df4fb30f8e2 --- example/little_demo/main.py | 4 +- mx_rec/__init__.py | 3 +- mx_rec/constants/__init__.py | 3 + mx_rec/{util => constants}/constants.py | 9 + mx_rec/core/asc/build_graph.py | 2 +- mx_rec/core/asc/helper.py | 2 +- mx_rec/core/asc/manager.py | 4 +- mx_rec/core/embedding.py | 4 +- mx_rec/graph/modifier.py | 2 +- mx_rec/saver/saver.py | 2 +- mx_rec/util/initialize.py | 76 ++++-- mx_rec/util/ops.py | 22 +- mx_rec/validator/validator.py | 300 ++++++++++++++++++++++++ 13 files changed, 389 insertions(+), 44 deletions(-) create mode 100644 mx_rec/constants/__init__.py rename mx_rec/{util => constants}/constants.py (92%) create mode 100644 mx_rec/validator/validator.py diff --git a/example/little_demo/main.py b/example/little_demo/main.py index f2d0175f..116ef432 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -15,7 +15,7 @@ from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import create_table, sparse_lookup from mx_rec.graph.modifier import modify_graph_and_start_emb_cache -from mx_rec.util.constants import MxRecMode, ASCEND_TIMESTAMP +from mx_rec.constants.constants import MxRecMode, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_id, get_rank_size, init, clear_channel, terminate_config_initializer, \ set_if_load, get_initializer from mx_rec.util.variable import get_dense_and_sparse_variable @@ -183,7 +183,7 @@ if __name__ == "__main__": train_ops.append(dense_optimizer.apply_gradients(avg_grads)) if use_dynamic_expansion: - from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) # do sparse optimization by addr diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index bbee9e8f..9578f5a2 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -2,8 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -from mx_rec.util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION -from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops +from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset diff --git a/mx_rec/constants/__init__.py b/mx_rec/constants/__init__.py new file mode 100644 index 00000000..6924f767 --- /dev/null +++ b/mx_rec/constants/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/mx_rec/util/constants.py b/mx_rec/constants/constants.py similarity index 92% rename from mx_rec/util/constants.py rename to mx_rec/constants/constants.py index 3a9be2ec..b3816d43 100644 --- a/mx_rec/util/constants.py +++ b/mx_rec/constants/constants.py @@ -33,6 +33,15 @@ DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 +# RANK INFO +VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"] +MIN_SIZE = 1 +MAX_SIZE = 1024 * 1024 * 1024 * 1024 +MAX_DEVICE_NUM = 16 +MAX_RANK_SIZE = 4095 +MIN_DEVICE_NUM = 1 +MIN_RANK_SIZE = 1 + class BaseEnum(Enum): @classmethod diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 31cf30a6..4b449fc5 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -7,7 +7,7 @@ import logging import tensorflow as tf import mxrec_pybind -from mx_rec.util.constants import AVOID_TENSOR_POS +from mx_rec.constants.constants import AVOID_TENSOR_POS from mx_rec.util.initialize import get_use_static from mx_rec.util.tf_version_adapter import npu_ops diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 21f497f3..dd0abce0 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -210,7 +210,7 @@ def do_insert(args, insert_tensors, splits, table_names, input_dict): # Only the tables that need to be used after table combination are retained in meituan situation. # Current solution has error in same situations. For example, a sparse table has not been auto-merged. - from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN new_insert_tensors, new_splits, new_table_names = [], [], [] logging.debug(f"In do_insert function, ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") for idx, table_name in enumerate(table_names): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index a8dcf6cb..944a0f9a 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -6,7 +6,7 @@ import logging import tensorflow as tf -from mx_rec.util.constants import MxRecMode +from mx_rec.constants.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ @@ -15,7 +15,7 @@ from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, se def generate_table_info_list(): from mxrec_pybind import EmbInfo - from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN # table_name is corresponding to channel_name which is in used in operator gen_npu_ops.get_next table_info_list = [] diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 4becc98a..e42b92dc 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -15,7 +15,7 @@ from tensorflow.python.ops import array_ops from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID @@ -429,7 +429,7 @@ class SparseEmbedding: @tf.custom_gradient def sparse_forward(table, feat_ids): logging.debug(f"fp rank size: {rank_size}") - if use_static: + if feat_ids.shape.as_list()[0] is not None: restore_vector = tf.ones(shape=[np.prod(feat_ids.shape.as_list()), ], dtype=tf.int32, name="restore_vector") else: diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 7ceb42dd..87eebfd7 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -13,7 +13,7 @@ from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding -from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ +from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ ASCAnchorAttr, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_size, get_training_mode_channel_id, get_feature_spec, \ insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, get_use_dynamic_expansion, \ diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 93a00810..02b70bd1 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -13,7 +13,7 @@ import numpy as np import tensorflow as tf from tensorflow.python.util import compat -from mx_rec.util.constants import DataName, DataAttr +from mx_rec.constants.constants import DataName, DataAttr from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, \ get_ascend_global_hashtable_collection diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index e9ae8c5b..1aed3a39 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -7,13 +7,14 @@ import logging import os from collections import defaultdict +import mxrec_pybind import psutil -import mxrec_pybind -import mx_rec.util.constants -from mx_rec.util.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, \ - ASCEND_GLOBAL_HASHTABLE_COLLECTION +import mx_rec.constants.constants +from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST +from mx_rec.constants.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE from mx_rec.util.ops import import_host_pipeline_ops +from mx_rec.validator.validator import RankInfoValidator class ConfigInitializer: @@ -56,7 +57,7 @@ class ConfigInitializer: if os.getenv("RANK_TABLE_FILE"): self.parse_hccl_json() else: - self.set_device_dict() + self.set_hccl_info_without_json() self.check_parameters() self.train_interval = kwargs.get("train_interval", -1) self.eval_steps = kwargs.get("eval_steps", -1) @@ -198,31 +199,56 @@ class ConfigInitializer: raise ValueError(f"get logic id from physic id fail.") self._rank_to_device_dict[rank_id] = device_id - def set_device_dict(self): + def set_hccl_info_without_json(self): + """ + Used for no rank table file configured training situation. + Now, only less than or equal 8p training job is supported. + :return: None + """ + RankInfoValidator() ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") - if not ascend_visible_devices: - raise ValueError("env variable ascend_visible_devices is null.") - if "-" in ascend_visible_devices: - rank_start = int(ascend_visible_devices.strip().split("-")[0]) - device_list = [i for i in range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]))] - elif "," in ascend_visible_devices: - device_list = list(map(int, ascend_visible_devices.strip().split(","))) - elif ascend_visible_devices in ["0", "1", "2", "3", "4", "5", "6", "7"]: - device_list = [int(ascend_visible_devices.strip())] - else: - raise ValueError("invalid env variable ascend_visible_devices.") - rank_size = int(os.getenv("CM_WORKER_SIZE")) - self._rank_to_device_dict[0] = int(os.getenv("CM_CHIEF_DEVICE")) - device_list.pop(int(os.getenv("CM_CHIEF_DEVICE"))) + device_list = [] + try: + if "-" in ascend_visible_devices: + split_devices = ascend_visible_devices.strip().split("-") + if len(split_devices) >= 1: + rank_start = int(split_devices[0]) + device_list = [i for i in range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]) + 1)] + elif "," in ascend_visible_devices: + device_list = list(map(int, ascend_visible_devices.strip().split(","))) + elif ascend_visible_devices in VALID_DEVICE_ID_LIST: + device_list = [int(ascend_visible_devices.strip())] + else: + raise ValueError("invalid env variable ascend_visible_devices.") + except ValueError as error: + raise ValueError("Invalid env variable ascend_visible_devices, no valid device id is configured. " + "Please refer to the document https://www.hiascend.com/document/detail/zh/" + "CANNCommunityEdition/63RC2alpha002/ptmoddevg/ptmigr/ptmigr_0151.html for " + "the correct configuration method.") from error + except IndexError as error: + raise IndexError( + f"Index of ascend_visible_devices {ascend_visible_devices.strip().split('-')[-1]} is out of range") \ + from error + + chief_device = os.getenv("CM_CHIEF_DEVICE") + rank_size = os.getenv("CM_WORKER_SIZE") + try: + rank_size = int(rank_size) + self._rank_to_device_dict[0] = int(chief_device) + device_list.pop(int(chief_device)) + except IndexError as err: + raise IndexError( + f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {device_list}.") from err + except ValueError as err: + raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err + if rank_size: - local_rank_size = rank_size if rank_size < 8 else 8 - for device_index in range(local_rank_size - 1): + local_rank_size = len(device_list) + for device_index in range(local_rank_size): device_id = mxrec_pybind.get_logic_id(int(device_list[device_index])) if device_id > 16: raise ValueError(f"get logic id from physic id fail.") self._rank_to_device_dict[device_index + 1] = device_id - else: - raise ValueError("get CM_WORKER_SIZE failed.") def insert_training_mode_channel_id(self, is_training): if is_training not in self._training_mode_channel_dict: @@ -557,7 +583,7 @@ def set_initializer(is_training, initializer): def set_ascend_table_name_must_contain(name="merged"): - mx_rec.util.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name + mx_rec.constants.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name def set_ascend_env(): diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py index 40c99f37..edb29781 100644 --- a/mx_rec/util/ops.py +++ b/mx_rec/util/ops.py @@ -2,20 +2,28 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import os import logging +import os + import tensorflow as tf -from mx_rec.util.constants import HOST_PIPELINE_OPS_LIB_PATH +from mx_rec.constants.constants import HOST_PIPELINE_OPS_LIB_PATH def import_host_pipeline_ops(): host_pipeline_ops_lib_path = os.getenv(HOST_PIPELINE_OPS_LIB_PATH) - if host_pipeline_ops_lib_path: + if host_pipeline_ops_lib_path and os.path.exists(host_pipeline_ops_lib_path): logging.debug(f"Using the HOST_PIPELINE_OPS_LIB_PATH '{host_pipeline_ops_lib_path}' to get ops lib.") return tf.load_op_library(host_pipeline_ops_lib_path) + elif os.path.exists( + os.path.join(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), + 'mx_rec/libasc/libasc_ops.so')): + default_so_path = os.path.join( + os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), + 'mx_rec/libasc/libasc_ops.so') + logging.debug(f"Using the DEFAULT PATH '{default_so_path}' to get ops lib.") + return tf.load_op_library(default_so_path) else: - mx_rec_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")) - so_path = os.path.join(mx_rec_dir, 'mx_rec/libasc/libasc_ops.so') - logging.debug(f"Using the DEFAULT PATH '{so_path}' to get ops lib.") - return tf.load_op_library(so_path) + raise ValueError("Invalid host pipeline ops lib path. Please check if libasc_ops.so exists or corrected " + "configured") + diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py new file mode 100644 index 00000000..1922a42f --- /dev/null +++ b/mx_rec/validator/validator.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + + +import os +from typing import Callable, Any +from typing import List, Optional, Tuple + +from mx_rec.constants.constants import MIN_SIZE +from mx_rec.constants.constants import MAX_SIZE +from mx_rec.constants.constants import MAX_DEVICE_NUM +from mx_rec.constants.constants import MAX_RANK_SIZE +from mx_rec.constants.constants import MIN_DEVICE_NUM +from mx_rec.constants.constants import MIN_RANK_SIZE + + +class Validator: + """ + A validator to check the input parameters + """ + + def __init__(self, value, msg="value is invalid"): + """ + :param value: the value for validation + :param msg: default error msg + """ + self.value = value + self.msg = msg + self.checkers = [] + self.is_valid_state = None + + def register_checker(self, checker: Callable[[Any], bool], msg: str = None): + self.checkers.append((checker, msg if msg else self.msg)) + + def check(self): + if self.is_valid_state is None: + self.is_valid_state = True + for checker, msg in self.checkers: + self.is_valid_state &= checker(self.value) + if not self.is_valid_state: + self.msg = msg + break + if self.is_valid_state: + self.msg = None + return self + + def is_valid(self): + if self.is_valid_state is None: + self.check() + return self.is_valid_state + + def get_value(self, default=None): + return self.value if self.is_valid() else default + + +class ClassValidator(Validator): + """ + Check class validator. + """ + + def __init__(self, value, classes): + super().__init__(value) + self.classes = classes + + def check_isinstance(self): + """Check arg isinstance of classes""" + self.register_checker(lambda path: isinstance(self.value, self.classes), "Invalid parameter type") + return self + + +class StringValidator(Validator): + """ + String type validator. + """ + + def __init__(self, value, max_len=None, min_len=0): + super().__init__(value) + self.max_len = max_len + self.min_len = min_len + self.register_checker(lambda x: isinstance(x, str), "type is not str") + + def check_string_length(self): + if self.min_len is not None: + self.register_checker(lambda x: len(x) >= self.min_len, f"length is less than {self.min_len}") + if self.max_len is not None: + self.register_checker(lambda x: len(x) <= self.max_len, f"length is bigger than {self.max_len}") + return self + + def check_not_contain_black_element(self, element): + self.register_checker(lambda x: x is not None and element is not None and x.find(element) == -1) + return self + + def can_be_transformed2int(self, min_value: int = None, max_value: int = None): + if min_value is None: + min_value = MIN_RANK_SIZE + if max_value is None: + max_value = MAX_RANK_SIZE + + can_transformed = self.value.isdigit() + try: + if can_transformed and (min_value > int(self.value) or max_value < int(self.value)): + can_transformed = False + except ValueError: + can_transformed = False + finally: + if self.is_valid_state is not None: + self.is_valid_state &= can_transformed + else: + self.is_valid_state = can_transformed + return self + + +class IntValidator(Validator): + """ + Int type validator + """ + + def __init__(self, value: int, min_value: int = None, max_value: int = None): + super().__init__(value) + self.min_value = min_value + self.max_value = max_value + self.register_checker(lambda x: isinstance(x, int), "type is not int") + + def check_value(self): + if self.min_value is not None: + self.register_checker(lambda x: x >= self.min_value, f"value is less than {self.min_value}") + if self.max_value is not None: + self.register_checker(lambda x: x <= self.max_value, f"value is bigger than {self.max_value}") + return self + + +class RankSizeValidator(IntValidator): + """ + Distributed training job size validator + """ + + def check_rank_size_valid(self): + super().__init__(self.value) + self.register_checker(lambda x: MIN_RANK_SIZE <= self.value <= MAX_RANK_SIZE, + "Invalid rank size") + return self + + def check_device_num_valid(self): + super().__init__(self.value) + self.register_checker(lambda x: MIN_DEVICE_NUM <= self.value <= MAX_DEVICE_NUM, + "Invalid device num") + return self + + +class DirectoryValidator(StringValidator): + def __init__(self, value, max_len=None, min_len=1): + """ + @param value: the path, should not be emtpy string, should not contain double dot(../) + """ + super().__init__(value, max_len, min_len) + self.register_checker(lambda x: isinstance(x, str), "type is not str") + + @staticmethod + def remove_prefix(string: Optional[str], prefix: Optional[str]) -> Tuple[bool, Optional[str]]: + if string is None or prefix is None or len(string) < len(prefix): + return False, string + if string.startswith(prefix): + return True, string[len(prefix):] + else: + return False, string + + @staticmethod + def check_is_children_path(path_: str, target_: str): + if not target_: + return False + + try: + realpath_ = os.path.realpath(path_) + except (TypeError, ValueError, OSError): + return False + + try: + realpath_target = os.path.realpath(target_) + except (TypeError, ValueError, OSError): + return False + + is_prefix, rest_part = DirectoryValidator.remove_prefix(realpath_target, realpath_) + + if rest_part.startswith(os.path.sep): + rest_part = rest_part.lstrip(os.path.sep) + if is_prefix: + joint_path = os.path.join(realpath_, rest_part) + return os.path.realpath(joint_path) == realpath_target + else: + return False + + @staticmethod + def __check_with_sensitive_words(path: str, words: List): + _, name = os.path.split(path) + if name: + return not any(map(lambda x: x in path, words)) + else: + return True + + def check_is_not_none(self): + self.register_checker(lambda path: self.value is not None and len(self.value) > 0, + "Invalid directory parameter") + return self + + def check_not_soft_link(self): + self.register_checker(lambda path: os.path.realpath(self.value) == os.path.normpath(self.value), + "soft link or relative path should not be in the path parameter") + return self + + def path_should_exist(self, is_file=True, msg=None): + self.register_checker(lambda path: os.path.exists(self.value), + msg if msg else "path parameter does not exist") + if is_file: + self.register_checker(lambda path: os.path.isfile(self.value), + msg if msg else "path parameter is not a file") + return self + + def path_should_not_exist(self): + self.register_checker(lambda path: not os.path.exists(self.value), "path parameter does not exist") + return self + + def with_blacklist(self, lst: List = None, exact_compare: bool = True, msg: str = None): + if lst is None: + lst = ["/usr/bin", "/usr/sbin", "/etc", "/usr/lib", "/usr/lib64"] + if len(lst) == 0: + return self + if msg is None: + msg = "path should is in blacklist" + if exact_compare: + self.register_checker(lambda path: path not in [os.path.realpath(each) for each in lst], msg) + else: + self.register_checker( + lambda path: not any([DirectoryValidator.check_is_children_path(each, path) for each in lst]), msg + ) + return self + + def should_not_contains_sensitive_words(self, words: List = None, msg=None): + if words is None: + words = ["Key", "password", "privatekey"] + self.register_checker(lambda path: DirectoryValidator.__check_with_sensitive_words(path, words), msg) + return self + + +class FileValidator(StringValidator): + def __init__(self, value): + """ + @param value: the file path, should not be emtpy string, should not contain double dot(../) + """ + super().__init__(value) + self.register_checker(lambda x: isinstance(x, str), "type is not str") + + def check_file_size(self, max_size=MAX_SIZE, min_size=MIN_SIZE): + self.register_checker(lambda path: min_size < os.path.getsize(self.value) <= max_size, + "file size is invalid") + return self + + def check_not_soft_link(self): + self.register_checker(lambda path: os.path.realpath(self.value) == self.value, + "soft link or relative path should not be in the path parameter") + return self + + def check_user_group(self): + process_uid = os.geteuid() + process_gid = os.getegid() + stat_info = os.stat(self.value) + file_uid = stat_info.st_uid + file_gid = stat_info.st_gid + self.register_checker( + lambda path: process_uid == file_uid or process_gid == file_gid, "Invalid log file user or group.") + return self + + +class RankInfoValidator: + def check_visible_devices(self): + visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") + device_res = StringValidator(visible_devices).check().is_valid() + if not device_res: + raise TypeError("env variable ascend_visible_devices is null, please config ASCEND_VISIBLE_DEVICES in " + "docker container start.") + + rank_size = os.getenv("CM_WORKER_SIZE") + rank_size_res = StringValidator(rank_size).check().is_valid().is_valid() + if not rank_size_res: + raise TypeError("env variable CM_WORKER_SIZE is null, please config CM_WORKER_SIZE. For example, " + "CM_WORKER_SIZE=1") + + try: + rank_size_value = int(rank_size) + res = RankSizeValidator(rank_size_value, 1, 16).check_rank_size_valid().is_valid() + if not res and rank_size_value not in [1, 2, 4, 8, 16]: + raise ValueError("Invalid rank size, rank size must between 0 and 15 in recommendation training.") + except ValueError as err: + raise ValueError("Invalid rank size, rank size is a valid integer.") from err + + chief_device = os.getenv("CM_CHIEF_DEVICE") + chief_device_res = StringValidator(chief_device).check().is_valid() + if not chief_device_res: + raise TypeError("env variable CM_CHIEF_DEVICE is null, please config CM_CHIEF_DEVICE. For example, " + "CM_CHIEF_DEVICE=0") -- Gitee From 56508f0451826ad3c194de7488abda54431066f2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 11:36:51 +0800 Subject: [PATCH 070/551] Match-id-156348ee639d272b4b2448aa555e394cf2bb01e3 --- src/core/checkpoint/checkpoint.cpp | 9 +++++++-- .../feat_admit_n_evict_ckpt.cpp | 1 + src/core/emb_hashmap/emb_hashmap.cpp | 4 ++-- src/core/emb_mgmt/emb_mgmt.cpp | 4 +++- src/core/emb_table/emb_table.cpp | 4 ++-- src/core/host_emb/host_emb.cpp | 4 ++-- src/core/key_process/key_process.cpp | 11 +++++++---- src/core/utils/unique.h | 8 ++++---- src/ops_tf/hybrid_dataset_ops.cpp | 16 ++++++++++------ 9 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index b398da94..57fed0e9 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -195,6 +195,7 @@ void Checkpoint::WriteEmbedding(CkptTransData& transData, const string& dataDir, auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { spdlog::error("Set device failed, device_id:{}", deviceId); + throw runtime_error(fmt::format("Set device failed, device_id:{}", deviceId).c_str()); } auto &transArr = transData.int64Arr; @@ -208,6 +209,7 @@ void Checkpoint::WriteEmbedding(CkptTransData& transData, const string& dataDir, ACL_MEMCPY_DEVICE_TO_HOST); if (ret != ACL_SUCCESS) { spdlog::error("aclrtMemcpy failed, ret={}", ret); + throw runtime_error(fmt::format("aclrtMemcpy failed, ret={}", ret).c_str()); } writeFile.write((const char *) (row.data()), embeddingSize * sizeof(float)); @@ -225,6 +227,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { spdlog::error("Set device failed, device_id:{}", deviceId); + throw runtime_error(fmt::format("Set device failed, device_id:{}", deviceId).c_str()); } auto &AttributeArr = transData.attribute; @@ -238,6 +241,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { spdlog::error("aclrtMalloc failed, ret={}", ret); + throw runtime_error(fmt::format("aclrtMemcpy failed, ret={}", ret).c_str()); } float *floatPtr = static_cast(newBlock); @@ -251,6 +255,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { spdlog::error("aclrtMemcpy failed, ret={}", ret); + throw runtime_error(fmt::format("aclrtMemcpy failed, ret={}", ret).c_str()); } int64_t address = reinterpret_cast(floatPtr + i * embeddingSize); @@ -414,7 +419,7 @@ void Checkpoint::ReadStream(CkptTransData& transData, uint32_t dataElmtBytes) { if (dataElmtBytes == 0) { - spdlog::error("dataElmtBytes is 0, don't handle [/ %] operation"); + spdlog::warn("dataElmtBytes is 0, don't handle [/ %] operation"); return ; } std::ifstream readFile; @@ -531,4 +536,4 @@ void Checkpoint::ReadDataset(CkptTransData& transData, } else if (dataType == CkptDataType::ATTRIBUTE) { readFile.read((char*)(transData.attribute.data()) + idx, readSize); } -} \ No newline at end of file +} diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index 6c342253..d6a4d388 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -18,6 +18,7 @@ void FeatAdmitNEvictCkpt::SetProcessData(CkptData& processData) if (processData.tens2Thresh.empty() || processData.histRec.timestamps.empty() || processData.histRec.historyRecords.empty()) { spdlog::error("Missing Feature Admit and Evict data"); + throw std::runtime_error("Missing Feature Admit and Evict data"); } saveTens2Thresh = std::move(processData.tens2Thresh); saveHistRec = std::move(processData.histRec); diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 8f446a65..ad24048b 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -247,7 +247,7 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& size_t offset; auto key = keys[i]; if (key == -1) { - spdlog::error("evict key equal -1!"); + spdlog::warn("evict key equal -1!"); continue; } const auto& iter = embHashMap.hostHashMap.find(key); @@ -272,4 +272,4 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& spdlog::info("ddr EvictDeleteEmb, emb: [{}], hostEvictSize: {}, devEvictSize: {} ", embName, embHashMap.evictPos.size(), embHashMap.evictDevPos.size()); spdlog::trace("hostHashMap, {}", embHashMaps[embName].hostHashMap); -} \ No newline at end of file +} diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/emb_mgmt/emb_mgmt.cpp index 9eb4d908..bd99cd14 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/emb_mgmt/emb_mgmt.cpp @@ -615,6 +615,8 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) if (evictDevOffset.size() > embInfo.devVocabSize) { spdlog::error(MGMT + "{} overflow! evict pos on dev {} bigger than dev vocabSize {}", embName, evictDevOffset.size(), embInfo.devVocabSize); + throw runtime_error(fmt::format(MGMT + "{} overflow! evict pos on dev {} bigger than dev vocabSize {}", + embName, evictDevOffset.size(), embInfo.devVocabSize).c_str()); } if (mgmtRankInfo.useStatic) { evictDevOffset.resize(embInfo.devVocabSize, -1); @@ -624,4 +626,4 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) auto tmpData = Vec2TensorI32(evictDevOffset); hdTransfer->Send(EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); -} \ No newline at end of file +} diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 09cd6aec..6d8299b2 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -136,7 +136,7 @@ void EmbTable::RandomInit(void* newBlock, const vector& initiali break; } default: { - spdlog::error("Device Invalid Initializer Type. Using default Constant Initializer with value 0."); + spdlog::warn("Device Invalid Initializer Type. Using default Constant Initializer with value 0."); ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0); initializer = &defaultInitializer; } @@ -238,4 +238,4 @@ list EmbTable::LoadEmb(const vector> &savedEmb) memoryList.push_back(newBlock); return addressList; #endif -} \ No newline at end of file +} diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 905d0d5f..2e3e10e3 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -62,7 +62,7 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in break; } default: { - spdlog::error(HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); + spdlog::warn(HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0); initializer = &defaultInitializer; } @@ -246,4 +246,4 @@ void HostEmb::EvictInitEmb(const string& embName, const vector& offset) EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); spdlog::info(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); -} \ No newline at end of file +} diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 673d683b..050be81c 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -548,6 +548,8 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, if (static_cast(i.size()) > embInfos[batch->name].sendCount) { spdlog::error("{}[{}]:{} overflow! set send count bigger than {}", batch->name, batch->channel, batch->batchId, i.size()); + throw runtime_error(fmt::format("{}[{}]:{} overflow! set send count bigger than {}", + batch->name, batch->channel, batch->batchId, i.size()).c_str()); } i.resize(embInfos[batch->name].sendCount, -1); } @@ -894,9 +896,8 @@ T KeyProcess::GetInfo(std::vector>& list, int batch, const string } auto topBatch = get(list[batchListId][embName][channel].top()); if (topBatch < batch) { - spdlog::error("wrong batch id, top:{} expect:{}, channel:{}, embName: {}, queue_size:{}, " - "may not clear channel", - topBatch, batch, channel, embName, list[batchListId][embName][channel].size()); + spdlog::warn("wrong batch id, top:{} expect:{}, channel:{}, embName: {}, queue_size:{}, may not clear channel", + topBatch, batch, channel, embName, list[batchListId][embName][channel].size()); this_thread::sleep_for(1s); } if (topBatch != batch) { @@ -1027,7 +1028,7 @@ void KeyProcess::EvictDeleteDeviceEmb(const string& embName, const vector offset if (offset.size() > embInfos[embName].devVocabSize) { spdlog::error("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", embName, offset.size(), embInfos[embName].devVocabSize); + throw runtime_error(fmt::format("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", + embName, offset.size(), embInfos[embName].devVocabSize).c_str()); } if (rankInfo.useStatic) { offset.resize(embInfos[embName].devVocabSize, -1); diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index b40115fb..d30cc680 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -34,6 +34,7 @@ #include "time_cost.h" using namespace MxRec; +using namespace std; struct UniqueData { void *inputData; @@ -642,7 +643,6 @@ public: for (int i = 0; i < groupMethod_.GroupCount(); i++) { if (send_cnt_ < uniqueSizeVector[i]){ spdlog::error("sendCnt should not be smaller than uniqueSize, sendCnt {}, uniqueSize {}", send_cnt_, uniqueSizeVector[i]); - throw SendCntTooSmallError(); } } } @@ -737,13 +737,13 @@ public: auto rc = memcpy_s(uniqueIds + start, mem_size, uniqueVector + index, mem_size); if (rc != 0) { spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}",mem_size); - return; + throw std::runtime_error(fmt::format("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}",mem_size).c_str()); } mem_size = uniqueSizeVector[i] * sizeof(int32_t); rc = memcpy_s(idCountFill + start, mem_size, idCount + index, mem_size); if (rc != 0) { spdlog::error("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", mem_size); - return; + throw std::runtime_error(fmt::format("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", mem_size).c_str()); } } @@ -772,4 +772,4 @@ private: std::vector> dedupShards_; int32_t send_cnt_; }; -#endif \ No newline at end of file +#endif diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index b8f19779..5d568e64 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -230,7 +230,7 @@ public: std::unique_ptr batch = TensorCopy(inputTensor, move(batchData), len, offset); if (batch == nullptr) { spdlog::error("batch can not be null"); - return; + throw runtime_error("batch can not be null"); } queue->Pushv(move(batch)); } @@ -241,8 +241,8 @@ public: const size_t& len, size_t& offset) { if (len == 0) { - spdlog::error("len can not be zero"); - return nullptr; + spdlog::error("the length of batchData can not be zero"); + throw runtime_error("the length of batchData can not be zero"); } TimeCost ct; void* src = nullptr; @@ -261,11 +261,13 @@ public: batchData->tensorAddr = malloc(memSize); if (batchData->tensorAddr == nullptr) { spdlog::error("mmemory allocation failded..."); + throw runtime_error("mmemory allocation failded..."); } void* dst = reinterpret_cast(batchData->tensorAddr); auto rc = memcpy_s(dst, memSize, src, memSize); if (rc != 0) { spdlog::error("[ReadEmbKeyV2Dynamic]memcpy_s failded... memSize: {}", memSize); + throw runtime_error(fmt::format("[ReadEmbKeyV2Dynamic]memcpy_s failded... memSize: {}", memSize).c_str()); } TIME_PRINT("copy TimeCost(ms):{}", ct.ElapsedMS()); offset += len; @@ -461,7 +463,7 @@ public: std::unique_ptr batch = TensorCopy(inputTensor, move(batchData), len, offset); if (batch == nullptr) { spdlog::error("batch can not be null"); - return -1; + throw runtime_error("batch can not be null"); } queue->Pushv(move(batch)); } @@ -474,7 +476,7 @@ public: { if (len == 0) { spdlog::error("len can not be zero"); - return nullptr; + throw runtime_error("len can not be zero"); } TimeCost ct; void* src = nullptr; @@ -493,11 +495,13 @@ public: batchData->tensorAddr = malloc(memSize); if (batchData->tensorAddr == nullptr) { spdlog::error("mmemory allocation failded..."); + throw runtime_error("mmemory allocation failded..."); } void* dst = reinterpret_cast(batchData->tensorAddr); auto rc = memcpy_s(dst, memSize, src, memSize); if (rc != 0) { spdlog::error("[ReadEmbKeyV2Static]memcpy_s failded... memSize: {}", memSize); + throw runtime_error(fmt::format("[ReadEmbKeyV2Static]memcpy_s failded... memSize: {}", memSize).c_str()); } TIME_PRINT("copy TimeCost(ms):{}", ct.ElapsedMS()); offset += len; @@ -783,4 +787,4 @@ REGISTER_OP("EmbeddingUpdateByAddress") return Status::OK(); }); -REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), CustOps); \ No newline at end of file +REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), CustOps); -- Gitee From d1ecfce0ec3004909410a5a0d77b4ff1540668cc Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 11:53:25 +0800 Subject: [PATCH 071/551] Match-id-c7960a31cde5dde9f2f7e3db6e327a06b6e1fac6 --- example/little_demo/main.py | 22 ++++++++-------- mx_rec/core/embedding.py | 2 +- mx_rec/optimizers/lazy_adam_by_addr.py | 35 -------------------------- 3 files changed, 12 insertions(+), 47 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 27af7e30..671515d4 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -97,7 +97,7 @@ if __name__ == "__main__": warnings.filterwarnings("ignore") mode = MxRecMode.mapping(os.getenv("MXREC_MODE")) - TRAIN_INTERVAL = 200 + TRAIN_INTERVAL = 100 EVAL_STEPS = 10 SAVING_INTERVAL = 100 USE_TIMESTAMP = False @@ -208,7 +208,7 @@ if __name__ == "__main__": grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) - # saver = tf.compat.v1.train.Saver() + saver = tf.compat.v1.train.Saver() if MODIFY_GRAPH_FLAG: logging.info("start to modifying graph") modify_graph_and_start_emb_cache(dump_graph=True) @@ -222,10 +222,10 @@ if __name__ == "__main__": sess.run(train_iterator.initializer) sess.run(tf.compat.v1.global_variables_initializer()) EPOCH = 0 - # if os.path.exists(f"./saved-model/sparse-model-{rank_id}-%d" % 0): - # saver.restore(sess, f"./saved-model/model-{rank_id}-%d" % 0) - # else: - # saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0) + if os.path.exists(f"./saved-model/sparse-model-{rank_id}-%d" % 0): + saver.restore(sess, f"./saved-model/model-{rank_id}-%d" % 0) + else: + saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0) for i in range(1, 201): logging.info(f"################ training at step {i} ################") @@ -237,12 +237,12 @@ if __name__ == "__main__": else: if i % TRAIN_INTERVAL == 0: EPOCH += 1 - # evaluate() + evaluate() + + if i % SAVING_INTERVAL == 0: + saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) - # if i % SAVING_INTERVAL == 0: - # saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) - # - # saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) + saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) terminate_config_initializer() logging.info("Demo done!") diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 866a4f0f..90d9f999 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -632,7 +632,7 @@ class SparseEmbedding: embedding_type=1) from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN - if is_training and use_dynamic_expansion and ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ + if is_training and ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 6081a8e6..8ccddbe2 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -128,38 +128,3 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): var_update_op = host_pipeline_ops.embedding_update_by_address(addr, update_tensor, update_type=0) return var_update_op - - - # def _apply_sparse_shared(self, grad, addr): - # power_b1, power_b2 = self._get_beta_accumulators() - # power_b1 = math_ops.cast(power_b1, grad.dtype.base_dtype) - # power_b2 = math_ops.cast(power_b2, grad.dtype.base_dtype) - # temp_lr, temp_b1, temp_b2, temp_epsilon = self._cast_to_base_type(grad) - # learning_rate = tf.divide(temp_lr * math_ops.sqrt(1 - power_b2), (1 - power_b1)) - # - # host_pipeline_ops = get_host_pipeline_ops() - # dim = grad.shape.as_list()[-1] - # - # # addr_m = tf.where(tf.math.greater(addr, 0), addr + 4*dim, addr) - # # addr_v = tf.where(tf.math.greater(addr, 0), addr + 8*dim, addr) - # addr_m = tf.add(addr, 4*dim) - # addr_v = tf.add(addr, 8*dim) - # - # logging.debug(f'lazy adam by addr, addr is {addr}, addr_m is {addr_m}, addr_v is {addr_v}') - # old_m_slice = \ - # host_pipeline_ops.embedding_lookup_by_address(addr_m, embedding_dim=dim, embedding_type=1) - # old_v_slice = \ - # host_pipeline_ops.embedding_lookup_by_address(addr_v, embedding_dim=dim, embedding_type=1) - # - # m_t_slice = temp_b1 * old_m_slice + (1 - temp_b1) * grad - # m_update_op = host_pipeline_ops.embedding_update_by_address(addr_m, m_t_slice - old_m_slice, update_type=0) - # - # v_t_slice = temp_b2 * old_v_slice + (1 - temp_b2) * math_ops.square(grad) - # v_update_op = host_pipeline_ops.embedding_update_by_address(addr_v, v_t_slice - old_v_slice, update_type=0) - # - # denominator_slice = math_ops.sqrt(v_t_slice) + temp_epsilon - # - # var_update_op = host_pipeline_ops.embedding_update_by_address(addr, tf.divide(-learning_rate * m_t_slice, - # denominator_slice), update_type=0) - # - # return control_flow_ops.group(m_update_op, v_update_op, var_update_op) \ No newline at end of file -- Gitee From 3c5ffc4d628c4db085c6d790a1bc88e14af4c056 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 14:35:01 +0800 Subject: [PATCH 072/551] Match-id-95302267b39081123794e86d0fd6e55c43d39f52 --- mx_rec/core/embedding.py | 14 +++++++------- mx_rec/optimizers/lazy_adam_by_addr.py | 5 ++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 8c6ffba2..adaa5c01 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -18,7 +18,7 @@ from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ - DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID + DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, ASCEND_TABLE_NAME_MUST_CONTAIN from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, trigger_evict, get_table_instance_by_name, get_use_hot, get_device_id, export_feature_spec, \ @@ -423,9 +423,9 @@ class SparseEmbedding: embedding_dim=self.emb_size, embedding_type=1) - from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN - if is_training and use_dynamic_expansion and ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ - ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name: + is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ + ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name + if is_training and use_dynamic_expansion and is_table_name_valid: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) @@ -657,9 +657,9 @@ class SparseEmbedding: host_pipeline_ops.embedding_lookup_by_address(id_offsets, embedding_dim=self.emb_size, embedding_type=1) - from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN - if is_training and ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ - ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name: + is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ + ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name + if is_training and is_table_name_valid: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 31cf08bf..8e6b3c41 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -43,11 +43,10 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="LazyAdamByAddress"): self.optimizer_type = "LazyAdamByAddress" - super(CustomizedLazyAdamByAddress, self).__init__(use_locking, name) - super(CustomizedLazyAdamByAddress, self).__get_name__(name=name) super(CustomizedLazyAdamByAddress, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, - epsilon=epsilon, use_locking=use_locking, name=self.unique_name) + epsilon=epsilon, use_locking=use_locking, + name=self.unique_name) self._slot_num = 2 self._check_input_param() -- Gitee From 9001af2b47fcd977133d8629f34f4e4c5d0ccefa Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 14:51:38 +0800 Subject: [PATCH 073/551] Match-id-ccde453f594dfb10ddcc1ee4be2f7ecfea112414 --- mx_rec/optimizers/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index b4582e5e..dc420733 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -6,9 +6,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import logging from collections import defaultdict -import logging from tensorflow.python.framework import ops from tensorflow.python.training.optimizer import _TensorProcessor @@ -42,7 +42,6 @@ class CustomizedOptimizer: def my_update_op(self, opt, grad): - logging.debug("tf.compat.v1.training.optimizer._TensorProcessor has been patched, update_op.") if isinstance(grad, ops.Tensor): logging.debug(">>>>Enter update_op ops.Tensor") update_op = opt._apply_sparse(grad, self._v) # pylint: disable=protected-access @@ -53,4 +52,4 @@ def my_update_op(self, opt, grad): def patch_for_optimizer(): _TensorProcessor.update_op = my_update_op - logging.debug("Class tf.compat.v1.training.optimizer._TensorProcessor has been patched.") \ No newline at end of file + logging.debug("update_op in Class optimizer._TensorProcessor has been patched.") \ No newline at end of file -- Gitee From c4356639e38514fcd6b57f5926e1948d13389e6a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 15:33:40 +0800 Subject: [PATCH 074/551] Match-id-f66c523a942ff97ca51d400e1f695c2ef4fd83a7 --- mx_rec/core/embedding.py | 4 ++++ mx_rec/optimizers/lazy_adam_by_addr.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index adaa5c01..a3021822 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -428,6 +428,8 @@ class SparseEmbedding: if is_training and use_dynamic_expansion and is_table_name_valid: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) + logging.debug(f"modify graph mode, table_name: {self.table_name}, " + f"ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") @tf.custom_gradient def sparse_forward(table, feat_ids): @@ -662,6 +664,8 @@ class SparseEmbedding: if is_training and is_table_name_valid: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) + logging.debug(f"modify graph mode, table_name: {self.table_name}, " + f"ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") return sparse_forward(local_embeddings) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 8e6b3c41..50c0707b 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -43,7 +43,7 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="LazyAdamByAddress"): self.optimizer_type = "LazyAdamByAddress" - super(CustomizedLazyAdamByAddress, self).__get_name__(name=name) + super(CustomizedLazyAdamByAddress, self)._get_name(name=name) super(CustomizedLazyAdamByAddress, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, use_locking=use_locking, name=self.unique_name) -- Gitee From 4919a01aa5f66f15f15ad398058c62d7aa023d53 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 15:34:35 +0800 Subject: [PATCH 075/551] Match-id-6802a80d6704352156c24953d33af0f232b117eb --- mx_rec/core/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index a3021822..12a46a50 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -664,7 +664,7 @@ class SparseEmbedding: if is_training and is_table_name_valid: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) - logging.debug(f"modify graph mode, table_name: {self.table_name}, " + logging.debug(f"feature spec mode, table_name: {self.table_name}, " f"ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") return sparse_forward(local_embeddings) -- Gitee From 00af9d9717e8990cdb9e0e8d7a51ff90f03f3222 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 15:50:55 +0800 Subject: [PATCH 076/551] Match-id-5bb6ccf4ec6a120fd5b306235a199a79a271ffba --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 89411221..1e4b350d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -501,15 +501,15 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, remainBatchOut = false; } auto restore = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); - hdTransfer->Send(RESTORE, *restore, channelId, embName); + hdTransfer->Send(TransferChannel::RESTORE, *restore, channelId, embName); vector tmpData; hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData); - hdTransfer->Send(LOOKUP, { tmpData.front() }, channelId, embName); + hdTransfer->Send(TransferChannel::LOOKUP, { tmpData.front() }, channelId, embName); tmpData.erase(tmpData.begin()); - hdTransfer->Send(SWAP, tmpData, channelId, embName); + hdTransfer->Send(TransferChannel::SWAP, tmpData, channelId, embName); if (!mgmtRankInfo.useStatic) { auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(ALL2ALL, *all2all, channelId, embName); + hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch spdlog::warn(MGMT + "embName {}[{}]{},iBatch:{} freeSize not enough, {}", embName, channelId, -- Gitee From ddb1da1c1cbdd20ab3fdf3c2614f67d39b8e9c99 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 15:54:59 +0800 Subject: [PATCH 077/551] Match-id-d8043aee4089a1d9a37b9c8a70843f60e71ee431 --- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index ed6bca8f..98168d3f 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -85,7 +85,7 @@ namespace MxRec { bool ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut); - void EmbHDTrans(const int channelId,const int batchId); + void EmbHDTrans(const int channelId, const int batchId); void Evict(); -- Gitee From 959703a6353c8933982524262d57959e0d310bc4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 16:20:23 +0800 Subject: [PATCH 078/551] Match-id-bb74526fee261789e716ee1eac9b5a403dbfa97c --- src/core/utils/unique.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index 2b5ae171..f789475c 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -132,7 +132,7 @@ template class Dedup { public: Dedup(int bucketCountPower2 = kDefaultBucketCount, int groups = 1) - : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) + : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) { void *area = aligned_alloc(SysytemConst::LEVEL1_CACHE, sizeof(Meta) * bucketCount_); table_ = reinterpret_cast *>(area); @@ -652,7 +652,7 @@ public: if (uniqueFlag.useStatic) { for (int i = 0; i < groupMethod_.GroupCount(); i++) { - if (send_cnt_ < uniqueSizeVector[i]){ + if (send_cnt_ < uniqueSizeVector[i]) { spdlog::error("sendCnt should not be smaller than uniqueSize, sendCnt {}, uniqueSize {}", send_cnt_, uniqueSizeVector[i]); } @@ -757,14 +757,14 @@ public: if (rc != 0) { spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}", mem_size); throw std::runtime_error( - fmt::format("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}", mem_size).c_str()); + fmt::format("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}", mem_size).c_str()); } mem_size = uniqueSizeVector[i] * sizeof(int32_t); rc = memcpy_s(idCountFill + start, mem_size, idCount + index, mem_size); if (rc != 0) { spdlog::error("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", mem_size); throw std::runtime_error(fmt::format("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", - mem_size).c_str()); + mem_size).c_str()); } } -- Gitee From bffca6e988728d1b46f15e55d0fb2f355147c4d8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 17:45:40 +0800 Subject: [PATCH 079/551] Match-id-e3bd5b975cbd24485d1c6c00a0e945eb6b55fe76 --- mx_rec/validator/__init__.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 mx_rec/validator/__init__.py diff --git a/mx_rec/validator/__init__.py b/mx_rec/validator/__init__.py new file mode 100644 index 00000000..8f75c6b6 --- /dev/null +++ b/mx_rec/validator/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. -- Gitee From 031d8faa5d60f060d0490c607de213be1f01970d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 18:06:45 +0800 Subject: [PATCH 080/551] Match-id-b3bb390bf4027d2ddcb7c61d6bcc4c2333e8e34e --- src/core/checkpoint/checkpoint.cpp | 13 +-- src/core/checkpoint/checkpoint.h | 4 +- src/core/emb_hashmap/emb_hashmap.cpp | 5 +- src/core/emb_hashmap/emb_hashmap.h | 5 +- src/core/hd_transfer/hd_transfer.cpp | 4 +- src/core/hd_transfer/hd_transfer.h | 18 ++-- src/core/host_emb/host_emb.cpp | 6 +- src/core/host_emb/host_emb.h | 2 +- .../hybrid_mgmt.cpp} | 40 ++++----- .../emb_mgmt.h => hybrid_mgmt/hybrid_mgmt.h} | 12 +-- .../constant_initializer.cpp | 2 +- .../constant_initializer.h | 2 +- .../random_normal_initializer.cpp | 2 +- .../random_normal_initializer.h | 2 +- .../truncated_normal_initializer.cpp | 2 +- .../truncated_normal_initializer.h | 2 +- .../key_process/feature_admit_and_evict.cpp | 4 +- .../key_process/feature_admit_and_evict.h | 4 +- src/core/key_process/key_process.cpp | 25 +++--- src/core/key_process/key_process.h | 2 +- src/core/utils/common.cpp | 4 +- src/core/utils/common.h | 24 ++---- src/core/utils/spinlock.h | 66 +++++++------- src/core/utils/task_queue.h | 11 ++- src/core/utils/unique.h | 85 ++++++++++++------- src/ops_tf/hybrid_dataset_ops.cpp | 41 +++++---- src/pybind/module_main.cpp | 2 +- src/tests/emb_mgmt/emb_mgmt_test.cpp | 4 +- src/tests/key_process/key_process_test.cpp | 2 +- 29 files changed, 205 insertions(+), 190 deletions(-) rename src/core/{emb_mgmt/emb_mgmt.cpp => hybrid_mgmt/hybrid_mgmt.cpp} (94%) rename src/core/{emb_mgmt/emb_mgmt.h => hybrid_mgmt/hybrid_mgmt.h} (92%) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 57fed0e9..7f13d379 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -23,8 +23,6 @@ using namespace MxRec; void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo) { - // TODO: check savePath - processPath = savePath; rankId = mgmtRankInfo.rankId; deviceId = mgmtRankInfo.deviceId; @@ -42,8 +40,6 @@ void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRa void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo, const vector& featureTypes) { - // TODO: check loadPath - processPath = loadPath; rankId = mgmtRankInfo.rankId; deviceId = mgmtRankInfo.deviceId; @@ -143,7 +139,7 @@ void Checkpoint::MakeSaveDir(const string& dirName) } } -int Checkpoint::GetEmbeddingSize(const string& embName) +int Checkpoint::GetEmbeddingSize(const string& embName) const { for (const auto &embInfo: mgmtEmbInfo) { if (embInfo.name == embName) { @@ -186,7 +182,7 @@ void Checkpoint::SaveDataset(const vector& embNames, } } -void Checkpoint::WriteEmbedding(CkptTransData& transData, const string& dataDir, int& embeddingSize) +void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize) { ofstream writeFile; writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); @@ -338,11 +334,9 @@ void Checkpoint::LoadProcess(CkptData& ckptData) void Checkpoint::GetUpperLayerLoadDir(const vector& dirNames) { innerDirPath = processPath; - // TODO: check existence for (const auto& dirName : dirNames) { innerDirPath = innerDirPath + dirSeparator + dirName; - // TODO: check existence } } @@ -360,7 +354,6 @@ vector Checkpoint::GetTableLayerLoadDir() } closedir(dir); } - // TODO: may cause memory problem? need to check return loadTableDir; } @@ -372,10 +365,8 @@ void Checkpoint::LoadDataset(const vector& embNames, { for (const auto& embName : embNames) { auto dataDir { innerDirPath + dirSeparator + embName }; - // TODO: check existence for (const auto& saveDataType : saveDataTypes) { auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; - // TODO: check existence auto datasetDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; auto attributeDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + attribFileType }; diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 548517fb..27578053 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -80,10 +80,10 @@ namespace MxRec { void WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType); void WriteDataset(CkptTransData& transData, ofstream& writeFile, size_t writeSize, CkptDataType dataType, size_t idx); - void WriteEmbedding(CkptTransData& transData, const string& dataDir, int& embeddingSize); + void WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize); void ReadEmbedding(CkptTransData& transData, const string& dataDir); - int GetEmbeddingSize(const string& embName); + int GetEmbeddingSize(const string& embName) const; void LoadProcess(CkptData& ckptData); void GetUpperLayerLoadDir(const vector& dirNames); diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 9f74756e..f506c809 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -108,7 +108,7 @@ void EmbHashMap::FindAndUpdateOffset(const string& embName, const vector& keys, size_t curr } } -void EmbHashMap::FindPos(EmbHashMapInfo& embHashMap, int num, size_t currentBatchId, - size_t keepBatchId) +void EmbHashMap::FindPos(EmbHashMapInfo& embHashMap, int num, size_t keepBatchId) { while (num != 0) { if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 498b61fa..ab51ab13 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -9,9 +9,9 @@ #define MX_REC_EMB_HASHMAP_H #include -#include "absl/container/flat_hash_map.h" #include #include +#include "absl/container/flat_hash_map.h" #include "host_emb/host_emb.h" namespace MxRec { @@ -32,8 +32,7 @@ namespace MxRec { void ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_t hostOffset, size_t currentBatchId, int pos); - void FindPos(EmbHashMapInfo& embHashMap, int num, size_t currentBatchId, - size_t keepBatchId); + void FindPos(EmbHashMapInfo& embHashMap, int num, size_t keepBatchId); auto GetHashMaps() -> absl::flat_hash_map; diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index f3e9618b..39c8c76e 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -60,7 +60,7 @@ void HDTransfer::Destroy() #endif } -void HDTransfer::CreateChannel(uint32_t localRankId, const string& embName, int channelNum) +void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum) { #ifndef GTEST int channelSize; @@ -82,7 +82,7 @@ void HDTransfer::CreateChannel(uint32_t localRankId, const string& embName, int } } spdlog::info("user config all2all restore lookup channel size:{}", channelSize); - for (int c = D2H; c != INVALID; c++) { + for (int c = static_cast(TransferChannel::D2H); c != static_cast(TransferChannel::INVALID); c++) { auto channel = static_cast(c); string sendName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelNum); if (TransferChannel2Str(channel) == "all2all" || diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index 2e9ff303..b840649d 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -27,7 +27,7 @@ namespace MxRec { const int PING_PONG_SIZE = 12; const int LARGE_CHANNEL_SIZE = 100; - enum TransferChannel { + enum class TransferChannel { D2H, RESTORE, ALL2ALL, @@ -41,19 +41,19 @@ namespace MxRec { inline string TransferChannel2Str(TransferChannel e) { switch (e) { - case D2H: + case TransferChannel::D2H: return "d2h"; - case RESTORE: + case TransferChannel::RESTORE: return "restore"; - case ALL2ALL: + case TransferChannel::ALL2ALL: return "all2all"; - case LOOKUP: + case TransferChannel::LOOKUP: return "lookup"; - case EVICT: + case TransferChannel::EVICT: return "evict"; - case H2D: + case TransferChannel::H2D: return "h2d"; - case SWAP: + case TransferChannel::SWAP: return "swap"; default: throw std::invalid_argument("Invalid TransferChannel"); @@ -84,7 +84,7 @@ namespace MxRec { std::unordered_map transferChannels; #endif bool running; - void CreateChannel(uint32_t localRankId, const string& embName, int channelNum); + void CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum); }; } #endif // MX_REC_HD_TRANSFER_H diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index a56e32b7..eeb6aa96 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -19,7 +19,7 @@ using namespace MxRec; using namespace std; using namespace chrono; -bool HostEmb::Initialize(const vector& embInfos, int seed, bool ifLoad) +bool HostEmb::Initialize(const vector& embInfos, int seed) { for (const auto& embInfo: embInfos) { HostEmbTable hostEmb; @@ -89,7 +89,7 @@ void HostEmb::Join() t->join(); } procThreads.clear(); - spdlog::info(HOSTEMB + "hostemb end join, cost:{}", TO_MS(sw)); + spdlog::info(HOSTEMB + "hostemb end join, cost:{}", duration_cast((sw).elapsed())); } /* @@ -245,4 +245,4 @@ void HostEmb::EvictInitEmb(const string& embName, const vector& offset) EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); spdlog::info(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); -} +} \ No newline at end of file diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index b12a68bc..4464a202 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -28,7 +28,7 @@ namespace MxRec { ~HostEmb() {}; - bool Initialize(const vector& embInfos, int seed, bool ifLoad = false); + bool Initialize(const vector& embInfos, int seed); void LoadEmb(absl::flat_hash_map& loadData); diff --git a/src/core/emb_mgmt/emb_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp similarity index 94% rename from src/core/emb_mgmt/emb_mgmt.cpp rename to src/core/hybrid_mgmt/hybrid_mgmt.cpp index be446118..1e4b350d 100644 --- a/src/core/emb_mgmt/emb_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -4,7 +4,7 @@ * Author: MindX SDK * Date: 2022/11/15 */ -#include "emb_mgmt.h" +#include "hybrid_mgmt.h" #include #include @@ -81,7 +81,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, if (!rankInfo.noDDR) { hostEmbs = make_unique(); hostHashMaps = make_unique(); - hostEmbs->Initialize(embInfos, seed, ifLoad); + hostEmbs->Initialize(embInfos, seed); hostHashMaps->Init(rankInfo, embInfos, ifLoad); } isLoad = ifLoad; @@ -98,7 +98,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, return true; } -bool HybridMgmt::Save(string savePath) +bool HybridMgmt::Save(const string savePath) { preprocess->LoadSaveLock(); @@ -379,7 +379,7 @@ bool HybridMgmt::SendTask() return false; } -bool HybridMgmt::GetLookupAndRestore(int channelId, int &batchId) +bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) { spdlog::info(MGMT + "start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); for (const auto& embInfo: mgmtEmbInfo) { @@ -406,7 +406,7 @@ bool HybridMgmt::GetLookupAndRestore(int channelId, int &batchId) return true; } -bool HybridMgmt::SendLookupAndRestore(int channelId, int &batchId) +bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) { for (const auto& embInfo: mgmtEmbInfo) { vector names = {embInfo.name}; @@ -418,7 +418,7 @@ bool HybridMgmt::SendLookupAndRestore(int channelId, int &batchId) if (!mgmtRankInfo.useStatic) { for (const string& name: names) { auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(ALL2ALL, { *all2all }, channelId, name); + hdTransfer->Send(TransferChannel::ALL2ALL, { *all2all }, channelId, name); } } spdlog::info("SendLookupAndRestore batchId: {}, name: {}, channelId: {}", @@ -433,7 +433,7 @@ bool HybridMgmt::SendLookupAndRestore(int channelId, int &batchId) TimeCost sendLookupTC; for (const string& name: names) { auto lookUpKeys = lookUpKeysQueue->WaitAndPop(); - hdTransfer->Send(LOOKUP, lookUpKeys, channelId, name); + hdTransfer->Send(TransferChannel::LOOKUP, lookUpKeys, channelId, name); } TIME_PRINT("LOOKUP Send TimeCost(ms):{}", sendLookupTC.ElapsedMS()); } @@ -442,7 +442,7 @@ bool HybridMgmt::SendLookupAndRestore(int channelId, int &batchId) TimeCost sendRestoreTC; for (const string& name: names) { auto restore = restoreQueue->WaitAndPop(); - hdTransfer->Send(RESTORE, restore, channelId, name); + hdTransfer->Send(TransferChannel::RESTORE, restore, channelId, name); } TIME_PRINT("RESTORE Send TimeCost(ms):{}", sendRestoreTC.ElapsedMS()); } @@ -453,7 +453,7 @@ bool HybridMgmt::SendLookupAndRestore(int channelId, int &batchId) return true; } -bool HybridMgmt::EndBatch(int batchId, int channelId) +bool HybridMgmt::EndBatch(int batchId, int channelId) const { return (batchId % mgmtRankInfo.maxStep[channelId] == 0 && mgmtRankInfo.maxStep[channelId] != -1); } @@ -462,8 +462,10 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) { spdlog::info(MGMT + "DDR mode, start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); TimeCost parseKeyTC; - int start = batchId, iBatch = 0; - bool ifHashmapFree = true, remainBatch = true; + int start = batchId; + int iBatch = 0; + bool ifHashmapFree = true; + bool remainBatch = true; while (true) { spdlog::info(MGMT + "parse keys, [{}]:{}", channelId, batchId); for (const auto& embInfo : mgmtEmbInfo) { @@ -499,15 +501,15 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, remainBatchOut = false; } auto restore = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); - hdTransfer->Send(RESTORE, *restore, channelId, embName); + hdTransfer->Send(TransferChannel::RESTORE, *restore, channelId, embName); vector tmpData; hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData); - hdTransfer->Send(LOOKUP, { tmpData.front() }, channelId, embName); + hdTransfer->Send(TransferChannel::LOOKUP, { tmpData.front() }, channelId, embName); tmpData.erase(tmpData.begin()); - hdTransfer->Send(SWAP, tmpData, channelId, embName); + hdTransfer->Send(TransferChannel::SWAP, tmpData, channelId, embName); if (!mgmtRankInfo.useStatic) { auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(ALL2ALL, *all2all, channelId, embName); + hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch spdlog::warn(MGMT + "embName {}[{}]{},iBatch:{} freeSize not enough, {}", embName, channelId, @@ -534,7 +536,7 @@ void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, in } } -void HybridMgmt::EmbHDTrans(int channelId, int batchId) +void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) { EASY_FUNCTION(profiler::colors::Blue) EASY_VALUE("mgmtProcess", batchId) @@ -544,7 +546,7 @@ void HybridMgmt::EmbHDTrans(int channelId, int batchId) auto& missingKeys = hostHashMaps->embHashMaps.at(embInfo.name).missingKeysHostPos; vector h2dEmb; hostEmbs->GetH2DEmb(missingKeys, embInfo.name, h2dEmb); // order! - hdTransfer->Send(H2D, h2dEmb, channelId, embInfo.name, batchId); + hdTransfer->Send(TransferChannel::H2D, h2dEmb, channelId, embInfo.name, batchId); } for (const auto& embInfo: mgmtEmbInfo) { const auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); @@ -563,7 +565,7 @@ void HybridMgmt::EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embI spdlog::info(MGMT + "trans emb dummy, batchId:{}, channelId:{}", batchId, channelId); auto transferName = TransferChannel::D2H; auto d2hEmb = hdTransfer->Recv(transferName, channelId, embInfo.name)[0]; - hdTransfer->Send(H2D, {}, channelId, embInfo.name); + hdTransfer->Send(TransferChannel::H2D, {}, channelId, embInfo.name); } /* @@ -630,5 +632,5 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) } auto tmpData = Vec2TensorI32(evictDevOffset); - hdTransfer->Send(EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); + hdTransfer->Send(TransferChannel::EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); } diff --git a/src/core/emb_mgmt/emb_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h similarity index 92% rename from src/core/emb_mgmt/emb_mgmt.h rename to src/core/hybrid_mgmt/hybrid_mgmt.h index bba6420c..98168d3f 100644 --- a/src/core/emb_mgmt/emb_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -9,11 +9,11 @@ #define MX_REC_EMB_MGMT_H #include -#include "absl/container/flat_hash_map.h" #include #include #include #include +#include "absl/container/flat_hash_map.h" #include "utils/common.h" #include "utils/singleton.h" #include "utils/task_queue.h" @@ -46,7 +46,7 @@ namespace MxRec { bool Initialize(RankInfo rankInfo, const vector& embInfos, int seed, const vector& thresholdValues, bool ifLoad); - bool Save(string savePath); + bool Save(const string savePath); bool Load(const string& loadPath); @@ -85,7 +85,7 @@ namespace MxRec { bool ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut); - void EmbHDTrans(int channelId, int batchId); + void EmbHDTrans(const int channelId, const int batchId); void Evict(); @@ -122,12 +122,12 @@ namespace MxRec { bool TrainParseKeys(); bool EvalParseKeys(); - bool GetLookupAndRestore(int channelId, int &batchId); - bool SendLookupAndRestore(int channelId, int &batchId); + bool GetLookupAndRestore(const int channelId, int &batchId); + bool SendLookupAndRestore(const int channelId, int &batchId); void EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo); - bool EndBatch(int batchId, int channelId); + bool EndBatch(int batchId, int channelId) const; void EmbHDTransWrap(int channelId, const int& batchId, int start, int iBatch); diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index 4bced738..954ca98f 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -13,7 +13,7 @@ using namespace MxRec; ConstantInitializer::ConstantInitializer(int start, int len, float value) : start(start), len(len), value(value) {} -void ConstantInitializer::GenerateData(float* emb, const int embSize) +void ConstantInitializer::GenerateData(float* const emb, const int embSize) { if (len == 0) { return; diff --git a/src/core/initializer/constant_initializer/constant_initializer.h b/src/core/initializer/constant_initializer/constant_initializer.h index b763087b..68aa0654 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.h +++ b/src/core/initializer/constant_initializer/constant_initializer.h @@ -21,7 +21,7 @@ namespace MxRec { ~ConstantInitializer() override {}; - void GenerateData(float* emb, const int embSize) override; + void GenerateData(float* const emb, const int embSize) override; int start; int len; diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index a5a31381..7933555f 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -18,7 +18,7 @@ RandomNormalInitializer::RandomNormalInitializer(int start, int len, float mean, distribution = std::normal_distribution(mean, stddev); } -void RandomNormalInitializer::GenerateData(float* emb, const int embSize) +void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) { if (len == 0) { return; diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index e0127ca2..d6c8b376 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -23,7 +23,7 @@ namespace MxRec { ~RandomNormalInitializer() override {}; - void GenerateData(float* emb, const int embSize) override; + void GenerateData(float *const emb, const int embSize) override; int start; int len; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index a1f49c08..d02ac998 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -21,7 +21,7 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, float } -void TruncatedNormalInitializer::GenerateData(float* emb, const int embSize) +void TruncatedNormalInitializer::GenerateData(float* const emb, const int embSize) { if (len == 0) { return; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index 3c6bb980..d2da1bef 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -23,7 +23,7 @@ namespace MxRec { ~TruncatedNormalInitializer() override {}; - void GenerateData(float* emb, const int embSize) override; + void GenerateData(float* const emb, const int embSize) override; int boundNum = 2; diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index bcbb02e7..7cb2caa8 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -91,8 +91,8 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; } -FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(int channel, const std::string& tensorName, - int64_t featureId, uint32_t featureCnt) +FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, const std::string& tensorName, + const int64_t featureId, const uint32_t featureCnt) { // “特征准入”逻辑 uint32_t currKeyCount = 0; diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index a355711d..dfc2490c 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -84,8 +84,8 @@ namespace MxRec { // 解析m_tensor2Threshold bool ParseThresholdCfg(const std::vector& thresholdValues); std::vector GetAllNeedEvictTensorNames(); - FeatureAdmitType FeatureAdmitHelper(int channel, const std::string& tensorName, - int64_t featureId, uint32_t featureCnt); + FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tensorName, + const int64_t featureId, const uint32_t featureCnt); void FeatureEvictHelper(const std::string& embName, std::vector& evictKey); void ResetAllRecords(); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 2d78422e..c590681d 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -94,11 +94,11 @@ int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, } } spdlog::info(KEY_PROCESS "hot emb count info:{}", hotEmbTotCount); - MPI_Group world_group; - MPI_Comm_group(MPI_COMM_WORLD, &world_group); + MPI_Group worldGroup; + MPI_Comm_group(MPI_COMM_WORLD, &worldGroup); for (auto& i: comm) { for (auto& j: i) { - MPI_Comm_create(MPI_COMM_WORLD, world_group, &j); + MPI_Comm_create(MPI_COMM_WORLD, worldGroup, &j); } } isRunning = true; @@ -217,7 +217,7 @@ void KeyProcess::LoadSaveUnlock() } } -void KeyProcess::KeyProcessTask(int channel, int id) +void KeyProcess::KeyProcessTask(const int channel, const int id) // thread id [0, KEY_PROCESS_THREAD-1] { unique_ptr batch; @@ -239,7 +239,7 @@ void KeyProcess::KeyProcessTask(int channel, int id) spdlog::info(KEY_PROCESS "batch is nullptr"); break; } - auto getBatchTime = TO_MS(sw); + auto getBatchTime = duration_cast((sw).elapsed()); sw.reset(); auto sendCountSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); @@ -264,7 +264,8 @@ void KeyProcess::KeyProcessTask(int channel, int id) } TIME_PRINT("getAndProcesTC TimeCost(ms):{}", getAndProcesTC.ElapsedMS()); spdlog::info(KEY_PROCESS "key process cost:{}, get data time:{} batch {}[{}]:{} ", - TO_MS(sw), getBatchTime, batch->name, batch->channel, batch->batchId); + duration_cast( + (sw).elapsed()), getBatchTime, batch->name, batch->channel, batch->batchId); free(batch->tensorAddr); batch->tensorAddr = nullptr; batchQueue->PutDirty(move(batch)); @@ -591,7 +592,7 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< EASY_FUNCTION(profiler::colors::Gold) auto* batchData = batch->sample.data(); size_t miniBs = batch->Size(); - ASSERT(batchData != nullptr); + assert(batchData != nullptr); vector splitKeys(rankInfo.rankSize); vector restore(batch->Size()); vector hashSplitLens(rankInfo.rankSize); // 初始化全0,记录每个桶的长度 @@ -630,7 +631,7 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const EASY_FUNCTION(profiler::colors::Gold) auto* batchData = batch->sample.data(); size_t miniBs = batch->Size(); - ASSERT(batchData != nullptr); + assert(batchData != nullptr); vector splitKeys(rankInfo.rankSize); vector> keyCount(rankInfo.rankSize); // splitKeys在原始batch中对应的频次 vector restore(batch->Size()); @@ -788,7 +789,7 @@ void KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel throw EndRunError("GetScAll end run."); } EASY_END_BLOCK; - spdlog::debug(KEY_PROCESS "barrier time:{}", TO_MS(sw)); + spdlog::debug(KEY_PROCESS "barrier time:{}", duration_cast((sw).elapsed())); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); @@ -921,11 +922,11 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) return get(ret); } catch (EmptyList&) { spdlog::trace("GetLookupKeys GetInfo failed {}[{}]:{} no input, wait and retry", - embName, channel, batch); + embName, channel, batch); this_thread::sleep_for(1ms); } catch (WrongListTop&) { spdlog::trace("GetLookupKeys GetInfo failed {}[{}]:{} wrong top", - embName, channel, batch); + embName, channel, batch); this_thread::sleep_for(1ms); } } @@ -1054,7 +1055,7 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset auto trans = Singleton::GetInstance(); // evict key发送给dev侧,dev侧初始化emb auto tmpData = Vec2TensorI32(offset); - trans->Send(EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); + trans->Send(TransferChannel::EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); spdlog::info(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index ef4a9dba..3803244e 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -130,7 +130,7 @@ namespace MxRec { auto GetSendCount(const string& name, const string& channelName, bool modifyGraph); - void KeyProcessTask(int channel, int id); + void KeyProcessTask(const int channel, const int id); bool KeyProcessTaskHelper(unique_ptr& batch, shared_ptr unique, int channel, int id, spdlog::stopwatch& sw); diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index ba77c2f3..2485d5ba 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -41,8 +41,8 @@ namespace MxRec { if (localRankSize != 0) { localRankId = rankId % localRankSize; } - useStatic = option bitand HybridOption::USE_STATIC; - useHot = option bitand HybridOption::USE_HOT; + useStatic = option & HybridOption::USE_STATIC; + useHot = option & HybridOption::USE_HOT; } RandomInfo::RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 9553ec2a..41f26558 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -45,8 +45,6 @@ #endif namespace MxRec { -#define ASSERT(arg) assert(arg) -#define TO_MS(arg) duration_cast((arg).elapsed()) #define INFO_PTR shared_ptr #define TIME_PRINT spdlog::info #define MGMT_CPY_THREADS 4 @@ -129,30 +127,22 @@ namespace MxRec { template struct Batch { - size_t Size() + size_t Size() const { return sample.size(); } - std::string UnParse() + std::string UnParse() const { std::string s; constexpr size_t MAX_DISP_LEN = 20; - int max_len = std::min(sample.size(), MAX_DISP_LEN); - for (int i = 0; i < max_len; i++) { + int maxLen = std::min(sample.size(), MAX_DISP_LEN); + for (int i = 0; i < maxLen; i++) { s += std::to_string(sample[i]) + " "; } return s; } - ~Batch() - { - if (tensorAddr) { - free(tensorAddr); - tensorAddr = nullptr; - } - } - std::vector sample; void *tensorAddr = nullptr; std::string name; @@ -283,12 +273,12 @@ struct BatchTask { std::default_random_engine& generator, RandomInfo& randomInfo) { - float min = ((randomInfo.randomMin == 0) ? -0.1f : randomInfo.randomMin); - float max = ((randomInfo.randomMax == 0) ? 0.1f : randomInfo.randomMax); + float min = ((!randomInfo.randomMin) ? -0.1f : randomInfo.randomMin); + float max = ((!randomInfo.randomMax) ? 0.1f : randomInfo.randomMax); if (randomInfo.len == 0) { return; } - ASSERT(static_cast(vecData.size()) >= randomInfo.len + randomInfo.start); + assert(static_cast(vecData.size()) >= randomInfo.len + randomInfo.start); std::uniform_real_distribution distribution(min, max); std::generate_n(vecData.begin() + randomInfo.start, randomInfo.len, [&]() { return distribution(generator); }); } diff --git a/src/core/utils/spinlock.h b/src/core/utils/spinlock.h index 95c0a35c..e3b9e07d 100644 --- a/src/core/utils/spinlock.h +++ b/src/core/utils/spinlock.h @@ -12,12 +12,8 @@ #include #include // NOLINT -#define DISALLOW_COPY_MOVE_AND_ASSIGN_(type) \ - type(type const &) = delete; \ - type(type &&) noexcept = delete; \ - type &operator=(type const &) = delete - -static __inline void cpu_pause() { +static __inline void cpu_pause() +{ #ifdef __GNUC__ #ifdef __aarch64__ __asm volatile("yield" ::: "memory"); @@ -38,8 +34,8 @@ static constexpr uint16_t g_kMaxSpinCountBeforeThreadYield = 64; class SpinLock final { public: void lock() noexcept {} - bool try_lock() noexcept { return true; } void unlock() noexcept {} + bool try_lock() noexcept { return true; } }; #elif defined(USE_MUTEX) @@ -60,9 +56,12 @@ class SpinLock final { public: SpinLock() = default; - DISALLOW_COPY_MOVE_AND_ASSIGN_(SpinLock); + SpinLock(SpinLock const &) = delete; + SpinLock(SpinLock &&) noexcept = delete; + SpinLock &operator=(SpinLock const &) = delete; - inline void lock() noexcept { + inline void lock() noexcept + { while (true) { if (!lock_.exchange(true, std::memory_order_acquire)) { break; @@ -80,7 +79,8 @@ public: } } - inline bool try_lock() noexcept { + inline bool try_lock() noexcept + { if (lock_.load(std::memory_order_relaxed)) { return false; } @@ -105,24 +105,27 @@ class RWSpinLock final { public: RWSpinLock() = default; - DISALLOW_COPY_MOVE_AND_ASSIGN_(RWSpinLock); + RWSpinLock(RWSpinLock const &) = delete; + RWSpinLock(RWSpinLock &&) noexcept = delete; + RWSpinLock &operator=(RWSpinLock const &) = delete; - inline void r_lock() noexcept { - LockData oldData, newData; + inline void r_lock() noexcept + { + LockData oldData; + LockData newData; while (true) { uint16_t counter = 0; for (;;) { oldData.raw = lock_.load(std::memory_order_relaxed); - if (oldData.lock.writer > 0) { - cpu_pause(); - if (++counter > g_kMaxSpinCountBeforeThreadYield) { - std::this_thread::yield(); - // reset counter - counter = 0; - } - } else { + if (oldData.lock.writer <= 0) { break; } + cpu_pause(); + if (++counter > g_kMaxSpinCountBeforeThreadYield) { + std::this_thread::yield(); + // reset counter + counter = 0; + } } newData.lock.readers = oldData.lock.readers + 1; @@ -135,22 +138,23 @@ public: } } - inline void w_lock() noexcept { - LockData oldData, newData; + inline void w_lock() noexcept + { + LockData oldData; + LockData newData; while (true) { uint16_t counter = 0; for (;;) { oldData.raw = lock_.load(std::memory_order_relaxed); - if (oldData.raw != 0) { - cpu_pause(); - if (++counter > g_kMaxSpinCountBeforeThreadYield) { - std::this_thread::yield(); - // reset counter - counter = 0; - } - } else { + if (oldData.raw == 0) { break; } + cpu_pause(); + if (++counter > g_kMaxSpinCountBeforeThreadYield) { + std::this_thread::yield(); + // reset counter + counter = 0; + } } newData.lock.readers = 0; diff --git a/src/core/utils/task_queue.h b/src/core/utils/task_queue.h index 14ed1202..d44e2a11 100644 --- a/src/core/utils/task_queue.h +++ b/src/core/utils/task_queue.h @@ -38,7 +38,6 @@ public: return *this; } - void Pushv(T &t) { std::lock_guard lk(mut); @@ -57,14 +56,14 @@ public: { std::unique_lock lk(mut); dataCond.wait(lk, [this] { - if (!finished){ + if (!finished) { return !dataQueue.empty(); - } else{ + } else { return true; } }); T res; - if (finished){ + if (finished) { return res; } res = dataQueue.front(); @@ -72,12 +71,12 @@ public: return res; } - void DestroyQueue(){ + void DestroyQueue() + { finished = true; dataCond.notify_one(); } - bool Empty() const { std::lock_guard lk(mut); diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index d30cc680..f789475c 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -65,6 +65,16 @@ struct UniqueThreadNum { int maxThread; }; +namespace SysytemConst { + const int LEVEL1_CACHE = 64; + const int DEFAULT_DEDUPLICATION_RATE = 4; + const int DEDUPLICATION_RATE = 2; + const int HASH_SPLIT_BUCKERT_1 = 16; + const int HASH_SPLIT_BUCKERT_2 = 32; + const int HASH_SPLIT_BUCKERT_3 = 48; + const int PRE_APPLY_MEMORY = 256; +} + class SendCntTooSmallError : public std::exception { }; @@ -106,7 +116,7 @@ template class Dedup { static const int kDefaultBucketCountMask = kDefaultBucketCount - 1; template struct Meta { - static_assert(M <= UNIQUE_MAX_BUCKET_WIDTH, "should be no larger than max bucket width"); + static_assert(M <= MxRec::UNIQUE_MAX_BUCKET_WIDTH, "should be no larger than max bucket width"); SpinLock lock; volatile int8_t count; int8_t pad[3]; @@ -122,9 +132,9 @@ template class Dedup { public: Dedup(int bucketCountPower2 = kDefaultBucketCount, int groups = 1) - : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) + : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) { - void *area = aligned_alloc(64, sizeof(Meta) * bucketCount_); + void *area = aligned_alloc(SysytemConst::LEVEL1_CACHE, sizeof(Meta) * bucketCount_); table_ = reinterpret_cast *>(area); Clear(bucketCount_); } @@ -255,7 +265,8 @@ public: free(table_); bucketCount_ = newBucketCountPowerOf2; bucketCountMask_ = bucketCount_ - 1; - table_ = reinterpret_cast *>(aligned_alloc(64, sizeof(Meta) * bucketCount_)); + table_ = reinterpret_cast *>(aligned_alloc(SysytemConst::LEVEL1_CACHE, + sizeof(Meta) * bucketCount_)); } bzero(table_, sizeof(Meta) * bucketCount_); overflow_.clear(); @@ -271,7 +282,8 @@ public: // sake. uint64_t shardedTableSize = newBucketCountPowerOf2 * N * groupCount_; int largeCount = 0; - while (shardedTableSize > stats_.totalUniques * 4 && largeCount_ != 1) { + while (shardedTableSize > stats_.totalUniques * SysytemConst::DEFAULT_DEDUPLICATION_RATE && + largeCount_ != 1) { // too large newBucketCountPowerOf2 >>= 1; shardedTableSize >>= 1; @@ -286,7 +298,7 @@ public: } } - while (shardedTableSize < stats_.totalUniques + (stats_.totalUniques >> 2)) { + while (shardedTableSize < stats_.totalUniques + (stats_.totalUniques >> SysytemConst::DEDUPLICATION_RATE)) { newBucketCountPowerOf2 <<= 1; shardedTableSize <<= 1; } @@ -303,7 +315,6 @@ public: } // Warning: functions below are not thread safe! - // // Return the unique values // Also update the hash-order base of each bucket std::vector Unique() @@ -366,7 +377,8 @@ public: return total - priorTotal; } - void handleHotKey(int key, map &hotMap, map &hotPosMap, int &hotCount) { + void handleHotKey(int key, map &hotMap, map &hotPosMap, int &hotCount) + { auto hot = hotMap.find(key); if (hot != hotMap.end()) { if (hot->second == -1) { @@ -382,7 +394,7 @@ public: uint32_t UniqueRawForHot(int64_t *output, uint32_t priorTotal, int32_t* idCount, map &hotMap, map &hotPosMap, int &hotCount, - absl::flat_hash_map &keyCountMap) + absl::flat_hash_map &keyCountMap) { uint32_t total = priorTotal; int32_t replace_offset = priorTotal; @@ -446,7 +458,8 @@ private: static inline uint64_t hash(uint64_t val) { - return val ^ (val >> 16) ^ (val >> 32) ^ (val >> 48); + return val ^ (val >> SysytemConst::HASH_SPLIT_BUCKERT_1) ^ (val >> SysytemConst::HASH_SPLIT_BUCKERT_2) ^ + (val >> SysytemConst::HASH_SPLIT_BUCKERT_3); } void insertOverflow(uint64_t val) @@ -479,8 +492,6 @@ private: } }; // Dedup -#define CACHE_LINE_ALIGN(size) (((size) + 63ul) & ~63ul) - class OneSimpleGroupMethod { public: inline int GroupCount() @@ -503,8 +514,8 @@ public: using DedupT = Dedup; ShardedDedup(const GroupMethod &groupMethod, int desiredSize, int send_cnt, - int estimatedDuplicateRatio = kDefaultDuplicateRatio) - : groupMethod_(groupMethod), bucketCountPower2_(256), send_cnt_(send_cnt) + int estimatedDuplicateRatio = kDefaultDuplicateRatio) + : groupMethod_(groupMethod), bucketCountPower2_(SysytemConst::PRE_APPLY_MEMORY), send_cnt_(send_cnt) { const int numOfGroupsInShard = groupMethod_.GroupCount(); @@ -630,19 +641,20 @@ public: baseVector.push_back(base); base += total; - partSize = CACHE_LINE_ALIGN(partSize); + partSize = ((partSize) + 63ul) & ~63ul; int32_t *beginPtr = output; int32_t *finishPtr = beginPtr + inputSize; int32_t *partBeginPtr = beginPtr; int32_t *partEndPtr = - reinterpret_cast(CACHE_LINE_ALIGN(reinterpret_cast(partBeginPtr + partSize))); + reinterpret_cast(((reinterpret_cast(partBeginPtr + partSize)) + 63ul) & ~63ul); - if(uniqueFlag.useStatic){ + if (uniqueFlag.useStatic) { for (int i = 0; i < groupMethod_.GroupCount(); i++) { - if (send_cnt_ < uniqueSizeVector[i]){ - spdlog::error("sendCnt should not be smaller than uniqueSize, sendCnt {}, uniqueSize {}", send_cnt_, uniqueSizeVector[i]); + if (send_cnt_ < uniqueSizeVector[i]) { + spdlog::error("sendCnt should not be smaller than uniqueSize, sendCnt {}, uniqueSize {}", send_cnt_, + uniqueSizeVector[i]); } } } @@ -668,7 +680,8 @@ public: // should be +/-1 off. const int numOfGroupsInShard = groupMethod_.GroupCount(); tasks.push_back([this, input, &baseVector, beginPtr, partBeginPtr, partEndPtr, numOfGroupsInShard, - totalUniqueSize, useStatic, isInt64, useHot, offset, hotMap, hotPos, hotPosMap]() -> TaskReturnType { + totalUniqueSize, useStatic, isInt64, useHot, offset, hotMap, hotPos, + hotPosMap]() -> TaskReturnType { for (int32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { auto val = isInt64 ? ((int64_t *)input)[ptr - beginPtr] : ((int32_t *)input)[ptr - beginPtr]; auto group = groupMethod_.GroupId(val); @@ -686,14 +699,16 @@ public: pool->SyncRun(tasks); } - - TileAndFill(groupMethod_.GroupCount(), uniqueVector, uniqueSize, uniqueIds, idCount, idCountFill, useStatic, uniqueSizeVector); + TileAndFill(groupMethod_.GroupCount(), uniqueVector, uniqueSize, uniqueIds, idCount, idCountFill, useStatic, + uniqueSizeVector); return 0; } - void ComputeRestore(bool useHot, int offset,const map &hotMap, int *hotPos,const map &hotPosMap, - int32_t *ptr, int64_t val, uint32_t fillOffset) const { + void ComputeRestore(bool useHot, int offset, const map &hotMap, int *hotPos, + const map &hotPosMap, + int32_t *ptr, int64_t val, uint32_t fillOffset) const + { auto hot = hotPosMap.find(val); if (!useHot) { *ptr = fillOffset; @@ -708,16 +723,20 @@ public: } uint32_t GetFillOffset(bool useStatic, const vector &baseVector, const vector &totalUniqueSize, - int64_t val, int32_t group) { + int64_t val, int32_t group) + { if (!useStatic) { return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0]; } else { - return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0] + send_cnt_ * group - totalUniqueSize[group]; + return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0] + send_cnt_ * group - + totalUniqueSize[group]; } } - void TileAndFill(int groupCount, const int64_t *uniqueVector, int32_t *uniqueSize, int64_t *uniqueIds, const int32_t *idCount, - int32_t *idCountFill, bool useStatic, const std::vector &uniqueSizeVector) const { + void TileAndFill(int groupCount, const int64_t *uniqueVector, int32_t *uniqueSize, int64_t *uniqueIds, + const int32_t *idCount, int32_t *idCountFill, bool useStatic, + const std::vector &uniqueSizeVector) const + { int start = 0; int index = 0; @@ -736,14 +755,16 @@ public: size_t mem_size = uniqueSizeVector[i] * sizeof(int64_t); auto rc = memcpy_s(uniqueIds + start, mem_size, uniqueVector + index, mem_size); if (rc != 0) { - spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}",mem_size); - throw std::runtime_error(fmt::format("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}",mem_size).c_str()); + spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}", mem_size); + throw std::runtime_error( + fmt::format("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}", mem_size).c_str()); } mem_size = uniqueSizeVector[i] * sizeof(int32_t); rc = memcpy_s(idCountFill + start, mem_size, idCount + index, mem_size); if (rc != 0) { spdlog::error("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", mem_size); - throw std::runtime_error(fmt::format("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", mem_size).c_str()); + throw std::runtime_error(fmt::format("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", + mem_size).c_str()); } } @@ -772,4 +793,4 @@ private: std::vector> dedupShards_; int32_t send_cnt_; }; -#endif +#endif \ No newline at end of file diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 5d568e64..f6816ca2 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -199,7 +199,7 @@ public: EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}, " "splits: {}, dataSize: {}, filedNum: {}, channelNames: {}, modifyGraph: {}", - TO_MS(sw), TO_MS(staticSw), + duration_cast((sw).elapsed()), duration_cast((staticSw).elapsed()), channelId, batchId, splits.size(), dataSize, fieldNum, channelNames, modifyGraph); staticSw.reset(); } @@ -250,13 +250,15 @@ public: if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { batchData->isInt64 = false; memSize = len * sizeof(int32_t); - src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data(). - data())) + offset); + src = reinterpret_cast( + reinterpret_cast(const_cast((string *)(inputTensor.tensor_data().data()))) + + offset); } else { batchData->isInt64 = true; memSize = len * sizeof(int64_t); - src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data(). - data())) + offset); + src = reinterpret_cast( + reinterpret_cast(const_cast((string *)(inputTensor.tensor_data(). + data()))) + offset); } batchData->tensorAddr = malloc(memSize); if (batchData->tensorAddr == nullptr) { @@ -427,8 +429,9 @@ public: EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); TIME_PRINT("EnqueueBatchData TimeCost(ms):{}", tc.ElapsedMS()); - TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", TO_MS(sw), - TO_MS(staticSw), channelId, batchId); + TIME_PRINT(KEY_PROCESS + "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", duration_cast((sw).elapsed()), + duration_cast((staticSw).elapsed()), channelId, batchId); staticSw.reset(); } @@ -484,13 +487,15 @@ public: if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { batchData->isInt64 = false; memSize = len * sizeof(int32_t); - src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data() - .data())) + offset); + src = reinterpret_cast( + reinterpret_cast(const_cast((string *)(inputTensor.tensor_data().data()))) + + offset); } else { batchData->isInt64 = true; memSize = len * sizeof(int64_t); - src = reinterpret_cast(reinterpret_cast(const_cast(inputTensor.tensor_data() - .data())) + offset); + src = reinterpret_cast( + reinterpret_cast(const_cast((string *)(inputTensor.tensor_data().data()))) + + offset); } batchData->tensorAddr = malloc(memSize); if (batchData->tensorAddr == nullptr) { @@ -603,7 +608,8 @@ public: for (int i { 0 }; i < restoreLen; ++i) { r(i) = i % lookupLen; } - spdlog::warn("dummy read batch cost: {},elapsed from last {}", TO_MS(sw), TO_MS(staticSw)); + spdlog::warn("dummy read batch cost: {},elapsed from last {}", duration_cast((sw).elapsed()), + duration_cast((staticSw).elapsed())); staticSw.reset(); } @@ -674,8 +680,10 @@ public: floatDataIndex += floatList.value_size(); } } - spdlog::info("ReadRaw sampleId:{} cost:{} copy:{} , elapsed from last:{}", sampleId++, TO_MS(sw), - TO_MS(sw_copy), TO_MS(staticReadRaw)); + spdlog::info("ReadRaw sampleId:{} cost:{} copy:{} , elapsed from last:{}", sampleId++, + duration_cast((sw).elapsed()), + duration_cast((sw_copy).elapsed()), + duration_cast((staticReadRaw).elapsed())); staticReadRaw.reset(); } @@ -727,8 +735,9 @@ public: auto input = inputTensor.flat(); int32_t batchId = input(0); - spdlog::info("ReadRawDummy cost:{}, elapsed from last:{} , batchId = {}", TO_MS(sw), TO_MS(staticReadRaw), - batchId); + spdlog::info("ReadRawDummy cost:{}, elapsed from last:{} , batchId = {}", + duration_cast((sw).elapsed()), + duration_cast((staticReadRaw).elapsed()), batchId); staticReadRaw.reset(); } diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 6d297dc1..3839fe37 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -8,7 +8,7 @@ #include #include -#include "emb_mgmt/emb_mgmt.h" +#include "hybrid_mgmt/hybrid_mgmt.h" #include "module_main.h" namespace py = pybind11; diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 22b8655f..d103b144 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -9,7 +9,7 @@ #include #include #include -#include "emb_mgmt/emb_mgmt.h" +#include "hybrid_mgmt/hybrid_mgmt.h" #include "host_emb/host_emb.h" #include "utils/common.h" @@ -127,7 +127,7 @@ TEST_F(EmbMgmtTest, Initialize) allRank = RankInfo(rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); auto hostEmbs = make_unique(); - hostEmbs->Initialize(embInfos, seed, false); + hostEmbs->Initialize(embInfos, seed); auto hostHashMaps = make_unique(); hostHashMaps->Init(allRank, embInfos, false); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index db23625f..19c7e0a8 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -18,7 +18,7 @@ #include "utils/common.h" #include "host_emb/host_emb.h" #include "key_process/key_process.h" -#include "emb_mgmt/emb_mgmt.h" +#include "hybrid_mgmt/hybrid_mgmt.h" using namespace std; using namespace MxRec; -- Gitee From a61e7cd20b4287fcd049d94e152932061ec35f2e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 19:48:52 +0800 Subject: [PATCH 081/551] Match-id-211efc46bb89c77b0e973ec07e01025f749d0ecf --- mx_rec/core/asc/manager.py | 2 +- mx_rec/optimizers/base.py | 18 +++++++++--------- mx_rec/util/initialize.py | 6 +++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 944a0f9a..20b8bfc7 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -54,8 +54,8 @@ def generate_table_info_list(): if len(table_instance.channel_name_list) == 1: ids_channel_name = table_instance.channel_name_list[0] table_instance.channel_name_list = [table_instance.table_name] + table_instance.send_count_map.pop(ids_channel_name) try: - table_instance.send_count_map.pop(ids_channel_name) table_instance.send_count_map[table_instance.table_name] = table_instance.send_count except KeyError as error: raise KeyError(f"ids_channel_name '{ids_channel_name}' not in send_count_map " diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index 3f00978a..075cfec7 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -17,6 +17,15 @@ class CustomizedOptimizer: self.unique_name = "" self.base_name = "" + def initialize_slots(self, var, table_instance): + raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") + + def insert_slot(self, slot, named_slots_key, slot_name): + raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") + + def get_slot_init_values(self): + raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") + def _get_name(self, name="CustomizedOptimizer"): if name in CustomizedOptimizer.name_counter: CustomizedOptimizer.name_counter[name] += 1 @@ -26,12 +35,3 @@ class CustomizedOptimizer: count = CustomizedOptimizer.name_counter[name] self.unique_name = name + "_" + str(count) self.base_name = name - - def initialize_slots(self, var, table_instance): - raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") - - def insert_slot(self, slot, named_slots_key, slot_name): - raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") - - def get_slot_init_values(self): - raise NotImplementedError(f"Please define a specific realization on {self.__class__.__name__}") diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 1aed3a39..94004bab 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -323,9 +323,9 @@ class ConfigInitializer: def del_asc_manager(self): self.delete_initializers() + self.unfreeze() self._asc_manager.destroy() self._asc_manager = None - self.unfreeze() logging.debug("ASC manager has been destroyed.") @train_interval.setter @@ -665,7 +665,7 @@ def get_available_cpu_num_and_range(): valid_cpu_range_list = [] if is_ok: logging.info(f"available numa node num: {len(pkg_id2cpu_list)}") - for k, part_cpu_list in pkg_id2cpu_list.items(): + for _, part_cpu_list in pkg_id2cpu_list.items(): parse_range(part_cpu_list, valid_cpu_range_list) else: parse_range(list(cpu_available), valid_cpu_range_list) @@ -718,6 +718,6 @@ def bind_cpu(rank_id: int, rank_size: int = None): process = psutil.Process() try: process.cpu_affinity(cpu_list) - logging.info(f"bind cpu for rank {rank_id}: {cpu_list}") except IndexError: logging.error(f"failed to bind cpu for rank {rank_id}: {cpu_list}") + logging.info(f"bind cpu for rank {rank_id}: {cpu_list}") \ No newline at end of file -- Gitee From 8787f4b5a1eb3f1000d650a7ec95f5d2f7ce1a82 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 21:09:16 +0800 Subject: [PATCH 082/551] Match-id-4bca71ebf300ac0078198da8d09c13dd985a498f --- example/little_demo/main.py | 4 ++-- mx_rec/__init__.py | 2 +- mx_rec/core/asc/build_graph.py | 2 +- mx_rec/core/asc/helper.py | 2 +- mx_rec/core/asc/manager.py | 4 ++-- mx_rec/core/embedding.py | 2 +- mx_rec/graph/modifier.py | 2 +- mx_rec/saver/saver.py | 2 +- mx_rec/{constants => util}/constants.py | 0 mx_rec/util/initialize.py | 10 +++++----- mx_rec/util/ops.py | 2 +- mx_rec/validator/validator.py | 12 ++++++------ 12 files changed, 22 insertions(+), 22 deletions(-) rename mx_rec/{constants => util}/constants.py (100%) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 116ef432..f2d0175f 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -15,7 +15,7 @@ from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import create_table, sparse_lookup from mx_rec.graph.modifier import modify_graph_and_start_emb_cache -from mx_rec.constants.constants import MxRecMode, ASCEND_TIMESTAMP +from mx_rec.util.constants import MxRecMode, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_id, get_rank_size, init, clear_channel, terminate_config_initializer, \ set_if_load, get_initializer from mx_rec.util.variable import get_dense_and_sparse_variable @@ -183,7 +183,7 @@ if __name__ == "__main__": train_ops.append(dense_optimizer.apply_gradients(avg_grads)) if use_dynamic_expansion: - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET + from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) # do sparse optimization by addr diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 9578f5a2..7a2ed887 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION +from mx_rec.util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 4b449fc5..31cf30a6 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -7,7 +7,7 @@ import logging import tensorflow as tf import mxrec_pybind -from mx_rec.constants.constants import AVOID_TENSOR_POS +from mx_rec.util.constants import AVOID_TENSOR_POS from mx_rec.util.initialize import get_use_static from mx_rec.util.tf_version_adapter import npu_ops diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index dd0abce0..21f497f3 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -210,7 +210,7 @@ def do_insert(args, insert_tensors, splits, table_names, input_dict): # Only the tables that need to be used after table combination are retained in meituan situation. # Current solution has error in same situations. For example, a sparse table has not been auto-merged. - from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN new_insert_tensors, new_splits, new_table_names = [], [], [] logging.debug(f"In do_insert function, ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") for idx, table_name in enumerate(table_names): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 944a0f9a..a8dcf6cb 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -6,7 +6,7 @@ import logging import tensorflow as tf -from mx_rec.constants.constants import MxRecMode +from mx_rec.util.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ @@ -15,7 +15,7 @@ from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, se def generate_table_info_list(): from mxrec_pybind import EmbInfo - from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN # table_name is corresponding to channel_name which is in used in operator gen_npu_ops.get_next table_info_list = [] diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index e42b92dc..50b606ef 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -15,7 +15,7 @@ from tensorflow.python.ops import array_ops from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ +from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 87eebfd7..7ceb42dd 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -13,7 +13,7 @@ from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding -from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ +from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ ASCAnchorAttr, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_size, get_training_mode_channel_id, get_feature_spec, \ insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, get_use_dynamic_expansion, \ diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 02b70bd1..93a00810 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -13,7 +13,7 @@ import numpy as np import tensorflow as tf from tensorflow.python.util import compat -from mx_rec.constants.constants import DataName, DataAttr +from mx_rec.util.constants import DataName, DataAttr from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, \ get_ascend_global_hashtable_collection diff --git a/mx_rec/constants/constants.py b/mx_rec/util/constants.py similarity index 100% rename from mx_rec/constants/constants.py rename to mx_rec/util/constants.py diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 1aed3a39..045c7cb7 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -10,9 +10,9 @@ from collections import defaultdict import mxrec_pybind import psutil -import mx_rec.constants.constants -from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST -from mx_rec.constants.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE +import mx_rec.util.constants +from mx_rec.util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST +from mx_rec.util.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.validator.validator import RankInfoValidator @@ -205,7 +205,7 @@ class ConfigInitializer: Now, only less than or equal 8p training job is supported. :return: None """ - RankInfoValidator() + RankInfoValidator().check_visible_devices() ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") device_list = [] try: @@ -583,7 +583,7 @@ def set_initializer(is_training, initializer): def set_ascend_table_name_must_contain(name="merged"): - mx_rec.constants.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name + mx_rec.util.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name def set_ascend_env(): diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py index edb29781..8b5ccef0 100644 --- a/mx_rec/util/ops.py +++ b/mx_rec/util/ops.py @@ -7,7 +7,7 @@ import os import tensorflow as tf -from mx_rec.constants.constants import HOST_PIPELINE_OPS_LIB_PATH +from mx_rec.util.constants import HOST_PIPELINE_OPS_LIB_PATH def import_host_pipeline_ops(): diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index 1922a42f..285461db 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -7,12 +7,12 @@ import os from typing import Callable, Any from typing import List, Optional, Tuple -from mx_rec.constants.constants import MIN_SIZE -from mx_rec.constants.constants import MAX_SIZE -from mx_rec.constants.constants import MAX_DEVICE_NUM -from mx_rec.constants.constants import MAX_RANK_SIZE -from mx_rec.constants.constants import MIN_DEVICE_NUM -from mx_rec.constants.constants import MIN_RANK_SIZE +from mx_rec.util.constants import MIN_SIZE +from mx_rec.util.constants import MAX_SIZE +from mx_rec.util.constants import MAX_DEVICE_NUM +from mx_rec.util.constants import MAX_RANK_SIZE +from mx_rec.util.constants import MIN_DEVICE_NUM +from mx_rec.util.constants import MIN_RANK_SIZE class Validator: -- Gitee From face89206730923a88b1f5a3258f40ff57dadc02 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 21:27:48 +0800 Subject: [PATCH 083/551] Match-id-258ac3c3653bc3609941a878381fa28587f9b3f6 --- mx_rec/core/asc/build_graph.py | 9 ++++++- mx_rec/core/asc/feature_spec.py | 8 ++++++- mx_rec/core/asc/helper.py | 29 ++++++++++++++++------ mx_rec/core/embedding.py | 33 +++++++++++++++++--------- mx_rec/graph/modifier.py | 12 ++++++---- mx_rec/optimizers/lazy_adam.py | 14 +++++++++-- mx_rec/optimizers/lazy_adam_by_addr.py | 14 +++++++++-- mx_rec/saver/patch.py | 13 ++++++---- 8 files changed, 100 insertions(+), 32 deletions(-) diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 4b449fc5..1ebf9e67 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -142,4 +142,11 @@ def get_preprocessed_tensor_for_asc(table, config, ids_channel_name=None, modify h2d_emb_split = tf.split(h2d_emb, table_num, axis=1) swap_in = [tf.compat.v1.scatter_nd_update(table[i], nd_swap_pos, h2d_emb_split[i]) for i in range(len(table))] - return restore_vector, hot_pos, id_offsets, swap_in, all2all_args + reslt = { + 'restore_vector' : restore_vector, + 'hot_pos' : hot_pos, + 'id_offsets' : id_offsets, + 'swap_in' : swap_in, + 'all2all_args' : all2all_args, + } + return reslt diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 8e599e59..2a80f8a1 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -171,7 +171,13 @@ class FeatureSpec: f"is not {is_training}. ") insert_feature_spec(self, is_training) - return tensor, self.table_name, self.feat_cnt, self.split + reslt = { + 'tensor' : tensor, + 'table_name' : self.table_name, + 'feat_count' : self.feat_cnt, + 'split' : self.split, + } + return reslt def get_feature_spec(table_name, access_and_evict_config): diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index dd0abce0..c8c5632d 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -64,8 +64,12 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ if len(args) == 1: data_src = args[0] - read_emb_key_inputs_dict = {"insert_tensors": [], "table_names": [], - "feature_spec_names": [], "splits": []} + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } get_target_tensors_with_feature_specs(tgt_key_specs, data_src, is_training, read_emb_key_inputs_dict) logging.debug(f"do_insert with spec for {read_emb_key_inputs_dict.get('table_names')}") return do_insert(args, @@ -138,7 +142,13 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list, featu output_tensorshape_split_list.append(last_tensorshape_split) logging.debug(f"merge request from {table_name_list} {split_list} " f" to {output_table_name_list} {output_split_list}") - return output_feature_id_list, output_split_list, output_table_name_list, output_tensorshape_split_list + list_set = { + 'output_feature_id_list' : output_feature_id_list, + 'output_split_list' : output_split_list, + 'output_table_name_list' : output_table_name_list, + 'output_tensorshape_split_list' : output_tensorshape_split_list, + } + return list_set def send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict): @@ -155,9 +165,11 @@ def send_feature_id_request_async(feature_id_list, split_list, table_name_list, feature_id_list = feature_id_list[1:] if not auto_change_graph: # future support acg - feature_id_list, split_list, table_name_list, tensorshape_split_list = \ - merge_feature_id_request(feature_id_list, split_list, - table_name_list, feature_spec_names) + list_set = merge_feature_id_request(feature_id_list, split_list, table_name_list, feature_spec_names) + feature_id_list = list_set.get("output_feature_id_list") + split_list = list_set.get("output_split_list") + table_name_list = list_set.get("output_table_name_list") + tensorshape_split_list = list_set.get("output_tensorshape_split_list") else: tensorshape_split_list = split_list @@ -326,7 +338,10 @@ def get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, rea raise ValueError(f"Encounter a invalid batch.") if feature_spec.is_timestamp is None: - tensor, table_name, feat_count, split = feature_spec.set_feat_attribute(tensor, is_training) + reslt = feature_spec.set_feat_attribute(tensor, is_training) + tensor = reslt.get("tensor") + table_name = reslt.get("table_name") + split = reslt.get("split") if tensor.dtype != tf.int64: tensor = tf.cast(tensor, dtype=tf.int64) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index e42b92dc..b9b17120 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -27,11 +27,21 @@ from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.util.variable import remove_saving_var -def create_table(key_dtype, dim, name, emb_initializer, device_vocabulary_size=1, host_vocabulary_size=0, - optimizer_list=None, mode=MxRecMode.ASC, value_dtype=tf.float32, shard_num=1, - fusion_optimizer_var=True, hashtable_threshold=0): - """ +def create_table(**kwargs): + key_dtype = kwargs.get("key_dtype") + dim = kwargs.get("dim") + name = kwargs.get("name") + emb_initializer = kwargs.get("emb_initializer") + device_vocabulary_size = kwargs.get("device_vocabulary_size", 1) + host_vocabulary_size = kwargs.get("host_vocabulary_size", 0) + optimizer_list = kwargs.get("optimizer_list", None) + mode = kwargs.get("mode", MxRecMode.ASC) + value_dtype = kwargs.get("value_dtype", tf.float32) + shard_num = kwargs.get("shard_num", 1) + fusion_optimizer_var = kwargs.get("fusion_optimizer_var", True) + hashtable_threshold = kwargs.get("hashtable_threshold", 0) + """ Args: key_dtype: data type for feature id dim: embedding vector size @@ -46,10 +56,8 @@ def create_table(key_dtype, dim, name, emb_initializer, device_vocabulary_size=1 shard_num: embedding partition number fusion_optimizer_var: fusion optimizer variable with embedding hashtable_threshold: choose to implement based on hash table or linear layer - - Returns: SparseEmbedding instance - """ + config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, optimizer_list=optimizer_list, mode=mode, value_dtype=value_dtype, shard_num=shard_num, @@ -585,12 +593,15 @@ class SparseEmbedding: use_dynamic_expansion=use_dynamic_expansion) if self.skip_emb_transfer: - restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( - self.variable, config) + result = get_preprocessed_tensor_for_asc(self.variable, config) else: variable_list = [self.variable] + [slot_info.get("slot") for slot_info in self.optimizer_slot_info_list] - restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( - variable_list, config) + result = get_preprocessed_tensor_for_asc(variable_list, config) + restore_vector = result.get("restore_vector") + hot_pos = result.get("hot_pos") + id_offsets = result.get("id_offsets") + swap_in = result.get("swap_in") + all2all_matrix = result.get("all2all_matrix") control_ops = swap_in id_offsets = tf.identity(id_offsets, name="identity_addr") diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 87eebfd7..612363c3 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -389,13 +389,17 @@ def build_asc_graph(table_instance, cutting_point, config, is_training): raise ValueError(f"The length of channel_name_list must be greater than or equal to 1.") if skip_emb_transfer: - restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( - table_instance.variable, config, ids_channel_name, table_instance.modify_graph) + result = get_preprocessed_tensor_for_asc(table_instance.variable, config, ids_channel_name, + table_instance.modify_graph) else: variable_list = [table_instance.variable] \ + [slot_info.get("slot") for slot_info in table_instance.optimizer_slot_info_list] - restore_vector, hot_pos, id_offsets, swap_in, all2all_matrix = get_preprocessed_tensor_for_asc( - variable_list, config, ids_channel_name, table_instance.modify_graph) + result = get_preprocessed_tensor_for_asc(variable_list, config, ids_channel_name, table_instance.modify_graph) + restore_vector = result.get("restore_vector") + hot_pos = result.get("hot_pos") + id_offsets = result.get("id_offsets") + swap_in = result.get("swap_in") + all2all_matrix = result.get("all2all_args") with tf.control_dependencies(swap_in): id_offsets = tf.identity(id_offsets) diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 51cf6e6f..a578b7df 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -112,7 +112,13 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): temp_b1 = math_ops.cast(self._beta1_t, var_type) temp_b2 = math_ops.cast(self._beta2_t, var_type) temp_epsilon = math_ops.cast(self._epsilon_t, var_type) - return temp_lr, temp_b1, temp_b2, temp_epsilon + temp = { + 'temp_lr' : temp_lr, + 'temp_b1' : temp_b1, + 'temp_b2' : temp_b2, + 'temp_epsilon' : temp_epsilon, + } + return temp def _resource_apply_sparse(self, grad, handle, indices): logging.debug("Enter _resource_apply_sparse") @@ -134,7 +140,11 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): power_b1, power_b2 = self._get_beta_accumulators() power_b1 = math_ops.cast(power_b1, var.dtype.base_dtype) power_b2 = math_ops.cast(power_b2, var.dtype.base_dtype) - temp_lr, temp_b1, temp_b2, temp_epsilon = self._cast_to_base_type(var) + temp = self._cast_to_base_type(var) + temp_lr = temp.get("temp_lr") + temp_b1 = temp.get("temp_b1") + temp_b2 = temp.get("temp_b2") + temp_epsilon = temp.get("temp_epsilon") learning_rate = tf.divide(temp_lr * math_ops.sqrt(1 - power_b2), (1 - power_b1)) abs_indices = tf.math.maximum(indices, 0) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 42d0e5ba..17b58a9a 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -202,7 +202,13 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): temp_b1 = math_ops.cast(self._beta1_t, var_type) temp_b2 = math_ops.cast(self._beta2_t, var_type) temp_epsilon = math_ops.cast(self._epsilon_t, var_type) - return temp_lr, temp_b1, temp_b2, temp_epsilon + temp = { + 'temp_lr' : temp_lr, + 'temp_b1' : temp_b1, + 'temp_b2' : temp_b2, + 'temp_epsilon' : temp_epsilon, + } + return temp def _apply_sparse(self, grad, addr): logging.debug(">>>> Enter _apply_sparse Lazy_adam by addr") @@ -214,7 +220,11 @@ class CustomizedLazyAdamByAddress(optimizer.Optimizer, CustomizedOptimizer): power_b1, power_b2 = self._get_beta_accumulators() power_b1 = math_ops.cast(power_b1, grad.dtype.base_dtype) power_b2 = math_ops.cast(power_b2, grad.dtype.base_dtype) - temp_lr, temp_b1, temp_b2, temp_epsilon = self._cast_to_base_type(grad) + temp = self._cast_to_base_type(grad) + temp_lr = temp.get("temp_lr") + temp_b1 = temp.get("temp_b1") + temp_b2 = temp.get("temp_b2") + temp_epsilon = temp.get("temp_epsilon") learning_rate = tf.divide(temp_lr * math_ops.sqrt(1 - power_b2), (1 - power_b1)) host_pipeline_ops = get_host_pipeline_ops() diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 706f45b0..9bc25da9 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -123,8 +123,12 @@ def get_model_checkpoint_path(self, checkpoint_file, sess): return model_checkpoint_path -def update_checkpoint_state(self, model_checkpoint_path, parent_save_path, latest_file_name, suffix_meta_graph, - save_path): +def update_checkpoint_state(self, **kwargs): + model_checkpoint_path = kwargs.get("model_checkpoint_path") + parent_save_path = kwargs.get("parent_save_path") + latest_file_name = kwargs.get("latest_file_name") + suffix_meta_graph = kwargs.get("suffix_meta_graph") + save_path = kwargs.get("save_path") self._RecordLastCheckpoint(model_checkpoint_path) try: checkpoint_management.update_checkpoint_state_internal(save_dir=parent_save_path, @@ -182,8 +186,9 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra model_checkpoint_path = compat.as_str(get_model_checkpoint_path(self, checkpoint_file, sess)) if write_state: - update_checkpoint_state(self, model_checkpoint_path, save_path_parent, latest_filename, meta_graph_suffix, - save_path) + update_checkpoint_state(self, model_checkpoint_path=model_checkpoint_path, save_path_parent=save_path_parent, + latest_filename=latest_filename, meta_graph_suffix=meta_graph_suffix, + save_path=save_path) if write_meta_graph: write_meta_graph_task(self, checkpoint_file, meta_graph_suffix, sess, strip_default_attrs, save_debug_info) return model_checkpoint_path -- Gitee From 1d9a628dc6563c3dc4c1b0f58fde3081e196c2e4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 21:45:51 +0800 Subject: [PATCH 084/551] Match-id-56902871e0e2a0adf65b497fb632035454dfe61f --- src/core/utils/spinlock.h | 39 ++++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/src/core/utils/spinlock.h b/src/core/utils/spinlock.h index e3b9e07d..527ef2b5 100644 --- a/src/core/utils/spinlock.h +++ b/src/core/utils/spinlock.h @@ -34,17 +34,33 @@ static constexpr uint16_t g_kMaxSpinCountBeforeThreadYield = 64; class SpinLock final { public: void lock() noexcept {} + void unlock() noexcept {} - bool try_lock() noexcept { return true; } + + bool try_lock() noexcept + { + return true; + } }; #elif defined(USE_MUTEX) class SpinLock final { public: - void lock() noexcept { mt_.lock(); } - bool try_lock() noexcept { return mt_.try_lock(); } - void unlock() noexcept { mt_.unlock(); } + void lock() noexcept + { + mt_.lock(); + } + + bool try_lock() noexcept + { + return mt_.try_lock(); + } + + void unlock() noexcept + { + mt_.unlock(); + } private: std::mutex mt_; @@ -87,7 +103,10 @@ public: return !lock_.exchange(true, std::memory_order_acquire); } - inline void unlock() noexcept { lock_.store(false, std::memory_order_release); } + inline void unlock() noexcept + { + lock_.store(false, std::memory_order_release); + } private: std::atomic lock_{false}; @@ -167,9 +186,15 @@ public: } } - inline void r_unlock() noexcept { --lock_; } + inline void r_unlock() noexcept + { + --lock_; + } - inline void w_unlock() noexcept { lock_.store(0, std::memory_order_release); } + inline void w_unlock() noexcept + { + lock_.store(0, std::memory_order_release); + } private: std::atomic lock_{0}; -- Gitee From f673ed5390ee032768266b61ac0d2596658d859a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 21:55:25 +0800 Subject: [PATCH 085/551] Match-id-c70fa003177847aee66eea9c43dde9f3736433d9 --- mx_rec/core/asc/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 20b8bfc7..944a0f9a 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -54,8 +54,8 @@ def generate_table_info_list(): if len(table_instance.channel_name_list) == 1: ids_channel_name = table_instance.channel_name_list[0] table_instance.channel_name_list = [table_instance.table_name] - table_instance.send_count_map.pop(ids_channel_name) try: + table_instance.send_count_map.pop(ids_channel_name) table_instance.send_count_map[table_instance.table_name] = table_instance.send_count except KeyError as error: raise KeyError(f"ids_channel_name '{ids_channel_name}' not in send_count_map " -- Gitee From c109ee1e257eceffdcb36ed965070f42e54f10ac Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 22:02:05 +0800 Subject: [PATCH 086/551] Match-id-5df78bff87e70e0b77dffa9af72b0c3eefd06b72 --- mx_rec/core/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index b9b17120..c9496e3b 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -34,7 +34,7 @@ def create_table(**kwargs): emb_initializer = kwargs.get("emb_initializer") device_vocabulary_size = kwargs.get("device_vocabulary_size", 1) host_vocabulary_size = kwargs.get("host_vocabulary_size", 0) - optimizer_list = kwargs.get("optimizer_list", None) + optimizer_list = kwargs.get("optimizer_list") mode = kwargs.get("mode", MxRecMode.ASC) value_dtype = kwargs.get("value_dtype", tf.float32) shard_num = kwargs.get("shard_num", 1) -- Gitee From 2e7d3e67215c73456833d83de043851467c857ce Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 22:07:40 +0800 Subject: [PATCH 087/551] Match-id-586e4a77b04fa5e673a432b94b15816a887b584d --- mx_rec/util/initialize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 94004bab..04c546a9 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -323,9 +323,9 @@ class ConfigInitializer: def del_asc_manager(self): self.delete_initializers() - self.unfreeze() self._asc_manager.destroy() self._asc_manager = None + self.unfreeze() logging.debug("ASC manager has been destroyed.") @train_interval.setter -- Gitee From 62323d91ccc4a31f87d4bfef94429a775db52209 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 22:16:30 +0800 Subject: [PATCH 088/551] Match-id-48b79aacc93c0d2bfd2a2362a079246e40b9af4e --- mx_rec/core/asc/build_graph.py | 4 ++-- mx_rec/core/asc/feature_spec.py | 4 ++-- mx_rec/core/asc/helper.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 1ebf9e67..70bffe2f 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -142,11 +142,11 @@ def get_preprocessed_tensor_for_asc(table, config, ids_channel_name=None, modify h2d_emb_split = tf.split(h2d_emb, table_num, axis=1) swap_in = [tf.compat.v1.scatter_nd_update(table[i], nd_swap_pos, h2d_emb_split[i]) for i in range(len(table))] - reslt = { + result = { 'restore_vector' : restore_vector, 'hot_pos' : hot_pos, 'id_offsets' : id_offsets, 'swap_in' : swap_in, 'all2all_args' : all2all_args, } - return reslt + return result diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 2a80f8a1..c11515dd 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -171,13 +171,13 @@ class FeatureSpec: f"is not {is_training}. ") insert_feature_spec(self, is_training) - reslt = { + result = { 'tensor' : tensor, 'table_name' : self.table_name, 'feat_count' : self.feat_cnt, 'split' : self.split, } - return reslt + return result def get_feature_spec(table_name, access_and_evict_config): diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index c8c5632d..b57a2adb 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -338,10 +338,10 @@ def get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, rea raise ValueError(f"Encounter a invalid batch.") if feature_spec.is_timestamp is None: - reslt = feature_spec.set_feat_attribute(tensor, is_training) - tensor = reslt.get("tensor") - table_name = reslt.get("table_name") - split = reslt.get("split") + result = feature_spec.set_feat_attribute(tensor, is_training) + tensor = result.get("tensor") + table_name = result.get("table_name") + split = result.get("split") if tensor.dtype != tf.int64: tensor = tf.cast(tensor, dtype=tf.int64) -- Gitee From d5869db4eb35b685d998e5ca16034ece33369bff Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 29 May 2023 22:33:12 +0800 Subject: [PATCH 089/551] Match-id-ac9c7b20308cf8ff07bdb86649b3913a6dfc2f91 --- mx_rec/util/constants.py | 9 - mx_rec/util/initialize.py | 72 +++----- mx_rec/util/ops.py | 20 +-- mx_rec/validator/__init__.py | 3 - mx_rec/validator/validator.py | 300 ---------------------------------- 5 files changed, 29 insertions(+), 375 deletions(-) delete mode 100644 mx_rec/validator/__init__.py delete mode 100644 mx_rec/validator/validator.py diff --git a/mx_rec/util/constants.py b/mx_rec/util/constants.py index b3816d43..3a9be2ec 100644 --- a/mx_rec/util/constants.py +++ b/mx_rec/util/constants.py @@ -33,15 +33,6 @@ DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 -# RANK INFO -VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"] -MIN_SIZE = 1 -MAX_SIZE = 1024 * 1024 * 1024 * 1024 -MAX_DEVICE_NUM = 16 -MAX_RANK_SIZE = 4095 -MIN_DEVICE_NUM = 1 -MIN_RANK_SIZE = 1 - class BaseEnum(Enum): @classmethod diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 045c7cb7..e9ae8c5b 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -7,14 +7,13 @@ import logging import os from collections import defaultdict -import mxrec_pybind import psutil +import mxrec_pybind import mx_rec.util.constants -from mx_rec.util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST -from mx_rec.util.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE +from mx_rec.util.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, \ + ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.ops import import_host_pipeline_ops -from mx_rec.validator.validator import RankInfoValidator class ConfigInitializer: @@ -57,7 +56,7 @@ class ConfigInitializer: if os.getenv("RANK_TABLE_FILE"): self.parse_hccl_json() else: - self.set_hccl_info_without_json() + self.set_device_dict() self.check_parameters() self.train_interval = kwargs.get("train_interval", -1) self.eval_steps = kwargs.get("eval_steps", -1) @@ -199,56 +198,31 @@ class ConfigInitializer: raise ValueError(f"get logic id from physic id fail.") self._rank_to_device_dict[rank_id] = device_id - def set_hccl_info_without_json(self): - """ - Used for no rank table file configured training situation. - Now, only less than or equal 8p training job is supported. - :return: None - """ - RankInfoValidator().check_visible_devices() + def set_device_dict(self): ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") - device_list = [] - try: - if "-" in ascend_visible_devices: - split_devices = ascend_visible_devices.strip().split("-") - if len(split_devices) >= 1: - rank_start = int(split_devices[0]) - device_list = [i for i in range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]) + 1)] - elif "," in ascend_visible_devices: - device_list = list(map(int, ascend_visible_devices.strip().split(","))) - elif ascend_visible_devices in VALID_DEVICE_ID_LIST: - device_list = [int(ascend_visible_devices.strip())] - else: - raise ValueError("invalid env variable ascend_visible_devices.") - except ValueError as error: - raise ValueError("Invalid env variable ascend_visible_devices, no valid device id is configured. " - "Please refer to the document https://www.hiascend.com/document/detail/zh/" - "CANNCommunityEdition/63RC2alpha002/ptmoddevg/ptmigr/ptmigr_0151.html for " - "the correct configuration method.") from error - except IndexError as error: - raise IndexError( - f"Index of ascend_visible_devices {ascend_visible_devices.strip().split('-')[-1]} is out of range") \ - from error - - chief_device = os.getenv("CM_CHIEF_DEVICE") - rank_size = os.getenv("CM_WORKER_SIZE") - try: - rank_size = int(rank_size) - self._rank_to_device_dict[0] = int(chief_device) - device_list.pop(int(chief_device)) - except IndexError as err: - raise IndexError( - f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {device_list}.") from err - except ValueError as err: - raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err - + if not ascend_visible_devices: + raise ValueError("env variable ascend_visible_devices is null.") + if "-" in ascend_visible_devices: + rank_start = int(ascend_visible_devices.strip().split("-")[0]) + device_list = [i for i in range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]))] + elif "," in ascend_visible_devices: + device_list = list(map(int, ascend_visible_devices.strip().split(","))) + elif ascend_visible_devices in ["0", "1", "2", "3", "4", "5", "6", "7"]: + device_list = [int(ascend_visible_devices.strip())] + else: + raise ValueError("invalid env variable ascend_visible_devices.") + rank_size = int(os.getenv("CM_WORKER_SIZE")) + self._rank_to_device_dict[0] = int(os.getenv("CM_CHIEF_DEVICE")) + device_list.pop(int(os.getenv("CM_CHIEF_DEVICE"))) if rank_size: - local_rank_size = len(device_list) - for device_index in range(local_rank_size): + local_rank_size = rank_size if rank_size < 8 else 8 + for device_index in range(local_rank_size - 1): device_id = mxrec_pybind.get_logic_id(int(device_list[device_index])) if device_id > 16: raise ValueError(f"get logic id from physic id fail.") self._rank_to_device_dict[device_index + 1] = device_id + else: + raise ValueError("get CM_WORKER_SIZE failed.") def insert_training_mode_channel_id(self, is_training): if is_training not in self._training_mode_channel_dict: diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py index 8b5ccef0..40c99f37 100644 --- a/mx_rec/util/ops.py +++ b/mx_rec/util/ops.py @@ -2,9 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging import os - +import logging import tensorflow as tf from mx_rec.util.constants import HOST_PIPELINE_OPS_LIB_PATH @@ -12,18 +11,11 @@ from mx_rec.util.constants import HOST_PIPELINE_OPS_LIB_PATH def import_host_pipeline_ops(): host_pipeline_ops_lib_path = os.getenv(HOST_PIPELINE_OPS_LIB_PATH) - if host_pipeline_ops_lib_path and os.path.exists(host_pipeline_ops_lib_path): + if host_pipeline_ops_lib_path: logging.debug(f"Using the HOST_PIPELINE_OPS_LIB_PATH '{host_pipeline_ops_lib_path}' to get ops lib.") return tf.load_op_library(host_pipeline_ops_lib_path) - elif os.path.exists( - os.path.join(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), - 'mx_rec/libasc/libasc_ops.so')): - default_so_path = os.path.join( - os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), - 'mx_rec/libasc/libasc_ops.so') - logging.debug(f"Using the DEFAULT PATH '{default_so_path}' to get ops lib.") - return tf.load_op_library(default_so_path) else: - raise ValueError("Invalid host pipeline ops lib path. Please check if libasc_ops.so exists or corrected " - "configured") - + mx_rec_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")) + so_path = os.path.join(mx_rec_dir, 'mx_rec/libasc/libasc_ops.so') + logging.debug(f"Using the DEFAULT PATH '{so_path}' to get ops lib.") + return tf.load_op_library(so_path) diff --git a/mx_rec/validator/__init__.py b/mx_rec/validator/__init__.py deleted file mode 100644 index 8f75c6b6..00000000 --- a/mx_rec/validator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py deleted file mode 100644 index 285461db..00000000 --- a/mx_rec/validator/validator.py +++ /dev/null @@ -1,300 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. - - -import os -from typing import Callable, Any -from typing import List, Optional, Tuple - -from mx_rec.util.constants import MIN_SIZE -from mx_rec.util.constants import MAX_SIZE -from mx_rec.util.constants import MAX_DEVICE_NUM -from mx_rec.util.constants import MAX_RANK_SIZE -from mx_rec.util.constants import MIN_DEVICE_NUM -from mx_rec.util.constants import MIN_RANK_SIZE - - -class Validator: - """ - A validator to check the input parameters - """ - - def __init__(self, value, msg="value is invalid"): - """ - :param value: the value for validation - :param msg: default error msg - """ - self.value = value - self.msg = msg - self.checkers = [] - self.is_valid_state = None - - def register_checker(self, checker: Callable[[Any], bool], msg: str = None): - self.checkers.append((checker, msg if msg else self.msg)) - - def check(self): - if self.is_valid_state is None: - self.is_valid_state = True - for checker, msg in self.checkers: - self.is_valid_state &= checker(self.value) - if not self.is_valid_state: - self.msg = msg - break - if self.is_valid_state: - self.msg = None - return self - - def is_valid(self): - if self.is_valid_state is None: - self.check() - return self.is_valid_state - - def get_value(self, default=None): - return self.value if self.is_valid() else default - - -class ClassValidator(Validator): - """ - Check class validator. - """ - - def __init__(self, value, classes): - super().__init__(value) - self.classes = classes - - def check_isinstance(self): - """Check arg isinstance of classes""" - self.register_checker(lambda path: isinstance(self.value, self.classes), "Invalid parameter type") - return self - - -class StringValidator(Validator): - """ - String type validator. - """ - - def __init__(self, value, max_len=None, min_len=0): - super().__init__(value) - self.max_len = max_len - self.min_len = min_len - self.register_checker(lambda x: isinstance(x, str), "type is not str") - - def check_string_length(self): - if self.min_len is not None: - self.register_checker(lambda x: len(x) >= self.min_len, f"length is less than {self.min_len}") - if self.max_len is not None: - self.register_checker(lambda x: len(x) <= self.max_len, f"length is bigger than {self.max_len}") - return self - - def check_not_contain_black_element(self, element): - self.register_checker(lambda x: x is not None and element is not None and x.find(element) == -1) - return self - - def can_be_transformed2int(self, min_value: int = None, max_value: int = None): - if min_value is None: - min_value = MIN_RANK_SIZE - if max_value is None: - max_value = MAX_RANK_SIZE - - can_transformed = self.value.isdigit() - try: - if can_transformed and (min_value > int(self.value) or max_value < int(self.value)): - can_transformed = False - except ValueError: - can_transformed = False - finally: - if self.is_valid_state is not None: - self.is_valid_state &= can_transformed - else: - self.is_valid_state = can_transformed - return self - - -class IntValidator(Validator): - """ - Int type validator - """ - - def __init__(self, value: int, min_value: int = None, max_value: int = None): - super().__init__(value) - self.min_value = min_value - self.max_value = max_value - self.register_checker(lambda x: isinstance(x, int), "type is not int") - - def check_value(self): - if self.min_value is not None: - self.register_checker(lambda x: x >= self.min_value, f"value is less than {self.min_value}") - if self.max_value is not None: - self.register_checker(lambda x: x <= self.max_value, f"value is bigger than {self.max_value}") - return self - - -class RankSizeValidator(IntValidator): - """ - Distributed training job size validator - """ - - def check_rank_size_valid(self): - super().__init__(self.value) - self.register_checker(lambda x: MIN_RANK_SIZE <= self.value <= MAX_RANK_SIZE, - "Invalid rank size") - return self - - def check_device_num_valid(self): - super().__init__(self.value) - self.register_checker(lambda x: MIN_DEVICE_NUM <= self.value <= MAX_DEVICE_NUM, - "Invalid device num") - return self - - -class DirectoryValidator(StringValidator): - def __init__(self, value, max_len=None, min_len=1): - """ - @param value: the path, should not be emtpy string, should not contain double dot(../) - """ - super().__init__(value, max_len, min_len) - self.register_checker(lambda x: isinstance(x, str), "type is not str") - - @staticmethod - def remove_prefix(string: Optional[str], prefix: Optional[str]) -> Tuple[bool, Optional[str]]: - if string is None or prefix is None or len(string) < len(prefix): - return False, string - if string.startswith(prefix): - return True, string[len(prefix):] - else: - return False, string - - @staticmethod - def check_is_children_path(path_: str, target_: str): - if not target_: - return False - - try: - realpath_ = os.path.realpath(path_) - except (TypeError, ValueError, OSError): - return False - - try: - realpath_target = os.path.realpath(target_) - except (TypeError, ValueError, OSError): - return False - - is_prefix, rest_part = DirectoryValidator.remove_prefix(realpath_target, realpath_) - - if rest_part.startswith(os.path.sep): - rest_part = rest_part.lstrip(os.path.sep) - if is_prefix: - joint_path = os.path.join(realpath_, rest_part) - return os.path.realpath(joint_path) == realpath_target - else: - return False - - @staticmethod - def __check_with_sensitive_words(path: str, words: List): - _, name = os.path.split(path) - if name: - return not any(map(lambda x: x in path, words)) - else: - return True - - def check_is_not_none(self): - self.register_checker(lambda path: self.value is not None and len(self.value) > 0, - "Invalid directory parameter") - return self - - def check_not_soft_link(self): - self.register_checker(lambda path: os.path.realpath(self.value) == os.path.normpath(self.value), - "soft link or relative path should not be in the path parameter") - return self - - def path_should_exist(self, is_file=True, msg=None): - self.register_checker(lambda path: os.path.exists(self.value), - msg if msg else "path parameter does not exist") - if is_file: - self.register_checker(lambda path: os.path.isfile(self.value), - msg if msg else "path parameter is not a file") - return self - - def path_should_not_exist(self): - self.register_checker(lambda path: not os.path.exists(self.value), "path parameter does not exist") - return self - - def with_blacklist(self, lst: List = None, exact_compare: bool = True, msg: str = None): - if lst is None: - lst = ["/usr/bin", "/usr/sbin", "/etc", "/usr/lib", "/usr/lib64"] - if len(lst) == 0: - return self - if msg is None: - msg = "path should is in blacklist" - if exact_compare: - self.register_checker(lambda path: path not in [os.path.realpath(each) for each in lst], msg) - else: - self.register_checker( - lambda path: not any([DirectoryValidator.check_is_children_path(each, path) for each in lst]), msg - ) - return self - - def should_not_contains_sensitive_words(self, words: List = None, msg=None): - if words is None: - words = ["Key", "password", "privatekey"] - self.register_checker(lambda path: DirectoryValidator.__check_with_sensitive_words(path, words), msg) - return self - - -class FileValidator(StringValidator): - def __init__(self, value): - """ - @param value: the file path, should not be emtpy string, should not contain double dot(../) - """ - super().__init__(value) - self.register_checker(lambda x: isinstance(x, str), "type is not str") - - def check_file_size(self, max_size=MAX_SIZE, min_size=MIN_SIZE): - self.register_checker(lambda path: min_size < os.path.getsize(self.value) <= max_size, - "file size is invalid") - return self - - def check_not_soft_link(self): - self.register_checker(lambda path: os.path.realpath(self.value) == self.value, - "soft link or relative path should not be in the path parameter") - return self - - def check_user_group(self): - process_uid = os.geteuid() - process_gid = os.getegid() - stat_info = os.stat(self.value) - file_uid = stat_info.st_uid - file_gid = stat_info.st_gid - self.register_checker( - lambda path: process_uid == file_uid or process_gid == file_gid, "Invalid log file user or group.") - return self - - -class RankInfoValidator: - def check_visible_devices(self): - visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") - device_res = StringValidator(visible_devices).check().is_valid() - if not device_res: - raise TypeError("env variable ascend_visible_devices is null, please config ASCEND_VISIBLE_DEVICES in " - "docker container start.") - - rank_size = os.getenv("CM_WORKER_SIZE") - rank_size_res = StringValidator(rank_size).check().is_valid().is_valid() - if not rank_size_res: - raise TypeError("env variable CM_WORKER_SIZE is null, please config CM_WORKER_SIZE. For example, " - "CM_WORKER_SIZE=1") - - try: - rank_size_value = int(rank_size) - res = RankSizeValidator(rank_size_value, 1, 16).check_rank_size_valid().is_valid() - if not res and rank_size_value not in [1, 2, 4, 8, 16]: - raise ValueError("Invalid rank size, rank size must between 0 and 15 in recommendation training.") - except ValueError as err: - raise ValueError("Invalid rank size, rank size is a valid integer.") from err - - chief_device = os.getenv("CM_CHIEF_DEVICE") - chief_device_res = StringValidator(chief_device).check().is_valid() - if not chief_device_res: - raise TypeError("env variable CM_CHIEF_DEVICE is null, please config CM_CHIEF_DEVICE. For example, " - "CM_CHIEF_DEVICE=0") -- Gitee From 4fdb642ab7161f4d8431894d52816e4a43439ee3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 30 May 2023 10:11:30 +0800 Subject: [PATCH 090/551] Match-id-9cd4a71d603363ad9bb2d88f3098abbd5df02c90 --- mx_rec/saver/patch.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 9bc25da9..72fe950f 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -143,8 +143,14 @@ def update_checkpoint_state(self, **kwargs): self._MaybeDeleteOldCheckpoints(meta_graph_suffix=suffix_meta_graph) -def write_meta_graph_task(self, checkpoint_file, suffix_meta_graph, sess, strip_default_attrs, save_debug_info): - meta_graph_name = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix=suffix_meta_graph) +def write_meta_graph_task(self, **kwargs): + checkpoint_file = kwargs.get("checkpoint_file") + meta_graph_suffix = kwargs.get("meta_graph_suffix") + sess = kwargs.get("sess") + strip_default_attrs = kwargs.get("strip_default_attrs") + save_debug_info = kwargs.get("save_debug_info") + + meta_graph_name = checkpoint_management.meta_graph_filename(checkpoint_file, meta_graph_suffix=meta_graph_suffix) if not context.executing_eagerly(): with sess.graph.as_default(): self.export_meta_graph(meta_graph_name, strip_default_attrs=strip_default_attrs, @@ -190,7 +196,8 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra latest_filename=latest_filename, meta_graph_suffix=meta_graph_suffix, save_path=save_path) if write_meta_graph: - write_meta_graph_task(self, checkpoint_file, meta_graph_suffix, sess, strip_default_attrs, save_debug_info) + write_meta_graph_task(self, checkpoint_file=checkpoint_file, meta_graph_suffix=meta_graph_suffix, sess=sess, + strip_default_attrs=strip_default_attrs, save_debug_info=save_debug_info) return model_checkpoint_path -- Gitee From 647a48353b4f724fc31cbdcaa913a637d4ad402f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 30 May 2023 10:15:11 +0800 Subject: [PATCH 091/551] Match-id-49d341a65b63c0a203369fc97c3f82226ce56635 --- .../cmake/util/makeself/COPYING | 339 -------- .../cmake/util/makeself/README.md | 246 ------ .../cmake/util/makeself/VERSION | 1 - .../cmake/util/makeself/make-release.sh | 9 - .../cmake/util/makeself/makeself-header.sh | 660 -------------- .../cmake/util/makeself/makeself.1 | 110 --- .../cmake/util/makeself/makeself.lsm | 16 - .../cmake/util/makeself/makeself.sh | 822 ------------------ .../cmake/util/makeself/run-tests.sh | 8 - mx_rec/validator/__init__.py | 3 + 10 files changed, 3 insertions(+), 2211 deletions(-) delete mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/COPYING delete mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/README.md delete mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/VERSION delete mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/make-release.sh delete mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/makeself-header.sh delete mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/makeself.1 delete mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/makeself.lsm delete mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/makeself.sh delete mode 100644 cust_op/cust_op_by_addr/cmake/util/makeself/run-tests.sh create mode 100644 mx_rec/validator/__init__.py diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/COPYING b/cust_op/cust_op_by_addr/cmake/util/makeself/COPYING deleted file mode 100644 index d159169d..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/makeself/COPYING +++ /dev/null @@ -1,339 +0,0 @@ - GNU GENERAL PUBLIC LICENSE - Version 2, June 1991 - - Copyright (C) 1989, 1991 Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - Preamble - - The licenses for most software are designed to take away your -freedom to share and change it. By contrast, the GNU General Public -License is intended to guarantee your freedom to share and change free -software--to make sure the software is free for all its users. This -General Public License applies to most of the Free Software -Foundation's software and to any other program whose authors commit to -using it. (Some other Free Software Foundation software is covered by -the GNU Lesser General Public License instead.) You can apply it to -your programs, too. - - When we speak of free software, we are referring to freedom, not -price. Our General Public Licenses are designed to make sure that you -have the freedom to distribute copies of free software (and charge for -this service if you wish), that you receive source code or can get it -if you want it, that you can change the software or use pieces of it -in new free programs; and that you know you can do these things. - - To protect your rights, we need to make restrictions that forbid -anyone to deny you these rights or to ask you to surrender the rights. -These restrictions translate to certain responsibilities for you if you -distribute copies of the software, or if you modify it. - - For example, if you distribute copies of such a program, whether -gratis or for a fee, you must give the recipients all the rights that -you have. You must make sure that they, too, receive or can get the -source code. And you must show them these terms so they know their -rights. - - We protect your rights with two steps: (1) copyright the software, and -(2) offer you this license which gives you legal permission to copy, -distribute and/or modify the software. - - Also, for each author's protection and ours, we want to make certain -that everyone understands that there is no warranty for this free -software. If the software is modified by someone else and passed on, we -want its recipients to know that what they have is not the original, so -that any problems introduced by others will not reflect on the original -authors' reputations. - - Finally, any free program is threatened constantly by software -patents. We wish to avoid the danger that redistributors of a free -program will individually obtain patent licenses, in effect making the -program proprietary. To prevent this, we have made it clear that any -patent must be licensed for everyone's free use or not licensed at all. - - The precise terms and conditions for copying, distribution and -modification follow. - - GNU GENERAL PUBLIC LICENSE - TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION - - 0. This License applies to any program or other work which contains -a notice placed by the copyright holder saying it may be distributed -under the terms of this General Public License. The "Program", below, -refers to any such program or work, and a "work based on the Program" -means either the Program or any derivative work under copyright law: -that is to say, a work containing the Program or a portion of it, -either verbatim or with modifications and/or translated into another -language. (Hereinafter, translation is included without limitation in -the term "modification".) Each licensee is addressed as "you". - -Activities other than copying, distribution and modification are not -covered by this License; they are outside its scope. The act of -running the Program is not restricted, and the output from the Program -is covered only if its contents constitute a work based on the -Program (independent of having been made by running the Program). -Whether that is true depends on what the Program does. - - 1. You may copy and distribute verbatim copies of the Program's -source code as you receive it, in any medium, provided that you -conspicuously and appropriately publish on each copy an appropriate -copyright notice and disclaimer of warranty; keep intact all the -notices that refer to this License and to the absence of any warranty; -and give any other recipients of the Program a copy of this License -along with the Program. - -You may charge a fee for the physical act of transferring a copy, and -you may at your option offer warranty protection in exchange for a fee. - - 2. You may modify your copy or copies of the Program or any portion -of it, thus forming a work based on the Program, and copy and -distribute such modifications or work under the terms of Section 1 -above, provided that you also meet all of these conditions: - - a) You must cause the modified files to carry prominent notices - stating that you changed the files and the date of any change. - - b) You must cause any work that you distribute or publish, that in - whole or in part contains or is derived from the Program or any - part thereof, to be licensed as a whole at no charge to all third - parties under the terms of this License. - - c) If the modified program normally reads commands interactively - when run, you must cause it, when started running for such - interactive use in the most ordinary way, to print or display an - announcement including an appropriate copyright notice and a - notice that there is no warranty (or else, saying that you provide - a warranty) and that users may redistribute the program under - these conditions, and telling the user how to view a copy of this - License. (Exception: if the Program itself is interactive but - does not normally print such an announcement, your work based on - the Program is not required to print an announcement.) - -These requirements apply to the modified work as a whole. If -identifiable sections of that work are not derived from the Program, -and can be reasonably considered independent and separate works in -themselves, then this License, and its terms, do not apply to those -sections when you distribute them as separate works. But when you -distribute the same sections as part of a whole which is a work based -on the Program, the distribution of the whole must be on the terms of -this License, whose permissions for other licensees extend to the -entire whole, and thus to each and every part regardless of who wrote it. - -Thus, it is not the intent of this section to claim rights or contest -your rights to work written entirely by you; rather, the intent is to -exercise the right to control the distribution of derivative or -collective works based on the Program. - -In addition, mere aggregation of another work not based on the Program -with the Program (or with a work based on the Program) on a volume of -a storage or distribution medium does not bring the other work under -the scope of this License. - - 3. You may copy and distribute the Program (or a work based on it, -under Section 2) in object code or executable form under the terms of -Sections 1 and 2 above provided that you also do one of the following: - - a) Accompany it with the complete corresponding machine-readable - source code, which must be distributed under the terms of Sections - 1 and 2 above on a medium customarily used for software interchange; or, - - b) Accompany it with a written offer, valid for at least three - years, to give any third party, for a charge no more than your - cost of physically performing source distribution, a complete - machine-readable copy of the corresponding source code, to be - distributed under the terms of Sections 1 and 2 above on a medium - customarily used for software interchange; or, - - c) Accompany it with the information you received as to the offer - to distribute corresponding source code. (This alternative is - allowed only for noncommercial distribution and only if you - received the program in object code or executable form with such - an offer, in accord with Subsection b above.) - -The source code for a work means the preferred form of the work for -making modifications to it. For an executable work, complete source -code means all the source code for all modules it contains, plus any -associated interface definition files, plus the scripts used to -control compilation and installation of the executable. However, as a -special exception, the source code distributed need not include -anything that is normally distributed (in either source or binary -form) with the major components (compiler, kernel, and so on) of the -operating system on which the executable runs, unless that component -itself accompanies the executable. - -If distribution of executable or object code is made by offering -access to copy from a designated place, then offering equivalent -access to copy the source code from the same place counts as -distribution of the source code, even though third parties are not -compelled to copy the source along with the object code. - - 4. You may not copy, modify, sublicense, or distribute the Program -except as expressly provided under this License. Any attempt -otherwise to copy, modify, sublicense or distribute the Program is -void, and will automatically terminate your rights under this License. -However, parties who have received copies, or rights, from you under -this License will not have their licenses terminated so long as such -parties remain in full compliance. - - 5. You are not required to accept this License, since you have not -signed it. However, nothing else grants you permission to modify or -distribute the Program or its derivative works. These actions are -prohibited by law if you do not accept this License. Therefore, by -modifying or distributing the Program (or any work based on the -Program), you indicate your acceptance of this License to do so, and -all its terms and conditions for copying, distributing or modifying -the Program or works based on it. - - 6. Each time you redistribute the Program (or any work based on the -Program), the recipient automatically receives a license from the -original licensor to copy, distribute or modify the Program subject to -these terms and conditions. You may not impose any further -restrictions on the recipients' exercise of the rights granted herein. -You are not responsible for enforcing compliance by third parties to -this License. - - 7. If, as a consequence of a court judgment or allegation of patent -infringement or for any other reason (not limited to patent issues), -conditions are imposed on you (whether by court order, agreement or -otherwise) that contradict the conditions of this License, they do not -excuse you from the conditions of this License. If you cannot -distribute so as to satisfy simultaneously your obligations under this -License and any other pertinent obligations, then as a consequence you -may not distribute the Program at all. For example, if a patent -license would not permit royalty-free redistribution of the Program by -all those who receive copies directly or indirectly through you, then -the only way you could satisfy both it and this License would be to -refrain entirely from distribution of the Program. - -If any portion of this section is held invalid or unenforceable under -any particular circumstance, the balance of the section is intended to -apply and the section as a whole is intended to apply in other -circumstances. - -It is not the purpose of this section to induce you to infringe any -patents or other property right claims or to contest validity of any -such claims; this section has the sole purpose of protecting the -integrity of the free software distribution system, which is -implemented by public license practices. Many people have made -generous contributions to the wide range of software distributed -through that system in reliance on consistent application of that -system; it is up to the author/donor to decide if he or she is willing -to distribute software through any other system and a licensee cannot -impose that choice. - -This section is intended to make thoroughly clear what is believed to -be a consequence of the rest of this License. - - 8. If the distribution and/or use of the Program is restricted in -certain countries either by patents or by copyrighted interfaces, the -original copyright holder who places the Program under this License -may add an explicit geographical distribution limitation excluding -those countries, so that distribution is permitted only in or among -countries not thus excluded. In such case, this License incorporates -the limitation as if written in the body of this License. - - 9. The Free Software Foundation may publish revised and/or new versions -of the General Public License from time to time. Such new versions will -be similar in spirit to the present version, but may differ in detail to -address new problems or concerns. - -Each version is given a distinguishing version number. If the Program -specifies a version number of this License which applies to it and "any -later version", you have the option of following the terms and conditions -either of that version or of any later version published by the Free -Software Foundation. If the Program does not specify a version number of -this License, you may choose any version ever published by the Free Software -Foundation. - - 10. If you wish to incorporate parts of the Program into other free -programs whose distribution conditions are different, write to the author -to ask for permission. For software which is copyrighted by the Free -Software Foundation, write to the Free Software Foundation; we sometimes -make exceptions for this. Our decision will be guided by the two goals -of preserving the free status of all derivatives of our free software and -of promoting the sharing and reuse of software generally. - - NO WARRANTY - - 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY -FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN -OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES -PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED -OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS -TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE -PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, -REPAIR OR CORRECTION. - - 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING -WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR -REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, -INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING -OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED -TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY -YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER -PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE -POSSIBILITY OF SUCH DAMAGES. - - END OF TERMS AND CONDITIONS - - How to Apply These Terms to Your New Programs - - If you develop a new program, and you want it to be of the greatest -possible use to the public, the best way to achieve this is to make it -free software which everyone can redistribute and change under these terms. - - To do so, attach the following notices to the program. It is safest -to attach them to the start of each source file to most effectively -convey the exclusion of warranty; and each file should have at least -the "copyright" line and a pointer to where the full notice is found. - - - Copyright (C) - - This program is free software; you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation; either version 2 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License along - with this program; if not, write to the Free Software Foundation, Inc., - 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - -Also add information on how to contact you by electronic and paper mail. - -If the program is interactive, make it output a short notice like this -when it starts in an interactive mode: - - Gnomovision version 69, Copyright (C) year name of author - Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. - This is free software, and you are welcome to redistribute it - under certain conditions; type `show c' for details. - -The hypothetical commands `show w' and `show c' should show the appropriate -parts of the General Public License. Of course, the commands you use may -be called something other than `show w' and `show c'; they could even be -mouse-clicks or menu items--whatever suits your program. - -You should also get your employer (if you work as a programmer) or your -school, if any, to sign a "copyright disclaimer" for the program, if -necessary. Here is a sample; alter the names: - - Yoyodyne, Inc., hereby disclaims all copyright interest in the program - `Gnomovision' (which makes passes at compilers) written by James Hacker. - - , 1 April 1989 - Ty Coon, President of Vice - -This General Public License does not permit incorporating your program into -proprietary programs. If your program is a subroutine library, you may -consider it more useful to permit linking proprietary applications with the -library. If this is what you want to do, use the GNU Lesser General -Public License instead of this License. diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/README.md b/cust_op/cust_op_by_addr/cmake/util/makeself/README.md deleted file mode 100644 index b41f0168..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/makeself/README.md +++ /dev/null @@ -1,246 +0,0 @@ -[![License: GPL v2](https://img.shields.io/badge/License-GPL%20v2-blue.svg)](https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html) -![Build Status](https://github.com/megastep/makeself/workflows/CI/badge.svg) - -# makeself - Make self-extractable archives on Unix - -[makeself.sh][1] is a small shell script that generates a self-extractable -compressed tar archive from a directory. The resulting file appears as a shell script -(many of those have a **.run** suffix), and can be launched as is. The archive -will then uncompress itself to a temporary directory and an optional arbitrary -command will be executed (for example an installation script). This is pretty -similar to archives generated with WinZip Self-Extractor in the Windows world. -Makeself archives also include checksums for integrity self-validation (CRC -and/or MD5/SHA256 checksums). - -The makeself.sh script itself is used only to create the archives from a -directory of files. The resultant archive is actually a compressed (using -gzip, bzip2, or compress) TAR archive, with a small shell script stub at the -beginning. This small stub performs all the steps of extracting the files, -running the embedded command, and removing the temporary files when done. -All the user has to do to install the software contained in such an -archive is to "run" the archive, i.e **sh nice-software.run**. I recommend -using the ".run" (which was introduced by some Makeself archives released by -Loki Software) or ".sh" suffix for such archives not to confuse the users, -so that they will know they are actually shell scripts (with quite a lot of binary data -attached to them though!). - -I am trying to keep the code of this script as portable as possible, i.e it is -not relying on any bash-specific features and only calls commands that are -installed on any functioning UNIX-compatible system. This script as well as -the archives it generates should run on any Unix flavor, with any compatible -Bourne shell, provided of course that the compression programs are available. - -As of version 2.1, Makeself has been rewritten and tested on the following -platforms : - - * Linux (all distributions) - * Sun Solaris (8 and above) - * HP-UX (tested on 11.0 and 11i on HPPA RISC) - * SCO OpenUnix and OpenServer - * IBM AIX 5.1L - * macOS (Darwin) - * SGI IRIX 6.5 - * FreeBSD - * UnicOS / Cray - * Cygwin (Windows) - -If you successfully run Makeself and/or archives created with it on another -system, then please [let me know][2]! - -Examples of publicly available archives made using makeself are : - - * Game patches and installers for [Id Software][3] games like Quake 3 for Linux or Return To Castle Wolfenstein ; - * All game patches released by [Loki Software][4] for the Linux version of popular games ; - * The [nVidia drivers][5] for Linux - * The installer for the Linux version of [Google Earth][6] - * The [VirtualBox][7] installers for Linux - * The [Makeself][1] distribution itself ;-) - * and countless others... - -**Important note for Apache users:** By default, most Web servers will think that Makeself archives are regular text files and thus they may show up as text in a Web browser. The correct way to prevent this is to add a MIME type for this file format, like so (in httpd.conf) : - -`AddType application/x-makeself .run` - -**Important note for certain GNU/Linux distributions:** Archives created with Makeself prior to v2.1.2 were using an old syntax for the _head_ and _tail_ Unix commands that is being progressively obsoleted in their GNU forms. Therefore you may have problems uncompressing some of these archives. A workaround for this is to set the environment variable $_POSIX2_VERSION to enable the old syntax, i.e. : - -`export _POSIX2_VERSION=199209` - -## Usage - -The syntax of makeself is the following: - -``` -makeself.sh [args] archive_dir file_name label startup_script [script_args] -``` - - * _args_ are optional options for Makeself. The available ones are : - - * **`--version`** : Prints the version number on stdout, then exits immediately - * **`--gzip`** : Use gzip for compression (the default on platforms on which gzip is commonly available, like Linux) - * **`--bzip2`** : Use bzip2 instead of gzip for better compression. The bzip2 command must be available in the command path. It is recommended that the archive prefix be set to something like '.bz2.run', so that potential users know that they'll need bzip2 to extract it. - * **`--pbzip2`** : Use pbzip2 instead of gzip for better and faster compression on machines having multiple CPUs. The pbzip2 command must be available in the command path. It is recommended that the archive prefix be set to something like '.bz2.run', so that potential users know that they'll need bzip2 to extract it. - * **`--xz`** : Use xz instead of gzip for better compression. The xz command must be available in the command path. It is recommended that the archive prefix be set to something like '.xz.run' for the archive, so that potential users know that they'll need xz to extract it. - * **`--lzo`** : Use lzop instead of gzip for better compression. The lzop command must be available in the command path. It is recommended that the archive prefix be set to something like `.lzo.run` for the archive, so that potential users know that they'll need lzop to extract it. - * **`--lz4`** : Use lz4 instead of gzip for better compression. The lz4 command must be available in the command path. It is recommended that the archive prefix be set to something like '.lz4.run' for the archive, so that potential users know that they'll need lz4 to extract it. - * **`--zstd`** : Use zstd instead of gzip for better compression. The zstd command must be available in the command path. It is recommended that the archive prefix be set to something like '.zstd.run' for the archive, so that potential users know that they'll need zstd to extract it. - * **`--pigz`** : Use pigz for compression. - * **`--base64`** : Encode the archive to ASCII in Base64 format instead of compressing (base64 command required). - * **`--gpg-encrypt`** : Encrypt the archive using `gpg -ac -z $COMPRESS_LEVEL`. This will prompt for a password to encrypt with. Assumes that potential users have `gpg` installed. - * **`--ssl-encrypt`** : Encrypt the archive using `openssl aes-256-cbc -a -salt`. This will prompt for a password to encrypt with. Assumes that the potential users have the OpenSSL tools installed. - * **`--compress`** : Use the UNIX `compress` command to compress the data. This should be the default on all platforms that don't have gzip available. - * **`--nocomp`** : Do not use any compression for the archive, which will then be an uncompressed TAR. - * **`--complevel`** : Specify the compression level for gzip, bzip2, pbzip2, zstd, xz, lzo or lz4. (defaults to 9) - * **`--threads`** : Specify the number of threads to be used by compressors that support parallelization. Omit to use compressor's default. Most useful (and required) for opting into xz's threading, usually with `--threads=0` for all available cores. pbzip2 and pigz are parallel by default, and setting this value allows limiting the number of threads they use. - * **`--notemp`** : The generated archive will not extract the files to a temporary directory, but in a new directory created in the current directory. This is better to distribute software packages that may extract and compile by themselves (i.e. launch the compilation through the embedded script). - * **`--current`** : Files will be extracted to the current directory, instead of in a subdirectory. This option implies `--notemp` above. - * **`--follow`** : Follow the symbolic links inside of the archive directory, i.e. store the files that are being pointed to instead of the links themselves. - * **`--append`** _(new in 2.1.x)_: Append data to an existing archive, instead of creating a new one. In this mode, the settings from the original archive are reused (compression type, label, embedded script), and thus don't need to be specified again on the command line. - * **`--header`** : Makeself uses a separate file to store the header stub, called `makeself-header.sh`. By default, it is assumed that it is stored in the same location as makeself.sh. This option can be used to specify its actual location if it is stored someplace else. - * **`--cleanup`** : Specify a script that is run when execution is interrupted or finishes successfully. The script is executed with the same environment and initial `script_args` as `startup_script`. - * **`--copy`** : Upon extraction, the archive will first extract itself to a temporary directory. The main application of this is to allow self-contained installers stored in a Makeself archive on a CD, when the installer program will later need to unmount the CD and allow a new one to be inserted. This prevents "Filesystem busy" errors for installers that span multiple CDs. - * **`--nox11`** : Disable the automatic spawning of a new terminal in X11. - * **`--nowait`** : When executed from a new X11 terminal, disable the user prompt at the end of the script execution. - * **`--nomd5`** and **`--nocrc`** : Disable the creation of a MD5 / CRC checksum for the archive. This speeds up the extraction process if integrity checking is not necessary. - * **`--sha256`** : Adds a SHA256 checksum for the archive. This is in addition to the MD5 / CRC checksums unless `--nomd5` is also used. - * **`--lsm` _file_** : Provide and LSM file to makeself, that will be embedded in the generated archive. LSM files are describing a software package in a way that is easily parseable. The LSM entry can then be later retrieved using the `--lsm` argument to the archive. An example of a LSM file is provided with Makeself. - * **`--tar-format opt`** : Specify the tar archive format (default is ustar); you may use any value accepted by your tar command (such as posix, v7, etc). - * **`--tar-extra opt`** : Append more options to the tar command line. - - For instance, in order to exclude the `.git` directory from the packaged archive directory using the GNU `tar`, one can use `makeself.sh --tar-extra "--exclude=.git" ...` - - * **`--keep-umask`** : Keep the umask set to shell default, rather than overriding when executing self-extracting archive. - * **`--packaging-date date`** : Use provided string as the packaging date instead of the current date. - * **`--license`** : Append a license file. - * **`--nooverwrite`** : Do not extract the archive if the specified target directory already exists. - * **`--help-header file`** : Add a header to the archive's `--help` output. - * `archive_dir` is the name of the directory that contains the files to be archived - * `file_name` is the name of the archive to be created - * `label` is an arbitrary text string describing the package. It will be displayed while extracting the files. - * `startup_script` is the command to be executed _from within_ the directory of extracted files. Thus, if you wish to execute a program contained in this directory, you must prefix your command with `./`. For example, `./program` will be fine. The `script_args` are additional arguments for this command. - -Here is an example, assuming the user has a package image stored in a **/home/joe/mysoft**, and he wants to generate a self-extracting package named -**mysoft.sh**, which will launch the "setup" script initially stored in /home/joe/mysoft : - -`makeself.sh /home/joe/mysoft mysoft.sh "Joe's Nice Software Package" ./setup -` - -Here is also how I created the [makeself.run][9] archive which contains the Makeself distribution : - -`makeself.sh --notemp makeself makeself.run "Makeself by Stephane Peter" echo "Makeself has extracted itself" ` - -Archives generated with Makeself can be passed the following arguments: - - * **`--keep`** : Prevent the files to be extracted in a temporary directory that will be removed after the embedded script's execution. The files will then be extracted in the current working directory and will stay here until you remove them. - * **`--verbose`** : Will prompt the user before executing the embedded command - * **`--target dir`** : Allows to extract the archive in an arbitrary place. - * **`--nox11`** : Do not spawn a X11 terminal. - * **`--confirm`** : Prompt the user for confirmation before running the embedded command. - * **`--info`** : Print out general information about the archive (does not extract). - * **`--lsm`** : Print out the LSM entry, if it is present. - * **`--list`** : List the files in the archive. - * **`--check`** : Check the archive for integrity using the embedded checksums. Does not extract the archive. - * **`--nochown`** : By default, a `chown -R` command is run on the target directory after extraction, so that all files belong to the current user. This is mostly needed if you are running as root, as tar will then try to recreate the initial user ownerships. You may disable this behavior with this flag. - * **`--tar`** : Run the tar command on the contents of the archive, using the following arguments as parameter for the command. - * **`--noexec`** : Do not run the embedded script after extraction. - * **`--noexec-cleanup`** : Do not run the embedded cleanup script. - * **`--nodiskspace`** : Do not check for available disk space before attempting to extract. - * **`--cleanup-args`** : Specify arguments to be passed to the cleanup script. Wrap value in quotes to specify multiple arguments. - -Any subsequent arguments to the archive will be passed as additional arguments to the embedded command. You must explicitly use the `--` special command-line construct before any such options to make sure that Makeself will not try to interpret them. - -## Startup Script - -The startup script must be a regular Shell script. - -Within the startup script, you can use the `$USER_PWD` variable to get the path of the folder from which the self-extracting script is executed. This is especially useful to access files that are located in the same folder as the script, as shown in the example below. - -`my-self-extracting-script.sh --fooBarFileParameter foo.bar` - -## Building and Testing - -Clone the git repo and execute `git submodule update --init --recursive` to obtain all submodules. - -* To make a release: `make` -* To run all tests: `make test` - -## Maven Usage - -Makeself is now supported by the following maven plugin [makeself-maven-plugin](https://github.com/hazendaz/makeself-maven-plugin). Please refer to project for usage and report any bugs in regards to maven plugin on that project. - -## License - -Makeself itself is covered by the [GNU General Public License][8] (GPL) version 2 and above. Archives generated by Makeself don't have to be placed under this license (although I encourage it ;-)), since the archive itself is merely data for Makeself. - -## Contributing - -I will gladly consider merging your pull requests on the [GitHub][10] repository. However, please keep the following in mind: - - * One of the main purposes of Makeself is portability. Do not submit patches that will break supported platforms. The more platform-agnostic, the better. - * Please explain clearly what the purpose of the patch is, and how you achieved it. - -## Download - -Get the latest official distribution [here][9] (version 2.4.2). - -The latest development version can be grabbed from [GitHub][10]. Feel free to submit any patches there through the fork and pull request process. - -## Version history - - * **v1.0:** Initial public release - * **v1.1:** The archive can be passed parameters that will be passed on to the embedded script, thanks to John C. Quillan - * **v1.2:** Cosmetic updates, support for bzip2 compression and non-temporary archives. Many ideas thanks to Francois Petitjean. - * **v1.3:** More patches from Bjarni R. Einarsson and Francois Petitjean: Support for no compression (`--nocomp`), script is no longer mandatory, automatic launch in an xterm, optional verbose output, and -target archive option to indicate where to extract the files. - * **v1.4:** Many patches from Francois Petitjean: improved UNIX compatibility, automatic integrity checking, support of LSM files to get info on the package at run time.. - * **v1.5.x:** A lot of bugfixes, and many other patches, including automatic verification through the usage of checksums. Version 1.5.5 was the stable release for a long time, even though the Web page didn't get updated ;-). Makeself was also officially made a part of the [Loki Setup installer][11], and its source is being maintained as part of this package. - * **v2.0:** Complete internal rewrite of Makeself. The command-line parsing was vastly improved, the overall maintenance of the package was greatly improved by separating the stub from makeself.sh. Also Makeself was ported and tested to a variety of Unix platforms. - * **v2.0.1:** First public release of the new 2.0 branch. Prior versions are officially obsoleted. This release introduced the `--copy` argument that was introduced in response to a need for the [UT2K3][12] Linux installer. - * **v2.1.0:** Big change : Makeself can now support multiple embedded tarballs, each stored separately with their own checksums. An existing archive can be updated with the `--append` flag. Checksums are also better managed, and the `--nochown` option for archives appeared. - * **v2.1.1:** Fixes related to the Unix compression (compress command). Some Linux distributions made the insane choice to make it unavailable, even though gzip is capable of uncompressing these files, plus some more bugfixes in the extraction and checksum code. - * **v2.1.2:** Some bug fixes. Use head -n to avoid problems with POSIX conformance. - * **v2.1.3:** Bug fixes with the command line when spawning terminals. Added `--tar`, `--noexec` for archives. Added `--nomd5` and `--nocrc` to avoid creating checksums in archives. The embedded script is now run through "eval". The `--info` output now includes the command used to create the archive. A man page was contributed by Bartosz Fenski. - * **v2.1.4:** Fixed `--info` output. Generate random directory name when extracting files to . to avoid problems. Better handling of errors with wrong permissions for the directory containing the files. Avoid some race conditions, Unset the $CDPATH variable to avoid problems if it is set. Better handling of dot files in the archive directory. - * **v2.1.5:** Made the md5sum detection consistent with the header code. Check for the presence of the archive directory. Added `--encrypt` for symmetric encryption through gpg (Eric Windisch). Added support for the digest command on Solaris 10 for MD5 checksums. Check for available disk space before extracting to the target directory (Andreas Schweitzer). Allow extraction to run asynchronously (patch by Peter Hatch). Use file descriptors internally to avoid error messages (patch by Kay Tiong Khoo). - * **v2.1.6:** Replaced one dot per file progress with a realtime progress percentage and a spinning cursor. Added `--noprogress` to prevent showing the progress during the decompression. Added `--target` dir to allow extracting directly to a target directory. (Guy Baconniere) - * **v2.2.0:** First major new release in years! Includes many bugfixes and user contributions. Please look at the [project page on Github][10] for all the details. - * **v2.3.0:** Support for archive encryption via GPG or OpenSSL. Added LZO and LZ4 compression support. Options to set the packaging date and stop the umask from being overriden. Optionally ignore check for available disk space when extracting. New option to check for root permissions before extracting. - * **v2.3.1:** Various compatibility updates. Added unit tests for Travis CI in the GitHub repo. New `--tar-extra`, `--untar-extra`, `--gpg-extra`, `--gpg-asymmetric-encrypt-sign` options. - * **v2.4.0:** Added optional support for SHA256 archive integrity checksums. - * **v2.4.2:** New --cleanup and --cleanup-args arguments for cleanup scripts. Added threading support for supported compressors. Now supports zstd compression. - * **v2.4.3:** Make explicit POSIX tar archives for increased compatibility. - * **v2.4.4:** Fixed various compatibility issues (no longer use POSIX tar archives), Github Actions to check on Solaris and FreeBSD. - * **v2.4.5:** Added `--tar-format` option to set the tar archive format (default is ustar) - -## Links - - * Check out the ["Loki Setup"][11] installer, used to install many Linux games and other applications, and of which I am the co-author. Since the demise of Loki, I am now the official maintainer of the project, and it is now being hosted here on GitHub. - * Bjarni R. Einarsson also wrote the **setup.sh** installer script, inspired by Makeself. [Check it out !][14] - -## Contact - -This script was written by [Stéphane Peter][15] (megastep at megastep.org). Any enhancements and suggestions are welcome. - -Contributions were included from John C. Quillan, Bjarni R. Einarsson, -Francois Petitjean, Ryan C. Gordon, and many contributors on GitHub. If you think I forgot -your name, don't hesitate to contact me. - -This project is now hosted on GitHub. Feel free to submit patches and bug reports on the [project page][10]. - -* * * - -[Stephane Peter][2] - - [1]: http://makeself.io/ - [2]: mailto:megastep@megastep.org - [3]: http://www.idsoftware.com/ - [4]: http://www.lokigames.com/products/myth2/updates.php3 - [5]: http://www.nvidia.com/ - [6]: http://earth.google.com/ - [7]: http://www.virtualbox.org/ - [8]: http://www.gnu.org/copyleft/gpl.html - [9]: https://github.com/megastep/makeself/releases/download/release-2.4.5/makeself-2.4.5.run - [10]: https://github.com/megastep/makeself - [11]: https://github.com/megastep/loki_setup/ - [12]: http://www.unrealtournament2003.com/ - [13]: http://www.icculus.org/ - [14]: http://bre.klaki.net/programs/setup.sh/ - [15]: https://stephanepeter.com/ diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/VERSION b/cust_op/cust_op_by_addr/cmake/util/makeself/VERSION deleted file mode 100644 index 59aa62c1..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/makeself/VERSION +++ /dev/null @@ -1 +0,0 @@ -2.4.5 diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/make-release.sh b/cust_op/cust_op_by_addr/cmake/util/makeself/make-release.sh deleted file mode 100644 index b5692d49..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/makeself/make-release.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/sh -# -# Create a distributable archive of the current version of Makeself - -VER=`cat VERSION` -mkdir -p /tmp/makeself-$VER release -cp -pPR makeself* test README.md COPYING VERSION .gitmodules /tmp/makeself-$VER/ -./makeself.sh --notemp /tmp/makeself-$VER release/makeself-$VER.run "Makeself v$VER" echo "Makeself has extracted itself" - diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself-header.sh b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself-header.sh deleted file mode 100644 index 94090314..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself-header.sh +++ /dev/null @@ -1,660 +0,0 @@ -cat << EOF > "$archname" -#!/bin/bash -# This script was generated using Makeself $MS_VERSION -# The license covering this archive and its contents, if any, is wholly independent of the Makeself license (GPL) -# 2022.3.19-Modified the MS_Help function and some options -# Huawei Technologies Co., Ltd. - -ORIG_UMASK=\`umask\` - -CRCsum="$CRCsum" -MD5="$MD5sum" -SHA="$SHAsum" -SIGNATURE="$Signature" -TMPROOT=\${TMPDIR:="\$HOME"} -if ! test -d "\$TMPROOT"; then - TMPROOT="\$PWD" -fi -export TMPDIR="\$TMPROOT" -USER_PWD="\$PWD" -if ! test -d "\$USER_PWD"; then - exit 1 -fi -export USER_PWD -ARCHIVE_DIR=\`dirname "\$0"\` -export ARCHIVE_DIR - -name_of_file="\$0 " -pwd_of_file="\$PWD" -label="$LABEL" -script="$SCRIPT" -scriptargs="$SCRIPTARGS" -cleanup_script="${CLEANUP_SCRIPT}" -licensetxt="$LICENSE" -helpheader='$HELPHEADER' -targetdir="$archdirname" -filesizes="$filesizes" -totalsize="$totalsize" -keep="$KEEP" -nooverwrite="$NOOVERWRITE" -quiet="n" -accept="n" -nodiskspace="n" -export_conf="$EXPORT_CONF" -decrypt_cmd="$DECRYPT_CMD" -skip="$SKIP" - -print_cmd_arg="" -if type printf > /dev/null; then - print_cmd="printf" -elif test -x /usr/ucb/echo; then - print_cmd="/usr/ucb/echo" -else - print_cmd="echo" -fi - -if test -d /usr/xpg4/bin; then - PATH=/usr/xpg4/bin:\$PATH - export PATH -fi - -if test -d /usr/sfw/bin; then - PATH=\$PATH:/usr/sfw/bin - export PATH -fi - -unset CDPATH - -MS_Printf() -{ - \$print_cmd \$print_cmd_arg "\$1" -} - -MS_PrintLicense() -{ - PAGER=\${PAGER:=more} - if test x"\$licensetxt" != x; then - PAGER_PATH=\`exec <&- 2>&-; which \$PAGER || command -v \$PAGER || type \$PAGER\` - if test -x "\$PAGER_PATH"; then - echo "\$licensetxt" | \$PAGER - else - echo "\$licensetxt" - fi - if test x"\$accept" != xy; then - while true - do - MS_Printf "Please type y to accept, n otherwise: " - read yn - if test x"\$yn" = xn; then - keep=n - eval \$finish; exit 1 - break; - elif test x"\$yn" = xy; then - break; - fi - done - fi - fi -} - -MS_diskspace() -{ - ( - df -kP "\$1" | tail -1 | awk '{ if (\$4 ~ /%/) {print \$3} else {print \$4} }' - ) -} - -MS_dd() -{ - blocks=\`expr \$3 / 1024\` - bytes=\`expr \$3 % 1024\` - # Test for ibs, obs and conv feature - if dd if=/dev/zero of=/dev/null count=1 ibs=512 obs=512 conv=sync 2> /dev/null; then - dd if="\$1" ibs=\$2 skip=1 obs=1024 conv=sync 2> /dev/null | \\ - { test \$blocks -gt 0 && dd ibs=1024 obs=1024 count=\$blocks ; \\ - test \$bytes -gt 0 && dd ibs=1 obs=1024 count=\$bytes ; } 2> /dev/null - else - dd if="\$1" bs=\$2 skip=1 2> /dev/null - fi -} - -MS_dd_Progress() -{ - if test x"\$noprogress" = xy; then - MS_dd "\$@" - return \$? - fi - file="\$1" - offset=\$2 - length=\$3 - pos=0 - bsize=4194304 - while test \$bsize -gt \$length; do - bsize=\`expr \$bsize / 4\` - done - blocks=\`expr \$length / \$bsize\` - bytes=\`expr \$length % \$bsize\` - ( - dd ibs=\$offset skip=1 2>/dev/null - pos=\`expr \$pos \+ \$bsize\` - MS_Printf " 0%% " 1>&2 - if test \$blocks -gt 0; then - while test \$pos -le \$length; do - dd bs=\$bsize count=1 2>/dev/null - pcent=\`expr \$length / 100\` - pcent=\`expr \$pos / \$pcent\` - if test \$pcent -lt 100; then - MS_Printf "\b\b\b\b\b\b\b" 1>&2 - if test \$pcent -lt 10; then - MS_Printf " \$pcent%% " 1>&2 - else - MS_Printf " \$pcent%% " 1>&2 - fi - fi - pos=\`expr \$pos \+ \$bsize\` - done - fi - if test \$bytes -gt 0; then - dd bs=\$bytes count=1 2>/dev/null - fi - MS_Printf "\b\b\b\b\b\b\b" 1>&2 - MS_Printf " 100%% " 1>&2 - ) < "\$file" -} - -MS_Help() -{ - cat << EOH >&2 -Usage: \$0 [options] -Options: - --help | -h Print this message - --info Print embedded info : title, default target directory, embedded script ... - --list Print the list of files in the archive - --check Checks integrity and version dependency of the archive - --quiet Quiet install mode, skip human-computer interactions - --nox11 Do not spawn an xterm - --noexec Do not run embedded script - --extract= Extract directly to a target directory (absolute or relative) - Usually used with --noexec to just extract files without running - --tar arg1 [arg2 ...] Access the contents of the archive through the tar command -\${helpheader} -EOH -} - -MS_Verify_Sig() -{ - GPG_PATH=\`exec <&- 2>&-; which gpg || command -v gpg || type gpg\` - MKTEMP_PATH=\`exec <&- 2>&-; which mktemp || command -v mktemp || type mktemp\` - test -x "\$GPG_PATH" || GPG_PATH=\`exec <&- 2>&-; which gpg || command -v gpg || type gpg\` - test -x "\$MKTEMP_PATH" || MKTEMP_PATH=\`exec <&- 2>&-; which mktemp || command -v mktemp || type mktemp\` - offset=\`head -n "\$skip" "\$1" | wc -c | tr -d " "\` - temp_sig=\`mktemp -t XXXXX\` - echo \$SIGNATURE | base64 --decode > "\$temp_sig" - gpg_output=\`MS_dd "\$1" \$offset \$totalsize | LC_ALL=C "\$GPG_PATH" --verify "\$temp_sig" - 2>&1\` - gpg_res=\$? - rm -f "\$temp_sig" - if test \$gpg_res -eq 0 && test \`echo \$gpg_output | grep -c Good\` -eq 1; then - if test \`echo \$gpg_output | grep -c \$sig_key\` -eq 1; then - test x"\$quiet" = xn && echo "GPG signature is good" >&2 - else - echo "GPG Signature key does not match" >&2 - exit 2 - fi - else - test x"\$quiet" = xn && echo "GPG signature failed to verify" >&2 - exit 2 - fi -} - -MS_Check() -{ - OLD_PATH="\$PATH" - PATH=\${GUESS_MD5_PATH:-"\$OLD_PATH:/bin:/usr/bin:/sbin:/usr/local/ssl/bin:/usr/local/bin:/opt/openssl/bin"} - MD5_ARG="" - MD5_PATH=\`exec <&- 2>&-; which md5sum || command -v md5sum || type md5sum\` - test -x "\$MD5_PATH" || MD5_PATH=\`exec <&- 2>&-; which md5 || command -v md5 || type md5\` - test -x "\$MD5_PATH" || MD5_PATH=\`exec <&- 2>&-; which digest || command -v digest || type digest\` - PATH="\$OLD_PATH" - - SHA_PATH=\`exec <&- 2>&-; which shasum || command -v shasum || type shasum\` - test -x "\$SHA_PATH" || SHA_PATH=\`exec <&- 2>&-; which sha256sum || command -v sha256sum || type sha256sum\` - - if test x"\$quiet" = xn; then - MS_Printf "Verifying archive integrity..." - fi - offset=\`head -n "\$skip" "\$1" | wc -c | tr -d " "\` - fsize=\`cat "\$1" | wc -c | tr -d " "\` - if test \$totalsize -ne \`expr \$fsize - \$offset\`; then - echo " Unexpected archive size." >&2 - exit 2 - fi - verb=\$2 - i=1 - for s in \$filesizes - do - crc=\`echo \$CRCsum | cut -d" " -f\$i\` - if test -x "\$SHA_PATH"; then - if test x"\`basename \$SHA_PATH\`" = xshasum; then - SHA_ARG="-a 256" - fi - sha=\`echo \$SHA | cut -d" " -f\$i\` - if test x"\$sha" = x0000000000000000000000000000000000000000000000000000000000000000; then - test x"\$verb" = xy && echo " \$1 does not contain an embedded SHA256 checksum." >&2 - else - shasum=\`MS_dd_Progress "\$1" \$offset \$s | eval "\$SHA_PATH \$SHA_ARG" | cut -b-64\`; - if test x"\$shasum" != x"\$sha"; then - echo "Error in SHA256 checksums: \$shasum is different from \$sha" >&2 - exit 2 - elif test x"\$quiet" = xn; then - MS_Printf " SHA256 checksums are OK." >&2 - fi - crc="0000000000"; - fi - fi - if test -x "\$MD5_PATH"; then - if test x"\`basename \$MD5_PATH\`" = xdigest; then - MD5_ARG="-a md5" - fi - md5=\`echo \$MD5 | cut -d" " -f\$i\` - if test x"\$md5" = x00000000000000000000000000000000; then - test x"\$verb" = xy && echo " \$1 does not contain an embedded MD5 checksum." >&2 - else - md5sum=\`MS_dd_Progress "\$1" \$offset \$s | eval "\$MD5_PATH \$MD5_ARG" | cut -b-32\`; - if test x"\$md5sum" != x"\$md5"; then - echo "Error in MD5 checksums: \$md5sum is different from \$md5" >&2 - exit 2 - elif test x"\$quiet" = xn; then - MS_Printf " MD5 checksums are OK." >&2 - fi - crc="0000000000"; verb=n - fi - fi - if test x"\$crc" = x0000000000; then - test x"\$verb" = xy && echo " \$1 does not contain a CRC checksum." >&2 - else - sum1=\`MS_dd_Progress "\$1" \$offset \$s | CMD_ENV=xpg4 cksum | awk '{print \$1}'\` - if test x"\$sum1" != x"\$crc"; then - echo "Error in checksums: \$sum1 is different from \$crc" >&2 - exit 2 - elif test x"\$quiet" = xn; then - MS_Printf " CRC checksums are OK." >&2 - fi - fi - i=\`expr \$i + 1\` - offset=\`expr \$offset + \$s\` - done - if test x"\$quiet" = xn; then - echo " All good." - fi -} - -MS_Decompress() -{ - if test x"\$decrypt_cmd" != x""; then - { eval "\$decrypt_cmd" || echo " ... Decryption failed." >&2; } | eval "$GUNZIP_CMD" - else - eval "$GUNZIP_CMD" - fi - - if test \$? -ne 0; then - echo " ... Decompression failed." >&2 - fi -} - -UnTAR() -{ - if test x"\$quiet" = xn; then - tar \$1vf - $UNTAR_EXTRA 2>&1 || { echo " ... Extraction failed." >&2; kill -15 \$$; } - else - tar \$1f - $UNTAR_EXTRA 2>&1 || { echo Extraction failed. >&2; kill -15 \$$; } - fi -} - -MS_exec_cleanup() { - if test x"\$cleanup" = xy && test x"\$cleanup_script" != x""; then - cleanup=n - cd "\$tmpdir" - eval "\"\$cleanup_script\" \$scriptargs \$cleanupargs" - fi -} - -MS_cleanup() -{ - echo 'Signal caught, cleaning up' >&2 - MS_exec_cleanup - cd "\$TMPROOT" - rm -rf "\$tmpdir" - eval \$finish; exit 15 -} - -Script_Args_Check() -{ - script_supported_args=\$(echo \${helpheader} | grep -o -E "\-\-[^ ]+" | awk -F"=" {'print \$1'}) - arg_to_test=\$(echo \$1|awk -F"=" {'print \$1'}) - - for arg in \${script_supported_args}; - do - if test x"\$arg_to_test" = x"\$arg" ;then - return - fi - done - - MS_Help - exit 1 -} - -finish=true -xterm_loop= -noprogress=$NOPROGRESS -nox11=$NOX11 -copy=$COPY -ownership=$OWNERSHIP -verbose=n -cleanup=y -cleanupargs= -sig_key= - -initargs="\$@" - -while [ -n "\$*" ] -do - case "\$1" in - -h | --help) - MS_Help - exit 0 - ;; - -q | --quiet) - quiet=y - noprogress=y - shift - ;; - --info) - echo Identification: "\$label" - echo Target directory: "\$targetdir" - echo Uncompressed size: $USIZE KB - echo Compression: $COMPRESS - if test x"$ENCRYPT" != x""; then - echo Encryption: $ENCRYPT - fi - echo Date of packaging: $DATE - echo Built with Makeself version $MS_VERSION - echo Build command was: "$MS_COMMAND" - if test x"\$script" != x; then - echo Script run after extraction: - echo " " \$script \$scriptargs - fi - if test x"$copy" = xcopy; then - echo "Archive will copy itself to a temporary location" - fi - if test x"$NEED_ROOT" = xy; then - echo "Root permissions required for extraction" - fi - if test x"$KEEP" = xy; then - echo "directory \$targetdir is permanent" - else - echo "\$targetdir will be removed after extraction" - fi - exit 0 - ;; - --list) - echo Target directory: \$targetdir - offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` - for s in \$filesizes - do - MS_dd "\$0" \$offset \$s | MS_Decompress | UnTAR t - offset=\`expr \$offset + \$s\` - done - exit 0 - ;; - --tar) - offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` - arg1="\$2" - shift 2 || { MS_Help; exit 1; } - for s in \$filesizes - do - MS_dd "\$0" \$offset \$s | MS_Decompress | tar "\$arg1" - "\$@" - offset=\`expr \$offset + \$s\` - done - exit 0 - ;; - --check) - MS_Check "\$0" y - scriptargs="\$scriptargs \$1" - shift - ;; - --noexec) - script="" - cleanup_script="" - shift - ;; - --extract=*) - keep=y - targetdir=\`echo \$1 | cut -d"=" -f2 \` - if ! shift; then MS_Help; exit 1; fi - ;; - --nox11) - nox11=y - shift - ;; - --xwin) - if test "$NOWAIT" = n; then - finish="echo Press Return to close this window...; read junk" - fi - xterm_loop=1 - shift - ;; - --phase2) - copy=phase2 - shift - ;; - --repack | --repack-path=*) - Script_Args_Check \$1 - scriptargs="\$scriptargs '\$1'" - shift - if [[ ! "\$1" =~ ^-.* ]]; then - scriptargs="\$scriptargs '\$1'" - shift - fi - ;; - *) - Script_Args_Check \$1 - scriptargs="\$scriptargs '\$1'" - shift - ;; - esac -done - -quiet_para="" -if test x"\$quiet" = xy; then - quiet_para="--quiet " -fi -scriptargs="--\$name_of_file""--\"\$pwd_of_file\""" \$quiet_para""\$scriptargs" - -if test x"\$quiet" = xy -a x"\$verbose" = xy; then - echo Cannot be verbose and quiet at the same time. >&2 - exit 1 -fi - -if test x"$NEED_ROOT" = xy -a \`id -u\` -ne 0; then - echo "Administrative privileges required for this archive (use su or sudo)" >&2 - exit 1 -fi - -if test x"\$copy" \!= xphase2; then - MS_PrintLicense -fi - -case "\$copy" in -copy) - tmpdir="\$TMPROOT"/makeself.\$RANDOM.\`date +"%y%m%d%H%M%S"\`.\$\$ - mkdir "\$tmpdir" || { - echo "Could not create temporary directory \$tmpdir" >&2 - exit 1 - } - SCRIPT_COPY="\$tmpdir/makeself" - echo "Copying to a temporary location..." >&2 - cp "\$0" "\$SCRIPT_COPY" - chmod +x "\$SCRIPT_COPY" - cd "\$TMPROOT" - exec "\$SCRIPT_COPY" --phase2 -- \$initargs - ;; -phase2) - finish="\$finish ; rm -rf \`dirname \$0\`" - ;; -esac - -if test x"\$nox11" = xn; then - if tty -s; then # Do we have a terminal? - : - else - if test x"\$DISPLAY" != x -a x"\$xterm_loop" = x; then # No, but do we have X? - if xset q > /dev/null 2>&1; then # Check for valid DISPLAY variable - GUESS_XTERMS="xterm gnome-terminal rxvt dtterm eterm Eterm xfce4-terminal lxterminal kvt konsole aterm terminology" - for a in \$GUESS_XTERMS; do - if type \$a >/dev/null 2>&1; then - XTERM=\$a - break - fi - done - chmod a+x \$0 || echo Please add execution rights on \$0 - if test \`echo "\$0" | cut -c1\` = "/"; then # Spawn a terminal! - exec \$XTERM -e "\$0 --xwin \$initargs" - else - exec \$XTERM -e "./\$0 --xwin \$initargs" - fi - fi - fi - fi -fi - -if test x"\$targetdir" = x.; then - tmpdir="." -else - if test x"\$keep" = xy; then - if test x"\$nooverwrite" = xy && test -d "\$targetdir"; then - echo "Target directory \$targetdir already exists, aborting." >&2 - exit 1 - fi - if test x"\$quiet" = xn; then - echo "Creating directory \$targetdir" >&2 - fi - tmpdir="\$targetdir" - dashp="-p" - else - tmpdir="\$TMPROOT/selfgz\$\$\$RANDOM" - dashp="" - fi - mkdir \$dashp "\$tmpdir" || { - echo 'Cannot create target directory' \$tmpdir >&2 - echo 'You should try option --extract=' >&2 - eval \$finish - exit 1 - } -fi - -location="\`pwd\`" -if test x"\$SETUP_NOCHECK" != x1; then - MS_Check "\$0" -fi -offset=\`head -n "\$skip" "\$0" | wc -c | tr -d " "\` - -if test x"\$verbose" = xy; then - MS_Printf "About to extract $USIZE KB in \$tmpdir ... Proceed ? [Y/n] " - read yn - if test x"\$yn" = xn; then - eval \$finish; exit 1 - fi -fi - -if test x"\$quiet" = xn; then - # Decrypting with openssl will ask for password, - # the prompt needs to start on new line - if test x"$ENCRYPT" = x"openssl"; then - echo "Decrypting and uncompressing \$label..." - else - MS_Printf "Uncompressing \$label" - fi -fi -res=3 -if test x"\$keep" = xn; then - trap MS_cleanup 1 2 3 15 -fi - -if test x"\$nodiskspace" = xn; then - leftspace=\`MS_diskspace "\$tmpdir"\` - if test -n "\$leftspace"; then - if test "\$leftspace" -lt $USIZE; then - echo - echo "Not enough space left in "\`dirname \$tmpdir\`" (\$leftspace KB) to decompress \$0 ($USIZE KB)" >&2 - if test x"\$keep" = xn; then - echo "Consider setting TMPDIR to a directory with more free space." - fi - eval \$finish; exit 1 - fi - fi -fi - -for s in \$filesizes -do - if MS_dd_Progress "\$0" \$offset \$s | MS_Decompress | ( cd "\$tmpdir"; umask \$ORIG_UMASK ; UnTAR xp ) 1>/dev/null; then - if test x"\$ownership" = xy; then - (cd "\$tmpdir"; chown -R \`id -u\` .; chgrp -R \`id -g\` .) - fi - else - echo >&2 - echo "Unable to decompress \$0" >&2 - eval \$finish; exit 1 - fi - offset=\`expr \$offset + \$s\` -done -if test x"\$quiet" = xn; then - echo -fi - -cd "\$tmpdir" -res=0 -if test x"\$script" != x; then - if test x"\$export_conf" = x"y"; then - MS_BUNDLE="\$0" - MS_LABEL="\$label" - MS_SCRIPT="\$script" - MS_SCRIPTARGS="\$scriptargs" - MS_ARCHDIRNAME="\$archdirname" - MS_KEEP="\$KEEP" - MS_NOOVERWRITE="\$NOOVERWRITE" - MS_COMPRESS="\$COMPRESS" - MS_CLEANUP="\$cleanup" - export MS_BUNDLE MS_LABEL MS_SCRIPT MS_SCRIPTARGS - export MS_ARCHDIRNAME MS_KEEP MS_NOOVERWRITE MS_COMPRESS - fi - - if test x"\$verbose" = x"y"; then - yn="x" - while test x"\$yn" != x -a x"\$yn" != xy -a x"\$yn" != xY -a x"\$yn" != xn -a x"\$yn" != xN - do - MS_Printf "OK to execute: \$script \$scriptargs \$* ? [Y/n] " - read yn - if test x"\$yn" = x -o x"\$yn" = xy -o x"\$yn" = xY; then - eval "\"\$script\" \$scriptargs \"\\\$@\""; res=\$?; - elif test x"\$yn" = xn -o x"\$yn" = xN; then - echo "Unable to decompress \$script ,because of aborting! ";res=\$? - else - echo "Input value is unacceptable,please try again." - fi - done - else - eval "\"\$script\" \$scriptargs \"\\\$@\""; res=\$? - fi - if test "\$res" -ne 0; then - test x"\$verbose" = xy && echo "The program '\$script' returned an error code (\$res)" >&2 - fi -fi - -MS_exec_cleanup - -if test x"\$keep" = xn; then - cd "\$TMPROOT" - rm -rf "\$tmpdir" -fi -eval \$finish; exit \$res -EOF diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.1 b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.1 deleted file mode 100644 index 81bf6e4f..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.1 +++ /dev/null @@ -1,110 +0,0 @@ -.TH "MAKESELF" "1" "2.4.5" -.SH "NAME" -makeself \- An utility to generate self-extractable archives. -.SH "SYNTAX" -.B makeself [\fIoptions\fP] archive_dir file_name label -.B [\fIstartup_script\fP] [\fIargs\fP] -.SH "DESCRIPTION" -This program is a free (GPL) utility designed to create self-extractable -archives from a directory. -.SH "OPTIONS" -The following options are supported. -.TP 15 -.B -v, --version -Prints out the makeself version number and exits. -.TP -.B -h, --help -Print out help information. -.TP -.B --tar-quietly -Suppress verbose output from the tar command -.TP -.B --quiet -Do not print any messages other than errors -.TP -.B --gzip -Compress using gzip (default if detected). -.TP -.B --bzip2 -Compress using bzip2. -.TP -.B --pbzip2 -Compress using pbzip2. -.TP -.B --xz -Compress using xz. -.TP -.B --lzo -Compress using lzop. -.TP -.B --lz4 -Compress using lz4. -.TP -.B --compress -Compress using the UNIX 'compress' command. -.TP -.B --nocomp -Do not compress the data. -.TP -.B --complevel lvl -Specify the compression level for gzip,bzip2,pbzui2,xz,lzo or lz4 -.TP -.B --notemp -The archive will create archive_dir in the current directory and -uncompress in ./archive_dir. -.TP -.B --copy -Upon extraction, the archive will first copy itself to a temporary directory. -.TP -.B --append -Append more files to an existing makeself archive. The label and startup scripts will then be ignored. -.TP -.B --current -Files will be extracted to the current directory. Both --current and --target dir imply --notemp. -.TP -.B --target dir -Extract directly to a target directory. Directory path can be either absolute or relative. -.TP -.B --header file -Specify location of the header script. -.TP -.B --cleanup file -Specify a cleanup script that executes on interrupt and when finished successfully. -.TP -.B --follow -Follow the symlinks in the archive. -.TP -.B --noprogress -Do not show the progress during the decompression. -.TP -.B --nox11 -Disable automatic spawn of an xterm if running in X11. -.TP -.B --nowait -Do not wait for user input after executing embedded program from an xterm. -.TP -.B --nomd5 -Do not create a MD5 checksum for the archive. -.TP -.B --nocrc -Do not create a CRC32 checksum for the archive. -.TP -.B --lsm file -LSM file describing the package. -.B --packaging-date date -Use provided string as the packaging date instead of the current date. -.SH "EXAMPLES" -Here is an example, assuming the user has a package image stored in a /home/joe/mysoft, -and he wants to generate a self-extracting package named mysoft.sh, which will launch -the "setup" script initially stored in /home/joe/mysoft: -.TP -makeself.sh /home/joe/mysoft mysoft.sh "Joe's Nice Software Package" ./setup -.TP -Here is also how I created the makeself.run archive which contains the Makeself distribution: -.TP -makeself.sh --notemp makeself makeself.run "Makeself by Stephane Peter" echo "Makeself has extracted itself" -.SH "AUTHORS" -Makeself has been written by Stéphane Peter . -.BR -This man page was originally written by Bartosz Fenski for the -Debian GNU/Linux distribution (but it may be used by others). diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.lsm b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.lsm deleted file mode 100644 index 3c4cea8c..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.lsm +++ /dev/null @@ -1,16 +0,0 @@ -Begin3 -Title: makeself.sh -Version: 2.4.5 -Description: makeself.sh is a shell script that generates a self-extractable - tar.gz archive from a directory. The resulting file appears as a shell - script, and can be launched as is. The archive will then uncompress - itself to a temporary directory and an arbitrary command will be - executed (for example an installation script). This is pretty similar - to archives generated with WinZip Self-Extractor in the Windows world. -Keywords: Installation archive tar winzip -Author: Stephane Peter (megastep@megastep.org) -Maintained-by: Stephane Peter (megastep@megastep.org) -Original-site: https://makeself.io/ -Platform: Unix -Copying-policy: GPL -End diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.sh b/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.sh deleted file mode 100644 index c8ea5659..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/makeself/makeself.sh +++ /dev/null @@ -1,822 +0,0 @@ -#!/bin/sh -# -# Makeself version 2.4.x -# by Stephane Peter -# -# Utility to create self-extracting tar.gz archives. -# The resulting archive is a file holding the tar.gz archive with -# a small Shell script stub that uncompresses the archive to a temporary -# directory and then executes a given script from withing that directory. -# -# Makeself home page: https://makeself.io/ -# -# Version 2.0 is a rewrite of version 1.0 to make the code easier to read and maintain. -# -# Version history : -# - 1.0 : Initial public release -# - 1.1 : The archive can be passed parameters that will be passed on to -# the embedded script, thanks to John C. Quillan -# - 1.2 : Package distribution, bzip2 compression, more command line options, -# support for non-temporary archives. Ideas thanks to Francois Petitjean -# - 1.3 : More patches from Bjarni R. Einarsson and Francois Petitjean: -# Support for no compression (--nocomp), script is no longer mandatory, -# automatic launch in an xterm, optional verbose output, and -target -# archive option to indicate where to extract the files. -# - 1.4 : Improved UNIX compatibility (Francois Petitjean) -# Automatic integrity checking, support of LSM files (Francois Petitjean) -# - 1.5 : Many bugfixes. Optionally disable xterm spawning. -# - 1.5.1 : More bugfixes, added archive options -list and -check. -# - 1.5.2 : Cosmetic changes to inform the user of what's going on with big -# archives (Quake III demo) -# - 1.5.3 : Check for validity of the DISPLAY variable before launching an xterm. -# More verbosity in xterms and check for embedded command's return value. -# Bugfix for Debian 2.0 systems that have a different "print" command. -# - 1.5.4 : Many bugfixes. Print out a message if the extraction failed. -# - 1.5.5 : More bugfixes. Added support for SETUP_NOCHECK environment variable to -# bypass checksum verification of archives. -# - 1.6.0 : Compute MD5 checksums with the md5sum command (patch from Ryan Gordon) -# - 2.0 : Brand new rewrite, cleaner architecture, separated header and UNIX ports. -# - 2.0.1 : Added --copy -# - 2.1.0 : Allow multiple tarballs to be stored in one archive, and incremental updates. -# Added --nochown for archives -# Stopped doing redundant checksums when not necesary -# - 2.1.1 : Work around insane behavior from certain Linux distros with no 'uncompress' command -# Cleaned up the code to handle error codes from compress. Simplified the extraction code. -# - 2.1.2 : Some bug fixes. Use head -n to avoid problems. -# - 2.1.3 : Bug fixes with command line when spawning terminals. -# Added --tar for archives, allowing to give arbitrary arguments to tar on the contents of the archive. -# Added --noexec to prevent execution of embedded scripts. -# Added --nomd5 and --nocrc to avoid creating checksums in archives. -# Added command used to create the archive in --info output. -# Run the embedded script through eval. -# - 2.1.4 : Fixed --info output. -# Generate random directory name when extracting files to . to avoid problems. (Jason Trent) -# Better handling of errors with wrong permissions for the directory containing the files. (Jason Trent) -# Avoid some race conditions (Ludwig Nussel) -# Unset the $CDPATH variable to avoid problems if it is set. (Debian) -# Better handling of dot files in the archive directory. -# - 2.1.5 : Made the md5sum detection consistent with the header code. -# Check for the presence of the archive directory -# Added --encrypt for symmetric encryption through gpg (Eric Windisch) -# Added support for the digest command on Solaris 10 for MD5 checksums -# Check for available disk space before extracting to the target directory (Andreas Schweitzer) -# Allow extraction to run asynchronously (patch by Peter Hatch) -# Use file descriptors internally to avoid error messages (patch by Kay Tiong Khoo) -# - 2.1.6 : Replaced one dot per file progress with a realtime progress percentage and a spining cursor (Guy Baconniere) -# Added --noprogress to prevent showing the progress during the decompression (Guy Baconniere) -# Added --target dir to allow extracting directly to a target directory (Guy Baconniere) -# - 2.2.0 : Many bugfixes, updates and contributions from users. Check out the project page on Github for the details. -# - 2.3.0 : Option to specify packaging date to enable byte-for-byte reproducibility. (Marc Pawlowsky) -# - 2.4.0 : Optional support for SHA256 checksums in archives. -# - 2.4.2 : Add support for threads for several compressors. (M. Limber) -# Added zstd support. -# - 2.4.3 : Make explicit POSIX tar archives for increased compatibility. -# - 2.4.5 : Added --tar-format to override ustar tar archive format -# -# (C) 1998-2021 by Stephane Peter -# -# This software is released under the terms of the GNU GPL version 2 and above -# Please read the license at http://www.gnu.org/copyleft/gpl.html -# Self-extracting archives created with this script are explictly NOT released under the term of the GPL -# - -MS_VERSION=2.4.5 -MS_COMMAND="$0" -unset CDPATH - -for f in ${1+"$@"}; do - MS_COMMAND="$MS_COMMAND \\\\ - \\\"$f\\\"" -done - -# For Solaris systems -if test -d /usr/xpg4/bin; then - PATH=/usr/xpg4/bin:$PATH - export PATH -fi - -# Procedures - -MS_Usage() -{ - echo "Usage: $0 [args] archive_dir file_name label startup_script [script_args]" - echo "args can be one or more of the following :" - echo " --version | -v : Print out Makeself version number and exit" - echo " --help | -h : Print out this help message" - echo " --tar-quietly : Suppress verbose output from the tar command" - echo " --quiet | -q : Do not print any messages other than errors." - echo " --gzip : Compress using gzip (default if detected)" - echo " --pigz : Compress with pigz" - echo " --zstd : Compress with zstd" - echo " --bzip2 : Compress using bzip2 instead of gzip" - echo " --pbzip2 : Compress using pbzip2 instead of gzip" - echo " --xz : Compress using xz instead of gzip" - echo " --lzo : Compress using lzop instead of gzip" - echo " --lz4 : Compress using lz4 instead of gzip" - echo " --compress : Compress using the UNIX 'compress' command" - echo " --complevel lvl : Compression level for gzip pigz zstd xz lzo lz4 bzip2 and pbzip2 (default 9)" - echo " --threads thds : Number of threads to be used by compressors that support parallelization." - echo " Omit to use compressor's default. Most useful (and required) for opting" - echo " into xz's threading, usually with '--threads=0' for all available cores." - echo " pbzip2 and pigz are parallel by default, and setting this value allows" - echo " limiting the number of threads they use." - echo " --base64 : Instead of compressing, encode the data using base64" - echo " --gpg-encrypt : Instead of compressing, encrypt the data using GPG" - echo " --gpg-asymmetric-encrypt-sign" - echo " : Instead of compressing, asymmetrically encrypt and sign the data using GPG" - echo " --gpg-extra opt : Append more options to the gpg command line" - echo " --ssl-encrypt : Instead of compressing, encrypt the data using OpenSSL" - echo " --ssl-passwd pass : Use the given password to encrypt the data using OpenSSL" - echo " --ssl-pass-src src : Use the given src as the source of password to encrypt the data" - echo " using OpenSSL. See \"PASS PHRASE ARGUMENTS\" in man openssl." - echo " If this option is not supplied, the user will be asked to enter" - echo " encryption password on the current terminal." - echo " --ssl-no-md : Do not use \"-md\" option not supported by older OpenSSL." - echo " --nochown : Do not give the target folder to the current user (default)" - echo " --chown : Give the target folder to the current user recursively" - echo " --nocomp : Do not compress the data" - echo " --notemp : The archive will create archive_dir in the" - echo " current directory and uncompress in ./archive_dir" - echo " --needroot : Check that the root user is extracting the archive before proceeding" - echo " --copy : Upon extraction, the archive will first copy itself to" - echo " a temporary directory" - echo " --append : Append more files to an existing Makeself archive" - echo " The label and startup scripts will then be ignored" - echo " --target dir : Extract directly to a target directory" - echo " directory path can be either absolute or relative" - echo " --nooverwrite : Do not extract the archive if the specified target directory exists" - echo " --current : Files will be extracted to the current directory" - echo " Both --current and --target imply --notemp" - echo " --tar-format opt : Specify a tar archive format (default is ustar)" - echo " --tar-extra opt : Append more options to the tar command line" - echo " --untar-extra opt : Append more options to the during the extraction of the tar archive" - echo " --nomd5 : Don't calculate an MD5 for archive" - echo " --nocrc : Don't calculate a CRC for archive" - echo " --sha256 : Compute a SHA256 checksum for the archive" - echo " --header file : Specify location of the header script" - echo " --cleanup file : Specify a cleanup script that executes on interrupt and when finished successfully." - echo " --follow : Follow the symlinks in the archive" - echo " --noprogress : Do not show the progress during the decompression" - echo " --nox11 : Disable automatic spawn of a xterm" - echo " --nowait : Do not wait for user input after executing embedded" - echo " program from an xterm" - echo " --sign passphrase : Signature private key to sign the package with" - echo " --lsm file : LSM file describing the package" - echo " --license file : Append a license file" - echo " --help-header file : Add a header to the archive's --help output" - echo " --packaging-date date" - echo " : Use provided string as the packaging date" - echo " instead of the current date." - echo - echo " --keep-umask : Keep the umask set to shell default, rather than overriding when executing self-extracting archive." - echo " --export-conf : Export configuration variables to startup_script" - echo - echo "Do not forget to give a fully qualified startup script name" - echo "(i.e. with a ./ prefix if inside the archive)." - exit 1 -} - -# Default settings -if type gzip >/dev/null 2>&1; then - COMPRESS=gzip -elif type compress >/dev/null 2>&1; then - COMPRESS=compress -else - echo "ERROR: missing commands: gzip, compress" >&2 - MS_Usage -fi -ENCRYPT=n -PASSWD="" -PASSWD_SRC="" -OPENSSL_NO_MD=n -COMPRESS_LEVEL=9 -DEFAULT_THREADS=123456 # Sentinel value -THREADS=$DEFAULT_THREADS -KEEP=n -CURRENT=n -NOX11=n -NOWAIT=n -APPEND=n -TAR_QUIETLY=n -KEEP_UMASK=n -QUIET=n -NOPROGRESS=n -COPY=none -NEED_ROOT=n -TAR_ARGS=rvf -TAR_FORMAT=ustar -TAR_EXTRA="" -GPG_EXTRA="" -DU_ARGS=-ks -HEADER=`dirname "$0"`/makeself-header.sh -SIGNATURE="" -TARGETDIR="" -NOOVERWRITE=n -DATE=`LC_ALL=C date` -EXPORT_CONF=n -SHA256=n -OWNERSHIP=n -SIGN=n -GPG_PASSPHRASE="" - -# LSM file stuff -LSM_CMD="echo No LSM. >> \"\$archname\"" - -while true -do - case "$1" in - --version | -v) - echo Makeself version $MS_VERSION - exit 0 - ;; - --pbzip2) - COMPRESS=pbzip2 - shift - ;; - --bzip2) - COMPRESS=bzip2 - shift - ;; - --gzip) - COMPRESS=gzip - shift - ;; - --pigz) - COMPRESS=pigz - shift - ;; - --zstd) - COMPRESS=zstd - shift - ;; - --xz) - COMPRESS=xz - shift - ;; - --lzo) - COMPRESS=lzo - shift - ;; - --lz4) - COMPRESS=lz4 - shift - ;; - --compress) - COMPRESS=compress - shift - ;; - --base64) - COMPRESS=base64 - shift - ;; - --gpg-encrypt) - COMPRESS=gpg - shift - ;; - --gpg-asymmetric-encrypt-sign) - COMPRESS=gpg-asymmetric - shift - ;; - --gpg-extra) - GPG_EXTRA="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --ssl-encrypt) - ENCRYPT=openssl - shift - ;; - --ssl-passwd) - PASSWD=$2 - shift 2 || { MS_Usage; exit 1; } - ;; - --ssl-pass-src) - PASSWD_SRC=$2 - shift 2 || { MS_Usage; exit 1; } - ;; - --ssl-no-md) - OPENSSL_NO_MD=y - shift - ;; - --nocomp) - COMPRESS=none - shift - ;; - --complevel) - COMPRESS_LEVEL="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --threads) - THREADS="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --nochown) - OWNERSHIP=n - shift - ;; - --chown) - OWNERSHIP=y - shift - ;; - --notemp) - KEEP=y - shift - ;; - --copy) - COPY=copy - shift - ;; - --current) - CURRENT=y - KEEP=y - shift - ;; - --tar-format) - TAR_FORMAT="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --tar-extra) - TAR_EXTRA="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --untar-extra) - UNTAR_EXTRA="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --target) - TARGETDIR="$2" - KEEP=y - shift 2 || { MS_Usage; exit 1; } - ;; - --sign) - SIGN=y - GPG_PASSPHRASE="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --nooverwrite) - NOOVERWRITE=y - shift - ;; - --needroot) - NEED_ROOT=y - shift - ;; - --header) - HEADER="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --cleanup) - CLEANUP_SCRIPT="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --license) - # We need to escape all characters having a special meaning in double quotes - LICENSE=$(sed 's/\\/\\\\/g; s/"/\\\"/g; s/`/\\\`/g; s/\$/\\\$/g' "$2") - shift 2 || { MS_Usage; exit 1; } - ;; - --follow) - TAR_ARGS=rvhf - DU_ARGS=-ksL - shift - ;; - --noprogress) - NOPROGRESS=y - shift - ;; - --nox11) - NOX11=y - shift - ;; - --nowait) - NOWAIT=y - shift - ;; - --nomd5) - NOMD5=y - shift - ;; - --sha256) - SHA256=y - shift - ;; - --nocrc) - NOCRC=y - shift - ;; - --append) - APPEND=y - shift - ;; - --lsm) - LSM_CMD="cat \"$2\" >> \"\$archname\"" - shift 2 || { MS_Usage; exit 1; } - ;; - --packaging-date) - DATE="$2" - shift 2 || { MS_Usage; exit 1; } - ;; - --help-header) - HELPHEADER=`sed -e "s/'/'\\\\\''/g" $2` - shift 2 || { MS_Usage; exit 1; } - [ -n "$HELPHEADER" ] && HELPHEADER="$HELPHEADER -" - ;; - --tar-quietly) - TAR_QUIETLY=y - shift - ;; - --keep-umask) - KEEP_UMASK=y - shift - ;; - --export-conf) - EXPORT_CONF=y - shift - ;; - -q | --quiet) - QUIET=y - shift - ;; - -h | --help) - MS_Usage - ;; - -*) - echo Unrecognized flag : "$1" - MS_Usage - ;; - *) - break - ;; - esac -done - -if test $# -lt 1; then - MS_Usage -else - if test -d "$1"; then - archdir="$1" - else - echo "Directory $1 does not exist." >&2 - exit 1 - fi -fi -archname="$2" - -if test "$QUIET" = "y" || test "$TAR_QUIETLY" = "y"; then - if test "$TAR_ARGS" = "rvf"; then - TAR_ARGS="rf" - elif test "$TAR_ARGS" = "rvhf"; then - TAR_ARGS="rhf" - fi -fi - -if test "$APPEND" = y; then - if test $# -lt 2; then - MS_Usage - fi - - # Gather the info from the original archive - OLDENV=`sh "$archname" --dumpconf` - if test $? -ne 0; then - echo "Unable to update archive: $archname" >&2 - exit 1 - else - eval "$OLDENV" - OLDSKIP=`expr $SKIP + 1` - fi -else - if test "$KEEP" = n -a $# = 3; then - echo "ERROR: Making a temporary archive with no embedded command does not make sense!" >&2 - echo >&2 - MS_Usage - fi - # We don't want to create an absolute directory unless a target directory is defined - if test "$CURRENT" = y; then - archdirname="." - elif test x"$TARGETDIR" != x; then - archdirname="$TARGETDIR" - else - archdirname=`basename "$1"` - fi - - if test $# -lt 3; then - MS_Usage - fi - - LABEL="$3" - SCRIPT="$4" - test "x$SCRIPT" = x || shift 1 - shift 3 - SCRIPTARGS="$*" -fi - -if test "$KEEP" = n -a "$CURRENT" = y; then - echo "ERROR: It is A VERY DANGEROUS IDEA to try to combine --notemp and --current." >&2 - exit 1 -fi - -case $COMPRESS in -gzip) - GZIP_CMD="gzip -c$COMPRESS_LEVEL" - GUNZIP_CMD="gzip -cd" - ;; -pigz) - GZIP_CMD="pigz -$COMPRESS_LEVEL" - if test $THREADS -ne $DEFAULT_THREADS; then # Leave as the default if threads not indicated - GZIP_CMD="$GZIP_CMD --processes $THREADS" - fi - GUNZIP_CMD="gzip -cd" - ;; -zstd) - GZIP_CMD="zstd -$COMPRESS_LEVEL" - if test $THREADS -ne $DEFAULT_THREADS; then # Leave as the default if threads not indicated - GZIP_CMD="$GZIP_CMD --threads=$THREADS" - fi - GUNZIP_CMD="zstd -cd" - ;; -pbzip2) - GZIP_CMD="pbzip2 -c$COMPRESS_LEVEL" - if test $THREADS -ne $DEFAULT_THREADS; then # Leave as the default if threads not indicated - GZIP_CMD="$GZIP_CMD -p$THREADS" - fi - GUNZIP_CMD="bzip2 -d" - ;; -bzip2) - GZIP_CMD="bzip2 -$COMPRESS_LEVEL" - GUNZIP_CMD="bzip2 -d" - ;; -xz) - GZIP_CMD="xz -c$COMPRESS_LEVEL" - # Must opt-in by specifying a value since not all versions of xz support threads - if test $THREADS -ne $DEFAULT_THREADS; then - GZIP_CMD="$GZIP_CMD --threads=$THREADS" - fi - GUNZIP_CMD="xz -d" - ;; -lzo) - GZIP_CMD="lzop -c$COMPRESS_LEVEL" - GUNZIP_CMD="lzop -d" - ;; -lz4) - GZIP_CMD="lz4 -c$COMPRESS_LEVEL" - GUNZIP_CMD="lz4 -d" - ;; -base64) - GZIP_CMD="base64" - GUNZIP_CMD="base64 --decode -i -" - ;; -gpg) - GZIP_CMD="gpg $GPG_EXTRA -ac -z$COMPRESS_LEVEL" - GUNZIP_CMD="gpg -d" - ENCRYPT="gpg" - ;; -gpg-asymmetric) - GZIP_CMD="gpg $GPG_EXTRA -z$COMPRESS_LEVEL -es" - GUNZIP_CMD="gpg --yes -d" - ENCRYPT="gpg" - ;; -compress) - GZIP_CMD="compress -fc" - GUNZIP_CMD="(type compress >/dev/null 2>&1 && compress -fcd || gzip -cd)" - ;; -none) - GZIP_CMD="cat" - GUNZIP_CMD="cat" - ;; -esac - -if test x"$ENCRYPT" = x"openssl"; then - if test x"$APPEND" = x"y"; then - echo "Appending to existing archive is not compatible with OpenSSL encryption." >&2 - fi - - ENCRYPT_CMD="openssl enc -aes-256-cbc -salt" - DECRYPT_CMD="openssl enc -aes-256-cbc -d" - - if test x"$OPENSSL_NO_MD" != x"y"; then - ENCRYPT_CMD="$ENCRYPT_CMD -md sha256" - DECRYPT_CMD="$DECRYPT_CMD -md sha256" - fi - - if test -n "$PASSWD_SRC"; then - ENCRYPT_CMD="$ENCRYPT_CMD -pass $PASSWD_SRC" - elif test -n "$PASSWD"; then - ENCRYPT_CMD="$ENCRYPT_CMD -pass pass:$PASSWD" - fi -fi - -tmpfile="${TMPDIR:-/tmp}/mkself$$" - -if test -f "$HEADER"; then - oldarchname="$archname" - archname="$tmpfile" - # Generate a fake header to count its lines - SKIP=0 - . "$HEADER" - SKIP=`cat "$tmpfile" |wc -l` - # Get rid of any spaces - SKIP=`expr $SKIP` - rm -f "$tmpfile" - if test "$QUIET" = "n"; then - echo "Header is $SKIP lines long" >&2 - fi - archname="$oldarchname" -else - echo "Unable to open header file: $HEADER" >&2 - exit 1 -fi - -if test "$QUIET" = "n"; then - echo -fi - -if test "$APPEND" = n; then - if test -f "$archname"; then - echo "WARNING: Overwriting existing file: $archname" >&2 - fi -fi - -USIZE=`du $DU_ARGS "$archdir" | awk '{print $1}'` - -if test "." = "$archdirname"; then - if test "$KEEP" = n; then - archdirname="makeself-$$-`date +%Y%m%d%H%M%S`" - fi -fi - -test -d "$archdir" || { echo "Error: $archdir does not exist."; rm -f "$tmpfile"; exit 1; } -if test "$QUIET" = "n"; then - echo "About to compress $USIZE KB of data..." - echo "Adding files to archive named \"$archname\"..." -fi - -# See if we have GNU tar -TAR=`exec <&- 2>&-; which gtar || command -v gtar || type gtar` -test -x "$TAR" || TAR=tar - -tmparch="${TMPDIR:-/tmp}/mkself$$.tar" -( - if test "$APPEND" = "y"; then - tail -n "+$OLDSKIP" "$archname" | eval "$GUNZIP_CMD" > "$tmparch" - fi - cd "$archdir" - # "Determining if a directory is empty" - # https://www.etalabs.net/sh_tricks.html - find . \ - \( \ - ! -type d \ - -o \ - \( -links 2 -exec sh -c ' - is_empty () ( - cd "$1" - set -- .[!.]* ; test -f "$1" && return 1 - set -- ..?* ; test -f "$1" && return 1 - set -- * ; test -f "$1" && return 1 - return 0 - ) - is_empty "$0"' {} \; \ - \) \ - \) -print \ - | LC_ALL=C sort \ - | sed 's/./\\&/g' \ - | xargs $TAR $TAR_EXTRA --format $TAR_FORMAT -$TAR_ARGS "$tmparch" -) || { - echo "ERROR: failed to create temporary archive: $tmparch" - rm -f "$tmparch" "$tmpfile" - exit 1 -} - -USIZE=`du $DU_ARGS "$tmparch" | awk '{print $1}'` - -eval "$GZIP_CMD" <"$tmparch" >"$tmpfile" || { - echo "ERROR: failed to create temporary file: $tmpfile" - rm -f "$tmparch" "$tmpfile" - exit 1 -} -rm -f "$tmparch" - -if test x"$ENCRYPT" = x"openssl"; then - echo "About to encrypt archive \"$archname\"..." - { eval "$ENCRYPT_CMD -in $tmpfile -out ${tmpfile}.enc" && mv -f ${tmpfile}.enc $tmpfile; } || \ - { echo Aborting: could not encrypt temporary file: "$tmpfile".; rm -f "$tmpfile"; exit 1; } -fi - -fsize=`cat "$tmpfile" | wc -c | tr -d " "` - -# Compute the checksums - -shasum=0000000000000000000000000000000000000000000000000000000000000000 -md5sum=00000000000000000000000000000000 -crcsum=0000000000 - -if test "$NOCRC" = y; then - if test "$QUIET" = "n"; then - echo "skipping crc at user request" - fi -else - crcsum=`CMD_ENV=xpg4 cksum < "$tmpfile" | sed -e 's/ /Z/' -e 's/ /Z/' | cut -dZ -f1` - if test "$QUIET" = "n"; then - echo "CRC: $crcsum" - fi -fi - -if test "$SHA256" = y; then - SHA_PATH=`exec <&- 2>&-; which shasum || command -v shasum || type shasum` - if test -x "$SHA_PATH"; then - shasum=`eval "$SHA_PATH -a 256" < "$tmpfile" | cut -b-64` - else - SHA_PATH=`exec <&- 2>&-; which sha256sum || command -v sha256sum || type sha256sum` - shasum=`eval "$SHA_PATH" < "$tmpfile" | cut -b-64` - fi - if test "$QUIET" = "n"; then - if test -x "$SHA_PATH"; then - echo "SHA256: $shasum" - else - echo "SHA256: none, SHA command not found" - fi - fi -fi -if test "$NOMD5" = y; then - if test "$QUIET" = "n"; then - echo "Skipping md5sum at user request" - fi -else - # Try to locate a MD5 binary - OLD_PATH=$PATH - PATH=${GUESS_MD5_PATH:-"$OLD_PATH:/bin:/usr/bin:/sbin:/usr/local/ssl/bin:/usr/local/bin:/opt/openssl/bin"} - MD5_ARG="" - MD5_PATH=`exec <&- 2>&-; which md5sum || command -v md5sum || type md5sum` - test -x "$MD5_PATH" || MD5_PATH=`exec <&- 2>&-; which md5 || command -v md5 || type md5` - test -x "$MD5_PATH" || MD5_PATH=`exec <&- 2>&-; which digest || command -v digest || type digest` - PATH=$OLD_PATH - if test -x "$MD5_PATH"; then - if test `basename ${MD5_PATH}`x = digestx; then - MD5_ARG="-a md5" - fi - md5sum=`eval "$MD5_PATH $MD5_ARG" < "$tmpfile" | cut -b-32` - if test "$QUIET" = "n"; then - echo "MD5: $md5sum" - fi - else - if test "$QUIET" = "n"; then - echo "MD5: none, MD5 command not found" - fi - fi -fi -if test "$SIGN" = y; then - GPG_PATH=`exec <&- 2>&-; which gpg || command -v gpg || type gpg` - if test -x "$GPG_PATH"; then - SIGNATURE=`$GPG_PATH --pinentry-mode=loopback --batch --yes --passphrase "$GPG_PASSPHRASE" --output - --detach-sig $tmpfile | base64 | tr -d \\\\n` - if test "$QUIET" = "n"; then - echo "Signature: $SIGNATURE" - fi - else - echo "Missing gpg command" >&2 - fi -fi - -totalsize=0 -for size in $fsize; -do - totalsize=`expr $totalsize + $size` -done - -if test "$APPEND" = y; then - mv "$archname" "$archname".bak || exit - - # Prepare entry for new archive - filesizes="$fsize" - CRCsum="$crcsum" - MD5sum="$md5sum" - SHAsum="$shasum" - Signature="$SIGNATURE" - # Generate the header - . "$HEADER" - # Append the new data - cat "$tmpfile" >> "$archname" - - chmod +x "$archname" - rm -f "$archname".bak - if test "$QUIET" = "n"; then - echo "Self-extractable archive \"$archname\" successfully updated." - fi -else - filesizes="$fsize" - CRCsum="$crcsum" - MD5sum="$md5sum" - SHAsum="$shasum" - Signature="$SIGNATURE" - - # Generate the header - . "$HEADER" - - # Append the compressed tar data after the stub - if test "$QUIET" = "n"; then - echo - fi - cat "$tmpfile" >> "$archname" - chmod +x "$archname" - if test "$QUIET" = "n"; then - echo Self-extractable archive \"$archname\" successfully created. - fi -fi -rm -f "$tmpfile" diff --git a/cust_op/cust_op_by_addr/cmake/util/makeself/run-tests.sh b/cust_op/cust_op_by_addr/cmake/util/makeself/run-tests.sh deleted file mode 100644 index 31ee1651..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/makeself/run-tests.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/sh -# Run every available test - Bash needed -cd test -for test in *test; -do - echo "Running test $test ..." - bash $test || { echo "*** ERROR: Test '$test' failed!"; exit 1; } -done diff --git a/mx_rec/validator/__init__.py b/mx_rec/validator/__init__.py new file mode 100644 index 00000000..6924f767 --- /dev/null +++ b/mx_rec/validator/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -- Gitee From 2509b4715e295d07d79ee954725cfe9f255c39fd Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 30 May 2023 10:16:08 +0800 Subject: [PATCH 092/551] Match-id-c02e623277eb6e6746ca18ac30379a1054ef9515 --- mx_rec/validator/__init__.py | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 mx_rec/validator/__init__.py diff --git a/mx_rec/validator/__init__.py b/mx_rec/validator/__init__.py deleted file mode 100644 index 6924f767..00000000 --- a/mx_rec/validator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -- Gitee From 36d2104c96946d90826ccaee35b5a28500b0c3dd Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 30 May 2023 15:17:43 +0800 Subject: [PATCH 093/551] Match-id-6d241cc45d3f153e4acc5aeb36794eda657447de --- mx_rec/__init__.py | 1 + mx_rec/constants/__init__.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) delete mode 100644 mx_rec/constants/__init__.py diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 7a2ed887..bbee9e8f 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -3,6 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from mx_rec.util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION +from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset diff --git a/mx_rec/constants/__init__.py b/mx_rec/constants/__init__.py deleted file mode 100644 index 6924f767..00000000 --- a/mx_rec/constants/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -- Gitee From dea8634d71ecc1b9c96d16313760c2a629d32ee4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 30 May 2023 15:35:06 +0800 Subject: [PATCH 094/551] Match-id-a866e3671d029683adf365b0746accd918830091 --- mx_rec/saver/patch.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 72fe950f..34acdaad 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -123,12 +123,8 @@ def get_model_checkpoint_path(self, checkpoint_file, sess): return model_checkpoint_path -def update_checkpoint_state(self, **kwargs): - model_checkpoint_path = kwargs.get("model_checkpoint_path") - parent_save_path = kwargs.get("parent_save_path") - latest_file_name = kwargs.get("latest_file_name") - suffix_meta_graph = kwargs.get("suffix_meta_graph") - save_path = kwargs.get("save_path") +def update_checkpoint_state(self, model_checkpoint_path, parent_save_path, latest_file_name, suffix_meta_graph, + save_path): self._RecordLastCheckpoint(model_checkpoint_path) try: checkpoint_management.update_checkpoint_state_internal(save_dir=parent_save_path, @@ -192,9 +188,8 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra model_checkpoint_path = compat.as_str(get_model_checkpoint_path(self, checkpoint_file, sess)) if write_state: - update_checkpoint_state(self, model_checkpoint_path=model_checkpoint_path, save_path_parent=save_path_parent, - latest_filename=latest_filename, meta_graph_suffix=meta_graph_suffix, - save_path=save_path) + update_checkpoint_state(self, model_checkpoint_path, save_path_parent, latest_filename, meta_graph_suffix, + save_path) if write_meta_graph: write_meta_graph_task(self, checkpoint_file=checkpoint_file, meta_graph_suffix=meta_graph_suffix, sess=sess, strip_default_attrs=strip_default_attrs, save_debug_info=save_debug_info) -- Gitee From 90499e9b6cfea57d83c8e9f93e6c288deca34819 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 31 May 2023 10:18:45 +0800 Subject: [PATCH 095/551] Match-id-98e31e2d19c154ab417ce2f7d7db8b618183d4d9 --- mx_rec/core/embedding.py | 16 +++++++++++++--- mx_rec/util/constants.py | 1 + 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 8b63a7af..da1e2429 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -18,7 +18,7 @@ from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ - DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID + DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, MULTI_LOOKUP_TIMES from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, trigger_evict, get_table_instance_by_name, get_use_hot, get_device_id, export_feature_spec, \ @@ -301,6 +301,14 @@ class SparseEmbedding: raise RuntimeError(f"Current sparse table was config in {self.mode.value} mode, but sparse lookup method " f"for {method_mode} was in use.") + def check_multi_lookup_times(self): + if self.modify_graph: + self.lookup_result = dict() + if len(self.channel_name_list) > MULTI_LOOKUP_TIMES or len(self.lookup_result) > MULTI_LOOKUP_TIMES: + run_mode = "Modify Graph" if self.modify_graph else "Feature Spec" + raise RuntimeError(f"In '{run_mode}' mode, the number of multiple sparse lookup for a table" + f"({self.table_name}) is {MULTI_LOOKUP_TIMES}.") + def check_and_format_lookup_params(self, feature, send_count, is_training): logging.debug(f"sparse lookup for table {self.table_name} with is_training {is_training}") @@ -416,6 +424,7 @@ class SparseEmbedding: self._send_count = send_count self.send_count_map[ids_channel_name] = send_count self.modify_graph = kwargs.get("modify_graph", True) + self.check_multi_lookup_times() logging.debug(f"In lookup_for_asc function, table name: {self.table_name}, anchor_ids: {anchor_ids}, " f"ids_channel_name: {ids_channel_name}, use_dynamic_expansion: {use_dynamic_expansion}, " f"use_static: {use_static}, use_hot: {use_hot}") @@ -535,7 +544,6 @@ class SparseEmbedding: if len(same_table_feature_spec) == 1: lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) self.lookup_result = {spec_name: {is_training: lookup_result}} - return self.lookup_result.get(spec_name).get(is_training) else: same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) same_table_spec_count = len(same_table_feature_spec) @@ -556,7 +564,9 @@ class SparseEmbedding: lookup_result_split = tf.split(lookup_result, split_size) self.lookup_result = {k.name: {is_training: tf.reshape(v, k.dims + [self.scalar_emb_size])} for k, v in zip(same_table_feature_spec, lookup_result_split)} - return self.lookup_result.get(spec_name).get(is_training) + + self.check_multi_lookup_times() + return self.lookup_result.get(spec_name).get(is_training) def lookup_for_asc_with_feature_spec_inner(self, feature_spec: FeatureSpec, send_count: int, **kwargs): """ diff --git a/mx_rec/util/constants.py b/mx_rec/util/constants.py index 3a9be2ec..4c48076c 100644 --- a/mx_rec/util/constants.py +++ b/mx_rec/util/constants.py @@ -29,6 +29,7 @@ LOCAL_RANK_SIZE = "LOCAL_RANK_SIZE" # 训练时,当前服务器使用的NPU MAX_DEVICE_NUM_LOCAL_MACHINE = 16 # 单台服务器最大的卡数 DEFAULT_DEVICE_NUM_LOCAL_MACHINE = 8 # 单台服务器默认的卡数 +MULTI_LOOKUP_TIMES = 2048 DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 -- Gitee From 2da38f3473e26e6c2cf5b670d8ebca09c0d2f9be Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 31 May 2023 10:19:21 +0800 Subject: [PATCH 096/551] Match-id-b1da735d232f402545d4b50c4faa704d1aaf3345 --- mx_rec/util/constants.py | 1 + mx_rec/util/initialize.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mx_rec/util/constants.py b/mx_rec/util/constants.py index 4c48076c..1fe759e7 100644 --- a/mx_rec/util/constants.py +++ b/mx_rec/util/constants.py @@ -33,6 +33,7 @@ MULTI_LOOKUP_TIMES = 2048 DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 +HASHTABLE_COLLECTION_NAME_LENGTH = 30 class BaseEnum(Enum): diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 1ccefea4..906cee20 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -12,7 +12,7 @@ import psutil import mxrec_pybind import mx_rec.util.constants from mx_rec.util.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, \ - ASCEND_GLOBAL_HASHTABLE_COLLECTION + ASCEND_GLOBAL_HASHTABLE_COLLECTION, HASHTABLE_COLLECTION_NAME_LENGTH from mx_rec.util.ops import import_host_pipeline_ops @@ -328,6 +328,9 @@ class ConfigInitializer: def ascend_global_hashtable_collection(self, name): if not isinstance(name, str): raise TypeError(f"collection name '{name}' must be a string.") + if len(name) > HASHTABLE_COLLECTION_NAME_LENGTH: + raise ValueError(f"The length of the collection name '{name}' should be between " + f"[0, {HASHTABLE_COLLECTION_NAME_LENGTH}].") self._ascend_global_hashtable_collection = name def get_initializer(self, is_training): -- Gitee From 558218438a8700db93210f8e659d4723743e4643 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 31 May 2023 11:15:38 +0800 Subject: [PATCH 097/551] Match-id-d873c69db44a9d713c45a25b276b2bd5b77be147 --- mx_rec/core/embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index d1ba2d83..156cdd78 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -431,7 +431,7 @@ class SparseEmbedding: embedding_dim=self.emb_size, embedding_type=1) - is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ + is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is None or \ ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name if is_training and use_dynamic_expansion and is_table_name_valid: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) @@ -670,7 +670,7 @@ class SparseEmbedding: host_pipeline_ops.embedding_lookup_by_address(id_offsets, embedding_dim=self.emb_size, embedding_type=1) - is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is not None and \ + is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is None or \ ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name if is_training and is_table_name_valid: tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) -- Gitee From 6c6c58137f36584298cd1493d0e843dab8128848 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 31 May 2023 11:48:51 +0800 Subject: [PATCH 098/551] Match-id-424f1dd88457234a48c390e7fd2209e94bc17a5b --- mx_rec/core/asc/feature_spec.py | 6 ++++++ mx_rec/core/embedding.py | 5 ++--- mx_rec/util/constants.py | 3 +++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index c11515dd..4e229014 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -9,6 +9,7 @@ import tensorflow as tf from mx_rec.util.atomic import AtomicInteger from mx_rec.util.initialize import insert_feature_spec, insert_training_mode_channel_id, get_use_static +from mx_rec.util.constants import MAX_INT32 feature_spec_global_id = AtomicInteger() @@ -104,11 +105,16 @@ class FeatureSpec: if self._access_threshold is not None: check_natural_number(self._access_threshold, "access_threshold") + if self._access_threshold > MAX_INT32: + raise ValueError(f"Access_threshold is too big that exceed int32.") + elif self._eviction_threshold is not None: raise ValueError(f"Access_threshold should be configured before eviction_threshold.") if self._eviction_threshold is not None: check_natural_number(self._eviction_threshold, "eviction_threshold") + if self._eviction_threshold > MAX_INT32: + raise ValueError(f"Eviction_threshold is too big that exceed int32.") if self._is_timestamp is not None: check_bool(self._is_timestamp, "is_timestamp") diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 156cdd78..f38b52a5 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -18,7 +18,7 @@ from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ - DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, ASCEND_TABLE_NAME_MUST_CONTAIN + DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32 from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, trigger_evict, get_table_instance_by_name, get_use_hot, get_device_id, export_feature_spec, \ @@ -336,8 +336,7 @@ class SparseEmbedding: raise ValueError("Send count must be a integer which is larger than 0.") check_params() - max_int32 = np.iinfo(np.int32).max - if self.slice_host_vocabulary_size + self.slice_device_vocabulary_size > max_int32: + if self.slice_host_vocabulary_size + self.slice_device_vocabulary_size > MAX_INT32: raise ValueError(f"Given device_vocabulary_size and host_vocabulary_size was too big for table " f"'{self.table_name}', in which slice_device_vocabulary_size was " f"{self.slice_device_vocabulary_size} and slice_host_vocabulary_size was " diff --git a/mx_rec/util/constants.py b/mx_rec/util/constants.py index 3a9be2ec..423e86fe 100644 --- a/mx_rec/util/constants.py +++ b/mx_rec/util/constants.py @@ -3,6 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from enum import Enum +import numpy as np ASCEND_GLOBAL_HASHTABLE_COLLECTION = "ASCEND_GLOBAL_HASHTABLE_COLLECTION" ASCEND_CUTTING_POINT_INITIALIZER = "ASCEND_CUTTING_POINT_INITIALIZER" @@ -33,6 +34,8 @@ DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 +MAX_INT32 = np.iinfo(np.int32).max + class BaseEnum(Enum): @classmethod -- Gitee From 55af3b36505f3a3720a5c3b404fc14a75ecbe574 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 31 May 2023 14:19:20 +0800 Subject: [PATCH 099/551] Match-id-91f052e3b044db099d021176abdfb6e3ff4f627e --- mx_rec/__init__.py | 1 - mx_rec/util/sparse.py | 198 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 mx_rec/util/sparse.py diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index bbee9e8f..7a2ed887 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -3,7 +3,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from mx_rec.util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION -from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset diff --git a/mx_rec/util/sparse.py b/mx_rec/util/sparse.py new file mode 100644 index 00000000..2874d4e0 --- /dev/null +++ b/mx_rec/util/sparse.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import logging +import os +import json + +import tensorflow as tf +import numpy as np + +from mx_rec.util.initialize import get_ascend_global_hashtable_collection, get_table_instance, \ + get_table_instance_by_name + + +class SparseProcessor: + single_instance = None + + def __init__(self, model_dir, **kwargs): + self.sep = "/" + self.export_name = "key-emb" + self.device_dir_list = ["HashTable", "HBM"] + self.host_dir_list = ["HashTable", "DDR"] + self.device_emb_dir = "embedding" + self.host_emb_dir = "embedding_data" + self.device_hashmap_dir = "key_offset_map" + self.host_hashmap_dir = "embedding_hashmap" + self.data_suffix = ".data" + self.attrib_suffix = ".attribute" + self.json_attrib_dtype = "data_type" + self.json_attrib_shape = "shape" + self.model_dir = model_dir + if not os.path.exists(model_dir): + raise FileExistsError(f"the model_dir supported {model_dir} does not exist.") + self.table_list = kwargs.get("table_list") + self.default_table_list = get_table_list() + + if not self.table_list: + logging.debug("table list not be set, use default value : all table created ") + self.table_list = self.default_table_list + else: + self.table_list = check_table_param(self.table_list, self.default_table_list) + + @staticmethod + def set_instance(model_dir, **kwargs): + SparseProcessor.single_instance = SparseProcessor(model_dir, **kwargs) + + @staticmethod + def _get_data(data_dir, dtype, data_shape): + try: + data = np.fromfile(data_dir, dtype=dtype) + except FileNotFoundError as err: + raise FileNotFoundError(f"data dir not found.") from err + data = data.reshape(data_shape) + return data + + @staticmethod + def _get_shape_form_attrib(attribute_dir, is_json): + if is_json: + try: + with open(attribute_dir, "r") as fin: + attributes = json.load(fin) + except FileNotFoundError as err: + raise FileNotFoundError(f"attribute dir not found.") from err + else: + try: + attributes = np.fromfile(attribute_dir, np.uint64) + except FileNotFoundError as err: + raise FileNotFoundError(f"attribute dir not found.") from err + + return attributes + + def export_sparse_data(self): + logging.info("table list to be exported is %s", self.table_list) + sparse_dirs = self._get_sparse_dirs() + for sparse_dir in sparse_dirs: + ddr = False + sparse_dir = os.path.join(self.model_dir, sparse_dir) + dev_dir = set_upper_dir(sparse_dir, self.device_dir_list) + host_dir = set_upper_dir(sparse_dir, self.host_dir_list) + for table in self.table_list: + table_instance = get_table_instance_by_name(table) + device_table_dir = os.path.join(dev_dir, table) + host_table_dir = os.path.join(host_dir, table) + if table_instance.host_vocabulary_size != 0: + ddr = True + out_dir = host_table_dir + else: + out_dir = device_table_dir + key, offset = self._get_hashmap(out_dir, ddr) + emb_data = self.get_embedding(device_table_dir, host_table_dir, ddr) + emb_data = emb_data[offset] + transformed_data = dict(zip(key[:], emb_data[:])) + np.save(out_dir + self.sep + self.export_name + ".npy", transformed_data) + + + def get_embedding(self, device_table_dir, host_table_dir, ddr): + emb_dir = os.path.join(device_table_dir, self.device_emb_dir) + data_file, attribute_file = self._get_file_names(emb_dir) + if not os.path.exists(data_file): + raise FileExistsError(f"embedding data file {data_file} does not exist when reading.") + if not os.path.exists(attribute_file): + raise FileExistsError(f"attribute file {attribute_file} does not exist when reading.") + + temp = self._get_shape_form_attrib(attribute_file, is_json=True) + data_shape = temp.pop(self.json_attrib_shape) + data_dtype = temp.pop(self.json_attrib_dtype) + emb_data = self._get_data(data_file, data_dtype, data_shape) + + if ddr: + emb_dir = os.path.join(host_table_dir, self.host_emb_dir) + data_file, attribute_file = self._get_file_names(emb_dir) + host_attribute = self._get_shape_form_attrib(attribute_file, is_json=False) + host_data_shape = [host_attribute[0], host_attribute[1]] + host_data = self._get_data(data_file, data_dtype, host_data_shape) + host_data = host_data[:, :data_shape[1]] + emb_data = np.append(emb_data, host_data, axis=0) + return emb_data + + def _get_hashmap(self, table_dir, ddr): + if not ddr: + hashmap_dir = os.path.join(table_dir, self.device_hashmap_dir) + else: + hashmap_dir = os.path.join(table_dir, self.host_hashmap_dir) + data_file, attribute_file = self._get_file_names(hashmap_dir) + if not os.path.exists(data_file): + raise FileExistsError(f"hashmap data file {data_file} does not exist when reading.") + if not os.path.exists(attribute_file): + raise FileExistsError(f"hashmap attribute file {attribute_file} does not exist when reading.") + + shape_data = self._get_shape_form_attrib(attribute_file, is_json=False) + if len(shape_data) < 2: + raise ValueError(f"the attribute data from file {attribute_file} is invalid") + data_shape = shape_data[:2] + raw_hashmap = self._get_data(data_file, np.uint64, data_shape) + offset = raw_hashmap[:, -1] + key = raw_hashmap[:, 0] + return key, offset + + def _get_sparse_dirs(self): + sub_dirs = [] + for _, sub_dir, _ in os.walk(self.model_dir): + sub_dirs.append(sub_dir) + if not sub_dirs: + raise FileExistsError("There is no sparse folder in the model ") + return sub_dirs[0] + + def _get_file_names(self, directory): + files = [] + data_file = None + attribute_file = None + for _, _, file in os.walk(directory): + files.append(file) + if not files: + raise FileExistsError(f"There is no files under the {directory} ") + for file in files[0]: + if file.find(self.data_suffix) != -1: + data_file = file + elif file.find(self.attrib_suffix) != -1: + attribute_file = file + data_file = os.path.join(directory, data_file) + attribute_file = os.path.join(directory, attribute_file) + return data_file, attribute_file + + +def export(model_dir, **kwargs): + empty_value = 0 + SparseProcessor.set_instance(model_dir, **kwargs) + if SparseProcessor.single_instance.table_list: + return SparseProcessor.single_instance.export_sparse_data() + else: + logging.warning("no table can be exported ,please check if you have saved or created tables") + return empty_value + + +def get_table_list(): + var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + table_list = [] + for var in var_list: + table_instance = get_table_instance(var) + table_name = table_instance.table_name + table_list.append(table_name) + return table_list + + +def check_table_param(table_list, default_table_list): + out_list = [] + for table in table_list: + if table not in default_table_list: + logging.warning(f"{table} not be created , please check your table name.") + out_list.append(table) + return out_list + + +def set_upper_dir(model_dir, dir_list): + for directory in dir_list: + model_dir = os.path.join(model_dir, directory) + return model_dir -- Gitee From 7634e598b5b1439fcdb98a8a7a9a63496f3cc4e1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 31 May 2023 16:05:32 +0800 Subject: [PATCH 100/551] Match-id-85d618a448b2771a3913448718c5717bb13f2bf2 --- mx_rec/__init__.py | 3 ++- mx_rec/core/embedding.py | 9 ++++---- mx_rec/graph/modifier.py | 5 +++-- mx_rec/graph/patch.py | 43 +++++++++++++++++++++++++++++++++++++++ mx_rec/util/initialize.py | 40 ++++++++++++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 8 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 7a2ed887..33010774 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -4,11 +4,12 @@ from mx_rec.util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.saver.patch import patch_for_saver -from mx_rec.graph.patch import patch_for_dataset +from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator patch_for_saver() patch_for_dataset() +patch_for_chief_session_creator() __version__ = "5.0.RC2" diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index da1e2429..e2cf6eeb 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -22,7 +22,8 @@ from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_L from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, trigger_evict, get_table_instance_by_name, get_use_hot, get_device_id, export_feature_spec, \ - ConfigInitializer, get_ascend_global_hashtable_collection, get_host_pipeline_ops, get_use_dynamic_expansion + ConfigInitializer, get_ascend_global_hashtable_collection, get_host_pipeline_ops, get_use_dynamic_expansion, \ + set_modify_graph from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.util.variable import remove_saving_var @@ -100,8 +101,7 @@ def sparse_lookup(hashtable, ids, send_count, **kwargs): def check_modify_graph(): if not kwargs.get("modify_graph"): - logging.warning(f"MxRecMode {MxRecMode.ASC} must config with a 'True' " - f"modify_graph.") + raise ValueError(f"modify_graph must be turn-on when lookup by ids(Tensor, not FeatureSpec).") check_lookup_kwargs() scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) @@ -112,6 +112,7 @@ def sparse_lookup(hashtable, ids, send_count, **kwargs): return hashtable.lookup_for_asc_with_feature_spec(ids, send_count, **kwargs) else: check_modify_graph() + set_modify_graph(True) return hashtable.lookup_for_asc(ids, send_count, **kwargs) else: raise EnvironmentError(f"Invalid MxRec Mode.") @@ -392,8 +393,6 @@ class SparseEmbedding: """ logging.debug(f"Enter ASC Branch.") - if not kwargs.get("modify_graph"): - raise ValueError(f"modify_graph must be turn-on when lookup by ids(Tensor, not FeatureSpec).") self.check_mode(MxRecMode.ASC) is_training = kwargs.get("is_train") diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 784d4b41..0c3f4c41 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -17,7 +17,7 @@ from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARS ASCAnchorAttr, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_size, get_training_mode_channel_id, get_feature_spec, \ insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, get_use_dynamic_expansion, \ - terminate_config_initializer + terminate_config_initializer, set_is_graph_modify_hook_running from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, replace_anchor, \ record_ops_to_replace, export_pb_graph, make_sorted_key_to_tensor_list @@ -427,9 +427,10 @@ def replace_anchor_vec(cutting_point, attribute, anchor): class GraphModifierHook(tf.estimator.SessionRunHook): - def __init__(self, dump_graph=True, modify_graph=False): + def __init__(self, dump_graph=True, modify_graph=True): self.dump_graph = dump_graph self.modify_graph = modify_graph + set_is_graph_modify_hook_running(True) def begin(self): if self.modify_graph: diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 8f2add0f..ee3f988f 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -3,12 +3,15 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import weakref +import logging import tensorflow as tf from tensorflow.python.data.ops.dataset_ops import DatasetV2 from tensorflow.python.data.ops.dataset_ops import _VariantTracker from tensorflow.python.framework import ops +from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph + def init_dataset(self, input_data): """ @@ -26,3 +29,43 @@ def init_dataset(self, input_data): def patch_for_dataset(): DatasetV2.__init__ = init_dataset + + +def chief_session_creator_init(self, scaffold=None, master='', config=None, checkpoint_dir=None, + checkpoint_filename_with_path=None): + """ + Initializes a chief session creator and check if 'GraphModifierHook' is configured. + + Args: + self: An instance object of the class ChiefSessionCreator. + scaffold: A `Scaffold` used for gathering or building supportive ops. If + not specified a default one is created. It's used to finalize the graph. + master: `String` representation of the TensorFlow master to use. + config: `ConfigProto` proto used to configure the session. + checkpoint_dir: A string. Optional path to a directory where to restore variables. + checkpoint_filename_with_path: Full file name path to the checkpoint file. + Returns:None + """ + logging.debug("Enter the mxrec init function of Class 'monitored_session.ChiefSessionCreator'.") + if get_modify_graph() and not get_is_graph_modify_hook_running(): + raise RuntimeError( + f"When 'modify_graph' is True, 'GraphModifierHook' must be configured. Example: \n" + f"\t from mx_rec.graph.modifier import GraphModifierHook \n" + f"\t estimator.train(..., hooks=[GraphModifierHook()])") + + self._checkpoint_dir = checkpoint_dir + self._checkpoint_filename_with_path = checkpoint_filename_with_path + self._scaffold = scaffold or tf.compat.v1.train.Scaffold() + self._session_manager = None + self._master = master + self._config = config + + +def patch_for_chief_session_creator(): + """ + The 'train, predict, train_and_evaluate' mode in the estimator mode ultimately creates the 'ChiefSessionCreator' + class, so it can be determined whether 'GraphModifierHook' is configured in the init function of this class. + Returns:None + """ + tf.compat.v1.train.ChiefSessionCreator.__init__ = chief_session_creator_init + logging.debug("__init__ in Class 'monitored_session.ChiefSessionCreator' has been patched.") diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 906cee20..979f21b6 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -41,6 +41,8 @@ class ConfigInitializer: self._rank_to_device_dict = dict() self._initializer_dict = {} self._optimizer_instance = None + self._is_graph_modify_hook_running = False + self._modify_graph = False if self._use_mpi: logging.debug(f"Using mpi to launch task.") @@ -73,6 +75,14 @@ class ConfigInitializer: def __del__(self): self.terminate() + @property + def is_graph_modify_hook_running(self): + return self._is_graph_modify_hook_running + + @property + def modify_graph(self): + return self._modify_graph + @property def feature_spec_dict(self): return self._feature_spec_dict @@ -324,6 +334,20 @@ class ConfigInitializer: self._if_load = flag + @is_graph_modify_hook_running.setter + def is_graph_modify_hook_running(self, is_hook_running): + if not isinstance(is_hook_running, bool): + raise TypeError(f"is_hook_running should be a boolean.") + + self._is_graph_modify_hook_running = is_hook_running + + @modify_graph.setter + def modify_graph(self, is_modify_graph): + if not isinstance(is_modify_graph, bool): + raise TypeError(f"is_modify_graph should be a boolean.") + + self._modify_graph = is_modify_graph + @ascend_global_hashtable_collection.setter def ascend_global_hashtable_collection(self, name): if not isinstance(name, str): @@ -370,6 +394,22 @@ def init(use_mpi, **kwargs): set_ascend_env() +def get_is_graph_modify_hook_running(): + return ConfigInitializer.get_instance().is_graph_modify_hook_running + + +def set_is_graph_modify_hook_running(is_running): + ConfigInitializer.get_instance().is_graph_modify_hook_running = is_running + + +def get_modify_graph(): + return ConfigInitializer.get_instance().modify_graph + + +def set_modify_graph(is_modify_graph): + ConfigInitializer.get_instance().modify_graph = is_modify_graph + + def is_mpi_in_use(): return ConfigInitializer.get_instance().use_mpi -- Gitee From aaf30c437623e24afcd30a63fb37c582750a3bd8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 1 Jun 2023 11:47:52 +0800 Subject: [PATCH 101/551] Match-id-bc585773e6f2fa8f180c9a1aaa99974abf152643 --- example/little_demo/main.py | 4 +- mx_rec/__init__.py | 3 +- mx_rec/constants/__init__.py | 3 + mx_rec/{util => constants}/constants.py | 9 + mx_rec/core/asc/build_graph.py | 2 +- mx_rec/core/asc/feature_spec.py | 2 +- mx_rec/core/asc/helper.py | 2 +- mx_rec/core/asc/manager.py | 4 +- mx_rec/core/embedding.py | 10 +- mx_rec/graph/modifier.py | 2 +- mx_rec/saver/saver.py | 2 +- mx_rec/util/initialize.py | 89 ++++--- mx_rec/util/ops.py | 21 +- mx_rec/validator/__init__.py | 3 + mx_rec/validator/validator.py | 304 ++++++++++++++++++++++++ 15 files changed, 404 insertions(+), 56 deletions(-) create mode 100644 mx_rec/constants/__init__.py rename mx_rec/{util => constants}/constants.py (92%) create mode 100644 mx_rec/validator/__init__.py create mode 100644 mx_rec/validator/validator.py diff --git a/example/little_demo/main.py b/example/little_demo/main.py index f2d0175f..116ef432 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -15,7 +15,7 @@ from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import create_table, sparse_lookup from mx_rec.graph.modifier import modify_graph_and_start_emb_cache -from mx_rec.util.constants import MxRecMode, ASCEND_TIMESTAMP +from mx_rec.constants.constants import MxRecMode, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_id, get_rank_size, init, clear_channel, terminate_config_initializer, \ set_if_load, get_initializer from mx_rec.util.variable import get_dense_and_sparse_variable @@ -183,7 +183,7 @@ if __name__ == "__main__": train_ops.append(dense_optimizer.apply_gradients(avg_grads)) if use_dynamic_expansion: - from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) # do sparse optimization by addr diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index ee0b04b4..bf55ffcf 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -from mx_rec.util.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION +from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION +from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator from mx_rec.optimizers.base import patch_for_optimizer diff --git a/mx_rec/constants/__init__.py b/mx_rec/constants/__init__.py new file mode 100644 index 00000000..6924f767 --- /dev/null +++ b/mx_rec/constants/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/mx_rec/util/constants.py b/mx_rec/constants/constants.py similarity index 92% rename from mx_rec/util/constants.py rename to mx_rec/constants/constants.py index 6925dc06..c39349e9 100644 --- a/mx_rec/util/constants.py +++ b/mx_rec/constants/constants.py @@ -36,6 +36,15 @@ TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 HASHTABLE_COLLECTION_NAME_LENGTH = 30 +# RANK INFO +VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"] +MIN_SIZE = 1 +MAX_SIZE = 1024 * 1024 * 1024 * 1024 +MAX_DEVICE_NUM = 16 +MAX_RANK_SIZE = 4095 +MIN_DEVICE_NUM = 1 +MIN_RANK_SIZE = 1 + MAX_INT32 = np.iinfo(np.int32).max diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index c4ea20e7..70bffe2f 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -7,7 +7,7 @@ import logging import tensorflow as tf import mxrec_pybind -from mx_rec.util.constants import AVOID_TENSOR_POS +from mx_rec.constants.constants import AVOID_TENSOR_POS from mx_rec.util.initialize import get_use_static from mx_rec.util.tf_version_adapter import npu_ops diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 4e229014..fb0bbbea 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -9,7 +9,7 @@ import tensorflow as tf from mx_rec.util.atomic import AtomicInteger from mx_rec.util.initialize import insert_feature_spec, insert_training_mode_channel_id, get_use_static -from mx_rec.util.constants import MAX_INT32 +from mx_rec.constants.constants import MAX_INT32 feature_spec_global_id = AtomicInteger() diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 3b2a16f5..b57a2adb 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -222,7 +222,7 @@ def do_insert(args, insert_tensors, splits, table_names, input_dict): # Only the tables that need to be used after table combination are retained in meituan situation. # Current solution has error in same situations. For example, a sparse table has not been auto-merged. - from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN new_insert_tensors, new_splits, new_table_names = [], [], [] logging.debug(f"In do_insert function, ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") for idx, table_name in enumerate(table_names): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index a8dcf6cb..944a0f9a 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -6,7 +6,7 @@ import logging import tensorflow as tf -from mx_rec.util.constants import MxRecMode +from mx_rec.constants.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ @@ -15,7 +15,7 @@ from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, se def generate_table_info_list(): from mxrec_pybind import EmbInfo - from mx_rec.util.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN # table_name is corresponding to channel_name which is in used in operator gen_npu_ops.get_next table_info_list = [] diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index d8d8dc8e..331c1248 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -15,7 +15,7 @@ from tensorflow.python.ops import array_ops from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32 @@ -441,8 +441,8 @@ class SparseEmbedding: is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is None or \ ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name if is_training and use_dynamic_expansion and is_table_name_valid: - tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) - tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) logging.debug(f"modify graph mode, table_name: {self.table_name}, " f"ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") @@ -681,8 +681,8 @@ class SparseEmbedding: is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is None or \ ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name if is_training and is_table_name_valid: - tf.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) - tf.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) logging.debug(f"feature spec mode, table_name: {self.table_name}, " f"ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 0c3f4c41..b53e00d5 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -13,7 +13,7 @@ from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding -from mx_rec.util.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ +from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ ASCAnchorAttr, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_size, get_training_mode_channel_id, get_feature_spec, \ insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, get_use_dynamic_expansion, \ diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 93a00810..02b70bd1 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -13,7 +13,7 @@ import numpy as np import tensorflow as tf from tensorflow.python.util import compat -from mx_rec.util.constants import DataName, DataAttr +from mx_rec.constants.constants import DataName, DataAttr from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, \ get_ascend_global_hashtable_collection diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 979f21b6..2a1ff997 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -7,13 +7,13 @@ import logging import os from collections import defaultdict +import mxrec_pybind import psutil -import mxrec_pybind -import mx_rec.util.constants -from mx_rec.util.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, \ - ASCEND_GLOBAL_HASHTABLE_COLLECTION, HASHTABLE_COLLECTION_NAME_LENGTH +from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST +from mx_rec.constants.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE from mx_rec.util.ops import import_host_pipeline_ops +from mx_rec.validator.validator import RankInfoValidator class ConfigInitializer: @@ -58,7 +58,7 @@ class ConfigInitializer: if os.getenv("RANK_TABLE_FILE"): self.parse_hccl_json() else: - self.set_device_dict() + self.set_hccl_info_without_json() self.check_parameters() self.train_interval = kwargs.get("train_interval", -1) self.eval_steps = kwargs.get("eval_steps", -1) @@ -208,31 +208,56 @@ class ConfigInitializer: raise ValueError(f"get logic id from physic id fail.") self._rank_to_device_dict[rank_id] = device_id - def set_device_dict(self): + def set_hccl_info_without_json(self): + """ + Used for no rank table file configured training situation. + Now, only less than or equal 8p training job is supported. + :return: None + """ + RankInfoValidator().check_visible_devices() ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") - if not ascend_visible_devices: - raise ValueError("env variable ascend_visible_devices is null.") - if "-" in ascend_visible_devices: - rank_start = int(ascend_visible_devices.strip().split("-")[0]) - device_list = list(range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]))) - elif "," in ascend_visible_devices: - device_list = list(map(int, ascend_visible_devices.strip().split(","))) - elif ascend_visible_devices in ["0", "1", "2", "3", "4", "5", "6", "7"]: - device_list = [int(ascend_visible_devices.strip())] - else: - raise ValueError("invalid env variable ascend_visible_devices.") - rank_size = int(os.getenv("CM_WORKER_SIZE")) - self._rank_to_device_dict[0] = int(os.getenv("CM_CHIEF_DEVICE")) - device_list.pop(int(os.getenv("CM_CHIEF_DEVICE"))) - if rank_size: - local_rank_size = rank_size if rank_size < 8 else 8 - for device_index in range(local_rank_size - 1): - device_id = mxrec_pybind.get_logic_id(int(device_list[device_index])) - if device_id > 16: - raise ValueError(f"get logic id from physic id fail.") - self._rank_to_device_dict[device_index + 1] = device_id - else: - raise ValueError("get CM_WORKER_SIZE failed.") + device_list = [] + try: + if "-" in ascend_visible_devices: + split_devices = ascend_visible_devices.strip().split("-") + if len(split_devices) >= 1: + rank_start = int(split_devices[0]) + device_list = list(range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]) + 1)) + elif "," in ascend_visible_devices: + device_list = list(map(int, ascend_visible_devices.strip().split(","))).sort() + elif ascend_visible_devices in VALID_DEVICE_ID_LIST: + device_list = [int(ascend_visible_devices.strip())] + else: + raise ValueError("invalid env variable ascend_visible_devices.") + except ValueError as error: + raise ValueError("Invalid env variable ascend_visible_devices, no valid device id is configured. " + "Please refer to the document https://www.hiascend.com/document/detail/zh/" + "CANNCommunityEdition/63RC2alpha002/ptmoddevg/ptmigr/ptmigr_0151.html for " + "the correct configuration method.") from error + except IndexError as error: + raise IndexError( + f"Index of ascend_visible_devices {ascend_visible_devices.strip().split('-')[-1]} is out of range") \ + from error + + chief_device = os.getenv("CM_CHIEF_DEVICE") + rank_size = os.getenv("CM_WORKER_SIZE") + if int(rank_size) != len(device_list): + raise ValueError(f"Rank size {rank_size} is different from device num {len(device_list)}.") + try: + self._rank_to_device_dict[0] = int(chief_device) + device_list.pop(int(chief_device)) + except IndexError as err: + raise IndexError( + f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {device_list}.") from err + except ValueError as err: + raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err + + for device_idx in device_list: + device_id = mxrec_pybind.get_logic_id(int(device_idx)) + if device_id > 16: + raise ValueError(f"get logic id from physic id fail.") + index = device_list.index(device_idx) + self._rank_to_device_dict[index + 1] = device_id def insert_training_mode_channel_id(self, is_training): if is_training not in self._training_mode_channel_dict: @@ -352,9 +377,6 @@ class ConfigInitializer: def ascend_global_hashtable_collection(self, name): if not isinstance(name, str): raise TypeError(f"collection name '{name}' must be a string.") - if len(name) > HASHTABLE_COLLECTION_NAME_LENGTH: - raise ValueError(f"The length of the collection name '{name}' should be between " - f"[0, {HASHTABLE_COLLECTION_NAME_LENGTH}].") self._ascend_global_hashtable_collection = name def get_initializer(self, is_training): @@ -600,7 +622,7 @@ def set_initializer(is_training, initializer): def set_ascend_table_name_must_contain(name="merged"): - mx_rec.util.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name + mx_rec.constants.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name def set_ascend_env(): @@ -609,7 +631,6 @@ def set_ascend_env(): """ rank = get_rank_id() rank_size = get_rank_size() - local_rank_size = 8 os.environ["MOX_USE_NPU"] = "1" os.environ["FUSION_TENSOR_SIZE"] = "2000000000" diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py index 40c99f37..90bea48e 100644 --- a/mx_rec/util/ops.py +++ b/mx_rec/util/ops.py @@ -2,20 +2,27 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import os import logging +import os + import tensorflow as tf -from mx_rec.util.constants import HOST_PIPELINE_OPS_LIB_PATH +from mx_rec.constants.constants import HOST_PIPELINE_OPS_LIB_PATH def import_host_pipeline_ops(): host_pipeline_ops_lib_path = os.getenv(HOST_PIPELINE_OPS_LIB_PATH) - if host_pipeline_ops_lib_path: + if host_pipeline_ops_lib_path and os.path.exists(host_pipeline_ops_lib_path): logging.debug(f"Using the HOST_PIPELINE_OPS_LIB_PATH '{host_pipeline_ops_lib_path}' to get ops lib.") return tf.load_op_library(host_pipeline_ops_lib_path) + elif os.path.exists( + os.path.join(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), + 'mx_rec/libasc/libasc_ops.so')): + default_so_path = os.path.join( + os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), + 'mx_rec/libasc/libasc_ops.so') + logging.debug(f"Using the DEFAULT PATH '{default_so_path}' to get ops lib.") + return tf.load_op_library(default_so_path) else: - mx_rec_dir = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")) - so_path = os.path.join(mx_rec_dir, 'mx_rec/libasc/libasc_ops.so') - logging.debug(f"Using the DEFAULT PATH '{so_path}' to get ops lib.") - return tf.load_op_library(so_path) + raise ValueError("Invalid host pipeline ops lib path. Please check if libasc_ops.so exists or corrected " + "configured") diff --git a/mx_rec/validator/__init__.py b/mx_rec/validator/__init__.py new file mode 100644 index 00000000..8f75c6b6 --- /dev/null +++ b/mx_rec/validator/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py new file mode 100644 index 00000000..c18d989c --- /dev/null +++ b/mx_rec/validator/validator.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + + +import os +from typing import Callable, Any +from typing import List, Optional, Tuple + +from mx_rec.constants.constants import MIN_SIZE +from mx_rec.constants.constants import MAX_SIZE +from mx_rec.constants.constants import MAX_DEVICE_NUM +from mx_rec.constants.constants import MAX_RANK_SIZE +from mx_rec.constants.constants import MIN_DEVICE_NUM +from mx_rec.constants.constants import MIN_RANK_SIZE + + +class Validator: + """ + A validator to check the input parameters + """ + + def __init__(self, value, msg="value is invalid"): + """ + :param value: the value for validation + :param msg: default error msg + """ + self.value = value + self.msg = msg + self.checkers = [] + self.is_valid_state = None + + def register_checker(self, checker: Callable[[Any], bool], msg: str = None): + self.checkers.append((checker, msg if msg else self.msg)) + + def check(self): + if self.is_valid_state is None: + self.is_valid_state = True + for checker, msg in self.checkers: + self.is_valid_state &= checker(self.value) + if not self.is_valid_state: + self.msg = msg + break + if self.is_valid_state: + self.msg = None + return self + + def is_valid(self): + if self.is_valid_state is None: + self.check() + return self.is_valid_state + + def get_value(self, default=None): + return self.value if self.is_valid() else default + + +class ClassValidator(Validator): + """ + Check class validator. + """ + + def __init__(self, value, classes): + super().__init__(value) + self.classes = classes + + def check_isinstance(self): + """Check arg isinstance of classes""" + self.register_checker(lambda path: isinstance(self.value, self.classes), "Invalid parameter type") + return self + + +class StringValidator(Validator): + """ + String type validator. + """ + + def __init__(self, value, max_len=None, min_len=0): + super().__init__(value) + self.max_len = max_len + self.min_len = min_len + self.register_checker(lambda x: isinstance(x, str), "type is not str") + + def check_string_length(self): + if self.min_len is not None: + self.register_checker(lambda x: len(x) >= self.min_len, f"length is less than {self.min_len}") + if self.max_len is not None: + self.register_checker(lambda x: len(x) <= self.max_len, f"length is bigger than {self.max_len}") + return self + + def check_not_contain_black_element(self, element): + self.register_checker(lambda x: x is not None and element is not None and x.find(element) == -1) + return self + + def can_be_transformed2int(self, min_value: int = None, max_value: int = None): + if min_value is None: + min_value = MIN_RANK_SIZE + if max_value is None: + max_value = MAX_RANK_SIZE + + can_transformed = self.value.isdigit() + try: + if can_transformed and (min_value > int(self.value) or max_value < int(self.value)): + can_transformed = False + except ValueError: + can_transformed = False + finally: + if self.is_valid_state is not None: + self.is_valid_state &= can_transformed + else: + self.is_valid_state = can_transformed + return self + + +class IntValidator(Validator): + """ + Int type validator + """ + + def __init__(self, value: int, min_value: int = None, max_value: int = None): + super().__init__(value) + self.min_value = min_value + self.max_value = max_value + self.register_checker(lambda x: isinstance(x, int), "type is not int") + + def check_value(self): + if self.min_value is not None: + self.register_checker(lambda x: x >= self.min_value, f"value is less than {self.min_value}") + if self.max_value is not None: + self.register_checker(lambda x: x <= self.max_value, f"value is bigger than {self.max_value}") + return self + + +class RankSizeValidator(IntValidator): + """ + Distributed training job size validator + """ + + def check_rank_size_valid(self): + super().__init__(self.value) + self.register_checker(lambda x: MIN_RANK_SIZE <= self.value <= MAX_RANK_SIZE, + "Invalid rank size") + return self + + def check_device_num_valid(self): + super().__init__(self.value) + self.register_checker(lambda x: MIN_DEVICE_NUM <= self.value <= MAX_DEVICE_NUM, + "Invalid device num") + return self + + +class DirectoryValidator(StringValidator): + def __init__(self, value, max_len=None, min_len=1): + """ + @param value: the path, should not be emtpy string, should not contain double dot(../) + """ + super().__init__(value, max_len, min_len) + self.register_checker(lambda x: isinstance(x, str), "type is not str") + + @staticmethod + def remove_prefix(string: Optional[str], prefix: Optional[str]) -> Tuple[bool, Optional[str]]: + if string is None or prefix is None or len(string) < len(prefix): + return False, string + if string.startswith(prefix): + return True, string[len(prefix):] + else: + return False, string + + @staticmethod + def check_is_children_path(path_: str, target_: str): + if not target_: + return False + + try: + realpath_ = os.path.realpath(path_) + except (TypeError, ValueError, OSError): + return False + + try: + realpath_target = os.path.realpath(target_) + except (TypeError, ValueError, OSError): + return False + + is_prefix, rest_part = DirectoryValidator.remove_prefix(realpath_target, realpath_) + + if rest_part.startswith(os.path.sep): + rest_part = rest_part.lstrip(os.path.sep) + if is_prefix: + joint_path = os.path.join(realpath_, rest_part) + return os.path.realpath(joint_path) == realpath_target + else: + return False + + @staticmethod + def __check_with_sensitive_words(path: str, words: List): + _, name = os.path.split(path) + if name: + return not any(map(lambda x: x in path, words)) + else: + return True + + def check_is_not_none(self): + self.register_checker(lambda path: self.value is not None and len(self.value) > 0, + "Invalid directory parameter") + return self + + def check_not_soft_link(self): + self.register_checker(lambda path: os.path.realpath(self.value) == os.path.normpath(self.value), + "soft link or relative path should not be in the path parameter") + return self + + def path_should_exist(self, is_file=True, msg=None): + self.register_checker(lambda path: os.path.exists(self.value), + msg if msg else "path parameter does not exist") + if is_file: + self.register_checker(lambda path: os.path.isfile(self.value), + msg if msg else "path parameter is not a file") + return self + + def path_should_not_exist(self): + self.register_checker(lambda path: not os.path.exists(self.value), "path parameter does not exist") + return self + + def with_blacklist(self, lst: List = None, exact_compare: bool = True, msg: str = None): + if lst is None: + lst = ["/usr/bin", "/usr/sbin", "/etc", "/usr/lib", "/usr/lib64"] + if len(lst) == 0: + return self + if msg is None: + msg = "path should is in blacklist" + if exact_compare: + self.register_checker(lambda path: path not in [os.path.realpath(each) for each in lst], msg) + else: + self.register_checker( + lambda path: not any([DirectoryValidator.check_is_children_path(each, path) for each in lst]), msg + ) + return self + + def should_not_contains_sensitive_words(self, words: List = None, msg=None): + if words is None: + words = ["Key", "password", "privatekey"] + self.register_checker(lambda path: DirectoryValidator.__check_with_sensitive_words(path, words), msg) + return self + + +class FileValidator(StringValidator): + def __init__(self, value): + """ + @param value: the file path, should not be emtpy string, should not contain double dot(../) + """ + super().__init__(value) + self.register_checker(lambda x: isinstance(x, str), "type is not str") + + def check_file_size(self, max_size=MAX_SIZE, min_size=MIN_SIZE): + self.register_checker(lambda path: min_size < os.path.getsize(self.value) <= max_size, + "file size is invalid") + return self + + def check_not_soft_link(self): + self.register_checker(lambda path: os.path.realpath(self.value) == self.value, + "soft link or relative path should not be in the path parameter") + return self + + def check_user_group(self): + process_uid = os.geteuid() + process_gid = os.getegid() + stat_info = os.stat(self.value) + file_uid = stat_info.st_uid + file_gid = stat_info.st_gid + self.register_checker( + lambda path: process_uid == file_uid or process_gid == file_gid, "Invalid log file user or group.") + return self + + +class RankInfoValidator: + """ + Check replace rank table system environment configuration. + """ + @staticmethod + def check_visible_devices(): + visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") + device_res = StringValidator(visible_devices).check() + if not device_res: + raise TypeError("env variable ascend_visible_devices is null, please config ASCEND_VISIBLE_DEVICES in " + "docker container start.") + + rank_size = os.getenv("CM_WORKER_SIZE") + rank_size_res = StringValidator(rank_size).check() + if not rank_size_res: + raise TypeError("env variable CM_WORKER_SIZE is null, please config CM_WORKER_SIZE. For example, " + "CM_WORKER_SIZE=1") + + try: + rank_size_value = int(rank_size) + res = RankSizeValidator(rank_size_value, 1, 16).check_rank_size_valid() + if not res and rank_size_value not in [1, 2, 4, 8, 16]: + raise ValueError("Invalid rank size, rank size must between 0 and 15 in recommendation training.") + except ValueError as err: + raise ValueError("Invalid rank size, rank size is a valid integer.") from err + + chief_device = os.getenv("CM_CHIEF_DEVICE") + chief_device_res = StringValidator(chief_device).check() + if not chief_device_res: + raise TypeError("env variable CM_CHIEF_DEVICE is null, please config CM_CHIEF_DEVICE. For example, " + "CM_CHIEF_DEVICE=0") -- Gitee From 7f7cad981cca8d69752251d40c4d8c14ae7cb7c3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 1 Jun 2023 17:36:23 +0800 Subject: [PATCH 102/551] Match-id-d35c53d1c80360b339d35017fe10d021f78dcfa9 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 1e4b350d..d306c620 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -62,11 +62,6 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, spdlog::info(MGMT + "begin initialize, localRankSize:{}, localRankId {}, rank {}", rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); - bool rc = InitKeyProcess(rankInfo, embInfos, thresholdValues, ifLoad, seed); - if (!rc) { - return false; - } - mgmtRankInfo = rankInfo; mgmtEmbInfo = embInfos; skipUpdate = getenv("SKIP_UPDATE") != nullptr; @@ -74,6 +69,11 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, hdTransfer = Singleton::GetInstance(); hdTransfer->Init(embInfos, rankInfo.deviceId); + bool rc = InitKeyProcess(rankInfo, embInfos, thresholdValues, ifLoad, seed); + if (!rc) { + return false; + } + lookUpKeysQueue = make_unique>>(); restoreQueue = make_unique>>(); isRunning = true; -- Gitee From 94a6720015e4dc010dcdff16fe14f8a70d47833c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 1 Jun 2023 19:05:02 +0800 Subject: [PATCH 103/551] Match-id-aee5abbbc48f738195531e9c64b21f8d72992d62 --- mx_rec/core/asc/helper.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index e6ca9ef7..2a1bb30c 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -47,9 +47,6 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names def find_dangling_table(table_names): def check_tensor(table_name, table_reachable_tensor): - the_op = table_reachable_tensor.op - logging.info(f"** table_reachable_op:{the_op.outputs} {the_op.name} {the_op.type}**") - if table_reachable_tensor.op.type == 'ApplyAdam': return True -- Gitee From 4e2aea2ef69c44202bb5eab3596fa3a4dc67e545 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 2 Jun 2023 11:02:53 +0800 Subject: [PATCH 104/551] Match-id-b58bf97b34e93486a35f685f02a7c665ab32bfb4 --- mx_rec/core/asc/helper.py | 42 ++++++++++++++++++++++++++++++++++---- mx_rec/core/asc/manager.py | 4 ++++ 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 2a1bb30c..76d3e138 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -46,7 +46,20 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names def find_dangling_table(table_names): - def check_tensor(table_name, table_reachable_tensor): + """ Find the tables which are disconenct with the forward training graph. And + these table will not be backward updated. + + :param table_names: all created tables' names, which is a list + :return: A list of dangling table names. + """ + def check_tensor(table_reachable_tensor): + """Check whether the tensor op is optimizer op or backward gradient. + + Args: + table_reachable_tensor: tensor + Returns: + bool + """ if table_reachable_tensor.op.type == 'ApplyAdam': return True @@ -56,6 +69,15 @@ def find_dangling_table(table_names): return False def find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor): + """ find all the table lookup op. + :param table_name: a list of all created tables' names + :param the_op: the op to be + :param table_lookup_op: list of the table lookup ops + :param table_reachable_tensor: the tensors which table lookup op can reach ( + here we just add the table lookup op's output tensors). + The data structure is map, key is table_name, value is the output tensors of table lookup op. + :return: None + """ if table_name in the_op.name and the_op.type == "IdentityN": if table_name not in table_lookup_op: table_lookup_op[table_name] = [the_op] @@ -77,18 +99,30 @@ def find_dangling_table(table_names): dangling_table = [] def extend(op_list, tensor, spread_tensors): + """extend the tensors which table lookup op can reach + + :param op_list: all op in the graph + :param tensor: the tensor visited by bfs + :param spread_tensors: the list of tensors which table lookup op can reach + :return: + """ for the_op in op_list: if tensor in the_op.inputs: spread_tensors.extend(the_op.outputs) - def bfs_lookup(table_name, next_to_visit): + def bfs_lookup(next_to_visit): + """find all the tensors which table lookup op can reach + + :param next_to_visit: the tensor list to be visited by bfs + :return: bool value indicate whether reached optimizer op or backward gradient op + """ tensors_visited = set() while next_to_visit: spread_tensors = [] for tensor in next_to_visit: if tensor in tensors_visited: continue - if check_tensor(table_name, tensor): + if check_tensor(tensor): return True tensors_visited.add(tensor) extend(op_list, tensor, spread_tensors) @@ -96,7 +130,7 @@ def find_dangling_table(table_names): return False for table_name, table_op in table_reachable_tensor.items(): - found = bfs_lookup(table_name, table_op) + found = bfs_lookup(table_op) if not found: dangling_table.append(table_name) insert_dangling_table(table_name) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 53c6e06e..8c2ac8c9 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -15,6 +15,10 @@ from mx_rec.core.asc.helper import find_dangling_table def check_dangling_table(): + """ + If the dangling_table list is empty(maybe feature_spec mode), try to find again + :return: list of dangling_table + """ dangling_table = export_dangling_table() if not dangling_table: dangling_table = find_dangling_table([table_instance.table_name -- Gitee From f45ff7ba49d0e88973486f204da740f5d06e2aa3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 2 Jun 2023 14:16:35 +0800 Subject: [PATCH 105/551] Match-id-65e9b779c6a4fc89b38b6522dba58ff34522838d --- mx_rec/core/asc/helper.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 76d3e138..b6f2f161 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -4,8 +4,11 @@ import logging from functools import reduce - import tensorflow as tf +from typing import List +from typing import Dict +from tensorflow import Tensor +from tensorflow import Operation from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static, \ export_table_instances, insert_dangling_table @@ -45,14 +48,14 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names **kwargs) -def find_dangling_table(table_names): +def find_dangling_table(table_names: List[str]): """ Find the tables which are disconenct with the forward training graph. And these table will not be backward updated. - :param table_names: all created tables' names, which is a list - :return: A list of dangling table names. + :param table_names: list of all created tables' names + :return: a list of dangling table names. """ - def check_tensor(table_reachable_tensor): + def check_tensor(table_reachable_tensor: Tensor): """Check whether the tensor op is optimizer op or backward gradient. Args: @@ -68,9 +71,12 @@ def find_dangling_table(table_names): return False - def find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor): + def find_table_op(table_name: str, + the_op: Operation, + table_lookup_op: Dict[str,List[Operation]], + table_reachable_tensor: Dict[str,List[Tensor]]): """ find all the table lookup op. - :param table_name: a list of all created tables' names + :param table_name: tables' names :param the_op: the op to be :param table_lookup_op: list of the table lookup ops :param table_reachable_tensor: the tensors which table lookup op can reach ( @@ -98,7 +104,9 @@ def find_dangling_table(table_names): logging.info(f"*********** find tables: {table_lookup_op}***********") dangling_table = [] - def extend(op_list, tensor, spread_tensors): + def extend(op_list:List[Operation], + tensor: Tensor, + spread_tensors: List[Tensor]): """extend the tensors which table lookup op can reach :param op_list: all op in the graph @@ -110,7 +118,7 @@ def find_dangling_table(table_names): if tensor in the_op.inputs: spread_tensors.extend(the_op.outputs) - def bfs_lookup(next_to_visit): + def bfs_lookup(next_to_visit: List[Tensor]): """find all the tensors which table lookup op can reach :param next_to_visit: the tensor list to be visited by bfs -- Gitee From 1cbb11f43d3098ce5eef40b8134fbb67c9dcde06 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 2 Jun 2023 14:20:06 +0800 Subject: [PATCH 106/551] Match-id-3a892e367132cce2b69217485b5e9b593805af72 --- mx_rec/core/asc/helper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index b6f2f161..5d8cf292 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -4,9 +4,9 @@ import logging from functools import reduce -import tensorflow as tf from typing import List from typing import Dict +import tensorflow as tf from tensorflow import Tensor from tensorflow import Operation @@ -73,8 +73,8 @@ def find_dangling_table(table_names: List[str]): def find_table_op(table_name: str, the_op: Operation, - table_lookup_op: Dict[str,List[Operation]], - table_reachable_tensor: Dict[str,List[Tensor]]): + table_lookup_op: Dict[str, List[Operation]], + table_reachable_tensor: Dict[str, List[Tensor]]): """ find all the table lookup op. :param table_name: tables' names :param the_op: the op to be -- Gitee From 41a937de28b92128ba1dd0ee427b7dea0639acf8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 2 Jun 2023 18:08:38 +0800 Subject: [PATCH 107/551] Match-id-02fcdfa3610aa741b8df87e56dbdad7a5216ef8a --- mx_rec/core/asc/helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 5d8cf292..0c4688a2 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -92,7 +92,7 @@ def find_dangling_table(table_names: List[str]): table_lookup_op[table_name].append(the_op) table_reachable_tensor[table_name].extend(the_op.outputs) - op_list = tf.get_default_graph().get_operations() + op_list = tf.compat.v1.get_default_graph().get_operations() table_lookup_op = {} table_reachable_tensor = {} -- Gitee From 8f3bf981f7afa66b8e6d40a430482557c39d9b36 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 6 Jun 2023 15:25:18 +0800 Subject: [PATCH 108/551] Match-id-198905943f8b0a4ebca2eb407cfc8c58a1e67279 --- mx_rec/core/asc/manager.py | 5 ++++- mx_rec/core/embedding.py | 9 +++++++-- src/core/initializer/initializer.h | 1 + .../random_normal_initializer.cpp | 5 +++-- .../random_normal_initializer.h | 2 +- src/core/utils/common.cpp | 7 ++++--- src/core/utils/common.h | 3 ++- src/pybind/module_main.cpp | 6 ++++-- 8 files changed, 26 insertions(+), 12 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 8c2ac8c9..019d926e 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -110,11 +110,14 @@ def matched_emb_initializer(tabel_info): constant_val=tabel_info.emb_initializer.value)) elif initializer_case_map.get("tf1/tf2_random_normal_initializer"): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed + init_param = tabel_info.init_param + logging.debug(f"tabel_info.initK is {init_param}.") initializer = InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.scalar_emb_size, normal_initializer_info=NormalInitializerInfo( mean=tabel_info.emb_initializer.mean, stddev=tabel_info.emb_initializer.stddev, - seed=random_seed + seed=random_seed, + initK=init_param )) elif initializer_case_map.get("tf1_truncated_normal_initializer") or \ initializer_case_map.get("tf2_truncated_normal_initializer"): diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 331c1248..4445929e 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -41,6 +41,7 @@ def create_table(**kwargs): shard_num = kwargs.get("shard_num", 1) fusion_optimizer_var = kwargs.get("fusion_optimizer_var", True) hashtable_threshold = kwargs.get("hashtable_threshold", 0) + init_param = kwargs.get("init_param", 1.0) """ Args: @@ -57,12 +58,14 @@ def create_table(**kwargs): shard_num: embedding partition number fusion_optimizer_var: fusion optimizer variable with embedding hashtable_threshold: choose to implement based on hash table or linear layer + init_param: embedding init param-coefficient """ config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, optimizer_list=optimizer_list, mode=mode, value_dtype=value_dtype, shard_num=shard_num, - fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold) + fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold, + init_param=init_param) embedding = SparseEmbedding(config) return embedding @@ -156,6 +159,7 @@ class SparseEmbedding: self.send_count_map = dict() self.channel_name_dict = {True: [], False: []} self.modify_graph = False + self.init_param = config.get("init_param") self.set_slice_vocab_size() self.set_emb_size() @@ -698,7 +702,8 @@ class SparseEmbedding: f" {self.slice_host_vocabulary_size}.") def _initialize_variables(self): - initialized_tensor = self.emb_initializer(self.slice_device_vocabulary_size + self.embedding_size) + initialized_tensor = \ + self.emb_initializer(self.slice_device_vocabulary_size + self.embedding_size) * self.init_param self.variable = tf.compat.v1.get_variable(self.table_name, trainable=False, initializer=initialized_tensor) # make sure sparse table variable will not be saved and restored within tf checkpoint. remove_saving_var(self.variable) diff --git a/src/core/initializer/initializer.h b/src/core/initializer/initializer.h index 0fce164c..feec8729 100644 --- a/src/core/initializer/initializer.h +++ b/src/core/initializer/initializer.h @@ -20,6 +20,7 @@ namespace MxRec { virtual void GenerateData(float* emb, int embSize)= 0; int start; int len; + float initParam = 1.0; }; } diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index 7933555f..27af622c 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -11,9 +11,10 @@ using namespace MxRec; -RandomNormalInitializer::RandomNormalInitializer(int start, int len, float mean, float stddev, int seed) +RandomNormalInitializer::RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK) : start(start), len(len), mean(mean), stddev(stddev), seed(seed) { + initParam = initK; generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); } @@ -29,5 +30,5 @@ void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) start, len, embSize); return; } - std::generate_n(emb + start, len, [&]() { return distribution(generator); }); + std::generate_n(emb + start, len, [&]() { return initParam*distribution(generator); }); } \ No newline at end of file diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index d6c8b376..8f39315f 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -19,7 +19,7 @@ namespace MxRec { class RandomNormalInitializer : public Initializer { public: RandomNormalInitializer() = default; - RandomNormalInitializer(int start, int len, float mean, float stddev, int seed); + RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK); ~RandomNormalInitializer() override {}; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 2485d5ba..13d22345 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -53,8 +53,8 @@ namespace MxRec { : constantValue(constantValue) {} - NormalInitializerInfo::NormalInitializerInfo(float mean, float stddev, int seed) - : mean(mean), stddev(stddev), seed(seed) + NormalInitializerInfo::NormalInitializerInfo(float mean, float stddev, int seed, float initK) + : mean(mean), stddev(stddev), seed(seed), initK(initK) {} InitializeInfo::InitializeInfo(std::string& name, int start, int len, @@ -79,7 +79,8 @@ namespace MxRec { } else if (name == "random_normal_initializer") { initializerType = InitializerType::RANDOM_NORMAL; randomNormalInitializer = RandomNormalInitializer(start, len, - normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed); + normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed, + normalInitializerInfo.initK); } else { throw std::invalid_argument("Invalid Initializer Type."); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 41f26558..5bb6526d 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -298,11 +298,12 @@ struct BatchTask { struct NormalInitializerInfo { NormalInitializerInfo() = default; - NormalInitializerInfo(float mean, float stddev, int seed); + NormalInitializerInfo(float mean, float stddev, int seed, float initK); float mean; float stddev; int seed; + float initK = 1.0; }; struct InitializeInfo { diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 3839fe37..dd39bfed 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -145,10 +145,12 @@ void GetConstantInitializerInfo(pybind11::module_ &m) void GetNormalInitializerInfo(pybind11::module_ &m) { pybind11::class_(m, "NormalInitializerInfo") - .def(py::init(), py::arg("mean") = 0.0, py::arg("stddev") = 1.0, py::arg("seed") = 0) + .def(py::init(), py::arg("mean") = 0.0, py::arg("stddev") = 1.0, py::arg("seed") = 0, + py::arg("initK") = 1.0) .def_readwrite("mean", &NormalInitializerInfo::mean) .def_readwrite("stddev", &NormalInitializerInfo::stddev) - .def_readwrite("seed", &NormalInitializerInfo::seed); + .def_readwrite("seed", &NormalInitializerInfo::seed) + .def_readwrite("initK", &NormalInitializerInfo::initK); } void GetHybridMgmt(pybind11::module_& m) -- Gitee From 1e16f2da98c03047a42987e816daa6069bd567a4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 6 Jun 2023 16:31:11 +0800 Subject: [PATCH 109/551] Match-id-4f83959ac0fa31b50aba418b7274b59e7f994f15 --- mx_rec/core/asc/manager.py | 2 +- .../random_normal_initializer.cpp | 9 +++++---- .../random_normal_initializer.h | 3 ++- src/core/utils/common.cpp | 4 +--- src/tests/initializer/initializer_test.cpp | 3 ++- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 019d926e..2dd2f238 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -111,7 +111,7 @@ def matched_emb_initializer(tabel_info): elif initializer_case_map.get("tf1/tf2_random_normal_initializer"): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed init_param = tabel_info.init_param - logging.debug(f"tabel_info.initK is {init_param}.") + logging.debug(f"tabel: {tabel_info.table_name}, initK is {init_param}.") initializer = InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.scalar_emb_size, normal_initializer_info=NormalInitializerInfo( mean=tabel_info.emb_initializer.mean, diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index 27af622c..fbdfcb2e 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -11,10 +11,11 @@ using namespace MxRec; -RandomNormalInitializer::RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK) - : start(start), len(len), mean(mean), stddev(stddev), seed(seed) +RandomNormalInitializer::RandomNormalInitializer(int start, int len, NormalInitializerInfo normalInitializerInfo) + : start(start), len(len), mean(normalInitializerInfo.mean), stddev(normalInitializerInfo.stddev), + seed(normalInitializerInfo.seed) { - initParam = initK; + initParam = normalInitializerInfo.initK; generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); } @@ -30,5 +31,5 @@ void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) start, len, embSize); return; } - std::generate_n(emb + start, len, [&]() { return initParam*distribution(generator); }); + std::generate_n(emb + start, len, [&]() { return initParam * distribution(generator); }); } \ No newline at end of file diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index 8f39315f..560f8b77 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -12,6 +12,7 @@ #include #include "initializer/initializer.h" +#include "utils/common.h" namespace MxRec { using namespace std; @@ -19,7 +20,7 @@ namespace MxRec { class RandomNormalInitializer : public Initializer { public: RandomNormalInitializer() = default; - RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK); + RandomNormalInitializer(int start, int len, NormalInitializerInfo normalInitializerInfo); ~RandomNormalInitializer() override {}; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 13d22345..fb55959d 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -78,9 +78,7 @@ namespace MxRec { normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed); } else if (name == "random_normal_initializer") { initializerType = InitializerType::RANDOM_NORMAL; - randomNormalInitializer = RandomNormalInitializer(start, len, - normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed, - normalInitializerInfo.initK); + randomNormalInitializer = RandomNormalInitializer(start, len, normalInitializerInfo); } else { throw std::invalid_argument("Invalid Initializer Type."); } diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index 8f48ec66..a177765f 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -75,7 +75,8 @@ TEST(InitializerTest, TruncatedNormalInitializerTest) TEST(InitializerTest, RandomNormalInitializerTest) { - RandomNormalInitializer randomNormalInitializer(1, 10, 2.0, 0.5, 1); + NormalInitializerInfo normalInitializerInfo{2.0, 0.5, 1, 0.1}; + RandomNormalInitializer randomNormalInitializer(1, 10, normalInitializerInfo); vector> embData; int vocabSize = 5; -- Gitee From 82d65ab1bcffb642c353ca48da1f31e15a626cc7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 6 Jun 2023 17:01:42 +0800 Subject: [PATCH 110/551] Match-id-dd45b288b5bb2477522a939f0894ba09d8014fbc --- .../random_normal_initializer.cpp | 7 +++---- .../random_normal_initializer/random_normal_initializer.h | 3 +-- src/core/utils/common.cpp | 4 +++- src/tests/initializer/initializer_test.cpp | 3 +-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index fbdfcb2e..d832d875 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -11,11 +11,10 @@ using namespace MxRec; -RandomNormalInitializer::RandomNormalInitializer(int start, int len, NormalInitializerInfo normalInitializerInfo) - : start(start), len(len), mean(normalInitializerInfo.mean), stddev(normalInitializerInfo.stddev), - seed(normalInitializerInfo.seed) +RandomNormalInitializer::RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK) + : start(start), len(len), mean(mean), stddev(stddev), seed(seed) { - initParam = normalInitializerInfo.initK; + initParam = initK; generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); } diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index 560f8b77..d29ebdef 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -12,7 +12,6 @@ #include #include "initializer/initializer.h" -#include "utils/common.h" namespace MxRec { using namespace std; @@ -20,7 +19,7 @@ namespace MxRec { class RandomNormalInitializer : public Initializer { public: RandomNormalInitializer() = default; - RandomNormalInitializer(int start, int len, NormalInitializerInfo normalInitializerInfo); + RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, int initK); ~RandomNormalInitializer() override {}; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index fb55959d..13d22345 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -78,7 +78,9 @@ namespace MxRec { normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed); } else if (name == "random_normal_initializer") { initializerType = InitializerType::RANDOM_NORMAL; - randomNormalInitializer = RandomNormalInitializer(start, len, normalInitializerInfo); + randomNormalInitializer = RandomNormalInitializer(start, len, + normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed, + normalInitializerInfo.initK); } else { throw std::invalid_argument("Invalid Initializer Type."); } diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index a177765f..735fbf8a 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -75,8 +75,7 @@ TEST(InitializerTest, TruncatedNormalInitializerTest) TEST(InitializerTest, RandomNormalInitializerTest) { - NormalInitializerInfo normalInitializerInfo{2.0, 0.5, 1, 0.1}; - RandomNormalInitializer randomNormalInitializer(1, 10, normalInitializerInfo); + RandomNormalInitializer randomNormalInitializer(1, 10, 2.0, 0.5, 1, 0.1); vector> embData; int vocabSize = 5; -- Gitee From 65dc7196adb5e25016cc5cdef70f18a543712076 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 6 Jun 2023 17:16:43 +0800 Subject: [PATCH 111/551] Match-id-749602d310aba7087a2f350500e5d5a228cc94a8 --- .../random_normal_initializer/random_normal_initializer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index d29ebdef..8f39315f 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -19,7 +19,7 @@ namespace MxRec { class RandomNormalInitializer : public Initializer { public: RandomNormalInitializer() = default; - RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, int initK); + RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK); ~RandomNormalInitializer() override {}; -- Gitee From 5b2f1d5e3fd28dee657a01659441342e28387769 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 6 Jun 2023 19:32:07 +0800 Subject: [PATCH 112/551] Match-id-1cb50bb8f3d16c64782b7235c41a08298264f620 --- cust_op/cust_op_by_addr/CMakeLists.txt | 46 --- cust_op/cust_op_by_addr/CMakePresets.json | 47 --- cust_op/cust_op_by_addr/build.sh | 39 -- cust_op/cust_op_by_addr/cmake/config.cmake | 16 - cust_op/cust_op_by_addr/cmake/func.cmake | 217 ---------- cust_op/cust_op_by_addr/cmake/intf.cmake | 25 -- cust_op/cust_op_by_addr/cmake/makeself.cmake | 14 - .../cust_op_by_addr/cmake/util/__init__.py | 8 - .../cmake/util/batch_replay_impl.temp | 117 ------ .../cmake/util/code_channel_infer.py | 142 ------- .../cust_op_by_addr/cmake/util/const_var.py | 32 -- .../cmake/util/gen_impl_and_mrege_json.sh | 57 --- .../cmake/util/gen_ops_filter.sh | 61 --- .../cmake/util/insert_op_info.py | 36 -- .../cmake/util/insert_simplified_keys.py | 248 ------------ .../cmake/util/kernel_entry.py | 115 ------ .../cmake/util/kernel_impl.temp | 10 - .../cmake/util/merge_aicpu_info_json.sh | 31 -- .../cmake/util/opdesc_parser.py | 176 -------- .../cmake/util/parse_ini_to_json.py | 339 ---------------- .../cmake/util/preset_parse.py | 23 -- .../cmake/util/replay_codegen.py | 105 ----- .../cmake/util/replay_impl.temp | 120 ------ .../cmake/util/tik2_bin_param_build.py | 121 ------ .../cmake/util/tik2_impl_build.py | 376 ------------------ .../cmake/util/tik2_ops_config.py | 111 ------ .../cmake/util/tik2_replay_build.py | 65 --- .../cmake/util/tiling_data_def_build.py | 75 ---- cust_op/cust_op_by_addr/emb_custom.json | 90 +++++ .../cust_op_by_addr/framework/CMakeLists.txt | 11 - .../framework/tf_plugin/CMakeLists.txt | 8 - ...flow_embedding_lookup_by_address_plugin.cc | 13 - .../cust_op_by_addr/op_host/CMakeLists.txt | 35 -- cust_op/cust_op_by_addr/op_host/readme.md | 218 ---------- .../cust_op_by_addr/op_kernel/CMakeLists.txt | 80 ---- cust_op/cust_op_by_addr/readme.md | 197 --------- cust_op/cust_op_by_addr/run.sh | 44 ++ cust_op/cust_op_by_addr/scripts/install.sh | 228 ----------- cust_op/cust_op_by_addr/scripts/upgrade.sh | 121 ------ 39 files changed, 134 insertions(+), 3683 deletions(-) delete mode 100644 cust_op/cust_op_by_addr/CMakeLists.txt delete mode 100644 cust_op/cust_op_by_addr/CMakePresets.json delete mode 100644 cust_op/cust_op_by_addr/build.sh delete mode 100644 cust_op/cust_op_by_addr/cmake/config.cmake delete mode 100644 cust_op/cust_op_by_addr/cmake/func.cmake delete mode 100644 cust_op/cust_op_by_addr/cmake/intf.cmake delete mode 100644 cust_op/cust_op_by_addr/cmake/makeself.cmake delete mode 100644 cust_op/cust_op_by_addr/cmake/util/__init__.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/batch_replay_impl.temp delete mode 100644 cust_op/cust_op_by_addr/cmake/util/code_channel_infer.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/const_var.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/gen_impl_and_mrege_json.sh delete mode 100644 cust_op/cust_op_by_addr/cmake/util/gen_ops_filter.sh delete mode 100644 cust_op/cust_op_by_addr/cmake/util/insert_op_info.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/insert_simplified_keys.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/kernel_entry.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/kernel_impl.temp delete mode 100644 cust_op/cust_op_by_addr/cmake/util/merge_aicpu_info_json.sh delete mode 100644 cust_op/cust_op_by_addr/cmake/util/opdesc_parser.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/parse_ini_to_json.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/preset_parse.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/replay_codegen.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/replay_impl.temp delete mode 100644 cust_op/cust_op_by_addr/cmake/util/tik2_bin_param_build.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/tik2_impl_build.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/tik2_ops_config.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/tik2_replay_build.py delete mode 100644 cust_op/cust_op_by_addr/cmake/util/tiling_data_def_build.py create mode 100644 cust_op/cust_op_by_addr/emb_custom.json delete mode 100644 cust_op/cust_op_by_addr/framework/CMakeLists.txt delete mode 100644 cust_op/cust_op_by_addr/framework/tf_plugin/CMakeLists.txt delete mode 100644 cust_op/cust_op_by_addr/framework/tf_plugin/tensorflow_embedding_lookup_by_address_plugin.cc delete mode 100644 cust_op/cust_op_by_addr/op_host/CMakeLists.txt delete mode 100644 cust_op/cust_op_by_addr/op_host/readme.md delete mode 100644 cust_op/cust_op_by_addr/op_kernel/CMakeLists.txt delete mode 100644 cust_op/cust_op_by_addr/readme.md create mode 100644 cust_op/cust_op_by_addr/run.sh delete mode 100644 cust_op/cust_op_by_addr/scripts/install.sh delete mode 100644 cust_op/cust_op_by_addr/scripts/upgrade.sh diff --git a/cust_op/cust_op_by_addr/CMakeLists.txt b/cust_op/cust_op_by_addr/CMakeLists.txt deleted file mode 100644 index 2b50f0d9..00000000 --- a/cust_op/cust_op_by_addr/CMakeLists.txt +++ /dev/null @@ -1,46 +0,0 @@ -cmake_minimum_required(VERSION 3.14.0) -project(opp) - -include(cmake/config.cmake) -include(cmake/func.cmake) -include(cmake/intf.cmake) - -if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/framework) - add_subdirectory(framework) -endif() -if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/op_host) - add_subdirectory(op_host) -endif() -if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/op_kernel) - add_subdirectory(op_kernel) -endif() -if(ENABLE_TEST AND EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/testcases) - add_subdirectory(testcases) -endif() - -# modify vendor_name in install.sh and upgrade.sh -add_custom_command(OUTPUT ${CMAKE_BINARY_DIR}/scripts/install.sh ${CMAKE_BINARY_DIR}/scripts/upgrade.sh - COMMAND mkdir -p ${CMAKE_BINARY_DIR}/scripts - COMMAND cp -r ${CMAKE_SOURCE_DIR}/scripts/* ${CMAKE_BINARY_DIR}/scripts/ - COMMAND sed -i "s/vendor_name=customize/vendor_name=${vendor_name}/g" ${CMAKE_BINARY_DIR}/scripts/* -) -add_custom_target(modify_vendor ALL DEPENDS ${CMAKE_BINARY_DIR}/scripts/install.sh ${CMAKE_BINARY_DIR}/scripts/upgrade.sh) -install(DIRECTORY ${CMAKE_BINARY_DIR}/scripts/ DESTINATION . FILE_PERMISSIONS OWNER_EXECUTE OWNER_READ GROUP_READ) - -install(FILES ${CMAKE_SOURCE_DIR}/custom.proto DESTINATION packages OPTIONAL) - -get_system_info(SYSTEM_INFO) - -# CPack config -set(CPACK_PACKAGE_NAME ${CMAKE_PROJECT_NAME}) -set(CPACK_PACKAGE_VERSION ${CMAKE_PROJECT_VERSION}) -set(CPACK_PACKAGE_DESCRIPTION "CPack opp project") -set(CPACK_PACKAGE_DESCRIPTION_SUMMARY "CPack opp project") -set(CPACK_PACKAGE_DIRECTORY ${CMAKE_INSTALL_PREFIX}) -set(CPACK_PACKAGE_FILE_NAME "custom_opp_${SYSTEM_INFO}.run") -set(CPACK_GENERATOR External) -set(CPACK_CMAKE_GENERATOR "Unix Makefiles") -set(CPACK_EXTERNAL_ENABLE_STAGING TRUE) -set(CPACK_EXTERNAL_PACKAGE_SCRIPT ${CMAKE_SOURCE_DIR}/cmake/makeself.cmake) -set(CPACK_EXTERNAL_BUILT_PACKAGES ${CPACK_PACKAGE_DIRECTORY}/_CPack_Packages/Linux/External/${CPACK_PACKAGE_FILE_NAME}/${CPACK_PACKAGE_FILE_NAME}) -include(CPack) diff --git a/cust_op/cust_op_by_addr/CMakePresets.json b/cust_op/cust_op_by_addr/CMakePresets.json deleted file mode 100644 index bd4e93df..00000000 --- a/cust_op/cust_op_by_addr/CMakePresets.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "version": 1, - "cmakeMinimumRequired": { - "major": 3, - "minor": 19, - "patch": 0 - }, - "configurePresets": [ - { - "name": "default", - "displayName": "Default Config", - "description": "Default build using Unix Makefiles generator", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build_out", - "cacheVariables": { - "CMAKE_BUILD_TYPE": { - "type": "STRING", - "value": "Release" - }, - "ENABLE_SOURCE_PACKAGE": { - "type": "BOOL", - "value": "True" - }, - "ENABLE_BINARY_PACKAGE": { - "type": "BOOL", - "value": "False" - }, - "ENABLE_TEST": { - "type": "BOOL", - "value": "False" - }, - "vendor_name": { - "type": "STRING", - "value": "customize" - }, - "ASCEND_CANN_PACKAGE_PATH": { - "type": "PATH", - "value": "/usr/local/Ascend/ascend-toolkit/latest" - }, - "CMAKE_INSTALL_PREFIX": { - "type": "PATH", - "value": "${sourceDir}/build_out" - } - } - } - ] -} diff --git a/cust_op/cust_op_by_addr/build.sh b/cust_op/cust_op_by_addr/build.sh deleted file mode 100644 index f38465f9..00000000 --- a/cust_op/cust_op_by_addr/build.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash -script_path=$(realpath $(dirname $0)) - - -mkdir -p build_out -rm -rf build_out/* -cd build_out - -chmod +x $script_path/cmake/util/gen_ops_filter.sh - -cmake_version=$(cmake --version | grep "cmake version" | awk '{print $3}') -if [ "$cmake_version" \< "3.19.0" ] ; then - opts=$(python3 $script_path/cmake/util/preset_parse.py $script_path/CMakePresets.json) - echo $opts - cmake .. $opts -else - cmake .. --preset=default -fi -target=package -if [ "$1"x != ""x ]; then target=$1; fi - -cmake --build . --target $target -j16 -if [ $? -ne 0 ]; then exit 1; fi - -if [ $target = "package" ]; then - if test -d ./op_kernel/binary ; then - ./cust*.run - if [ $? -ne 0 ]; then exit 1; fi - cmake --build . --target binary -j16 - if [ $? -ne 0 ]; then exit 1; fi - cmake --build . --target $target -j16 - fi -fi - -# for debug -# cd build_out -# make -# cpack -# verbose append -v diff --git a/cust_op/cust_op_by_addr/cmake/config.cmake b/cust_op/cust_op_by_addr/cmake/config.cmake deleted file mode 100644 index c6f09290..00000000 --- a/cust_op/cust_op_by_addr/cmake/config.cmake +++ /dev/null @@ -1,16 +0,0 @@ - -set(CMAKE_CXX_FLAGS_DEBUG "") -set(CMAKE_CXX_FLAGS_RELEASE "") - -if (NOT DEFINED vendor_name) - set(vendor_name customize CACHE STRING "") -endif() -if (NOT DEFINED ASCEND_CANN_PACKAGE_PATH) - set(ASCEND_CANN_PACKAGE_PATH /usr/local/Ascend/latest CACHE PATH "") -endif() -set(ASCEND_TENSOR_COMPILER_PATH ${ASCEND_CANN_PACKAGE_PATH}/compiler) -set(ASCEND_CCEC_COMPILER_PATH ${ASCEND_TENSOR_COMPILER_PATH}/ccec_compiler/bin) -set(ASCEND_AUTOGEN_PATH ${CMAKE_BINARY_DIR}/autogen) -set(ASCEND_COMPUTE_UNIT ascend910 ascend910b) -set(ASCEND_FRAMEWORK_TYPE tensorflow) -file(MAKE_DIRECTORY ${ASCEND_AUTOGEN_PATH}) diff --git a/cust_op/cust_op_by_addr/cmake/func.cmake b/cust_op/cust_op_by_addr/cmake/func.cmake deleted file mode 100644 index 69f2208d..00000000 --- a/cust_op/cust_op_by_addr/cmake/func.cmake +++ /dev/null @@ -1,217 +0,0 @@ - -function(get_system_info SYSTEM_INFO) - if (UNIX) - execute_process(COMMAND grep -i ^id= /etc/os-release OUTPUT_VARIABLE TEMP) - string(REGEX REPLACE "\n|id=|ID=|\"" "" SYSTEM_NAME ${TEMP}) - set(${SYSTEM_INFO} ${SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR} PARENT_SCOPE) - elseif (WIN32) - message(STATUS "System is Windows. Only for pre-build.") - else () - message(FATAL_ERROR "${CMAKE_SYSTEM_NAME} not support.") - endif () -endfunction() - -function(opbuild) - message(STATUS "Opbuild generating sources") - cmake_parse_arguments(OPBUILD "" "OUT_DIR;PROJECT_NAME;ACCESS_PREFIX" "OPS_SRC" ${ARGN}) - execute_process(COMMAND ${CMAKE_CXX_COMPILER} -g -fPIC -shared -std=c++11 ${OPBUILD_OPS_SRC} -D_GLIBCXX_USE_CXX11_ABI=0 - -I ${ASCEND_CANN_PACKAGE_PATH}/include -L ${ASCEND_CANN_PACKAGE_PATH}/lib64 -lexe_graph -lregister - -o ${OPBUILD_OUT_DIR}/libascend_all_ops.so - RESULT_VARIABLE EXEC_RESULT - OUTPUT_VARIABLE EXEC_INFO - ERROR_VARIABLE EXEC_ERROR - ) - if (${EXEC_RESULT}) - message("build ops lib info: ${EXEC_INFO}") - message("build ops lib error: ${EXEC_ERROR}") - message(FATAL_ERROR "opbuild run failed!") - endif() - set(proj_env "") - set(prefix_env "") - if (NOT "${OPBUILD_PROJECT_NAME}x" STREQUAL "x") - set(proj_env "OPS_PROJECT_NAME=${OPBUILD_PROJECT_NAME}") - endif() - if (NOT "${OPBUILD_ACCESS_PREFIX}x" STREQUAL "x") - set(prefix_env "OPS_DIRECT_ACCESS_PREFIX=${OPBUILD_ACCESS_PREFIX}") - endif() - execute_process(COMMAND ${proj_env} ${prefix_env} ${ASCEND_CANN_PACKAGE_PATH}/toolkit/tools/opbuild/op_build - ${OPBUILD_OUT_DIR}/libascend_all_ops.so ${OPBUILD_OUT_DIR} - RESULT_VARIABLE EXEC_RESULT - OUTPUT_VARIABLE EXEC_INFO - ERROR_VARIABLE EXEC_ERROR - ) - if (${EXEC_RESULT}) - message("opbuild ops info: ${EXEC_INFO}") - message("opbuild ops error: ${EXEC_ERROR}") - endif() - message(STATUS "Opbuild generating sources - done") -endfunction() - -function(add_ops_info_target) - cmake_parse_arguments(OPINFO "" "TARGET;OPS_INFO;OUTPUT;INSTALL_DIR" "" ${ARGN}) - get_filename_component(opinfo_file_path "${OPINFO_OUTPUT}" DIRECTORY) - add_custom_command(OUTPUT ${OPINFO_OUTPUT} - COMMAND mkdir -p ${opinfo_file_path} - COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/parse_ini_to_json.py - ${OPINFO_OPS_INFO} ${OPINFO_OUTPUT} - ) - add_custom_target(${OPINFO_TARGET} ALL - DEPENDS ${OPINFO_OUTPUT} - ) - install(FILES ${OPINFO_OUTPUT} - DESTINATION ${OPINFO_INSTALL_DIR} - ) -endfunction() - -function(add_ops_impl_target) - cmake_parse_arguments(OPIMPL "" "TARGET;OPS_INFO;IMPL_DIR;OUT_DIR;INSTALL_DIR" "OPS_BATCH;OPS_ITERATE" ${ARGN}) - add_custom_command(OUTPUT ${OPIMPL_OUT_DIR}/.impl_timestamp - COMMAND mkdir -p ${OPIMPL_OUT_DIR}/dynamic - COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/tik2_impl_build.py - ${OPIMPL_OPS_INFO} - \"${OPIMPL_OPS_BATCH}\" \"${OPIMPL_OPS_ITERATE}\" - ${OPIMPL_IMPL_DIR} - ${OPIMPL_OUT_DIR}/dynamic - COMMAND rm -rf ${OPIMPL_OUT_DIR}/.impl_timestamp - COMMAND touch ${OPIMPL_OUT_DIR}/.impl_timestamp - DEPENDS ${OPIMPL_OPS_INFO} - ${CMAKE_SOURCE_DIR}/cmake/util/tik2_impl_build.py - ) - add_custom_target(${OPIMPL_TARGET} ALL - DEPENDS ${OPIMPL_OUT_DIR}/.impl_timestamp) - if (${ENABLE_SOURCE_PACKAGE}) - install(DIRECTORY ${OPIMPL_OUT_DIR}/dynamic - DESTINATION ${OPIMPL_INSTALL_DIR} - ) - endif() -endfunction() - -function(add_ops_replay_targets) - cmake_parse_arguments(OPREPLAY "" "OPS_INFO;COMPUTE_UNIT;IMPL_DIR;OUT_DIR;INSTALL_DIR" "OPS_BATCH;OPS_ITERATE" ${ARGN}) - # ccec compile options - set(ccec_base_opts -c -O2 --cce-aicore-only -mllvm -cce-aicore-function-stack-size=16000 - -mllvm -cce-aicore-record-overflow=false -std=c++17) - set(ccec_extopts_ascend310p --cce-aicore-arch=dav-m200 -mllvm -cce-aicore-fp-ceiling=2) - set(ccec_extopts_ascend910 --cce-aicore-arch=dav-c100) - set(ccec_extopts_ascend910b --cce-aicore-arch=dav-c220-cube) - file(MAKE_DIRECTORY ${OPREPLAY_OUT_DIR}) - execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/tik2_replay_build.py - ${OPREPLAY_OPS_INFO} - "${OPREPLAY_OPS_BATCH}" "${OPREPLAY_OPS_ITERATE}" - ${OPREPLAY_IMPL_DIR} - ${OPREPLAY_OUT_DIR} - ${OPREPLAY_COMPUTE_UNIT} - ) - file(GLOB replay_kernel_entries ${OPREPLAY_OUT_DIR}/*.cce) - if (NOT "${replay_kernel_entries}x" STREQUAL "x") - foreach(replay_kernel_file ${replay_kernel_entries}) - get_filename_component(replay_kernel_file_name "${replay_kernel_file}" NAME) - string(REPLACE "_entry.cce" "" op_kerne_name ${replay_kernel_file_name}) - file(GLOB replay_lib_src ${OPREPLAY_OUT_DIR}/${op_kerne_name}*.cpp) - set(OP_TILING_DATA_H_PATH ${OPREPLAY_OUT_DIR}/${op_kerne_name}_tiling_data.h) - add_library(replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} SHARED ${replay_lib_src}) - if(EXISTS ${OP_TILING_DATA_H_PATH}) - target_compile_options(replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} PRIVATE - -include ${OP_TILING_DATA_H_PATH} - ) - endif() - target_compile_definitions(replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} PRIVATE - ${op_kerne_name}=${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} - ) - target_link_libraries(replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} PRIVATE intf_pub - tikreplaylib::${OPREPLAY_COMPUTE_UNIT} - register - ) - add_custom_command(OUTPUT ${OPREPLAY_OUT_DIR}/${op_kerne_name}_entry_${OPREPLAY_COMPUTE_UNIT}.o - COMMAND ccec ${ccec_base_opts} ${ccec_extopts_${OPREPLAY_COMPUTE_UNIT}} ${replay_kernel_file} - -o ${OPREPLAY_OUT_DIR}/${op_kerne_name}_entry_${OPREPLAY_COMPUTE_UNIT}.o - DEPENDS ${replay_kernel_file} - ) - add_custom_target(replay_kernel_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} ALL - DEPENDS ${OPREPLAY_OUT_DIR}/${op_kerne_name}_entry_${OPREPLAY_COMPUTE_UNIT}.o - ) - install(TARGETS replay_${op_kerne_name}_${OPREPLAY_COMPUTE_UNIT} - LIBRARY DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_replay - ) - install(FILES ${OPREPLAY_OUT_DIR}/${op_kerne_name}_entry_${OPREPLAY_COMPUTE_UNIT}.o - DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_replay - ) - endforeach() - endif() -endfunction() - -function(add_npu_support_target) - cmake_parse_arguments(NPUSUP "" "TARGET;OPS_INFO_DIR;OUT_DIR;INSTALL_DIR" "" ${ARGN}) - get_filename_component(npu_sup_file_path "${NPUSUP_OUT_DIR}" DIRECTORY) - add_custom_command(OUTPUT ${NPUSUP_OUT_DIR}/npu_supported_ops.json - COMMAND mkdir -p ${NPUSUP_OUT_DIR} - COMMAND ${CMAKE_SOURCE_DIR}/cmake/util/gen_ops_filter.sh - ${NPUSUP_OPS_INFO_DIR} - ${NPUSUP_OUT_DIR} - ) - add_custom_target(npu_supported_ops ALL - DEPENDS ${NPUSUP_OUT_DIR}/npu_supported_ops.json - ) - install(FILES ${NPUSUP_OUT_DIR}/npu_supported_ops.json - DESTINATION ${NPUSUP_INSTALL_DIR} - ) -endfunction() - -function(add_bin_compile_target) - cmake_parse_arguments(BINCMP "" "TARGET;OPS_INFO;COMPUTE_UNIT;IMPL_DIR;ADP_DIR;OUT_DIR;INSTALL_DIR" "" ${ARGN}) - file(MAKE_DIRECTORY ${BINCMP_OUT_DIR}/src) - file(MAKE_DIRECTORY ${BINCMP_OUT_DIR}/bin) - file(MAKE_DIRECTORY ${BINCMP_OUT_DIR}/gen) - execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/tik2_bin_param_build.py - ${BINCMP_OPS_INFO} ${BINCMP_OUT_DIR}/gen ${BINCMP_COMPUTE_UNIT} - RESULT_VARIABLE EXEC_RESULT - OUTPUT_VARIABLE EXEC_INFO - ERROR_VARIABLE EXEC_ERROR - ) - if (${EXEC_RESULT}) - message("ops binary compile scripts gen info: ${EXEC_INFO}") - message("ops binary compile scripts gen error: ${EXEC_ERROR}") - message(FATAL_ERROR "ops binary compile scripts gen failed!") - endif() - if (NOT TARGET binary) - add_custom_target(binary) - endif() - add_custom_target(${BINCMP_TARGET} - COMMAND cp ${BINCMP_IMPL_DIR}/*.cpp ${BINCMP_OUT_DIR}/src - ) - add_custom_target(${BINCMP_TARGET}_gen_ops_config - COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/insert_simplified_keys.py -p ${BINCMP_OUT_DIR}/bin - COMMAND python3 ${CMAKE_SOURCE_DIR}/cmake/util/tik2_ops_config.py -p ${BINCMP_OUT_DIR}/bin - -s ${BINCMP_COMPUTE_UNIT} - ) - add_dependencies(binary ${BINCMP_TARGET}_gen_ops_config) - file(GLOB bin_scripts ${BINCMP_OUT_DIR}/gen/*.sh) - foreach(bin_script ${bin_scripts}) - get_filename_component(bin_file ${bin_script} NAME_WE) - string(REPLACE "-" ";" bin_sep ${bin_file}) - list(GET bin_sep 0 op_type) - list(GET bin_sep 1 op_file) - list(GET bin_sep 2 op_index) - if (NOT TARGET ${BINCMP_TARGET}_${op_file}_copy) - file(MAKE_DIRECTORY ${BINCMP_OUT_DIR}/bin/${op_file}) - add_custom_target(${BINCMP_TARGET}_${op_file}_copy - COMMAND cp ${BINCMP_ADP_DIR}/${op_file}.py ${BINCMP_OUT_DIR}/src/${op_type}.py - ) - install(DIRECTORY ${BINCMP_OUT_DIR}/bin/${op_file} - DESTINATION ${BINCMP_INSTALL_DIR}/${BINCMP_COMPUTE_UNIT} OPTIONAL - ) - install(FILES ${BINCMP_OUT_DIR}/bin/${op_file}.json - DESTINATION ${BINCMP_INSTALL_DIR}/config/${BINCMP_COMPUTE_UNIT}/ OPTIONAL - ) - endif() - add_custom_target(${BINCMP_TARGET}_${op_file}_${op_index} - COMMAND bash ${bin_script} ${BINCMP_OUT_DIR}/src/${op_type}.py ${BINCMP_OUT_DIR}/bin/${op_file} - WORKING_DIRECTORY ${BINCMP_OUT_DIR} - ) - add_dependencies(${BINCMP_TARGET}_${op_file}_${op_index} ${BINCMP_TARGET} ${BINCMP_TARGET}_${op_file}_copy) - add_dependencies(${BINCMP_TARGET}_gen_ops_config ${BINCMP_TARGET}_${op_file}_${op_index}) - endforeach() - install(FILES ${BINCMP_OUT_DIR}/bin/binary_info_config.json - DESTINATION ${BINCMP_INSTALL_DIR}/config/${BINCMP_COMPUTE_UNIT} OPTIONAL - ) -endfunction() diff --git a/cust_op/cust_op_by_addr/cmake/intf.cmake b/cust_op/cust_op_by_addr/cmake/intf.cmake deleted file mode 100644 index 1c54b6ea..00000000 --- a/cust_op/cust_op_by_addr/cmake/intf.cmake +++ /dev/null @@ -1,25 +0,0 @@ - -add_library(intf_pub INTERFACE) -target_compile_options(intf_pub INTERFACE - -fPIC - -fvisibility=hidden - -fvisibility-inlines-hidden - $<$:-O2 -s> - $<$:-O0 -g> - $<$:-std=c++11> - $<$,$>:-ftrapv -fstack-check> - $<$:-pthread -Wfloat-equal -Wshadow -Wformat=2 -Wno-deprecated -Wextra> - $,-fstack-protector-strong,-fstack-protector-all> -) -target_compile_definitions(intf_pub INTERFACE - _GLIBCXX_USE_CXX11_ABI=0 - $<$:_FORTIFY_SOURCE=2> -) -target_include_directories(intf_pub INTERFACE ${ASCEND_CANN_PACKAGE_PATH}/include) -target_link_options(intf_pub INTERFACE - $<$,EXECUTABLE>:-pie> - -Wl,-z,relro - -Wl,-z,now - -Wl,-z,noexecstack -) -target_link_directories(intf_pub INTERFACE ${ASCEND_CANN_PACKAGE_PATH}/lib64) diff --git a/cust_op/cust_op_by_addr/cmake/makeself.cmake b/cust_op/cust_op_by_addr/cmake/makeself.cmake deleted file mode 100644 index 18bdc331..00000000 --- a/cust_op/cust_op_by_addr/cmake/makeself.cmake +++ /dev/null @@ -1,14 +0,0 @@ -execute_process(COMMAND chmod +x ${CMAKE_CURRENT_LIST_DIR}/util/makeself/makeself.sh) -execute_process(COMMAND ${CMAKE_CURRENT_LIST_DIR}/util/makeself/makeself.sh --gzip --complevel 4 --nomd5 --sha256 - ./ ${CPACK_PACKAGE_FILE_NAME} "version:1.0" ./install.sh - WORKING_DIRECTORY ${CPACK_TEMPORARY_DIRECTORY} - RESULT_VARIABLE EXEC_RESULT - ERROR_VARIABLE EXEC_ERROR -) -if (NOT "${EXEC_RESULT}x" STREQUAL "0x") - message(FATAL_ERROR "CPack Command error: ${EXEC_RESULT}\n${EXEC_ERROR}") -endif() -execute_process(COMMAND cp ${CPACK_EXTERNAL_BUILT_PACKAGES} ${CPACK_PACKAGE_DIRECTORY}/ - COMMAND echo "Copy ${CPACK_EXTERNAL_BUILT_PACKAGES} to ${CPACK_PACKAGE_DIRECTORY}/" - WORKING_DIRECTORY ${CPACK_TEMPORARY_DIRECTORY} -) diff --git a/cust_op/cust_op_by_addr/cmake/util/__init__.py b/cust_op/cust_op_by_addr/cmake/util/__init__.py deleted file mode 100644 index c4ddc893..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import sys -import os - -PYF_PATH = os.path.dirname(os.path.realpath(__file__)) -sys.path.append(PYF_PATH) diff --git a/cust_op/cust_op_by_addr/cmake/util/batch_replay_impl.temp b/cust_op/cust_op_by_addr/cmake/util/batch_replay_impl.temp deleted file mode 100644 index 7b4f5edf..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/batch_replay_impl.temp +++ /dev/null @@ -1,117 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include "replay_def.h" -#include "code_gen.h" -#include "replay_fun.h" -#include "register/op_check.h" -#define __TIK2_REPLAY_CODE__ -#include - -using namespace std; -using namespace optiling; -using namespace tik2_replay; - -extern "C" void __KERNEL_FUN__ (__ARGS_DEF__, const char *); -extern "C" int elf_batch_append(char *elf, uint32_t elfSize, char *jit, int kernum, char *atext[], int alen[], - int atlen, const char* kernelname[]); - -#define KERNEL_N 1 -#define ARG_N (__ARG_NUM__) -#define MAX_L (1024 * 1024 * 100) -#define MAX_E (1024 * 1024) - -int __KERNEL_FUN___replay___OPS_PRODUCT__(ReplayFuncParam& param, const int core_type) -{ - // gen type 1 : direct call codes 0: load .o file - if (param.gentype < 0 || param.gentype > 1) { - printf("Error: call replay gen type is %d, should only be 1 or 0\n", param.gentype); - return 0; - } else if (param.gentype == 1 && param.objptr == nullptr) { - printf("Error: call replay with direct call mode, but code obj addr is null\n"); - return 0; - } else if (param.gentype == 0 && param.output_kernel_file == nullptr) { - printf("Error: call replay with object file mode, but object file path is null\n"); - return 0; - } - // core_type 0:MIX 1:CUBE 2:VEC - if (core_type < 0 || core_type > 2) { - printf("Error: call replay core type is %d !\n", core_type); - return 0; - } - g_coreType = __CORE_TYPE__; - g_taskRation = param.task_ration; - g_tilingKey = param.tiling_key; - - unsigned char *buf, *jit; - char *kernel[KERNEL_N]; - int len[KERNEL_N]; - block_idx = 0; - block_num = param.block_dim; - g_ubBase = block_num; - uint8_t *code = (uint8_t *)malloc(MAX_L); - uint8_t *pos = code; - struct timespec tp1, tp2; - - clock_gettime(CLOCK_MONOTONIC, &tp1); - if (block_num > 32) { - printf("Error: block_num > 32\n"); - return 0; - } - //__OP_FOPEN__ - for (int i = 0; i < KERNEL_N; i++) { - //__OP_SET_KERNEL__ - for (int j = 0; j < ARG_N; j++) - AddArg(j, ARG_STEP * (j + 1)); -#ifdef FP_CEILING - SetCtrlFloatEnable(); -#else - SetCtrlFloatDisable(); -#endif - CodeInit(pos, true); - __KERNEL_FUN__(__KERNEL_ARGS__, param.tiling_data); - CodeEnd(); - kernel[i] = (char *)pos; - len[i] = CodeLen(); - pos += len[i]; - } - //__OP_FCLOSE__ - clock_gettime(CLOCK_MONOTONIC, &tp2); - buf = (unsigned char *)malloc(MAX_E); - int fd = open(param.entry_file, O_RDONLY); - if (fd < 0) { - printf("[error]: cannot find entry.o : %s\n", param.entry_file); - return 0; - } - uint32_t bufSize = read(fd, buf, MAX_E); - if (bufSize <= 0) { - printf("[error]: entry.o : %s is too small ! \n", param.entry_file); - } - close(fd); - jit = (unsigned char *)malloc(MAX_L); - printf("total code generated %ld\n", pos - code); - int sz = elf_batch_append((char *)buf, bufSize, (char *)jit, KERNEL_N, kernel, len, pos - code, ¶m.kernel_name); - if (tp1.tv_sec != tp2.tv_sec) { - printf("%ld NS\n", tp2.tv_nsec + 1000000000 - tp1.tv_nsec); - } else { - printf("%ld NS\n", tp2.tv_nsec - tp1.tv_nsec); - } - printf("new elf size %d\n", sz); - if (param.gentype == 0) { - fd = open(param.output_kernel_file, O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR); - (void)write(fd, jit, sz); - close(fd); - free(jit); - } else if (param.gentype == 1) { - *param.objptr = (char*)jit; - } - free(buf); - free(code); - return sz; -} - -REG_REPLAY_FUNC(__OPTYPE__, __OPS_PRODUCT__, __KERNEL_FUN___replay___OPS_PRODUCT__); diff --git a/cust_op/cust_op_by_addr/cmake/util/code_channel_infer.py b/cust_op/cust_op_by_addr/cmake/util/code_channel_infer.py deleted file mode 100644 index 49ce5e52..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/code_channel_infer.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" -import os -import stat -import ctypes -import collections -import shutil -import subprocess -import copy - -"""CODE_* is used to cube/vector api is called in operator code -CODE_MIX means both cube and vector api is called -CODE_CUBE means only cube api is called -CODE_VEC means only vector api is called -""" -CODE_MIX = 0 -CODE_CUBE = 1 -CODE_VEC = 2 - - -def _is_v220(op_product: str): - """return if current soc version is V220 - - Returns: - res: True means V220 - """ - if op_product in ["ascend910b", "ascend910c"]: - return True - return False - -CheckCoreTypeParams = collections.namedtuple('CheckCoreTypeParams',\ -['src_file', 'arch', 'kernel_name', 'compile_options', 'addrspace_list', 'outdir']) - - -def _check_core_type(check_core_type_params: CheckCoreTypeParams): - """1. call ccec -S -emit-llvm to generate llvm-ir file - 2. analysis addrspace to check if exists cube or vector buffer scope - - Args: - CheckCoreTypeParams: - src_file (str): TIK2 operator code file - arch (str): _description_ - kernel_name (str): kernel function name - compile_options (list): compile options for ccec cmd - addrspace_list (list): addrspace of cube or vector - outdir(str): temp file output - - Returns: - res (bool): True if exists target addrspapce of arch - """ - llvm_ir_file = os.path.join(check_core_type_params.outdir, check_core_type_params.kernel_name + "_" +\ - check_core_type_params.arch + ".ll") - compile_cmd = [shutil.which("ccec"), '-S', '-emit-llvm', '-std=c++17', '-x', 'cce',\ - check_core_type_params.src_file] - compile_cmd += check_core_type_params.compile_options - - compile_cmd += ["--cce-aicore-arch=%s" % check_core_type_params.arch, - "--cce-aicore-only", "-o", llvm_ir_file] - proc = subprocess.Popen( - compile_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - (out, _) = proc.communicate() - if proc.returncode != 0: - msg = "compile %s error :%s\n" % (check_core_type_params.src_file, out.decode()) - print("check core type ", msg) - return False - - def _check_exist_space(line_content, space_list): - if "addrspace" not in line_content: - return False - for space in space_list: - if space in line_content: - return True - return False - - access_space = False - with open(llvm_ir_file) as llvm_ir: - line_list = llvm_ir.readlines() - for line in line_list: - if access_space: - break - access_space = _check_exist_space(line, check_core_type_params.addrspace_list) - os.remove(llvm_ir_file) - return access_space - - -InfoCodeChanelParams = collections.namedtuple('InfoCodeChanelParams',\ -['src_file', 'tiling_header', 'kernel_name', 'outdir', 'op_product', 'compile_options']) - - -def infer_code_channel(params: InfoCodeChanelParams): - """get code channel for v220, return CODE_MIX if soc version is not V220 - - Args: - src_file (str): TIK2 operator code file - src_file (str): TIK2 operator tiling header file - kernel_name (str): kernel function name - optype (str): operator type - compile_options (list): compile options for ccec cmd - - Raises: - Exception: if not exist L1/L0/UB if code, it's not a aicore code - - Returns: - res (int): CODE_MIX/CODE_CUBE/CODE_VEC - """ - if not _is_v220(params.op_product): - return CODE_MIX - if params.compile_options is None: - compile_options = [] - else: - compile_options = params.compile_options - ccec = shutil.which("ccec") - if ccec is not None: - ccec_path = os.path.dirname(ccec) - tikcpp_path = os.path.realpath(os.path.join(ccec_path, "..", "..", "tikcpp")) - else: - tikcpp_path = os.path.realpath("/usr/local/Ascend/latest/compiler/tikcpp") - compile_options.append("-I" + tikcpp_path) - compile_options.append("-I" + os.path.join(tikcpp_path, "tikcfw")) - compile_options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "impl")) - compile_options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "interface")) - compile_options.append("-D__NPU_TILING__") - compile_options += ["-include", params.tiling_header] - cube_addrspace_list = ["addrspace(2)", "addrspace(3)", "addrspace(4)", "addrspace(5)"] - access_l1_l0a = _check_core_type(CheckCoreTypeParams(params.src_file, "dav-c220-cube", params.kernel_name,\ - compile_options, cube_addrspace_list, params.outdir)) - cube_addrspace_list = ["addrspace(6)"] - access_ub = _check_core_type(CheckCoreTypeParams(params.src_file, "dav-c220-vec", params.kernel_name,\ - compile_options, cube_addrspace_list, params.outdir)) - - if access_l1_l0a and access_ub: - return CODE_MIX - elif access_l1_l0a: - return CODE_CUBE - elif access_ub: - return CODE_VEC - else: - raise Exception(f"cannot find valid addrspace in (2,3,4,5,6) in {params.src_file}") diff --git a/cust_op/cust_op_by_addr/cmake/util/const_var.py b/cust_op/cust_op_by_addr/cmake/util/const_var.py deleted file mode 100644 index f5dde656..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/const_var.py +++ /dev/null @@ -1,32 +0,0 @@ - -#!/usr/bin/env python -# coding=utf-8 -""" -Function: -The replay funtion entry -Copyright Information: -Huawei Technologies Co., Ltd. All Rights Reserved © 2020 -""" - -import os -import stat - - -REPLAY_BATCH = 'batch' -REPLAY_ITERATE = 'iterate' -CFG_IMPL_DIR = 'impl_dir' -CFG_OUT_DIR = 'out_dir' -WFLAGS = os.O_WRONLY | os.O_CREAT -WMODES = stat.S_IWUSR | stat.S_IRUSR -SOC_MAP_EXT = {'ascend310p': 'Ascend310P3', 'ascend310b': 'Ascend310B1', - 'ascend910': 'Ascend910A', 'ascend910b': 'Ascend910B1'} -BIN_CMD = 'opc $1 --main_func={fun} --input_param={param} --soc_version={soc} \ ---output=$2 --impl_mode={impl} --op_mode=dynamic\n' -CHK_CMD = ''' -if ! test -f $2/{res_file} ; then - echo "$2/{res_file} not generated!" - exit 1 -fi -''' -ATTR_DEF_VAL = {'str' : '', 'int': 0, 'float': 0.0, 'bool': False, 'list_bool': [], - 'list_int': [], 'list_float': [], 'list_list_int': [[]]} diff --git a/cust_op/cust_op_by_addr/cmake/util/gen_impl_and_mrege_json.sh b/cust_op/cust_op_by_addr/cmake/util/gen_impl_and_mrege_json.sh deleted file mode 100644 index 55e12e5e..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/gen_impl_and_mrege_json.sh +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/bash - -project_path=$1 -build_path=$2 -vendor_name=customize -if [[ ! -d "$project_path" ]]; then - echo "[ERROR] No projcet path is provided" - exit 1 -fi - -if [[ ! -d "$build_path" ]]; then - echo "[ERROR] No build path is provided" - exit 1 -fi - -# copy ai_core operators implements -# tbe_impl_files_num=$(ls $project_path/tbe/impl/* 2> /dev/null | wc -l) -# if [[ "$tbe_impl_files_num" -gt 0 ]];then -# cp -r ${project_path}/tbe/impl/* ${build_path}/makepkg/packages/vendors/$vendor_name/op_impl/ai_core/tbe/customize_impl -# cp -r ${project_path}/tbe/impl/* ${build_path}/makepkg/packages/vendors/$vendor_name/op_impl/vector_core/tbe/customize_impl -# fi - -# copy aicpu kernel so operators -if [[ -d "${project_path}/cpukernel/aicpu_kernel_lib" ]]; then - cp -f ${project_path}/cpukernel/aicpu_kernel_lib/* ${build_path}/makepkg/packages/vendors/$vendor_name/op_impl/cpu/aicpu_kernel/impl - rm -rf ${project_path}/cpukernel/aicpu_kernel_lib -fi - -# merge aicpu.ini and aicore.ini to generate npu_supported_ops.json -# mkdir -p ${build_path}/framework/op_info_cfg -# mkdir -p ${build_path}/framework/op_info_cfg/aicpu_kernel -# mkdir -p ${build_path}/framework/op_info_cfg/ai_core - -# if [[ -d "${project_path}/tbe/op_info_cfg/ai_core" ]]; then -# bash ${project_path}/cmake/util/gen_ops_filter.sh ${project_path}/tbe/op_info_cfg/ai_core ${build_path}/framework/op_info_cfg/ai_core -# fi - -# if [[ -d "${project_path}/cpukernel/op_info_cfg/aicpu_kernel" ]]; then -# bash ${project_path}/cmake/util/gen_ops_filter.sh ${project_path}/cpukernel/op_info_cfg/aicpu_kernel ${build_path}/framework/op_info_cfg/aicpu_kernel -# fi - -# aicpu_filter_file=${build_path}/framework/op_info_cfg/aicpu_kernel/npu_supported_ops.json -# aicore_filter_file=${build_path}/framework/op_info_cfg/ai_core/npu_supported_ops.json -# if [[ -f "${aicpu_filter_file}" ]] && [[ ! -f "${aicore_filter_file}" ]]; then -# cp $aicpu_filter_file ${build_path}/makepkg/packages/vendors/$vendor_name/framework/tensorflow -# fi -# if [[ -f "${aicore_filter_file}" ]] && [[ ! -f "${aicpu_filter_file}" ]]; then -# cp $aicore_filter_file ${build_path}/makepkg/packages/vendors/$vendor_name/framework/tensorflow -# fi - -# if [[ -f "${aicore_filter_file}" ]] && [[ -f "${aicpu_filter_file}" ]]; then -# chmod u+w ${aicpu_filter_file} -# python3 ${project_path}/cmake/util/insert_op_info.py ${aicore_filter_file} ${aicpu_filter_file} -# chmod u-w ${aicpu_filter_file} -# cp $aicpu_filter_file ${build_path}/makepkg/packages/vendors/$vendor_name/framework/tensorflow -# fi - diff --git a/cust_op/cust_op_by_addr/cmake/util/gen_ops_filter.sh b/cust_op/cust_op_by_addr/cmake/util/gen_ops_filter.sh deleted file mode 100644 index 54c7c640..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/gen_ops_filter.sh +++ /dev/null @@ -1,61 +0,0 @@ -#!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -# Description: Generate npu_supported_ops.json -# ============================================================================== - -if [[ -z "$1" ]]; then - echo -e "[ERROR] No source dir provided" - exit 1 -fi - -if [[ -z "$2" ]]; then - echo -e "[ERROR] No destination dir provided" - exit 1 -fi - -src=$1 -dest_file=$2/npu_supported_ops.json - -if [ -f "$dest_file" ];then - chmod u+w $dest_file -fi - -echo $* - -add_ops() { - name=$1 - isHeavy=$2 - file=$3 - grep -w "\"$name\"" ${file} >/dev/null - if [ $? == 0 ];then - return - fi - echo " \"${name}\": {" >> ${file} - echo " \"isGray\": false," >> ${file} - echo " \"isHeavy\": ${isHeavy}" >> ${file} - echo " }," >> ${file} -} - -echo "{" > ${dest_file} -ini_files=$(find ${src} -name "*.ini") -for file in ${ini_files} ; do - name=$(grep '^\[' ${file} | sed 's/\[//g' | sed 's/]//g' | sed 's/\r//g') - grep 'heavyOp.flag' ${file} >/dev/null - if [ $? == 0 ];then - isHeavy=$(grep 'heavyOp.flag' ${file} | awk -F= '{print $2}') - else - isHeavy="false" - fi - for op in ${name}; do - add_ops ${op} ${isHeavy} ${dest_file} - done -done -echo "}" >> ${dest_file} -file_count=$(cat ${dest_file} | wc -l) -line=$(($file_count-1)) -sed -i "${line}{s/,//g}" ${dest_file} - -chmod 640 "${dest_file}" -echo -e "[INFO] Succed generated ${dest_file}" - -exit 0 diff --git a/cust_op/cust_op_by_addr/cmake/util/insert_op_info.py b/cust_op/cust_op_by_addr/cmake/util/insert_op_info.py deleted file mode 100644 index 28ba0875..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/insert_op_info.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" -import json -import os -import sys -import stat -import const_var - - -if __name__ == '__main__': - if len(sys.argv) != 3: - print(sys.argv) - print('argv error, inert_op_info.py your_op_file lib_op_file') - sys.exit(2) - - with open(sys.argv[1], 'r') as load_f: - insert_operator = json.load(load_f) - - all_operators = {} - if os.path.exists(sys.argv[2]): - if os.path.getsize(sys.argv[2]) != 0: - with open(sys.argv[2], 'r') as load_f: - all_operators = json.load(load_f) - - for k in insert_operator.keys(): - if k in all_operators.keys(): - print('replace op:[', k, '] success') - else: - print('insert op:[', k, '] success') - all_operators[k] = insert_operator[k] - - with os.fdopen(os.open(sys.argv[2], const_var.WFLAGS, const_var.WMODES), 'w') as json_file: - json_file.write(json.dumps(all_operators, indent=4)) diff --git a/cust_op/cust_op_by_addr/cmake/util/insert_simplified_keys.py b/cust_op/cust_op_by_addr/cmake/util/insert_simplified_keys.py deleted file mode 100644 index 19c3820f..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/insert_simplified_keys.py +++ /dev/null @@ -1,248 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" - -import sys -import os -import re -import glob -import json -import argparse -import const_var - - -DATA_TPYE_DICT = { - 'float32': 0, - 'float16': 1, - 'int8': 2, - 'int16': 6, - 'uint16': 7, - 'uint8': 4, - 'int32': 3, - 'int64': 9, - 'uint32': 8, - 'uint64': 10, - 'bool': 12, - 'double': 11, - 'string': 13, - 'dual': 14, - 'dual': 15, - 'complex64': 16, - 'complex128': 17, - 'qint8': 18, - 'qint16': 19, - 'qint32': 20, - 'quint8': 21, - 'quint16': 22, - 'resource': 23, - 'string': 24, - 'dual': 25, - 'variant': 26, - 'bf16': 27, - 'bfloat16': 27, - 'undefined': 28, - 'int4': 29, - 'uint1': 30, - 'int2': 31 -} - -FORMAT_DICT = { - 'NCHW': 0, - 'NHWC': 1, - 'ND': 2, - 'NC1HWC0': 3, - 'FRACTAL_Z': 4, - 'NC1C0HWPAD': 5, - 'NHWC1C0': 6, - 'FSR_NCHW': 7, - 'FRACTAL_DECONV': 8, - 'C1HWNC0': 9, - 'FRACTAL_DECONV_TRANSPOSE': 10, - 'FRACTAL_DECONV_SP_STRIDE_TRANS': 11, - 'NC1HWC0_C04': 12, - 'FRACTAL_Z_C04': 13, - 'CHWN': 14, - 'FRACTAL_DECONV_SP_STRIDE8_TRANS': 15, - 'HWCN': 16, - 'NC1KHKWHWC0': 17, - 'BN_WEIGHT': 18, - 'FILTER_HWCK': 19, - 'HASHTABLE_LOOKUP_LOOKUPS': 20, - 'HASHTABLE_LOOKUP_KEYS': 21, - 'HASHTABLE_LOOKUP_VALUE': 22, - 'HASHTABLE_LOOKUP_OUTPUT': 23, - 'HASHTABLE_LOOKUP_HITS': 24, - 'C1HWNCoC0': 25, - 'MD': 26, - 'NDHWC': 27, - 'FRACTAL_ZZ': 28, - 'FRACTAL_NZ': 29, - 'NCDHW': 30, - 'DHWCN': 31, - 'NDC1HWC0': 32, - 'FRACTAL_Z_3D': 33, - 'CN': 34, - 'NC': 35, - 'DHWNC': 36, - 'FRACTAL_Z_3D_TRANSPOSE': 37, - 'FRACTAL_ZN_LSTM': 38, - 'FRACTAL_Z_G': 39, - 'RESERVED': 40, - 'ALL': 41, - 'NULL': 42, - 'ND_RNN_BIAS': 43, - 'FRACTAL_ZN_RNN': 44, - 'NYUV': 45, - 'NYUV_A': 46 -} - - -def load_json(json_file: str): - with open(json_file, encoding='utf-8') as file: - json_content = json.load(file) - return json_content - - -def get_specified_suffix_file(root_dir, suffix): - specified_suffix = os.path.join(root_dir, '**/*.{}'.format(suffix)) - all_suffix_files = glob.glob(specified_suffix, recursive=True) - return all_suffix_files - - -def get_deterministic_value(support_info): - deterministic_key = 'deterministic' - if deterministic_key not in support_info: - return 0 - deterministic_value = support_info.get(deterministic_key) - if deterministic_value == 'true': - return 1 - else: - return 0 - - -def get_precision_value(support_info): - precision_key = 'implMode' - precision_value = support_info.get(precision_key) - if precision_value == 'high_performance': - _value = 1 - elif precision_value == 'high_precision': - _value = 2 - else: - _value = 0 - return _value - - -def get_overflow_value(support_info): - return 0 - - -def get_parameters(info): - if info: - if 'dtype' in info: - data_type = info['dtype'] - data_type_value = DATA_TPYE_DICT.get(data_type) - else: - data_type_value = 0 - if 'format' in info: - _format = info['format'] - _format_value = FORMAT_DICT.get(_format) - else: - _format_value = 0 - else: - data_type_value = 0 - _format_value = 0 - return str(data_type_value), str(_format_value) - - -def get_dynamic_parameters(info): - # 动态输入时只需获取第一个参数 - return get_parameters(info[0]) - - -def get_all_parameters(support_info, _type): - result_list = list() - info_lists = support_info.get(_type) - if info_lists: - for _info in info_lists: - # 输入为列表时是动态输入 - if isinstance(_info, list): - data_type_value, _format_value = get_dynamic_parameters(_info) - else: - data_type_value, _format_value = get_parameters(_info) - result_list.append("{},{}".format(data_type_value, _format_value)) - return result_list - - -def get_all_input_parameters(support_info): - result = get_all_parameters(support_info, 'inputs') - return '/'.join(result) - - -def insert_content_into_file(input_file, content): - with open(input_file, 'r+') as file: - lines = file.readlines() - for index, line in enumerate(lines): - match_result = re.search(r'"staticKey":', line) - if match_result: - count = len(line) - len(line.lstrip()) - new_content = "{}{}".format(' ' * count, content) - # 插入到前一行,防止插入最后时还需要考虑是否添加逗号 - lines.insert(index, new_content) - break - file.seek(0) - file.write(''.join(lines)) - - -def insert_simplified_keys(json_file): - contents = load_json(json_file) - # 不存在'binFileName'或者'supportInfo'字段时,非需要替换的解析json文件 - if ('binFileName' not in contents) or ('supportInfo' not in contents): - return - support_info = contents.get('supportInfo') - bin_file_name = contents.get('binFileName') - bin_suffix = contents.get('binFileSuffix') - # 'simplifiedKey'字段已经存在时,直接返回,不重复生成 - if 'simplifiedKey' in support_info: - return - op_type = bin_file_name.split('_')[0] - deterministic = str(get_deterministic_value(support_info)) - precision = str(get_precision_value(support_info)) - overflow = str(get_overflow_value(support_info)) - input_parameters = get_all_input_parameters(support_info) - key = '{}/d={},p={},o={}/{}/'.format( - op_type, - deterministic, - precision, - overflow, - input_parameters) - result = '"simplifiedKey": "' + key + '",\n' - insert_content_into_file(json_file, result) - - -def insert_all_simplified_keys(root_dir): - suffix = 'json' - all_json_files = get_specified_suffix_file(root_dir, suffix) - for _json in all_json_files: - insert_simplified_keys(_json) - - -def args_prase(): - parser = argparse.ArgumentParser() - parser.add_argument('-p', - '--path', - nargs='?', - required=True, - help='Parse the path of the json file.') - return parser.parse_args() - - -def main(): - args = args_prase() - insert_all_simplified_keys(args.path) - - -if __name__ == '__main__': - main() diff --git a/cust_op/cust_op_by_addr/cmake/util/kernel_entry.py b/cust_op/cust_op_by_addr/cmake/util/kernel_entry.py deleted file mode 100644 index 2b77c970..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/kernel_entry.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" - - -def gen_fun_def(title, kernel, argn, arg_type, arg_name): - entry = [] - entry.append(title) - entry.append(kernel) - entry.append('(') - args = [] - for i in range(0, argn): - args.append(arg_type + ' ' + arg_name + str(i)) - entry.append(', '.join(args)) - entry.append(')') - return ' '.join(entry) - - -def gen_batch_kernel_body(fname, argn, arg_name): - body = [] - body.append('{') - fun = [] - fun.append(fname) - fun.append('(') - args = [] - for i in range(0, argn): - args.append(arg_name + str(i)) - fun.append(', '.join(args)) - fun.append(');') - body.append(' '.join(fun)) - body.append('}') - return '\n'.join(body) - - -def gen_mc_kernel_body(kn, argn, arg_name, blknum): - body = [] - body.append('{') - body.append(' switch(block_idx) {') - for blk in range(0, blknum): - fun = [] - fun.append('{}_blk{:02d}'.format(kn, blk)) - fun.append('(') - args = [] - for i in range(0, argn): - args.append(arg_name + str(i)) - fun.append(', '.join(args)) - fun.append(')') - body.append(' case {}: {}; break;'.format(blk, ' '.join(fun))) - body.append(' default: break;') - body.append(' }') - body.append('}') - return '\n'.join(body) - - -def gen_proc_body(argn, arg_name): - body = [] - body.append('{') - args = [] - for i in range(0, argn): - args.append(arg_name + str(i)) - body.append('uint64_t __x = (uint64_t)' + ' + (uint64_t)'.join(args) + ';') - body.append('__asm__ ("NOP");') - body.append('__asm__ ("NOP");') - body.append('__asm__ ("NOP");') - body.append('}') - return '\n'.join(body) - - -def batch_code_gen(kn, argn, argt): - codes = [] - kernel_name = kn - proc_name = kernel_name + '_percore' - arg_num = int(argn) - data_type = argt - arg_type = '__gm__ ' + data_type + '* __restrict__' - arg_name = 'arg' - kernel_title = 'extern \"C\" __global__ __aicore__ void' - proc_title = 'extern \"C\" __attribute__((noinline)) __aicore__ void' - codes.append('#ifndef __aicore__') - codes.append('#define __aicore__ [aicore]') - codes.append('#endif') - codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name) + ';') - codes.append(gen_fun_def(kernel_title, kernel_name, arg_num, arg_type, arg_name)) - codes.append(gen_batch_kernel_body(proc_name, arg_num, arg_name)) - codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name)) - codes.append(gen_proc_body(arg_num, arg_name)) - return '\n'.join(codes) + '\n' - - -def mc_code_gen(kn, argn, argt, blknum): - codes = [] - kernel_name = kn - core_num = int(blknum) - arg_num = int(argn) - data_type = argt - arg_type = '__gm__ ' + data_type + '* __restrict__' - arg_name = 'arg' - kernel_title = 'extern \"C\" __global__ __aicore__ void' - proc_title = 'extern \"C\" __attribute__((noinline)) __aicore__ void' - codes.append('#ifndef __aicore__') - codes.append('#define __aicore__ [aicore]') - codes.append('#endif') - for i in range(0, core_num): - proc_name = '{}_blk{:02d}'.format(kernel_name, i) - codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name) + ';') - codes.append(gen_fun_def(kernel_title, kernel_name, arg_num, arg_type, arg_name)) - codes.append(gen_mc_kernel_body(kernel_name, arg_num, arg_name, core_num)) - for i in range(0, core_num): - proc_name = '{}_blk{:02d}'.format(kernel_name, i) - codes.append(gen_fun_def(proc_title, proc_name, arg_num, arg_type, arg_name)) - codes.append(gen_proc_body(arg_num, arg_name)) - return '\n'.join(codes) + '\n' diff --git a/cust_op/cust_op_by_addr/cmake/util/kernel_impl.temp b/cust_op/cust_op_by_addr/cmake/util/kernel_impl.temp deleted file mode 100644 index 7391befa..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/kernel_impl.temp +++ /dev/null @@ -1,10 +0,0 @@ -#include -#include -#include -#include -#include -#include "replay_def.h" -#include "code_gen.h" -#include "replay_fun.h" -#define __TIK2_REPLAY_CODE__ -#include "__CCE_FILE__" diff --git a/cust_op/cust_op_by_addr/cmake/util/merge_aicpu_info_json.sh b/cust_op/cust_op_by_addr/cmake/util/merge_aicpu_info_json.sh deleted file mode 100644 index a977bd51..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/merge_aicpu_info_json.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -project_path=$1 -build_path=$2 -vendor_name=customize -echo $@ -if [[ ! -d "$project_path" ]]; then - echo "[ERROR] No projcet path is provided" - exit 1 -fi - -if [[ ! -d "$build_path" ]]; then - echo "[ERROR] No build path is provided" - exit 1 -fi - -if [[ ! -d "$ASCEND_OPP_PATH" ]]; then - echo "[ERROR] No opp install path is provided" - exit 1 -fi -custom_exist_info_json=$ASCEND_OPP_PATH/vendors/$vendor_name/op_impl/cpu/config/cust_aicpu_kernel.json -custom_new_info_json=$build_path/makepkg/packages/vendors/$vendor_name/op_impl/cpu/config/cust_aicpu_kernel.json -temp_info_json=$build_path/makepkg/packages/vendors/$vendor_name/op_impl/cpu/config/temp_cust_aicpu_kernel.json - -if [[ -f "$custom_exist_info_json" ]] && [[ -f "$custom_new_info_json" ]]; then - cp -f $custom_exist_info_json $temp_info_json - chmod +w $temp_info_json - python3 ${project_path}/cmake/util/insert_op_info.py ${custom_new_info_json} ${temp_info_json} - cp -f $temp_info_json $custom_new_info_json - rm -f $temp_info_json -fi diff --git a/cust_op/cust_op_by_addr/cmake/util/opdesc_parser.py b/cust_op/cust_op_by_addr/cmake/util/opdesc_parser.py deleted file mode 100644 index c58729c3..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/opdesc_parser.py +++ /dev/null @@ -1,176 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" - -import sys -import os - - -class OpDesc: - def __init__(self: any, op_type: str): - self.op_type = op_type - self.attr_list = [] - self.attr_val = {} - self.input_name = [] - self.input_type = [] - self.input_dtype = [] - self.input_fmt = [] - self.output_name = [] - self.output_type = [] - self.output_dtype = [] - self.output_fmt = [] - self.op_fmt_sel = False - self.op_chk_support = False - self.op_intf = '' - self.kern_name = '' - self.op_file = '' - self.op_replay_flag = False - self.op_replay_batch = False - self.input_idx = 0 - self.output_idx = 0 - self.max_block_dim = 32 - self.max_shape_size = 268435456 - self.dynamic_shape = False - self.op_range_limit = '' - - @staticmethod - def _parse_digit(conf: str) -> int: - return int(conf.split('=')[1]) - - @staticmethod - def _parse_flag(conf: str) -> bool: - if 'true' == conf.split('=')[1]: - return True - return False - - @staticmethod - def _parse_str(conf: str) -> str: - return conf.split('=')[1] - - @staticmethod - def _parse_list(conf: str) -> list: - return conf.split('=')[1].split(',') - - def parse_input(self: any, conf: str): - if conf.startswith('input{}.name'.format(int(self.input_idx / 4))): - self.input_name.append(self._parse_str(conf)) - self.input_idx += 1 - elif conf.startswith('input{}.paramType'.format(int(self.input_idx / 4))): - self.input_type.append(self._parse_str(conf)) - self.input_idx += 1 - elif conf.startswith('input{}.dtype'.format(int(self.input_idx / 4))): - self.input_dtype.append(self._parse_str(conf)) - self.input_idx += 1 - elif conf.startswith('input{}.format'.format(int(self.input_idx / 4))): - self.input_fmt.append(self._parse_str(conf)) - self.input_idx += 1 - else: - return - - def parse_output(self: any, conf: str): - if conf.startswith('output{}.name'.format(int(self.output_idx / 4))): - self.output_name.append(self._parse_str(conf)) - self.output_idx += 1 - elif conf.startswith('output{}.paramType'.format(int(self.output_idx / 4))): - self.output_type.append(self._parse_str(conf)) - self.output_idx += 1 - elif conf.startswith('output{}.dtype'.format(int(self.output_idx / 4))): - self.output_dtype.append(self._parse_str(conf)) - self.output_idx += 1 - elif conf.startswith('output{}.format'.format(int(self.output_idx / 4))): - self.output_fmt.append(self._parse_str(conf)) - self.output_idx += 1 - else: - return - - def parse_op_format(self: any, conf: str): - self.op_fmt_sel = self._parse_flag(conf) - - def parse_check_support(self: any, conf: str): - self.op_chk_support = self._parse_flag(conf) - - def parse_range_limit(self: any, conf: str): - self.op_range_limit = self._parse_str(conf) - - def parse_kern_name(self: any, conf: str): - self.kern_name = self._parse_str(conf) - - def parse_op_intf(self: any, conf: str): - self.op_intf = self._parse_str(conf) - - def parse_op_file(self: any, conf: str): - self.op_file = self._parse_str(conf) - - def parse_dynamic_shape(self: any, conf: str): - self.dynamic_shape = self._parse_flag(conf) - - def parse_attr_list(self: any, conf: str): - self.attr_list = self._parse_list(conf) - - def parse_attr_val(self: any, conf: str): - for attr in self.attr_list: - if self.attr_val.get(attr) is None: - self.attr_val[attr] = {} - if conf.startswith('attr_{}.type'.format(attr)): - self.attr_val.get(attr)['type'] = self._parse_str(conf) - elif conf.startswith('attr_{}.paramType'.format(attr)): - self.attr_val.get(attr)['paramType'] = self._parse_str(conf) - elif conf.startswith('attr_{}.defaultValue'.format(attr)): - self.attr_val.get(attr)['defaultValue'] = self._parse_str(conf) - - def parse_replay_val(self: any, batch_list: list, iterator_list: list): - if self.op_type in batch_list: - self.op_replay_flag = True - self.op_replay_batch = True - elif self.op_type in iterator_list: - self.op_replay_flag = True - self.op_replay_batch = False - - -def get_op_desc(file: str, batch_list: list, iterator_list: list, builder: any, op_type: list) -> list: - op_descs = [] - op_match = False - with open (file, 'r') as fd: - lines = fd.readlines() - for line in lines: - line = line.strip() - if line.startswith('['): - name = line[1:-1] - if op_type is None or name in op_type: - op_match = True - op_desc = builder(name) - op_descs.append(op_desc) - else: - op_match = False - if op_type is not None and len(op_descs) == len(op_type): - return op_descs - continue - if not op_match: - continue - if line.startswith('input'): - op_desc.parse_input(line) - elif line.startswith('output'): - op_desc.parse_output(line) - elif line.startswith('dynamicFormat.flag'): - op_desc.parse_op_format(line) - elif line.startswith('needCheckSupport.flag'): - op_desc.parse_check_support(line) - elif line.startswith('rangeLimit.value'): - op_desc.parse_range_limit(line) - elif line.startswith('opInterface.value'): - op_desc.parse_op_intf(line) - elif line.startswith('kernel.name'): - op_desc.parse_kern_name(line) - elif line.startswith('opFile.value'): - op_desc.parse_op_file(line) - elif line.startswith('dynamicShapeSupport.flag'): - op_desc.parse_dynamic_shape(line) - elif line.startswith('attr.list'): - op_desc.parse_attr_list(line) - elif line.startswith('attr_'): - op_desc.parse_attr_val(line) - op_desc.parse_replay_val(batch_list, iterator_list) - return op_descs diff --git a/cust_op/cust_op_by_addr/cmake/util/parse_ini_to_json.py b/cust_op/cust_op_by_addr/cmake/util/parse_ini_to_json.py deleted file mode 100644 index f75c5260..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/parse_ini_to_json.py +++ /dev/null @@ -1,339 +0,0 @@ -# Copyright 2020-2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -parser ini to json -""" - -import json -import os -import stat -import sys - - -ATTR_TYPE_LIST = ["int", "float", "bool", "str", "listInt", "listFloat", "listBool", "listStr", "listListInt", - "type", "listType", "tensor", "listTensor"] -ATTR_PARAMTYPE_LIST = ["optional", "required"] -BOOL_FLAG_KEY = ["dynamicFormat", "dynamicShapeSupport", "dynamicRankSupport", "precision_reduce", "heavyOp", - "needCheckSupport"] -BOOL_LIST = ["true", "false"] -DTYPE_LIST = ["float16", "float", "float32", "int8", "int16", "int32", "uint8", "uint16", "uint32", "bool", - "int64", "uint64", "qint8", "qint16", "qint32", "quint8", "quint16", "double", "complex64", - "complex128", "string", "resource", "dual", "dual_sub_int8", "dual_sub_uint8", "string_ref", - "int4", "bfloat16", "uint1"] -FORMAT_LIST = ["NCHW", "NHWC", "ND", "NC1HWC0", "FRACTAL_Z", "NC1C0HWPAD", "NHWC1C0", "FSR_NCHW", "FRACTAL_DECONV", - "C1HWNC0", "FRACTAL_DECONV_TRANSPOSE", "FRACTAL_DECONV_SP_STRIDE_TRANS", "NC1HWC0_C04", - "FRACTAL_Z_C04", "CHWN", "FRACTAL_DECONV_SP_STRIDE8_TRANS", "HWCN", "NC1KHKWHWC0", "BN_WEIGHT", - "FILTER_HWCK", "HASHTABLE_LOOKUP_LOOKUPS", "HASHTABLE_LOOKUP_KEYS", "HASHTABLE_LOOKUP_VALUE", - "HASHTABLE_LOOKUP_OUTPUT", "HASHTABLE_LOOKUP_HITS", "C1HWNCoC0", "MD", "NDHWC", "FRACTAL_ZZ", - "FRACTAL_NZ", "NCDHW", "DHWCN", "NDC1HWC0", "FRACTAL_Z_3D", "CN", "NC", "DHWNC", - "FRACTAL_Z_3D_TRANSPOSE", "FRACTAL_ZN_LSTM", "FRACTAL_ZN_RNN", "FRACTAL_Z_G", "NULL"] - - -def parse_ini_files(ini_files): - """ - parse ini files to json - Parameters: - ---------------- - ini_files:input file list - return:ops_info - ---------------- - """ - tbe_ops_info = {} - for ini_file in ini_files: - check_file_size(ini_file) - parse_ini_to_obj(ini_file, tbe_ops_info) - return tbe_ops_info - - -def check_file_size(input_file): - try: - file_size = os.path.getsize(input_file) - except OSError as os_error: - print('[ERROR] Failed to open "%s". %s' % (input_file, str(os_error))) - raise OSError from os_error - if file_size > 10*1024*1024: - print('[WARN] The size of %s exceeds 10MB, it may take more time to run, please wait.' % input_file) - - -def parse_ini_to_obj(ini_file, tbe_ops_info): - """ - parse ini file to json obj - Parameters: - ---------------- - ini_file:ini file path - tbe_ops_info:ops_info - ---------------- - """ - with open(ini_file) as ini_file: - lines = ini_file.readlines() - op_dict = {} - op_name = "" - find_op_type = False - for line in lines: - line = line.rstrip() - if line == "": - continue - if line.startswith("["): - if line.endswith("]"): - op_name = line[1:-1] - op_dict = {} - tbe_ops_info[op_name] = op_dict - find_op_type = True - elif "=" in line: - key1 = line[:line.index("=")] - key2 = line[line.index("=")+1:] - key1_0, key1_1 = key1.split(".") - if key1_0 not in op_dict: - op_dict[key1_0] = {} - if key1_1 in op_dict.get(key1_0): - raise RuntimeError("Op:" + op_name + " " + key1_0 + " " + - key1_1 + " is repeated!") - dic_key = op_dict.get(key1_0) - dic_key[key1_1] = key2 - else: - continue - if not find_op_type: - raise RuntimeError("Not find OpType in .ini file.") - - -def check_output_exist(op_dict, is_valid): - """ - Function Description: - Check output is exist - Parameter: op_dict - Parameter: is_valid - """ - if "output0" in op_dict: - output0_dict = op_dict.get("output0") - if output0_dict.get("name", None) is None: - is_valid = False - print("output0.name is required in .ini file!") - else: - is_valid = False - print("output0 is required in .ini file!") - return is_valid - - -def check_attr_dict(attr_dict, is_valid, attr): - """ - Function Description: - Check attr_dict - Parameter: attr_dict - Parameter: is_valid - Parameter: attr - """ - attr_type = attr_dict.get("type") - value = attr_dict.get("value") - param_type = attr_dict.get("paramType") - if attr_type is None or value is None: - is_valid = False - print("If attr.list is exist, {0}.type and {0}.value is required".format(attr)) - if param_type and param_type not in ATTR_PARAMTYPE_LIST: - is_valid = False - print("{0}.paramType only support {1}.".format(attr, ATTR_PARAMTYPE_LIST)) - if attr_type and attr_type not in ATTR_TYPE_LIST: - is_valid = False - print("{0}.type only support {1}.".format(attr, ATTR_TYPE_LIST)) - return is_valid - - -def check_attr(op_dict, is_valid): - """ - Function Description: - Check attr - Parameter: op_dict - Parameter: is_valid - """ - if "attr" in op_dict: - attr_dict = op_dict.get("attr") - attr_list_str = attr_dict.get("list", None) - if attr_list_str is None: - is_valid = False - print("attr.list is required in .ini file!") - else: - attr_list = attr_list_str.split(",") - for attr_name in attr_list: - attr = "attr_" + attr_name.strip() - attr_dict = op_dict.get(attr) - if attr_dict: - is_valid = check_attr_dict(attr_dict, is_valid, attr) - else: - is_valid = False - print("%s is required in .ini file, when attr.list is %s!" % (attr, attr_list_str)) - return is_valid - - -def check_bool_flag(op_dict, is_valid): - """ - Function Description: - check_bool_flag - Parameter: op_dict - Parameter: is_valid - """ - for key in BOOL_FLAG_KEY: - if key in op_dict: - op_bool_key = op_dict.get(key) - if op_bool_key.get("flag").strip() not in BOOL_LIST: - is_valid = False - print("{0}.flag only support {1}.".format(key, BOOL_LIST)) - return is_valid - - -def check_type_format(op_info, is_valid, op_info_key): - """ - Function Description: - Check type and format - Parameter: op_info - Parameter: is_valid - Parameter: op_info_key - """ - op_info_dtype_str = op_info.get("dtype") - op_info_dtype_num = 0 - op_info_format_num = 0 - if op_info_dtype_str: - op_info_dtype = op_info_dtype_str.split(",") - op_info_dtype_num = len(op_info_dtype) - for dtype in op_info_dtype: - if dtype.strip() not in DTYPE_LIST: - is_valid = False - print("{0}.dtype not support {1}.".format(op_info_key, dtype)) - op_info_format_str = op_info.get("format") - if op_info_format_str: - op_info_format = op_info_format_str.split(",") - op_info_format_num = len(op_info_format) - for op_format in op_info_format: - if op_format.strip() not in FORMAT_LIST: - is_valid = False - print("{0}.format not support {1}.".format(op_info_key, op_format)) - if op_info_dtype_num > 0 and op_info_format_num > 0: - if op_info_dtype_num != op_info_format_num: - is_valid = False - print("The number of {0}.dtype not match the number of {0}.format.".format(op_info_key)) - return is_valid - - -def check_op_info(tbe_ops): - """ - Function Description: - Check info. - Parameter: tbe_ops - Return Value: is_valid - """ - print("\n\n==============check valid for ops info start==============") - required_op_input_info_keys = ["paramType", "name"] - required_op_output_info_keys = ["paramType", "name"] - param_type_valid_value = ["dynamic", "optional", "required"] - is_valid = True - for op_key in tbe_ops: - op_dict = tbe_ops[op_key] - is_valid = check_output_exist(op_dict, is_valid) - for op_info_key in op_dict: - if op_info_key.startswith("input"): - op_input_info = op_dict[op_info_key] - missing_keys = [] - for required_op_input_info_key in required_op_input_info_keys: - if required_op_input_info_key not in op_input_info: - missing_keys.append(required_op_input_info_key) - if len(missing_keys) > 0: - print("op: " + op_key + " " + op_info_key + " missing: " + - ",".join(missing_keys)) - is_valid = False - else: - if not op_input_info["paramType"] in param_type_valid_value: - print("op: " + op_key + " " + op_info_key + \ - " paramType not valid, valid key:[dynamic, " - "optional, required]") - is_valid = False - is_valid = check_type_format(op_input_info, is_valid, op_info_key) - if op_info_key.startswith("output"): - op_input_info = op_dict[op_info_key] - missing_keys = [] - for required_op_input_info_key in required_op_output_info_keys: - if required_op_input_info_key not in op_input_info: - missing_keys.append(required_op_input_info_key) - if len(missing_keys) > 0: - print("op: " + op_key + " " + op_info_key + " missing: " + - ",".join(missing_keys)) - is_valid = False - else: - if not op_input_info["paramType"] in param_type_valid_value: - print("op: " + op_key + " " + op_info_key + - " paramType not valid, valid key:[dynamic, " - "optional, required]") - is_valid = False - is_valid = check_type_format(op_input_info, is_valid, op_info_key) - is_valid = check_attr(op_dict, is_valid) - is_valid = check_bool_flag(op_dict, is_valid) - print("==============check valid for ops info end================\n\n") - return is_valid - - -def write_json_file(tbe_ops_info, json_file_path): - """ - Save info to json file - Parameters: - ---------------- - tbe_ops_info: ops_info - json_file_path: json file path - ---------------- - """ - json_file_real_path = os.path.realpath(json_file_path) - wr_flag = os.O_WRONLY | os.O_CREAT - wr_mode = stat.S_IWUSR | stat.S_IRUSR - with os.fdopen(os.open(json_file_real_path, wr_flag, wr_mode), 'w') as file_path: - # Only the owner and group have rights - os.chmod(json_file_real_path, stat.S_IWGRP + stat.S_IWUSR + stat.S_IRGRP - + stat.S_IRUSR) - json.dump(tbe_ops_info, file_path, sort_keys=True, indent=4, - separators=(',', ':')) - print("Compile op info cfg successfully.") - - -def parse_ini_to_json(ini_file_paths, outfile_path): - """ - parse ini files to json file - Parameters: - ---------------- - ini_file_paths: list of ini file path - outfile_path: output file path - ---------------- - """ - tbe_ops_info = parse_ini_files(ini_file_paths) - if not check_op_info(tbe_ops_info): - print("Compile op info cfg failed.") - return False - write_json_file(tbe_ops_info, outfile_path) - return True - - -if __name__ == '__main__': - args = sys.argv - - OUTPUT_FILE_PATH = "tbe_ops_info.json" - ini_file_path_list = [] - - for arg in args: - if arg.endswith("ini"): - ini_file_path_list.append(arg) - OUTPUT_FILE_PATH = arg.replace(".ini", ".json") - if arg.endswith("json"): - OUTPUT_FILE_PATH = arg - - if len(ini_file_path_list) == 0: - ini_file_path_list.append("tbe_ops_info.ini") - - if not parse_ini_to_json(ini_file_path_list, OUTPUT_FILE_PATH): - sys.exit(1) - sys.exit(0) diff --git a/cust_op/cust_op_by_addr/cmake/util/preset_parse.py b/cust_op/cust_op_by_addr/cmake/util/preset_parse.py deleted file mode 100644 index 8f1124b1..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/preset_parse.py +++ /dev/null @@ -1,23 +0,0 @@ -import json -import sys -import os - - -def get_config_opts(file): - src_dir = os.path.abspath(os.path.dirname(file)) - opts = '' - with open(file, 'r') as fd: - config = json.load(fd) - for conf in config: - if conf == 'configurePresets': - for node in config[conf]: - macros = node.get('cacheVariables') - if macros is not None: - for key in macros: - opts += '-D{}={} '.format(key, macros[key]['value']) - opts = opts.replace('${sourceDir}', src_dir) - print(opts) - - -if __name__ == "__main__": - get_config_opts(sys.argv[1]) diff --git a/cust_op/cust_op_by_addr/cmake/util/replay_codegen.py b/cust_op/cust_op_by_addr/cmake/util/replay_codegen.py deleted file mode 100644 index 1baa364e..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/replay_codegen.py +++ /dev/null @@ -1,105 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" - -import os -import stat -import collections -import kernel_entry as keb -from tiling_data_def_build import gen_tiling -import code_channel_infer -import const_var - -PYF_PATH = os.path.dirname(__file__) - -ReplayCodeGenParams = collections.namedtuple('ReplayCodeGenParams',\ -['op_type', 'impl', 'tiling_file', 'kernel', 'entry', 'argn', 'op_replay_batch', 'max_block_dim', 'max_shape_size']) - - -class ReplayCodeGen: - def __init__(self, replayCodeGenParams): - self.op_type = replayCodeGenParams.op_type - self.impl = replayCodeGenParams.impl - self.tiling_file = replayCodeGenParams.tiling_file - self.tiling_data_file = '' - self.kernel = replayCodeGenParams.kernel - self.entry = replayCodeGenParams.entry - self.argn = replayCodeGenParams.argn - self.batch = False - self.outdir = '' - self.data_type = 'uint8_t' - self.blknum = 32 - self.op_replay_batch = replayCodeGenParams.op_replay_batch - self.max_block_dim = replayCodeGenParams.max_block_dim - self.max_shape_size = replayCodeGenParams.max_shape_size - - def set_batch(self, is_batch): - self.batch = is_batch - - def set_outdir(self, outdir): - self.outdir = outdir - - def gen_replay(self, ops_product: str): - kerentry = os.path.join(self.outdir, self.kernel + '_entry.cce') - kerimpl = os.path.join(self.outdir, self.kernel + '_impl.cpp') - replayimpl = os.path.join(self.outdir, self.kernel + '_replay.cpp') - if self.batch: - reptmp = os.path.join(PYF_PATH, 'batch_replay_impl.temp') - else: - reptmp = os.path.join(PYF_PATH, 'replay_impl.temp') - kertmp = os.path.join(PYF_PATH, 'kernel_impl.temp') - self._gen_kentry(kerentry) - self._gen_kimpl_code(kerimpl, kertmp) - self._gen_tiling_data_header() - self._gen_replay_code(replayimpl, reptmp, ops_product) - - def _gen_tiling_data_header(self): - self.tiling_data_file = os.path.join(self.outdir, self.kernel + '_tiling_data.h') - gen_tiling(self.tiling_file, self.tiling_data_file) - - def _gen_kimpl_code(self, src, tmpfile): - with open(tmpfile, 'r') as fd: - temp = fd.read() - temp = temp.replace('__CCE_FILE__', self.impl) - with os.fdopen(os.open(src, const_var.WFLAGS, const_var.WMODES), 'w') as ofd: - ofd.write(temp) - - def _gen_replay_code(self, src, tmpfile, ops_product: str): - with open(tmpfile, 'r') as fd: - temp = fd.read() - temp = temp.replace('__ARG_NUM__', str(self.argn)) - argdef = [] - kargs = [] - for i in range(0, self.argn): - argdef.append('{} *'.format(self.data_type)) - kargs.append('({} *)GetArg({})'.format(self.data_type, i)) - temp = temp.replace('__ARGS_DEF__', ', '.join(argdef)) - temp = temp.replace('__KERNEL_ARGS__', ', '.join(kargs)) - temp = temp.replace('__KERNEL_FUN__', self.entry) - core_type_infer = 'core_type' - code_channel = code_channel_infer.infer_code_channel(code_channel_infer.InfoCodeChanelParams(self.impl,\ - self.tiling_data_file, self.kernel, self.outdir, ops_product, None)) - if code_channel == code_channel_infer.CODE_VEC: - core_type_infer = '0' - elif code_channel == code_channel_infer.CODE_CUBE: - core_type_infer = '1' - temp = temp.replace('__CORE_TYPE__', core_type_infer) - # regist function - temp = temp.replace('__OPS_PRODUCT__', ops_product) - temp = temp.replace('__OPTYPE__', self.op_type) - with os.fdopen(os.open(src, const_var.WFLAGS, const_var.WMODES), 'w') as ofd: - ofd.write(temp) - - def _gen_kentry(self, src): - kf = '' - pre_alloc_str = 'A' * 256 - if self.batch: - kf += keb.batch_code_gen("K{:02d}_{}{}".format(0, self.entry, pre_alloc_str), self.argn, self.data_type) - else: - kf += keb.mc_code_gen("K{:02d}_{}{}".format(0, self.entry, pre_alloc_str),\ - self.argn, self.data_type, self.blknum) - with os.fdopen(os.open(src, const_var.WFLAGS, const_var.WMODES), 'w') as ofd: - ofd.write(kf) diff --git a/cust_op/cust_op_by_addr/cmake/util/replay_impl.temp b/cust_op/cust_op_by_addr/cmake/util/replay_impl.temp deleted file mode 100644 index 3e0b2f44..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/replay_impl.temp +++ /dev/null @@ -1,120 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "replay_def.h" -#include "code_gen.h" -#include "replay_fun.h" -#include "register/op_check.h" -#define __TIK2_REPLAY_CODE__ -using namespace std; -using namespace optiling; -using namespace tik2_replay; - -extern "C" void __KERNEL_FUN__ (__ARGS_DEF__, const char *); -extern "C" int elf_append(char *elf, uint32_t elfSize, char *jit, int kernum, int blknum[], char *atext[], - int alen[], int atlen, const char* kernelname[]); - -#define KERNEL_N 1 -#define ARG_N (__ARG_NUM__) -#define MAX_L (1024 * 1024 * 100) -#define MAX_E (1024 * 1024) - -int __KERNEL_FUN___replay___OPS_PRODUCT__(ReplayFuncParam& param, const int core_type) -{ - // gen type 1 : direct call codes 0: load .o file - if (param.gentype < 0 || param.gentype > 1) { - printf("Error: call replay gen type is %d, should only be 1 or 0\n", param.gentype); - return 0; - } else if (param.gentype == 1 && param.objptr == nullptr) { - printf("Error: call replay with direct call mode, but code obj addr is null\n"); - return 0; - } else if (param.gentype == 0 && param.output_kernel_file == nullptr) { - printf("Error: call replay with object file mode, but object file path is null\n"); - return 0; - } - // core_type 0:MIX 1:CUBE 2:VEC - if (core_type < 0 || core_type > 2) { - printf("Error: call replay core type is %d !\n", core_type); - return 0; - } - g_coreType = __CORE_TYPE__; - g_taskRation = param.task_ration; - g_tilingKey = param.tiling_key; - - unsigned char *buf, *jit; - char *kernel[KERNEL_N * 32]; - int len[KERNEL_N * 32]; - int blknum[KERNEL_N]; - int max; - block_num = param.block_dim; - g_ubBase = block_num; - uint8_t *code = (uint8_t *)malloc(MAX_L); - uint8_t *pos = code; - struct timespec tp1, tp2; - - clock_gettime(CLOCK_MONOTONIC, &tp1); - if (block_num > 32) { - printf("Error: block_num > 32\n"); - return 0; - } - //__OP_FOPEN__ - for (int i = 0; i < KERNEL_N; i++) { - for (int j = 0; j < ARG_N; j++) - AddArg(j, ARG_STEP * (j + 1)); - for (block_idx = 0; block_idx < block_num; block_idx++) { - //__OP_SET_KERNEL__ - int code_idx = i * block_num + block_idx; -#ifdef FP_CEILING - SetCtrlFloatEnable(); -#else - SetCtrlFloatDisable(); -#endif - CodeInit(pos, false); - __KERNEL_FUN__(__KERNEL_ARGS__, param.tiling_data); - CodeEnd(); - kernel[code_idx] = (char *)pos; - len[code_idx] = CodeLen(); - pos += len[code_idx]; - printf("kernel %d core %ld code generated len %d\n", i, block_idx, len[code_idx]); - } - blknum[i] = block_num; - } - //__OP_FCLOSE__ - clock_gettime(CLOCK_MONOTONIC, &tp2); - buf = (unsigned char *)malloc(MAX_E); - int fd = open(param.entry_file, O_RDONLY); - if (fd < 0) { - printf("[error]: cannot find entry.o : %s\n", param.entry_file); - return 0; - } - uint32_t bufSize = read(fd, buf, MAX_E); - if (bufSize <= 0) { - printf("[error]: entry.o : %s is too small ! \n", param.entry_file); - } - close(fd); - jit = (unsigned char *)malloc(MAX_L); - printf("total code generated %ld\n", pos - code); - int sz = elf_append((char *)buf, bufSize, (char *)jit, KERNEL_N, blknum, kernel, len, pos - code, ¶m.kernel_name); - if (tp1.tv_sec != tp2.tv_sec) { - printf("%ld NS\n", tp2.tv_nsec + 1000000000 - tp1.tv_nsec); - } else { - printf("%ld NS\n", tp2.tv_nsec - tp1.tv_nsec); - } - printf("new elf size %d\n", sz); - if (param.gentype == 0) { - fd = open(param.output_kernel_file, O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR); - (void)write(fd, jit, sz); - close(fd); - free(jit); - } else if (param.gentype == 1) { - *param.objptr = (char*)jit; - } - free(buf); - free(code); - return sz; -} - -REG_REPLAY_FUNC(__OPTYPE__, __OPS_PRODUCT__, __KERNEL_FUN___replay___OPS_PRODUCT__); diff --git a/cust_op/cust_op_by_addr/cmake/util/tik2_bin_param_build.py b/cust_op/cust_op_by_addr/cmake/util/tik2_bin_param_build.py deleted file mode 100644 index 98095ab8..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/tik2_bin_param_build.py +++ /dev/null @@ -1,121 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" - -import sys -import os -import json -import hashlib -import const_var -import opdesc_parser - -PYF_PATH = os.path.dirname(os.path.realpath(__file__)) - - -class BinParamBuilder(opdesc_parser.OpDesc): - def __init__(self: any, op_type: str): - super().__init__(op_type) - self.soc = '' - self.out_path = '' - - def set_soc_version(self: any, soc: str): - self.soc = soc - - def set_out_path(self: any, out_path: str): - self.out_path = out_path - - def gen_input_json(self: any): - key_map = {} - count = len(self.input_dtype[0].split(',')) - for i in range(0, count): - inputs = [] - outputs = [] - attrs = [] - op_node = {} - for idx in range(0, len(self.input_name)): - idtypes = self.input_dtype[idx].split(',') - ifmts = self.input_fmt[idx].split(',') - para = {} - para['name'] = self.input_name[idx] - para['index'] = idx - para['dtype'] = idtypes[i] - para['format'] = ifmts[i] - para['paramType'] = self.input_type[idx] - para['shape'] = [-2] - inputs.append(para) - for idx in range(0, len(self.output_name)): - odtypes = self.output_dtype[idx].split(',') - ofmts = self.output_fmt[idx].split(',') - para = {} - para['name'] = self.output_name[idx] - para['index'] = idx - para['dtype'] = odtypes[i] - para['format'] = ofmts[i] - para['paramType'] = self.output_type[idx] - para['shape'] = [-2] - outputs.append(para) - for attr in self.attr_list: - att = {} - att['name'] = attr - atype = self.attr_val.get(attr).get('type').lower() - atype = atype.replace('list', 'list_') - att['dtype'] = atype - att['value'] = const_var.ATTR_DEF_VAL.get(atype) - attrs.append(att) - op_node['bin_filename'] = '' - op_node['inputs'] = inputs - op_node['outputs'] = outputs - if len(attrs) > 0: - op_node['attrs'] = attrs - param = {} - param['op_type'] = self.op_type - param['op_list'] = [op_node] - objstr = json.dumps(param, indent=' ') - md5sum = hashlib.md5(objstr.encode('utf-8')).hexdigest() - while key_map.get(md5sum) is not None: - objstr += '1' - md5sum = hashlib.md5(objstr.encode('utf-8')).hexdigest() - key_map[md5sum] = md5sum - bin_file = self.op_type + '_' + md5sum - op_node['bin_filename'] = bin_file - param_file = os.path.join(self.out_path, bin_file + '_param.json') - param_file = os.path.realpath(param_file) - with os.fdopen(os.open(param_file, const_var.WFLAGS, const_var.WMODES), 'w') as fd: - json.dump(param, fd, indent=' ') - self._write_buld_cmd(param_file, bin_file, i) - - - def _write_buld_cmd(self: any, param_file: str, bin_file: str, index: int): - hard_soc = const_var.SOC_MAP_EXT.get(self.soc) - if not hard_soc: - hard_soc = soc.capitalize() - name_com = [self.op_type, self.op_file, str(index)] - compile_file = os.path.join(self.out_path, '-'.join(name_com) + '.sh') - compile_file = os.path.realpath(compile_file) - with os.fdopen(os.open(compile_file, const_var.WFLAGS, const_var.WMODES), 'w') as fd: - fd.write('#!/bin/bash\n') - fd.write('echo "[{}] Generating {} ..."\n'.format(hard_soc, bin_file)) - cmd = const_var.BIN_CMD.format(fun=self.op_intf, soc=hard_soc, param=param_file, impl='""') - fd.write(cmd) - chk = const_var.CHK_CMD.format(res_file=bin_file + '.json') - fd.write(chk) - chk = const_var.CHK_CMD.format(res_file=bin_file + '.o') - fd.write(chk) - fd.write('echo "[{}] Generating {} Done"\n'.format(hard_soc, bin_file)) - - -def gen_bin_param_file(cfgfile: str, out_dir: str, soc: str): - op_descs = opdesc_parser.get_op_desc(cfgfile, [], [], BinParamBuilder, None) - for op_desc in op_descs: - op_desc.set_soc_version(soc) - op_desc.set_out_path(out_dir) - op_desc.gen_input_json() - - -if __name__ == '__main__': - if len(sys.argv) <= 3: - raise RuntimeError('arguments must greater than 3') - gen_bin_param_file(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/cust_op/cust_op_by_addr/cmake/util/tik2_impl_build.py b/cust_op/cust_op_by_addr/cmake/util/tik2_impl_build.py deleted file mode 100644 index 70a21db5..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/tik2_impl_build.py +++ /dev/null @@ -1,376 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" - -import sys -import os -import stat -import opdesc_parser -import const_var - -PYF_PATH = os.path.dirname(os.path.realpath(__file__)) - -IMPL_HEAD = ''' -import os, sys -import ctypes -import json -import shutil -from tbe.common.platform import get_soc_spec -from tbe.common.utils import para_check -from tbe.tikcpp import compile_op, replay_op, check_op_cap, generalize_op_params, get_code_channel, OpInfo -from impl.util.platform_adapter import tbe_register -PYF_PATH = os.path.dirname(os.path.realpath(__file__)) - -DTYPE_MAP = {"float32": ["DT_FLOAT", "float"], - "float16": ["DT_FLOAT16", "half"], - "int8": ["DT_INT8", "int8_t"], - "int16": ["DT_INT16", "int16_t"], - "int32": ["DT_INT32", "int32_t"], - "int64": ["DT_INT64", "int64_t"], - "uint1": ["DT_UINT1", "uint8_t"], - "uint8": ["DT_UINT8", "uint8_t"], - "uint16": ["DT_UINT16", "uint16_t"], - "uint32": ["DT_UINT32", "uint32_t"], - "uint64": ["DT_UINT64", "uint64_t"], - "bool": ["DT_BOOL", "bool"], - "double": ["DT_DOUBLE", "double"], - "dual": ["DT_DUAL", "unknown"], - "dual_sub_int8": ["DT_DUAL_SUB_INT8", "unknown"], - "dual_sub_uint8": ["DT_DUAL_SUB_UINT8", "unknown"], - "string": ["DT_STRING", "unknown"], - "complex64": ["DT_COMPLEX64", "unknown"], - "complex128": ["DT_COMPLEX128", "unknown"], - "qint8": ["DT_QINT8", "unknown"], - "qint16": ["DT_QINT16", "unknown"], - "qint32": ["DT_QINT32", "unknown"], - "quint8": ["DT_QUINT8", "unknown"], - "quint16": ["DT_QUINT16", "unknown"], - "resource": ["DT_RESOURCE", "unknown"], - "string_ref": ["DT_STRING_REF", "unknown"], - "int4": ["DT_INT4", "int8_t"], - "bfloat16": ["DT_BF16", "half"]} - -def get_dtype_fmt_options(inputs, outputs): - options = [] - for x in inputs + outputs: - x_n = x.get("param_name").upper() - x_fmt = x.get("format") - x_dtype = x.get("dtype") - options.append("-DDTYPE_{n}={t}".format(n=x_n, t=DTYPE_MAP.get(x_dtype)[1])) - options.append("-DORIG_DTYPE_{n}={ot}".format(n=x_n, ot=DTYPE_MAP.get(x_dtype)[0])) - options.append("-DFORMAT_{n}={f}".format(n=x_n, f=x_fmt)) - return options - -def load_dso(so_path): - try: - ctypes.CDLL(so_path) - except OSError as error : - print(error) - raise RuntimeError("cannot open %s" %(so_path)) - else: - print("load so succ ", so_path) - -''' - -IMPL_API = ''' -@para_check.check_op_params({}) -def {}({}, kernel_name="{}", impl_mode=""): - inputs, outputs, attrs = _build_args({}) - options = get_dtype_fmt_options(inputs, outputs) - options += ["-x", "cce"] - ccec = shutil.which("ccec") - if ccec != None: - ccec_path = os.path.dirname(ccec) - tikcpp_path = os.path.realpath(os.path.join(ccec_path, "..", "..", "tikcpp")) - else: - tikcpp_path = os.path.realpath("/usr/local/Ascend/latest/compiler/tikcpp") - options.append("-I" + tikcpp_path) - options.append("-I" + os.path.join(tikcpp_path, "tikcfw")) - options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "impl")) - options.append("-I" + os.path.join(tikcpp_path, "tikcfw", "interface")) - origin_func_name = "{}" - src = os.path.join(PYF_PATH, "{}") -''' - -REPLAY_OP_API = ''' - print("call into replay op") - soc_version = get_soc_spec("SOC_VERSION") - soc_short = get_soc_spec("SHORT_SOC_VERSION").lower() - tikreplay_codegen_path = tikcpp_path + "/tikreplaylib/lib" - tikreplay_stub_path = tikcpp_path + "/tikreplaylib/lib/" + soc_version - print("start load libtikreplaylib_codegen.so and libtikreplaylib_stub.so") - codegen_so_path = tikreplay_codegen_path + "/libtikreplaylib_codegen.so" - replaystub_so_path = tikreplay_stub_path + "/libtikreplaylib_stub.so" - if PYF_PATH.endswith("dynamic"): - op_replay_path = os.path.join(PYF_PATH, "..", "..", "op_replay") - else: - op_replay_path = os.path.join(PYF_PATH, "..", "op_replay") - replayapi_so_path = os.path.join(op_replay_path, "libreplay_{}_" + soc_short + ".so") - load_dso(codegen_so_path) - load_dso(replaystub_so_path) - load_dso(replayapi_so_path) - op_type = "{}" - entry_obj = os.path.join(op_replay_path, "{}_entry_" + soc_short + ".o") - code_channel = get_code_channel(src, kernel_name, op_type, options) - op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = inputs, outputs = outputs, attrs = attrs,\\ - impl_mode = impl_mode) - res, msg = replay_op(op_info, entry_obj, code_channel, src, options) - if not res: - print("call replay op failed for %s and get into call compile op" %(msg)) - compile_op(src, origin_func_name, op_info, options, code_channel) -''' - -COMPILE_OP_API = ''' - print("call into compile op") - op_type = "{}" - code_channel = get_code_channel(src, kernel_name, op_type, options) - op_info = OpInfo(kernel_name = kernel_name, op_type = op_type, inputs = inputs, outputs = outputs, attrs = attrs,\\ - impl_mode = impl_mode) - compile_op(src, origin_func_name, op_info, options, code_channel) -''' - -SUP_API = ''' -def {}({}, impl_mode=""): - inputs, outputs, attrs = _build_args({}) - ret_str = check_op_cap("{}", "{}", inputs, outputs, attrs) - ret_dict = json.loads(ret_str) - err_code = ret_dict.get("ret_code") - sup = "Unknown" - reason = "Unknown reason" - if err_code is not None: - if err_code is 0: - sup = "True" - reason = "" - elif err_code is 1: - sup = "False" - reason = ret_dict.get("reason") - else: - sup = "Unknown" - reason = ret_dict.get("reason") - return sup, reason -''' -CAP_API = ''' -def {}({}, impl_mode=""): - inputs, outputs, attrs = _build_args({}) - result = check_op_cap("{}", "{}", inputs, outputs, attrs) - return result.decode("utf-8") -''' -GLZ_API = ''' -@tbe_register.register_param_generalization("{}") -def {}_generalization({}, generalize_config=None): - inputs, outputs, attrs = _build_args({}) - ret_str = generalize_op_params("{}", inputs, outputs, attrs, generalize_config) - return [json.loads(ret_str)] -''' - -ATTR_DEFAULT = {'bool': 'False', 'int': '0'} - - -class AdpBuilder(opdesc_parser.OpDesc): - def __init__(self: any, op_type: str): - self.argsname = [] - self.argsdefv = [] - super().__init__(op_type) - - def write_adapt(self: any, impl_path, path: str): - self._build_paradefault() - if impl_path != "": - src_file = os.path.join(impl_path, self.op_file + '.cpp') - if not os.path.exists(src_file): - return - out_path = os.path.abspath(path) - if self.dynamic_shape and not out_path.endswith('dynamic'): - out_path = os.path.join(path, 'dynamic') - os.makedirs(out_path, exist_ok=True) - adpfile = os.path.join(out_path, self.op_file + '.py') - with os.fdopen(os.open(adpfile, const_var.WFLAGS, const_var.WMODES), 'w') as fd: - self._write_head(fd) - self._write_argparse(fd) - self._write_impl(fd) - if self.op_chk_support: - self._write_cap('check_supported', fd) - self._write_cap('get_op_support_info', fd) - if self.op_fmt_sel: - self._write_cap('op_select_format', fd) - self._write_cap('get_op_specific_info', fd) - if self.op_range_limit == 'limited' or self.op_range_limit == 'dynamic': - self._write_glz(fd) - - def _ip_argpack(self: any, default: bool = True) -> list: - args = [] - for i in range(len(self.input_name)): - arg = self.input_name[i] - if default and self.argsdefv[i] is not None: - arg += '=' + self.argsdefv[i] - args.append(arg) - return args - - def _op_argpack(self: any, default: bool = True) -> list: - args = [] - argidx = len(self.input_name) - for i in range(len(self.output_name)): - arg = self.output_name[i] - if default and self.argsdefv[i + argidx] is not None: - arg += '=' + self.argsdefv[i + argidx] - args.append(arg) - return args - - def _attr_argpack(self: any, default: bool = True) -> list: - args = [] - argidx = len(self.input_name) + len(self.output_name) - for i in range(len(self.attr_list)): - att = self.attr_list[i] - arg = att - if default and self.argsdefv[i + argidx] is not None: - if self.attr_val.get(att).get('type') == 'str': - arg += '="' + self.argsdefv[i + argidx] + '"' - elif self.attr_val.get(att).get('type') == 'bool': - arg += '="' + self.argsdefv[i + argidx].capitalize() + '"' - else: - arg += '=' + self.argsdefv[i + argidx] - args.append(arg) - return args - - def _build_paralist(self: any, default: bool = True) -> str: - args = [] - args.extend(self._ip_argpack(default)) - args.extend(self._op_argpack(default)) - args.extend(self._attr_argpack(default)) - return ', '.join(args) - - def _io_parachk(self: any, types: list, type_name: str) -> list: - chk = [] - for iot in types: - if iot == 'optional': - ptype = 'OPTION' - else: - ptype = iot.upper() - chk.append('para_check.{}_{}'.format(ptype, type_name)) - return chk - - def _attr_parachk(self: any) -> list: - chk = [] - for att in self.attr_list: - if self.attr_val.get(att).get('paramType') == 'optional': - pt = 'OPTION' - else: - pt = self.attr_val.get(att).get('paramType').upper() - att_type = self.attr_val.get(att).get('type').upper() - att_type = att_type.replace('LIST', 'LIST_') - chk.append('para_check.{}_ATTR_{}'.format(pt, att_type)) - return chk - - def _build_parachk(self: any) -> str: - chk = [] - chk.extend(self._io_parachk(self.input_type, 'INPUT')) - chk.extend(self._io_parachk(self.output_type, 'OUTPUT')) - chk.extend(self._attr_parachk()) - chk.append('para_check.KERNEL_NAME') - return ', '.join(chk) - - def _build_paradefault(self: any): - optional = False - argtypes = [] - argtypes.extend(self.input_type) - argtypes.extend(self.output_type) - for atype in argtypes: - if atype == 'optional': - optional = True - if optional: - self.argsdefv.append('None') - else: - self.argsdefv.append(None) - for attr in self.attr_list: - atype = self.attr_val.get(attr).get('paramType') - if atype == 'optional': - optional = True - attrval = self.attr_val.get(attr).get('defaultValue') - if attrval is not None: - optional = True - if type == "bool": - attrval = attrval.capitalize() - elif type == "str": - attrval = "\"" + attrval + "\"" - self.argsdefv.append(attrval) - continue - if optional: - self.argsdefv.append(ATTR_DEFAULT.get(self.attr_val.get(attr).get('type'))) - else: - self.argsdefv.append(None) - - def _write_head(self: any, fd: object): - fd.write(IMPL_HEAD) - - def _write_argparse(self: any, fd: object): - args = self._build_paralist(False) - fd.write('def _build_args({}):\n'.format(args)) - fd.write(' inputs = []\n') - fd.write(' for arg in [{}]:\n'.format(', '.join(self.input_name))) - fd.write(' if arg != None:\n') - fd.write(' inputs.append(arg)\n') - fd.write(' outputs = []\n') - fd.write(' for arg in [{}]:\n'.format(', '.join(self.output_name))) - fd.write(' if arg != None:\n') - fd.write(' outputs.append(arg)\n') - fd.write(' attrs = []\n') - for attr in self.attr_list: - fd.write(' if {} != None:\n'.format(attr)) - fd.write(' attr = {}\n') - fd.write(' attr["name"] = "{}"\n'.format(attr)) - fd.write(' attr["dtype"] = "{}"\n'.format(self.attr_val.get(attr).get('type'))) - fd.write(' attr["value"] = {}\n'.format(attr)) - fd.write(' attrs.append(attr)\n') - fd.write(' return inputs, outputs, attrs\n') - - def _write_impl(self: any, fd: object): - argsdef = self._build_paralist() - argsval = self._build_paralist(False) - pchk = self._build_parachk() - if len(self.kern_name) > 0: - kern_name = self.kern_name - else: - kern_name = self.op_intf - src = self.op_file + '.cpp' - fd.write(IMPL_API.format(pchk, self.op_intf, argsdef, kern_name, argsval, self.op_intf, src)) - if self.op_replay_flag: - fd.write(REPLAY_OP_API.format(self.op_file, self.op_type, self.op_file)) - else: - fd.write(COMPILE_OP_API.format(self.op_type)) - - def _write_cap(self: any, cap_name: str, fd: object): - argsdef = self._build_paralist() - argsval = self._build_paralist(False) - if cap_name == 'check_supported': - fd.write(SUP_API.format(cap_name, argsdef, argsval, cap_name, self.op_type)) - else: - fd.write(CAP_API.format(cap_name, argsdef, argsval, cap_name, self.op_type)) - - def _write_glz(self: any, fd: object): - argsdef = self._build_paralist() - argsval = self._build_paralist(False) - fd.write(GLZ_API.format(self.op_type, self.op_intf, argsdef, argsval, self.op_type)) - - -def write_scripts(cfgfile: str, cfgs: dict, dirs: dict, ops: list = None): - batch_lists = cfgs.get(const_var.REPLAY_BATCH).split(';') - iterator_lists = cfgs.get(const_var.REPLAY_ITERATE).split(';') - file_map = {} - op_descs = opdesc_parser.get_op_desc(cfgfile, batch_lists, iterator_lists, AdpBuilder, ops) - for op_desc in op_descs: - op_desc.write_adapt(dirs.get(const_var.CFG_IMPL_DIR), dirs.get(const_var.CFG_OUT_DIR)) - file_map[op_desc.op_type] = op_desc.op_file - return file_map - -if __name__ == '__main__': - if len(sys.argv) <= 5: - raise RuntimeError('arguments must greater equal than 5') - rep_cfg = {} - rep_cfg[const_var.REPLAY_BATCH] = sys.argv[2] - rep_cfg[const_var.REPLAY_ITERATE] = sys.argv[3] - cfg_dir = {} - cfg_dir[const_var.CFG_IMPL_DIR] = sys.argv[4] - cfg_dir[const_var.CFG_OUT_DIR] = sys.argv[5] - write_scripts(sys.argv[1], rep_cfg, cfg_dir) diff --git a/cust_op/cust_op_by_addr/cmake/util/tik2_ops_config.py b/cust_op/cust_op_by_addr/cmake/util/tik2_ops_config.py deleted file mode 100644 index 2c881b67..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/tik2_ops_config.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" - -import sys -import os -import glob -import json -import argparse -import const_var - - -def load_json(json_file: str): - with open(json_file, encoding='utf-8') as file: - json_content = json.load(file) - return json_content - - -def get_specified_suffix_file(root_dir, suffix): - specified_suffix = os.path.join(root_dir, '**/*.{}'.format(suffix)) - all_suffix_files = glob.glob(specified_suffix, recursive=True) - return all_suffix_files - - -def add_simplified_config(op_type, key, objfile, config): - simple_cfg = config.get('binary_info_config.json') - op_cfg = simple_cfg.get(op_type) - if not op_cfg: - op_cfg = {} - op_cfg['dynamicRankSupport'] = True - op_cfg['simplifiedKeyMode'] = 0 - op_cfg['binaryList'] = [] - simple_cfg[op_type] = op_cfg - bin_list = op_cfg.get('binaryList') - bin_list.append({'simplifiedKey': key, 'binPath': objfile}) - - -def add_op_config(op_file, bin_info, config): - op_cfg = config.get(op_file) - if not op_cfg: - op_cfg = {} - op_cfg['binList'] = [] - config[op_file] = op_cfg - op_cfg.get('binList').append(bin_info) - - -def gen_ops_config(json_file, soc, config): - contents = load_json(json_file) - if ('binFileName' not in contents) or ('supportInfo' not in contents): - return - json_base_name = os.path.basename(json_file) - op_dir = os.path.basename(os.path.dirname(json_file)) - support_info = contents.get('supportInfo') - bin_name = contents.get('binFileName') - bin_suffix = contents.get('binFileSuffix') - bin_file_name = bin_name + bin_suffix - op_type = bin_name.split('_')[0] - op_file = op_dir + '.json' - bin_info = {} - key = support_info.get('simplifiedKey') - if key: - bin_info['simplifiedKey'] = key - add_simplified_config(op_type, key, os.path.join(soc, op_dir, bin_file_name), config) - bin_info['staticKey'] = support_info.get('staticKey') - bin_info['int64Mode'] = support_info.get('int64Mode') - bin_info['inputs'] = support_info.get('inputs') - bin_info['outputs'] = support_info.get('outputs') - if support_info.get('attrs'): - bin_info['attrs'] = support_info.get('attrs') - bin_info['binInfo'] = {'jsonFilePath': os.path.join(soc, op_dir, json_base_name)} - add_op_config(op_file, bin_info, config) - - -def gen_all_config(root_dir, soc): - suffix = 'json' - config = {} - config['binary_info_config.json'] = {} - all_json_files = get_specified_suffix_file(root_dir, suffix) - for _json in all_json_files: - gen_ops_config(_json, soc, config) - for cfg_key in config.keys(): - cfg_file = os.path.join(root_dir, cfg_key) - with os.fdopen(os.open(cfg_file, const_var.WFLAGS, const_var.WMODES), 'w') as fd: - json.dump(config.get(cfg_key), fd, indent=' ') - - -def args_prase(): - parser = argparse.ArgumentParser() - parser.add_argument('-p', - '--path', - nargs='?', - required=True, - help='Parse the path of the json file.') - parser.add_argument('-s', - '--soc', - nargs='?', - required=True, - help='Parse the soc_version of ops.') - return parser.parse_args() - - -def main(): - args = args_prase() - gen_all_config(args.path, args.soc) - - -if __name__ == '__main__': - main() diff --git a/cust_op/cust_op_by_addr/cmake/util/tik2_replay_build.py b/cust_op/cust_op_by_addr/cmake/util/tik2_replay_build.py deleted file mode 100644 index 1cac7d91..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/tik2_replay_build.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python -# -*- coding: UTF-8 -*- -""" -Created on Feb 28 20:56:45 2020 -Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. -""" - -import sys -import os -import opdesc_parser -import replay_codegen -import const_var -from replay_codegen import ReplayCodeGenParams - -PYF_PATH = os.path.dirname(os.path.realpath(__file__)) - - -class ReplayBuilder(opdesc_parser.OpDesc): - def __init__(self: any, op_type: str): - super().__init__(op_type) - - def gen_replay_source(self: any, impl_path: str, out_path: str, ops_product: str): - if not self.op_replay_flag: - print('{} replay not enabled'.format(self.op_type)) - return - argn = len(self.input_name) + len(self.output_name) + 1 - if self.op_replay_batch: - print('{} replay in batch mode'.format(self.op_type)) - else: - print('{} replay in normal mode'.format(self.op_type)) - if impl_path.endswith('op_kernel'): - implf = os.path.join(impl_path, self.op_file + '.cpp') - tiling_file = os.path.join(impl_path, "../op_host", self.op_file + '_tiling.h') - else: - if self.dynamic_shape: - dyn_path = 'dynamic' - else: - dyn_path = '' - implf = os.path.join(impl_path, dyn_path, self.op_file + '.cpp') - tiling_file = os.path.join(impl_path, "../../op_tiling", self.op_file + '_tiling.h') - rep_conf = replay_codegen.ReplayCodeGen(ReplayCodeGenParams(self.op_type, implf, tiling_file, self.op_file, \ - self.op_intf, argn, self.op_replay_batch, self.max_block_dim, self.max_shape_size)) - rep_conf.set_batch(self.op_replay_batch) - rep_conf.set_outdir(out_path) - rep_conf.gen_replay(ops_product) - - -def gen_replay(cfgfile: str, cfgs: dict, dirs: dict, ops_product: str, ops: list = None): - batch_lists = cfgs.get(const_var.REPLAY_BATCH).split(';') - iterator_lists = cfgs.get(const_var.REPLAY_ITERATE).split(';') - op_descs = opdesc_parser.get_op_desc(cfgfile, batch_lists, iterator_lists, ReplayBuilder, ops) - for op_desc in op_descs: - op_desc.gen_replay_source(dirs.get(const_var.CFG_IMPL_DIR), dirs.get(const_var.CFG_OUT_DIR), ops_product) - - -if __name__ == '__main__': - if len(sys.argv) <= 6: - raise RuntimeError('arguments must greater than 6') - rep_cfg = {} - rep_cfg[const_var.REPLAY_BATCH] = sys.argv[2] - rep_cfg[const_var.REPLAY_ITERATE] = sys.argv[3] - rep_dir = {} - rep_dir[const_var.CFG_IMPL_DIR] = sys.argv[4] - rep_dir[const_var.CFG_OUT_DIR] = sys.argv[5] - gen_replay(sys.argv[1], rep_cfg, rep_dir, sys.argv[6]) diff --git a/cust_op/cust_op_by_addr/cmake/util/tiling_data_def_build.py b/cust_op/cust_op_by_addr/cmake/util/tiling_data_def_build.py deleted file mode 100644 index 678756cb..00000000 --- a/cust_op/cust_op_by_addr/cmake/util/tiling_data_def_build.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -""" -Function: -The replay funtion entry -Copyright Information: -Huawei Technologies Co., Ltd. All Rights Reserved © 2020 -""" - -import sys -import os -import stat -import re -import const_var - - -def gen_tiling(tiling_header_file: str, tiling_file_out: str): - if not os.path.exists(tiling_header_file): - print("warning: no userdef tiling header file: ", tiling_header_file) - return - print("generate tiling def header file: ", tiling_file_out) - tiling_source = '#include \n' - tiling_source += '#include \n\n' - end_source = "" - pattern = re.compile(r'[(](.*)[)]', re.S) - with open(tiling_header_file, 'r') as fd: - lines = fd.readlines() - for line in lines: - line = line.strip() - if (line.startswith('BEGIN_TILING_DATA_DEF')): - tiling_source += '#pragma pack(1)\n' - tiling_source += 'struct ' - struct_def = re.findall(pattern, line)[0] - tiling_source += struct_def + ' {\n' - elif (line.startswith('TILING_DATA_FIELD_DEF_ARR')): - field_params = re.findall(pattern, line)[0] - fds = field_params.split(',') - tiling_source += ' {} {}[{}] = {{}};\n'.format(fds[0].strip(), fds[2].strip(), fds[1].strip()) - elif (line.startswith('TILING_DATA_FIELD_DEF_STRUCT')): - field_params = re.findall(pattern, line)[0] - fds = field_params.split(',') - tiling_source += ' {} {};\n'.format(fds[0].strip(), fds[1].strip()) - elif (line.startswith('TILING_DATA_FIELD_DEF')): - field_params = re.findall(pattern, line)[0] - fds = field_params.split(',') - tiling_source += ' {} {} = 0;\n'.format(fds[0].strip(), fds[1].strip()) - elif (line.startswith('END_TILING_DATA_DEF')): - tiling_source += '};\n' - tiling_source += '#pragma pack()\n\n' - tiling_source += '#ifdef __NPU_TILING__\n' - tiling_source += \ - 'inline [aicore] void InitTilingData(const __gm__ uint8_t* tiling, {}* const_data)\n'\ - .format(struct_def) - tiling_source += '{\n' - tiling_source += '}\n' - tiling_source += '#else\n' - tiling_source += 'inline void InitTilingData(uint8_t* tiling, {}* const_data)\n'.format(struct_def) - tiling_source += '{\n' - tiling_source += ' memcpy(const_data, tiling, sizeof({}));\n'.format(struct_def) - tiling_source += '}\n' - tiling_source += '#endif\n\n' - end_source = ''' -#define GET_TILING_DATA(tiling_data, tiling_arg) \\ -{} tiling_data; \\ -InitTilingData(tiling_arg, &tiling_data)\n -'''.format(struct_def) - tiling_source += end_source - with os.fdopen(os.open(tiling_file_out, const_var.WFLAGS, const_var.WMODES), 'w') as ofd: - ofd.write(tiling_source) - - -if __name__ == '__main__': - if len(sys.argv) <= 2: - raise RuntimeError('arguments must greater than 2') - gen_tiling(sys.argv[1], sys.argv[2]) diff --git a/cust_op/cust_op_by_addr/emb_custom.json b/cust_op/cust_op_by_addr/emb_custom.json new file mode 100644 index 00000000..29b375b3 --- /dev/null +++ b/cust_op/cust_op_by_addr/emb_custom.json @@ -0,0 +1,90 @@ +[ + { + "op": "EmbeddingLookupByAddress", + "language": "cpp", + "input_desc": [ + { + "name": "address", + "param_type": "required", + "format": [ + "ND","ND","ND" + ], + "type": [ + "int64","int64","int64" + ] + } + ], + "output_desc": [ + { + "name": "y", + "param_type": "required", + "format": [ + "ND","ND","ND" + ], + "type": [ + "fp32","fp16","int32" + ] + } + ], + "attr": [ + { + "name": "embedding_dim", + "param_type": "optional", + "type": "int" , + "default_value": 32 + }, + { + "name": "embedding_type", + "param_type": "optional", + "type": "int" , + "default_value": 1 + } + ] + }, + { + "op": "EmbeddingUpdateByAddress", + "language": "cpp", + "input_desc": [ + { + "name": "address", + "param_type": "required", + "format": [ + "ND","ND","ND" + ], + "type": [ + "int64","int64","int64" + ] + }, + { + "name": "embedding", + "param_type": "required", + "format": [ + "ND","ND","ND" + ], + "type": [ + "fp32","fp16","int32" + ] + } + ], + "output_desc": [ + { + "name": "y", + "param_type": "required", + "format": [ + "ND","ND","ND" + ], + "type": [ + "fp32","fp16","int32" + ] + } + ], + "attr": [ + { + "name": "update_type", + "param_type": "optional", + "type": "int" , + "default_value": 0 + } + ] + } +] \ No newline at end of file diff --git a/cust_op/cust_op_by_addr/framework/CMakeLists.txt b/cust_op/cust_op_by_addr/framework/CMakeLists.txt deleted file mode 100644 index b6be9b49..00000000 --- a/cust_op/cust_op_by_addr/framework/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/mindspore") - if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/caffe_plugin") - add_subdirectory(caffe_plugin) - endif() - if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/tf_plugin") - add_subdirectory(tf_plugin) - endif() - if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/onnx_plugin") - add_subdirectory(onnx_plugin) - endif() -endif() diff --git a/cust_op/cust_op_by_addr/framework/tf_plugin/CMakeLists.txt b/cust_op/cust_op_by_addr/framework/tf_plugin/CMakeLists.txt deleted file mode 100644 index 18b3f140..00000000 --- a/cust_op/cust_op_by_addr/framework/tf_plugin/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ - -aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} plugin_srcs) -add_library(cust_tf_parsers SHARED ${plugin_srcs}) -target_compile_definitions(cust_tf_parsers PRIVATE google=ascend_private) -target_link_libraries(cust_tf_parsers PRIVATE intf_pub graph) -install(TARGETS cust_tf_parsers - LIBRARY DESTINATION packages/vendors/${vendor_name}/framework/tensorflow -) diff --git a/cust_op/cust_op_by_addr/framework/tf_plugin/tensorflow_embedding_lookup_by_address_plugin.cc b/cust_op/cust_op_by_addr/framework/tf_plugin/tensorflow_embedding_lookup_by_address_plugin.cc deleted file mode 100644 index b8a49483..00000000 --- a/cust_op/cust_op_by_addr/framework/tf_plugin/tensorflow_embedding_lookup_by_address_plugin.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "register/register.h" - -namespace domi { -// register op info to GE -REGISTER_CUSTOM_OP("EmbeddingLookupByAddress") -.FrameworkType(TENSORFLOW) // type: CAFFE, TENSORFLOW -.OriginOpType("EmbeddingLookupByAddress") // name in tf module -.ParseParamsByOperatorFn(AutoMappingByOpFn); -REGISTER_CUSTOM_OP("EmbeddingUpdateByAddress") -.FrameworkType(TENSORFLOW) // type: CAFFE, TENSORFLOW -.OriginOpType("EmbeddingUpdateByAddress") // name in tf module -.ParseParamsByOperatorFn(AutoMappingByOpFn); -} // namespace domi \ No newline at end of file diff --git a/cust_op/cust_op_by_addr/op_host/CMakeLists.txt b/cust_op/cust_op_by_addr/op_host/CMakeLists.txt deleted file mode 100644 index 005b7d01..00000000 --- a/cust_op/cust_op_by_addr/op_host/CMakeLists.txt +++ /dev/null @@ -1,35 +0,0 @@ - -aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} ops_srcs) - -opbuild(OPS_SRC ${ops_srcs} - OUT_DIR ${ASCEND_AUTOGEN_PATH} -) - -add_library(cust_op_proto SHARED ${ops_srcs} ${ASCEND_AUTOGEN_PATH}/op_proto.cc) -target_compile_definitions(cust_op_proto PRIVATE OP_PROTO_LIB) -target_link_libraries(cust_op_proto PRIVATE intf_pub exe_graph register) -set_target_properties(cust_op_proto PROPERTIES OUTPUT_NAME - cust_opsproto_rt2.0 -) - -add_library(cust_optiling SHARED ${ops_srcs}) -target_compile_definitions(cust_optiling PRIVATE OP_TILING_LIB) -target_link_libraries(cust_optiling PRIVATE intf_pub exe_graph register) -set_target_properties(cust_optiling PROPERTIES OUTPUT_NAME - cust_opmaster_rt2.0 -) - -add_custom_target(optiling_compat ALL - COMMAND ln -sf lib/linux/${CMAKE_SYSTEM_PROCESSOR}/$ - ${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so -) - -install(TARGETS cust_op_proto - LIBRARY DESTINATION packages/vendors/${vendor_name}/op_proto/lib/linux/${CMAKE_SYSTEM_PROCESSOR}) -install(FILES ${ASCEND_AUTOGEN_PATH}/op_proto.h - DESTINATION packages/vendors/${vendor_name}/op_proto/inc) -install(TARGETS cust_optiling - LIBRARY DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling/lib/linux/${CMAKE_SYSTEM_PROCESSOR}) -install(FILES ${CMAKE_CURRENT_BINARY_DIR}/liboptiling.so - DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_tiling) - diff --git a/cust_op/cust_op_by_addr/op_host/readme.md b/cust_op/cust_op_by_addr/op_host/readme.md deleted file mode 100644 index 2f9b4fd9..00000000 --- a/cust_op/cust_op_by_addr/op_host/readme.md +++ /dev/null @@ -1,218 +0,0 @@ -# tik2 op_host - - -### Ŀ¼ṹ - -``` - op_host //ԭ&tiling - embedding_update_by_address_tiling.h - embedding_update_by_address.cpp - CMakeLists.txt -``` - - -### tiling.h ļд - -embedding_update_by_address_tiling.h - -˴һṹ - -Щıͨǽ shapeattrֵϢдkerenl -``` -#include "register/tilingdata_base.h" - -namespace optiling -{ - BEGIN_TILING_DATA_DEF(TilingData2) - TILING_DATA_FIELD_DEF(int32_t, update_dim); - TILING_DATA_FIELD_DEF(int32_t, addr_nums); - TILING_DATA_FIELD_DEF(int32_t, ub_limit); - TILING_DATA_FIELD_DEF(int32_t, embbeding_type); - TILING_DATA_FIELD_DEF(int32_t, update_type); - END_TILING_DATA_DEF; - - REGISTER_TILING_DATA_CLASS(EmbeddingUpdateByAddress, TilingData2) -} -``` - -### cpp ļд - -##### ṹ - -``` -static ge::graphStatus TilingFunc(gert::TilingContext *context) //tilingıֵ - -ge::graphStatus InferShape(gert::InferShapeContext *context) //shape - -ge::graphStatus InferShapeRange(gert::InferShapeRangeContext *context) //shaperange - -ge::graphStatus InferDataType(gert::InferDataTypeContext *context) //dtype - -class EmbeddingUpdateByAddress : public OpDef //˴ӵϢ -``` - -##### ͨô - -``` -//ȡ shape -gert::Shape *x_shape = context->GetInputShape(0) -gert::Shape *y_shape = context->GetInputShape(1) -gert::Shape *z_shape = context->GetOutputShape(0) - -std::vector x_dims = x_shape->GetDims(); // shapedims x_dims = {232,123,2} -size_t x_dims_len = x_shape->GetDimNum(); // shapeά x_dims_len = 3 -int64_t x_dim2 = x_shape->GetDim(1); // shapeĵڶά x_dim2 = 123 -int64_t x_shapesize = x_shape->GetDimNum(); // shapeܴС x_shapesize = 232*123*2 - -//shape -*z_shape=Shape(x_dims); //z_shape Ϊ {232,123,2} -*z_shape=*x_shape; //z_shape x_shapeһ - - -//ȡ -DataType input2_dtype = context->GetInputDataType(1); //ȡڶ DataType -context->SetOutputDataType(0, input2_dtype); -context->SetOutputDataType(1, ge::DataType(DT_FLOAT)); - - -//ȡ attr -auto *attrs = context->GetAttrs(); -const auto attr0_value = *(attrs->GetAttrPointer(0)); // attr0_value = ӵ1Եֵ -const auto attr1_value = *(attrs->GetAttrPointer(1)); // attr1_value = ӵ2Եֵ -``` - - - - -##### TilingFunc д - -``` -TilingData1 tiling; - -int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); //ȡС -tiling.set_addr_nums(input_shape); // tilingaddr_nums addr_nums=input_shape -tiling.set_xxx(input_shape); //xxx ӦΪtilingͷļõı -``` - -#### ԭ д - -``` -namespace ge -{ - ge::graphStatus InferShape(gert::InferShapeContext *context) - { - gert::Shape *y_shape = context->GetOutputShape(0); - int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); - int64_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; - y_shape->SetDimNum(2); - y_shape->SetDim(0, input_shape); - y_shape->SetDim(1, input_dim); - return GRAPH_SUCCESS; - } - ge::graphStatus InferShapeRange(gert::InferShapeRangeContext *context) - { - return GRAPH_SUCCESS; - } - ge::graphStatus InferDataType(gert::InferDataTypeContext *context) - { - context->SetOutputDataType(0, ge::DataType(DT_FLOAT)); - return GRAPH_SUCCESS; - } -} -``` - -#### Ϣд - -``` -namespace ops -{ - class EmbeddingUpdateByAddress : public OpDef - { - public: - EmbeddingUpdateByAddress(const char *name) : OpDef(name) - { - this->Input("address") //ӵ1 - .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("embedding") //ӵ2 - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Output("y") //ӵ1 - .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); - this->Attr("update_type").AttrType(OPTIONAL).Int(0); //ӵ1ԣĬֵ 0 - - this->SetInferShape(ge::InferShape) // ԭƵ - .SetInferDataType(ge::InferDataType); - //.SetInferShapeRange(ge::InferShapeRange); //ûãõ - - this->AICore() - .SetTiling(optiling::TilingFunc) - .SetTilingParse(optiling::TilingPrepare) - .SetCheckSupport(optiling::check_op_support); - - OpAICoreConfig aicConfig; - aicConfig.AsyncFlag(true) - .DynamicCompileStaticFlag(true) - .DynamicFormatFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(false) //ϲܻϾȣfp32->fp16 - .RangeLimitValue("limited"); - this->AICore().AddConfig("ascend910b", aicConfig); //֧910b - this->AICore().AddConfig("ascend910", aicConfig); //֧910 - } - }; - OP_ADD(EmbeddingUpdateByAddress, optiling::TilingCompileInfo); -} - -``` - - - -### tips - -#### Ͳο -``` - enum DataType { - DT_FLOAT = 0, // float type - DT_FLOAT16 = 1, // fp16 type - DT_INT8 = 2, // int8 type - DT_INT16 = 6, // int16 type - DT_UINT16 = 7, // uint16 type - DT_UINT8 = 4, // uint8 type - DT_INT32 = 3, // - DT_INT64 = 9, // int64 type - DT_UINT32 = 8, // unsigned int32 - DT_UINT64 = 10, // unsigned int64 - DT_BOOL = 12, // bool type - DT_DOUBLE = 11, // double type - DT_STRING = 13, // string type - DT_DUAL_SUB_INT8 = 14, /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15, /**< dual output uint8 type */ - DT_COMPLEX64 = 16, // complex64 type - DT_COMPLEX128 = 17, // complex128 type - DT_QINT8 = 18, // qint8 type - DT_QINT16 = 19, // qint16 type - DT_QINT32 = 20, // qint32 type - DT_QUINT8 = 21, // quint8 type - DT_QUINT16 = 22, // quint16 type - DT_RESOURCE = 23, // resource type - DT_STRING_REF = 24, // string ref type - DT_DUAL = 25, // dual output type - DT_VARIANT = 26, // dt_variant type - DT_BF16 = 27, // bf16 type - DT_UNDEFINED = 28, // Used to indicate a DataType field has not been set. - DT_INT4 = 29, // int4 type - DT_MAX // Mark the boundaries of data types - }; - -``` - diff --git a/cust_op/cust_op_by_addr/op_kernel/CMakeLists.txt b/cust_op/cust_op_by_addr/op_kernel/CMakeLists.txt deleted file mode 100644 index 6fe828d5..00000000 --- a/cust_op/cust_op_by_addr/op_kernel/CMakeLists.txt +++ /dev/null @@ -1,80 +0,0 @@ - -set(tikreplaylib_DIR ${ASCEND_TENSOR_COMPILER_PATH}/tikcpp/tikreplaylib/lib/cmake) -find_package(tikreplaylib REQUIRED) -message(STATUS "PACKAGE tikreplaylib FOUND") - -# replay config -set(BATCH_MODE_REPLAY_LIST ) -set(ITERATOR_MODE_REPLAY_LIST ) - -foreach(compute_unit ${ASCEND_COMPUTE_UNIT}) - - # generate aic-${compute_unit}-ops-info.json - add_ops_info_target(TARGET ops_info_gen_${compute_unit} - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/tbe/op_info_cfg/ai_core/${compute_unit}/aic-${compute_unit}-ops-info.json - OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini - INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/config/${compute_unit} - ) - - # generate tik2 impl py once - if (NOT TARGET tik2_impl_gen) - add_ops_impl_target(TARGET tik2_impl_gen - OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini - OPS_BATCH ${BATCH_MODE_REPLAY_LIST} - OPS_ITERATE ${ITERATOR_MODE_REPLAY_LIST} - IMPL_DIR ${CMAKE_CURRENT_SOURCE_DIR} - OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/tbe - INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/${vendor_name}_impl - ) - endif() - - # dynamic shape binary compile - if (${ENABLE_BINARY_PACKAGE}) - add_bin_compile_target(TARGET tik2_bin_${compute_unit} - OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini - IMPL_DIR ${CMAKE_CURRENT_SOURCE_DIR} - ADP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tbe/dynamic - OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/binary/${compute_unit} - INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/kernel - COMPUTE_UNIT ${compute_unit} - ) - add_dependencies(tik2_bin_${compute_unit} tik2_impl_gen) - endif() - - # generate replay _impl.cpp, _replay.cpp and _entry.cce - add_ops_replay_targets(OPS_BATCH ${BATCH_MODE_REPLAY_LIST} - OPS_ITERATE ${ITERATOR_MODE_REPLAY_LIST} - OPS_INFO ${ASCEND_AUTOGEN_PATH}/aic-${compute_unit}-ops-info.ini - IMPL_DIR ${CMAKE_CURRENT_SOURCE_DIR} - OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/tbe/replay/${compute_unit} - INSTALL_DIR packages/vendors/${vendor_name}/op_impl/ai_core/tbe/op_replay - COMPUTE_UNIT ${compute_unit} - ) -endforeach() - -# generate npu_supported_ops.json -add_npu_support_target(TARGET npu_supported_ops - OPS_INFO_DIR ${ASCEND_AUTOGEN_PATH} - OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/tbe/op_info_cfg/ai_core - INSTALL_DIR packages/vendors/${vendor_name}/framework/${ASCEND_FRAMEWORK_TYPE} -) - -if(ENABLE_TEST AND EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/testcases) - add_subdirectory(testcases) -endif() - -# install kernel file -if (${ENABLE_SOURCE_PACKAGE}) - file(GLOB KERNEL_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp - ) - install(FILES ${KERNEL_FILES} - DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/${vendor_name}_impl/dynamic - ) - file(GLOB KERNEL_FILES - ${CMAKE_CURRENT_SOURCE_DIR}/*.py - ) - install(FILES ${KERNEL_FILES} - DESTINATION packages/vendors/${vendor_name}/op_impl/ai_core/tbe/${vendor_name}_impl/dynamic - ) -endif() diff --git a/cust_op/cust_op_by_addr/readme.md b/cust_op/cust_op_by_addr/readme.md deleted file mode 100644 index a1fb963b..00000000 --- a/cust_op/cust_op_by_addr/readme.md +++ /dev/null @@ -1,197 +0,0 @@ -# tik2_custom_op - -### -tik2 Զ - - -### Ŀ¼ṹ - -``` - CMakeLists.txt - CMakePresets.json - README.md - build.sh - framework // - op_host //ԭ&tiling - xx.h - xx.cc - CMakeLists.txt - op_kernel //kernel - xx.cc - CMakeLists.txt - cmake - config.cmake - util - makeself - parse_ini_to_json.py - gen_ops_filter.sh -``` - - - -#### Լ -- ưմշ SampleTik2 -- hostkernelĿ¼µļǰ׺շתΪsnake,磺 SampleTik2 -> sample_tik2 -- tilingṹļ _tiling.h, 磺 sample_tik2_tiling.h -- ӵںΪsnakeʽ磺sample_tik2 - - -### ӹ̴/ - -#### ʹmsopgenߴ/ - -ҪȶӵIRjsonļDzο - -``` -[ -{ - "op":"AddDSL", - "input_desc":[ - { - "name":"x1", - "param_type":"required", - "format":[ - "NCHW" - ], - "type":[ - "fp16" - ] - }, - { - "name":"x2", - "param_type":"required", - "format":[ - "NCHW" - ], - "type":[ - "fp16" - ] - } - ], - "output_desc":[ - { - "name":"y", - "param_type":"required", - "format":[ - "NCHW" - ], - "type":[ - "fp16" - ] - } - ] - "attr":[ - { - "name":"n", - "param_type":"required", - "type":"int" - } - ] -} -] -``` - -2ʹmsopgen,jsonļתΪӹ -``` -#ӹ -msopgen gen -i xxx.json -f tf -c ai_core-ascend910 -lan cpp -m 0 -out new_tik2_custom -#ӹ׷ -msopgen gen -i xxx.json -f tf -c ai_core-ascend910 -lan cpp -m 1 -out ./ -``` -Ŀǰ׷СbugҪcmake/config.cmake ļ - -``` -set(ASCEND_COMPUTE_UNIT ascend910 ascend910b) #˴оƬǷ񺭸ӵӦ÷Χ -``` -3вӵ - -``` - op_host - xx_tiling.h //tilingṹ - xx.cc //ԭ/tiling/infershape/Ϣⶨ - op_kernel - xx.cc //kernelʵ -``` - - -#### ֹ - -1дõ3Ӵ뽻빤̶ӦĿ¼ -``` - op_host - new_op_tiling.h //tilingṹ - new_op.cc //ԭ/tiling/infershape/Ϣⶨ - op_kernel - new_op.cc //kernelʵ -``` -ɾӣ ֱɾ op_hostop_kernel Ŀ¼¶ӦӼɡ - -2cmake/config.cmakeļ -``` -set(ASCEND_COMPUTE_UNIT ascend910 ascend910b) #˴оƬǷ񺭸ӵӦ÷Χ -``` -3ڿ ļ - - - - - -### Ӱ뼰װ - -1. ޸CMakePresets.json ļ - -˴ASCEND_CANN_PACKAGE_PATH ޸ΪȷcannĿ¼ - -``` - "ASCEND_CANN_PACKAGE_PATH": { - "type": "PATH", - "value": "/usr/local/Ascend/ascend-toolkit/latest" - }, -``` - -2. û - - -``` -# װtoolkitʱ -source ${HOME}/Ascend/ascend-toolkit/set_env.sh -``` - -3. ԶӰ - -``` -./build.sh -``` - -4. Ӱװ - -ִԶӽԶװ ${ASCEND_OPP_PATH}/vendors Ŀ¼ -``` -./build_out/custom_opp__.run -``` - - - -### TIK2ɲο̳& -https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/63RC1alpha002/operatordevelopment/tik2opdevg/atlastik2_10_0001.html - -https://codehub-y.huawei.com/TIKC-Project/tikcpp_smoke/files?ref=master&filePath=external_tik2_demo - - - -### 빱 - -1. Fork ֿ -2. ½ Feat_xxx ֧ -3. ύ -4. ½ Pull Request - - -### ؼ - -1. ʹ Readme_XXX.md ֲ֧ͬԣ Readme_en.md, Readme_zh.md -2. Gitee ٷ [blog.gitee.com](httpsblog.gitee.com) -3. [httpsgitee.comexplore](httpsgitee.comexplore) ַ˽ Gitee ϵ㿪ԴĿ -4. [GVP](httpsgitee.comgvp) ȫ Gitee мֵԴĿۺ㿪ԴĿ -5. Gitee ٷṩʹֲ [httpsgitee.comhelp](httpsgitee.comhelp) -6. Gitee һչʾ Gitee ԱɵĿ [httpsgitee.comgitee-stars](httpsgitee.comgitee-stars) diff --git a/cust_op/cust_op_by_addr/run.sh b/cust_op/cust_op_by_addr/run.sh new file mode 100644 index 00000000..2ed188d8 --- /dev/null +++ b/cust_op/cust_op_by_addr/run.sh @@ -0,0 +1,44 @@ + +#source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /etc/profile +msopgen_path=$(find /usr/local/Ascend/ -name msopgen | grep bin) +# 截取上一层目录 +parent_dir=$(dirname "$msopgen_path") + +export PATH=$parent_dir:$PATH + +rm -rf ./custom_op +msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b -lan cpp -out ./custom_op -m 0 -op EmbeddingLookupByAddress +msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b -lan cpp -out ./custom_op -m 1 -op EmbeddingUpdateByAddress + +cp -rf op_kernel custom_op/ +cp -rf op_host custom_op/ + +cd custom_op + +# 判断当前目录下是否存在cmake.json文件 +if [ ! -f "CMakePresets.json" ]; then + echo "当前目录下不存在cmake.json文件" + exit 1 +fi + +#jq '.configurePresets.cacheVariables.ASCEND_CANN_PACKAGE_PATH.value = "/usr/local/Ascend/ascend-toolkit/latest"' CMakePresets.json > tmp.json +#mv tmp.json cmake.json + +sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json + +cd cmake + +# 判断当前目录下是否存在config.cmake文件 +if [ ! -f "config.cmake" ]; then + echo "当前目录下不存在cmake.json文件" + exit 1 +fi + + +sed -i 's:set(ASCEND_COMPUTE_UNIT ascend910b):set(ASCEND_COMPUTE_UNIT ascend910b ascend910):g' config.cmake + +cd .. + +bash build.sh + diff --git a/cust_op/cust_op_by_addr/scripts/install.sh b/cust_op/cust_op_by_addr/scripts/install.sh deleted file mode 100644 index e3988cde..00000000 --- a/cust_op/cust_op_by_addr/scripts/install.sh +++ /dev/null @@ -1,228 +0,0 @@ -#!/bin/bash -vendor_name=customize -targetdir=/usr/local/Ascend/opp -target_custom=0 - -sourcedir=$PWD/packages -vendordir=vendors/$vendor_name - -QUIET="y" - -for i in "$@" -do - echo $i - if test $i = "--quiet"; then - QUIET="y" - break - fi -done - -log() { - cur_date=`date +"%Y-%m-%d %H:%M:%S"` - echo "[runtime] [$cur_date] "$1 -} - -if [ -n "${ASCEND_CUSTOM_OPP_PATH}" ]; then - if [ ! -d ${ASCEND_CUSTOM_OPP_PATH} ]; then - mkdir -p ${ASCEND_CUSTOM_OPP_PATH} >> /dev/null 2>&1 - if [ $? -ne 0 ]; then - log "[ERROR] create ${ASCEND_CUSTOM_OPP_PATH} failed" - fi - fi - targetdir=${ASCEND_CUSTOM_OPP_PATH} -else - if [ "x${ASCEND_OPP_PATH}" == "x" ]; then - log "[ERROR] env ASCEND_OPP_PATH no exist" - exit 1 - fi - targetdir="${ASCEND_OPP_PATH}" -fi - -if [ ! -d $targetdir ];then - log "[ERROR] $targetdir no exist" - exit 1 -fi - -upgrade() -{ - if [ ! -d ${sourcedir}/$vendordir/$1 ]; then - log "[INFO] no need to upgrade ops $1 files" - return 0 - fi - - if [ ! -d ${targetdir}/$vendordir/$1 ];then - log "[INFO] create ${targetdir}/$vendordir/$1." - mkdir -p ${targetdir}/$vendordir/$1 - if [ $? -ne 0 ];then - log "[ERROR] create ${targetdir}/$vendordir/$1 failed" - return 1 - fi - else - has_same_file=-1 - for file_a in ${sourcedir}/$vendordir/$1/*; do - file_b=${file_a##*/}; - if [ "ls ${targetdir}/$vendordir/$1" = "" ]; then - log "[INFO] ${targetdir}/$vendordir/$1 is empty !!" - return 1 - fi - grep -q $file_b <<<`ls ${targetdir}/$vendordir/$1`; - if [[ $? -eq 0 ]]; then - echo -n "${file_b} " - has_same_file=0 - fi - done - if [ 0 -eq $has_same_file ]; then - if test $QUIET = "n"; then - echo "[INFO]: has old version in ${targetdir}/$vendordir/$1, \ - you want to Overlay Installation , please enter:[o]; \ - or replace directory installation , please enter: [r]; \ - or not install , please enter:[n]." - - while true - do - read orn - if [ "$orn" = n ]; then - return 0 - elif [ "$orn" = m ]; then - break; - elif [ "$0rn" = r ]; then - [ -n "${targetdir}/$vendordir/$1/" ] && rm -rf "${targetdir}/$vendordir/$1"/* - break; - else - echo "[ERROR] input error, please input again!" - fi - done - fi - fi - log "[INFO] replace or merge old ops $1 files .g....." - fi - - log "copy new ops $1 files ......" - if [ -d ${targetdir}/$vendordir/$1/ ]; then - chmod -R +w "$targetdir/$vendordir/$1/" >/dev/null 2>&1 - fi - cp -rf ${sourcedir}/$vendordir/$1/* $targetdir/$vendordir/$1/ - if [ $? -ne 0 ];then - log "[ERROR] copy new $1 files failed" - return 1 - fi - - return 0 -} -upgrade_proto() -{ - if [ ! -f ${sourcedir}/$vendordir/custom.proto ]; then - log "[INFO] no need to upgrade custom.proto files" - return 0 - fi - if [ ! -d ${targetdir}/$vendordir/framework/caffe ];then - log "[INFO] create ${targetdir}/$vendordir/framework/caffe." - mkdir -p ${targetdir}/$vendordir/framework/caffe - if [ $? -ne 0 ];then - log "[ERROR] create ${targetdir}/$vendordir/framework/caffe failed" - return 1 - fi - else - if [ -f ${targetdir}/$vendordir/framework/caffe/custom.proto ]; then - # 有老版本,判断是否要覆盖式安装 - if test $QUIET = "n"; then - echo "[INFO] ${targetdir}/$vendordir/framework/caffe has old version"\ - "custom.proto file. Do you want to replace? [y/n] " - - while true - do - read yn - if [ "$yn" = n ]; then - return 0 - elif [ "$yn" = y ]; then - break; - else - echo "[ERROR] input error, please input again!" - fi - done - fi - fi - log "[INFO] replace old caffe.proto files ......" - fi - chmod -R +w "$targetdir/$vendordir/framework/caffe/" >/dev/null 2>&1 - cp -rf ${sourcedir}/$vendordir/custom.proto ${targetdir}/$vendordir/framework/caffe/ - if [ $? -ne 0 ];then - log "[ERROR] copy new custom.proto failed" - return 1 - fi - log "[INFO] copy custom.proto success" - - return 0 -} - -log "[INFO] copy uninstall sh success" - -if [ ! -d ${targetdir}/vendors ];then - log "[INFO] create ${targetdir}/vendors." - mkdir -p ${targetdir}/vendors - if [ $? -ne 0 ];then - log "[ERROR] create ${targetdir}/vendors failed" - return 1 - fi -fi -chmod u+w ${targetdir}/vendors - -echo "[ops_custom]upgrade framework" -upgrade framework -if [ $? -ne 0 ];then - exit 1 -fi - -echo "[ops_custom]upgrade op proto" -upgrade op_proto -if [ $? -ne 0 ];then - exit 1 -fi - -echo "[ops_custom]upgrade op impl" -upgrade op_impl -if [ $? -ne 0 ];then - exit 1 -fi - -upgrade_proto -if [ $? -ne 0 ];then - exit 1 -fi - -config_file=${targetdir}/vendors/config.ini -if [ ! -f ${config_file} ]; then - touch ${config_file} - chmod 640 ${config_file} - echo "load_priority=$vendor_name" > ${config_file} - if [ $? -ne 0 ];then - echo "echo load_priority failed" - exit 1 - fi -else - found_vendors="$(grep -w "load_priority" "$config_file" | cut --only-delimited -d"=" -f2-)" - found_vendor=$(echo $found_vendors | sed "s/$vendor_name//g" | tr ',' ' ') - vendor=$(echo $found_vendor | tr -s ' ' ',') - if [ "$vendor" != "" ]; then - sed -i "/load_priority=$found_vendors/s@load_priority=$found_vendors@load_priority=$vendor_name,$vendor@g" "$config_file" - fi -fi - -chmod u-w ${targetdir}/vendors - -if [ -d ${targetdir}/$vendordir/op_impl/cpu/aicpu_kernel/impl/ ]; then - chmod -R 440 ${targetdir}/$vendordir/op_impl/cpu/aicpu_kernel/impl/* >/dev/null 2>&1 -fi -if [ -f ${targetdir}/ascend_install.info ]; then - chmod -R 440 ${targetdir}/ascend_install.info -fi -if [ -f ${targetdir}/scene.info ]; then - chmod -R 440 ${targetdir}/scene.info -fi -if [ -f ${targetdir}/version.info ]; then - chmod -R 440 ${targetdir}/version.info -fi - -echo "SUCCESS" -exit 0 - diff --git a/cust_op/cust_op_by_addr/scripts/upgrade.sh b/cust_op/cust_op_by_addr/scripts/upgrade.sh deleted file mode 100644 index 2fa595c9..00000000 --- a/cust_op/cust_op_by_addr/scripts/upgrade.sh +++ /dev/null @@ -1,121 +0,0 @@ -#!/bin/bash -vendor_name=customize -targetdir=/usr/local/Ascend/opp -target_custom=0 - -sourcedir=$PWD/packages -vendordir=vendors/$vendor_name - -log() { - cur_date=`date +"%Y-%m-%d %H:%M:%S"` - echo "[runtime] [$cur_date] "$1 -} - -if [[ "x${ASCEND_OPP_PATH}" == "x" ]];then - log "[ERROR] env ASCEND_OPP_PATH no exist" - exit 1 -fi - -targetdir=${ASCEND_OPP_PATH} - -if [ ! -d $targetdir ];then - log "[ERROR] $targetdir no exist" - exit 1 -fi - -upgrade() -{ - if [ ! -d ${sourcedir}/$vendordir/$1 ]; then - log "[INFO] no need to upgrade ops $1 files" - return 0 - fi - - iif [ ! -d ${targetdir}/$vendordir/$1 ];then - log "[INFO] create ${targetdir}/$vendordir/$1." - mkdir -p ${targetdir}/$vendordir/$1 - if [ $? -ne 0 ];then - log "[ERROR] create ${targetdir}/$vendordir/$1 failed" - return 1 - fi - else - vendor_installed_dir=$(ls "$targetdir/vendors" 2> /dev/null) - for i in $vendor_installed_dir;do - vendor_installed_file=$(ls "$vendor_installed_dir/$vendor_name/$i" 2> /dev/null) - if [ "$i" = "$vendor_name" ] && [ "$vendor_installed_file" != "" ]; then - echo "[INFO]: $vendor_name custom opp package has been installed on the path $vendor_installed_dir, \ - you want to Overlay Installation , please enter:[o]; \ - or replace directory installation , please enter: [r]; \ - or not install , please enter:[n]." - fi - while true - do - read mrn - if [ "$mrn" = m ]; then - break - elif [ "$mrn" = r ]; then - [ -n "$vendor_installed_file"] && rm -rf "$vendor_installed_file" - break - elif [ "$mrn" = n ]; then - return 0 - else - echo "[WARNING]: Input error, please input m or r or n to choose!" - fi - done - done - log "[INFO] replace old ops $1 files ......" - fi - - log "copy new ops $1 files ......" - cp -rf ${sourcedir}/$vendordir/$1/* $targetdir/$vendordir/$1/ - if [ $? -ne 0 ];then - log "[ERROR] copy new $1 files failed" - return 1 - fi - - return 0 -} -log "[INFO] copy uninstall sh success" - -echo "[ops_custom]upgrade framework" -upgrade framework -if [ $? -ne 0 ];then - exit 1 -fi - -echo "[ops_custom]upgrade op proto" -upgrade op_proto -if [ $? -ne 0 ];then - exit 1 -fi - -echo "[ops_custom]upgrade op impl" -upgrade op_impl -if [ $? -ne 0 ];then - exit 1 -fi - -config_file=${targetdir}/vendors/config.ini -found_vendors="$(grep -w "load_priority" "$config_file" | cut --only-delimited -d"=" -f2-)" -found_vendor=$(echo $found_vendors | sed "s/$vendor_name//g" | tr ',' ' ') -vendor=$(echo $found_vendor | tr -s ' ' ',') -if [ "$vendor" != "" ]; then - sed -i "/load_priority=$found_vendors/s@load_priority=$found_vendors@load_priority=$vendor_name,$vendor@g" "$config_file" -fi - -changemode() -{ - if [ -d ${targetdir} ];then - chmod -R 550 ${targetdir}>/dev/null 2>&1 - fi - - return 0 -} -echo "[ops_custom]changemode..." -#changemode -if [ $? -ne 0 ];then - exit 1 -fi - -echo "SUCCESS" -exit 0 - -- Gitee From db10da4fb40bd94957b565c3633196a5346a4269 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 5 Jun 2023 20:01:10 +0800 Subject: [PATCH 113/551] Match-id-c08244e47165f73318268d6f23bb9068fe154f66 --- example/little_demo/main.py | 56 ++++++++++------- mx_rec/core/asc/build_graph.py | 10 +-- mx_rec/core/asc/feature_spec.py | 8 +-- mx_rec/core/asc/helper.py | 18 +++--- mx_rec/core/embedding.py | 106 ++++++++++++++++++++++++-------- 5 files changed, 132 insertions(+), 66 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 116ef432..7bf70769 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -23,7 +23,8 @@ from mx_rec.util.variable import get_dense_and_sparse_variable tf.compat.v1.disable_eager_execution() -def make_batch_and_iterator(is_training, use_timestamp=False, dump_graph=False, batch_number=100): +def make_batch_and_iterator(is_training, feature_spec_list=None, + use_timestamp=False, dump_graph=False, batch_number=100): dataset = generate_dataset(cfg, use_timestamp=use_timestamp, batch_number=batch_number) if not MODIFY_GRAPH_FLAG: insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph) @@ -43,7 +44,7 @@ def model_forward(input_list, batch, is_train, modify_graph, config_dict=None): access_and_evict_config = config_dict.get(hash_table.table_name) embedding = sparse_lookup(hash_table, feature, send_count, dim=None, is_train=is_train, access_and_evict_config=access_and_evict_config, - name=hash_table.table_name + "_lookup", modify_graph=modify_graph) + name=hash_table.table_name + "_lookup", modify_graph=modify_graph, batch=batch) reduced_embedding = tf.reduce_sum(embedding, axis=1, keepdims=False) embedding_list.append(reduced_embedding) @@ -53,8 +54,9 @@ def model_forward(input_list, batch, is_train, modify_graph, config_dict=None): return my_model -def build_graph(hash_table_list, is_train, use_timestamp=False, config_dict=None, batch_number=100): - batch, iterator = make_batch_and_iterator(is_training=is_train, use_timestamp=use_timestamp, dump_graph=is_train, +def build_graph(hash_table_list, is_train, feature_spec_list=None, config_dict=None, batch_number=100): + batch, iterator = make_batch_and_iterator(is_train, feature_spec_list=feature_spec_list, + use_timestamp=USE_TIMESTAMP, dump_graph=is_train, batch_number=batch_number) if MODIFY_GRAPH_FLAG: input_list = [ @@ -67,7 +69,9 @@ def build_graph(hash_table_list, is_train, use_timestamp=False, config_dict=None model = model_forward(input_list, batch, is_train=is_train, modify_graph=True, config_dict=config_dict) else: - model = model_forward([feature_spec_list, hash_table_list, [cfg.user_send_cnt, cfg.item_send_cnt]], batch, + hash_table_list = [hash_table_list[0], hash_table_list[0], hash_table_list[0], hash_table_list[1]] + send_cnt_list = [cfg.user_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt, cfg.item_send_cnt] + model = model_forward([feature_spec_list, hash_table_list, send_cnt_list], batch, is_train=is_train, modify_graph=False, config_dict=config_dict) return iterator, model @@ -88,6 +92,26 @@ def evaluate(): break +def create_feature_spec_list(use_timestamp=False): + access_threshold = cfg.access_threshold if use_timestamp else None + eviction_threshold = cfg.eviction_threshold if use_timestamp else None + feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold), + FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold), + FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold), + FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold)] + if use_timestamp: + feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) + return feature_spec_list + + if __name__ == "__main__": tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) warnings.filterwarnings("ignore") @@ -122,21 +146,11 @@ if __name__ == "__main__": # access_threshold unit counts; eviction_threshold unit seconds ACCESS_AND_EVICT = None if USE_TIMESTAMP: - feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", - access_threshold=cfg.access_threshold, - eviction_threshold=cfg.eviction_threshold), - FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table", - access_threshold=cfg.access_threshold, - eviction_threshold=cfg.eviction_threshold), - FeatureSpec("timestamp", is_timestamp=True)] - config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) ACCESS_AND_EVICT = dict(user_table=config_for_user_table, item_table=config_for_item_table) - - else: - feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table"), - FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table")] + train_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) + eval_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) optimizer_list = [get_dense_and_sparse_optimizer(cfg) for _ in range(2)] sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] @@ -160,11 +174,11 @@ if __name__ == "__main__": mode=mode) train_iterator, train_model = build_graph([user_hashtable, item_hashtable], is_train=True, - use_timestamp=USE_TIMESTAMP, config_dict=ACCESS_AND_EVICT, - batch_number=cfg.batch_number) + feature_spec_list=train_feature_spec_list, + config_dict=ACCESS_AND_EVICT, batch_number=cfg.batch_number) eval_iterator, eval_model = build_graph([user_hashtable, item_hashtable], is_train=False, - use_timestamp=USE_TIMESTAMP, config_dict=ACCESS_AND_EVICT, - batch_number=cfg.batch_number) + feature_spec_list=eval_feature_spec_list, + config_dict=ACCESS_AND_EVICT, batch_number=cfg.batch_number) dense_variables, sparse_variables = get_dense_and_sparse_variable() rank_size = get_rank_size() diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 70bffe2f..4f30e5b4 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -143,10 +143,10 @@ def get_preprocessed_tensor_for_asc(table, config, ids_channel_name=None, modify swap_in = [tf.compat.v1.scatter_nd_update(table[i], nd_swap_pos, h2d_emb_split[i]) for i in range(len(table))] result = { - 'restore_vector' : restore_vector, - 'hot_pos' : hot_pos, - 'id_offsets' : id_offsets, - 'swap_in' : swap_in, - 'all2all_args' : all2all_args, + 'restore_vector': restore_vector, + 'hot_pos': hot_pos, + 'id_offsets': id_offsets, + 'swap_in': swap_in, + 'all2all_args': all2all_args, } return result diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index fb0bbbea..750c7e8f 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -178,10 +178,10 @@ class FeatureSpec: insert_feature_spec(self, is_training) result = { - 'tensor' : tensor, - 'table_name' : self.table_name, - 'feat_count' : self.feat_cnt, - 'split' : self.split, + 'tensor': tensor, + 'table_name': self.table_name, + 'feat_count': self.feat_cnt, + 'split': self.split, } return result diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 0c4688a2..5327276e 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -55,6 +55,7 @@ def find_dangling_table(table_names: List[str]): :param table_names: list of all created tables' names :return: a list of dangling table names. """ + def check_tensor(table_reachable_tensor: Tensor): """Check whether the tensor op is optimizer op or backward gradient. @@ -104,7 +105,7 @@ def find_dangling_table(table_names: List[str]): logging.info(f"*********** find tables: {table_lookup_op}***********") dangling_table = [] - def extend(op_list:List[Operation], + def extend(op_list: List[Operation], tensor: Tensor, spread_tensors: List[Tensor]): """extend the tensors which table lookup op can reach @@ -189,7 +190,7 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ dangling_tables = find_dangling_table(table_names) logging.info(f"In insert found dangling table(s): {dangling_tables} " - f"which does not need to be provided to the EmbInfo.") + f"which does not need to be provided to the EmbInfo.") def insert_fn_for_arg_indexes(*args): insert_tensors = get_target_tensors_with_args_indexes(args_index_list) @@ -227,7 +228,7 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ return insert_fn -def merge_feature_id_request(feature_id_list, split_list, table_name_list, feature_spec_names): +def merge_feature_id_request(feature_id_list, split_list, table_name_list): if not (len(feature_id_list) == len(split_list) and len(split_list) == len(table_name_list)): raise RuntimeError(f"shape not match. len(feature_id_list): {len(feature_id_list)}," f"len(split_list): {len(split_list)}" @@ -261,10 +262,10 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list, featu logging.debug(f"merge request from {table_name_list} {split_list} " f" to {output_table_name_list} {output_split_list}") list_set = { - 'output_feature_id_list' : output_feature_id_list, - 'output_split_list' : output_split_list, - 'output_table_name_list' : output_table_name_list, - 'output_tensorshape_split_list' : output_tensorshape_split_list, + 'output_feature_id_list': output_feature_id_list, + 'output_split_list': output_split_list, + 'output_table_name_list': output_table_name_list, + 'output_tensorshape_split_list': output_tensorshape_split_list, } return list_set @@ -272,7 +273,6 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list, featu def send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict): is_training = input_dict["is_training"] timestamp = input_dict["timestamp"] - feature_spec_names = input_dict["feature_spec_names"] auto_change_graph = input_dict["auto_change_graph"] host_pipeline_ops = get_host_pipeline_ops() use_static = get_use_static() @@ -283,7 +283,7 @@ def send_feature_id_request_async(feature_id_list, split_list, table_name_list, feature_id_list = feature_id_list[1:] if not auto_change_graph: # future support acg - list_set = merge_feature_id_request(feature_id_list, split_list, table_name_list, feature_spec_names) + list_set = merge_feature_id_request(feature_id_list, split_list, table_name_list) feature_id_list = list_set.get("output_feature_id_list") split_list = list_set.get("output_split_list") table_name_list = list_set.get("output_table_name_list") diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 331c1248..d3f0972f 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -5,6 +5,7 @@ import logging import math import time +from typing import Union from collections import defaultdict import numpy as np @@ -150,7 +151,7 @@ class SparseEmbedding: self.slice_host_vocabulary_size = 0 self.variable = None self.lookup_info = set() - self.lookup_result = None + self.lookup_result = dict() self.use_dynamic_expansion = get_use_dynamic_expansion() self.channel_name_list = [] self.send_count_map = dict() @@ -264,9 +265,10 @@ class SparseEmbedding: self.check_optimizer_instance() def get_default_lookup_name(self): - logging.debug(f"getting one default lookup name") self._default_name_count += 1 - return "sparse_lookup_%d" % self._default_name_count + default_name = "sparse_lookup_%d" % self._default_name_count + logging.debug(f"getting one default lookup name {default_name}") + return default_name def set_using_feature_mapping(self): self._use_feature_mapping = True @@ -320,7 +322,6 @@ class SparseEmbedding: if isinstance(feature, FeatureSpec): if not feature.initialized: raise ValueError(f"Feature Spec has not been initialized.") - key_info = "{}_{}".format(feature.name, feature.index_key) if is_training not in feature.pipeline_mode: raise ValueError(f"You have not config feature for is training mode '{is_training}', please config " f"feature with func sparse_lookup at first.") @@ -536,37 +537,81 @@ class SparseEmbedding: """ spec_name = feature_spec.name is_training = kwargs.get("is_train") - if self.lookup_result is not None and spec_name in self.lookup_result \ - and is_training in self.lookup_result.get(spec_name): + if spec_name in self.lookup_result and is_training in self.lookup_result.get(spec_name): return self.lookup_result.get(spec_name).get(is_training) + if not get_use_static() and kwargs.get("batch") is None: + raise RuntimeError(f"When the 'feature spec' mode and 'dynamic shape' are used, the 'batch' is required.") table_name = feature_spec.table_name same_table_feature_spec = ConfigInitializer.get_instance().table_name_to_feature_spec[table_name][is_training] - if len(same_table_feature_spec) == 0: + same_table_spec_count = len(same_table_feature_spec) + if same_table_spec_count == 0: raise RuntimeError(f"spec_name {spec_name} not in table {table_name}") - if len(same_table_feature_spec) == 1: + if same_table_spec_count == 1: lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) - self.lookup_result = {spec_name: {is_training: lookup_result}} + if spec_name not in self.lookup_result: + self.lookup_result[spec_name] = {} + self.lookup_result[spec_name][is_training] = lookup_result else: + def get_tensor_list() -> list: + """ + Use 'feature spec' to find the corresponding tensor from batch. + Returns: Tensor list in batch. + """ + same_table_tensor_list = [] + for feat_spec in same_table_feature_spec: + tensor = kwargs.get("batch").get(feat_spec.index_key) + if tensor is None: + raise KeyError(f"index_key '{feat_spec.index_key}' does not exist in batch.") + same_table_tensor_list.append(tensor) + return same_table_tensor_list + + def set_feature_spec_attr(mock_feature_spec: FeatureSpec, total_feature_count: Union[int, tf.Tensor]): + """ + Set properties for a temporary feature_spec. + Args: + mock_feature_spec: A temporary feature_spec consisting of multiple feature_spec with the same table. + total_feature_count: Inner product of the shape of a tensor. + Returns: None + """ + mock_feature_spec.batch_size = total_feature_count + mock_feature_spec.dims = [total_feature_count, 1] + mock_feature_spec.initialized = True + mock_feature_spec.pipeline_mode.add(True) + mock_feature_spec.pipeline_mode.add(False) + same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) - same_table_spec_count = len(same_table_feature_spec) - feature_count = [x.feat_cnt * x.batch_size for x in same_table_feature_spec] - total_feature_count = sum(feature_count) - mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", - feat_count=total_feature_count, table_name=table_name) - mock_feature_spec.batch_size = 1 - mock_feature_spec.dims = [1, total_feature_count] - mock_feature_spec.initialized = True - mock_feature_spec.pipeline_mode.add(True) - mock_feature_spec.pipeline_mode.add(False) + mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", feat_count=1, table_name=table_name) + + if get_use_static(): + tensor_list = [] + tensor_split_list = [feat_spec.split for feat_spec in same_table_feature_spec] + total_feature_count = sum(tensor_split_list) + else: + tensor_list = get_tensor_list() + tensor_split_list = [tf.math.reduce_prod(array_ops.shape(tensor)) for tensor in tensor_list] + total_feature_count = tf.add_n(tensor_split_list) + set_feature_spec_attr(mock_feature_spec, total_feature_count) + + kwargs["multi_lookup"] = True lookup_result = self.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, send_count * same_table_spec_count, **kwargs) - logging.debug(f"lookup table {table_name} via {feature_count}") - lookup_result = tf.reshape(lookup_result, [-1, self.scalar_emb_size]) - split_size = [x.feat_cnt * x.batch_size for x in same_table_feature_spec] - lookup_result_split = tf.split(lookup_result, split_size) - self.lookup_result = {k.name: {is_training: tf.reshape(v, k.dims + [self.scalar_emb_size])} - for k, v in zip(same_table_feature_spec, lookup_result_split)} + logging.debug(f"lookup table {table_name} via {tensor_split_list}") + + lookup_result_split = tf.split(lookup_result, tensor_split_list) + if len(lookup_result_split) != len(same_table_feature_spec) or ( + not get_use_static() and len(same_table_feature_spec) != len(tensor_list)): + raise RuntimeError(f"shape not match. len(lookup_result_split): {len(lookup_result_split)}," + f"len(same_table_feature_spec): {len(same_table_feature_spec)}" + f"len(tensor_list): {len(tensor_list)}") + for idx, (one_feature_spec, one_result) in enumerate(zip(same_table_feature_spec, lookup_result_split)): + if one_feature_spec.name not in self.lookup_result: + self.lookup_result[one_feature_spec.name] = {} + if get_use_static(): + dest_shape = one_feature_spec.dims + [self.scalar_emb_size] + else: + dest_shape = array_ops.concat([array_ops.shape(tensor_list[idx]), [self.scalar_emb_size]], 0) + self.lookup_result[one_feature_spec.name][is_training] = array_ops.reshape(one_result, dest_shape) self.check_multi_lookup_times() return self.lookup_result.get(spec_name).get(is_training) @@ -614,7 +659,7 @@ class SparseEmbedding: hot_pos = result.get("hot_pos") id_offsets = result.get("id_offsets") swap_in = result.get("swap_in") - all2all_matrix = result.get("all2all_matrix") + all2all_matrix = result.get("all2all_args") control_ops = swap_in id_offsets = tf.identity(id_offsets, name="identity_addr") @@ -647,7 +692,14 @@ class SparseEmbedding: if use_static: lookup_result = tf.reshape(embeddings, feature_spec.dims + [self.scalar_emb_size]) else: - lookup_result = tf.reshape(embeddings, [-1, self.scalar_emb_size]) + if kwargs.get("multi_lookup"): + lookup_result = tf.reshape(embeddings, [-1, self.scalar_emb_size]) + else: + tensor = kwargs.get("batch").get(feature_spec.index_key) + if tensor is None: + raise KeyError(f"index_key '{feature_spec.index_key}' does not exist in batch.") + dest_shape = array_ops.concat([array_ops.shape(tensor), [self.scalar_emb_size]], 0) + lookup_result = array_ops.reshape(embeddings, dest_shape) def grad(lookup_diff): embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) -- Gitee From 53cfcbcdb5cb187dc07a053266abebc60c6092d9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 7 Jun 2023 10:59:08 +0800 Subject: [PATCH 114/551] Match-id-21cae82adbd92b877431debd3b2b7a0032f34cd8 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 15 +++++++++++++-- src/core/key_process/key_process.cpp | 2 +- src/core/key_process/key_process.h | 4 ---- src/core/utils/common.cpp | 2 ++ src/core/utils/common.h | 8 +++++++- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index d306c620..a6c6a072 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -20,8 +20,8 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& { if (getenv("KEY_PROCESS_THREAD_NUM") != nullptr) { int num = std::atoi(getenv("KEY_PROCESS_THREAD_NUM")); - if (num > MAX_KEY_PROCESS_THREAD) { - spdlog::error("[HybridMgmt::InitKeyProcess] KEY_PROCESS_THREAD_NUM:{} should be less than {}", + if (num < 1 || num > MAX_KEY_PROCESS_THREAD) { + spdlog::error("[HybridMgmt::InitKeyProcess] KEY_PROCESS_THREAD_NUM:{}, should in range [1, {}]", num, MAX_KEY_PROCESS_THREAD); return false; } @@ -29,6 +29,17 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& spdlog::info("config KEY_PROCESS_THREAD_NUM:{}", num); } + if (getenv("MAX_UNIQUE_THREAD_NUM") != nullptr) { + int num = std::atoi(getenv("MAX_UNIQUE_THREAD_NUM")); + if (num < 1 || num > DEFAULT_MAX_UNIQUE_THREAD_NUM) { + spdlog::error("[HybridMgmt::InitKeyProcess] MAX_UNIQUE_THREAD_NUM:{}, should in range [1, {}]", + num, DEFAULT_MAX_UNIQUE_THREAD_NUM); + return false; + } + PerfConfig::maxUniqueThreadNum = num; + spdlog::info("config MAX_UNIQUE_THREAD_NUM:{}", num); + } + preprocess = Singleton::GetInstance(); preprocess->Initialize(rankInfo, embInfos, thresholdValues, ifLoad, seed); preprocess->Start(); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index c590681d..9c5dec8d 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -476,7 +476,7 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba splitSize.data(), keySendInfo.keySend.data(), idCount.data(), keySendInfo.keyCount.data()}; UniqueFlag uniqueFlag = {batch->isInt64, rankInfo.useStatic, rankInfo.useHot}; UniqueForHot uniqueForHot = {hotOffset, uniqueInfoOut.hotPos.data(), hotMap, keyCountMap}; - UniqueThreadNum uniqueThreadNum = {MIN_UNIQUE_THREAD_NUM, MAX_UNIQUE_THREAD_NUM}; + UniqueThreadNum uniqueThreadNum = {MIN_UNIQUE_THREAD_NUM, PerfConfig::maxUniqueThreadNum}; unique->Compute(&pool_, uniqueData, uniqueFlag, uniqueForHot, uniqueThreadNum); EASY_END_BLOCK diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 77fc78e2..e6fa2b47 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -35,10 +35,6 @@ namespace MxRec { using namespace std; - constexpr int UNIQUE_BUCKET = 6; - constexpr int MIN_UNIQUE_THREAD_NUM = 1; - constexpr int MAX_UNIQUE_THREAD_NUM = 8; - using a2a_info_t = vector; using sharded_dedup = ShardedDedup; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 2485d5ba..a4d23fe7 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -19,6 +19,8 @@ using std::chrono::system_clock; namespace MxRec { int PerfConfig::keyProcessThreadNum = DEFAULT_KEY_PROCESS_THREAD; + int PerfConfig::maxUniqueThreadNum = DEFAULT_MAX_UNIQUE_THREAD_NUM; + RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, const vector& maxStep) : rankId(rankId), deviceId(deviceId), localRankSize(localRankSize), option(option), diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 41f26558..27d4ff24 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -59,10 +59,16 @@ namespace MxRec { constexpr int MAX_CHANNEL_NUM = 2; constexpr int MAX_KEY_PROCESS_THREAD = 10; constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * MAX_KEY_PROCESS_THREAD; - constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; + + // unique related config + constexpr int UNIQUE_BUCKET = 6; + constexpr int MIN_UNIQUE_THREAD_NUM = 1; + constexpr int DEFAULT_MAX_UNIQUE_THREAD_NUM = 8; + struct PerfConfig { static int keyProcessThreadNum; + static int maxUniqueThreadNum; }; constexpr int KEY_PROCESS_TIMEOUT = 120; -- Gitee From ee8bb8df8ff342f0ad2e4fb81ca89937e21b5085 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 7 Jun 2023 11:39:17 +0800 Subject: [PATCH 115/551] Match-id-75ede926ba912a93cb727ab5d11dc86ee5e3f343 --- build/build.sh | 2 +- build/package.sh | 60 ------------------------------------------------ 2 files changed, 1 insertion(+), 61 deletions(-) delete mode 100644 build/package.sh diff --git a/build/build.sh b/build/build.sh index b7fc65f3..9d81f52c 100644 --- a/build/build.sh +++ b/build/build.sh @@ -156,7 +156,7 @@ gen_wheel_file() touch "${src_path}"/libasc/__init__.py remove "${ROOT_DIR}"/mx_rec/libasc mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec - python3 setup.py bdist_wheel + python3 setup.py bdist_wheel --plat-name=linux_$(arch) mkdir -p "$1" mv dist/mx_rec*.whl "$1" remove "${ROOT_DIR}"/mx_rec/libasc diff --git a/build/package.sh b/build/package.sh deleted file mode 100644 index 3e4a1e07..00000000 --- a/build/package.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Package script -# Copyright © Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -set -e - -CURDIR=$(dirname "$(readlink -f "$0")") -SCRIPT_NAME=$(basename "$0") -ROOT_PATH=$(readlink -f "$CURDIR"/../) -OUTPUT_PATH="$ROOT_PATH/output" -VERSION_FILE=${SCRIPT_DIR}/../../mindxsdk/build/conf/config.yaml -get_version() { - if [ -f "$VERSION_FILE" ]; then - VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") - if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then - VERSION=${VERSION%.*} - fi - else - VERSION="5.0.T104" - fi -} - -get_version -export VERSION - - -function make_zip_package() -{ - cd "${OUTPUT_PATH}" - pkg_file=$(ls "$OUTPUT_PATH"/*"${1}"*."${2}") - pkg_file="${pkg_file##*/}" - pkg_release="${pkg_file%."${2}"}" - - package_file="${OUTPUT_PATH}"/package - [ -d "$package_file" ] && rm -rf "$package_file" - mkdir "$package_file" - cp -f "${OUTPUT_PATH}"/crldata.crl "$OUTPUT_PATH/${pkg_release}.${2}.crl" - cp "$pkg_release".* "$package_file" - - cd "$package_file" - chmod 600 "$pkg_release.${2}" - chmod 600 "$pkg_release.${2}".cms - chmod 600 "$pkg_release.${2}".crl - zip_file="${3}$pkg_release.zip" - zip -r "$zip_file" "$pkg_release.${2}" "$pkg_release.${2}".cms "$pkg_release.${2}".crl - - mv "$package_file/$zip_file" "${OUTPUT_PATH}/$zip_file" - echo "zip $zip_file success !" - [ -d "$package_file" ] && rm -rf "$package_file" - return 0 -} - -function main() -{ - make_zip_package Ascend-mindxsdk-mxrec tar.gz - return 0 -} - -echo "begin to execute $SCRIPT_NAME" -main;ret="$?" -echo "finish exuecte $SCRIPT_NAME, result is $ret" -- Gitee From 4bd82b3b764cec74814dfe8fd83fa5c417102c9d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 7 Jun 2023 17:39:07 +0800 Subject: [PATCH 116/551] Match-id-1a704039eda66b78bf7dd589d1d1623d98aa0101 --- mx_rec/core/asc/manager.py | 2 +- mx_rec/core/embedding.py | 7 ++- mx_rec/optimizers/adagrad.py | 6 +-- mx_rec/optimizers/lazy_adam.py | 12 ++--- mx_rec/saver/patch.py | 27 ++++++++--- mx_rec/saver/saver.py | 62 +++++++++++++++++++++--- mx_rec/util/initialize.py | 17 +++++++ src/core/checkpoint/checkpoint.cpp | 13 +++++ src/core/checkpoint/checkpoint.h | 1 + src/core/utils/common.h | 4 +- src/pybind/module_main.cpp | 6 ++- src/tests/checkpoint/checkpoint_test.cpp | 1 + src/tests/emb_mgmt/emb_mgmt_test.cpp | 9 ++-- 13 files changed, 134 insertions(+), 33 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 8c2ac8c9..3f193fc2 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -76,7 +76,7 @@ def generate_table_info_list(): logging.debug(f"table_instance, table_name: {table_instance.table_name}, channel_name_list: " f"{table_instance.channel_name_list}, send_count_map: {table_instance.send_count_map}") table_info = EmbInfo(table_instance.table_name, table_instance.send_count, table_instance.scalar_emb_size, - table_instance.ext_emb_size, table_instance.modify_graph, + table_instance.ext_emb_size, table_instance.modify_graph, table_instance.is_save, table_instance.channel_name_list, [table_instance.slice_device_vocabulary_size, table_instance.slice_host_vocabulary_size], diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index d3f0972f..de507c00 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -24,7 +24,7 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, trigger_evict, get_table_instance_by_name, get_use_hot, get_device_id, export_feature_spec, \ ConfigInitializer, get_ascend_global_hashtable_collection, get_host_pipeline_ops, get_use_dynamic_expansion, \ - set_modify_graph + set_modify_graph, insert_removing_var_list from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.util.variable import remove_saving_var @@ -42,6 +42,7 @@ def create_table(**kwargs): shard_num = kwargs.get("shard_num", 1) fusion_optimizer_var = kwargs.get("fusion_optimizer_var", True) hashtable_threshold = kwargs.get("hashtable_threshold", 0) + is_save = kwargs.get("is_save", True) """ Args: @@ -63,7 +64,7 @@ def create_table(**kwargs): config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, optimizer_list=optimizer_list, mode=mode, value_dtype=value_dtype, shard_num=shard_num, - fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold) + fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold, is_save=is_save) embedding = SparseEmbedding(config) return embedding @@ -137,6 +138,7 @@ class SparseEmbedding: self._optimizer_instance_list = config.get("optimizer_list") self.emb_initializer = config.get("emb_initializer") self._mode = config.get("mode") + self.is_save = config.get("is_save") self.optimizer_slot_info_list = [] self._slot_num = dict() self._send_count = 0 @@ -752,6 +754,7 @@ class SparseEmbedding: def _initialize_variables(self): initialized_tensor = self.emb_initializer(self.slice_device_vocabulary_size + self.embedding_size) self.variable = tf.compat.v1.get_variable(self.table_name, trainable=False, initializer=initialized_tensor) + insert_removing_var_list(self.variable.name) # make sure sparse table variable will not be saved and restored within tf checkpoint. remove_saving_var(self.variable) self._record() diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index f1499c25..306208ca 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -16,8 +16,8 @@ from tensorflow.python.training import adagrad, training_ops from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance -from mx_rec.util.variable import remove_saving_var, check_param_type +from mx_rec.util.initialize import get_table_instance, insert_removing_var_list +from mx_rec.util.variable import check_param_type def create_hash_optimizer(learning_rate=0.001, @@ -63,7 +63,7 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): return new_slot_variable accumulator = creat_one_single_slot(var, self._name + "/" + "accumulator") - remove_saving_var(accumulator) + insert_removing_var_list(accumulator.name) named_slot_key = (var.op.graph, var.op.name) table_instance = get_table_instance(var) if self._name in table_instance.optimizer: diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index a578b7df..42904df0 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -19,8 +19,8 @@ from tensorflow.python.training import adam from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance -from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var, check_param_type, check_param_range +from mx_rec.util.initialize import get_table_instance, insert_removing_var_list +from mx_rec.util.variable import check_and_get_config_via_var, check_param_type, check_param_range def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdam"): @@ -66,8 +66,8 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): momentum = creat_one_single_slot(var, self._name + "/" + "momentum") velocity = creat_one_single_slot(var, self._name + "/" + "velocity") - remove_saving_var(momentum) - remove_saving_var(velocity) + insert_removing_var_list(momentum.name) + insert_removing_var_list(velocity.name) named_slot_key = (var.op.graph, var.op.name) table_instance = get_table_instance(var) if self._name in table_instance.optimizer: @@ -186,8 +186,8 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): momentum = self._zeros_slot(each_var, "m", m_state_name) velocity = self._zeros_slot(each_var, "v", v_state_name) # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - remove_saving_var(momentum) - remove_saving_var(velocity) + insert_removing_var_list(momentum.name) + insert_removing_var_list(velocity.name) if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 34acdaad..8c8bee45 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -26,7 +26,7 @@ from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util from mx_rec.saver.saver import Saver as SparseSaver -from mx_rec.util.initialize import get_ascend_global_hashtable_collection +from mx_rec.util.initialize import get_ascend_global_hashtable_collection, export_removing_var_list def get_sparse_vars(var_list): @@ -60,12 +60,12 @@ def saver_init(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=saver_pb2.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None, fid_version=0): + var_list = build_var_list(var_list) self._last_checkpoints = [] self._checkpoints_to_be_deleted = [] self._var_list = var_list self._is_built = False self._is_empty = None - init_check(defer_build, var_list) self._write_version = write_version self._reshape = reshape @@ -92,7 +92,7 @@ def saver_init(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, self._next_checkpoint_time = (time.time() + keep_time) elif not defer_build: self.build() - self._object_restore_saver = None + self._object_restllore_saver = None # mxRec Patch # create sparse saver only when var_list is not None self.sparse_saver = None @@ -115,6 +115,7 @@ def get_model_checkpoint_path(self, checkpoint_file, sess): # save sparse model, only run when self.sparse_saver is not None if self.sparse_saver: self.sparse_saver.save(sess, save_path=checkpoint_file) + logging.info("Save model into dir %s", checkpoint_file) else: self._build_eager(checkpoint_file, build_save=True, build_restore=False) @@ -209,13 +210,13 @@ def restore(self, sess, save_path): tf_logging.info("Restoring parameters from %s", checkpoint_prefix) try: if not context.executing_eagerly(): - sess.run(self.saver_def.restore_op_name, - {self.saver_def.filename_tensor_name: save_path}) # mxRec Patch # restore sparse model, only run when self.sparse_saver is not None if self.sparse_saver: self.sparse_saver.restore(sess, save_path) + sess.run(self.saver_def.restore_op_name, + {self.saver_def.filename_tensor_name: save_path}) logging.info("Restore from dir %s", save_path) else: self._build_eager(save_path, build_save=False, build_restore=True) @@ -271,7 +272,8 @@ def saver_from_object_based_checkpoint(checkpoint_path, var_list=None, builder=N raise ValueError("Checkpoint in %s not an object-based checkpoint." % checkpoint_path) from err if var_list is None: - var_list = variables._all_saveable_objects() + var_list = build_var_list(var_list) + if builder is None: builder = BulkSaverBuilder() @@ -311,6 +313,17 @@ def saver_from_object_based_checkpoint(checkpoint_path, var_list=None, builder=N return cached_saver +def build_var_list(var_list): + if var_list is None: + var_list = [] + tmp_list = variables._all_saveable_objects() + removing_var_list = export_removing_var_list() + for var in tmp_list: + if var.name not in removing_var_list: + var_list.append(var) + return var_list + + class BaseSaverBuilder(object): VariableSaveable = saveable_object_util.ReferenceVariableSaveable SaveSpec = saveable_object.SaveSpec @@ -357,3 +370,5 @@ def patch_for_saver(): dense_saver.save = save dense_saver.restore = restore logging.debug("Class tf.train.Saver has been patched.") + + diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 02b70bd1..5989c150 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -38,8 +38,13 @@ class Saver(object): def build(self): if self.var_list is None: + self.var_list = [] logging.debug(f"optimizer collection name: {get_ascend_global_hashtable_collection()}") - self.var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + temp_var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + for var in temp_var_list: + table_instance = get_table_instance(var) + if table_instance.is_save: + self.var_list.append(var) with tf.compat.v1.variable_scope("mx_rec_save"): self._build_save() @@ -156,15 +161,16 @@ class Saver(object): self.rank_id) def _restore(self, sess, reading_path): - if is_asc_manager_initialized(): - restore_host_data(reading_path) - logging.debug(f"host data was restored.") - restore_feed_dict = defaultdict(dict) for table_name, sub_placeholder_dict in self.placeholder_dict.items(): fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, NameDescriptor(table_name, DataName.EMBEDDING.value)) table_instance = get_table_instance_by_name(table_name) + + if is_asc_manager_initialized(): + restore_host_data(reading_path) + logging.debug(f"host data was restored.") + if table_instance.use_feature_mapping: fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, NameDescriptor(table_name, DataName.FEATURE_MAPPING.value)) @@ -199,7 +205,7 @@ def fill_placeholder(reading_path, placeholder_dict, feed_dict, suffix, name_des else: target_path = generate_path(reading_path, "HashTable", "HBM", name_descriptor.table_name, name_descriptor.data_name) - restore_data_dict = read_binary_data(target_path, suffix, name_descriptor.data_name) + restore_data_dict = read_binary_data(target_path, suffix, name_descriptor.data_name, name_descriptor.table_name) for key, data in restore_data_dict.items(): embedding_placeholder = placeholder_dict.get(key) @@ -278,7 +284,15 @@ def write_binary_data(writing_path, suffix, data, attributes=None): file.write(json.dumps(attributes)) -def read_binary_data(reading_path, suffix, data_name): +def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: str) -> dict: + """ + Read sparse origin data from binary file + :param reading_path: sparse data path + :param suffix: suffix of sparse data + :param data_name: the data type,including embedding, offset, etc. + :param table_name: the sparse table name + :return: the sparse data dict + """ data_file, attribute_file = generate_file_name(suffix) target_data_dir = os.path.join(reading_path, data_file) target_attribute_dir = os.path.join(reading_path, attribute_file) @@ -294,8 +308,14 @@ def read_binary_data(reading_path, suffix, data_name): raise AttributeError(f"Lack of attribute {DataAttr.DATATYPE.value}.") data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) + if DataAttr.SHAPE.value in attributes: - data_to_restore = data_to_restore.reshape(attributes.pop(DataAttr.SHAPE.value)) + data_shape = attributes.pop(DataAttr.SHAPE.value) + data_to_restore = data_to_restore.reshape(data_shape) + table_instance = get_table_instance_by_name(table_name) + current_data_shape = [table_instance.slice_device_vocabulary_size, table_instance.scalar_emb_size] + if data_shape != current_data_shape: + data_to_restore = process_embedding_data(data_to_restore, current_data_shape, data_shape) data_dict = {data_name: data_to_restore} for key, item in attributes.items(): @@ -304,3 +324,29 @@ def read_binary_data(reading_path, suffix, data_name): logging.debug(f"Reading shape is {data_to_restore.shape}.") return data_dict + + +def process_embedding_data(data_to_restore: np.ndarray, current_data_shape: list, data_shape: list) -> np.ndarray: + """ + Process embedding data when reading binary file + :param data_to_restore: the embedding data reading from the binary file + :param current_data_shape: current embedding data shape set by user + :param data_shape: embedding data shape saved in the binary file + :return: the embedding data + """ + try: + restore_vocab_size, restore_emb_size = current_data_shape + vocab_size, emb_size = data_shape + except ValueError as err: + raise ValueError(f"The shape dimension of a sparse table cannot exceed two dimensions. ") from err + + if restore_vocab_size > vocab_size: + pad_count = restore_vocab_size - vocab_size + pad_matrix = np.zeros((pad_count, restore_emb_size)) + data_to_restore = np.concatenate((data_to_restore, pad_matrix), axis=0) + + elif restore_vocab_size < vocab_size: + raise Exception(f"restore vocabulary size {restore_vocab_size} cannot be less than " + f"saved vocabulary size {vocab_size},which would loss the mapping between keys and embeddings ") + + return data_to_restore diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 83f8effe..eeaec25f 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -34,6 +34,7 @@ class ConfigInitializer: self._if_load = None self._table_instance_dict = dict() self._dangling_table = [] + self._removing_var_list = [] self._name_to_var_dict = dict() self._table_name_set = set() self._table_name_to_feature_spec = dict() @@ -272,10 +273,18 @@ class ConfigInitializer: if name not in self._dangling_table: self._dangling_table.append(name) + def insert_removing_var_list(self, name): + if name not in self._removing_var_list: + self._removing_var_list.append(name) + @property def dangling_table(self): return self._dangling_table + @property + def removing_var_list(self): + return self._removing_var_list + def insert_table_instance(self, name, key, instance): if key in self._table_instance_dict: raise KeyError(f"Given key {key} has been used.") @@ -562,6 +571,10 @@ def insert_dangling_table(table_name): ConfigInitializer.get_instance().insert_dangling_table(table_name) +def insert_removing_var_list(var_name): + ConfigInitializer.get_instance().insert_removing_var_list(var_name) + + def insert_table_instance(name, key, instance): ConfigInitializer.get_instance().insert_table_instance(name, key, instance) @@ -574,6 +587,10 @@ def export_dangling_table(): return ConfigInitializer.get_instance().dangling_table +def export_removing_var_list(): + return ConfigInitializer.get_instance().removing_var_list + + def insert_optimizer(optimizer): ConfigInitializer.get_instance().insert_optimizer(optimizer) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 7f13d379..66bbbc15 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -149,12 +149,25 @@ int Checkpoint::GetEmbeddingSize(const string& embName) const return 0; } +bool Checkpoint::CheckEmbNames(const string& embName) +{ + for (const auto &embInfo: mgmtEmbInfo) { + if (embInfo.name == embName && embInfo.isSave == true) { + return true; + } + } + return false; +} void Checkpoint::SaveDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler) { for (const auto& embName: embNames) { + if (!CheckEmbNames(embName)) { + continue; + } + auto dataDir{innerDirPath + dirSeparator + embName}; for (const auto& saveDataType: saveDataTypes) { auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 27578053..f2120519 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -84,6 +84,7 @@ namespace MxRec { void ReadEmbedding(CkptTransData& transData, const string& dataDir); int GetEmbeddingSize(const string& embName) const; + bool CheckEmbNames(const string& embNames); void LoadProcess(CkptData& ckptData); void GetUpperLayerLoadDir(const vector& dirNames); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 27d4ff24..e66dc5a3 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -360,12 +360,13 @@ struct BatchTask { int embeddingSize, int extEmbeddingSize, bool modifyGraph, + bool isSave, std::vector channelNames, std::vector vocabsize, std::vector initializeInfos, std::map sendCountMap) : name(name), sendCount(sendCount), embeddingSize(embeddingSize), extEmbeddingSize(extEmbeddingSize), - modifyGraph(modifyGraph), channelNames(channelNames), initializeInfos(initializeInfos), + modifyGraph(modifyGraph), isSave(isSave), channelNames(channelNames), initializeInfos(initializeInfos), sendCountMap(sendCountMap) { devVocabSize = vocabsize[0]; @@ -377,6 +378,7 @@ struct BatchTask { int embeddingSize; int extEmbeddingSize; bool modifyGraph; + bool isSave; size_t devVocabSize; size_t hostVocabSize; std::vector channelNames; diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 3839fe37..d59a58d2 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -93,16 +93,18 @@ void GetRankInfo(pybind11::module_& m) void GetEmbInfo(pybind11::module_& m) { pybind11::class_(m, "EmbInfo") - .def(pybind11::init, std::vector, + .def(pybind11::init, std::vector, std::vector&, std::map>(), py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), - py::arg("ext_embedding_size"), py::arg("modify_graph"), py::arg("channel_name_list"), + py::arg("ext_embedding_size"), py::arg("modify_graph"), + py::arg("is_save"), py::arg("channel_name_list"), py::arg("vocab_size"), py::arg("initialize_infos"), py::arg("send_count_map")) .def_readwrite("name", &EmbInfo::name) .def_readwrite("send_count", &EmbInfo::sendCount) .def_readwrite("embedding_size", &EmbInfo::embeddingSize) .def_readwrite("ext_embedding_size", &EmbInfo::extEmbeddingSize) .def_readwrite("modify_graph", &EmbInfo::modifyGraph) + .def_readwrite("is_save", &EmbInfo::isSave) .def_readwrite("channel_name_list", &EmbInfo::channelNames) .def_readwrite("dev_vocab_size", &EmbInfo::devVocabSize) .def_readwrite("host_vocab_size", &EmbInfo::hostVocabSize) diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index cd348bfa..59c339f9 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -74,6 +74,7 @@ protected: testEmbInfo.extEmbeddingSize = embeddingSize; testEmbInfo.devVocabSize = devVocabSize; testEmbInfo.hostVocabSize = hostVocabSize; + testEmbInfo.isSave = true; ++idx; } } diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index d103b144..52fd21c8 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -39,6 +39,7 @@ protected: int embeddingSize = 8; int extEmbeddingSize = 24; bool modifyGraph = false; + bool isSave = true; size_t devVocabSize = 5; size_t hostVocabSize = 15; vector randomInfos; @@ -116,7 +117,7 @@ protected: TEST_F(EmbMgmtTest, Initialize) { vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, channelNames, vocabsize, + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, isSave, channelNames, vocabsize, initializeInfos, sendCountMap); embInfos.emplace_back(embInfo); vector thresholdValues = {}; @@ -174,7 +175,7 @@ TEST_F(EmbMgmtTest, Initialize_HBM) devVocabSize = HBM_DEVICE_SIZE; hostVocabSize = HBM_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, channelNames, vocabsize, + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, isSave, channelNames, vocabsize, initializeInfos, sendCountMap); embInfos.emplace_back(embInfo); vector thresholdValues; @@ -193,7 +194,7 @@ TEST_F(EmbMgmtTest, Evict) size_t devVocabSize = DDR_DEVICE_SIZE; size_t hostVocabSize = DDR_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, channelNames, vocabsize, + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, isSave, channelNames, vocabsize, initializeInfos, sendCountMap); embInfos.emplace_back(embInfo); vector thresholdValues; @@ -215,7 +216,7 @@ TEST_F(EmbMgmtTest, Evict_HBM) devVocabSize = HBM_DEVICE_SIZE; hostVocabSize = HBM_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, channelNames, vocabsize, + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, isSave, channelNames, vocabsize, initializeInfos, sendCountMap); embInfos.emplace_back(embInfo); vector thresholdValues; -- Gitee From 16ca86a0beebba08e273c2ab2ddaefef4b1cadb9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 8 Jun 2023 15:11:04 +0800 Subject: [PATCH 117/551] Match-id-5466f85089067720d0ceebbffe7ce688c1361138 --- src/CMakeLists.txt | 3 +- src/core/checkpoint/checkpoint.cpp | 2 +- .../ckpt_data_handler/ckpt_data_handler.cpp | 6 ++- .../feat_admit_n_evict_ckpt.cpp | 2 +- .../host_emb_ckpt/host_emb_ckpt.cpp | 4 ++ .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp | 3 +- .../nddr_offset_ckpt/nddr_offset_ckpt.cpp | 7 ++- src/core/emb_hashmap/emb_hashmap.cpp | 14 +++--- src/core/emb_hashmap/emb_hashmap.h | 2 +- src/core/emb_table/emb_table.cpp | 6 +-- src/core/host_emb/host_emb.cpp | 6 +-- .../truncated_normal_initializer.cpp | 6 +-- src/core/key_process/key_process.cpp | 16 ++++--- src/core/utils/common.h | 2 +- src/core/utils/unique.h | 48 +++++++++---------- src/ops_tf/hybrid_dataset_ops.cpp | 5 +- src/pybind/CMakeLists.txt | 6 +-- src/pybind/module_main.cpp | 4 +- 18 files changed, 78 insertions(+), 64 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 42d2f2df..6e07d657 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -35,7 +35,8 @@ else () ADD_DEFINITIONS(-DBUILD_WITH_EASY_PROFILER) endif () set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -ffunction-sections -O0 -Wall -g2 -ggdb") -set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -ffunction-sections -O3 -Wall -DNDEBUG") +set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -ffunction-sections -O3 -Wall -DNDEBUG -fPIC -fstack-protector-all -Wextra -Wconversion -D_FORTIFY_SOURCE=2 -s") +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack") option(ENABLE_DEBUG "use debug mode" OFF) if (ENABLE_DEBUG) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 66bbbc15..c63df0c7 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -287,7 +287,7 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si int loops = 1; if (dataType == CkptDataType::EMB_DATA) { - loops = transData.floatArr.size(); + loops = static_cast(transData.floatArr.size()); } for (int i = 0; i < loops; i++) { size_t idx = 0; diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.cpp b/src/core/ckpt_data_handler/ckpt_data_handler.cpp index 6273fa66..92ffb8ba 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.cpp +++ b/src/core/ckpt_data_handler/ckpt_data_handler.cpp @@ -4,7 +4,7 @@ * Author: MindX SDK * Create: 2022-11-12 */ - +#include #include "ckpt_data_handler.h" @@ -34,5 +34,7 @@ void CkptDataHandler::CleanTransfer() void CkptDataHandler::SetDatasetForLoadEmb(CkptDataType dataType, string embName, CkptTransData& loadedData, CkptData& ckptData) { - throw std::runtime_error("Wrong CkptDataType, only EMB_INFO and EMB_DATA supported for load host emb"); + spdlog::error("Load host emb failed. dataType:{}, embName:{}, loadedData:{}, ckptData:{}", dataType, embName, + loadedData.datasetSize, ckptData.embHashMaps.empty()); + throw runtime_error("only EMB_INFO and EMB_DATA supported for load host emb"); } \ No newline at end of file diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index d6a4d388..2d992c24 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -136,7 +136,7 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) const auto& lastTime = transArr[i + lastTimeIdxOffset]; histRecs[featureId].featureId = featureId; - histRecs[featureId].count = count; + histRecs[featureId].count = static_cast(count); histRecs[featureId].lastTime = lastTime; histRecs[featureId].tensorName = embName; } diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp index dc573a63..f236637c 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp @@ -5,6 +5,7 @@ * Create: 2022-11-12 */ +#include #include "host_emb_ckpt.h" @@ -24,6 +25,7 @@ void HostEmbCkpt::GetProcessData(CkptData& processData) { saveHostEmbs = nullptr; loadHostEmbs = nullptr; + spdlog::info("processData.embHashMaps.empty():{}", processData.embHashMaps.empty()); } vector HostEmbCkpt::GetDataTypes() @@ -58,6 +60,7 @@ CkptTransData HostEmbCkpt::GetDataset(CkptDataType dataType, string embName) void HostEmbCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { + spdlog::info("Parameter dataType:{}, embName:{}, loadedData:{}", dataType, embName, loadedData.datasetSize); return; } @@ -115,6 +118,7 @@ void HostEmbCkpt::SetEmbInfo(string embName, CkptData& ckptData) // load Emb data void HostEmbCkpt::SetEmbData(string embName, CkptData& ckptData) { + spdlog::info("Parameter embName:{}, ckptData:{}", embName, ckptData.embHashMaps.empty()); return; } diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp index 271af3cc..2d3bc5a3 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp @@ -67,7 +67,7 @@ CkptTransData NddrFeatMapCkpt::GetDataset(CkptDataType dataType, string embName) transArr.push_back(it.first); transArr.push_back(it.second); } - + spdlog::info("CkptDataType::EMB_INFO:{}, dataType{} is", CkptDataType::EMB_INFO, dataType); return move(transferData); } @@ -86,4 +86,5 @@ void NddrFeatMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTran int64_t key { transArr.at(i) }; hostHashMap[key] = transArr.at(i + 1); } + spdlog::info("dataType{} is", dataType); } diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp index 17f19b75..1dd065ca 100644 --- a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp @@ -5,6 +5,7 @@ * Create: 2022-11-17 */ +#include #include "nddr_offset_ckpt.h" @@ -47,11 +48,12 @@ vector NddrOffsetCkpt::GetEmbNames() CkptTransData NddrOffsetCkpt::GetDataset(CkptDataType dataType, string embName) { CleanTransfer(); - transferData.int32Arr.push_back(saveMaxOffset.at(embName)); + transferData.int32Arr.push_back(static_cast(saveMaxOffset.at(embName))); transferData.datasetSize = fourBytes; transferData.attribute.push_back(1); transferData.attribute.push_back(fourBytes); transferData.attributeSize = transferData.attribute.size() * eightBytes; + spdlog::info("CkptDataType::EMB_INFO:{}, dataType:{} is", CkptDataType::EMB_INFO, dataType); return move(transferData); } @@ -60,4 +62,5 @@ void NddrOffsetCkpt::SetDataset(CkptDataType dataType, string embName, CkptTrans CleanTransfer(); transferData = move(loadedData); loadMaxOffset[embName] = transferData.int32Arr.front(); -} \ No newline at end of file + spdlog::info("CkptDataType::EMB_INFO:{}, dataType:{} is", CkptDataType::EMB_INFO, dataType); +} diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index f506c809..d76abfed 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -108,7 +108,7 @@ void EmbHashMap::FindAndUpdateOffset(const string& embName, const vector(swapSize), keepBatchId); EASY_END_BLOCK EASY_BLOCK("ChangeInfo") #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ @@ -137,25 +137,25 @@ void EmbHashMap::ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_ int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap) { - int offset; + int32_t offset; const auto& iter = embHashMap.hostHashMap.find(key); if (iter != embHashMap.hostHashMap.end()) { // 由于未全局去重,需要再次查询确保是新key - offset = iter->second; + offset = static_cast(iter->second); } else if (embHashMap.evictDevPos.size() != 0) { // 优先复用hbm表 - offset = embHashMap.evictDevPos.back(); + offset = static_cast(embHashMap.evictDevPos.back()); embHashMap.hostHashMap[key] = offset; spdlog::trace("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", key, offset, embHashMap.evictDevPos.size()); embHashMap.evictDevPos.pop_back(); } else if (embHashMap.evictPos.size() != 0) { // hbm不足,再复用ddr表 - offset = embHashMap.evictPos.back(); + offset = static_cast(embHashMap.evictPos.back()); embHashMap.hostHashMap[key] = offset; spdlog::trace("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", key, offset, embHashMap.evictPos.size()); embHashMap.evictPos.pop_back(); } else { embHashMap.hostHashMap[key] = embHashMap.maxOffset; - offset = embHashMap.maxOffset; + offset = static_cast(embHashMap.maxOffset); embHashMap.maxOffset++; if (embHashMap.maxOffset == embHashMap.devVocabSize) { spdlog::info("start using host vocab!"); @@ -185,7 +185,7 @@ void EmbHashMap::FindAndUpdateBatchId(const vector& keys, size_t curr embHashMap.lookUpVec[i] = offset; // convert to offset(current) spdlog::trace("key will be used, {} , offset , {}", key, offset); if (offset < static_cast(embHashMap.devVocabSize)) { - embHashMap.devOffset2Batch[offset] = currentBatchId; + embHashMap.devOffset2Batch[offset] = static_cast(currentBatchId); embHashMap.devOffset2Key[offset] = key; } } diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index ab51ab13..a8fcf977 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -63,4 +63,4 @@ namespace MxRec { }; } -#endif // MX_REC_EMB_HASHMAP_H \ No newline at end of file +#endif // MX_REC_EMB_HASHMAP_H diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 6d8299b2..78e9e815 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -53,7 +53,7 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) SplitMemoryBlock(newBlock); } } - totalCapacity = memoryList.size(); + totalCapacity = static_cast(memoryList.size()); spdlog::info("aclrtMalloc success, emb name:{}, total capacity:{}", embInfo.name, totalCapacity); #endif } @@ -208,12 +208,12 @@ list EmbTable::LoadEmb(const vector> &savedEmb) { #ifndef GTEST list addressList; - int embCapacity = savedEmb.size(); + int embCapacity = static_cast(savedEmb.size()); if (savedEmb.size() == 0 || savedEmb[0].size() == 0) { spdlog::error("Load invalid savedEmb"); return addressList; } - embSize = savedEmb[0].size(); + embSize = static_cast(savedEmb[0].size()); void *newBlock = nullptr; aclError ret = aclrtMalloc(&newBlock, embCapacity * embSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index eeb6aa96..6f74b8d8 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -24,8 +24,8 @@ bool HostEmb::Initialize(const vector& embInfos, int seed) for (const auto& embInfo: embInfos) { HostEmbTable hostEmb; hostEmb.hostEmbInfo = embInfo; - EmbDataGenerator(embInfo.initializeInfos, seed, embInfo.hostVocabSize, embInfo.extEmbeddingSize, - hostEmb.embData); + EmbDataGenerator(embInfo.initializeInfos, seed, static_cast(embInfo.hostVocabSize), + embInfo.extEmbeddingSize, hostEmb.embData); hostEmbs[embInfo.name] = move(hostEmb); spdlog::info(HOSTEMB + "HostEmb Initialize End"); } @@ -231,7 +231,7 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve } for (size_t i = 0; i < offset.size(); i++) { - initializer->GenerateData(embData.at(offset.at(i)).data(), embData[0].size()); + initializer->GenerateData(embData.at(offset.at(i)).data(), static_cast(embData[0].size())); } } } diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index d02ac998..1b85e202 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -16,8 +16,8 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, float { generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); - minBound = mean - boundNum * stddev; - maxBound = mean + boundNum * stddev; + minBound = mean - static_cast(boundNum) * stddev; + maxBound = mean + static_cast(boundNum) * stddev; } @@ -39,4 +39,4 @@ void TruncatedNormalInitializer::GenerateData(float* const emb, const int embSiz } return tmp; }); -} \ No newline at end of file +} diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 9c5dec8d..bae493d5 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -149,8 +149,8 @@ void KeyProcess::InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo) if (rankInfo.useDynamicExpansion) { embeddingSize = info.embeddingSize; } - hotEmbTotCount[info.name] = static_cast(GetUBSize(rInfo.deviceId) / sizeof(float) * HOT_EMB_CACHE_PCT / - embeddingSize); + hotEmbTotCount[info.name] = static_cast(static_cast(GetUBSize(rInfo.deviceId) / sizeof(float)) * + HOT_EMB_CACHE_PCT / static_cast(embeddingSize)); } auto KeyProcess::GetSendCount(const string& name, const string& channelName, bool modifyGraph) @@ -245,7 +245,7 @@ void KeyProcess::KeyProcessTask(const int channel, const int id) // thread id [0 auto sendCountSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); shared_ptr uniquePtr; if (uniquePtrMap.find(sendCountSize) == uniquePtrMap.end()) { - uniquePtr.reset(new sharded_dedup(groupMethod, batch->batchSize, sendCountSize)); + uniquePtr.reset(new sharded_dedup(groupMethod, static_cast(batch->batchSize), sendCountSize)); uniquePtrMap.insert(std::make_pair(sendCountSize, uniquePtr)); } unique = uniquePtrMap[sendCountSize]; @@ -287,7 +287,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, shared_ptr UniqueInfo uniqueInfo; ProcessBatchWithUniqueCompute(batch, unique, id, uniqueInfo); TIME_PRINT("no copy ProcessBatchWithUniqueCompute TimeCost(ms):{}", tc.ElapsedMS()); - + sw.reset(); // 特征准入&淘汰 if (isWithFAAE && (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, uniqueInfo.all2AllInfo.keyRecv, @@ -712,7 +712,8 @@ tuple, vector, vector> auto hot = hotMap.find(key); if (hot != hotMap.end()) { // is hot key if (hot->second == -1) { // is new hot key in this batch - hotPos[hotCount] = splitKeys[devId].size() - 1; // pos in lookup vec (need add ss) for hot-gather + // pos in lookup vec (need add ss) for hot-gather + hotPos[hotCount] = static_cast(splitKeys[devId].size()) - 1; hotPosDev[hotCount] = devId; // which dev, for get ss hot->second = hotCount; restore[i] = hotCount++; // get pos of hot emb @@ -720,7 +721,8 @@ tuple, vector, vector> restore[i] = hot->second; } } else { // is not hot key - restore[i] = splitKeys[devId].size() + hotOffset - 1; // restore记录去重后key在桶内偏移量(用于计算恢复向量) + // restore记录去重后key在桶内偏移量(用于计算恢复向量) + restore[i] = static_cast(splitKeys[devId].size() + (hotOffset - 1)); } uKey[key] = restore[i]; } @@ -741,7 +743,7 @@ void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& h } } else { for (auto& splitKey: splitKeys) { - splitKeysSize.push_back(splitKey.size()); + splitKeysSize.push_back(static_cast(splitKey.size())); } } auto cs = Count2Start(splitKeysSize); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index e66dc5a3..fe628307 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -81,7 +81,7 @@ namespace MxRec { constexpr int MGMT_THREAD_BIND = 48; constexpr int UNIQUE_MAX_BUCKET_WIDTH = 6; constexpr int HOT_EMB_UPDATE_STEP_DEFAULT = 1000; - constexpr float HOT_EMB_CACHE_PCT = 1. / 3; // hot emb cache percent + constexpr float HOT_EMB_CACHE_PCT = static_cast(1. / 3); // hot emb cache percent using emb_key_t = int64_t; using emb_name_t = std::string; diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index f789475c..ce0c4655 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -84,9 +84,9 @@ public: { return groupCount_; } - inline int GroupId(uint64_t val) + inline uint64_t GroupId(uint64_t val) { - return val & (groupCount_ - 1); + return val & static_cast(groupCount_ - 1); } void SetGroupCount(int count) { @@ -263,7 +263,7 @@ public: std::lock_guard lg(overflowMutex_); if (newBucketCountPowerOf2 > 0 && newBucketCountPowerOf2 != (uint64_t)bucketCount_) { free(table_); - bucketCount_ = newBucketCountPowerOf2; + bucketCount_ = static_cast(newBucketCountPowerOf2); bucketCountMask_ = bucketCount_ - 1; table_ = reinterpret_cast *>(aligned_alloc(SysytemConst::LEVEL1_CACHE, sizeof(Meta) * bucketCount_)); @@ -356,7 +356,7 @@ public: } bucket->replace_base = replace_offset; for (int j = 0; j < bucket->count; ++j) { - idCount[total] = bucket->idCount[j]; + idCount[total] = static_cast(bucket->idCount[j]); output[total++] = bucket->data[j]; } replace_offset += bucket->count; @@ -406,10 +406,10 @@ public: } bucket->replace_base = replace_offset; for (int j = 0; j < bucket->count; ++j) { - idCount[total] = bucket->idCount[j]; + idCount[total] = static_cast(bucket->idCount[j]); output[total++] = bucket->data[j]; - handleHotKey(bucket->data[j], hotMap, hotPosMap, hotCount); - keyCountMap[bucket->data[j]] = bucket->idCount[j]; + handleHotKey(static_cast(bucket->data[j]), hotMap, hotPosMap, hotCount); + keyCountMap[bucket->data[j]] = static_cast(bucket->idCount[j]); } replace_offset += bucket->count; } @@ -419,7 +419,7 @@ public: idCount[total] = idCountOverflow_[it->first]; keyCountMap[it->first] = idCountOverflow_[it->first]; output[total++] = it->first; - handleHotKey(it->first, hotMap, hotPosMap, hotCount); + handleHotKey(static_cast(it->first), hotMap, hotPosMap, hotCount); it->second = replace_offset++; ++it; ++totalOverflow; @@ -428,7 +428,7 @@ public: // set total overflow count stats_.totalUniques = total - priorTotal; stats_.totalOverflowUniques = totalOverflow; - return total - priorTotal; + return static_cast(total - priorTotal); } std::vector Replacement(const std::vector &input, std::vector *unique = nullptr, @@ -498,8 +498,9 @@ public: { return 1; } - inline int GroupId(uint64_t val) + inline int32_t GroupId(uint64_t val) { + spdlog::info("val:{} is", val); return 0; } }; @@ -530,9 +531,9 @@ public: ~ShardedDedup() {} - const int NumOfGroupsInEachShard() + int NumOfGroupsInEachShard() const { - return groupMethod_.GroupCount(); + return static_cast(groupMethod_.GroupCount()); } /* * @@ -591,7 +592,7 @@ public: size_t inputSize = size; - uint32_t threadNum = (inputSize + kMinimalWorkloadPerWorker - 1) / kMinimalWorkloadPerWorker; + uint32_t threadNum = static_cast(inputSize + kMinimalWorkloadPerWorker - 1) / kMinimalWorkloadPerWorker; threadNum = std::min(maxThreadCount, std::max(threadNum, minThreadCount)); size_t partSize = (inputSize + threadNum - 1) / threadNum; @@ -617,14 +618,14 @@ public: std::vector baseVector; // Collect Unique and base vectors - uint64_t base = 0; - uint64_t total = 0; + uint32_t base = 0; + uint32_t total = 0; int hotCount = 0; map hotPosMap; for (int j = 0; j < groupMethod_.GroupCount(); ++j) { - uint64_t inGroupTotal = 0; + uint32_t inGroupTotal = 0; if (useHot) { inGroupTotal = dedupShards_[j]->UniqueRawForHot(uniqueVector, total, idCount, hotMap, hotPosMap, hotCount, @@ -684,7 +685,7 @@ public: hotPosMap]() -> TaskReturnType { for (int32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { auto val = isInt64 ? ((int64_t *)input)[ptr - beginPtr] : ((int32_t *)input)[ptr - beginPtr]; - auto group = groupMethod_.GroupId(val); + int32_t group = static_cast(groupMethod_.GroupId(val)); uint32_t fillOffset = GetFillOffset(useStatic, baseVector, totalUniqueSize, val, group); ComputeRestore(useHot, offset, hotMap, hotPos, hotPosMap, ptr, val, fillOffset); } @@ -726,10 +727,9 @@ public: int64_t val, int32_t group) { if (!useStatic) { - return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0]; + return static_cast(dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0]); } else { - return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0] + send_cnt_ * group - - totalUniqueSize[group]; + return static_cast(dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0] + send_cnt_ * group - totalUniqueSize[group]); } } @@ -742,7 +742,7 @@ public: for (int i = 0; i < groupCount; i++) { if (i > 0) { - index += uniqueSizeVector[i - 1]; + index += static_cast(uniqueSizeVector[i - 1]); } if (useStatic) { @@ -768,7 +768,7 @@ public: } } - int fillLen = send_cnt_ - uniqueSizeVector[i]; + long int fillLen = send_cnt_ - uniqueSizeVector[i]; if (useStatic) { for (int j = 0; j < fillLen; j++) { uniqueIds[start + uniqueSizeVector[i] + j] = -1; @@ -776,7 +776,7 @@ public: } } - uniqueSize[i] = uniqueSizeVector[i]; + uniqueSize[i] = static_cast(uniqueSizeVector[i]); } } @@ -793,4 +793,4 @@ private: std::vector> dedupShards_; int32_t send_cnt_; }; -#endif \ No newline at end of file +#endif diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 9b7c30d4..5573b75e 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -66,7 +66,7 @@ public: void Compute(OpKernelContextPtr context) override { - spdlog::info("clear channel {}", channelId); + spdlog::info("clear channel {}, context {}", channelId, context->step_id()); batchIdsInfo.at(channelId) = 0; } @@ -780,7 +780,7 @@ public: const Tensor& inputTensor = context->input(TENSOR_INDEX_0); auto input = inputTensor.flat(); - int32_t batchId = input(0); + int32_t batchId = static_cast(input(0)); spdlog::info("ReadRawDummy cost:{}, elapsed from last:{} , batchId = {}", duration_cast((sw).elapsed()), @@ -802,6 +802,7 @@ public: void Compute(OpKernelContextPtr context) override { + spdlog::info("context {}", context->step_id()); std::cout << " Cust opp not installed!!" << std::endl; } diff --git a/src/pybind/CMakeLists.txt b/src/pybind/CMakeLists.txt index 5bf95c14..28ec5210 100644 --- a/src/pybind/CMakeLists.txt +++ b/src/pybind/CMakeLists.txt @@ -1,8 +1,8 @@ -cmake_minimum_required(VERSION 3.12) +cmake_minimum_required(VERSION 3.20) pybind11_add_module(mxrec_pybind module_main.cpp) -set_target_properties(mxrec_pybind PROPERTIES LINK_FLAGS "-Wl,-rpath,/") +set_target_properties(mxrec_pybind PROPERTIES LINK_FLAGS "-Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack") target_include_directories(mxrec_pybind PUBLIC ${ASCEND_DRIVER_PATH}/include) target_link_directories(mxrec_pybind PUBLIC ${ASCEND_DRIVER_PATH}/lib64/driver) target_link_libraries(mxrec_pybind PUBLIC ASC dcmi) -install(TARGETS mxrec_pybind LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) \ No newline at end of file +install(TARGETS mxrec_pybind LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index d59a58d2..9e1d6461 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -32,7 +32,7 @@ void GetNormalInitializerInfo(pybind11::module_& m); int GetUBHotSize(int devID) { - return static_cast(MxRec::GetUBSize(devID)/ sizeof(float) * HOT_EMB_CACHE_PCT) ; + return static_cast(static_cast(MxRec::GetUBSize(devID)) / sizeof(float) * HOT_EMB_CACHE_PCT) ; } uint32_t GetLogicID(uint32_t phyid) @@ -173,4 +173,4 @@ void GetThresholdValue(pybind11::module_& m) .def_readwrite("tensor_name", &ThresholdValue::tensorName) .def_readwrite("count_threshold", &ThresholdValue::countThreshold) .def_readwrite("time_threshold", &ThresholdValue::timeThreshold); -} \ No newline at end of file +} -- Gitee From 64a867ca14bb3d6fb3168706df7214beb1ab0286 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 8 Jun 2023 09:46:11 +0800 Subject: [PATCH 118/551] Match-id-fd4fc650176ba2b1b0ac18619da3e8a384ea804c --- build/build.sh | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/build/build.sh b/build/build.sh index 9d81f52c..4cdfaffa 100644 --- a/build/build.sh +++ b/build/build.sh @@ -16,15 +16,17 @@ then pip3 install virtualenv --force-reinstall virtualenv -p "$(which python3.7)" tf2_env source tf2_env/bin/activate - [ ! -f tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl ] && wget --no-check-certificate https://cmc-szver-artifactory.cmc.tools.huawei.com/artifactory/cmc-software-release/MindX/mindx_img_tools/1.0.0/tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl - pip3 install tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl --no-deps + tf265="tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl" + [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ + pip3 install "${tf265}" --no-deps pip3 install setuptools==49.2.1 tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow deactivate tf2_env virtualenv -p "$(which python3.7)" tf1_env source tf1_env/bin/activate - [ ! -f tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl ] && wget --no-check-certificate https://cmc-szver-artifactory.cmc.tools.huawei.com/artifactory/cmc-software-release/MindX/mindx_img_tools/1.0.0/tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl - pip3 install tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl --no-deps + tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl" + [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ + pip3 install "${tf115}" --no-deps pip3 install setuptools==49.2.1 tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core deactivate tf1_env @@ -35,15 +37,17 @@ then pip3 install virtualenv --force-reinstall virtualenv -p "$(which python3.7)" tf2_env source tf2_env/bin/activate - [ ! -f tensorflow-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl ] && wget --no-check-certificate https://cmc-hgh-artifactory.cmc.tools.huawei.com/artifactory/opensource_general/Tensorflow/2.6.5/package/tensorflow-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl - pip3 install tensorflow-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl --no-deps + tf265="tensorflow_cpu-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl" + [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ + pip3 install "${tf265}" --no-deps pip3 install setuptools==49.2.1 tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow deactivate tf2_env virtualenv -p "$(which python3.7)" tf1_env source tf1_env/bin/activate - [ ! -f tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl ] && wget --no-check-certificate https://cmc-szver-artifactory.cmc.tools.huawei.com/artifactory/cmc-software-release/MindX/mindx_img_tools/1.0.0/tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl - pip3 install tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl --no-deps + tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl" + [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ + pip3 install "${tf115}" --no-deps pip3 install setuptools==49.2.1 tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core deactivate tf1_env -- Gitee From fc4d91a884cc1d4e4ed1cd7b8cb9d71511d66778 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 8 Jun 2023 17:09:06 +0800 Subject: [PATCH 119/551] Match-id-d6506051fc2a397c7c90e3068dcf4b800e5a1232 --- mx_rec/util/initialize.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index eeaec25f..3b19d79d 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -45,6 +45,7 @@ class ConfigInitializer: self._optimizer_instance = None self._is_graph_modify_hook_running = False self._modify_graph = False + self._is_terminated = False if self._use_mpi: logging.debug(f"Using mpi to launch task.") @@ -166,6 +167,10 @@ class ConfigInitializer: ConfigInitializer._single_instance = ConfigInitializer(use_mpi, **kwargs) def terminate(self): + if self._is_terminated: + logging.warning("The initializer has already been released once, please do not release it again.") + return + if self._asc_manager is not None: self.del_asc_manager() @@ -173,6 +178,8 @@ class ConfigInitializer: self._mpi.Finalize() logging.debug("MPI has been destroyed.") + self._is_terminated = True + def insert_feature_spec(self, feature, is_training): self._feature_spec_dict[feature.name] = feature if feature.table_name not in self._table_name_to_feature_spec: -- Gitee From bcb820ca509a185b6e5de2d54cb3535462711d08 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 8 Jun 2023 19:09:31 +0800 Subject: [PATCH 120/551] Match-id-79410dc2f385ac81a8460f8763afb4b90374ad3c --- mx_rec/saver/patch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 8c8bee45..89df2598 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -60,7 +60,8 @@ def saver_init(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=saver_pb2.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None, fid_version=0): - var_list = build_var_list(var_list) + if not defer_build: + var_list = build_var_list(var_list) self._last_checkpoints = [] self._checkpoints_to_be_deleted = [] self._var_list = var_list -- Gitee From 9f17e739d81f0bc959bd5d888017a46b0b1c1dde Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 9 Jun 2023 09:31:54 +0800 Subject: [PATCH 121/551] Match-id-ccc851780596d7352f42eb5f2332913d2f339cf6 --- mx_rec/saver/patch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 8c8bee45..89df2598 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -60,7 +60,8 @@ def saver_init(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=saver_pb2.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None, fid_version=0): - var_list = build_var_list(var_list) + if not defer_build: + var_list = build_var_list(var_list) self._last_checkpoints = [] self._checkpoints_to_be_deleted = [] self._var_list = var_list -- Gitee From 94bdb128207271bd2601057b41c9f326899b7dc0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 9 Jun 2023 10:10:11 +0800 Subject: [PATCH 122/551] Match-id-df2c53c9e784228944a0d3a6663eb58b83d179b3 --- src/core/emb_hashmap/emb_hashmap.cpp | 258 +++++++++++-- src/core/emb_hashmap/emb_hashmap.h | 20 +- src/core/host_emb/host_emb.cpp | 18 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 32 +- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- src/core/key_process/key_process.cpp | 410 +++++++-------------- src/core/key_process/key_process.h | 65 ++-- src/core/utils/common.cpp | 3 + src/core/utils/common.h | 24 +- src/ops_tf/hybrid_dataset_ops.cpp | 356 +++++------------- src/tests/emb_mgmt/emb_mgmt_test.cpp | 9 +- src/tests/key_process/key_process_test.cpp | 25 +- 12 files changed, 556 insertions(+), 666 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index d76abfed..689e1849 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -37,16 +37,30 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, } } -void EmbHashMap::Process(const string& embName, const vector& keys, size_t iBatch, +void EmbHashMap::Process(const string& embName, vector& keys, size_t iBatch, vector& tmpDataOut) { +#ifndef GTEST EASY_FUNCTION(profiler::colors::Pink) + auto& embHashMap = embHashMaps.at(embName); + embHashMap.devOffset2KeyOld.clear(); + embHashMap.oldSwap.clear(); + embHashMap.maxOffsetOld = embHashMap.maxOffset; + auto keepBatch = swapId - iBatch; - FindAndUpdateOffset(embName, keys, swapId, keepBatch); + bool findOffsetV2 = getenv("FIND_OFFSET_V2") != nullptr; + spdlog::debug("FindOffset, {}", findOffsetV2); + + if (findOffsetV2) { + FindAndUpdateOffset(embName, keys, swapId, keepBatch); + } else { + FindOffset(embName, keys, swapId, keepBatch); + } + spdlog::debug("FindOffset end"); + swapId++; EASY_BLOCK("hostHashMaps->tdt") - auto& embHashMap = embHashMaps.at(embName); auto lookUpVecSize = static_cast(embHashMap.lookUpVec.size()); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); @@ -67,30 +81,28 @@ void EmbHashMap::Process(const string& embName, const vector& keys, s } spdlog::trace("swapTensor, {}", embHashMap.swapPos); embHashMap.swapPos.clear(); - spdlog::info("current dev emb usage:{}-{}/[{}+{}]", embName, embHashMap.maxOffset, embHashMap.devVocabSize, + embHashMap.lookUpVec.clear(); + spdlog::info("current dev emb usage:{}/[{}+{}]", embHashMap.maxOffset, embHashMap.devVocabSize, embHashMap.hostVocabSize); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto swapLen = tmpDataOut.back().flat(); swapLen(0) = swapSize; EASY_END_BLOCK +#endif } /* * 从embHashMaps获取key对应的位置,并更新devOffset2Batch */ -void EmbHashMap::FindAndUpdateOffset(const string& embName, const vector& keys, +#ifndef GTEST +void EmbHashMap::FindAndUpdateOffset(const string& embName, vector& keys, size_t currentBatchId, size_t keepBatchId) { EASY_FUNCTION() size_t keySize = keys.size(); auto& embHashMap = embHashMaps.at(embName); - embHashMap.lookUpVec.resize(keySize); - std::fill(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), INVALID_KEY_VALUE); - FindAndUpdateBatchId(keys, currentBatchId, keySize, embHashMap); - EASY_BLOCK("FindNewOffset") - vector> KeysAndOffset; - + const int devVocabSize = static_cast(embHashMap.devVocabSize); for (size_t i = 0; i < keySize; i++) { auto key = keys[i]; if (key == -1) { @@ -99,28 +111,17 @@ void EmbHashMap::FindAndUpdateOffset(const string& embName, const vector(currentBatchId); + } } - if (offset >= static_cast(embHashMap.devVocabSize)) { + if (offset >= devVocabSize) { embHashMap.missingKeysHostPos.emplace_back(offset - embHashMap.devVocabSize); - KeysAndOffset.emplace_back(key, i); + offset = FindSwapPosV2(embName, key, offset, currentBatchId, keepBatchId); } } - EASY_END_BLOCK - EASY_BLOCK("FindPos") - size_t swapSize = KeysAndOffset.size(); - FindPos(embHashMap, static_cast(swapSize), keepBatchId); - EASY_END_BLOCK - EASY_BLOCK("ChangeInfo") -#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ - shared(swapSize, KeysAndOffset, embHashMap, currentBatchId) - for (size_t i = 0; i < swapSize; i++) { - auto[key, j] = KeysAndOffset[i]; - int pos = static_cast(embHashMap.swapPos[i]); - ChangeSwapInfo(embHashMap, key, embHashMap.missingKeysHostPos[i] + embHashMap.devVocabSize, - currentBatchId, pos); - embHashMap.lookUpVec[j] = pos; - } - EASY_END_BLOCK } void EmbHashMap::ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_t hostOffset, size_t currentBatchId, @@ -130,6 +131,7 @@ void EmbHashMap::ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_ embHashMap.hostHashMap[key] = pos; auto& oldKey = embHashMap.devOffset2Key[pos]; if (oldKey != -1) { + embHashMap.oldSwap.emplace_back(oldKey, key); embHashMap.hostHashMap[oldKey] = hostOffset; } oldKey = key; @@ -169,25 +171,30 @@ int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashM return offset; } -void EmbHashMap::FindAndUpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, +void EmbHashMap::FindAndUpdateBatchId(vector& keys, size_t currentBatchId, size_t keySize, EmbHashMapInfo& embHashMap) const { EASY_FUNCTION() + bool findOffsetV3 = getenv("FIND_OFFSET_V3") != nullptr; for (size_t i = 0; i < keySize; i++) { int offset; - auto key = keys[i]; + auto& key = keys[i]; if (key == -1) { continue; } const auto& iter = embHashMap.hostHashMap.find(key); if (iter != embHashMap.hostHashMap.end()) { // found + if (findOffsetV3) { + key = -1; + } offset = static_cast(iter->second); - embHashMap.lookUpVec[i] = offset; // convert to offset(current) - spdlog::trace("key will be used, {} , offset , {}", key, offset); + embHashMap.lookUpVec.emplace_back(offset); // convert to offset(current) + if (offset < static_cast(embHashMap.devVocabSize)) { embHashMap.devOffset2Batch[offset] = static_cast(currentBatchId); - embHashMap.devOffset2Key[offset] = key; } + } else { + embHashMap.lookUpVec.emplace_back(INVALID_KEY_VALUE); } } } @@ -211,9 +218,25 @@ void EmbHashMap::FindPos(EmbHashMapInfo& embHashMap, int num, size_t keepBatchId } } + auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map { - return embHashMaps; + auto embHashMapsOld = embHashMaps; + for (auto& temp: embHashMapsOld) { + auto& embHashMap = temp.second; + for (auto& swapKeys: embHashMap.oldSwap) { + emb_key_t oldKey = swapKeys.first; + emb_key_t key = swapKeys.second; + int tempOffset = static_cast(embHashMap.hostHashMap[key]); + embHashMap.hostHashMap[key] = embHashMap.hostHashMap[oldKey]; + embHashMap.hostHashMap[oldKey] = static_cast(tempOffset); + } + embHashMap.maxOffset = embHashMap.maxOffsetOld; + for (auto& Offset2Key: embHashMap.devOffset2KeyOld) { + embHashMap.devOffset2Key[Offset2Key.first] = Offset2Key.second; + } + } + return embHashMapsOld; } void EmbHashMap::LoadHashMap(emb_hash_mem_t& loadData) @@ -260,6 +283,7 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& if (offset < embHashMap.devVocabSize) { embHashMap.devOffset2Batch[offset] = -1; + embHashMap.devOffset2KeyOld.emplace_back(offset, embHashMap.devOffset2Key[offset]); embHashMap.devOffset2Key[offset] = -1; embHashMap.evictDevPos.emplace_back(offset); } else { @@ -271,3 +295,167 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& embName, embHashMap.evictPos.size(), embHashMap.evictDevPos.size()); spdlog::trace("hostHashMap, {}", embHashMaps[embName].hostHashMap); } + +// old version +/* + * 从embHashMaps获取key对应的位置,并更新devOffset2Batch + */ + +void EmbHashMap::FindOffset(const string& embName, const vector& keys, + size_t currentBatchId, size_t keepBatchId) +{ + EASY_FUNCTION() + size_t keySize = keys.size(); + auto& embHashMap = embHashMaps.at(embName); + UpdateBatchId(keys, currentBatchId, keySize, embHashMap); + for (size_t i = 0; i < keySize; i++) { + auto key = keys[i]; + if (key == -1) { + embHashMap.lookUpVec.emplace_back(INVALID_KEY_VALUE); + continue; + } + auto offset = FindOffsetHelper(key, embHashMap); + if (offset < embHashMap.devVocabSize) { + embHashMap.lookUpVec.emplace_back(offset); + embHashMap.devOffset2KeyOld.emplace_back(offset, static_cast(embHashMap.devOffset2Key[offset])); + embHashMap.devOffset2Key[offset] = key; + } else { + embHashMap.missingKeysHostPos.emplace_back(offset - embHashMap.devVocabSize); + FindSwapPosOld(embName, key, offset, currentBatchId, keepBatchId); + } + } + if (currentBatchId == 0) { + spdlog::info("max offset {}", embHashMap.maxOffset); + } + spdlog::trace("hostHashMap, {}", embHashMaps[embName].hostHashMap); +} + + +size_t EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap) + +{ + size_t offset; + const auto& iter = embHashMap.hostHashMap.find(key); + if (iter != embHashMap.hostHashMap.end()) { + offset = iter->second; + spdlog::trace("devVocabSize, {} , offset , {}", embHashMap.devVocabSize, offset); + } else if (embHashMap.evictDevPos.size() != 0) { // 优先复用hbm表 + offset = embHashMap.evictDevPos.back(); + embHashMap.hostHashMap[key] = offset; + spdlog::trace("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", + key, offset, embHashMap.evictDevPos.size()); + embHashMap.evictDevPos.pop_back(); + } else if (embHashMap.evictPos.size() != 0) { // hbm不足,再复用ddr表 + offset = embHashMap.evictPos.back(); + embHashMap.hostHashMap[key] = offset; + spdlog::trace("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", + key, offset, embHashMap.evictPos.size()); + embHashMap.evictPos.pop_back(); + } else { + embHashMap.hostHashMap[key] = embHashMap.maxOffset; + offset = embHashMap.maxOffset; + embHashMap.maxOffset++; + if (embHashMap.maxOffset == embHashMap.devVocabSize) { + spdlog::info("start using host vocab!"); + } + if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { + spdlog::error("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, + embHashMap.hostVocabSize); + throw runtime_error("hostVocabSize too small"); + } + } + return offset; +} + +void EmbHashMap::UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, + EmbHashMapInfo& embHashMap) const +{ + for (size_t i = 0; i < keySize; i++) { + size_t offset; + auto key = keys[i]; + if (key == -1) { + continue; + } + const auto& iter = embHashMap.hostHashMap.find(key); + if (iter != embHashMap.hostHashMap.end()) { + offset = iter->second; + + spdlog::trace("key will be used, {} , offset , {}", key, offset); + if (offset < embHashMap.devVocabSize) { + embHashMap.devOffset2Batch[offset] = static_cast(currentBatchId); + } + } + } +} + +/* + * 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys + */ +int EmbHashMap::FindSwapPosV2(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, + size_t keepBatchId) +{ + bool notFind = true; + auto& embHashMap = embHashMaps.at(embName); + int newDevOffset; + while (notFind) { + if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { + embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] = static_cast(currentBatchId); + embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); + newDevOffset = static_cast(embHashMap.currentUpdatePos); + embHashMap.hostHashMap[key] = embHashMap.currentUpdatePos; + embHashMap.devOffset2KeyOld.emplace_back(embHashMap.currentUpdatePos, + embHashMap.devOffset2Key[embHashMap.currentUpdatePos]); + auto& oldKey = embHashMap.devOffset2Key[embHashMap.currentUpdatePos]; + embHashMap.oldSwap.emplace_back(oldKey, key); + embHashMap.hostHashMap[oldKey] = hostOffset; + oldKey = key; + notFind = false; + } + embHashMap.currentUpdatePos++; + embHashMap.freeSize--; + if (embHashMap.currentUpdatePos == embHashMap.devVocabSize) { + embHashMap.currentUpdatePos = 0; + } + if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { + spdlog::error("devVocabSize is too small"); + throw runtime_error("devVocabSize is too small"); + } + } + return newDevOffset; +} + +/* + * 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys + */ +bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, + size_t keepBatchId) +{ + bool notFind = true; + auto& embHashMap = embHashMaps.at(embName); + while (notFind) { + if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { + embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] = static_cast(currentBatchId); + embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); + embHashMap.lookUpVec.emplace_back(embHashMap.currentUpdatePos); + embHashMap.hostHashMap[key] = embHashMap.currentUpdatePos; + embHashMap.devOffset2KeyOld.emplace_back(embHashMap.currentUpdatePos, + embHashMap.devOffset2Key[embHashMap.currentUpdatePos]); + auto& oldKey = embHashMap.devOffset2Key[embHashMap.currentUpdatePos]; + embHashMap.oldSwap.emplace_back(oldKey, key); + embHashMap.hostHashMap[oldKey] = hostOffset; + oldKey = key; + notFind = false; + } + embHashMap.currentUpdatePos++; + embHashMap.freeSize--; + if (embHashMap.currentUpdatePos == embHashMap.devVocabSize) { + embHashMap.currentUpdatePos = 0; + } + if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { + spdlog::error("devVocabSize is too small"); + throw runtime_error("devVocabSize is too small"); + } + } + return true; +} +#endif diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index a8fcf977..b99e7777 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -23,10 +23,10 @@ namespace MxRec { void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); - void Process(const string& embName, const std::vector& keys, size_t iBatch, + void Process(const string& embName, std::vector& keys, size_t iBatch, vector& tmpDataOut); - void FindAndUpdateOffset(const string& embName, const vector& keys, size_t currentBatchId, + void FindAndUpdateOffset(const string& embName, vector& keys, size_t currentBatchId, size_t keepBatchId); void ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_t hostOffset, size_t currentBatchId, @@ -52,11 +52,25 @@ namespace MxRec { absl::flat_hash_map embHashMaps; + void FindOffset(const string& embName, const vector& keys, + size_t currentBatchId, size_t keepBatchId); + + size_t FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap); + + void UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, + EmbHashMapInfo& embHashMap) const; + + int FindSwapPosV2(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, + size_t keepBatchId); + + bool FindSwapPosOld(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, + size_t keepBatchId); + private: RankInfo rankInfo; int swapId { 0 }; - void FindAndUpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, + void FindAndUpdateBatchId(vector& keys, size_t currentBatchId, size_t keySize, EmbHashMapInfo& embHashMap) const; int32_t FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap); diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 6f74b8d8..374e1fb6 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -35,6 +35,7 @@ bool HostEmb::Initialize(const vector& embInfos, int seed) void HostEmb::EmbDataGenerator(const vector &initializeInfos, int seed, int vocabSize, int embeddingSize, vector> &embData) { +#ifndef GTEST spdlog::info(HOSTEMB + "GenerateEmbData Start, seed:{}", seed); embData.clear(); embData.resize(vocabSize, vector(embeddingSize)); @@ -72,13 +73,15 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in initializer->GenerateData(embData.at(i).data(), embeddingSize); } } - spdlog::info(HOSTEMB + "GenerateEmbData End, seed:{}", seed); +#endif } void HostEmb::LoadEmb(emb_mem_t& loadData) { +#ifndef GTEST hostEmbs = std::move(loadData); +#endif } void HostEmb::Join() @@ -96,6 +99,7 @@ void HostEmb::Join() * 从hdTransfer获取device侧返回的emb信息,并在host侧表的对应位置插入。 * missingKeysHostPos为host侧需要发送的emb的位置,也就是淘汰的emb的插入位置 */ +#ifndef GTEST void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName) { EASY_FUNCTION(profiler::colors::Purple) @@ -120,9 +124,8 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, auto& dst = embData[missingKeysHostPos[i]]; #pragma omp simd for (int j = 0; j < embeddingSize; j++) { - dst[j] = tensorPtr[j]; + dst[j] = tensorPtr[j + embeddingSize * i]; } - tensorPtr = tensorPtr + embeddingSize; } spdlog::info(HOSTEMB + "update emb end"); EASY_END_BLOCK @@ -130,7 +133,6 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName) { -#ifndef GTEST EASY_FUNCTION(profiler::colors::Purple) procThreads.emplace_back(make_unique( [&, missingKeysHostPos, channelId, embName] { @@ -156,7 +158,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI auto& dst = embData[missingKeysHostPos[j]]; #pragma omp simd for (int k = 0; k < embeddingSize; k++) { - dst[k] = ptr[k]; + dst[k] = ptr[k + embeddingSize * j]; } } if (acltdtDestroyDataset(aclDataset) != ACL_ERROR_NONE) { @@ -164,7 +166,6 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI } spdlog::info(HOSTEMB + "update emb end"); })); -#endif } /* @@ -193,6 +194,7 @@ void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& spdlog::info("GetH2DEmb end, missingKeys count:{}", missingKeysHostPos.size()); } + auto HostEmb::GetHostEmbs() -> absl::flat_hash_map* { return &hostEmbs; @@ -235,14 +237,16 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve } } } +#endif /* * 利用initializer初始化emb淘汰的位置 */ void HostEmb::EvictInitEmb(const string& embName, const vector& offset) { +#ifndef GTEST auto& hostEmb = GetEmb(embName); EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); - spdlog::info(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); +#endif } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index a6c6a072..15de3b73 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -16,7 +16,7 @@ using namespace MxRec; using namespace std; bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, - const vector& thresholdValues, bool ifLoad, int seed) + const vector& thresholdValues, int seed) { if (getenv("KEY_PROCESS_THREAD_NUM") != nullptr) { int num = std::atoi(getenv("KEY_PROCESS_THREAD_NUM")); @@ -41,7 +41,7 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& } preprocess = Singleton::GetInstance(); - preprocess->Initialize(rankInfo, embInfos, thresholdValues, ifLoad, seed); + preprocess->Initialize(rankInfo, embInfos, thresholdValues, seed); preprocess->Start(); return true; } @@ -80,7 +80,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, hdTransfer = Singleton::GetInstance(); hdTransfer->Init(embInfos, rankInfo.deviceId); - bool rc = InitKeyProcess(rankInfo, embInfos, thresholdValues, ifLoad, seed); + bool rc = InitKeyProcess(rankInfo, embInfos, thresholdValues, seed); if (!rc) { return false; } @@ -111,6 +111,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, bool HybridMgmt::Save(const string savePath) { +#ifndef GTEST preprocess->LoadSaveLock(); CkptData saveData; @@ -136,12 +137,13 @@ bool HybridMgmt::Save(const string savePath) saveCkpt.SaveModel(savePath, saveData, mgmtRankInfo, mgmtEmbInfo); preprocess->LoadSaveUnlock(); - +#endif return true; } bool HybridMgmt::Load(const string& loadPath) { +#ifndef GTEST preprocess->LoadSaveLock(); spdlog::debug(MGMT + "Start host side load process"); @@ -189,7 +191,7 @@ bool HybridMgmt::Load(const string& loadPath) if (!mgmtRankInfo.useDataset && isLoad) { Start(); } - +#endif return true; } @@ -244,6 +246,7 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) void HybridMgmt::Start() { +#ifndef GTEST if (mgmtRankInfo.noDDR) { auto getInfoTask = [this]() { auto ret = GetInfoTask(); @@ -268,8 +271,10 @@ void HybridMgmt::Start() }; procThreads.emplace_back(std::make_unique(parseKeysTask)); } +#endif } +#ifndef GTEST bool HybridMgmt::TrainParseKeys() { do { @@ -463,6 +468,7 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) batchId++; return true; } +#endif bool HybridMgmt::EndBatch(int batchId, int channelId) const { @@ -471,6 +477,7 @@ bool HybridMgmt::EndBatch(int batchId, int channelId) const bool HybridMgmt::ParseKeys(int channelId, int& batchId) { +#ifndef GTEST spdlog::info(MGMT + "DDR mode, start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); TimeCost parseKeyTC; int start = batchId; @@ -497,9 +504,11 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) } EmbHDTransWrap(channelId, batchId - 1, start, iBatch); TIME_PRINT("[{}]-{}, parseKeyTC TimeCost(ms):{}", channelId, batchId, parseKeyTC.ElapsedMS()); +#endif return true; } +#ifndef GTEST bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut) { @@ -562,7 +571,12 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) for (const auto& embInfo: mgmtEmbInfo) { const auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); if (!(skipUpdate && missingKeys.empty())) { - hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! + bool updateEmbV2 = getenv("UpdateEmb_V2") != nullptr; + if (updateEmbV2) { + hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! + } else { + hostEmbs->UpdateEmb(missingKeys, channelId, embInfo.name); // order! + } } // skip when skip update and empty missing keys hostHashMaps->ClearMissingKeys(embInfo.name); } @@ -578,12 +592,13 @@ void HybridMgmt::EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embI auto d2hEmb = hdTransfer->Recv(transferName, channelId, embInfo.name)[0]; hdTransfer->Send(TransferChannel::H2D, {}, channelId, embInfo.name); } - +#endif /* * hook通过时间或者step数触发淘汰 */ void HybridMgmt::Evict() { +#ifndef GTEST auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { featAdmitNEvict.FeatureEvict(evictKeyMap); @@ -602,11 +617,13 @@ void HybridMgmt::Evict() EvictKeys(evict.first, evict.second); } } +#endif } // ddr模式淘汰->删除映射表、初始化host表、发送dev淘汰位置 void HybridMgmt::EvictKeys(const string& embName, const vector& keys) { +#ifndef GTEST spdlog::debug(MGMT + "ddr mode, delete emb: [{}]! evict keySize:{}", embName, keys.size()); // 删除映射关系 if (keys.size() != 0) { @@ -644,4 +661,5 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) auto tmpData = Vec2TensorI32(evictDevOffset); hdTransfer->Send(TransferChannel::EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); +#endif } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 98168d3f..f834b9f6 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -93,7 +93,7 @@ namespace MxRec { private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, - const vector& thresholdValues, bool ifLoad, int seed); + const vector& thresholdValues, int seed); void InitRankInfo(RankInfo& rankInfo, const vector& embInfos); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index bae493d5..ea2f6572 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -17,7 +17,7 @@ #include "checkpoint/checkpoint.h" #include "hd_transfer/hd_transfer.h" #include "utils/common.h" -#include "utils/time_cost.h" + using namespace std; using namespace chrono; @@ -35,32 +35,9 @@ inline vector Count2Start(const vector& count) return start; } -KeyProcess::KeyProcess() -{ - // init class members with PerfConfig::keyProcessThreadNum - for (size_t i = 0; i < MAX_CHANNEL_NUM; ++i) { - comm[i].resize(PerfConfig::keyProcessThreadNum); - for (int j = 0; j < PerfConfig::keyProcessThreadNum; ++j) { - comm[i][j] = MPI_COMM_WORLD; - } - } - - for (size_t i = 0; i < MAX_CHANNEL_NUM; ++i) { - std::vector tmp(PerfConfig::keyProcessThreadNum); - loadSaveMut[i].swap(tmp); - } - std::vector tmp(PerfConfig::keyProcessThreadNum); - getInfoMut.swap(tmp); - - storage.resize(PerfConfig::keyProcessThreadNum); - lookupKeysList.resize(PerfConfig::keyProcessThreadNum); - infoList.resize(PerfConfig::keyProcessThreadNum); - all2AllList.resize(PerfConfig::keyProcessThreadNum); -} - int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, const vector& thresholdValues, - bool ifLoad, int seed) + int seed) { this->rankInfo = rInfo; if (rankInfo.useHot) { @@ -88,10 +65,6 @@ int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, embeddingTableMap[info.name].Init(info, rInfo, seed); spdlog::info(KEY_PROCESS "EmbeddingTableMap:{} init success", info.name); } - - if (rankInfo.rankId == 0 && !ifLoad) { - Key2OffsetInit(info.name); - } } spdlog::info(KEY_PROCESS "hot emb count info:{}", hotEmbTotCount); MPI_Group worldGroup; @@ -123,7 +96,7 @@ int KeyProcess::Start() // bind like: // 0 1 2 3 4 5 0 1 2 3 4 5 // | rank0 | | rank1 | - // each rank creates PerfConfig::keyProcessThreadNum threads, each thread process one batchdata + // each rank creates KEY_PROCESS_THREAD threads, each thread process one batchdata spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { #ifndef GTEST @@ -135,9 +108,20 @@ int KeyProcess::Start() #endif KeyProcessTask(channel, id); }; // for clean code + int threadNum; for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { - for (int id = 0; id < PerfConfig::keyProcessThreadNum; ++id) { - procThreads.emplace_back(std::make_unique(fn, channel, id)); + const char* threadNumEnv = getenv("THREAD_NUM"); + if (threadNumEnv != nullptr) { + threadNum = static_cast(*threadNumEnv) - static_cast('0'); + if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { + throw runtime_error(fmt::format("{} is not valid", threadNum)); + } + } else { + threadNum = KEY_PROCESS_THREAD; + } + for (int id = 0; id < threadNum; ++id) { + procThreads.emplace_back( + std::make_unique(fn, channel, id)); // use lambda expression initialize thread } } return 0; @@ -192,8 +176,8 @@ void KeyProcess::Destroy() { isRunning = false; spdlog::info(KEY_PROCESS "rank {} begin destroy.", rankInfo.rankId); - for (auto& t: procThreads) { - t->join(); + for (auto& i: procThreads) { + i->join(); } procThreads.clear(); spdlog::info(KEY_PROCESS "rank {} destroy success.", rankInfo.rankId); @@ -202,7 +186,7 @@ void KeyProcess::Destroy() void KeyProcess::LoadSaveLock() { for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { - for (int threadId { 0 }; threadId < PerfConfig::keyProcessThreadNum; ++threadId) { + for (int threadId { 0 }; threadId < KEY_PROCESS_THREAD; ++threadId) { loadSaveMut[channelId][threadId].lock(); } } @@ -211,63 +195,31 @@ void KeyProcess::LoadSaveLock() void KeyProcess::LoadSaveUnlock() { for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { - for (int threadId { 0 }; threadId < PerfConfig::keyProcessThreadNum; ++threadId) { + for (int threadId { 0 }; threadId < KEY_PROCESS_THREAD; ++threadId) { loadSaveMut[channelId][threadId].unlock(); } } } -void KeyProcess::KeyProcessTask(const int channel, const int id) // thread id [0, KEY_PROCESS_THREAD-1] +void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCESS_THREAD-1] { unique_ptr batch; - - GroupMethod groupMethod; - groupMethod.SetGroupCount(rankInfo.rankSize); - - shared_ptr unique; - map> uniquePtrMap; - spdlog::stopwatch sw; try { while (true) { - TimeCost getAndProcesTC; - TimeCost getBatchTC; batch = GetBatchData(channel, id); // get batch data from SingletonQueue - TIME_PRINT("GetBatchData TimeCost(ms):{}", getBatchTC.ElapsedMS()); - if (batch == nullptr) { - spdlog::info(KEY_PROCESS "batch is nullptr"); break; } - auto getBatchTime = duration_cast((sw).elapsed()); + auto getBatchTime = Format2Ms(sw); sw.reset(); - auto sendCountSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); - shared_ptr uniquePtr; - if (uniquePtrMap.find(sendCountSize) == uniquePtrMap.end()) { - uniquePtr.reset(new sharded_dedup(groupMethod, static_cast(batch->batchSize), sendCountSize)); - uniquePtrMap.insert(std::make_pair(sendCountSize, uniquePtr)); - } - unique = uniquePtrMap[sendCountSize]; - - if (unique != nullptr) { - unique->StartNewRound(); - } - - auto batchQueue = - SingletonQueue::getInstances(id + PerfConfig::keyProcessThreadNum * batch->channel); - - if (!KeyProcessTaskHelper(batch, unique, channel, id, sw)) { - free(batch->tensorAddr); - batchQueue->PutDirty(move(batch)); + if (!KeyProcessTaskHelper(batch, channel, id)) { break; } - TIME_PRINT("getAndProcesTC TimeCost(ms):{}", getAndProcesTC.ElapsedMS()); spdlog::info(KEY_PROCESS "key process cost:{}, get data time:{} batch {}[{}]:{} ", - duration_cast( - (sw).elapsed()), getBatchTime, batch->name, batch->channel, batch->batchId); - free(batch->tensorAddr); - batch->tensorAddr = nullptr; + Format2Ms(sw), getBatchTime, batch->name, batch->channel, batch->batchId); + auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } } catch (const EndRunError &e) { @@ -276,60 +228,74 @@ void KeyProcess::KeyProcessTask(const int channel, const int id) // thread id [0 spdlog::info(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, id, channel); } -bool KeyProcess::KeyProcessTaskHelper(unique_ptr &batch, shared_ptr unique, - int channel, int id, spdlog::stopwatch &sw) +void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, + vector & restore, vector & hotPos, + vector >& keyCount) { - // tuple for keyRec restore hotPos scAll countRecv - std::tuple, vector, vector, vector, vector> rets; - isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && - FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; - TimeCost tc; - UniqueInfo uniqueInfo; - ProcessBatchWithUniqueCompute(batch, unique, id, uniqueInfo); - TIME_PRINT("no copy ProcessBatchWithUniqueCompute TimeCost(ms):{}", tc.ElapsedMS()); - sw.reset(); + if (m_featureAdmitAndEvict.GetFunctionSwitch() && + FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { + tie(splitKeys, restore, keyCount) = HashSplit_withFAAE(batch); // 按存储dev id切分并去重 + } else { + if (rankInfo.useHot) { + tie(splitKeys, restore, hotPos) = HotHashSplit(batch); // 按存储dev id切分并去重 + } else { + tie(splitKeys, restore) = HashSplit(batch); // 按存储dev id切分并去重 + } + } +} + +bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int id) +{ + vector splitKeys; + vector restore; + vector hotPos; + vector> keyCount; + HashSplitHelper(batch, splitKeys, restore, hotPos, keyCount); + + auto [lookupKeys, scAll, ss] = ProcessSplitKeys(batch, id, splitKeys); + vector countRecv; + if (m_featureAdmitAndEvict.GetFunctionSwitch() && + FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { + countRecv = GetCountRecv(batch, id, keyCount, scAll, ss); + } + BuildRestoreVec(batch, ss, restore, static_cast(hotPos.size())); // 特征准入&淘汰 - if (isWithFAAE && - (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, uniqueInfo.all2AllInfo.keyRecv, - uniqueInfo.all2AllInfo.countRecv) - == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { + if (m_featureAdmitAndEvict.GetFunctionSwitch() && + FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE && + (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, lookupKeys, + countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { spdlog::error(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", - rankInfo.rankId, id, channel); + rankInfo.rankId, id, channel); return false; } - int batchListId = batch->batchId % PerfConfig::keyProcessThreadNum; // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) if (rankInfo.noDDR) { - TimeCost key2OffsetTc; - Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv); - TIME_PRINT("Key2Offset TimeCost(ms):{}", key2OffsetTc.ElapsedMS()); + Key2Offset(batch->name, lookupKeys); } if (!rankInfo.useStatic) { // Static all2all,need send count auto embName = batch->name; if (batch->modifyGraph) { embName = batch->channelName; } - SendA2A(uniqueInfo.all2AllInfo.scAll, embName, batch->channel, batch->batchId); + SendA2A(scAll, embName, batch->channel, batch->batchId); } auto tensors = make_unique>(); - tensors->push_back(Vec2TensorI32(uniqueInfo.restore)); + tensors->push_back(Vec2TensorI32(restore)); if (rankInfo.useHot) { - uniqueInfo.hotPos.resize(hotEmbTotCount[batch->name], -1); - tensors->push_back(Vec2TensorI32(uniqueInfo.hotPos)); + hotPos.resize(hotEmbTotCount[batch->name], 0); + tensors->push_back(Vec2TensorI32(hotPos)); } if (rankInfo.noDDR) { if (rankInfo.useDynamicExpansion) { - tensors->push_back(Vec2TensorI64(uniqueInfo.all2AllInfo.keyRecv)); + tensors->push_back(Vec2TensorI64(lookupKeys)); } else { - tensors->push_back(Vec2TensorI32(uniqueInfo.all2AllInfo.keyRecv)); + tensors->push_back(Vec2TensorI32(lookupKeys)); } } - TimeCost pushTensorTc; - PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv, batchListId); - TIME_PRINT("pushTensorToListTC TimeCost(ms):{}", pushTensorTc.ElapsedMS()); + PushResult(batch, move(tensors), lookupKeys); return true; } @@ -338,7 +304,7 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, { if (rankInfo.useStatic) { for (auto& cnt: keyCount) { - cnt.resize(embInfos[batch->name].sendCount, 0); + cnt.resize(GetSendCount(batch->name, batch->channelName, batch->modifyGraph), 0); } } vector countSend; @@ -362,19 +328,19 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, } void KeyProcess::PushResult(unique_ptr& batch, unique_ptr> tensors, - keys_t& lookupKeys, int id) + keys_t& lookupKeys) { - std::unique_lock lockGuard(getInfoMut[id]); - storage[id].push_front(move(tensors)); + std::unique_lock lockGuard(mut); + storage.push_front(move(tensors)); if (batch->modifyGraph) { - infoList[id][batch->channelName][batch->channel].push( - make_tuple(batch->batchId, batch->channelName, storage[id].begin())); + infoList[batch->channelName][batch->channel].push( + make_tuple(batch->batchId, batch->channelName, storage.begin())); } else { - infoList[id][batch->name][batch->channel].push( - make_tuple(batch->batchId, batch->name, storage[id].begin())); + infoList[batch->name][batch->channel].push( + make_tuple(batch->batchId, batch->name, storage.begin())); } if (!rankInfo.noDDR) { - lookupKeysList[id][batch->name][batch->channel].push( + lookupKeysList[batch->name][batch->channel].push( make_tuple(batch->batchId, batch->name, move(lookupKeys))); } lockGuard.unlock(); @@ -382,14 +348,15 @@ void KeyProcess::PushResult(unique_ptr& batch, unique_ptr中读取batch数据并返回。batch数据由 ReadEmbKeyV2 写入。 - * commID为线程标识[0, PerfConfig::keyProcessThreadNum-1],不同线程、训练或推理数据用不同的共享队列通信 + * commID为线程标识[0, KEY_PROCESS_THREAD-1],不同线程、训练或推理数据用不同的共享队列通信 */ unique_ptr KeyProcess::GetBatchData(int channel, int commId) { EASY_FUNCTION() unique_ptr batch = nullptr; - // train data, queue id = thread id [0, PerfConfig::keyProcessThreadNum-1] - auto batchQueue = SingletonQueue::getInstances(commId + PerfConfig::keyProcessThreadNum * channel); + + // train data, queue id = thread id [0, KEY_PROCESS_THREAD-1] + auto batchQueue = SingletonQueue::getInstances(commId + KEY_PROCESS_THREAD * channel); EASY_BLOCK("get samples") EASY_VALUE("run on CPU", sched_getcpu()) spdlog::stopwatch sw; @@ -403,7 +370,7 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) if (duration_cast(sw.elapsed()).count() > GET_BATCH_TIMEOUT) { if (commId == 0) { spdlog::warn(KEY_PROCESS "getting batch timeout! 1. check last 'read batch cost' print. " - "channel[{}] commId[{}]", channel, commId); + "channel[{}] commId[{}]", channel, commId); } this_thread::sleep_for(seconds(1)); sw.reset(); @@ -416,123 +383,19 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) } } EASY_END_BLOCK - spdlog::info(KEY_PROCESS "GetBatchData get batchId:{}, batchSize:{}, batch.channel:{}, batch.channelName:{}, " - "name:{}, channel:{}, commId:{}, ", - batch->batchId, batch->batchSize, batch->channel, batch->channelName, batch->name, channel, commId); + spdlog::debug(KEY_PROCESS "rank {} thread {} get batch {}[{}]:{} done. bs:{} sample:[{}]", + rankInfo.rankId, commId, batch->name, batch->channel, batch->batchId, batch->Size(), + batch->UnParse()); #if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) if (batch->batchId == PROFILING_START_BATCH_ID) { EASY_PROFILER_ENABLE } else if (batch->batchId == PROFILING_END_BATCH_ID) { - EASY_PROFILER_ENABLE ::profiler::dumpBlocksToFile(fmt::format("/home/MX_REC-profile-{}.prof", rankInfo.rankId).c_str()); } #endif return batch; } -size_t KeyProcess::GetKeySize(const unique_ptr &batch) -{ - size_t size = rankInfo.rankSize * embInfos[batch->name].sendCount; - if (batch->modifyGraph) { - size = rankInfo.rankSize * embInfos[batch->name].sendCountMap[batch->channelName]; - } - if (!rankInfo.useStatic) { - size = batch->batchSize; - } - return size; -} - -void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, shared_ptr unique, - int id, UniqueInfo& uniqueInfoOut) -{ - EASY_FUNCTION(profiler::colors::Purple) - EASY_VALUE("batchId", batch->batchId) - - EASY_BLOCK("ock-unique") - - TimeCost unique_tc; - - SimpleThreadPool pool_; - KeySendInfo keySendInfo; - size_t size = GetKeySize(batch); - keySendInfo.keySend.resize(size); - vector splitSize(rankInfo.rankSize); - vector uniqueVector(batch->batchSize); - uniqueInfoOut.restore.resize(batch->batchSize); - vector idCount(batch->batchSize); - keySendInfo.keyCount.resize(size); - std::shared_lock lock(g_smut); - auto hotMap = hotKey[batch->name]; - lock.unlock(); - int hotOffset = 0; - - if (rankInfo.useHot) { - uniqueInfoOut.hotPos.resize(hotEmbTotCount[batch->name]); - hotOffset = hotEmbTotCount[batch->name]; - } - absl::flat_hash_map keyCountMap; - - UniqueData uniqueData = {batch->tensorAddr, batch->batchSize, uniqueInfoOut.restore.data(), uniqueVector.data(), - splitSize.data(), keySendInfo.keySend.data(), idCount.data(), keySendInfo.keyCount.data()}; - UniqueFlag uniqueFlag = {batch->isInt64, rankInfo.useStatic, rankInfo.useHot}; - UniqueForHot uniqueForHot = {hotOffset, uniqueInfoOut.hotPos.data(), hotMap, keyCountMap}; - UniqueThreadNum uniqueThreadNum = {MIN_UNIQUE_THREAD_NUM, PerfConfig::maxUniqueThreadNum}; - - unique->Compute(&pool_, uniqueData, uniqueFlag, uniqueForHot, uniqueThreadNum); - EASY_END_BLOCK - TIME_PRINT("UniqueCompute TimeCost(ms):{}", unique_tc.ElapsedMS()); - - if (rankInfo.useHot) { - UpdateHotMap(keyCountMap, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, batch->name); - } - - vector sc; // send count - if (rankInfo.useStatic) { - sc.resize(rankInfo.rankSize, GetSendCount(batch->name, batch->channelName, batch->modifyGraph)); - } else { - sc.resize(rankInfo.rankSize); - for (int i = 0;i < rankInfo.rankSize; i++) { - sc[i] = splitSize[i]; - } - } - All2All(sc, id, batch->channel, keySendInfo, uniqueInfoOut.all2AllInfo); - - spdlog::debug(KEY_PROCESS "ProcessBatchWithUniqueCompute get batchId:{}, batchSize:{}, channel:{}, " - "channelName:{}, name:{}, restore:{}, keyCount:{}", batch->batchId, batch->batchSize, - batch->channel, batch->channelName, batch->name, uniqueInfoOut.restore.size(), - keySendInfo.keyCount.size()); -} - -void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, - All2AllInfo& all2AllInfoOut) - -{ - TimeCost get_sc_all; - GetScAll(sc, id, channel, all2AllInfoOut.scAll); // Allgather通信获取所有(不同rank相同thread id的) - TIME_PRINT("GetScAll TimeCost(ms):{}", get_sc_all.ElapsedMS()); - - TimeCost all2allTC; - auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 - vector rc(rankInfo.rankSize); // receive count - for (int i = 0; i < rankInfo.rankSize; ++i) { - // 通信量矩阵某一列的和即为本地要从其他设备接受的key数据量 - rc[i] = all2AllInfoOut.scAll.at(i * rankInfo.rankSize + rankInfo.rankId); - } - auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 - all2AllInfoOut.keyRecv.resize(rs.back() + rc.back()); - EASY_BLOCK("all2all") - MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfoOut.keyRecv.data(), - rc.data(), rs.data(), MPI_INT64_T, comm[channel][id]); - - all2AllInfoOut.countRecv.resize(rs.back() + rc.back()); - if (isWithFAAE) { - MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfoOut.countRecv.data(), - rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); - } - TIME_PRINT("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); - EASY_END_BLOCK -} - auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, vector& splitKeys) -> tuple, vector> { @@ -543,13 +406,13 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, // 使用静态all2all通信:发送或接受量为预置固定值 scInfo[batch->name] = 65536 / rankSize 经验值 if (rankInfo.useStatic) { // maybe move after all2all for (auto& i: splitKeys) { - if (static_cast(i.size()) > embInfos[batch->name].sendCount) { + if (static_cast(i.size()) > GetSendCount(batch->name, batch->channelName, batch->modifyGraph)) { spdlog::error("{}[{}]:{} overflow! set send count bigger than {}", batch->name, batch->channel, batch->batchId, i.size()); throw runtime_error(fmt::format("{}[{}]:{} overflow! set send count bigger than {}", batch->name, batch->channel, batch->batchId, i.size()).c_str()); } - i.resize(embInfos[batch->name].sendCount, -1); + i.resize(GetSendCount(batch->name, batch->channelName, batch->modifyGraph), -1); } } keys_t keySend; @@ -559,8 +422,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, keySend.insert(keySend.end(), i.begin(), i.end()); } keys_t keyRecv; - vector scAll; - GetScAll(sc, id, batch->channel, scAll); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 + auto scAll = GetScAll(sc, id, batch->channel); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 vector rc; // receive count for (int i = 0; i < rankInfo.rankSize; ++i) { @@ -569,15 +431,15 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, } auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 keyRecv.resize(rs.back() + rc.back()); - spdlog::trace(KEY_PROCESS "MPI_Alltoallv begin. rank {} thread {} batch {} {}", - rankInfo.rankId, id, batch->batchId, batch->name); + spdlog::trace(KEY_PROCESS "MPI_Alltoallv begin. rank {} thread {} batch {} {}", rankInfo.rankId, id, batch->batchId, + batch->name); EASY_BLOCK("all2all") MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[batch->channel][id]); EASY_END_BLOCK spdlog::trace(KEY_PROCESS "MPI_Alltoallv finish. rank {} thread {} batch {} {}", - rankInfo.rankId, id, batch->batchId, batch->name); + rankInfo.rankId, id, batch->batchId, batch->name); return { keyRecv, scAll, ss }; } @@ -592,7 +454,6 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< EASY_FUNCTION(profiler::colors::Gold) auto* batchData = batch->sample.data(); size_t miniBs = batch->Size(); - assert(batchData != nullptr); vector splitKeys(rankInfo.rankSize); vector restore(batch->Size()); vector hashSplitLens(rankInfo.rankSize); // 初始化全0,记录每个桶的长度 @@ -631,7 +492,6 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const EASY_FUNCTION(profiler::colors::Gold) auto* batchData = batch->sample.data(); size_t miniBs = batch->Size(); - assert(batchData != nullptr); vector splitKeys(rankInfo.rankSize); vector> keyCount(rankInfo.rankSize); // splitKeys在原始batch中对应的频次 vector restore(batch->Size()); @@ -728,18 +588,18 @@ tuple, vector, vector> } UpdateHotMap(keyCountMap, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, batch->name); - AddCountStartToHotPos(splitKeys, hotPos, hotPosDev, batch->name); + AddCountStartToHotPos(splitKeys, hotPos, hotPosDev, batch); return { splitKeys, restore, hotPos }; } void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, - const string& embName) const + const unique_ptr& batch) { vector splitKeysSize {}; if (rankInfo.useStatic) { for (size_t i = 0; i < splitKeys.size(); i++) { - splitKeysSize.push_back(embInfos.at(embName).sendCount); + splitKeysSize.push_back(GetSendCount(batch->name, batch->channelName, batch->modifyGraph)); } } else { for (auto& splitKey: splitKeys) { @@ -778,10 +638,11 @@ void KeyProcess::UpdateHotMap(absl::flat_hash_map& keyCountMap, * 将本地(rank)batch要发送的key数据量进行Allgather通信,获取所有(不同rank相同thread id的)线程间的通信量矩阵 * scAll返回:所有线程间的通信量矩阵(按行平铺的一维向量) */ -void KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const +vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel) const { EASY_FUNCTION() - scAllOut.resize(rankInfo.rankSize * rankInfo.rankSize); + vector scAll; + scAll.resize(rankInfo.rankSize * rankInfo.rankSize); EASY_BLOCK("barrier"); // 通信终止信号,同步退出,防止线程卡住 spdlog::stopwatch sw; @@ -791,17 +652,18 @@ void KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel throw EndRunError("GetScAll end run."); } EASY_END_BLOCK; - spdlog::debug(KEY_PROCESS "barrier time:{}", duration_cast((sw).elapsed())); + spdlog::debug(KEY_PROCESS "barrier time:{}", Format2Ms(sw)); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, - scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); - spdlog::debug("rank {} key scAllOut matrix:\n{}", rankInfo.rankId, scAllOut); + scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + spdlog::debug("rank {} key scAll matrix:\n{}", rankInfo.rankId, scAll); + return scAll; } void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) { EASY_FUNCTION(profiler::colors::Blue600) - std::lock_guard lk(key2OffsetMut); // lock for PROCESS_THREAD + std::lock_guard lk(mut); // lock for PROCESS_THREAD auto& key2Offset = keyOffsetMap[embName]; auto& maxOffsetTmp = maxOffset[embName]; auto& evictPos = evictPosMap[embName]; @@ -829,30 +691,25 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) } else { // 新值 if (rankInfo.useDynamicExpansion) { +#ifndef GTEST auto addr = curEmbTable.GetEmbAddress(); key2Offset[key] = addr; key = addr; +#endif maxOffsetTmp++; } else { - key2Offset[key] = maxOffsetTmp; - key = maxOffsetTmp++; - } + key2Offset[key] = maxOffsetTmp; + key = maxOffsetTmp++; } } + } if (!rankInfo.useDynamicExpansion && maxOffsetTmp > embInfos[embName].devVocabSize) { spdlog::error("dev cache overflow {}>{}", maxOffsetTmp, embInfos[embName].devVocabSize); + throw std::runtime_error("dev cache overflow!"); } spdlog::debug("current dev emb usage:{}/{}", maxOffsetTmp, embInfos[embName].devVocabSize); } -void KeyProcess::Key2OffsetInit(const emb_name_t& embName) -{ - auto& key2Offset = keyOffsetMap[embName]; - auto& offset = maxOffset[embName]; - key2Offset[rankInfo.rankId] = offset; // 0 rank init feature id 0 to offset 0 - offset++; -} - /* * 构建恢复向量,以便从去重后的emb向量/key恢复回batch对应的emb向量 * 输入接收到emb块的偏移blockOffset,batch内每个key在块内的偏移restoreVec @@ -885,26 +742,25 @@ class WrongListTop : public std::exception { }; template -T KeyProcess::GetInfo(std::vector>& list, int batch, const string& embName, int channel) +T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, int channel) { - int batchListId = batch % PerfConfig::keyProcessThreadNum; - std::lock_guard lockGuard(getInfoMut[batchListId]); - if (list[batchListId][embName][channel].empty()) { + std::lock_guard lockGuard(mut); + if (list[embName][channel].empty()) { spdlog::trace("get info list is empty."); throw EmptyList(); } - auto topBatch = get(list[batchListId][embName][channel].top()); + auto topBatch = get(list[embName][channel].top()); if (topBatch < batch) { - spdlog::warn("wrong batch id, top:{} expect:{}, channel:{}, embName: {}, queue_size:{}, may not clear channel", - topBatch, batch, channel, embName, list[batchListId][embName][channel].size()); + spdlog::error("wrong batch id, top:{} getting:{}, channel:{}, may not clear channel", topBatch, + batch, channel); this_thread::sleep_for(1s); } if (topBatch != batch) { spdlog::trace("topBatch({}) is not equal batch({}).", topBatch, batch); throw WrongListTop(); } - auto t = list[batchListId][embName][channel].top(); - list[batchListId][embName][channel].pop(); + auto t = list[embName][channel].top(); + list[embName][channel].pop(); return move(t); } @@ -923,12 +779,10 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) auto ret = GetInfo(lookupKeysList, batch, embName, channel); return get(ret); } catch (EmptyList&) { - spdlog::trace("GetLookupKeys GetInfo failed {}[{}]:{} no input, wait and retry", - embName, channel, batch); + spdlog::trace("getting info failed {}[{}]:{}", embName, channel, batch); this_thread::sleep_for(1ms); } catch (WrongListTop&) { - spdlog::trace("GetLookupKeys GetInfo failed {}[{}]:{} wrong top", - embName, channel, batch); + spdlog::trace("getting info failed {}[{}]:{} wrong top", embName, channel, batch); this_thread::sleep_for(1ms); } } @@ -937,7 +791,7 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) { spdlog::stopwatch sw; - std::vector>* list; + info_list_t* list; switch (type) { case ProcessedInfo::ALL2ALL: list = &all2AllList; @@ -960,23 +814,20 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa auto ret = GetInfo(*list, batch, embName, channel); auto it = get>>::iterator>(ret); auto uTensor = move(*it); - int batchListId = batch % PerfConfig::keyProcessThreadNum; - unique_lock lockGuard(getInfoMut[batchListId]); - storage[batchListId].erase(it); + std::unique_lock lockGuard(mut); + storage.erase(it); return uTensor; } catch (EmptyList&) { - spdlog::trace("GetInfoVec GetInfo failed {}[{}]:{} type: {} no input and retry", - embName, channel, batch, type); + spdlog::trace("getting info failed {}[{}]:{}", embName, channel, batch); this_thread::sleep_for(1ms); } catch (WrongListTop&) { - spdlog::trace("GetInfoVec GetInfo failed {}[{}]:{} type: {} wrong top", - embName, channel, batch, type); + spdlog::trace("getting info failed {}[{}]:{} wrong top", embName, channel, batch); this_thread::sleep_for(1ms); } } } -void KeyProcess::SendA2A(const vector& a2aInfo, const string& embName, int channel, int batchId) +void KeyProcess::SendA2A(const vector& a2aInfo, const string& embName, int channel, int batch) { // 数据放到队列里,在mgmt里面发送(检查发送数据量) auto tensors = make_unique>(); @@ -989,10 +840,9 @@ void KeyProcess::SendA2A(const vector& a2aInfo, const string& embName, int } tensors->emplace_back(move(tmpTensor)); - int batchListId = batchId % PerfConfig::keyProcessThreadNum; - std::unique_lock lockGuard(getInfoMut[batchListId]); - storage[batchListId].push_front(move(tensors)); - all2AllList[batchListId][embName][channel].push(make_tuple(batchId, embName, storage[batchListId].begin())); + std::unique_lock lockGuard(mut); + storage.push_front(move(tensors)); + all2AllList[embName][channel].push(make_tuple(batch, embName, storage.begin())); lockGuard.unlock(); } @@ -1017,7 +867,7 @@ void KeyProcess::EvictKeys(const string& embName, const vector& keys) void KeyProcess::EvictDeleteDeviceEmb(const string& embName, const vector& keys) { EASY_FUNCTION(profiler::colors::Blue600) - std::lock_guard lk(key2OffsetMut); // lock for PROCESS_THREAD + std::lock_guard lk(mut); // lock for PROCESS_THREAD size_t keySize = keys.size(); auto& devHashMap = keyOffsetMap.at(embName); @@ -1027,7 +877,7 @@ void KeyProcess::EvictDeleteDeviceEmb(const string& embName, const vector offset embName, offset.size(), embInfos[embName].devVocabSize).c_str()); } if (rankInfo.useStatic) { - offset.resize(embInfos[embName].devVocabSize, -1); + offset.resize(embInfos[embName].devVocabSize, -1); } auto trans = Singleton::GetInstance(); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index e6fa2b47..a11e4e2c 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -19,14 +19,12 @@ #include #include #include +#include #include #include #include "utils/common.h" #include "utils/safe_queue.h" -#include "utils/unique.h" -#include "utils/spinlock.h" -#include "utils/task_queue.h" #include "host_emb/host_emb.h" #include "emb_table/emb_table.h" @@ -35,11 +33,9 @@ namespace MxRec { using namespace std; - using a2a_info_t = vector; - using sharded_dedup = ShardedDedup; - - template struct Cmp { - bool operator () (const T &a, const T &b) + template + struct Cmp { + bool operator()(const T& a, const T& b) { return get(a) > get(b); // batch id order } @@ -59,10 +55,8 @@ namespace MxRec { class KeyProcess { public: - KeyProcess(); - int Initialize(const RankInfo& rInfo, const vector& eInfos, - const vector& thresholdValues = {}, bool ifLoad = false, int seed = 0); + const vector& thresholdValues = {}, int seed = 0); unique_ptr> GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type); @@ -98,24 +92,18 @@ namespace MxRec { }; GTEST_PRIVATE: template - T GetInfo(std::vector>& list, int batch, const string& embName, int channel); + T GetInfo(info_list_t& list, int batch, const string& embName, int channel); RankInfo rankInfo; map embInfos; - - std::vector comm[MAX_CHANNEL_NUM] {}; - + MPI_Comm comm[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD]; + std::mutex mut {}; vector> procThreads {}; - std::mutex key2OffsetMut {}; - - std::vector loadSaveMut[MAX_CHANNEL_NUM]; - std::vector getInfoMut; - - std::vector>>> storage; - std::vector> lookupKeysList; - std::vector> infoList; - std::vector> all2AllList; - + std::mutex loadSaveMut[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD] {}; + info_list_t lookupKeysList; + list>> storage; + info_list_t infoList; + info_list_t all2AllList; map maxOffset {}; map> keyOffsetMap {}; FeatureAdmitAndEvict m_featureAdmitAndEvict {}; @@ -130,28 +118,21 @@ namespace MxRec { auto GetSendCount(const string& name, const string& channelName, bool modifyGraph); - void KeyProcessTask(const int channel, const int id); + void KeyProcessTask(int channel, int id); + + bool KeyProcessTaskHelper(unique_ptr& batch, int channel, int id); - bool KeyProcessTaskHelper(unique_ptr& batch, shared_ptr unique, - int channel, int id, spdlog::stopwatch& sw); auto ProcessSplitKeys(const unique_ptr& batch, int id, vector& splitKeys) -> tuple, vector>; - void ProcessBatchWithUniqueCompute(const unique_ptr &batch, shared_ptr unique, - int id, UniqueInfo& uniqueInfoOut); - - size_t GetKeySize(const unique_ptr &batch); - - void All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, - All2AllInfo& all2AllInfoOut); - auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector>; auto HashSplit_withFAAE(const unique_ptr& batch) const -> tuple, vector, vector>>; - void GetScAll(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const; + + vector GetScAll(const vector& keyScLocal, int commId, int channel) const; void Key2Offset(const emb_name_t& embName, keys_t& splitKey); @@ -162,8 +143,6 @@ namespace MxRec { void SendA2A(const vector& a2aInfo, const string& embName, int channel, int batch); - void Key2OffsetInit(const emb_name_t& embName); - void EvictDeleteDeviceEmb(const string& embName, const vector& keys); void EvictInitDeviceEmb(const string& embName, vector offset); @@ -171,13 +150,17 @@ namespace MxRec { void UpdateHotMap(absl::flat_hash_map& keyCountMap, uint32_t count, bool refresh, const string& embName); - void PushResult(unique_ptr& batch, unique_ptr> tensors, keys_t& lookupKeys, int id); + void PushResult(unique_ptr& batch, unique_ptr> tensors, keys_t& lookupKeys); void AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, - const string& embName) const; + const unique_ptr& batch); vector GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss); + + void HashSplitHelper(const unique_ptr & batch, vector & splitKeys, + vector & restore, vector & hotPos, + vector >& keyCount); }; } // end namespace MxRec #endif // MX_REC_KEY_PROCESS_H diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index a4d23fe7..b63e4bd1 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -9,6 +9,7 @@ #include "common.h" #include +#include #include #include @@ -89,6 +90,8 @@ namespace MxRec { void SetLog(int rank) { + auto logger = spdlog::stderr_color_mt("console"); + spdlog::set_default_logger(logger); std::string pattern = "[%H:%M:%S.%e] [" + std::to_string(rank) + "] [%^%l%$] %v"; spdlog::default_logger()->set_pattern(pattern); auto env_val = spdlog::details::os::getenv("SPDLOG_LEVEL"); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index fe628307..3d8b1021 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -23,7 +23,11 @@ #include #include #include - +#include +#include +#include +#include +#include #include "tensorflow/core/framework/tensor.h" #include "absl/container/flat_hash_map.h" @@ -49,10 +53,7 @@ namespace MxRec { #define TIME_PRINT spdlog::info #define MGMT_CPY_THREADS 4 #define PROFILING - // read batch cost - // key process cost - using namespace tensorflow; - + using namespace tensorflow; constexpr int TRAIN_CHANNEL_ID = 0; constexpr int EVAL_CHANNEL_ID = 1; @@ -60,6 +61,7 @@ namespace MxRec { constexpr int MAX_KEY_PROCESS_THREAD = 10; constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * MAX_KEY_PROCESS_THREAD; constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; + constexpr int KEY_PROCESS_THREAD = 6; // unique related config constexpr int UNIQUE_BUCKET = 6; @@ -131,6 +133,11 @@ namespace MxRec { throw std::runtime_error("unknown chip ub size" + GetChipName(devID)); } + inline std::chrono::milliseconds::rep Format2Ms(spdlog::stopwatch& sw) + { + return std::chrono::duration_cast((sw).elapsed()).count(); + } + template struct Batch { size_t Size() const @@ -142,7 +149,7 @@ namespace MxRec { { std::string s; constexpr size_t MAX_DISP_LEN = 20; - int maxLen = std::min(sample.size(), MAX_DISP_LEN); + int maxLen = static_cast(std::min(sample.size(), MAX_DISP_LEN)); for (int i = 0; i < maxLen; i++) { s += std::to_string(sample[i]) + " "; } @@ -406,6 +413,11 @@ struct BatchTask { size_t maxOffset { 0 }; std::vector evictPos; std::vector evictDevPos; + size_t maxOffsetOld { 0 }; + std::vector evictPosChange; + std::vector evictDevPosChange; + std::vector> devOffset2KeyOld; + std::vector> oldSwap; // (old on dev, old on host) void SetStartCount(); diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 5573b75e..c2fb614f 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -21,15 +21,11 @@ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/example/example.pb.h" - -#include "securec.h" - #include "key_process/key_process.h" #include "key_process/feature_admit_and_evict.h" #include "utils/common.h" #include "utils/safe_queue.h" #include "utils/singleton.h" -#include "utils/time_cost.h" using namespace tensorflow; using shape_inference::InferenceContext; @@ -44,7 +40,17 @@ using InferenceContextPtr = ::tensorflow::shape_inference::InferenceContext*; spdlog::stopwatch staticSw {}; spdlog::stopwatch staticReadRaw {}; -array, MAX_CHANNEL_NUM> batchIdsInfo {}; +array batchIdsInfo {}; + +size_t GetBatchSize(OpKernelContextPtr context, const size_t dataSize, const size_t fieldNum) +{ + if (fieldNum == 0 || dataSize / fieldNum <= 0) { + context->SetStatus( + errors::Aborted(__FILE__, ":", __LINE__, " ", fmt::format("batchSize error. {}/{}", dataSize, fieldNum))); + return 0; + } + return dataSize / fieldNum; +} REGISTER_OP("ClearChannel").Attr("channel_id : int"); @@ -56,8 +62,8 @@ public: if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("ClearChannel channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:{})", - MAX_CHANNEL_NUM))); + fmt::format("ClearChannel channelId invalid. It should be in range " + "[0, MAX_CHANNEL_NUM:{})", MAX_CHANNEL_NUM))); return; } } @@ -144,12 +150,22 @@ public: if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("ReadEmbKeyV2Dynamic channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:{})", - MAX_CHANNEL_NUM))); + fmt::format("ReadEmbKeyV2Dynamic channelId invalid. It should be in " + "range [0, MAX_CHANNEL_NUM:{})", MAX_CHANNEL_NUM))); return; } batchIdsInfo.at(channelId) = 0; + const char* threadNumEnv = getenv("THREAD_NUM"); + if (threadNumEnv != nullptr) { + threadNum = static_cast(*threadNumEnv) - static_cast('0'); + if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { + throw runtime_error(fmt::format("{} is not valid", threadNum)); + } + } else { + threadNum = KEY_PROCESS_THREAD; + } + auto keyProcess = Singleton::GetInstance(); if (!keyProcess->isRunning) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); @@ -164,7 +180,7 @@ public: EASY_FUNCTION(); spdlog::debug("enter ReadEmbKeyV2Dynamic"); spdlog::stopwatch sw; - int batchId = batchIdsInfo.at(channelId).fetch_add(1); + int batchId = batchIdsInfo.at(channelId)++; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { spdlog::warn("skip excess batch after {}/{}", batchId, maxStep); @@ -183,24 +199,24 @@ public: // 如果传递了时间戳,解析和校验 if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("timestamp[{}] error, skip excess batch after {}/{}", timestamp, batchId, maxStep))); + fmt::format("timestamp[{}] error, skip excess batch after {}/{}", + timestamp, batchId, maxStep))); return; } // 保证所有embNames在m_embStatus中有状态记录 SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); - // [batchId % PerfConfig::keyProcessThreadNum] which thread process this batch - // [PerfConfig::keyProcessThreadNum * 0 or 1] train or inference - int batchQueueId = batchId % PerfConfig::keyProcessThreadNum + PerfConfig::keyProcessThreadNum * channelId; + // [batchId % KEY_PROCESS_THREAD] which thread process this batch + // [KEY_PROCESS_THREAD * 0 or 1] train or inference + int batchQueueId = batchId % threadNum + KEY_PROCESS_THREAD * channelId; Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); out(0) = batchId; EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); - TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}, " - "splits: {}, dataSize: {}, filedNum: {}, channelNames: {}, modifyGraph: {}", - duration_cast((sw).elapsed()), duration_cast((staticSw).elapsed()), - channelId, batchId, splits.size(), dataSize, fieldNum, channelNames, modifyGraph); + TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", + Format2Ms(sw), Format2Ms(staticSw), + channelId, batchId); staticSw.reset(); } @@ -233,7 +249,7 @@ public: offset += splits(i); continue; } - auto batchData = queue->WaitAndGetOne(); // get dirty or empty data block + auto batchData = queue->GetOne(); // get dirty or empty data block batchData->name = embNames.at(i); if (modifyGraph) { batchData->modifyGraph = modifyGraph; @@ -242,61 +258,21 @@ public: size_t len = splits(i); batchData->channel = channelId; batchData->batchId = ids[0]; - batchData->batchSize = len; + batchData->sample.resize(len); if (isTimestamp) { batchData->timestamp = timestamp; } - spdlog::info("split size:{} {}", i, splits(i)); - spdlog::info("emb_name:{} {}", i, embNames.at(i)); - - spdlog::debug("batch[{}/{}] flatten bs: {}", ids[0], i+1, len); - std::unique_ptr batch = TensorCopy(inputTensor, move(batchData), len, offset); - if (batch == nullptr) { - spdlog::error("batch can not be null"); - throw runtime_error("batch can not be null"); - } - queue->Pushv(move(batch)); - } - TIME_PRINT(KEY_PROCESS "EnqueueBatchData, batchId:{}, channelId:{}", ids[0], channelId); - } - std::unique_ptr TensorCopy(const Tensor& inputTensor, std::unique_ptr batchData, - const size_t& len, size_t& offset) - { - if (len == 0) { - spdlog::error("the length of batchData can not be zero"); - throw runtime_error("the length of batchData can not be zero"); - } - TimeCost ct; - void* src = nullptr; - size_t memSize; - if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { - batchData->isInt64 = false; - memSize = len * sizeof(int32_t); - src = reinterpret_cast( - reinterpret_cast(const_cast((string *)(inputTensor.tensor_data().data()))) + - offset); - } else { - batchData->isInt64 = true; - memSize = len * sizeof(int64_t); - src = reinterpret_cast( - reinterpret_cast(const_cast((string *)(inputTensor.tensor_data(). - data()))) + offset); - } - batchData->tensorAddr = malloc(memSize); - if (batchData->tensorAddr == nullptr) { - spdlog::error("mmemory allocation failded..."); - throw runtime_error("mmemory allocation failded..."); - } - void* dst = reinterpret_cast(batchData->tensorAddr); - auto rc = memcpy_s(dst, memSize, src, memSize); - if (rc != 0) { - spdlog::error("[ReadEmbKeyV2Dynamic]memcpy_s failded... memSize: {}", memSize); - throw runtime_error(fmt::format("[ReadEmbKeyV2Dynamic]memcpy_s failded... memSize: {}", memSize).c_str()); + if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { + auto src = (const int32_t*)inputTensor.tensor_data().data(); + copy(src + offset, src + offset + len, batchData->sample.data()); + } else { + auto src = (const int64_t*)inputTensor.tensor_data().data(); + copy(src + offset, src + offset + len, batchData->sample.data()); + } + offset += len; + queue->Pushv(move(batchData)); } - TIME_PRINT("copy TimeCost(ms):{}", ct.ElapsedMS()); - offset += len; - return move(batchData); } bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp, @@ -342,6 +318,7 @@ public: int maxStep = 0; bool isTimestamp { false }; bool modifyGraph { false }; + int threadNum = 0; }; REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2Dynamic").Device(DEVICE_CPU), ReadEmbKeyV2Dynamic); @@ -389,17 +366,28 @@ public: if (splits.size() != embNames.size()) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("splits & embNames size error.{} {}", splits.size(), embNames.size()))); + fmt::format("splits & embNames size error.{} {}", splits.size(), + embNames.size()))); return; } if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("ReadEmbKeyV2 channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:{})", - MAX_CHANNEL_NUM))); + fmt::format("ReadEmbKeyV2 channelId invalid. It should be in range " + "[0, MAX_CHANNEL_NUM:{})", MAX_CHANNEL_NUM))); return; } batchIdsInfo.at(channelId) = 0; + const char* threadNumEnv = getenv("THREAD_NUM"); + if (threadNumEnv != nullptr) { + threadNum = static_cast(*threadNumEnv) - static_cast('0'); + if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { + throw runtime_error(fmt::format("{} is not valid", threadNum)); + } + } else { + threadNum = KEY_PROCESS_THREAD; + } + auto keyProcess = Singleton::GetInstance(); if (!keyProcess->isRunning) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); @@ -433,29 +421,24 @@ public: // 如果传递了时间戳,解析和校验 if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("timestamp[{}] error, skip excess batch after {}/{}", timestamp, batchId, maxStep))); - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); - auto out = output->flat(); - out(0) = batchId; + fmt::format("timestamp[{}] error, skip excess batch after {}/{}", + timestamp, batchId, maxStep))); return; } // 保证所有embNames在m_embStatus中有状态记录 SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); - // [batchId % PerfConfig::keyProcessThreadNum] which thread process this batch - // [PerfConfig::keyProcessThreadNum * 0 or 1] train or inference - int batchQueueId = batchId % PerfConfig::keyProcessThreadNum + PerfConfig::keyProcessThreadNum * channelId; + // [batchId % KEY_PROCESS_THREAD] which thread process this batch + // [KEY_PROCESS_THREAD * 0 or 1] train or inference + int batchQueueId = batchId % threadNum + KEY_PROCESS_THREAD * channelId; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); out(0) = batchId; - - TimeCost tc; EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); - TIME_PRINT("EnqueueBatchData TimeCost(ms):{}", tc.ElapsedMS()); - - TIME_PRINT(KEY_PROCESS - "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", duration_cast((sw).elapsed()), - duration_cast((staticSw).elapsed()), channelId, batchId); + TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", + Format2Ms(sw), Format2Ms(staticSw), + channelId, batchId); staticSw.reset(); } @@ -472,7 +455,7 @@ public: } } - int EnqueueBatchData(int batchId, int batchQueueId, time_t timestamp, const Tensor& inputTensor) + void EnqueueBatchData(int batchId, int batchQueueId, time_t timestamp, const Tensor& inputTensor) { if (tableUsed.empty()) { CheckEmbTables(); @@ -483,17 +466,12 @@ public: if (isTimestamp) { offset += 1; // 前面8个字节是unix时间戳 } - TimeCost ctAll; for (size_t i = 0; i < splits.size(); ++i) { if (!tableUsed.at(i)) { offset += splits.at(i); continue; } - - TimeCost tp; - auto batchData = queue->WaitAndGetOne(); // get dirty or empty data block - TIME_PRINT("TryPopTimeCost(ms):{}", tp.ElapsedMS()); - + auto batchData = queue->GetOne(); // get dirty or empty data block batchData->name = embNames.at(i); if (modifyGraph) { batchData->modifyGraph = modifyGraph; @@ -502,61 +480,21 @@ public: size_t len = splits.at(i); batchData->channel = channelId; batchData->batchId = batchId; - batchData->batchSize = len; - TimeCost fz; + batchData->sample.resize(len); if (isTimestamp) { batchData->timestamp = timestamp; } - TIME_PRINT("fz TimeCost(ms):{}", fz.ElapsedMS()); - std::unique_ptr batch = TensorCopy(inputTensor, move(batchData), len, offset); - if (batch == nullptr) { - spdlog::error("batch can not be null"); - throw runtime_error("batch can not be null"); + if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { + auto src = (const int32_t*)inputTensor.tensor_data().data(); + copy(src + offset, src + offset + len, batchData->sample.data()); + } else { + auto src = (const int64_t*)inputTensor.tensor_data().data(); + copy(src + offset, src + offset + len, batchData->sample.data()); } - queue->Pushv(move(batch)); + offset += len; + queue->Pushv(move(batchData)); } - TIME_PRINT("all copy TimeCost(ms):{}", ctAll.ElapsedMS()); - return 0; - } - - std::unique_ptr TensorCopy(const Tensor& inputTensor, std::unique_ptr batchData, - const size_t& len, size_t& offset) - { - if (len == 0) { - spdlog::error("len can not be zero"); - throw runtime_error("len can not be zero"); - } - TimeCost ct; - void* src = nullptr; - size_t memSize; - if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { - batchData->isInt64 = false; - memSize = len * sizeof(int32_t); - src = reinterpret_cast( - reinterpret_cast(const_cast((string *)(inputTensor.tensor_data().data()))) + - offset); - } else { - batchData->isInt64 = true; - memSize = len * sizeof(int64_t); - src = reinterpret_cast( - reinterpret_cast(const_cast((string *)(inputTensor.tensor_data().data()))) + - offset); - } - batchData->tensorAddr = malloc(memSize); - if (batchData->tensorAddr == nullptr) { - spdlog::error("mmemory allocation failded..."); - throw runtime_error("mmemory allocation failded..."); - } - void* dst = reinterpret_cast(batchData->tensorAddr); - auto rc = memcpy_s(dst, memSize, src, memSize); - if (rc != 0) { - spdlog::error("[ReadEmbKeyV2Static]memcpy_s failded... memSize: {}", memSize); - throw runtime_error(fmt::format("[ReadEmbKeyV2Static]memcpy_s failded... memSize: {}", memSize).c_str()); - } - TIME_PRINT("copy TimeCost(ms):{}", ct.ElapsedMS()); - offset += len; - return move(batchData); } bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp, @@ -603,6 +541,7 @@ public: int maxStep = 0; bool isTimestamp { false }; bool modifyGraph { false }; + int threadNum = 0; }; REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), ReadEmbKeyV2); @@ -655,8 +594,8 @@ public: for (int i { 0 }; i < restoreLen; ++i) { r(i) = i % lookupLen; } - spdlog::warn("dummy read batch cost: {},elapsed from last {}", duration_cast((sw).elapsed()), - duration_cast((staticSw).elapsed())); + spdlog::warn("dummy read batch cost: {},elapsed from last {}", + Format2Ms(sw), Format2Ms(staticSw)); staticSw.reset(); } @@ -665,135 +604,6 @@ public: REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyDatasetDummy").Device(DEVICE_CPU), ReadEmbKeyDatasetDummy); - -// ##################### ReadRaw ####################### -REGISTER_OP("ReadRaw") - .Input("sample: string") - .Output("int_output: int64") - .Output("float_output: float") - .Attr("int_len: int") - .Attr("float_len: int") - .Attr("feat_order: list(string)") - .SetShapeFn([](InferenceContextPtr c) { - int temp; - TF_RETURN_IF_ERROR(c->GetAttr("int_len", &temp)); - c->set_output(TENSOR_INDEX_0, c->Vector(temp)); - TF_RETURN_IF_ERROR(c->GetAttr("float_len", &temp)); - c->set_output(TENSOR_INDEX_1, c->Vector(temp)); - return Status::OK(); - }); - -class ReadRaw : public OpKernel { -public: - explicit ReadRaw(OpKernelConstructionPtr context) : OpKernel(context) - { - OP_REQUIRES_OK(context, context->GetAttr("int_len", &intLen)); - OP_REQUIRES_OK(context, context->GetAttr("float_len", &floatLen)); - OP_REQUIRES_OK(context, context->GetAttr("feat_order", &featOrder)); - sampleId = 0; - } - - ~ReadRaw() override = default; - - void Compute(OpKernelContextPtr context) override - { - spdlog::stopwatch sw; - Tensor* intTensor = nullptr; - Tensor* floatTensor = nullptr; - int intDataIndex = 0; - int floatDataIndex = 0; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape { intLen }, &intTensor)); - OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape { floatLen }, &floatTensor)); - const Tensor& inputTensor = context->input(TENSOR_INDEX_0); - auto input = inputTensor.flat()(0); - tensorflow::Example example; - if (!example.ParseFromString(input)) { - cerr << "Failed to parse file." << endl; - } - spdlog::stopwatch sw_copy; - auto all_feature_map = example.features().feature(); - for (const auto& featName: featOrder) { - auto& cur_feature_value = all_feature_map.at(featName); - if (cur_feature_value.has_int64_list()) { - auto int64List = cur_feature_value.int64_list(); - int64* flat = intTensor->flat().data() + intDataIndex; - std::copy(int64List.value().begin(), int64List.value().end(), flat); - intDataIndex += int64List.value_size(); - } - if (cur_feature_value.has_float_list()) { - auto floatList = cur_feature_value.float_list(); - float* flat = floatTensor->flat().data() + floatDataIndex; - std::copy(floatList.value().begin(), floatList.value().end(), flat); - floatDataIndex += floatList.value_size(); - } - } - spdlog::info("ReadRaw sampleId:{} cost:{} copy:{} , elapsed from last:{}", sampleId++, - duration_cast((sw).elapsed()), - duration_cast((sw_copy).elapsed()), - duration_cast((staticReadRaw).elapsed())); - staticReadRaw.reset(); - } - - int intLen; - int floatLen; - vector featOrder; - atomic sampleId; -}; - -REGISTER_KERNEL_BUILDER(Name("ReadRaw").Device(DEVICE_CPU), ReadRaw); - - -// ##################### ReadRawDummy ####################### -REGISTER_OP("ReadRawDummy") - .Input("sample: int64") - .Output("int_output: int64") - .Output("float_output: float") - .Attr("int_len: int") - .Attr("float_len: int") - .SetShapeFn([](InferenceContextPtr c) { - int temp; - TF_RETURN_IF_ERROR(c->GetAttr("int_len", &temp)); - c->set_output(TENSOR_INDEX_0, c->Vector(temp)); - TF_RETURN_IF_ERROR(c->GetAttr("float_len", &temp)); - c->set_output(TENSOR_INDEX_1, c->Vector(temp)); - return Status::OK(); - }); - -class ReadRawDummy : public OpKernel { -public: - explicit ReadRawDummy(OpKernelConstructionPtr context) : OpKernel(context) - { - OP_REQUIRES_OK(context, context->GetAttr("int_len", &intLen)); - OP_REQUIRES_OK(context, context->GetAttr("float_len", &floatLen)); - } - - ~ReadRawDummy() override = default; - - void Compute(OpKernelContextPtr context) override - { - spdlog::stopwatch sw; - Tensor* intTensor = nullptr; - Tensor* floatTensor = nullptr; - - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape { intLen }, &intTensor)); - OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape { floatLen }, &floatTensor)); - - const Tensor& inputTensor = context->input(TENSOR_INDEX_0); - auto input = inputTensor.flat(); - int32_t batchId = static_cast(input(0)); - - spdlog::info("ReadRawDummy cost:{}, elapsed from last:{} , batchId = {}", - duration_cast((sw).elapsed()), - duration_cast((staticReadRaw).elapsed()), batchId); - staticReadRaw.reset(); - } - - int intLen; - int floatLen; -}; - -REGISTER_KERNEL_BUILDER(Name("ReadRawDummy").Device(DEVICE_CPU), ReadRawDummy); - class CustOps : public OpKernel { public: explicit CustOps(OpKernelConstructionPtr context) : OpKernel(context) diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 52fd21c8..13a29e5e 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -113,7 +113,7 @@ protected: // delete } }; - +#ifndef GTEST TEST_F(EmbMgmtTest, Initialize) { vector vocabsize = { devVocabSize, hostVocabSize }; @@ -169,7 +169,9 @@ TEST_F(EmbMgmtTest, Initialize) hybridMgmt->Destroy(); } +#endif +#ifndef GTEST TEST_F(EmbMgmtTest, Initialize_HBM) { devVocabSize = HBM_DEVICE_SIZE; @@ -188,7 +190,9 @@ TEST_F(EmbMgmtTest, Initialize_HBM) hybridMgmt->Destroy(); } +#endif +#ifndef GTEST TEST_F(EmbMgmtTest, Evict) { size_t devVocabSize = DDR_DEVICE_SIZE; @@ -210,9 +214,11 @@ TEST_F(EmbMgmtTest, Evict) hybridMgmt->Destroy(); } +#endif TEST_F(EmbMgmtTest, Evict_HBM) { +#ifndef GTEST devVocabSize = HBM_DEVICE_SIZE; hostVocabSize = HBM_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; @@ -232,4 +238,5 @@ TEST_F(EmbMgmtTest, Evict_HBM) hybridMgmt->EvictKeys(name, keys); hybridMgmt->Destroy(); +#endif } diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 19c7e0a8..0df6411b 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -24,7 +24,7 @@ using namespace std; using namespace MxRec; using namespace testing; -static constexpr size_t BATCH_NUM_EACH_THREAD = 5; +static constexpr size_t BATCH_NUM_EACH_THREAD = 3; class KeyProcessTest : public testing::Test { protected: @@ -52,16 +52,16 @@ protected: vector> PrepareBatch() { - vector> result(PerfConfig::keyProcessThreadNum * MAX_CHANNEL_NUM); - // 向共享队列中写入本进程所有线程要处理的 PerfConfig::keyProcessThreadNum * BATCH_NUM_EACH_THREAD 个batch数据 - for (size_t threadId = 0; threadId < PerfConfig::keyProcessThreadNum; ++threadId) { - int batchQueueId = threadId + PerfConfig::keyProcessThreadNum * channel; + vector> result(KEY_PROCESS_THREAD * MAX_CHANNEL_NUM); + // 向共享队列中写入本进程所有线程要处理的 KEY_PROCESS_THREAD * BATCH_NUM_EACH_THREAD 个batch数据 + for (size_t threadId = 0; threadId < KEY_PROCESS_THREAD; ++threadId) { + int batchQueueId = threadId + KEY_PROCESS_THREAD * channel; unsigned int seed = batchQueueId * 10; auto queue = SingletonQueue::getInstances(batchQueueId); for (size_t batchNum = 0; batchNum < BATCH_NUM_EACH_THREAD; ++batchNum) { size_t batchId = - batchNum * PerfConfig::keyProcessThreadNum + threadId; + batchNum * KEY_PROCESS_THREAD + threadId; for (size_t i = 0; i < embInfos.size(); i++) { // key按照不同emb表的存储切分开 auto batch = queue->GetOne(); @@ -112,7 +112,7 @@ protected: { default_random_engine generator; uniform_int_distribution distribution(randMin, randMax); - int embSizeMin = 5, embSizeMax = 8, base = 2; + int embSizeMin = 5, embSizeMax = 8, base = 2, vocabSize = 100; uniform_int_distribution embSizeDistribution(embSizeMin, embSizeMax); stringstream ss; for (unsigned int i = 0; i < embNums; ++i) { @@ -123,6 +123,7 @@ protected: ss.clear(); temp.sendCount = distribution(generator); temp.extEmbeddingSize = pow(base, embSizeDistribution(generator)); + temp.devVocabSize = vocabSize; geFieldNums.push_back(sampleSize); allEmbInfos.push_back(move(temp)); } @@ -250,7 +251,7 @@ TEST_F(KeyProcessTest, HashSplit) } ASSERT_THAT(restore, ElementsAreArray(expectRestore)); } - +#ifndef GTEST TEST_F(KeyProcessTest, GetScAll) { vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 @@ -265,7 +266,7 @@ TEST_F(KeyProcessTest, GetScAll) process.GetScAll(keyScLocal, 0, 0, scAll); ASSERT_THAT(scAll, ElementsAreArray(expectScAll)); } - +#endif TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { auto queue = SingletonQueue::getInstances(0); @@ -309,7 +310,7 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) hotPos); }; // for clean code for (int channel = 0; channel < 1; ++channel) { - for (int id = 0; id < PerfConfig::keyProcessThreadNum; ++id) { + for (int id = 0; id < 1; ++id) { // use lambda expression initialize thread process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } @@ -339,7 +340,7 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) batch->batchId, lookupKeys, scAll, restore); }; // for clean code for (int channel = 0; channel < 1; ++channel) { - for (int id = 0; id < PerfConfig::keyProcessThreadNum; ++id) { + for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { // use lambda expression initialize thread process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } @@ -354,7 +355,7 @@ TEST_F(KeyProcessTest, Key2Offset) keys_t expectOffset = { 0, 1, 2, 0, 3, 0, 4, 3 }; ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); ASSERT_EQ(process.isRunning, true); - process.Key2Offset(emb_name_t(), lookupKeys); + process.Key2Offset("emb0", lookupKeys); spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys, process.keyOffsetMap); ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); } -- Gitee From d669c0e74a52e3e20ae1827601de45c2e02299ef Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 9 Jun 2023 11:03:39 +0800 Subject: [PATCH 123/551] Match-id-eea4fdca7df23573796664176afcea61f2fef6a0 --- mx_rec/util/initialize.py | 11 ++++++----- mx_rec/validator/validator.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 3b19d79d..0c08f2fe 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -10,10 +10,10 @@ from collections import defaultdict import mxrec_pybind import psutil -from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST -from mx_rec.constants.constants import LOCAL_RANK_SIZE, MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE +from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST, LOCAL_RANK_SIZE, \ + MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH from mx_rec.util.ops import import_host_pipeline_ops -from mx_rec.validator.validator import RankInfoValidator +from mx_rec.validator.validator import RankInfoValidator, StringValidator class ConfigInitializer: @@ -400,8 +400,9 @@ class ConfigInitializer: @ascend_global_hashtable_collection.setter def ascend_global_hashtable_collection(self, name): - if not isinstance(name, str): - raise TypeError(f"collection name '{name}' must be a string.") + string_validator = StringValidator(name, max_len=HASHTABLE_COLLECTION_NAME_LENGTH, min_len=1) + if not string_validator.check_string_length().check_whitelist().is_valid(): + raise ValueError(string_validator.msg) self._ascend_global_hashtable_collection = name def get_initializer(self, is_training): diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index c18d989c..bc738915 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -4,6 +4,7 @@ import os +import re from typing import Callable, Any from typing import List, Optional, Tuple @@ -78,6 +79,7 @@ class StringValidator(Validator): super().__init__(value) self.max_len = max_len self.min_len = min_len + self.whitelist = "^[0-9A-Za-z_]+$" self.register_checker(lambda x: isinstance(x, str), "type is not str") def check_string_length(self): @@ -91,6 +93,13 @@ class StringValidator(Validator): self.register_checker(lambda x: x is not None and element is not None and x.find(element) == -1) return self + def check_whitelist(self): + """Perform whitelist verification on the input string""" + self.register_checker(lambda x: x is not None and re.match(self.whitelist, x) is not None, + "The string is invalid, please check the input string. " + "Note: It should be a string consisting of numbers, letters, and underscores.") + return self + def can_be_transformed2int(self, min_value: int = None, max_value: int = None): if min_value is None: min_value = MIN_RANK_SIZE @@ -275,6 +284,7 @@ class RankInfoValidator: """ Check replace rank table system environment configuration. """ + @staticmethod def check_visible_devices(): visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") -- Gitee From c0eb89dec8564fc880b85eebf48d7032d1cdf830 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 9 Jun 2023 15:48:12 +0800 Subject: [PATCH 124/551] Match-id-d938b44da1af84883c873b434fadd072dc950aae --- example/little_demo/main.py | 54 ++++++++++++++++++++++--------------- example/little_demo/run.sh | 5 ++++ 2 files changed, 38 insertions(+), 21 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 7bf70769..ddd51453 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -59,19 +59,26 @@ def build_graph(hash_table_list, is_train, feature_spec_list=None, config_dict=N use_timestamp=USE_TIMESTAMP, dump_graph=is_train, batch_number=batch_number) if MODIFY_GRAPH_FLAG: - input_list = [ - [batch["user_ids"], batch["item_ids"], batch["user_ids"], batch["item_ids"]], - [hash_table_list[0], hash_table_list[0], hash_table_list[0], hash_table_list[1]], - [cfg.user_send_cnt, cfg.item_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt], - ] + input_list = [[batch["user_ids"], batch["item_ids"]], + [hash_table_list[0], hash_table_list[1]], + [cfg.user_send_cnt, cfg.item_send_cnt]] + if use_multi_lookup: + input_list = [[batch["user_ids"], batch["item_ids"], batch["user_ids"], batch["item_ids"]], + [hash_table_list[0], hash_table_list[0], hash_table_list[0], hash_table_list[1]], + [cfg.user_send_cnt, cfg.item_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt]] if USE_TIMESTAMP: tf.add_to_collection(ASCEND_TIMESTAMP, batch["timestamp"]) model = model_forward(input_list, batch, is_train=is_train, modify_graph=True, config_dict=config_dict) else: - hash_table_list = [hash_table_list[0], hash_table_list[0], hash_table_list[0], hash_table_list[1]] - send_cnt_list = [cfg.user_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt, cfg.item_send_cnt] - model = model_forward([feature_spec_list, hash_table_list, send_cnt_list], batch, + input_list = [feature_spec_list, + [hash_table_list[0], hash_table_list[1]], + [cfg.user_send_cnt, cfg.item_send_cnt]] + if use_multi_lookup: + input_list = [feature_spec_list, + [hash_table_list[0], hash_table_list[1], hash_table_list[0], hash_table_list[0]], + [cfg.user_send_cnt, cfg.item_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt]] + model = model_forward(input_list, batch, is_train=is_train, modify_graph=False, config_dict=config_dict) return iterator, model @@ -98,15 +105,16 @@ def create_feature_spec_list(use_timestamp=False): feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", access_threshold=access_threshold, eviction_threshold=eviction_threshold), - FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", - access_threshold=access_threshold, - eviction_threshold=eviction_threshold), - FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table", - access_threshold=access_threshold, - eviction_threshold=eviction_threshold), FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table", access_threshold=access_threshold, eviction_threshold=eviction_threshold)] + if use_multi_lookup: + feature_spec_list.extend([FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold), + FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold)]) if use_timestamp: feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) return feature_spec_list @@ -120,18 +128,23 @@ if __name__ == "__main__": TRAIN_INTERVAL = 100 EVAL_STEPS = 10 SAVING_INTERVAL = 100 - USE_TIMESTAMP = False - # add dynamic expansion support + # get init configuration + use_mpi = bool(int(os.getenv("USE_MPI", 1))) + use_dynamic = int(os.getenv("USE_DYNAMIC", 0)) + use_hot = bool(int(os.getenv("USE_HOT", 0))) use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) + use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 1))) + MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) + USE_TIMESTAMP = bool(int(os.getenv("USE_TIMESTAMP", 0))) # nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0 - init(use_mpi=bool(int(os.getenv("USE_MPI"))), + init(use_mpi=use_mpi, train_interval=TRAIN_INTERVAL, eval_steps=EVAL_STEPS, prefetch_batch_number=5, - use_dynamic=int(os.getenv("USE_DYNAMIC", 0)), - use_hot=bool(int(os.getenv("USE_HOT", 0))), + use_dynamic=use_dynamic, + use_hot=use_hot, use_dynamic_expansion=use_dynamic_expansion) IF_LOAD = False rank_id = get_rank_id() @@ -140,8 +153,6 @@ if __name__ == "__main__": IF_LOAD = True set_if_load(IF_LOAD) - MODIFY_GRAPH_FLAG = False # ASC + use_MPI + modify_graph - cfg = Config() # access_threshold unit counts; eviction_threshold unit seconds ACCESS_AND_EVICT = None @@ -198,6 +209,7 @@ if __name__ == "__main__": if use_dynamic_expansion: from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET + train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) # do sparse optimization by addr diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 3ad0a4eb..e4c1cff6 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -29,9 +29,14 @@ export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info export MXREC_MODE="ASC" export USE_MPI=1 +################# 参数配置 ###################### export USE_DYNAMIC=0 # 0:静态shape;1:动态shape export USE_HOT=0 # 0:关闭hot emb;1: 开启hot emb export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 +export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 +export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 +export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 +################################################ export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 #################使用去除ranktable方案时开启###################### -- Gitee From 07404241eed21dcd56b8bf1a580c6a1cd1bce1a6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 12 Jun 2023 09:27:59 +0800 Subject: [PATCH 125/551] Match-id-02bd53d4eb218f0c0d887ef6e11e7d3325bc5331 --- mx_rec/core/asc/manager.py | 3 +- mx_rec/optimizers/gradient_descent_by_addr.py | 127 +----------------- src/core/emb_table/emb_table.cpp | 19 ++- src/core/emb_table/emb_table.h | 2 +- src/core/host_emb/host_emb.cpp | 19 ++- 5 files changed, 34 insertions(+), 136 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 3f193fc2..426dd525 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -140,7 +140,8 @@ def matched_opt_slot_initializers(table_instance): start_index = table_instance.scalar_emb_size slot_initializers = [] - + logging.debug(f"matched_opt_slot_initializers, scalar emb size:{table_instance.ext_emb_size}, " + f"optimizer_instance_list size:{len(table_instance.optimizer_instance_list)}") for optimizer in table_instance.optimizer_instance_list: for slot_init_value in optimizer.get_slot_init_values(): slot_initializer = InitializeInfo(name="constant_initializer", diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index 596b0375..de00dc06 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -6,15 +6,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import abc import logging from collections import defaultdict -from tensorflow.python.framework import ops, indexed_slices from tensorflow.python.ops import math_ops -from tensorflow.python.training import optimizer -from tensorflow.python.eager import context -from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer @@ -29,90 +25,28 @@ def create_hash_optimizer_by_addr(learning_rate, weight_decay=0.0001, use_lockin return optimizer_by_addr -class CustomizedGradientDescentByAddr(optimizer.Optimizer, CustomizedOptimizer): +class CustomizedGradientDescentByAddr(gradient_descent.GradientDescentOptimizer, CustomizedOptimizer): name_counter = defaultdict(int) def __init__(self, learning_rate, weight_decay, use_locking=False, name="GradientDescentByAddr"): self.optimizer_type = "gradient_descent_by_addr" self.weight_decay = weight_decay - super(CustomizedGradientDescentByAddr, self).__init__(use_locking, name) + super(CustomizedGradientDescentByAddr, self)._get_name(name=name) + super(CustomizedGradientDescentByAddr, self).__init__(learning_rate=learning_rate, use_locking=use_locking, + name=self.unique_name) - self._learning_rate = learning_rate - self._learning_rate_tensor = None self._slot_num = 0 @property def slot_num(self): return self._slot_num - def apply_gradients(self, grads_and_vars, global_step=None, name=None): - # No DistributionStrategy case. - grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. - if not grads_and_vars: - raise ValueError("No variables provided.") - converted_grads_and_addrs = tuple(self._convert_grads_and_addrs(grads_and_vars)) - - addr_list = [a for g, a, _ in converted_grads_and_addrs if g is not None] - if not addr_list: - raise ValueError("No gradients provided for any address: %s." % - ([str(a) for _, a, _ in converted_grads_and_addrs],)) - with ops.init_scope(): - self._create_slots(addr_list) - - update_ops = [] - with ops.name_scope(name, self._name) as name: - self._prepare() - for grad, addr, processor in converted_grads_and_addrs: - if grad is None: - continue - if (context.executing_eagerly() or - resource_variable_ops.is_resource_variable(addr) - and not addr._in_graph_mode): # pylint: disable=protected-access - scope_name = "" - else: - scope_name = addr.op.name - with ops.name_scope( - "update_" + scope_name), ops.colocate_with(addr): - update_ops.append(processor.update_op(self, grad)) - - apply_updates = self._finish(update_ops, name) - - if not context.executing_eagerly(): - if isinstance(apply_updates, ops.Tensor): - logging.debug(">>>>Enter ops.Tensor") - apply_updates = apply_updates.op - train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) - if apply_updates not in train_op: - logging.debug(">>>>Enter apply_updates not in train_op") - train_op.append(apply_updates) - else: - raise RuntimeError("eager wrong.") - - return apply_updates + def initialize_slots(self, var, table_instance): + return [] def get_slot_init_values(self): return [] - def _convert_grads_and_addrs(self, grads_and_vars): - converted_grads_and_addrs = [] - for grad, addr in grads_and_vars: - if grad is not None: - try: - # Convert the grad to Tensor or IndexedSlices if necessary. - grad = ops.convert_to_tensor_or_indexed_slices(grad) - except TypeError as error: - raise TypeError("Gradient must be convertible to a Tensor or IndexedSlices, or None") from error - if not isinstance(grad, (ops.Tensor, indexed_slices.IndexedSlices)): - raise TypeError("Gradient must be a Tensor, IndexedSlices, or None") - processor = _get_processor(addr) - converted_grads_and_addrs.append((grad, addr, processor)) - return converted_grads_and_addrs - - def _prepare(self): - learning_rate = self._call_if_callable(self._learning_rate) - self._learning_rate_tensor = ops.convert_to_tensor( - learning_rate, name="learning_rate") - def _apply_sparse(self, grad, addr): logging.debug(">>>> Enter _apply_sparse SGD by addr") host_pipeline_ops = get_host_pipeline_ops() @@ -133,51 +67,4 @@ class CustomizedGradientDescentByAddr(optimizer.Optimizer, CustomizedOptimizer): raise NotImplementedError("You are using a wrong type of variable.") -def get_filtered_grad_fn(grad_fn): - def filtered_grad_fn(*args, **kwargs): - return [(g, a) for g, a in grad_fn(*args, **kwargs) if g is not None] - - return filtered_grad_fn - - -class _OptimizableAddr(metaclass=abc.ABCMeta): - """Interface for abstracting over addresses in the optimizers.""" - - @abc.abstractmethod - def target(self): - """Returns the optimization target for this address.""" - raise NotImplementedError("Calling an abstract method.") - - @abc.abstractmethod - def update_op(self, opt, grad): - """Returns the update ops for updating the address.""" - raise NotImplementedError("Calling an abstract method.") - -def _get_processor(addr): - """The processor of v.""" - if isinstance(addr, ops.Tensor): - logging.debug(">>>>Enter _get_processor tensor") - return _TensorByAddressProcessor(addr) - raise NotImplementedError("Trying to optimize unsupported type ", addr) - - -class _TensorByAddressProcessor(_OptimizableAddr): - """Processor for Tensor filled with addresses.""" - - def __init__(self, addr): - self._a = addr - - def __str__(self): - return "<_TensorByAddressProcessor(%s)>" % self._a - - def target(self): - return self._a - - def update_op(self, opt, grad): - if isinstance(grad, ops.Tensor): - logging.debug(">>>>Enter update_op ops.Tensor") - update_op = opt._apply_sparse(grad, self._a) # pylint: disable=protected-access - return update_op - else: - raise RuntimeError("Only support g with type Tensor.") diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 78e9e815..4422abb0 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -112,26 +112,31 @@ void EmbTable::PutEmbAddress(int64_t curAddress) void EmbTable::RandomInit(void* newBlock, const vector& initializeInfos, int seed) { #ifndef GTEST - spdlog::info("Device GenerateEmbData Start, seed:{}", seed); + spdlog::info("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); vector devEmb(blockSize); for (auto initializeInfo: initializeInfos) { Initializer* initializer; switch (initializeInfo.initializerType) { case InitializerType::CONSTANT: { - spdlog::info("Device GenerateEmbData ing using Constant Initializer by value {}.", - initializeInfo.constantInitializerInfo.constantValue); + spdlog::info("Device GenerateEmbData ing using Constant Initializer by value {}. name {}, start {}, " + "len {}.", initializeInfo.constantInitializerInfo.constantValue, + initializeInfo.name, initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.constantInitializer; break; } case InitializerType::TRUNCATED_NORMAL: { - spdlog::info("Device GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}.", - initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + spdlog::info("Device GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}. " + "name {}, start {}, len {}.", initializeInfo.normalInitializerInfo.mean, + initializeInfo.normalInitializerInfo.stddev, initializeInfo.name, + initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.truncatedNormalInitializer; break; } case InitializerType::RANDOM_NORMAL: { - spdlog::info("Device GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}.", - initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + spdlog::info("Device GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}. " + "name {}, start {}, len {}.", initializeInfo.normalInitializerInfo.mean, + initializeInfo.normalInitializerInfo.stddev, initializeInfo.name, + initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.randomNormalInitializer; break; } diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h index 5a1c0927..34370add 100644 --- a/src/core/emb_table/emb_table.h +++ b/src/core/emb_table/emb_table.h @@ -56,7 +56,7 @@ namespace MxRec { list LoadEmb(const vector> &savedEmb); GTEST_PRIVATE: - constexpr static int BLOCK_EMB_COUNT = 1000; + constexpr static int BLOCK_EMB_COUNT = 100000; constexpr static int INIT_BLOCK_COUNT = 5; constexpr static int TEST_EMB_SIZE = 12; EmbInfo embInfo; diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 374e1fb6..453797c9 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -36,7 +36,7 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in int embeddingSize, vector> &embData) { #ifndef GTEST - spdlog::info(HOSTEMB + "GenerateEmbData Start, seed:{}", seed); + spdlog::info(HOSTEMB + "GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); embData.clear(); embData.resize(vocabSize, vector(embeddingSize)); @@ -45,20 +45,25 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in switch (initializeInfo.initializerType) { case InitializerType::CONSTANT: { - spdlog::info(HOSTEMB + "GenerateEmbData ing using Constant Initializer by value {}.", - initializeInfo.constantInitializerInfo.constantValue); + spdlog::info(HOSTEMB + "GenerateEmbData ing using Constant Initializer by value {}. name {}, " + "start {}, len {}.", initializeInfo.constantInitializerInfo.constantValue, + initializeInfo.name, initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.constantInitializer; break; } case InitializerType::TRUNCATED_NORMAL: { - spdlog::info(HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}.", - initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + spdlog::info(HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}. " + "name {}, start {}, len {}.", initializeInfo.normalInitializerInfo.mean, + initializeInfo.normalInitializerInfo.stddev, initializeInfo.name, + initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.truncatedNormalInitializer; break; } case InitializerType::RANDOM_NORMAL: { - spdlog::info(HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}.", - initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + spdlog::info(HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}. " + "name {}, start {}, len {}.", initializeInfo.normalInitializerInfo.mean, + initializeInfo.normalInitializerInfo.stddev, initializeInfo.name, + initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.randomNormalInitializer; break; } -- Gitee From 9919d8a7559da8e14d0f4afd9ac8384b7a5f1430 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 12 Jun 2023 10:37:57 +0800 Subject: [PATCH 126/551] Match-id-cb9d777fbc486b3c3436503e149a17a188064052 --- mx_rec/core/asc/helper.py | 2 +- mx_rec/graph/modifier.py | 19 +++++++++++++------ mx_rec/graph/utils.py | 2 +- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 5327276e..d69fb8e0 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -401,7 +401,7 @@ def get_valid_op_key(batch_dict: dict) -> str: def get_target_tensors_with_args_indexes(args_index_list): insert_tensors = [] - graph = tf.get_default_graph() + graph = tf.compat.v1.get_default_graph() for index in args_index_list: tensor = graph.get_tensor_by_name("args_%d:0" % index) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index b53e00d5..a0a7eed6 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -60,7 +60,7 @@ def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tenso input_tensors.append(tensor) else: - graph = tf.get_default_graph() + graph = tf.compat.v1.get_default_graph() for index in pipeline_input_indexes: tensor = graph.get_tensor_by_name("args_%d:0" % index) input_tensors.append(tensor) @@ -94,7 +94,7 @@ def get_input_index_list(cutting_point_list, replacement_specs, mapping_name_lis def find_make_iterator_op(batch_tensor): - graph = tf.get_default_graph() + graph = tf.compat.v1.get_default_graph() operations = graph.get_operations() for each_op in operations: for input_tensor in batch_tensor.op.inputs: @@ -132,10 +132,17 @@ def get_op_before_optimize_dataset(get_next_op): # looking for the MakeIterator operator which corresponds to given batch_tensor base_op = find_make_iterator_op(get_next_op.outputs[0]) # looking for the op which is the one before OptimizeDataset operator - target_op = find_target_dataset_op(base_op, "OptimizeDataset") - if find_parent_op(target_op)[0].type == "PrefetchDataset": - target_op = find_parent_op(target_op)[0] - + if tf.__version__.startswith("1"): + optimize_dataset_op = find_target_dataset_op(base_op, "OptimizeDataset") + target_op = find_parent_op(optimize_dataset_op) + if not target_op: + raise RuntimeError(f"The parent op for 'OptimizeDataset' op was not found.") + if target_op[0].type != "PrefetchDataset": + raise TypeError(f"Op PrefetchDataset was not found.") + target_op = target_op[0] + else: + # 'OptimizeDataset' is not available in TensorFlow2.X + target_op = find_target_dataset_op(base_op, "PrefetchDataset") return target_op diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index f2851b7a..5399f5d6 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -40,7 +40,7 @@ def check_cutting_points(cutting_point_list): def record_ops_to_replace(src_op): replacement_specs = defaultdict(list) output_list = src_op.outputs - op_list = tf.get_default_graph().get_operations() + op_list = tf.compat.v1.get_default_graph().get_operations() for tensor in output_list: for operator in op_list: if tensor in operator.inputs: -- Gitee From 087045f5acc54f9d0f8f052ca89c2a2364fcb54a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 12 Jun 2023 15:19:59 +0800 Subject: [PATCH 127/551] Match-id-fb2e1655debdd7e44bbb959e0549201b4da6baba --- mx_rec/core/asc/build_graph.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 4f30e5b4..184210b6 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -134,9 +134,6 @@ def get_preprocessed_tensor_for_asc(table, config, ids_channel_name=None, modify swap_out_op = npu_ops.outfeed_enqueue_op( channel_name=f'{config.get("table_name")}_d2h_{config.get("channel_id")}', inputs=[swap_out]) with tf.control_dependencies([swap_out_op]): - # fix empty nd update - swap_pos = tf.concat([swap_pos, tf.constant([AVOID_TENSOR_POS])], axis=0) - h2d_emb = tf.concat([h2d_emb, tf.constant([[0.1] * config.get("ext_emb_size")])], axis=0) nd_swap_pos = tf.expand_dims(swap_pos, 1) table_num = len(table) h2d_emb_split = tf.split(h2d_emb, table_num, axis=1) -- Gitee From ca267d01bf9fe448a0768be793bcb11203fa8e53 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 12 Jun 2023 16:58:29 +0800 Subject: [PATCH 128/551] Match-id-95b7bd16c53e8deb5315ef87323ab416d5d71d60 --- mx_rec/core/embedding.py | 5 +++-- mx_rec/util/initialize.py | 7 +++++-- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 9 +++++++-- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index de507c00..e43c71c2 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -886,7 +886,7 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}')[0] initialized_tensor = instance.emb_initializer( - evict_pos.shape.as_list()[0] + instance.embedding_size) + tf.shape(evict_pos)[0] + instance.embedding_size) logging.debug(f'evict_pos output shape {evict_pos}, and slice_device_vocabulary_size ' f'{instance.slice_device_vocabulary_size}, ' @@ -909,7 +909,8 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): if cur_time - self._start_time > self._evict_time_interval or \ (self._evict_step_interval is not None and self._global_step % self._evict_step_interval == 0): logging.info(f"_EvictHook - > evict switch on!!! after_run step: {self._global_step}") - trigger_evict() + if not trigger_evict(): + return self._start_time = cur_time for name in self._hash_table_instance.keys(): run_context.session.run(self._evict_op.get(name)) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 0c08f2fe..baaba222 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -482,8 +482,11 @@ def trigger_evict(): if not is_asc_manager_initialized(): raise RuntimeError("ASC manager does not exist.") - ConfigInitializer.get_instance().get_asc_manager().evict() - logging.debug("Feature evict is triggered by ops.") + if ConfigInitializer.get_instance().get_asc_manager().evict(): + logging.debug("Feature evict is triggered by ops.") + return True + logging.warning("Feature evict not success, skip this time!") + return False def clear_channel(is_train_channel=False): diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 15de3b73..b20388cf 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -596,7 +596,7 @@ void HybridMgmt::EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embI /* * hook通过时间或者step数触发淘汰 */ -void HybridMgmt::Evict() +bool HybridMgmt::Evict() { #ifndef GTEST auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); @@ -604,9 +604,13 @@ void HybridMgmt::Evict() featAdmitNEvict.FeatureEvict(evictKeyMap); } else { spdlog::warn(MGMT + "Hook can not trigger evict, cause AdmitNEvict is not open"); - return; + return false; } spdlog::debug(MGMT + "evict triggered by hook, evict TableNum {} ", evictKeyMap.size()); + if (evictKeyMap.size() == 0) { + spdlog::warn(MGMT + "evict triggered by hook before dataset in injected"); + return false; + } if (mgmtRankInfo.noDDR) { for (auto evict : evictKeyMap) { @@ -617,6 +621,7 @@ void HybridMgmt::Evict() EvictKeys(evict.first, evict.second); } } + return true; #endif } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index f834b9f6..079897ac 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -87,7 +87,7 @@ namespace MxRec { void EmbHDTrans(const int channelId, const int batchId); - void Evict(); + bool Evict(); void EvictKeys(const string& embName, const vector& keys); -- Gitee From 7eac68f0295232afb984531d4bddb6b174902a86 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 10 Jun 2023 17:08:23 +0800 Subject: [PATCH 129/551] Match-id-115487bca7d97ec3d8cfcf5e81983ebed3d4e969 --- mx_rec/core/embedding.py | 4 +--- mx_rec/optimizers/ftrl.py | 12 ++++++------ mx_rec/optimizers/ftrl_t.py | 20 ++++++++++---------- mx_rec/optimizers/ftrl_t_dense.py | 13 +++++++------ mx_rec/optimizers/momentum.py | 9 ++++----- mx_rec/saver/patch.py | 18 +++++++++--------- mx_rec/saver/saver.py | 8 ++++---- mx_rec/util/variable.py | 10 ---------- src/core/checkpoint/checkpoint.cpp | 4 ++++ 9 files changed, 45 insertions(+), 53 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index de507c00..4adc67b0 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -26,7 +26,6 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is ConfigInitializer, get_ascend_global_hashtable_collection, get_host_pipeline_ops, get_use_dynamic_expansion, \ set_modify_graph, insert_removing_var_list from mx_rec.util.tf_version_adapter import npu_ops -from mx_rec.util.variable import remove_saving_var def create_table(**kwargs): @@ -754,9 +753,8 @@ class SparseEmbedding: def _initialize_variables(self): initialized_tensor = self.emb_initializer(self.slice_device_vocabulary_size + self.embedding_size) self.variable = tf.compat.v1.get_variable(self.table_name, trainable=False, initializer=initialized_tensor) - insert_removing_var_list(self.variable.name) # make sure sparse table variable will not be saved and restored within tf checkpoint. - remove_saving_var(self.variable) + insert_removing_var_list(self.variable.name) self._record() if self.use_dynamic_expansion: diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 5498a932..d2c39367 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -20,8 +20,8 @@ from tensorflow.python.training import ftrl from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance -from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var, check_param_type, check_param_range +from mx_rec.util.initialize import get_table_instance, insert_removing_var_list +from mx_rec.util.variable import check_and_get_config_via_var, check_param_type, check_param_range def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl", **kwargs): @@ -73,8 +73,8 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): accum = slot_creator.create_slot(var, val, self._name + "/" + "accum") linear = slot_creator.create_zeros_slot(var, self._name + "/" + "linear") - remove_saving_var(accum) - remove_saving_var(linear) + insert_removing_var_list(accum.name) + insert_removing_var_list(linear.name) named_slot_key = (var.op.graph, var.op.name) table_instance = get_table_instance(var) if self._name in table_instance.optimizer: @@ -236,8 +236,8 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): accum = self._get_or_make_slot(each_var, val, "accum", accum_state_name) linear = self._zeros_slot(each_var, "linear", linear_state_name) # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - remove_saving_var(accum) - remove_saving_var(linear) + insert_removing_var_list(accum.name) + insert_removing_var_list(linear.name) if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"accum": accum, "linear": linear}) diff --git a/mx_rec/optimizers/ftrl_t.py b/mx_rec/optimizers/ftrl_t.py index 0dedc009..4337841a 100644 --- a/mx_rec/optimizers/ftrl_t.py +++ b/mx_rec/optimizers/ftrl_t.py @@ -20,8 +20,8 @@ from tensorflow.python.training import optimizer from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance -from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var +from mx_rec.util.initialize import get_table_instance, insert_removing_var_list +from mx_rec.util.variable import check_and_get_config_via_var def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl_t", **kwargs): @@ -60,10 +60,10 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): n = slot_creator.create_zeros_slot(var, self._name + "/" + "n") g = slot_creator.create_zeros_slot(var, self._name + "/" + "g") w = slot_creator.create_zeros_slot(var, self._name + "/" + "w") - remove_saving_var(z) - remove_saving_var(n) - remove_saving_var(g) - remove_saving_var(w) + insert_removing_var_list(z.name) + insert_removing_var_list(n.name) + insert_removing_var_list(g.name) + insert_removing_var_list(w.name) named_slot_key = (var.op.graph, var.op.name) table_instance = get_table_instance(var) if self._name in table_instance.optimizer: @@ -245,10 +245,10 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): g = self._zeros_slot(each_var, "g", g_state_name) w = self._zeros_slot(each_var, "w", w_state_name) # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - remove_saving_var(z) - remove_saving_var(n) - remove_saving_var(g) - remove_saving_var(w) + insert_removing_var_list(z.name) + insert_removing_var_list(n.name) + insert_removing_var_list(g.name) + insert_removing_var_list(w.name) if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"z": z, "n": n, "g": g, "w": w}) diff --git a/mx_rec/optimizers/ftrl_t_dense.py b/mx_rec/optimizers/ftrl_t_dense.py index 412a7617..40573f59 100644 --- a/mx_rec/optimizers/ftrl_t_dense.py +++ b/mx_rec/optimizers/ftrl_t_dense.py @@ -20,8 +20,8 @@ from tensorflow.python.training import optimizer from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance -from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var +from mx_rec.util.initialize import get_table_instance, insert_removing_var_list +from mx_rec.util.variable import check_and_get_config_via_var def create_ftrl_dense_optimizer(learning_rate, use_locking=False, name="Ftrl_t_dense", **kwargs): @@ -183,8 +183,9 @@ class CustomizedFtrlTZ(optimizer.Optimizer): g_zero = self._zeros_slot(each_var, "g", g_state_name) w_zero = self._zeros_slot(each_var, "w", w_state_name) # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - remove_saving_var(z_zero) - remove_saving_var(n_zero) - remove_saving_var(g_zero) - remove_saving_var(w_zero) + insert_removing_var_list(z_zero.name) + insert_removing_var_list(n_zero.name) + insert_removing_var_list(g_zero.name) + insert_removing_var_list(w_zero.name) + diff --git a/mx_rec/optimizers/momentum.py b/mx_rec/optimizers/momentum.py index 2424b8f6..df4adaeb 100644 --- a/mx_rec/optimizers/momentum.py +++ b/mx_rec/optimizers/momentum.py @@ -16,8 +16,8 @@ from tensorflow.python.training import momentum from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance -from mx_rec.util.variable import remove_saving_var, check_and_get_config_via_var, check_param_type, check_param_range +from mx_rec.util.initialize import get_table_instance, insert_removing_var_list +from mx_rec.util.variable import check_and_get_config_via_var, check_param_type, check_param_range def create_hash_optimizer(learning_rate_input=0.001, mom=0.9, enable_locking=False, optimizer_name="momentum", @@ -71,7 +71,7 @@ class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): return new_slot_variable momentum_slot = creat_one_single_slot(var, self._name + "/" + "momentum") - remove_saving_var(momentum_slot) + insert_removing_var_list(momentum_slot.name) named_slot_key = (var.op.graph, var.op.name) table_instance = get_table_instance(var) if self._name in table_instance.optimizer: @@ -107,8 +107,7 @@ class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): for var in var_list: table_instance = check_and_get_config_via_var(var, self.optimizer_type) momentum_slot = self._zeros_slot(var, "m", m_state_name) - - remove_saving_var(momentum_slot) + insert_removing_var_list(momentum_slot.name) if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"momentum": momentum_slot}) logging.debug(" End _create_slots") diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 89df2598..58469617 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -30,19 +30,18 @@ from mx_rec.util.initialize import get_ascend_global_hashtable_collection, expor def get_sparse_vars(var_list): + sparse_var_list = [] # build sparse saver if var_list is not None: if not isinstance(var_list, (list, tuple)): raise TypeError("A non-None var_list must be a list or tuple.") ascend_variables = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) - sparse_var_list = [] for var in var_list: if var in ascend_variables: sparse_var_list.append(var) - var_list = sparse_var_list else: - var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) - return var_list + sparse_var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + return sparse_var_list def init_check(defer_build, var_list): @@ -60,11 +59,10 @@ def saver_init(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=saver_pb2.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None, fid_version=0): - if not defer_build: - var_list = build_var_list(var_list) + + self._var_list = var_list self._last_checkpoints = [] self._checkpoints_to_be_deleted = [] - self._var_list = var_list self._is_built = False self._is_empty = None init_check(defer_build, var_list) @@ -92,6 +90,7 @@ def saver_init(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_time = self._keep_checkpoint_every_n_hours * 3600 self._next_checkpoint_time = (time.time() + keep_time) elif not defer_build: + self._var_list = build_var_list(var_list) self.build() self._object_restllore_saver = None # mxRec Patch @@ -316,12 +315,13 @@ def saver_from_object_based_checkpoint(checkpoint_path, var_list=None, builder=N def build_var_list(var_list): if var_list is None: - var_list = [] + save_var_list = [] tmp_list = variables._all_saveable_objects() removing_var_list = export_removing_var_list() for var in tmp_list: if var.name not in removing_var_list: - var_list.append(var) + save_var_list.append(var) + return save_var_list return var_list diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 5989c150..d83dca75 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -162,15 +162,15 @@ class Saver(object): def _restore(self, sess, reading_path): restore_feed_dict = defaultdict(dict) + if is_asc_manager_initialized(): + restore_host_data(reading_path) + logging.debug(f"host data was restored.") + for table_name, sub_placeholder_dict in self.placeholder_dict.items(): fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, NameDescriptor(table_name, DataName.EMBEDDING.value)) table_instance = get_table_instance_by_name(table_name) - if is_asc_manager_initialized(): - restore_host_data(reading_path) - logging.debug(f"host data was restored.") - if table_instance.use_feature_mapping: fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, NameDescriptor(table_name, DataName.FEATURE_MAPPING.value)) diff --git a/mx_rec/util/variable.py b/mx_rec/util/variable.py index 9101616a..900d1dad 100644 --- a/mx_rec/util/variable.py +++ b/mx_rec/util/variable.py @@ -14,16 +14,6 @@ def get_dense_and_sparse_variable(): return dense_variables, sparse_variables -def remove_saving_var(variable): - global_variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - savable_objects = ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS) - if variable in global_variables: - global_variables.remove(variable) - - if variable in savable_objects: - savable_objects.remove(variable) - - def check_and_get_config_via_var(variable, optimizer_type: str): table_instance = get_table_instance(variable) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index c63df0c7..48891b5d 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -377,6 +377,10 @@ void Checkpoint::LoadDataset(const vector& embNames, CkptData& ckptData) { for (const auto& embName : embNames) { + if (!CheckEmbNames(embName)) { + continue; + } + auto dataDir { innerDirPath + dirSeparator + embName }; for (const auto& saveDataType : saveDataTypes) { auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; -- Gitee From 7398d4b01b59875c79270db3d0ac362ed1c565f6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 13 Jun 2023 17:27:22 +0800 Subject: [PATCH 130/551] Match-id-cfc1dd5a4dc21e1e20687632c4ce6df7ffaf9f58 --- src/core/hd_transfer/hd_transfer.cpp | 7 ++-- src/core/hd_transfer/hd_transfer.h | 4 +-- src/core/host_emb/host_emb.cpp | 13 +++++--- src/core/key_process/key_process.cpp | 3 +- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- tools/mx_rec_perf.sh | 49 ++++++++++++++++++++++++++++ 6 files changed, 66 insertions(+), 12 deletions(-) create mode 100644 tools/mx_rec_perf.sh diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 39c8c76e..7928f8f2 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -148,7 +148,7 @@ vector HDTransfer::Recv(TransferChannel channel, int channel std::vector tensors; string recvName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelId); spdlog::debug("hd transfer try recv:{}", recvName); - + spdlog::stopwatch sw; tensorflow::Status status = tensorflow::RecvTensorByAcl(transferChannels[recvName], tensors); if (!running) { return {}; @@ -162,7 +162,7 @@ vector HDTransfer::Recv(TransferChannel channel, int channel for (auto& t: tensors) { sizes.push_back(t.NumElements()); } - spdlog::info("hd transfer recv:{}, size:{}", recvName, sizes); + spdlog::info("hd transfer recv:{}, size:{} cost:{}ms", recvName, sizes, Format2Ms(sw)); return tensors; #endif return {}; @@ -175,6 +175,7 @@ tuple HDTransfer::RecvAcl(TransferChannel channel, int c std::vector tensors; string recvName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelId); spdlog::debug("hd transfer try recv:{}", recvName); + spdlog::stopwatch sw; acltdtDataset* aclDataset = acltdtCreateDataset(); if (aclDataset == nullptr) { throw runtime_error(fmt::format("Failed recv:{}.", recvName).c_str()); @@ -186,7 +187,7 @@ tuple HDTransfer::RecvAcl(TransferChannel channel, int c if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { throw runtime_error(fmt::format("Failed receive data from acl channel, acl status:{}", aclStatus).c_str()); } - spdlog::info("hd transfer recv:{}", recvName); + spdlog::info("hd transfer recv:{} cost:{}ms", recvName, Format2Ms(sw)); return {aclDataset, acltdtGetDatasetSize(aclDataset)}; #endif return {nullptr, 0}; diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index b840649d..37fee0b4 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -24,8 +24,8 @@ namespace MxRec { const std::string MGMT = "\033[32m[Mgmt]\033[0m "; const std::string HD = "\033[32m[HD]\033[0m "; const std::string HOSTEMB = "\033[32m[HostEmb]\033[0m "; - const int PING_PONG_SIZE = 12; - const int LARGE_CHANNEL_SIZE = 100; + const int PING_PONG_SIZE = 6; + const int LARGE_CHANNEL_SIZE = 40; enum class TransferChannel { D2H, diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 453797c9..f62c9a67 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -107,7 +107,8 @@ void HostEmb::Join() #ifndef GTEST void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName) { - EASY_FUNCTION(profiler::colors::Purple) + EASY_FUNCTION(profiler::colors::Purple); + spdlog::stopwatch sw; auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; spdlog::info(HOSTEMB + "wait D2H embs, channelId:{}", channelId); @@ -132,7 +133,7 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, dst[j] = tensorPtr[j + embeddingSize * i]; } } - spdlog::info(HOSTEMB + "update emb end"); + spdlog::info(HOSTEMB + "update emb end cost: {}ms", Format2Ms(sw)); EASY_END_BLOCK } @@ -149,6 +150,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI spdlog::warn(HOSTEMB + "recv empty data"); return; } + spdlog::stopwatch sw; spdlog::info(HOSTEMB + "UpdateEmb End missingkeys len = {}", missingKeysHostPos.size()); EASY_BLOCK("Update") auto& embData = hostEmbs[embName].embData; @@ -169,7 +171,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI if (acltdtDestroyDataset(aclDataset) != ACL_ERROR_NONE) { throw runtime_error("Acl destroy tensor dataset failed."); } - spdlog::info(HOSTEMB + "update emb end"); + spdlog::info(HOSTEMB + "update emb end cost: {}ms", Format2Ms(sw)); })); } @@ -181,6 +183,7 @@ void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& vector& h2dEmbOut) { EASY_FUNCTION() + spdlog::stopwatch sw; const auto& emb = hostEmbs[embName]; const int embeddingSize = emb.hostEmbInfo.extEmbeddingSize; h2dEmbOut.emplace_back(Tensor(tensorflow::DT_FLOAT, { @@ -190,13 +193,13 @@ void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& auto tmpData = tmpTensor.flat(); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(missingKeysHostPos, emb, tmpData) for (size_t i = 0; i < missingKeysHostPos.size(); i++) { - const auto src = emb.embData[missingKeysHostPos[i]]; + const auto& src = emb.embData[missingKeysHostPos[i]]; #pragma omp simd for (int j = 0; j < embeddingSize; j++) { tmpData(j + i * embeddingSize) = src[j]; } } - spdlog::info("GetH2DEmb end, missingKeys count:{}", missingKeysHostPos.size()); + spdlog::info("GetH2DEmb end, missingKeys count:{} cost:{}ms", missingKeysHostPos.size(), Format2Ms(sw)); } diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index ea2f6572..75586c65 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -110,7 +110,7 @@ int KeyProcess::Start() }; // for clean code int threadNum; for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { - const char* threadNumEnv = getenv("THREAD_NUM"); + const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); if (threadNumEnv != nullptr) { threadNum = static_cast(*threadNumEnv) - static_cast('0'); if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { @@ -119,6 +119,7 @@ int KeyProcess::Start() } else { threadNum = KEY_PROCESS_THREAD; } + spdlog::info(KEY_PROCESS "key process thread num: {}", threadNum); for (int id = 0; id < threadNum; ++id) { procThreads.emplace_back( std::make_unique(fn, channel, id)); // use lambda expression initialize thread diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index c2fb614f..20761f3b 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -156,7 +156,7 @@ public: } batchIdsInfo.at(channelId) = 0; - const char* threadNumEnv = getenv("THREAD_NUM"); + const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); if (threadNumEnv != nullptr) { threadNum = static_cast(*threadNumEnv) - static_cast('0'); if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { diff --git a/tools/mx_rec_perf.sh b/tools/mx_rec_perf.sh new file mode 100644 index 00000000..6fadb5e8 --- /dev/null +++ b/tools/mx_rec_perf.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. +# Description MxRec性能分析脚本 V1.0 + +file="$1" #请输入spdlog文件 + +calculate_average() { + awk '{ + sum += $1; + count++ + } END { + avg = sum / count; + print avg + }' +} +echo " =========MxRec性能分析脚本 V1.0========= " +echo "read batch cost" +cat ${file} | grep 'read batch cost'|tail -n 20| awk 'NR%2==1' +echo "====================================" +echo "key process cost" +cat ${file} | grep 'key process cost'|tail +avg=`cat ${file} | grep -Po '(?<=key process cost:)[^,:]+(?=ms)'|tail -n +20 |calculate_average` +real_avg=$(echo "$avg" | awk '{ printf "%.2f", $0/6 }') +echo "Average: $real_avg(single thread avg is $avg)" +echo "====================================" +echo "分析host和device流水,当 host key process 提前训练step时,host性能不为瓶颈" +echo "按输入训练step打印标志,(默认为step) Enter打开分析,按q退出" +read step +step="${step:-step}" +cat ${file} | grep -P "key process cost|${step}"|tail -n100|less + +exit + +# getnext超时问题定位 +echo -n "超时通道为:" +cat ${file} | grep -Po "aicpu_getnext.*GetNext" +echo - "检查是否发送, 发送数量为注意8卡 " +cat ${file} | grep -P "send" |grep all2all |wc -l +cat ${file} | grep -P "send"|grep h2d|wc -l + +echo -n "检查数据读取, 读取batch数量为 " +cat ${file} | grep 'read batch cost'|wc -l +cat ${file} | grep 'read batch cost'|tail + + + +# 查看hot emb去重率 +echo "表名及去重率(去重后/去重前)为:(应该要小于0.4)" +cat op_summary_*.csv |grep gather_for_restore_vector |awk -F "," '{print $6,$14,$15}'|sed 's/"//g'|sed 's/ [0-9]*;/\//' -- Gitee From b2b0b269b7d1cebba1638ccb5c12079a50f5c37c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 14 Jun 2023 10:25:53 +0800 Subject: [PATCH 131/551] Match-id-3e98bc6597cebe13d2babdfb2d67bbe1be23daf2 --- cust_op/cust_op_by_addr/run.sh | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/cust_op/cust_op_by_addr/run.sh b/cust_op/cust_op_by_addr/run.sh index 2ed188d8..e994c1b2 100644 --- a/cust_op/cust_op_by_addr/run.sh +++ b/cust_op/cust_op_by_addr/run.sh @@ -1,12 +1,18 @@ - -#source /usr/local/Ascend/ascend-toolkit/set_env.sh +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: Build for cust_op by address +# Author: MindX SDK +# Create: 2023 +# History: NA +set -e source /etc/profile + +# 查找msopgen的路径,加入到环境变量PATH中 msopgen_path=$(find /usr/local/Ascend/ -name msopgen | grep bin) -# 截取上一层目录 parent_dir=$(dirname "$msopgen_path") - export PATH=$parent_dir:$PATH +# 利用msopgen生成可编译文件 rm -rf ./custom_op msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b -lan cpp -out ./custom_op -m 0 -op EmbeddingLookupByAddress msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b -lan cpp -out ./custom_op -m 1 -op EmbeddingUpdateByAddress @@ -16,15 +22,13 @@ cp -rf op_host custom_op/ cd custom_op -# 判断当前目录下是否存在cmake.json文件 +# 判断当前目录下是否存在CMakePresets.json文件 if [ ! -f "CMakePresets.json" ]; then echo "当前目录下不存在cmake.json文件" exit 1 fi -#jq '.configurePresets.cacheVariables.ASCEND_CANN_PACKAGE_PATH.value = "/usr/local/Ascend/ascend-toolkit/latest"' CMakePresets.json > tmp.json -#mv tmp.json cmake.json - +# 修改cann安装路径 sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json cd cmake @@ -35,7 +39,7 @@ if [ ! -f "config.cmake" ]; then exit 1 fi - +# 修改设备环境 sed -i 's:set(ASCEND_COMPUTE_UNIT ascend910b):set(ASCEND_COMPUTE_UNIT ascend910b ascend910):g' config.cmake cd .. -- Gitee From 746b9fc9da52cf129828bae3160d023fb059d597 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 14 Jun 2023 14:43:30 +0800 Subject: [PATCH 132/551] Match-id-3e1614d1fd43fb9f69e7a12a25fb162c4cfd2d81 --- .gitmodules | 3 + build/build.sh | 14 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 6 + src/core/key_process/key_process.cpp | 322 +++++++++++++++++- src/core/key_process/key_process.h | 32 ++ src/core/ock_ctr_common/include/factory.h | 60 ++++ .../include/ock_ctr_common_def.h | 58 ++++ src/core/ock_ctr_common/include/unique.h | 124 +++++++ src/core/utils/common.cpp | 1 + src/core/utils/common.h | 3 +- src/core/utils/unique.h | 1 + src/platform/AccCTR | 1 + src/test_ut.sh | 13 + src/tests/CMakeLists.txt | 1 + src/tests/key_process/key_process_test.cpp | 124 +++++++ 15 files changed, 759 insertions(+), 4 deletions(-) create mode 100644 src/core/ock_ctr_common/include/factory.h create mode 100644 src/core/ock_ctr_common/include/ock_ctr_common_def.h create mode 100644 src/core/ock_ctr_common/include/unique.h create mode 160000 src/platform/AccCTR diff --git a/.gitmodules b/.gitmodules index 4b398c49..57a5bc65 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "src/thirdparty/pybind11"] path = src/thirdparty/pybind11 url = https://codehub-dg-y.huawei.com/OpenSourceCenter/pybind11.git +[submodule "src/platform/AccCTR"] + path = src/platform/AccCTR + url = https://szv-y.codehub.huawei.com/ComputingFoundationSoftware/ock-ascend-domain/AccCTR.git diff --git a/build/build.sh b/build/build.sh index 4cdfaffa..fb684738 100644 --- a/build/build.sh +++ b/build/build.sh @@ -92,7 +92,8 @@ echo "${abseil_src_path}" abseil_install_path="${ROOT_DIR}"/install/abseil src_path="${ROOT_DIR}"/src - +acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR +cp -rf ../platform/securec/* /usr1/mxRec/src/platform/AccCTR/3rdparty/huawei_secure_c cd "${ROOT_DIR}" release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz @@ -143,6 +144,13 @@ compile_so_file() cd .. } +compile_acc_ctr_so_file() +{ + cd "${acc_ctr_path}" + chmod u+x build.sh + ./build.sh "release" +} + collect_so_file() { cd "${src_path}" @@ -150,6 +158,7 @@ collect_so_file() mkdir -p "${src_path}"/libasc chmod u+x libasc + cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc cp -df "${ROOT_DIR}"/output/*.so* libasc cp "${ROOT_DIR}"/platform/securec/lib/libsecurec.so libasc } @@ -197,6 +206,9 @@ clean() install_abseil compile_securec +echo "-----Build AccCTR -----" +compile_acc_ctr_so_file + echo "-----Build Start tf1 -----" source "${SCRIPT_DIR}"/tf1_env/bin/activate compile_so_file "${tf1_path}" diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index b20388cf..aebfd006 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -40,6 +40,12 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& spdlog::info("config MAX_UNIQUE_THREAD_NUM:{}", num); } + if (getenv("FAST_UNIQUE") != nullptr) { + bool isFastUnique = std::atoi(getenv("FAST_UNIQUE")); + PerfConfig::fastUnique = isFastUnique; + spdlog::info("config FAST_UNIQUE:{}", PerfConfig::fastUnique); + } + preprocess = Singleton::GetInstance(); preprocess->Initialize(rankInfo, embInfos, thresholdValues, seed); preprocess->Start(); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 75586c65..9171d327 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -13,15 +13,17 @@ #include #include #include +#include #include "checkpoint/checkpoint.h" #include "hd_transfer/hd_transfer.h" #include "utils/common.h" - +#include "utils/time_cost.h" using namespace std; using namespace chrono; using namespace MxRec; +using namespace ock::ctr; static shared_mutex g_smut; @@ -85,6 +87,10 @@ int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, spdlog::warn(KEY_PROCESS "Feature admit-and-evict function is unavailable ..."); } + if (PerfConfig::fastUnique) { + Factory::Create(factory); + } + spdlog::info(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", scInfo, rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot); return 0; @@ -202,20 +208,84 @@ void KeyProcess::LoadSaveUnlock() } } +void KeyProcess::GetUniqueConfig(UniqueConf& uniqueConf) +{ + if (rankInfo.rankSize > 0) { + uniqueConf.useSharding = true; + uniqueConf.shardingNum = rankInfo.rankSize; + } + + if (rankInfo.useStatic) { + uniqueConf.usePadding = true; + uniqueConf.paddingVal = -1; + } else { + uniqueConf.usePadding = false; + } + + uniqueConf.useIdCount = true; + uniqueConf.outputType = OutputType::ENHANCED; + uniqueConf.minThreadNum = MIN_UNIQUE_THREAD_NUM; + uniqueConf.maxThreadNum = PerfConfig::maxUniqueThreadNum; +} + +void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, + const unique_ptr & batch, UniquePtr& unique) +{ + uniqueConf.desiredSize = (uint32_t)batch->Size(); + if (preBatchSize != batch->Size()) { + uniqueInitialize = false; + preBatchSize = batch->Size(); + } + + if (!uniqueInitialize) { + if (rankInfo.useStatic) { + uniqueConf.paddingSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); + } + + uniqueConf.maxIdVal = INT64_MAX; + uniqueConf.dataType = ock::ctr::DataType::INT64; + + unique->Initialize(uniqueConf); + uniqueInitialize = true; + } +} + void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCESS_THREAD-1] { unique_ptr batch; + UniquePtr unique = nullptr; + UniqueConf uniqueConf; + size_t preBatchSize = 0; + bool uniqueInitialize = false; + + if (PerfConfig::fastUnique) { + factory->CreateUnique(unique); + GetUniqueConfig(uniqueConf); + } + spdlog::stopwatch sw; try { while (true) { + TimeCost getAndProcesTC; + TimeCost getBatchTC; batch = GetBatchData(channel, id); // get batch data from SingletonQueue + TIME_PRINT("GetBatchData TimeCost(ms):{}", getBatchTC.ElapsedMS()); + if (batch == nullptr) { break; } auto getBatchTime = Format2Ms(sw); sw.reset(); - if (!KeyProcessTaskHelper(batch, channel, id)) { + bool ret = false; + if (PerfConfig::fastUnique) { + InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); + ret = KeyProcessTaskHelperWithUnique(batch, unique, channel, id); + } else { + ret = KeyProcessTaskHelper(batch, channel, id); + } + + if (!ret) { break; } spdlog::info(KEY_PROCESS "key process cost:{}, get data time:{} batch {}[{}]:{} ", @@ -223,6 +293,9 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } + if (PerfConfig::fastUnique) { + unique->UnInitialize(); + } } catch (const EndRunError &e) { spdlog::debug(KEY_PROCESS "abort run: {}", e.what()); } @@ -245,12 +318,69 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < } } +bool KeyProcess::KeyProcessTaskHelperWithUnique(unique_ptr& batch, UniquePtr& unique, + int channel, int id) +{ + // tuple for keyRec restore hotPos scAll countRecv + isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && + FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; + TimeCost tc; + UniqueInfo uniqueInfo; + ProcessBatchWithUniqueCompute(batch, unique, id, uniqueInfo); + TIME_PRINT("no copy ProcessBatchWithUniqueCompute TimeCost(ms):{}", tc.ElapsedMS()); + + // 特征准入&淘汰 + if (isWithFAAE && + (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, uniqueInfo.all2AllInfo.keyRecv, + uniqueInfo.all2AllInfo.countRecv) + == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { + spdlog::error(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", + rankInfo.rankId, id, channel); + return false; + } + + // without host, just device, all embedding vectors were stored in device + // map key to offset directly by lookup keyOffsetMap (hashmap) + if (rankInfo.noDDR) { + TimeCost key2OffsetTc; + Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv); + TIME_PRINT("Key2Offset TimeCost(ms):{}", key2OffsetTc.ElapsedMS()); + } + if (!rankInfo.useStatic) { // Static all2all,need send count + auto embName = batch->name; + if (batch->modifyGraph) { + embName = batch->channelName; + } + SendA2A(uniqueInfo.all2AllInfo.scAll, embName, batch->channel, batch->batchId); + } + + auto tensors = make_unique>(); + tensors->push_back(Vec2TensorI32(uniqueInfo.restore)); + if (rankInfo.useHot) { + uniqueInfo.hotPos.resize(hotEmbTotCount[batch->name], -1); + tensors->push_back(Vec2TensorI32(uniqueInfo.hotPos)); + } + if (rankInfo.noDDR) { + if (rankInfo.useDynamicExpansion) { + tensors->push_back(Vec2TensorI64(uniqueInfo.all2AllInfo.keyRecv)); + } else { + tensors->push_back(Vec2TensorI32(uniqueInfo.all2AllInfo.keyRecv)); + } + } + TimeCost pushTensorTc; + PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv); + TIME_PRINT("pushTensorToListTC TimeCost(ms):{}", pushTensorTc.ElapsedMS()); + return true; +} + bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int id) { vector splitKeys; vector restore; vector hotPos; vector> keyCount; + + TimeCost hashSplit_tc; HashSplitHelper(batch, splitKeys, restore, hotPos, keyCount); auto [lookupKeys, scAll, ss] = ProcessSplitKeys(batch, id, splitKeys); @@ -260,6 +390,8 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe countRecv = GetCountRecv(batch, id, keyCount, scAll, ss); } BuildRestoreVec(batch, ss, restore, static_cast(hotPos.size())); + TIME_PRINT("HashSplit TimeCost(ms):{}", hashSplit_tc.ElapsedMS()); + // 特征准入&淘汰 if (m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE && @@ -397,6 +529,147 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) return batch; } +size_t KeyProcess::GetKeySize(const unique_ptr &batch) +{ + size_t size = rankInfo.rankSize * embInfos[batch->name].sendCount; + if (batch->modifyGraph) { + size = rankInfo.rankSize * embInfos[batch->name].sendCountMap[batch->channelName]; + } + if (!rankInfo.useStatic) { + size = batch->Size(); + } + return size; +} + +void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, UniquePtr& unique, + int id, UniqueInfo& uniqueInfoOut) +{ + EASY_FUNCTION(profiler::colors::Purple) + EASY_VALUE("batchId", batch->batchId) + + EASY_BLOCK("ock-unique") + + KeySendInfo keySendInfo; + size_t size = GetKeySize(batch); + keySendInfo.keySend.resize(size); + vector splitSize(rankInfo.rankSize); + vector uniqueVector(batch->Size()); + uniqueInfoOut.restore.resize(batch->Size()); + vector idCount(batch->Size()); + if (rankInfo.useStatic) { + keySendInfo.keyCount.resize(size); + } + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = (uint32_t)batch->Size(); + uniqueIn.inputId = reinterpret_cast(batch->sample.data()); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(keySendInfo.keySend.data()); + uniqueOut.index = (uint32_t*)uniqueInfoOut.restore.data(); + uniqueOut.idCnt = idCount.data(); + uniqueOut.idCntFill = keySendInfo.keyCount.data(); + uniqueOut.uniqueIdCntInBucket = splitSize.data(); + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueVector.data()); + uniqueOut.uniqueIdCnt = 0; + + TimeCost unique_tc; + int ret = unique->DoEnhancedUnique(uniqueIn, uniqueOut); + EASY_END_BLOCK + TIME_PRINT("UniqueCompute TimeCost(ms):{} ret:{}", unique_tc.ElapsedMS(), ret); + + vector sc; + HandleHotAndSendCount(batch, uniqueInfoOut, keySendInfo, sc, splitSize); + + All2All(sc, id, batch->channel, keySendInfo, uniqueInfoOut.all2AllInfo); + + spdlog::debug(KEY_PROCESS "ProcessBatchWithUniqueCompute get batchId:{}, batchSize:{}, channel:{}, " + "channelName:{}, name:{}, restore:{}, keyCount:{}", batch->batchId, batch->Size(), + batch->channel, batch->channelName, batch->name, uniqueInfoOut.restore.size(), + keySendInfo.keyCount.size()); +} + +void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, + KeySendInfo& keySendInfo, vector& sc, vector& splitSize) +{ + std::shared_lock lock(g_smut); + auto hotMap = hotKey[batch->name]; + lock.unlock(); + + if (rankInfo.useHot) { + int hotOffset = 0; + uniqueInfoOut.hotPos.resize(hotEmbTotCount[batch->name]); + hotOffset = hotEmbTotCount[batch->name]; + + TimeCost ComputeHotTc; + ComputeHotPos(batch, hotMap, uniqueInfoOut.hotPos, uniqueInfoOut.restore, hotOffset); + TIME_PRINT("ComputeHot TimeCost(ms):{}", ComputeHotTc.ElapsedMS()); + UpdateHotMapForUnique(keySendInfo.keySend, keySendInfo.keyCount, + hotOffset, batch->batchId % hotEmbUpdateStep == 0, batch->name); + } + + if (rankInfo.useStatic) { + sc.resize(rankInfo.rankSize, GetSendCount(batch->name, batch->channelName, batch->modifyGraph)); + } else { + sc.resize(rankInfo.rankSize); + for (int i = 0;i < rankInfo.rankSize; i++) { + sc[i] = splitSize[i]; + } + } +} + +void KeyProcess::ComputeHotPos(const unique_ptr &batch, map &hotMap, + vector &hotPos, vector &restore, const int hotOffset) +{ + auto inputData = batch->sample.data(); + + int hotCount = 0; + for (size_t i = 0;i < batch->Size(); ++i) { + auto key = inputData[i]; + auto hot = hotMap.find(key); + if (hot != hotMap.end()) { + if (hot->second == -1) { + hotPos[hotCount] = restore[i]; + hot->second = hotCount; + restore[i] = hotCount++; + } else { + restore[i] = hot->second; + } + } else { + restore[i] += hotOffset; + } + } +} + +void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, + All2AllInfo& all2AllInfoOut) +{ + TimeCost get_sc_all; + GetScAllForUnique(sc, id, channel, all2AllInfoOut.scAll); // Allgather通信获取所有(不同rank相同thread id的) + TIME_PRINT("GetScAll TimeCost(ms):{}", get_sc_all.ElapsedMS()); + + TimeCost all2allTC; + auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 + vector rc(rankInfo.rankSize); // receive count + for (int i = 0; i < rankInfo.rankSize; ++i) { + // 通信量矩阵某一列的和即为本地要从其他设备接受的key数据量 + rc[i] = all2AllInfoOut.scAll.at(i * rankInfo.rankSize + rankInfo.rankId); + } + auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 + all2AllInfoOut.keyRecv.resize(rs.back() + rc.back()); + EASY_BLOCK("all2all") + MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfoOut.keyRecv.data(), + rc.data(), rs.data(), MPI_INT64_T, comm[channel][id]); + + all2AllInfoOut.countRecv.resize(rs.back() + rc.back()); + if (isWithFAAE) { + MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfoOut.countRecv.data(), + rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); + } + TIME_PRINT("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); + EASY_END_BLOCK +} + auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, vector& splitKeys) -> tuple, vector> { @@ -613,6 +886,31 @@ void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& h } } +void KeyProcess::UpdateHotMapForUnique(const keys_t &keySend, const vector &keyCount, + uint32_t count, bool refresh, const string& embName) +{ + auto& hotMap = hotKey[embName]; + if (refresh) { + priority_queue> pq; + for (size_t i = 0;i < keySend.size(); ++i) { + if (keySend[i] == -1) { + continue; + } + pq.push(pair(-keyCount[i], keySend[i])); + if (pq.size() > count) { + pq.pop(); + } + } + // gen new hot map + std::unique_lock lock(g_smut); + hotMap.clear(); + while (!pq.empty()) { + hotMap.insert(make_pair(pq.top().second, -1)); + pq.pop(); + } + } +} + void KeyProcess::UpdateHotMap(absl::flat_hash_map& keyCountMap, uint32_t count, bool refresh, const string& embName) { @@ -661,6 +959,26 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int return scAll; } +void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const +{ + EASY_FUNCTION() + scAllOut.resize(rankInfo.rankSize * rankInfo.rankSize); + EASY_BLOCK("barrier"); + // 通信终止信号,同步退出,防止线程卡住 + spdlog::stopwatch sw; + int exitFlag = isRunning; + MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + if (exitFlag < rankInfo.rankSize) { + throw EndRunError("GetScAll end run."); + } + EASY_END_BLOCK; + spdlog::debug(KEY_PROCESS "barrier time:{}", duration_cast((sw).elapsed())); + // allgather keyScLocal(key all2all keyScLocal = device all2all rc) + MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, + scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + spdlog::debug("rank {} key scAllOut matrix:\n{}", rankInfo.rankId, scAllOut); +} + void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) { EASY_FUNCTION(profiler::colors::Blue600) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index a11e4e2c..5486b2fc 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -25,13 +25,16 @@ #include "utils/common.h" #include "utils/safe_queue.h" +#include "utils/task_queue.h" #include "host_emb/host_emb.h" #include "emb_table/emb_table.h" #include "feature_admit_and_evict.h" +#include "ock_ctr_common/include/factory.h" namespace MxRec { using namespace std; + using namespace ock::ctr; template struct Cmp { @@ -111,6 +114,8 @@ namespace MxRec { map> hotKey {}; map hotEmbTotCount; map embeddingTableMap {}; + + FactoryPtr factory {}; int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; @@ -122,9 +127,25 @@ namespace MxRec { bool KeyProcessTaskHelper(unique_ptr& batch, int channel, int id); + bool KeyProcessTaskHelperWithUnique(unique_ptr &batch, UniquePtr& unique, + int channel, int id); + auto ProcessSplitKeys(const unique_ptr& batch, int id, vector& splitKeys) -> tuple, vector>; + void GetUniqueConfig(UniqueConf& uniqueConf); + + void InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, + const unique_ptr & batch, UniquePtr& unique); + + void ProcessBatchWithUniqueCompute(const unique_ptr &batch, UniquePtr& unique, + int id, UniqueInfo& uniqueInfoOut); + + size_t GetKeySize(const unique_ptr &batch); + + void All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, + All2AllInfo& all2AllInfoOut); + auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector>; @@ -134,6 +155,8 @@ namespace MxRec { vector GetScAll(const vector& keyScLocal, int commId, int channel) const; + void GetScAllForUnique(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const; + void Key2Offset(const emb_name_t& embName, keys_t& splitKey); unique_ptr GetBatchData(int channel, int commId); @@ -150,11 +173,20 @@ namespace MxRec { void UpdateHotMap(absl::flat_hash_map& keyCountMap, uint32_t count, bool refresh, const string& embName); + void UpdateHotMapForUnique(const keys_t &keySend, const vector &keyCount, + uint32_t count, bool refresh, const string& embName); + + void HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, + KeySendInfo& keySendInfo, vector& sc, vector& splitSize); + void PushResult(unique_ptr& batch, unique_ptr> tensors, keys_t& lookupKeys); void AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, const unique_ptr& batch); + void ComputeHotPos(const unique_ptr &batch, map &hotMap, + vector &hotPos, vector &restore, const int hotOffset); + vector GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss); diff --git a/src/core/ock_ctr_common/include/factory.h b/src/core/ock_ctr_common/include/factory.h new file mode 100644 index 00000000..6753462c --- /dev/null +++ b/src/core/ock_ctr_common/include/factory.h @@ -0,0 +1,60 @@ +/* + * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * @Description: + * @Version: 1.0 + * @Author: dev + * @Date: 2023-05-5 09:50:00 + * @LastEditors: dev + * @LastEditTime: 2023-05-5 09:50:00 + */ + +#ifndef UNIQUE_OCK_CTR_COMMON_H +#define UNIQUE_OCK_CTR_COMMON_H + +#include +#include +#include +#include "unique.h" + + +#ifdef __cplusplus +extern "C" { +#endif + +using ExternalLog = void (*)(int level, const char *msg); + +#ifdef __cplusplus +} +#endif + +#include "ock_ctr_common_def.h" + +namespace ock { +namespace ctr { +class Factory; + +using FactoryPtr = std::shared_ptr; +using UniquePtr = std::shared_ptr; + +class Factory { +public: + virtual ~Factory() = default; + virtual int CreateUnique(UniquePtr &out) = 0; + virtual int SetExternalLogFuncInner(ExternalLog logFunc) = 0; + +public: + static int Create(FactoryPtr &out) + { + int result = 0; + uintptr_t factory = 0; + /* dynamic load function */ + if ((result = OckCtrCommonDef::CreatFactory(&factory)) == 0) { + out.reset(reinterpret_cast(factory)); + } + return result; + } +}; +} +} + +#endif // UNIQUE_OCK_CTR_COMMON_H diff --git a/src/core/ock_ctr_common/include/ock_ctr_common_def.h b/src/core/ock_ctr_common/include/ock_ctr_common_def.h new file mode 100644 index 00000000..66a50c8b --- /dev/null +++ b/src/core/ock_ctr_common/include/ock_ctr_common_def.h @@ -0,0 +1,58 @@ +/* + * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * @Description: + * @Version: 1.0 + * @Author: dev + * @Date: 2023-05-5 09:50:00 + * @LastEditors: dev + * @LastEditTime: 2023-05-5 09:50:00 + */ + +#ifndef OCK_OCK_CTR_COMMON_DEF_H +#define OCK_OCK_CTR_COMMON_DEF_H + +#include +#include +#include + +using CTR_CREATE_FACTORY_FUNCTION = int (*)(uintptr_t *); + +namespace ock { +namespace ctr { +class OckCtrCommonDef { +public: + static int CreatFactory(uintptr_t *factory) + { + static void *handle = nullptr; + static std::mutex m; + std::unique_lock lock(m); + if (handle != nullptr) { + std::cout << "can't create factory more than 1 time." << std::endl; + return -1; + } + + handle = dlopen(LIBRARY_NAME, RTLD_NOW); + if (handle == nullptr) { + std::cout << "Failed to call dlopen to load library '" << LIBRARY_NAME << "', error " << dlerror() << + std::endl; + return -1; + } + + auto fun = (CTR_CREATE_FACTORY_FUNCTION)dlsym(handle, "CTR_CreateFactory"); + if (fun == nullptr) { + std::cout << "Failed to call dlsym to load function 'CTR_CreateFactory', error " << dlerror() << std::endl; + dlclose(handle); + return -1; + } + + fun(factory); + return 0; + } + +private: + constexpr static const char *LIBRARY_NAME = "lib_ock_ctr_common.so"; +}; +} +} + +#endif // OCK_OCK_CTR_COMMON_DEF_H diff --git a/src/core/ock_ctr_common/include/unique.h b/src/core/ock_ctr_common/include/unique.h new file mode 100644 index 00000000..59ed98b5 --- /dev/null +++ b/src/core/ock_ctr_common/include/unique.h @@ -0,0 +1,124 @@ +/* + * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * @Description: + * @Version: 1.0 + * @Author: dev + * @Date: 2023-05-5 09:50:00 + * @LastEditors: dev + * @LastEditTime: 2023-05-5 09:50:00 + */ + +#ifndef OCK_UNIQUE_H +#define OCK_UNIQUE_H +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +using ExternalThread = void (*)(const std::vector> &tasks); + +#ifdef __cplusplus +} +#endif + +namespace ock { +namespace ctr { +using BucketStrategy = enum class BucketStrategy { + MODULO +}; + +using DataType = enum class DataType { + INT64 = 0, + INT32 +}; + +using OutputType = enum class OutputType { + NORMAL = 0, + ENHANCED +}; + +using UniqueConf = struct UniqueConfCTR { + BucketStrategy bucketStrategy = BucketStrategy::MODULO; + OutputType outputType = OutputType::NORMAL; // 是否为普通unique + DataType dataType = DataType::INT64; // 输入id类型 + bool usePadding = false; // 是否开启padding, 开启前要先开启sharding + bool useIdCount = false; // 是否开启id计数 + bool useSharding = false; // 是否开启sharding + int shardingNum = -1; // 分桶个数 + uint32_t desiredSize = 256; // 预估id所需内存空间 + int paddingSize = 0; // 填充后长度 + int paddingVal = -1; // 填充值 + uint32_t minThreadNum = 1; // 最小工作线程数 + uint32_t maxThreadNum = 8; // 最大工作线程数 + int64_t maxIdVal = 0; // 最大id值 + bool trace = false; // 是否开启性能检测,需要配合外部日志输出 +} __attribute__((packed)); + +using UniqueIn = struct UniqueInCTR { + void *inputId = nullptr; // 输入的ids首地址(需要用户申请)必填 + uint32_t inputIdCnt = 0; // 输入ids的个数 +}; + +using UniqueOut = struct UniqueOutCTR { + void *uniqueId = nullptr; // 去重分桶填充之后最终的的ids(需要用户申请)必选 + uint32_t *index = nullptr; // 去重后id的索引位置(需要用户申请)必选 + int uniqueIdCnt = 0; // 去重后的id个数 +}; + +using EnhancedUniqueOut = struct EnhancedUniqueOutCTR { + void *uniqueId = nullptr; // 去重分桶填充之后最终的的ids(需要用户申请)必选 + uint32_t *index = nullptr; // 去重后id的索引位置(需要用户申请)必选 + void *uniqueIdInBucket = nullptr; // 去重之后的分桶内的ids(需要用户申请) sharding开启之后必须申请 + int *uniqueIdCntInBucket = nullptr; // 每个桶去重后的id个数(需要用户申请) sharding开启之后必须申请 + int uniqueIdCnt = 0; // 去重后的id个数 + int *idCnt = nullptr; // 每个id的重复次数(需要用户申请) 开启idCnt之后必选 + int *idCntFill = nullptr; // 每个id的重复次数带了填充(需要用户申请) 开启idCnt和padding之后必选 +}; + +class Unique { +public: + virtual ~Unique() = default; + /* * + * 初始化unique 所需配置项 + * + * @param conf 输入unique所需的配置 + * @return error_code + */ + virtual int Initialize(const UniqueConf &conf) = 0; + + /* * + * 释放unique资源 + */ + virtual void UnInitialize() = 0; + + /* * + * id去重接口 + * + * @param UniqueIn 入参:unique用户输入 + * @param UniqueOut 出参:unique用户输出 + * @return errorCode + */ + virtual int DoUnique(UniqueIn &uniqueIn, UniqueOut &uniqueOut) = 0; + + /* * + * 具有额外输出的unique + * + * @param uniqueIn + * @param EnhancedUniqueOut + * @return errorCode + */ + virtual int DoEnhancedUnique(UniqueIn &uniqueIn, EnhancedUniqueOut &enhancedUniqueOut) = 0; + + /* * + * 设置外部线程池方法 + * + * @return + */ + virtual int SetExternalThreadFuncInner(ExternalThread threadFunc) = 0; +}; +} +} + +#endif // OCK_UNIQUE_H diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index b63e4bd1..01705dec 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -21,6 +21,7 @@ using std::chrono::system_clock; namespace MxRec { int PerfConfig::keyProcessThreadNum = DEFAULT_KEY_PROCESS_THREAD; int PerfConfig::maxUniqueThreadNum = DEFAULT_MAX_UNIQUE_THREAD_NUM; + bool PerfConfig::fastUnique = false; RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 3d8b1021..cf94660e 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -50,7 +50,7 @@ namespace MxRec { #define INFO_PTR shared_ptr -#define TIME_PRINT spdlog::info +#define TIME_PRINT spdlog::debug #define MGMT_CPY_THREADS 4 #define PROFILING using namespace tensorflow; @@ -71,6 +71,7 @@ namespace MxRec { struct PerfConfig { static int keyProcessThreadNum; static int maxUniqueThreadNum; + static bool fastUnique; }; constexpr int KEY_PROCESS_TIMEOUT = 120; diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h index ce0c4655..46782131 100644 --- a/src/core/utils/unique.h +++ b/src/core/utils/unique.h @@ -110,6 +110,7 @@ public: } } }; + template class Dedup { static constexpr uint32_t kMinimalWorkloadPerWorker = 1 << 12; static const int kDefaultBucketCount = 1 << 24; diff --git a/src/platform/AccCTR b/src/platform/AccCTR new file mode 160000 index 00000000..a4c7f7e5 --- /dev/null +++ b/src/platform/AccCTR @@ -0,0 +1 @@ +Subproject commit a4c7f7e598334c24e875df13b08877155b6c0451 diff --git a/src/test_ut.sh b/src/test_ut.sh index 26a4423d..1605692d 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -15,6 +15,9 @@ source /opt/rh/devtoolset-7/enable CUR_DIR=$(dirname "$(readlink -f "$0")") ROOT_DIR=$(dirname "${CUR_DIR}") +acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR +cp -rf ../platform/securec/* /usr1/mxRec/src/platform/AccCTR/3rdparty/huawei_secure_c +export LD_LIBRARY_PATH="${acc_ctr_path}"/output/ock_ctr_common/lib:$LD_LIBRARY_PATH compile_securec() { @@ -30,6 +33,16 @@ compile_securec() } compile_securec +compile_acc_ctr_so_file() +{ + cd "${acc_ctr_path}" + chmod u+x build.sh + ./build.sh "release" +} + +echo "-----Build AccCTR -----" +compile_acc_ctr_so_file + cd "${ROOT_DIR}"/src find ./ -name "*.sh" -exec dos2unix {} \; diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 817348a2..0059273c 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -52,6 +52,7 @@ target_link_libraries(test_main PUBLIC ${TF_LIB} securec OpenMP::OpenMP_CXX ${HDF5_CXX_LIBRARIES} ${MPI_CXX_LIBRARIES} ${PYTHON_LIBRARY} drvdsmi_host dcmi + -ldl ) target_link_libraries(test_main PUBLIC diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 0df6411b..61a3745c 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -19,12 +19,40 @@ #include "host_emb/host_emb.h" #include "key_process/key_process.h" #include "hybrid_mgmt/hybrid_mgmt.h" +#include "ock_ctr_common/include/unique.h" using namespace std; using namespace MxRec; using namespace testing; static constexpr size_t BATCH_NUM_EACH_THREAD = 3; +static constexpr int DESIRED_SIZE = 1; +FactoryPtr factory; + +class SimpleThreadPool { +public: + static void SyncRun(const std::vector> &tasks) + { + std::vector> futs; + for (auto &task : tasks) { + futs.push_back(std::async(task)); + } + for (auto &fut : futs) { + fut.wait(); + } + } +}; + +static void CTRLog(int level, const char *msg) +{ + switch (level) { + case 0: + spdlog::debug("{}", msg); + break; + default: + break; + } +} class KeyProcessTest : public testing::Test { protected: @@ -222,6 +250,8 @@ TEST_F(KeyProcessTest, Initialize) for (const EmbInfo& info: embInfos) { ASSERT_NE(process.embInfos.find(info.name), process.embInfos.end()); } + + Factory::Create(factory); } TEST_F(KeyProcessTest, Start) @@ -251,6 +281,7 @@ TEST_F(KeyProcessTest, HashSplit) } ASSERT_THAT(restore, ElementsAreArray(expectRestore)); } + #ifndef GTEST TEST_F(KeyProcessTest, GetScAll) { @@ -267,6 +298,22 @@ TEST_F(KeyProcessTest, GetScAll) ASSERT_THAT(scAll, ElementsAreArray(expectScAll)); } #endif + +TEST_F(KeyProcessTest, GetScAllForUnique) +{ + vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 + spdlog::debug(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, keyScLocal); + vector expectScAll(worldSize * worldSize); + for (unsigned int i = 0; i < expectScAll.size(); ++i) { + expectScAll[i] = floor(i / worldSize) + 1; + } + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.isRunning, true); + vector scAll; + process.GetScAllForUnique(keyScLocal, 0, 0, scAll); + ASSERT_THAT(scAll, ElementsAreArray(expectScAll)); +} + TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { auto queue = SingletonQueue::getInstances(0); @@ -359,6 +406,17 @@ TEST_F(KeyProcessTest, Key2Offset) spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys, process.keyOffsetMap); ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); } + +TEST_F(KeyProcessTest, GetUniqueConfig) +{ + UniqueConf uniqueConf; + process.rankInfo.rankSize = worldSize; + process.rankInfo.useStatic = true; + process.GetUniqueConfig(uniqueConf); + process.rankInfo.useStatic = false; + process.GetUniqueConfig(uniqueConf); +} + // 自动化测试用例 // 边界值、重复度测试 TEST_F(KeyProcessTest, ProcessPrefetchTask) @@ -375,3 +433,69 @@ TEST_F(KeyProcessTest, ProcessPrefetchTask) this_thread::sleep_for(20s); process.Destroy(); } + +TEST_F(KeyProcessTest, InitializeUnique) +{ + ASSERT_EQ(Factory::Create(factory), -1); + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + PrepareBatch(); + unique_ptr batch; + batch = process.GetBatchData(0, 0); + UniqueConf uniqueConf; + process.rankInfo.rankSize = worldSize; + process.rankInfo.useStatic = true; + bool uniqueInitialize = false; + size_t preBatchSize = 0; + process.InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); +} + +TEST_F(KeyProcessTest, GetKeySize) +{ + PrepareBatch(); + unique_ptr batch; + batch = process.GetBatchData(0, 0); + process.rankInfo.rankSize = worldSize; + process.rankInfo.useStatic = true; + process.GetKeySize(batch); +} + +TEST_F(KeyProcessTest, ProcessBatchWithUniqueCompute) +{ + PrepareBatch(); + + ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + + auto fn = [this](int channel, int id) { + UniquePtr unique; + + auto embName = embInfos[0].name; + process.hotEmbTotCount[embName] = 10; + vector splitKeys; + vector restore; + vector hotPos; + unique_ptr batch; + UniqueInfo uniqueInfo; + batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue + + ASSERT_EQ(factory->CreateUnique(unique), 0); + UniqueConf uniqueConf; + process.GetUniqueConfig(uniqueConf); + unique->Initialize(uniqueConf); + + spdlog::info("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); + process.KeyProcessTaskHelperWithUnique(batch, unique, channel, id); + spdlog::info("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, + hotPos); + }; // for clean code + for (int channel = 0; channel < 1; ++channel) { + for (int id = 0; id < 1; ++id) { + // use lambda expression initialize thread + process.procThreads.emplace_back(std::make_unique(fn, channel, id)); + } + } + this_thread::sleep_for(20s); + process.Destroy(); +} \ No newline at end of file -- Gitee From 4625e986d1d5991c65f0d1de0fdd821dc910b1ea Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 14 Jun 2023 15:06:34 +0800 Subject: [PATCH 133/551] Match-id-47bb1aee1d988285da803a95f7614fd5ecc4ddf1 --- example/little_demo/run.sh | 65 ++++++++++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 9 deletions(-) diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index e4c1cff6..2c0c80b9 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -18,10 +18,8 @@ export PYTHONPATH=${so_path}:$PYTHONPATH # 环境python安装路径 export LD_PRELOAD=/usr/lib64/libgomp.so.1 # GNU OpenMP动态库路径 export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH # 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 -export RANK_TABLE_FILE="${cur_path}/hccl_json_${local_rank_size}p.json" # 若使用去除ranktable方案,请注释掉这一行 export JOB_ID=10086 # 训练任务使用的NPU卡数总数 -export RANK_SIZE=$num_process # 若使用去除ranktable方案,请注释掉这一行 export MXREC_LOG_LEVEL="DEBUG" # 框架日志等级 export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL # 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 @@ -39,13 +37,62 @@ export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特 ################################################ export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 -#################使用去除ranktable方案时开启###################### -#export CM_CHIEF_IP="192.168.1.1" # 主节点ip -#export CM_CHIEF_PORT=6000 # 主节点监听端口 -#export CM_CHIEF_DEVICE=0 # 主节点device id -#export CM_WORKER_IP="192.168.1.1" # 当前节点ip -#export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 -######################################################### + +# 帮助信息,不需要修改 +if [[ $1 == --help || $1 == -h ]];then + echo "Usage: ./run.sh [OPTION]... [IP]..." + echo " " + echo "parameter explain: + [OPTION] main.py + [IP] IP address of the host + -h/--help show help message + " + exit 1 +fi + +# 使用ranktable方案 +function rankTableSolution() { + echo "The ranktable solution" + export RANK_TABLE_FILE="${cur_path}/hccl_json_${local_rank_size}p.json" + export RANK_SIZE=$num_process + echo "RANK_TABLE_FILE=$RANK_TABLE_FILE" + if [ ! -f "$RANK_TABLE_FILE" ];then + echo "the rank table file does not exit. Please reference {hccl_json_8p.json} to correctly config rank table file" + exit 1 + fi +} + +ip=$2 +if [ ! -n "$ip" ]; then + rankTableSolution +else + VALID_CHECK=$(echo $ip|awk -F. '$1<=255&&$2<=255&&$3<=255&&$4<=255{print "yes"}') + if echo $ip|grep -E "^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$">/dev/null; then + if [ "$VALID_CHECK" == "yes" ]; then + #################使用去除ranktable方案时开启###################### + echo "ip: $ip available." + echo "The ranktable solution is removed." + export CM_CHIEF_IP=$ip # 主节点ip + export CM_CHIEF_PORT=6000 # 主节点监听端口 + export CM_CHIEF_DEVICE=0 # 主节点device id + export CM_WORKER_IP=$ip # 当前节点ip + export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 + echo "CM_CHIEF_IP=$CM_CHIEF_IP" + echo "CM_CHIEF_PORT=$CM_CHIEF_PORT" + echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" + echo "CM_WORKER_IP=$CM_WORKER_IP" + echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" + echo "ASCEND_VISIBLE_DEVICES=$ASCEND_VISIBLE_DEVICES" + ######################################################### + else + echo "ip: $ip not available!" # 使用ranktable方案 + rankTableSolution + fi + else + echo "ip: $ip not available!" # 使用ranktable方案 + rankTableSolution + fi +fi py=$1 echo "py is $py" -- Gitee From 77a4af1b18ddad2003fe53c83b08d3338255fd41 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 15 Jun 2023 14:19:49 +0800 Subject: [PATCH 134/551] Match-id-44ffbb3949a821a9ba36fe95c7af9fc52df31980 --- src/ops_tf/hybrid_dataset_ops.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 20761f3b..34b0d7f5 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -129,6 +130,8 @@ class ReadEmbKeyV2Dynamic : public OpKernel { public: explicit ReadEmbKeyV2Dynamic(OpKernelConstructionPtr context) : OpKernel(context) { + auto logger = spdlog::stderr_color_mt("console"); + spdlog::set_default_logger(logger); spdlog::cfg::load_env_levels(); spdlog::default_logger()->set_pattern("[%H:%M:%S.%e] [%^%l%$] %v"); spdlog::debug("ReadEmbKeyV2Dynamic init"); @@ -343,6 +346,8 @@ class ReadEmbKeyV2 : public OpKernel { public: explicit ReadEmbKeyV2(OpKernelConstructionPtr context) : OpKernel(context) { + auto logger = spdlog::stderr_color_mt("console"); + spdlog::set_default_logger(logger); spdlog::cfg::load_env_levels(); spdlog::default_logger()->set_pattern("[%H:%M:%S.%e] [%^%l%$] %v"); spdlog::debug("ReadEmbKeyV2 init"); @@ -378,7 +383,7 @@ public: } batchIdsInfo.at(channelId) = 0; - const char* threadNumEnv = getenv("THREAD_NUM"); + const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); if (threadNumEnv != nullptr) { threadNum = static_cast(*threadNumEnv) - static_cast('0'); if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { -- Gitee From db2f46ed10b1a785d1548fc3f5e4b96d97bced98 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 15 Jun 2023 16:21:36 +0800 Subject: [PATCH 135/551] Match-id-0d52af2d3adf8f025d9b02b9aae6e636c2219abb --- src/ops_tf/hybrid_dataset_ops.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 34b0d7f5..77fbeeab 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -130,8 +130,12 @@ class ReadEmbKeyV2Dynamic : public OpKernel { public: explicit ReadEmbKeyV2Dynamic(OpKernelConstructionPtr context) : OpKernel(context) { - auto logger = spdlog::stderr_color_mt("console"); - spdlog::set_default_logger(logger); + if (!spdlog::get("console")) { + auto logger = spdlog::stderr_color_mt("console"); + spdlog::set_default_logger(logger); + } else { + spdlog::set_default_logger(spdlog::get("console")); + } spdlog::cfg::load_env_levels(); spdlog::default_logger()->set_pattern("[%H:%M:%S.%e] [%^%l%$] %v"); spdlog::debug("ReadEmbKeyV2Dynamic init"); @@ -346,8 +350,12 @@ class ReadEmbKeyV2 : public OpKernel { public: explicit ReadEmbKeyV2(OpKernelConstructionPtr context) : OpKernel(context) { - auto logger = spdlog::stderr_color_mt("console"); - spdlog::set_default_logger(logger); + auto logger = spdlog::get("console"); + if (!logger) { + logger = spdlog::stderr_color_mt("console"); + } + spdlog::set_default_logger(spdlog::get("console")); + spdlog::cfg::load_env_levels(); spdlog::default_logger()->set_pattern("[%H:%M:%S.%e] [%^%l%$] %v"); spdlog::debug("ReadEmbKeyV2 init"); @@ -389,10 +397,7 @@ public: if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { throw runtime_error(fmt::format("{} is not valid", threadNum)); } - } else { - threadNum = KEY_PROCESS_THREAD; } - auto keyProcess = Singleton::GetInstance(); if (!keyProcess->isRunning) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); @@ -546,7 +551,7 @@ public: int maxStep = 0; bool isTimestamp { false }; bool modifyGraph { false }; - int threadNum = 0; + int threadNum = KEY_PROCESS_THREAD; }; REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), ReadEmbKeyV2); -- Gitee From 922ddcb001bb28447274fc67bcc04d94582bfdf8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 16 Jun 2023 11:04:47 +0800 Subject: [PATCH 136/551] Match-id-0056693675d89b61ad3d3eac2650bc9ed9425598 --- src/core/utils/spinlock.h | 204 ---------- src/core/utils/unique.h | 797 -------------------------------------- 2 files changed, 1001 deletions(-) delete mode 100644 src/core/utils/spinlock.h delete mode 100644 src/core/utils/unique.h diff --git a/src/core/utils/spinlock.h b/src/core/utils/spinlock.h deleted file mode 100644 index 527ef2b5..00000000 --- a/src/core/utils/spinlock.h +++ /dev/null @@ -1,204 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: spinlock module - * Author: MindX SDK - * Create: 2022 - * History: NA - */ -#ifndef SRC_UTILS_SPINLOCK_H -#define SRC_UTILS_SPINLOCK_H - -#include -#include -#include // NOLINT - -static __inline void cpu_pause() -{ -#ifdef __GNUC__ - #ifdef __aarch64__ - __asm volatile("yield" ::: "memory"); -#elif defined(__i386__) || defined(__x86_64__) - __asm__ __volatile__("rep;nop;nop" ::: "memory"); -#else -#error "unknown architecture" -#endif -#else -#error "unknown architecture" -#endif -} - -static constexpr uint16_t g_kMaxSpinCountBeforeThreadYield = 64; - -#ifdef LOCK_NOTHING - -class SpinLock final { - public: - void lock() noexcept {} - - void unlock() noexcept {} - - bool try_lock() noexcept - { - return true; - } -}; - -#elif defined(USE_MUTEX) - -class SpinLock final { -public: - void lock() noexcept - { - mt_.lock(); - } - - bool try_lock() noexcept - { - return mt_.try_lock(); - } - - void unlock() noexcept - { - mt_.unlock(); - } - -private: - std::mutex mt_; -}; - -#else - -class SpinLock final { -public: - SpinLock() = default; - - SpinLock(SpinLock const &) = delete; - SpinLock(SpinLock &&) noexcept = delete; - SpinLock &operator=(SpinLock const &) = delete; - - inline void lock() noexcept - { - while (true) { - if (!lock_.exchange(true, std::memory_order_acquire)) { - break; - } - - uint16_t counter = 0; - while (lock_.load(std::memory_order_relaxed)) { - cpu_pause(); - if (++counter > g_kMaxSpinCountBeforeThreadYield) { - std::this_thread::yield(); - // reset counter - counter = 0; - } - } - } - } - - inline bool try_lock() noexcept - { - if (lock_.load(std::memory_order_relaxed)) { - return false; - } - return !lock_.exchange(true, std::memory_order_acquire); - } - - inline void unlock() noexcept - { - lock_.store(false, std::memory_order_release); - } - -private: - std::atomic lock_{false}; -}; - -class RWSpinLock final { - union LockData { - uint64_t raw; - struct { - uint32_t readers; - uint32_t writer; - } lock; - }; - -public: - RWSpinLock() = default; - - RWSpinLock(RWSpinLock const &) = delete; - RWSpinLock(RWSpinLock &&) noexcept = delete; - RWSpinLock &operator=(RWSpinLock const &) = delete; - - inline void r_lock() noexcept - { - LockData oldData; - LockData newData; - while (true) { - uint16_t counter = 0; - for (;;) { - oldData.raw = lock_.load(std::memory_order_relaxed); - if (oldData.lock.writer <= 0) { - break; - } - cpu_pause(); - if (++counter > g_kMaxSpinCountBeforeThreadYield) { - std::this_thread::yield(); - // reset counter - counter = 0; - } - } - - newData.lock.readers = oldData.lock.readers + 1; - newData.lock.writer = 0; - if (lock_.compare_exchange_weak(oldData.raw, newData.raw, - std::memory_order_acquire, - std::memory_order_relaxed)) { - break; - } - } - } - - inline void w_lock() noexcept - { - LockData oldData; - LockData newData; - while (true) { - uint16_t counter = 0; - for (;;) { - oldData.raw = lock_.load(std::memory_order_relaxed); - if (oldData.raw == 0) { - break; - } - cpu_pause(); - if (++counter > g_kMaxSpinCountBeforeThreadYield) { - std::this_thread::yield(); - // reset counter - counter = 0; - } - } - - newData.lock.readers = 0; - newData.lock.writer = 1; - if (lock_.compare_exchange_weak(oldData.raw, newData.raw, - std::memory_order_acquire, - std::memory_order_relaxed)) { - break; - } - } - } - - inline void r_unlock() noexcept - { - --lock_; - } - - inline void w_unlock() noexcept - { - lock_.store(0, std::memory_order_release); - } - -private: - std::atomic lock_{0}; -}; - -#endif -#endif diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h deleted file mode 100644 index 46782131..00000000 --- a/src/core/utils/unique.h +++ /dev/null @@ -1,797 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: unique keys module - * Author: MindX SDK - * Create: 2022 - * History: NA - */ -#ifndef SRC_UTILS_UNIQUE_H -#define SRC_UTILS_UNIQUE_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "securec.h" - -#include -#include - -#include "absl/container/flat_hash_map.h" - -#include "common.h" -#include "spinlock.h" -#include "time_cost.h" - -using namespace MxRec; -using namespace std; - -struct UniqueData { - void *inputData; - size_t dataSize; - int32_t *restore; - int64_t *uniqueVector; - int32_t *splitSize; - int64_t *keySend; - int32_t *idCount; - int32_t *idCountFill; -}; - -struct UniqueFlag { - bool isInt64; - bool useStatic; - bool useHot; -}; - -struct UniqueForHot { - int hotOffset; - int *hotPos; - map &hotMap; - absl::flat_hash_map &keyCountMap; -}; - -struct UniqueThreadNum { - int minThread; - int maxThread; -}; - -namespace SysytemConst { - const int LEVEL1_CACHE = 64; - const int DEFAULT_DEDUPLICATION_RATE = 4; - const int DEDUPLICATION_RATE = 2; - const int HASH_SPLIT_BUCKERT_1 = 16; - const int HASH_SPLIT_BUCKERT_2 = 32; - const int HASH_SPLIT_BUCKERT_3 = 48; - const int PRE_APPLY_MEMORY = 256; -} - -class SendCntTooSmallError : public std::exception { -}; - -class GroupMethod { -public: - inline int GroupCount() - { - return groupCount_; - } - inline uint64_t GroupId(uint64_t val) - { - return val & static_cast(groupCount_ - 1); - } - void SetGroupCount(int count) - { - groupCount_ = count; - } - -private: - int groupCount_; -}; - -class SimpleThreadPool { -public: - void SyncRun(const std::vector> &tasks) - { - std::vector> futs; - for (auto &task : tasks) { - futs.push_back(std::async(task)); - } - for (auto &fut : futs) { - fut.wait(); - } - } -}; - -template class Dedup { - static constexpr uint32_t kMinimalWorkloadPerWorker = 1 << 12; - static const int kDefaultBucketCount = 1 << 24; - static const int kDefaultBucketCountMask = kDefaultBucketCount - 1; - - template struct Meta { - static_assert(M <= MxRec::UNIQUE_MAX_BUCKET_WIDTH, "should be no larger than max bucket width"); - SpinLock lock; - volatile int8_t count; - int8_t pad[3]; - int32_t replace_base; - volatile uint64_t data[M]; - std::atomic idCount[M]; - } __attribute__((__aligned__(64))); - - struct Statistics { - uint64_t totalUniques = 0; - uint64_t totalOverflowUniques = 0; - }; - -public: - Dedup(int bucketCountPower2 = kDefaultBucketCount, int groups = 1) - : bucketCount_(bucketCountPower2), bucketCountMask_(bucketCount_ - 1), groupCount_(groups) - { - void *area = aligned_alloc(SysytemConst::LEVEL1_CACHE, sizeof(Meta) * bucketCount_); - table_ = reinterpret_cast *>(area); - Clear(bucketCount_); - } - - ~Dedup() - { - free(table_); - } - - static size_t BucketSize() - { - return sizeof(Meta); - } - - void Insert(uint64_t val) - { - int32_t h = static_cast(hash(val) & bucketCountMask_); - Meta *bucket = &table_[h]; - - int8_t count = bucket->count; - - int totalCount = 0; - - for (int i = 0; i < count; ++i) { - if (bucket->data[totalCount] == val) { - bucket->idCount[totalCount]++; - // found one - return; - } - totalCount++; - } - // try again, this time with lock acquired - if (count < N) { - std::lock_guard lg(bucket->lock); - for (int i = totalCount; i < bucket->count; ++i) { - if (bucket->data[totalCount] == val) { - bucket->idCount[totalCount]++; - // found one - return; - } - totalCount++; - } - if (totalCount < N) { - bucket->data[totalCount] = val; - bucket->count++; - bucket->idCount[totalCount]++; - return; - } - } - // shift to the overflow reservior - insertOverflow(val); - } - - int32_t GetReplaceOffset(uint64_t val) - { - int32_t h = static_cast(hash(val) & bucketCountMask_); - Meta *bucket = &table_[h]; - - int8_t count = bucket->count; - int totalCount = 0; - for (int i = 0; i < count; ++i) { - if (bucket->data[totalCount] == val) { - // found one - return bucket->replace_base + totalCount; - } - totalCount++; - } - // try again, this time with lock acquired - if (count < N) { - std::lock_guard lg(bucket->lock); - for (int i = totalCount; i < bucket->count; ++i) { - if (bucket->data[totalCount] == val) { - return bucket->replace_base + totalCount; - } - totalCount++; - } - if (totalCount < N) { - return -1; - } - } - return getReplaceOffsetFromOverflow(val); - } - - int32_t GetReplaceOffsetUnsafe(uint64_t val) - { - int32_t h = static_cast(hash(val) & bucketCountMask_); - Meta *bucket = &table_[h]; - - int totalCount = 0; - for (int i = 0; i < bucket->count; ++i) { - if (bucket->data[totalCount] == val) { - // found one - return bucket->replace_base + totalCount; - } - totalCount++; - } - if (totalCount < N) { - return -1; - } - return getReplaceOffsetFromOverflowUnsafe(val); - } - - bool Contains(uint64_t val) - { - int32_t h = static_cast(hash(val) & bucketCountMask_); - Meta *bucket = &table_[h]; - { - std::lock_guard lg(bucket->lock); - int totalCount = 0; - for (int i = 0; i < bucket->count; ++i) { - if (bucket->data[totalCount] == val) { - return true; - } - totalCount++; - } - if (totalCount < N) { - // bucket isn't filled, no hit for sure - return false; - } - } - return checkOverflow(val); - } - - void Clear(uint64_t newBucketCountPowerOf2 = 0) - { - std::lock_guard lg(overflowMutex_); - if (newBucketCountPowerOf2 > 0 && newBucketCountPowerOf2 != (uint64_t)bucketCount_) { - free(table_); - bucketCount_ = static_cast(newBucketCountPowerOf2); - bucketCountMask_ = bucketCount_ - 1; - table_ = reinterpret_cast *>(aligned_alloc(SysytemConst::LEVEL1_CACHE, - sizeof(Meta) * bucketCount_)); - } - bzero(table_, sizeof(Meta) * bucketCount_); - overflow_.clear(); - idCountOverflow_.clear(); - } - - void NewParameter() - { - int32_t newBucketCountPowerOf2 = bucketCount_; - - if (stats_.totalUniques > 0 && stats_.totalOverflowUniques > kMinimalWorkloadPerWorker) { - // Time to check the proper size of sharded tables for performance - // sake. - uint64_t shardedTableSize = newBucketCountPowerOf2 * N * groupCount_; - int largeCount = 0; - while (shardedTableSize > stats_.totalUniques * SysytemConst::DEFAULT_DEDUPLICATION_RATE && - largeCount_ != 1) { - // too large - newBucketCountPowerOf2 >>= 1; - shardedTableSize >>= 1; - largeCount++; - } - - int count = ((largeCount == 1) && (largeCount != largeCount_)) ? 2 : 1; - for (int i = 0; i < count; i++) { - if (stats_.totalOverflowUniques > kMinimalWorkloadPerWorker) { - newBucketCountPowerOf2 <<= 1; - shardedTableSize <<= 1; - } - } - - while (shardedTableSize < stats_.totalUniques + (stats_.totalUniques >> SysytemConst::DEDUPLICATION_RATE)) { - newBucketCountPowerOf2 <<= 1; - shardedTableSize <<= 1; - } - - if (largeCount_ != 1) { - largeCount_ = largeCount; - } - } - - Clear(newBucketCountPowerOf2); - bucketCount_ = newBucketCountPowerOf2; - stats_.totalUniques = 0; - stats_.totalOverflowUniques = 0; - } - - // Warning: functions below are not thread safe! - // Return the unique values - // Also update the hash-order base of each bucket - std::vector Unique() - { - int32_t replace_offset = 0; - std::vector output; - - for (int i = 0; i < bucketCount_; ++i) { - Meta *bucket = &table_[i]; - if (bucket->count == 0) { // 如果桶为0,则跳过 - continue; - } - bucket->replace_base = replace_offset; // 取桶的偏移量 - for (int j = 0; j < bucket->count; ++j) { - auto data = bucket->data[j]; - output.push_back(data); // 挨个桶取数据,然后填到output中去 - } - replace_offset += bucket->count; - } - auto it = overflow_.begin(); // 取overflow里面的,也添加到output中去 - while (it != overflow_.end()) { - output.push_back(it->first); - it->second = replace_offset++; // 记录偏移量++ - ++it; - } - return output; - } - - // Used by ShardedDedup Only! - uint32_t UniqueRaw(int64_t *output, uint32_t priorTotal, int32_t *idCount) - { - uint32_t total = priorTotal; - int32_t replace_offset = priorTotal; - - for (int i = 0; i < bucketCount_; ++i) { - Meta *bucket = &table_[i]; - if (bucket->count == 0) { - continue; - } - bucket->replace_base = replace_offset; - for (int j = 0; j < bucket->count; ++j) { - idCount[total] = static_cast(bucket->idCount[j]); - output[total++] = bucket->data[j]; - } - replace_offset += bucket->count; - } - auto it = overflow_.begin(); - int32_t totalOverflow = 0; - while (it != overflow_.end()) { - idCount[total] = idCountOverflow_[it->first]; - output[total++] = it->first; - it->second = replace_offset++; - ++it; - ++totalOverflow; - } - - // set total overflow count - stats_.totalUniques = total - priorTotal; - stats_.totalOverflowUniques = totalOverflow; - return total - priorTotal; - } - - void handleHotKey(int key, map &hotMap, map &hotPosMap, int &hotCount) - { - auto hot = hotMap.find(key); - if (hot != hotMap.end()) { - if (hot->second == -1) { - int pos = hotCount; - hotMap[key] = pos; - hotPosMap[key] = pos; - hotCount++; - } else { - hotPosMap[key] = -1; - } - } - } - - uint32_t UniqueRawForHot(int64_t *output, uint32_t priorTotal, int32_t* idCount, - map &hotMap, map &hotPosMap, int &hotCount, - absl::flat_hash_map &keyCountMap) - { - uint32_t total = priorTotal; - int32_t replace_offset = priorTotal; - - for (int i = 0; i < bucketCount_; ++i) { - Meta *bucket = &table_[i]; - if (bucket->count == 0) { - continue; - } - bucket->replace_base = replace_offset; - for (int j = 0; j < bucket->count; ++j) { - idCount[total] = static_cast(bucket->idCount[j]); - output[total++] = bucket->data[j]; - handleHotKey(static_cast(bucket->data[j]), hotMap, hotPosMap, hotCount); - keyCountMap[bucket->data[j]] = static_cast(bucket->idCount[j]); - } - replace_offset += bucket->count; - } - auto it = overflow_.begin(); - int32_t totalOverflow = 0; - while (it != overflow_.end()) { - idCount[total] = idCountOverflow_[it->first]; - keyCountMap[it->first] = idCountOverflow_[it->first]; - output[total++] = it->first; - handleHotKey(static_cast(it->first), hotMap, hotPosMap, hotCount); - it->second = replace_offset++; - ++it; - ++totalOverflow; - } - - // set total overflow count - stats_.totalUniques = total - priorTotal; - stats_.totalOverflowUniques = totalOverflow; - return static_cast(total - priorTotal); - } - - std::vector Replacement(const std::vector &input, std::vector *unique = nullptr, - int32_t base = 0) - { - std::vector output; - if (unique) { - *unique = std::move(Unique()); - } - for (auto &val : input) { - output.push_back(GetReplaceOffsetUnsafe(val) + base); - } - return output; - } - -private: - int bucketCount_; - int bucketCountMask_; - int upperRangeIndex_; - int groupCount_; - int largeCount_ { 0 }; - Meta *table_; - std::unordered_map overflow_; - std::unordered_map idCountOverflow_; - SpinLock overflowMutex_; - Statistics stats_; - - static inline uint64_t hash(uint64_t val) - { - return val ^ (val >> SysytemConst::HASH_SPLIT_BUCKERT_1) ^ (val >> SysytemConst::HASH_SPLIT_BUCKERT_2) ^ - (val >> SysytemConst::HASH_SPLIT_BUCKERT_3); - } - - void insertOverflow(uint64_t val) - { - std::lock_guard lg(overflowMutex_); - auto it = overflow_.find(val); - if (it == overflow_.end()) { - overflow_[val] = 0; - } - idCountOverflow_[val]++; - } - - bool checkOverflow(uint64_t val) - { - std::lock_guard lg(overflowMutex_); - return overflow_.find(val) != overflow_.end(); - } - - int32_t getReplaceOffsetFromOverflow(uint64_t val) - { - std::lock_guard lg(overflowMutex_); - auto it = overflow_.find(val); - return (it != overflow_.end()) ? it->second : -1; - } - - int32_t getReplaceOffsetFromOverflowUnsafe(uint64_t val) - { - auto it = overflow_.find(val); - return (it != overflow_.end()) ? it->second : -1; - } -}; // Dedup - -class OneSimpleGroupMethod { -public: - inline int GroupCount() - { - return 1; - } - inline int32_t GroupId(uint64_t val) - { - spdlog::info("val:{} is", val); - return 0; - } -}; - -template class ShardedDedup { - static constexpr uint32_t kMinimalWorkloadPerWorker = 1 << 13; - static constexpr int kDefaultDuplicateRatio = 4; - static constexpr int kMinimalWorkerCount = 2; - static constexpr int kMaximalWorkerCount = 32; - -public: - using DedupT = Dedup; - - ShardedDedup(const GroupMethod &groupMethod, int desiredSize, int send_cnt, - int estimatedDuplicateRatio = kDefaultDuplicateRatio) - : groupMethod_(groupMethod), bucketCountPower2_(SysytemConst::PRE_APPLY_MEMORY), send_cnt_(send_cnt) - { - const int numOfGroupsInShard = groupMethod_.GroupCount(); - - desiredSize += (desiredSize >> 1); - while (bucketCountPower2_ * BucketWidth * numOfGroupsInShard * estimatedDuplicateRatio < desiredSize) { - bucketCountPower2_ <<= 1; - } - for (int32_t i = 0; i < numOfGroupsInShard; ++i) { - dedupShards_.emplace_back(new DedupT(bucketCountPower2_, numOfGroupsInShard)); - } - } - - ~ShardedDedup() {} - - int NumOfGroupsInEachShard() const - { - return static_cast(groupMethod_.GroupCount()); - } - - /* * - * @brief given the input vector, compute unique values and partition - * them into regions delimited by the partition boundaries passed - * as ctor input (see above) - * - * - * @param pool thread pool which is used by unique task - * @param input the data input - * @param size the size of the data input - * @param uniqueVector unique values - * @param uniqueSize unique of sizes - * @param output the output vector of index values - * @param uniqueIds unique ids final - * @param idCount key count - * @param idCountFill key count and filled zero by send count - * @param isStatic output and idCount Fill isFilled - * @param isInt64 input data is int64 or int32 - * @param useHot hot embedding - * @param offset add hot map size - * @param hotMap hot key map - * @param keyCountMap record key count - * @param minThreadCount min thread number - * @param maxThreadCount max thread number - */ - template - int Compute(ThreadPool *pool, UniqueData &uniqueData, UniqueFlag &uniqueFlag, - UniqueForHot &uniqueForHot, UniqueThreadNum &uniqueThreadNum) - { - // Now kick off the computation - - void *input = uniqueData.inputData; - const size_t size = uniqueData.dataSize; - int64_t *uniqueVector = uniqueData.uniqueVector; - int32_t *uniqueSize = uniqueData.splitSize; - int32_t *output = uniqueData.restore; - int64_t *uniqueIds = uniqueData.keySend; - int32_t *idCount = uniqueData.idCount; - int32_t *idCountFill = uniqueData.idCountFill; - - map &hotMap = uniqueForHot.hotMap; - absl::flat_hash_map &keyCountMap = uniqueForHot.keyCountMap; - int offset = uniqueForHot.hotOffset; - int *hotPos = uniqueForHot.hotPos; - - bool useStatic = uniqueFlag.useStatic; - bool useHot = uniqueFlag.useHot; - bool isInt64 = uniqueFlag.isInt64; - - uint32_t minThreadCount = uniqueThreadNum.minThread; - uint32_t maxThreadCount = uniqueThreadNum.maxThread; - - std::vector uniqueSizeVector; - uniqueSizeVector.resize(groupMethod_.GroupCount()); - - size_t inputSize = size; - - uint32_t threadNum = static_cast(inputSize + kMinimalWorkloadPerWorker - 1) / kMinimalWorkloadPerWorker; - threadNum = std::min(maxThreadCount, std::max(threadNum, minThreadCount)); - - size_t partSize = (inputSize + threadNum - 1) / threadNum; - - std::vector> tasks; - - for (uint32_t i = 0; i < threadNum; ++i) { - const int numOfGroupsInShard = groupMethod_.GroupCount(); - tasks.push_back([this, i, input, inputSize, partSize, numOfGroupsInShard, isInt64]() -> TaskReturnType { - for (uint64_t j = i * partSize; j < std::min(inputSize, (i + 1) * partSize); ++j) { - auto val = isInt64 ? ((int64_t *)input)[j] : ((int32_t *)input)[j]; - auto group = groupMethod_.GroupId(val); - dedupShards_[group]->Insert(val); - } - return TaskReturnType {}; - }); - } - spdlog::debug("unique finish insert"); - - if (!tasks.empty()) { - pool->SyncRun(tasks); - } - - std::vector baseVector; - // Collect Unique and base vectors - uint32_t base = 0; - uint32_t total = 0; - - int hotCount = 0; - map hotPosMap; - - for (int j = 0; j < groupMethod_.GroupCount(); ++j) { - uint32_t inGroupTotal = 0; - if (useHot) { - inGroupTotal = dedupShards_[j]->UniqueRawForHot(uniqueVector, total, idCount, - hotMap, hotPosMap, hotCount, - keyCountMap); - } else { - inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueVector, total, idCount); - } - uniqueSizeVector[j] = inGroupTotal; - total += inGroupTotal; - } - - spdlog::debug("unique finish uniqueRaw"); - - baseVector.push_back(base); - base += total; - - partSize = ((partSize) + 63ul) & ~63ul; - - int32_t *beginPtr = output; - int32_t *finishPtr = beginPtr + inputSize; - - int32_t *partBeginPtr = beginPtr; - int32_t *partEndPtr = - reinterpret_cast(((reinterpret_cast(partBeginPtr + partSize)) + 63ul) & ~63ul); - - if (uniqueFlag.useStatic) { - for (int i = 0; i < groupMethod_.GroupCount(); i++) { - if (send_cnt_ < uniqueSizeVector[i]) { - spdlog::error("sendCnt should not be smaller than uniqueSize, sendCnt {}, uniqueSize {}", send_cnt_, - uniqueSizeVector[i]); - } - } - } - - std::vector totalUniqueSize; - totalUniqueSize.resize(groupMethod_.GroupCount()); - - size_t totalNumber = 0; - for (int i = 0; i < groupMethod_.GroupCount(); i++) { - totalUniqueSize[i] = totalNumber; - totalNumber += uniqueSizeVector[i]; - } - spdlog::debug("uniqueSize: {}", totalNumber); - - tasks.clear(); - while (partBeginPtr < finishPtr) { - if (partEndPtr > finishPtr) { - partEndPtr = finishPtr; - } - if (partBeginPtr < partEndPtr) { - // Due to cacheline alignment computation, the actual number of - // threads created here may not match threadNum exactly but - // should be +/-1 off. - const int numOfGroupsInShard = groupMethod_.GroupCount(); - tasks.push_back([this, input, &baseVector, beginPtr, partBeginPtr, partEndPtr, numOfGroupsInShard, - totalUniqueSize, useStatic, isInt64, useHot, offset, hotMap, hotPos, - hotPosMap]() -> TaskReturnType { - for (int32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { - auto val = isInt64 ? ((int64_t *)input)[ptr - beginPtr] : ((int32_t *)input)[ptr - beginPtr]; - int32_t group = static_cast(groupMethod_.GroupId(val)); - uint32_t fillOffset = GetFillOffset(useStatic, baseVector, totalUniqueSize, val, group); - ComputeRestore(useHot, offset, hotMap, hotPos, hotPosMap, ptr, val, fillOffset); - } - return TaskReturnType {}; - }); - } - partBeginPtr = partEndPtr; - partEndPtr += partSize; - } - - if (!tasks.empty()) { - pool->SyncRun(tasks); - } - - TileAndFill(groupMethod_.GroupCount(), uniqueVector, uniqueSize, uniqueIds, idCount, idCountFill, useStatic, - uniqueSizeVector); - - return 0; - } - - void ComputeRestore(bool useHot, int offset, const map &hotMap, int *hotPos, - const map &hotPosMap, - int32_t *ptr, int64_t val, uint32_t fillOffset) const - { - auto hot = hotPosMap.find(val); - if (!useHot) { - *ptr = fillOffset; - } else if (hot == hotPosMap.end()) { - *ptr = offset + fillOffset; - } else if (hot->second == -1) { - *ptr = hotMap.find(val)->second; - } else { - hotPos[hot->second] = fillOffset; - *ptr = hot->second; - } - } - - uint32_t GetFillOffset(bool useStatic, const vector &baseVector, const vector &totalUniqueSize, - int64_t val, int32_t group) - { - if (!useStatic) { - return static_cast(dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0]); - } else { - return static_cast(dedupShards_[group]->GetReplaceOffsetUnsafe(val) + baseVector[0] + send_cnt_ * group - totalUniqueSize[group]); - } - } - - void TileAndFill(int groupCount, const int64_t *uniqueVector, int32_t *uniqueSize, int64_t *uniqueIds, - const int32_t *idCount, int32_t *idCountFill, bool useStatic, - const std::vector &uniqueSizeVector) const - { - int start = 0; - int index = 0; - - for (int i = 0; i < groupCount; i++) { - if (i > 0) { - index += static_cast(uniqueSizeVector[i - 1]); - } - - if (useStatic) { - start = i * send_cnt_; - } else { - start = index; - } - - if (uniqueSizeVector[i] > 0) { - size_t mem_size = uniqueSizeVector[i] * sizeof(int64_t); - auto rc = memcpy_s(uniqueIds + start, mem_size, uniqueVector + index, mem_size); - if (rc != 0) { - spdlog::error("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}", mem_size); - throw std::runtime_error( - fmt::format("[TileAndFill/uniqueIds] memcpy_s failded... mem_size: {}", mem_size).c_str()); - } - mem_size = uniqueSizeVector[i] * sizeof(int32_t); - rc = memcpy_s(idCountFill + start, mem_size, idCount + index, mem_size); - if (rc != 0) { - spdlog::error("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", mem_size); - throw std::runtime_error(fmt::format("[TileAndFill/idCountFill] memcpy_s failded... mem_size: {}", - mem_size).c_str()); - } - } - - long int fillLen = send_cnt_ - uniqueSizeVector[i]; - if (useStatic) { - for (int j = 0; j < fillLen; j++) { - uniqueIds[start + uniqueSizeVector[i] + j] = -1; - idCountFill[start + uniqueSizeVector[i] + j] = 0; - } - } - - uniqueSize[i] = static_cast(uniqueSizeVector[i]); - } - } - - void StartNewRound() - { - for (auto &s : dedupShards_) { - s->NewParameter(); - } - } - -private: - GroupMethod groupMethod_; - int32_t bucketCountPower2_; - std::vector> dedupShards_; - int32_t send_cnt_; -}; -#endif -- Gitee From 761d2d927d4b56fa59eafbadbb34e7d08b18f288 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 16 Jun 2023 14:21:42 +0800 Subject: [PATCH 137/551] Match-id-8bed95c6e40a6a593e411811ac2b5bbcc06a8668 --- .gitignore | 1 + MANIFEST.in | 2 + build/build.sh | 1 + example/little_demo/run.sh | 5 +- setup.py | 2 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 14 +- src/core/key_process/key_process.cpp | 70 +++--- src/core/key_process/key_process.h | 4 +- src/ops_tf/hybrid_dataset_ops.cpp | 19 +- src/tests/key_process/key_process_test.cpp | 4 +- tools/perf/fast.sh | 264 +++++++++++++++++++++ 11 files changed, 341 insertions(+), 45 deletions(-) create mode 100755 MANIFEST.in create mode 100755 tools/perf/fast.sh diff --git a/.gitignore b/.gitignore index 990936c2..d18f1c3f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea/ cmake-build-debug/ +.vscode/ diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100755 index 00000000..3fc1dbb6 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include mx_rec/tools/* +include mx_rec/tools/*/* \ No newline at end of file diff --git a/build/build.sh b/build/build.sh index fb684738..1b7a8699 100644 --- a/build/build.sh +++ b/build/build.sh @@ -169,6 +169,7 @@ gen_wheel_file() touch "${src_path}"/libasc/__init__.py remove "${ROOT_DIR}"/mx_rec/libasc mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec + cp -rf "${ROOT_DIR}"/tools "${ROOT_DIR}"/mx_rec python3 setup.py bdist_wheel --plat-name=linux_$(arch) mkdir -p "$1" mv dist/mx_rec*.whl "$1" diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 2c0c80b9..f5c4e739 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -34,9 +34,10 @@ export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 -################################################ - +################# 性能调优相关 #################### export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 +export FAST_UNIQUE=0 #if use fast unique +################################################ # 帮助信息,不需要修改 if [[ $1 == --help || $1 == -h ]];then diff --git a/setup.py b/setup.py index 36b406fe..4a169a7c 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ setup( ), package_dir={}, # other file - package_data={'': ['*.yml', '*.sh', '*.so*']}, + package_data={'': ['tools/*', 'tools/*/*', '*.yml', '*.sh', '*.so*']}, # dependency python_requires='>=3.7.5' ) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index aebfd006..0cb99214 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -422,7 +422,7 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) infoVecs->pop_back(); restoreQueue->Pushv(*infoVecs); } - TIME_PRINT("getAllTensorTC TimeCost(ms):{}", getAllTensorTC.ElapsedMS()); + TIME_PRINT("getAllTensorTC(ms):{}", getAllTensorTC.ElapsedMS()); } batchId++; return true; @@ -446,7 +446,7 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) spdlog::info("SendLookupAndRestore batchId: {}, name: {}, channelId: {}", batchId, embInfo.name, channelId); - TimeCost sendTensorTC; + TimeCost sendTensorsTC; omp_set_num_threads(SEND_TENSOR_TYPE_NUM); #pragma omp parallel sections { @@ -457,7 +457,7 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) auto lookUpKeys = lookUpKeysQueue->WaitAndPop(); hdTransfer->Send(TransferChannel::LOOKUP, lookUpKeys, channelId, name); } - TIME_PRINT("LOOKUP Send TimeCost(ms):{}", sendLookupTC.ElapsedMS()); + TIME_PRINT("sendLookupTC(ms):{}", sendLookupTC.ElapsedMS()); } #pragma omp section { @@ -466,10 +466,10 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) auto restore = restoreQueue->WaitAndPop(); hdTransfer->Send(TransferChannel::RESTORE, restore, channelId, name); } - TIME_PRINT("RESTORE Send TimeCost(ms):{}", sendRestoreTC.ElapsedMS()); + TIME_PRINT("sendRestoreTC(ms):{}", sendRestoreTC.ElapsedMS()); } } - TIME_PRINT("sendTensorTC TimeCost(ms):{}", sendTensorTC.ElapsedMS()); + TIME_PRINT("sendTensorsTC(ms):{}", sendTensorsTC.ElapsedMS()); } batchId++; return true; @@ -526,6 +526,8 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, if (lookupKeys.empty()) { remainBatchOut = false; } + + TimeCost getAndSendTensorsTC; auto restore = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); hdTransfer->Send(TransferChannel::RESTORE, *restore, channelId, embName); vector tmpData; @@ -537,6 +539,8 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } + TIME_PRINT("getAndSendTensorsTC(ms):{}", getAndSendTensorsTC.ElapsedMS()); + if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch spdlog::warn(MGMT + "embName {}[{}]{},iBatch:{} freeSize not enough, {}", embName, channelId, batchId, iBatch, lookupKeys.size()); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 9171d327..42680f2e 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -266,11 +266,8 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES spdlog::stopwatch sw; try { while (true) { - TimeCost getAndProcesTC; - TimeCost getBatchTC; + TimeCost getAndProcessTC; batch = GetBatchData(channel, id); // get batch data from SingletonQueue - TIME_PRINT("GetBatchData TimeCost(ms):{}", getBatchTC.ElapsedMS()); - if (batch == nullptr) { break; } @@ -280,7 +277,7 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES bool ret = false; if (PerfConfig::fastUnique) { InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); - ret = KeyProcessTaskHelperWithUnique(batch, unique, channel, id); + ret = KeyProcessTaskHelperWithFastUnique(batch, unique, channel, id); } else { ret = KeyProcessTaskHelper(batch, channel, id); } @@ -288,8 +285,8 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES if (!ret) { break; } - spdlog::info(KEY_PROCESS "key process cost:{}, get data time:{} batch {}[{}]:{} ", - Format2Ms(sw), getBatchTime, batch->name, batch->channel, batch->batchId); + spdlog::info(KEY_PROCESS "getAndProcessTC(ms):{}, key process cost:{}, get data time:{} batch {}[{}]:{} ", + getAndProcessTC.ElapsedMS(), Format2Ms(sw), getBatchTime, batch->name, batch->channel, batch->batchId); auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } @@ -306,6 +303,7 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < vector & restore, vector & hotPos, vector >& keyCount) { + TimeCost UniqueTC; if (m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { tie(splitKeys, restore, keyCount) = HashSplit_withFAAE(batch); // 按存储dev id切分并去重 @@ -316,18 +314,19 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < tie(splitKeys, restore) = HashSplit(batch); // 按存储dev id切分并去重 } } + TIME_PRINT("UniqueTC(ms):{}", UniqueTC.ElapsedMS()); } -bool KeyProcess::KeyProcessTaskHelperWithUnique(unique_ptr& batch, UniquePtr& unique, - int channel, int id) +bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, UniquePtr& unique, + int channel, int id) { // tuple for keyRec restore hotPos scAll countRecv isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; - TimeCost tc; + TimeCost fastUniqueTC; UniqueInfo uniqueInfo; - ProcessBatchWithUniqueCompute(batch, unique, id, uniqueInfo); - TIME_PRINT("no copy ProcessBatchWithUniqueCompute TimeCost(ms):{}", tc.ElapsedMS()); + ProcessBatchWithFastUnique(batch, unique, id, uniqueInfo); + TIME_PRINT("ProcessBatchWithFastUnique(ms):{}", fastUniqueTC.ElapsedMS()); // 特征准入&淘汰 if (isWithFAAE && @@ -342,9 +341,9 @@ bool KeyProcess::KeyProcessTaskHelperWithUnique(unique_ptr& batch, // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) if (rankInfo.noDDR) { - TimeCost key2OffsetTc; + TimeCost key2OffsetTC; Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv); - TIME_PRINT("Key2Offset TimeCost(ms):{}", key2OffsetTc.ElapsedMS()); + TIME_PRINT("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); } if (!rankInfo.useStatic) { // Static all2all,need send count auto embName = batch->name; @@ -367,9 +366,9 @@ bool KeyProcess::KeyProcessTaskHelperWithUnique(unique_ptr& batch, tensors->push_back(Vec2TensorI32(uniqueInfo.all2AllInfo.keyRecv)); } } - TimeCost pushTensorTc; + TimeCost pushResultTC; PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv); - TIME_PRINT("pushTensorToListTC TimeCost(ms):{}", pushTensorTc.ElapsedMS()); + TIME_PRINT("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); return true; } @@ -380,17 +379,16 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe vector hotPos; vector> keyCount; - TimeCost hashSplit_tc; HashSplitHelper(batch, splitKeys, restore, hotPos, keyCount); - auto [lookupKeys, scAll, ss] = ProcessSplitKeys(batch, id, splitKeys); + vector countRecv; if (m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { countRecv = GetCountRecv(batch, id, keyCount, scAll, ss); } + BuildRestoreVec(batch, ss, restore, static_cast(hotPos.size())); - TIME_PRINT("HashSplit TimeCost(ms):{}", hashSplit_tc.ElapsedMS()); // 特征准入&淘汰 if (m_featureAdmitAndEvict.GetFunctionSwitch() && @@ -407,6 +405,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe if (rankInfo.noDDR) { Key2Offset(batch->name, lookupKeys); } + if (!rankInfo.useStatic) { // Static all2all,need send count auto embName = batch->name; if (batch->modifyGraph) { @@ -415,6 +414,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe SendA2A(scAll, embName, batch->channel, batch->batchId); } + TimeCost pushResultTC; auto tensors = make_unique>(); tensors->push_back(Vec2TensorI32(restore)); if (rankInfo.useHot) { @@ -429,12 +429,14 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe } } PushResult(batch, move(tensors), lookupKeys); + TIME_PRINT("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); return true; } vector KeyProcess::GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss) { + TimeCost getCountRecvTC; if (rankInfo.useStatic) { for (auto& cnt: keyCount) { cnt.resize(GetSendCount(batch->name, batch->channelName, batch->modifyGraph), 0); @@ -457,6 +459,7 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, countRecv.resize(rs.back() + rc.back()); MPI_Alltoallv(countSend.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), rc.data(), rs.data(), MPI_UINT32_T, comm[batch->channel][id]); + TIME_PRINT("getCountRecvTC(ms)(with-all2all):{}", getCountRecvTC.ElapsedMS()); return countRecv; } @@ -541,8 +544,8 @@ size_t KeyProcess::GetKeySize(const unique_ptr &batch) return size; } -void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &batch, UniquePtr& unique, - int id, UniqueInfo& uniqueInfoOut) +void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch, UniquePtr& unique, + int id, UniqueInfo& uniqueInfoOut) { EASY_FUNCTION(profiler::colors::Purple) EASY_VALUE("batchId", batch->batchId) @@ -573,17 +576,17 @@ void KeyProcess::ProcessBatchWithUniqueCompute(const unique_ptr &ba uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueVector.data()); uniqueOut.uniqueIdCnt = 0; - TimeCost unique_tc; + TimeCost uniqueTC; int ret = unique->DoEnhancedUnique(uniqueIn, uniqueOut); EASY_END_BLOCK - TIME_PRINT("UniqueCompute TimeCost(ms):{} ret:{}", unique_tc.ElapsedMS(), ret); + TIME_PRINT("FastUniqueCompute(ms):{}, ret:{}", uniqueTC.ElapsedMS(), ret); vector sc; HandleHotAndSendCount(batch, uniqueInfoOut, keySendInfo, sc, splitSize); All2All(sc, id, batch->channel, keySendInfo, uniqueInfoOut.all2AllInfo); - spdlog::debug(KEY_PROCESS "ProcessBatchWithUniqueCompute get batchId:{}, batchSize:{}, channel:{}, " + spdlog::debug(KEY_PROCESS "ProcessBatchWithFastUnique get batchId:{}, batchSize:{}, channel:{}, " "channelName:{}, name:{}, restore:{}, keyCount:{}", batch->batchId, batch->Size(), batch->channel, batch->channelName, batch->name, uniqueInfoOut.restore.size(), keySendInfo.keyCount.size()); @@ -644,9 +647,9 @@ void KeyProcess::ComputeHotPos(const unique_ptr &batch, map& sc, int id, int channel, KeySendInfo& keySendInfo, All2AllInfo& all2AllInfoOut) { - TimeCost get_sc_all; + TimeCost getScAllTC; GetScAllForUnique(sc, id, channel, all2AllInfoOut.scAll); // Allgather通信获取所有(不同rank相同thread id的) - TIME_PRINT("GetScAll TimeCost(ms):{}", get_sc_all.ElapsedMS()); + TIME_PRINT("GetScAll TimeCost(ms):{}", getScAllTC.ElapsedMS()); TimeCost all2allTC; auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 @@ -673,6 +676,7 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, vector& splitKeys) -> tuple, vector> { + TimeCost processSplitKeysTC; EASY_FUNCTION(profiler::colors::Purple) EASY_VALUE("batchId", batch->batchId) spdlog::info(KEY_PROCESS "ProcessSplitKeys start batchId:{}, channel:{}", batch->batchId, batch->channel); @@ -696,7 +700,11 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, keySend.insert(keySend.end(), i.begin(), i.end()); } keys_t keyRecv; + + TimeCost getScAllTC; auto scAll = GetScAll(sc, id, batch->channel); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 + TIME_PRINT("getScAllTC(ms)(AllReduce-AllGather):{}", getScAllTC.ElapsedMS()); + auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 vector rc; // receive count for (int i = 0; i < rankInfo.rankSize; ++i) { @@ -708,13 +716,17 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, spdlog::trace(KEY_PROCESS "MPI_Alltoallv begin. rank {} thread {} batch {} {}", rankInfo.rankId, id, batch->batchId, batch->name); EASY_BLOCK("all2all") + + TimeCost uniqueAll2AllTC; MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[batch->channel][id]); + TIME_PRINT("uniqueAll2AllTC(ms):{}", uniqueAll2AllTC.ElapsedMS()); + EASY_END_BLOCK spdlog::trace(KEY_PROCESS "MPI_Alltoallv finish. rank {} thread {} batch {} {}", rankInfo.rankId, id, batch->batchId, batch->name); - + TIME_PRINT("processSplitKeysTC(ms):{}", processSplitKeysTC.ElapsedMS()); return { keyRecv, scAll, ss }; } @@ -981,6 +993,7 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, in void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) { + TimeCost key2OffsetTC; EASY_FUNCTION(profiler::colors::Blue600) std::lock_guard lk(mut); // lock for PROCESS_THREAD auto& key2Offset = keyOffsetMap[embName]; @@ -1027,6 +1040,7 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) throw std::runtime_error("dev cache overflow!"); } spdlog::debug("current dev emb usage:{}/{}", maxOffsetTmp, embInfos[embName].devVocabSize); + TIME_PRINT("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); } /* @@ -1039,6 +1053,7 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, vector& restoreVec, int hotPosSize) const { + TimeCost buildRestoreVecTC; EASY_FUNCTION() int hotNum = 0; bool spdDebug = (spdlog::get_level() == spdlog::level::debug); @@ -1052,6 +1067,7 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vec } } spdlog::debug("hot num in all:{}/{}", hotNum, batch->Size()); + TIME_PRINT("buildRestoreVecTC(ms):{}", buildRestoreVecTC.ElapsedMS()); } class EmptyList : public std::exception { diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 5486b2fc..6efc4181 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -127,7 +127,7 @@ namespace MxRec { bool KeyProcessTaskHelper(unique_ptr& batch, int channel, int id); - bool KeyProcessTaskHelperWithUnique(unique_ptr &batch, UniquePtr& unique, + bool KeyProcessTaskHelperWithFastUnique(unique_ptr &batch, UniquePtr& unique, int channel, int id); auto ProcessSplitKeys(const unique_ptr& batch, int id, @@ -138,7 +138,7 @@ namespace MxRec { void InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, const unique_ptr & batch, UniquePtr& unique); - void ProcessBatchWithUniqueCompute(const unique_ptr &batch, UniquePtr& unique, + void ProcessBatchWithFastUnique(const unique_ptr &batch, UniquePtr& unique, int id, UniqueInfo& uniqueInfoOut); size_t GetKeySize(const unique_ptr &batch); diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 77fbeeab..072f554a 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -22,15 +22,18 @@ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/example/example.pb.h" + #include "key_process/key_process.h" #include "key_process/feature_admit_and_evict.h" #include "utils/common.h" #include "utils/safe_queue.h" #include "utils/singleton.h" +#include "utils/time_cost.h" using namespace tensorflow; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; + using namespace std; using namespace chrono; using namespace MxRec; @@ -220,10 +223,12 @@ public: OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); out(0) = batchId; + + TimeCost enqueueTC; EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); - TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", - Format2Ms(sw), Format2Ms(staticSw), - channelId, batchId); + TIME_PRINT(KEY_PROCESS "ReadEmbKeyV2Dynamic read batch cost(ms):{}, elapsed from last(ms):{}, " + "enqueueTC(ms):{}, batch[{}]:{}", + Format2Ms(sw), Format2Ms(staticSw), enqueueTC.ElapsedMS(), channelId, batchId); staticSw.reset(); } @@ -445,10 +450,12 @@ public: OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); out(0) = batchId; + + TimeCost enqueueTC; EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); - TIME_PRINT(KEY_PROCESS "read batch cost: {}, elapsed from last:{}, batch[{}]:{}", - Format2Ms(sw), Format2Ms(staticSw), - channelId, batchId); + TIME_PRINT(KEY_PROCESS "ReadEmbKeyV2Static read batch cost(ms):{}, elapsed from last(ms):{}, " + "enqueueTC(ms):{}, batch[{}]:{}", + Format2Ms(sw), Format2Ms(staticSw), enqueueTC.ElapsedMS(), channelId, batchId); staticSw.reset(); } diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 61a3745c..87a836c9 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -461,7 +461,7 @@ TEST_F(KeyProcessTest, GetKeySize) process.GetKeySize(batch); } -TEST_F(KeyProcessTest, ProcessBatchWithUniqueCompute) +TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) { PrepareBatch(); @@ -486,7 +486,7 @@ TEST_F(KeyProcessTest, ProcessBatchWithUniqueCompute) unique->Initialize(uniqueConf); spdlog::info("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); - process.KeyProcessTaskHelperWithUnique(batch, unique, channel, id); + process.KeyProcessTaskHelperWithFastUnique(batch, unique, channel, id); spdlog::info("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, hotPos); }; // for clean code diff --git a/tools/perf/fast.sh b/tools/perf/fast.sh new file mode 100755 index 00000000..63fc2759 --- /dev/null +++ b/tools/perf/fast.sh @@ -0,0 +1,264 @@ +#! /bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. +# Description: performace analysis tool +# Author: MindX SDK +# Create: 2023 +# History: NA + +# -----------------------------------------ReadMe Begin-------------------------------------------- +# 1. 功能描述 +# 本工具用来分析模型执行过程中pipeline中各个pipe的耗时、以及各个pipe中的子模块(Step)的耗时,以便于发现系统瓶颈。 +# (pipeline的基本原理是:每个pipe的耗时近似相等,pipe之间的耗时能够互相掩盖起来,这样,才能减少堵塞和等待,提升吞吐。) +# +# 2. 使用方法 +# bash fast.sh your_log_file.log +# +# 3. 注意事项 +# 基于spdlog::info,mxRec中添加了TimeCost打点日志,因此,在执行前务必确保run.sh中设置 +# SPDLOG_LEVEL=info 或者 SPDLOG_LEEL=debug (如果没有设置,本工具会退出,并给予提示) +# +# 4. 解读结果 +# (1) Pipeline: 整个Pipeline由多个Pipe串行构成,性能分析结果分Pipe呈现,例如Pipe-1/Pipe-2/Pipe-3/Pipe-4等; +# (2) Pipe: 每个Pipe级都会有一个整个耗时。(我们希望每个Pipe的耗时近似相等,这样Pipe之间才能互相掩盖,流水线效率才最高) +# (3) 子模块(Step):一个Pipe可能有多个串行的子模块(Step)构成、子模块又可能包含下一级子模块(SubStep)。因此,在性能分级报告中, +# 下一级的子模块耗时用--开头,再下一级的子模块耗时用----开头;依次类推;(上一级的耗时中包含了下一级的耗时) +# +# 5. 性能调优 +# 通过分析报告,我们可能会发现: +# (1)耗时特别长的Pipe; +# (2)耗时特别长的子模块; +# 需要具体问题具体分析,针对性的调优或者开展深度优化。 +# 例如:如果发现Tensorflow数据解析慢(Pipe-1),导致供应不足,可以调节Tensorflow侧解析数据的num_parallel参数; +# 如果发现CPU打满而导致数据预处理阻塞(Pipe 2: Data Preprocess),则可以调低KEY_PROCESS_THREAD_NUM (默认为6); +# 如果发现H2D阻塞(Pipe 4: H2D Send Tensors (no DDR)),则可能需要排查NPU侧GetNext或者DNN训练是否堵塞。 +# 然而,对于一些深层的问题,可能涉及到需要开展深度优化:比如Pipe拆分、串行改并行、锁优化、执行逻辑调整。 +# 另外,本工具也可以作为性能优化的参考,例如优化了某个子模块,可以对比观察(优化前vs优化后)该子模块的耗时, +# 同时对比观察端到端耗时、吞吐变化等。 +# +# 6. 该工具也需要不断升级,和代码同步更新,欢迎大家修改、完善。Good Luck! +# -----------------------------------------ReadMe End-------------------------------------------- +#set -x + +LOG_INFO() { echo -e "\033[1;4;32m$1\033[0m" ; } +LOG_NOTICE() { echo -e "\033[1;4;45m$1\033[0m" ; } +LOG_WARN() { echo -e "\033[1;31m[WARN]$1\033[0m" ; } +LOG_ERROR() { echo -e "\033[1;31m[Error]$1\033[0m" ; } + +logfile=$1 + +validate_options() +{ + if [ $# -ne 1 ]; then + LOG_ERROR "NO log_file" + echo "[Usage]: bash $0 log_file" + exit 1 + fi +} + +check_spdlog_level() +{ + $(grep 'ReadEmbKeyV2Static' $logfile > /dev/null 2>&1) + if [ $? != 0 ]; then + $(grep 'ReadEmbKeyV2Dynamic' $logfile > /dev/null 2>&1) + if [ $? != 0 ]; then + LOG_ERROR "No timecost-related logs, please check 'mpi_args' in your run.sh, + make sure SPDLOG_LEVEL=info or SPDLOG_LEEL=debug, and run again!" + exit 1 + fi + fi +} + +parse_pipe_1_data_parser() +{ + LOG_NOTICE "Pipe-1: Data Parser" + + $(grep 'ReadEmbKeyV2Dynamic' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + LOG_INFO "Step-1.x ReadEmbKeyV2 Dynamic" + else + LOG_INFO "Step-1.x ReadEmbKeyV2 Static" + fi + + grep 'read batch cost(ms)' $logfile | cut -d" " -f7| \ + awk -F "[:,]" '{sum+=$2} END {printf "read batch cost: avg=%0.1f\n", sum/NR}' + + grep 'enqueueTC(ms)' $logfile | grep -v 'timeout' | cut -d" " -f11 | \ + awk -F "[:,]" '{sum+=$2} END {printf "--|enqueueTC: avg=%0.1f\n", sum/NR}' + + grep 'elapsed from last(ms)' $logfile | grep -v 'timeout' | cut -d" " -f10 | \ + awk -F "[:,]" '{print $2}' | \ + awk 'BEGIN {sum=0; count=0} {if($1<1000) {sum+=$NF; count++} } END \ + {printf "elapsed from last: avg=%0.1f\n", sum/count}' +} + +parse_pipe_2_key_process() +{ + LOG_NOTICE "Pipe-2: Data Preprocess" + + grep 'getAndProcessTC(ms)' $logfile | cut -d" " -f5 | \ + awk -F"[:,]" '{print $2}' | \ + awk 'BEGIN{count=0; total=0;} {if ($1<2000) {total+=$NF; count++;}} END \ + {printf "getAndProcessTC(filter>2000ms): avg=%0.3f\n", total/count}' + + LOG_INFO "Step-2.1 GetBatchData" + + grep 'get data time' $logfile | cut -d" " -f11 | \ + awk -F"[:,]" '{print $2}' | \ + awk 'BEGIN { max=0 } { sum+=$NF; if($NF>max) max=$NF } END \ + {printf "--|get data time: total=%d, max=%0.1f, avg=%0.1f\n", NR, max, sum/NR}' + + grep 'get data time' $logfile | cut -d" " -f11 | \ + awk -F"[:,]" '{print $2}' | \ + awk 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ + {printf "--|get data time(filter>2000ms): count=%d, avg=%0.1f\n", count, sum/count}' + + grep 'get data time' $logfile | cut -d" " -f11 | \ + awk -F"[:,]" '{print $2}' | \ + awk 'BEGIN { total=0; none_zero_ms_num=0 } { total++; if($NF>0) none_zero_ms_num++ } END \ + {printf "--|get data time: total=%d, none_zero_ms_num=%d, none_zero_ms_rate=%0.3f, zero_ms_rate=%0.3f\n", \ + total, none_zero_ms_num, none_zero_ms_num/total, (1-none_zero_ms_num/total)}' + + LOG_INFO "Step-2.2 KeyProcess" + + grep 'key process cost' $logfile | cut -d" " -f8 | cut -d ":" -f2 | cut -d"," -f1 | grep '^[0-9]' | grep '[0-9]$' | \ + awk 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ + {printf "--|key process cost(filter>2000): avg=%0.1f\n", sum/count}' + + # fast-unique related start + $(grep 'ProcessBatchWithFastUnique(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'ProcessBatchWithFastUnique(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "----|ProcessBatchWithFastUnique: avg=%0.1f\n", sum/NR}' + fi + + $(grep 'FastUniqueCompute(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'FastUniqueCompute(ms)' $logfile | cut -d' ' -f4 | \ + awk -F"[:,]" '{sum+=$2} END {printf "------|FastUniqueCompute: avg=%0.1f\n", sum/NR}' + fi + + $(grep 'GetScAll TimeCost(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'GetScAll TimeCost(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "------|GetScAll: avg=%0.1f\n", sum/NR}' + fi + + $(grep 'all2allTC TimeCost(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'all2allTC TimeCost(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "------|all2allTC: avg=%0.1f\n", sum/NR}' + fi + # fast-unique related end + + $(grep 'UniqueTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'UniqueTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "----|UniqueTC: avg=%0.1f\n", sum/NR}' + fi + + $(grep 'processSplitKeysTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'processSplitKeysTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "----|processSplitKeysTC: avg=%0.1f\n", sum/NR}' + fi + + $(grep 'getScAllTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'getScAllTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "------|getScAllTC(AllReduce-AllGather): avg=%0.1f\n", sum/NR}' + fi + + $(grep 'uniqueAll2AllTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'uniqueAll2AllTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "------|uniqueAll2AllTC(All2allv): avg=%0.1f\n", sum/NR}' + fi + + $(grep 'buildRestoreVecTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'buildRestoreVecTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "----|buildRestoreVecTC: avg=%0.1f\n", sum/NR}' + fi + + # common start + $(grep 'key2OffsetTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'key2OffsetTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "----|key2OffsetTC: avg=%0.1f\n", sum/NR}' + fi + + $(grep 'pushResultTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'pushResultTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {printf "----|pushResultTC, avg=%0.1f\n", sum/NR}' + fi + # common end +} + +parse_pipe_3_get_tensors_no_ddr() +{ + LOG_NOTICE "Pipe-3: Get Tensors (no DDR)" + + $(grep 'getAllTensorTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'getAllTensorTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {print "getAllTensorTC, avg=", sum/NR}' + fi +} + +parse_pipe_4_send_tensors_no_ddr() +{ + LOG_NOTICE "Pipe-4: H2D Send Tensors (no DDR)" + + $(grep 'sendTensorsTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'sendTensorsTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {print "sendTensorsTC, avg=", sum/NR}' + fi + + $(grep 'sendLookupTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'sendLookupTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {print "--|sendLookupTC, avg=", sum/NR}' + fi + + $(grep 'sendRestoreTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'sendRestoreTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {print "--|sendRestoreTC, avg=", sum/NR}' + fi +} + +parse_pipe_3_get_and_send_tensors_with_ddr() +{ + LOG_NOTICE "Pipe-3: Get and Send Tensors (with DDR)" + + $(grep 'getAndSendTensorsTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'getAndSendTensorsTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {print "GetAndSendTensors, avg=", sum/NR}' + fi +} + +main() +{ + validate_options $@ + check_spdlog_level + + echo "+----------------------------------------------------------------+" + echo "+ Profile Result +" + echo "+----------------------------------------------------------------+" + + parse_pipe_1_data_parser + parse_pipe_2_key_process + + $(grep 'DDR mode' $logfile > /dev/null 2>&1) + if [ $? -eq 0 ]; then + parse_pipe_3_get_and_send_tensors_with_ddr + else + parse_pipe_3_get_tensors_no_ddr + parse_pipe_4_send_tensors_no_ddr + fi +} + +main $@ \ No newline at end of file -- Gitee From b471bd6b7c784789ba5d9449d82b97d6658ed327 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 16 Jun 2023 16:31:53 +0800 Subject: [PATCH 138/551] Match-id-5b527ce9b4192e08689e3cc999ee5c7c871df37e --- mx_rec/core/asc/manager.py | 3 +- mx_rec/optimizers/lazy_adam_by_addr.py | 20 +-- mx_rec/util/initialize.py | 29 ++-- mx_rec/validator/validator.py | 7 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 176 +++++++++------------ src/core/hybrid_mgmt/hybrid_mgmt.h | 16 +- src/core/key_process/key_process.cpp | 10 +- src/core/key_process/key_process.h | 2 +- src/core/utils/unique.h | 0 src/tests/key_process/key_process_test.cpp | 22 +-- 10 files changed, 136 insertions(+), 149 deletions(-) create mode 100644 src/core/utils/unique.h diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 426dd525..fcbe0027 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -221,6 +221,7 @@ def start_asc_pipeline(): table_info_list = generate_table_info_list() threshold_list = generate_threshold_list() if not table_info_list: - logging.warning(f"table_info_list is empty") + logging.error("table_info_list is empty!") + raise RuntimeError("table_info_list is empty!") if not is_asc_manager_initialized() and table_info_list: initialize_emb_cache(table_info_list, threshold_list) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index e9be09a7..10a785d6 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -51,6 +51,16 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): self._slot_num = 2 self._check_input_param() + @property + def slot_num(self): + return self._slot_num + + def get_slot_init_values(self): + # return state value list of adam that needs to initialize in ASC DDR. + initial_momentum_value = 0.0 + initial_velocity_value = 0.0 + return [initial_momentum_value, initial_velocity_value] + def _check_input_param(self): check_param_type("beta1", self._beta1, (int, float)) check_param_range("beta1", self._beta1, 0, 1) @@ -63,10 +73,6 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): check_param_type("use_locking", self._use_locking, bool) - @property - def slot_num(self): - return self._slot_num - def _create_slots(self, addr_list): first_addr = addr_list[0] self._create_non_slot_variable( @@ -74,12 +80,6 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): self._create_non_slot_variable( initial_value=self._beta2, name="beta2_power", colocate_with=first_addr) - def get_slot_init_values(self): - # return state value list of adam that needs to initialize in ASC DDR. - initial_momentum_value = 0.0 - initial_velocity_value = 0.0 - return [initial_momentum_value, initial_velocity_value] - def _apply_dense(self, grad, var): logging.debug(">>>>Enter _apply_dense") raise NotImplementedError("You are using a wrong type of variable.") diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index baaba222..83ae73df 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -62,9 +62,9 @@ class ConfigInitializer: self.parse_hccl_json() else: self.set_hccl_info_without_json() - self.check_parameters() self.train_interval = kwargs.get("train_interval", -1) self.eval_steps = kwargs.get("eval_steps", -1) + self.check_parameters() self.prefetch_batch_number = kwargs.get("prefetch_batch_number", 1) self.if_load = kwargs.get("if_load", False) if_dynamic = kwargs.get("use_dynamic", 1) @@ -152,6 +152,14 @@ class ConfigInitializer: def ascend_global_hashtable_collection(self): return self._ascend_global_hashtable_collection + @property + def dangling_table(self): + return self._dangling_table + + @property + def removing_var_list(self): + return self._removing_var_list + @staticmethod def get_instance(): if ConfigInitializer._single_instance is None: @@ -254,12 +262,13 @@ class ConfigInitializer: raise ValueError(f"Rank size {rank_size} is different from device num {len(device_list)}.") try: self._rank_to_device_dict[0] = int(chief_device) + except ValueError as err: + raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err + try: device_list.pop(int(chief_device)) except IndexError as err: raise IndexError( f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {device_list}.") from err - except ValueError as err: - raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err for device_idx in device_list: device_id = mxrec_pybind.get_logic_id(int(device_idx)) @@ -284,14 +293,6 @@ class ConfigInitializer: if name not in self._removing_var_list: self._removing_var_list.append(name) - @property - def dangling_table(self): - return self._dangling_table - - @property - def removing_var_list(self): - return self._removing_var_list - def insert_table_instance(self, name, key, instance): if key in self._table_instance_dict: raise KeyError(f"Given key {key} has been used.") @@ -338,6 +339,9 @@ class ConfigInitializer: if self.rank_id >= self.rank_size: raise ValueError(f"Rank_id must be within the range from 0 to rank_size.") + if self._train_interval == 0 and self._eval_steps == 0: + raise ValueError(f"Train interval and eval steps could not both equal 0.") + def freeze(self): self._is_frozen = True @@ -433,9 +437,6 @@ def check_step(param, min_value=-1): if param < min_value: raise ValueError(f"Valid value range is larger than or equals to {min_value}.") - if param == 0: - raise ValueError("Arg train_interval or eval_steps cannot equal to 0.") - def init(use_mpi, **kwargs): ConfigInitializer.set_instance(use_mpi, **kwargs) diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index bc738915..73010f76 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -301,12 +301,13 @@ class RankInfoValidator: try: rank_size_value = int(rank_size) - res = RankSizeValidator(rank_size_value, 1, 16).check_rank_size_valid() - if not res and rank_size_value not in [1, 2, 4, 8, 16]: - raise ValueError("Invalid rank size, rank size must between 0 and 15 in recommendation training.") except ValueError as err: raise ValueError("Invalid rank size, rank size is a valid integer.") from err + res = RankSizeValidator(rank_size_value, 1, 16).check_rank_size_valid() + if not res and rank_size_value not in [1, 2, 4, 8, 16]: + raise ValueError("Invalid rank size, rank size must between 0 and 15 in recommendation training.") + chief_device = os.getenv("CM_CHIEF_DEVICE") chief_device_res = StringValidator(chief_device).check() if not chief_device_res: diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 0cb99214..3b8dcb8d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -255,14 +255,14 @@ void HybridMgmt::Start() #ifndef GTEST if (mgmtRankInfo.noDDR) { auto getInfoTask = [this]() { - auto ret = GetInfoTask(); + auto ret = Task(TaskType::GETINFO); spdlog::info("getInfoTask done"); return ret; }; procThreads.emplace_back(std::make_unique(getInfoTask)); auto sendInfoTask = [this]() { - auto ret = SendTask(); + auto ret = Task(TaskType::SEND); spdlog::info("sendInfoTask done"); return ret; }; @@ -271,7 +271,7 @@ void HybridMgmt::Start() if (!mgmtRankInfo.noDDR) { auto parseKeysTask = [this]() { - auto ret = ParseKeysTask(); + auto ret = Task(TaskType::DDR); spdlog::info("parseKeysTask done"); return ret; }; @@ -281,51 +281,19 @@ void HybridMgmt::Start() } #ifndef GTEST -bool HybridMgmt::TrainParseKeys() -{ - do { - if (!isRunning) { - return false; - } - ParseKeys(TRAIN_CHANNEL_ID, getInfoBatchId); - spdlog::info(MGMT + "parseKeysBatchId = {}", getInfoBatchId); - } while (getInfoBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1); - - return true; -} - -bool HybridMgmt::EvalParseKeys() -{ - int evalGetInfoBatchId = 0; // 0-99, 0-99 - do { - if (!isRunning) { - return false; - } - bool status = ParseKeys(EVAL_CHANNEL_ID, evalGetInfoBatchId); - if (!status) { - mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] = evalGetInfoBatchId; - break; - } - } while (evalGetInfoBatchId % mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1); - - return true; -} - -// 腾讯需要DDR开启特性 -bool HybridMgmt::ParseKeysTask() +bool HybridMgmt::Task(TaskType type) { while (isRunning) { - spdlog::info(MGMT + "Start Mgmt ParseKeysTask"); + spdlog::info(MGMT + "Start Mgmt Train Task: {}", type); if (mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] > 0) { - if (!TrainParseKeys()) { + if (!TrainTask(type)) { return false; } } + spdlog::info(MGMT + "Start Mgmt Eval Task: {}", type); if (mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] > 0) { - if (!EvalParseKeys()) { + if (!EvalTask(type)) { return false; } } @@ -334,71 +302,83 @@ bool HybridMgmt::ParseKeysTask() return false; } -bool HybridMgmt::GetInfoTask() +bool HybridMgmt::TrainTask(TaskType type) { - while (isRunning) { - spdlog::info(MGMT + "Start Mgmt GetInfoTask"); - do { - if (!isRunning) { - return false; - } - GetLookupAndRestore(TRAIN_CHANNEL_ID, getInfoBatchId); - spdlog::info(MGMT + "getInfoBatchId = {}", getInfoBatchId); - } while (getInfoBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1); - - int evalGetInfoBatchId = 0; // 0-99, 0-99 - do { - if (!isRunning) { - return false; - } - bool status = GetLookupAndRestore(EVAL_CHANNEL_ID, evalGetInfoBatchId); - if (!status) { - mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] = evalGetInfoBatchId; + bool isContinue = false; + do { + if (!isRunning) { + return false; + } + switch (type) { + case TaskType::GETINFO: + GetLookupAndRestore(TRAIN_CHANNEL_ID, getInfoBatchId); + isContinue = getInfoBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; + spdlog::info(MGMT + "getInfoBatchId = {}", getInfoBatchId); break; - } - } while (evalGetInfoBatchId % mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1); - } - return false; + case TaskType::SEND: + SendLookupAndRestore(TRAIN_CHANNEL_ID, sendBatchId); + isContinue = sendBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; + spdlog::info(MGMT + "sendBatchId = {}", sendBatchId); +#if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) + if (sendBatchId == PROFILING_START_BATCH_ID) { + EASY_PROFILER_ENABLE + } else if (sendBatchId == PROFILING_END_BATCH_ID) { + EASY_PROFILER_DISABLE + ::profiler::dumpBlocksToFile(fmt::format("/home/MX_REC-mgmt-profile-{}.prof", + mgmtRankInfo.rankId).c_str()); + } +#endif + break; + case TaskType::DDR: + ParseKeys(TRAIN_CHANNEL_ID, trainBatchId); + isContinue = trainBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; + spdlog::info(MGMT + "parseKeysBatchId = {}", trainBatchId); + break; + default: + throw std::invalid_argument("Invalid TaskType Type."); + } + } while (isContinue); + + return true; } -bool HybridMgmt::SendTask() +bool HybridMgmt::EvalTask(TaskType type) { - while (isRunning) { - spdlog::info(MGMT + "Start Mgmt SendTask"); - do { - if (!isRunning) { - return false; - } - SendLookupAndRestore(TRAIN_CHANNEL_ID, sendBatchId); -#if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) - spdlog::info(MGMT + "sendBatchId = {}", sendBatchId); - if (trainBatchId == PROFILING_START_BATCH_ID) { - EASY_PROFILER_ENABLE - } else if (trainBatchId == PROFILING_END_BATCH_ID) { - EASY_PROFILER_DISABLE - ::profiler::dumpBlocksToFile( - fmt::format("/home/MX_REC-mgmt-profile-{}.prof", mgmtRankInfo.rankId).c_str()); - } -#endif - } while (sendBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1); + int evalBatchId = 0; // 0-99, 0-99 + do { + if (!isRunning) { + return false; + } + bool status = false; - int evalSendBatchId = 0; - do { - if (!isRunning) { - return false; - } - bool status = SendLookupAndRestore(EVAL_CHANNEL_ID, evalSendBatchId); - if (!status) { - mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] = evalSendBatchId; + switch (type) { + case TaskType::GETINFO: + status = GetLookupAndRestore(EVAL_CHANNEL_ID, evalBatchId); + spdlog::info(MGMT + "GETINFO evalBatchId = {}", evalBatchId); break; - } - } while (evalSendBatchId % mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1); - } - return false; + case TaskType::SEND: + status = SendLookupAndRestore(EVAL_CHANNEL_ID, evalBatchId); + spdlog::info(MGMT + "SEND evalBatchId = {}", evalBatchId); + break; + case TaskType::DDR: + status = ParseKeys(EVAL_CHANNEL_ID, evalBatchId); + spdlog::info(MGMT + "DDR evalBatchId = {}", evalBatchId); + break; + default: + throw std::invalid_argument("Invalid TaskType Type."); + } + + if (!status) { + mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] = evalBatchId; + break; + } + } while (evalBatchId % mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1); + + return true; } bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 079897ac..e2ff7a87 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -28,6 +28,12 @@ namespace MxRec { constexpr int SEND_TENSOR_TYPE_NUM = 2; + enum class TaskType { + GETINFO, + SEND, + DDR + }; + class HybridMgmt { public: HybridMgmt() = default; @@ -100,7 +106,7 @@ namespace MxRec { private: int currentBatchId; int trainBatchId = 0; // 0-199, 200- - int getInfoBatchId; + int getInfoBatchId; // 0-199, 200- int sendBatchId; vector mgmtEmbInfo; RankInfo mgmtRankInfo; @@ -116,11 +122,9 @@ namespace MxRec { bool skipUpdate; bool isLoad { false }; - bool ParseKeysTask(); - bool GetInfoTask(); - bool SendTask(); - bool TrainParseKeys(); - bool EvalParseKeys(); + bool Task(TaskType type); + bool TrainTask(TaskType type); + bool EvalTask(TaskType type); bool GetLookupAndRestore(const int channelId, int &batchId); bool SendLookupAndRestore(const int channelId, int &batchId); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 42680f2e..2763189a 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -37,9 +37,9 @@ inline vector Count2Start(const vector& count) return start; } -int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, - const vector& thresholdValues, - int seed) +bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, + const vector& thresholdValues, + int seed) { this->rankInfo = rInfo; if (rankInfo.useHot) { @@ -93,7 +93,7 @@ int KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, spdlog::info(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", scInfo, rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot); - return 0; + return true; } // bind and start main process @@ -1135,7 +1135,7 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa list = &infoList; break; default: - throw runtime_error("ERROR list type"); + throw std::invalid_argument("Invalid ProcessedInfo Type."); } while (true) { if (!isRunning) { diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 6efc4181..3f71b1bc 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -58,7 +58,7 @@ namespace MxRec { class KeyProcess { public: - int Initialize(const RankInfo& rInfo, const vector& eInfos, + bool Initialize(const RankInfo& rInfo, const vector& eInfos, const vector& thresholdValues = {}, int seed = 0); unique_ptr> GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type); diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h new file mode 100644 index 00000000..e69de29b diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 87a836c9..15b96063 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -239,7 +239,7 @@ protected: TEST_F(KeyProcessTest, Initialize) { - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); ASSERT_EQ(process.rankInfo.rankId, rankInfo.rankId); ASSERT_EQ(process.rankInfo.rankSize, rankInfo.rankSize); @@ -256,7 +256,7 @@ TEST_F(KeyProcessTest, Initialize) TEST_F(KeyProcessTest, Start) { - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); ASSERT_EQ(process.Start(), 0); process.Destroy(); @@ -272,7 +272,7 @@ TEST_F(KeyProcessTest, HashSplit) vector> expectSplitKeys = { { 4, 16 }, { 1, 21, 29 }, { 14, 2 }, { 23, 7 } }; batch->sample = std::move(batchKeys); spdlog::debug(KEY_PROCESS "batch sample: {}", batch->sample); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.rankInfo.rankSize = rankSize; auto [splitKeys, restore] = process.HashSplit(batch); @@ -291,7 +291,7 @@ TEST_F(KeyProcessTest, GetScAll) for (unsigned int i = 0; i < expectScAll.size(); ++i) { expectScAll[i] = floor(i / worldSize) + 1; } - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); vector scAll; process.GetScAll(keyScLocal, 0, 0, scAll); @@ -307,7 +307,7 @@ TEST_F(KeyProcessTest, GetScAllForUnique) for (unsigned int i = 0; i < expectScAll.size(); ++i) { expectScAll[i] = floor(i / worldSize) + 1; } - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); vector scAll; process.GetScAllForUnique(keyScLocal, 0, 0, scAll); @@ -329,7 +329,7 @@ TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { 6, 3, 7, 4, 3, 0, 1, 2, 5, 8 } }; batch->sample = std::move(allBatchKeys[worldRank]); spdlog::info(KEY_PROCESS "test BuildRestoreVec: rank {}, batchKeys {}", worldRank, batch->sample); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); auto [splitKeys, restore] = process.HashSplit(batch); spdlog::debug("rank: {} splitKeys: {}", worldRank, splitKeys); @@ -340,7 +340,7 @@ TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) { PrepareBatch(); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { @@ -369,7 +369,7 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) { PrepareBatch(); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { @@ -400,7 +400,7 @@ TEST_F(KeyProcessTest, Key2Offset) { keys_t lookupKeys = { 4, 16, 28, 4, 24, 4, 20, 24 }; keys_t expectOffset = { 0, 1, 2, 0, 3, 0, 4, 3 }; - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.Key2Offset("emb0", lookupKeys); spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys, process.keyOffsetMap); @@ -422,7 +422,7 @@ TEST_F(KeyProcessTest, GetUniqueConfig) TEST_F(KeyProcessTest, ProcessPrefetchTask) { PrepareBatch(); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); process.rankInfo.rankSize = worldSize; process.rankInfo.localRankId = process.rankInfo.rankId % process.rankInfo.localRankSize; ASSERT_EQ(process.isRunning, true); @@ -465,7 +465,7 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) { PrepareBatch(); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), 0); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { -- Gitee From 5ae74bddedee3fe30c0636d150c94fe4086faa7b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 16 Jun 2023 17:11:12 +0800 Subject: [PATCH 139/551] Match-id-9212985de97f6852ac59e13d53c589554b98c748 --- mx_rec/optimizers/ftrl_t_dense.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/optimizers/ftrl_t_dense.py b/mx_rec/optimizers/ftrl_t_dense.py index 40573f59..364267e3 100644 --- a/mx_rec/optimizers/ftrl_t_dense.py +++ b/mx_rec/optimizers/ftrl_t_dense.py @@ -77,11 +77,11 @@ class CustomizedFtrlTZ(optimizer.Optimizer): def _apply_dense(self, grad, var): if self._lambda1 > 1e-10: return self._apply_dense_shared( - grad.values, + grad, var) else: return self._apply_dense_shared_v2( - grad.values, + grad, var) def _apply_dense_shared(self, grad, var): -- Gitee From f242e66b00867f2027fab083d3b4a188288242e1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 16 Jun 2023 18:13:54 +0800 Subject: [PATCH 140/551] Match-id-4c29cc2472855da32e4e8482e9340a65ba3f2b78 --- build/build.sh | 2 +- src/core/utils/unique.h | 0 src/test_ut.sh | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 src/core/utils/unique.h diff --git a/build/build.sh b/build/build.sh index 1b7a8699..cecc19ec 100644 --- a/build/build.sh +++ b/build/build.sh @@ -93,7 +93,7 @@ abseil_install_path="${ROOT_DIR}"/install/abseil src_path="${ROOT_DIR}"/src acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR -cp -rf ../platform/securec/* /usr1/mxRec/src/platform/AccCTR/3rdparty/huawei_secure_c +cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c cd "${ROOT_DIR}" release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz diff --git a/src/core/utils/unique.h b/src/core/utils/unique.h deleted file mode 100644 index e69de29b..00000000 diff --git a/src/test_ut.sh b/src/test_ut.sh index 1605692d..712ce65f 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -16,7 +16,7 @@ source /opt/rh/devtoolset-7/enable CUR_DIR=$(dirname "$(readlink -f "$0")") ROOT_DIR=$(dirname "${CUR_DIR}") acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR -cp -rf ../platform/securec/* /usr1/mxRec/src/platform/AccCTR/3rdparty/huawei_secure_c +cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c export LD_LIBRARY_PATH="${acc_ctr_path}"/output/ock_ctr_common/lib:$LD_LIBRARY_PATH compile_securec() -- Gitee From 542c309a12df7547f1fa0de8ecb9833e7a6e62f5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 16 Jun 2023 19:27:14 +0800 Subject: [PATCH 141/551] Match-id-107ad0ebe5944f78bf776ed51cf3b9663b8264f7 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 155 +++++++++++++++++++-------- src/core/hybrid_mgmt/hybrid_mgmt.h | 17 ++- 2 files changed, 123 insertions(+), 49 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 3b8dcb8d..f3a72e48 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -91,8 +91,10 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, return false; } - lookUpKeysQueue = make_unique>>(); - restoreQueue = make_unique>>(); + lookUpKeysQueueForTrain = make_unique>>(); + restoreQueueForTrain = make_unique>>(); + lookUpKeysQueueForEval = make_unique>>(); + restoreQueueForEval = make_unique>>(); isRunning = true; if (!rankInfo.noDDR) { @@ -254,52 +256,70 @@ void HybridMgmt::Start() { #ifndef GTEST if (mgmtRankInfo.noDDR) { - auto getInfoTask = [this]() { - auto ret = Task(TaskType::GETINFO); - spdlog::info("getInfoTask done"); - return ret; + auto getInfoTaskForTrain = [this]() { + TaskForTrain(TaskType::GETINFO); + spdlog::info("getInfoTaskForTrain done"); }; - procThreads.emplace_back(std::make_unique(getInfoTask)); + procThreads.emplace_back(std::make_unique(getInfoTaskForTrain)); - auto sendInfoTask = [this]() { - auto ret = Task(TaskType::SEND); - spdlog::info("sendInfoTask done"); - return ret; + auto getInfoTaskForEval = [this]() { + TaskForEval(TaskType::GETINFO); + spdlog::info("getInfoTaskForEval done"); }; - procThreads.emplace_back(std::make_unique(sendInfoTask)); + procThreads.emplace_back(std::make_unique(getInfoTaskForEval)); + + auto sendInfoTaskForTrain = [this]() { + TaskForTrain(TaskType::SEND); + spdlog::info("sendInfoTaskForTrain done"); + }; + procThreads.emplace_back(std::make_unique(sendInfoTaskForTrain)); + + auto sendInfoTaskForEval = [this]() { + TaskForEval(TaskType::SEND); + spdlog::info("sendInfoTaskForEval done"); + }; + procThreads.emplace_back(std::make_unique(sendInfoTaskForEval)); } if (!mgmtRankInfo.noDDR) { - auto parseKeysTask = [this]() { - auto ret = Task(TaskType::DDR); - spdlog::info("parseKeysTask done"); - return ret; + auto parseKeysTaskForTrain = [this]() { + TaskForTrain(TaskType::DDR); + spdlog::info("parseKeysTaskForTrain done"); + }; + procThreads.emplace_back(std::make_unique(parseKeysTaskForTrain)); + + auto parseKeysTaskForEval = [this]() { + TaskForEval(TaskType::DDR); + spdlog::info("parseKeysTaskForEval done"); }; - procThreads.emplace_back(std::make_unique(parseKeysTask)); + procThreads.emplace_back(std::make_unique(parseKeysTaskForEval)); } #endif } #ifndef GTEST -bool HybridMgmt::Task(TaskType type) +void HybridMgmt::TaskForTrain(TaskType type) { while (isRunning) { - spdlog::info(MGMT + "Start Mgmt Train Task: {}", type); + spdlog::info(MGMT + "Start Train Task: {}", type); if (mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] > 0) { if (!TrainTask(type)) { - return false; + return; } } + } +} - spdlog::info(MGMT + "Start Mgmt Eval Task: {}", type); +void HybridMgmt::TaskForEval(TaskType type) +{ + while (isRunning) { + spdlog::info(MGMT + "Start Eval Task: {}", type); if (mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] > 0) { if (!EvalTask(type)) { - return false; + return; } } } - - return false; } bool HybridMgmt::TrainTask(TaskType type) @@ -309,15 +329,17 @@ bool HybridMgmt::TrainTask(TaskType type) if (!isRunning) { return false; } + bool status = false; + switch (type) { case TaskType::GETINFO: - GetLookupAndRestore(TRAIN_CHANNEL_ID, getInfoBatchId); + status = GetLookupAndRestore(TRAIN_CHANNEL_ID, getInfoBatchId); isContinue = getInfoBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; spdlog::info(MGMT + "getInfoBatchId = {}", getInfoBatchId); break; case TaskType::SEND: - SendLookupAndRestore(TRAIN_CHANNEL_ID, sendBatchId); + status = SendLookupAndRestore(TRAIN_CHANNEL_ID, sendBatchId); isContinue = sendBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; spdlog::info(MGMT + "sendBatchId = {}", sendBatchId); @@ -332,7 +354,7 @@ bool HybridMgmt::TrainTask(TaskType type) #endif break; case TaskType::DDR: - ParseKeys(TRAIN_CHANNEL_ID, trainBatchId); + status = ParseKeys(TRAIN_CHANNEL_ID, trainBatchId); isContinue = trainBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; spdlog::info(MGMT + "parseKeysBatchId = {}", trainBatchId); @@ -340,6 +362,10 @@ bool HybridMgmt::TrainTask(TaskType type) default: throw std::invalid_argument("Invalid TaskType Type."); } + + if (!status) { + return false; + } } while (isContinue); return true; @@ -372,8 +398,7 @@ bool HybridMgmt::EvalTask(TaskType type) } if (!status) { - mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] = evalBatchId; - break; + return false; } } while (evalBatchId % mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] != 0 || mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1); @@ -398,9 +423,21 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); return false; } - lookUpKeysQueue->Pushv({ infoVecs->back() }); - infoVecs->pop_back(); - restoreQueue->Pushv(*infoVecs); + + switch (channelId) { + case TRAIN_CHANNEL_ID: + lookUpKeysQueueForTrain->Pushv({ infoVecs->back() }); + infoVecs->pop_back(); + restoreQueueForTrain->Pushv(*infoVecs); + break; + case EVAL_CHANNEL_ID: + lookUpKeysQueueForEval->Pushv({ infoVecs->back() }); + infoVecs->pop_back(); + restoreQueueForEval->Pushv(*infoVecs); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); + } } TIME_PRINT("getAllTensorTC(ms):{}", getAllTensorTC.ElapsedMS()); } @@ -408,6 +445,46 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) return true; } +void HybridMgmt::LookupKeys(const int channelId, vector names) +{ + TimeCost sendLookupTC; + for (const string& name: names) { + vector lookUpKeys; + switch (channelId) { + case TRAIN_CHANNEL_ID: + lookUpKeys = lookUpKeysQueueForTrain->WaitAndPop(); + break; + case EVAL_CHANNEL_ID: + lookUpKeys = lookUpKeysQueueForEval->WaitAndPop(); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); + } + hdTransfer->Send(TransferChannel::LOOKUP, lookUpKeys, channelId, name); + } + TIME_PRINT("sendLookupTC(ms):{}", sendLookupTC.ElapsedMS()); +} + +void HybridMgmt::RestoreKeys(const int channelId, vector names) +{ + TimeCost sendRestoreTC; + for (const string& name: names) { + vector restore; + switch (channelId) { + case TRAIN_CHANNEL_ID: + restore = restoreQueueForTrain->WaitAndPop(); + break; + case EVAL_CHANNEL_ID: + restore = restoreQueueForEval->WaitAndPop(); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); + } + hdTransfer->Send(TransferChannel::RESTORE, restore, channelId, name); + } + TIME_PRINT("sendRestoreTC(ms):{}", sendRestoreTC.ElapsedMS()); +} + bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) { for (const auto& embInfo: mgmtEmbInfo) { @@ -432,21 +509,11 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) { #pragma omp section { - TimeCost sendLookupTC; - for (const string& name: names) { - auto lookUpKeys = lookUpKeysQueue->WaitAndPop(); - hdTransfer->Send(TransferChannel::LOOKUP, lookUpKeys, channelId, name); - } - TIME_PRINT("sendLookupTC(ms):{}", sendLookupTC.ElapsedMS()); + LookupKeys(channelId, names); } #pragma omp section { - TimeCost sendRestoreTC; - for (const string& name: names) { - auto restore = restoreQueue->WaitAndPop(); - hdTransfer->Send(TransferChannel::RESTORE, restore, channelId, name); - } - TIME_PRINT("sendRestoreTC(ms):{}", sendRestoreTC.ElapsedMS()); + RestoreKeys(channelId, names); } } TIME_PRINT("sendTensorsTC(ms):{}", sendTensorsTC.ElapsedMS()); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index e2ff7a87..22421520 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -65,8 +65,10 @@ namespace MxRec { } // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 isRunning = false; - restoreQueue->DestroyQueue(); - lookUpKeysQueue->DestroyQueue(); + restoreQueueForTrain->DestroyQueue(); + lookUpKeysQueueForTrain->DestroyQueue(); + restoreQueueForEval->DestroyQueue(); + lookUpKeysQueueForEval->DestroyQueue(); // 先发送停止信号给preprocess,用于停止查询中lookup卡住状态 preprocess->isRunning = false; @@ -113,8 +115,10 @@ namespace MxRec { unique_ptr hostEmbs {}; unique_ptr hostHashMaps {}; vector> procThreads {}; - unique_ptr>> lookUpKeysQueue; - unique_ptr>> restoreQueue; + unique_ptr>> lookUpKeysQueueForTrain; + unique_ptr>> restoreQueueForTrain; + unique_ptr>> lookUpKeysQueueForEval; + unique_ptr>> restoreQueueForEval; map> evictKeyMap {}; KeyProcess *preprocess; HDTransfer *hdTransfer; @@ -122,10 +126,13 @@ namespace MxRec { bool skipUpdate; bool isLoad { false }; - bool Task(TaskType type); + void TaskForTrain(TaskType type); + void TaskForEval(TaskType type); bool TrainTask(TaskType type); bool EvalTask(TaskType type); + void LookupKeys(const int channelId, vector names); + void RestoreKeys(const int channelId, vector names); bool GetLookupAndRestore(const int channelId, int &batchId); bool SendLookupAndRestore(const int channelId, int &batchId); -- Gitee From 37e9a4959d0156b024aaf8f9864694baea1c2f97 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 17 Jun 2023 18:36:46 +0800 Subject: [PATCH 142/551] Match-id-79028c55a68f36ce3cd786b9bff5a78a30ea8148 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 148 ++++++++++++++++++++++----- src/core/hybrid_mgmt/hybrid_mgmt.h | 8 ++ 2 files changed, 130 insertions(+), 26 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index f3a72e48..41ba5efb 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -255,7 +255,42 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) void HybridMgmt::Start() { #ifndef GTEST + int mode = 0; + const char* envTaskMode = std::getenv("MGMT_HBM_TASK_MODE"); // 获取环境变量 + if (envTaskMode != nullptr) { // 如果环境变量存在 + try { + mode = std::stoi(envTaskMode); // 将字符串转换为整数 + spdlog::info("The value of MGMT_HBM_TASK_MODE is an integer: {}", mode); + } catch (const std::invalid_argument& e) { // 如果转换失败 + spdlog::info("The value of MGMT_HBM_TASK_MODE is not an integer!"); + } + } else { // 如果环境变量不存在 + mode = 0; + } if (mgmtRankInfo.noDDR) { + InsertThreadForHBM(mode); + } + + if (!mgmtRankInfo.noDDR) { + auto parseKeysTaskForTrain = [this]() { + TaskForTrain(TaskType::DDR); + spdlog::info("parseKeysTaskForTrain done"); + }; + procThreads.emplace_back(std::make_unique(parseKeysTaskForTrain)); + + auto parseKeysTaskForEval = [this]() { + TaskForEval(TaskType::DDR); + spdlog::info("parseKeysTaskForEval done"); + }; + procThreads.emplace_back(std::make_unique(parseKeysTaskForEval)); + } +#endif +} + +void HybridMgmt::InsertThreadForHBM(int mode) +{ +#ifndef GTEST + if (mode == 1) { auto getInfoTaskForTrain = [this]() { TaskForTrain(TaskType::GETINFO); spdlog::info("getInfoTaskForTrain done"); @@ -279,20 +314,18 @@ void HybridMgmt::Start() spdlog::info("sendInfoTaskForEval done"); }; procThreads.emplace_back(std::make_unique(sendInfoTaskForEval)); - } - - if (!mgmtRankInfo.noDDR) { - auto parseKeysTaskForTrain = [this]() { - TaskForTrain(TaskType::DDR); - spdlog::info("parseKeysTaskForTrain done"); + } else { + auto parseKeysTaskForHBMTrain = [this]() { + TaskForTrain(TaskType::HBM); + spdlog::info("parseKeysTaskForHBMTrain done"); }; - procThreads.emplace_back(std::make_unique(parseKeysTaskForTrain)); + procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMTrain)); - auto parseKeysTaskForEval = [this]() { - TaskForEval(TaskType::DDR); - spdlog::info("parseKeysTaskForEval done"); + auto parseKeysTaskForHBMEval = [this]() { + TaskForEval(TaskType::HBM); + spdlog::info("parseKeysTaskForHBMEval done"); }; - procThreads.emplace_back(std::make_unique(parseKeysTaskForEval)); + procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMEval)); } #endif } @@ -353,6 +386,12 @@ bool HybridMgmt::TrainTask(TaskType type) } #endif break; + case TaskType::HBM: + status = ParseKeysHBM(TRAIN_CHANNEL_ID, trainBatchId); + isContinue = trainBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || + mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; + spdlog::info(MGMT + "ParseKeysHBMBatchId = {}", trainBatchId); + break; case TaskType::DDR: status = ParseKeys(TRAIN_CHANNEL_ID, trainBatchId); isContinue = trainBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || @@ -389,6 +428,10 @@ bool HybridMgmt::EvalTask(TaskType type) status = SendLookupAndRestore(EVAL_CHANNEL_ID, evalBatchId); spdlog::info(MGMT + "SEND evalBatchId = {}", evalBatchId); break; + case TaskType::HBM: + status = ParseKeysHBM(EVAL_CHANNEL_ID, evalBatchId); + spdlog::info(MGMT + "HBM evalBatchId = {}", evalBatchId); + break; case TaskType::DDR: status = ParseKeys(EVAL_CHANNEL_ID, evalBatchId); spdlog::info(MGMT + "DDR evalBatchId = {}", evalBatchId); @@ -419,6 +462,19 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) embInfo.name, embInfo.modifyGraph, names); for (const string& name: names) { auto infoVecs = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::RESTORE); + if (!mgmtRankInfo.useStatic) { + auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); + switch (channelId) { + case TRAIN_CHANNEL_ID: + a2aQueueForTrain->Pushv({ *all2all }); + break; + case EVAL_CHANNEL_ID: + a2aQueueForEval->Pushv({ *all2all }); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); + } + } if (infoVecs == nullptr) { spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); return false; @@ -445,6 +501,26 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) return true; } +void HybridMgmt::All2AllKeys(const int channelId, vector names) +{ + TimeCost a2aKeysTC; + for (const string &name : names) { + vector all2allKeys; + switch (channelId) { + case TRAIN_CHANNEL_ID: + all2allKeys = a2aQueueForTrain->WaitAndPop(); + break; + case EVAL_CHANNEL_ID: + all2allKeys = a2aQueueForEval->WaitAndPop(); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); + } + hdTransfer->Send(TransferChannel::ALL2ALL, all2allKeys, channelId, name); + } + TIME_PRINT("All2AllKeysTC(ms):{}", a2aKeysTC.ElapsedMS()); +} + void HybridMgmt::LookupKeys(const int channelId, vector names) { TimeCost sendLookupTC; @@ -494,29 +570,49 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) } spdlog::debug(MGMT + "SendLookupAndRestore embInfoName:{}, modifyGraph:{}, names:{}", embInfo.name, embInfo.modifyGraph, names); + TimeCost sendTensorsTC; if (!mgmtRankInfo.useStatic) { - for (const string& name: names) { - auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(TransferChannel::ALL2ALL, { *all2all }, channelId, name); - } + All2AllKeys(channelId, names); } + spdlog::info("SendLookupAndRestore batchId: {}, name: {}, channelId: {}", batchId, embInfo.name, channelId); - TimeCost sendTensorsTC; - omp_set_num_threads(SEND_TENSOR_TYPE_NUM); -#pragma omp parallel sections - { -#pragma omp section - { - LookupKeys(channelId, names); + LookupKeys(channelId, names); + RestoreKeys(channelId, names); + TIME_PRINT("sendTensorsTC(ms):{}", sendTensorsTC.ElapsedMS()); + } + batchId++; + return true; +} + +bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) +{ + spdlog::info(MGMT + "start parse keys HBM, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); + for (const auto& embInfo: mgmtEmbInfo) { + TimeCost ParseKeysTC; + vector names = {embInfo.name}; + if (embInfo.modifyGraph) { + names = embInfo.channelNames; + } + spdlog::debug(MGMT + "ParseKeysHBM embInfoName:{}, modifyGraph:{}, names:{}", + embInfo.name, embInfo.modifyGraph, names); + for (const string& name: names) { + auto infoVecs = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::RESTORE); + if (infoVecs == nullptr) { + spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); + return false; } -#pragma omp section - { - RestoreKeys(channelId, names); + + if (!mgmtRankInfo.useStatic) { + auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); + hdTransfer->Send(TransferChannel::ALL2ALL, { *all2all }, channelId, name); } + hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, name); + infoVecs->pop_back(); + hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, name); } - TIME_PRINT("sendTensorsTC(ms):{}", sendTensorsTC.ElapsedMS()); + TIME_PRINT("ParseKeysTC HBM mode (ms):{}", ParseKeysTC.ElapsedMS()); } batchId++; return true; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 22421520..fcdae8f2 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -31,6 +31,7 @@ namespace MxRec { enum class TaskType { GETINFO, SEND, + HBM, DDR }; @@ -58,6 +59,8 @@ namespace MxRec { void Start(); + void InsertThreadForHBM(int mode); + void Destroy() { if (!isRunning) { @@ -91,6 +94,8 @@ namespace MxRec { bool ParseKeys(int channelId, int& batchId); + bool ParseKeysHBM(int channelId, int& batchId); + bool ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut); void EmbHDTrans(const int channelId, const int batchId); @@ -119,6 +124,8 @@ namespace MxRec { unique_ptr>> restoreQueueForTrain; unique_ptr>> lookUpKeysQueueForEval; unique_ptr>> restoreQueueForEval; + unique_ptr>> a2aQueueForTrain; + unique_ptr>> a2aQueueForEval; map> evictKeyMap {}; KeyProcess *preprocess; HDTransfer *hdTransfer; @@ -131,6 +138,7 @@ namespace MxRec { bool TrainTask(TaskType type); bool EvalTask(TaskType type); + void All2AllKeys(const int channelId, vector names); void LookupKeys(const int channelId, vector names); void RestoreKeys(const int channelId, vector names); bool GetLookupAndRestore(const int channelId, int &batchId); -- Gitee From d572de685b3d8755696869a8dab0a5bc871c30e7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 17 Jun 2023 18:57:36 +0800 Subject: [PATCH 143/551] Match-id-50a0115b2e059c60765a424b8f7e00a557c680e3 --- mx_rec/core/asc/helper.py | 55 +++++++++++++++++++++++++++++++++++--- mx_rec/core/asc/manager.py | 13 ++++----- mx_rec/util/initialize.py | 6 +++++ 3 files changed, 64 insertions(+), 10 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index d69fb8e0..f33f5e3c 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -11,7 +11,7 @@ from tensorflow import Tensor from tensorflow import Operation from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static, \ - export_table_instances, insert_dangling_table + get_enable_table_merge, export_table_instances, insert_dangling_table from mx_rec.core.asc.feature_spec import FeatureSpec @@ -70,6 +70,13 @@ def find_dangling_table(table_names: List[str]): if 'gradients/' in table_reachable_tensor.name and table_reachable_tensor.op.type == 'Identity': return True + if 'logistic_loss' in table_reachable_tensor.op.name and table_reachable_tensor.op.type == 'AddV2': + return True + + if 'SparseSoftmaxCrossEntropyWithLogits' in table_reachable_tensor.op.name \ + and table_reachable_tensor.op.type == 'SparseSoftmaxCrossEntropyWithLogits': + return True + return False def find_table_op(table_name: str, @@ -98,13 +105,25 @@ def find_dangling_table(table_names: List[str]): table_lookup_op = {} table_reachable_tensor = {} + for _, table_instance in export_table_instances().items(): + if table_instance.table_name not in table_names: + table_names.append(table_instance.table_name) + for the_op in op_list: for table_name in table_names: find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) - logging.info(f"*********** find tables: {table_lookup_op}***********") + logging.info(f"*********** find tables: {table_lookup_op} ***********") + dangling_table = [] + for table_name in table_names: + if table_name not in table_lookup_op: + logging.info(f"*********** created table {table_name} but never look up***********") + dangling_table.append(table_name) + insert_dangling_table(table_name) + + def extend(op_list: List[Operation], tensor: Tensor, spread_tensors: List[Tensor]): @@ -146,6 +165,23 @@ def find_dangling_table(table_names: List[str]): return dangling_table +def should_skip(table_name): + from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ + and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, str) \ + and ASCEND_TABLE_NAME_MUST_CONTAIN not in table_name: + return True + if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ + and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, list): + skip = True + for key_word in ASCEND_TABLE_NAME_MUST_CONTAIN: + if isinstance(key_word, str) and key_word in table_name: + skip = False + break + return skip + return False + + def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_counts=None, table_names=None, **kwargs): both_none = tgt_key_specs is None and args_index_list is None @@ -188,7 +224,10 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ if feature_counts is None or table_names is None: raise ValueError("Please config 'args_index_list', 'feature_counts' and 'table_names' at the same time.") - dangling_tables = find_dangling_table(table_names) + dangling_tables = [] + if get_enable_table_merge(): + dangling_tables = find_dangling_table(table_names) + logging.info(f"In insert found dangling table(s): {dangling_tables} " f"which does not need to be provided to the EmbInfo.") @@ -205,8 +244,16 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ new_insert_tensors, new_splits, new_table_names = [], [], [] for idx, table_name in enumerate(table_names): if table_name in dangling_tables: - logging.info(f"do_insert skip table: {table_name}") + logging.info(f"do_insert skip table : {table_name}") continue + + skip = should_skip(table_name) + if skip: + logging.info(f"do_insert skip table 2: {table_name}") + continue + + + new_insert_tensors.append(insert_tensors[idx]) new_splits.append(splits[idx]) new_table_names.append(table_names[idx]) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index fcbe0027..db51c04f 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -10,8 +10,8 @@ from mx_rec.constants.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ - get_use_hot, get_use_dynamic_expansion, export_optimizer, export_dangling_table -from mx_rec.core.asc.helper import find_dangling_table + get_use_hot, get_use_dynamic_expansion, get_enable_table_merge, export_optimizer, export_dangling_table +from mx_rec.core.asc.helper import find_dangling_table, should_skip def check_dangling_table(): @@ -20,7 +20,7 @@ def check_dangling_table(): :return: list of dangling_table """ dangling_table = export_dangling_table() - if not dangling_table: + if not dangling_table and get_enable_table_merge(): dangling_table = find_dangling_table([table_instance.table_name for _, table_instance in export_table_instances().items()]) return dangling_table @@ -48,9 +48,9 @@ def generate_table_info_list(): if optimizer is not None: table_instance.ext_emb_size = table_instance.scalar_emb_size * (1 + optimizer.slot_num) logging.debug(f"ext_emb_size is reset to be {table_instance.ext_emb_size} for EmbInfo") - - if table_instance.table_name in dangling_table: - logging.info(f"Found dangling table: {table_instance.table_name} " + skip = should_skip(table_instance.table_name) + if table_instance.table_name in dangling_table or skip: + logging.info(f"skip table {skip}: {table_instance.table_name} " f"which does not need to be provided to the EmbInfo.") continue @@ -220,6 +220,7 @@ def initialize_emb_cache(table_info_list, threshold_list): def start_asc_pipeline(): table_info_list = generate_table_info_list() threshold_list = generate_threshold_list() + if not table_info_list: logging.error("table_info_list is empty!") raise RuntimeError("table_info_list is empty!") diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 83ae73df..ef5d0bc0 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -10,6 +10,7 @@ from collections import defaultdict import mxrec_pybind import psutil +import mx_rec.constants.constants from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST, LOCAL_RANK_SIZE, \ MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH from mx_rec.util.ops import import_host_pipeline_ops @@ -74,6 +75,7 @@ class ConfigInitializer: self.use_dynamic_expansion = kwargs.get("use_dynamic_expansion", False) if kwargs.get("bind_cpu", True): bind_cpu(self._rank_id, self._rank_size) + self.enable_table_merge = True if os.getenv("TF_DEVICE") == "NPU" else False def __del__(self): self.terminate() @@ -647,6 +649,10 @@ def get_use_hot(): return ConfigInitializer.get_instance().use_hot +def get_enable_table_merge(): + return ConfigInitializer.get_instance().enable_table_merge + + def get_use_dynamic_expansion(): return ConfigInitializer.get_instance().use_dynamic_expansion -- Gitee From edab972a512be5fc58262f31068be64399de43fb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 19 Jun 2023 14:18:57 +0800 Subject: [PATCH 144/551] Match-id-c963636ac2b60c4b5831a6dfcd3c873160b3e61d --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 40 +++++++++++++++++----------- src/core/hybrid_mgmt/hybrid_mgmt.h | 1 + 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 41ba5efb..30fbf0cf 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -95,6 +95,8 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, restoreQueueForTrain = make_unique>>(); lookUpKeysQueueForEval = make_unique>>(); restoreQueueForEval = make_unique>>(); + a2aQueueForTrain = make_unique>>(); + a2aQueueForEval = make_unique>>(); isRunning = true; if (!rankInfo.noDDR) { @@ -262,7 +264,8 @@ void HybridMgmt::Start() mode = std::stoi(envTaskMode); // 将字符串转换为整数 spdlog::info("The value of MGMT_HBM_TASK_MODE is an integer: {}", mode); } catch (const std::invalid_argument& e) { // 如果转换失败 - spdlog::info("The value of MGMT_HBM_TASK_MODE is not an integer!"); + spdlog::error("The value of MGMT_HBM_TASK_MODE is not an integer!"); + throw std::invalid_argument("Invalid env value MGMT_HBM_TASK_MODE"); } } else { // 如果环境变量不存在 mode = 0; @@ -449,6 +452,21 @@ bool HybridMgmt::EvalTask(TaskType type) return true; } +void HybridMgmt::GetAll2All(const int channelId, int &batchId, const string &name) +{ + auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); + switch (channelId) { + case TRAIN_CHANNEL_ID: + a2aQueueForTrain->Pushv(*all2all); + break; + case EVAL_CHANNEL_ID: + a2aQueueForEval->Pushv(*all2all); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); + } +} + bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) { spdlog::info(MGMT + "start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); @@ -462,24 +480,10 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) embInfo.name, embInfo.modifyGraph, names); for (const string& name: names) { auto infoVecs = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::RESTORE); - if (!mgmtRankInfo.useStatic) { - auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); - switch (channelId) { - case TRAIN_CHANNEL_ID: - a2aQueueForTrain->Pushv({ *all2all }); - break; - case EVAL_CHANNEL_ID: - a2aQueueForEval->Pushv({ *all2all }); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } - } if (infoVecs == nullptr) { spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); return false; } - switch (channelId) { case TRAIN_CHANNEL_ID: lookUpKeysQueueForTrain->Pushv({ infoVecs->back() }); @@ -494,6 +498,10 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) default: throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); } + + if (!mgmtRankInfo.useStatic) { + GetAll2All(channelId, batchId, name); + } } TIME_PRINT("getAllTensorTC(ms):{}", getAllTensorTC.ElapsedMS()); } @@ -606,7 +614,7 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) if (!mgmtRankInfo.useStatic) { auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(TransferChannel::ALL2ALL, { *all2all }, channelId, name); + hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, name); } hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, name); infoVecs->pop_back(); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index fcdae8f2..ff27c885 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -142,6 +142,7 @@ namespace MxRec { void LookupKeys(const int channelId, vector names); void RestoreKeys(const int channelId, vector names); bool GetLookupAndRestore(const int channelId, int &batchId); + void GetAll2All(const int channelId, int &batchId, const string &name); bool SendLookupAndRestore(const int channelId, int &batchId); void EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo); -- Gitee From 069ecdb434a0c9b1f2b4d436950edc8fb9f34f6d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 19 Jun 2023 15:15:13 +0800 Subject: [PATCH 145/551] Match-id-363050e3436632358b918c4b197820717f2e5b39 --- mx_rec/core/asc/manager.py | 7 +++++-- mx_rec/core/embedding.py | 10 ++++++++-- .../truncated_normal_initializer.cpp | 8 +++++--- .../truncated_normal_initializer.h | 2 +- src/core/utils/common.cpp | 3 ++- 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 7906b5e6..515ec18f 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -111,7 +111,7 @@ def matched_emb_initializer(tabel_info): elif initializer_case_map.get("tf1/tf2_random_normal_initializer"): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed init_param = tabel_info.init_param - logging.debug(f"tabel: {tabel_info.table_name}, initK is {init_param}.") + logging.debug(f"random_normal_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") initializer = InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.scalar_emb_size, normal_initializer_info=NormalInitializerInfo( mean=tabel_info.emb_initializer.mean, @@ -122,11 +122,14 @@ def matched_emb_initializer(tabel_info): elif initializer_case_map.get("tf1_truncated_normal_initializer") or \ initializer_case_map.get("tf2_truncated_normal_initializer"): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed + init_param = tabel_info.init_param + logging.debug(f"truncated_normal_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") initializer = InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, normal_initializer_info=NormalInitializerInfo( mean=tabel_info.emb_initializer.mean, stddev=tabel_info.emb_initializer.stddev, - seed=random_seed + seed=random_seed, + initK=init_param )) else: initializer = InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 9d8358b1..8c3cdfbe 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -41,6 +41,7 @@ def create_table(**kwargs): shard_num = kwargs.get("shard_num", 1) fusion_optimizer_var = kwargs.get("fusion_optimizer_var", True) hashtable_threshold = kwargs.get("hashtable_threshold", 0) + init_param = kwargs.get("init_param", 1.0) is_save = kwargs.get("is_save", True) """ @@ -58,12 +59,14 @@ def create_table(**kwargs): shard_num: embedding partition number fusion_optimizer_var: fusion optimizer variable with embedding hashtable_threshold: choose to implement based on hash table or linear layer + init_param: embedding init param-coefficient """ config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, optimizer_list=optimizer_list, mode=mode, value_dtype=value_dtype, shard_num=shard_num, - fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold, is_save=is_save) + fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold, + init_param=init_param, is_save=is_save) embedding = SparseEmbedding(config) return embedding @@ -158,6 +161,7 @@ class SparseEmbedding: self.send_count_map = dict() self.channel_name_dict = {True: [], False: []} self.modify_graph = False + self.init_param = config.get("init_param") self.set_slice_vocab_size() self.set_emb_size() @@ -751,7 +755,9 @@ class SparseEmbedding: f" {self.slice_host_vocabulary_size}.") def _initialize_variables(self): - initialized_tensor = self.emb_initializer(self.slice_device_vocabulary_size + self.embedding_size) + initialized_tensor = \ + self.emb_initializer(self.slice_device_vocabulary_size + self.embedding_size) * self.init_param + self.variable = tf.compat.v1.get_variable(self.table_name, trainable=False, initializer=initialized_tensor) # make sure sparse table variable will not be saved and restored within tf checkpoint. insert_removing_var_list(self.variable.name) diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index 1b85e202..631dc88f 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -11,9 +11,11 @@ using namespace MxRec; -TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, float mean, float stddev, int seed) +TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, float mean, float stddev, int seed, + float initK) : start(start), len(len), mean(mean), stddev(stddev), seed(seed) { + initParam = initK; generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); minBound = mean - static_cast(boundNum) * stddev; @@ -33,9 +35,9 @@ void TruncatedNormalInitializer::GenerateData(float* const emb, const int embSiz return; } std::generate_n(emb + start, len, [&]() { - float tmp = distribution(generator); + float tmp = initParam * distribution(generator); while (tmp < minBound || tmp > maxBound) { - tmp = distribution(generator); + tmp = initParam * distribution(generator); } return tmp; }); diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index d2da1bef..94ec6d2e 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -19,7 +19,7 @@ namespace MxRec { class TruncatedNormalInitializer : public Initializer { public: TruncatedNormalInitializer() = default; - TruncatedNormalInitializer(int start, int len, float mean, float stddev, int seed); + TruncatedNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK); ~TruncatedNormalInitializer() override {}; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 3480bf58..75cbf368 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -79,7 +79,8 @@ namespace MxRec { if (name == "truncated_normal_initializer") { initializerType = InitializerType::TRUNCATED_NORMAL; truncatedNormalInitializer = TruncatedNormalInitializer(start, len, - normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed); + normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed, + normalInitializerInfo.initK); } else if (name == "random_normal_initializer") { initializerType = InitializerType::RANDOM_NORMAL; randomNormalInitializer = RandomNormalInitializer(start, len, -- Gitee From 8cfb444f83d665e124148ad842bbbe34278b4b5a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 19 Jun 2023 16:57:52 +0800 Subject: [PATCH 146/551] Match-id-524f7feb3932a2838b7cfb4ac04b8de7cb0d4a3a --- .../random_normal_initializer.cpp | 6 +++--- .../random_normal_initializer.h | 2 +- .../truncated_normal_initializer.cpp | 13 +++++++------ .../truncated_normal_initializer.h | 2 +- src/core/utils/common.cpp | 11 +++++------ src/tests/initializer/initializer_test.cpp | 6 ++++-- 6 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index d832d875..b5411b50 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -11,10 +11,10 @@ using namespace MxRec; -RandomNormalInitializer::RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK) - : start(start), len(len), mean(mean), stddev(stddev), seed(seed) +RandomNormalInitializer::RandomNormalInitializer(int start, int len, std::tuple ret) + : start(start), len(len) { - initParam = initK; + std::tie(mean, stddev, seed, initParam) = ret; generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); } diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index 8f39315f..e4a9ef45 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -19,7 +19,7 @@ namespace MxRec { class RandomNormalInitializer : public Initializer { public: RandomNormalInitializer() = default; - RandomNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK); + RandomNormalInitializer(int start, int len, std::tuple); ~RandomNormalInitializer() override {}; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index 631dc88f..0e05ab03 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -5,17 +5,18 @@ * Date: 2022/12/22 */ -#include "truncated_normal_initializer.h" -#include #include +#include +#include +#include "truncated_normal_initializer.h" using namespace MxRec; -TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, float mean, float stddev, int seed, - float initK) - : start(start), len(len), mean(mean), stddev(stddev), seed(seed) +TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, std::tuple ret) + : start(start), len(len) { - initParam = initK; + std::tie(mean, stddev, seed, initParam) = ret; + generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); minBound = mean - static_cast(boundNum) * stddev; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index 94ec6d2e..39024e42 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -19,7 +19,7 @@ namespace MxRec { class TruncatedNormalInitializer : public Initializer { public: TruncatedNormalInitializer() = default; - TruncatedNormalInitializer(int start, int len, float mean, float stddev, int seed, float initK); + TruncatedNormalInitializer(int start, int len, std::tuple); ~TruncatedNormalInitializer() override {}; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 75cbf368..90bb3b26 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -76,16 +76,15 @@ namespace MxRec { InitializeInfo::InitializeInfo(std::string& name, int start, int len, NormalInitializerInfo normalInitializerInfo) : name(name), start(start), len(len), normalInitializerInfo(normalInitializerInfo) { + std::tuple ret(normalInitializerInfo.mean, normalInitializerInfo.stddev, + normalInitializerInfo.seed, normalInitializerInfo.initK); + if (name == "truncated_normal_initializer") { initializerType = InitializerType::TRUNCATED_NORMAL; - truncatedNormalInitializer = TruncatedNormalInitializer(start, len, - normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed, - normalInitializerInfo.initK); + truncatedNormalInitializer = TruncatedNormalInitializer(start, len, ret); } else if (name == "random_normal_initializer") { initializerType = InitializerType::RANDOM_NORMAL; - randomNormalInitializer = RandomNormalInitializer(start, len, - normalInitializerInfo.mean, normalInitializerInfo.stddev, normalInitializerInfo.seed, - normalInitializerInfo.initK); + randomNormalInitializer = RandomNormalInitializer(start, len, ret); } else { throw std::invalid_argument("Invalid Initializer Type."); } diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index 735fbf8a..87779ee5 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -47,7 +47,8 @@ TEST(InitializerTest, TruncatedNormalInitializerTest) { TruncatedNormalInitializer truncatedNormalInitializer; - truncatedNormalInitializer = TruncatedNormalInitializer(1, 10, 1.0, 0.3, 1); + std::tuple ret(1.0, 0.3, 1, 0.1); + truncatedNormalInitializer = TruncatedNormalInitializer(1, 10, ret); vector> embData; int vocabSize = 5; @@ -75,7 +76,8 @@ TEST(InitializerTest, TruncatedNormalInitializerTest) TEST(InitializerTest, RandomNormalInitializerTest) { - RandomNormalInitializer randomNormalInitializer(1, 10, 2.0, 0.5, 1, 0.1); + std::tuple ret(2.0, 0.5, 1, 0.1); + RandomNormalInitializer randomNormalInitializer(1, 10, ret); vector> embData; int vocabSize = 5; -- Gitee From 6f9583316780c4ddb05b7118df7f8aecc1e13ca4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 19 Jun 2023 17:38:55 +0800 Subject: [PATCH 147/551] Match-id-a32cc357abce35984122a76cc963aa9b0f73ab80 --- .../random_normal_initializer/random_normal_initializer.cpp | 4 ++-- .../random_normal_initializer/random_normal_initializer.h | 1 + .../truncated_normal_initializer.cpp | 1 - .../truncated_normal_initializer.h | 1 + src/tests/initializer/initializer_test.cpp | 1 + 5 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index b5411b50..c9f0b5b8 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -5,9 +5,9 @@ * Date: 2022/12/23 */ -#include "random_normal_initializer.h" -#include #include +#include +#include "random_normal_initializer.h" using namespace MxRec; diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index e4a9ef45..e020bd64 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -10,6 +10,7 @@ #include #include +#include #include "initializer/initializer.h" diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index 0e05ab03..da45f65f 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -6,7 +6,6 @@ */ #include -#include #include #include "truncated_normal_initializer.h" diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index 39024e42..67285072 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -10,6 +10,7 @@ #include #include +#include #include "initializer/initializer.h" diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index 87779ee5..8e7ad2a1 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -6,6 +6,7 @@ * History: NA */ +#include "tuple" #include #include -- Gitee From 63e9164d58bcc0549a1e2a0a67b4a1aa0ee9c8bb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 19 Jun 2023 17:41:09 +0800 Subject: [PATCH 148/551] Match-id-fbcfc2bdf6c6d4d9ea6d4cb000b39547f6d7fde7 --- src/tests/initializer/initializer_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index 8e7ad2a1..45ac9222 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -6,8 +6,8 @@ * History: NA */ -#include "tuple" #include +#include #include #include "initializer/initializer.h" -- Gitee From 42ac2d449728ced3eea6b822aab599d24bccaa8d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 19 Jun 2023 19:20:56 +0800 Subject: [PATCH 149/551] Match-id-eb993b4ee407a5d4258dc202a35a87348ca96cbc --- .../truncated_normal_initializer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index da45f65f..ed151726 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -18,8 +18,8 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, std:: generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); - minBound = mean - static_cast(boundNum) * stddev; - maxBound = mean + static_cast(boundNum) * stddev; + minBound = initParam * (mean - static_cast(boundNum) * stddev); + maxBound = initParam * (mean + static_cast(boundNum) * stddev); } -- Gitee From 401dc5cefce394d973a045fd311dd2a80748aeae Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 19 Jun 2023 20:07:21 +0800 Subject: [PATCH 150/551] Match-id-a52cc8285f7f9883d930a2b0f4dd552782c98745 --- mx_rec/core/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 8c3cdfbe..d2dfd29b 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -882,7 +882,7 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): output_shapes=[instance.slice_device_vocabulary_size], channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}')[0] initialized_tensor = instance.emb_initializer( - instance.slice_device_vocabulary_size + instance.embedding_size) + instance.slice_device_vocabulary_size + instance.embedding_size) * instance.init_param else: evict_pos = npu_ops.gen_npu_ops.get_next( output_types=[tf.int32], -- Gitee From 9e56eaf1d50f720df10576bd5b25f19c4c183358 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 20 Jun 2023 16:29:11 +0800 Subject: [PATCH 151/551] Match-id-22923c2a5b36335c4a83d3426306682497441cf6 --- mx_rec/util/initialize.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index ef5d0bc0..5d6addea 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -700,11 +700,6 @@ def set_ascend_env(): if os.getenv("RANK_TABLE_FILE"): os.environ["RANK_SIZE"] = str(rank_size) - else: - import socket - host_name = socket.gethostname() - host_ip = socket.gethostbyname(host_name) - os.environ["CM_WORKER_IP"] = host_ip os.environ["HCCL_CONNECT_TIMEOUT"] = "1200" os.environ["JOB_ID"] = "10086" -- Gitee From 36869861d55db4b0253fed9bf3ef7e7339f12c08 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 21 Jun 2023 11:40:18 +0800 Subject: [PATCH 152/551] Match-id-086f2421d36b5e2173145a3bf67dd4d6027008cb --- mx_rec/core/asc/manager.py | 4 +++- mx_rec/core/embedding.py | 3 +++ .../constant_initializer/constant_initializer.cpp | 8 ++++++-- .../constant_initializer/constant_initializer.h | 2 +- src/core/key_process/feature_admit_and_evict.cpp | 4 +++- src/core/utils/common.cpp | 7 ++++--- src/core/utils/common.h | 3 ++- src/pybind/module_main.cpp | 5 +++-- src/tests/initializer/initializer_test.cpp | 4 ++-- 9 files changed, 27 insertions(+), 13 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 515ec18f..8c71a8fa 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -105,9 +105,11 @@ def matched_emb_initializer(tabel_info): isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal), } if initializer_case_map.get("tf1/tf2_constant_initializer"): + init_param = tabel_info.init_param + logging.debug(f"constant_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") initializer = InitializeInfo(name="constant_initializer", start=0, len=tabel_info.scalar_emb_size, constant_initializer_info=ConstantInitializerInfo( - constant_val=tabel_info.emb_initializer.value)) + constant_val=tabel_info.emb_initializer.value, initK=init_param)) elif initializer_case_map.get("tf1/tf2_random_normal_initializer"): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed init_param = tabel_info.init_param diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index d2dfd29b..1c958431 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -340,6 +340,9 @@ class SparseEmbedding: if is_training not in self.lookup_info: self.lookup_info.add(is_training) + if not isinstance(self.init_param, int): + raise ValueError("Arg is_train should be a integer.") + if get_use_static(): if isinstance(send_count, int) and send_count > 0: if self._send_count and self._send_count != send_count: diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index 954ca98f..9f012abf 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -11,7 +11,11 @@ using namespace std; using namespace MxRec; -ConstantInitializer::ConstantInitializer(int start, int len, float value) : start(start), len(len), value(value) {} +ConstantInitializer::ConstantInitializer(int start, int len, float value, float initK) +: start(start), len(len), value(value) +{ + initParam = initK; +} void ConstantInitializer::GenerateData(float* const emb, const int embSize) { @@ -24,5 +28,5 @@ void ConstantInitializer::GenerateData(float* const emb, const int embSize) start, len, embSize); return; } - std::fill_n(emb + start, len, value); + std::fill_n(emb + start, len, initParam * value); } \ No newline at end of file diff --git a/src/core/initializer/constant_initializer/constant_initializer.h b/src/core/initializer/constant_initializer/constant_initializer.h index 68aa0654..8af23170 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.h +++ b/src/core/initializer/constant_initializer/constant_initializer.h @@ -17,7 +17,7 @@ namespace MxRec { class ConstantInitializer : public Initializer { public: ConstantInitializer() = default; - ConstantInitializer(int start, int len, float value); + ConstantInitializer(int start, int len, float value, float initK); ~ConstantInitializer() override {}; diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 7cb2caa8..5e9073bf 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -64,7 +64,9 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, spdlog::debug("FeatureAdmitAndEvict PrintSize, name:[{}], history key:[{}] ...", tensorName, m_recordsData.historyRecords[tensorName].size()); - m_recordsData.timestamps[tensorName] = batch->timestamp; + if (batch->timestamp > m_recordsData.timestamps[tensorName]) { + m_recordsData.timestamps[tensorName] = batch->timestamp; + } absl::flat_hash_map visitedRecords; for (auto& key : splitKey) { if (key == -1) { diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 90bb3b26..6b4b9367 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -53,8 +53,8 @@ namespace MxRec { : start(start), len(len), constantVal(constantVal), randomMin(randomMin), randomMax(randomMax) {} - ConstantInitializerInfo::ConstantInitializerInfo(float constantValue) - : constantValue(constantValue) + ConstantInitializerInfo::ConstantInitializerInfo(float constantValue, float initK) + : constantValue(constantValue), initK(initK) {} NormalInitializerInfo::NormalInitializerInfo(float mean, float stddev, int seed, float initK) @@ -67,7 +67,8 @@ namespace MxRec { { if (name == "constant_initializer") { initializerType = InitializerType::CONSTANT; - constantInitializer = ConstantInitializer(start, len, constantInitializerInfo.constantValue); + constantInitializer = ConstantInitializer(start, len, constantInitializerInfo.constantValue, + constantInitializerInfo.initK); } else { throw std::invalid_argument("Invalid Initializer Type."); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 5641524d..19ac882d 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -305,9 +305,10 @@ struct BatchTask { struct ConstantInitializerInfo { ConstantInitializerInfo() = default; - explicit ConstantInitializerInfo(float constantValue); + explicit ConstantInitializerInfo(float constantValue, float initK); float constantValue; + float initK = 1.0; }; struct NormalInitializerInfo { diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index ddb757dc..b8f253d9 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -140,8 +140,9 @@ void GetInitializeInfo(pybind11::module_ &m) void GetConstantInitializerInfo(pybind11::module_ &m) { pybind11::class_(m, "ConstantInitializerInfo") - .def(py::init(), py::arg("constant_val") = 0) - .def_readwrite("constant_val", &ConstantInitializerInfo::constantValue); + .def(py::init(), py::arg("constant_val") = 0, py::arg("initK") = 1.0) + .def_readwrite("constant_val", &ConstantInitializerInfo::constantValue) + .def_readwrite("initK", &ConstantInitializerInfo::initK); } void GetNormalInitializerInfo(pybind11::module_ &m) diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index 45ac9222..a1e59688 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -22,7 +22,7 @@ TEST(InitializerTest, ConstantInitializerTest) { ConstantInitializer constant_initializer; // start; end; constant_val; - constant_initializer = ConstantInitializer(1, 5, 7); + constant_initializer = ConstantInitializer(1, 5, 7, 0.1); vector> embData; int vocabSize = 5; @@ -40,7 +40,7 @@ TEST(InitializerTest, ConstantInitializerTest) std::cout << std::endl; } - ASSERT_EQ(embData.at(2).at(2), 7); + ASSERT_EQ(embData.at(2).at(2), 0.7); ASSERT_EQ(embData.at(2).at(0), 0); } -- Gitee From 41f4a2f04f67476052d28d71ac13276990ab1d86 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 21 Jun 2023 11:44:37 +0800 Subject: [PATCH 153/551] Match-id-573f67364b2d92841f8d5ba606efaf1e45d1906b --- .../initializer/constant_initializer/constant_initializer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index 9f012abf..5ce30a95 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -12,7 +12,7 @@ using namespace std; using namespace MxRec; ConstantInitializer::ConstantInitializer(int start, int len, float value, float initK) -: start(start), len(len), value(value) + : start(start), len(len), value(value) { initParam = initK; } -- Gitee From 61a43b41adc7ceeaf4d4d05406340b911e08189c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 21 Jun 2023 15:34:23 +0800 Subject: [PATCH 154/551] Match-id-ab2cfcf3aa682e1702ed0c493e7cc2f1ab277bf7 --- build/build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/build/build.sh b/build/build.sh index cecc19ec..c0ca4449 100644 --- a/build/build.sh +++ b/build/build.sh @@ -182,6 +182,7 @@ gen_tar_file() mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" cp -r "${src_path}"/../example ../build/"${pkg_dir}" + cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" cd ../build tar -zvcf "${release_tar}" "${pkg_dir}" || { warn "compression failed, packages might be broken" -- Gitee From e68500f98e984e09b4b00996b07752f56c82e27a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 21 Jun 2023 16:01:16 +0800 Subject: [PATCH 155/551] Match-id-56f1edbd81828e8ceddd1d0cf54bb6ef033ebe2d --- src/core/key_process/key_process.cpp | 2 ++ tools/perf/fast.sh | 44 ++++++++++++++++++---------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 2763189a..a6263c58 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -267,7 +267,9 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES try { while (true) { TimeCost getAndProcessTC; + TimeCost getBatchDataTC; batch = GetBatchData(channel, id); // get batch data from SingletonQueue + TIME_PRINT("getBatchDataTC(ms):{}", getBatchDataTC.ElapsedMS()); if (batch == nullptr) { break; } diff --git a/tools/perf/fast.sh b/tools/perf/fast.sh index 63fc2759..79b6ec63 100755 --- a/tools/perf/fast.sh +++ b/tools/perf/fast.sh @@ -102,21 +102,18 @@ parse_pipe_2_key_process() LOG_INFO "Step-2.1 GetBatchData" - grep 'get data time' $logfile | cut -d" " -f11 | \ - awk -F"[:,]" '{print $2}' | \ - awk 'BEGIN { max=0 } { sum+=$NF; if($NF>max) max=$NF } END \ - {printf "--|get data time: total=%d, max=%0.1f, avg=%0.1f\n", NR, max, sum/NR}' + grep 'getBatchDataTC' $logfile | \ + awk -F":" 'BEGIN { max=0 } { sum+=$NF; if($NF>max) max=$NF } END \ + {printf "--|get data time: total=%d, max=%0.1f, avg=%0.1f\n", NR, max, sum/NR}' - grep 'get data time' $logfile | cut -d" " -f11 | \ - awk -F"[:,]" '{print $2}' | \ - awk 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ - {printf "--|get data time(filter>2000ms): count=%d, avg=%0.1f\n", count, sum/count}' + grep 'getBatchDataTC' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ + {printf "--|get data time(filter>2000ms): count=%d, avg=%0.1f\n", count, sum/count}' - grep 'get data time' $logfile | cut -d" " -f11 | \ - awk -F"[:,]" '{print $2}' | \ - awk 'BEGIN { total=0; none_zero_ms_num=0 } { total++; if($NF>0) none_zero_ms_num++ } END \ - {printf "--|get data time: total=%d, none_zero_ms_num=%d, none_zero_ms_rate=%0.3f, zero_ms_rate=%0.3f\n", \ - total, none_zero_ms_num, none_zero_ms_num/total, (1-none_zero_ms_num/total)}' + grep 'getBatchDataTC' $logfile | \ + awk -F":" 'BEGIN { total=0; none_zero_ms_num=0 } { total++; if($NF>0) none_zero_ms_num++ } END \ + {printf "--|get data time: total=%d, none_zero_ms_num=%d, none_zero_ms_rate=%0.3f, zero_ms_rate=%0.3f\n", \ + total, none_zero_ms_num, none_zero_ms_num/total, (1-none_zero_ms_num/total)}' LOG_INFO "Step-2.2 KeyProcess" @@ -240,6 +237,18 @@ parse_pipe_3_get_and_send_tensors_with_ddr() fi } +parse_pipe_3_get_and_send_tensors_without_ddr() +{ + LOG_NOTICE "Pipe-3: Get and Send Tensors (without DDR)" + + $(grep 'ParseKeysTC HBM mode (ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'ParseKeysTC HBM mode (ms)' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ + {printf "--|ParseKeysTC(filter>2000ms): avg=%0.1f\n", sum/count}' + fi +} + main() { validate_options $@ @@ -256,8 +265,13 @@ main() if [ $? -eq 0 ]; then parse_pipe_3_get_and_send_tensors_with_ddr else - parse_pipe_3_get_tensors_no_ddr - parse_pipe_4_send_tensors_no_ddr + $(grep 'ParseKeysTC HBM mode (ms)' $logfile > /dev/null 2>&1) + if [ $? -eq 0 ]; then + parse_pipe_3_get_and_send_tensors_without_ddr + else + parse_pipe_3_get_tensors_no_ddr + parse_pipe_4_send_tensors_no_ddr + fi fi } -- Gitee From b2d07bd63c73af60fdc719eb135bd35bc1069f52 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 21 Jun 2023 16:18:59 +0800 Subject: [PATCH 156/551] Match-id-d3f294573ff553a22577c3b2fbe8bcd1451a41f1 --- example/little_demo/run.sh | 1 + tools/perf/fast.sh | 18 +++++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index f5c4e739..2d460786 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -37,6 +37,7 @@ export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特 ################# 性能调优相关 #################### export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 export FAST_UNIQUE=0 #if use fast unique +export MGMT_HBM_TASK_MODE=0 #if async h2d (get and send tensors) ################################################ # 帮助信息,不需要修改 diff --git a/tools/perf/fast.sh b/tools/perf/fast.sh index 79b6ec63..20aec3df 100755 --- a/tools/perf/fast.sh +++ b/tools/perf/fast.sh @@ -192,9 +192,9 @@ parse_pipe_2_key_process() # common end } -parse_pipe_3_get_tensors_no_ddr() +parse_pipe_3_get_tensors_async_no_ddr() { - LOG_NOTICE "Pipe-3: Get Tensors (no DDR)" + LOG_NOTICE "Pipe-3: Get Tensors async (no DDR)" $(grep 'getAllTensorTC(ms)' $logfile > /dev/null 2>&1) if [ $? == 0 ]; then @@ -203,9 +203,9 @@ parse_pipe_3_get_tensors_no_ddr() fi } -parse_pipe_4_send_tensors_no_ddr() +parse_pipe_4_send_tensors_async_no_ddr() { - LOG_NOTICE "Pipe-4: H2D Send Tensors (no DDR)" + LOG_NOTICE "Pipe-4: H2D Send Tensors async (no DDR)" $(grep 'sendTensorsTC(ms)' $logfile > /dev/null 2>&1) if [ $? == 0 ]; then @@ -237,9 +237,9 @@ parse_pipe_3_get_and_send_tensors_with_ddr() fi } -parse_pipe_3_get_and_send_tensors_without_ddr() +parse_pipe_3_get_and_send_tensors_sync_without_ddr() { - LOG_NOTICE "Pipe-3: Get and Send Tensors (without DDR)" + LOG_NOTICE "Pipe-3: Get and Send Tensors sync (no DDR)" $(grep 'ParseKeysTC HBM mode (ms)' $logfile > /dev/null 2>&1) if [ $? == 0 ]; then @@ -267,10 +267,10 @@ main() else $(grep 'ParseKeysTC HBM mode (ms)' $logfile > /dev/null 2>&1) if [ $? -eq 0 ]; then - parse_pipe_3_get_and_send_tensors_without_ddr + parse_pipe_3_get_and_send_tensors_sync_without_ddr else - parse_pipe_3_get_tensors_no_ddr - parse_pipe_4_send_tensors_no_ddr + parse_pipe_3_get_tensors_async_no_ddr + parse_pipe_4_send_tensors_async_no_ddr fi fi } -- Gitee From 638798a62038a69f52efbf93d25b58f013f62d9e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 21 Jun 2023 18:27:22 +0800 Subject: [PATCH 157/551] Match-id-392358fab7b21095391280de9bee8aef75db587d --- mx_rec/constants/constants.py | 2 + mx_rec/core/asc/build_graph.py | 6 +- mx_rec/core/asc/feature_spec.py | 19 +++ mx_rec/core/asc/helper.py | 45 +++----- mx_rec/core/asc/manager.py | 21 +--- mx_rec/core/embedding.py | 112 +++++++++--------- mx_rec/graph/modifier.py | 153 +++++++++++++++++++----- src/core/hd_transfer/hd_transfer.cpp | 11 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 166 +++++++++++---------------- src/core/hybrid_mgmt/hybrid_mgmt.h | 6 +- src/core/key_process/key_process.cpp | 53 ++------- src/core/key_process/key_process.h | 2 - src/core/utils/common.h | 13 +-- src/ops_tf/hybrid_dataset_ops.cpp | 22 +--- src/pybind/module_main.cpp | 15 +-- src/tests/emb_mgmt/emb_mgmt_test.cpp | 15 +-- tools/mx_rec_perf.sh | 88 ++++++++------ tools/perf/fast.sh | 2 +- 18 files changed, 375 insertions(+), 376 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index c39349e9..cd806a3e 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -11,6 +11,7 @@ ASCEND_CUTTING_POINT = "ASCEND_CUTTING_POINT" ASCEND_SPARSE_LOOKUP_ENTRANCE = "ASCEND_SPARSE_LOOKUP_ENTRANCE" ASCEND_SPARSE_LOOKUP_ID_OFFSET = "ASCEND_SPARSE_LOOKUP_ID_OFFSET" ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR = "ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR" +ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT = "ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT" # dynamic shape identity ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX = "ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX" # hot embed function identity @@ -83,6 +84,7 @@ class ASCAnchorAttr(Enum): FEATURE_SPEC = "feature_spec" ALL2ALL_MATRIX = "all2all_matrix" HOT_POS = "hot_pos" + LOOKUP_RESULT = "lookup_result" class MxRecMode(BaseEnum): diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 184210b6..f5ebee29 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -94,16 +94,12 @@ def get_all2all_args(use_static: bool, config: dict) -> list: return all2all_args -def get_preprocessed_tensor_for_asc(table, config, ids_channel_name=None, modify_graph=False): +def get_preprocessed_tensor_for_asc(table, config): use_static = get_use_static() max_lookup_vec_size = None if use_static: max_lookup_vec_size = config.get("send_count") * config.get("rank_size") - if modify_graph: - config["table_name"] = ids_channel_name - logging.debug(f"GetNext, table_name: {config.get('table_name')}, modify_graph: {modify_graph}") - with tf.compat.v1.variable_scope("restore_vector"): restore_vector, hot_pos = get_restore_vector(config) diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 750c7e8f..60d93291 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -3,6 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import logging +from typing import Union from functools import reduce import tensorflow as tf @@ -193,3 +194,21 @@ def get_feature_spec(table_name, access_and_evict_config): access_threshold = access_and_evict_config.get("access_threshold") eviction_threshold = access_and_evict_config.get("eviction_threshold") return FeatureSpec(table_name, access_threshold=access_threshold, eviction_threshold=eviction_threshold) + + +def set_temporary_feature_spec_attribute(mock_feature_spec: FeatureSpec, total_feature_count: Union[int, tf.Tensor]): + """ + Set properties for a temporary feature_spec. + + Args: + mock_feature_spec: A temporary feature_spec consisting of multiple feature_spec with the same table. + total_feature_count: Inner product of the shape of a tensor. + + Returns: None + + """ + mock_feature_spec.batch_size = total_feature_count + mock_feature_spec.dims = [total_feature_count, 1] + mock_feature_spec.initialized = True + mock_feature_spec.pipeline_mode.add(True) + mock_feature_spec.pipeline_mode.add(False) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index f33f5e3c..529d78da 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -283,6 +283,7 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list): feature_id_requests = zip(feature_id_list, split_list, table_name_list) feature_id_requests = sorted(feature_id_requests, key=lambda x: (x[2], x[0].name)) logging.debug(f" features to merge: {feature_id_requests}") + last_table_name = None last_split = 0 last_tensorshape_split = 0 @@ -302,12 +303,14 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list): last_table_name = table_name last_split = split last_tensorshape_split = tf.math.reduce_prod(tf.shape(feature_id)) + if last_table_name is not None: output_table_name_list.append(last_table_name) output_split_list.append(last_split) output_tensorshape_split_list.append(last_tensorshape_split) logging.debug(f"merge request from {table_name_list} {split_list} " f" to {output_table_name_list} {output_split_list}") + list_set = { 'output_feature_id_list': output_feature_id_list, 'output_split_list': output_split_list, @@ -320,7 +323,6 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list): def send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict): is_training = input_dict["is_training"] timestamp = input_dict["timestamp"] - auto_change_graph = input_dict["auto_change_graph"] host_pipeline_ops = get_host_pipeline_ops() use_static = get_use_static() timestamp_feature_id = [] @@ -329,14 +331,11 @@ def send_feature_id_request_async(feature_id_list, split_list, table_name_list, timestamp_feature_id = feature_id_list[:1] feature_id_list = feature_id_list[1:] - if not auto_change_graph: # future support acg - list_set = merge_feature_id_request(feature_id_list, split_list, table_name_list) - feature_id_list = list_set.get("output_feature_id_list") - split_list = list_set.get("output_split_list") - table_name_list = list_set.get("output_table_name_list") - tensorshape_split_list = list_set.get("output_tensorshape_split_list") - else: - tensorshape_split_list = split_list + list_set = merge_feature_id_request(feature_id_list, split_list, table_name_list) + feature_id_list = list_set.get("output_feature_id_list") + split_list = list_set.get("output_split_list") + table_name_list = list_set.get("output_table_name_list") + tensorshape_split_list = list_set.get("output_tensorshape_split_list") # check training mode order and ensure channel id channel_id = get_training_mode_channel_id(is_training=is_training) @@ -345,37 +344,19 @@ def send_feature_id_request_async(feature_id_list, split_list, table_name_list, feature_id_list = timestamp_feature_id + feature_id_list concat_tensor = tf.concat(feature_id_list, axis=0) - ids_channel_name_list = [] - if auto_change_graph: - for _, table_instance in export_table_instances().items(): - if table_instance.table_name not in table_name_list: - logging.info(f"table_name ('{table_instance.table_name}') not in table_name_list: {table_name_list}") - continue - if len(table_instance.channel_name_list) > 1: - ids_channel_name_list.extend(table_instance.channel_name_list) - else: - ids_channel_name_list.append(table_instance.table_name) - if len(ids_channel_name_list) != len(tensorshape_split_list): - raise RuntimeError(f"The length of ids_channel_name_list and tensorshape_split_list must be equal, " - f"ids_channel_name_list: {ids_channel_name_list}, " - f"tensorshape_split_list: {tensorshape_split_list}") - if len(split_list) == 0 or len(tensorshape_split_list) == 0: raise RuntimeError(f"The length of split list can not be 0.") if use_static: - logging.debug(f"read_emb_key_v2(static), table_name_list: {table_name_list}, split_list: {split_list}, " - f"ids_channel_name_list: {ids_channel_name_list}") + logging.info(f"read_emb_key_v2(static), table_name_list: {table_name_list}, split_list: {split_list}") return host_pipeline_ops.read_emb_key_v2(concat_tensor, channel_id=channel_id, splits=split_list, - emb_name=table_name_list, timestamp=timestamp, - channel_name=ids_channel_name_list, modify_graph=auto_change_graph) + emb_name=table_name_list, timestamp=timestamp) - logging.debug(f"read_emb_key_v2_dynamic, table_name_list: {table_name_list}, " - f"tensorshape_split_list: {tensorshape_split_list}, ids_channel_name_list: {ids_channel_name_list}") + logging.info(f"read_emb_key_v2(dynamic), table_name_list: {table_name_list}, " + f"tensorshape_split_list: {tensorshape_split_list}") return host_pipeline_ops.read_emb_key_v2_dynamic(concat_tensor, tensorshape_split_list, channel_id=channel_id, emb_name=table_name_list, - timestamp=timestamp, channel_name=ids_channel_name_list, - modify_graph=auto_change_graph) + timestamp=timestamp) def do_insert(args, insert_tensors, splits, table_names, input_dict): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index db51c04f..e5d9c8ca 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -60,29 +60,12 @@ def generate_table_info_list(): if static_shape_rec_flag or dynamic_shape_rec_flag: logging.debug(f"table_instance.slice_device_vocabulary_size: {table_instance.slice_device_vocabulary_size}") logging.debug(f"table_instance.slice_host_vocabulary_size: {table_instance.slice_host_vocabulary_size}") - if table_instance.modify_graph and len(table_instance.channel_name_list) > 1 \ - and table_instance.slice_host_vocabulary_size > 0: - raise RuntimeError(f"In the case of modify graph, multiple lookups of a table are currently " - f"only compatible with HBM mode.") - if len(table_instance.channel_name_list) == 1: - ids_channel_name = table_instance.channel_name_list[0] - table_instance.channel_name_list = [table_instance.table_name] - try: - table_instance.send_count_map.pop(ids_channel_name) - table_instance.send_count_map[table_instance.table_name] = table_instance.send_count - except KeyError as error: - raise KeyError(f"ids_channel_name '{ids_channel_name}' not in send_count_map " - f"'{table_instance.send_count_map}'") from error - logging.debug(f"table_instance, table_name: {table_instance.table_name}, channel_name_list: " - f"{table_instance.channel_name_list}, send_count_map: {table_instance.send_count_map}") table_info = EmbInfo(table_instance.table_name, table_instance.send_count, table_instance.scalar_emb_size, - table_instance.ext_emb_size, table_instance.modify_graph, table_instance.is_save, - table_instance.channel_name_list, + table_instance.ext_emb_size, table_instance.is_save, [table_instance.slice_device_vocabulary_size, table_instance.slice_host_vocabulary_size], [matched_emb_initializer(table_instance)] + - matched_opt_slot_initializers(table_instance), - table_instance.send_count_map) + matched_opt_slot_initializers(table_instance)) table_info_list.append(table_info) return table_info_list diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 9d8358b1..f2e0bee7 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -5,7 +5,6 @@ import logging import math import time -from typing import Union from collections import defaultdict import numpy as np @@ -14,12 +13,13 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc -from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec +from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temporary_feature_spec_attribute from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ - DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32 + DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, \ + ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, trigger_evict, get_table_instance_by_name, get_use_hot, get_device_id, export_feature_spec, \ @@ -41,7 +41,7 @@ def create_table(**kwargs): shard_num = kwargs.get("shard_num", 1) fusion_optimizer_var = kwargs.get("fusion_optimizer_var", True) hashtable_threshold = kwargs.get("hashtable_threshold", 0) - is_save = kwargs.get("is_save", True) + is_save = kwargs.get("is_save", True) """ Args: @@ -141,6 +141,7 @@ class SparseEmbedding: self.optimizer_slot_info_list = [] self._slot_num = dict() self._send_count = 0 + self.same_table_send_count = 0 self._use_feature_mapping = False self.skip_emb_transfer = True if self.host_vocabulary_size <= 0 else False self._default_name_count = -1 @@ -154,9 +155,7 @@ class SparseEmbedding: self.lookup_info = set() self.lookup_result = dict() self.use_dynamic_expansion = get_use_dynamic_expansion() - self.channel_name_list = [] - self.send_count_map = dict() - self.channel_name_dict = {True: [], False: []} + self.lookup_name_list = [] self.modify_graph = False self.set_slice_vocab_size() @@ -308,7 +307,7 @@ class SparseEmbedding: def check_multi_lookup_times(self): if self.modify_graph: self.lookup_result = dict() - if len(self.channel_name_list) > MULTI_LOOKUP_TIMES or len(self.lookup_result) > MULTI_LOOKUP_TIMES: + if len(self.lookup_name_list) > MULTI_LOOKUP_TIMES or len(self.lookup_result) > MULTI_LOOKUP_TIMES: run_mode = "Modify Graph" if self.modify_graph else "Feature Spec" raise RuntimeError(f"In '{run_mode}' mode, the number of multiple sparse lookup for a table" f"({self.table_name}) is {MULTI_LOOKUP_TIMES}.") @@ -372,12 +371,6 @@ class SparseEmbedding: self._optimizer[key] = state_dict - def set_channel_name(self, ids_channel_name, eval_mode): - self.channel_name_list.append(ids_channel_name) - if not eval_mode: - self.channel_name_dict.get(True).insert(0, ids_channel_name) - self.channel_name_dict.get(False).insert(0, ids_channel_name) - def lookup_for_asc(self, ids: tf.Tensor, send_count, **kwargs): """ @@ -400,6 +393,7 @@ class SparseEmbedding: if is_asc_frozen() and is_training: raise RuntimeError(f"Cannot build new sparse forward graph after emb cache management was built.") + self.same_table_send_count += send_count if send_count is not None else 0 feature_spec = get_feature_spec(self.table_name, kwargs.get("access_and_evict_config")) feature_spec.set_feat_attribute(ids, is_training) # 'clear_channel()' function needs to be executed after 'set_feat_attribute()' function @@ -414,19 +408,15 @@ class SparseEmbedding: use_dynamic_expansion = get_use_dynamic_expansion() use_static = get_use_static() use_hot = get_use_hot() - eval_mode = not is_training and len(self.channel_name_dict.get(not is_training)) == 0 - ids_channel_name = "" + eval_mode = not is_training and get_training_mode_channel_id(is_training) is None + ids_lookup_name = feature_spec.name + "_lookup_ids" # set in train mode, train and eval mode, eval mode if is_training or eval_mode: - ids_channel_name = feature_spec.name + "_lookup_ids" - self.set_channel_name(ids_channel_name, eval_mode) - send_count = send_count if send_count is not None else 0 - self._send_count = send_count - self.send_count_map[ids_channel_name] = send_count + self.lookup_name_list.append(ids_lookup_name) self.modify_graph = kwargs.get("modify_graph", True) self.check_multi_lookup_times() logging.debug(f"In lookup_for_asc function, table name: {self.table_name}, anchor_ids: {anchor_ids}, " - f"ids_channel_name: {ids_channel_name}, use_dynamic_expansion: {use_dynamic_expansion}, " + f"ids_lookup_name: {ids_lookup_name}, use_dynamic_expansion: {use_dynamic_expansion}, " f"use_static: {use_static}, use_hot: {use_hot}") rank_size = get_rank_size() @@ -464,7 +454,11 @@ class SparseEmbedding: all2all_matrix = None if not use_static: - all2all_matrix = tf.ones(shape=[rank_size, rank_size], dtype=tf.int64, name="all2all_matrix") + # In the case of multiple lookups of a table, the all2all_matrix does not run the 'getnext' op + # to obtain the actual value. Instead, the initial value is 1. So it needs to be multiplied by + # 'self.scalar_emb_size' to ensure the correctness of the 'Reshape' op in the get_own_emb function. + all2all_matrix = tf.ones(shape=[rank_size, rank_size], + dtype=tf.int64, name="all2all_matrix") * self.scalar_emb_size all2all_matrix = tf.identity(all2all_matrix, name=ASCAnchorAttr.ALL2ALL_MATRIX.value) tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, all2all_matrix) SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.ALL2ALL_MATRIX] = all2all_matrix @@ -472,7 +466,8 @@ class SparseEmbedding: hot_pos = None if use_hot: import mxrec_pybind - hot_size = int(mxrec_pybind.get_ub_hot_size(get_device_id()) / self.emb_size) + emb_size = self.scalar_emb_size if self.skip_emb_transfer else self.ext_emb_size + hot_size = int(mxrec_pybind.get_ub_hot_size(get_device_id()) / emb_size) hot_pos = tf.ones(shape=[hot_size, ], dtype=tf.int32, name="hot_pos") hot_pos = tf.identity(hot_pos, name=ASCAnchorAttr.HOT_POS.value) tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_HOT_POS, hot_pos) @@ -497,6 +492,12 @@ class SparseEmbedding: dest_shape = array_ops.concat([array_ops.shape(feat_ids), [self.scalar_emb_size]], 0) lookup_result = array_ops.reshape(embeddings, dest_shape) + # In the case of multiple lookups of a table, the lookup result node needs to be recorded and + # replaced during modify graph. + lookup_result = tf.identity(lookup_result, name=ASCAnchorAttr.LOOKUP_RESULT.value) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT, lookup_result) + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.LOOKUP_RESULT] = lookup_result + def grad(lookup_diff): embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) logging.debug(f"bp rank size: {rank_size}") @@ -567,20 +568,6 @@ class SparseEmbedding: same_table_tensor_list.append(tensor) return same_table_tensor_list - def set_feature_spec_attr(mock_feature_spec: FeatureSpec, total_feature_count: Union[int, tf.Tensor]): - """ - Set properties for a temporary feature_spec. - Args: - mock_feature_spec: A temporary feature_spec consisting of multiple feature_spec with the same table. - total_feature_count: Inner product of the shape of a tensor. - Returns: None - """ - mock_feature_spec.batch_size = total_feature_count - mock_feature_spec.dims = [total_feature_count, 1] - mock_feature_spec.initialized = True - mock_feature_spec.pipeline_mode.add(True) - mock_feature_spec.pipeline_mode.add(False) - same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", feat_count=1, table_name=table_name) @@ -592,31 +579,48 @@ class SparseEmbedding: tensor_list = get_tensor_list() tensor_split_list = [tf.math.reduce_prod(array_ops.shape(tensor)) for tensor in tensor_list] total_feature_count = tf.add_n(tensor_split_list) - set_feature_spec_attr(mock_feature_spec, total_feature_count) + set_temporary_feature_spec_attribute(mock_feature_spec, total_feature_count) kwargs["multi_lookup"] = True lookup_result = self.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, send_count * same_table_spec_count, **kwargs) logging.debug(f"lookup table {table_name} via {tensor_split_list}") - - lookup_result_split = tf.split(lookup_result, tensor_split_list) - if len(lookup_result_split) != len(same_table_feature_spec) or ( - not get_use_static() and len(same_table_feature_spec) != len(tensor_list)): - raise RuntimeError(f"shape not match. len(lookup_result_split): {len(lookup_result_split)}," - f"len(same_table_feature_spec): {len(same_table_feature_spec)}" - f"len(tensor_list): {len(tensor_list)}") - for idx, (one_feature_spec, one_result) in enumerate(zip(same_table_feature_spec, lookup_result_split)): - if one_feature_spec.name not in self.lookup_result: - self.lookup_result[one_feature_spec.name] = {} - if get_use_static(): - dest_shape = one_feature_spec.dims + [self.scalar_emb_size] - else: - dest_shape = array_ops.concat([array_ops.shape(tensor_list[idx]), [self.scalar_emb_size]], 0) - self.lookup_result[one_feature_spec.name][is_training] = array_ops.reshape(one_result, dest_shape) + self.split_lookup_result(same_table_feature_spec, tensor_split_list, tensor_list, lookup_result, + is_training) self.check_multi_lookup_times() return self.lookup_result.get(spec_name).get(is_training) + def split_lookup_result(self, same_table_feature_spec: list, tensor_split_list: list, tensor_list: list, + lookup_result: tf.Tensor, is_training: bool): + """ + Splits the result of the merge sparse lookup. + + Args: + same_table_feature_spec: a list of feature specs in a same table + tensor_split_list: a list of tensor split in a same table + tensor_list: a list of tensor in a same table + lookup_result: results of the sparse lookup + is_training: indicates whether the training mode is used. + + Returns: None + + """ + lookup_result_split = tf.split(lookup_result, tensor_split_list) + if len(lookup_result_split) != len(same_table_feature_spec) or ( + not get_use_static() and len(same_table_feature_spec) != len(tensor_list)): + raise RuntimeError(f"shape not match. len(lookup_result_split): {len(lookup_result_split)}," + f"len(same_table_feature_spec): {len(same_table_feature_spec)}" + f"len(tensor_list): {len(tensor_list)}") + for idx, (one_feature_spec, one_result) in enumerate(zip(same_table_feature_spec, lookup_result_split)): + if one_feature_spec.name not in self.lookup_result: + self.lookup_result[one_feature_spec.name] = {} + if get_use_static(): + dest_shape = one_feature_spec.dims + [self.scalar_emb_size] + else: + dest_shape = array_ops.concat([array_ops.shape(tensor_list[idx]), [self.scalar_emb_size]], 0) + self.lookup_result[one_feature_spec.name][is_training] = array_ops.reshape(one_result, dest_shape) + def lookup_for_asc_with_feature_spec_inner(self, feature_spec: FeatureSpec, send_count: int, **kwargs): """ Args: diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index a0a7eed6..16c64d9e 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -4,13 +4,15 @@ import logging from collections import defaultdict +from functools import reduce import tensorflow as tf +from tensorflow.python.ops import array_ops from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.helper import get_asc_insert_func -from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.asc.feature_spec import FeatureSpec, set_temporary_feature_spec_attribute from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ @@ -324,6 +326,7 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): cutting_point in sub_cutting_point_list] table_names = [SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).table_name for cutting_point in sub_cutting_point_list] + tgt_dataset = tgt_dataset.map( get_asc_insert_func(feature_numbers=feature_numbers, table_names=table_names, args_index_list=input_index_list, is_training=is_training, dump_graph=dump_graph)) @@ -341,22 +344,107 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): new_get_next_op_name = find_target_dataset_op(one_tensor.op, "IteratorGetNext").name update_input_tensor_with_new_batch(records.get("replacement_specs"), new_get_next_op_name) - for _, cutting_point in enumerate(sub_cutting_point_list): - feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) - table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) - channel_id = get_training_mode_channel_id(is_training) - config = dict( - batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, - send_count=table_instance.send_count, channel_id=channel_id, rank_size=get_rank_size(), - table_name=table_instance.table_name, skip_emb_transfer=table_instance.skip_emb_transfer, - ext_emb_size=table_instance.ext_emb_size, emb_size=table_instance.scalar_emb_size, - use_hot=get_use_hot(), device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion()) - build_asc_graph(table_instance, cutting_point, config, is_training) + # multiple lookups of a same table + lookup_for_same_table(sub_cutting_point_list, is_training) + # replace the stub node for sparse lookup from the graph + replace_stub_node_with_asc_graph(sub_cutting_point_list, is_training) logging.info("Graph has been revised.") export_pb_graph("new_graph.pb", dump_graph) +def lookup_for_same_table(sub_cutting_point_list: list, is_training: bool): + """ + Merge multiple lookups of a sparse table into one lookup. + + Args: + sub_cutting_point_list: the feature ids list passed in by sparse lookup + is_training: indicates whether the training mode is used + + Returns: None + + """ + same_table_feature_spec_dict = {} + feature_spec_ids_dict = {} + for cutting_point in sub_cutting_point_list: + feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) + table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) + if len(table_instance.lookup_name_list) > 1: + if same_table_feature_spec_dict.get(table_instance.table_name) is None: + same_table_feature_spec_dict[table_instance.table_name] = [] + same_table_feature_spec_dict[table_instance.table_name].append(feature_spec) + feature_spec_ids_dict[feature_spec.name] = cutting_point + + for table_name, same_feature_spec_list in same_table_feature_spec_dict.items(): + same_table_feature_spec = sorted(same_feature_spec_list, key=lambda x: x.name) + mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", feat_count=1, table_name=table_name) + + tensor_split_list = [] + tensor_list = [] + table_instance = None + for one_feature_spec in same_table_feature_spec: + feature_ids = feature_spec_ids_dict.get(one_feature_spec.name) + if feature_ids is None: + raise RuntimeError(f"In the case of multiple lookups of a table, feature ids cannot be None.") + tensor_list.append(feature_ids) + table_instance = SparseEmbedding.get_anchor_attribute(feature_ids, ASCAnchorAttr.TABLE_INSTANCE) + + # dynamic shape + if not get_use_static(): + tensor_split_list.append(tf.math.reduce_prod(array_ops.shape(feature_ids))) + continue + + # static shape + rank = feature_ids.shape.rank + if rank < 1: + raise ValueError(f"Given tensor rank cannot be smaller than 1, which is {rank} now.") + dims = feature_ids.shape.as_list() + feat_cnt = 1 if rank == 1 else reduce(lambda x, y: x * y, dims[1:]) + tensor_split_list.append(dims[0] * feat_cnt) + + total_feature_count = sum(tensor_split_list) if get_use_static() else tf.add_n(tensor_split_list) + set_temporary_feature_spec_attribute(mock_feature_spec, total_feature_count) + + kwargs = {"multi_lookup": True, "is_train": is_training} + if table_instance is None: + raise RuntimeError(f"In the case of multiple lookups of a table, table instance cannot be None.") + lookup_result = table_instance.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, + table_instance.same_table_send_count, + **kwargs) + table_instance.split_lookup_result(same_table_feature_spec, tensor_split_list, tensor_list, lookup_result, + is_training) + logging.info(f"Multiple lookups of a table for '{table_name}' have completed.") + + +def replace_stub_node_with_asc_graph(sub_cutting_point_list: list, is_training: bool): + """ + Replace the stub node for sparse lookup from the graph. e.g., id_offset, restore_vector, etc. + + Args: + sub_cutting_point_list: the feature ids list passed in by sparse lookup + is_training: indicates whether the training mode is used + + Returns: None + + """ + for _, cutting_point in enumerate(sub_cutting_point_list): + feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) + table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) + channel_id = get_training_mode_channel_id(is_training) + config = dict( + batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, + send_count=table_instance.send_count, channel_id=channel_id, rank_size=get_rank_size(), + table_name=table_instance.table_name, skip_emb_transfer=table_instance.skip_emb_transfer, + ext_emb_size=table_instance.ext_emb_size, emb_size=table_instance.scalar_emb_size, + use_hot=get_use_hot(), device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion()) + + lookup_result = None + if len(table_instance.lookup_name_list) > 1 and feature_spec.name in table_instance.lookup_result and \ + is_training in table_instance.lookup_result.get(feature_spec.name): + lookup_result = table_instance.lookup_result.get(feature_spec.name).get(is_training) + build_asc_graph(config, table_instance, cutting_point, lookup_result) + + def get_timestamp_index(get_next_op, is_training): timestamp_tensor_list = tf.compat.v1.get_collection(ASCEND_TIMESTAMP) timestamp_index = None @@ -378,30 +466,39 @@ def get_timestamp_index(get_next_op, is_training): return timestamp_index -def build_asc_graph(table_instance, cutting_point, config, is_training): +def build_asc_graph(config: dict, table_instance: SparseEmbedding, cutting_point: tf.Tensor, lookup_result: tf.Tensor): + """ + Build the GetNext node in the graph and replace the stub node with this node. + + Args: + config: parameters required for GetNext + table_instance: sparse embedding table + cutting_point: the feature ids passed in by sparse lookup + lookup_result: results of the sparse lookup + + Returns: None + + """ # returned results swap_pos and swap_len were not in used, will be applied for DDR mode logging.debug(f"try to replace anchors for table {config.get('table_name')} on channel {config.get('channel_id')}") - skip_emb_transfer = config.get("skip_emb_transfer") - logging.info(f"modifier build_asc_graph skip_emb_transfer: {skip_emb_transfer}") - if len(table_instance.channel_name_list) > 1: - channel_name_queue = table_instance.channel_name_dict.get(is_training) - if len(channel_name_queue) < 1: - raise ValueError(f"The length of channel_name_queue must be greater than or equal to 1.") - ids_channel_name = channel_name_queue.pop() - config["send_count"] = table_instance.send_count_map.get(ids_channel_name) - elif len(table_instance.channel_name_list) == 1: - ids_channel_name = config.get('table_name') - else: - raise ValueError(f"The length of channel_name_list must be greater than or equal to 1.") + # In the case of multiple lookups of a table, replace the stub node of the lookup result in the graph + if len(table_instance.lookup_name_list) > 1: + if lookup_result is None: + raise RuntimeError(f"In the case of multiple lookups of a table, lookup result cannot be None.") + replace_anchor_vec(cutting_point, ASCAnchorAttr.LOOKUP_RESULT, lookup_result) + logging.info(f"The lookup result corresponding to feature ids '{cutting_point}' has been replaced by " + f"'{lookup_result}'.") + return + skip_emb_transfer = config.get("skip_emb_transfer") + logging.info(f"modifier build_asc_graph skip_emb_transfer: {skip_emb_transfer}") if skip_emb_transfer: - result = get_preprocessed_tensor_for_asc(table_instance.variable, config, ids_channel_name, - table_instance.modify_graph) + result = get_preprocessed_tensor_for_asc(table_instance.variable, config) else: variable_list = [table_instance.variable] \ + [slot_info.get("slot") for slot_info in table_instance.optimizer_slot_info_list] - result = get_preprocessed_tensor_for_asc(variable_list, config, ids_channel_name, table_instance.modify_graph) + result = get_preprocessed_tensor_for_asc(variable_list, config) restore_vector = result.get("restore_vector") hot_pos = result.get("hot_pos") id_offsets = result.get("id_offsets") diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 7928f8f2..4181fa68 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -31,14 +31,9 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) } spdlog::info(MGMT + "end Set device, rank:{}", localRankId); for (const auto& embInfo: embInfos) { - vector names = {embInfo.name}; - if (embInfo.modifyGraph) { - names = embInfo.channelNames; - } - for (const string& name: names) { - for (int i = 0; i < MAX_CHANNEL_NUM; ++i) { - CreateChannel(localRankId, name, i); - } + auto embName = embInfo.name; + for (int i = 0; i < MAX_CHANNEL_NUM; ++i) { + CreateChannel(localRankId, embInfo.name, i); } } running = true; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 30fbf0cf..55c42554 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -472,36 +472,28 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) spdlog::info(MGMT + "start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); for (const auto& embInfo: mgmtEmbInfo) { TimeCost getAllTensorTC; - vector names = {embInfo.name}; - if (embInfo.modifyGraph) { - names = embInfo.channelNames; + auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); + if (infoVecs == nullptr) { + spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); + return false; + } + switch (channelId) { + case TRAIN_CHANNEL_ID: + lookUpKeysQueueForTrain->Pushv({ infoVecs->back() }); + infoVecs->pop_back(); + restoreQueueForTrain->Pushv(*infoVecs); + break; + case EVAL_CHANNEL_ID: + lookUpKeysQueueForEval->Pushv({ infoVecs->back() }); + infoVecs->pop_back(); + restoreQueueForEval->Pushv(*infoVecs); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); } - spdlog::debug(MGMT + "GetLookupAndRestore embInfoName:{}, modifyGraph:{}, names:{}", - embInfo.name, embInfo.modifyGraph, names); - for (const string& name: names) { - auto infoVecs = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::RESTORE); - if (infoVecs == nullptr) { - spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); - return false; - } - switch (channelId) { - case TRAIN_CHANNEL_ID: - lookUpKeysQueueForTrain->Pushv({ infoVecs->back() }); - infoVecs->pop_back(); - restoreQueueForTrain->Pushv(*infoVecs); - break; - case EVAL_CHANNEL_ID: - lookUpKeysQueueForEval->Pushv({ infoVecs->back() }); - infoVecs->pop_back(); - restoreQueueForEval->Pushv(*infoVecs); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } - if (!mgmtRankInfo.useStatic) { - GetAll2All(channelId, batchId, name); - } + if (!mgmtRankInfo.useStatic) { + GetAll2All(channelId, batchId, embInfo.name); } TIME_PRINT("getAllTensorTC(ms):{}", getAllTensorTC.ElapsedMS()); } @@ -509,85 +501,73 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) return true; } -void HybridMgmt::All2AllKeys(const int channelId, vector names) +void HybridMgmt::All2AllKeys(const int channelId, const string &embName) { TimeCost a2aKeysTC; - for (const string &name : names) { - vector all2allKeys; - switch (channelId) { - case TRAIN_CHANNEL_ID: - all2allKeys = a2aQueueForTrain->WaitAndPop(); - break; - case EVAL_CHANNEL_ID: - all2allKeys = a2aQueueForEval->WaitAndPop(); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } - hdTransfer->Send(TransferChannel::ALL2ALL, all2allKeys, channelId, name); + vector all2allKeys; + switch (channelId) { + case TRAIN_CHANNEL_ID: + all2allKeys = a2aQueueForTrain->WaitAndPop(); + break; + case EVAL_CHANNEL_ID: + all2allKeys = a2aQueueForEval->WaitAndPop(); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); } + hdTransfer->Send(TransferChannel::ALL2ALL, all2allKeys, channelId, embName); TIME_PRINT("All2AllKeysTC(ms):{}", a2aKeysTC.ElapsedMS()); } -void HybridMgmt::LookupKeys(const int channelId, vector names) +void HybridMgmt::LookupKeys(const int channelId, const string &embName) { TimeCost sendLookupTC; - for (const string& name: names) { - vector lookUpKeys; - switch (channelId) { - case TRAIN_CHANNEL_ID: - lookUpKeys = lookUpKeysQueueForTrain->WaitAndPop(); - break; - case EVAL_CHANNEL_ID: - lookUpKeys = lookUpKeysQueueForEval->WaitAndPop(); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } - hdTransfer->Send(TransferChannel::LOOKUP, lookUpKeys, channelId, name); + vector lookUpKeys; + switch (channelId) { + case TRAIN_CHANNEL_ID: + lookUpKeys = lookUpKeysQueueForTrain->WaitAndPop(); + break; + case EVAL_CHANNEL_ID: + lookUpKeys = lookUpKeysQueueForEval->WaitAndPop(); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); } + hdTransfer->Send(TransferChannel::LOOKUP, lookUpKeys, channelId, embName); TIME_PRINT("sendLookupTC(ms):{}", sendLookupTC.ElapsedMS()); } -void HybridMgmt::RestoreKeys(const int channelId, vector names) +void HybridMgmt::RestoreKeys(const int channelId, const string &embName) { TimeCost sendRestoreTC; - for (const string& name: names) { - vector restore; - switch (channelId) { - case TRAIN_CHANNEL_ID: - restore = restoreQueueForTrain->WaitAndPop(); - break; - case EVAL_CHANNEL_ID: - restore = restoreQueueForEval->WaitAndPop(); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } - hdTransfer->Send(TransferChannel::RESTORE, restore, channelId, name); + vector restore; + switch (channelId) { + case TRAIN_CHANNEL_ID: + restore = restoreQueueForTrain->WaitAndPop(); + break; + case EVAL_CHANNEL_ID: + restore = restoreQueueForEval->WaitAndPop(); + break; + default: + throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); } + hdTransfer->Send(TransferChannel::RESTORE, restore, channelId, embName); TIME_PRINT("sendRestoreTC(ms):{}", sendRestoreTC.ElapsedMS()); } bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) { for (const auto& embInfo: mgmtEmbInfo) { - vector names = {embInfo.name}; - if (embInfo.modifyGraph) { - names = embInfo.channelNames; - } - spdlog::debug(MGMT + "SendLookupAndRestore embInfoName:{}, modifyGraph:{}, names:{}", - embInfo.name, embInfo.modifyGraph, names); TimeCost sendTensorsTC; if (!mgmtRankInfo.useStatic) { - All2AllKeys(channelId, names); + All2AllKeys(channelId, embInfo.name); } spdlog::info("SendLookupAndRestore batchId: {}, name: {}, channelId: {}", batchId, embInfo.name, channelId); - LookupKeys(channelId, names); - RestoreKeys(channelId, names); + LookupKeys(channelId, embInfo.name); + RestoreKeys(channelId, embInfo.name); TIME_PRINT("sendTensorsTC(ms):{}", sendTensorsTC.ElapsedMS()); } batchId++; @@ -599,27 +579,19 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) spdlog::info(MGMT + "start parse keys HBM, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); for (const auto& embInfo: mgmtEmbInfo) { TimeCost ParseKeysTC; - vector names = {embInfo.name}; - if (embInfo.modifyGraph) { - names = embInfo.channelNames; + auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); + if (infoVecs == nullptr) { + spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); + return false; } - spdlog::debug(MGMT + "ParseKeysHBM embInfoName:{}, modifyGraph:{}, names:{}", - embInfo.name, embInfo.modifyGraph, names); - for (const string& name: names) { - auto infoVecs = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::RESTORE); - if (infoVecs == nullptr) { - spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); - return false; - } - if (!mgmtRankInfo.useStatic) { - auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); - hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, name); - } - hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, name); - infoVecs->pop_back(); - hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, name); + if (!mgmtRankInfo.useStatic) { + auto all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); + hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embInfo.name); } + hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, embInfo.name); + infoVecs->pop_back(); + hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); TIME_PRINT("ParseKeysTC HBM mode (ms):{}", ParseKeysTC.ElapsedMS()); } batchId++; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index ff27c885..af9016ec 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -138,9 +138,9 @@ namespace MxRec { bool TrainTask(TaskType type); bool EvalTask(TaskType type); - void All2AllKeys(const int channelId, vector names); - void LookupKeys(const int channelId, vector names); - void RestoreKeys(const int channelId, vector names); + void All2AllKeys(const int channelId, const string &embName); + void LookupKeys(const int channelId, const string &embName); + void RestoreKeys(const int channelId, const string &embName); bool GetLookupAndRestore(const int channelId, int &batchId); void GetAll2All(const int channelId, int &batchId, const string &name); bool SendLookupAndRestore(const int channelId, int &batchId); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a6263c58..fd7377af 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -56,7 +56,6 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos map scInfo; for (const auto& info: eInfos) { - spdlog::debug(KEY_PROCESS "Init sendCountMap:{}, channelNames:{}", info.sendCountMap, info.channelNames); embInfos[info.name] = info; scInfo[info.name] = info.sendCount; if (rankInfo.useHot) { @@ -144,15 +143,6 @@ void KeyProcess::InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo) HOT_EMB_CACHE_PCT / static_cast(embeddingSize)); } -auto KeyProcess::GetSendCount(const string& name, const string& channelName, bool modifyGraph) -{ - auto sendCountSize = embInfos[name].sendCount; - if (modifyGraph) { - sendCountSize = embInfos[name].sendCountMap[channelName]; - } - return sendCountSize; -} - auto KeyProcess::GetMaxOffset() -> offset_mem_t { return maxOffset; @@ -239,7 +229,7 @@ void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, if (!uniqueInitialize) { if (rankInfo.useStatic) { - uniqueConf.paddingSize = GetSendCount(batch->name, batch->channelName, batch->modifyGraph); + uniqueConf.paddingSize = embInfos[batch->name].sendCount; } uniqueConf.maxIdVal = INT64_MAX; @@ -348,11 +338,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat TIME_PRINT("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); } if (!rankInfo.useStatic) { // Static all2all,need send count - auto embName = batch->name; - if (batch->modifyGraph) { - embName = batch->channelName; - } - SendA2A(uniqueInfo.all2AllInfo.scAll, embName, batch->channel, batch->batchId); + SendA2A(uniqueInfo.all2AllInfo.scAll, batch->name, batch->channel, batch->batchId); } auto tensors = make_unique>(); @@ -409,11 +395,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe } if (!rankInfo.useStatic) { // Static all2all,need send count - auto embName = batch->name; - if (batch->modifyGraph) { - embName = batch->channelName; - } - SendA2A(scAll, embName, batch->channel, batch->batchId); + SendA2A(scAll, batch->name, batch->channel, batch->batchId); } TimeCost pushResultTC; @@ -441,7 +423,7 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, TimeCost getCountRecvTC; if (rankInfo.useStatic) { for (auto& cnt: keyCount) { - cnt.resize(GetSendCount(batch->name, batch->channelName, batch->modifyGraph), 0); + cnt.resize(embInfos[batch->name].sendCount, 0); } } vector countSend; @@ -470,16 +452,9 @@ void KeyProcess::PushResult(unique_ptr& batch, unique_ptr lockGuard(mut); storage.push_front(move(tensors)); - if (batch->modifyGraph) { - infoList[batch->channelName][batch->channel].push( - make_tuple(batch->batchId, batch->channelName, storage.begin())); - } else { - infoList[batch->name][batch->channel].push( - make_tuple(batch->batchId, batch->name, storage.begin())); - } + infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); if (!rankInfo.noDDR) { - lookupKeysList[batch->name][batch->channel].push( - make_tuple(batch->batchId, batch->name, move(lookupKeys))); + lookupKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(lookupKeys))); } lockGuard.unlock(); } @@ -537,9 +512,6 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) size_t KeyProcess::GetKeySize(const unique_ptr &batch) { size_t size = rankInfo.rankSize * embInfos[batch->name].sendCount; - if (batch->modifyGraph) { - size = rankInfo.rankSize * embInfos[batch->name].sendCountMap[batch->channelName]; - } if (!rankInfo.useStatic) { size = batch->Size(); } @@ -589,9 +561,8 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch All2All(sc, id, batch->channel, keySendInfo, uniqueInfoOut.all2AllInfo); spdlog::debug(KEY_PROCESS "ProcessBatchWithFastUnique get batchId:{}, batchSize:{}, channel:{}, " - "channelName:{}, name:{}, restore:{}, keyCount:{}", batch->batchId, batch->Size(), - batch->channel, batch->channelName, batch->name, uniqueInfoOut.restore.size(), - keySendInfo.keyCount.size()); + "name:{}, restore:{}, keyCount:{}", batch->batchId, batch->Size(), + batch->channel, batch->name, uniqueInfoOut.restore.size(), keySendInfo.keyCount.size()); } void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, @@ -614,7 +585,7 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uni } if (rankInfo.useStatic) { - sc.resize(rankInfo.rankSize, GetSendCount(batch->name, batch->channelName, batch->modifyGraph)); + sc.resize(rankInfo.rankSize, embInfos[batch->name].sendCount); } else { sc.resize(rankInfo.rankSize); for (int i = 0;i < rankInfo.rankSize; i++) { @@ -686,13 +657,13 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, // 使用静态all2all通信:发送或接受量为预置固定值 scInfo[batch->name] = 65536 / rankSize 经验值 if (rankInfo.useStatic) { // maybe move after all2all for (auto& i: splitKeys) { - if (static_cast(i.size()) > GetSendCount(batch->name, batch->channelName, batch->modifyGraph)) { + if (static_cast(i.size()) > embInfos[batch->name].sendCount) { spdlog::error("{}[{}]:{} overflow! set send count bigger than {}", batch->name, batch->channel, batch->batchId, i.size()); throw runtime_error(fmt::format("{}[{}]:{} overflow! set send count bigger than {}", batch->name, batch->channel, batch->batchId, i.size()).c_str()); } - i.resize(GetSendCount(batch->name, batch->channelName, batch->modifyGraph), -1); + i.resize(embInfos[batch->name].sendCount, -1); } } keys_t keySend; @@ -887,7 +858,7 @@ void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& h vector splitKeysSize {}; if (rankInfo.useStatic) { for (size_t i = 0; i < splitKeys.size(); i++) { - splitKeysSize.push_back(GetSendCount(batch->name, batch->channelName, batch->modifyGraph)); + splitKeysSize.push_back(embInfos[batch->name].sendCount); } } else { for (auto& splitKey: splitKeys) { diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 3f71b1bc..5e95c207 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -121,8 +121,6 @@ namespace MxRec { void InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo); - auto GetSendCount(const string& name, const string& channelName, bool modifyGraph); - void KeyProcessTask(int channel, int id); bool KeyProcessTaskHelper(unique_ptr& batch, int channel, int id); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index cf94660e..cb292782 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -160,12 +160,10 @@ namespace MxRec { std::vector sample; void *tensorAddr = nullptr; std::string name; - std::string channelName; size_t batchSize; int batchId; int channel = 0; bool isInt64; // true int64 false int32 - bool modifyGraph; time_t timestamp { -1 }; }; @@ -367,15 +365,11 @@ struct BatchTask { int sendCount, int embeddingSize, int extEmbeddingSize, - bool modifyGraph, bool isSave, - std::vector channelNames, std::vector vocabsize, - std::vector initializeInfos, - std::map sendCountMap) + std::vector initializeInfos) : name(name), sendCount(sendCount), embeddingSize(embeddingSize), extEmbeddingSize(extEmbeddingSize), - modifyGraph(modifyGraph), isSave(isSave), channelNames(channelNames), initializeInfos(initializeInfos), - sendCountMap(sendCountMap) + isSave(isSave), initializeInfos(initializeInfos) { devVocabSize = vocabsize[0]; hostVocabSize = vocabsize[1]; @@ -385,13 +379,10 @@ struct BatchTask { int sendCount; int embeddingSize; int extEmbeddingSize; - bool modifyGraph; bool isSave; size_t devVocabSize; size_t hostVocabSize; - std::vector channelNames; std::vector initializeInfos; - std::map sendCountMap; }; struct HostEmbTable { diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 072f554a..ffac87d3 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -122,8 +122,6 @@ REGISTER_OP("ReadEmbKeyV2Dynamic") .Attr("channel_id: int") .Attr("emb_name: list(string)") // for which table to lookup .Attr("timestamp: bool") // use for feature evict, (unix timestamp) - .Attr("channel_name: list(string)") // use for multi lookup - .Attr("modify_graph: bool") // auto modify graph enabled .SetShapeFn([](InferenceContextPtr c) { c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); return Status::OK(); @@ -145,8 +143,6 @@ public: OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); - OP_REQUIRES_OK(context, context->GetAttr("channel_name", &channelNames)); - OP_REQUIRES_OK(context, context->GetAttr("modify_graph", &modifyGraph)); // 特征准入&淘汰功能 相关校验 @@ -263,10 +259,6 @@ public: } auto batchData = queue->GetOne(); // get dirty or empty data block batchData->name = embNames.at(i); - if (modifyGraph) { - batchData->modifyGraph = modifyGraph; - batchData->channelName = channelNames.at(i); - } size_t len = splits(i); batchData->channel = channelId; batchData->batchId = ids[0]; @@ -326,10 +318,8 @@ public: int channelId {}; vector embNames {}; vector tableUsed{}; - vector channelNames {}; int maxStep = 0; bool isTimestamp { false }; - bool modifyGraph { false }; int threadNum = 0; }; @@ -344,8 +334,6 @@ REGISTER_OP("ReadEmbKeyV2") .Attr("splits: list(int)") .Attr("emb_name: list(string)") // for which table to lookup .Attr("timestamp: bool") // use for feature evict, (unix timestamp) - .Attr("channel_name: list(string)") // use for multi lookup - .Attr("modify_graph: bool") // auto modify graph enabled .SetShapeFn([](InferenceContextPtr c) { c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); return Status::OK(); @@ -360,7 +348,7 @@ public: logger = spdlog::stderr_color_mt("console"); } spdlog::set_default_logger(spdlog::get("console")); - + spdlog::cfg::load_env_levels(); spdlog::default_logger()->set_pattern("[%H:%M:%S.%e] [%^%l%$] %v"); spdlog::debug("ReadEmbKeyV2 init"); @@ -368,8 +356,6 @@ public: OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); OP_REQUIRES_OK(context, context->GetAttr("splits", &splits)); // 每个表的field Number OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); - OP_REQUIRES_OK(context, context->GetAttr("channel_name", &channelNames)); - OP_REQUIRES_OK(context, context->GetAttr("modify_graph", &modifyGraph)); fieldNum = accumulate(splits.begin(), splits.end(), 0); // 特征准入&淘汰功能 相关校验 @@ -490,10 +476,6 @@ public: } auto batchData = queue->GetOne(); // get dirty or empty data block batchData->name = embNames.at(i); - if (modifyGraph) { - batchData->modifyGraph = modifyGraph; - batchData->channelName = channelNames.at(i); - } size_t len = splits.at(i); batchData->channel = channelId; batchData->batchId = batchId; @@ -554,10 +536,8 @@ public: vector tableUsed{}; int fieldNum {}; vector embNames {}; - vector channelNames {}; int maxStep = 0; bool isTimestamp { false }; - bool modifyGraph { false }; int threadNum = KEY_PROCESS_THREAD; }; diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 9e1d6461..56c99b7d 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -93,23 +93,18 @@ void GetRankInfo(pybind11::module_& m) void GetEmbInfo(pybind11::module_& m) { pybind11::class_(m, "EmbInfo") - .def(pybind11::init, std::vector, - std::vector&, std::map>(), - py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), - py::arg("ext_embedding_size"), py::arg("modify_graph"), - py::arg("is_save"), py::arg("channel_name_list"), - py::arg("vocab_size"), py::arg("initialize_infos"), py::arg("send_count_map")) + .def(pybind11::init, + std::vector&>(), + py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), py::arg("ext_embedding_size"), + py::arg("is_save"), py::arg("vocab_size"), py::arg("initialize_infos")) .def_readwrite("name", &EmbInfo::name) .def_readwrite("send_count", &EmbInfo::sendCount) .def_readwrite("embedding_size", &EmbInfo::embeddingSize) .def_readwrite("ext_embedding_size", &EmbInfo::extEmbeddingSize) - .def_readwrite("modify_graph", &EmbInfo::modifyGraph) .def_readwrite("is_save", &EmbInfo::isSave) - .def_readwrite("channel_name_list", &EmbInfo::channelNames) .def_readwrite("dev_vocab_size", &EmbInfo::devVocabSize) .def_readwrite("host_vocab_size", &EmbInfo::hostVocabSize) - .def_readwrite("initialize_infos", &EmbInfo::initializeInfos) - .def_readwrite("send_count_map", &EmbInfo::sendCountMap); + .def_readwrite("initialize_infos", &EmbInfo::initializeInfos); } void GetRandomInfo(pybind11::module_& m) diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 13a29e5e..d8208531 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -38,13 +38,10 @@ protected: int sendCount = 5; int embeddingSize = 8; int extEmbeddingSize = 24; - bool modifyGraph = false; bool isSave = true; size_t devVocabSize = 5; size_t hostVocabSize = 15; vector randomInfos; - vector channelNames = {"model_1", "model_2"}; - map sendCountMap = {{"model_1", 500}, {"model_2", 500}}; RandomInfo randomInfo; int start = 0; int len = hostVocabSize * embeddingSize; @@ -117,8 +114,7 @@ protected: TEST_F(EmbMgmtTest, Initialize) { vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, isSave, channelNames, vocabsize, - initializeInfos, sendCountMap); + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues = {}; @@ -177,8 +173,7 @@ TEST_F(EmbMgmtTest, Initialize_HBM) devVocabSize = HBM_DEVICE_SIZE; hostVocabSize = HBM_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, isSave, channelNames, vocabsize, - initializeInfos, sendCountMap); + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; thresholdValues.emplace_back(name, 1, 1); @@ -198,8 +193,7 @@ TEST_F(EmbMgmtTest, Evict) size_t devVocabSize = DDR_DEVICE_SIZE; size_t hostVocabSize = DDR_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, isSave, channelNames, vocabsize, - initializeInfos, sendCountMap); + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; thresholdValues.emplace_back(name, 1, 1); @@ -222,8 +216,7 @@ TEST_F(EmbMgmtTest, Evict_HBM) devVocabSize = HBM_DEVICE_SIZE; hostVocabSize = HBM_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, modifyGraph, isSave, channelNames, vocabsize, - initializeInfos, sendCountMap); + embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; thresholdValues.emplace_back(name, 1, 1); diff --git a/tools/mx_rec_perf.sh b/tools/mx_rec_perf.sh index 6fadb5e8..fe1ee706 100644 --- a/tools/mx_rec_perf.sh +++ b/tools/mx_rec_perf.sh @@ -1,6 +1,7 @@ #!/bin/bash # Copyright (c) Huawei Technologies Co., Ltd. # Description MxRec性能分析脚本 V1.0 +set -e file="$1" #请输入spdlog文件 @@ -9,41 +10,62 @@ calculate_average() { sum += $1; count++ } END { - avg = sum / count; - print avg + average = sum / count; + print average }' } -echo " =========MxRec性能分析脚本 V1.0========= " -echo "read batch cost" -cat ${file} | grep 'read batch cost'|tail -n 20| awk 'NR%2==1' -echo "====================================" -echo "key process cost" -cat ${file} | grep 'key process cost'|tail -avg=`cat ${file} | grep -Po '(?<=key process cost:)[^,:]+(?=ms)'|tail -n +20 |calculate_average` -real_avg=$(echo "$avg" | awk '{ printf "%.2f", $0/6 }') -echo "Average: $real_avg(single thread avg is $avg)" -echo "====================================" -echo "分析host和device流水,当 host key process 提前训练step时,host性能不为瓶颈" -echo "按输入训练step打印标志,(默认为step) Enter打开分析,按q退出" -read step -step="${step:-step}" -cat ${file} | grep -P "key process cost|${step}"|tail -n100|less - -exit - -# getnext超时问题定位 -echo -n "超时通道为:" -cat ${file} | grep -Po "aicpu_getnext.*GetNext" -echo - "检查是否发送, 发送数量为注意8卡 " -cat ${file} | grep -P "send" |grep all2all |wc -l -cat ${file} | grep -P "send"|grep h2d|wc -l - -echo -n "检查数据读取, 读取batch数量为 " -cat ${file} | grep 'read batch cost'|wc -l -cat ${file} | grep 'read batch cost'|tail +perf() { + echo "read batch cost" + cat ${file} | grep 'read batch cost'|grep -v timeout|tail -n 20| awk 'NR%2==1' + echo "====================================" + echo "key process cost" + cat ${file} | grep 'key process cost'|tail + avg=$(cat ${file} | grep -Po '(?<=key process cost:)[^,:]+(?=,)'|tail -n +20 |calculate_average) + echo "Average: $avg" + echo "====================================" + echo "分析host和device流水,当 host key process 提前训练step时,host性能不为瓶颈" + echo "按输入训练step打印标志,(默认为step) Enter打开分析,按q退出" + read step + step="${step:-step}" + cat ${file} | grep -P "key process cost|${step}"|tail -n100|less +} +echo -e "\e[45m\e[1m =========MxRec分析脚本 V1.0========= \e[0m" +echo +stuck_check() { + echo -e "\e[106m--------卡住、getnext超时问题定位----------\e[0m" + echo -n "超时通道为:" + cat ${file} | grep -Po "aicpu_getnext.*GetNext" + echo + echo "检查每张卡发送lookup数量:" + for i in {0..7} + do + line=$(cat ${file} | grep -P "send"|grep "h2d"|grep "1,${i}"|wc -l) + echo -n "$line " + done + echo + echo "检查每张卡发送h2d数量是否相同:" + for i in {0..7} + do + line=$(cat ${file} | grep "send"|grep "h2d"|grep "1,${i}"|wc -l) + echo -n "$line " + done + echo + echo "检查每张卡接收数量是否相同:" + for i in {0..7} + do + line=$(cat ${file} | grep "r recv"|grep "1,${i}"|wc -l) + echo -n "$line " + done + echo + echo "每张卡最后接收batch为:" + cat ${file}|grep "trans emb"|grep "info"|tail +} +hot_check() { + # 查看hot emb去重率 + echo "表名及去重率(去重后/去重前)为:(应该要小于0.4)" + cat op_summary_*.csv |grep gather_for_restore_vector |awk -F "," '{print $6,$14,$15}'|sed 's/"//g'|sed 's/ [0-9]*;/\//' +} -# 查看hot emb去重率 -echo "表名及去重率(去重后/去重前)为:(应该要小于0.4)" -cat op_summary_*.csv |grep gather_for_restore_vector |awk -F "," '{print $6,$14,$15}'|sed 's/"//g'|sed 's/ [0-9]*;/\//' +perf diff --git a/tools/perf/fast.sh b/tools/perf/fast.sh index 20aec3df..474634ff 100755 --- a/tools/perf/fast.sh +++ b/tools/perf/fast.sh @@ -1,4 +1,4 @@ -#! /bin/bash +#!/bin/bash # Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. # Description: performace analysis tool # Author: MindX SDK -- Gitee From b39efbf4b427ec204fd18d480504452f00c7ff8c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 25 Jun 2023 10:08:56 +0800 Subject: [PATCH 158/551] Match-id-2f35d7d29ee02c0b044fd9ab2351670e64b09ddf --- src/core/emb_table/emb_table.cpp | 2 +- src/core/host_emb/host_emb.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 4422abb0..7973e528 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -142,7 +142,7 @@ void EmbTable::RandomInit(void* newBlock, const vector& initiali } default: { spdlog::warn("Device Invalid Initializer Type. Using default Constant Initializer with value 0."); - ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0); + ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0, 1); initializer = &defaultInitializer; } } diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index f62c9a67..e0d40fa3 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -69,7 +69,7 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in } default: { spdlog::warn(HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); - ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0); + ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0, 1); initializer = &defaultInitializer; } } -- Gitee From 2dd3d5d71f6d7e609c7edaa0e368a138418e9a3f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 25 Jun 2023 10:11:25 +0800 Subject: [PATCH 159/551] Match-id-2b0c94c01a664c4f0c6fc3bfae7431e8f2eca7a2 --- src/core/host_emb/host_emb.cpp | 2 +- src/tests/initializer/initializer_test.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index e0d40fa3..61db3ef8 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -235,7 +235,7 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve } default: { spdlog::error(HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); - ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0); + ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0, 1); initializer = &defaultInitializer; } } diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index a1e59688..678ca84a 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -20,7 +20,7 @@ using namespace MxRec; TEST(InitializerTest, ConstantInitializerTest) { - ConstantInitializer constant_initializer; // start; end; constant_val; + ConstantInitializer constant_initializer; // start; end; constant_val; initK; constant_initializer = ConstantInitializer(1, 5, 7, 0.1); -- Gitee From 4da3c9400cd1828bf68bba2d33a883e5df3088df Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 25 Jun 2023 11:30:13 +0800 Subject: [PATCH 160/551] Match-id-5a777904e74237f6ec3a26752508cfb16d778564 --- mx_rec/core/asc/manager.py | 64 +++++++++++++++++----------- src/tests/emb_mgmt/emb_mgmt_test.cpp | 2 +- 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index e2138e4d..93ba87f4 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -12,6 +12,7 @@ from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, se export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ get_use_hot, get_use_dynamic_expansion, get_enable_table_merge, export_optimizer, export_dangling_table from mx_rec.core.asc.helper import find_dangling_table, should_skip +from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo def check_dangling_table(): @@ -71,8 +72,41 @@ def generate_table_info_list(): return table_info_list +def matched_constant_initializer(tabel_info): + init_param = tabel_info.init_param + logging.debug(f"constant_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") + return InitializeInfo(name="constant_initializer", start=0, len=tabel_info.scalar_emb_size, + constant_initializer_info=ConstantInitializerInfo( + constant_val=tabel_info.emb_initializer.value, initK=init_param)) + + +def matched_random_normal_initializer(tabel_info): + random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed + init_param = tabel_info.init_param + logging.debug(f"random_normal_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") + return InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.scalar_emb_size, + normal_initializer_info=NormalInitializerInfo( + mean=tabel_info.emb_initializer.mean, + stddev=tabel_info.emb_initializer.stddev, + seed=random_seed, + initK=init_param + )) + + +def matched_truncated_normal_initializer(tabel_info): + random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed + init_param = tabel_info.init_param + logging.debug(f"truncated_normal_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") + return InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, + normal_initializer_info=NormalInitializerInfo( + mean=tabel_info.emb_initializer.mean, + stddev=tabel_info.emb_initializer.stddev, + seed=random_seed, + initK=init_param + )) + + def matched_emb_initializer(tabel_info): - from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo initializer_case_map = {"tf1/tf2_constant_initializer": isinstance(tabel_info.emb_initializer, tf.keras.initializers.Constant) or isinstance(tabel_info.emb_initializer, tf.constant_initializer), @@ -88,34 +122,12 @@ def matched_emb_initializer(tabel_info): isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal), } if initializer_case_map.get("tf1/tf2_constant_initializer"): - init_param = tabel_info.init_param - logging.debug(f"constant_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") - initializer = InitializeInfo(name="constant_initializer", start=0, len=tabel_info.scalar_emb_size, - constant_initializer_info=ConstantInitializerInfo( - constant_val=tabel_info.emb_initializer.value, initK=init_param)) + initializer = matched_constant_initializer(tabel_info) elif initializer_case_map.get("tf1/tf2_random_normal_initializer"): - random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed - init_param = tabel_info.init_param - logging.debug(f"random_normal_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") - initializer = InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.scalar_emb_size, - normal_initializer_info=NormalInitializerInfo( - mean=tabel_info.emb_initializer.mean, - stddev=tabel_info.emb_initializer.stddev, - seed=random_seed, - initK=init_param - )) + initializer = matched_random_normal_initializer(tabel_info) elif initializer_case_map.get("tf1_truncated_normal_initializer") or \ initializer_case_map.get("tf2_truncated_normal_initializer"): - random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed - init_param = tabel_info.init_param - logging.debug(f"truncated_normal_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") - initializer = InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, - normal_initializer_info=NormalInitializerInfo( - mean=tabel_info.emb_initializer.mean, - stddev=tabel_info.emb_initializer.stddev, - seed=random_seed, - initK=init_param - )) + initializer = matched_truncated_normal_initializer(tabel_info) else: initializer = InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, normal_initializer_info=NormalInitializerInfo( diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index d8208531..488a853f 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -97,7 +97,7 @@ protected: void SetUp() { // init key_process (RankInfo rankInfo, const vector &embInfos) - constantInitializerInfo = ConstantInitializerInfo(constantVal); + constantInitializerInfo = ConstantInitializerInfo(constantVal, 1); initializeInfo = InitializeInfo(constantInitializerName, start, embeddingSize, constantInitializerInfo); initializeInfos.push_back(initializeInfo); -- Gitee From 6571fb5f3318ef89eaced3cb733a876c6c8540f0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 25 Jun 2023 11:55:24 +0800 Subject: [PATCH 161/551] Match-id-65041dd5ae5b609f29114cbe23587c00baa2e640 --- mx_rec/core/asc/manager.py | 6 ++---- mx_rec/core/embedding.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 93ba87f4..6d6f3a79 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -3,16 +3,16 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import logging - import tensorflow as tf +from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo + from mx_rec.constants.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ get_use_hot, get_use_dynamic_expansion, get_enable_table_merge, export_optimizer, export_dangling_table from mx_rec.core.asc.helper import find_dangling_table, should_skip -from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo def check_dangling_table(): @@ -139,8 +139,6 @@ def matched_emb_initializer(tabel_info): def matched_opt_slot_initializers(table_instance): - from mxrec_pybind import InitializeInfo, ConstantInitializerInfo - start_index = table_instance.scalar_emb_size slot_initializers = [] logging.debug(f"matched_opt_slot_initializers, scalar emb size:{table_instance.ext_emb_size}, " diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index fef8be42..2a5e1b5f 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -340,7 +340,7 @@ class SparseEmbedding: self.lookup_info.add(is_training) if not isinstance(self.init_param, int): - raise ValueError("Arg is_train should be a integer.") + raise ValueError("Arg init_param should be a integer.") if get_use_static(): if isinstance(send_count, int) and send_count > 0: -- Gitee From f81dd523520641795ddbf50b12960e211b9554ab Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 25 Jun 2023 11:57:54 +0800 Subject: [PATCH 162/551] Match-id-7ae5e6fe673f69897be6252e179eae468f1bbc3a --- mx_rec/core/embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 2a5e1b5f..d74e43c1 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -339,8 +339,8 @@ class SparseEmbedding: if is_training not in self.lookup_info: self.lookup_info.add(is_training) - if not isinstance(self.init_param, int): - raise ValueError("Arg init_param should be a integer.") + if not isinstance(self.init_param, float): + raise ValueError("Arg init_param should be a float.") if get_use_static(): if isinstance(send_count, int) and send_count > 0: -- Gitee From 050f3ff2c8fe8658af0b293c30b7ed266a04c415 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 25 Jun 2023 16:40:06 +0800 Subject: [PATCH 163/551] Match-id-693fc4b325bafc8cb8d7c87415c3ef0ff00987c5 --- src/tests/initializer/initializer_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index 678ca84a..db2c7aee 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -22,7 +22,7 @@ TEST(InitializerTest, ConstantInitializerTest) { ConstantInitializer constant_initializer; // start; end; constant_val; initK; - constant_initializer = ConstantInitializer(1, 5, 7, 0.1); + constant_initializer = ConstantInitializer(1, 5, 7, 1); vector> embData; int vocabSize = 5; @@ -40,7 +40,7 @@ TEST(InitializerTest, ConstantInitializerTest) std::cout << std::endl; } - ASSERT_EQ(embData.at(2).at(2), 0.7); + ASSERT_EQ(embData.at(2).at(2), 7); ASSERT_EQ(embData.at(2).at(0), 0); } -- Gitee From a64ddeb6bc9f6719fd296dd5e35d0e81e56be0cb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 26 Jun 2023 09:38:30 +0800 Subject: [PATCH 164/551] Match-id-e4a8bc1359b97c101f54d422cbb87c2e9c9656ce --- mx_rec/util/initialize.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 5d6addea..cf07b256 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -243,7 +243,7 @@ class ConfigInitializer: rank_start = int(split_devices[0]) device_list = list(range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]) + 1)) elif "," in ascend_visible_devices: - device_list = list(map(int, ascend_visible_devices.strip().split(","))).sort() + device_list = list(map(int, ascend_visible_devices.strip().split(","))) elif ascend_visible_devices in VALID_DEVICE_ID_LIST: device_list = [int(ascend_visible_devices.strip())] else: @@ -260,23 +260,25 @@ class ConfigInitializer: chief_device = os.getenv("CM_CHIEF_DEVICE") rank_size = os.getenv("CM_WORKER_SIZE") - if int(rank_size) != len(device_list): - raise ValueError(f"Rank size {rank_size} is different from device num {len(device_list)}.") + sorted_device_list = sorted(device_list) + if int(rank_size) != len(sorted_device_list): + raise ValueError(f"Rank size {rank_size} is different from device num {len(sorted_device_list)}.") try: self._rank_to_device_dict[0] = int(chief_device) except ValueError as err: raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err try: - device_list.pop(int(chief_device)) + sorted_device_list.pop(int(chief_device)) except IndexError as err: raise IndexError( - f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {device_list}.") from err + f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {sorted_device_list}.") \ + from err - for device_idx in device_list: + for device_idx in sorted_device_list: device_id = mxrec_pybind.get_logic_id(int(device_idx)) if device_id > 16: raise ValueError(f"get logic id from physic id fail.") - index = device_list.index(device_idx) + index = sorted_device_list.index(device_idx) self._rank_to_device_dict[index + 1] = device_id def insert_training_mode_channel_id(self, is_training): -- Gitee From 4649b7538c94cc0ca07ccc1e388d9379a54ed694 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 26 Jun 2023 14:12:36 +0800 Subject: [PATCH 165/551] Match-id-c32efebb8099bfea32fc956d2c1b7bfdb70a2b5c --- src/core/key_process/key_process.cpp | 11 ++++++----- src/core/key_process/key_process.h | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index fd7377af..07c3b11b 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -525,6 +525,7 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch EASY_VALUE("batchId", batch->batchId) EASY_BLOCK("ock-unique") + TimeCost uniqueTC; KeySendInfo keySendInfo; size_t size = GetKeySize(batch); @@ -550,7 +551,6 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueVector.data()); uniqueOut.uniqueIdCnt = 0; - TimeCost uniqueTC; int ret = unique->DoEnhancedUnique(uniqueIn, uniqueOut); EASY_END_BLOCK TIME_PRINT("FastUniqueCompute(ms):{}, ret:{}", uniqueTC.ElapsedMS(), ret); @@ -594,14 +594,15 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uni } } -void KeyProcess::ComputeHotPos(const unique_ptr &batch, map &hotMap, +void KeyProcess::ComputeHotPos(const unique_ptr &batch, absl::flat_hash_map &hotMap, vector &hotPos, vector &restore, const int hotOffset) { - auto inputData = batch->sample.data(); + auto* inputData = batch->sample.data(); + size_t miniBs = batch->Size(); int hotCount = 0; - for (size_t i = 0;i < batch->Size(); ++i) { - auto key = inputData[i]; + for (size_t i = 0;i < miniBs; i++) { + const emb_key_t& key = inputData[i]; auto hot = hotMap.find(key); if (hot != hotMap.end()) { if (hot->second == -1) { diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 5e95c207..384365be 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -111,7 +111,7 @@ namespace MxRec { map> keyOffsetMap {}; FeatureAdmitAndEvict m_featureAdmitAndEvict {}; map> evictPosMap {}; - map> hotKey {}; + map> hotKey {}; map hotEmbTotCount; map embeddingTableMap {}; @@ -182,7 +182,7 @@ namespace MxRec { void AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, const unique_ptr& batch); - void ComputeHotPos(const unique_ptr &batch, map &hotMap, + void ComputeHotPos(const unique_ptr &batch, absl::flat_hash_map &hotMap, vector &hotPos, vector &restore, const int hotOffset); vector GetCountRecv(const unique_ptr& batch, int id, -- Gitee From 29786fe1ba39d8e46963b4af32588c13fe9e5592 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 25 Jun 2023 16:04:09 +0800 Subject: [PATCH 166/551] Match-id-6e23104be69b03fdb7dc0e38ade4a7bcdb095ae1 --- mx_rec/saver/patch.py | 27 ++++++++++++++++----------- src/core/checkpoint/checkpoint.cpp | 7 +++++-- src/core/checkpoint/checkpoint.h | 2 +- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 58469617..4fa7b168 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -90,7 +90,6 @@ def saver_init(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_time = self._keep_checkpoint_every_n_hours * 3600 self._next_checkpoint_time = (time.time() + keep_time) elif not defer_build: - self._var_list = build_var_list(var_list) self.build() self._object_restllore_saver = None # mxRec Patch @@ -163,6 +162,13 @@ def get_checkpoint_file(self, global_step, sess, save_path): return checkpoint_file +def build(self): + self._var_list = build_var_list() + if context.executing_eagerly(): + raise RuntimeError("Use save/restore instead of build in eager mode.") + self._build(self._filename, build_save=True, build_restore=True) + + def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False): if not self._is_built and not context.executing_eagerly(): @@ -313,16 +319,14 @@ def saver_from_object_based_checkpoint(checkpoint_path, var_list=None, builder=N return cached_saver -def build_var_list(var_list): - if var_list is None: - save_var_list = [] - tmp_list = variables._all_saveable_objects() - removing_var_list = export_removing_var_list() - for var in tmp_list: - if var.name not in removing_var_list: - save_var_list.append(var) - return save_var_list - return var_list +def build_var_list(): + save_var_list = [] + tmp_list = variables._all_saveable_objects() + removing_var_list = export_removing_var_list() + for var in tmp_list: + if var.name not in removing_var_list: + save_var_list.append(var) + return save_var_list class BaseSaverBuilder(object): @@ -370,6 +374,7 @@ def patch_for_saver(): dense_saver.__init__ = saver_init dense_saver.save = save dense_saver.restore = restore + dense_saver.build = build logging.debug("Class tf.train.Saver has been patched.") diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 48891b5d..02343b7a 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "ckpt_data_handler//emb_hash_ckpt/emb_hash_ckpt.h" #include "ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h" @@ -134,8 +135,10 @@ void Checkpoint::MakeDataLayerSaveDir(const vector& embNames, void Checkpoint::MakeSaveDir(const string& dirName) { - if (mkdir(dirName.c_str(), dirMode) == -1) { - spdlog::debug("Unable to create directory: {}", dirName); + if (access(dirName.c_str(), F_OK) == -1) { + if (mkdir(dirName.c_str(), dirMode) == -1) { + spdlog::debug("Unable to create directory: {}", dirName); + } } } diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index f2120519..1da11b0b 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -31,7 +31,7 @@ namespace MxRec { const string dataFileType { ".data" }; const string attribFileType { ".attribute" }; const string dirSeparator { "/" }; - const mode_t dirMode { 0777 }; + const mode_t dirMode { 0755 }; const string currDir { "." }; const string prevDir { ".." }; -- Gitee From a6335dc3aa20fff469c0fba094bc75b627582156 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 27 Jun 2023 10:46:47 +0800 Subject: [PATCH 167/551] Match-id-eb2f0e140f5870c7fe50e646c650f532e13bc5dd --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 21 +++++++++- tools/perf/fast.sh | 29 ++++++++++++-- tools/perf/host_set.sh | 17 ++++++++ tools/perf/msprof.sh | 24 +++++++++++ tools/perf/mt_1207.sh | 60 ++++++++++++++++++++++++++++ tools/perf/perf_flame_graph.sh | 37 +++++++++++++++++ 6 files changed, 184 insertions(+), 4 deletions(-) create mode 100755 tools/perf/host_set.sh create mode 100755 tools/perf/msprof.sh create mode 100755 tools/perf/mt_1207.sh create mode 100755 tools/perf/perf_flame_graph.sh diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 55c42554..00fb3127 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -579,19 +579,38 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) spdlog::info(MGMT + "start parse keys HBM, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); for (const auto& embInfo: mgmtEmbInfo) { TimeCost ParseKeysTC; + // get + TimeCost getTensorsSyncTC; auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); if (infoVecs == nullptr) { spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); return false; } + unique_ptr> all2all = nullptr; + if (!mgmtRankInfo.useStatic) { + all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); + } + TIME_PRINT("getTensorsSyncTC(ms):{}", getTensorsSyncTC.ElapsedMS()); + // send + TimeCost sendTensorsSyncTC; if (!mgmtRankInfo.useStatic) { - auto all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); + TimeCost sendAll2AllScSyncTC; hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embInfo.name); + TIME_PRINT("sendAll2AllScSyncTC(ms):{}", sendAll2AllScSyncTC.ElapsedMS()); } + + TimeCost sendLookupSyncTC; hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, embInfo.name); infoVecs->pop_back(); + TIME_PRINT("sendLookupSyncTC(ms):{}", sendLookupSyncTC.ElapsedMS()); + + TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); + TIME_PRINT("sendRestoreSyncTC(ms):{}", sendRestoreSyncTC.ElapsedMS()); + + TIME_PRINT("sendTensorsSyncTC(ms):{}", sendTensorsSyncTC.ElapsedMS()); + TIME_PRINT("ParseKeysTC HBM mode (ms):{}", ParseKeysTC.ElapsedMS()); } batchId++; diff --git a/tools/perf/fast.sh b/tools/perf/fast.sh index 474634ff..5ad5098f 100755 --- a/tools/perf/fast.sh +++ b/tools/perf/fast.sh @@ -15,7 +15,7 @@ # # 3. 注意事项 # 基于spdlog::info,mxRec中添加了TimeCost打点日志,因此,在执行前务必确保run.sh中设置 -# SPDLOG_LEVEL=info 或者 SPDLOG_LEEL=debug (如果没有设置,本工具会退出,并给予提示) +# SPDLOG_LEEL=debug (如果没有设置,本工具会退出,并给予提示) # # 4. 解读结果 # (1) Pipeline: 整个Pipeline由多个Pipe串行构成,性能分析结果分Pipe呈现,例如Pipe-1/Pipe-2/Pipe-3/Pipe-4等; @@ -62,7 +62,7 @@ check_spdlog_level() $(grep 'ReadEmbKeyV2Dynamic' $logfile > /dev/null 2>&1) if [ $? != 0 ]; then LOG_ERROR "No timecost-related logs, please check 'mpi_args' in your run.sh, - make sure SPDLOG_LEVEL=info or SPDLOG_LEEL=debug, and run again!" + make sure SPDLOG_LEEL=debug, and run again!" exit 1 fi fi @@ -245,8 +245,31 @@ parse_pipe_3_get_and_send_tensors_sync_without_ddr() if [ $? == 0 ]; then grep 'ParseKeysTC HBM mode (ms)' $logfile | \ awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ - {printf "--|ParseKeysTC(filter>2000ms): avg=%0.1f\n", sum/count}' + {printf "ParseKeysTC(filter>2000ms): avg=%0.1f\n", sum/count}' fi + + grep 'getTensorsSyncTC(ms)' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<1000) {sum+=$NF; count++;}} END \ + {printf "--|getTensorsSyncTC(filter>1000ms): avg=%0.1f\n", sum/count}' + + grep 'sendTensorsSyncTC(ms)' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<1000) {sum+=$NF; count++;}} END \ + {printf "--|sendTensorsSyncTC(filter>1000ms): avg=%0.1f\n", sum/count}' + + $(grep 'sendAll2AllScSyncTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'sendAll2AllScSyncTC(ms)' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<200) {sum+=$NF; count++;}} END \ + {printf "----|sendAll2AllScSyncTC(filter>200ms): avg=%0.1f\n", sum/count}' + fi + + grep 'sendLookupSyncTC(ms)' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<200) {sum+=$NF; count++;}} END \ + {printf "----|sendLookupSyncTC(filter>200ms): avg=%0.1f\n", sum/count}' + + grep 'sendRestoreSyncTC(ms)' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<200) {sum+=$NF; count++;}} END \ + {printf "----|sendRestoreSyncTC(filter>200ms): avg=%0.1f\n", sum/count}' } main() diff --git a/tools/perf/host_set.sh b/tools/perf/host_set.sh new file mode 100755 index 00000000..0120ebb9 --- /dev/null +++ b/tools/perf/host_set.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. +# Description: performace analysis tool +# Author: MindX SDK +# Create: 2023 +# History: NA + +# cpu with high-performance +cpupower frequency-set -g performance +cat /proc/cpuinfo|grep MHz + +# clear cache +echo 3 > /proc/sys/vm/drop_caches +free -h + +# swap off +swapoff -a diff --git a/tools/perf/msprof.sh b/tools/perf/msprof.sh new file mode 100755 index 00000000..c1821c83 --- /dev/null +++ b/tools/perf/msprof.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. +# Description: performace analysis tool +# Author: MindX SDK +# Create: 2023 +# History: NA + + +curr_path=$(cd $(dirname $0); pwd) + +# ---------------config start--------------------- +model_run_path=/path/to/model/run +run_cmd="bash run.sh" +# ---------------config end--------------------- + +# ------------------------------+ +# msprof + +# ------------------------------+ +output_path="${model_run_path}"/msprof_out + +cd "${model_run_path}" +rm -rf "${output_path}" + +msprof --application="${run_cmd}" --output="${output_path}" diff --git a/tools/perf/mt_1207.sh b/tools/perf/mt_1207.sh new file mode 100755 index 00000000..fc0af5db --- /dev/null +++ b/tools/perf/mt_1207.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. +# Description: performace analysis tool +# Author: MindX SDK +# Create: 2023 +# History: NA + +#set -x + +LOG_INFO() { echo -e "\033[1;4;32m$1\033[0m" ; } +LOG_NOTICE() { echo -e "\033[1;4;45m$1\033[0m" ; } +LOG_WARN() { echo -e "\033[1;31m[WARN]$1\033[0m" ; } +LOG_ERROR() { echo -e "\033[1;31m[Error]$1\033[0m" ; } + +logfile=$1 + +# ---------------config start--------------------- +batchsize=9600 +parallel=8 +nv_throughput=820000 +# ---------------config end--------------------- + +validate_options() +{ + if [ $# -ne 1 ]; then + LOG_ERROR "NO log_file" + echo "[Usage]: bash $0 your_file.log" + exit 1 + fi +} + +print_throughput() +{ + LOG_INFO "=========Throughput=====================" + nv_sps=$(awk 'BEGIN{printf "%.2f\n",('${nv_throughput}'/'$batchsize'/'$parallel')}') + LOG_NOTICE "batchsize:${batchsize}, parallel:${parallel}" + LOG_NOTICE "nv_throughput:${nv_throughput}, nv_sps:${nv_sps}" + + grep 'tensorflow:global_step/sec' $logfile | \ + awk -F" " '{sum+=$NF} END \ + {printf "Throughput: avg=%0.3f, xA100:%0.3f\n", \ + sum/NR, sum/NR/'${nv_sps}'}' + + grep 'tensorflow:global_step/sec' $logfile | \ + awk -F" " 'BEGIN {sum=0; count=0;} {if ($NF > 3) {sum+=$NF; count++;}} END \ + {printf "Throughput: after filter(<3), avg=%0.3f, xA100:%0.3f\n", \ + sum/count, sum/count/'${nv_sps}'}' + + grep 'tensorflow:global_step/sec' $logfile | \ + awk -F" " 'BEGIN {max=0} {if($2>max) max=$2} END \ + {printf "Throughput: max=%0.3f, xA100:%0.3f\n", max, max/'${nv_sps}'}' +} + +main() +{ + validate_options $@ + print_throughput +} + +main $@ diff --git a/tools/perf/perf_flame_graph.sh b/tools/perf/perf_flame_graph.sh new file mode 100755 index 00000000..dce91600 --- /dev/null +++ b/tools/perf/perf_flame_graph.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. +# Description: performace analysis tool +# Author: MindX SDK +# Create: 2023 +# History: NA + +#set -x + +curr_path=$(cd $(dirname $0); pwd) + +LOG_INFO() { echo -e "\033[1;4;32m$1\033[0m" ; } +LOG_NOTICE() { echo -e "\033[1;4;45m$1\033[0m" ; } +LOG_WARN() { echo -e "\033[1;31m[WARN]$1\033[0m" ; } +LOG_ERROR() { echo -e "\033[1;31m[Error]$1\033[0m" ; } + +# ---------------config start--------------------- +model_run_path=/path/to/model/run +run_cmd="bash run.sh" +flame_graph_path=/home/FlameGraph +# ---------------config end--------------------- + +cd "${model_run_path}" +rm -rf perf* + +#---- perf cpu-clock on all workers and build flame graph------------ +perf record -F 99 -a -g "${run_cmd}" +wait $! + +perf script -i perf.data | \ + "${flame_graph_path}"/stackcollapse-perf.pl | \ + "${flame_graph_path}"/flamegraph.pl > perf_mxRec.svg +wait $! + +LOG_INFO "perf_mxRec.svg is created, please check!" + + -- Gitee From bbea15a99b34b73c1a7f7d470c987b3d5391b02f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 28 Jun 2023 10:21:56 +0800 Subject: [PATCH 168/551] Match-id-d2f8358b93bde74cc4fd6ba60998d75a77ea4706 --- mx_rec/core/embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index d74e43c1..fbd72ca2 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -474,6 +474,8 @@ class SparseEmbedding: if use_hot: import mxrec_pybind emb_size = self.scalar_emb_size if self.skip_emb_transfer else self.ext_emb_size + if emb_size == 0: + raise RuntimeError("emb_size is 0, please set a valid value.") hot_size = int(mxrec_pybind.get_ub_hot_size(get_device_id()) / emb_size) hot_pos = tf.ones(shape=[hot_size, ], dtype=tf.int32, name="hot_pos") hot_pos = tf.identity(hot_pos, name=ASCAnchorAttr.HOT_POS.value) -- Gitee From e93f92eea89950d6bd0e3e1fd36e57f569659bb0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 28 Jun 2023 14:56:07 +0800 Subject: [PATCH 169/551] Match-id-070d65c6fe13aba58ca3d1aba40e7bf8a488b71b --- src/core/key_process/key_process.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 07c3b11b..927f2fef 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -534,9 +534,7 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch vector uniqueVector(batch->Size()); uniqueInfoOut.restore.resize(batch->Size()); vector idCount(batch->Size()); - if (rankInfo.useStatic) { - keySendInfo.keyCount.resize(size); - } + keySendInfo.keyCount.resize(size); UniqueIn uniqueIn; uniqueIn.inputIdCnt = (uint32_t)batch->Size(); @@ -545,8 +543,12 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch EnhancedUniqueOut uniqueOut; uniqueOut.uniqueId = reinterpret_cast(keySendInfo.keySend.data()); uniqueOut.index = (uint32_t*)uniqueInfoOut.restore.data(); - uniqueOut.idCnt = idCount.data(); - uniqueOut.idCntFill = keySendInfo.keyCount.data(); + if (rankInfo.useStatic) { + uniqueOut.idCnt = idCount.data(); + uniqueOut.idCntFill = keySendInfo.keyCount.data(); + } else { + uniqueOut.idCnt = keySendInfo.keyCount.data(); + } uniqueOut.uniqueIdCntInBucket = splitSize.data(); uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueVector.data()); uniqueOut.uniqueIdCnt = 0; -- Gitee From b6f5eeb8b66b2795bf668ac9fd1f8c3d5deede77 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 29 Jun 2023 21:23:51 +0800 Subject: [PATCH 170/551] Match-id-f2b1d873342d4a5bd45850643327975ee78741e6 --- example/little_demo/main.py | 2 +- mx_rec/core/asc/helper.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index ddd51453..de2f18f2 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -67,7 +67,7 @@ def build_graph(hash_table_list, is_train, feature_spec_list=None, config_dict=N [hash_table_list[0], hash_table_list[0], hash_table_list[0], hash_table_list[1]], [cfg.user_send_cnt, cfg.item_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt]] if USE_TIMESTAMP: - tf.add_to_collection(ASCEND_TIMESTAMP, batch["timestamp"]) + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, batch["timestamp"]) model = model_forward(input_list, batch, is_train=is_train, modify_graph=True, config_dict=config_dict) else: diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 529d78da..7b628fc2 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -252,14 +252,16 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ logging.info(f"do_insert skip table 2: {table_name}") continue - - new_insert_tensors.append(insert_tensors[idx]) new_splits.append(splits[idx]) new_table_names.append(table_names[idx]) if FeatureSpec.use_timestamp(is_training): new_insert_tensors = insert_tensors + if len(splits) < 1: + raise ValueError(f"When use_timestamp is set to True, " + f"the length of the splits list must be greater than or equal to 1.") + new_splits = splits[1:] return do_insert(args, insert_tensors=new_insert_tensors, -- Gitee From 506eafdbf0ad8d7c0e50a3e5841d9ad9a6af2197 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 1 Jul 2023 10:18:42 +0800 Subject: [PATCH 171/551] Match-id-bdd265362a9e8877f8e14693f6bcfb25b1df517a --- src/pybind/module_main.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index f15bae97..b3ed0604 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -6,7 +6,7 @@ */ #include #include -#include +#include #include "hybrid_mgmt/hybrid_mgmt.h" #include "module_main.h" @@ -39,7 +39,7 @@ uint32_t GetLogicID(uint32_t phyid) { int32_t ret = 0; uint32_t logicId; - ret = dcmi_get_device_logicid_from_phyid(phyid, &logicId); + ret = dsmi_get_logicid_from_phyid(phyid, &logicId); if (ret != 0) { return ret; } -- Gitee From dae0bc49fe3f7e295cd58909399703436230879c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 1 Jul 2023 18:39:42 +0800 Subject: [PATCH 172/551] Match-id-6763f96a18310b942418e61ea32250081271f977 --- mx_rec/core/embedding.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index fbd72ca2..6eb180fe 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -6,6 +6,7 @@ import logging import math import time from collections import defaultdict +from typing import Optional import numpy as np import tensorflow as tf @@ -442,8 +443,7 @@ class SparseEmbedding: if is_training and use_dynamic_expansion and is_table_name_valid: tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) - logging.debug(f"modify graph mode, table_name: {self.table_name}, " - f"ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") + logging.debug(f"modify graph, table_name: {self.table_name}, contain: {ASCEND_TABLE_NAME_MUST_CONTAIN}") @tf.custom_gradient def sparse_forward(table, feat_ids): @@ -475,7 +475,7 @@ class SparseEmbedding: import mxrec_pybind emb_size = self.scalar_emb_size if self.skip_emb_transfer else self.ext_emb_size if emb_size == 0: - raise RuntimeError("emb_size is 0, please set a valid value.") + raise ValueError("emb_size is 0, please set a valid value.") hot_size = int(mxrec_pybind.get_ub_hot_size(get_device_id()) / emb_size) hot_pos = tf.ones(shape=[hot_size, ], dtype=tf.int32, name="hot_pos") hot_pos = tf.identity(hot_pos, name=ASCAnchorAttr.HOT_POS.value) @@ -485,6 +485,7 @@ class SparseEmbedding: if not use_dynamic_expansion: id_offsets_abs = tf.abs(id_offsets) local_emb = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") + local_emb = set_zero_for_non_valid_key(id_offsets, local_emb) else: local_emb = tf.identity(table, name="identity_local_emb") all2all_args = send_count if use_static else all2all_matrix @@ -512,8 +513,7 @@ class SparseEmbedding: logging.debug(f"bp rank size: {rank_size}") unique_embeddings_shape = unique_embeddings.shape.as_list() if use_static \ else tf.shape(unique_embeddings) - unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, - restore_vector, + unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, unique_embeddings_shape[0]) bp_all2all_args = all2all_args if use_static else tf.transpose(all2all_args) if hot_pos is not None: @@ -688,6 +688,7 @@ class SparseEmbedding: if not use_dynamic_expansion: id_offsets_abs = tf.abs(id_offsets) local_embeddings = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") + local_embeddings = set_zero_for_non_valid_key(id_offsets, local_embeddings) else: local_embeddings = tf.identity(table, name="identity_local_emb") @@ -947,3 +948,15 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): check_type(self._evict_time_interval, int, "evict_time_interval") if self._evict_step_interval is not None: check_type(self._evict_step_interval, int, "evict_time_interval") + + +def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Optional[tf.Tensor]): + """ + 将key为-1的特征对应的emb置为0 + :param id_offsets: 特征索引 + :param embeddings: 稀疏表 + :return: + """ + id_offsets_expand = tf.expand_dims(id_offsets >= 0, axis=-1) + embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like(embeddings)) + return embeddings -- Gitee From 6f22dc6b89882f6f96100b63b370a1c71ce2f215 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 3 Jul 2023 14:18:33 +0800 Subject: [PATCH 173/551] Match-id-9ef0a305af515c26767efc7a8a8febab03844d3d --- src/core/key_process/key_process.cpp | 1 + src/core/ock_ctr_common/include/unique.h | 1 + src/platform/AccCTR | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 927f2fef..3cf54b9a 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -216,6 +216,7 @@ void KeyProcess::GetUniqueConfig(UniqueConf& uniqueConf) uniqueConf.outputType = OutputType::ENHANCED; uniqueConf.minThreadNum = MIN_UNIQUE_THREAD_NUM; uniqueConf.maxThreadNum = PerfConfig::maxUniqueThreadNum; + uniqueConf.performance = true; } void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, diff --git a/src/core/ock_ctr_common/include/unique.h b/src/core/ock_ctr_common/include/unique.h index 59ed98b5..9f01ec96 100644 --- a/src/core/ock_ctr_common/include/unique.h +++ b/src/core/ock_ctr_common/include/unique.h @@ -54,6 +54,7 @@ using UniqueConf = struct UniqueConfCTR { uint32_t maxThreadNum = 8; // 最大工作线程数 int64_t maxIdVal = 0; // 最大id值 bool trace = false; // 是否开启性能检测,需要配合外部日志输出 + bool performance = false; // 是否开启增强接口,增强接口shardingNum必须是2的幂次方,默认用取模分桶 } __attribute__((packed)); using UniqueIn = struct UniqueInCTR { diff --git a/src/platform/AccCTR b/src/platform/AccCTR index a4c7f7e5..bc9dc810 160000 --- a/src/platform/AccCTR +++ b/src/platform/AccCTR @@ -1 +1 @@ -Subproject commit a4c7f7e598334c24e875df13b08877155b6c0451 +Subproject commit bc9dc8103109eb8b77c09e7c4028a992a66d5015 -- Gitee From f5adc4997cf346e9dcb5c84597e2863b80e76687 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 3 Jul 2023 15:35:54 +0800 Subject: [PATCH 174/551] Match-id-aeda057d79d8000517251d93e98cb296d23355e7 --- mx_rec/core/embedding.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 6eb180fe..fcddb0f9 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -485,7 +485,7 @@ class SparseEmbedding: if not use_dynamic_expansion: id_offsets_abs = tf.abs(id_offsets) local_emb = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") - local_emb = set_zero_for_non_valid_key(id_offsets, local_emb) + local_emb = set_zero_for_non_valid_key(id_offsets, local_emb, feature_spec.access_threshold) else: local_emb = tf.identity(table, name="identity_local_emb") all2all_args = send_count if use_static else all2all_matrix @@ -688,7 +688,8 @@ class SparseEmbedding: if not use_dynamic_expansion: id_offsets_abs = tf.abs(id_offsets) local_embeddings = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") - local_embeddings = set_zero_for_non_valid_key(id_offsets, local_embeddings) + local_embeddings = set_zero_for_non_valid_key(id_offsets, local_embeddings, + feature_spec.access_threshold) else: local_embeddings = tf.identity(table, name="identity_local_emb") @@ -950,13 +951,19 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): check_type(self._evict_step_interval, int, "evict_time_interval") -def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Optional[tf.Tensor]): +def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Optional[tf.Tensor], + access_threshold: bool): """ 将key为-1的特征对应的emb置为0 :param id_offsets: 特征索引 :param embeddings: 稀疏表 + :param access_threshold: 准入阈值 :return: """ + if access_threshold is None or access_threshold <= 0: + return embeddings id_offsets_expand = tf.expand_dims(id_offsets >= 0, axis=-1) + if get_use_static(): + id_offsets_expand = tf.compat.v1.broadcast_to(id_offsets_expand, embeddings.shape.as_list()) embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like(embeddings)) return embeddings -- Gitee From 89077323cab85536ca9fd86e45778a1a38a47486 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 3 Jul 2023 17:33:23 +0800 Subject: [PATCH 175/551] Match-id-701d79e2e1f2daacdfa186a78d3d6387a37c1651 --- mx_rec/util/initialize.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index cf07b256..2ab75c96 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -268,11 +268,13 @@ class ConfigInitializer: except ValueError as err: raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err try: - sorted_device_list.pop(int(chief_device)) + sorted_device_list.pop(int(chief_device) % len(sorted_device_list)) except IndexError as err: raise IndexError( f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {sorted_device_list}.") \ from err + except ZeroDivisionError as err: + raise ZeroDivisionError("sorted_device_list length can not equal to 0.") from err for device_idx in sorted_device_list: device_id = mxrec_pybind.get_logic_id(int(device_idx)) -- Gitee From 363c9c6c2be107c5c74f67e1e74cfc6f1699b193 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 4 Jul 2023 15:33:56 +0800 Subject: [PATCH 176/551] Match-id-f339a25815596012407c8082c936477f3f63e827 --- tools/parse_data/data_parser.py | 124 ++++++++++++++++++++++++++++++++ tools/parse_data/run.sh | 11 +++ tools/perf/fast.sh | 10 +-- 3 files changed, 140 insertions(+), 5 deletions(-) create mode 100755 tools/parse_data/data_parser.py create mode 100755 tools/parse_data/run.sh diff --git a/tools/parse_data/data_parser.py b/tools/parse_data/data_parser.py new file mode 100755 index 00000000..735c8faa --- /dev/null +++ b/tools/parse_data/data_parser.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: +# Author: MindX SDK +# Create: 2023-06-28 + +# -----------------------------------------ReadMe Begin-------------------------------------------- +# 1. 功能描述 +# 本工具用于单测tensorflow数据解析阶段耗时,便于分析数据解析阶段是不是整个pipeline的瓶颈?堵塞了pipeline的流畅运行? +# 2. 注意事项 +# 数据解析逻辑主要包含在make_dataset函数中,本函数缺省使用criteo数据集。如果需要测试其他数据集的解析耗时,可根据需要重新定义make_dataset; +# 3. 绑核 +# 为了模拟真实场景,bind_cpu默认模拟了80核cpu、8worker平均分配核;如果worker数目不同、真实cpu核数不同,可根据需要重新定义bind_cpu函数; +# 4. 启动执行 +# 4.1 单worker执行: python3 data_parser.py +# 4.2 多worker执行: bash run.sh data_parser.py +# -----------------------------------------ReadMe End-------------------------------------------- + +import os +import sys +import time + +import logging +import psutil + +import tensorflow as tf + +logging.basicConfig(level=logging.DEBUG) + + +def make_dataset(data_path, batch_size=102400, line_per_sample=1024): + def extract_fn(data_record): + features = { + # Extract features using the keys set during creation + 'label': tf.FixedLenFeature(shape=(line_per_sample,), dtype=tf.int64), + 'sparse_feature': tf.FixedLenFeature(shape=(26 * line_per_sample,), dtype=tf.int64), + 'dense_feature': tf.FixedLenFeature(shape=(13 * line_per_sample,), dtype=tf.float32), + } + sample = tf.parse_single_example(data_record, features) + return sample + + def feat_cast(feat): + for name, tensor in feat.items(): + if tensor.dtype == tf.int64: + feat[name] = tf.cast(tensor, tf.int32) + return feat + + def reshape_fn(batch): + batch['label'] = tf.reshape(batch['label'], [-1, 1]) + batch['dense_feature'] = tf.reshape(batch['dense_feature'], [-1, 13]) + batch['dense_feature'] = tf.math.log(batch['dense_feature'] + 3.0) + batch['sparse_feature'] = tf.reshape(batch['sparse_feature'], [-1, 26]) + return batch + + file_list = sorted([os.path.join(data_path, file) for file in os.listdir(data_path)]) + dataset = tf.data.TFRecordDataset(file_list, num_parallel_reads=4) + + num_parallel = 8 + dataset = dataset.map(extract_fn, num_parallel_calls=num_parallel) + + line_cnt = batch_size // line_per_sample + dataset = dataset.batch(line_cnt, drop_remainder=True) + + dataset = dataset.map(feat_cast, num_parallel_calls=num_parallel) + dataset = dataset.map(reshape_fn, num_parallel_calls=num_parallel) + + dataset = dataset.prefetch(10) + return dataset + + +def bind_cpu(rank_id): + process = psutil.Process() + cpu_kernels = { + 0: 0, + 1: 10, + 2: 40, + 3: 50, + 4: 20, + 5: 30, + 6: 60, + 7: 70 + } + try: + process.cpu_affinity([cpu_kernels.get(rank_id) + x for x in range(10)]) + except IndexError: + logging.error("error cpu bind info, skipped.") + + +if __name__ == '__main__': + RANK_ID = 0 + if (len(sys.argv) > 1): + RANK_ID = int(sys.argv[1]) + bind_cpu(RANK_ID) + + DATA_PATH = "/media/mxRec/data/criteo_tfrecord_small/train" + train_dataset = make_dataset(DATA_PATH) + iterator = train_dataset.make_initializable_iterator() + next_batch = iterator.get_next() + + input_data = [] + for example in next_batch: + input_data.append(next_batch[example]) + + COUNT = 0 + TOTAL_TIME = 0.0 + + with tf.Session() as sess: + sess.run(iterator.initializer) + while True: + try: + start_time = time.time() + result = sess.run(input_data[0]) + end_time = time.time() + + COUNT += 1 + + if COUNT > 1: + TOTAL_TIME += end_time - start_time + logging.info("StepId:%d, StepTimeCost(ms):%f", COUNT, (end_time - start_time)) + except tf.errors.OutOfRangeError as e: + logging.error("End of Training Dataset") + break + logging.info("StepTimeCost avg(ms):%f", TOTAL_TIME / (COUNT - 1)) \ No newline at end of file diff --git a/tools/parse_data/run.sh b/tools/parse_data/run.sh new file mode 100755 index 00000000..b3ab73bb --- /dev/null +++ b/tools/parse_data/run.sh @@ -0,0 +1,11 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. +# Description: performace analysis tool +# Author: MindX SDK +# Create: 2023 +# History: NA + +for i in {0..7} +do + nohup python3 data_parser.py $i > rank_$i.log 2>&1 & +done \ No newline at end of file diff --git a/tools/perf/fast.sh b/tools/perf/fast.sh index 5ad5098f..00dce55c 100755 --- a/tools/perf/fast.sh +++ b/tools/perf/fast.sh @@ -14,8 +14,8 @@ # bash fast.sh your_log_file.log # # 3. 注意事项 -# 基于spdlog::info,mxRec中添加了TimeCost打点日志,因此,在执行前务必确保run.sh中设置 -# SPDLOG_LEEL=debug (如果没有设置,本工具会退出,并给予提示) +# 基于spdlog::debug,mxRec中添加了TimeCost打点日志,因此,在执行前务必确保run.sh中设置 +# SPDLOG_LEVEL=debug (如果没有设置,本工具会退出,并给予提示) # # 4. 解读结果 # (1) Pipeline: 整个Pipeline由多个Pipe串行构成,性能分析结果分Pipe呈现,例如Pipe-1/Pipe-2/Pipe-3/Pipe-4等; @@ -104,15 +104,15 @@ parse_pipe_2_key_process() grep 'getBatchDataTC' $logfile | \ awk -F":" 'BEGIN { max=0 } { sum+=$NF; if($NF>max) max=$NF } END \ - {printf "--|get data time: total=%d, max=%0.1f, avg=%0.1f\n", NR, max, sum/NR}' + {printf "--|getBatchDataTC: total=%d, max=%0.1f, avg=%0.1f\n", NR, max, sum/NR}' grep 'getBatchDataTC' $logfile | \ awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ - {printf "--|get data time(filter>2000ms): count=%d, avg=%0.1f\n", count, sum/count}' + {printf "--|getBatchDataTC(filter>2000ms): count=%d, avg=%0.1f\n", count, sum/count}' grep 'getBatchDataTC' $logfile | \ awk -F":" 'BEGIN { total=0; none_zero_ms_num=0 } { total++; if($NF>0) none_zero_ms_num++ } END \ - {printf "--|get data time: total=%d, none_zero_ms_num=%d, none_zero_ms_rate=%0.3f, zero_ms_rate=%0.3f\n", \ + {printf "--|getBatchDataTC: total=%d, none_zero_ms_num=%d, none_zero_ms_rate=%0.3f, zero_ms_rate=%0.3f\n", \ total, none_zero_ms_num, none_zero_ms_num/total, (1-none_zero_ms_num/total)}' LOG_INFO "Step-2.2 KeyProcess" -- Gitee From f50a4d1aef41b6b10e8716afbe379232b3dfd4cf Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 4 Jul 2023 15:36:31 +0800 Subject: [PATCH 177/551] Match-id-ccf2d30289364063d400bf4a28910d4b0282d7f7 --- example/little_demo/main.py | 2 +- .../feat_admit_n_evict_ckpt.cpp | 4 +--- .../key_process/feature_admit_and_evict.cpp | 17 ++++++++++++----- src/core/utils/common.h | 13 ++----------- src/tests/checkpoint/checkpoint_test.cpp | 4 ---- .../ckpt_data_handler_test.cpp | 4 ---- .../feature_admit_and_evict_test.cpp | 13 +++++-------- 7 files changed, 21 insertions(+), 36 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index de2f18f2..2ee01a49 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -163,7 +163,7 @@ if __name__ == "__main__": train_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) eval_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) - optimizer_list = [get_dense_and_sparse_optimizer(cfg) for _ in range(2)] + optimizer_list = [get_dense_and_sparse_optimizer(cfg) for _ in range(1)] sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] user_hashtable = create_table(key_dtype=tf.int64, diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index 2d992c24..985ed29f 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -104,7 +104,7 @@ void FeatAdmitNEvictCkpt::SetHistRecTrans(string embName) transArr.push_back(timeStamp); for (const auto& histRec : histRecs) { - transArr.push_back(histRec.second.featureId); + transArr.push_back(histRec.first); transArr.push_back(static_cast(histRec.second.count)); transArr.push_back(static_cast(histRec.second.lastTime)); } @@ -135,10 +135,8 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) const auto& count = transArr[i + countIdxOffset]; const auto& lastTime = transArr[i + lastTimeIdxOffset]; - histRecs[featureId].featureId = featureId; histRecs[featureId].count = static_cast(count); histRecs[featureId].lastTime = lastTime; - histRecs[featureId].tensorName = embName; } } diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 5e9073bf..a36347ae 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -68,6 +68,8 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, m_recordsData.timestamps[tensorName] = batch->timestamp; } absl::flat_hash_map visitedRecords; + spdlog::trace("FeatureAdmit, name:[{}], channel:[{}], before admit, splitKey:[{}] ...", tensorName, channel, + splitKey); for (auto& key : splitKey) { if (key == -1) { continue; @@ -89,6 +91,8 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, key = -1; } } + spdlog::trace("FeatureAdmit, name:[{}], channel:[{}], after admit, splitKey:[{}] ...", tensorName, channel, + splitKey); return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; } @@ -104,7 +108,7 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con if (channel == TRAIN_CHANNEL_ID) { if (innerIt == historyRecordInfos.end()) { // 维护 m_historyRecords - FeatureItemInfo info(featureId, featureCnt, tensorName, m_recordsData.timestamps[tensorName]); + FeatureItemInfo info(featureCnt, m_recordsData.timestamps[tensorName]); historyRecordInfos[featureId] = info; currKeyCount = featureCnt; } else { @@ -153,15 +157,17 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v // 从 m_historyRecords 中淘汰删除 time_t currTime = m_recordsData.timestamps[embName]; // 从 m_tensor2SortedLastTime 获取当前要淘汰的featureId - SortedRecords lastTimePriority; + auto cmp = [](const auto& a, const auto& b) { return a.second.lastTime > b.second.lastTime; }; + std::priority_queue, + std::vector>, decltype(cmp)> lastTimePriority(cmp); for (auto& item : m_recordsData.historyRecords[embName]) { - lastTimePriority.push(item.second); + lastTimePriority.emplace(item); } while (!lastTimePriority.empty()) { - if (currTime - lastTimePriority.top().lastTime < m_tensor2Threshold[embName].timeThreshold) { + if (currTime - lastTimePriority.top().second.lastTime < m_tensor2Threshold[embName].timeThreshold) { break; } - evictKey.emplace_back(lastTimePriority.top().featureId); + evictKey.emplace_back(lastTimePriority.top().first); lastTimePriority.pop(); } @@ -171,6 +177,7 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v } spdlog::info("tensor-name[{}]'s lastTime[{}], had size[{}] keys to delete ...", embName, currTime, evictKey.size()); + spdlog::trace("tensor-name[{}]'s lastTime[{}], evictKey:[{}] ...", embName, currTime, evictKey); // 真正从 m_historyRecords 中淘汰 absl::flat_hash_map& historyRecords = m_recordsData.historyRecords[embName]; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 85cde647..ab988ebd 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -256,23 +256,14 @@ struct BatchTask { struct FeatureItemInfo { FeatureItemInfo() = default; - FeatureItemInfo(int64_t id, uint32_t cnt, std::string name, time_t lastT) - : featureId(id), count(cnt), tensorName(name), lastTime(lastT) + FeatureItemInfo(uint32_t cnt, time_t lastT) + : count(cnt), lastTime(lastT) {} - bool operator > (const FeatureItemInfo& item) const - { - return lastTime > item.lastTime; - } - - int64_t featureId { -1 }; uint32_t count { 0 }; - std::string tensorName { "" }; time_t lastTime { 0 }; }; - using SortedRecords = - std::priority_queue, std::greater>; using HistoryRecords = absl::flat_hash_map>; struct AdmitAndEvictData { HistoryRecords historyRecords; // embName ---> {id, FeatureItemInfo} 映射 diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 59c339f9..f5940db5 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -210,9 +210,7 @@ protected: timestamps = timeStamp; for (int i = 0; i < count; ++i) { - historyRecords[featureId].featureId = featureId; historyRecords[featureId].count = count; - historyRecords[featureId].tensorName = testEmbInfo.name; historyRecords[featureId].lastTime = lastTime; featureId++; @@ -463,9 +461,7 @@ TEST_F(CheckpointTest, FeatAdmitNEvict) for (const auto& validHR : validHistRec) { const auto& testHR = historyRecords.at(validHR.first); - EXPECT_EQ(validHR.second.featureId, testHR.featureId); EXPECT_EQ(validHR.second.count, testHR.count); - EXPECT_EQ(validHR.second.tensorName, testHR.tensorName); EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); } } diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index bda4803a..f4338561 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -111,9 +111,7 @@ protected: validA.push_back(timeStamp); for (int i = 0; i < count; ++i) { - historyRecords[featureId].featureId = featureId; historyRecords[featureId].count = count; - historyRecords[featureId].tensorName = testEmbInfo.name; historyRecords[featureId].lastTime = lastTime; validA.push_back(featureId); @@ -216,9 +214,7 @@ TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) for (const auto& validHR : validHistRec) { const auto& testHR = historyRecords.at(validHR.first); - EXPECT_EQ(validHR.second.featureId, testHR.featureId); EXPECT_EQ(validHR.second.count, testHR.count); - EXPECT_EQ(validHR.second.tensorName, testHR.tensorName); EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); } } diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index 1715ba71..4dd721a0 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -73,7 +73,7 @@ protected: oldInfos.erase(ele.first); } - FeatureItemInfo info = {ele.first, ele.second + oldCnt, embName, ts}; + FeatureItemInfo info = {ele.second + oldCnt, ts}; newInfos.insert(std::pair(ele.first, info)); } @@ -83,8 +83,8 @@ protected: printf("now, expect history info: \n"); for (auto& ele : newInfos) { - printf("\t{featureId[%ld], count[%d], embName[%s], lastTime[%ld]}\n", ele.second.featureId, - ele.second.count, ele.second.tensorName.c_str(), ele.second.lastTime); + printf("\t{featureId[%ld], count[%d], lastTime[%ld]}\n", ele.first, + ele.second.count, ele.second.lastTime); } printf("\n"); @@ -106,9 +106,7 @@ protected: } FeatureItemInfo& info2 = records2[ele1.first]; - if (info1.featureId != info2.featureId || - info1.count != info2.count || - info1.tensorName != info2.tensorName || + if (info1.count != info2.count || info1.lastTime != info2.lastTime) { printf("IsAllTheSameMap() 333333\n"); return false; @@ -367,8 +365,7 @@ protected: for (auto& ele : mergeKeys) { auto it = history.find(ele.first); ASSERT_EQ(it != history.end(), true); - ASSERT_EQ((history[ele.first].featureId == ele.first && - history[ele.first].count == ele.second), true); + ASSERT_EQ(history[ele.first].count == ele.second, true); } } static void TestMultiThread(FeatureAdmitAndEvictTest* testObj, std::string& thrName) -- Gitee From 56446651ec087a65c6aca913d66263875236ea74 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 4 Jul 2023 15:38:46 +0800 Subject: [PATCH 178/551] Match-id-677a05f949279fe9331ad100eb65504d641cd6f4 --- mx_rec/__init__.py | 6 ++- mx_rec/graph/modifier.py | 19 ++++++++- mx_rec/graph/patch.py | 89 ++++++++++++++++++++++++++++++++++++++- mx_rec/util/initialize.py | 58 +++++++++++++++++++++++-- 4 files changed, 164 insertions(+), 8 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index bf55ffcf..c516b6ab 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -5,13 +5,17 @@ from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops from mx_rec.saver.patch import patch_for_saver -from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator +from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator, patch_for_bool_gauge, \ + patch_for_end, patch_for_assert_eval_spec from mx_rec.optimizers.base import patch_for_optimizer patch_for_saver() patch_for_dataset() patch_for_chief_session_creator() +patch_for_assert_eval_spec() +patch_for_bool_gauge() +patch_for_end() patch_for_optimizer() __version__ = "5.0.RC2" diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 16c64d9e..172e11bc 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -19,7 +19,8 @@ from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_ ASCAnchorAttr, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_rank_size, get_training_mode_channel_id, get_feature_spec, \ insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, get_use_dynamic_expansion, \ - terminate_config_initializer, set_is_graph_modify_hook_running + terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, increase_run_times, \ + get_is_last_round from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, replace_anchor, \ record_ops_to_replace, export_pb_graph, make_sorted_key_to_tensor_list @@ -547,4 +548,18 @@ class GraphModifierHook(tf.estimator.SessionRunHook): session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) def end(self, session): - terminate_config_initializer() + bool_gauge_set = get_bool_gauge_set() + logging.debug(f"GraphModifierHook, bool_gauge_set: {bool_gauge_set}") + + # In eval or predict mode, the initializer can be directly terminated. + if 'train' not in bool_gauge_set: + logging.debug(f"In evaluate or predict case, GraphModifierHook call 'terminate_config_initializer'...") + terminate_config_initializer() + return + + if 'train_and_evaluate' in bool_gauge_set: + increase_run_times() + # In 'train_and_evaluate' mode, the terminate function should be executed last. + if get_is_last_round(): + logging.debug(f"In train_and_evaluate case, GraphModifierHook call 'terminate_config_initializer'...") + terminate_config_initializer() diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index ee3f988f..26a044be 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -4,13 +4,20 @@ import weakref import logging +from typing import Any import tensorflow as tf +import tensorflow_estimator as tensorflow_estimator_lib +from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.data.ops.dataset_ops import DatasetV2 from tensorflow.python.data.ops.dataset_ops import _VariantTracker from tensorflow.python.framework import ops +from tensorflow_estimator.python.estimator.training import EvalSpec +from tensorflow.python.eager.monitoring import BoolGauge, BoolGaugeCell +from npu_bridge.estimator.npu.npu_hook import NPUCheckpointSaverHook -from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph +from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph, insert_bool_gauge, \ + get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round def init_dataset(self, input_data): @@ -69,3 +76,83 @@ def patch_for_chief_session_creator(): """ tf.compat.v1.train.ChiefSessionCreator.__init__ = chief_session_creator_init logging.debug("__init__ in Class 'monitored_session.ChiefSessionCreator' has been patched.") + + +def get_cell(self: BoolGauge, *labels: Any) -> Any: + """ + Retrieves the cell. + Args: + self: An `BoolGauge` instance. + *labels: The label list of the new metric. + + Returns: Obtains the cell value set by the user. + """ + + logging.debug(f"Enter patch 'BoolGauge.get_cell'.") + if len(labels) > 0: + insert_bool_gauge(labels[0]) + return BoolGaugeCell(super(BoolGauge, self).get_cell(*labels)) + + +def patch_for_bool_gauge(): + """Patch for 'BoolGauge.get_cell'.""" + + BoolGauge.get_cell = get_cell + logging.debug(f"Function 'get_cell' in Class 'BoolGauge' has been patched.") + + +def end(self: NPUCheckpointSaverHook, session: tf.Session): + """ + Call at the end of session hook. + + Args: + self: An `NPUCheckpointSaverHook` instance. + session: A TensorFlow Session that will be soon closed. + + Returns: None + + """ + + logging.debug(f"Enter patch 'NPUCheckpointSaverHook.end'.") + logging.info("NPUCheckpointSaverHook end...") + basic_session_run_hooks.CheckpointSaverHook.end(self, session) + + if 'train_and_evaluate' in get_bool_gauge_set() and get_run_times() == 1: + set_is_last_round(True) + return + logging.debug(f"NPUCheckpointSaverHook call 'terminate_config_initializer'...") + terminate_config_initializer() + + +def patch_for_end(): + """Patch for 'NPUCheckpointSaverHook.end'.""" + + NPUCheckpointSaverHook.end = end + logging.debug(f"Function 'end' in Class 'NPUCheckpointSaverHook' has been patched.") + + +def assert_eval_spec(eval_spec: EvalSpec): + """ + Raise error if `eval_spec` is not of the right type. + + Args: + eval_spec: A `TrainSpec` instance to specify the training specification. + + Returns: None + + """ + + logging.debug(f"Enter patch 'tensorflow_estimator.python.estimator.training._assert_eval_spec'.") + if not isinstance(eval_spec, EvalSpec): + raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`. Got: {}'.format(type(eval_spec))) + + if 'train_and_evaluate' not in get_bool_gauge_set(): + insert_bool_gauge('train_and_evaluate') + logging.debug("assert_eval_spec: add 'train_and_evaluate' to BoolGaugeCell.") + + +def patch_for_assert_eval_spec(): + """Patch for 'tensorflow_estimator.python.estimator.training._assert_eval_spec'.""" + + tensorflow_estimator_lib.python.estimator.training._assert_eval_spec = assert_eval_spec + logging.debug(f"Function '_assert_eval_spec' in 'tensorflow_estimator.python.estimator.training' has been patched.") diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 2ab75c96..549df87e 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -15,6 +15,7 @@ from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.validator.validator import RankInfoValidator, StringValidator +from mx_rec.util.atomic import AtomicInteger class ConfigInitializer: @@ -43,10 +44,13 @@ class ConfigInitializer: self._training_mode_channel_dict = dict() self._rank_to_device_dict = dict() self._initializer_dict = {} + self._bool_gauge_set = set() self._optimizer_instance = None self._is_graph_modify_hook_running = False self._modify_graph = False self._is_terminated = False + self._is_last_round = False + self._run_times = AtomicInteger() if self._use_mpi: logging.debug(f"Using mpi to launch task.") @@ -59,10 +63,7 @@ class ConfigInitializer: self._rank_id = kwargs.get("rank_id") self._rank_size = kwargs.get("rank_size") - if os.getenv("RANK_TABLE_FILE"): - self.parse_hccl_json() - else: - self.set_hccl_info_without_json() + self.parse_hccl_json() if os.getenv("RANK_TABLE_FILE") else self.set_hccl_info_without_json() self.train_interval = kwargs.get("train_interval", -1) self.eval_steps = kwargs.get("eval_steps", -1) self.check_parameters() @@ -80,6 +81,18 @@ class ConfigInitializer: def __del__(self): self.terminate() + @property + def is_last_round(self): + return self._is_last_round + + @property + def run_times(self): + return self._run_times + + @property + def bool_gauge_set(self): + return self._bool_gauge_set + @property def is_graph_modify_hook_running(self): return self._is_graph_modify_hook_running @@ -313,6 +326,12 @@ class ConfigInitializer: self._name_to_var_dict[name] = key self._table_instance_dict[key] = instance + def insert_bool_gauge(self, name): + if not isinstance(name, str): + raise TypeError(f"bool gauge name '{name}' should be str.") + + self._bool_gauge_set.add(name) + def get_table_instance(self, key): if key not in self._table_instance_dict: raise KeyError(f"Given key does not exist.") @@ -408,6 +427,13 @@ class ConfigInitializer: self._modify_graph = is_modify_graph + @is_last_round.setter + def is_last_round(self, last_round): + if not isinstance(last_round, bool): + raise TypeError(f"last_round should be a boolean.") + + self._is_last_round = last_round + @ascend_global_hashtable_collection.setter def ascend_global_hashtable_collection(self, name): string_validator = StringValidator(name, max_len=HASHTABLE_COLLECTION_NAME_LENGTH, min_len=1) @@ -457,6 +483,30 @@ def set_is_graph_modify_hook_running(is_running): ConfigInitializer.get_instance().is_graph_modify_hook_running = is_running +def get_run_times(): + return ConfigInitializer.get_instance().run_times + + +def increase_run_times(): + ConfigInitializer.get_instance().run_times.increase() + + +def get_is_last_round(): + return ConfigInitializer.get_instance().is_last_round + + +def set_is_last_round(last_round): + ConfigInitializer.get_instance().is_last_round = last_round + + +def get_bool_gauge_set(): + return ConfigInitializer.get_instance().bool_gauge_set + + +def insert_bool_gauge(name): + ConfigInitializer.get_instance().insert_bool_gauge(name) + + def get_modify_graph(): return ConfigInitializer.get_instance().modify_graph -- Gitee From a65a48efa637c5f12189345d1069485fc399f5f9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 29 Jun 2023 11:07:28 +0800 Subject: [PATCH 179/551] Match-id-6735609b0174a978a796e6e6898daaa0f88f7a71 --- mx_rec/constants/constants.py | 1 + mx_rec/core/embedding.py | 1 + mx_rec/saver/patch.py | 2 +- mx_rec/saver/saver.py | 127 ++++++++++++++++++++------- mx_rec/util/initialize.py | 14 +++ mx_rec/util/sparse.py | 15 +--- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 57 ++++++++++++ src/core/hybrid_mgmt/hybrid_mgmt.h | 4 + src/core/utils/common.h | 2 + src/pybind/module_main.cpp | 4 +- 10 files changed, 179 insertions(+), 48 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index cd806a3e..cde18ae9 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -63,6 +63,7 @@ class BaseEnum(Enum): class DataName(Enum): + KEY = "key" EMBEDDING = "embedding" FEATURE_MAPPING = "feature_mapping" OFFSET = "offset" diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index fbd72ca2..0893a8ef 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -59,6 +59,7 @@ def create_table(**kwargs): shard_num: embedding partition number fusion_optimizer_var: fusion optimizer variable with embedding hashtable_threshold: choose to implement based on hash table or linear layer + is_save: switch whether to store sparse table data. init_param: embedding init param-coefficient """ diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 4fa7b168..c9853850 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -278,7 +278,7 @@ def saver_from_object_based_checkpoint(checkpoint_path, var_list=None, builder=N raise ValueError("Checkpoint in %s not an object-based checkpoint." % checkpoint_path) from err if var_list is None: - var_list = build_var_list(var_list) + var_list = build_var_list() if builder is None: builder = BulkSaverBuilder() diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index d83dca75..a09b5455 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -4,9 +4,7 @@ import json import os -import shutil import logging -import stat from collections import defaultdict import numpy as np @@ -15,8 +13,8 @@ from tensorflow.python.util import compat from mx_rec.constants.constants import DataName, DataAttr from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ - get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, \ - get_ascend_global_hashtable_collection + get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ + send_host_data, get_ascend_global_hashtable_collection from mx_rec.util.perf import performance @@ -34,6 +32,8 @@ class Saver(object): self.save_op_dict = defaultdict(dict) self.restore_fetch_list = [] self.placeholder_dict = defaultdict(dict) + # save_easy_mode : only save the embedding and key data of sparse tables + self.save_easy_mode = os.getenv("SAVE_EASY", 0) self.build() def build(self): @@ -66,6 +66,7 @@ class Saver(object): logging.debug(f"======== Start saving for rank id {self.rank_id} ========") save_path = save_path if save_path else self._prefix_name directory, base_name = os.path.split(save_path) + if global_step: if not isinstance(global_step, compat.integral_types): global_step = int(sess.run(global_step)) @@ -75,13 +76,11 @@ class Saver(object): integrated_path = os.path.join(directory, ckpt_name) saving_path = integrated_path - if integrated_path.startswith("/"): - saving_path = os.path.abspath(integrated_path) - if os.path.exists(saving_path): - shutil.rmtree(saving_path, ignore_errors=True) + if tf.io.gfile.exists(saving_path): + tf.io.gfile.rmtree(saving_path) logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been deleted.") - os.makedirs(saving_path, exist_ok=True) + tf.io.gfile.makedirs(saving_path) logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been made.") self._save(sess, saving_path) @@ -94,9 +93,8 @@ class Saver(object): directory, base_name = os.path.split(reading_path) ckpt_name = "sparse-%s" % base_name - integrated_path = os.path.join(directory, ckpt_name) - reading_path = os.path.abspath(integrated_path) - if not os.path.exists(reading_path): + reading_path = os.path.join(directory, ckpt_name) + if not tf.io.gfile.exists(reading_path): raise FileExistsError(f"Given dir {reading_path} does not exist, please double check.") self._restore(sess, reading_path) @@ -145,11 +143,17 @@ class Saver(object): def _save(self, sess, root_dir): result = sess.run(self.save_op_dict) for table_name, dump_data_dict in result.items(): - save_embedding_data(root_dir, table_name, dump_data_dict, self.rank_id) - table_instance = get_table_instance_by_name(table_name) - if is_asc_manager_initialized(): + if is_asc_manager_initialized() and self.save_easy_mode: + host_data = get_host_data(table_name) + key = np.array(list(host_data.keys())) + offset = list(host_data.values()) + get_valid_dict_data(dump_data_dict, offset) + save_key_data(root_dir, table_name, key, self.rank_id) + if is_asc_manager_initialized() and not self.save_easy_mode: save_host_data(root_dir) logging.debug(f"host data was saved.") + save_embedding_data(root_dir, table_name, dump_data_dict, self.rank_id) + table_instance = get_table_instance_by_name(table_name) if table_instance.use_feature_mapping: save_feature_mapping_data(root_dir, table_name, dump_data_dict, self.rank_id) @@ -162,13 +166,12 @@ class Saver(object): def _restore(self, sess, reading_path): restore_feed_dict = defaultdict(dict) - if is_asc_manager_initialized(): - restore_host_data(reading_path) - logging.debug(f"host data was restored.") - + key_offset_dict = defaultdict(dict) for table_name, sub_placeholder_dict in self.placeholder_dict.items(): fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, NameDescriptor(table_name, DataName.EMBEDDING.value)) + if self.save_easy_mode: + fill_key_offset_dict(reading_path, self.rank_id, table_name, key_offset_dict) table_instance = get_table_instance_by_name(table_name) if table_instance.use_feature_mapping: @@ -188,6 +191,12 @@ class Saver(object): name_descriptor=NameDescriptor(table_name, state_key, optimizer_name=optimizer_name)) + if is_asc_manager_initialized() and self.save_easy_mode: + send_host_data(key_offset_dict) + logging.debug(f"host data was sent to the host pipeline.") + if is_asc_manager_initialized() and not self.save_easy_mode: + restore_host_data(reading_path) + logging.debug(f"host data was restored.") sess.run(self.restore_fetch_list, feed_dict=restore_feed_dict) @@ -198,6 +207,41 @@ class NameDescriptor: self.optimizer_name = optimizer_name +def get_valid_dict_data(dump_data_dict: dict, offset: list): + """ + Extract embedding and optimizer data from the dict based on offset. + :param dump_data_dict: sparse data dict to be saved + :param offset: offset of the sparse table + """ + embedding_data = dump_data_dict.get(DataName.EMBEDDING.value)[offset, :] + dump_data_dict[DataName.EMBEDDING.value] = embedding_data + if "optimizer" in dump_data_dict: + dump_optimizer_data_dict = dump_data_dict.get("optimizer") + for optimizer_name, dump_optimizer_data in dump_optimizer_data_dict.items(): + for state_key, state in dump_optimizer_data.items(): + state = state[offset, :] + dump_optimizer_data[state_key] = state + dump_optimizer_data_dict[optimizer_name] = dump_optimizer_data + dump_data_dict["optimizer"] = dump_optimizer_data_dict + + +def fill_key_offset_dict(reading_path: str, rank_id: int, table_name: str, key_offset_dict: dict): + """ + Filling data in the key-offset dictionary , which is sent to the host pipeline. + :param reading_path: the path restoring the model + :param rank_id: rank id + :param table_name: the sparse table name + :param key_offset_dict: key-offset dictionary saving mapping relationship + """ + target_path = generate_path(reading_path, "HashTable", "HBM", table_name, + DataName.KEY.value) + key = read_binary_data(target_path, rank_id, DataName.KEY.value, table_name) + key = key.get(DataName.KEY.value) + offsets = list(range(key.shape[0])) + key_offset_map = dict(zip(key, offsets)) + key_offset_dict[table_name] = key_offset_map + + def fill_placeholder(reading_path, placeholder_dict, feed_dict, suffix, name_descriptor): if name_descriptor.optimizer_name: target_path = generate_path(reading_path, "Optimizer", name_descriptor.optimizer_name, "HBM", @@ -222,6 +266,21 @@ def save_embedding_data(root_dir, table_name, dump_data_dict, suffix): write_binary_data(target_path, suffix, data_to_write, attributes=attribute) +def save_key_data(root_dir: str, table_name: str, data_to_write: np.ndarray, suffix: int): + """ + Save the keys of the sparse table + :param root_dir: the root path saving the model + :param table_name: the sparse table name + :param data_to_write: the key array to be written + :param suffix: suffix of sparse data + """ + target_path = generate_path(root_dir, "HashTable", "HBM", table_name, DataName.KEY.value) + attribute = dict() + attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name + attribute[DataAttr.SHAPE.value] = data_to_write.shape + write_binary_data(target_path, suffix, data_to_write, attributes=attribute) + + def save_feature_mapping_data(root_dir, table_name, dump_data_dict, suffix): target_path = generate_path(root_dir, "HashTable", "HBM", table_name, DataName.FEATURE_MAPPING.value) data_to_write = dump_data_dict.get(DataName.FEATURE_MAPPING.value) @@ -265,22 +324,24 @@ def generate_file_name(suffix): def write_binary_data(writing_path, suffix, data, attributes=None): - os.makedirs(writing_path, exist_ok=True) + tf.io.gfile.makedirs(writing_path) data_file, attribute_file = generate_file_name(suffix) target_data_dir = os.path.join(writing_path, data_file) target_attribute_dir = os.path.join(writing_path, attribute_file) - if os.path.exists(target_data_dir): + if tf.io.gfile.exists(target_data_dir): raise FileExistsError(f"Target_data_dir {target_data_dir} exists before writing.") - if os.path.exists(target_attribute_dir): + if tf.io.gfile.exists(target_attribute_dir): raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} exists before writing.") - data.tofile(target_data_dir) + + with tf.io.gfile.GFile(target_data_dir, "wb") as file: + data = json.dumps(data.flatten().tolist()) + file.write(data) if attributes is not None: if not isinstance(attributes, dict): raise TypeError(f"Parameter 'attributes' must be one dict instance, instead of {type(attributes)}") - flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL - mode = stat.S_IRUSR | stat.S_IWUSR - with os.fdopen(os.open(target_attribute_dir, flags, mode), 'w') as file: + + with tf.io.gfile.GFile(target_attribute_dir, "w") as file: file.write(json.dumps(attributes)) @@ -296,20 +357,22 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: data_file, attribute_file = generate_file_name(suffix) target_data_dir = os.path.join(reading_path, data_file) target_attribute_dir = os.path.join(reading_path, attribute_file) - if not os.path.exists(target_data_dir): + if not tf.io.gfile.exists(target_data_dir): raise FileExistsError(f"Target_data_dir {target_data_dir} does not exist when reading.") - if not os.path.exists(target_attribute_dir): + if not tf.io.gfile.exists(target_attribute_dir): raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} does not exist when reading.") - with open(target_attribute_dir, "r") as fin: + with tf.io.gfile.GFile(target_attribute_dir, "r") as fin: attributes = json.load(fin) if DataAttr.DATATYPE.value not in attributes: raise AttributeError(f"Lack of attribute {DataAttr.DATATYPE.value}.") - data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) + with tf.io.gfile.GFile(target_data_dir, "rb") as file: + data_to_restore = file.read() + data_to_restore = np.array(json.loads(data_to_restore)) - if DataAttr.SHAPE.value in attributes: + if DataAttr.SHAPE.value in attributes and data_name != DataName.KEY.value: data_shape = attributes.pop(DataAttr.SHAPE.value) data_to_restore = data_to_restore.reshape(data_shape) table_instance = get_table_instance_by_name(table_name) @@ -318,8 +381,6 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: data_to_restore = process_embedding_data(data_to_restore, current_data_shape, data_shape) data_dict = {data_name: data_to_restore} - for key, item in attributes.items(): - data_dict[key] = item logging.debug(f"Attribute: '{target_attribute_dir}' and data file: '{target_data_dir}' have been read.") logging.debug(f"Reading shape is {data_to_restore.shape}.") diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index cf07b256..4f50305c 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -507,6 +507,20 @@ def is_asc_manager_initialized(): return ConfigInitializer.get_instance().get_asc_manager() is not None +def get_host_data(table_name): + if not is_asc_manager_initialized(): + raise RuntimeError("ASC manager does not exist.") + logging.debug("start to get host data.") + return ConfigInitializer.get_instance().get_asc_manager().send(table_name) + + +def send_host_data(key_offset_map): + if not is_asc_manager_initialized(): + raise RuntimeError("ASC manager does not exist.") + ConfigInitializer.get_instance().get_asc_manager().receive(key_offset_map) + logging.debug("Data has been send to the host pipeline.") + + def save_host_data(root_dir): if not is_asc_manager_initialized(): raise RuntimeError("ASC manager does not exist.") diff --git a/mx_rec/util/sparse.py b/mx_rec/util/sparse.py index 2874d4e0..9ba8392b 100644 --- a/mx_rec/util/sparse.py +++ b/mx_rec/util/sparse.py @@ -9,8 +9,7 @@ import json import tensorflow as tf import numpy as np -from mx_rec.util.initialize import get_ascend_global_hashtable_collection, get_table_instance, \ - get_table_instance_by_name +from mx_rec.util.initialize import get_table_instance, get_table_instance_by_name, export_table_name_set class SparseProcessor: @@ -33,7 +32,7 @@ class SparseProcessor: if not os.path.exists(model_dir): raise FileExistsError(f"the model_dir supported {model_dir} does not exist.") self.table_list = kwargs.get("table_list") - self.default_table_list = get_table_list() + self.default_table_list = list(export_table_name_set()) if not self.table_list: logging.debug("table list not be set, use default value : all table created ") @@ -173,16 +172,6 @@ def export(model_dir, **kwargs): return empty_value -def get_table_list(): - var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) - table_list = [] - for var in var_list: - table_instance = get_table_instance(var) - table_name = table_instance.table_name - table_list.append(table_name) - return table_list - - def check_table_param(table_list, default_table_list): out_list = [] for table in table_list: diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 00fb3127..e159c5d8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -11,6 +11,7 @@ #include "checkpoint/checkpoint.h" #include "utils/time_cost.h" +#include "utils/common.h" using namespace MxRec; using namespace std; @@ -205,6 +206,62 @@ bool HybridMgmt::Load(const string& loadPath) return true; } +key_offset_map_t HybridMgmt::SendHostMap(const string tableName) +{ +#ifndef GTEST + preprocess->LoadSaveLock(); + key_offset_mem_t keyOffsetMap; + key_offset_map_t sendKeyOffsetMap; + + if (!mgmtRankInfo.noDDR) { + spdlog::debug(MGMT + "Start send sparse data: ddr mode hashmap"); + } else { + spdlog::debug(MGMT + "Start send sparse data: no ddr mode hashmap"); + keyOffsetMap = preprocess->GetKeyOffsetMap(); + } + + if ((!keyOffsetMap.empty()) && keyOffsetMap.count(tableName)) { + for (const auto& it : keyOffsetMap.at(tableName)) { + sendKeyOffsetMap[it.first] = it.second; + } + } + + preprocess->LoadSaveUnlock(); + return sendKeyOffsetMap; +#endif +} + +void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) +{ +#ifndef GTEST + preprocess->LoadSaveLock(); + key_offset_mem_t loadKeyOffsetMap; + offset_mem_t loadMaxOffset; + if (!ReceiveKeyOffsetMap.empty()) { + for (const auto& KeyOffsetMap : ReceiveKeyOffsetMap) { + auto& SingleHashMap = loadKeyOffsetMap[KeyOffsetMap.first]; + auto& MaxOffset = loadMaxOffset[KeyOffsetMap.first]; + for (const auto& it : KeyOffsetMap.second) { + SingleHashMap[it.first] = it.second; + } + MaxOffset = KeyOffsetMap.second.size(); + } + } + if (!mgmtRankInfo.noDDR) { + spdlog::debug(MGMT + "Start receive sparse data: ddr mode hashmap"); + } else { + spdlog::debug(MGMT + "Start receive sparse data: no ddr mode hashmap"); + preprocess->LoadKeyOffsetMap(loadKeyOffsetMap); + preprocess->LoadMaxOffset(loadMaxOffset); + } + + preprocess->LoadSaveUnlock(); + if (!mgmtRankInfo.useDataset && isLoad) { + Start(); + } +#endif +} + bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) { bool loadDataMatches { true }; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index af9016ec..d734f1dd 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -57,6 +57,10 @@ namespace MxRec { bool Load(const string& loadPath); + key_offset_map_t SendHostMap(const string tableName); + + void ReceiveHostMap(all_key_offset_map_t keyOffsetMap); + void Start(); void InsertThreadForHBM(int mode); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 85cde647..c9474df1 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -441,6 +441,8 @@ struct BatchTask { using key_offset_mem_t = std::map>; using tensor_2_thresh_mem_t = absl::flat_hash_map; using trans_serialize_t = uint8_t; + using key_offset_map_t = std::map; + using all_key_offset_map_t = std::map>; enum class CkptFeatureType { HOST_EMB = 0, diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index f15bae97..b70cb47b 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -161,7 +161,9 @@ void GetHybridMgmt(pybind11::module_& m) .def("save", &MxRec::HybridMgmt::Save, py::arg("save_path") = "") .def("load", &MxRec::HybridMgmt::Load, py::arg("load_path") = "") .def("destroy", &MxRec::HybridMgmt::Destroy) - .def("evict", &MxRec::HybridMgmt::Evict); + .def("evict", &MxRec::HybridMgmt::Evict) + .def("send", &MxRec::HybridMgmt::SendHostMap, py::arg("table_name") = "") + .def("receive", &MxRec::HybridMgmt::ReceiveHostMap, py::arg("key_offset_map")); } void GetThresholdValue(pybind11::module_& m) -- Gitee From 4b9d9f812e008a5c92558d7ad5c0567f4c4350df Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 5 Jul 2023 11:45:29 +0800 Subject: [PATCH 180/551] Match-id-d01c79e2cc03257d0fae3fb8665a316912f2bb73 --- mx_rec/__init__.py | 2 +- mx_rec/graph/patch.py | 4 ++-- mx_rec/util/tf_version_adapter.py | 5 +++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index c516b6ab..3b1de2b6 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -3,7 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION -from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops +from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator, patch_for_bool_gauge, \ patch_for_end, patch_for_assert_eval_spec diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 26a044be..3c3bca3c 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -14,10 +14,10 @@ from tensorflow.python.data.ops.dataset_ops import _VariantTracker from tensorflow.python.framework import ops from tensorflow_estimator.python.estimator.training import EvalSpec from tensorflow.python.eager.monitoring import BoolGauge, BoolGaugeCell -from npu_bridge.estimator.npu.npu_hook import NPUCheckpointSaverHook from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph, insert_bool_gauge, \ get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round +from mx_rec.util.tf_version_adapter import NPUCheckpointSaverHook def init_dataset(self, input_data): @@ -101,7 +101,7 @@ def patch_for_bool_gauge(): logging.debug(f"Function 'get_cell' in Class 'BoolGauge' has been patched.") -def end(self: NPUCheckpointSaverHook, session: tf.Session): +def end(self: NPUCheckpointSaverHook, session: tf.compat.v1.Session): """ Call at the end of session hook. diff --git a/mx_rec/util/tf_version_adapter.py b/mx_rec/util/tf_version_adapter.py index d071c5c4..7d13f96e 100644 --- a/mx_rec/util/tf_version_adapter.py +++ b/mx_rec/util/tf_version_adapter.py @@ -13,3 +13,8 @@ if tf.__version__.startswith("1"): from npu_bridge.estimator import npu_ops else: from npu_device.compat.v1.estimator import npu_ops + +if tf.__version__.startswith("1"): + from npu_bridge.estimator.npu.npu_hook import NPUCheckpointSaverHook +else: + from npu_device.compat.v1.estimator.npu.npu_hook import NPUCheckpointSaverHook -- Gitee From 15c6019855b02f9ed5e3130360181c0774d2e980 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 5 Jul 2023 15:20:22 +0800 Subject: [PATCH 181/551] Match-id-fb92a97714c61ac65dc7c0dd52e0947afeb978a2 --- src/core/key_process/key_process.cpp | 1 - src/core/ock_ctr_common/include/unique.h | 1 - src/platform/AccCTR | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 3cf54b9a..927f2fef 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -216,7 +216,6 @@ void KeyProcess::GetUniqueConfig(UniqueConf& uniqueConf) uniqueConf.outputType = OutputType::ENHANCED; uniqueConf.minThreadNum = MIN_UNIQUE_THREAD_NUM; uniqueConf.maxThreadNum = PerfConfig::maxUniqueThreadNum; - uniqueConf.performance = true; } void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, diff --git a/src/core/ock_ctr_common/include/unique.h b/src/core/ock_ctr_common/include/unique.h index 9f01ec96..59ed98b5 100644 --- a/src/core/ock_ctr_common/include/unique.h +++ b/src/core/ock_ctr_common/include/unique.h @@ -54,7 +54,6 @@ using UniqueConf = struct UniqueConfCTR { uint32_t maxThreadNum = 8; // 最大工作线程数 int64_t maxIdVal = 0; // 最大id值 bool trace = false; // 是否开启性能检测,需要配合外部日志输出 - bool performance = false; // 是否开启增强接口,增强接口shardingNum必须是2的幂次方,默认用取模分桶 } __attribute__((packed)); using UniqueIn = struct UniqueInCTR { diff --git a/src/platform/AccCTR b/src/platform/AccCTR index bc9dc810..62ab674f 160000 --- a/src/platform/AccCTR +++ b/src/platform/AccCTR @@ -1 +1 @@ -Subproject commit bc9dc8103109eb8b77c09e7c4028a992a66d5015 +Subproject commit 62ab674f0a42d8de8398eafb4799e506fe99549d -- Gitee From b6ff9221f0cb5855eb5059f33c7bc76e6b774d1d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 5 Jul 2023 15:44:29 +0800 Subject: [PATCH 182/551] Match-id-f97d88955bdae38c30736556a1a83b39b33ecdee --- mx_rec/saver/saver.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index a09b5455..cd36b8dc 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -35,6 +35,9 @@ class Saver(object): # save_easy_mode : only save the embedding and key data of sparse tables self.save_easy_mode = os.getenv("SAVE_EASY", 0) self.build() + # since tf 2.6.0, tf needs tensorflow_io to support hdfs path + if tf.__version__.startswith("2"): + import tensorflow_io as tfio def build(self): if self.var_list is None: @@ -333,9 +336,14 @@ def write_binary_data(writing_path, suffix, data, attributes=None): if tf.io.gfile.exists(target_attribute_dir): raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} exists before writing.") - with tf.io.gfile.GFile(target_data_dir, "wb") as file: - data = json.dumps(data.flatten().tolist()) - file.write(data) + if target_data_dir.find("://") != -1: + logging.debug(f"use hdfs path {target_data_dir} to save sparse data.") + with tf.io.gfile.GFile(target_data_dir, "w") as file: + data = json.dumps(data.flatten().tolist()) + file.write(data) + else: + logging.debug(f"use local file path {target_data_dir} to save sparse data.") + data.tofile(target_data_dir) if attributes is not None: if not isinstance(attributes, dict): @@ -368,9 +376,14 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: if DataAttr.DATATYPE.value not in attributes: raise AttributeError(f"Lack of attribute {DataAttr.DATATYPE.value}.") - with tf.io.gfile.GFile(target_data_dir, "rb") as file: - data_to_restore = file.read() - data_to_restore = np.array(json.loads(data_to_restore)) + if target_data_dir.find("://") != -1: + logging.debug(f"use hdfs path {target_data_dir} to restore sparse data.") + with tf.io.gfile.GFile(target_data_dir, "r") as file: + data_to_restore = file.read() + data_to_restore = np.array(json.loads(data_to_restore)) + else: + logging.debug(f"use local file path {target_data_dir} to restore sparse data.") + data_to_restore = np.fromfile(target_data_dir) if DataAttr.SHAPE.value in attributes and data_name != DataName.KEY.value: data_shape = attributes.pop(DataAttr.SHAPE.value) -- Gitee From 374f72121f2ff3f97cb41f4d58ab69c3b463135c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 6 Jul 2023 16:32:44 +0800 Subject: [PATCH 183/551] Match-id-330f51dfbd681bda17ececc90a41ddbc893a33cb --- mx_rec/saver/saver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index cd36b8dc..de260160 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -383,7 +383,7 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: data_to_restore = np.array(json.loads(data_to_restore)) else: logging.debug(f"use local file path {target_data_dir} to restore sparse data.") - data_to_restore = np.fromfile(target_data_dir) + data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) if DataAttr.SHAPE.value in attributes and data_name != DataName.KEY.value: data_shape = attributes.pop(DataAttr.SHAPE.value) -- Gitee From b792479a0812e83d65b53bd247e14bd461a8273b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 6 Jul 2023 19:21:53 +0800 Subject: [PATCH 184/551] Match-id-0e4e79083bcc2ce337b704e3b40c41bc526a6d96 --- .../cust_op_by_addr/op_host/embedding_lookup_by_address.cpp | 6 ++++++ .../cust_op_by_addr/op_host/embedding_update_by_address.cpp | 5 +++++ .../op_kernel/embedding_lookup_by_address.cpp | 3 ++- .../op_kernel/embedding_update_by_address.cpp | 3 ++- mx_rec/core/embedding.py | 2 +- 5 files changed, 16 insertions(+), 3 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 6e3ea142..4c935417 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -12,6 +12,12 @@ namespace optiling static ge::graphStatus TilingFunc(gert::TilingContext *context) { TilingData1 tiling; + + size_t usrSize = 256; + size_t sysWorkspaceSize = 16 * 1024 * 1024; + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + currentWorkspace[0] = sysWorkspaceSize + usrSize; + int32_t block_total_nums = 48; int32_t ub_limit = 160 * 1024; auto *attrs = context->GetAttrs(); diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index d27b2ac2..430b167f 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -13,6 +13,11 @@ namespace optiling { TilingData2 tiling; + size_t usrSize = 256; + size_t sysWorkspaceSize = 16 * 1024 * 1024; + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + currentWorkspace[0] = sysWorkspaceSize + usrSize; + int32_t block_total_nums = 48; int32_t ub_limit = 160 * 1024; diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 58fed7bb..c2edb18f 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -190,7 +190,8 @@ private: GlobalTensor srcAddrGlobal; }; -extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR address, GM_ADDR y, GM_ADDR tiling) +extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR address, GM_ADDR y, GM_ADDR usrWorkspace, + GM_ADDR tiling) { GET_TILING_DATA(constData, tiling); // // TODO: user kernel impl diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index b5a1e976..04f1da19 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -202,7 +202,8 @@ private: GlobalTensor srcAddrGlobal; }; -extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR address, GM_ADDR embedding, GM_ADDR y, GM_ADDR tiling) +extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR address, GM_ADDR embedding, GM_ADDR y, + GM_ADDR usrWorkspace, GM_ADDR tiling) { GET_TILING_DATA(constData, tiling); diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 0225b653..1d674441 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -902,7 +902,7 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}')[0] initialized_tensor = instance.emb_initializer( - tf.shape(evict_pos)[0] + instance.embedding_size) + tf.shape(evict_pos)[0] + instance.embedding_size) * instance.init_param logging.debug(f'evict_pos output shape {evict_pos}, and slice_device_vocabulary_size ' f'{instance.slice_device_vocabulary_size}, ' -- Gitee From a3bea4bd044afe47224f238a055cedf7ade4d903 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 6 Jul 2023 19:23:09 +0800 Subject: [PATCH 185/551] Match-id-99ebbba82dd309189c62c1339d5bd6be43e5690b --- example/little_demo/main.py | 6 ++++-- mx_rec/core/embedding.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 2ee01a49..71d26584 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -166,12 +166,14 @@ if __name__ == "__main__": optimizer_list = [get_dense_and_sparse_optimizer(cfg) for _ in range(1)] sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] + # 如需验证DDR模式,请按照key数量、batch unique数量合理设置device与host表大小。 + # 验证DDR的配置参考:数据集key总量大于device表,小于device+host;一个batch的unique key数量小于device表。 user_hashtable = create_table(key_dtype=tf.int64, dim=tf.TensorShape([cfg.user_hashtable_dim]), name='user_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), device_vocabulary_size=cfg.user_vocab_size * 10, - host_vocabulary_size=0, # cfg.user_vocab_size * 100, # for h2d test + host_vocabulary_size=0, optimizer_list=sparse_optimizer_list, mode=mode) @@ -180,7 +182,7 @@ if __name__ == "__main__": name='item_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), device_vocabulary_size=cfg.item_vocab_size * 10, - host_vocabulary_size=0, # cfg.user_vocab_size * 100, # for h2d test + host_vocabulary_size=0, optimizer_list=sparse_optimizer_list, mode=mode) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 1d674441..60ec218b 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -964,7 +964,7 @@ def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Opti if access_threshold is None or access_threshold <= 0: return embeddings id_offsets_expand = tf.expand_dims(id_offsets >= 0, axis=-1) - if get_use_static(): - id_offsets_expand = tf.compat.v1.broadcast_to(id_offsets_expand, embeddings.shape.as_list()) + if tf.__version__.startswith("1"): + id_offsets_expand = tf.repeat(id_offsets_expand, [tf.shape(embeddings)[-1]], axis=-1) embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like(embeddings)) return embeddings -- Gitee From 1474c610c85e9d49f84fb3bb7a99735f3d6485a0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 25 Jun 2023 15:36:36 +0800 Subject: [PATCH 186/551] Match-id-825ee41874993a04d1013e28af3644b86faa420f --- example/little_demo/main.py | 106 +++++--------------- example/little_demo/optimizer.py | 2 +- example/little_demo/run.sh | 7 +- example/little_demo/run_mode.py | 141 +++++++++++++++++++++++++++ mx_rec/core/asc/manager.py | 14 +-- mx_rec/util/initialize.py | 39 ++++---- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 14 ++- 7 files changed, 209 insertions(+), 114 deletions(-) create mode 100644 example/little_demo/run_mode.py diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 2ee01a49..c1c09b63 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -1,23 +1,26 @@ -# coding: UTF-8 +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + import logging import os import warnings from glob import glob -import tensorflow as tf -from config import sess_config, Config +import tensorflow as tf +from config import Config from dataset import generate_dataset -from optimizer import get_dense_and_sparse_optimizer +from optimizer import create_dense_and_sparse_optimizer from model import MyModel -from mx_rec.util.tf_version_adapter import hccl_ops +from run_mode import RunMode, UseMode + from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import create_table, sparse_lookup from mx_rec.graph.modifier import modify_graph_and_start_emb_cache from mx_rec.constants.constants import MxRecMode, ASCEND_TIMESTAMP -from mx_rec.util.initialize import get_rank_id, get_rank_size, init, clear_channel, terminate_config_initializer, \ - set_if_load, get_initializer +from mx_rec.util.initialize import get_rank_id, init, terminate_config_initializer, set_if_load from mx_rec.util.variable import get_dense_and_sparse_variable tf.compat.v1.disable_eager_execution() @@ -84,21 +87,6 @@ def build_graph(hash_table_list, is_train, feature_spec_list=None, config_dict=N return iterator, model -def evaluate(): - if MODIFY_GRAPH_FLAG: - sess.run(get_initializer(False)) - else: - sess.run(eval_iterator.initializer) - clear_channel(is_train_channel=False) - for j in range(1, EVAL_STEPS + 1): - logging.info(f"################ eval at step {j} epoch {EPOCH} ################") - try: - sess.run(eval_model.loss_list) - except tf.errors.OutOfRangeError: - logging.info(f"Encounter the end of Sequence for eval.") - break - - def create_feature_spec_list(use_timestamp=False): access_threshold = cfg.access_threshold if use_timestamp else None eviction_threshold = cfg.eviction_threshold if use_timestamp else None @@ -124,8 +112,10 @@ if __name__ == "__main__": tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) warnings.filterwarnings("ignore") + use_mode = UseMode.mapping(os.getenv("USE_MODE")) mode = MxRecMode.mapping(os.getenv("MXREC_MODE")) - TRAIN_INTERVAL = 100 + TRAIN_STEPS = 100 + EVAL_INTERVAL = 100 EVAL_STEPS = 10 SAVING_INTERVAL = 100 @@ -140,7 +130,7 @@ if __name__ == "__main__": # nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0 init(use_mpi=use_mpi, - train_interval=TRAIN_INTERVAL, + train_steps=TRAIN_STEPS, eval_steps=EVAL_STEPS, prefetch_batch_number=5, use_dynamic=use_dynamic, @@ -163,7 +153,7 @@ if __name__ == "__main__": train_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) eval_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) - optimizer_list = [get_dense_and_sparse_optimizer(cfg) for _ in range(1)] + optimizer_list = [create_dense_and_sparse_optimizer(cfg)] sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] user_hashtable = create_table(key_dtype=tf.int64, @@ -192,71 +182,21 @@ if __name__ == "__main__": config_dict=ACCESS_AND_EVICT, batch_number=cfg.batch_number) dense_variables, sparse_variables = get_dense_and_sparse_variable() - rank_size = get_rank_size() - train_ops = [] - # multi task training - for loss, (dense_optimizer, sparse_optimizer) in zip(train_model.loss_list, optimizer_list): - # do dense optimization - grads = dense_optimizer.compute_gradients(loss, var_list=dense_variables) - avg_grads = [] - for grad, var in grads: - if rank_size > 1: - grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None - if grad is not None: - avg_grads.append((grad, var)) - # apply gradients: update variables - train_ops.append(dense_optimizer.apply_gradients(avg_grads)) - - if use_dynamic_expansion: - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET - - train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) - # do sparse optimization by addr - local_grads = tf.gradients(loss, train_emb_list) # local_embedding - grads_and_vars = [(grad, address) for grad, address in zip(local_grads, train_address_list)] - train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) - else: - # do sparse optimization - sparse_grads = tf.gradients(loss, sparse_variables) - grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] - train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + run_mode = RunMode( + MODIFY_GRAPH_FLAG, optimizer_list, train_model, eval_model, train_iterator, eval_iterator, + TRAIN_STEPS, EVAL_STEPS + ) - saver = tf.compat.v1.train.Saver() if MODIFY_GRAPH_FLAG: logging.info("start to modifying graph") modify_graph_and_start_emb_cache(dump_graph=True) else: start_asc_pipeline() - with tf.compat.v1.Session(config=sess_config(dump_data=False)) as sess: - if MODIFY_GRAPH_FLAG: - sess.run(get_initializer(True)) - else: - sess.run(train_iterator.initializer) - sess.run(tf.compat.v1.global_variables_initializer()) - EPOCH = 0 - if os.path.exists(f"./saved-model/sparse-model-{rank_id}-%d" % 0): - saver.restore(sess, f"./saved-model/model-{rank_id}-%d" % 0) - else: - saver.save(sess, f"./saved-model/model-{rank_id}", global_step=0) - - for i in range(1, 201): - logging.info(f"################ training at step {i} ################") - try: - sess.run([train_ops, train_model.loss_list]) - except tf.errors.OutOfRangeError: - logging.info(f"Encounter the end of Sequence for training.") - break - else: - if i % TRAIN_INTERVAL == 0: - EPOCH += 1 - evaluate() - - if i % SAVING_INTERVAL == 0: - saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) - - saver.save(sess, f"./saved-model/model-{rank_id}", global_step=i) + if use_mode == UseMode.TRAIN: + run_mode.train(EVAL_INTERVAL, SAVING_INTERVAL) + elif use_mode == UseMode.PREDICT: + run_mode.predict() terminate_config_initializer() logging.info("Demo done!") diff --git a/example/little_demo/optimizer.py b/example/little_demo/optimizer.py index 7a2e2a4b..43294764 100644 --- a/example/little_demo/optimizer.py +++ b/example/little_demo/optimizer.py @@ -9,7 +9,7 @@ from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address from mx_rec.util.initialize import get_use_dynamic_expansion -def get_dense_and_sparse_optimizer(cfg): +def create_dense_and_sparse_optimizer(cfg): dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate) use_dynamic_expansion = get_use_dynamic_expansion() if use_dynamic_expansion: diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 2d460786..81692814 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -1,6 +1,5 @@ kill -9 `ps -ef | grep python | grep -v grep | awk '{print $2}'` > /dev/null 2>&1 rm -rf /root/ascend/log/* -rm -rf ./saved-model/* rm -rf ./kernel* rm -rf ./export_graph/* @@ -26,6 +25,12 @@ export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL export MXREC_MODE="ASC" export USE_MPI=1 +export USE_MODE="train" # 支持[train, predict] + +if [ $USE_MODE = "train" ];then + echo "train mode: saved-model will be deleted" + rm -rf ./saved-model +fi ################# 参数配置 ###################### export USE_DYNAMIC=0 # 0:静态shape;1:动态shape diff --git a/example/little_demo/run_mode.py b/example/little_demo/run_mode.py new file mode 100644 index 00000000..bd068b5c --- /dev/null +++ b/example/little_demo/run_mode.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import os +import logging + +import tensorflow as tf +from config import sess_config + +from mx_rec.util.initialize import get_initializer, get_rank_id, get_rank_size, clear_channel +from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.util.tf_version_adapter import hccl_ops +from mx_rec.constants.constants import BaseEnum + + +class UseMode(BaseEnum): + TRAIN = "train" + PREDICT = "predict" + + +class RunMode: + + def __init__( + self, is_modify_graph: bool, optimizer_list: list, train_model, eval_model, train_iterator, eval_iterator, + train_steps: int, infer_steps: int): + self.is_modify_graph = is_modify_graph + self.session = tf.compat.v1.Session(config=sess_config(dump_data=False)) + self.train_model = train_model + self.train_iterator = train_iterator + self.eval_model = eval_model + self.eval_iterator = eval_iterator + self.saver = tf.compat.v1.train.Saver() + self.rank_id = get_rank_id() + self.train_ops = [] + self.optimizer_list = optimizer_list + self.epoch = 1 + self.train_steps = train_steps + self.infer_steps = infer_steps + + def _infer(self): + if self.is_modify_graph: + self.session.run(get_initializer(False)) + else: + self.session.run(self.eval_iterator.initializer) + clear_channel(is_train_channel=False) + for i in range(1, self.infer_steps + 1): + logging.info("############### infer at step %d ################", i) + try: + self.session.run(self.eval_model.loss_list) + except tf.errors.OutOfRangeError: + logging.info(f"Encounter the end of Sequence for eval.") + break + + def set_train_ops(self): + dense_variables, sparse_variables = get_dense_and_sparse_variable() + + # multi task training + for loss, (dense_optimizer, sparse_optimizer) in zip(self.train_model.loss_list, self.optimizer_list): + # do dense optimization + grads = dense_optimizer.compute_gradients(loss, var_list=dense_variables) + avg_grads = [] + for grad, var in grads: + if get_rank_size() > 1: + grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None + if grad is not None: + avg_grads.append((grad, var)) + # apply gradients: update variables + self.train_ops.append(dense_optimizer.apply_gradients(avg_grads)) + + if bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))): + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET + + train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + # do sparse optimization by addr + local_grads = tf.gradients(loss, train_emb_list) # local_embedding + grads_and_vars = [(grad, address) for grad, address in zip(local_grads, train_address_list)] + self.train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + else: + # do sparse optimization + sparse_grads = tf.gradients(loss, sparse_variables) + grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] + self.train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + + def train(self, eval_interval: int, saving_interval: int): + self.set_train_ops() + if self.is_modify_graph: + self.session.run(get_initializer(True)) + else: + self.session.run(self.train_iterator.initializer) + self.session.run(tf.compat.v1.global_variables_initializer()) + + for i in range(1, self.train_steps + 1): + logging.info("################ training at step %d ################", i) + try: + self.session.run([self.train_ops, self.train_model.loss_list]) + except tf.errors.OutOfRangeError: + logging.info(f"Encounter the end of Sequence for training.") + break + else: + if i % eval_interval == 0: + self.evaluate() + + if i % saving_interval == 0: + self.saver.save(self.session, f"./saved-model/model-{self.rank_id}", global_step=i) + + self.saver.save(self.session, f"./saved-model/model-{self.rank_id}", global_step=i) + logging.info("################ training end ################") + + def evaluate(self): + logging.info("############### start evaluate, epoch:%d ################", self.epoch) + self._infer() + logging.info("############### evaluate end, epoch::%d ################", self.epoch) + self.epoch += 1 + + def predict(self): + logging.info(f"############### start predict ################") + import glob + import re + + model_file = glob.glob(f"./saved-model/sparse-model-{self.rank_id}-*") + if len(model_file) == 0: + raise ValueError("model file not exit") + + # get the latest model + pattern = f".*sparse-model-{self.rank_id}-([0-9]+).*" + latest_step = -1 + for file_path in model_file: + match = re.match(pattern, file_path) + if match and match.groups(): + step = int(match.groups()[0]) + + if step > latest_step: + latest_step = step + if latest_step == -1: + raise RuntimeError("latest model not found") + + self.saver.restore(self.session, f"./saved-model/model-{self.rank_id}-{latest_step}") + self._infer() + logging.info(f"############### predict end ################") diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 6d6f3a79..dfc8eb7a 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -9,7 +9,7 @@ from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitiali from mx_rec.constants.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ - is_asc_manager_initialized, get_train_interval, get_eval_steps, get_prefetch_batch_number, \ + is_asc_manager_initialized, get_train_steps, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ get_use_hot, get_use_dynamic_expansion, get_enable_table_merge, export_optimizer, export_dangling_table from mx_rec.core.asc.helper import find_dangling_table, should_skip @@ -183,7 +183,7 @@ def initialize_emb_cache(table_info_list, threshold_list): rank_id = get_rank_id() device_id = get_device_id() rank_size = get_rank_size() - evaluate_stride = get_train_interval() + train_steps = get_train_steps() eval_steps = get_eval_steps() n_batch_to_prefetch = get_prefetch_batch_number() if_load = get_if_load() @@ -195,12 +195,8 @@ def initialize_emb_cache(table_info_list, threshold_list): if get_use_dynamic_expansion(): option = option | USE_DYNAMIC_EXPANSION - if get_training_mode_channel_id(is_training=False) == 0: - rank_info = RankInfo(rank_id, device_id, rank_size, option, n_batch_to_prefetch, - [eval_steps, evaluate_stride]) - else: - rank_info = RankInfo(rank_id, device_id, rank_size, option, n_batch_to_prefetch, - [evaluate_stride, eval_steps]) + # [train_steps, eval_steps] pass step information to HybridMgmt for data process loop + rank_info = RankInfo(rank_id, device_id, rank_size, option, n_batch_to_prefetch, [train_steps, eval_steps]) emb_cache = HybridMgmt() if threshold_list: @@ -213,7 +209,7 @@ def initialize_emb_cache(table_info_list, threshold_list): logging.info("Preprocessing has been sunk into the host pipeline.") logging.debug(f"Flag if load is {if_load}.") logging.debug(f"n_batch_to_prefetch is {n_batch_to_prefetch}.") - logging.debug(f"evaluate_stride is {evaluate_stride}.") + logging.debug(f"train_steps is {train_steps}.") logging.debug(f"eval_steps is {eval_steps}.") logging.debug(f"threshold_values are {threshold_list}.") diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 80d87448..40ba638a 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -12,7 +12,8 @@ import psutil import mx_rec.constants.constants from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST, LOCAL_RANK_SIZE, \ - MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH + MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH,\ + TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.validator.validator import RankInfoValidator, StringValidator from mx_rec.util.atomic import AtomicInteger @@ -30,7 +31,7 @@ class ConfigInitializer: self._asc_manager = None self._mpi = None self._is_frozen = False - self._train_interval = None + self._train_steps = None self._eval_steps = None self._prefetch_batch_number = None self._if_load = None @@ -64,7 +65,7 @@ class ConfigInitializer: self._rank_size = kwargs.get("rank_size") self.parse_hccl_json() if os.getenv("RANK_TABLE_FILE") else self.set_hccl_info_without_json() - self.train_interval = kwargs.get("train_interval", -1) + self.train_steps = kwargs.get("train_steps", -1) self.eval_steps = kwargs.get("eval_steps", -1) self.check_parameters() self.prefetch_batch_number = kwargs.get("prefetch_batch_number", 1) @@ -148,8 +149,8 @@ class ConfigInitializer: return self._rank_to_device_dict[self._rank_id] @property - def train_interval(self): - return self._train_interval + def train_steps(self): + return self._train_steps @property def eval_steps(self): @@ -298,8 +299,10 @@ class ConfigInitializer: def insert_training_mode_channel_id(self, is_training): if is_training not in self._training_mode_channel_dict: - # mx_rec has 2 channel for data input. it would bind channel_id to training mode recorded in dict. - self._training_mode_channel_dict[is_training] = len(self._training_mode_channel_dict) + # mx_rec has 2 channel for data input. + # train_model bind to channel TRAIN_CHANNEL_ID + # eval_model bind to channel EVAL_CHANNEL_ID + self._training_mode_channel_dict[is_training] = TRAIN_CHANNEL_ID if is_training else EVAL_CHANNEL_ID def get_training_mode_channel_id(self, is_training): return self._training_mode_channel_dict.get(is_training) @@ -364,8 +367,8 @@ class ConfigInitializer: if self.rank_id >= self.rank_size: raise ValueError(f"Rank_id must be within the range from 0 to rank_size.") - if self._train_interval == 0 and self._eval_steps == 0: - raise ValueError(f"Train interval and eval steps could not both equal 0.") + if self._train_steps == 0 and self._eval_steps == 0: + raise ValueError(f"Train steps and eval steps could not both equal 0.") def freeze(self): self._is_frozen = True @@ -391,10 +394,10 @@ class ConfigInitializer: self.unfreeze() logging.debug("ASC manager has been destroyed.") - @train_interval.setter - def train_interval(self, interval): - check_step(interval) - self._train_interval = interval + @train_steps.setter + def train_steps(self, step: int): + check_step(step) + self._train_steps = step @eval_steps.setter def eval_steps(self, steps): @@ -617,19 +620,19 @@ def get_customized_ops(): return ConfigInitializer.customized_ops -def get_train_interval(): - return ConfigInitializer.get_instance().train_interval +def get_train_steps(): + return ConfigInitializer.get_instance().train_steps def get_eval_steps(): return ConfigInitializer.get_instance().eval_steps -def set_train_interval(interval): - ConfigInitializer.get_instance().train_interval = interval +def set_train_steps(steps: int): + ConfigInitializer.get_instance().train_steps = steps -def set_eval_steps(steps): +def set_eval_steps(steps: int): ConfigInitializer.get_instance().eval_steps = steps diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index e159c5d8..50428eed 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -393,25 +393,35 @@ void HybridMgmt::InsertThreadForHBM(int mode) #ifndef GTEST void HybridMgmt::TaskForTrain(TaskType type) { + bool isFirstIn = true; while (isRunning) { - spdlog::info(MGMT + "Start Train Task: {}", type); + if (isFirstIn) { + spdlog::info(MGMT + "Start Train Task: {}", type); + isFirstIn = false; + } if (mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] > 0) { if (!TrainTask(type)) { return; } } + this_thread::sleep_for(1ms); } } void HybridMgmt::TaskForEval(TaskType type) { + bool isFirstIn = true; while (isRunning) { - spdlog::info(MGMT + "Start Eval Task: {}", type); + if (isFirstIn) { + spdlog::info(MGMT + "Start Eval Task: {}", type); + isFirstIn = false; + } if (mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] > 0) { if (!EvalTask(type)) { return; } } + this_thread::sleep_for(1ms); } } -- Gitee From ed099c02931ba2edc9657edb495c838c798d3ac3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 10 Jul 2023 10:41:29 +0800 Subject: [PATCH 187/551] Match-id-88d0e6d1b92861ccbbfc10536a210198d6799bcd --- mx_rec/saver/patch.py | 8 ++++++++ mx_rec/saver/saver.py | 3 --- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index c9853850..1d7cbc10 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -171,6 +171,10 @@ def build(self): def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False): + # since tf 2.6.0, tf needs tensorflow_io to support hdfs path + if tf.__version__.startswith("2") and save_path.find("://") != -1: + import tensorflow_io as tfio + if not self._is_built and not context.executing_eagerly(): raise RuntimeError("`build()` should be called before save if defer_build==True") if latest_filename is None: @@ -206,6 +210,10 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra def restore(self, sess, save_path): if save_path is None: raise ValueError("Can't load save_path when it is None.") + # since tf 2.6.0, tf needs tensorflow_io to support hdfs path + if tf.__version__.startswith("2") and save_path.find("://") != -1: + import tensorflow_io as tfio + checkpoint_prefix = compat.as_text(save_path) if self._is_empty: return diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index de260160..dcfd4fca 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -35,9 +35,6 @@ class Saver(object): # save_easy_mode : only save the embedding and key data of sparse tables self.save_easy_mode = os.getenv("SAVE_EASY", 0) self.build() - # since tf 2.6.0, tf needs tensorflow_io to support hdfs path - if tf.__version__.startswith("2"): - import tensorflow_io as tfio def build(self): if self.var_list is None: -- Gitee From 04ea0d5199b470e9cd58a5cd88dfe08c6254ac41 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 10 Jul 2023 15:17:27 +0800 Subject: [PATCH 188/551] Match-id-317c415ad1a04a3c75df8e748c3da1ab8212e4e1 --- build/build.sh | 6 ------ 1 file changed, 6 deletions(-) diff --git a/build/build.sh b/build/build.sh index c0ca4449..14e6f269 100644 --- a/build/build.sh +++ b/build/build.sh @@ -13,13 +13,11 @@ ROOT_DIR=$(dirname "${SCRIPT_DIR}") cd "$SCRIPT_DIR" if [ "$(uname -m)" = "aarch64" ] then - pip3 install virtualenv --force-reinstall virtualenv -p "$(which python3.7)" tf2_env source tf2_env/bin/activate tf265="tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl" [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ pip3 install "${tf265}" --no-deps - pip3 install setuptools==49.2.1 tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow deactivate tf2_env virtualenv -p "$(which python3.7)" tf1_env @@ -27,20 +25,17 @@ then tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl" [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ pip3 install "${tf115}" --no-deps - pip3 install setuptools==49.2.1 tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core deactivate tf1_env fi if [ "$(uname -m)" = "x86_64" ] then - pip3 install virtualenv --force-reinstall virtualenv -p "$(which python3.7)" tf2_env source tf2_env/bin/activate tf265="tensorflow_cpu-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl" [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ pip3 install "${tf265}" --no-deps - pip3 install setuptools==49.2.1 tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow deactivate tf2_env virtualenv -p "$(which python3.7)" tf1_env @@ -48,7 +43,6 @@ then tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl" [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ pip3 install "${tf115}" --no-deps - pip3 install setuptools==49.2.1 tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core deactivate tf1_env fi -- Gitee From 116799121073b1370e503fc8d78c0323eb2f1832 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 10 Jul 2023 19:29:38 +0800 Subject: [PATCH 189/551] Match-id-2c93b5ca7f6a2b1f4fbd55c74d1e0e64ad0e5f19 --- src/core/emb_hashmap/emb_hashmap.cpp | 44 ++++++++------- src/core/emb_hashmap/emb_hashmap.h | 8 +-- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 +- src/core/key_process/key_process.cpp | 64 ++++++++++++++++------ src/core/key_process/key_process.h | 4 +- src/tests/key_process/key_process_test.cpp | 19 ++++++- 6 files changed, 96 insertions(+), 45 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 689e1849..b239f292 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -38,7 +38,7 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, } void EmbHashMap::Process(const string& embName, vector& keys, size_t iBatch, - vector& tmpDataOut) + vector& tmpDataOut, int channelId) { #ifndef GTEST EASY_FUNCTION(profiler::colors::Pink) @@ -52,9 +52,9 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t spdlog::debug("FindOffset, {}", findOffsetV2); if (findOffsetV2) { - FindAndUpdateOffset(embName, keys, swapId, keepBatch); + FindAndUpdateOffset(embName, keys, swapId, keepBatch, channelId); } else { - FindOffset(embName, keys, swapId, keepBatch); + FindOffset(embName, keys, swapId, keepBatch, channelId); } spdlog::debug("FindOffset end"); @@ -96,7 +96,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t */ #ifndef GTEST void EmbHashMap::FindAndUpdateOffset(const string& embName, vector& keys, - size_t currentBatchId, size_t keepBatchId) + size_t currentBatchId, size_t keepBatchId, int channelId) { EASY_FUNCTION() size_t keySize = keys.size(); @@ -109,7 +109,7 @@ void EmbHashMap::FindAndUpdateOffset(const string& embName, vector& k continue; } auto& offset = embHashMap.lookUpVec[i]; - if (offset == INVALID_KEY_VALUE) { + if (offset == INVALID_KEY_VALUE && channelId == TRAIN_CHANNEL_ID) { offset = FindNewOffset(key, embHashMap); if (offset < devVocabSize) { embHashMap.devOffset2KeyOld.emplace_back(offset, embHashMap.devOffset2Key[offset]); @@ -302,7 +302,7 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& */ void EmbHashMap::FindOffset(const string& embName, const vector& keys, - size_t currentBatchId, size_t keepBatchId) + size_t currentBatchId, size_t keepBatchId, int channelId) { EASY_FUNCTION() size_t keySize = keys.size(); @@ -314,7 +314,7 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys embHashMap.lookUpVec.emplace_back(INVALID_KEY_VALUE); continue; } - auto offset = FindOffsetHelper(key, embHashMap); + auto offset = FindOffsetHelper(key, embHashMap, channelId); if (offset < embHashMap.devVocabSize) { embHashMap.lookUpVec.emplace_back(offset); embHashMap.devOffset2KeyOld.emplace_back(offset, static_cast(embHashMap.devOffset2Key[offset])); @@ -331,7 +331,7 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys } -size_t EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap) +size_t EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId) { size_t offset; @@ -339,29 +339,33 @@ size_t EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHas if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; spdlog::trace("devVocabSize, {} , offset , {}", embHashMap.devVocabSize, offset); - } else if (embHashMap.evictDevPos.size() != 0) { // 优先复用hbm表 + } else if (embHashMap.evictDevPos.size() != 0 && channelId == TRAIN_CHANNEL_ID) { // 优先复用hbm表 offset = embHashMap.evictDevPos.back(); embHashMap.hostHashMap[key] = offset; spdlog::trace("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", key, offset, embHashMap.evictDevPos.size()); embHashMap.evictDevPos.pop_back(); - } else if (embHashMap.evictPos.size() != 0) { // hbm不足,再复用ddr表 + } else if (embHashMap.evictPos.size() != 0 && channelId == TRAIN_CHANNEL_ID) { // hbm不足,再复用ddr表 offset = embHashMap.evictPos.back(); embHashMap.hostHashMap[key] = offset; spdlog::trace("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", key, offset, embHashMap.evictPos.size()); embHashMap.evictPos.pop_back(); } else { - embHashMap.hostHashMap[key] = embHashMap.maxOffset; - offset = embHashMap.maxOffset; - embHashMap.maxOffset++; - if (embHashMap.maxOffset == embHashMap.devVocabSize) { - spdlog::info("start using host vocab!"); - } - if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { - spdlog::error("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, - embHashMap.hostVocabSize); - throw runtime_error("hostVocabSize too small"); + if (channelId == TRAIN_CHANNEL_ID) { + embHashMap.hostHashMap[key] = embHashMap.maxOffset; + offset = embHashMap.maxOffset; + embHashMap.maxOffset++; + if (embHashMap.maxOffset == embHashMap.devVocabSize) { + spdlog::info("start using host vocab!"); + } + if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { + spdlog::error("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, + embHashMap.hostVocabSize); + throw runtime_error("hostVocabSize too small"); + } + } else { + offset = -1; } } return offset; diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index b99e7777..c9b88c09 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -24,10 +24,10 @@ namespace MxRec { void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); void Process(const string& embName, std::vector& keys, size_t iBatch, - vector& tmpDataOut); + vector& tmpDataOut, int channelId); void FindAndUpdateOffset(const string& embName, vector& keys, size_t currentBatchId, - size_t keepBatchId); + size_t keepBatchId, int channelId); void ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_t hostOffset, size_t currentBatchId, int pos); @@ -53,9 +53,9 @@ namespace MxRec { absl::flat_hash_map embHashMaps; void FindOffset(const string& embName, const vector& keys, - size_t currentBatchId, size_t keepBatchId); + size_t currentBatchId, size_t keepBatchId, int channelId); - size_t FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap); + size_t FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId); void UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, EmbHashMapInfo& embHashMap) const; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 50428eed..1c80778d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -740,7 +740,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, auto restore = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); hdTransfer->Send(TransferChannel::RESTORE, *restore, channelId, embName); vector tmpData; - hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData); + hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData, channelId); hdTransfer->Send(TransferChannel::LOOKUP, { tmpData.front() }, channelId, embName); tmpData.erase(tmpData.begin()); hdTransfer->Send(TransferChannel::SWAP, tmpData, channelId, embName); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 927f2fef..818d21ae 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -334,7 +334,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat // map key to offset directly by lookup keyOffsetMap (hashmap) if (rankInfo.noDDR) { TimeCost key2OffsetTC; - Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv); + Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv, channel); TIME_PRINT("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); } if (!rankInfo.useStatic) { // Static all2all,need send count @@ -391,7 +391,11 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) if (rankInfo.noDDR) { - Key2Offset(batch->name, lookupKeys); + if (rankInfo.useDynamicExpansion) { + Key2OffsetDynamicExpansion(batch->name, lookupKeys, channel); + } else { + Key2Offset(batch->name, lookupKeys, channel); + } } if (!rankInfo.useStatic) { // Static all2all,need send count @@ -967,7 +971,7 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, in spdlog::debug("rank {} key scAllOut matrix:\n{}", rankInfo.rankId, scAllOut); } -void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) +void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int channel) { TimeCost key2OffsetTC; EASY_FUNCTION(profiler::colors::Blue600) @@ -975,19 +979,14 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) auto& key2Offset = keyOffsetMap[embName]; auto& maxOffsetTmp = maxOffset[embName]; auto& evictPos = evictPosMap[embName]; - auto& curEmbTable = embeddingTableMap[embName]; // empty when not use dynamic expansion for (long& key : splitKey) { if (key == -1) { - if (rankInfo.useDynamicExpansion) { - key = 0; - } continue; } const auto& iter = key2Offset.find(key); if (iter != key2Offset.end()) { - // 老值 key = iter->second; - } else if (evictPos.size() != 0) { + } else if (evictPos.size() != 0 && channel == TRAIN_CHANNEL_ID) { size_t offset; // 新值, emb有pos可复用 offset = evictPos.back(); @@ -998,23 +997,52 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey) evictPos.pop_back(); } else { // 新值 - if (rankInfo.useDynamicExpansion) { + if (channel == TRAIN_CHANNEL_ID) { + key2Offset[key] = maxOffsetTmp; + key = maxOffsetTmp++; + } else { + key = INVALID_KEY_VALUE; + } + } + } + if (maxOffsetTmp > embInfos[embName].devVocabSize) { + spdlog::error("dev cache overflow {}>{}", maxOffsetTmp, embInfos[embName].devVocabSize); + throw std::runtime_error("dev cache overflow!"); + } + spdlog::debug("current dev emb usage:{}/{}", maxOffsetTmp, embInfos[embName].devVocabSize); + TIME_PRINT("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); +} + +void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& splitKey, int channel) +{ + TimeCost key2OffsetTC; + EASY_FUNCTION(profiler::colors::Blue600) + std::lock_guard lk(mut); // lock for PROCESS_THREAD + auto& key2Offset = keyOffsetMap[embName]; + auto& maxOffsetTmp = maxOffset[embName]; + auto& curEmbTable = embeddingTableMap[embName]; // empty when not use dynamic expansion + for (long& key : splitKey) { + if (key == -1) { + key = 0; + continue; + } + const auto& iter = key2Offset.find(key); + if (iter != key2Offset.end()) { + key = iter->second; + } else { + // 新值 + if (channel == TRAIN_CHANNEL_ID) { #ifndef GTEST auto addr = curEmbTable.GetEmbAddress(); key2Offset[key] = addr; key = addr; #endif maxOffsetTmp++; - } else { - key2Offset[key] = maxOffsetTmp; - key = maxOffsetTmp++; + continue; + } + key = 0; } } - } - if (!rankInfo.useDynamicExpansion && maxOffsetTmp > embInfos[embName].devVocabSize) { - spdlog::error("dev cache overflow {}>{}", maxOffsetTmp, embInfos[embName].devVocabSize); - throw std::runtime_error("dev cache overflow!"); - } spdlog::debug("current dev emb usage:{}/{}", maxOffsetTmp, embInfos[embName].devVocabSize); TIME_PRINT("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 384365be..4c2745b2 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -155,7 +155,9 @@ namespace MxRec { void GetScAllForUnique(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const; - void Key2Offset(const emb_name_t& embName, keys_t& splitKey); + void Key2Offset(const emb_name_t& embName, keys_t& splitKey, int channel); + + void Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& splitKey, int channel); unique_ptr GetBatchData(int channel, int commId); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 15b96063..643d6d96 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -402,7 +402,24 @@ TEST_F(KeyProcessTest, Key2Offset) keys_t expectOffset = { 0, 1, 2, 0, 3, 0, 4, 3 }; ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); - process.Key2Offset("emb0", lookupKeys); + process.Key2Offset("emb0", lookupKeys, TRAIN_CHANNEL_ID); + spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys, process.keyOffsetMap); + ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); + + keys_t lookupKeys2 = { 5, 17, 29, 5, 25, 5, 21, 25 }; + keys_t expectOffset2 = { -1, -1, -1, -1, -1, -1, -1, -1 }; + process.Key2Offset("emb0", lookupKeys2, EVAL_CHANNEL_ID); + spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys2, process.keyOffsetMap); + ASSERT_THAT(lookupKeys2, ElementsAreArray(expectOffset2)); +} + +TEST_F(KeyProcessTest, Key2OffsetDynamicExpansion) +{ + keys_t lookupKeys = { 4, 16, 28, -1, 24, -1, 20, 24 }; + keys_t expectOffset = { 0, 0, 0, 0, 0, 0, 0, 0 }; + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); + ASSERT_EQ(process.isRunning, true); + process.Key2OffsetDynamicExpansion("emb0", lookupKeys, EVAL_CHANNEL_ID); spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys, process.keyOffsetMap); ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); } -- Gitee From b64a257636d17de0e5f0c4821cdcddce4d463f66 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 10 Jul 2023 22:07:55 +0800 Subject: [PATCH 190/551] Match-id-36f0843125e59aa4b2e4c4fe60300d96644871d4 --- example/little_demo/run.sh | 45 +++++++++++++++++++++++++++++++++++--- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 81692814..7694d271 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -3,6 +3,48 @@ rm -rf /root/ascend/log/* rm -rf ./kernel* rm -rf ./export_graph/* +# 获取输入参数:py、ip +if [ $# -ge 1 ]; then + py=$1 + ip=$2 +else + echo "for example: bash run.sh main.py 10.10.10.10 or bash run.sh main.py" + exit 1 +fi + +# 检查输入的python文件是否合法 +if [[ $py =~ ^[a-z0-9_]+\.py$ ]]; then + echo "File $py is a valid Python file" +else + echo "File $py is not a Python file" + exit 1 +fi + +# 判断IP地址是否有效 +if [ -n "$ip" ]; then + if [[ $ip =~ ^([0-9]{1,3}\.){3}[0-9]{1,3}$ ]]; then + # 将IP地址拆分成四个数字 + ip_array=(${ip//./ }) + # 判断每个数字是否在0-255之间 + valid=true + for i in "${ip_array[@]}"; do + if ((i < 0 || i > 255)); then + valid=false + break + fi + done + if $valid; then + echo "ip: $ip is valid" + else + echo "ip: $ip is not valid" + exit 1 + fi + else + echo "ip: $ip is not valid." + exit 1 + fi +fi + cur_path=`pwd` mx_rec_package_path="/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec" # please config so_path=${mx_rec_package_path}/libasc @@ -69,7 +111,6 @@ function rankTableSolution() { fi } -ip=$2 if [ ! -n "$ip" ]; then rankTableSolution else @@ -101,8 +142,6 @@ else fi fi -py=$1 -echo "py is $py" echo "use horovod to start tasks" DATE=$(date +%Y-%m-%d-%H-%M-%S) horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ -- Gitee From 2d5c3b4662c8ba85a5e5253942bc4c66d964ba6c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 11 Jul 2023 19:16:31 +0800 Subject: [PATCH 191/551] Match-id-660fb307cc8660ac0cb19b9a05f9af3203bc40f4 --- cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp | 2 +- cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 4c935417..3e50e6a4 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -44,7 +44,7 @@ namespace optiling return ge::GRAPH_SUCCESS; } - static int check_op_support(const ge::Operator &op, ge::AscendString &result) + static ge::graphStatus check_op_support(const ge::Operator &op, ge::AscendString &result) { std::string res_json_str = "{\"ret_code\": \"0\",\"reason\": \"check_supported_stub\"}"; result = ge::AscendString(res_json_str.c_str()); diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index 430b167f..b03d1a6d 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -63,7 +63,7 @@ namespace optiling return ge::GRAPH_SUCCESS; } - static int check_op_support(const ge::Operator &op, ge::AscendString &result) + static ge::graphStatus check_op_support(const ge::Operator &op, ge::AscendString &result) { std::string res_json_str = "{\"ret_code\": \"0\",\"reason\": \"check_supported_stub\"}"; result = ge::AscendString(res_json_str.c_str()); -- Gitee From 679341459c1674a0bc7c726a143d6da3396a868e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 11 Jul 2023 20:39:02 +0800 Subject: [PATCH 192/551] Match-id-045598c11d6d8a58512b2144134fadee29aff585 --- src/core/hd_transfer/hd_transfer.cpp | 30 +++++++++++++++++++++------- src/core/hd_transfer/hd_transfer.h | 5 ++++- src/core/host_emb/host_emb.cpp | 7 ++----- src/pybind/CMakeLists.txt | 2 +- src/tests/CMakeLists.txt | 2 +- 5 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 4181fa68..8a4872d6 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -35,7 +35,19 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) for (int i = 0; i < MAX_CHANNEL_NUM; ++i) { CreateChannel(localRankId, embInfo.name, i); } + aclDatasets[embInfo.name] = acltdtCreateDataset(); } + const char* timeoutEnv = getenv("AclTimeout"); + if (timeoutEnv != nullptr) { + int32_t timeoutEnvCast = static_cast(std::atoi(timeoutEnv)); + spdlog::debug("timeoutEnv:{}", timeoutEnvCast); + if (timeoutEnvCast > INT32_MAX || timeoutEnvCast < -1) { + spdlog::warn("AclTimeout={} is not valid", timeoutEnvCast); + } else { + timeout = timeoutEnvCast; + } + } + spdlog::debug("hd transfer timeout:{}", timeout); running = true; spdlog::info("hd_transfer init"); #endif @@ -51,6 +63,11 @@ void HDTransfer::Destroy() tensorflow::StopRecvTensorByAcl(&c.second, c.first); spdlog::info(HD + "destroy channel:{}", c.first); } + for (auto& d: aclDatasets) { + if (acltdtDestroyDataset(d.second) != ACL_ERROR_NONE) { + throw runtime_error("Acl destroy tensor dataset failed."); + } + } aclFinalize(); #endif } @@ -163,7 +180,7 @@ vector HDTransfer::Recv(TransferChannel channel, int channel return {}; } -tuple HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& embName) +size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& embName) { EASY_FUNCTION() #ifndef GTEST @@ -171,21 +188,20 @@ tuple HDTransfer::RecvAcl(TransferChannel channel, int c string recvName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelId); spdlog::debug("hd transfer try recv:{}", recvName); spdlog::stopwatch sw; - acltdtDataset* aclDataset = acltdtCreateDataset(); - if (aclDataset == nullptr) { + if (aclDatasets[embName] == nullptr) { throw runtime_error(fmt::format("Failed recv:{}.", recvName).c_str()); } - auto aclStatus = acltdtReceiveTensor(transferChannels[recvName], aclDataset, -1 /* no timeout */); + auto aclStatus = acltdtReceiveTensor(transferChannels[recvName], aclDatasets[embName], timeout /*-1 no timeout */); if (!running) { - return {nullptr, 0}; + return 0; } if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { throw runtime_error(fmt::format("Failed receive data from acl channel, acl status:{}", aclStatus).c_str()); } spdlog::info("hd transfer recv:{} cost:{}ms", recvName, Format2Ms(sw)); - return {aclDataset, acltdtGetDatasetSize(aclDataset)}; + return acltdtGetDatasetSize(aclDatasets[embName]); #endif - return {nullptr, 0}; + return 0; } size_t HDTransfer::QueryChannelSize(const string& channelName) diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index 37fee0b4..fea3c43e 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -62,6 +62,8 @@ namespace MxRec { class HDTransfer { public: + std::unordered_map aclDatasets; + HDTransfer() = default; int Init(const vector& embInfos, uint32_t localRankId); @@ -71,7 +73,7 @@ namespace MxRec { vector Recv(TransferChannel channel, int channelId, const string& embName); - tuple RecvAcl(TransferChannel channel, int channelId, const string& embName); + size_t RecvAcl(TransferChannel channel, int channelId, const string& embName); size_t QueryChannelSize(const string& channelName); @@ -84,6 +86,7 @@ namespace MxRec { std::unordered_map transferChannels; #endif bool running; + int32_t timeout{-1}; void CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum); }; } diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 61db3ef8..7772ab13 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -145,7 +145,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; spdlog::info(HOSTEMB + "wait D2H embs, channelId:{}", channelId); - auto [aclDataset, size] = hdTransfer->RecvAcl(transferName, channelId, embName); + auto size = hdTransfer->RecvAcl(transferName, channelId, embName); if (size == 0) { spdlog::warn(HOSTEMB + "recv empty data"); return; @@ -155,7 +155,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI EASY_BLOCK("Update") auto& embData = hostEmbs[embName].embData; auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; - auto aclData = acltdtGetDataItem(aclDataset, 0); + auto aclData = acltdtGetDataItem(hdTransfer->aclDatasets[embName], 0); if (aclData == nullptr) { throw runtime_error("Acl get tensor data from dataset failed."); } @@ -168,9 +168,6 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI dst[k] = ptr[k + embeddingSize * j]; } } - if (acltdtDestroyDataset(aclDataset) != ACL_ERROR_NONE) { - throw runtime_error("Acl destroy tensor dataset failed."); - } spdlog::info(HOSTEMB + "update emb end cost: {}ms", Format2Ms(sw)); })); } diff --git a/src/pybind/CMakeLists.txt b/src/pybind/CMakeLists.txt index 28ec5210..63131cd3 100644 --- a/src/pybind/CMakeLists.txt +++ b/src/pybind/CMakeLists.txt @@ -4,5 +4,5 @@ pybind11_add_module(mxrec_pybind module_main.cpp) set_target_properties(mxrec_pybind PROPERTIES LINK_FLAGS "-Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack") target_include_directories(mxrec_pybind PUBLIC ${ASCEND_DRIVER_PATH}/include) target_link_directories(mxrec_pybind PUBLIC ${ASCEND_DRIVER_PATH}/lib64/driver) -target_link_libraries(mxrec_pybind PUBLIC ASC dcmi) +target_link_libraries(mxrec_pybind PUBLIC ASC) install(TARGETS mxrec_pybind LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 0059273c..8c73e748 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -51,7 +51,7 @@ target_link_directories(test_main target_link_libraries(test_main PUBLIC ${TF_LIB} securec OpenMP::OpenMP_CXX ${HDF5_CXX_LIBRARIES} ${MPI_CXX_LIBRARIES} - ${PYTHON_LIBRARY} drvdsmi_host dcmi + ${PYTHON_LIBRARY} drvdsmi_host -ldl ) -- Gitee From 021af0058d086288fdfb3b056f302299c1ed9d28 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 11 Jul 2023 21:42:48 +0800 Subject: [PATCH 193/551] Match-id-887f1bfaad77579ee31333e53e424610c058a466 --- example/little_demo/main.py | 2 +- example/little_demo/run.sh | 1 + src/core/host_emb/host_emb.cpp | 55 ++++++++++++++++++++++++---- src/core/host_emb/host_emb.h | 5 ++- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 12 +++--- src/core/hybrid_mgmt/hybrid_mgmt.h | 3 +- src/tests/host_emb/host_emb_test.cpp | 12 ++++++ 7 files changed, 72 insertions(+), 18 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 0842a9a4..37c13eb8 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -132,7 +132,7 @@ if __name__ == "__main__": init(use_mpi=use_mpi, train_steps=TRAIN_STEPS, eval_steps=EVAL_STEPS, - prefetch_batch_number=5, + prefetch_batch_number=1, use_dynamic=use_dynamic, use_hot=use_hot, use_dynamic_expansion=use_dynamic_expansion) diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 81692814..967a5356 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -39,6 +39,7 @@ export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 +export UpdateEmb_V2=0 # 0: UpdateEmb同步更新;1:UpdateEmb_V2异步更新 ################# 性能调优相关 #################### export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 export FAST_UNIQUE=0 #if use fast unique diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 61db3ef8..2e111b8a 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -89,15 +89,41 @@ void HostEmb::LoadEmb(emb_mem_t& loadData) #endif } -void HostEmb::Join() +void HostEmb::Join(int channelId) { spdlog::stopwatch sw; - spdlog::debug(HOSTEMB + "hostemb start join {}", procThreads.size()); - for (auto& t: procThreads) { - t->join(); + switch (channelId) { + case TRAIN_CHANNEL_ID: + spdlog::debug( + HOSTEMB + "start join, channelId:{}, procThreadsForTrain num:{}", + channelId, procThreadsForTrain.size() + ); + for (auto& t: procThreadsForTrain) { + t->join(); + } + procThreadsForTrain.clear(); + spdlog::debug( + HOSTEMB + "end join, channelId:{}, cost:{}", + channelId, duration_cast((sw).elapsed()) + ); + break; + case EVAL_CHANNEL_ID: + spdlog::debug( + HOSTEMB + "start join, channelId:{}, procThreadsForEval num:{}", + channelId, procThreadsForEval.size() + ); + for (auto& t: procThreadsForEval) { + t->join(); + } + procThreadsForEval.clear(); + spdlog::debug( + HOSTEMB + "end join, channelId:{}, cost:{}", + channelId, duration_cast((sw).elapsed()) + ); + break; + default: + throw invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); } - procThreads.clear(); - spdlog::info(HOSTEMB + "hostemb end join, cost:{}", duration_cast((sw).elapsed())); } /* @@ -107,6 +133,7 @@ void HostEmb::Join() #ifndef GTEST void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName) { + spdlog::info(HOSTEMB + "UpdateEmb, channelId:{}, embName:{}", channelId, embName); EASY_FUNCTION(profiler::colors::Purple); spdlog::stopwatch sw; auto hdTransfer = Singleton::GetInstance(); @@ -139,8 +166,9 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName) { + spdlog::info(HOSTEMB + "UpdateEmbV2, channelId:{}, embName:{}", channelId, embName); EASY_FUNCTION(profiler::colors::Purple) - procThreads.emplace_back(make_unique( + auto updateThread = [&, missingKeysHostPos, channelId, embName] { auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; @@ -172,7 +200,18 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI throw runtime_error("Acl destroy tensor dataset failed."); } spdlog::info(HOSTEMB + "update emb end cost: {}ms", Format2Ms(sw)); - })); + }; + + switch (channelId) { + case TRAIN_CHANNEL_ID: + procThreadsForTrain.emplace_back(make_unique(updateThread)); + break; + case EVAL_CHANNEL_ID: + procThreadsForEval.emplace_back(make_unique(updateThread)); + break; + default: + throw invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); + } } /* diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index 4464a202..9ffd9bd5 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -32,7 +32,7 @@ namespace MxRec { void LoadEmb(absl::flat_hash_map& loadData); - void Join(); + void Join(int channelId); void UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName); @@ -52,7 +52,8 @@ namespace MxRec { GTEST_PRIVATE: absl::flat_hash_map hostEmbs; - std::vector> procThreads; + std::vector> procThreadsForTrain; + std::vector> procThreadsForEval; void EmbDataGenerator(const vector& initializeInfos, int seed, int vocabSize, int embeddingSize, vector>& embData); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 50428eed..d1a441d7 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -748,7 +748,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } - TIME_PRINT("getAndSendTensorsTC(ms):{}", getAndSendTensorsTC.ElapsedMS()); + TIME_PRINT("getAndSendTensorsTC(ms):{}, channelId:{}", getAndSendTensorsTC.ElapsedMS(), channelId); if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch spdlog::warn(MGMT + "embName {}[{}]{},iBatch:{} freeSize not enough, {}", embName, channelId, @@ -764,8 +764,8 @@ void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, in if (iBatch == 0) { return; } - spdlog::info(MGMT + "trans emb, batchId:[{}-{}]", start, batchId); - hostEmbs->Join(); + spdlog::info(MGMT + "trans emb, batchId:[{}-{}], channelId:{}", start, batchId, channelId); + hostEmbs->Join(channelId); EmbHDTrans(channelId, batchId); for (int i = 0; i < iBatch - 1; ++i) { @@ -790,8 +790,8 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) for (const auto& embInfo: mgmtEmbInfo) { const auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); if (!(skipUpdate && missingKeys.empty())) { - bool updateEmbV2 = getenv("UpdateEmb_V2") != nullptr; - if (updateEmbV2) { + auto updateEmbV2 = getenv("UpdateEmb_V2"); + if (updateEmbV2 != nullptr and atoi(updateEmbV2) == 1) { hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! } else { hostEmbs->UpdateEmb(missingKeys, channelId, embInfo.name); // order! @@ -799,7 +799,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) } // skip when skip update and empty missing keys hostHashMaps->ClearMissingKeys(embInfo.name); } - TIME_PRINT("EmbHDTrans TimeCost(ms):{} batchId: {} ", tr.ElapsedMS(), batchId); + TIME_PRINT("EmbHDTrans TimeCost(ms):{} batchId:{} channelId:{}", tr.ElapsedMS(), batchId, channelId); } void HybridMgmt::EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index d734f1dd..5b8ee6e7 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -85,7 +85,8 @@ namespace MxRec { t->join(); } if (hostEmbs != nullptr) { - hostEmbs->Join(); + hostEmbs->Join(TRAIN_CHANNEL_ID); + hostEmbs->Join(EVAL_CHANNEL_ID); hostEmbs = nullptr; } procThreads.clear(); diff --git a/src/tests/host_emb/host_emb_test.cpp b/src/tests/host_emb/host_emb_test.cpp index de74f400..1f0c9673 100644 --- a/src/tests/host_emb/host_emb_test.cpp +++ b/src/tests/host_emb/host_emb_test.cpp @@ -51,3 +51,15 @@ TEST(HostEmb, Tensor2Float) std::cout << q[1].flat()(0) << std::endl; ASSERT_EQ(1, 1); } + +TEST(HostEmb, DefaultConstructor) +{ + HostEmb h; + h.procThreadsForTrain.emplace_back(make_unique([] {})); + h.Join(TRAIN_CHANNEL_ID); + ASSERT_EQ(h.procThreadsForTrain.size(), 0); + + h.procThreadsForEval.emplace_back(make_unique([] {})); + h.Join(EVAL_CHANNEL_ID); + ASSERT_EQ(h.procThreadsForEval.size(), 0); +} \ No newline at end of file -- Gitee From 81fa57c935de0ce535d4800d4941c96aa562a620 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 12 Jul 2023 11:12:25 +0800 Subject: [PATCH 194/551] Match-id-f631f79a5b1a2a1513d563aae0ee5e76db908b91 --- mx_rec/util/initialize.py | 19 ++++++++++++++++++- mx_rec/util/sparse.py | 18 +++++++++++++++++- mx_rec/validator/validator.py | 18 +++++++++--------- tools/python/key_2_emb_formatter.py | 9 +++++++++ 4 files changed, 53 insertions(+), 11 deletions(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 40ba638a..46763197 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -15,7 +15,7 @@ from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH,\ TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID from mx_rec.util.ops import import_host_pipeline_ops -from mx_rec.validator.validator import RankInfoValidator, StringValidator +from mx_rec.validator.validator import RankInfoValidator, StringValidator, FileValidator from mx_rec.util.atomic import AtomicInteger @@ -217,7 +217,16 @@ class ConfigInitializer: rank_table_path = os.path.realpath(os.getenv("RANK_TABLE_FILE")) if not os.path.exists(rank_table_path): raise FileExistsError(f"Target_hccl_json_dir {rank_table_path} does not exist when reading.") + with open(rank_table_path, "r", encoding="utf-8") as file: + # check whether json file is valid + file_validator = FileValidator(rank_table_path) + # 1.check whether rank_table_path is soft link + file_validator.check_not_soft_link() + # 2.check json file size + file_validator.check_file_size(file) + file_validator.check() + table_hccl = json.load(file) if "server_list" not in table_hccl: raise AttributeError(f"Lack of attribute server_list.") @@ -800,7 +809,15 @@ def get_available_cpu_num_and_range(): logging.warning(f"failed to get numa node of cpu: {cpu}") is_ok = False break + with open(f_path, "r", encoding="utf-8") as f_in: + # check whether file is valid + file_validator = FileValidator(f_path) + # 1.check whether f_path is soft link + file_validator.check_not_soft_link() + # 2.check file size + file_validator.check_file_size(f_in) + file_validator.check() pkg_id = f_in.readline().strip() pkg_id2cpu_list[pkg_id].append(cpu) diff --git a/mx_rec/util/sparse.py b/mx_rec/util/sparse.py index 9ba8392b..67811cfa 100644 --- a/mx_rec/util/sparse.py +++ b/mx_rec/util/sparse.py @@ -10,6 +10,7 @@ import tensorflow as tf import numpy as np from mx_rec.util.initialize import get_table_instance, get_table_instance_by_name, export_table_name_set +from mx_rec.validator.validator import FileValidator class SparseProcessor: @@ -46,6 +47,15 @@ class SparseProcessor: @staticmethod def _get_data(data_dir, dtype, data_shape): + with open(data_dir, "rb", encoding="utf-8") as file: + # check whether data file is valid + file_validator = FileValidator(data_dir) + # 1.check whether data_dir is soft link + file_validator.check_not_soft_link() + # 2.check data file size + file_validator.check_file_size(file) + file_validator.check() + try: data = np.fromfile(data_dir, dtype=dtype) except FileNotFoundError as err: @@ -58,6 +68,13 @@ class SparseProcessor: if is_json: try: with open(attribute_dir, "r") as fin: + # check whether attribute file is valid + file_validator = FileValidator(attribute_dir) + # 1.check whether attribute_dir is soft link + file_validator.check_not_soft_link() + # 2.check attribute file size + file_validator.check_file_size(fin) + file_validator.check() attributes = json.load(fin) except FileNotFoundError as err: raise FileNotFoundError(f"attribute dir not found.") from err @@ -92,7 +109,6 @@ class SparseProcessor: transformed_data = dict(zip(key[:], emb_data[:])) np.save(out_dir + self.sep + self.export_name + ".npy", transformed_data) - def get_embedding(self, device_table_dir, host_table_dir, ddr): emb_dir = os.path.join(device_table_dir, self.device_emb_dir) data_file, attribute_file = self._get_file_names(emb_dir) diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index 73010f76..80d3edb6 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -38,10 +38,9 @@ class Validator: if self.is_valid_state is None: self.is_valid_state = True for checker, msg in self.checkers: - self.is_valid_state &= checker(self.value) - if not self.is_valid_state: + if not checker(self.value): self.msg = msg - break + raise ValueError(self.msg) if self.is_valid_state: self.msg = None return self @@ -257,16 +256,17 @@ class FileValidator(StringValidator): @param value: the file path, should not be emtpy string, should not contain double dot(../) """ super().__init__(value) - self.register_checker(lambda x: isinstance(x, str), "type is not str") + self.register_checker(lambda x: isinstance(x, str), "parameter value's type is not str") - def check_file_size(self, max_size=MAX_SIZE, min_size=MIN_SIZE): - self.register_checker(lambda path: min_size < os.path.getsize(self.value) <= max_size, - "file size is invalid") + def check_file_size(self, file_obj, max_size=MAX_SIZE, min_size=MIN_SIZE): + file_info = os.stat(file_obj.fileno()) + self.register_checker(lambda path: min_size < file_info.st_size <= max_size, + f"file size: {file_info.st_size} is invalid, not in [{min_size}, {max_size}]") return self def check_not_soft_link(self): - self.register_checker(lambda path: os.path.realpath(self.value) == self.value, - "soft link or relative path should not be in the path parameter") + self.register_checker(lambda path: not os.path.islink(self.value), + f"soft link or relative path: {self.value} should not be in the path parameter") return self def check_user_group(self): diff --git a/tools/python/key_2_emb_formatter.py b/tools/python/key_2_emb_formatter.py index 467bd5c3..4f5ca6b3 100644 --- a/tools/python/key_2_emb_formatter.py +++ b/tools/python/key_2_emb_formatter.py @@ -12,6 +12,8 @@ import os import re import numpy as np +from mx_rec.validator.validator import FileValidator + parser = argparse.ArgumentParser() parser.add_argument('--path', type=str, required=True, help='path of the root dir of saved file') @@ -185,6 +187,13 @@ class Formatter: file_dir = os.path.join(directory, file_name) if is_json: with open(file_dir, "r") as fin: + # check whether attribute file is valid + file_validator = FileValidator(file_dir) + # 1.check whether file_dir is soft link + file_validator.check_not_soft_link() + # 2.check attribute file size + file_validator.check_file_size(fin) + file_validator.check() attributes = json.load(fin) return attributes else: -- Gitee From 67a7272e2f6bd7f39a038fad53b1445ebe7559bc Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 12 Jul 2023 18:43:50 +0800 Subject: [PATCH 195/551] Match-id-915fd32f9413b0047489259b5edc7bcb603812e1 --- mx_rec/core/embedding.py | 30 ++++++++++++++++++++++++++++++ mx_rec/validator/validator.py | 3 ++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 60ec218b..7648a986 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -12,6 +12,7 @@ import numpy as np import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops.init_ops import Initializer from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temporary_feature_spec_attribute @@ -27,6 +28,7 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is ConfigInitializer, get_ascend_global_hashtable_collection, get_host_pipeline_ops, get_use_dynamic_expansion, \ set_modify_graph, insert_removing_var_list from mx_rec.util.tf_version_adapter import npu_ops +from mx_rec.validator.validator import ClassValidator, StringValidator def create_table(**kwargs): @@ -63,6 +65,7 @@ def create_table(**kwargs): is_save: switch whether to store sparse table data. init_param: embedding init param-coefficient """ + check_create_table_params(key_dtype, dim, name, emb_initializer) config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, @@ -968,3 +971,30 @@ def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Opti id_offsets_expand = tf.repeat(id_offsets_expand, [tf.shape(embeddings)[-1]], axis=-1) embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like(embeddings)) return embeddings + + +def check_create_table_params(key_dtype, dim, name, emb_initializer): + """ + 校验create_table接口必选参数:key_dtype, dim, name, emb_initializer和optimizer_list(已有校验) + :param key_dtype: data type for feature id, tf.int64 or tf.int32 or tf.string + :param dim: embedding vector size, dim's type: int or tf.TensorShape + :param name: hash table name, name's type: str + :param emb_initializer: the initializer for embedding values + :return: + """ + # check key_dtype + if key_dtype not in [tf.int64, tf.int32, tf.string]: + raise ValueError(f"key_dtype: {key_dtype} not in [tf.int64, tf.int32, tf.string]") + # check dim + dim_validator = ClassValidator(value=dim, classes=(int, tf.TensorShape)) + dim_validator.check_isinstance() + dim_validator.check() + # check name + name_validator = StringValidator(name) + name_validator.check_string_length() + name_validator.check_whitelist() + name_validator.check() + # check emb_initializer + emb_initializer_validator = ClassValidator(value=emb_initializer, classes=Initializer) + emb_initializer_validator.check_isinstance() + emb_initializer_validator.check() diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index 80d3edb6..524e8381 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -65,7 +65,8 @@ class ClassValidator(Validator): def check_isinstance(self): """Check arg isinstance of classes""" - self.register_checker(lambda path: isinstance(self.value, self.classes), "Invalid parameter type") + self.register_checker(lambda path: isinstance(self.value, self.classes), f"Invalid parameter type, not " + f"in {self.classes}") return self -- Gitee From b962c212a1bafe6787bebabf1a0e9705603bc727 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 12 Jul 2023 19:01:57 +0800 Subject: [PATCH 196/551] Match-id-85f8833c61c87d17588fded0477ace97c58b9342 --- build/build.sh | 78 ++++++---------- build/build_all.sh | 221 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 50 deletions(-) create mode 100644 build/build_all.sh diff --git a/build/build.sh b/build/build.sh index 14e6f269..0a975d8f 100644 --- a/build/build.sh +++ b/build/build.sh @@ -11,39 +11,14 @@ ARCH="$(uname -m)" SCRIPT_DIR=$(dirname "$(readlink -f "$0")") ROOT_DIR=$(dirname "${SCRIPT_DIR}") cd "$SCRIPT_DIR" -if [ "$(uname -m)" = "aarch64" ] -then - virtualenv -p "$(which python3.7)" tf2_env - source tf2_env/bin/activate - tf265="tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl" - [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ - pip3 install "${tf265}" --no-deps - tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow - deactivate tf2_env - virtualenv -p "$(which python3.7)" tf1_env - source tf1_env/bin/activate - tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl" - [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ - pip3 install "${tf115}" --no-deps - tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core - deactivate tf1_env -fi if [ "$(uname -m)" = "x86_64" ] then - virtualenv -p "$(which python3.7)" tf2_env - source tf2_env/bin/activate - tf265="tensorflow_cpu-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl" - [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ - pip3 install "${tf265}" --no-deps - tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow + source /opt/buildtools/tf2_env/bin/activate + tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow deactivate tf2_env - virtualenv -p "$(which python3.7)" tf1_env - source tf1_env/bin/activate - tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl" - [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ - pip3 install "${tf115}" --no-deps - tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core + source /opt/buildtools/tf1_env/bin/activate + tf1_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow_core deactivate tf1_env fi @@ -164,7 +139,7 @@ gen_wheel_file() remove "${ROOT_DIR}"/mx_rec/libasc mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec cp -rf "${ROOT_DIR}"/tools "${ROOT_DIR}"/mx_rec - python3 setup.py bdist_wheel --plat-name=linux_$(arch) + python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) mkdir -p "$1" mv dist/mx_rec*.whl "$1" remove "${ROOT_DIR}"/mx_rec/libasc @@ -199,28 +174,31 @@ clean() remove "${ROOT_DIR}"/build/mindxsdk-mxrec } -install_abseil -compile_securec +if [ "$(uname -m)" = "x86_64" ] +then + install_abseil + compile_securec -echo "-----Build AccCTR -----" -compile_acc_ctr_so_file + echo "-----Build AccCTR -----" + compile_acc_ctr_so_file -echo "-----Build Start tf1 -----" -source "${SCRIPT_DIR}"/tf1_env/bin/activate -compile_so_file "${tf1_path}" -collect_so_file -gen_wheel_file "${ROOT_DIR}"/tf1_whl -deactivate tf1_env + echo "-----Build Start tf1 -----" + source /opt/buildtools/tf1_env/bin/activate + compile_so_file "${tf1_path}" + collect_so_file + gen_wheel_file "${ROOT_DIR}"/tf1_whl + deactivate tf1_env -echo "-----Build Start tf2 -----" -source "${SCRIPT_DIR}"/tf2_env/bin/activate -compile_so_file "${tf2_path}" -collect_so_file -gen_wheel_file "${ROOT_DIR}"/tf2_whl -deactivate tf2_env + echo "-----Build Start tf2 -----" + source /opt/buildtools/tf2_env/bin/activate + compile_so_file "${tf2_path}" + collect_so_file + gen_wheel_file "${ROOT_DIR}"/tf2_whl + deactivate tf2_env -echo "-----Build gen tar -----" -gen_tar_file + echo "-----Build gen tar -----" + gen_tar_file -clean -echo "-----Done-----" + clean + echo "-----Done-----" +fi \ No newline at end of file diff --git a/build/build_all.sh b/build/build_all.sh new file mode 100644 index 00000000..926e6977 --- /dev/null +++ b/build/build_all.sh @@ -0,0 +1,221 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. +# Description: build script. +# Author: MindX SDK +# Create: 2021 +# History: NA + +set -e +warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } +ARCH="$(uname -m)" +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +ROOT_DIR=$(dirname "${SCRIPT_DIR}") +cd "$SCRIPT_DIR" +if [ "$(uname -m)" = "aarch64" ] +then + virtualenv -p "$(which python3.7)" tf2_env + source tf2_env/bin/activate + tf265="tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl" + [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ + tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow + deactivate tf2_env + virtualenv -p "$(which python3.7)" tf1_env + source tf1_env/bin/activate + tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl" + [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ + tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core + deactivate tf1_env +fi + +if [ "$(uname -m)" = "x86_64" ] +then + source tf2_env/bin/activate + tf265="tensorflow_cpu-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl" + [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ + tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow + deactivate tf2_env + virtualenv -p "$(which python3.7)" tf1_env + source tf1_env/bin/activate + tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl" + [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ + tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core + deactivate tf1_env +fi + +VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml +get_version() { + if [ -f "$VERSION_FILE" ]; then + VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") + if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then + VERSION=${VERSION%.*} + fi + else + VERSION="5.0.T104" + fi +} + +remove() +{ + if [ -d "$1" ]; then + rm -rf "$1" + elif [ -f "$1" ]; then + rm -f "$1" + fi +} + +project_output_path="${ROOT_DIR}"/output/ +remove "${project_output_path}" +remove "${SCRIPT_DIR}/lib" +get_version +export VERSION +echo "MindX SDK mxrec: ${VERSION}" >> ./version.info + +pkg_dir=mindxsdk-mxrec +remove "${pkg_dir}" +mkdir "${pkg_dir}" +mv version.info "${pkg_dir}" + +opensource_path="${ROOT_DIR}"/../opensource/opensource +abseil_src_path=${opensource_path}/abseil +echo "${abseil_src_path}" +abseil_install_path="${ROOT_DIR}"/install/abseil + +src_path="${ROOT_DIR}"/src +acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR +cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c +cd "${ROOT_DIR}" + +release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz + +install_abseil() +{ + remove "${abseil_install_path}" + echo "${abseil_install_path}" + if [[ ! -d "${abseil_install_path}" ]] + then mkdir -p "${abseil_install_path}" + fi + + cd "${abseil_src_path}" + echo "${abseil_src_path}" + remove CMakeCache.txt + cmake -DCMAKE_INSTALL_PREFIX="${abseil_install_path}" . && make -j8 && make install + + echo "${project_output_path}"/abseil + mkdir -p "${project_output_path}"/abseil + if [ -d "${abseil_install_path}"/lib64/ ]; then + cp -rf "${abseil_install_path}"/lib64/libabsl* "${project_output_path}"/abseil + elif [ -d "${abseil_install_path}"/lib/ ]; then + cp -rf "${abseil_install_path}"/lib/libabsl* "${project_output_path}"/abseil + else + echo "${abseil_install_path}"/lib64/ not exist + exit 1 + fi +} + +compile_securec() +{ + if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then + echo "securec is not exist" + exit 1 + fi + + if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then + cd "${ROOT_DIR}"/platform/securec/src + make -j + fi +} + +compile_so_file() +{ + cd "${src_path}" + chmod u+x build.sh + ./build.sh "$1" "${ROOT_DIR}" + cd .. +} + +compile_acc_ctr_so_file() +{ + cd "${acc_ctr_path}" + chmod u+x build.sh + ./build.sh "release" +} + +collect_so_file() +{ + cd "${src_path}" + remove "${src_path}"/libasc + mkdir -p "${src_path}"/libasc + chmod u+x libasc + + cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc + cp -df "${ROOT_DIR}"/output/*.so* libasc + cp "${ROOT_DIR}"/platform/securec/lib/libsecurec.so libasc +} + +gen_wheel_file() +{ + cd "${ROOT_DIR}" + touch "${src_path}"/libasc/__init__.py + remove "${ROOT_DIR}"/mx_rec/libasc + mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec + cp -rf "${ROOT_DIR}"/tools "${ROOT_DIR}"/mx_rec + python3 setup.py bdist_wheel --plat-name=linux_$(arch) + mkdir -p "$1" + mv dist/mx_rec*.whl "$1" + remove "${ROOT_DIR}"/mx_rec/libasc +} + +gen_tar_file() +{ + cd "${src_path}" + mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" + mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" + cp -r "${src_path}"/../example ../build/"${pkg_dir}" + cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" + cd ../build + tar -zvcf "${release_tar}" "${pkg_dir}" || { + warn "compression failed, packages might be broken" + } + + mv "${release_tar}" "${SCRIPT_DIR}"/../output/ + +} + +clean() +{ + remove "${ROOT_DIR}"/dist + remove "${ROOT_DIR}"/install + remove "${ROOT_DIR}"/mx_rec.egg-info + remove "${ROOT_DIR}"/src/build + remove "${ROOT_DIR}"/build/bdist.linux-"$(arch)" + remove "${ROOT_DIR}"/build/tf1_env + remove "${ROOT_DIR}"/build/tf2_env + remove "${ROOT_DIR}"/build/lib + remove "${ROOT_DIR}"/build/mindxsdk-mxrec +} + +install_abseil +compile_securec + +echo "-----Build AccCTR -----" +compile_acc_ctr_so_file + +echo "-----Build Start tf1 -----" +source "${SCRIPT_DIR}"/tf1_env/bin/activate +compile_so_file "${tf1_path}" +collect_so_file +gen_wheel_file "${ROOT_DIR}"/tf1_whl +deactivate tf1_env + +echo "-----Build Start tf2 -----" +source "${SCRIPT_DIR}"/tf2_env/bin/activate +compile_so_file "${tf2_path}" +collect_so_file +gen_wheel_file "${ROOT_DIR}"/tf2_whl +deactivate tf2_env + +echo "-----Build gen tar -----" +gen_tar_file + +clean +echo "-----Done-----" -- Gitee From d9f046a5a667f0c10e12d53128902b43e20d5cb0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 12 Jul 2023 15:38:26 +0800 Subject: [PATCH 197/551] Match-id-932ef472ad874068eb6417782e5475706fc09103 --- example/little_demo/main.py | 13 +++- example/little_demo/run.sh | 2 +- mx_rec/constants/constants.py | 20 ++++- mx_rec/core/asc/helper.py | 1 - mx_rec/core/embedding.py | 73 +++++++++++++------ mx_rec/optimizers/adagrad.py | 1 - mx_rec/optimizers/base.py | 1 - mx_rec/optimizers/ftrl.py | 7 -- mx_rec/optimizers/ftrl_t.py | 7 -- mx_rec/optimizers/ftrl_t_dense.py | 3 - mx_rec/optimizers/gradient_descent.py | 2 - mx_rec/optimizers/gradient_descent_by_addr.py | 2 - mx_rec/optimizers/lazy_adam.py | 6 -- mx_rec/optimizers/lazy_adam_by_addr.py | 2 - mx_rec/optimizers/momentum.py | 2 - 15 files changed, 79 insertions(+), 63 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 37c13eb8..e6b28a82 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -20,8 +20,9 @@ from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import create_table, sparse_lookup from mx_rec.graph.modifier import modify_graph_and_start_emb_cache from mx_rec.constants.constants import MxRecMode, ASCEND_TIMESTAMP -from mx_rec.util.initialize import get_rank_id, init, terminate_config_initializer, set_if_load +from mx_rec.util.initialize import get_rank_id, init, terminate_config_initializer, set_if_load, get_rank_size from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.constants.constants import ApplyGradientsStrategy tf.compat.v1.disable_eager_execution() @@ -165,7 +166,9 @@ if __name__ == "__main__": device_vocabulary_size=cfg.user_vocab_size * 10, host_vocabulary_size=0, optimizer_list=sparse_optimizer_list, - mode=mode) + mode=mode, + all2all_gradients_op="sum_gradients_and_div_by_ranksize", + apply_gradients_strategy = ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY) item_hashtable = create_table(key_dtype=tf.int64, dim=tf.TensorShape([cfg.item_hashtable_dim]), @@ -178,10 +181,12 @@ if __name__ == "__main__": train_iterator, train_model = build_graph([user_hashtable, item_hashtable], is_train=True, feature_spec_list=train_feature_spec_list, - config_dict=ACCESS_AND_EVICT, batch_number=cfg.batch_number) + config_dict=ACCESS_AND_EVICT, + batch_number=TRAIN_STEPS * get_rank_size()) eval_iterator, eval_model = build_graph([user_hashtable, item_hashtable], is_train=False, feature_spec_list=eval_feature_spec_list, - config_dict=ACCESS_AND_EVICT, batch_number=cfg.batch_number) + config_dict=ACCESS_AND_EVICT, + batch_number=EVAL_STEPS * get_rank_size()) dense_variables, sparse_variables = get_dense_and_sparse_variable() run_mode = RunMode( diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index ecc70ade..6cf25dee 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -56,7 +56,7 @@ num_process=$((${num_server} * ${local_rank_size})) # 训练总的进程数, export HCCL_CONNECT_TIMEOUT=1200 # HCCL集合通信 建链超时时间,取值范围[120,7200] export PYTHONPATH=${so_path}:$PYTHONPATH # 环境python安装路径 -export LD_PRELOAD=/usr/lib64/libgomp.so.1 # GNU OpenMP动态库路径 +#export LD_PRELOAD=/usr/lib64/libgomp.so.1 # GNU OpenMP动态库路径. 不应该使用LD_PRELOAD这种方式加载! export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH # 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 export JOB_ID=10086 diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index cde18ae9..57b7800c 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -53,11 +53,15 @@ class BaseEnum(Enum): @classmethod def mapping(cls, key): for mode in cls: - if key == mode.value: + if isinstance(key, BaseEnum): + key_value = key.value + else: + key_value = key + if key_value == mode.value: return mode raise KeyError(f"Cannot find a corresponding mode in current Enum " - f"class {cls}, given parameter '{key}' is illegal, " + f"class {cls}, given parameter '{key}[{key.__class__}]' is illegal, " f"please choose a valid one from " f"'{list(map(lambda c: c.value, cls))}'.") @@ -106,3 +110,15 @@ class OptimizerType(Enum): OPTIMIZER_STATE_META = {OptimizerType.LAZY_ADAM: ["momentum", "velocity"], OptimizerType.SGD: []} + + +class All2allGradientsOp(BaseEnum): + SUM_GRADIENTS = "sum_gradients" + SUM_GRADIENTS_AND_DIV_BY_RANKSIZE = "sum_gradients_and_div_by_ranksize" + + +class ApplyGradientsStrategy(BaseEnum): + DIRECT_APPLY = "direct_apply" + SUM_SAME_ID_GRADIENTS_AND_APPLY = "sum_same_id_gradients_and_apply" + + diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 7b628fc2..dbc93eae 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -558,7 +558,6 @@ def get_asc_read_raw_func(cfg_list): float_split_res = tf.split(raw_float_sample, [i * line_per_sample_list[0] for i in float_len_list]) - logging.debug(f"############ Enter read_raw_fn ########") for name_id, name in enumerate(int_name_order): batch[name] = int_split_res[name_id] diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 7648a986..7c977432 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -21,7 +21,7 @@ from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPA ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, \ - ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT + ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT, All2allGradientsOp, ApplyGradientsStrategy from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, trigger_evict, get_table_instance_by_name, get_use_hot, get_device_id, export_feature_spec, \ @@ -32,21 +32,6 @@ from mx_rec.validator.validator import ClassValidator, StringValidator def create_table(**kwargs): - key_dtype = kwargs.get("key_dtype") - dim = kwargs.get("dim") - name = kwargs.get("name") - emb_initializer = kwargs.get("emb_initializer") - device_vocabulary_size = kwargs.get("device_vocabulary_size", 1) - host_vocabulary_size = kwargs.get("host_vocabulary_size", 0) - optimizer_list = kwargs.get("optimizer_list") - mode = kwargs.get("mode", MxRecMode.ASC) - value_dtype = kwargs.get("value_dtype", tf.float32) - shard_num = kwargs.get("shard_num", 1) - fusion_optimizer_var = kwargs.get("fusion_optimizer_var", True) - hashtable_threshold = kwargs.get("hashtable_threshold", 0) - is_save = kwargs.get("is_save", True) - init_param = kwargs.get("init_param", 1.0) - """ Args: key_dtype: data type for feature id @@ -64,14 +49,34 @@ def create_table(**kwargs): hashtable_threshold: choose to implement based on hash table or linear layer is_save: switch whether to store sparse table data. init_param: embedding init param-coefficient + all2all_gradients_op: sum_grads (default) or sum_gradients_and_div_by_ranksize. + apply_gradients_strategy: direct_apply (default) or sum_same_id_gradients_and_apply. + """ check_create_table_params(key_dtype, dim, name, emb_initializer) + key_dtype = kwargs.get("key_dtype") + dim = kwargs.get("dim") + name = kwargs.get("name") + emb_initializer = kwargs.get("emb_initializer") + device_vocabulary_size = kwargs.get("device_vocabulary_size", 1) + host_vocabulary_size = kwargs.get("host_vocabulary_size", 0) + optimizer_list = kwargs.get("optimizer_list") + mode = kwargs.get("mode", MxRecMode.ASC) + value_dtype = kwargs.get("value_dtype", tf.float32) + shard_num = kwargs.get("shard_num", 1) + fusion_optimizer_var = kwargs.get("fusion_optimizer_var", True) + hashtable_threshold = kwargs.get("hashtable_threshold", 0) + is_save = kwargs.get("is_save", True) + init_param = kwargs.get("init_param", 1.0) + all2all_gradients_op = kwargs.get("all2all_gradients_op", All2allGradientsOp.SUM_GRADIENTS) + apply_gradients_strategy = kwargs.get("apply_gradients_strategy", ApplyGradientsStrategy.DIRECT_APPLY) config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, optimizer_list=optimizer_list, mode=mode, value_dtype=value_dtype, shard_num=shard_num, fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold, - init_param=init_param, is_save=is_save) + init_param=init_param, is_save=is_save, all2all_gradients_op=all2all_gradients_op, + apply_gradients_strategy=apply_gradients_strategy) embedding = SparseEmbedding(config) return embedding @@ -166,6 +171,9 @@ class SparseEmbedding: self.lookup_name_list = [] self.modify_graph = False self.init_param = config.get("init_param") + self.all2all_gradients_op = All2allGradientsOp.mapping(config.get("all2all_gradients_op")) + self.apply_gradients_strategy = ApplyGradientsStrategy.mapping( + config.get("apply_gradients_strategy")) self.set_slice_vocab_size() self.set_emb_size() @@ -514,7 +522,6 @@ class SparseEmbedding: def grad(lookup_diff): embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) - logging.debug(f"bp rank size: {rank_size}") unique_embeddings_shape = unique_embeddings.shape.as_list() if use_static \ else tf.shape(unique_embeddings) unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, @@ -525,9 +532,20 @@ class SparseEmbedding: tf.shape(unique_grads)[0] - tf.shape(hot_pos)[0]], axis=0) unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) local_grad = get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) - + if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: + local_grad = local_grad / get_rank_size() + if use_dynamic_expansion: return local_grad, feat_ids + + if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + unique_id_offsets, unique_id_offsets_position = array_ops.unique(id_offsets) + unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, + unique_id_offsets_position, + array_ops.shape(unique_id_offsets)[0]) + return ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, + dense_shape=tf.shape(table)), feat_ids + return ops.IndexedSlices(values=local_grad, indices=id_offsets, dense_shape=tf.shape(table)), feat_ids return lookup_result, grad @@ -723,7 +741,6 @@ class SparseEmbedding: def grad(lookup_diff): embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) - logging.debug(f"bp rank size: {rank_size}") unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, unique_embeddings_shape[0]) @@ -733,11 +750,23 @@ class SparseEmbedding: tf.shape(unique_grads)[0] - tf.shape(hot_pos)[0]], axis=0) unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) local_grad = get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) + if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: + local_grad = local_grad / get_rank_size() + if use_dynamic_expansion: update_grad = local_grad else: - update_grad = ops.IndexedSlices(values=local_grad, indices=id_offsets, - dense_shape=tf.shape(table)) + if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + unique_id_offsets, unique_id_offsets_position = array_ops.unique(id_offsets) + unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, + unique_id_offsets_position, + array_ops.shape(unique_id_offsets)[0]) + update_grad = ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, + dense_shape=tf.shape(table)) + else: + + update_grad = ops.IndexedSlices(values=local_grad, indices=id_offsets, + dense_shape=tf.shape(table)) return update_grad return lookup_result, grad diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index 306208ca..5430373a 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -91,7 +91,6 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): check_param_type("use_locking", self._use_locking, bool) def _create_slots(self, var_list): - logging.debug(" Start _create_slots") for var in var_list: dtype = var.dtype.base_dtype if var.get_shape().is_fully_defined(): diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index b4db8856..9ad62ec6 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -43,7 +43,6 @@ class CustomizedOptimizer: def my_update_op(self, opt, grad): if isinstance(grad, ops.Tensor): - logging.debug(">>>>Enter update_op ops.Tensor") update_op = opt._apply_sparse(grad, self._v) # pylint: disable=protected-access return update_op else: diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index d2c39367..0b4efb66 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -98,15 +98,12 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): return [self._initial_accumulator_value, initial_linear_value] def _apply_sparse_duplicate_indices(self, grad, var): - logging.debug(f"######### _apply_sparse_duplicate_indices {var}") return self._apply_sparse(grad, var) def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): - logging.debug(f"######### _resource_apply_sparse_duplicate_indices {indices}") return self._resource_apply_sparse(grad, handle, indices) def _resource_apply_sparse(self, grad, handle, indices): - logging.debug("Enter _resource_apply_sparse") if self._l2_shrinkage_regularization_strength <= 0.0: return self._apply_sparse_shared( grad, @@ -121,7 +118,6 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): self._resource_scatter_nd_update) def _apply_sparse(self, grad, var): - logging.debug("Enter _apply_sparse") if self._l2_shrinkage_regularization_strength <= 0.0: return self._apply_sparse_shared( grad.values, @@ -136,7 +132,6 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): lambda x, i, v: tf.compat.v1.scatter_nd_update(x, i, v)) def _apply_sparse_shared(self, grad, var, indices, scatter_nd_update): - logging.debug("Enter _apply_sparse_shared") accum = self.get_slot(var, "accum") linear = self.get_slot(var, "linear") lr = math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) @@ -174,7 +169,6 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): return control_flow_ops.group(accum_update_op, linear_update_op, var_update_op) def _apply_sparse_shared_v2(self, grad, var, indices, scatter_nd_update): - logging.debug("Enter _apply_sparse_shared_v2") accum = self.get_slot(var, "accum") linear = self.get_slot(var, "linear") lr = math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) @@ -221,7 +215,6 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): return x.value() def _create_slots(self, var_list): - logging.debug(" Enter _create_slots") # Create slots for the first and second moments. accum_state_name = self._name + "/" + "accum" diff --git a/mx_rec/optimizers/ftrl_t.py b/mx_rec/optimizers/ftrl_t.py index 4337841a..12710c57 100644 --- a/mx_rec/optimizers/ftrl_t.py +++ b/mx_rec/optimizers/ftrl_t.py @@ -101,15 +101,12 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): self._grad_factor_tensor = ops.convert_to_tensor(self._grad_factor, name="grad_factor") def _apply_sparse_duplicate_indices(self, grad, var): - logging.debug(f"######### _apply_sparse_duplicate_indices {var}") return self._apply_sparse(grad, var) def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): - logging.debug(f"######### _resource_apply_sparse_duplicate_indices {indices}") return self._resource_apply_sparse(grad, handle, indices) def _resource_apply_sparse(self, grad, handle, indices): - logging.debug("Enter _resource_apply_sparse") if self._lambda1 > 1e-10: return self._apply_sparse_shared( grad, @@ -124,7 +121,6 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): self._resource_scatter_nd_update) def _apply_sparse(self, grad, var): - logging.debug("Enter _apply_sparse") if self._lambda1 > 1e-10: return self._apply_sparse_shared( grad.values, @@ -139,7 +135,6 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): lambda x, i, v: tf.compat.v1.scatter_nd_update(x, i, v)) def _apply_sparse_shared(self, grad, var, indices, scatter_nd_update): - logging.debug("Enter _apply_sparse_shared") z = self.get_slot(var, "z") n = self.get_slot(var, "n") g = self.get_slot(var, "g") @@ -183,7 +178,6 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): return control_flow_ops.group(g_update, z_update, n_update, w_update, var_update) def _apply_sparse_shared_v2(self, grad, var, indices, scatter_nd_update): - logging.debug("Enter _apply_sparse_shared_v2") z = self.get_slot(var, "z") n = self.get_slot(var, "n") g = self.get_slot(var, "g") @@ -229,7 +223,6 @@ class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): return x.value() def _create_slots(self, var_list): - logging.debug(" Enter _create_slots") # Create slots for the first and second moments. z_state_name = self._name + "/" + "z" diff --git a/mx_rec/optimizers/ftrl_t_dense.py b/mx_rec/optimizers/ftrl_t_dense.py index 364267e3..4dbbdb50 100644 --- a/mx_rec/optimizers/ftrl_t_dense.py +++ b/mx_rec/optimizers/ftrl_t_dense.py @@ -85,7 +85,6 @@ class CustomizedFtrlTZ(optimizer.Optimizer): var) def _apply_dense_shared(self, grad, var): - logging.debug("Enter _apply_dense_shared") z_var = self.get_slot(var, "z") n_var = self.get_slot(var, "n") g_var = self.get_slot(var, "g") @@ -126,7 +125,6 @@ class CustomizedFtrlTZ(optimizer.Optimizer): return control_flow_ops.group(g_update, z_update, n_update, w_update, var_updata) def _apply_dense_shared_v2(self, grad, var): - logging.debug("Enter _apply_dense_shared_v2") z_var = self.get_slot(var, "z") n_var = self.get_slot(var, "n") g_var = self.get_slot(var, "g") @@ -169,7 +167,6 @@ class CustomizedFtrlTZ(optimizer.Optimizer): return x_input.value() def _create_slots(self, var_list): - logging.debug(" Enter _create_slots") # Create slots for the first and second moments. z_state_name = self._name + "/" + "z" diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index ac2206a1..25747aed 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -42,12 +42,10 @@ class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, Custo return [] def _apply_sparse_duplicate_indices(self, grad, var): - logging.debug(" Enter _apply_sparse_duplicate_indices") nd_indices = tf.expand_dims(grad.indices, 1) nd_value = grad.values * math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) var_update_op = tf.scatter_nd_add(var, nd_indices, -nd_value, use_locking=self._use_locking) return var_update_op def _apply_dense(self, grad, var): - logging.debug(" Enter _apply_dense") raise NotImplementedError("You are using a wrong type of variable.") diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index de00dc06..0fa3cb6e 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -48,7 +48,6 @@ class CustomizedGradientDescentByAddr(gradient_descent.GradientDescentOptimizer, return [] def _apply_sparse(self, grad, addr): - logging.debug(">>>> Enter _apply_sparse SGD by addr") host_pipeline_ops = get_host_pipeline_ops() dim = grad.shape.as_list()[-1] if self.weight_decay is None: @@ -63,7 +62,6 @@ class CustomizedGradientDescentByAddr(gradient_descent.GradientDescentOptimizer, return var_update_op def _apply_dense(self, grad, var): - logging.debug(">>>> Enter _apply_dense") raise NotImplementedError("You are using a wrong type of variable.") diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 42904df0..4f1ee9e9 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -95,15 +95,12 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): def _apply_sparse_duplicate_indices(self, grad, var): # _apply_sparse_duplicate_indices method include tf.unique and unsorted_segment_sum operations which may # introduce dynamic shape problem, if encounter that, please de-annotation the method below. - logging.debug(f"_apply_sparse_duplicate_indices {var}") return self._apply_sparse(grad, var) def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): - logging.debug(f"_resource_apply_sparse_duplicate_indices {indices}") return self._resource_apply_sparse(grad, handle, indices) def _apply_dense(self, grad, var): - logging.debug("Enter _apply_dense") raise NotImplementedError("You are using a wrong type of variable.") def _cast_to_base_type(self, var): @@ -121,7 +118,6 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): return temp def _resource_apply_sparse(self, grad, handle, indices): - logging.debug("Enter _resource_apply_sparse") return self._apply_sparse_shared( grad, handle, @@ -129,7 +125,6 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): self._resource_scatter_nd_add) def _apply_sparse(self, grad, var): - logging.debug("Enter _apply_sparse") return self._apply_sparse_shared( grad.values, var, @@ -170,7 +165,6 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): return x.value() def _create_slots(self, var_list): - logging.debug(" Enter _create_slots") first_var = min(var_list, key=lambda x: x.name) self._create_non_slot_variable( initial_value=self._beta1, name="beta1_power", colocate_with=first_var) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 10a785d6..76f6f72d 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -81,7 +81,6 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): initial_value=self._beta2, name="beta2_power", colocate_with=first_addr) def _apply_dense(self, grad, var): - logging.debug(">>>>Enter _apply_dense") raise NotImplementedError("You are using a wrong type of variable.") def _cast_to_base_type(self, var): @@ -99,7 +98,6 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): return temp def _apply_sparse(self, grad, addr): - logging.debug(">>>> Enter _apply_sparse Lazy_adam by addr") return self._apply_sparse_shared( grad, addr) diff --git a/mx_rec/optimizers/momentum.py b/mx_rec/optimizers/momentum.py index df4adaeb..daae8119 100644 --- a/mx_rec/optimizers/momentum.py +++ b/mx_rec/optimizers/momentum.py @@ -102,7 +102,6 @@ class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): check_param_range("momentum", self._momentum, 0.0, 1.0) def _create_slots(self, var_list): - logging.debug(" Start _create_slots") m_state_name = self._name + "/" + "momentum" for var in var_list: table_instance = check_and_get_config_via_var(var, self.optimizer_type) @@ -110,7 +109,6 @@ class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): insert_removing_var_list(momentum_slot.name) if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"momentum": momentum_slot}) - logging.debug(" End _create_slots") def _apply_sparse(self, grad, var): mom = self.get_slot(var, "m") -- Gitee From e781219a2c8c6f55959b1bad9d2d4d6f8fbbb2ae Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 13 Jul 2023 17:03:18 +0800 Subject: [PATCH 198/551] Match-id-e656ced39406df725adcf0e35885f8ed1afb56bf --- mx_rec/core/asc/build_graph.py | 74 ++++++++++++++++--------------- mx_rec/util/synchronizer.py | 79 ---------------------------------- 2 files changed, 39 insertions(+), 114 deletions(-) delete mode 100644 mx_rec/util/synchronizer.py diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index f5ebee29..81bb27a8 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -7,7 +7,6 @@ import logging import tensorflow as tf import mxrec_pybind -from mx_rec.constants.constants import AVOID_TENSOR_POS from mx_rec.util.initialize import get_use_static from mx_rec.util.tf_version_adapter import npu_ops @@ -35,18 +34,20 @@ def get_restore_vector(config): else: restore_size = None - if use_hot: - device_id = int(config.get("device_id")) - hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) - restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32, tf.int32], - output_shapes=[restore_size, [hot_size]], - channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}') - else: - restore_vector = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32], - output_shapes=[restore_size], - channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}')[0] + with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + if use_hot: + device_id = int(config.get("device_id")) + hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) + restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[restore_size, [hot_size]], + channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}' + ) + else: + restore_vector = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[restore_size], + channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}')[0] return restore_vector, hot_pos @@ -54,23 +55,24 @@ def get_restore_vector(config): def get_id_offsets(max_lookup_vec_size, config): logging.debug(f'Channel {config.get("table_name")}_lookup_{config.get("channel_id")} was built for getnext') # 自动扩容当前只支持HBM模式,默认没有换入换出 - if config.get("use_dynamic_expansion"): + with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + if config.get("use_dynamic_expansion"): + [id_offsets] = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int64], + output_shapes=[[max_lookup_vec_size]], + channel_name=f'{config.get("table_name")}_lookup_{config.get("channel_id")}') + return id_offsets, [], 0 + [id_offsets] = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int64], + output_types=[tf.int32], output_shapes=[[max_lookup_vec_size]], channel_name=f'{config.get("table_name")}_lookup_{config.get("channel_id")}') - return id_offsets, [], 0 - - [id_offsets] = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32], - output_shapes=[[max_lookup_vec_size]], - channel_name=f'{config.get("table_name")}_lookup_{config.get("channel_id")}') - if config.get("skip_emb_transfer"): - return id_offsets, [], 0 - swap_pos, swap_len = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32, tf.int32], - output_shapes=[[max_lookup_vec_size], []], - channel_name=f'{config.get("table_name")}_swap_{config.get("channel_id")}') + if config.get("skip_emb_transfer"): + return id_offsets, [], 0 + swap_pos, swap_len = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[[max_lookup_vec_size], []], + channel_name=f'{config.get("table_name")}_swap_{config.get("channel_id")}') return id_offsets, swap_pos, swap_len @@ -82,14 +84,16 @@ def get_all2all_args(use_static: bool, config: dict) -> list: :return: all2all parametrs """ all2all_args = None - if not use_static: - with tf.compat.v1.variable_scope("all2all"): - logging.debug(f'Channel {config.get("table_name")}_a2a_{config.get("channel_id")} was built for getnext') - all2all_args = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int64], - output_shapes=[[config.get("rank_size"), config.get("rank_size")]], - channel_name=f'{config.get("table_name")}_all2all_{config.get("channel_id")}', - name="a2a_get_next")[0] * config.get("emb_size") + with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + if not use_static: + with tf.compat.v1.variable_scope("all2all"): + logging.debug( + f'Channel {config.get("table_name")}_a2a_{config.get("channel_id")} was built for getnext') + all2all_args = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int64], + output_shapes=[[config.get("rank_size"), config.get("rank_size")]], + channel_name=f'{config.get("table_name")}_all2all_{config.get("channel_id")}', + name="a2a_get_next")[0] * config.get("emb_size") return all2all_args diff --git a/mx_rec/util/synchronizer.py b/mx_rec/util/synchronizer.py deleted file mode 100644 index 9fba25dd..00000000 --- a/mx_rec/util/synchronizer.py +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - -import logging -import socket -from time import sleep - -from mx_rec.util.initialize import get_rank_id - - -class Communicator: - def __init__(self): - self.socket = socket.socket() - self.host = socket.gethostname() - logging.debug(f"host: {self.host}") - self.port = 12345 - self.rank_id = get_rank_id() - self.local_rank_id = self.rank_id % 8 - self.build_connection() - - def build_connection(self): - if self.local_rank_id == 0: - self.socket.bind((self.host, self.port)) - self.socket.listen(8) - - else: - i = 0 - while True: - try: - self.socket.connect((self.host, self.port)) - break - except ConnectionRefusedError: - sleep(0.01) - - i += 1 - logging.debug(f"Connection failed at the NO.{i} time for local rank id {self.local_rank_id}, " - f"rank id {self.rank_id}") - if i > 200: - raise EnvironmentError(f"Socket connecting over time.") - - logging.debug(f"Connection was build for local rank id {self.local_rank_id}, rank id {self.rank_id}") - - - def server_reply(self): - conn, address = self.socket.accept() - client_data = conn.recv(1024).decode() - logging.debug(f"connecting address:{address}") - logging.debug(f"Receive client msg: {client_data}") - conn.send(b"Acknowledged!") - conn.close() - return client_data - - def client_connect(self): - info = str(self.local_rank_id).encode() - self.socket.send(info) - server_reply = self.socket.recv(1024).decode() - if server_reply != "Acknowledged!": - raise IOError("Got a unexpected string.") - - logging.debug(f"Got the reply from local rank 0 for local rank id {self.local_rank_id}, " - f"rank id {self.rank_id}.") - - self.socket.close() - - -if __name__ == "__main__": - communicator = Communicator() - if communicator.local_rank_id != 0: - communicator.client_connect() - - else: - synchronizer_check_list = [i for i in range(1, 8)] - while synchronizer_check_list: - idx = int(communicator.server_reply()) - synchronizer_check_list.remove(idx) - logging.info(f"Remove NO.{idx} element for synchronizer_check_list.") - - logging.info(f"Saver synchronized.") -- Gitee From a1b3b1e8b5c94b2da8c006e690c67a3da438f362 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 13 Jul 2023 17:25:43 +0800 Subject: [PATCH 199/551] Match-id-e355af09ea63e9dbab35ce0391d4f257dd05977e --- example/little_demo/main.py | 18 +++++++++++------- mx_rec/core/embedding.py | 7 ++++++- mx_rec/util/initialize.py | 4 ++-- src/ops_tf/hybrid_dataset_ops.cpp | 5 +++++ 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 37c13eb8..e4ef5e83 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -120,13 +120,17 @@ if __name__ == "__main__": SAVING_INTERVAL = 100 # get init configuration - use_mpi = bool(int(os.getenv("USE_MPI", 1))) - use_dynamic = int(os.getenv("USE_DYNAMIC", 0)) - use_hot = bool(int(os.getenv("USE_HOT", 0))) - use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) - use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 1))) - MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) - USE_TIMESTAMP = bool(int(os.getenv("USE_TIMESTAMP", 0))) + try: + use_mpi = bool(int(os.getenv("USE_MPI", 1))) + use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) + use_hot = bool(int(os.getenv("USE_HOT", 0))) + use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) + use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 1))) + MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) + USE_TIMESTAMP = bool(int(os.getenv("USE_TIMESTAMP", 0))) + except ValueError as err: + raise ValueError(f"please correctly config USE_MPI or USE_DYNAMIC or USE_HOT or USE_DYNAMIC_EXPANSION or " + f"USE_MULTI_LOOKUP or USE_MODIFY_GRAPH or USE_TIMESTAMP only 0 or 1 is supported.") from err # nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0 init(use_mpi=use_mpi, diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 7648a986..3bc52ed7 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -76,13 +76,14 @@ def create_table(**kwargs): return embedding -def sparse_lookup(hashtable, ids, send_count, **kwargs): +def sparse_lookup(hashtable, ids, send_count, is_train, **kwargs): """ Args: hashtable: SparseEmbedding instance to be looked up ids: Tensor to lookup from hashtable send_count: used to config all2all communication parameters + is_train: indicates whether the mode is train. kwargs: dim: not in use is_train: not in use @@ -102,6 +103,9 @@ def sparse_lookup(hashtable, ids, send_count, **kwargs): if not isinstance(kwargs.get("modify_graph"), bool): raise TypeError("Given name must be a boolean.") + if not isinstance(kwargs.get("is_train"), bool): + raise TypeError("Given name must be a boolean.") + def check_table_legality_for_feature_spec(table, feature_spec): # check whether the name of the table exists with FeatureSpec. if table.table_name != feature_spec.table_name: @@ -112,6 +116,7 @@ def sparse_lookup(hashtable, ids, send_count, **kwargs): if not kwargs.get("modify_graph"): raise ValueError(f"modify_graph must be turn-on when lookup by ids(Tensor, not FeatureSpec).") + kwargs["is_train"] = is_train check_lookup_kwargs() scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) with tf.compat.v1.variable_scope(scope_name): diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 46763197..0c4d13f0 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -70,9 +70,9 @@ class ConfigInitializer: self.check_parameters() self.prefetch_batch_number = kwargs.get("prefetch_batch_number", 1) self.if_load = kwargs.get("if_load", False) - if_dynamic = kwargs.get("use_dynamic", 1) + if_dynamic = kwargs.get("use_dynamic", True) - self.use_static = 0 if if_dynamic == 1 else 1 + self.use_static = not if_dynamic self.use_hot = kwargs.get("use_hot", True) self.use_dynamic_expansion = kwargs.get("use_dynamic_expansion", False) if kwargs.get("bind_cpu", True): diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index ffac87d3..3d071e10 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -584,6 +584,11 @@ public: auto l = lookupVec->flat(); auto r = restoreVecTensor->flat(); + // check whether lookupLen is zero + if (lookupLen == 0) { + throw runtime_error("lookupLen is 0, it causes the denominator to be 0 during division"); + } + // dummy data for (int i { 0 }; i < lookupLen; ++i) { l(i) = i; -- Gitee From e5f8f0979dab995d16b3b14ee59ff8160a34b7e1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 13 Jul 2023 19:31:36 +0800 Subject: [PATCH 200/551] Match-id-2add46ccc11d551124c848cf24f01c0d33f1d363 --- mx_rec/core/asc/helper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 7b628fc2..a27a28c2 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -70,8 +70,8 @@ def find_dangling_table(table_names: List[str]): if 'gradients/' in table_reachable_tensor.name and table_reachable_tensor.op.type == 'Identity': return True - if 'logistic_loss' in table_reachable_tensor.op.name and table_reachable_tensor.op.type == 'AddV2': - return True + # if 'logistic_loss' in table_reachable_tensor.op.name and table_reachable_tensor.op.type == 'AddV2': + # return True if 'SparseSoftmaxCrossEntropyWithLogits' in table_reachable_tensor.op.name \ and table_reachable_tensor.op.type == 'SparseSoftmaxCrossEntropyWithLogits': -- Gitee From 39e1a8e9f136449b5650c7cb99c3371ef723ae1b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 13 Jul 2023 19:32:35 +0800 Subject: [PATCH 201/551] Match-id-1ef974df1990cdf32562bc55aac682227475fcfc --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 27 +++++++++++++-- tools/perf/fast.sh | 51 ++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 1e775117..32a8beac 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -704,7 +704,9 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) for (const auto& embInfo : mgmtEmbInfo) { ifHashmapFree = ProcessEmbInfo(embInfo.name, batchId, channelId, iBatch, remainBatch); if (!remainBatch) { + TimeCost embHdTrans1; EmbHDTransWrap(channelId, batchId, start, iBatch); + TIME_PRINT("embHdTrans1TC TimeCost(ms):{}", embHdTrans1.ElapsedMS()); return false; } } @@ -717,7 +719,9 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) if (!isRunning) { return false; } + TimeCost embHdTrans2TC; EmbHDTransWrap(channelId, batchId - 1, start, iBatch); + TIME_PRINT("embHdTrans2TC TimeCost(ms):{}", embHdTrans2TC.ElapsedMS()); TIME_PRINT("[{}]-{}, parseKeyTC TimeCost(ms):{}", channelId, batchId, parseKeyTC.ElapsedMS()); #endif return true; @@ -727,6 +731,8 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut) { + TimeCost getAndSendTensorsTC; + TimeCost getTensorsTC; auto& embHashMap = hostHashMaps->embHashMaps.at(embName); if (iBatch == 0) { embHashMap.SetStartCount(); @@ -736,11 +742,17 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, remainBatchOut = false; } - TimeCost getAndSendTensorsTC; auto restore = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); + TIME_PRINT("getTensorsTC(ms):{}", getTensorsTC.ElapsedMS()); + hdTransfer->Send(TransferChannel::RESTORE, *restore, channelId, embName); vector tmpData; + + TimeCost hostHashMapProcessTC; hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData, channelId); + TIME_PRINT("hostHashMapProcessTC(ms):{}", hostHashMapProcessTC.ElapsedMS()); + + TimeCost sendTensorsTC; hdTransfer->Send(TransferChannel::LOOKUP, { tmpData.front() }, channelId, embName); tmpData.erase(tmpData.begin()); hdTransfer->Send(TransferChannel::SWAP, tmpData, channelId, embName); @@ -748,6 +760,8 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } + TIME_PRINT("sendTensorsTC(ms):{}", sendTensorsTC.ElapsedMS()); + TIME_PRINT("getAndSendTensorsTC(ms):{}, channelId:{}", getAndSendTensorsTC.ElapsedMS(), channelId); if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch @@ -765,7 +779,10 @@ void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, in return; } spdlog::info(MGMT + "trans emb, batchId:[{}-{}], channelId:{}", start, batchId, channelId); + TimeCost hostEmbsTC; hostEmbs->Join(channelId); + TIME_PRINT("hostEmbsTC(ms):{}", hostEmbsTC.ElapsedMS()); + EmbHDTrans(channelId, batchId); for (int i = 0; i < iBatch - 1; ++i) { @@ -781,12 +798,16 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) EASY_VALUE("mgmtProcess", batchId) spdlog::debug(MGMT + "trans emb, batchId:{}, channelId:{}", batchId, channelId); TimeCost tr; + TimeCost h2dTC; for (const auto& embInfo: mgmtEmbInfo) { auto& missingKeys = hostHashMaps->embHashMaps.at(embInfo.name).missingKeysHostPos; vector h2dEmb; hostEmbs->GetH2DEmb(missingKeys, embInfo.name, h2dEmb); // order! hdTransfer->Send(TransferChannel::H2D, h2dEmb, channelId, embInfo.name, batchId); } + TIME_PRINT("h2dTC(ms):{}", h2dTC.ElapsedMS()); + + TimeCost d2hTC; for (const auto& embInfo: mgmtEmbInfo) { const auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); if (!(skipUpdate && missingKeys.empty())) { @@ -799,7 +820,9 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) } // skip when skip update and empty missing keys hostHashMaps->ClearMissingKeys(embInfo.name); } - TIME_PRINT("EmbHDTrans TimeCost(ms):{} batchId:{} channelId:{}", tr.ElapsedMS(), batchId, channelId); + TIME_PRINT("d2hTC(ms):{}", d2hTC.ElapsedMS()); + + TIME_PRINT("EmbHDTrans TimeCost(ms):{} batchId: {} channelId:{}", tr.ElapsedMS(), batchId, channelId); } void HybridMgmt::EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo) diff --git a/tools/perf/fast.sh b/tools/perf/fast.sh index 00dce55c..aae6f9d4 100755 --- a/tools/perf/fast.sh +++ b/tools/perf/fast.sh @@ -230,11 +230,56 @@ parse_pipe_3_get_and_send_tensors_with_ddr() { LOG_NOTICE "Pipe-3: Get and Send Tensors (with DDR)" - $(grep 'getAndSendTensorsTC(ms)' $logfile > /dev/null 2>&1) + grep 'parseKeyTC' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ + {printf "parseKeyTC(filter>1000ms): avg=%0.1f\n", sum/count}' + + + grep 'getAndSendTensorsTC' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ + {printf "--getAndSendTensorsTC(filter>1000ms): avg=%0.1f\n", sum/count}' + + grep 'getTensorsTC' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ + {printf "----getTensorsTC(filter>1000ms): avg=%0.1f\n", sum/count}' + + $(grep 'hostHashMapProcessTC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'hostHashMapProcessTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {print "----hostHashMapProcessTC, avg=", sum/NR}' + fi + + $(grep 'sendTensorsTC(ms)' $logfile > /dev/null 2>&1) if [ $? == 0 ]; then - grep 'getAndSendTensorsTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {print "GetAndSendTensors, avg=", sum/NR}' + grep 'sendTensorsTC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {print "----sendTensorsTC, avg=", sum/NR}' fi + + $(grep 'embHdTrans1TC(ms)' $logfile > /dev/null 2>&1) + if [ $? == 0 ]; then + grep 'embHdTrans1TC(ms)' $logfile | \ + awk -F":" '{sum+=$NF} END {print "--embHdTrans1TC, avg=", sum/NR}' + fi + + grep 'embHdTrans2TC' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ + {printf "--embHdTrans2TC(filter>1000ms): avg=%0.1f\n", sum/count}' + + grep 'hostEmbsTC' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ + {printf "----hostEmbsTC(filter>1000ms): avg=%0.1f\n", sum/count}' + + grep 'EmbHDTrans' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ + {printf "----EmbHDTrans(filter>1000ms): avg=%0.1f\n", sum/count}' + + grep 'h2dTC' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ + {printf "------h2dTC(filter>1000ms): avg=%0.1f\n", sum/count}' + + grep 'd2hTC' $logfile | \ + awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ + {printf "------d2hTC(filter>1000ms): avg=%0.1f\n", sum/count}' } parse_pipe_3_get_and_send_tensors_sync_without_ddr() -- Gitee From ffe24c10cc0d6a491db0476153164a6f7e7d7993 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 13 Jul 2023 19:37:37 +0800 Subject: [PATCH 202/551] Match-id-400b656f39b7141c6cff1f4e59cc547d6b5f31a4 --- mx_rec/core/asc/helper.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index a27a28c2..ac957463 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -70,9 +70,6 @@ def find_dangling_table(table_names: List[str]): if 'gradients/' in table_reachable_tensor.name and table_reachable_tensor.op.type == 'Identity': return True - # if 'logistic_loss' in table_reachable_tensor.op.name and table_reachable_tensor.op.type == 'AddV2': - # return True - if 'SparseSoftmaxCrossEntropyWithLogits' in table_reachable_tensor.op.name \ and table_reachable_tensor.op.type == 'SparseSoftmaxCrossEntropyWithLogits': return True -- Gitee From 2e70c56031fa48a7452638f211cd8b020a42cacb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 13 Jul 2023 20:25:27 +0800 Subject: [PATCH 203/551] Match-id-6bd76afab99f98b039a32b195dd70f04083ec7cc --- mx_rec/core/asc/helper.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 7b628fc2..ac957463 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -70,9 +70,6 @@ def find_dangling_table(table_names: List[str]): if 'gradients/' in table_reachable_tensor.name and table_reachable_tensor.op.type == 'Identity': return True - if 'logistic_loss' in table_reachable_tensor.op.name and table_reachable_tensor.op.type == 'AddV2': - return True - if 'SparseSoftmaxCrossEntropyWithLogits' in table_reachable_tensor.op.name \ and table_reachable_tensor.op.type == 'SparseSoftmaxCrossEntropyWithLogits': return True -- Gitee From 6fc96b73d294911a030f772092e3613ef7372d7b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 13 Jul 2023 20:52:15 +0800 Subject: [PATCH 204/551] Match-id-72de4c4179ac6292f1c7d80ab297e0ce64d8a303 --- mx_rec/core/embedding.py | 23 ++++++++--------------- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 26 +++++++++----------------- src/core/key_process/key_process.cpp | 17 +++++++++++------ 3 files changed, 28 insertions(+), 38 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 3bc52ed7..ab5c19aa 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -895,22 +895,15 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): logging.debug(f'Channel {instance.table_name}_evict_{TRAIN_CHANNEL_ID} was built for op ' f'getnext') - use_static = get_use_static() - if use_static: - evict_pos = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32], - output_shapes=[instance.slice_device_vocabulary_size], - channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}')[0] - initialized_tensor = instance.emb_initializer( - instance.slice_device_vocabulary_size + instance.embedding_size) * instance.init_param - else: - evict_pos = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32], - output_shapes=[None], - channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}')[0] + evict_pos, evict_len = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[[None], []], + channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}') + + initialized_tensor = instance.emb_initializer( + instance.slice_device_vocabulary_size + instance.embedding_size) * instance.init_param - initialized_tensor = instance.emb_initializer( - tf.shape(evict_pos)[0] + instance.embedding_size) * instance.init_param + initialized_tensor = initialized_tensor[0:evict_len, :] logging.debug(f'evict_pos output shape {evict_pos}, and slice_device_vocabulary_size ' f'{instance.slice_device_vocabulary_size}, ' diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 32a8beac..36424db0 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -890,23 +890,15 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) auto evictDevOffset = hostHashMaps->embHashMaps.at(embName).evictDevPos; spdlog::debug(MGMT + "ddr mode, init dev emb: [{}]! evict size on dev :{}", embName, evictDevOffset.size()); - for (const auto& embInfo : mgmtEmbInfo) { - if (embInfo.name != embName) { - continue; - } - if (evictDevOffset.size() > embInfo.devVocabSize) { - spdlog::error(MGMT + "{} overflow! evict pos on dev {} bigger than dev vocabSize {}", - embName, evictDevOffset.size(), embInfo.devVocabSize); - throw runtime_error(fmt::format(MGMT + "{} overflow! evict pos on dev {} bigger than dev vocabSize {}", - embName, evictDevOffset.size(), embInfo.devVocabSize).c_str()); - } - if (mgmtRankInfo.useStatic) { - evictDevOffset.resize(embInfo.devVocabSize, -1); - } - break; - } + vector tmpDataOut; + Tensor tmpData = Vec2TensorI32(evictDevOffset); + tmpDataOut.emplace_back(tmpData); + tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); + + auto evictLen = tmpDataOut.back().flat(); + auto evictSize = static_cast(evictDevOffset.size()); + evictLen(0) = evictSize; - auto tmpData = Vec2TensorI32(evictDevOffset); - hdTransfer->Send(TransferChannel::EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); + hdTransfer->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); #endif } diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 818d21ae..82d11a72 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1239,14 +1239,19 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset throw runtime_error(fmt::format("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", embName, offset.size(), embInfos[embName].devVocabSize).c_str()); } - if (rankInfo.useStatic) { - offset.resize(embInfos[embName].devVocabSize, -1); - } - auto trans = Singleton::GetInstance(); + vector tmpDataOut; + Tensor tmpData = Vec2TensorI32(offset); + tmpDataOut.emplace_back(tmpData); + tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); + + auto evictLen = tmpDataOut.back().flat(); + auto evictSize = static_cast(offset.size()); + evictLen(0) = evictSize; + // evict key发送给dev侧,dev侧初始化emb - auto tmpData = Vec2TensorI32(offset); - trans->Send(TransferChannel::EVICT, { tmpData }, TRAIN_CHANNEL_ID, embName); + auto trans = Singleton::GetInstance(); + trans->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); spdlog::info(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); } -- Gitee From 7c09a8b6eecfc29aecf794d237011588d137aa55 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 13 Jul 2023 21:04:14 +0800 Subject: [PATCH 205/551] Match-id-5440bfa245f8218fc13af376e1555ff0c683196f --- example/little_demo/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/example/little_demo/config.py b/example/little_demo/config.py index 122b168f..6dd5d183 100644 --- a/example/little_demo/config.py +++ b/example/little_demo/config.py @@ -108,5 +108,6 @@ def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all") session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF return session_config -- Gitee From 123bbc104d821d0ff37ebcf783315e6aa277f08f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 13 Jul 2023 22:36:57 +0800 Subject: [PATCH 206/551] Match-id-6caca50025f5da902611e64c9b37032a2451ec6c --- build/build.sh | 9 --------- 1 file changed, 9 deletions(-) diff --git a/build/build.sh b/build/build.sh index 0a975d8f..e1ad0905 100644 --- a/build/build.sh +++ b/build/build.sh @@ -148,7 +148,6 @@ gen_wheel_file() gen_tar_file() { cd "${src_path}" - mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" cp -r "${src_path}"/../example ../build/"${pkg_dir}" cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" @@ -168,7 +167,6 @@ clean() remove "${ROOT_DIR}"/mx_rec.egg-info remove "${ROOT_DIR}"/src/build remove "${ROOT_DIR}"/build/bdist.linux-"$(arch)" - remove "${ROOT_DIR}"/build/tf1_env remove "${ROOT_DIR}"/build/tf2_env remove "${ROOT_DIR}"/build/lib remove "${ROOT_DIR}"/build/mindxsdk-mxrec @@ -182,13 +180,6 @@ then echo "-----Build AccCTR -----" compile_acc_ctr_so_file - echo "-----Build Start tf1 -----" - source /opt/buildtools/tf1_env/bin/activate - compile_so_file "${tf1_path}" - collect_so_file - gen_wheel_file "${ROOT_DIR}"/tf1_whl - deactivate tf1_env - echo "-----Build Start tf2 -----" source /opt/buildtools/tf2_env/bin/activate compile_so_file "${tf2_path}" -- Gitee From 261dc02a30c24cce1b531cc826dcfa21f8c15160 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 14 Jul 2023 11:13:03 +0800 Subject: [PATCH 207/551] Match-id-496097ac7daeced6aeba2688ff103e4a5d6bd43b --- mx_rec/core/embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 3f08d105..553acb48 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -53,7 +53,6 @@ def create_table(**kwargs): apply_gradients_strategy: direct_apply (default) or sum_same_id_gradients_and_apply. """ - check_create_table_params(key_dtype, dim, name, emb_initializer) key_dtype = kwargs.get("key_dtype") dim = kwargs.get("dim") name = kwargs.get("name") @@ -71,6 +70,8 @@ def create_table(**kwargs): all2all_gradients_op = kwargs.get("all2all_gradients_op", All2allGradientsOp.SUM_GRADIENTS) apply_gradients_strategy = kwargs.get("apply_gradients_strategy", ApplyGradientsStrategy.DIRECT_APPLY) + check_create_table_params(key_dtype, dim, name, emb_initializer) + config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, optimizer_list=optimizer_list, mode=mode, value_dtype=value_dtype, shard_num=shard_num, -- Gitee From e97073ca42593ea057c62bf540d2baadfbe5a37c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 14 Jul 2023 14:35:24 +0800 Subject: [PATCH 208/551] Match-id-0a8f830dd65c39ed1228d629b8a74de4f51a4b60 --- src/ops_tf/hybrid_dataset_ops.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 3d071e10..6642be39 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -64,6 +64,11 @@ public: { OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); + if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { + throw runtime_error(fmt::format("channelId is invalid, It should be in range [0, {})", + MAX_CHANNEL_NUM)); + } + if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", fmt::format("ClearChannel channelId invalid. It should be in range " -- Gitee From f11ad26ded4e98ff0fe8a93c9eb7d5004a3c7537 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 14 Jul 2023 17:12:02 +0800 Subject: [PATCH 209/551] Match-id-7b830957814c6930788b179582df581fdb7cc513 --- build/build.sh | 113 ++----------------------- build/build_tf1.sh | 178 +++++++++++++++++++++++++++++++++++++++ build/build_tf2.sh | 163 +++++++++++++++++++++++++++++++++++ mx_rec/core/embedding.py | 16 ++-- src/core/CMakeLists.txt | 7 +- 5 files changed, 361 insertions(+), 116 deletions(-) create mode 100644 build/build_tf1.sh create mode 100644 build/build_tf2.sh diff --git a/build/build.sh b/build/build.sh index e1ad0905..ce0ef514 100644 --- a/build/build.sh +++ b/build/build.sh @@ -12,15 +12,6 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")") ROOT_DIR=$(dirname "${SCRIPT_DIR}") cd "$SCRIPT_DIR" -if [ "$(uname -m)" = "x86_64" ] -then - source /opt/buildtools/tf2_env/bin/activate - tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow - deactivate tf2_env - source /opt/buildtools/tf1_env/bin/activate - tf1_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow_core - deactivate tf1_env -fi VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml get_version() { @@ -55,99 +46,15 @@ remove "${pkg_dir}" mkdir "${pkg_dir}" mv version.info "${pkg_dir}" -opensource_path="${ROOT_DIR}"/../opensource/opensource -abseil_src_path=${opensource_path}/abseil -echo "${abseil_src_path}" -abseil_install_path="${ROOT_DIR}"/install/abseil - src_path="${ROOT_DIR}"/src -acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR -cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c cd "${ROOT_DIR}" release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz -install_abseil() -{ - remove "${abseil_install_path}" - echo "${abseil_install_path}" - if [[ ! -d "${abseil_install_path}" ]] - then mkdir -p "${abseil_install_path}" - fi - - cd "${abseil_src_path}" - echo "${abseil_src_path}" - remove CMakeCache.txt - cmake -DCMAKE_INSTALL_PREFIX="${abseil_install_path}" . && make -j8 && make install - - echo "${project_output_path}"/abseil - mkdir -p "${project_output_path}"/abseil - if [ -d "${abseil_install_path}"/lib64/ ]; then - cp -rf "${abseil_install_path}"/lib64/libabsl* "${project_output_path}"/abseil - elif [ -d "${abseil_install_path}"/lib/ ]; then - cp -rf "${abseil_install_path}"/lib/libabsl* "${project_output_path}"/abseil - else - echo "${abseil_install_path}"/lib64/ not exist - exit 1 - fi -} - -compile_securec() -{ - if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then - echo "securec is not exist" - exit 1 - fi - - if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then - cd "${ROOT_DIR}"/platform/securec/src - make -j - fi -} - -compile_so_file() -{ - cd "${src_path}" - chmod u+x build.sh - ./build.sh "$1" "${ROOT_DIR}" - cd .. -} - -compile_acc_ctr_so_file() -{ - cd "${acc_ctr_path}" - chmod u+x build.sh - ./build.sh "release" -} - -collect_so_file() -{ - cd "${src_path}" - remove "${src_path}"/libasc - mkdir -p "${src_path}"/libasc - chmod u+x libasc - - cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc - cp -df "${ROOT_DIR}"/output/*.so* libasc - cp "${ROOT_DIR}"/platform/securec/lib/libsecurec.so libasc -} - -gen_wheel_file() -{ - cd "${ROOT_DIR}" - touch "${src_path}"/libasc/__init__.py - remove "${ROOT_DIR}"/mx_rec/libasc - mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec - cp -rf "${ROOT_DIR}"/tools "${ROOT_DIR}"/mx_rec - python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) - mkdir -p "$1" - mv dist/mx_rec*.whl "$1" - remove "${ROOT_DIR}"/mx_rec/libasc -} - gen_tar_file() { cd "${src_path}" + mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" cp -r "${src_path}"/../example ../build/"${pkg_dir}" cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" @@ -168,27 +75,19 @@ clean() remove "${ROOT_DIR}"/src/build remove "${ROOT_DIR}"/build/bdist.linux-"$(arch)" remove "${ROOT_DIR}"/build/tf2_env + remove "${ROOT_DIR}"/build/tf1_env remove "${ROOT_DIR}"/build/lib remove "${ROOT_DIR}"/build/mindxsdk-mxrec } + if [ "$(uname -m)" = "x86_64" ] then - install_abseil - compile_securec - - echo "-----Build AccCTR -----" - compile_acc_ctr_so_file - - echo "-----Build Start tf2 -----" - source /opt/buildtools/tf2_env/bin/activate - compile_so_file "${tf2_path}" - collect_so_file - gen_wheel_file "${ROOT_DIR}"/tf2_whl - deactivate tf2_env - echo "-----Build gen tar -----" + bash ${ROOT_DIR}/build/build_tf1.sh + bash ${ROOT_DIR}/build/build_tf2.sh gen_tar_file + echo "-----Build gen tar finished-----" clean echo "-----Done-----" diff --git a/build/build_tf1.sh b/build/build_tf1.sh new file mode 100644 index 00000000..ef2a9797 --- /dev/null +++ b/build/build_tf1.sh @@ -0,0 +1,178 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. +# Description: build script. +# Author: MindX SDK +# Create: 2021 +# History: NA + +set -e +warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } +ARCH="$(uname -m)" +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +ROOT_DIR=$(dirname "${SCRIPT_DIR}") +cd "$SCRIPT_DIR" + +if [ "$(uname -m)" = "x86_64" ] +then + virtualenv -p "$(which python3.7)" tf1_env + source /opt/buildtools/tf1_env/bin/activate + tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core + deactivate tf1_env +fi + +VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml +get_version() { + if [ -f "$VERSION_FILE" ]; then + VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") + if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then + VERSION=${VERSION%.*} + fi + else + VERSION="5.0.T104" + fi +} + +remove() +{ + if [ -d "$1" ]; then + rm -rf "$1" + elif [ -f "$1" ]; then + rm -f "$1" + fi +} + +project_output_path="${ROOT_DIR}"/output/ +remove "${project_output_path}" +remove "${SCRIPT_DIR}/lib" +get_version +export VERSION +echo "MindX SDK mxrec: ${VERSION}" >> ./version.info + +pkg_dir=mindxsdk-mxrec +remove "${pkg_dir}" +mkdir "${pkg_dir}" +mv version.info "${pkg_dir}" + +opensource_path="${ROOT_DIR}"/../opensource/opensource +abseil_src_path=${opensource_path}/abseil +echo "${abseil_src_path}" +abseil_install_path="${ROOT_DIR}"/install/abseil + +src_path="${ROOT_DIR}"/src +acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR +cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c +cd "${ROOT_DIR}" + +install_abseil() +{ + remove "${abseil_install_path}" + echo "${abseil_install_path}" + if [[ ! -d "${abseil_install_path}" ]] + then mkdir -p "${abseil_install_path}" + fi + + cd "${abseil_src_path}" + echo "${abseil_src_path}" + remove CMakeCache.txt + cmake -DCMAKE_INSTALL_PREFIX="${abseil_install_path}" . && make -j8 && make install + + echo "${project_output_path}"/abseil + mkdir -p "${project_output_path}"/abseil + if [ -d "${abseil_install_path}"/lib64/ ]; then + cp -rf "${abseil_install_path}"/lib64/libabsl* "${project_output_path}"/abseil + elif [ -d "${abseil_install_path}"/lib/ ]; then + cp -rf "${abseil_install_path}"/lib/libabsl* "${project_output_path}"/abseil + else + echo "${abseil_install_path}"/lib64/ not exist + exit 1 + fi +} + +compile_securec() +{ + if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then + echo "securec is not exist" + exit 1 + fi + + if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then + cd "${ROOT_DIR}"/platform/securec/src + make -j + fi +} + +compile_so_file() +{ + cd "${src_path}" + chmod u+x build.sh + ./build.sh "$1" "${ROOT_DIR}" + cd .. +} + +compile_acc_ctr_so_file() +{ + cd "${acc_ctr_path}" + chmod u+x build.sh + ./build.sh "release" +} + +collect_so_file() +{ + cd "${src_path}" + remove "${src_path}"/libasc + mkdir -p "${src_path}"/libasc + chmod u+x libasc + + cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc + cp -df "${ROOT_DIR}"/output/*.so* libasc + cp "${ROOT_DIR}"/platform/securec/lib/libsecurec.so libasc +} + +gen_wheel_file() +{ + cd "${ROOT_DIR}" + touch "${src_path}"/libasc/__init__.py + remove "${ROOT_DIR}"/mx_rec/libasc + mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec + cp -rf "${ROOT_DIR}"/tools "${ROOT_DIR}"/mx_rec + python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) + mkdir -p "$1" + mv dist/mx_rec*.whl "$1" + remove "${ROOT_DIR}"/mx_rec/libasc +} + +gen_tar_file() +{ + cd "${src_path}" + mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" + mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" + cp -r "${src_path}"/../example ../build/"${pkg_dir}" + cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" + cd ../build + tar -zvcf "${release_tar}" "${pkg_dir}" || { + warn "compression failed, packages might be broken" + } + + mv "${release_tar}" "${SCRIPT_DIR}"/../output/ + +} + + +if [ "$(uname -m)" = "x86_64" ] +then + install_abseil + compile_securec + + echo "-----Build AccCTR -----" + compile_acc_ctr_so_file + + echo "-----Build Start tf1 -----" + virtualenv -p "$(which python3.7)" tf1_env + echo "--tf1 env ${env}---" + source /opt/buildtools/tf1_env/bin/activate + compile_so_file "${tf1_path}" + collect_so_file + gen_wheel_file "${ROOT_DIR}"/tf1_whl + deactivate tf1_env + echo "-----Build tf1 finished-----" +fi \ No newline at end of file diff --git a/build/build_tf2.sh b/build/build_tf2.sh new file mode 100644 index 00000000..481e3eb0 --- /dev/null +++ b/build/build_tf2.sh @@ -0,0 +1,163 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. +# Description: build script. +# Author: MindX SDK +# Create: 2023 +# History: NA + +set -e +warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } +ARCH="$(uname -m)" +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +ROOT_DIR=$(dirname "${SCRIPT_DIR}") +cd "$SCRIPT_DIR" + +if [ "$(uname -m)" = "x86_64" ] +then + virtualenv -p "$(which python3.7)" tf2_env + source /opt/buildtools/tf2_env/bin/activate + tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow + deactivate tf2_env +fi + +VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml +get_version() { + if [ -f "$VERSION_FILE" ]; then + VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") + if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then + VERSION=${VERSION%.*} + fi + else + VERSION="5.0.T104" + fi +} + +remove() +{ + if [ -d "$1" ]; then + rm -rf "$1" + elif [ -f "$1" ]; then + rm -f "$1" + fi +} + +project_output_path="${ROOT_DIR}"/output/ +remove "${project_output_path}" +remove "${SCRIPT_DIR}/lib" +get_version +export VERSION +echo "MindX SDK mxrec: ${VERSION}" >> ./version.info + +pkg_dir=mindxsdk-mxrec +remove "${pkg_dir}" +mkdir "${pkg_dir}" +mv version.info "${pkg_dir}" + +opensource_path="${ROOT_DIR}"/../opensource/opensource +abseil_src_path=${opensource_path}/abseil +echo "${abseil_src_path}" +abseil_install_path="${ROOT_DIR}"/install/abseil + +src_path="${ROOT_DIR}"/src +acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR +cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c +cd "${ROOT_DIR}" + +release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz + +install_abseil() +{ + remove "${abseil_install_path}" + echo "${abseil_install_path}" + if [[ ! -d "${abseil_install_path}" ]] + then mkdir -p "${abseil_install_path}" + fi + + cd "${abseil_src_path}" + echo "${abseil_src_path}" + remove CMakeCache.txt + cmake -DCMAKE_INSTALL_PREFIX="${abseil_install_path}" . && make -j8 && make install + + echo "${project_output_path}"/abseil + mkdir -p "${project_output_path}"/abseil + if [ -d "${abseil_install_path}"/lib64/ ]; then + cp -rf "${abseil_install_path}"/lib64/libabsl* "${project_output_path}"/abseil + elif [ -d "${abseil_install_path}"/lib/ ]; then + cp -rf "${abseil_install_path}"/lib/libabsl* "${project_output_path}"/abseil + else + echo "${abseil_install_path}"/lib64/ not exist + exit 1 + fi +} + +compile_securec() +{ + if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then + echo "securec is not exist" + exit 1 + fi + + if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then + cd "${ROOT_DIR}"/platform/securec/src + make -j + fi +} + +compile_so_file() +{ + cd "${src_path}" + chmod u+x build.sh + ./build.sh "$1" "${ROOT_DIR}" + cd .. +} + +compile_acc_ctr_so_file() +{ + cd "${acc_ctr_path}" + chmod u+x build.sh + ./build.sh "release" +} + +collect_so_file() +{ + cd "${src_path}" + remove "${src_path}"/libasc + mkdir -p "${src_path}"/libasc + chmod u+x libasc + + cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc + cp -df "${ROOT_DIR}"/output/*.so* libasc + cp "${ROOT_DIR}"/platform/securec/lib/libsecurec.so libasc +} + +gen_wheel_file() +{ + cd "${ROOT_DIR}" + touch "${src_path}"/libasc/__init__.py + remove "${ROOT_DIR}"/mx_rec/libasc + mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec + cp -rf "${ROOT_DIR}"/tools "${ROOT_DIR}"/mx_rec + python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) + mkdir -p "$1" + mv dist/mx_rec*.whl "$1" + remove "${ROOT_DIR}"/mx_rec/libasc +} + +if [ "$(uname -m)" = "x86_64" ] +then + install_abseil + compile_securec + + echo "-----Build AccCTR -----" + compile_acc_ctr_so_file + + echo "-----Build Start tf2 -----" + virtualenv -p "$(which python3.7)" tf2_env + source /opt/buildtools/tf2_env/bin/activate + compile_so_file "${tf2_path}" + collect_so_file + gen_wheel_file "${ROOT_DIR}"/tf2_whl + + deactivate tf2_env + echo "-----Build tf2 finished -----" +fi \ No newline at end of file diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 553acb48..7f59694a 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -540,17 +540,17 @@ class SparseEmbedding: local_grad = get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: local_grad = local_grad / get_rank_size() - + if use_dynamic_expansion: return local_grad, feat_ids - + if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: unique_id_offsets, unique_id_offsets_position = array_ops.unique(id_offsets) unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - unique_id_offsets_position, - array_ops.shape(unique_id_offsets)[0]) - return ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, - dense_shape=tf.shape(table)), feat_ids + unique_id_offsets_position, + array_ops.shape(unique_id_offsets)[0]) + return ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, + dense_shape=tf.shape(table)), feat_ids return ops.IndexedSlices(values=local_grad, indices=id_offsets, dense_shape=tf.shape(table)), feat_ids @@ -765,8 +765,8 @@ class SparseEmbedding: if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: unique_id_offsets, unique_id_offsets_position = array_ops.unique(id_offsets) unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - unique_id_offsets_position, - array_ops.shape(unique_id_offsets)[0]) + unique_id_offsets_position, + array_ops.shape(unique_id_offsets)[0]) update_grad = ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, dense_shape=tf.shape(table)) else: diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 3b59fb7a..1afdb368 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -12,7 +12,12 @@ message("SECUREC_PATH: " ${SECUREC_PATH}) include_directories(${ABSEIL_PATH}/include) link_directories(${ABSEIL_PATH}/lib) -include_directories(${PYTHON_PATH}/lib/python3.7/site-packages/tensorflow_core/include) + +if (${TF_PATH} MATCHES "tensorflow_core") + include_directories(${PYTHON_PATH}/lib/python3.7/site-packages/tensorflow_core/include) +else() + include_directories(${PYTHON_PATH}/lib/python3.7/site-packages/tensorflow/include) +endif() file(GLOB_RECURSE MXREC_SRC ./*.cpp) add_library(ASC SHARED ${MXREC_SRC}) -- Gitee From 0bdd3b0e508e6449e3635bdeb6cbe11bd4c04f0d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 14 Jul 2023 17:37:10 +0800 Subject: [PATCH 210/551] Match-id-4ea99acd634aaf85d1af7d5587d7fac7bcaae094 --- mx_rec/core/embedding.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 553acb48..48c0fce8 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -12,7 +12,8 @@ import numpy as np import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops.init_ops import Initializer +from tensorflow.python.ops.init_ops import Initializer as InitializerV1 +from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2 from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temporary_feature_spec_attribute @@ -540,17 +541,17 @@ class SparseEmbedding: local_grad = get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: local_grad = local_grad / get_rank_size() - + if use_dynamic_expansion: return local_grad, feat_ids - + if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: unique_id_offsets, unique_id_offsets_position = array_ops.unique(id_offsets) unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - unique_id_offsets_position, - array_ops.shape(unique_id_offsets)[0]) - return ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, - dense_shape=tf.shape(table)), feat_ids + unique_id_offsets_position, + array_ops.shape(unique_id_offsets)[0]) + return ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, + dense_shape=tf.shape(table)), feat_ids return ops.IndexedSlices(values=local_grad, indices=id_offsets, dense_shape=tf.shape(table)), feat_ids @@ -765,8 +766,8 @@ class SparseEmbedding: if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: unique_id_offsets, unique_id_offsets_position = array_ops.unique(id_offsets) unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - unique_id_offsets_position, - array_ops.shape(unique_id_offsets)[0]) + unique_id_offsets_position, + array_ops.shape(unique_id_offsets)[0]) update_grad = ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, dense_shape=tf.shape(table)) else: @@ -1023,6 +1024,6 @@ def check_create_table_params(key_dtype, dim, name, emb_initializer): name_validator.check_whitelist() name_validator.check() # check emb_initializer - emb_initializer_validator = ClassValidator(value=emb_initializer, classes=Initializer) + emb_initializer_validator = ClassValidator(value=emb_initializer, classes=(InitializerV1, InitializerV2)) emb_initializer_validator.check_isinstance() emb_initializer_validator.check() -- Gitee From 4448ac3c490f7bcaa39a5f89c72614e07173ff24 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 14 Jul 2023 18:33:21 +0800 Subject: [PATCH 211/551] Match-id-19a9dbc3632d3a073c9ae07713bd857635216538 --- .../op_host/embedding_lookup_by_address.cpp | 43 +++++-------------- .../embedding_lookup_by_address_tiling.h | 4 +- .../op_host/embedding_update_by_address.cpp | 41 ++++-------------- .../embedding_update_by_address_tiling.h | 4 +- .../op_kernel/embedding_lookup_by_address.cpp | 2 +- .../op_kernel/embedding_update_by_address.cpp | 2 +- 6 files changed, 25 insertions(+), 71 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 3e50e6a4..117efaa3 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -4,11 +4,6 @@ namespace optiling { - struct TilingCompileInfo - { - int64_t ub_size; - }; - static ge::graphStatus TilingFunc(gert::TilingContext *context) { TilingData1 tiling; @@ -19,9 +14,9 @@ namespace optiling currentWorkspace[0] = sysWorkspaceSize + usrSize; int32_t block_total_nums = 48; - int32_t ub_limit = 160 * 1024; + int32_t ub_limit = 175 * 1024; auto *attrs = context->GetAttrs(); - const auto *attr0_value = attrs->GetAttrPointer(0); + const auto *attr0_value = attrs->GetAttrPointer(0); int32_t embbeding_dim = *attr0_value; const auto *attr1_value = attrs->GetAttrPointer(1); int32_t embbeding_type = *attr1_value; @@ -38,23 +33,11 @@ namespace optiling return ge::GRAPH_SUCCESS; } - - static ge::graphStatus TilingPrepare(gert::TilingParseContext *context) - { - return ge::GRAPH_SUCCESS; - } - - static ge::graphStatus check_op_support(const ge::Operator &op, ge::AscendString &result) - { - std::string res_json_str = "{\"ret_code\": \"0\",\"reason\": \"check_supported_stub\"}"; - result = ge::AscendString(res_json_str.c_str()); - return 1; - } } namespace ge { - ge::graphStatus InferShape1(gert::InferShapeContext *context) + static ge::graphStatus InferShape1(gert::InferShapeContext *context) { gert::Shape *y_shape = context->GetOutputShape(0); @@ -67,11 +50,7 @@ namespace ge y_shape->SetDim(1, update_dim); return GRAPH_SUCCESS; } - ge::graphStatus InferShapeRange1(gert::InferShapeRangeContext *context) - { - return GRAPH_SUCCESS; - } - ge::graphStatus InferDataType1(gert::InferDataTypeContext *context) + static ge::graphStatus InferDataType1(gert::InferDataTypeContext *context) { int64_t embbeding_type; @@ -126,27 +105,25 @@ namespace ops .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("embedding_dim").AttrType(OPTIONAL).Int(32); this->Attr("embedding_type").AttrType(OPTIONAL).Int(1); + this->SetInferShape(ge::InferShape1) .SetInferDataType(ge::InferDataType1); this->AICore() - .SetTiling(optiling::TilingFunc) - .SetTilingParse(optiling::TilingPrepare) - .SetCheckSupport(optiling::check_op_support); + .SetTiling(optiling::TilingFunc); OpAICoreConfig aicConfig; - aicConfig.AsyncFlag(true) - .DynamicCompileStaticFlag(true) + aicConfig.DynamicCompileStaticFlag(true) .DynamicFormatFlag(true) .DynamicRankSupportFlag(true) .DynamicShapeSupportFlag(true) .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(false) - .RangeLimitValue("limited"); + .PrecisionReduceFlag(false); + this->AICore().AddConfig("ascend910b", aicConfig); this->AICore().AddConfig("ascend910", aicConfig); } }; - OP_ADD(EmbeddingLookupByAddress, optiling::TilingCompileInfo); + OP_ADD(EmbeddingLookupByAddress); } diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h index 12c45086..b91f759b 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h @@ -2,13 +2,13 @@ namespace optiling { - BEGIN_TILING_DATA_DEF(TilingData1) +BEGIN_TILING_DATA_DEF(TilingData1) TILING_DATA_FIELD_DEF(int32_t, update_dim); TILING_DATA_FIELD_DEF(int32_t, addr_nums); TILING_DATA_FIELD_DEF(int32_t, ub_limit); TILING_DATA_FIELD_DEF(int32_t, embbeding_type); TILING_DATA_FIELD_DEF(int32_t, update_type); - END_TILING_DATA_DEF; +END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(EmbeddingLookupByAddress, TilingData1) } \ No newline at end of file diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index b03d1a6d..ca43db22 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -4,10 +4,6 @@ namespace optiling { - struct TilingCompileInfo - { - int64_t ub_size; - }; static ge::graphStatus TilingFunc(gert::TilingContext *context) { @@ -19,14 +15,14 @@ namespace optiling currentWorkspace[0] = sysWorkspaceSize + usrSize; int32_t block_total_nums = 48; - int32_t ub_limit = 160 * 1024; + int32_t ub_limit = 175 * 1024; int32_t update_dim; int32_t embbeding_type; int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); int32_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; - int32_t update_type=*(context->GetAttrs()->GetAttrPointer(0)); + int32_t update_type = *(context->GetAttrs()->GetAttrPointer(0)); ge::DataType input_datatype = context->GetInputTensor(1)->GetDataType(); switch (input_datatype) @@ -57,23 +53,11 @@ namespace optiling return ge::GRAPH_SUCCESS; } - - static ge::graphStatus TilingPrepare(gert::TilingParseContext *context) - { - return ge::GRAPH_SUCCESS; - } - - static ge::graphStatus check_op_support(const ge::Operator &op, ge::AscendString &result) - { - std::string res_json_str = "{\"ret_code\": \"0\",\"reason\": \"check_supported_stub\"}"; - result = ge::AscendString(res_json_str.c_str()); - return 1; - } } namespace ge { - ge::graphStatus InferShape(gert::InferShapeContext *context) + static ge::graphStatus InferShape(gert::InferShapeContext *context) { gert::Shape *y_shape = context->GetOutputShape(0); int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); @@ -83,11 +67,7 @@ namespace ge y_shape->SetDim(1, input_dim); return GRAPH_SUCCESS; } - ge::graphStatus InferShapeRange(gert::InferShapeRangeContext *context) - { - return GRAPH_SUCCESS; - } - ge::graphStatus InferDataType(gert::InferDataTypeContext *context) + static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) { context->SetOutputDataType(0, ge::DataType(DT_FLOAT)); return GRAPH_SUCCESS; @@ -121,22 +101,19 @@ namespace ops .SetInferDataType(ge::InferDataType); this->AICore() - .SetTiling(optiling::TilingFunc) - .SetTilingParse(optiling::TilingPrepare) - .SetCheckSupport(optiling::check_op_support); + .SetTiling(optiling::TilingFunc); OpAICoreConfig aicConfig; - aicConfig.AsyncFlag(true) - .DynamicCompileStaticFlag(true) + aicConfig.DynamicCompileStaticFlag(true) .DynamicFormatFlag(true) .DynamicRankSupportFlag(true) .DynamicShapeSupportFlag(true) .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(false) - .RangeLimitValue("limited"); + .PrecisionReduceFlag(false); + this->AICore().AddConfig("ascend910b", aicConfig); this->AICore().AddConfig("ascend910", aicConfig); } }; - OP_ADD(EmbeddingUpdateByAddress, optiling::TilingCompileInfo); + OP_ADD(EmbeddingUpdateByAddress); } diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h index 2a28626c..323014d3 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h @@ -3,13 +3,13 @@ namespace optiling { - BEGIN_TILING_DATA_DEF(TilingData2) +BEGIN_TILING_DATA_DEF(TilingData2) TILING_DATA_FIELD_DEF(int32_t, update_dim); TILING_DATA_FIELD_DEF(int32_t, addr_nums); TILING_DATA_FIELD_DEF(int32_t, ub_limit); TILING_DATA_FIELD_DEF(int32_t, embbeding_type); TILING_DATA_FIELD_DEF(int32_t, update_type); - END_TILING_DATA_DEF; +END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(EmbeddingUpdateByAddress, TilingData2) } \ No newline at end of file diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index c2edb18f..437e5b25 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -1,6 +1,6 @@ #include "kernel_operator.h" -using namespace tik2; +using namespace AscendC; template class KernelEimtable { diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index 04f1da19..4e45d2ac 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -1,5 +1,5 @@ #include "kernel_operator.h" -using namespace tik2; +using namespace AscendC; template class KernelEimtable_update { -- Gitee From 4f342fd628b4cd64c9d1d458cbd3e01940cb694b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 14 Jul 2023 19:00:29 +0800 Subject: [PATCH 212/551] Match-id-33104cfdec259dc6548c68062869d612f0255479 --- build/build_tf1.sh | 29 ----------------------------- build/build_tf2.sh | 29 ----------------------------- src/build.sh | 7 +------ src/core/CMakeLists.txt | 4 +++- src/test_ut.sh | 2 +- 5 files changed, 5 insertions(+), 66 deletions(-) diff --git a/build/build_tf1.sh b/build/build_tf1.sh index ef2a9797..255a6922 100644 --- a/build/build_tf1.sh +++ b/build/build_tf1.sh @@ -54,40 +54,12 @@ mkdir "${pkg_dir}" mv version.info "${pkg_dir}" opensource_path="${ROOT_DIR}"/../opensource/opensource -abseil_src_path=${opensource_path}/abseil -echo "${abseil_src_path}" -abseil_install_path="${ROOT_DIR}"/install/abseil src_path="${ROOT_DIR}"/src acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c cd "${ROOT_DIR}" -install_abseil() -{ - remove "${abseil_install_path}" - echo "${abseil_install_path}" - if [[ ! -d "${abseil_install_path}" ]] - then mkdir -p "${abseil_install_path}" - fi - - cd "${abseil_src_path}" - echo "${abseil_src_path}" - remove CMakeCache.txt - cmake -DCMAKE_INSTALL_PREFIX="${abseil_install_path}" . && make -j8 && make install - - echo "${project_output_path}"/abseil - mkdir -p "${project_output_path}"/abseil - if [ -d "${abseil_install_path}"/lib64/ ]; then - cp -rf "${abseil_install_path}"/lib64/libabsl* "${project_output_path}"/abseil - elif [ -d "${abseil_install_path}"/lib/ ]; then - cp -rf "${abseil_install_path}"/lib/libabsl* "${project_output_path}"/abseil - else - echo "${abseil_install_path}"/lib64/ not exist - exit 1 - fi -} - compile_securec() { if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then @@ -160,7 +132,6 @@ gen_tar_file() if [ "$(uname -m)" = "x86_64" ] then - install_abseil compile_securec echo "-----Build AccCTR -----" diff --git a/build/build_tf2.sh b/build/build_tf2.sh index 481e3eb0..b2397050 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -54,9 +54,6 @@ mkdir "${pkg_dir}" mv version.info "${pkg_dir}" opensource_path="${ROOT_DIR}"/../opensource/opensource -abseil_src_path=${opensource_path}/abseil -echo "${abseil_src_path}" -abseil_install_path="${ROOT_DIR}"/install/abseil src_path="${ROOT_DIR}"/src acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR @@ -65,31 +62,6 @@ cd "${ROOT_DIR}" release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz -install_abseil() -{ - remove "${abseil_install_path}" - echo "${abseil_install_path}" - if [[ ! -d "${abseil_install_path}" ]] - then mkdir -p "${abseil_install_path}" - fi - - cd "${abseil_src_path}" - echo "${abseil_src_path}" - remove CMakeCache.txt - cmake -DCMAKE_INSTALL_PREFIX="${abseil_install_path}" . && make -j8 && make install - - echo "${project_output_path}"/abseil - mkdir -p "${project_output_path}"/abseil - if [ -d "${abseil_install_path}"/lib64/ ]; then - cp -rf "${abseil_install_path}"/lib64/libabsl* "${project_output_path}"/abseil - elif [ -d "${abseil_install_path}"/lib/ ]; then - cp -rf "${abseil_install_path}"/lib/libabsl* "${project_output_path}"/abseil - else - echo "${abseil_install_path}"/lib64/ not exist - exit 1 - fi -} - compile_securec() { if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then @@ -145,7 +117,6 @@ gen_wheel_file() if [ "$(uname -m)" = "x86_64" ] then - install_abseil compile_securec echo "-----Build AccCTR -----" diff --git a/src/build.sh b/src/build.sh index c5728012..94e9707b 100644 --- a/src/build.sh +++ b/src/build.sh @@ -18,18 +18,13 @@ else exit 1 fi -if [ ! -d "$2"/install/abseil/ ]; then - echo "ERROR: $2/install/abseil/ not exist" - exit 1 -fi - cmake -DCMAKE_BUILD_TYPE=Release \ -DTF_PATH="$1" \ -DOMPI_PATH=/usr/local/openmpi/ \ -DPYTHON_PATH="$python_path" \ -DEASY_PROFILER_PATH=/ \ -DASCEND_PATH="$ascend_path" \ - -DABSEIL_PATH="$2"/install/abseil/ \ + -DABSEIL_PATH="$python_path"/lib/python3.7/site-packages/tensorflow_core/ \ -DSECUREC_PATH="$2"/platform/securec \ -DCMAKE_INSTALL_PREFIX="$2"/output .. make -j diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 1afdb368..5ff198d8 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -1,8 +1,10 @@ cmake_minimum_required(VERSION 3.12) set(CMAKE_CXX_STANDARD 17) + if(NOT ABSEIL_PATH) - set(ABSEIL_PATH ${PROJECT_SOURCE_DIR}/../install/abseil/) + set(ABSEIL_PATH ${PYTHON_PATH}/lib/python3.7/site-packages/tensorflow_core/) endif() + message("ABSEIL_PATH: " ${ABSEIL_PATH}) if(NOT SECUREC_PATH) diff --git a/src/test_ut.sh b/src/test_ut.sh index 712ce65f..54fde4ee 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -59,7 +59,7 @@ cmake -DCMAKE_BUILD_TYPE=Debug \ -DPYTHON_PATH="$(dirname "$(dirname "$(which python3.7)")")" \ -DEASY_PROFILER_PATH=/opt/buildtools/ \ -DASCEND_PATH=/usr/local/Ascend/ascend-toolkit/latest \ - -DABSEIL_PATH="$(dirname "$(dirname "${PWD}")")"/install/abseil/ \ + -DABSEIL_PATH="$python_path"/lib/python3.7/site-packages/tensorflow_core/ \ -DSECUREC_PATH="${ROOT_DIR}"/platform/securec \ -DBUILD_TESTS=on -DCOVERAGE=on "$(dirname "${PWD}")" -- Gitee From a3e9996bbc62599d5acb00c335ee44ca37879b73 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 14 Jul 2023 13:22:47 +0800 Subject: [PATCH 213/551] Match-id-6a83aaa000db67b6f9b7867d1d1bb26711d92ddd --- build/build.sh | 2 + example/little_demo/run.sh | 4 +- src/CMakeLists.txt | 8 +- src/core/CMakeLists.txt | 4 +- src/core/checkpoint/checkpoint.cpp | 71 +++-- .../ckpt_data_handler/ckpt_data_handler.cpp | 7 +- .../feat_admit_n_evict_ckpt.cpp | 3 +- .../host_emb_ckpt/host_emb_ckpt.cpp | 9 +- .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp | 7 +- .../nddr_offset_ckpt/nddr_offset_ckpt.cpp | 5 +- src/core/emb_hashmap/emb_hashmap.cpp | 99 +++--- src/core/emb_table/emb_table.cpp | 81 +++-- src/core/emb_table/emb_table.h | 2 + src/core/hd_transfer/hd_transfer.cpp | 87 ++--- src/core/host_emb/host_emb.cpp | 117 +++---- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 300 ++++++++++-------- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 + .../constant_initializer.cpp | 7 +- .../random_normal_initializer.cpp | 7 +- .../truncated_normal_initializer.cpp | 7 +- .../key_process/feature_admit_and_evict.cpp | 59 ++-- .../key_process/feature_admit_and_evict.h | 3 +- src/core/key_process/key_process.cpp | 270 +++++++++------- src/core/key_process/key_process.h | 5 +- src/core/utils/common.cpp | 32 +- src/core/utils/common.h | 80 ++++- src/ops_tf/hybrid_dataset_ops.cpp | 148 ++++----- src/platform/AccCTR | 2 +- src/test_ut.sh | 2 +- src/tests/CMakeLists.txt | 2 +- src/tests/checkpoint/checkpoint_test.cpp | 7 - .../ckpt_data_handler_test.cpp | 2 - src/tests/emb_mgmt/emb_mgmt_test.cpp | 29 +- src/tests/emb_table/emb_table_test.cpp | 15 +- src/tests/host_emb/host_emb_test.cpp | 1 - src/tests/initializer/initializer_test.cpp | 1 - .../feature_admit_and_evict_test.cpp | 20 +- src/tests/key_process/key_process_test.cpp | 143 ++++++--- src/tests/utils/common_test.cpp | 50 +++ 39 files changed, 997 insertions(+), 703 deletions(-) create mode 100644 src/tests/utils/common_test.cpp diff --git a/build/build.sh b/build/build.sh index ce0ef514..5eff8d3a 100644 --- a/build/build.sh +++ b/build/build.sh @@ -5,6 +5,8 @@ # Create: 2021 # History: NA +export GLOG_CUSTOM_PREFIX_SUPPORT=1 + set -e warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } ARCH="$(uname -m)" diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 6cf25dee..2cd84a80 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -48,7 +48,9 @@ fi cur_path=`pwd` mx_rec_package_path="/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec" # please config so_path=${mx_rec_package_path}/libasc -mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x SPDLOG_LEVEL=debug -bind-to none' +# GLOG_stderrthreshold 0:INFO 1:WARNING 2:ERROR 3:FATAL +# GLOG_v 1:DEBUG(print as INFO) 2:TRACE(print as INFO) +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=0 -x GLOG_logtostderr=true -x GLOG_v=0 -bind-to none' interface="lo" local_rank_size=8 # 每个节点使用的NPU卡数 num_server=1 # 训练节点数 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6e07d657..3f4118bd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -94,13 +94,13 @@ endif() if(IS_DIRECTORY ${OPENSOURCE_DIR}) add_subdirectory(${OPENSOURCE_DIR}/pybind11 pybind11.out) - add_subdirectory(${OPENSOURCE_DIR}/spdlog spdlog.out) + add_subdirectory(${OPENSOURCE_DIR}/glog glog.out) else() - message(FATAL_ERROR "INVALID FOLDER") + message(FATAL_ERROR "INVALID FOLDER, ${OPENSOURCE_DIR}") endif() -include_directories(${PROJECT_SOURCE_DIR}/../../opensource/opensource/spdlog/include) -install(TARGETS spdlog LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) +include_directories(${PROJECT_SOURCE_DIR}/../../opensource/opensource/glog/include) +install(TARGETS glog LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) add_subdirectory(core) add_subdirectory(ops_tf) add_subdirectory(pybind) diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 5ff198d8..266f4335 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -1,6 +1,8 @@ cmake_minimum_required(VERSION 3.12) set(CMAKE_CXX_STANDARD 17) +set(WITH_CUSTOM_PREFIX ON) + if(NOT ABSEIL_PATH) set(ABSEIL_PATH ${PYTHON_PATH}/lib/python3.7/site-packages/tensorflow_core/) endif() @@ -46,7 +48,7 @@ target_link_libraries(ASC PUBLIC -l:_tf_adapter.so OpenMP::OpenMP_CXX ${MPI_CXX_LIBRARIES} ${PYTHON_LIBRARY} - PRIVATE spdlog::spdlog + PUBLIC glog::glog ) find_package(easy_profiler PATHS ${EASY_PROFILER_PATH} NO_DEFAULT_PATH) if (easy_profiler_FOUND) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 02343b7a..222f8565 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -5,7 +5,6 @@ * Create: 2022-11-15 */ -#include #include #include #include @@ -30,12 +29,12 @@ void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRa useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; mgmtEmbInfo = EmbInfo; - spdlog::info("Start host side saving data."); - spdlog::debug("==Start to create save data handler."); + LOG(INFO) << "Start host side saving data."; + VLOG(GLOG_DEBUG) << "==Start to create save data handler."; SetDataHandler(ckptData); - spdlog::debug("==Start save data process."); + VLOG(GLOG_DEBUG) << "==Start save data process."; SaveProcess(ckptData); - spdlog::info("Finish host side saving data."); + LOG(INFO) << "Finish host side saving data."; } void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo, @@ -47,12 +46,12 @@ void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRa useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; mgmtEmbInfo = EmbInfo; - spdlog::info("Start host side loading data."); - spdlog::debug("==Start to create load data handler."); + LOG(INFO) << "Start host side loading data."; + VLOG(GLOG_DEBUG) << "==Start to create load data handler."; SetDataHandler(featureTypes); - spdlog::debug("==Start load data process."); + VLOG(GLOG_DEBUG) << "==Start load data process."; LoadProcess(ckptData); - spdlog::info("Finish host side loading data."); + LOG(INFO) << "Finish host side loading data."; } void Checkpoint::SetDataHandler(CkptData& ckptData) @@ -137,7 +136,7 @@ void Checkpoint::MakeSaveDir(const string& dirName) { if (access(dirName.c_str(), F_OK) == -1) { if (mkdir(dirName.c_str(), dirMode) == -1) { - spdlog::debug("Unable to create directory: {}", dirName); + VLOG(GLOG_DEBUG) << StringFormat("Unable to create directory: %s", dirName.c_str()); } } } @@ -177,7 +176,7 @@ void Checkpoint::SaveDataset(const vector& embNames, auto datasetDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; auto attributeDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + attribFileType }; - spdlog::debug("====Start getting data from handler to: {}", datasetDir); + VLOG(GLOG_DEBUG) << StringFormat("====Start getting data from handler to: %s", datasetDir.c_str()); auto transData { dataHandler->GetDataset(saveDataType, embName) }; // save embedding when dynamic expansion is open @@ -186,13 +185,13 @@ void Checkpoint::SaveDataset(const vector& embNames, auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; auto embeddingSize = GetEmbeddingSize(embName); MakeSaveDir(embedPath); - spdlog::debug("====Start saving embedding data to: {}", datasetDir); + VLOG(GLOG_DEBUG) << StringFormat("====Start saving embedding data to: %s", datasetDir.c_str()); WriteEmbedding(transData, embedDatasetDir, embeddingSize); } - spdlog::debug("====Start saving data to: {}", datasetDir); + VLOG(GLOG_DEBUG) << StringFormat("====Start saving data to: %s", datasetDir.c_str()); WriteStream(transData, datasetDir, transData.datasetSize, saveDataType); - spdlog::debug("====Start saving data to: {}", attributeDir); + VLOG(GLOG_DEBUG) << StringFormat("====Start saving data to: %s", attributeDir.c_str()); WriteStream(transData, attributeDir, transData.attributeSize, CkptDataType::ATTRIBUTE); } } @@ -206,8 +205,8 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da #ifndef GTEST auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { - spdlog::error("Set device failed, device_id:{}", deviceId); - throw runtime_error(fmt::format("Set device failed, device_id:{}", deviceId).c_str()); + LOG(ERROR) << StringFormat("Set device failed, device_id:%d", deviceId); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); } auto &transArr = transData.int64Arr; @@ -220,8 +219,8 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da floatPtr, embeddingSize * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtMemcpy failed, ret={}", ret); - throw runtime_error(fmt::format("aclrtMemcpy failed, ret={}", ret).c_str()); + LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } writeFile.write((const char *) (row.data()), embeddingSize * sizeof(float)); @@ -238,8 +237,8 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) #ifndef GTEST auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { - spdlog::error("Set device failed, device_id:{}", deviceId); - throw runtime_error(fmt::format("Set device failed, device_id:{}", deviceId).c_str()); + LOG(ERROR) << StringFormat("Set device failed, device_id:%d", deviceId); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); } auto &AttributeArr = transData.attribute; @@ -252,8 +251,8 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) void *newBlock = nullptr; ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtMalloc failed, ret={}", ret); - throw runtime_error(fmt::format("aclrtMemcpy failed, ret={}", ret).c_str()); + LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } float *floatPtr = static_cast(newBlock); @@ -266,8 +265,8 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtMemcpy failed, ret={}", ret); - throw runtime_error(fmt::format("aclrtMemcpy failed, ret={}", ret).c_str()); + LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } int64_t address = reinterpret_cast(floatPtr + i * embeddingSize); @@ -283,7 +282,7 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); if (!writeFile.is_open()) { - spdlog::debug("unable to open save file: {}", dataDir); + VLOG(GLOG_DEBUG) << StringFormat("unable to open save file: %s", dataDir.c_str()); writeFile.close(); return; } @@ -393,7 +392,7 @@ void Checkpoint::LoadDataset(const vector& embNames, CkptTransData transData; - spdlog::debug("====Start reading data from: {}", attributeDir); + VLOG(GLOG_DEBUG) << StringFormat("====Start reading data from: %s", attributeDir.c_str()); auto dataElmtBytes { dataHandler->GetDataElmtBytes(CkptDataType::ATTRIBUTE) }; ReadStream(transData, attributeDir, CkptDataType::ATTRIBUTE, dataElmtBytes); @@ -402,7 +401,7 @@ void Checkpoint::LoadDataset(const vector& embNames, ReadStreamForEmbData(transData, datasetDir, dataElmtBytes, ckptData, embName); continue; } else { - spdlog::debug("====Start reading data from: {}", datasetDir); + VLOG(GLOG_DEBUG) << StringFormat("====Start reading data from: %s", datasetDir.c_str()); ReadStream(transData, datasetDir, saveDataType, dataElmtBytes); } @@ -410,11 +409,13 @@ void Checkpoint::LoadDataset(const vector& embNames, if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { auto embedPath { dataDir + dirSeparator + "key_embedding" }; auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; - spdlog::debug("====Start loading embedding data from: {}", datasetDir); + VLOG(GLOG_DEBUG) << StringFormat("====Start loading embedding data from: %s", datasetDir.c_str()); ReadEmbedding(transData, embedDatasetDir); } - spdlog::debug("====Start loading data from: {} to data handler.", attributeDir); + VLOG(GLOG_DEBUG) << StringFormat( + "====Start loading data from: %s to data handler.", attributeDir.c_str() + ); if ((saveDataType == CkptDataType::EMB_INFO)) { dataHandler->SetDatasetForLoadEmb(saveDataType, embName, transData, ckptData); } else { @@ -430,7 +431,7 @@ void Checkpoint::ReadStream(CkptTransData& transData, uint32_t dataElmtBytes) { if (dataElmtBytes == 0) { - spdlog::warn("dataElmtBytes is 0, don't handle [/ %] operation"); + LOG(WARNING) << "dataElmtBytes is 0, don't handle [/ %] operation"; return ; } std::ifstream readFile; @@ -440,7 +441,7 @@ void Checkpoint::ReadStream(CkptTransData& transData, readFile.seekg(0, std::ios::beg); if (datasetSize % dataElmtBytes > 0) { - spdlog::debug("data is missing or incomplete in load file: {}", dataDir); + VLOG(GLOG_DEBUG) << StringFormat("data is missing or incomplete in load file: %s", dataDir.c_str()); } auto resizeSize { datasetSize / dataElmtBytes }; SetTransDataSize(transData, resizeSize, dataType); @@ -459,7 +460,7 @@ void Checkpoint::ReadStream(CkptTransData& transData, idx += readSize; } } else { - spdlog::debug("unable to open load file: {}", dataDir); + VLOG(GLOG_DEBUG) << StringFormat("unable to open load file: %s", dataDir.c_str()); } readFile.close(); @@ -472,7 +473,7 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, string embName) { if (dataElmtBytes == 0) { - spdlog::error("dataElmtBytes is 0, don't handle [/ %] operation"); + LOG(ERROR) << "dataElmtBytes is 0, don't handle [/ %] operation"; return ; } @@ -489,13 +490,13 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, readFile.seekg(0, std::ios::beg); if (datasetSize % embDataOuterSize > 0 || datasetSize % dataElmtBytes > 0) { - spdlog::error("data is missing or incomplete in load file: {}", dataDir); + LOG(ERROR) << StringFormat("data is missing or incomplete in load file: %s", dataDir.c_str()); throw runtime_error("unable to load EMB_DATA cause wrong-format saved emb data"); } auto onceReadByteSize { datasetSize / embDataOuterSize }; if (!readFile.is_open()) { - spdlog::debug("unable to open load file: {}", dataDir); + VLOG(GLOG_DEBUG) << StringFormat("unable to open load file: %s", dataDir.c_str()); readFile.close(); return; } diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.cpp b/src/core/ckpt_data_handler/ckpt_data_handler.cpp index 92ffb8ba..9f5a5522 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.cpp +++ b/src/core/ckpt_data_handler/ckpt_data_handler.cpp @@ -4,7 +4,6 @@ * Author: MindX SDK * Create: 2022-11-12 */ -#include #include "ckpt_data_handler.h" @@ -34,7 +33,9 @@ void CkptDataHandler::CleanTransfer() void CkptDataHandler::SetDatasetForLoadEmb(CkptDataType dataType, string embName, CkptTransData& loadedData, CkptData& ckptData) { - spdlog::error("Load host emb failed. dataType:{}, embName:{}, loadedData:{}, ckptData:{}", dataType, embName, - loadedData.datasetSize, ckptData.embHashMaps.empty()); + LOG(ERROR) << StringFormat( + "Load host emb failed. dataType:%d, embName:%s, loadedData:%d, ckptData:%d", + dataType, embName.c_str(), loadedData.datasetSize, ckptData.embHashMaps.empty() + ); throw runtime_error("only EMB_INFO and EMB_DATA supported for load host emb"); } \ No newline at end of file diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index 985ed29f..1e4e9d69 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -5,7 +5,6 @@ * Create: 2022-11-22 */ -#include #include "feat_admit_n_evict_ckpt.h" @@ -17,7 +16,7 @@ void FeatAdmitNEvictCkpt::SetProcessData(CkptData& processData) ClearData(); if (processData.tens2Thresh.empty() || processData.histRec.timestamps.empty() || processData.histRec.historyRecords.empty()) { - spdlog::error("Missing Feature Admit and Evict data"); + LOG(ERROR) << "Missing Feature Admit and Evict data"; throw std::runtime_error("Missing Feature Admit and Evict data"); } saveTens2Thresh = std::move(processData.tens2Thresh); diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp index f236637c..9b9ffd6f 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp @@ -5,7 +5,6 @@ * Create: 2022-11-12 */ -#include #include "host_emb_ckpt.h" @@ -25,7 +24,7 @@ void HostEmbCkpt::GetProcessData(CkptData& processData) { saveHostEmbs = nullptr; loadHostEmbs = nullptr; - spdlog::info("processData.embHashMaps.empty():{}", processData.embHashMaps.empty()); + LOG(INFO) << StringFormat("processData.embHashMaps.empty():%d", processData.embHashMaps.empty()); } vector HostEmbCkpt::GetDataTypes() @@ -60,7 +59,9 @@ CkptTransData HostEmbCkpt::GetDataset(CkptDataType dataType, string embName) void HostEmbCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { - spdlog::info("Parameter dataType:{}, embName:{}, loadedData:{}", dataType, embName, loadedData.datasetSize); + LOG(INFO) << StringFormat( + "Parameter dataType:%d, embName:%s, loadedData:%d", dataType, embName.c_str(), loadedData.datasetSize + ); return; } @@ -118,7 +119,7 @@ void HostEmbCkpt::SetEmbInfo(string embName, CkptData& ckptData) // load Emb data void HostEmbCkpt::SetEmbData(string embName, CkptData& ckptData) { - spdlog::info("Parameter embName:{}, ckptData:{}", embName, ckptData.embHashMaps.empty()); + LOG(INFO) << StringFormat("Parameter embName:%s, ckptData:%d", embName.c_str(), ckptData.embHashMaps.empty()); return; } diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp index 2d3bc5a3..6379ed18 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp @@ -4,9 +4,6 @@ * Author: MindX SDK * Create: 2022-11-17 */ -#include -#include - #include "nddr_feat_map_ckpt.h" @@ -67,7 +64,7 @@ CkptTransData NddrFeatMapCkpt::GetDataset(CkptDataType dataType, string embName) transArr.push_back(it.first); transArr.push_back(it.second); } - spdlog::info("CkptDataType::EMB_INFO:{}, dataType{} is", CkptDataType::EMB_INFO, dataType); + LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType%d is", CkptDataType::EMB_INFO, dataType); return move(transferData); } @@ -86,5 +83,5 @@ void NddrFeatMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTran int64_t key { transArr.at(i) }; hostHashMap[key] = transArr.at(i + 1); } - spdlog::info("dataType{} is", dataType); + LOG(INFO) << StringFormat("dataType%d is", dataType); } diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp index 1dd065ca..356fb80d 100644 --- a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp @@ -5,7 +5,6 @@ * Create: 2022-11-17 */ -#include #include "nddr_offset_ckpt.h" @@ -53,7 +52,7 @@ CkptTransData NddrOffsetCkpt::GetDataset(CkptDataType dataType, string embName) transferData.attribute.push_back(1); transferData.attribute.push_back(fourBytes); transferData.attributeSize = transferData.attribute.size() * eightBytes; - spdlog::info("CkptDataType::EMB_INFO:{}, dataType:{} is", CkptDataType::EMB_INFO, dataType); + LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d is", CkptDataType::EMB_INFO, dataType); return move(transferData); } @@ -62,5 +61,5 @@ void NddrOffsetCkpt::SetDataset(CkptDataType dataType, string embName, CkptTrans CleanTransfer(); transferData = move(loadedData); loadMaxOffset[embName] = transferData.int32Arr.front(); - spdlog::info("CkptDataType::EMB_INFO:{}, dataType:{} is", CkptDataType::EMB_INFO, dataType); + LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d is", CkptDataType::EMB_INFO, dataType); } diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index b239f292..0bd54d63 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -7,21 +7,21 @@ #include "emb_hashmap.h" #include -#include #include #include -#include #include "hd_transfer/hd_transfer.h" #include "checkpoint/checkpoint.h" +#include "utils/common.h" using namespace MxRec; void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad) { +#ifndef GTEST this->rankInfo = rankInfo; if (!ifLoad) { EmbHashMapInfo embHashMap; - spdlog::info("init emb hash map from scratch"); + LOG(INFO) << "init emb hash map from scratch"; for (const auto& embInfo: embInfos) { embHashMap.devOffset2Batch.resize(embInfo.devVocabSize); embHashMap.devOffset2Key.resize(embInfo.devVocabSize); @@ -31,10 +31,18 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, fill(embHashMap.devOffset2Batch.begin(), embHashMap.devOffset2Batch.end(), -1); fill(embHashMap.devOffset2Key.begin(), embHashMap.devOffset2Key.end(), -1); embHashMaps[embInfo.name] = embHashMap; - spdlog::trace("devOffset2Key, {}", embHashMaps.at(embInfo.name).devOffset2Key); - spdlog::trace("devOffset2Batch, {}", embHashMaps.at(embInfo.name).devOffset2Batch); + + if (VLOG_IS_ON(GLOG_TRACE)) { + VLOG(GLOG_TRACE) << StringFormat( + "devOffset2Key, %s", VectorToString(embHashMaps.at(embInfo.name).devOffset2Key).c_str() + ); + VLOG(GLOG_TRACE) << StringFormat( + "devOffset2Batch, %s", VectorToString(embHashMaps.at(embInfo.name).devOffset2Batch).c_str() + ); + } } } +#endif } void EmbHashMap::Process(const string& embName, vector& keys, size_t iBatch, @@ -49,14 +57,14 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t auto keepBatch = swapId - iBatch; bool findOffsetV2 = getenv("FIND_OFFSET_V2") != nullptr; - spdlog::debug("FindOffset, {}", findOffsetV2); + VLOG(GLOG_DEBUG) << StringFormat("FindOffset, %s", findOffsetV2); if (findOffsetV2) { FindAndUpdateOffset(embName, keys, swapId, keepBatch, channelId); } else { FindOffset(embName, keys, swapId, keepBatch, channelId); } - spdlog::debug("FindOffset end"); + VLOG(GLOG_DEBUG) << "FindOffset end"; swapId++; EASY_BLOCK("hostHashMaps->tdt") @@ -68,7 +76,9 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t for (int i = 0; i < lookUpVecSize; i++) { lookupTensorData(i) = static_cast(embHashMap.lookUpVec[i]); } - spdlog::trace("lookupTensor, {}", embHashMap.lookUpVec); + if (VLOG_IS_ON(GLOG_TRACE)) { + VLOG(GLOG_TRACE) << StringFormat("lookupTensor, %s", VectorToString(embHashMap.lookUpVec).c_str()); + } auto swapSize = static_cast(embHashMap.swapPos.size()); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { swapSize })); @@ -77,13 +87,15 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t swapTensorData(i) = static_cast(embHashMap.swapPos[i]); } if (swapSize > 0) { - spdlog::debug("swap num: {}", swapSize); + VLOG(GLOG_DEBUG) << StringFormat("swap num: %d", swapSize); + } + if (VLOG_IS_ON(GLOG_TRACE)) { + VLOG(GLOG_TRACE) << StringFormat("swapTensor, %s", VectorToString(embHashMap.swapPos).c_str()); } - spdlog::trace("swapTensor, {}", embHashMap.swapPos); embHashMap.swapPos.clear(); embHashMap.lookUpVec.clear(); - spdlog::info("current dev emb usage:{}/[{}+{}]", embHashMap.maxOffset, embHashMap.devVocabSize, - embHashMap.hostVocabSize); + LOG(INFO) << StringFormat("current dev emb usage:%d/[%d+%d]", embHashMap.maxOffset, embHashMap.devVocabSize, + embHashMap.hostVocabSize); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto swapLen = tmpDataOut.back().flat(); swapLen(0) = swapSize; @@ -146,25 +158,27 @@ int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashM } else if (embHashMap.evictDevPos.size() != 0) { // 优先复用hbm表 offset = static_cast(embHashMap.evictDevPos.back()); embHashMap.hostHashMap[key] = offset; - spdlog::trace("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, embHashMap.evictDevPos.size()); + VLOG(GLOG_TRACE) << StringFormat( + "ddr mode, dev evictPos is not null, key [%d] reuse offset [%d], evictSize [%d]", + key, offset, embHashMap.evictDevPos.size()); embHashMap.evictDevPos.pop_back(); } else if (embHashMap.evictPos.size() != 0) { // hbm不足,再复用ddr表 offset = static_cast(embHashMap.evictPos.back()); embHashMap.hostHashMap[key] = offset; - spdlog::trace("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, embHashMap.evictPos.size()); + VLOG(GLOG_TRACE) << StringFormat( + "ddr mode, host evictPos is not null, key [%d] reuse offset [%d], evictSize [%d]", + key, offset, embHashMap.evictPos.size()); embHashMap.evictPos.pop_back(); } else { embHashMap.hostHashMap[key] = embHashMap.maxOffset; offset = static_cast(embHashMap.maxOffset); embHashMap.maxOffset++; if (embHashMap.maxOffset == embHashMap.devVocabSize) { - spdlog::info("start using host vocab!"); + LOG(INFO) << "start using host vocab!"; } if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { - spdlog::error("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, - embHashMap.hostVocabSize); + LOG(ERROR) << StringFormat("hostVocabSize too small! dev:%d host:%d", embHashMap.devVocabSize, + embHashMap.hostVocabSize); throw runtime_error("hostVocabSize too small"); } } @@ -212,7 +226,7 @@ void EmbHashMap::FindPos(EmbHashMapInfo& embHashMap, int num, size_t keepBatchId embHashMap.currentUpdatePos = 0; } if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { - spdlog::error("devVocabSize is too small"); + LOG(ERROR) << "devVocabSize is too small"; throw runtime_error("devVocabSize is too small"); } } @@ -268,14 +282,14 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& size_t offset; auto key = keys[i]; if (key == -1) { - spdlog::warn("evict key equal -1!"); + LOG(WARNING) << "evict key equal -1!"; continue; } const auto& iter = embHashMap.hostHashMap.find(key); if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; embHashMap.hostHashMap.erase(iter); - spdlog::trace("evict embName {} , offset , {}", embName, offset); + VLOG(GLOG_TRACE) << StringFormat("evict embName %s , offset , %d", embName.c_str(), offset); } else { // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 continue; @@ -291,9 +305,13 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& } } - spdlog::info("ddr EvictDeleteEmb, emb: [{}], hostEvictSize: {}, devEvictSize: {} ", - embName, embHashMap.evictPos.size(), embHashMap.evictDevPos.size()); - spdlog::trace("hostHashMap, {}", embHashMaps[embName].hostHashMap); + LOG(INFO) << StringFormat( + "ddr EvictDeleteEmb, emb: [%s], hostEvictSize: %d, devEvictSize: %d ", + embName.c_str(), embHashMap.evictPos.size(), embHashMap.evictDevPos.size() + ); + if (VLOG_IS_ON(GLOG_TRACE)) { + VLOG(GLOG_TRACE) << StringFormat("hostHashMap, %s", MapToString(embHashMaps[embName].hostHashMap).c_str()); + } } // old version @@ -325,9 +343,11 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys } } if (currentBatchId == 0) { - spdlog::info("max offset {}", embHashMap.maxOffset); + LOG(INFO) << StringFormat("max offset %d", embHashMap.maxOffset); + } + if (VLOG_IS_ON(GLOG_TRACE)) { + VLOG(GLOG_TRACE) << StringFormat("hostHashMap, %s", MapToString(embHashMaps[embName].hostHashMap).c_str()); } - spdlog::trace("hostHashMap, {}", embHashMaps[embName].hostHashMap); } @@ -338,18 +358,21 @@ size_t EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHas const auto& iter = embHashMap.hostHashMap.find(key); if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; - spdlog::trace("devVocabSize, {} , offset , {}", embHashMap.devVocabSize, offset); + VLOG(GLOG_TRACE) << StringFormat("devVocabSize, %d , offset , %d", embHashMap.devVocabSize, offset); } else if (embHashMap.evictDevPos.size() != 0 && channelId == TRAIN_CHANNEL_ID) { // 优先复用hbm表 offset = embHashMap.evictDevPos.back(); embHashMap.hostHashMap[key] = offset; - spdlog::trace("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, embHashMap.evictDevPos.size()); + VLOG(GLOG_TRACE) << StringFormat( + "ddr mode, dev evictPos is not null, key [%d] reuse offset [%d], evictSize [%d]", + key, offset, embHashMap.evictDevPos.size() + ); embHashMap.evictDevPos.pop_back(); } else if (embHashMap.evictPos.size() != 0 && channelId == TRAIN_CHANNEL_ID) { // hbm不足,再复用ddr表 offset = embHashMap.evictPos.back(); embHashMap.hostHashMap[key] = offset; - spdlog::trace("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, embHashMap.evictPos.size()); + VLOG(GLOG_TRACE) << StringFormat( + "ddr mode, host evictPos is not null, key [%d] reuse offset [%d], evictSize [%d]", + key, offset, embHashMap.evictPos.size()); embHashMap.evictPos.pop_back(); } else { if (channelId == TRAIN_CHANNEL_ID) { @@ -357,11 +380,11 @@ size_t EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHas offset = embHashMap.maxOffset; embHashMap.maxOffset++; if (embHashMap.maxOffset == embHashMap.devVocabSize) { - spdlog::info("start using host vocab!"); + LOG(INFO) << ("start using host vocab!"); } if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { - spdlog::error("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, - embHashMap.hostVocabSize); + LOG(ERROR) << StringFormat( + "hostVocabSize too small! dev:%d host:%d", embHashMap.devVocabSize, embHashMap.hostVocabSize); throw runtime_error("hostVocabSize too small"); } } else { @@ -384,7 +407,7 @@ void EmbHashMap::UpdateBatchId(const vector& keys, size_t currentBatc if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; - spdlog::trace("key will be used, {} , offset , {}", key, offset); + VLOG(GLOG_TRACE) << StringFormat("key will be used, %d , offset , %d", key, offset); if (offset < embHashMap.devVocabSize) { embHashMap.devOffset2Batch[offset] = static_cast(currentBatchId); } @@ -421,7 +444,7 @@ int EmbHashMap::FindSwapPosV2(const string& embName, emb_key_t key, size_t hostO embHashMap.currentUpdatePos = 0; } if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { - spdlog::error("devVocabSize is too small"); + LOG(ERROR) << "devVocabSize is too small"; throw runtime_error("devVocabSize is too small"); } } @@ -456,7 +479,7 @@ bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hos embHashMap.currentUpdatePos = 0; } if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { - spdlog::error("devVocabSize is too small"); + LOG(ERROR) << "devVocabSize is too small"; throw runtime_error("devVocabSize is too small"); } } diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 7973e528..3ccbe345 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -8,8 +8,6 @@ #include #include #include -#include -#include #include #include "acl/acl_base.h" #include "utils/common.h" @@ -25,11 +23,12 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) #ifndef GTEST this->rankInfo = rInfo; this->seed = seed; - spdlog::info("EmbTable init, deviceID {}, embSize {} running", rInfo.deviceId, embInfo.extEmbeddingSize); + LOG(INFO) << StringFormat( + "EmbTable init, deviceID %d, embSize %d running", rInfo.deviceId, embInfo.extEmbeddingSize); // 计算embedding table需要分配的内存块数 auto ret = aclrtSetDevice(static_cast(rInfo.deviceId)); if (ret != ACL_ERROR_NONE) { - spdlog::error("Set device failed, device_id:{}, ret={}", rInfo.deviceId, ret); + LOG(ERROR) << StringFormat("Set device failed, device_id:%d, ret=%d", rInfo.deviceId, ret); throw AclError(); } embSize = embInfo.extEmbeddingSize; @@ -39,7 +38,7 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) void *newBlock = nullptr; aclError ret = aclrtMalloc(&newBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtMalloc failed, ret={}", ret); + LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); throw AclError(); } if (newBlock == nullptr) { @@ -54,7 +53,9 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) } } totalCapacity = static_cast(memoryList.size()); - spdlog::info("aclrtMalloc success, emb name:{}, total capacity:{}", embInfo.name, totalCapacity); + LOG(INFO) << StringFormat( + "aclrtMalloc success, emb name:%s, total capacity:%d", embInfo.name.c_str(), totalCapacity + ); #endif } @@ -65,7 +66,7 @@ EmbTable::~EmbTable() // 释放内存块 aclError ret = aclrtFree(block); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtFree failed, ret={}", ret); + LOG(ERROR) << StringFormat("aclrtFree failed, ret=%d", ret); } } #endif @@ -77,11 +78,11 @@ int64_t EmbTable::GetEmbAddress() #ifndef GTEST if (embeddingList.empty()) { PrintStatus(); - spdlog::debug("GetEmbAddress, embedding_list size: empty! Add block!"); + VLOG(GLOG_DEBUG) << "GetEmbAddress, embedding_list size: empty! Add block!"; void *addBlock = nullptr; aclError ret = aclrtMalloc(&addBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtMalloc failed, ret={}", ret); + LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); throw AclError(); } if (addBlock == nullptr) { @@ -105,43 +106,49 @@ int64_t EmbTable::GetEmbAddress() // 将一个emb地址放入embeddingList中 void EmbTable::PutEmbAddress(int64_t curAddress) { +#ifndef GTEST embeddingList.push_back(reinterpret_cast(curAddress)); usedCapacity--; +#endif } void EmbTable::RandomInit(void* newBlock, const vector& initializeInfos, int seed) { #ifndef GTEST - spdlog::info("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); + LOG(INFO) << StringFormat( + "Device GenerateEmbData Start, seed:%d, initializer num: %d", seed, initializeInfos.size()); vector devEmb(blockSize); for (auto initializeInfo: initializeInfos) { Initializer* initializer; switch (initializeInfo.initializerType) { case InitializerType::CONSTANT: { - spdlog::info("Device GenerateEmbData ing using Constant Initializer by value {}. name {}, start {}, " - "len {}.", initializeInfo.constantInitializerInfo.constantValue, - initializeInfo.name, initializeInfo.start, initializeInfo.len); + LOG(INFO) << StringFormat( + "Device GenerateEmbData ing using Constant Initializer by value %d. name %s, start %d, len %d.", + initializeInfo.constantInitializerInfo.constantValue, + initializeInfo.name.c_str(), initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.constantInitializer; break; } case InitializerType::TRUNCATED_NORMAL: { - spdlog::info("Device GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}. " - "name {}, start {}, len {}.", initializeInfo.normalInitializerInfo.mean, - initializeInfo.normalInitializerInfo.stddev, initializeInfo.name, - initializeInfo.start, initializeInfo.len); + LOG(INFO) << StringFormat( + "Device GenerateEmbData ing using Truncated Normal Initializer by mean: %f stddev: %f. " + "name %s, start %d, len %d.", initializeInfo.normalInitializerInfo.mean, + initializeInfo.normalInitializerInfo.stddev, initializeInfo.name.c_str(), + initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.truncatedNormalInitializer; break; } case InitializerType::RANDOM_NORMAL: { - spdlog::info("Device GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}. " - "name {}, start {}, len {}.", initializeInfo.normalInitializerInfo.mean, - initializeInfo.normalInitializerInfo.stddev, initializeInfo.name, - initializeInfo.start, initializeInfo.len); + LOG(INFO) << StringFormat( + "Device GenerateEmbData ing using Random Normal Initializer by mean: %f stddev: %f. " + "name %s, start %d, len %d.", initializeInfo.normalInitializerInfo.mean, + initializeInfo.normalInitializerInfo.stddev, initializeInfo.name.c_str(), + initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.randomNormalInitializer; break; } default: { - spdlog::warn("Device Invalid Initializer Type. Using default Constant Initializer with value 0."); + LOG(WARNING) << "Device Invalid Initializer Type. Using default Constant Initializer with value 0."; ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0, 1); initializer = &defaultInitializer; } @@ -150,12 +157,18 @@ void EmbTable::RandomInit(void* newBlock, const vector& initiali initializer->GenerateData(&devEmb[i * embSize], embSize); } } - spdlog::info("Device GenerateEmbData End, seed:{}", seed); - aclError ret = aclrtMemcpy(newBlock, blockSize * sizeof(float), - devEmb.data(), blockSize * sizeof(float), - ACL_MEMCPY_HOST_TO_DEVICE); + LOG(INFO) << StringFormat("Device GenerateEmbData End, seed:%d", seed); + ExecuteAclMemcpy(newBlock, devEmb); +#endif +} + +void EmbTable::ExecuteAclMemcpy(void* newBlock, vector devEmb) +{ +#ifndef GTEST + aclError ret = aclrtMemcpy( + newBlock, blockSize * sizeof(float), devEmb.data(), blockSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtMemcpy failed, ret={}", ret); + LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); throw AclError(); } #endif @@ -164,6 +177,7 @@ void EmbTable::RandomInit(void* newBlock, const vector& initiali void EmbTable::SplitMemoryBlock(void *newBlock) { +#ifndef GTEST if (embSize == 0) { throw std::runtime_error("SplitMemoryBlock by embSize=0!"); } @@ -171,14 +185,15 @@ void EmbTable::SplitMemoryBlock(void *newBlock) float *embPtr = static_cast(newBlock) + i * embSize; embeddingList.push_back(embPtr); } +#endif } void EmbTable::PrintStatus() { // 输出embedding table的总容量 - spdlog::info("Total capacity:{}", totalCapacity * blockSize); + LOG(INFO) << StringFormat("Total capacity:%d", totalCapacity * blockSize); // 输出embedding table的未使用的使用容量 - spdlog::info("Unused capacity:{}", totalCapacity * blockSize - usedCapacity * embSize); + LOG(INFO) << StringFormat("Unused capacity:%d", totalCapacity * blockSize - usedCapacity * embSize); } // 用于保存 @@ -198,7 +213,7 @@ map> EmbTable::SaveEmb() floatPtr + i * embSize, embSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtMemcpy failed, ret={}", ret); + LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); throw AclError(); } savedEmb[reinterpret_cast(floatPtr + i * embSize)] = move(row); @@ -215,14 +230,14 @@ list EmbTable::LoadEmb(const vector> &savedEmb) list addressList; int embCapacity = static_cast(savedEmb.size()); if (savedEmb.size() == 0 || savedEmb[0].size() == 0) { - spdlog::error("Load invalid savedEmb"); + LOG(ERROR) << "Load invalid savedEmb"; return addressList; } embSize = static_cast(savedEmb[0].size()); void *newBlock = nullptr; aclError ret = aclrtMalloc(&newBlock, embCapacity * embSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtMalloc failed, ret={}", ret); + LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); throw AclError(); } if (newBlock == nullptr) { @@ -235,7 +250,7 @@ list EmbTable::LoadEmb(const vector> &savedEmb) savedEmb[i].data(), embSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { - spdlog::error("aclrtMemcpy failed, ret={}", ret); + LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); throw AclError(); } addressList.push_back(floatPtr + i * embSize); diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h index 34370add..500a324c 100644 --- a/src/core/emb_table/emb_table.h +++ b/src/core/emb_table/emb_table.h @@ -49,6 +49,8 @@ namespace MxRec { EmbTable& operator=(EmbTable&&) = delete; + void ExecuteAclMemcpy(void* newBlock, vector devEmb); + // 用于保存 map> SaveEmb(); diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 8a4872d6..c748507e 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -6,9 +6,8 @@ */ #include "hd_transfer.h" #include -#include -#include #include "utils/common.h" +#include "utils/time_cost.h" using namespace MxRec; using namespace std; @@ -16,20 +15,20 @@ using namespace std; int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) { #ifndef GTEST - spdlog::info(MGMT + "begin hd_transfer initialize, rank:{}", localRankId); + LOG(INFO) << StringFormat(MGMT + "begin hd_transfer initialize, rank:%d", localRankId); aclError retOk = aclInit(nullptr); - spdlog::info(MGMT + "end aclInit, rank:{}", localRankId); + LOG(INFO) << StringFormat(MGMT + "end aclInit, rank:%d", localRankId); if (retOk != ACL_SUCCESS) { - spdlog::error(MGMT + "aclInit fail, rank:{}, errno:{}", localRankId, retOk); + LOG(ERROR) << StringFormat(MGMT + "aclInit fail, rank:%d, errno:%d", localRankId, retOk); return false; } - spdlog::info(MGMT + "start Set device, rank:{}", localRankId); + LOG(INFO) << StringFormat(MGMT + "start Set device, rank:%d", localRankId); auto ret = aclrtSetDevice(static_cast(localRankId)); if (ret != ACL_ERROR_NONE) { - spdlog::error("Set device failed, device_id:{}", localRankId); + LOG(ERROR) << StringFormat("Set device failed, device_id:%d", localRankId); return false; } - spdlog::info(MGMT + "end Set device, rank:{}", localRankId); + LOG(INFO) << StringFormat(MGMT + "end Set device, rank:%d", localRankId); for (const auto& embInfo: embInfos) { auto embName = embInfo.name; for (int i = 0; i < MAX_CHANNEL_NUM; ++i) { @@ -40,16 +39,16 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) const char* timeoutEnv = getenv("AclTimeout"); if (timeoutEnv != nullptr) { int32_t timeoutEnvCast = static_cast(std::atoi(timeoutEnv)); - spdlog::debug("timeoutEnv:{}", timeoutEnvCast); + VLOG(GLOG_DEBUG) << StringFormat("timeoutEnv:%d", timeoutEnvCast); if (timeoutEnvCast > INT32_MAX || timeoutEnvCast < -1) { - spdlog::warn("AclTimeout={} is not valid", timeoutEnvCast); + LOG(WARNING) << StringFormat("AclTimeout=%d is not valid", timeoutEnvCast); } else { timeout = timeoutEnvCast; } } - spdlog::debug("hd transfer timeout:{}", timeout); + VLOG(GLOG_DEBUG) << StringFormat("hd transfer timeout:%d", timeout); running = true; - spdlog::info("hd_transfer init"); + LOG(INFO) << "hd_transfer init"; #endif return true; } @@ -58,10 +57,10 @@ void HDTransfer::Destroy() { #ifndef GTEST running = false; - spdlog::info(HD + "destroy channel start"); + LOG(INFO) << (HD + "destroy channel start"); for (auto& c: transferChannels) { tensorflow::StopRecvTensorByAcl(&c.second, c.first); - spdlog::info(HD + "destroy channel:{}", c.first); + LOG(INFO) << StringFormat(HD + "destroy channel:%s", c.first.c_str()); } for (auto& d: aclDatasets) { if (acltdtDestroyDataset(d.second) != ACL_ERROR_NONE) { @@ -83,20 +82,22 @@ void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName try { channelSize = stoi(env); } catch (const std::invalid_argument& e) { - spdlog::warn("wrong HD_CHANNEL_SIZE env {}", e.what()); + LOG(WARNING) << StringFormat("wrong HD_CHANNEL_SIZE env %s", e.what()); channelSize = LARGE_CHANNEL_SIZE; } catch (const std::out_of_range& e) { - spdlog::warn("wrong HD_CHANNEL_SIZE env {}", e.what()); + LOG(WARNING) << StringFormat("wrong HD_CHANNEL_SIZE env %s", e.what()); channelSize = LARGE_CHANNEL_SIZE; } if (channelSize <= 0) { channelSize = LARGE_CHANNEL_SIZE; } } - spdlog::info("user config all2all restore lookup channel size:{}", channelSize); + LOG(INFO) << StringFormat("user config all2all restore lookup channel size:%d", channelSize); for (int c = static_cast(TransferChannel::D2H); c != static_cast(TransferChannel::INVALID); c++) { auto channel = static_cast(c); - string sendName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelNum); + string sendName = StringFormat( + "%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelNum + ); if (TransferChannel2Str(channel) == "all2all" || TransferChannel2Str(channel) == "restore" || TransferChannel2Str(channel) == "lookup" || @@ -106,7 +107,9 @@ void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName } else { transferChannels[sendName] = tdtCreateChannel(localRankId, sendName.c_str(), PING_PONG_SIZE); } - spdlog::info("create channel:{} {}", sendName, static_cast(transferChannels[sendName])); + LOG(INFO) << StringFormat( + "create channel:%s %d", sendName.c_str(), static_cast(transferChannels[sendName]) + ); } #endif } @@ -123,13 +126,17 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in for (auto& t: tensors) { sizes.push_back(t.NumElements()); } - string sendName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelId); + string sendName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); - spdlog::info(HD + "hd transfer send {}, send count is {}, size list:{}", sendName, sizes.size(), - sizes); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat( + HD + "hd transfer send %s, send count is %d, size list:%s", + sendName.c_str(), sizes.size(), VectorToString(sizes).c_str() + ); + } if (sizes.size() == 0) { - spdlog::warn("tensors num can not be zero"); + LOG(WARNING) << "tensors num can not be zero"; return; } bool isNeedResend = false; @@ -142,11 +149,15 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in return; } if (status != tensorflow::Status::OK()) { - spdlog::error(MGMT + "hd send {} error '{}'", sendName, status.error_message()); + LOG(ERROR) << StringFormat( + MGMT + "hd send %s error '%s'", sendName.c_str(), status.error_message().c_str() + ); throw runtime_error("hd send error"); } if (batchId != -1 && resendTime != 0) { - spdlog::warn(MGMT + "hd send {} batch: {} failed, retry: {} ", sendName, batchId, resendTime); + LOG(WARNING) << StringFormat( + MGMT + "hd send %s batch: %d failed, retry: %d ", sendName.c_str(), batchId, resendTime + ); } resendTime++; } while (isNeedResend); @@ -158,15 +169,15 @@ vector HDTransfer::Recv(TransferChannel channel, int channel EASY_FUNCTION() #ifndef GTEST std::vector tensors; - string recvName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelId); - spdlog::debug("hd transfer try recv:{}", recvName); - spdlog::stopwatch sw; + string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); + VLOG(GLOG_DEBUG) << StringFormat("hd transfer try recv:%s", recvName.c_str()); + TimeCost tc = TimeCost(); tensorflow::Status status = tensorflow::RecvTensorByAcl(transferChannels[recvName], tensors); if (!running) { return {}; } if (status != tensorflow::Status::OK()) { - spdlog::error(MGMT + "{} hd recv error '{}'", recvName, status.error_message()); + LOG(ERROR) << StringFormat(MGMT + "%s hd recv error '%s'", recvName.c_str(), status.error_message().c_str()); throw runtime_error("hd recv error"); } @@ -174,7 +185,11 @@ vector HDTransfer::Recv(TransferChannel channel, int channel for (auto& t: tensors) { sizes.push_back(t.NumElements()); } - spdlog::info("hd transfer recv:{}, size:{} cost:{}ms", recvName, sizes, Format2Ms(sw)); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat( + "hd transfer recv:%s, size:%d cost:%dms", recvName.c_str(), VectorToString(sizes).c_str(), tc.ElapsedMS() + ); + } return tensors; #endif return {}; @@ -185,20 +200,20 @@ size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& EASY_FUNCTION() #ifndef GTEST std::vector tensors; - string recvName = fmt::format("{}_{}_{}", embName, TransferChannel2Str(channel), channelId); - spdlog::debug("hd transfer try recv:{}", recvName); - spdlog::stopwatch sw; + string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); + VLOG(GLOG_DEBUG) << StringFormat("hd transfer try recv:%s", recvName.c_str()); + TimeCost tc = TimeCost(); if (aclDatasets[embName] == nullptr) { - throw runtime_error(fmt::format("Failed recv:{}.", recvName).c_str()); + throw runtime_error(StringFormat("Failed recv:%s.", recvName.c_str()).c_str()); } auto aclStatus = acltdtReceiveTensor(transferChannels[recvName], aclDatasets[embName], timeout /*-1 no timeout */); if (!running) { return 0; } if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { - throw runtime_error(fmt::format("Failed receive data from acl channel, acl status:{}", aclStatus).c_str()); + throw runtime_error(StringFormat("Failed receive data from acl channel, acl status:%d", aclStatus).c_str()); } - spdlog::info("hd transfer recv:{} cost:{}ms", recvName, Format2Ms(sw)); + LOG(INFO) << StringFormat("hd transfer recv:%s cost:%dms", recvName.c_str(), tc.ElapsedMS()); return acltdtGetDatasetSize(aclDatasets[embName]); #endif return 0; diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 6f882a24..5b1e9bc6 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -6,14 +6,11 @@ */ #include "host_emb.h" -#include -#include -#include -#include #include #include "hd_transfer/hd_transfer.h" #include "checkpoint/checkpoint.h" #include "initializer/initializer.h" +#include "utils/time_cost.h" using namespace MxRec; using namespace std; @@ -27,7 +24,7 @@ bool HostEmb::Initialize(const vector& embInfos, int seed) EmbDataGenerator(embInfo.initializeInfos, seed, static_cast(embInfo.hostVocabSize), embInfo.extEmbeddingSize, hostEmb.embData); hostEmbs[embInfo.name] = move(hostEmb); - spdlog::info(HOSTEMB + "HostEmb Initialize End"); + LOG(INFO) << (HOSTEMB + "HostEmb Initialize End"); } return true; } @@ -36,7 +33,9 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in int embeddingSize, vector> &embData) { #ifndef GTEST - spdlog::info(HOSTEMB + "GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); + LOG(INFO) << StringFormat( + HOSTEMB + "GenerateEmbData Start, seed:%d, initializer num: %d", seed, initializeInfos.size() + ); embData.clear(); embData.resize(vocabSize, vector(embeddingSize)); @@ -45,30 +44,34 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in switch (initializeInfo.initializerType) { case InitializerType::CONSTANT: { - spdlog::info(HOSTEMB + "GenerateEmbData ing using Constant Initializer by value {}. name {}, " - "start {}, len {}.", initializeInfo.constantInitializerInfo.constantValue, - initializeInfo.name, initializeInfo.start, initializeInfo.len); + LOG(INFO) << StringFormat( + HOSTEMB + "GenerateEmbData ing using Constant Initializer by value %f. name %s, " + "start %d, len %d.", initializeInfo.constantInitializerInfo.constantValue, + initializeInfo.name.c_str(), initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.constantInitializer; break; } case InitializerType::TRUNCATED_NORMAL: { - spdlog::info(HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}. " - "name {}, start {}, len {}.", initializeInfo.normalInitializerInfo.mean, - initializeInfo.normalInitializerInfo.stddev, initializeInfo.name, - initializeInfo.start, initializeInfo.len); + LOG(INFO) << StringFormat( + HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: %f stddev: %f. " + "name %s, start %d, len %d.", initializeInfo.normalInitializerInfo.mean, + initializeInfo.normalInitializerInfo.stddev, initializeInfo.name.c_str(), + initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.truncatedNormalInitializer; break; } case InitializerType::RANDOM_NORMAL: { - spdlog::info(HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}. " - "name {}, start {}, len {}.", initializeInfo.normalInitializerInfo.mean, - initializeInfo.normalInitializerInfo.stddev, initializeInfo.name, - initializeInfo.start, initializeInfo.len); + LOG(INFO) << StringFormat( + HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: %f stddev: %f. " + "name %s, start %d, len %d.", initializeInfo.normalInitializerInfo.mean, + initializeInfo.normalInitializerInfo.stddev, initializeInfo.name.c_str(), + initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.randomNormalInitializer; break; } default: { - spdlog::warn(HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); + LOG(WARNING) << ( + HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0, 1); initializer = &defaultInitializer; } @@ -78,7 +81,7 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in initializer->GenerateData(embData.at(i).data(), embeddingSize); } } - spdlog::info(HOSTEMB + "GenerateEmbData End, seed:{}", seed); + LOG(INFO) << StringFormat(HOSTEMB + "GenerateEmbData End, seed:%d", seed); #endif } @@ -91,35 +94,29 @@ void HostEmb::LoadEmb(emb_mem_t& loadData) void HostEmb::Join(int channelId) { - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); switch (channelId) { case TRAIN_CHANNEL_ID: - spdlog::debug( - HOSTEMB + "start join, channelId:{}, procThreadsForTrain num:{}", - channelId, procThreadsForTrain.size() - ); + VLOG(GLOG_DEBUG) << StringFormat( + HOSTEMB + "start join, channelId:%d, procThreadsForTrain num:%d", + channelId, procThreadsForTrain.size()); for (auto& t: procThreadsForTrain) { t->join(); } procThreadsForTrain.clear(); - spdlog::debug( - HOSTEMB + "end join, channelId:{}, cost:{}", - channelId, duration_cast((sw).elapsed()) - ); + VLOG(GLOG_DEBUG) << StringFormat( + HOSTEMB + "end join, channelId:%d, cost:%dms", channelId, tc.ElapsedMS()); break; case EVAL_CHANNEL_ID: - spdlog::debug( - HOSTEMB + "start join, channelId:{}, procThreadsForEval num:{}", - channelId, procThreadsForEval.size() - ); + VLOG(GLOG_DEBUG) << StringFormat( + HOSTEMB + "start join, channelId:%d, procThreadsForEval num:%d", + channelId, procThreadsForEval.size()); for (auto& t: procThreadsForEval) { t->join(); } procThreadsForEval.clear(); - spdlog::debug( - HOSTEMB + "end join, channelId:{}, cost:{}", - channelId, duration_cast((sw).elapsed()) - ); + VLOG(GLOG_DEBUG) << StringFormat( + HOSTEMB + "end join, channelId:%d, cost:%dms", channelId, tc.ElapsedMS()); break; default: throw invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); @@ -133,19 +130,19 @@ void HostEmb::Join(int channelId) #ifndef GTEST void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName) { - spdlog::info(HOSTEMB + "UpdateEmb, channelId:{}, embName:{}", channelId, embName); + LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmb, channelId:%d, embName:%s", channelId, embName.c_str()); EASY_FUNCTION(profiler::colors::Purple); - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; - spdlog::info(HOSTEMB + "wait D2H embs, channelId:{}", channelId); + LOG(INFO) << StringFormat(HOSTEMB + "wait D2H embs, channelId:%d", channelId); const auto tensors = hdTransfer->Recv(transferName, channelId, embName); if (tensors.empty()) { - spdlog::warn(HOSTEMB + "recv empty data"); + LOG(WARNING) << (HOSTEMB + "recv empty data"); return; } const Tensor& d2hEmb = tensors[0]; - spdlog::info(HOSTEMB + "UpdateEmb End missingkeys len = {}", missingKeysHostPos.size()); + LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmb End missingkeys len = %d", missingKeysHostPos.size()); EASY_BLOCK("Update") const float* tensorPtr = d2hEmb.flat().data(); auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; @@ -160,26 +157,26 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, dst[j] = tensorPtr[j + embeddingSize * i]; } } - spdlog::info(HOSTEMB + "update emb end cost: {}ms", Format2Ms(sw)); + LOG(INFO) << StringFormat(HOSTEMB + "update emb end cost: %dms", tc.ElapsedMS()); EASY_END_BLOCK } void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName) { - spdlog::info(HOSTEMB + "UpdateEmbV2, channelId:{}, embName:{}", channelId, embName); + LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmbV2, channelId:%d, embName:%s", channelId, embName.c_str()); EASY_FUNCTION(profiler::colors::Purple) auto updateThread = [&, missingKeysHostPos, channelId, embName] { auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; - spdlog::info(HOSTEMB + "wait D2H embs, channelId:{}", channelId); + LOG(INFO) << StringFormat(HOSTEMB + "wait D2H embs, channelId:%d", channelId); auto size = hdTransfer->RecvAcl(transferName, channelId, embName); if (size == 0) { - spdlog::warn(HOSTEMB + "recv empty data"); + LOG(WARNING) << (HOSTEMB + "recv empty data"); return; } - spdlog::stopwatch sw; - spdlog::info(HOSTEMB + "UpdateEmb End missingkeys len = {}", missingKeysHostPos.size()); + TimeCost tc = TimeCost(); + LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmb End missingkeys len = %d", missingKeysHostPos.size()); EASY_BLOCK("Update") auto& embData = hostEmbs[embName].embData; auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; @@ -196,7 +193,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI dst[k] = ptr[k + embeddingSize * j]; } } - spdlog::info(HOSTEMB + "update emb end cost: {}ms", Format2Ms(sw)); + LOG(INFO) << StringFormat(HOSTEMB + "update emb end cost: %dms", tc.ElapsedMS()); }; switch (channelId) { @@ -219,7 +216,7 @@ void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& vector& h2dEmbOut) { EASY_FUNCTION() - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); const auto& emb = hostEmbs[embName]; const int embeddingSize = emb.hostEmbInfo.extEmbeddingSize; h2dEmbOut.emplace_back(Tensor(tensorflow::DT_FLOAT, { @@ -235,7 +232,8 @@ void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& tmpData(j + i * embeddingSize) = src[j]; } } - spdlog::info("GetH2DEmb end, missingKeys count:{} cost:{}ms", missingKeysHostPos.size(), Format2Ms(sw)); + LOG(INFO) << StringFormat( + "GetH2DEmb end, missingKeys count:%d cost:%dms", missingKeysHostPos.size(), tc.ElapsedMS()); } @@ -252,25 +250,27 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve switch (initializeInfo.initializerType) { case InitializerType::CONSTANT: { - spdlog::info(HOSTEMB + "GenerateEmbData ing using Constant Initializer by value {}.", - initializeInfo.constantInitializerInfo.constantValue); + LOG(INFO) << StringFormat(HOSTEMB + "GenerateEmbData ing using Constant Initializer by value %d.", + initializeInfo.constantInitializerInfo.constantValue); initializer = &initializeInfo.constantInitializer; break; } case InitializerType::TRUNCATED_NORMAL: { - spdlog::info(HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: {} stddev: {}.", - initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + LOG(INFO) << StringFormat( + HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: %f stddev: %f.", + initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); initializer = &initializeInfo.truncatedNormalInitializer; break; } case InitializerType::RANDOM_NORMAL: { - spdlog::info(HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: {} stddev: {}.", - initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); + LOG(INFO) << StringFormat( + HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: %f stddev: %f.", + initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); initializer = &initializeInfo.randomNormalInitializer; break; } default: { - spdlog::error(HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); + LOG(ERROR) << (HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0, 1); initializer = &defaultInitializer; } @@ -291,6 +291,7 @@ void HostEmb::EvictInitEmb(const string& embName, const vector& offset) #ifndef GTEST auto& hostEmb = GetEmb(embName); EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); - spdlog::info(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); + LOG(INFO) << StringFormat( + HOSTEMB + "ddr EvictInitEmb!host embName %s, init offsets size: %d", embName.c_str(), offset.size()); #endif } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 36424db0..89247a2f 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -6,9 +6,6 @@ */ #include "hybrid_mgmt.h" -#include -#include - #include "checkpoint/checkpoint.h" #include "utils/time_cost.h" #include "utils/common.h" @@ -19,42 +16,47 @@ using namespace std; bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed) { +#ifndef GTEST if (getenv("KEY_PROCESS_THREAD_NUM") != nullptr) { int num = std::atoi(getenv("KEY_PROCESS_THREAD_NUM")); if (num < 1 || num > MAX_KEY_PROCESS_THREAD) { - spdlog::error("[HybridMgmt::InitKeyProcess] KEY_PROCESS_THREAD_NUM:{}, should in range [1, {}]", - num, MAX_KEY_PROCESS_THREAD); + LOG(ERROR) << StringFormat( + "[HybridMgmt::InitKeyProcess] KEY_PROCESS_THREAD_NUM:%d, should in range [1, %d]", + num, MAX_KEY_PROCESS_THREAD); return false; } PerfConfig::keyProcessThreadNum = num; - spdlog::info("config KEY_PROCESS_THREAD_NUM:{}", num); + LOG(INFO) << StringFormat("config KEY_PROCESS_THREAD_NUM:%d", num); } if (getenv("MAX_UNIQUE_THREAD_NUM") != nullptr) { int num = std::atoi(getenv("MAX_UNIQUE_THREAD_NUM")); if (num < 1 || num > DEFAULT_MAX_UNIQUE_THREAD_NUM) { - spdlog::error("[HybridMgmt::InitKeyProcess] MAX_UNIQUE_THREAD_NUM:{}, should in range [1, {}]", - num, DEFAULT_MAX_UNIQUE_THREAD_NUM); + LOG(ERROR) << StringFormat( + "[HybridMgmt::InitKeyProcess] MAX_UNIQUE_THREAD_NUM:%d, should in range [1, %d]", + num, DEFAULT_MAX_UNIQUE_THREAD_NUM); return false; } PerfConfig::maxUniqueThreadNum = num; - spdlog::info("config MAX_UNIQUE_THREAD_NUM:{}", num); + LOG(INFO) << StringFormat("config MAX_UNIQUE_THREAD_NUM:%d", num); } if (getenv("FAST_UNIQUE") != nullptr) { bool isFastUnique = std::atoi(getenv("FAST_UNIQUE")); PerfConfig::fastUnique = isFastUnique; - spdlog::info("config FAST_UNIQUE:{}", PerfConfig::fastUnique); + LOG(INFO) << StringFormat("config FAST_UNIQUE:%d", PerfConfig::fastUnique); } preprocess = Singleton::GetInstance(); preprocess->Initialize(rankInfo, embInfos, thresholdValues, seed); preprocess->Start(); +#endif return true; } void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfos) { +#ifndef GTEST MPI_Comm_size(MPI_COMM_WORLD, &rankInfo.rankSize); rankInfo.localRankId = rankInfo.deviceId; @@ -66,19 +68,22 @@ void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfo rankInfo.noDDR = true; } rankInfo.useDataset = getenv("DATASET") != nullptr; +#endif } bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, int seed, const vector& thresholdValues, bool ifLoad) { +#ifndef GTEST if (isRunning) { return true; } - SetLog(rankInfo.rankId); + SetLog(); InitRankInfo(rankInfo, embInfos); - spdlog::info(MGMT + "begin initialize, localRankSize:{}, localRankId {}, rank {}", - rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); + LOG(INFO) << StringFormat( + MGMT + "begin initialize, localRankSize:%d, localRankId:%d, rank:%d", + rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); mgmtRankInfo = rankInfo; mgmtEmbInfo = embInfos; @@ -112,11 +117,15 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, } for (const auto& info: embInfos) { - spdlog::info(MGMT + "emb[{}] vocab size {}+{} sc:{}", info.name, info.devVocabSize, info.hostVocabSize, - info.sendCount); - } - spdlog::info(MGMT + "end initialize, useDataset:{}, noDDR:{}, maxStep:{}, rank:{}", - rankInfo.useDataset, rankInfo.noDDR, rankInfo.maxStep, rankInfo.rankId); + LOG(INFO) << StringFormat( + MGMT + "emb[%s] vocab size %d+%d sc:%d", + info.name.c_str(), info.devVocabSize, info.hostVocabSize, info.sendCount); + } + LOG(INFO) << StringFormat( + MGMT + "end initialize, useDataset:%d, noDDR:%d, maxStep:[%d, %d], rank:%d", + rankInfo.useDataset, rankInfo.noDDR, + rankInfo.maxStep.at(TRAIN_CHANNEL_ID), rankInfo.maxStep.at(EVAL_CHANNEL_ID), rankInfo.rankId); +#endif return true; } @@ -128,18 +137,18 @@ bool HybridMgmt::Save(const string savePath) CkptData saveData; Checkpoint saveCkpt; if (!mgmtRankInfo.noDDR) { - spdlog::debug(MGMT + "Start host side save: ddr mode hashmap"); + VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: ddr mode hashmap"); saveData.hostEmbs = hostEmbs->GetHostEmbs(); saveData.embHashMaps = hostHashMaps->GetHashMaps(); } else { - spdlog::debug(MGMT + "Start host side save: no ddr mode hashmap"); + VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: no ddr mode hashmap"); saveData.maxOffset = preprocess->GetMaxOffset(); saveData.keyOffsetMap = preprocess->GetKeyOffsetMap(); } auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { - spdlog::debug(MGMT + "Start host side save: feature admit and evict"); + VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: feature admit and evict"); saveData.tens2Thresh = featAdmitNEvict.GetTensorThresholds(); saveData.histRec.timestamps = featAdmitNEvict.GetHistoryRecords().timestamps; saveData.histRec.historyRecords = featAdmitNEvict.GetHistoryRecords().historyRecords; @@ -157,7 +166,7 @@ bool HybridMgmt::Load(const string& loadPath) #ifndef GTEST preprocess->LoadSaveLock(); - spdlog::debug(MGMT + "Start host side load process"); + VLOG(GLOG_DEBUG) << (MGMT + "Start host side load process"); CkptData loadData; Checkpoint loadCkpt; @@ -182,20 +191,20 @@ bool HybridMgmt::Load(const string& loadPath) } if (!mgmtRankInfo.noDDR) { - spdlog::debug(MGMT + "Start host side load: ddr mode hashmap"); + VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: ddr mode hashmap"); hostHashMaps->LoadHashMap(loadData.embHashMaps); } else { - spdlog::debug(MGMT + "Start host side load: no ddr mode hashmap"); + VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: no ddr mode hashmap"); preprocess->LoadMaxOffset(loadData.maxOffset); preprocess->LoadKeyOffsetMap(loadData.keyOffsetMap); } if (featAdmitNEvict.GetFunctionSwitch()) { - spdlog::debug(MGMT + "Start host side load: feature admit and evict"); + VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: feature admit and evict"); featAdmitNEvict.LoadTensorThresholds(loadData.tens2Thresh); featAdmitNEvict.LoadHistoryRecords(loadData.histRec); } - spdlog::debug(MGMT + "Finish host side load process"); + VLOG(GLOG_DEBUG) << (MGMT + "Finish host side load process"); preprocess->LoadSaveUnlock(); @@ -214,9 +223,9 @@ key_offset_map_t HybridMgmt::SendHostMap(const string tableName) key_offset_map_t sendKeyOffsetMap; if (!mgmtRankInfo.noDDR) { - spdlog::debug(MGMT + "Start send sparse data: ddr mode hashmap"); + VLOG(GLOG_DEBUG) << (MGMT + "Start send sparse data: ddr mode hashmap"); } else { - spdlog::debug(MGMT + "Start send sparse data: no ddr mode hashmap"); + VLOG(GLOG_DEBUG) << (MGMT + "Start send sparse data: no ddr mode hashmap"); keyOffsetMap = preprocess->GetKeyOffsetMap(); } @@ -248,9 +257,9 @@ void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) } } if (!mgmtRankInfo.noDDR) { - spdlog::debug(MGMT + "Start receive sparse data: ddr mode hashmap"); + VLOG(GLOG_DEBUG) << (MGMT + "Start receive sparse data: ddr mode hashmap"); } else { - spdlog::debug(MGMT + "Start receive sparse data: no ddr mode hashmap"); + VLOG(GLOG_DEBUG) << (MGMT + "Start receive sparse data: no ddr mode hashmap"); preprocess->LoadKeyOffsetMap(loadKeyOffsetMap); preprocess->LoadMaxOffset(loadMaxOffset); } @@ -262,50 +271,63 @@ void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) #endif } +bool HybridMgmt::IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t* embTableCount) +{ + bool loadDataMatches = { true }; + const auto& loadEmbTable { loadHostEmbs->find(setupHostEmbs->name) }; + if (loadEmbTable != loadHostEmbs->end()) { + embTableCount++; + + const auto& loadEmbInfo { loadEmbTable->second.hostEmbInfo }; + if (setupHostEmbs->sendCount != loadEmbInfo.sendCount) { + LOG(ERROR) << StringFormat( + MGMT + "Load data sendCount %d for table %s does not match setup sendCount %d", + setupHostEmbs->sendCount, setupHostEmbs->name.c_str(), loadEmbInfo.sendCount); + loadDataMatches = false; + } + if (setupHostEmbs->extEmbeddingSize != loadEmbInfo.extEmbeddingSize) { + LOG(ERROR) << StringFormat( + MGMT + "Load data extEmbeddingSize %d for table %s does not match setup extEmbeddingSize %d", + setupHostEmbs->extEmbeddingSize, setupHostEmbs->name.c_str(), loadEmbInfo.extEmbeddingSize); + loadDataMatches = false; + } + if (setupHostEmbs->devVocabSize != loadEmbInfo.devVocabSize) { + LOG(ERROR) << StringFormat( + MGMT + "Load data devVocabSize %d for table %s does not match setup devVocabSize %d", + setupHostEmbs->devVocabSize, setupHostEmbs->name.c_str(), loadEmbInfo.devVocabSize); + loadDataMatches = false; + } + if (setupHostEmbs->hostVocabSize != loadEmbInfo.hostVocabSize) { + LOG(ERROR) << StringFormat( + MGMT + "Load data hostVocabSize %d for table %s does not match setup hostVocabSize %d", + setupHostEmbs->hostVocabSize, setupHostEmbs->name.c_str(), loadEmbInfo.hostVocabSize); + loadDataMatches = false; + } + if (!loadDataMatches) { + return false; + } + } else { + LOG(ERROR) << StringFormat( + MGMT + "Load data does not contain table with table name: %s", setupHostEmbs->name.c_str() + ); + return false; + } + return true; +} + bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) { - bool loadDataMatches { true }; size_t embTableCount { 0 }; auto loadHostEmbs { loadData.hostEmbs }; - for (const auto& setupHostEmbs : mgmtEmbInfo) { - const auto& loadEmbTable { loadHostEmbs->find(setupHostEmbs.name) }; - if (loadEmbTable != loadHostEmbs->end()) { - embTableCount++; - - const auto& loadEmbInfo { loadEmbTable->second.hostEmbInfo }; - if (setupHostEmbs.sendCount != loadEmbInfo.sendCount) { - spdlog::error(MGMT + "Load data sendCount {} for table {} does not match setup sendCount {}", - setupHostEmbs.sendCount, setupHostEmbs.name, loadEmbInfo.sendCount); - loadDataMatches = false; - } - if (setupHostEmbs.extEmbeddingSize != loadEmbInfo.extEmbeddingSize) { - spdlog::error(MGMT + "Load data extEmbeddingSize {} for table {} does not match " - "setup extEmbeddingSize {}", - setupHostEmbs.extEmbeddingSize, setupHostEmbs.name, loadEmbInfo.extEmbeddingSize); - loadDataMatches = false; - } - if (setupHostEmbs.devVocabSize != loadEmbInfo.devVocabSize) { - spdlog::error(MGMT + "Load data devVocabSize {} for table {} does not match setup devVocabSize {}", - setupHostEmbs.devVocabSize, setupHostEmbs.name, loadEmbInfo.devVocabSize); - loadDataMatches = false; - } - if (setupHostEmbs.hostVocabSize != loadEmbInfo.hostVocabSize) { - spdlog::error(MGMT + "Load data hostVocabSize {} for table {} does not match setup hostVocabSize {}", - setupHostEmbs.hostVocabSize, setupHostEmbs.name, loadEmbInfo.hostVocabSize); - loadDataMatches = false; - } - if (!loadDataMatches) { - return loadDataMatches; - } - } else { - spdlog::error(MGMT + "Load data does not contain table with table name: {}", setupHostEmbs.name); + for (EmbInfo setupHostEmbs : mgmtEmbInfo) { + if (!IsLoadDataMatches(loadHostEmbs, &setupHostEmbs, &embTableCount)) { return false; } } if (embTableCount < loadHostEmbs->size()) { - spdlog::error(MGMT + "Load data has {} tables more than setup table num {}", - loadHostEmbs->size(), embTableCount); + LOG(ERROR) << StringFormat(MGMT + "Load data has %d tables more than setup table num %d", + loadHostEmbs->size(), embTableCount); return false; } return true; @@ -319,9 +341,9 @@ void HybridMgmt::Start() if (envTaskMode != nullptr) { // 如果环境变量存在 try { mode = std::stoi(envTaskMode); // 将字符串转换为整数 - spdlog::info("The value of MGMT_HBM_TASK_MODE is an integer: {}", mode); + LOG(INFO) << StringFormat("The value of MGMT_HBM_TASK_MODE is an integer: %d", mode); } catch (const std::invalid_argument& e) { // 如果转换失败 - spdlog::error("The value of MGMT_HBM_TASK_MODE is not an integer!"); + LOG(ERROR) << "The value of MGMT_HBM_TASK_MODE is not an integer!"; throw std::invalid_argument("Invalid env value MGMT_HBM_TASK_MODE"); } } else { // 如果环境变量不存在 @@ -334,13 +356,13 @@ void HybridMgmt::Start() if (!mgmtRankInfo.noDDR) { auto parseKeysTaskForTrain = [this]() { TaskForTrain(TaskType::DDR); - spdlog::info("parseKeysTaskForTrain done"); + LOG(INFO) << StringFormat("parseKeysTaskForTrain done"); }; procThreads.emplace_back(std::make_unique(parseKeysTaskForTrain)); auto parseKeysTaskForEval = [this]() { TaskForEval(TaskType::DDR); - spdlog::info("parseKeysTaskForEval done"); + LOG(INFO) << StringFormat("parseKeysTaskForEval done"); }; procThreads.emplace_back(std::make_unique(parseKeysTaskForEval)); } @@ -353,37 +375,37 @@ void HybridMgmt::InsertThreadForHBM(int mode) if (mode == 1) { auto getInfoTaskForTrain = [this]() { TaskForTrain(TaskType::GETINFO); - spdlog::info("getInfoTaskForTrain done"); + LOG(INFO) << "getInfoTaskForTrain done"; }; procThreads.emplace_back(std::make_unique(getInfoTaskForTrain)); auto getInfoTaskForEval = [this]() { TaskForEval(TaskType::GETINFO); - spdlog::info("getInfoTaskForEval done"); + LOG(INFO) << "getInfoTaskForEval done"; }; procThreads.emplace_back(std::make_unique(getInfoTaskForEval)); auto sendInfoTaskForTrain = [this]() { TaskForTrain(TaskType::SEND); - spdlog::info("sendInfoTaskForTrain done"); + LOG(INFO) << "sendInfoTaskForTrain done"; }; procThreads.emplace_back(std::make_unique(sendInfoTaskForTrain)); auto sendInfoTaskForEval = [this]() { TaskForEval(TaskType::SEND); - spdlog::info("sendInfoTaskForEval done"); + LOG(INFO) << "sendInfoTaskForEval done"; }; procThreads.emplace_back(std::make_unique(sendInfoTaskForEval)); } else { auto parseKeysTaskForHBMTrain = [this]() { TaskForTrain(TaskType::HBM); - spdlog::info("parseKeysTaskForHBMTrain done"); + LOG(INFO) << "parseKeysTaskForHBMTrain done"; }; procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMTrain)); auto parseKeysTaskForHBMEval = [this]() { TaskForEval(TaskType::HBM); - spdlog::info("parseKeysTaskForHBMEval done"); + LOG(INFO) << "parseKeysTaskForHBMEval done"; }; procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMEval)); } @@ -396,7 +418,7 @@ void HybridMgmt::TaskForTrain(TaskType type) bool isFirstIn = true; while (isRunning) { if (isFirstIn) { - spdlog::info(MGMT + "Start Train Task: {}", type); + LOG(INFO) << StringFormat(MGMT + "Start Train Task: %d", type); isFirstIn = false; } if (mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] > 0) { @@ -413,7 +435,7 @@ void HybridMgmt::TaskForEval(TaskType type) bool isFirstIn = true; while (isRunning) { if (isFirstIn) { - spdlog::info(MGMT + "Start Eval Task: {}", type); + LOG(INFO) << StringFormat(MGMT + "Start Eval Task: %d", type); isFirstIn = false; } if (mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] > 0) { @@ -439,20 +461,20 @@ bool HybridMgmt::TrainTask(TaskType type) status = GetLookupAndRestore(TRAIN_CHANNEL_ID, getInfoBatchId); isContinue = getInfoBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; - spdlog::info(MGMT + "getInfoBatchId = {}", getInfoBatchId); + LOG(INFO) << StringFormat(MGMT + "getInfoBatchId = %d", getInfoBatchId); break; case TaskType::SEND: status = SendLookupAndRestore(TRAIN_CHANNEL_ID, sendBatchId); isContinue = sendBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; - spdlog::info(MGMT + "sendBatchId = {}", sendBatchId); + LOG(INFO) << StringFormat(MGMT + "sendBatchId = %d", sendBatchId); #if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) if (sendBatchId == PROFILING_START_BATCH_ID) { EASY_PROFILER_ENABLE } else if (sendBatchId == PROFILING_END_BATCH_ID) { EASY_PROFILER_DISABLE - ::profiler::dumpBlocksToFile(fmt::format("/home/MX_REC-mgmt-profile-{}.prof", - mgmtRankInfo.rankId).c_str()); + ::profiler::dumpBlocksToFile( + StringFormat("/home/MX_REC-mgmt-profile-%s.prof", mgmtRankInfo.rankId).c_str()); } #endif break; @@ -460,13 +482,13 @@ bool HybridMgmt::TrainTask(TaskType type) status = ParseKeysHBM(TRAIN_CHANNEL_ID, trainBatchId); isContinue = trainBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; - spdlog::info(MGMT + "ParseKeysHBMBatchId = {}", trainBatchId); + LOG(INFO) << StringFormat(MGMT + "ParseKeysHBMBatchId = %d", trainBatchId); break; case TaskType::DDR: status = ParseKeys(TRAIN_CHANNEL_ID, trainBatchId); isContinue = trainBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; - spdlog::info(MGMT + "parseKeysBatchId = {}", trainBatchId); + LOG(INFO) << StringFormat(MGMT + "parseKeysBatchId = %d", trainBatchId); break; default: throw std::invalid_argument("Invalid TaskType Type."); @@ -492,19 +514,19 @@ bool HybridMgmt::EvalTask(TaskType type) switch (type) { case TaskType::GETINFO: status = GetLookupAndRestore(EVAL_CHANNEL_ID, evalBatchId); - spdlog::info(MGMT + "GETINFO evalBatchId = {}", evalBatchId); + LOG(INFO) << StringFormat(MGMT + "GETINFO evalBatchId = %d", evalBatchId); break; case TaskType::SEND: status = SendLookupAndRestore(EVAL_CHANNEL_ID, evalBatchId); - spdlog::info(MGMT + "SEND evalBatchId = {}", evalBatchId); + LOG(INFO) << StringFormat(MGMT + "SEND evalBatchId = %d", evalBatchId); break; case TaskType::HBM: status = ParseKeysHBM(EVAL_CHANNEL_ID, evalBatchId); - spdlog::info(MGMT + "HBM evalBatchId = {}", evalBatchId); + LOG(INFO) << StringFormat(MGMT + "HBM evalBatchId = %d", evalBatchId); break; case TaskType::DDR: status = ParseKeys(EVAL_CHANNEL_ID, evalBatchId); - spdlog::info(MGMT + "DDR evalBatchId = {}", evalBatchId); + LOG(INFO) << StringFormat(MGMT + "DDR evalBatchId = %d", evalBatchId); break; default: throw std::invalid_argument("Invalid TaskType Type."); @@ -536,12 +558,13 @@ void HybridMgmt::GetAll2All(const int channelId, int &batchId, const string &nam bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) { - spdlog::info(MGMT + "start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); + LOG(INFO) << StringFormat(MGMT + "start parse keys, nBatch:%d , [%d]:%d", mgmtRankInfo.nBatch, channelId, batchId); for (const auto& embInfo: mgmtEmbInfo) { TimeCost getAllTensorTC; auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); if (infoVecs == nullptr) { - spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); + LOG(INFO) << StringFormat( + MGMT + "ParseKeys infoVecs empty ! batchId:%d, channelId:%d", batchId, channelId); return false; } switch (channelId) { @@ -562,7 +585,7 @@ bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) if (!mgmtRankInfo.useStatic) { GetAll2All(channelId, batchId, embInfo.name); } - TIME_PRINT("getAllTensorTC(ms):{}", getAllTensorTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("getAllTensorTC(ms):%d", getAllTensorTC.ElapsedMS()); } batchId++; return true; @@ -583,7 +606,7 @@ void HybridMgmt::All2AllKeys(const int channelId, const string &embName) throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); } hdTransfer->Send(TransferChannel::ALL2ALL, all2allKeys, channelId, embName); - TIME_PRINT("All2AllKeysTC(ms):{}", a2aKeysTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("All2AllKeysTC(ms):%d", a2aKeysTC.ElapsedMS()); } void HybridMgmt::LookupKeys(const int channelId, const string &embName) @@ -601,7 +624,7 @@ void HybridMgmt::LookupKeys(const int channelId, const string &embName) throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); } hdTransfer->Send(TransferChannel::LOOKUP, lookUpKeys, channelId, embName); - TIME_PRINT("sendLookupTC(ms):{}", sendLookupTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("sendLookupTC(ms):%d", sendLookupTC.ElapsedMS()); } void HybridMgmt::RestoreKeys(const int channelId, const string &embName) @@ -619,7 +642,7 @@ void HybridMgmt::RestoreKeys(const int channelId, const string &embName) throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); } hdTransfer->Send(TransferChannel::RESTORE, restore, channelId, embName); - TIME_PRINT("sendRestoreTC(ms):{}", sendRestoreTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("sendRestoreTC(ms):%d", sendRestoreTC.ElapsedMS()); } bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) @@ -630,12 +653,14 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) All2AllKeys(channelId, embInfo.name); } - spdlog::info("SendLookupAndRestore batchId: {}, name: {}, channelId: {}", - batchId, embInfo.name, channelId); + LOG(INFO) << StringFormat( + "SendLookupAndRestore batchId:%d, name:%s, channelId:%d", + batchId, embInfo.name.c_str(), channelId + ); LookupKeys(channelId, embInfo.name); RestoreKeys(channelId, embInfo.name); - TIME_PRINT("sendTensorsTC(ms):{}", sendTensorsTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("sendTensorsTC(ms):%d", sendTensorsTC.ElapsedMS()); } batchId++; return true; @@ -643,42 +668,44 @@ bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) { - spdlog::info(MGMT + "start parse keys HBM, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); + LOG(INFO) << StringFormat( + MGMT + "start parse keys HBM, nBatch:%d , [%d]:%d", mgmtRankInfo.nBatch, channelId, batchId); for (const auto& embInfo: mgmtEmbInfo) { TimeCost ParseKeysTC; // get TimeCost getTensorsSyncTC; auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); if (infoVecs == nullptr) { - spdlog::info(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); + LOG(INFO) << StringFormat( + MGMT + "ParseKeys infoVecs empty ! batchId:%d, channelId:%d", batchId, channelId); return false; } unique_ptr> all2all = nullptr; if (!mgmtRankInfo.useStatic) { all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); } - TIME_PRINT("getTensorsSyncTC(ms):{}", getTensorsSyncTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("getTensorsSyncTC(ms):%d", getTensorsSyncTC.ElapsedMS()); // send TimeCost sendTensorsSyncTC; if (!mgmtRankInfo.useStatic) { TimeCost sendAll2AllScSyncTC; hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embInfo.name); - TIME_PRINT("sendAll2AllScSyncTC(ms):{}", sendAll2AllScSyncTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("sendAll2AllScSyncTC(ms):%d", sendAll2AllScSyncTC.ElapsedMS()); } TimeCost sendLookupSyncTC; hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, embInfo.name); infoVecs->pop_back(); - TIME_PRINT("sendLookupSyncTC(ms):{}", sendLookupSyncTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("sendLookupSyncTC(ms):%d", sendLookupSyncTC.ElapsedMS()); TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); - TIME_PRINT("sendRestoreSyncTC(ms):{}", sendRestoreSyncTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("sendRestoreSyncTC(ms):%d", sendRestoreSyncTC.ElapsedMS()); - TIME_PRINT("sendTensorsSyncTC(ms):{}", sendTensorsSyncTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("sendTensorsSyncTC(ms):%d", sendTensorsSyncTC.ElapsedMS()); - TIME_PRINT("ParseKeysTC HBM mode (ms):{}", ParseKeysTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("ParseKeysTC HBM mode (ms):%d", ParseKeysTC.ElapsedMS()); } batchId++; return true; @@ -693,20 +720,22 @@ bool HybridMgmt::EndBatch(int batchId, int channelId) const bool HybridMgmt::ParseKeys(int channelId, int& batchId) { #ifndef GTEST - spdlog::info(MGMT + "DDR mode, start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); + LOG(INFO) << StringFormat( + MGMT + "DDR mode, start parse keys, nBatch:%d , [%d]:%d", + mgmtRankInfo.nBatch, channelId, batchId); TimeCost parseKeyTC; int start = batchId; int iBatch = 0; bool ifHashmapFree = true; bool remainBatch = true; while (true) { - spdlog::info(MGMT + "parse keys, [{}]:{}", channelId, batchId); + LOG(INFO) << StringFormat(MGMT + "parse keys, [%d]:%d", channelId, batchId); for (const auto& embInfo : mgmtEmbInfo) { ifHashmapFree = ProcessEmbInfo(embInfo.name, batchId, channelId, iBatch, remainBatch); if (!remainBatch) { TimeCost embHdTrans1; EmbHDTransWrap(channelId, batchId, start, iBatch); - TIME_PRINT("embHdTrans1TC TimeCost(ms):{}", embHdTrans1.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("embHdTrans1TC TimeCost(ms):%d", embHdTrans1.ElapsedMS()); return false; } } @@ -721,8 +750,8 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) } TimeCost embHdTrans2TC; EmbHDTransWrap(channelId, batchId - 1, start, iBatch); - TIME_PRINT("embHdTrans2TC TimeCost(ms):{}", embHdTrans2TC.ElapsedMS()); - TIME_PRINT("[{}]-{}, parseKeyTC TimeCost(ms):{}", channelId, batchId, parseKeyTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("embHdTrans2TC TimeCost(ms):%d", embHdTrans2TC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("[%d]-%d, parseKeyTC TimeCost(ms):%d", channelId, batchId, parseKeyTC.ElapsedMS()); #endif return true; } @@ -743,14 +772,14 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, } auto restore = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); - TIME_PRINT("getTensorsTC(ms):{}", getTensorsTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("getTensorsTC(ms):%d", getTensorsTC.ElapsedMS()); hdTransfer->Send(TransferChannel::RESTORE, *restore, channelId, embName); vector tmpData; TimeCost hostHashMapProcessTC; hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData, channelId); - TIME_PRINT("hostHashMapProcessTC(ms):{}", hostHashMapProcessTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("hostHashMapProcessTC(ms):%d", hostHashMapProcessTC.ElapsedMS()); TimeCost sendTensorsTC; hdTransfer->Send(TransferChannel::LOOKUP, { tmpData.front() }, channelId, embName); @@ -760,13 +789,16 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } - TIME_PRINT("sendTensorsTC(ms):{}", sendTensorsTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("sendTensorsTC(ms):%d", sendTensorsTC.ElapsedMS()); - TIME_PRINT("getAndSendTensorsTC(ms):{}, channelId:{}", getAndSendTensorsTC.ElapsedMS(), channelId); + VLOG(GLOG_DEBUG) << StringFormat( + "getAndSendTensorsTC(ms):%d, channelId:%d", getAndSendTensorsTC.ElapsedMS(), channelId); if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch - spdlog::warn(MGMT + "embName {}[{}]{},iBatch:{} freeSize not enough, {}", embName, channelId, - batchId, iBatch, lookupKeys.size()); + LOG(WARNING) << StringFormat( + MGMT + "embName %s[%d]%d,iBatch:%d freeSize not enough, %d", embName.c_str(), channelId, + batchId, iBatch, lookupKeys.size() + ); return false; } return true; @@ -778,16 +810,16 @@ void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, in if (iBatch == 0) { return; } - spdlog::info(MGMT + "trans emb, batchId:[{}-{}], channelId:{}", start, batchId, channelId); + LOG(INFO) << StringFormat(MGMT + "trans emb, batchId:[%d-%d], channelId:%d", start, batchId, channelId); TimeCost hostEmbsTC; hostEmbs->Join(channelId); - TIME_PRINT("hostEmbsTC(ms):{}", hostEmbsTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("hostEmbsTC(ms):%d", hostEmbsTC.ElapsedMS()); EmbHDTrans(channelId, batchId); for (int i = 0; i < iBatch - 1; ++i) { // need send empty - spdlog::info(MGMT + "trans emb dummy, batchId:{}, ", start + 1 + i); + LOG(INFO) << StringFormat(MGMT + "trans emb dummy, batchId:%d, ", start + 1 + i); EmbHDTrans(channelId, batchId); } } @@ -796,7 +828,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) { EASY_FUNCTION(profiler::colors::Blue) EASY_VALUE("mgmtProcess", batchId) - spdlog::debug(MGMT + "trans emb, batchId:{}, channelId:{}", batchId, channelId); + VLOG(GLOG_DEBUG) << StringFormat(MGMT + "trans emb, batchId:%d, channelId:%d", batchId, channelId); TimeCost tr; TimeCost h2dTC; for (const auto& embInfo: mgmtEmbInfo) { @@ -805,7 +837,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) hostEmbs->GetH2DEmb(missingKeys, embInfo.name, h2dEmb); // order! hdTransfer->Send(TransferChannel::H2D, h2dEmb, channelId, embInfo.name, batchId); } - TIME_PRINT("h2dTC(ms):{}", h2dTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("h2dTC(ms):%d", h2dTC.ElapsedMS()); TimeCost d2hTC; for (const auto& embInfo: mgmtEmbInfo) { @@ -820,16 +852,18 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) } // skip when skip update and empty missing keys hostHashMaps->ClearMissingKeys(embInfo.name); } - TIME_PRINT("d2hTC(ms):{}", d2hTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("d2hTC(ms):%d", d2hTC.ElapsedMS()); - TIME_PRINT("EmbHDTrans TimeCost(ms):{} batchId: {} channelId:{}", tr.ElapsedMS(), batchId, channelId); + VLOG(GLOG_DEBUG) << StringFormat( + "EmbHDTrans TimeCost(ms):%d batchId: %d channelId:%d", tr.ElapsedMS(), batchId, channelId + ); } void HybridMgmt::EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo) { EASY_FUNCTION(profiler::colors::Blue) EASY_VALUE("mgmtProcess", batchId) - spdlog::info(MGMT + "trans emb dummy, batchId:{}, channelId:{}", batchId, channelId); + LOG(INFO) << StringFormat(MGMT + "trans emb dummy, batchId:%d, channelId:%d", batchId, channelId); auto transferName = TransferChannel::D2H; auto d2hEmb = hdTransfer->Recv(transferName, channelId, embInfo.name)[0]; hdTransfer->Send(TransferChannel::H2D, {}, channelId, embInfo.name); @@ -845,12 +879,12 @@ bool HybridMgmt::Evict() if (featAdmitNEvict.GetFunctionSwitch()) { featAdmitNEvict.FeatureEvict(evictKeyMap); } else { - spdlog::warn(MGMT + "Hook can not trigger evict, cause AdmitNEvict is not open"); + LOG(WARNING) << (MGMT + "Hook can not trigger evict, cause AdmitNEvict is not open"); return false; } - spdlog::debug(MGMT + "evict triggered by hook, evict TableNum {} ", evictKeyMap.size()); + VLOG(GLOG_DEBUG) << StringFormat(MGMT + "evict triggered by hook, evict TableNum %d ", evictKeyMap.size()); if (evictKeyMap.size() == 0) { - spdlog::warn(MGMT + "evict triggered by hook before dataset in injected"); + LOG(WARNING) << (MGMT + "evict triggered by hook before dataset in injected"); return false; } @@ -871,7 +905,9 @@ bool HybridMgmt::Evict() void HybridMgmt::EvictKeys(const string& embName, const vector& keys) { #ifndef GTEST - spdlog::debug(MGMT + "ddr mode, delete emb: [{}]! evict keySize:{}", embName, keys.size()); + VLOG(GLOG_DEBUG) << StringFormat( + MGMT + "ddr mode, delete emb: [%s]! evict keySize:%d", embName.c_str(), keys.size() + ); // 删除映射关系 if (keys.size() != 0) { hostHashMaps->EvictDeleteEmb(embName, keys); @@ -880,15 +916,19 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) // 初始化host侧的emb auto& evictOffset = hostHashMaps->embHashMaps.at(embName).evictPos; if (evictOffset.size() != 0) { - spdlog::debug(MGMT + "ddr mode, delete emb: [{}]! evict size on host:{}", embName, evictOffset.size()); + VLOG(GLOG_DEBUG) << StringFormat( + MGMT + "ddr mode, delete emb: [%s]! evict size on host:%d", embName.c_str(), evictOffset.size() + ); hostEmbs->EvictInitEmb(embName, evictOffset); } else { - spdlog::info(MGMT + "ddr mode, evict size on host is empty"); + LOG(INFO) << StringFormat(MGMT + "ddr mode, evict size on host is empty"); } // 发送dev侧的淘汰pos,以便dev侧初始化emb auto evictDevOffset = hostHashMaps->embHashMaps.at(embName).evictDevPos; - spdlog::debug(MGMT + "ddr mode, init dev emb: [{}]! evict size on dev :{}", embName, evictDevOffset.size()); + VLOG(GLOG_DEBUG) << StringFormat( + MGMT + "ddr mode, init dev emb: [%s]! evict size on dev :%d", embName.c_str(), evictDevOffset.size() + ); vector tmpDataOut; Tensor tmpData = Vec2TensorI32(evictDevOffset); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 5b8ee6e7..6db27f56 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -109,6 +109,8 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); + bool IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t* embTableCount); + private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed); diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index 5ce30a95..4327b638 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -6,7 +6,7 @@ */ #include "constant_initializer.h" -#include +#include "utils/common.h" using namespace std; using namespace MxRec; @@ -23,9 +23,8 @@ void ConstantInitializer::GenerateData(float* const emb, const int embSize) return; } if (embSize < (start + len)) { - spdlog::warn( - "InitializeInfo start {} + len {} is larger than embedding size {}.", - start, len, embSize); + LOG(WARNING) << StringFormat( + "InitializeInfo start %d + len %d is larger than embedding size %d.", start, len, embSize); return; } std::fill_n(emb + start, len, initParam * value); diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index c9f0b5b8..0cff333d 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -6,7 +6,7 @@ */ #include -#include +#include "utils/common.h" #include "random_normal_initializer.h" using namespace MxRec; @@ -25,9 +25,8 @@ void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) return; } if (embSize < (start + len)) { - spdlog::warn( - "InitializeInfo start {} + len {} is larger than embedding size {}.", - start, len, embSize); + LOG(WARNING) << StringFormat( + "InitializeInfo start %d + len %d is larger than embedding size %d.", start, len, embSize); return; } std::generate_n(emb + start, len, [&]() { return initParam * distribution(generator); }); diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index ed151726..85fb4a45 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -6,7 +6,7 @@ */ #include -#include +#include "utils/common.h" #include "truncated_normal_initializer.h" using namespace MxRec; @@ -29,9 +29,8 @@ void TruncatedNormalInitializer::GenerateData(float* const emb, const int embSiz return; } if (embSize < (start + len)) { - spdlog::warn( - "InitializeInfo start {} + len {} is larger than embedding size {}.", - start, len, embSize); + LOG(WARNING) << StringFormat( + "InitializeInfo start %d + len %d is larger than embedding size %d.", start, len, embSize); return; } std::generate_n(emb + start, len, [&]() { diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index a36347ae..29b2da1b 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -7,11 +7,6 @@ #include "feature_admit_and_evict.h" #include -#include -#include -#include -#include -#include "checkpoint/checkpoint.h" using namespace MxRec; @@ -33,7 +28,7 @@ bool FeatureAdmitAndEvict::Init(const std::vector& thresholdValu { if (!ParseThresholdCfg(thresholdValues)) { m_isEnableFunction = false; - spdlog::error("Config is error, feature admin-and-evict function is not available ...\n"); + LOG(ERROR) << "Config is error, feature admin-and-evict function is not available ...\n"; return false; } @@ -45,7 +40,7 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, const std::unique_ptr& batch, keys_t& splitKey, std::vector& keyCount) { if (splitKey.size() != keyCount.size()) { - spdlog::error("splitKey.size {} != keyCount.size {}", splitKey.size(), keyCount.size()); + LOG(ERROR) << StringFormat("splitKey.size %d != keyCount.size %d", splitKey.size(), keyCount.size()); return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR; } @@ -61,15 +56,14 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, absl::flat_hash_map records(m_recordsInitSize); m_recordsData.historyRecords[tensorName] = records; } - spdlog::debug("FeatureAdmitAndEvict PrintSize, name:[{}], history key:[{}] ...", tensorName, - m_recordsData.historyRecords[tensorName].size()); + VLOG(GLOG_DEBUG) << StringFormat( + "FeatureAdmitAndEvict PrintSize, name:[%s], history key:[%d] ...", tensorName.c_str(), + m_recordsData.historyRecords[tensorName].size()); if (batch->timestamp > m_recordsData.timestamps[tensorName]) { m_recordsData.timestamps[tensorName] = batch->timestamp; } absl::flat_hash_map visitedRecords; - spdlog::trace("FeatureAdmit, name:[{}], channel:[{}], before admit, splitKey:[{}] ...", tensorName, channel, - splitKey); for (auto& key : splitKey) { if (key == -1) { continue; @@ -87,12 +81,15 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, continue; } - if (visitedRecords[key] == false) { + if (!visitedRecords[key]) { key = -1; } } - spdlog::trace("FeatureAdmit, name:[{}], channel:[{}], after admit, splitKey:[{}] ...", tensorName, channel, - splitKey); + if (VLOG_IS_ON(GLOG_TRACE)) { + VLOG(GLOG_TRACE) << StringFormat( + "FeatureAdmit, name:[%s], channel:[%d], after admit, splitKey:[%s] ...", tensorName.c_str(), channel, + VectorToString(splitKey).c_str()); + } return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; } @@ -137,11 +134,11 @@ void FeatureAdmitAndEvict::FeatureEvict(map> { std::vector tensorNames = GetAllNeedEvictTensorNames(); if (tensorNames.empty()) { - spdlog::info("EmbNames is empty, no evict function ..."); + LOG(INFO) << "EmbNames is empty, no evict function ..."; return ; } if (!m_isEnableFunction) { - spdlog::warn("m_isEnableFunction switch is false, no evict function ..."); + LOG(WARNING) << "m_isEnableFunction switch is false, no evict function ..."; return ; } std::lock_guard lock(m_syncMutexs); @@ -172,12 +169,12 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v } if (evictKey.size() == 0) { - spdlog::info("tensor-name[{}]'s lastTime[{}], had no key to delete ...", embName, currTime); + LOG(INFO) << StringFormat( + "tensor-name[%s]'s lastTime[%d], had no key to delete ...", embName.c_str(), currTime); return; } - spdlog::info("tensor-name[{}]'s lastTime[{}], had size[{}] keys to delete ...", embName, currTime, - evictKey.size()); - spdlog::trace("tensor-name[{}]'s lastTime[{}], evictKey:[{}] ...", embName, currTime, evictKey); + LOG(INFO) << StringFormat( + "tensor-name[%s]'s lastTime[%d], had size[%d] keys to delete ...", embName.c_str(), currTime, evictKey.size()); // 真正从 m_historyRecords 中淘汰 absl::flat_hash_map& historyRecords = m_recordsData.historyRecords[embName]; @@ -193,9 +190,9 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v void FeatureAdmitAndEvict::SetFunctionSwitch(bool isEnableEvict) { if (isEnableEvict) { - spdlog::info("feature admit-and-evict switch is opened ..."); + LOG(INFO) << "feature admit-and-evict switch is opened ..."; } else { - spdlog::info("feature admit-and-evict switch is closed ..."); + LOG(INFO) << "feature admit-and-evict switch is closed ..."; } m_isEnableFunction = isEnableEvict; } @@ -227,15 +224,16 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t for (size_t i = 0; i < thresholds.size(); ++i) { auto it = std::find(embNames.begin(), embNames.end(), thresholds[i].tensorName); if (it == embNames.end()) { // 配置不存在于当前跑的模型,也要报错 - spdlog::error("embName[{}] is not exist at current model ...", thresholds[i].tensorName); + LOG(ERROR) << StringFormat( + "embName[%s] is not exist at current model ...", thresholds[i].tensorName.c_str()); return false; } else { // 同时支持“准入&淘汰”,却没有传时间戳 if (m_embStatus[*it] == SingleEmbTableStatus::SETS_ERROR) { - spdlog::error("embName[{}] config error, please check ...", embNames[i]); + LOG(ERROR) << StringFormat("embName[%s] config error, please check ...", embNames[i].c_str()); return false; } else if (m_embStatus[*it] == SingleEmbTableStatus::SETS_BOTH && !isTimestamp) { - spdlog::error("embName[{}] admit and evict, but no timestamp", embNames[i]); + LOG(ERROR) << StringFormat("embName[%s] admit and evict, but no timestamp", embNames[i].c_str()); return false; } } @@ -272,18 +270,19 @@ void FeatureAdmitAndEvict::LoadHistoryRecords(AdmitAndEvictData& loadData) bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& thresholdValues) { if (thresholdValues.empty()) { - spdlog::error("thresholdValues is empty ..."); + LOG(ERROR) << "thresholdValues is empty ..."; return false; } m_cfgThresholds = thresholdValues; for (const auto& value : thresholdValues) { - spdlog::info("embName[{}], count[{}], time[{}] ...", - value.tensorName, value.countThreshold, value.timeThreshold); + LOG(INFO) << StringFormat( + "embName[%s], count[%d], time[%d] ...", + value.tensorName.c_str(), value.countThreshold, value.timeThreshold); auto it = m_tensor2Threshold.find(value.tensorName); if (it != m_tensor2Threshold.end()) { // train和eval同时开启,会出现表重复配置 - spdlog::info("[{}] is repeated configuration ...", value.tensorName); + LOG(INFO) << StringFormat("[%s] is repeated configuration ...", value.tensorName.c_str()); return true; } m_tensor2Threshold[value.tensorName] = value; @@ -293,7 +292,7 @@ bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& } else if (value.countThreshold != -1 && value.timeThreshold == -1) { m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ONLY_ADMIT; } else { - spdlog::error("[{}] config error, have evict but no admit ...", value.tensorName); + LOG(ERROR) << StringFormat("[%s] config error, have evict but no admit ...", value.tensorName.c_str()); m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ERROR; return false; } diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index dfc2490c..5dfcc3e0 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -17,7 +17,6 @@ #include #include #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "host_emb/host_emb.h" #include "utils/common.h" #include "utils/safe_queue.h" @@ -58,6 +57,8 @@ namespace MxRec { // 特征淘汰接口 void FeatureEvict(map>& evictKeyMap); + void ExecuteFeatureAdmit( + const string& tensorName, int channel, keys_t& splitKey, absl::flat_hash_map& mergeKeys); // 特征淘汰的使能接口 void SetFunctionSwitch(bool isEnableEvict); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 82d11a72..0ca9d507 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -10,10 +10,7 @@ #include #include -#include -#include #include -#include #include "checkpoint/checkpoint.h" #include "hd_transfer/hd_transfer.h" @@ -37,21 +34,26 @@ inline vector Count2Start(const vector& count) return start; } +void KeyProcess::SetupHotEmbUpdateStep() +{ + const char* env = getenv("HOT_EMB_UPDATE_STEP"); + if (env == nullptr) { + hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; + } else { + hotEmbUpdateStep = stoi(env); + if (hotEmbUpdateStep == 0) { + hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; + } + } +} + bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, const vector& thresholdValues, int seed) { this->rankInfo = rInfo; if (rankInfo.useHot) { - const char* env = getenv("HOT_EMB_UPDATE_STEP"); - if (env == nullptr) { - hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; - } else { - hotEmbUpdateStep = stoi(env); - if (hotEmbUpdateStep == 0) { - hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; - } - } + SetupHotEmbUpdateStep(); } map scInfo; @@ -64,10 +66,13 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos if (rankInfo.useDynamicExpansion) { // 动态扩容 embeddingTableMap[info.name].Init(info, rInfo, seed); - spdlog::info(KEY_PROCESS "EmbeddingTableMap:{} init success", info.name); + LOG(INFO) << StringFormat(KEY_PROCESS "EmbeddingTableMap:%s init success", info.name.c_str()); } } - spdlog::info(KEY_PROCESS "hot emb count info:{}", hotEmbTotCount); + + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat(KEY_PROCESS "hot emb count info:%s", MapToString(hotEmbTotCount).c_str()); + } MPI_Group worldGroup; MPI_Comm_group(MPI_COMM_WORLD, &worldGroup); for (auto& i: comm) { @@ -83,15 +88,19 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos m_featureAdmitAndEvict.Init(thresholdValues); } else { m_featureAdmitAndEvict.SetFunctionSwitch(false); - spdlog::warn(KEY_PROCESS "Feature admit-and-evict function is unavailable ..."); + LOG(WARNING) << KEY_PROCESS "Feature admit-and-evict function is unavailable ..."; } if (PerfConfig::fastUnique) { Factory::Create(factory); } - spdlog::info(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", scInfo, - rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat( + KEY_PROCESS "scInfo:%s, localRankSize:%d, rankSize:%d, useStatic:%d, useHot:%d", + MapToString(scInfo).c_str(), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot + ); + } return true; } @@ -102,12 +111,12 @@ int KeyProcess::Start() // 0 1 2 3 4 5 0 1 2 3 4 5 // | rank0 | | rank1 | // each rank creates KEY_PROCESS_THREAD threads, each thread process one batchdata - spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + LOG(INFO) << StringFormat("CPU Core Num: %d", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { #ifndef GTEST auto ret = aclrtSetDevice(static_cast(rankInfo.deviceId)); if (ret != ACL_ERROR_NONE) { - spdlog::error("Set device failed, device_id:{}", rankInfo.deviceId); + LOG(ERROR) << StringFormat("Set device failed, device_id:%d", rankInfo.deviceId); return; } #endif @@ -119,12 +128,12 @@ int KeyProcess::Start() if (threadNumEnv != nullptr) { threadNum = static_cast(*threadNumEnv) - static_cast('0'); if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { - throw runtime_error(fmt::format("{} is not valid", threadNum)); + throw runtime_error(StringFormat("%d is not valid", threadNum)); } } else { threadNum = KEY_PROCESS_THREAD; } - spdlog::info(KEY_PROCESS "key process thread num: {}", threadNum); + LOG(INFO) << StringFormat(KEY_PROCESS "key process thread num: %d", threadNum); for (int id = 0; id < threadNum; ++id) { procThreads.emplace_back( std::make_unique(fn, channel, id)); // use lambda expression initialize thread @@ -172,12 +181,12 @@ void KeyProcess::LoadKeyOffsetMap(key_offset_mem_t& loadData) void KeyProcess::Destroy() { isRunning = false; - spdlog::info(KEY_PROCESS "rank {} begin destroy.", rankInfo.rankId); + LOG(INFO) << StringFormat(KEY_PROCESS "rank %d begin destroy.", rankInfo.rankId); for (auto& i: procThreads) { i->join(); } procThreads.clear(); - spdlog::info(KEY_PROCESS "rank {} destroy success.", rankInfo.rankId); + LOG(INFO) << StringFormat(KEY_PROCESS "rank %d destroy success.", rankInfo.rankId); } void KeyProcess::LoadSaveLock() @@ -253,18 +262,18 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES GetUniqueConfig(uniqueConf); } - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); try { while (true) { TimeCost getAndProcessTC; TimeCost getBatchDataTC; batch = GetBatchData(channel, id); // get batch data from SingletonQueue - TIME_PRINT("getBatchDataTC(ms):{}", getBatchDataTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("getBatchDataTC(ms):%d", getBatchDataTC.ElapsedMS()); if (batch == nullptr) { break; } - auto getBatchTime = Format2Ms(sw); - sw.reset(); + auto getBatchTime = tc.ElapsedMS(); + tc = TimeCost(); bool ret = false; if (PerfConfig::fastUnique) { @@ -277,8 +286,11 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES if (!ret) { break; } - spdlog::info(KEY_PROCESS "getAndProcessTC(ms):{}, key process cost:{}, get data time:{} batch {}[{}]:{} ", - getAndProcessTC.ElapsedMS(), Format2Ms(sw), getBatchTime, batch->name, batch->channel, batch->batchId); + LOG(INFO) << StringFormat( + KEY_PROCESS "getAndProcessTC(ms):%d, key process cost:%d, get data time:%d batch %s[%d]:%d", + getAndProcessTC.ElapsedMS(), tc.ElapsedMS(), getBatchTime, + batch->name.c_str(), batch->channel, batch->batchId + ); auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } @@ -286,9 +298,10 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES unique->UnInitialize(); } } catch (const EndRunError &e) { - spdlog::debug(KEY_PROCESS "abort run: {}", e.what()); + VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "abort run: %s", e.what()); } - spdlog::info(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, id, channel); + LOG(INFO) << StringFormat( + KEY_PROCESS "KeyProcessTask exit. rank:%d thread:%d, channel:%d", rankInfo.rankId, id, channel); } void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, @@ -306,7 +319,7 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < tie(splitKeys, restore) = HashSplit(batch); // 按存储dev id切分并去重 } } - TIME_PRINT("UniqueTC(ms):{}", UniqueTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("UniqueTC(ms):%d", UniqueTC.ElapsedMS()); } bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, UniquePtr& unique, @@ -318,14 +331,14 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat TimeCost fastUniqueTC; UniqueInfo uniqueInfo; ProcessBatchWithFastUnique(batch, unique, id, uniqueInfo); - TIME_PRINT("ProcessBatchWithFastUnique(ms):{}", fastUniqueTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("ProcessBatchWithFastUnique(ms):%d", fastUniqueTC.ElapsedMS()); // 特征准入&淘汰 if (isWithFAAE && (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, uniqueInfo.all2AllInfo.keyRecv, uniqueInfo.all2AllInfo.countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { - spdlog::error(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", + LOG(ERROR) << StringFormat(KEY_PROCESS "rank:%d thread:%d, channel:%d, Feature-admit-and-evict error ...", rankInfo.rankId, id, channel); return false; } @@ -335,7 +348,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat if (rankInfo.noDDR) { TimeCost key2OffsetTC; Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv, channel); - TIME_PRINT("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); } if (!rankInfo.useStatic) { // Static all2all,need send count SendA2A(uniqueInfo.all2AllInfo.scAll, batch->name, batch->channel, batch->batchId); @@ -356,7 +369,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat } TimeCost pushResultTC; PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv); - TIME_PRINT("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); return true; } @@ -383,7 +396,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE && (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, lookupKeys, countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { - spdlog::error(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", + LOG(ERROR) << StringFormat(KEY_PROCESS "rank:%d thread:%d, channel:%d, Feature-admit-and-evict error ...", rankInfo.rankId, id, channel); return false; } @@ -417,7 +430,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe } } PushResult(batch, move(tensors), lookupKeys); - TIME_PRINT("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); return true; } @@ -447,7 +460,7 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, countRecv.resize(rs.back() + rc.back()); MPI_Alltoallv(countSend.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), rc.data(), rs.data(), MPI_UINT32_T, comm[batch->channel][id]); - TIME_PRINT("getCountRecvTC(ms)(with-all2all):{}", getCountRecvTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("getCountRecvTC(ms)(with-all2all):%d", getCountRecvTC.ElapsedMS()); return countRecv; } @@ -476,7 +489,7 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) auto batchQueue = SingletonQueue::getInstances(commId + KEY_PROCESS_THREAD * channel); EASY_BLOCK("get samples") EASY_VALUE("run on CPU", sched_getcpu()) - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); while (true) { batch = batchQueue->TryPop(); if (batch != nullptr) { @@ -484,13 +497,15 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) } else { this_thread::sleep_for(100us); } - if (duration_cast(sw.elapsed()).count() > GET_BATCH_TIMEOUT) { + if (tc.ElapsedSec() > GET_BATCH_TIMEOUT) { if (commId == 0) { - spdlog::warn(KEY_PROCESS "getting batch timeout! 1. check last 'read batch cost' print. " - "channel[{}] commId[{}]", channel, commId); + LOG(WARNING) << StringFormat( + KEY_PROCESS "getting batch timeout! 1. check last 'read batch cost' print. channel[%d] commId[%d]", + channel, commId + ); } this_thread::sleep_for(seconds(1)); - sw.reset(); + tc = TimeCost(); } if (!isRunning) { // 通信终止信号,同步退出,防止线程卡住 @@ -500,14 +515,16 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) } } EASY_END_BLOCK - spdlog::debug(KEY_PROCESS "rank {} thread {} get batch {}[{}]:{} done. bs:{} sample:[{}]", - rankInfo.rankId, commId, batch->name, batch->channel, batch->batchId, batch->Size(), - batch->UnParse()); + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "rank %d thread %d get batch %s[%d]:%d done. bs:%d sample:[%s]", + rankInfo.rankId, commId, + batch->name.c_str(), batch->channel, batch->batchId, batch->Size(), batch->UnParse().c_str() + ); #if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) if (batch->batchId == PROFILING_START_BATCH_ID) { EASY_PROFILER_ENABLE } else if (batch->batchId == PROFILING_END_BATCH_ID) { - ::profiler::dumpBlocksToFile(fmt::format("/home/MX_REC-profile-{}.prof", rankInfo.rankId).c_str()); + ::profiler::dumpBlocksToFile(StringFormat("/home/MX_REC-profile-%d.prof", rankInfo.rankId).c_str()); } #endif return batch; @@ -559,16 +576,19 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch int ret = unique->DoEnhancedUnique(uniqueIn, uniqueOut); EASY_END_BLOCK - TIME_PRINT("FastUniqueCompute(ms):{}, ret:{}", uniqueTC.ElapsedMS(), ret); + VLOG(GLOG_DEBUG) << StringFormat("FastUniqueCompute(ms):%d, ret:%d", uniqueTC.ElapsedMS(), ret); vector sc; HandleHotAndSendCount(batch, uniqueInfoOut, keySendInfo, sc, splitSize); All2All(sc, id, batch->channel, keySendInfo, uniqueInfoOut.all2AllInfo); - spdlog::debug(KEY_PROCESS "ProcessBatchWithFastUnique get batchId:{}, batchSize:{}, channel:{}, " - "name:{}, restore:{}, keyCount:{}", batch->batchId, batch->Size(), - batch->channel, batch->name, uniqueInfoOut.restore.size(), keySendInfo.keyCount.size()); + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "ProcessBatchWithFastUnique get batchId:%d, batchSize:%d," + " channel:%d, name:%s, restore:%d, keyCount:%d", + batch->batchId, batch->Size(), batch->channel, batch->name.c_str(), + uniqueInfoOut.restore.size(), keySendInfo.keyCount.size() + ); } void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, @@ -582,10 +602,10 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uni int hotOffset = 0; uniqueInfoOut.hotPos.resize(hotEmbTotCount[batch->name]); hotOffset = hotEmbTotCount[batch->name]; - + TimeCost ComputeHotTc; ComputeHotPos(batch, hotMap, uniqueInfoOut.hotPos, uniqueInfoOut.restore, hotOffset); - TIME_PRINT("ComputeHot TimeCost(ms):{}", ComputeHotTc.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("ComputeHot TimeCost(ms):%d", ComputeHotTc.ElapsedMS()); UpdateHotMapForUnique(keySendInfo.keySend, keySendInfo.keyCount, hotOffset, batch->batchId % hotEmbUpdateStep == 0, batch->name); } @@ -629,7 +649,7 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS { TimeCost getScAllTC; GetScAllForUnique(sc, id, channel, all2AllInfoOut.scAll); // Allgather通信获取所有(不同rank相同thread id的) - TIME_PRINT("GetScAll TimeCost(ms):{}", getScAllTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("GetScAll TimeCost(ms):%d", getScAllTC.ElapsedMS()); TimeCost all2allTC; auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 @@ -649,7 +669,7 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfoOut.countRecv.data(), rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); } - TIME_PRINT("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("all2allTC TimeCost(ms):%d", all2allTC.ElapsedMS()); EASY_END_BLOCK } @@ -659,16 +679,18 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, TimeCost processSplitKeysTC; EASY_FUNCTION(profiler::colors::Purple) EASY_VALUE("batchId", batch->batchId) - spdlog::info(KEY_PROCESS "ProcessSplitKeys start batchId:{}, channel:{}", batch->batchId, batch->channel); + LOG(INFO) << StringFormat( + KEY_PROCESS "ProcessSplitKeys start batchId:%d, channel:%d", batch->batchId, batch->channel); // 使用静态all2all通信:发送或接受量为预置固定值 scInfo[batch->name] = 65536 / rankSize 经验值 if (rankInfo.useStatic) { // maybe move after all2all for (auto& i: splitKeys) { if (static_cast(i.size()) > embInfos[batch->name].sendCount) { - spdlog::error("{}[{}]:{} overflow! set send count bigger than {}", - batch->name, batch->channel, batch->batchId, i.size()); - throw runtime_error(fmt::format("{}[{}]:{} overflow! set send count bigger than {}", - batch->name, batch->channel, batch->batchId, i.size()).c_str()); + LOG(ERROR) << StringFormat("%s[%d]:%d overflow! set send count bigger than %d", + batch->name.c_str(), batch->channel, batch->batchId, i.size()); + throw runtime_error( + StringFormat("%s[%d]:%d overflow! set send count bigger than %d", + batch->name.c_str(), batch->channel, batch->batchId, i.size()).c_str()); } i.resize(embInfos[batch->name].sendCount, -1); } @@ -683,7 +705,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, TimeCost getScAllTC; auto scAll = GetScAll(sc, id, batch->channel); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 - TIME_PRINT("getScAllTC(ms)(AllReduce-AllGather):{}", getScAllTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("getScAllTC(ms)(AllReduce-AllGather):%d", getScAllTC.ElapsedMS()); auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 vector rc; // receive count @@ -693,20 +715,19 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, } auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 keyRecv.resize(rs.back() + rc.back()); - spdlog::trace(KEY_PROCESS "MPI_Alltoallv begin. rank {} thread {} batch {} {}", rankInfo.rankId, id, batch->batchId, - batch->name); + VLOG(GLOG_TRACE) << StringFormat(KEY_PROCESS "MPI_Alltoallv begin. rank %d thread %d batch %d %s", + rankInfo.rankId, id, batch->batchId, batch->name.c_str()); EASY_BLOCK("all2all") TimeCost uniqueAll2AllTC; MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, - keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, - comm[batch->channel][id]); - TIME_PRINT("uniqueAll2AllTC(ms):{}", uniqueAll2AllTC.ElapsedMS()); + keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[batch->channel][id]); + VLOG(GLOG_DEBUG) << StringFormat("uniqueAll2AllTC(ms):%d", uniqueAll2AllTC.ElapsedMS()); EASY_END_BLOCK - spdlog::trace(KEY_PROCESS "MPI_Alltoallv finish. rank {} thread {} batch {} {}", - rankInfo.rankId, id, batch->batchId, batch->name); - TIME_PRINT("processSplitKeysTC(ms):{}", processSplitKeysTC.ElapsedMS()); + VLOG(GLOG_TRACE) << StringFormat(KEY_PROCESS "MPI_Alltoallv finish. rank %d thread %d batch %d %s", + rankInfo.rankId, id, batch->batchId, batch->name.c_str()); + VLOG(GLOG_DEBUG) << StringFormat("processSplitKeysTC(ms):%d", processSplitKeysTC.ElapsedMS()); return { keyRecv, scAll, ss }; } @@ -738,7 +759,7 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< } } EASY_END_BLOCK - if (spdlog::get_level() == spdlog::level::trace) { + if (VLOG_IS_ON(GLOG_TRACE)) { stringstream ssTrace; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { ssTrace << '|' << devId << ":"; @@ -747,7 +768,9 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< } ssTrace << '|'; } - spdlog::trace("dump splitKeys\n{}", ssTrace.str()); + VLOG(GLOG_TRACE) << StringFormat( + "dump splitKeys\n%s", ssTrace.str().c_str() + ); } return { splitKeys, restore }; } @@ -789,7 +812,7 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const } EASY_END_BLOCK - if (spdlog::get_level() == spdlog::level::trace) { + if (VLOG_IS_ON(GLOG_TRACE)) { stringstream ssTrace; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { ssTrace << '|' << devId << ":"; @@ -798,7 +821,7 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const } ssTrace << '|'; } - spdlog::trace("dump splitKeys\n{}", ssTrace.str()); + VLOG(GLOG_TRACE) << StringFormat("dump splitKeys\n%s", ssTrace.str().c_str()); } return { splitKeys, restore, keyCount }; @@ -936,18 +959,22 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int scAll.resize(rankInfo.rankSize * rankInfo.rankSize); EASY_BLOCK("barrier"); // 通信终止信号,同步退出,防止线程卡住 - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); int exitFlag = isRunning; MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); if (exitFlag < rankInfo.rankSize) { throw EndRunError("GetScAll end run."); } EASY_END_BLOCK; - spdlog::debug(KEY_PROCESS "barrier time:{}", Format2Ms(sw)); + VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "barrier time:%d", tc.ElapsedMS()); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); - spdlog::debug("rank {} key scAll matrix:\n{}", rankInfo.rankId, scAll); + if (VLOG_IS_ON(GLOG_DEBUG)) { + VLOG(GLOG_DEBUG) << StringFormat( + "rank %d key scAll matrix:\n%s", rankInfo.rankId, VectorToString(scAll).c_str() + ); + } return scAll; } @@ -957,18 +984,22 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, in scAllOut.resize(rankInfo.rankSize * rankInfo.rankSize); EASY_BLOCK("barrier"); // 通信终止信号,同步退出,防止线程卡住 - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); int exitFlag = isRunning; MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); if (exitFlag < rankInfo.rankSize) { throw EndRunError("GetScAll end run."); } EASY_END_BLOCK; - spdlog::debug(KEY_PROCESS "barrier time:{}", duration_cast((sw).elapsed())); + VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "barrier time:%d", tc.ElapsedMS()); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); - spdlog::debug("rank {} key scAllOut matrix:\n{}", rankInfo.rankId, scAllOut); + if (VLOG_IS_ON(GLOG_DEBUG)) { + VLOG(GLOG_DEBUG) << StringFormat( + "rank %d key scAllOut matrix:\n%s", rankInfo.rankId, VectorToString(scAllOut).c_str() + ); + } } void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int channel) @@ -990,8 +1021,10 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int cha size_t offset; // 新值, emb有pos可复用 offset = evictPos.back(); - spdlog::trace("HBM mode, evictPos is not null, name[{}] key [{}] reuse offset [{}], evictSize [{}]!!!", - embName, key, offset, evictPos.size()); + VLOG(GLOG_TRACE) << StringFormat( + "HBM mode, evictPos is not null, name[%s] key [%d] reuse offset [%d], evictSize [%d]!!!", + embName.c_str(), key, offset, evictPos.size() + ); key2Offset[key] = offset; key = offset; evictPos.pop_back(); @@ -1006,11 +1039,11 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int cha } } if (maxOffsetTmp > embInfos[embName].devVocabSize) { - spdlog::error("dev cache overflow {}>{}", maxOffsetTmp, embInfos[embName].devVocabSize); + LOG(ERROR) << StringFormat("dev cache overflow %d>%d", maxOffsetTmp, embInfos[embName].devVocabSize); throw std::runtime_error("dev cache overflow!"); } - spdlog::debug("current dev emb usage:{}/{}", maxOffsetTmp, embInfos[embName].devVocabSize); - TIME_PRINT("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("current dev emb usage:%d/%d", maxOffsetTmp, embInfos[embName].devVocabSize); + VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); } void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& splitKey, int channel) @@ -1043,8 +1076,8 @@ void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& s key = 0; } } - spdlog::debug("current dev emb usage:{}/{}", maxOffsetTmp, embInfos[embName].devVocabSize); - TIME_PRINT("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("current dev emb usage:%d/%d", maxOffsetTmp, embInfos[embName].devVocabSize); + VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); } /* @@ -1060,18 +1093,17 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vec TimeCost buildRestoreVecTC; EASY_FUNCTION() int hotNum = 0; - bool spdDebug = (spdlog::get_level() == spdlog::level::debug); for (size_t i = 0; i < batch->Size(); ++i) { const emb_key_t d = batch->sample[i]; int devId = static_cast(d) & (rankInfo.rankSize - 1); if (restoreVec[i] >= hotPosSize) { restoreVec[i] += blockOffset[devId]; - } else if (spdDebug) { + } else if (VLOG_IS_ON(GLOG_DEBUG)) { hotNum += 1; } } - spdlog::debug("hot num in all:{}/{}", hotNum, batch->Size()); - TIME_PRINT("buildRestoreVecTC(ms):{}", buildRestoreVecTC.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("hot num in all:%d/%d", hotNum, batch->Size()); + VLOG(GLOG_DEBUG) << StringFormat("buildRestoreVecTC(ms):%d", buildRestoreVecTC.ElapsedMS()); } class EmptyList : public std::exception { @@ -1085,17 +1117,17 @@ T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, in { std::lock_guard lockGuard(mut); if (list[embName][channel].empty()) { - spdlog::trace("get info list is empty."); + VLOG(GLOG_TRACE) << "get info list is empty."; throw EmptyList(); } auto topBatch = get(list[embName][channel].top()); if (topBatch < batch) { - spdlog::error("wrong batch id, top:{} getting:{}, channel:{}, may not clear channel", topBatch, - batch, channel); + LOG(ERROR) << StringFormat( + "wrong batch id, top:%d getting:%d, channel:%d, may not clear channel", topBatch, batch, channel); this_thread::sleep_for(1s); } if (topBatch != batch) { - spdlog::trace("topBatch({}) is not equal batch({}).", topBatch, batch); + VLOG(GLOG_TRACE) << StringFormat("topBatch(%d) is not equal batch(%d).", topBatch, batch); throw WrongListTop(); } auto t = list[embName][channel].top(); @@ -1105,23 +1137,25 @@ T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, in keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) { - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); while (true) { if (!isRunning) { return {}; } - if (batch != 0 && channel != 0 && duration_cast(sw.elapsed()).count() > KEY_PROCESS_TIMEOUT) { - spdlog::warn(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", embName, channel, batch); + if (batch != 0 && channel != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { + LOG(WARNING) << StringFormat( + KEY_PROCESS "getting lookup keys timeout! %s[%d]:%d", embName.c_str(), channel, batch); return {}; } try { auto ret = GetInfo(lookupKeysList, batch, embName, channel); return get(ret); } catch (EmptyList&) { - spdlog::trace("getting info failed {}[{}]:{}", embName, channel, batch); + VLOG(GLOG_TRACE) << StringFormat("getting info failed %s[%d]:%d", embName.c_str(), channel, batch); this_thread::sleep_for(1ms); } catch (WrongListTop&) { - spdlog::trace("getting info failed {}[{}]:{} wrong top", embName, channel, batch); + VLOG(GLOG_TRACE) << StringFormat( + "getting info failed %s[%d]:%d wrong top", embName.c_str(), channel, batch); this_thread::sleep_for(1ms); } } @@ -1129,7 +1163,7 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) { - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); info_list_t* list; switch (type) { case ProcessedInfo::ALL2ALL: @@ -1145,8 +1179,9 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa if (!isRunning) { return nullptr; } - if (batch != 0 && channel != 0 && duration_cast(sw.elapsed()).count() > KEY_PROCESS_TIMEOUT) { - spdlog::warn(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", embName, channel, batch); + if (batch != 0 && channel != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { + LOG(WARNING) << StringFormat( + KEY_PROCESS "getting lookup keys timeout! %s[%d]:%d", embName.c_str(), channel, batch); return nullptr; } try { @@ -1157,10 +1192,11 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa storage.erase(it); return uTensor; } catch (EmptyList&) { - spdlog::trace("getting info failed {}[{}]:{}", embName, channel, batch); + VLOG(GLOG_TRACE) << StringFormat("getting info failed %s[%d]:%d", embName.c_str(), channel, batch); this_thread::sleep_for(1ms); } catch (WrongListTop&) { - spdlog::trace("getting info failed {}[{}]:{} wrong top", embName, channel, batch); + VLOG(GLOG_TRACE) << StringFormat( + "getting info failed %s[%d]:%d wrong top", embName.c_str(), channel, batch); this_thread::sleep_for(1ms); } } @@ -1192,7 +1228,7 @@ int KeyProcess::GetMaxStep(int channelId) const void KeyProcess::EvictKeys(const string& embName, const vector& keys) // hbm { - spdlog::info(KEY_PROCESS "hbm funEvictCall: [{}]! keySize:{}", embName, keys.size()); + LOG(INFO) << StringFormat(KEY_PROCESS "hbm funEvictCall: [%s]! keySize:%d", embName.c_str(), keys.size()); // 删除映射关系 if (keys.size() != 0) { @@ -1216,7 +1252,7 @@ void KeyProcess::EvictDeleteDeviceEmb(const string& embName, const vectorsecond; devHashMap.erase(iter); evictPos.emplace_back(offset); - spdlog::trace("evict embName {} , offset , {}", embName, offset); + VLOG(GLOG_TRACE) << StringFormat("evict embName:%s, offset:%d", embName.c_str(), offset); } - spdlog::info(KEY_PROCESS "hbm EvictDeleteDeviceEmb: [{}]! evict size on dev:{}", embName, evictPos.size()); + LOG(INFO) << StringFormat( + KEY_PROCESS "hbm EvictDeleteDeviceEmb: [%s]! evict size on dev:%d", embName.c_str(), evictPos.size()); } void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset) { if (offset.size() > embInfos[embName].devVocabSize) { - spdlog::error("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", - embName, offset.size(), embInfos[embName].devVocabSize); - throw runtime_error(fmt::format("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", - embName, offset.size(), embInfos[embName].devVocabSize).c_str()); + LOG(ERROR) << StringFormat( + "%s overflow! init evict dev, evictOffset size %s bigger than dev vocabSize %d", + embName.c_str(), offset.size(), embInfos[embName].devVocabSize); + throw runtime_error( + StringFormat( + "%s overflow! init evict dev, evictOffset size %d bigger than dev vocabSize %d", + embName.c_str(), offset.size(), embInfos[embName].devVocabSize + ).c_str()); } vector tmpDataOut; @@ -1253,5 +1294,6 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset auto trans = Singleton::GetInstance(); trans->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); - spdlog::info(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); + LOG(INFO) << StringFormat( + KEY_PROCESS "hbm EvictInitDeviceEmb: [%s]! send offsetSize:%d", embName.c_str(), offset.size()); } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 4c2745b2..b3170541 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -17,11 +17,8 @@ #include #include -#include #include #include -#include -#include #include "utils/common.h" #include "utils/safe_queue.h" @@ -87,6 +84,8 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); + void SetupHotEmbUpdateStep(); + bool isRunning { false }; inline bool hasEmbName(const string &emb_name) diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 6b4b9367..d048cdb3 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -8,12 +8,15 @@ #include "common.h" -#include -#include -#include +#include +#include +#include + #include +#include #include +#include using namespace std; using std::chrono::system_clock; @@ -22,6 +25,8 @@ namespace MxRec { int PerfConfig::keyProcessThreadNum = DEFAULT_KEY_PROCESS_THREAD; int PerfConfig::maxUniqueThreadNum = DEFAULT_MAX_UNIQUE_THREAD_NUM; bool PerfConfig::fastUnique = false; + string g_rankId; + int g_glogLevel; RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, @@ -91,14 +96,15 @@ namespace MxRec { } } - void SetLog(int rank) + void SetLog() { - auto logger = spdlog::stderr_color_mt("console"); - spdlog::set_default_logger(logger); - std::string pattern = "[%H:%M:%S.%e] [" + std::to_string(rank) + "] [%^%l%$] %v"; - spdlog::default_logger()->set_pattern(pattern); - auto env_val = spdlog::details::os::getenv("SPDLOG_LEVEL"); - spdlog::cfg::load_env_levels(); + // glog 0.5.0 can't pass any args into, and not support custom format + auto logLevel = getenv("GLOG_stderrthreshold"); + if (logLevel == nullptr) { + g_glogLevel = 0; // default as INFO + } else { + g_glogLevel = atoi(logLevel); + } } string GetChipName(int devID) @@ -109,8 +115,10 @@ namespace MxRec { { 0 }}; ret = dsmi_get_chip_info(devID, &info); if (ret == 0) { - spdlog::debug("dsmi_get_chip_info successful, ret = {}, chip_name = {}", ret, - reinterpret_cast(info.chip_name)); + VLOG(GLOG_DEBUG) << StringFormat( + "dsmi_get_chip_info successful, ret = %d, chip_name = %s", ret, + reinterpret_cast(info.chip_name) + ); return reinterpret_cast(info.chip_name); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 0bd8d3e5..2b8e3b6a 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -23,13 +23,10 @@ #include #include #include -#include -#include -#include -#include -#include +#include #include "tensorflow/core/framework/tensor.h" #include "absl/container/flat_hash_map.h" +#include "securec.h" #include "initializer/initializer.h" #include "initializer/constant_initializer/constant_initializer.h" @@ -50,9 +47,10 @@ namespace MxRec { #define INFO_PTR shared_ptr -#define TIME_PRINT spdlog::debug #define MGMT_CPY_THREADS 4 #define PROFILING + extern int g_glogLevel; + using namespace tensorflow; constexpr int TRAIN_CHANNEL_ID = 0; constexpr int EVAL_CHANNEL_ID = 1; @@ -63,6 +61,9 @@ namespace MxRec { constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; constexpr int KEY_PROCESS_THREAD = 6; + // for GLOG + constexpr int GLOG_MAX_BUF_SIZE = 2048; + // unique related config constexpr int UNIQUE_BUCKET = 6; constexpr int MIN_UNIQUE_THREAD_NUM = 1; @@ -134,11 +135,6 @@ namespace MxRec { throw std::runtime_error("unknown chip ub size" + GetChipName(devID)); } - inline std::chrono::milliseconds::rep Format2Ms(spdlog::stopwatch& sw) - { - return std::chrono::duration_cast((sw).elapsed()).count(); - } - template struct Batch { size_t Size() const @@ -270,7 +266,67 @@ struct BatchTask { absl::flat_hash_map timestamps; // 用于特征准入&淘汰的时间戳 }; - void SetLog(int rank); + void SetLog(); + + template + string StringFormat(const string& format, Args ... args) + { + auto size = static_cast(GLOG_MAX_BUF_SIZE); + unique_ptr buf(new char[size]); + memset_s(buf.get(), size, 0, size); + snprintf_s(buf.get(), size, SECUREC_STRING_MAX_LEN-1, format.c_str(), args ...); + return string(buf.get(), buf.get() + size); + } + + // use environment variable GLOG_v to decide if showing debug log. + // default 0, debug message will not display. + // 1 for debug, 2 for trace + const int GLOG_DEBUG = 1, GLOG_TRACE = 2; + + template + std::string VectorToString(const std::vector& vec) + { + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < vec.size(); ++i) { + ss << vec[i]; + if (i != vec.size() - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + + template + std::string MapToString(const std::map& map) + { + std::stringstream ss; + ss << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + ss << it->first << ": " << it->second; + if (std::next(it) != map.end()) { + ss << ", "; + } + } + ss << "}"; + return ss.str(); + } + + template + std::string MapToString(const absl::flat_hash_map& map) + { + std::stringstream ss; + ss << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + ss << it->first << ": " << it->second; + if (std::next(it) != map.end()) { + ss << ", "; + } + } + ss << "}"; + return ss.str(); + } inline void GenerateRandomValue(std::vector& vecData, std::default_random_engine& generator, diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 6642be39..a9ef8f2e 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -11,17 +11,9 @@ #include #include -#include -#include -#include -#include -#include -#include - #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/example/example.pb.h" #include "key_process/key_process.h" #include "key_process/feature_admit_and_evict.h" @@ -42,15 +34,15 @@ using OpKernelConstructionPtr = OpKernelConstruction*; using OpKernelContextPtr = OpKernelContext*; using InferenceContextPtr = ::tensorflow::shape_inference::InferenceContext*; -spdlog::stopwatch staticSw {}; -spdlog::stopwatch staticReadRaw {}; +TimeCost staticSw {}; +TimeCost staticReadRaw {}; array batchIdsInfo {}; size_t GetBatchSize(OpKernelContextPtr context, const size_t dataSize, const size_t fieldNum) { if (fieldNum == 0 || dataSize / fieldNum <= 0) { context->SetStatus( - errors::Aborted(__FILE__, ":", __LINE__, " ", fmt::format("batchSize error. {}/{}", dataSize, fieldNum))); + errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat("batchSize error. %d/%d", dataSize, fieldNum))); return 0; } return dataSize / fieldNum; @@ -65,14 +57,14 @@ public: OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { - throw runtime_error(fmt::format("channelId is invalid, It should be in range [0, {})", - MAX_CHANNEL_NUM)); + throw runtime_error(StringFormat( + "channelId is invalid, It should be in range [0, %d)", MAX_CHANNEL_NUM)); } if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("ClearChannel channelId invalid. It should be in range " - "[0, MAX_CHANNEL_NUM:{})", MAX_CHANNEL_NUM))); + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "ClearChannel channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", + MAX_CHANNEL_NUM))); return; } } @@ -81,7 +73,7 @@ public: void Compute(OpKernelContextPtr context) override { - spdlog::info("clear channel {}, context {}", channelId, context->step_id()); + LOG(INFO) << StringFormat("clear channel %d, context %d", channelId, context->step_id()); batchIdsInfo.at(channelId) = 0; } @@ -136,15 +128,7 @@ class ReadEmbKeyV2Dynamic : public OpKernel { public: explicit ReadEmbKeyV2Dynamic(OpKernelConstructionPtr context) : OpKernel(context) { - if (!spdlog::get("console")) { - auto logger = spdlog::stderr_color_mt("console"); - spdlog::set_default_logger(logger); - } else { - spdlog::set_default_logger(spdlog::get("console")); - } - spdlog::cfg::load_env_levels(); - spdlog::default_logger()->set_pattern("[%H:%M:%S.%e] [%^%l%$] %v"); - spdlog::debug("ReadEmbKeyV2Dynamic init"); + VLOG(GLOG_DEBUG) << "ReadEmbKeyV2Dynamic init"; OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); @@ -153,16 +137,18 @@ public: // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() && - !FeatureAdmitAndEvict::IsThresholdCfgOK(FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp)) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("threshold config, or timestamp error ..."))); + !FeatureAdmitAndEvict::IsThresholdCfgOK( + FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp) + ) { + context->SetStatus( + errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold config, or timestamp error ...")); return; } if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("ReadEmbKeyV2Dynamic channelId invalid. It should be in " - "range [0, MAX_CHANNEL_NUM:{})", MAX_CHANNEL_NUM))); + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "ReadEmbKeyV2Dynamic channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", + MAX_CHANNEL_NUM))); return; } batchIdsInfo.at(channelId) = 0; @@ -171,7 +157,7 @@ public: if (threadNumEnv != nullptr) { threadNum = static_cast(*threadNumEnv) - static_cast('0'); if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { - throw runtime_error(fmt::format("{} is not valid", threadNum)); + throw runtime_error(StringFormat("%d is not valid", threadNum)); } } else { threadNum = KEY_PROCESS_THREAD; @@ -189,12 +175,12 @@ public: void Compute(OpKernelContextPtr context) override { EASY_FUNCTION(); - spdlog::debug("enter ReadEmbKeyV2Dynamic"); - spdlog::stopwatch sw; + VLOG(GLOG_DEBUG) << "enter ReadEmbKeyV2Dynamic"; + TimeCost tc = TimeCost(); int batchId = batchIdsInfo.at(channelId)++; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { - spdlog::warn("skip excess batch after {}/{}", batchId, maxStep); + LOG(WARNING) << StringFormat("skip excess batch after %d/%d", batchId, maxStep); return; } } @@ -209,9 +195,8 @@ public: time_t timestamp = -1; // 如果传递了时间戳,解析和校验 if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("timestamp[{}] error, skip excess batch after {}/{}", - timestamp, batchId, maxStep))); + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "timestamp[%d] error, skip excess batch after %d/%d", timestamp, batchId, maxStep))); return; } // 保证所有embNames在m_embStatus中有状态记录 @@ -227,10 +212,12 @@ public: TimeCost enqueueTC; EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); - TIME_PRINT(KEY_PROCESS "ReadEmbKeyV2Dynamic read batch cost(ms):{}, elapsed from last(ms):{}, " - "enqueueTC(ms):{}, batch[{}]:{}", - Format2Ms(sw), Format2Ms(staticSw), enqueueTC.ElapsedMS(), channelId, batchId); - staticSw.reset(); + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "ReadEmbKeyV2Dynamic read batch cost(ms):%d, elapsed from last(ms):%d," + " enqueueTC(ms):%d, batch[%d}]:%d", + tc.ElapsedMS(), staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId + ); + staticSw = TimeCost(); } void CheckEmbTables() @@ -238,7 +225,9 @@ public: auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < embNames.size(); ++i) { if (!keyProcess->hasEmbName(embNames.at(i))) { - spdlog::info("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i)); + LOG(INFO) << StringFormat( + "ReadEmbKeyV2Dynamic not found emb_name:%d %s", i, embNames.at(i).c_str() + ); tableUsed.push_back(false); } else { tableUsed.push_back(true); @@ -288,18 +277,18 @@ public: size_t& dataSize) { if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 - spdlog::error("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); + LOG(ERROR) << StringFormat("dataSize[%d], fieldNum[%d] ...", dataSize, fieldNumTmp); return false; } // 前面8个字节、即占一个featureId位,是unix时间戳 auto src = (const time_t*)inputTensor.tensor_data().data(); std::copy(src, src + 1, ×tamp); - spdlog::info("current batchId[{}] timestamp[{}]", batchId, timestamp); + LOG(INFO) << StringFormat("current batchId[%d] timestamp[%d]", batchId, timestamp); dataSize -= 1; if (timestamp <= 0) { - spdlog::error("timestamp[{}] <= 0 ", timestamp); + LOG(ERROR) << StringFormat("timestamp[%d] <= 0 ", timestamp); return false; } @@ -348,15 +337,7 @@ class ReadEmbKeyV2 : public OpKernel { public: explicit ReadEmbKeyV2(OpKernelConstructionPtr context) : OpKernel(context) { - auto logger = spdlog::get("console"); - if (!logger) { - logger = spdlog::stderr_color_mt("console"); - } - spdlog::set_default_logger(spdlog::get("console")); - - spdlog::cfg::load_env_levels(); - spdlog::default_logger()->set_pattern("[%H:%M:%S.%e] [%^%l%$] %v"); - spdlog::debug("ReadEmbKeyV2 init"); + VLOG(GLOG_DEBUG) << "ReadEmbKeyV2 init"; OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); OP_REQUIRES_OK(context, context->GetAttr("splits", &splits)); // 每个表的field Number @@ -368,21 +349,20 @@ public: // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() && !FeatureAdmitAndEvict::IsThresholdCfgOK(FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp)) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("threshold config, or timestamp error ..."))); + context->SetStatus( + errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold config, or timestamp error ...") + ); return; } if (splits.size() != embNames.size()) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("splits & embNames size error.{} {}", splits.size(), - embNames.size()))); + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "splits & embNames size error.%d %d", splits.size(), embNames.size()))); return; } if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("ReadEmbKeyV2 channelId invalid. It should be in range " - "[0, MAX_CHANNEL_NUM:{})", MAX_CHANNEL_NUM))); + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "ReadEmbKeyV2 channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", MAX_CHANNEL_NUM))); return; } batchIdsInfo.at(channelId) = 0; @@ -391,7 +371,7 @@ public: if (threadNumEnv != nullptr) { threadNum = static_cast(*threadNumEnv) - static_cast('0'); if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { - throw runtime_error(fmt::format("{} is not valid", threadNum)); + throw runtime_error(StringFormat("%d is not valid", threadNum)); } } auto keyProcess = Singleton::GetInstance(); @@ -407,13 +387,13 @@ public: void Compute(OpKernelContextPtr context) override { EASY_FUNCTION(); - spdlog::debug("enter ReadEmbKeyV2"); - spdlog::stopwatch sw; + VLOG(GLOG_DEBUG) << "enter ReadEmbKeyV2"; + TimeCost tc = TimeCost(); int batchId = batchIdsInfo.at(channelId)++; Tensor* output = nullptr; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { - spdlog::warn("skip excess batch after {}/{}", batchId, maxStep); + LOG(WARNING) << StringFormat("skip excess batch after %d/%d", batchId, maxStep); OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); out(0) = batchId; @@ -426,9 +406,8 @@ public: time_t timestamp = -1; // 如果传递了时间戳,解析和校验 if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", - fmt::format("timestamp[{}] error, skip excess batch after {}/{}", - timestamp, batchId, maxStep))); + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "timestamp[%d] error, skip excess batch after %d/%d", timestamp, batchId, maxStep))); return; } // 保证所有embNames在m_embStatus中有状态记录 @@ -444,10 +423,11 @@ public: TimeCost enqueueTC; EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); - TIME_PRINT(KEY_PROCESS "ReadEmbKeyV2Static read batch cost(ms):{}, elapsed from last(ms):{}, " - "enqueueTC(ms):{}, batch[{}]:{}", - Format2Ms(sw), Format2Ms(staticSw), enqueueTC.ElapsedMS(), channelId, batchId); - staticSw.reset(); + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "ReadEmbKeyV2Static read batch cost(ms):%d, elapsed from last(ms):%d," + " enqueueTC(ms):%d, batch[%d]:%d", + tc.ElapsedMS(), staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId); + staticSw = TimeCost(); } void CheckEmbTables() @@ -455,7 +435,7 @@ public: auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < splits.size(); ++i) { if (!keyProcess->hasEmbName(embNames.at(i))) { - spdlog::info("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i)); + LOG(INFO) << StringFormat("ReadEmbKeyV2 not found emb_name:%d %s", i, embNames.at(i).c_str()); tableUsed.push_back(false); } else { tableUsed.push_back(true); @@ -505,18 +485,18 @@ public: size_t& dataSize) { if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 - spdlog::error("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); + LOG(ERROR) << StringFormat("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); return false; } // 前面8个字节、即占一个featureId位,是unix时间戳 auto src = (const time_t*)inputTensor.tensor_data().data(); std::copy(src, src + 1, ×tamp); - spdlog::info("current batchId[{}] timestamp[{}]", batchId, timestamp); + LOG(INFO) << StringFormat("current batchId[{}] timestamp[{}]", batchId, timestamp); dataSize -= 1; if (timestamp <= 0) { - spdlog::error("timestamp[{}] <= 0 ", timestamp); + LOG(ERROR) << StringFormat("timestamp[{}] <= 0 ", timestamp); return false; } @@ -575,7 +555,7 @@ public: void Compute(OpKernelContextPtr context) override { EASY_FUNCTION(); - spdlog::stopwatch sw; + TimeCost tc = TimeCost(); const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); auto input = inputTensor.flat(); const int restoreLen = static_cast(input.size()); @@ -601,9 +581,9 @@ public: for (int i { 0 }; i < restoreLen; ++i) { r(i) = i % lookupLen; } - spdlog::warn("dummy read batch cost: {},elapsed from last {}", - Format2Ms(sw), Format2Ms(staticSw)); - staticSw.reset(); + LOG(WARNING) << StringFormat("dummy read batch cost: %d,elapsed from last %d", + tc.ElapsedMS(), staticSw.ElapsedMS()); + tc = TimeCost(); } int lookupLen {}; @@ -619,7 +599,7 @@ public: void Compute(OpKernelContextPtr context) override { - spdlog::info("context {}", context->step_id()); + LOG(INFO) << StringFormat("context {}", context->step_id()); std::cout << " Cust opp not installed!!" << std::endl; } diff --git a/src/platform/AccCTR b/src/platform/AccCTR index 62ab674f..77af967d 160000 --- a/src/platform/AccCTR +++ b/src/platform/AccCTR @@ -1 +1 @@ -Subproject commit 62ab674f0a42d8de8398eafb4799e506fe99549d +Subproject commit 77af967dcc81f4f8f7e2affd015cb760db5e4d9f diff --git a/src/test_ut.sh b/src/test_ut.sh index 54fde4ee..aa8edf3c 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -81,7 +81,7 @@ cd "$(dirname "${PWD}")" COVERAGE_FILE=coverage.info REPORT_FOLDER=coverage_report lcov --rc lcov_branch_coverage=1 -c -d build -o "${COVERAGE_FILE}"_tmp -lcov --rc lcov_branch_coverage=1 -e "${COVERAGE_FILE}"_tmp "*src*" -o "${COVERAGE_FILE}" +lcov -r "${COVERAGE_FILE}"_tmp 'ut/*' '/usr1/mxRec/src/core/key_process*' '/usr1/mxRec/src/core/hybrid_mgmt*' '/usr1/mxRec/src/core/host_emb*' '/usr1/mxRec/src/core/emb_table*' '7/ext*' 'platform/*' '/usr/local/*' '/usr/include/*' '/opt/buildtools/python-3.7.5/lib/python3.7/site-packages/tensorflow*' 'tests/*' '/usr1/mxRec/src/core/ock_ctr_common/include*' --rc lcov_branch_coverage=1 -o "${COVERAGE_FILE}" genhtml --rc genhtml_branch_coverage=1 "${COVERAGE_FILE}" -o "${REPORT_FOLDER}" [ -d "${COVERAGE_FILE}"_tmp ] && rm -rf "${COVERAGE_FILE}"_tmp [ -d "${COVERAGE_FILE}" ] && rm -rf "${COVERAGE_FILE}" diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 8c73e748..1b672523 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -60,4 +60,4 @@ target_link_libraries(test_main PUBLIC MPI::MPI_CXX) target_link_libraries(test_main PUBLIC ascendcl msprofiler ge_executor gert runtime ge_common register graph ascend_protobuf - profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host) \ No newline at end of file + profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host glog::glog) \ No newline at end of file diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index f5940db5..33df7b23 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -7,15 +7,9 @@ #include #include -#include -#include #include "checkpoint/checkpoint.h" -#include "ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h" -#include "ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h" -#include "ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h" #include "ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h" -#include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" using namespace std; @@ -54,7 +48,6 @@ protected: void SetUp() { - spdlog::set_level(spdlog::level::trace); int claimed; MPI_Query_thread(&claimed); diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index f4338561..6caace33 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -6,8 +6,6 @@ */ #include -#include -#include #include "ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h" #include "ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h" diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 488a853f..752db8e2 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -7,8 +7,6 @@ */ #include -#include -#include #include "hybrid_mgmt/hybrid_mgmt.h" #include "host_emb/host_emb.h" #include "utils/common.h" @@ -60,9 +58,9 @@ protected: void UpdateEmb(vector &missingKeysHostPos, int channelId, const string &embName, std::unique_ptr &hostEmb, vector &d2h_emb) { - spdlog::info(HD + "update emb start"); + LOG(INFO) << (HD + "update emb start"); if (d2h_emb.size() == 0) { - spdlog::info(HD + "emb is none", channelId); + LOG(INFO) << StringFormat(HD + "emb is none channelId:%d", channelId); return; } @@ -74,9 +72,14 @@ protected: tensorPtr = tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.extEmbeddingSize; } for (size_t i = 0; i < hostEmb->GetEmb(embName).embData.size(); ++i) { - spdlog::info("hostEmb: embName {}, {} is: {}", embName, i, hostEmb->GetEmb(embName).embData[i]); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat( + "hostEmb: embName %s, %d is: %s", embName.c_str(), i, + VectorToString(hostEmb->GetEmb(embName).embData[i]).c_str() + ); + } } - spdlog::info(HD + "update emb end"); + LOG(INFO) << (HD + "update emb end"); d2h_emb.clear(); } @@ -121,7 +124,7 @@ TEST_F(EmbMgmtTest, Initialize) auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; - allRank = RankInfo(rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); + allRank = RankInfo(g_rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); auto hostEmbs = make_unique(); hostEmbs->Initialize(embInfos, seed); @@ -135,7 +138,7 @@ TEST_F(EmbMgmtTest, Initialize) vector tmpData; hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); auto missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; - spdlog::info("missingKeys {}", missingKeys); + LOG(INFO) << StringFormat("missingKeys %d", missingKeys); hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); auto status = Float2TensorVec(tmpDatas, d2h_emb); ASSERT_EQ(status, true); @@ -145,7 +148,7 @@ TEST_F(EmbMgmtTest, Initialize) lookupKeys = { 2, 3, 5, 6 }; hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; - spdlog::info("missingKeys {}", missingKeys); + LOG(INFO) << StringFormat("missingKeys %d", missingKeys); hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); status = Float2TensorVec(tmpDatas, d2h_emb); ASSERT_EQ(status, true); @@ -155,7 +158,7 @@ TEST_F(EmbMgmtTest, Initialize) lookupKeys = { 1, 7, 9, 10 }; hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; - spdlog::info("missingKeys {}", missingKeys); + LOG(INFO) << StringFormat("missingKeys %d", missingKeys); hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); Float2TensorVec(tmpDatas, d2h_emb); status = Float2TensorVec(tmpDatas, d2h_emb); @@ -180,7 +183,7 @@ TEST_F(EmbMgmtTest, Initialize_HBM) auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; - allRank = RankInfo(rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); + allRank = RankInfo(g_rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); hybridMgmt->Destroy(); @@ -200,7 +203,7 @@ TEST_F(EmbMgmtTest, Evict) auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; - allRank = RankInfo(rankId, deviceId, localRankSize, true, nBatch, maxStep); + allRank = RankInfo(g_rankId, deviceId, localRankSize, true, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); // evict test, ddr @@ -223,7 +226,7 @@ TEST_F(EmbMgmtTest, Evict_HBM) auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; - allRank = RankInfo(rankId, deviceId, localRankSize, true, nBatch, maxStep); + allRank = RankInfo(g_rankId, deviceId, localRankSize, true, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); // evict test, hbm diff --git a/src/tests/emb_table/emb_table_test.cpp b/src/tests/emb_table/emb_table_test.cpp index 7c2febfd..92a71c77 100644 --- a/src/tests/emb_table/emb_table_test.cpp +++ b/src/tests/emb_table/emb_table_test.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include #include @@ -25,11 +24,10 @@ class EmbTableTest : public testing::Test { protected: void SetUp() { - spdlog::set_level(spdlog::level::debug); // 设置测试用的EmbInfo embInfo.extEmbeddingSize = embTable.TEST_EMB_SIZE; - spdlog::info("EmbTable BLOCK_EMB_COUNT {} INIT_BLOCK_COUNT {}", - embTable.BLOCK_EMB_COUNT, embTable.INIT_BLOCK_COUNT); + LOG(INFO) << StringFormat( + "EmbTable BLOCK_EMB_COUNT %d INIT_BLOCK_COUNT %d", embTable.BLOCK_EMB_COUNT, embTable.INIT_BLOCK_COUNT); rankInfo.rankId = 0; rankInfo.rankSize = 1; rankInfo.localRankSize = 1; @@ -40,7 +38,7 @@ protected: rankInfo.deviceId = 0; // 初始化EmbeddingTable #ifndef GTEST - spdlog::info("rank {} running", rankInfo.deviceId); + LOG(INFO) << StringFormat("rank %d running", rankInfo.deviceId); aclInit(nullptr); #endif } @@ -60,14 +58,15 @@ TEST_F(EmbTableTest, Init) #ifndef GTEST // 测试初始化是否出现异常 EXPECT_NO_THROW(embTable.Init(embInfo, rankInfo, 0)); - spdlog::info("embTable Init succeed!"); - ASSERT_EQ(embTable.rankInfo.rankId, rankInfo.rankId); + LOG(INFO) << "embTable Init succeed!"; + ASSERT_EQ(embTable.rankInfo.g_rankId, rankInfo.g_rankId); ASSERT_EQ(embTable.rankInfo.rankSize, rankInfo.rankSize); ASSERT_EQ(embTable.rankInfo.localRankSize, rankInfo.localRankSize); ASSERT_EQ(embTable.rankInfo.useStatic, rankInfo.useStatic); ASSERT_EQ(embTable.rankInfo.localRankId, rankInfo.localRankId); // 测试容量是否正常 - spdlog::info("totalCapacity {}, INIT_BLOCK_COUNT {}", embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); + LOG(INFO) << StringFormat( + "totalCapacity %d, INIT_BLOCK_COUNT %d", embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); EXPECT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); #endif } diff --git a/src/tests/host_emb/host_emb_test.cpp b/src/tests/host_emb/host_emb_test.cpp index 1f0c9673..dc7baf29 100644 --- a/src/tests/host_emb/host_emb_test.cpp +++ b/src/tests/host_emb/host_emb_test.cpp @@ -7,7 +7,6 @@ */ #include -#include #include "host_emb/host_emb.h" #include "tensorflow/core/framework/tensor.h" diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index db2c7aee..99ffce4d 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -8,7 +8,6 @@ #include #include -#include #include "initializer/initializer.h" #include "initializer/constant_initializer/constant_initializer.h" diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index 4dd721a0..31bfdd1e 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -12,8 +12,6 @@ #include #include #include -#include -#include #include "utils/common.h" #include "key_process/feature_admit_and_evict.h" @@ -149,7 +147,7 @@ protected: batch->name = embName; batch->timestamp = ts; printf("\n"); - spdlog::info("current admit embName[{}] at time[{}] ...", embName.c_str(), ts); + LOG(INFO) << StringFormat("current admit embName[%s] at time[%d] ...", embName.c_str(), ts); // 校验调接口不出错 ASSERT_EQ(faae.FeatureAdmit(channel, batch, args.keys, args.cnt) != @@ -172,7 +170,7 @@ protected: batch->name = embName; batch->timestamp = ts; printf("\n"); - spdlog::info("current admit embName[{}] at time[{}] ...", embName.c_str(), ts); + LOG(INFO) << StringFormat("current admit embName[%s] at time[%d] ...", embName.c_str(), ts); // 校验调接口不出错 faae.FeatureAdmit(channel, batch, args.keys, args.cnt); @@ -233,7 +231,7 @@ protected: void StartEvictThread() { evictThr = std::thread([&]() { - spdlog::info("Evict-thread start ..."); + LOG(INFO) << "Evict-thread start ..."; time_t currTime = 0; time_t lastTime = 0; @@ -241,13 +239,13 @@ protected: std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); currTime = time(nullptr); if (currTime - lastTime >= SleepTime::SLEEP_SECOND_4) { - spdlog::info("Evict-thread doing at currTime[{}] ...", currTime); + LOG(INFO) << StringFormat("Evict-thread doing at currTime[%d] ...", currTime); map> evictPosMap {}; faae.FeatureEvict(evictPosMap); lastTime = currTime; } } - spdlog::info("Evict-thread exit ..."); + LOG(INFO) << "Evict-thread exit ..."; }); } void WaitEvictThread() @@ -319,7 +317,7 @@ protected: FeatureAdmitCommon(faae, 0, thresholds[1].tensorName, args5); WaitEvictThread(); - spdlog::info("TestCase1(): single thread test over ..."); + LOG(INFO) << "TestCase1: single thread test over ..."; } // 进行“准入”逻辑时,若(splitKey.size() != keyCount.size()),则业务报错退出;(说明是前面all2all通信数据错误) @@ -339,7 +337,7 @@ protected: // 校验调接口,出错 ASSERT_EQ(faae.FeatureAdmit(0, batch, tmpKeys, tmpCnt) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR, true); - spdlog::info("TestCase2() over ..."); + LOG(INFO) << "TestCase2 over ..."; } // 准入、淘汰阈值可单独配置;只配置“准入”阈值、却不配置“淘汰”阈值,功能正常; @@ -413,7 +411,7 @@ protected: } WaitEvictThread(); - spdlog::info("TestCase5(): multi thread test over ..."); + LOG(INFO) << "TestCase5: multi thread test over ..."; } // 同时不配置“准入、淘汰”阈值,特征准入&淘汰功能“不支持”; @@ -428,7 +426,7 @@ protected: batch->timestamp = time(nullptr); // 校验调接口,不支持 - spdlog::info("TestCase6() over ..."); + LOG(INFO) << "TestCase6 over ..."; } bool isExitFlag { false }; diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 643d6d96..668c5c33 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -7,18 +7,13 @@ */ #include -#include #include #include #include -#include -#include #include "utils/common.h" -#include "host_emb/host_emb.h" #include "key_process/key_process.h" -#include "hybrid_mgmt/hybrid_mgmt.h" #include "ock_ctr_common/include/unique.h" using namespace std; @@ -26,7 +21,6 @@ using namespace MxRec; using namespace testing; static constexpr size_t BATCH_NUM_EACH_THREAD = 3; -static constexpr int DESIRED_SIZE = 1; FactoryPtr factory; class SimpleThreadPool { @@ -47,7 +41,7 @@ static void CTRLog(int level, const char *msg) { switch (level) { case 0: - spdlog::debug("{}", msg); + VLOG(GLOG_DEBUG) << StringFormat("%s", msg); break; default: break; @@ -58,13 +52,12 @@ class KeyProcessTest : public testing::Test { protected: void SetUp() { - spdlog::set_level(spdlog::level::debug); int claimed; MPI_Query_thread(&claimed); ASSERT_EQ(claimed, MPI_THREAD_MULTIPLE); MPI_Comm_rank(MPI_COMM_WORLD, &worldRank); MPI_Comm_size(MPI_COMM_WORLD, &worldSize); - spdlog::info(KEY_PROCESS "wordRank: {}, worldSize: {}", worldRank, worldSize); + LOG(INFO) << StringFormat(KEY_PROCESS "wordRank: %d, worldSize: %d", worldRank, worldSize); // 初始化rank信息 rankInfo.rankId = worldRank; rankInfo.rankSize = worldSize; @@ -98,9 +91,11 @@ protected: batch->name = embInfos[i].name; batch->batchId = batchId; batch->channel = channel; - spdlog::debug("[{}/{}]" - KEY_PROCESS "PrepareBatch: batchQueueId: {}, {}[{}]{}, sampleSize:{}", worldRank, worldSize, - batchQueueId, batch->name, batch->channel, batch->batchId, batch->sample.size()); + VLOG(GLOG_DEBUG) << StringFormat( + "[%d/%d]" KEY_PROCESS "PrepareBatch: batchQueueId: %d, %s[%d]%d, sampleSize:%d", + worldRank, worldSize, + batchQueueId, batch->name.c_str(), batch->channel, batch->batchId, batch->sample.size() + ); emb_batch_t temp; temp.sample = batch->sample; temp.name = batch->name; @@ -183,12 +178,18 @@ protected: { for (int i = 0; i < rankSize; ++i) { std::cout << "splitKeys dev" << i << std::endl; - spdlog::info("{}", splitKeys[i]); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat("%d", VectorToString(splitKeys[i]).c_str()); + } } std::cout << "restore" << std::endl; - spdlog::info("{}", restore); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat("%d", VectorToString(restore).c_str()); + } std::cout << "hotPos" << std::endl; - spdlog::info("{}", hotPos); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat("%d", VectorToString(hotPos).c_str()); + } } void GetExpectRestore(keys_t& sample, vector& blockOffset, vector& restoreVec) @@ -271,7 +272,9 @@ TEST_F(KeyProcessTest, HashSplit) vector expectRestore = { 0, 0, 0, 0, 1, 1, 1, 1, 1, 2 }; vector> expectSplitKeys = { { 4, 16 }, { 1, 21, 29 }, { 14, 2 }, { 23, 7 } }; batch->sample = std::move(batchKeys); - spdlog::debug(KEY_PROCESS "batch sample: {}", batch->sample); + if (VLOG_IS_ON(GLOG_DEBUG)) { + VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "batch sample: %s", VectorToString(batch->sample).c_str()); + } ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.rankInfo.rankSize = rankSize; @@ -286,7 +289,11 @@ TEST_F(KeyProcessTest, HashSplit) TEST_F(KeyProcessTest, GetScAll) { vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 - spdlog::debug(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, keyScLocal); + if (VLOG_IS_ON(GLOG_DEBUG)) { + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "rank %d keyScLocal: %s", worldRank, VectorToString(keyScLocal).c_str() + ); + } vector expectScAll(worldSize * worldSize); for (unsigned int i = 0; i < expectScAll.size(); ++i) { expectScAll[i] = floor(i / worldSize) + 1; @@ -302,7 +309,11 @@ TEST_F(KeyProcessTest, GetScAll) TEST_F(KeyProcessTest, GetScAllForUnique) { vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 - spdlog::debug(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, keyScLocal); + if (VLOG_IS_ON(GLOG_DEBUG)) { + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "rank %d keyScLocal: %s", worldRank, VectorToString(keyScLocal).c_str() + ); + } vector expectScAll(worldSize * worldSize); for (unsigned int i = 0; i < expectScAll.size(); ++i) { expectScAll[i] = floor(i / worldSize) + 1; @@ -328,11 +339,22 @@ TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { 5, 0, 6, 2, 1, 3, 1, 7, 4, 8 }, { 6, 3, 7, 4, 3, 0, 1, 2, 5, 8 } }; batch->sample = std::move(allBatchKeys[worldRank]); - spdlog::info(KEY_PROCESS "test BuildRestoreVec: rank {}, batchKeys {}", worldRank, batch->sample); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat( + KEY_PROCESS "test BuildRestoreVec: rank %d, batchKeys %s", + worldRank, VectorToString(batch->sample).c_str() + ); + } ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); auto [splitKeys, restore] = process.HashSplit(batch); - spdlog::debug("rank: {} splitKeys: {}", worldRank, splitKeys); + if (VLOG_IS_ON(GLOG_DEBUG)) { + vector tmp; + for (const auto& i : splitKeys) { + tmp.emplace_back(VectorToString(i)); + } + VLOG(GLOG_DEBUG) << StringFormat("rank: %d splitKeys: %s", worldRank, VectorToString(tmp).c_str()); + } process.BuildRestoreVec(batch, allExpectSs[worldRank], restore); ASSERT_THAT(restore, ElementsAreArray(allExpectRestore[worldRank])); } @@ -341,7 +363,7 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) { PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + LOG(INFO) << StringFormat("CPU Core Num: %d", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { auto embName = embInfos[0].name; @@ -351,10 +373,13 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) vector hotPos; unique_ptr batch; batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue - spdlog::info("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); + LOG(INFO) << StringFormat("rankid :%d,batchid: %d", rankInfo.rankId, batch->batchId); tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); - spdlog::info("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, - hotPos); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat( + "rankid :%d,batchid: %d, hotPos %s", rankInfo.rankId, batch->batchId, VectorToString(hotPos).c_str() + ); + } }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < 1; ++id) { @@ -370,7 +395,7 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) { PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + LOG(INFO) << StringFormat("CPU Core Num: %d", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { auto embName = embInfos[0].name; @@ -379,12 +404,17 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) vector hotPos; unique_ptr batch; batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue - spdlog::info("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); + LOG(INFO) << StringFormat("rankid :%d,batchid: %d", rankInfo.rankId, batch->batchId); tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); auto[lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); process.BuildRestoreVec(batch, ss, restore, hotPos.size()); - spdlog::info("rankid :{},batchid: {}, lookupKeys: {}, scAll: {}, restore after build {}", rankInfo.rankId, - batch->batchId, lookupKeys, scAll, restore); + if (g_glogLevel >= INFO) { + LOG(INFO) << StringFormat( + "rankid :%d,batchid: %d, lookupKeys: %s, scAll: %s, restore after build %s", + rankInfo.rankId, batch->batchId, VectorToString(lookupKeys).c_str(), + VectorToString(scAll).c_str(), VectorToString(restore).c_str() + ); + } }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { @@ -403,13 +433,27 @@ TEST_F(KeyProcessTest, Key2Offset) ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.Key2Offset("emb0", lookupKeys, TRAIN_CHANNEL_ID); - spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys, process.keyOffsetMap); +// map> keyOffsetMap {}; + map tmp; + for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { + tmp.insert(pair(it->first, MapToString(it->second).c_str())); + } + + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "test Key2Offset: lookupKeys: %s, keyOffsetMap: %s", + VectorToString(lookupKeys).c_str(), MapToString(tmp).c_str()); ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); keys_t lookupKeys2 = { 5, 17, 29, 5, 25, 5, 21, 25 }; keys_t expectOffset2 = { -1, -1, -1, -1, -1, -1, -1, -1 }; process.Key2Offset("emb0", lookupKeys2, EVAL_CHANNEL_ID); - spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys2, process.keyOffsetMap); + map tmp2; + for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { + tmp.insert(pair(it->first, MapToString(it->second).c_str())); + } + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "test Key2Offset: lookupKeys: %s, keyOffsetMap: %s", VectorToString(lookupKeys2).c_str(), + MapToString(tmp2).c_str()); ASSERT_THAT(lookupKeys2, ElementsAreArray(expectOffset2)); } @@ -420,7 +464,17 @@ TEST_F(KeyProcessTest, Key2OffsetDynamicExpansion) ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.Key2OffsetDynamicExpansion("emb0", lookupKeys, EVAL_CHANNEL_ID); - spdlog::debug(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", lookupKeys, process.keyOffsetMap); + if (VLOG_IS_ON(GLOG_DEBUG)) { + map tmp; + for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { + tmp.insert(pair(it->first, MapToString(it->second).c_str())); + } + + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "test Key2Offset: lookupKeys: %s, keyOffsetMap: %s", + VectorToString(lookupKeys).c_str(), MapToString(tmp).c_str() + ); + } ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); } @@ -446,7 +500,7 @@ TEST_F(KeyProcessTest, ProcessPrefetchTask) ASSERT_EQ(process.Start(), 0); // 所有线程处理完(训练结束)后调用 this_thread::sleep_for(5s); - spdlog::info("wait 20s for thread running"); + LOG(INFO) << "wait 20s for thread running"; this_thread::sleep_for(20s); process.Destroy(); } @@ -483,7 +537,7 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - spdlog::info("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + LOG(INFO) << StringFormat("CPU Core Num: %d", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { UniquePtr unique; @@ -502,10 +556,11 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) process.GetUniqueConfig(uniqueConf); unique->Initialize(uniqueConf); - spdlog::info("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); + LOG(INFO) << StringFormat("rankid :%d,batchid: %d", rankInfo.rankId, batch->batchId); process.KeyProcessTaskHelperWithFastUnique(batch, unique, channel, id); - spdlog::info("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, - hotPos); + LOG(INFO) << StringFormat( + "rankid :%d,batchid: %d, hotPos %s", rankInfo.rankId, batch->batchId, VectorToString(hotPos).c_str() + ); }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < 1; ++id) { @@ -515,4 +570,20 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) } this_thread::sleep_for(20s); process.Destroy(); +} + +TEST(KeyProcess, SetupHotEmbUpdateStep) +{ + KeyProcess kp; + + kp.SetupHotEmbUpdateStep(); + ASSERT_EQ(kp.hotEmbUpdateStep, HOT_EMB_UPDATE_STEP_DEFAULT); + + putenv("HOT_EMB_UPDATE_STEP=1"); + kp.SetupHotEmbUpdateStep(); + ASSERT_EQ(kp.hotEmbUpdateStep, 1); + + putenv("HOT_EMB_UPDATE_STEP=0"); + kp.SetupHotEmbUpdateStep(); + ASSERT_EQ(kp.hotEmbUpdateStep, HOT_EMB_UPDATE_STEP_DEFAULT); } \ No newline at end of file diff --git a/src/tests/utils/common_test.cpp b/src/tests/utils/common_test.cpp new file mode 100644 index 00000000..90d60c3f --- /dev/null +++ b/src/tests/utils/common_test.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: key process test + * Author: MindX SDK + * Create: 2022 + * History: NA + */ + +#include +#include +#include + +#include "utils/common.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +TEST(common, SetLog) +{ + SetLog(); + ASSERT_EQ(g_glogLevel, 0); + + putenv("GLOG_stderrthreshold=1"); + SetLog(); + ASSERT_EQ(g_glogLevel, 1); +} + +TEST(common, InitializeInfo) +{ + NormalInitializerInfo nInfoTruncatedNormal; + string nameTruncatedNormal = "truncated_normal_initializer"; + InitializeInfo iInfo = InitializeInfo(nameTruncatedNormal, 0, 1, nInfoTruncatedNormal); + ASSERT_EQ(iInfo.initializerType, InitializerType::TRUNCATED_NORMAL); + + NormalInitializerInfo nInfoRandomNormal; + string nameRandomNormal = "random_normal_initializer"; + iInfo = InitializeInfo(nameRandomNormal, 0, 1, nInfoRandomNormal); + ASSERT_EQ(iInfo.initializerType, InitializerType::RANDOM_NORMAL); + + NormalInitializerInfo nInfoInvalid; + string nameInvalid = "x"; + bool isExceptionThrow = { false }; + try { + iInfo = InitializeInfo(nameInvalid, 0, 1, nInfoInvalid); + } catch (const std::invalid_argument& e) { + isExceptionThrow = true; + } + ASSERT_EQ(isExceptionThrow, true); +} \ No newline at end of file -- Gitee From 2b9df4c2ae3e20ba8a39185b7606bb0d1ade208f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 15 Jul 2023 17:47:19 +0800 Subject: [PATCH 214/551] Match-id-b0e229b7ff945cfb28737d6ee672ca32217bcc0b --- mx_rec/constants/constants.py | 2 + mx_rec/core/embedding.py | 155 ++++++++++++++++--------------- mx_rec/{util => saver}/sparse.py | 3 +- mx_rec/util/atomic.py | 7 +- mx_rec/validator/validator.py | 3 + 5 files changed, 93 insertions(+), 77 deletions(-) rename mx_rec/{util => saver}/sparse.py (98%) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 57b7800c..e2c38e8d 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -46,6 +46,8 @@ MAX_RANK_SIZE = 4095 MIN_DEVICE_NUM = 1 MIN_RANK_SIZE = 1 +LOG_MAX_SIZE = 1024 * 1024 + MAX_INT32 = np.iinfo(np.int32).max diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 48c0fce8..02473a83 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -242,6 +242,42 @@ class SparseEmbedding: optimizer.insert_slot(slot, named_slot_key, slot_name) + @staticmethod + def _get_own_emb(emb, all2all_args, emb_size, use_static): + """ + obtain embedding of source data + :param emb: origin embeddding + :param all2all_args: dynamic shape condition parameters + :param emb_size: size of embedding table + :param use_static: enable static shape training or not + :return: local embedding after all2all + """ + from mx_rec.util.tf_version_adapter import hccl_ops + rank_size = get_rank_size() + rank_id = get_rank_id() + + src_emb = emb + + reshape_info = [all2all_args * rank_size, emb_size] if use_static else [-1, emb_size] + + if rank_size == 1 and use_static: + return tf.reshape(src_emb, reshape_info) + + if use_static: + emb_send_cnt = tf.constant([all2all_args * emb_size] * rank_size, dtype=tf.int64) + emb_send_offset = tf.constant([all2all_args * emb_size * i for i in range(rank_size)], dtype=tf.int64) + src_emb = hccl_ops.all_to_all_v(send_data=emb, + send_counts=emb_send_cnt, + send_displacements=emb_send_offset, + recv_counts=emb_send_cnt, + recv_displacements=emb_send_offset) + else: + src_emb = hccl_ops.all_to_all_v_c(send_data=emb, + send_count_matrix=all2all_args, + rank=rank_id) + + return tf.reshape(src_emb, reshape_info) + def check_optimizer_instance(self): for optimizer_instance in self._optimizer_instance_list: if tf.__version__.startswith("1"): @@ -482,7 +518,7 @@ class SparseEmbedding: if not use_static: # In the case of multiple lookups of a table, the all2all_matrix does not run the 'getnext' op # to obtain the actual value. Instead, the initial value is 1. So it needs to be multiplied by - # 'self.scalar_emb_size' to ensure the correctness of the 'Reshape' op in the get_own_emb function. + # 'self.scalar_emb_size' to ensure the correctness of the 'Reshape' op in the _get_own_emb function. all2all_matrix = tf.ones(shape=[rank_size, rank_size], dtype=tf.int64, name="all2all_matrix") * self.scalar_emb_size all2all_matrix = tf.identity(all2all_matrix, name=ASCAnchorAttr.ALL2ALL_MATRIX.value) @@ -508,7 +544,7 @@ class SparseEmbedding: else: local_emb = tf.identity(table, name="identity_local_emb") all2all_args = send_count if use_static else all2all_matrix - unique_embeddings = get_own_emb(local_emb, all2all_args, self.scalar_emb_size, use_static) + unique_embeddings = self._get_own_emb(local_emb, all2all_args, self.scalar_emb_size, use_static) if hot_pos is not None: unique_embeddings = tf.concat([tf.gather(unique_embeddings, hot_pos, name="hot_pos"), @@ -529,8 +565,8 @@ class SparseEmbedding: def grad(lookup_diff): embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) - unique_embeddings_shape = unique_embeddings.shape.as_list() if use_static \ - else tf.shape(unique_embeddings) + unique_embeddings_shape = unique_embeddings.shape.as_list() if use_static else tf.shape( + unique_embeddings) unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, unique_embeddings_shape[0]) bp_all2all_args = all2all_args if use_static else tf.transpose(all2all_args) @@ -538,7 +574,7 @@ class SparseEmbedding: hot, cold = tf.split(unique_grads, [tf.shape(hot_pos)[0], tf.shape(unique_grads)[0] - tf.shape(hot_pos)[0]], axis=0) unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) - local_grad = get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) + local_grad = self._get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: local_grad = local_grad / get_rank_size() @@ -723,7 +759,7 @@ class SparseEmbedding: local_embeddings = tf.identity(table, name="identity_local_emb") all2all_args = send_count if use_static else all2all_matrix - unique_embeddings = get_own_emb(local_embeddings, all2all_args, self.scalar_emb_size, use_static) + unique_embeddings = self._get_own_emb(local_embeddings, all2all_args, self.scalar_emb_size, use_static) if hot_pos is not None: unique_embeddings = tf.concat([tf.gather(unique_embeddings, hot_pos, name="hot_pos"), @@ -756,7 +792,7 @@ class SparseEmbedding: hot, cold = tf.split(unique_grads, [tf.shape(hot_pos)[0], tf.shape(unique_grads)[0] - tf.shape(hot_pos)[0]], axis=0) unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) - local_grad = get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) + local_grad = self._get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: local_grad = local_grad / get_rank_size() @@ -829,70 +865,39 @@ class SparseEmbedding: self.set_optimizer_slot(slot_info) -def get_own_ids(unique_ids, origin_id_lens, send_cnt, self): - from mx_rec.util.tf_version_adapter import hccl_ops - rank_size = get_rank_size() - if rank_size > 1: - ids_send_cnt = tf.constant([send_cnt] * rank_size, dtype=tf.int64) - ids_send_offset = tf.constant([send_cnt * i for i in range(rank_size)], dtype=tf.int64) - own_ids = hccl_ops.all_to_all_v(send_data=unique_ids, - send_counts=ids_send_cnt, - send_displacements=ids_send_offset, - recv_counts=ids_send_cnt, - recv_displacements=ids_send_offset) - - lens_sc = tf.constant([1] * rank_size, dtype=tf.int64) - lens_sd = tf.constant([i for i in range(rank_size)], dtype=tf.int64) - local_id_lens = hccl_ops.all_to_all_v(send_data=origin_id_lens, - send_counts=lens_sc, - send_displacements=lens_sd, - recv_counts=lens_sc, - recv_displacements=lens_sd) - - else: - own_ids = unique_ids - local_id_lens = origin_id_lens - - def feature_mapping(): - self.set_using_feature_mapping() - id_offsets = SparseEmbedding.customized_ops.feature_mapping(own_ids, table_name=self.table_name) - return id_offsets - - id_offsets = feature_mapping() - id_offsets.set_shape([send_cnt * rank_size]) - - return id_offsets, local_id_lens - - -def get_own_emb(emb, all2all_args, emb_size, use_static): - ''' - obtain embedding of source data - ''' - from mx_rec.util.tf_version_adapter import hccl_ops - rank_size = get_rank_size() - rank_id = get_rank_id() - - src_emb = emb - - reshape_info = [all2all_args * rank_size, emb_size] if use_static else [-1, emb_size] - - if rank_size == 1 and use_static: - return tf.reshape(src_emb, reshape_info) - - if use_static: - emb_send_cnt = tf.constant([all2all_args * emb_size] * rank_size, dtype=tf.int64) - emb_send_offset = tf.constant([all2all_args * emb_size * i for i in range(rank_size)], dtype=tf.int64) - src_emb = hccl_ops.all_to_all_v(send_data=emb, - send_counts=emb_send_cnt, - send_displacements=emb_send_offset, - recv_counts=emb_send_cnt, - recv_displacements=emb_send_offset) - else: - src_emb = hccl_ops.all_to_all_v_c(send_data=emb, - send_count_matrix=all2all_args, - rank=rank_id) - - return tf.reshape(src_emb, reshape_info) +# def get_own_ids(unique_ids, origin_id_lens, send_cnt, self): +# from mx_rec.util.tf_version_adapter import hccl_ops +# rank_size = get_rank_size() +# if rank_size > 1: +# ids_send_cnt = tf.constant([send_cnt] * rank_size, dtype=tf.int64) +# ids_send_offset = tf.constant([send_cnt * i for i in range(rank_size)], dtype=tf.int64) +# own_ids = hccl_ops.all_to_all_v(send_data=unique_ids, +# send_counts=ids_send_cnt, +# send_displacements=ids_send_offset, +# recv_counts=ids_send_cnt, +# recv_displacements=ids_send_offset) +# +# lens_sc = tf.constant([1] * rank_size, dtype=tf.int64) +# lens_sd = tf.constant([i for i in range(rank_size)], dtype=tf.int64) +# local_id_lens = hccl_ops.all_to_all_v(send_data=origin_id_lens, +# send_counts=lens_sc, +# send_displacements=lens_sd, +# recv_counts=lens_sc, +# recv_displacements=lens_sd) +# +# else: +# own_ids = unique_ids +# local_id_lens = origin_id_lens +# +# def feature_mapping(): +# self.set_using_feature_mapping() +# id_offsets = SparseEmbedding.customized_ops.feature_mapping(own_ids, table_name=self.table_name) +# return id_offsets +# +# id_offsets = feature_mapping() +# id_offsets.set_shape([send_cnt * rank_size]) +# +# return id_offsets, local_id_lens class _EvictHook(tf.compat.v1.train.SessionRunHook): @@ -995,9 +1000,13 @@ def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Opti """ if access_threshold is None or access_threshold <= 0: return embeddings - id_offsets_expand = tf.expand_dims(id_offsets >= 0, axis=-1) + if tf.__version__.startswith("1"): - id_offsets_expand = tf.repeat(id_offsets_expand, [tf.shape(embeddings)[-1]], axis=-1) + id_offsets_expand = tf.math.greater_equal(id_offsets, 0) + embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like()) + return embeddings + + id_offsets_expand = tf.compat.v1.expand_dims(id_offsets >= 0, axis=-1) embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like(embeddings)) return embeddings diff --git a/mx_rec/util/sparse.py b/mx_rec/saver/sparse.py similarity index 98% rename from mx_rec/util/sparse.py rename to mx_rec/saver/sparse.py index 67811cfa..1649bec4 100644 --- a/mx_rec/util/sparse.py +++ b/mx_rec/saver/sparse.py @@ -6,10 +6,9 @@ import logging import os import json -import tensorflow as tf import numpy as np -from mx_rec.util.initialize import get_table_instance, get_table_instance_by_name, export_table_name_set +from mx_rec.util.initialize import get_table_instance_by_name, export_table_name_set from mx_rec.validator.validator import FileValidator diff --git a/mx_rec/util/atomic.py b/mx_rec/util/atomic.py index bca4c660..4c7242dc 100644 --- a/mx_rec/util/atomic.py +++ b/mx_rec/util/atomic.py @@ -5,7 +5,10 @@ import threading -class AtomicInteger(): +class AtomicInteger: + """ + counter atomic increment/decrement + """ def __init__(self, value=0): self._value = int(value) self._lock = threading.Lock() @@ -19,7 +22,7 @@ class AtomicInteger(): return self._value def decrease(self, num=1): - return self.inc(-num) + return self.increase(-num) def value(self): with self._lock: diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index 524e8381..1e2f45d0 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -252,6 +252,9 @@ class DirectoryValidator(StringValidator): class FileValidator(StringValidator): + """ + Check if file is valid. + """ def __init__(self, value): """ @param value: the file path, should not be emtpy string, should not contain double dot(../) -- Gitee From c59690a21ba2417a3a1f4572724b1db37b76f34e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 16 Jul 2023 14:07:08 +0800 Subject: [PATCH 215/551] Match-id-1347d121bc4d5b1a87309aa06fd50dd26576c0ca --- mx_rec/constants/constants.py | 1 + mx_rec/util/initialize.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 57b7800c..ed0d65f8 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -40,6 +40,7 @@ HASHTABLE_COLLECTION_NAME_LENGTH = 30 # RANK INFO VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"] MIN_SIZE = 1 +MAX_CONFIG_SIZE = 10 * 1024 * 1024 MAX_SIZE = 1024 * 1024 * 1024 * 1024 MAX_DEVICE_NUM = 16 MAX_RANK_SIZE = 4095 diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 0c4d13f0..e7a180cf 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -13,7 +13,7 @@ import psutil import mx_rec.constants.constants from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST, LOCAL_RANK_SIZE, \ MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH,\ - TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID + TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.validator.validator import RankInfoValidator, StringValidator, FileValidator from mx_rec.util.atomic import AtomicInteger @@ -224,7 +224,7 @@ class ConfigInitializer: # 1.check whether rank_table_path is soft link file_validator.check_not_soft_link() # 2.check json file size - file_validator.check_file_size(file) + file_validator.check_file_size(file, MAX_CONFIG_SIZE, MIN_SIZE) file_validator.check() table_hccl = json.load(file) @@ -816,7 +816,7 @@ def get_available_cpu_num_and_range(): # 1.check whether f_path is soft link file_validator.check_not_soft_link() # 2.check file size - file_validator.check_file_size(f_in) + file_validator.check_file_size(f_in, MAX_CONFIG_SIZE, MIN_SIZE) file_validator.check() pkg_id = f_in.readline().strip() pkg_id2cpu_list[pkg_id].append(cpu) -- Gitee From a45558486abdb65ba728d731186fbf83ee9f50e0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 16 Jul 2023 18:10:11 +0800 Subject: [PATCH 216/551] Match-id-e4ee26549baa13231976b61415d454b4c43224d9 --- src/tests/key_process/key_process_test.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 668c5c33..83cb9cee 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -433,7 +433,6 @@ TEST_F(KeyProcessTest, Key2Offset) ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.Key2Offset("emb0", lookupKeys, TRAIN_CHANNEL_ID); -// map> keyOffsetMap {}; map tmp; for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { tmp.insert(pair(it->first, MapToString(it->second).c_str())); -- Gitee From 42b467cbbf9bd6729c8369cbbc3465e3efa40ac5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 15 Jul 2023 18:09:59 +0800 Subject: [PATCH 217/551] Match-id-62e94faa4ee22bf9fd935ea4a0956b13bdb97499 --- mx_rec/core/asc/build_graph.py | 2 +- mx_rec/core/embedding.py | 144 +++--------------------------- mx_rec/core/feature_process.py | 100 +++++++++++++++++++++ mx_rec/graph/modifier.py | 12 +-- mx_rec/logger/log.py | 74 +++++++++++++++ mx_rec/logger/logger.yaml | 18 ++++ mx_rec/optimizers/ftrl_t_dense.py | 11 +-- mx_rec/saver/patch.py | 5 +- mx_rec/saver/saver.py | 6 +- mx_rec/saver/sparse.py | 4 +- 10 files changed, 222 insertions(+), 154 deletions(-) create mode 100644 mx_rec/core/feature_process.py create mode 100644 mx_rec/logger/log.py create mode 100644 mx_rec/logger/logger.yaml diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 81bb27a8..c25eb6be 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -35,7 +35,7 @@ def get_restore_vector(config): restore_size = None with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): - if use_hot: + if use_hot and emb_size: device_id = int(config.get("device_id")) hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 02473a83..183640b5 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -4,7 +4,6 @@ import logging import math -import time from collections import defaultdict from typing import Optional @@ -21,14 +20,13 @@ from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ - DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, \ + MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, \ ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT, All2allGradientsOp, ApplyGradientsStrategy from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ - clear_channel, trigger_evict, get_table_instance_by_name, get_use_hot, get_device_id, export_feature_spec, \ - ConfigInitializer, get_ascend_global_hashtable_collection, get_host_pipeline_ops, get_use_dynamic_expansion, \ + clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ + get_host_pipeline_ops, get_use_dynamic_expansion, \ set_modify_graph, insert_removing_var_list -from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.validator.validator import ClassValidator, StringValidator @@ -576,7 +574,10 @@ class SparseEmbedding: unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) local_grad = self._get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: - local_grad = local_grad / get_rank_size() + try: + local_grad = local_grad / get_rank_size() + except ZeroDivisionError as exp: + raise ZeroDivisionError("Rank size cannot be zero.") from exp if use_dynamic_expansion: return local_grad, feat_ids @@ -794,7 +795,10 @@ class SparseEmbedding: unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) local_grad = self._get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: - local_grad = local_grad / get_rank_size() + try: + local_grad = local_grad / get_rank_size() + except ZeroDivisionError as exp: + raise ZeroDivisionError("Rank size cannot be zero.") from exp if use_dynamic_expansion: update_grad = local_grad @@ -865,130 +869,6 @@ class SparseEmbedding: self.set_optimizer_slot(slot_info) -# def get_own_ids(unique_ids, origin_id_lens, send_cnt, self): -# from mx_rec.util.tf_version_adapter import hccl_ops -# rank_size = get_rank_size() -# if rank_size > 1: -# ids_send_cnt = tf.constant([send_cnt] * rank_size, dtype=tf.int64) -# ids_send_offset = tf.constant([send_cnt * i for i in range(rank_size)], dtype=tf.int64) -# own_ids = hccl_ops.all_to_all_v(send_data=unique_ids, -# send_counts=ids_send_cnt, -# send_displacements=ids_send_offset, -# recv_counts=ids_send_cnt, -# recv_displacements=ids_send_offset) -# -# lens_sc = tf.constant([1] * rank_size, dtype=tf.int64) -# lens_sd = tf.constant([i for i in range(rank_size)], dtype=tf.int64) -# local_id_lens = hccl_ops.all_to_all_v(send_data=origin_id_lens, -# send_counts=lens_sc, -# send_displacements=lens_sd, -# recv_counts=lens_sc, -# recv_displacements=lens_sd) -# -# else: -# own_ids = unique_ids -# local_id_lens = origin_id_lens -# -# def feature_mapping(): -# self.set_using_feature_mapping() -# id_offsets = SparseEmbedding.customized_ops.feature_mapping(own_ids, table_name=self.table_name) -# return id_offsets -# -# id_offsets = feature_mapping() -# id_offsets.set_shape([send_cnt * rank_size]) -# -# return id_offsets, local_id_lens - - -class _EvictHook(tf.compat.v1.train.SessionRunHook): - """Sets evict based on global step or time.""" - - def __init__(self, - evict_enable=False, - evict_time_interval=DEFAULT_EVICT_TIME_INTERVAL, - evict_step_interval=None): - self._evict_enable = evict_enable - self._evict_time_interval = evict_time_interval - self._evict_step_interval = evict_step_interval - self._hash_table_instance = dict() - self._start_time = time.time() - self._global_step = 0 - self._evict_op = dict() - self._global_step_tensor = None - - self.check_evict_init_params() - logging.info(f"_EvictHook - > evict_time_interval: {self._evict_time_interval}, " - f"evict_step_interval: {self._evict_step_interval}") - - def begin(self): - self._global_step_tensor = tf.compat.v1.train.get_or_create_global_step() - if self._global_step_tensor is None: - raise RuntimeError("Global step should be created to use _EvictHook.") - self.check_name_and_get_hashtable() - for name, instance in self._hash_table_instance.items(): - scope_name = "{0}//{1}".format(instance.table_name, "evict") - with tf.compat.v1.variable_scope(scope_name): - logging.debug(f'Channel {instance.table_name}_evict_{TRAIN_CHANNEL_ID} was built for op ' - f'getnext') - - evict_pos, evict_len = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32, tf.int32], - output_shapes=[[None], []], - channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}') - - initialized_tensor = instance.emb_initializer( - instance.slice_device_vocabulary_size + instance.embedding_size) * instance.init_param - - initialized_tensor = initialized_tensor[0:evict_len, :] - - logging.debug(f'evict_pos output shape {evict_pos}, and slice_device_vocabulary_size ' - f'{instance.slice_device_vocabulary_size}, ' - f'initialized_tensor shape: {initialized_tensor}') - - nd_evict_pos = tf.expand_dims(evict_pos, 1) - self._evict_op[name] = tf.compat.v1.scatter_nd_update(instance.variable, nd_evict_pos, - initialized_tensor) - - def after_create_session(self, session, coord): - self._global_step = session.run(self._global_step_tensor) - logging.debug(f"_EvictHook - > after_create_session, step: {self._global_step}") - - def after_run(self, run_context, run_values): - if not self._evict_enable: - return - - self._global_step = run_context.session.run(self._global_step_tensor) - cur_time = time.time() - if cur_time - self._start_time > self._evict_time_interval or \ - (self._evict_step_interval is not None and self._global_step % self._evict_step_interval == 0): - logging.info(f"_EvictHook - > evict switch on!!! after_run step: {self._global_step}") - if not trigger_evict(): - return - self._start_time = cur_time - for name in self._hash_table_instance.keys(): - run_context.session.run(self._evict_op.get(name)) - - def check_name_and_get_hashtable(self): - for _, feature_spec in export_feature_spec().items(): - if feature_spec.eviction_threshold: - logging.debug(f"_EvictHook - > check and get instance: table_names {feature_spec.table_name}") - self._hash_table_instance[feature_spec.table_name] = get_table_instance_by_name(feature_spec.table_name) - - def check_evict_init_params(self): - def check_type(arg, n_type, param_name): - if not isinstance(arg, n_type): - raise TypeError(f"{param_name} should be type '{n_type}', whose value is {arg} with type " - f"'{type(arg)}' in fact.") - if type(arg) == int and arg < 1: - raise ValueError(f"{param_name} should be bigger than 0, whose value is {arg} in fact") - - check_type(self._evict_enable, bool, "evict_enable") - if self._evict_time_interval is not None: - check_type(self._evict_time_interval, int, "evict_time_interval") - if self._evict_step_interval is not None: - check_type(self._evict_step_interval, int, "evict_time_interval") - - def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Optional[tf.Tensor], access_threshold: bool): """ @@ -1003,7 +883,7 @@ def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Opti if tf.__version__.startswith("1"): id_offsets_expand = tf.math.greater_equal(id_offsets, 0) - embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like()) + embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like(embeddings)) return embeddings id_offsets_expand = tf.compat.v1.expand_dims(id_offsets >= 0, axis=-1) diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py new file mode 100644 index 00000000..489a2946 --- /dev/null +++ b/mx_rec/core/feature_process.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import logging +import time + +import tensorflow as tf + +from mx_rec.constants.constants import DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID +from mx_rec.util.initialize import trigger_evict, get_table_instance_by_name, export_feature_spec + + +class _EvictHook(tf.compat.v1.train.SessionRunHook): + """Sets evict based on global step or time.""" + + def __init__(self, + evict_enable=False, + evict_time_interval=DEFAULT_EVICT_TIME_INTERVAL, + evict_step_interval=None): + self._evict_enable = evict_enable + self._evict_time_interval = evict_time_interval + self._evict_step_interval = evict_step_interval + self._hash_table_instance = dict() + self._start_time = time.time() + self._global_step = 0 + self._evict_op = dict() + self._global_step_tensor = None + + self.check_evict_init_params() + logging.info(f"_EvictHook - > evict_time_interval: %d, evict_step_interval: %d", self._evict_time_interval, + self._evict_step_interval) + + def begin(self): + self._global_step_tensor = tf.compat.v1.train.get_or_create_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step should be created to use _EvictHook.") + self.check_name_and_get_hashtable() + for name, instance in self._hash_table_instance.items(): + scope_name = f"{instance.table_name}//evict" + with tf.compat.v1.variable_scope(scope_name): + logging.debug('Channel %s_evict_%d was built for op getnext', instance.table_name, TRAIN_CHANNEL_ID) + + from mx_rec.util.tf_version_adapter import npu_ops + evict_pos, evict_len = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[[None], []], + channel_name=f'{instance.table_name}_evict_{TRAIN_CHANNEL_ID}') + + initialized_tensor = instance.emb_initializer( + instance.slice_device_vocabulary_size + instance.embedding_size) * instance.init_param + + initialized_tensor = initialized_tensor[0:evict_len, :] + + logging.debug( + 'evict_pos output shape %r, and slice_device_vocabulary_size %d, initialized_tensor shape: %r', + evict_pos, instance.slice_device_vocabulary_size, initialized_tensor) + + nd_evict_pos = tf.expand_dims(evict_pos, 1) + self._evict_op[name] = tf.compat.v1.scatter_nd_update(instance.variable, nd_evict_pos, + initialized_tensor) + + def after_create_session(self, session, coord): + self._global_step = session.run(self._global_step_tensor) + logging.debug("_EvictHook - > after_create_session, step: %d", self._global_step) + + def after_run(self, run_context, run_values): + if not self._evict_enable: + return + + self._global_step = run_context.session.run(self._global_step_tensor) + cur_time = time.time() + if cur_time - self._start_time > self._evict_time_interval or \ + (self._evict_step_interval is not None and self._global_step % self._evict_step_interval == 0): + logging.info("_EvictHook - > evict switch on!!! after_run step: %d", self._global_step) + if not trigger_evict(): + return + self._start_time = cur_time + for name in self._hash_table_instance.keys(): + run_context.session.run(self._evict_op.get(name)) + + def check_name_and_get_hashtable(self): + for _, feature_spec in export_feature_spec().items(): + if feature_spec.eviction_threshold: + logging.debug("_EvictHook - > check and get instance: table_names %s", feature_spec.table_name) + self._hash_table_instance[feature_spec.table_name] = get_table_instance_by_name(feature_spec.table_name) + + def check_evict_init_params(self): + def check_type(arg, n_type, param_name): + if not isinstance(arg, n_type): + raise TypeError(f"{param_name} should be type '{n_type}', whose value is {arg} with type " + f"'{type(arg)}' in fact.") + if type(arg) == int and arg < 1: + raise ValueError(f"{param_name} should be bigger than 0, whose value is {arg} in fact") + + check_type(self._evict_enable, bool, "evict_enable") + if self._evict_time_interval is not None: + check_type(self._evict_time_interval, int, "evict_time_interval") + if self._evict_step_interval is not None: + check_type(self._evict_step_interval, int, "evict_time_interval") diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 172e11bc..4095d750 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -230,7 +230,7 @@ def update_input_tensor_with_new_batch(replacement_specs, new_get_next_op_name): for idx, operator in item: old_tensor_name = old_tensor.name output_index = old_tensor_name.split(":")[-1] - new_tensor_name = "%s:%s" % (new_get_next_op_name, output_index) + new_tensor_name = f"{new_get_next_op_name}:{output_index}" new_tensor = graph.get_tensor_by_name(new_tensor_name) operator._update_input(idx, new_tensor) @@ -285,7 +285,7 @@ def generate_get_next_op_specs(cutting_point_list, dump_graph): get_next_op_map[get_next_op]["is_training"] = \ SparseEmbedding.get_anchor_attribute(input_tensor, ASCAnchorAttr.IS_TRAINING) - export_pb_graph("cut_graph_%s.pb" % get_next_op.name, dump_graph, graph_def=sub_graph_def) + export_pb_graph(f"cut_graph_{get_next_op.name}.pb", dump_graph, graph_def=sub_graph_def) return get_next_op_map @@ -341,7 +341,7 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): try: one_tensor = [v for _, v in new_batch.items()][0] except IndexError as err: - raise IndexError(f"Cannot find a tensor from given batch.") from err + raise IndexError("Cannot find a tensor from given batch.") from err new_get_next_op_name = find_target_dataset_op(one_tensor.op, "IteratorGetNext").name update_input_tensor_with_new_batch(records.get("replacement_specs"), new_get_next_op_name) @@ -386,7 +386,7 @@ def lookup_for_same_table(sub_cutting_point_list: list, is_training: bool): for one_feature_spec in same_table_feature_spec: feature_ids = feature_spec_ids_dict.get(one_feature_spec.name) if feature_ids is None: - raise RuntimeError(f"In the case of multiple lookups of a table, feature ids cannot be None.") + raise RuntimeError("In the case of multiple lookups of a table, feature ids cannot be None.") tensor_list.append(feature_ids) table_instance = SparseEmbedding.get_anchor_attribute(feature_ids, ASCAnchorAttr.TABLE_INSTANCE) @@ -408,7 +408,7 @@ def lookup_for_same_table(sub_cutting_point_list: list, is_training: bool): kwargs = {"multi_lookup": True, "is_train": is_training} if table_instance is None: - raise RuntimeError(f"In the case of multiple lookups of a table, table instance cannot be None.") + raise RuntimeError("In the case of multiple lookups of a table, table instance cannot be None.") lookup_result = table_instance.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, table_instance.same_table_send_count, **kwargs) @@ -486,7 +486,7 @@ def build_asc_graph(config: dict, table_instance: SparseEmbedding, cutting_point # In the case of multiple lookups of a table, replace the stub node of the lookup result in the graph if len(table_instance.lookup_name_list) > 1: if lookup_result is None: - raise RuntimeError(f"In the case of multiple lookups of a table, lookup result cannot be None.") + raise RuntimeError("In the case of multiple lookups of a table, lookup result cannot be None.") replace_anchor_vec(cutting_point, ASCAnchorAttr.LOOKUP_RESULT, lookup_result) logging.info(f"The lookup result corresponding to feature ids '{cutting_point}' has been replaced by " f"'{lookup_result}'.") diff --git a/mx_rec/logger/log.py b/mx_rec/logger/log.py new file mode 100644 index 00000000..e24a7efe --- /dev/null +++ b/mx_rec/logger/log.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2023 Huawei Technologies Co., Ltd + +import logging.config +import os +import yaml + +from mx_rec.constants.constants import MAX_SIZE, LOG_MAX_SIZE +from mx_rec.validator.validator import FileValidator +from mx_rec.validator.validator import DirectoryValidator + + +def init_sys_log(): + work_dir = os.path.dirname(os.path.dirname(__file__)) + log_cfg_file = os.path.join(work_dir, "logger.yaml") + real_config_path = os.path.realpath(log_cfg_file) + + if not FileValidator(log_cfg_file).check_file_size(real_config_path).check().is_valid(): + raise ValueError("Config file size is not valid.") + + with open(real_config_path, 'r', encoding='utf-8') as open_file: + if not FileValidator(real_config_path).\ + check_file_size(LOG_MAX_SIZE).\ + check_not_soft_link().\ + check_user_group().\ + is_valid(): + raise ValueError("Log config file is not valid.") + + data = open_file.read(LOG_MAX_SIZE) + log_cfg = yaml.safe_load(data) + + logging.config.dictConfig(log_cfg) + + +def init_log_dir_for_dt(log_cfg): + """Create log directory for local environment dt test. + + :param log_cfg: log configuration dictionary from yml file. + :return: None + """ + handlers = log_cfg.get('handlers') + if not handlers: + return + + for handler_name in handlers: + handler_dict = handlers.get(handler_name) + log_file = handler_dict.get('filename') + + if not log_file: + continue + + log_file_standard = os.path.realpath(log_file) + if log_file_standard != log_file: + continue + + log_dir = os.path.dirname(log_file_standard) + if not DirectoryValidator(log_dir) \ + .check_is_not_none() \ + .check_dir_name() \ + .should_not_contains_sensitive_words() \ + .with_blacklist() \ + .check() \ + .is_valid(): + continue + + +init_sys_log() +srv_stream_log = logging.getLogger("logStream") +env_log_level = os.getenv("MXREC_LOG_LEVEL") +srv_log = srv_stream_log +if env_log_level: + srv_log.setLevel(env_log_level) + diff --git a/mx_rec/logger/logger.yaml b/mx_rec/logger/logger.yaml new file mode 100644 index 00000000..13eb3158 --- /dev/null +++ b/mx_rec/logger/logger.yaml @@ -0,0 +1,18 @@ +version: 1 +formatters: + simpleFmt: + format: '[%(asctime)s][%(levelname)s][%(message)s]' + wholeFmt: + format: '[%(asctime)s][%(levelname)s][%(message)s][%(filename)s, %(funcName)s:%(lineno)d][%(process)d, %(thread)d]' +handlers: + runStreamHandler: + class: logging.handlers.RotatingFileHandler + level: INFO + formatter: wholeFmt + stream: ext://sys.stdout + +loggers: + logStream: + level: INFO + handlers: [runStreamHandler] + propagate: no \ No newline at end of file diff --git a/mx_rec/optimizers/ftrl_t_dense.py b/mx_rec/optimizers/ftrl_t_dense.py index 4dbbdb50..271da078 100644 --- a/mx_rec/optimizers/ftrl_t_dense.py +++ b/mx_rec/optimizers/ftrl_t_dense.py @@ -25,7 +25,6 @@ from mx_rec.util.variable import check_and_get_config_via_var def create_ftrl_dense_optimizer(learning_rate, use_locking=False, name="Ftrl_t_dense", **kwargs): - return CustomizedFtrlTZ(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) @@ -169,10 +168,10 @@ class CustomizedFtrlTZ(optimizer.Optimizer): def _create_slots(self, var_list): # Create slots for the first and second moments. - z_state_name = self._name + "/" + "z" - n_state_name = self._name + "/" + "n" - g_state_name = self._name + "/" + "g" - w_state_name = self._name + "/" + "w" + z_state_name = f"{self._name}/z" + n_state_name = f"{self._name}/n" + g_state_name = f"{self._name}/g" + w_state_name = f"{self._name}/w" for each_var in var_list: with ops.colocate_with(each_var): z_zero = self._zeros_slot(each_var, "z", z_state_name) @@ -184,5 +183,3 @@ class CustomizedFtrlTZ(optimizer.Optimizer): insert_removing_var_list(n_zero.name) insert_removing_var_list(g_zero.name) insert_removing_var_list(w_zero.name) - - diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 1d7cbc10..f60a0c36 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -103,7 +103,7 @@ def save_check(latest_filename, sess): if os.path.split(latest_filename)[0]: raise ValueError("'latest_filename' must not contain path components") if not context.executing_eagerly() and not isinstance(sess, session.SessionInterface): - raise TypeError("'sess' must be a Session; %s" % sess) + raise TypeError(f"'sess' must be a Session; {sess}") def get_model_checkpoint_path(self, checkpoint_file, sess): @@ -283,8 +283,7 @@ def saver_from_object_based_checkpoint(checkpoint_path, var_list=None, builder=N try: names_to_keys = object_graph_key_mapping(checkpoint_path) except errors.NotFoundError as err: - raise ValueError("Checkpoint in %s not an object-based checkpoint." % - checkpoint_path) from err + raise ValueError(f"Checkpoint in {checkpoint_path} not an object-based checkpoint.") from err if var_list is None: var_list = build_var_list() diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index dcfd4fca..5af58e98 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -70,9 +70,9 @@ class Saver(object): if global_step: if not isinstance(global_step, compat.integral_types): global_step = int(sess.run(global_step)) - ckpt_name = "sparse-%s-%d" % (base_name, global_step) + ckpt_name = f"sparse-{base_name}-{global_step}" else: - ckpt_name = "sparse-%s" % base_name + ckpt_name = f"sparse-{base_name}" integrated_path = os.path.join(directory, ckpt_name) saving_path = integrated_path @@ -91,7 +91,7 @@ class Saver(object): def restore(self, sess, reading_path): logging.debug("======== Start restoring ========") directory, base_name = os.path.split(reading_path) - ckpt_name = "sparse-%s" % base_name + ckpt_name = f"sparse-{base_name}" reading_path = os.path.join(directory, ckpt_name) if not tf.io.gfile.exists(reading_path): diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index 1649bec4..b4edb2cd 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -76,12 +76,12 @@ class SparseProcessor: file_validator.check() attributes = json.load(fin) except FileNotFoundError as err: - raise FileNotFoundError(f"attribute dir not found.") from err + raise FileNotFoundError("attribute dir not found.") from err else: try: attributes = np.fromfile(attribute_dir, np.uint64) except FileNotFoundError as err: - raise FileNotFoundError(f"attribute dir not found.") from err + raise FileNotFoundError("attribute dir not found.") from err return attributes -- Gitee From f26101f446e27d7f673045ccffb16d47f2528199 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 18 Jul 2023 09:38:20 +0800 Subject: [PATCH 218/551] Match-id-09adc66ae3b98c21cd461f566a74e32a0af06342 --- mx_rec/core/embedding.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 183640b5..ec9bbde5 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -563,10 +563,9 @@ class SparseEmbedding: def grad(lookup_diff): embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) - unique_embeddings_shape = unique_embeddings.shape.as_list() if use_static else tf.shape( - unique_embeddings) + unique_embed_shape = unique_embeddings.shape.as_list() if use_static else tf.shape(unique_embeddings) unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, - unique_embeddings_shape[0]) + unique_embed_shape[0]) bp_all2all_args = all2all_args if use_static else tf.transpose(all2all_args) if hot_pos is not None: hot, cold = tf.split(unique_grads, [tf.shape(hot_pos)[0], -- Gitee From 670b5e39b5f4833bf651a345b563d709428f5f58 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 14 Jul 2023 13:22:47 +0800 Subject: [PATCH 219/551] Match-id-b578cb6cf3c751d3f5e40e9ee0d527cfa9f6f075 --- src/CMakeLists.txt | 6 ++++-- src/core/CMakeLists.txt | 6 ++---- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 3 ++- src/core/utils/common.cpp | 24 +++++++++++++++++++++--- src/core/utils/common.h | 20 +++++++++++++------- src/ops_tf/hybrid_dataset_ops.cpp | 8 ++++---- src/tests/utils/common_test.cpp | 5 +++-- 7 files changed, 49 insertions(+), 23 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3f4118bd..1a8ad26a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -94,13 +94,15 @@ endif() if(IS_DIRECTORY ${OPENSOURCE_DIR}) add_subdirectory(${OPENSOURCE_DIR}/pybind11 pybind11.out) + + option(WITH_CUSTOM_PREFIX "use for glog v0.5.0 to enable custom log format" ON) add_subdirectory(${OPENSOURCE_DIR}/glog glog.out) + include_directories(glog.out) + install(TARGETS glog LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) else() message(FATAL_ERROR "INVALID FOLDER, ${OPENSOURCE_DIR}") endif() -include_directories(${PROJECT_SOURCE_DIR}/../../opensource/opensource/glog/include) -install(TARGETS glog LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) add_subdirectory(core) add_subdirectory(ops_tf) add_subdirectory(pybind) diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 266f4335..27fbc4be 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -1,8 +1,6 @@ cmake_minimum_required(VERSION 3.12) set(CMAKE_CXX_STANDARD 17) -set(WITH_CUSTOM_PREFIX ON) - if(NOT ABSEIL_PATH) set(ABSEIL_PATH ${PYTHON_PATH}/lib/python3.7/site-packages/tensorflow_core/) endif() @@ -48,8 +46,8 @@ target_link_libraries(ASC PUBLIC -l:_tf_adapter.so OpenMP::OpenMP_CXX ${MPI_CXX_LIBRARIES} ${PYTHON_LIBRARY} - PUBLIC glog::glog - ) + glog::glog +) find_package(easy_profiler PATHS ${EASY_PROFILER_PATH} NO_DEFAULT_PATH) if (easy_profiler_FOUND) message("==link with easy_profiler==") diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 89247a2f..ebcc55f7 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -78,7 +78,8 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, if (isRunning) { return true; } - SetLog(); + + SetLog(rankInfo.rankId); InitRankInfo(rankInfo, embInfos); LOG(INFO) << StringFormat( diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index d048cdb3..c21f0652 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include @@ -27,6 +26,7 @@ namespace MxRec { bool PerfConfig::fastUnique = false; string g_rankId; int g_glogLevel; + bool g_isGlogInit = false; RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, @@ -96,15 +96,33 @@ namespace MxRec { } } - void SetLog() + void SetLog(int rank) { - // glog 0.5.0 can't pass any args into, and not support custom format auto logLevel = getenv("GLOG_stderrthreshold"); if (logLevel == nullptr) { g_glogLevel = 0; // default as INFO } else { g_glogLevel = atoi(logLevel); } + if (!g_isGlogInit) { + google::InitGoogleLogging("mxRec", &CustomGlogFormat, &rank); + g_isGlogInit = true; + } + } + + void CustomGlogFormat(std::ostream &s, const LogMessageInfo &l, void* rank) + { + if (g_rankId.empty()) { + g_rankId = std::to_string(*static_cast(rank)); + } + + s << "[" + << setw(GLOG_TIME_WIDTH_2) << l.time.hour() << ':' + << setw(GLOG_TIME_WIDTH_2) << l.time.min() << ':' + << setw(GLOG_TIME_WIDTH_2) << l.time.sec() << "." + << setw(GLOG_TIME_WIDTH_6) << l.time.usec() << "]" + << " [" + g_rankId + "]" + << " [" << l.severity << "] "; } string GetChipName(int devID) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 2b8e3b6a..4fc2d6a0 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -23,8 +23,8 @@ #include #include #include -#include #include "tensorflow/core/framework/tensor.h" +#include "glog/logging.h" // note: must set behind any tensorflow reference, otherwise will overwrite logging.h #include "absl/container/flat_hash_map.h" #include "securec.h" @@ -49,8 +49,6 @@ namespace MxRec { #define INFO_PTR shared_ptr #define MGMT_CPY_THREADS 4 #define PROFILING - extern int g_glogLevel; - using namespace tensorflow; constexpr int TRAIN_CHANNEL_ID = 0; constexpr int EVAL_CHANNEL_ID = 1; @@ -62,7 +60,10 @@ namespace MxRec { constexpr int KEY_PROCESS_THREAD = 6; // for GLOG - constexpr int GLOG_MAX_BUF_SIZE = 2048; + extern int g_glogLevel; + constexpr int GLOG_MAX_BUF_SIZE = 1024; + constexpr int GLOG_TIME_WIDTH_2 = 2; + constexpr int GLOG_TIME_WIDTH_6 = 6; // unique related config constexpr int UNIQUE_BUCKET = 6; @@ -266,7 +267,9 @@ struct BatchTask { absl::flat_hash_map timestamps; // 用于特征准入&淘汰的时间戳 }; - void SetLog(); + void SetLog(int rank); + + void CustomGlogFormat(std::ostream &s, const LogMessageInfo &l, void* rank); template string StringFormat(const string& format, Args ... args) @@ -274,8 +277,11 @@ struct BatchTask { auto size = static_cast(GLOG_MAX_BUF_SIZE); unique_ptr buf(new char[size]); memset_s(buf.get(), size, 0, size); - snprintf_s(buf.get(), size, SECUREC_STRING_MAX_LEN-1, format.c_str(), args ...); - return string(buf.get(), buf.get() + size); + int nChar = snprintf_s(buf.get(), size, size-1, format.c_str(), args ...); + if (nChar == -1) { + throw invalid_argument("StringFormat failed"); + } + return string(buf.get(), buf.get() + nChar); } // use environment variable GLOG_v to decide if showing debug log. diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index a9ef8f2e..cd5209ad 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -485,18 +485,18 @@ public: size_t& dataSize) { if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 - LOG(ERROR) << StringFormat("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); + LOG(ERROR) << StringFormat("dataSize[%d], fieldNum[%d] ...", dataSize, fieldNumTmp); return false; } // 前面8个字节、即占一个featureId位,是unix时间戳 auto src = (const time_t*)inputTensor.tensor_data().data(); std::copy(src, src + 1, ×tamp); - LOG(INFO) << StringFormat("current batchId[{}] timestamp[{}]", batchId, timestamp); + LOG(INFO) << StringFormat("current batchId[%d] timestamp[%d]", batchId, timestamp); dataSize -= 1; if (timestamp <= 0) { - LOG(ERROR) << StringFormat("timestamp[{}] <= 0 ", timestamp); + LOG(ERROR) << StringFormat("timestamp[%d] <= 0 ", timestamp); return false; } @@ -599,7 +599,7 @@ public: void Compute(OpKernelContextPtr context) override { - LOG(INFO) << StringFormat("context {}", context->step_id()); + LOG(INFO) << StringFormat("context %d", context->step_id()); std::cout << " Cust opp not installed!!" << std::endl; } diff --git a/src/tests/utils/common_test.cpp b/src/tests/utils/common_test.cpp index 90d60c3f..b9e835b9 100644 --- a/src/tests/utils/common_test.cpp +++ b/src/tests/utils/common_test.cpp @@ -18,11 +18,12 @@ using namespace testing; TEST(common, SetLog) { - SetLog(); + int rankId = 0; + SetLog(rankId); ASSERT_EQ(g_glogLevel, 0); putenv("GLOG_stderrthreshold=1"); - SetLog(); + SetLog(rankId); ASSERT_EQ(g_glogLevel, 1); } -- Gitee From a49914ed72737acd28ca013e197b0730358f6800 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 20 Jul 2023 10:52:22 +0800 Subject: [PATCH 220/551] Match-id-eae3f2f9e447b063691385bb0531321d393f5461 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 ++-- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index ebcc55f7..0f8b72f0 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -272,7 +272,7 @@ void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) #endif } -bool HybridMgmt::IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t* embTableCount) +bool HybridMgmt::IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t& embTableCount) { bool loadDataMatches = { true }; const auto& loadEmbTable { loadHostEmbs->find(setupHostEmbs->name) }; @@ -321,7 +321,7 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) size_t embTableCount { 0 }; auto loadHostEmbs { loadData.hostEmbs }; for (EmbInfo setupHostEmbs : mgmtEmbInfo) { - if (!IsLoadDataMatches(loadHostEmbs, &setupHostEmbs, &embTableCount)) { + if (!IsLoadDataMatches(loadHostEmbs, &setupHostEmbs, embTableCount)) { return false; } } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 6db27f56..03ff37b2 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -109,7 +109,7 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); - bool IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t* embTableCount); + bool IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t& embTableCount); private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, -- Gitee From ad7a8361d6e45186d4d8cb72437d026eb62b88d5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 25 Jul 2023 20:32:25 +0800 Subject: [PATCH 221/551] Match-id-d4a3dfe581ee7453785068fda01d29f24562a165 --- src/core/emb_hashmap/emb_hashmap.cpp | 17 +++++++++++------ src/core/emb_hashmap/emb_hashmap.h | 2 +- src/core/emb_table/emb_table.cpp | 2 +- src/core/utils/common.cpp | 11 +++++------ src/core/utils/common.h | 2 +- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 6 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 0bd54d63..d1918cde 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -57,7 +57,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t auto keepBatch = swapId - iBatch; bool findOffsetV2 = getenv("FIND_OFFSET_V2") != nullptr; - VLOG(GLOG_DEBUG) << StringFormat("FindOffset, %s", findOffsetV2); + VLOG(GLOG_DEBUG) << StringFormat("FindOffset version:%d", findOffsetV2); if (findOffsetV2) { FindAndUpdateOffset(embName, keys, swapId, keepBatch, channelId); @@ -332,7 +332,13 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys embHashMap.lookUpVec.emplace_back(INVALID_KEY_VALUE); continue; } - auto offset = FindOffsetHelper(key, embHashMap, channelId); + size_t offset; + auto isOffsetValid = FindOffsetHelper(key, embHashMap, channelId, offset); + if (!isOffsetValid) { + embHashMap.lookUpVec.emplace_back(INVALID_KEY_VALUE); + continue; + } + if (offset < embHashMap.devVocabSize) { embHashMap.lookUpVec.emplace_back(offset); embHashMap.devOffset2KeyOld.emplace_back(offset, static_cast(embHashMap.devOffset2Key[offset])); @@ -351,10 +357,9 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys } -size_t EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId) +bool EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset) { - size_t offset; const auto& iter = embHashMap.hostHashMap.find(key); if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; @@ -388,10 +393,10 @@ size_t EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHas throw runtime_error("hostVocabSize too small"); } } else { - offset = -1; + return false; } } - return offset; + return true; } void EmbHashMap::UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index c9b88c09..e2216e43 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -55,7 +55,7 @@ namespace MxRec { void FindOffset(const string& embName, const vector& keys, size_t currentBatchId, size_t keepBatchId, int channelId); - size_t FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId); + bool FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset); void UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, EmbHashMapInfo& embHashMap) const; diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 3ccbe345..8bde2888 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -123,7 +123,7 @@ void EmbTable::RandomInit(void* newBlock, const vector& initiali switch (initializeInfo.initializerType) { case InitializerType::CONSTANT: { LOG(INFO) << StringFormat( - "Device GenerateEmbData ing using Constant Initializer by value %d. name %s, start %d, len %d.", + "Device GenerateEmbData ing using Constant Initializer by value %f. name %s, start %d, len %d.", initializeInfo.constantInitializerInfo.constantValue, initializeInfo.name.c_str(), initializeInfo.start, initializeInfo.len); initializer = &initializeInfo.constantInitializer; diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index c21f0652..706c16ed 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -104,18 +104,17 @@ namespace MxRec { } else { g_glogLevel = atoi(logLevel); } + if (g_rankId.empty()) { + g_rankId = std::to_string(rank); + } if (!g_isGlogInit) { - google::InitGoogleLogging("mxRec", &CustomGlogFormat, &rank); + google::InitGoogleLogging("mxRec", &CustomGlogFormat); g_isGlogInit = true; } } - void CustomGlogFormat(std::ostream &s, const LogMessageInfo &l, void* rank) + void CustomGlogFormat(std::ostream &s, const LogMessageInfo &l, void*) { - if (g_rankId.empty()) { - g_rankId = std::to_string(*static_cast(rank)); - } - s << "[" << setw(GLOG_TIME_WIDTH_2) << l.time.hour() << ':' << setw(GLOG_TIME_WIDTH_2) << l.time.min() << ':' diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 4fc2d6a0..f51d50c0 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -269,7 +269,7 @@ struct BatchTask { void SetLog(int rank); - void CustomGlogFormat(std::ostream &s, const LogMessageInfo &l, void* rank); + void CustomGlogFormat(std::ostream &s, const LogMessageInfo &l, void*); template string StringFormat(const string& format, Args ... args) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index cd5209ad..4891a126 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -214,7 +214,7 @@ public: EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); VLOG(GLOG_DEBUG) << StringFormat( KEY_PROCESS "ReadEmbKeyV2Dynamic read batch cost(ms):%d, elapsed from last(ms):%d," - " enqueueTC(ms):%d, batch[%d}]:%d", + " enqueueTC(ms):%d, batch[%d]:%d", tc.ElapsedMS(), staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId ); staticSw = TimeCost(); -- Gitee From 54bf604e849903b67b409aecc080688769a0333e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 27 Jul 2023 16:53:27 +0800 Subject: [PATCH 222/551] Match-id-d24c2cc1c67d7a30264329202f5bd2a9d8c02a6c --- tools/python/key_2_emb_formatter.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tools/python/key_2_emb_formatter.py b/tools/python/key_2_emb_formatter.py index 4f5ca6b3..0c838b29 100644 --- a/tools/python/key_2_emb_formatter.py +++ b/tools/python/key_2_emb_formatter.py @@ -12,7 +12,8 @@ import os import re import numpy as np -from mx_rec.validator.validator import FileValidator +MIN_SIZE = 1 +MAX_SIZE = 1024 * 1024 * 1024 * 1024 parser = argparse.ArgumentParser() @@ -56,7 +57,7 @@ class Formatter: self._json_attrib_dtype = "data_type" self._json_attrib_shape = "shape" self._host_attrib_dtype = np.uint64 - self._hashmap_dtype = np.uint32 + self._hashmap_dtype = np.uint64 self._raw_key_dtype = np.uint64 self._key_dtype = np.int64 self._raw_key_offset = np.iinfo(np.uint32).max @@ -125,12 +126,15 @@ class Formatter: if self._is_ddr_mode: data_file, attribute_file = self._get_file_names(host_hashmap_dir) + attribute = self._get_attribute(host_hashmap_dir, attribute_file, is_json=False) + data_shape = attribute[:2] + raw_hashmap = self._get_data(host_hashmap_dir, data_file, self._hashmap_dtype, data_shape) else: data_file, attribute_file = self._get_file_names(dev_hashmap_dir) + attribute = self._get_attribute(dev_hashmap_dir, attribute_file, is_json=False) + data_shape = attribute[:2] + raw_hashmap = self._get_data(dev_hashmap_dir, data_file, self._hashmap_dtype, data_shape) - attribute = self._get_attribute(host_hashmap_dir, attribute_file, is_json=False) - data_shape = attribute[:2] - raw_hashmap = self._get_data(host_hashmap_dir, data_file, self._hashmap_dtype, data_shape) offset = raw_hashmap[:, -1] raw_key = raw_hashmap[:, :2].astype(self._raw_key_dtype) key = raw_key[:, 0] * self._raw_key_offset + raw_key[:, 1] @@ -187,13 +191,12 @@ class Formatter: file_dir = os.path.join(directory, file_name) if is_json: with open(file_dir, "r") as fin: - # check whether attribute file is valid - file_validator = FileValidator(file_dir) - # 1.check whether file_dir is soft link - file_validator.check_not_soft_link() - # 2.check attribute file size - file_validator.check_file_size(fin) - file_validator.check() + # check file whether is valid + file_info = os.stat(fin.fileno()) + if file_info.st_size < MIN_SIZE or file_info.st_size > MAX_SIZE: + raise ValueError(f"file size {file_info.st_size} is not in range[{MIN_SIZE}, {MAX_SIZE}]") + if os.path.islink(file_dir): + raise ValueError(f"file dir {file_dir} is soft link or relative path") attributes = json.load(fin) return attributes else: -- Gitee From de251a3df20e2fe28b23a531bf25d514885f6928 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 27 Jul 2023 20:53:38 +0800 Subject: [PATCH 223/551] Match-id-c2dcffd7405a3fdc0b525589c224f2313e83ae5d --- example/little_demo/main.py | 18 +- example/little_demo/run.sh | 1 + mx_rec/core/asc/feature_spec.py | 15 +- mx_rec/core/asc/manager.py | 7 +- mx_rec/optimizers/base.py | 4 +- src/core/checkpoint/checkpoint.cpp | 2 +- src/core/checkpoint/checkpoint.h | 2 +- .../ckpt_data_handler/ckpt_data_handler.h | 2 +- .../feat_admit_n_evict_ckpt.cpp | 69 ++++--- .../feat_admit_n_evict_ckpt.h | 12 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 +- .../key_process/feature_admit_and_evict.cpp | 109 ++++++---- .../key_process/feature_admit_and_evict.h | 18 +- src/core/utils/common.cpp | 16 ++ src/core/utils/common.h | 17 +- src/pybind/module_main.cpp | 7 +- src/tests/checkpoint/checkpoint_test.cpp | 63 ++++-- .../ckpt_data_handler_test.cpp | 195 ++++++++++++------ src/tests/emb_mgmt/emb_mgmt_test.cpp | 6 +- .../feature_admit_and_evict_test.cpp | 69 ++++--- 20 files changed, 419 insertions(+), 217 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index d2840d5f..b8abf09e 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -93,17 +93,21 @@ def create_feature_spec_list(use_timestamp=False): eviction_threshold = cfg.eviction_threshold if use_timestamp else None feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold), + eviction_threshold=eviction_threshold, + faae_coefficient=1), FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold)] + eviction_threshold=eviction_threshold, + faae_coefficient=4)] if use_multi_lookup: feature_spec_list.extend([FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold), + eviction_threshold=eviction_threshold, + faae_coefficient=1), FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold)]) + eviction_threshold=eviction_threshold, + faae_coefficient=4)]) if use_timestamp: feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) return feature_spec_list @@ -152,8 +156,10 @@ if __name__ == "__main__": # access_threshold unit counts; eviction_threshold unit seconds ACCESS_AND_EVICT = None if USE_TIMESTAMP: - config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) - config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) + config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, + faae_coefficient=1) + config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, + faae_coefficient=4) ACCESS_AND_EVICT = dict(user_table=config_for_user_table, item_table=config_for_item_table) train_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) eval_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 2cd84a80..ba99348f 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -84,6 +84,7 @@ export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 export UpdateEmb_V2=0 # 0: UpdateEmb同步更新;1:UpdateEmb_V2异步更新 +export FAAE_MODE=0 # 0: combine history when faae; 1: separate history when faae ################# 性能调优相关 #################### export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 export FAST_UNIQUE=0 #if use fast unique diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 60d93291..8351a532 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -30,6 +30,7 @@ class FeatureSpec: self._feat_cnt = kwargs.get("feat_count") self._access_threshold = kwargs.get("access_threshold") self._eviction_threshold = kwargs.get("eviction_threshold") + self._faae_coefficient = kwargs.get("faae_coefficient", 1) self._is_timestamp = kwargs.get("is_timestamp") self.feat_pos_train = None self.feat_pos_eval = None @@ -53,6 +54,10 @@ class FeatureSpec: def eviction_threshold(self): return self._eviction_threshold + @property + def faae_coefficient(self): + return self._faae_coefficient + @property def index_key(self): return self._index_key @@ -117,6 +122,11 @@ class FeatureSpec: if self._eviction_threshold > MAX_INT32: raise ValueError(f"Eviction_threshold is too big that exceed int32.") + if self._faae_coefficient is not None: + check_natural_number(self._faae_coefficient, "eviction_threshold") + if self._faae_coefficient > MAX_INT32: + raise ValueError(f"Eviction_threshold is too big that exceed int32.") + if self._is_timestamp is not None: check_bool(self._is_timestamp, "is_timestamp") @@ -190,10 +200,13 @@ class FeatureSpec: def get_feature_spec(table_name, access_and_evict_config): access_threshold = None eviction_threshold = None + faae_coefficient = None if access_and_evict_config: access_threshold = access_and_evict_config.get("access_threshold") eviction_threshold = access_and_evict_config.get("eviction_threshold") - return FeatureSpec(table_name, access_threshold=access_threshold, eviction_threshold=eviction_threshold) + faae_coefficient = access_and_evict_config.get("faae_coefficient", 1) + return FeatureSpec(table_name, access_threshold=access_threshold, eviction_threshold=eviction_threshold, + faae_coefficient=faae_coefficient) def set_temporary_feature_spec_attribute(mock_feature_spec: FeatureSpec, total_feature_count: Union[int, tf.Tensor]): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index dfc8eb7a..9fb1976f 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -162,16 +162,19 @@ def generate_threshold_list(): threshold_list = [] for _, feature_spec in export_feature_spec().items(): + coef = 1 if feature_spec.faae_coefficent is None else feature_spec.faae_coefficent if feature_spec.eviction_threshold: threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, - feature_spec.eviction_threshold) + feature_spec.eviction_threshold, + coef) threshold_list.append(threshold) continue if feature_spec.access_threshold: threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, - -1) + -1, + coef) threshold_list.append(threshold) return threshold_list diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index 9ad62ec6..632e8ba3 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -41,7 +41,7 @@ class CustomizedOptimizer: self.base_name = name -def my_update_op(self, opt, grad): +def custom_update_op(self, opt, grad): if isinstance(grad, ops.Tensor): update_op = opt._apply_sparse(grad, self._v) # pylint: disable=protected-access return update_op @@ -50,5 +50,5 @@ def my_update_op(self, opt, grad): def patch_for_optimizer(): - _TensorProcessor.update_op = my_update_op + _TensorProcessor.update_op = custom_update_op logging.debug("update_op in Class optimizer._TensorProcessor has been patched.") \ No newline at end of file diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 222f8565..df4b2115 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -69,7 +69,7 @@ void Checkpoint::SetDataHandler(CkptData& ckptData) if (!ckptData.keyOffsetMap.empty()) { dataHandlers.push_back(make_unique()); } - if (!ckptData.tens2Thresh.empty() && !ckptData.histRec.timestamps.empty() && + if (!ckptData.table2Thresh.empty() && !ckptData.histRec.timestamps.empty() && !ckptData.histRec.historyRecords.empty()) { dataHandlers.push_back(make_unique()); } diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 1da11b0b..c9843ecf 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -42,7 +42,7 @@ namespace MxRec { CkptDataType::EMB_INFO, CkptDataType::EMB_CURR_STAT, CkptDataType::NDDR_OFFSET, - CkptDataType::TENSOR_2_THRESH + CkptDataType::TABLE_2_THRESH }; const set int64TransSet{ CkptDataType::EMB_HASHMAP, diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index ecc7907c..aea1d2b7 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -47,7 +47,7 @@ namespace MxRec { "embedding_current_status", "max_offset", "key_offset_map", - "tensor_2_threshold", + "table_2_threshold", "history_record" }; const vector dataElmtBytes { 4, 4, 8, 8, 4, 4, 8, 4, 8, 8 }; diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index 1e4e9d69..a2d6f96e 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -14,18 +14,18 @@ using namespace MxRec; void FeatAdmitNEvictCkpt::SetProcessData(CkptData& processData) { ClearData(); - if (processData.tens2Thresh.empty() || processData.histRec.timestamps.empty() || + if (processData.table2Thresh.empty() || processData.histRec.timestamps.empty() || processData.histRec.historyRecords.empty()) { LOG(ERROR) << "Missing Feature Admit and Evict data"; throw std::runtime_error("Missing Feature Admit and Evict data"); } - saveTens2Thresh = std::move(processData.tens2Thresh); + saveTable2Thresh = std::move(processData.table2Thresh); saveHistRec = std::move(processData.histRec); } void FeatAdmitNEvictCkpt::GetProcessData(CkptData& processData) { - processData.tens2Thresh = std::move(loadTens2Thresh); + processData.table2Thresh = std::move(loadTable2Thresh); processData.histRec = std::move(loadHistRec); ClearData(); } @@ -43,7 +43,7 @@ vector FeatAdmitNEvictCkpt::GetDirNames() vector FeatAdmitNEvictCkpt::GetEmbNames() { vector embNames; - for (const auto& item : saveTens2Thresh) { + for (const auto& item : saveTable2Thresh) { embNames.push_back(item.first); } return embNames; @@ -51,8 +51,8 @@ vector FeatAdmitNEvictCkpt::GetEmbNames() CkptTransData FeatAdmitNEvictCkpt::GetDataset(CkptDataType dataType, string embName) { - map> dataTransMap { { CkptDataType::TENSOR_2_THRESH, - [=] { SetTens2ThreshTrans(embName); } }, + map> dataTransMap { { CkptDataType::TABLE_2_THRESH, + [=] { SetTable2ThreshTrans(embName); } }, { CkptDataType::HIST_REC, [=] { SetHistRecTrans(embName); } } }; CleanTransfer(); @@ -62,8 +62,8 @@ CkptTransData FeatAdmitNEvictCkpt::GetDataset(CkptDataType dataType, string embN void FeatAdmitNEvictCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { - map> dataLoadMap { { CkptDataType::TENSOR_2_THRESH, - [=] { SetTens2Thresh(embName); } }, + map> dataLoadMap { { CkptDataType::TABLE_2_THRESH, + [=] { SetTable2Thresh(embName); } }, { CkptDataType::HIST_REC, [=] { SetHistRec(embName); } } }; CleanTransfer(); @@ -73,27 +73,31 @@ void FeatAdmitNEvictCkpt::SetDataset(CkptDataType dataType, string embName, Ckpt void FeatAdmitNEvictCkpt::ClearData() { - saveTens2Thresh.clear(); - loadTens2Thresh.clear(); + saveTable2Thresh.clear(); + loadTable2Thresh.clear(); saveHistRec.timestamps.clear(); saveHistRec.historyRecords.clear(); loadHistRec.timestamps.clear(); loadHistRec.historyRecords.clear(); } -void FeatAdmitNEvictCkpt::SetTens2ThreshTrans(string embName) +void FeatAdmitNEvictCkpt::SetTable2ThreshTrans(string embName) { - auto tens2ThreshSize = GetTens2ThreshSize(); + auto table2ThreshSize = GetTable2ThreshSize(); auto& transArr = transferData.int32Arr; - const auto& tens2Thresh = saveTens2Thresh.at(embName); + const auto& table2Thresh = saveTable2Thresh.at(embName); - transArr.reserve(tens2ThreshSize); - transArr.push_back(tens2Thresh.countThreshold); - transArr.push_back(tens2Thresh.timeThreshold); + transArr.reserve(table2ThreshSize); + transArr.push_back(table2Thresh.countThreshold); + transArr.push_back(table2Thresh.timeThreshold); } +// save void FeatAdmitNEvictCkpt::SetHistRecTrans(string embName) { + if (GetCombineSwitch()) { + embName = COMBINE_HISTORY_NAME; + } auto histRecSize = GetHistRecSize(embName); auto& transArr = transferData.int64Arr; const auto& timeStamp = saveHistRec.timestamps.at(embName); @@ -109,18 +113,22 @@ void FeatAdmitNEvictCkpt::SetHistRecTrans(string embName) } } -void FeatAdmitNEvictCkpt::SetTens2Thresh(string embName) +void FeatAdmitNEvictCkpt::SetTable2Thresh(string embName) { const auto& transArr = transferData.int32Arr; - auto& tens2Thresh = loadTens2Thresh[embName]; + auto& tens2Thresh = loadTable2Thresh[embName]; - tens2Thresh.tensorName = embName; + tens2Thresh.tableName = embName; tens2Thresh.countThreshold = transArr[countThresholdIdx]; tens2Thresh.timeThreshold = transArr[timeThresholdIdx]; } +// load void FeatAdmitNEvictCkpt::SetHistRec(string embName) { + if (GetCombineSwitch()) { + embName = COMBINE_HISTORY_NAME; + } const auto& transArr = transferData.int64Arr; const auto& attribute = transferData.attribute; auto& timestamp = loadHistRec.timestamps[embName]; @@ -129,17 +137,26 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) timestamp = transArr.front(); size_t featItemInfoTotalSize = attribute.front() * static_cast(featItemInfoSaveNum); - for (size_t i = featItemInfoOffset; i < featItemInfoTotalSize + featItemInfoOffset; i += featItemInfoSaveNum) { - const auto& featureId = transArr[i + featureIdIdxOffset]; - const auto& count = transArr[i + countIdxOffset]; - const auto& lastTime = transArr[i + lastTimeIdxOffset]; + VLOG(GLOG_DEBUG) << StringFormat("====Start SetHistRec, name: %s, featItemInfoTotalSize: %ld", embName.c_str(), + featItemInfoTotalSize); - histRecs[featureId].count = static_cast(count); - histRecs[featureId].lastTime = lastTime; + size_t process = 0; + size_t printPerStep = ((featItemInfoTotalSize / 100) > 0 ? (featItemInfoTotalSize / 100) : 1); + for (size_t i = featItemInfoOffset; i < featItemInfoTotalSize + featItemInfoOffset; i += featItemInfoSaveNum) { + process = i % printPerStep; + if (process == 1) { + VLOG(GLOG_DEBUG) << StringFormat("====in SetHistRec, process : %f", i/featItemInfoTotalSize); + } + auto featureId = transArr[i + featureIdIdxOffset]; + auto count = transArr[i + countIdxOffset]; + auto lastTime = transArr[i + lastTimeIdxOffset]; + + histRecs.emplace(featureId, FeatureItemInfo(static_cast(count), lastTime)); } + VLOG(GLOG_DEBUG) << StringFormat("====End SetHistRec, name: %s", embName.c_str()); } -int FeatAdmitNEvictCkpt::GetTens2ThreshSize() +int FeatAdmitNEvictCkpt::GetTable2ThreshSize() { auto& attribute = transferData.attribute; auto& attribSize = transferData.attributeSize; diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h index 2b12d315..268120c3 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h @@ -31,7 +31,7 @@ namespace MxRec { private: const vector fileDirNames { "HashTable", "DDR" }; - const vector saveDataTypes { CkptDataType::TENSOR_2_THRESH, CkptDataType::HIST_REC }; + const vector saveDataTypes { CkptDataType::TABLE_2_THRESH, CkptDataType::HIST_REC }; const int featItemInfoSaveNum { 3 }; const int threshValSaveNum { 2 }; @@ -45,21 +45,21 @@ namespace MxRec { const int countIdxOffset { 1 }; const int lastTimeIdxOffset { 2 }; - tensor_2_thresh_mem_t saveTens2Thresh; - tensor_2_thresh_mem_t loadTens2Thresh; + table_2_thresh_mem_t saveTable2Thresh; + table_2_thresh_mem_t loadTable2Thresh; AdmitAndEvictData saveHistRec; AdmitAndEvictData loadHistRec; void ClearData(); - void SetTens2ThreshTrans(string embName); + void SetTable2ThreshTrans(string embName); void SetHistRecTrans(string embName); - void SetTens2Thresh(string embName); + void SetTable2Thresh(string embName); void SetHistRec(string embName); - int GetTens2ThreshSize(); + int GetTable2ThreshSize(); size_t GetHistRecSize(string embName); }; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 0f8b72f0..cd400a52 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -150,7 +150,7 @@ bool HybridMgmt::Save(const string savePath) auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: feature admit and evict"); - saveData.tens2Thresh = featAdmitNEvict.GetTensorThresholds(); + saveData.table2Thresh = featAdmitNEvict.GetTableThresholds(); saveData.histRec.timestamps = featAdmitNEvict.GetHistoryRecords().timestamps; saveData.histRec.historyRecords = featAdmitNEvict.GetHistoryRecords().historyRecords; } @@ -201,7 +201,7 @@ bool HybridMgmt::Load(const string& loadPath) } if (featAdmitNEvict.GetFunctionSwitch()) { VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: feature admit and evict"); - featAdmitNEvict.LoadTensorThresholds(loadData.tens2Thresh); + featAdmitNEvict.LoadTableThresholds(loadData.table2Thresh); featAdmitNEvict.LoadHistoryRecords(loadData.histRec); } diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 29b2da1b..80db3003 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -31,6 +31,7 @@ bool FeatureAdmitAndEvict::Init(const std::vector& thresholdValu LOG(ERROR) << "Config is error, feature admin-and-evict function is not available ...\n"; return false; } + SetCombineSwitch(); return true; } @@ -44,24 +45,27 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR; } - // 如果当前 tensorName 不在准入范围之内,则不进行“特征准入”逻辑 - std::string tensorName = batch->name; + std::string tableName = batch->name; + if (m_isCombine) { + tableName = COMBINE_HISTORY_NAME; + } absl::flat_hash_map mergeKeys; mergeKeys.reserve(splitKey.size()); PreProcessKeys(splitKey, keyCount, mergeKeys); std::lock_guard lock(m_syncMutexs); - auto iter = m_recordsData.historyRecords.find(tensorName); - if (iter == m_recordsData.historyRecords.end()) { // 之前tensorName没出现过时,数据初始化 + // 如果当前 tableName 不在准入范围之内,则不进行“特征准入”逻辑 + auto iter = m_recordsData.historyRecords.find(tableName); + if (iter == m_recordsData.historyRecords.end()) { // 之前tableName没出现过时,数据初始化 absl::flat_hash_map records(m_recordsInitSize); - m_recordsData.historyRecords[tensorName] = records; + m_recordsData.historyRecords[tableName] = records; } VLOG(GLOG_DEBUG) << StringFormat( - "FeatureAdmitAndEvict PrintSize, name:[%s], history key:[%d] ...", tensorName.c_str(), - m_recordsData.historyRecords[tensorName].size()); + "FeatureAdmitAndEvict PrintSize, name:[%s], history key:[%d] ...", tableName.c_str(), + m_recordsData.historyRecords[tableName].size()); - if (batch->timestamp > m_recordsData.timestamps[tensorName]) { - m_recordsData.timestamps[tensorName] = batch->timestamp; + if (batch->timestamp > m_recordsData.timestamps[tableName]) { + m_recordsData.timestamps[tableName] = batch->timestamp; } absl::flat_hash_map visitedRecords; for (auto& key : splitKey) { @@ -73,7 +77,7 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, auto it = visitedRecords.find(key); if (it == visitedRecords.end()) { visitedRecords[key] = true; - if (FeatureAdmitHelper(channel, tensorName, key, mergeKeys[key]) == + if (FeatureAdmitHelper(channel, batch->name, key, mergeKeys[key]) == FeatureAdmitType::FEATURE_ADMIT_FAILED) { visitedRecords[key] = false; key = -1; // 被淘汰的Feature ID @@ -87,32 +91,38 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, } if (VLOG_IS_ON(GLOG_TRACE)) { VLOG(GLOG_TRACE) << StringFormat( - "FeatureAdmit, name:[%s], channel:[%d], after admit, splitKey:[%s] ...", tensorName.c_str(), channel, + "FeatureAdmit, name:[%s], channel:[%d], after admit, splitKey:[%s] ...", tableName.c_str(), channel, VectorToString(splitKey).c_str()); } return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; } -FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, const std::string& tensorName, +FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, const std::string& tableNameOrigin, const int64_t featureId, const uint32_t featureCnt) { // “特征准入”逻辑 uint32_t currKeyCount = 0; - absl::flat_hash_map& historyRecordInfos = m_recordsData.historyRecords[tensorName]; + std::string tableName = tableNameOrigin; + if (m_isCombine) { + tableName = COMBINE_HISTORY_NAME; + } + + absl::flat_hash_map& historyRecordInfos = m_recordsData.historyRecords[tableName]; auto innerIt = historyRecordInfos.find(featureId); if (channel == TRAIN_CHANNEL_ID) { if (innerIt == historyRecordInfos.end()) { // 维护 m_historyRecords - FeatureItemInfo info(featureCnt, m_recordsData.timestamps[tensorName]); + FeatureItemInfo info(featureCnt, m_recordsData.timestamps[tableName]); + info.count *= m_table2Threshold[tableNameOrigin].faaeCoefficient; historyRecordInfos[featureId] = info; - currKeyCount = featureCnt; + currKeyCount = info.count; } else { // 维护 m_historyRecords FeatureItemInfo &info = historyRecordInfos[featureId]; - info.count += featureCnt; - info.lastTime = m_recordsData.timestamps[tensorName]; + info.count += m_table2Threshold[tableNameOrigin].faaeCoefficient * featureCnt; + info.lastTime = m_recordsData.timestamps[tableName]; currKeyCount = info.count; } } else if (channel == EVAL_CHANNEL_ID) { // eval @@ -122,7 +132,7 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con } // 准入条件判断 - if (currKeyCount >= static_cast(m_tensor2Threshold[tensorName].countThreshold)) { + if (currKeyCount >= static_cast(m_table2Threshold[tableNameOrigin].countThreshold)) { return FeatureAdmitType::FEATURE_ADMIT_OK; } @@ -132,8 +142,8 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con // 特征淘汰接口 void FeatureAdmitAndEvict::FeatureEvict(map>& evictKeyMap) { - std::vector tensorNames = GetAllNeedEvictTensorNames(); - if (tensorNames.empty()) { + std::vector tableNames = GetAllNeedEvictTableNames(); + if (tableNames.empty()) { LOG(INFO) << "EmbNames is empty, no evict function ..."; return ; } @@ -143,9 +153,9 @@ void FeatureAdmitAndEvict::FeatureEvict(map> } std::lock_guard lock(m_syncMutexs); // 从 m_historyRecords 中淘汰删除 - size_t tensorCnt = tensorNames.size(); - for (size_t i = 0; i < tensorCnt; ++i) { - FeatureEvictHelper(tensorNames[i], evictKeyMap[tensorNames[i]]); + size_t tableCnt = tableNames.size(); + for (size_t i = 0; i < tableCnt; ++i) { + FeatureEvictHelper(tableNames[i], evictKeyMap[tableNames[i]]); } } @@ -153,7 +163,7 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v { // 从 m_historyRecords 中淘汰删除 time_t currTime = m_recordsData.timestamps[embName]; - // 从 m_tensor2SortedLastTime 获取当前要淘汰的featureId + // 从 m_table2SortedLastTime 获取当前要淘汰的featureId auto cmp = [](const auto& a, const auto& b) { return a.second.lastTime > b.second.lastTime; }; std::priority_queue, std::vector>, decltype(cmp)> lastTimePriority(cmp); @@ -161,7 +171,7 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v lastTimePriority.emplace(item); } while (!lastTimePriority.empty()) { - if (currTime - lastTimePriority.top().second.lastTime < m_tensor2Threshold[embName].timeThreshold) { + if (currTime - lastTimePriority.top().second.lastTime < m_table2Threshold[embName].timeThreshold) { break; } evictKey.emplace_back(lastTimePriority.top().first); @@ -170,11 +180,11 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v if (evictKey.size() == 0) { LOG(INFO) << StringFormat( - "tensor-name[%s]'s lastTime[%d], had no key to delete ...", embName.c_str(), currTime); + "table-name[%s]'s lastTime[%d], had no key to delete ...", embName.c_str(), currTime); return; } LOG(INFO) << StringFormat( - "tensor-name[%s]'s lastTime[%d], had size[%d] keys to delete ...", embName.c_str(), currTime, evictKey.size()); + "table-name[%s]'s lastTime[%d], had size[%d] keys to delete ...", embName.c_str(), currTime, evictKey.size()); // 真正从 m_historyRecords 中淘汰 absl::flat_hash_map& historyRecords = m_recordsData.historyRecords[embName]; @@ -196,6 +206,12 @@ void FeatureAdmitAndEvict::SetFunctionSwitch(bool isEnableEvict) } m_isEnableFunction = isEnableEvict; } + +void FeatureAdmitAndEvict::SetCombineSwitch() +{ + m_isCombine = GetCombineSwitch(); +} + bool FeatureAdmitAndEvict::GetFunctionSwitch() const { return m_isEnableFunction; @@ -222,10 +238,10 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t const std::vector& embNames, bool isTimestamp) { for (size_t i = 0; i < thresholds.size(); ++i) { - auto it = std::find(embNames.begin(), embNames.end(), thresholds[i].tensorName); + auto it = std::find(embNames.begin(), embNames.end(), thresholds[i].tableName); if (it == embNames.end()) { // 配置不存在于当前跑的模型,也要报错 LOG(ERROR) << StringFormat( - "embName[%s] is not exist at current model ...", thresholds[i].tensorName.c_str()); + "embName[%s] is not exist at current model ...", thresholds[i].tableName.c_str()); return false; } else { // 同时支持“准入&淘汰”,却没有传时间戳 @@ -242,10 +258,10 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t return true; } -auto FeatureAdmitAndEvict::GetTensorThresholds() -> tensor_2_thresh_mem_t +auto FeatureAdmitAndEvict::GetTableThresholds() -> table_2_thresh_mem_t { std::lock_guard lock(m_syncMutexs); - return m_tensor2Threshold; + return m_table2Threshold; } auto FeatureAdmitAndEvict::GetHistoryRecords() -> AdmitAndEvictData& @@ -254,10 +270,10 @@ auto FeatureAdmitAndEvict::GetHistoryRecords() -> AdmitAndEvictData& return m_recordsData; } -void FeatureAdmitAndEvict::LoadTensorThresholds(tensor_2_thresh_mem_t& loadData) +void FeatureAdmitAndEvict::LoadTableThresholds(table_2_thresh_mem_t& loadData) { std::lock_guard lock(m_syncMutexs); - m_tensor2Threshold = std::move(loadData); + m_table2Threshold = std::move(loadData); } void FeatureAdmitAndEvict::LoadHistoryRecords(AdmitAndEvictData& loadData) @@ -266,7 +282,7 @@ void FeatureAdmitAndEvict::LoadHistoryRecords(AdmitAndEvictData& loadData) m_recordsData = std::move(loadData); } -// 解析m_tensor2Threshold +// 解析m_table2Threshold bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& thresholdValues) { if (thresholdValues.empty()) { @@ -277,23 +293,26 @@ bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& m_cfgThresholds = thresholdValues; for (const auto& value : thresholdValues) { LOG(INFO) << StringFormat( - "embName[%s], count[%d], time[%d] ...", - value.tensorName.c_str(), value.countThreshold, value.timeThreshold); - auto it = m_tensor2Threshold.find(value.tensorName); - if (it != m_tensor2Threshold.end()) { + "embName[%s], count[%d], time[%d], coefficient[%d] ...", + value.tableName.c_str(), value.countThreshold, value.timeThreshold, value.faaeCoefficient); + auto it = m_table2Threshold.find(value.tableName); + if (it != m_table2Threshold.end()) { // train和eval同时开启,会出现表重复配置 - LOG(INFO) << StringFormat("[%s] is repeated configuration ...", value.tensorName.c_str()); + LOG(INFO) << StringFormat("[%s] is repeated configuration ...", value.tableName.c_str()); return true; } - m_tensor2Threshold[value.tensorName] = value; + m_table2Threshold[value.tableName] = value; + if (value.faaeCoefficient < 1) { + LOG(ERROR) << StringFormat("[%s] config error, coefficient smaller than 1 ...", value.tableName.c_str()); + } if (value.countThreshold != -1 && value.timeThreshold != -1) { - m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_BOTH; + m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_BOTH; } else if (value.countThreshold != -1 && value.timeThreshold == -1) { - m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ONLY_ADMIT; + m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_ONLY_ADMIT; } else { - LOG(ERROR) << StringFormat("[%s] config error, have evict but no admit ...", value.tensorName.c_str()); - m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ERROR; + LOG(ERROR) << StringFormat("[%s] config error, have evict but no admit ...", value.tableName.c_str()); + m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_ERROR; return false; } } @@ -301,7 +320,7 @@ bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& return true; } -std::vector FeatureAdmitAndEvict::GetAllNeedEvictTensorNames() +std::vector FeatureAdmitAndEvict::GetAllNeedEvictTableNames() { std::vector names; std::lock_guard lock(m_syncMutexs); diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 5dfcc3e0..83e5a392 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -58,10 +58,13 @@ namespace MxRec { // 特征淘汰接口 void FeatureEvict(map>& evictKeyMap); void ExecuteFeatureAdmit( - const string& tensorName, int channel, keys_t& splitKey, absl::flat_hash_map& mergeKeys); + const string& tableName, int channel, keys_t& splitKey, absl::flat_hash_map& mergeKeys); // 特征淘汰的使能接口 void SetFunctionSwitch(bool isEnableEvict); + // 特征准入合表统计 + void SetCombineSwitch(); + bool GetFunctionSwitch() const; void PreProcessKeys(const std::vector& splitKey, std::vector& keyCount, absl::flat_hash_map& mergeKeys); @@ -71,10 +74,10 @@ namespace MxRec { const std::vector& embNames, bool isTimestamp); // 与模型保存加载交互的接口 - auto GetTensorThresholds() -> tensor_2_thresh_mem_t; + auto GetTableThresholds() -> table_2_thresh_mem_t; auto GetHistoryRecords() -> AdmitAndEvictData&; - void LoadTensorThresholds(tensor_2_thresh_mem_t& loadData); + void LoadTableThresholds(table_2_thresh_mem_t& loadData); void LoadHistoryRecords(AdmitAndEvictData& loadData); static std::vector m_cfgThresholds; // 用于判断阈值配置的有效性 @@ -82,17 +85,18 @@ namespace MxRec { GTEST_PRIVATE : - // 解析m_tensor2Threshold + // 解析m_table2Threshold bool ParseThresholdCfg(const std::vector& thresholdValues); - std::vector GetAllNeedEvictTensorNames(); - FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tensorName, + std::vector GetAllNeedEvictTableNames(); + FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tableName, const int64_t featureId, const uint32_t featureCnt); void FeatureEvictHelper(const std::string& embName, std::vector& evictKey); void ResetAllRecords(); bool m_isEnableFunction { true }; // “特征淘汰”的使能开关 bool m_isExit { false }; // 淘汰线程退出的标识 - absl::flat_hash_map m_tensor2Threshold; // tensor-X ---> ThresholdValue 映射 + bool m_isCombine { true }; // 是否合并统计history + absl::flat_hash_map m_table2Threshold; // table-X ---> ThresholdValue 映射 AdmitAndEvictData m_recordsData; std::mutex m_syncMutexs; // 特征准入与特征淘汰竞争的同步锁 int m_recordsInitSize { DEFAULT_RECORDS_INIT_SIZE }; // m_historyRecords表初始容量 diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 706c16ed..d3ddbc2a 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -141,4 +141,20 @@ namespace MxRec { throw std::runtime_error("dsmi_get_chip_info failed, ret = " + to_string(ret)); } + + bool GetCombineSwitch() + { + const char* faaeMode = std::getenv("FAAE_MODE"); // 获取环境变量 + bool isCombine = true; + if (faaeMode != nullptr) { + try { + isCombine = (std::stoi(faaeMode) == 0); + LOG(INFO) << StringFormat("If combine history table: %d", isCombine); + } catch (const std::invalid_argument& e) { + LOG(ERROR) << "The value of FAAE_MODE is invalid!"; + throw std::invalid_argument("Invalid env value FAAE_MODE"); + } + } + return isCombine; + } } // end namespace MxRec diff --git a/src/core/utils/common.h b/src/core/utils/common.h index f51d50c0..60d4a97b 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -88,6 +88,8 @@ namespace MxRec { constexpr int HOT_EMB_UPDATE_STEP_DEFAULT = 1000; constexpr float HOT_EMB_CACHE_PCT = static_cast(1. / 3); // hot emb cache percent + const string COMBINE_HISTORY_NAME = "combine_table_history"; + using emb_key_t = int64_t; using emb_name_t = std::string; using keys_t = std::vector; @@ -102,6 +104,7 @@ namespace MxRec { }; string GetChipName(int devID); + bool GetCombineSwitch(); namespace UBSize { const int ASCEND910_PREMIUM_A = 262144; @@ -239,16 +242,18 @@ struct BatchTask { struct ThresholdValue { ThresholdValue() = default; - ThresholdValue(emb_name_t name, int countThre, int timeThre) + ThresholdValue(emb_name_t name, int countThre, int timeThre, int faaeCoef) { - tensorName = name; + tableName = name; countThreshold = countThre; timeThreshold = timeThre; + faaeCoefficient = faaeCoef; } - emb_name_t tensorName { "" }; // embName + emb_name_t tableName { "" }; // embName int countThreshold { -1 }; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 int timeThreshold { -1 }; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 + int faaeCoefficient { 1 }; // 配置后,该表在准入时,count计数会乘以该系数 }; struct FeatureItemInfo { @@ -492,7 +497,7 @@ struct BatchTask { using emb_hash_mem_t = absl::flat_hash_map; using offset_mem_t = std::map; using key_offset_mem_t = std::map>; - using tensor_2_thresh_mem_t = absl::flat_hash_map; + using table_2_thresh_mem_t = absl::flat_hash_map; using trans_serialize_t = uint8_t; using key_offset_map_t = std::map; using all_key_offset_map_t = std::map>; @@ -510,7 +515,7 @@ struct BatchTask { emb_hash_mem_t embHashMaps; offset_mem_t maxOffset; key_offset_mem_t keyOffsetMap; - tensor_2_thresh_mem_t tens2Thresh; + table_2_thresh_mem_t table2Thresh; AdmitAndEvictData histRec; }; @@ -532,7 +537,7 @@ struct BatchTask { EMB_CURR_STAT = 4, NDDR_OFFSET = 5, NDDR_FEATMAP = 6, - TENSOR_2_THRESH = 7, + TABLE_2_THRESH = 7, HIST_REC = 8, ATTRIBUTE = 9 }; diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index b81e5c6d..9c1b5fb2 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -169,8 +169,9 @@ void GetHybridMgmt(pybind11::module_& m) void GetThresholdValue(pybind11::module_& m) { pybind11::class_(m, "ThresholdValue") - .def(pybind11::init()) - .def_readwrite("tensor_name", &ThresholdValue::tensorName) + .def(pybind11::init()) + .def_readwrite("table_name", &ThresholdValue::tableName) .def_readwrite("count_threshold", &ThresholdValue::countThreshold) - .def_readwrite("time_threshold", &ThresholdValue::timeThreshold); + .def_readwrite("time_threshold", &ThresholdValue::timeThreshold) + .def_readwrite("faae_coefficient", &ThresholdValue::faaeCoefficient); } diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 33df7b23..1cf27378 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -175,17 +175,18 @@ protected: } } - void SetTens2Threshold(tensor_2_thresh_mem_t& testTens2Threshold) + void SetTable2Threshold(table_2_thresh_mem_t& testTable2Threshold) { for (const auto& testEmbInfo : testEmbInfos) { ThresholdValue val; - val.tensorName = testEmbInfo.name; + val.tableName = testEmbInfo.name; val.countThreshold = offsetMem; val.timeThreshold = offsetMem; + val.faaeCoefficient = 1; offsetMem++; - testTens2Threshold[testEmbInfo.name] = move(val); + testTable2Threshold[testEmbInfo.name] = move(val); } } @@ -214,6 +215,30 @@ protected: timeStamp++; } } + + void SetHistRecCombine(AdmitAndEvictData& histRec) + { + int64_t featureId { int64Min }; + int count { 1 }; + time_t lastTime { 1000 }; + time_t timeStamp { 10000 }; + + auto& historyRecords { histRec.historyRecords[COMBINE_HISTORY_NAME] }; + auto& timestamps { histRec.timestamps[COMBINE_HISTORY_NAME] }; + + timestamps = timeStamp; + + for (int i = 0; i < count; ++i) { + historyRecords[featureId].count = count; + historyRecords[featureId].lastTime = lastTime; + + featureId++; + } + + count++; + lastTime++; + timeStamp++; + } }; TEST_F(CheckpointTest, HostEmbs) @@ -403,25 +428,31 @@ TEST_F(CheckpointTest, AllMgmt) TEST_F(CheckpointTest, FeatAdmitNEvict) { - tensor_2_thresh_mem_t testTrens2Thresh; - tensor_2_thresh_mem_t validTrens2Thresh; + table_2_thresh_mem_t testTrens2Thresh; + table_2_thresh_mem_t validTrens2Thresh; AdmitAndEvictData testHistRec; AdmitAndEvictData validHistRec; SetEmbInfo(); - SetTens2Threshold(testTrens2Thresh); + SetTable2Threshold(testTrens2Thresh); validTrens2Thresh = testTrens2Thresh; - SetHistRec(testHistRec); + + if (GetCombineSwitch()) { + SetHistRecCombine(testHistRec); + } else { + SetHistRec(testHistRec); + } + validHistRec = testHistRec; CkptData testSaveData; CkptData validLoadData; CkptData testLoadData; - testSaveData.tens2Thresh = testTrens2Thresh; + testSaveData.table2Thresh = testTrens2Thresh; testSaveData.histRec.timestamps = testHistRec.timestamps; testSaveData.histRec.historyRecords = testHistRec.historyRecords; - validLoadData.tens2Thresh = validTrens2Thresh; + validLoadData.table2Thresh = validTrens2Thresh; validLoadData.histRec = validHistRec; validLoadData.histRec.timestamps = validHistRec.timestamps; validLoadData.histRec.historyRecords = validHistRec.historyRecords; @@ -430,16 +461,16 @@ TEST_F(CheckpointTest, FeatAdmitNEvict) testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::FEAT_ADMIT_N_EVICT }); - EXPECT_EQ(validLoadData.tens2Thresh.size(), testLoadData.tens2Thresh.size()); + EXPECT_EQ(validLoadData.table2Thresh.size(), testLoadData.table2Thresh.size()); EXPECT_EQ(validLoadData.histRec.historyRecords.size(), testLoadData.histRec.historyRecords.size()); - for (const auto& it : validLoadData.tens2Thresh) { - EXPECT_EQ(1, testLoadData.tens2Thresh.count(it.first)); + for (const auto& it : validLoadData.table2Thresh) { + EXPECT_EQ(1, testLoadData.table2Thresh.count(it.first)); - const auto& tens2Thresh = testLoadData.tens2Thresh.at(it.first); + const auto& table2Thresh = testLoadData.table2Thresh.at(it.first); - EXPECT_EQ(it.second.tensorName, tens2Thresh.tensorName); - EXPECT_EQ(it.second.countThreshold, tens2Thresh.countThreshold); - EXPECT_EQ(it.second.timeThreshold, tens2Thresh.timeThreshold); + EXPECT_EQ(it.second.tableName, table2Thresh.tableName); + EXPECT_EQ(it.second.countThreshold, table2Thresh.countThreshold); + EXPECT_EQ(it.second.timeThreshold, table2Thresh.timeThreshold); } for (const auto& it : validLoadData.histRec.timestamps) { diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 6caace33..55babf89 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -22,6 +22,17 @@ using valid_int64_t = absl::flat_hash_map>; using valie_dataset_t = absl::flat_hash_map>; using valid_attrib_t = absl::flat_hash_map>; +struct InputArgs { + vector& embNames; + CkptData& validData; + FeatAdmitNEvictCkpt& testCkpt; + valid_int_t& validTrens2ThreshArr; + valid_attrib_t& validTrens2ThreshAttrib; + valid_attrib_t& validHistRecAttrib; + valid_int64_t& validHistRecArr; + CkptData& testData; +}; + class CkptDataHandlerTest : public testing::Test { protected: int floatBytes { 4 }; @@ -68,7 +79,7 @@ protected: } } - void SetTens2Threshold(tensor_2_thresh_mem_t& testTens2Threshold, + void SetTable2Threshold(table_2_thresh_mem_t& testTable2Threshold, valid_int_t& validArr, valid_attrib_t& validAttrib) { @@ -77,7 +88,7 @@ protected: for (const auto& testEmbInfo : testEmbInfos) { ThresholdValue val; - val.tensorName = testEmbInfo.name; + val.tableName = testEmbInfo.name; val.countThreshold = countThreshold; val.timeThreshold = timeThreshold; @@ -86,7 +97,7 @@ protected: countThreshold++; timeThreshold++; - testTens2Threshold[testEmbInfo.name] = move(val); + testTable2Threshold[testEmbInfo.name] = move(val); validArr[testEmbInfo.name] = move(valid); validAttrib[testEmbInfo.name].push_back(2); // 2 is element num in one vector validAttrib[testEmbInfo.name].push_back(int32Bytes); @@ -128,12 +139,114 @@ protected: timeStamp++; } } + + void SetHistRecCombine(AdmitAndEvictData& histRec, valid_int64_t& validArr, valid_attrib_t& validAttrib) + { + int64_t featureId { int64Min }; + int count { 1 }; + time_t lastTime { 1000 }; + time_t timeStamp { 10000 }; + + auto& validA { validArr[COMBINE_HISTORY_NAME] }; + auto& historyRecords { histRec.historyRecords[COMBINE_HISTORY_NAME] }; + auto& timestamps { histRec.timestamps[COMBINE_HISTORY_NAME] }; + + timestamps = timeStamp; + validA.push_back(timeStamp); + + for (int i = 0; i < count; ++i) { + historyRecords[featureId].count = count; + historyRecords[featureId].lastTime = lastTime; + + validA.push_back(featureId); + validA.push_back(count); + validA.push_back(lastTime); + + featureId++; + } + + auto& attribute = validAttrib[COMBINE_HISTORY_NAME]; + attribute.push_back(count); + attribute.push_back(int64Bytes); + + count++; + lastTime++; + timeStamp++; + } + + void TestForSave(InputArgs& args) + { + for (const auto& embName : args.embNames) { + EXPECT_EQ(1, args.validData.table2Thresh.count(embName)); + + CkptTransData testSaveData = args.testCkpt.GetDataset(CkptDataType::TABLE_2_THRESH, embName); + EXPECT_EQ(args.validTrens2ThreshArr.at(embName), testSaveData.int32Arr); // need other test method + EXPECT_EQ(args.validTrens2ThreshAttrib.at(embName), testSaveData.attribute); + testSaveData = args.testCkpt.GetDataset(CkptDataType::HIST_REC, embName); + + if (!GetCombineSwitch()) { + EXPECT_EQ(1, args.validData.histRec.timestamps.count(embName)); + EXPECT_EQ(1, args.validData.histRec.historyRecords.count(embName)); + EXPECT_EQ(args.validHistRecAttrib.at(embName), testSaveData.attribute); + } else { + EXPECT_EQ(1, args.validData.histRec.timestamps.count(COMBINE_HISTORY_NAME)); + EXPECT_EQ(1, args.validData.histRec.historyRecords.count(COMBINE_HISTORY_NAME)); + EXPECT_EQ(args.validHistRecAttrib.at(COMBINE_HISTORY_NAME), testSaveData.attribute); + } + } + } + void TestForLoad(InputArgs& args) + { + CkptTransData testLoadData; + for (const auto& embName : args.embNames) { + testLoadData.int32Arr = args.validTrens2ThreshArr.at(embName); + testLoadData.attribute = args.validTrens2ThreshAttrib.at(embName); + args.testCkpt.SetDataset(CkptDataType::TABLE_2_THRESH, embName, testLoadData); + + if (!GetCombineSwitch()) { + testLoadData.int64Arr = args.validHistRecArr.at(embName); + testLoadData.attribute = args.validHistRecAttrib.at(embName); + } else { + testLoadData.int64Arr = args.validHistRecArr.at(COMBINE_HISTORY_NAME); + testLoadData.attribute = args.validHistRecAttrib.at(COMBINE_HISTORY_NAME); + } + args.testCkpt.SetDataset(CkptDataType::HIST_REC, embName, testLoadData); + } + args.testCkpt.GetProcessData(args.testData); + + EXPECT_EQ(args.validData.table2Thresh.size(), args.testData.table2Thresh.size()); + EXPECT_EQ(args.validData.histRec.historyRecords.size(), args.testData.histRec.historyRecords.size()); + for (const auto& it : args.validData.table2Thresh) { + EXPECT_EQ(1, args.testData.table2Thresh.count(it.first)); + + const auto& table2Thresh = args.testData.table2Thresh.at(it.first); + + EXPECT_EQ(it.second.tableName, table2Thresh.tableName); + EXPECT_EQ(it.second.countThreshold, table2Thresh.countThreshold); + EXPECT_EQ(it.second.timeThreshold, table2Thresh.timeThreshold); + } + + for (const auto& it : args.validData.histRec.timestamps) { + EXPECT_EQ(1, args.testData.histRec.timestamps.count(it.first)); + EXPECT_EQ(1, args.testData.histRec.historyRecords.count(it.first)); + + const auto& historyRecords = args.testData.histRec.historyRecords.at(it.first); + const auto& validHistRec = args.validData.histRec.historyRecords.at(it.first); + + for (const auto& validHR : validHistRec) { + const auto& testHR = historyRecords.at(validHR.first); + + EXPECT_EQ(validHR.second.count, testHR.count); + EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); + } + } + } }; TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) { - tensor_2_thresh_mem_t testTrens2Thresh; - tensor_2_thresh_mem_t validTrens2Thresh; + table_2_thresh_mem_t testTrens2Thresh; + table_2_thresh_mem_t validTrens2Thresh; AdmitAndEvictData testHistRec; AdmitAndEvictData validHistRec; @@ -143,77 +256,37 @@ TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) valid_attrib_t validHistRecAttrib; SetEmbInfo(); - SetTens2Threshold(testTrens2Thresh, validTrens2ThreshArr, validTrens2ThreshAttrib); + SetTable2Threshold(testTrens2Thresh, validTrens2ThreshArr, validTrens2ThreshAttrib); validTrens2Thresh = testTrens2Thresh; - SetHistRec(testHistRec, validHistRecArr, validHistRecAttrib); + + if (GetCombineSwitch()) { + SetHistRecCombine(testHistRec, validHistRecArr, validHistRecAttrib); + } else { + SetHistRec(testHistRec, validHistRecArr, validHistRecAttrib); + } validHistRec = testHistRec; CkptData testData; CkptData validData; FeatAdmitNEvictCkpt testCkpt; - testData.tens2Thresh = testTrens2Thresh; + testData.table2Thresh = testTrens2Thresh; testData.histRec.timestamps = testHistRec.timestamps; testData.histRec.historyRecords = testHistRec.historyRecords; - validData.tens2Thresh = validTrens2Thresh; + validData.table2Thresh = validTrens2Thresh; validData.histRec.timestamps = validHistRec.timestamps; validData.histRec.historyRecords = validHistRec.historyRecords; testCkpt.SetProcessData(testData); vector embNames { testCkpt.GetEmbNames() }; - CkptTransData testSaveData; - EXPECT_EQ(validData.tens2Thresh.size(), embNames.size()); - - for (const auto& embName : embNames) { - EXPECT_EQ(1, validData.tens2Thresh.count(embName)); - - EXPECT_EQ(1, validData.histRec.timestamps.count(embName)); - EXPECT_EQ(1, validData.histRec.historyRecords.count(embName)); + EXPECT_EQ(validData.table2Thresh.size(), embNames.size()); - testSaveData = testCkpt.GetDataset(CkptDataType::TENSOR_2_THRESH, embName); - EXPECT_EQ(validTrens2ThreshArr.at(embName), testSaveData.int32Arr); // need other test method - EXPECT_EQ(validTrens2ThreshAttrib.at(embName), testSaveData.attribute); - testSaveData = testCkpt.GetDataset(CkptDataType::HIST_REC, embName); - EXPECT_EQ(validHistRecAttrib.at(embName), testSaveData.attribute); - } - - CkptTransData testLoadData; - for (const auto& embName : embNames) { - testLoadData.int32Arr = validTrens2ThreshArr.at(embName); - testLoadData.attribute = validTrens2ThreshAttrib.at(embName); - testCkpt.SetDataset(CkptDataType::TENSOR_2_THRESH, embName, testLoadData); - - testLoadData.int64Arr = validHistRecArr.at(embName); - testLoadData.attribute = validHistRecAttrib.at(embName); - testCkpt.SetDataset(CkptDataType::HIST_REC, embName, testLoadData); - } - testCkpt.GetProcessData(testData); + InputArgs args = {embNames, validData, testCkpt, validTrens2ThreshArr, validTrens2ThreshAttrib, + validHistRecAttrib, validHistRecArr, testData}; + // 测试save + TestForSave(args); - EXPECT_EQ(validData.tens2Thresh.size(), testData.tens2Thresh.size()); - EXPECT_EQ(validData.histRec.historyRecords.size(), testData.histRec.historyRecords.size()); - for (const auto& it : validData.tens2Thresh) { - EXPECT_EQ(1, testData.tens2Thresh.count(it.first)); - - const auto& tens2Thresh = testData.tens2Thresh.at(it.first); - - EXPECT_EQ(it.second.tensorName, tens2Thresh.tensorName); - EXPECT_EQ(it.second.countThreshold, tens2Thresh.countThreshold); - EXPECT_EQ(it.second.timeThreshold, tens2Thresh.timeThreshold); - } - - for (const auto& it : validData.histRec.timestamps) { - EXPECT_EQ(1, testData.histRec.timestamps.count(it.first)); - EXPECT_EQ(1, testData.histRec.historyRecords.count(it.first)); - - const auto& historyRecords = testData.histRec.historyRecords.at(it.first); - const auto& validHistRec = validData.histRec.historyRecords.at(it.first); - - for (const auto& validHR : validHistRec) { - const auto& testHR = historyRecords.at(validHR.first); - - EXPECT_EQ(validHR.second.count, testHR.count); - EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); - } - } + // 测试load + TestForLoad(args); } \ No newline at end of file diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 752db8e2..a570915d 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -179,7 +179,7 @@ TEST_F(EmbMgmtTest, Initialize_HBM) embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1); + thresholdValues.emplace_back(name, 1, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; @@ -199,7 +199,7 @@ TEST_F(EmbMgmtTest, Evict) embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1); + thresholdValues.emplace_back(name, 1, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; @@ -222,7 +222,7 @@ TEST_F(EmbMgmtTest, Evict_HBM) embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1); + thresholdValues.emplace_back(name, 1, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index 31bfdd1e..e94fe455 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -181,48 +181,48 @@ protected: printf("\t############# [%s] tid[%lu] ############# begin ...\n", thrName.c_str(), std::hash{}(std::this_thread::get_id())); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys1 = {11, 11, 33, 44, 11, 55, 88, 55} cnt1 = 1 2 1 3 1 1 4 1 */ InputArgs args1 = {keys1, cnt1, {}, initHistory, {}}; // 每个表的第一次记录,要用initHistory追加 - FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args1); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tableName, args1); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys2 = {11, 12, 33, 21, 11, 12} cnt2 = 1 2 1 1 2 3 */ InputArgs args2 = {keys2, cnt2, {}, args1.expectHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args2); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tableName, args2); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tensorBBB", 3, 7} + {"tableBBB", 3, 7} keys3 = {123, 121, 121, 212, 211} cnt3 = 1 2 1 1 2 */ InputArgs args3 = {keys3, cnt3, {}, initHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tensorName, args3); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tableName, args3); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_6)); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys4 = {11, 11, 33, 44, 55, 88, 55} cnt4 = 1 2 3 2 1 2 1 */ InputArgs args4 = {keys4, cnt4, {}, args2.expectHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args4); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tableName, args4); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tensorBBB", 3, 7} + {"tableBBB", 3, 7} keys5 = {125, 121, 122, 212, 211} cnt5 = 1 2 1 3 1 */ InputArgs args5 = {keys5, cnt5, {}, args3.expectHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tensorName, args5); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tableName, args5); printf("\t############# [%s] tid[%lu] ############# end ...\n", thrName.c_str(), std::hash{}(std::this_thread::get_id())); @@ -263,58 +263,59 @@ protected: { faae.ResetAllRecords(); faae.ParseThresholdCfg(thresholds); + faae.SetCombineSwitch(); StartEvictThread(); printf("Current test single-thread is [%lu]\n", std::hash{}(std::this_thread::get_id())); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys1 = {11, 11, 33, 44, 11, 55, 88, 55} cnt1 = 1 2 1 3 1 1 4 1 */ keys_t expectRet1 = {11, 11, -1, 44, 11, 55, 88, 55}; InputArgs args1 = {keys1, cnt1, expectRet1, initHistory, {}}; // 每个表的第一次记录,要用initHistory追加 - FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args1); + FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args1); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys2 = {11, 12, 33, 21, 11, 12} cnt2 = 1 2 1 1 2 3 */ keys_t expectRet2 = {11, 12, 33, -1, 11, 12}; InputArgs args2 = {keys2, cnt2, expectRet2, args1.expectHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args2); + FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args2); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tensorBBB", 3, 7} + {"tableBBB", 3, 7} keys3 = {123, 121, 121, 212, 211} cnt3 = 1 2 1 1 2 */ keys_t expectRet3 = {-1, 121, 121, -1, -1}; InputArgs args3 = {keys3, cnt3, expectRet3, initHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[1].tensorName, args3); + FeatureAdmitCommon(faae, 0, thresholds[1].tableName, args3); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_6)); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys4 = {11, 11, 33, 44, 55, 88, 55} cnt4 = 1 2 3 2 1 2 1 */ keys_t expectRet4 = {11, 11, 33, 44, 55, 88, 55}; InputArgs args4 = {keys4, cnt4, expectRet4, args2.expectHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args4); + FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args4); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tensorBBB", 3, 7} + {"tableBBB", 3, 7} keys5 = {125, 121, 122, 212, 211} cnt5 = 1 2 1 3 1 */ keys_t expectRet5 = {-1, 121, -1, 212, 211}; InputArgs args5 = {keys5, cnt5, expectRet5, args3.expectHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[1].tensorName, args5); + FeatureAdmitCommon(faae, 0, thresholds[1].tableName, args5); WaitEvictThread(); LOG(INFO) << "TestCase1: single thread test over ..."; @@ -331,7 +332,7 @@ protected: vector tmpCnt = {1, 2, 1, 3, 1, 1, 4}; std::unique_ptr batch = make_unique(); - batch->name = thresholds[0].tensorName; + batch->name = thresholds[0].tableName; batch->timestamp = time(nullptr); // 校验调接口,出错 @@ -375,6 +376,7 @@ protected: { faae.ResetAllRecords(); faae.ParseThresholdCfg(thresholds); + faae.SetCombineSwitch(); StartEvictThread(); std::thread thrs[PerfConfig::keyProcessThreadNum]; @@ -396,9 +398,9 @@ protected: { /* 如果没有淘汰功能 - tensorAAA数据将会是 {11, 12, 21, 33, 44, 55, 88} + tableAAA数据将会是 {11, 12, 21, 33, 44, 55, 88} 10 5 1 5 5 4 6 - tensorBBB数据将会是 {121, 122, 123, 125, 211, 212}; + tableBBB数据将会是 {121, 122, 123, 125, 211, 212}; 5 1 1 1 3 4 */ keys_t expectKeys1 = {11, 33, 44, 55, 88}; // 12,21被淘汰掉了 @@ -406,8 +408,8 @@ protected: keys_t expectKeys2 = {121, 122, 125, 211, 212}; // 123被淘汰掉了 vector expectCnt2 = {5, 1, 1, 3, 4}; std::lock_guard lock(faae.m_syncMutexs); // 与 evict-thread 竞争资源 - CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tensorName, PerfConfig::keyProcessThreadNum); - CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tensorName, PerfConfig::keyProcessThreadNum); + CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tableName, PerfConfig::keyProcessThreadNum); + CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tableName, PerfConfig::keyProcessThreadNum); } WaitEvictThread(); @@ -421,8 +423,8 @@ protected: faae.ParseThresholdCfg(thresholds); std::unique_ptr batch = make_unique(); - // 测试点:tensorDDD表没有配置阈值,则不支持 - batch->name = std::string("tensorDDD"); + // 测试点:tableDDD表没有配置阈值,则不支持 + batch->name = std::string("tableDDD"); batch->timestamp = time(nullptr); // 校验调接口,不支持 @@ -443,11 +445,21 @@ protected: vector cnt4 = {1, 2, 3, 2, 1, 2, 1}; keys_t keys5 = {125, 121, 122, 212, 211}; vector cnt5 = {1, 2, 1, 3, 1}; - std::vector thresholds = {{"tensorAAA", 2, 5}, {"tensorBBB", 3, 7}, {"tensorCCC", 5, 9}}; + std::vector thresholds = {{"tableAAA", 2, 5, 1}, {"tableBBB", 3, 7, 1}, {"tableCCC", 5, 9, 1}}; }; +void SetEnv() +{ + const char* name = "FAAE_MODE"; + const char* mode = "1"; + int overwrite = 1; + + ASSERT_EQ(setenv(name, mode, overwrite), 0); +} + TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict1) { + SetEnv(); TestCase1(); } TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict2) @@ -464,6 +476,7 @@ TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict4) } TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict5) { + SetEnv(); TestCase5(); } TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict6) -- Gitee From b682150c94fa313f1ec34939f2cfbdf3e433e636 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 28 Jul 2023 16:02:03 +0800 Subject: [PATCH 224/551] Match-id-3ce2f03c0d183757f17f9633192d313238e1e534 --- src/core/key_process/key_process.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 0ca9d507..b0f3edbf 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -768,9 +768,7 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< } ssTrace << '|'; } - VLOG(GLOG_TRACE) << StringFormat( - "dump splitKeys\n%s", ssTrace.str().c_str() - ); + VLOG(GLOG_TRACE) << "dump splitKeys " << ssTrace.str(); } return { splitKeys, restore }; } @@ -821,7 +819,7 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const } ssTrace << '|'; } - VLOG(GLOG_TRACE) << StringFormat("dump splitKeys\n%s", ssTrace.str().c_str()); + VLOG(GLOG_TRACE) << "dump splitKeys " << ssTrace.str(); } return { splitKeys, restore, keyCount }; -- Gitee From ef7e3ed513db85bd31de6e61b6d49867b5c9a4d7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 28 Jul 2023 16:43:24 +0800 Subject: [PATCH 225/551] Match-id-e524a9485dc8e27af31cf5859c0add76b6e48bc8 --- mx_rec/core/asc/manager.py | 2 +- src/core/emb_hashmap/emb_hashmap.cpp | 4 ++-- src/core/key_process/key_process.cpp | 8 ++++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 9fb1976f..3892837e 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -162,7 +162,7 @@ def generate_threshold_list(): threshold_list = [] for _, feature_spec in export_feature_spec().items(): - coef = 1 if feature_spec.faae_coefficent is None else feature_spec.faae_coefficent + coef = 1 if feature_spec.faae_coefficient is None else feature_spec.faae_coefficient if feature_spec.eviction_threshold: threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index d1918cde..5e3b3dbb 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -94,8 +94,8 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t } embHashMap.swapPos.clear(); embHashMap.lookUpVec.clear(); - LOG(INFO) << StringFormat("current dev emb usage:%d/[%d+%d]", embHashMap.maxOffset, embHashMap.devVocabSize, - embHashMap.hostVocabSize); + LOG(INFO) << StringFormat("current ddr emb:%s, usage:%d/[%d+%d]", embName.c_str(), embHashMap.maxOffset, + embHashMap.devVocabSize, embHashMap.hostVocabSize); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto swapLen = tmpDataOut.back().flat(); swapLen(0) = swapSize; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 0ca9d507..f8b37f76 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1042,7 +1042,8 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int cha LOG(ERROR) << StringFormat("dev cache overflow %d>%d", maxOffsetTmp, embInfos[embName].devVocabSize); throw std::runtime_error("dev cache overflow!"); } - VLOG(GLOG_DEBUG) << StringFormat("current dev emb usage:%d/%d", maxOffsetTmp, embInfos[embName].devVocabSize); + VLOG(GLOG_DEBUG) << StringFormat("current hbm emb:%s, usage:%d/%d", embName.c_str(), maxOffsetTmp, + embInfos[embName].devVocabSize); VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); } @@ -1076,7 +1077,8 @@ void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& s key = 0; } } - VLOG(GLOG_DEBUG) << StringFormat("current dev emb usage:%d/%d", maxOffsetTmp, embInfos[embName].devVocabSize); + VLOG(GLOG_DEBUG) << StringFormat("current expansion emb:%s, usage:%d/%d", embName.c_str(), maxOffsetTmp, + embInfos[embName].devVocabSize); VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); } @@ -1135,6 +1137,7 @@ T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, in return move(t); } +// DDR keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) { TimeCost tc = TimeCost(); @@ -1161,6 +1164,7 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) } } +// HBM unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) { TimeCost tc = TimeCost(); -- Gitee From 4c409884391794a120b5f02d850544fdf8bbe5c0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 29 Jul 2023 17:58:36 +0800 Subject: [PATCH 226/551] Match-id-93fa3986f9cea862ee84f6d044c0ebbe9126e44c --- mx_rec/core/asc/feature_spec.py | 7 +++++++ mx_rec/core/embedding.py | 19 ++++++++++++++++++- src/core/checkpoint/checkpoint.cpp | 24 +++++++----------------- src/core/checkpoint/checkpoint.h | 2 +- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 8351a532..50eede22 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -3,6 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import logging +import re from typing import Union from functools import reduce @@ -40,6 +41,8 @@ class FeatureSpec: self.split = None # usually split == batch_size * feature_count self.initialized = False self._pipeline_mode = set() + + self.fix_invalid_table_name() self.check_params() @property @@ -130,6 +133,10 @@ class FeatureSpec: if self._is_timestamp is not None: check_bool(self._is_timestamp, "is_timestamp") + def fix_invalid_table_name(self): + if not re.match("^[0-9A-Za-z_]+$", self._table_name): + self._table_name = re.sub(r'\W+', '_', self._table_name) + def set_feat_pos(self, is_training): if is_training: self.feat_pos_train = FeatureSpec.instance_count_train diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index ec9bbde5..dc599fa8 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -4,6 +4,7 @@ import logging import math +import re from collections import defaultdict from typing import Optional @@ -69,6 +70,7 @@ def create_table(**kwargs): all2all_gradients_op = kwargs.get("all2all_gradients_op", All2allGradientsOp.SUM_GRADIENTS) apply_gradients_strategy = kwargs.get("apply_gradients_strategy", ApplyGradientsStrategy.DIRECT_APPLY) + name = fix_invalid_table_name(name) check_create_table_params(key_dtype, dim, name, emb_initializer) config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, @@ -907,7 +909,7 @@ def check_create_table_params(key_dtype, dim, name, emb_initializer): dim_validator.check_isinstance() dim_validator.check() # check name - name_validator = StringValidator(name) + name_validator = StringValidator(value=name, max_len=255) name_validator.check_string_length() name_validator.check_whitelist() name_validator.check() @@ -915,3 +917,18 @@ def check_create_table_params(key_dtype, dim, name, emb_initializer): emb_initializer_validator = ClassValidator(value=emb_initializer, classes=(InitializerV1, InitializerV2)) emb_initializer_validator.check_isinstance() emb_initializer_validator.check() + + +def fix_invalid_table_name(name): + """ + 校验table name字符串中是否含有特殊字符,如有,替换为下划线 + :param name: table name + :return : the fixed table name + """ + if re.match("^[0-9A-Za-z_]+$", name): + return name + fix_name = re.sub(r'\W+', '_', name) + logging.warning(f"The table name {name} contains invalid characters." + f"The system automatically replaces invalid characters with underscores (_). " + f"The table name was changed to {fix_name}") + return fix_name diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index df4b2115..dfedfe5c 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -338,7 +338,7 @@ void Checkpoint::LoadProcess(CkptData& ckptData) vector saveDataTypes { dataHandler->GetDataTypes() }; GetUpperLayerLoadDir(dirNames); - embNames = GetTableLayerLoadDir(); + embNames = GetEmbedTableNames(); LoadDataset(embNames, saveDataTypes, dataHandler, ckptData); @@ -355,22 +355,16 @@ void Checkpoint::GetUpperLayerLoadDir(const vector& dirNames) } } -vector Checkpoint::GetTableLayerLoadDir() +vector Checkpoint::GetEmbedTableNames() { - vector loadTableDir; - auto dir { opendir(innerDirPath.c_str()) }; - struct dirent* en; - if (dir != nullptr) { - while ((en = readdir(dir)) != nullptr) { - if (strcmp(en->d_name, currDir.c_str()) != 0 && - strcmp(en->d_name, prevDir.c_str()) != 0) { - loadTableDir.emplace_back(en->d_name); - } + vector loadTableNames; + for (const auto& embInfo : mgmtEmbInfo) { + if (embInfo.isSave == true) { + loadTableNames.push_back(embInfo.name); } - closedir(dir); } - return loadTableDir; + return loadTableNames; } void Checkpoint::LoadDataset(const vector& embNames, @@ -379,10 +373,6 @@ void Checkpoint::LoadDataset(const vector& embNames, CkptData& ckptData) { for (const auto& embName : embNames) { - if (!CheckEmbNames(embName)) { - continue; - } - auto dataDir { innerDirPath + dirSeparator + embName }; for (const auto& saveDataType : saveDataTypes) { auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index c9843ecf..4c3abc53 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -88,7 +88,7 @@ namespace MxRec { void LoadProcess(CkptData& ckptData); void GetUpperLayerLoadDir(const vector& dirNames); - vector GetTableLayerLoadDir(); + vector GetEmbedTableNames(); void LoadDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler, CkptData& ckptData); void ReadStream(CkptTransData& transData, const string& dataDir, CkptDataType dataType, uint32_t dataElmtBytes); -- Gitee From 0b78320d5af28e921b6f3992d4a554031c49375a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 2 Aug 2023 22:43:53 +0800 Subject: [PATCH 227/551] Match-id-c4f69ec36203822f26562fef4384775d69792ca3 --- mx_rec/core/asc/manager.py | 2 +- src/core/emb_hashmap/emb_hashmap.cpp | 4 ++-- src/core/key_process/key_process.cpp | 8 ++------ 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 3892837e..9fb1976f 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -162,7 +162,7 @@ def generate_threshold_list(): threshold_list = [] for _, feature_spec in export_feature_spec().items(): - coef = 1 if feature_spec.faae_coefficient is None else feature_spec.faae_coefficient + coef = 1 if feature_spec.faae_coefficent is None else feature_spec.faae_coefficent if feature_spec.eviction_threshold: threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 5e3b3dbb..d1918cde 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -94,8 +94,8 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t } embHashMap.swapPos.clear(); embHashMap.lookUpVec.clear(); - LOG(INFO) << StringFormat("current ddr emb:%s, usage:%d/[%d+%d]", embName.c_str(), embHashMap.maxOffset, - embHashMap.devVocabSize, embHashMap.hostVocabSize); + LOG(INFO) << StringFormat("current dev emb usage:%d/[%d+%d]", embHashMap.maxOffset, embHashMap.devVocabSize, + embHashMap.hostVocabSize); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto swapLen = tmpDataOut.back().flat(); swapLen(0) = swapSize; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index f24f58ff..b0f3edbf 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1040,8 +1040,7 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int cha LOG(ERROR) << StringFormat("dev cache overflow %d>%d", maxOffsetTmp, embInfos[embName].devVocabSize); throw std::runtime_error("dev cache overflow!"); } - VLOG(GLOG_DEBUG) << StringFormat("current hbm emb:%s, usage:%d/%d", embName.c_str(), maxOffsetTmp, - embInfos[embName].devVocabSize); + VLOG(GLOG_DEBUG) << StringFormat("current dev emb usage:%d/%d", maxOffsetTmp, embInfos[embName].devVocabSize); VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); } @@ -1075,8 +1074,7 @@ void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& s key = 0; } } - VLOG(GLOG_DEBUG) << StringFormat("current expansion emb:%s, usage:%d/%d", embName.c_str(), maxOffsetTmp, - embInfos[embName].devVocabSize); + VLOG(GLOG_DEBUG) << StringFormat("current dev emb usage:%d/%d", maxOffsetTmp, embInfos[embName].devVocabSize); VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); } @@ -1135,7 +1133,6 @@ T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, in return move(t); } -// DDR keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) { TimeCost tc = TimeCost(); @@ -1162,7 +1159,6 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) } } -// HBM unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) { TimeCost tc = TimeCost(); -- Gitee From 9bf5505b9d68ced4fbaec2946f8fb88ef126e80e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 2 Aug 2023 22:53:14 +0800 Subject: [PATCH 228/551] Match-id-298e1f05d9069b3083513be66112afeda47e55e2 --- example/little_demo/main.py | 18 +- example/little_demo/run.sh | 1 - mx_rec/core/asc/feature_spec.py | 15 +- mx_rec/core/asc/manager.py | 7 +- mx_rec/optimizers/base.py | 4 +- src/core/checkpoint/checkpoint.cpp | 2 +- src/core/checkpoint/checkpoint.h | 2 +- .../ckpt_data_handler/ckpt_data_handler.h | 2 +- .../feat_admit_n_evict_ckpt.cpp | 69 +++---- .../feat_admit_n_evict_ckpt.h | 12 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 +- .../key_process/feature_admit_and_evict.cpp | 109 ++++------ .../key_process/feature_admit_and_evict.h | 18 +- src/core/utils/common.cpp | 16 -- src/core/utils/common.h | 17 +- src/pybind/module_main.cpp | 7 +- src/tests/checkpoint/checkpoint_test.cpp | 63 ++---- .../ckpt_data_handler_test.cpp | 195 ++++++------------ src/tests/emb_mgmt/emb_mgmt_test.cpp | 6 +- .../feature_admit_and_evict_test.cpp | 69 +++---- 20 files changed, 217 insertions(+), 419 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index b8abf09e..d2840d5f 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -93,21 +93,17 @@ def create_feature_spec_list(use_timestamp=False): eviction_threshold = cfg.eviction_threshold if use_timestamp else None feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold, - faae_coefficient=1), + eviction_threshold=eviction_threshold), FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold, - faae_coefficient=4)] + eviction_threshold=eviction_threshold)] if use_multi_lookup: feature_spec_list.extend([FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold, - faae_coefficient=1), + eviction_threshold=eviction_threshold), FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold, - faae_coefficient=4)]) + eviction_threshold=eviction_threshold)]) if use_timestamp: feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) return feature_spec_list @@ -156,10 +152,8 @@ if __name__ == "__main__": # access_threshold unit counts; eviction_threshold unit seconds ACCESS_AND_EVICT = None if USE_TIMESTAMP: - config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, - faae_coefficient=1) - config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, - faae_coefficient=4) + config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) + config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) ACCESS_AND_EVICT = dict(user_table=config_for_user_table, item_table=config_for_item_table) train_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) eval_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index ba99348f..2cd84a80 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -84,7 +84,6 @@ export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 export UpdateEmb_V2=0 # 0: UpdateEmb同步更新;1:UpdateEmb_V2异步更新 -export FAAE_MODE=0 # 0: combine history when faae; 1: separate history when faae ################# 性能调优相关 #################### export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 export FAST_UNIQUE=0 #if use fast unique diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 50eede22..dd1bff70 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -31,7 +31,6 @@ class FeatureSpec: self._feat_cnt = kwargs.get("feat_count") self._access_threshold = kwargs.get("access_threshold") self._eviction_threshold = kwargs.get("eviction_threshold") - self._faae_coefficient = kwargs.get("faae_coefficient", 1) self._is_timestamp = kwargs.get("is_timestamp") self.feat_pos_train = None self.feat_pos_eval = None @@ -57,10 +56,6 @@ class FeatureSpec: def eviction_threshold(self): return self._eviction_threshold - @property - def faae_coefficient(self): - return self._faae_coefficient - @property def index_key(self): return self._index_key @@ -125,11 +120,6 @@ class FeatureSpec: if self._eviction_threshold > MAX_INT32: raise ValueError(f"Eviction_threshold is too big that exceed int32.") - if self._faae_coefficient is not None: - check_natural_number(self._faae_coefficient, "eviction_threshold") - if self._faae_coefficient > MAX_INT32: - raise ValueError(f"Eviction_threshold is too big that exceed int32.") - if self._is_timestamp is not None: check_bool(self._is_timestamp, "is_timestamp") @@ -207,13 +197,10 @@ class FeatureSpec: def get_feature_spec(table_name, access_and_evict_config): access_threshold = None eviction_threshold = None - faae_coefficient = None if access_and_evict_config: access_threshold = access_and_evict_config.get("access_threshold") eviction_threshold = access_and_evict_config.get("eviction_threshold") - faae_coefficient = access_and_evict_config.get("faae_coefficient", 1) - return FeatureSpec(table_name, access_threshold=access_threshold, eviction_threshold=eviction_threshold, - faae_coefficient=faae_coefficient) + return FeatureSpec(table_name, access_threshold=access_threshold, eviction_threshold=eviction_threshold) def set_temporary_feature_spec_attribute(mock_feature_spec: FeatureSpec, total_feature_count: Union[int, tf.Tensor]): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 9fb1976f..dfc8eb7a 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -162,19 +162,16 @@ def generate_threshold_list(): threshold_list = [] for _, feature_spec in export_feature_spec().items(): - coef = 1 if feature_spec.faae_coefficent is None else feature_spec.faae_coefficent if feature_spec.eviction_threshold: threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, - feature_spec.eviction_threshold, - coef) + feature_spec.eviction_threshold) threshold_list.append(threshold) continue if feature_spec.access_threshold: threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, - -1, - coef) + -1) threshold_list.append(threshold) return threshold_list diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index 632e8ba3..9ad62ec6 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -41,7 +41,7 @@ class CustomizedOptimizer: self.base_name = name -def custom_update_op(self, opt, grad): +def my_update_op(self, opt, grad): if isinstance(grad, ops.Tensor): update_op = opt._apply_sparse(grad, self._v) # pylint: disable=protected-access return update_op @@ -50,5 +50,5 @@ def custom_update_op(self, opt, grad): def patch_for_optimizer(): - _TensorProcessor.update_op = custom_update_op + _TensorProcessor.update_op = my_update_op logging.debug("update_op in Class optimizer._TensorProcessor has been patched.") \ No newline at end of file diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index dfedfe5c..3c5e92df 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -69,7 +69,7 @@ void Checkpoint::SetDataHandler(CkptData& ckptData) if (!ckptData.keyOffsetMap.empty()) { dataHandlers.push_back(make_unique()); } - if (!ckptData.table2Thresh.empty() && !ckptData.histRec.timestamps.empty() && + if (!ckptData.tens2Thresh.empty() && !ckptData.histRec.timestamps.empty() && !ckptData.histRec.historyRecords.empty()) { dataHandlers.push_back(make_unique()); } diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 4c3abc53..4baa6f0c 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -42,7 +42,7 @@ namespace MxRec { CkptDataType::EMB_INFO, CkptDataType::EMB_CURR_STAT, CkptDataType::NDDR_OFFSET, - CkptDataType::TABLE_2_THRESH + CkptDataType::TENSOR_2_THRESH }; const set int64TransSet{ CkptDataType::EMB_HASHMAP, diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index aea1d2b7..ecc7907c 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -47,7 +47,7 @@ namespace MxRec { "embedding_current_status", "max_offset", "key_offset_map", - "table_2_threshold", + "tensor_2_threshold", "history_record" }; const vector dataElmtBytes { 4, 4, 8, 8, 4, 4, 8, 4, 8, 8 }; diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index a2d6f96e..1e4e9d69 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -14,18 +14,18 @@ using namespace MxRec; void FeatAdmitNEvictCkpt::SetProcessData(CkptData& processData) { ClearData(); - if (processData.table2Thresh.empty() || processData.histRec.timestamps.empty() || + if (processData.tens2Thresh.empty() || processData.histRec.timestamps.empty() || processData.histRec.historyRecords.empty()) { LOG(ERROR) << "Missing Feature Admit and Evict data"; throw std::runtime_error("Missing Feature Admit and Evict data"); } - saveTable2Thresh = std::move(processData.table2Thresh); + saveTens2Thresh = std::move(processData.tens2Thresh); saveHistRec = std::move(processData.histRec); } void FeatAdmitNEvictCkpt::GetProcessData(CkptData& processData) { - processData.table2Thresh = std::move(loadTable2Thresh); + processData.tens2Thresh = std::move(loadTens2Thresh); processData.histRec = std::move(loadHistRec); ClearData(); } @@ -43,7 +43,7 @@ vector FeatAdmitNEvictCkpt::GetDirNames() vector FeatAdmitNEvictCkpt::GetEmbNames() { vector embNames; - for (const auto& item : saveTable2Thresh) { + for (const auto& item : saveTens2Thresh) { embNames.push_back(item.first); } return embNames; @@ -51,8 +51,8 @@ vector FeatAdmitNEvictCkpt::GetEmbNames() CkptTransData FeatAdmitNEvictCkpt::GetDataset(CkptDataType dataType, string embName) { - map> dataTransMap { { CkptDataType::TABLE_2_THRESH, - [=] { SetTable2ThreshTrans(embName); } }, + map> dataTransMap { { CkptDataType::TENSOR_2_THRESH, + [=] { SetTens2ThreshTrans(embName); } }, { CkptDataType::HIST_REC, [=] { SetHistRecTrans(embName); } } }; CleanTransfer(); @@ -62,8 +62,8 @@ CkptTransData FeatAdmitNEvictCkpt::GetDataset(CkptDataType dataType, string embN void FeatAdmitNEvictCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { - map> dataLoadMap { { CkptDataType::TABLE_2_THRESH, - [=] { SetTable2Thresh(embName); } }, + map> dataLoadMap { { CkptDataType::TENSOR_2_THRESH, + [=] { SetTens2Thresh(embName); } }, { CkptDataType::HIST_REC, [=] { SetHistRec(embName); } } }; CleanTransfer(); @@ -73,31 +73,27 @@ void FeatAdmitNEvictCkpt::SetDataset(CkptDataType dataType, string embName, Ckpt void FeatAdmitNEvictCkpt::ClearData() { - saveTable2Thresh.clear(); - loadTable2Thresh.clear(); + saveTens2Thresh.clear(); + loadTens2Thresh.clear(); saveHistRec.timestamps.clear(); saveHistRec.historyRecords.clear(); loadHistRec.timestamps.clear(); loadHistRec.historyRecords.clear(); } -void FeatAdmitNEvictCkpt::SetTable2ThreshTrans(string embName) +void FeatAdmitNEvictCkpt::SetTens2ThreshTrans(string embName) { - auto table2ThreshSize = GetTable2ThreshSize(); + auto tens2ThreshSize = GetTens2ThreshSize(); auto& transArr = transferData.int32Arr; - const auto& table2Thresh = saveTable2Thresh.at(embName); + const auto& tens2Thresh = saveTens2Thresh.at(embName); - transArr.reserve(table2ThreshSize); - transArr.push_back(table2Thresh.countThreshold); - transArr.push_back(table2Thresh.timeThreshold); + transArr.reserve(tens2ThreshSize); + transArr.push_back(tens2Thresh.countThreshold); + transArr.push_back(tens2Thresh.timeThreshold); } -// save void FeatAdmitNEvictCkpt::SetHistRecTrans(string embName) { - if (GetCombineSwitch()) { - embName = COMBINE_HISTORY_NAME; - } auto histRecSize = GetHistRecSize(embName); auto& transArr = transferData.int64Arr; const auto& timeStamp = saveHistRec.timestamps.at(embName); @@ -113,22 +109,18 @@ void FeatAdmitNEvictCkpt::SetHistRecTrans(string embName) } } -void FeatAdmitNEvictCkpt::SetTable2Thresh(string embName) +void FeatAdmitNEvictCkpt::SetTens2Thresh(string embName) { const auto& transArr = transferData.int32Arr; - auto& tens2Thresh = loadTable2Thresh[embName]; + auto& tens2Thresh = loadTens2Thresh[embName]; - tens2Thresh.tableName = embName; + tens2Thresh.tensorName = embName; tens2Thresh.countThreshold = transArr[countThresholdIdx]; tens2Thresh.timeThreshold = transArr[timeThresholdIdx]; } -// load void FeatAdmitNEvictCkpt::SetHistRec(string embName) { - if (GetCombineSwitch()) { - embName = COMBINE_HISTORY_NAME; - } const auto& transArr = transferData.int64Arr; const auto& attribute = transferData.attribute; auto& timestamp = loadHistRec.timestamps[embName]; @@ -137,26 +129,17 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) timestamp = transArr.front(); size_t featItemInfoTotalSize = attribute.front() * static_cast(featItemInfoSaveNum); - VLOG(GLOG_DEBUG) << StringFormat("====Start SetHistRec, name: %s, featItemInfoTotalSize: %ld", embName.c_str(), - featItemInfoTotalSize); - - size_t process = 0; - size_t printPerStep = ((featItemInfoTotalSize / 100) > 0 ? (featItemInfoTotalSize / 100) : 1); for (size_t i = featItemInfoOffset; i < featItemInfoTotalSize + featItemInfoOffset; i += featItemInfoSaveNum) { - process = i % printPerStep; - if (process == 1) { - VLOG(GLOG_DEBUG) << StringFormat("====in SetHistRec, process : %f", i/featItemInfoTotalSize); - } - auto featureId = transArr[i + featureIdIdxOffset]; - auto count = transArr[i + countIdxOffset]; - auto lastTime = transArr[i + lastTimeIdxOffset]; - - histRecs.emplace(featureId, FeatureItemInfo(static_cast(count), lastTime)); + const auto& featureId = transArr[i + featureIdIdxOffset]; + const auto& count = transArr[i + countIdxOffset]; + const auto& lastTime = transArr[i + lastTimeIdxOffset]; + + histRecs[featureId].count = static_cast(count); + histRecs[featureId].lastTime = lastTime; } - VLOG(GLOG_DEBUG) << StringFormat("====End SetHistRec, name: %s", embName.c_str()); } -int FeatAdmitNEvictCkpt::GetTable2ThreshSize() +int FeatAdmitNEvictCkpt::GetTens2ThreshSize() { auto& attribute = transferData.attribute; auto& attribSize = transferData.attributeSize; diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h index 268120c3..2b12d315 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h @@ -31,7 +31,7 @@ namespace MxRec { private: const vector fileDirNames { "HashTable", "DDR" }; - const vector saveDataTypes { CkptDataType::TABLE_2_THRESH, CkptDataType::HIST_REC }; + const vector saveDataTypes { CkptDataType::TENSOR_2_THRESH, CkptDataType::HIST_REC }; const int featItemInfoSaveNum { 3 }; const int threshValSaveNum { 2 }; @@ -45,21 +45,21 @@ namespace MxRec { const int countIdxOffset { 1 }; const int lastTimeIdxOffset { 2 }; - table_2_thresh_mem_t saveTable2Thresh; - table_2_thresh_mem_t loadTable2Thresh; + tensor_2_thresh_mem_t saveTens2Thresh; + tensor_2_thresh_mem_t loadTens2Thresh; AdmitAndEvictData saveHistRec; AdmitAndEvictData loadHistRec; void ClearData(); - void SetTable2ThreshTrans(string embName); + void SetTens2ThreshTrans(string embName); void SetHistRecTrans(string embName); - void SetTable2Thresh(string embName); + void SetTens2Thresh(string embName); void SetHistRec(string embName); - int GetTable2ThreshSize(); + int GetTens2ThreshSize(); size_t GetHistRecSize(string embName); }; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index cd400a52..0f8b72f0 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -150,7 +150,7 @@ bool HybridMgmt::Save(const string savePath) auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: feature admit and evict"); - saveData.table2Thresh = featAdmitNEvict.GetTableThresholds(); + saveData.tens2Thresh = featAdmitNEvict.GetTensorThresholds(); saveData.histRec.timestamps = featAdmitNEvict.GetHistoryRecords().timestamps; saveData.histRec.historyRecords = featAdmitNEvict.GetHistoryRecords().historyRecords; } @@ -201,7 +201,7 @@ bool HybridMgmt::Load(const string& loadPath) } if (featAdmitNEvict.GetFunctionSwitch()) { VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: feature admit and evict"); - featAdmitNEvict.LoadTableThresholds(loadData.table2Thresh); + featAdmitNEvict.LoadTensorThresholds(loadData.tens2Thresh); featAdmitNEvict.LoadHistoryRecords(loadData.histRec); } diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 80db3003..29b2da1b 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -31,7 +31,6 @@ bool FeatureAdmitAndEvict::Init(const std::vector& thresholdValu LOG(ERROR) << "Config is error, feature admin-and-evict function is not available ...\n"; return false; } - SetCombineSwitch(); return true; } @@ -45,27 +44,24 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR; } - std::string tableName = batch->name; - if (m_isCombine) { - tableName = COMBINE_HISTORY_NAME; - } + // 如果当前 tensorName 不在准入范围之内,则不进行“特征准入”逻辑 + std::string tensorName = batch->name; absl::flat_hash_map mergeKeys; mergeKeys.reserve(splitKey.size()); PreProcessKeys(splitKey, keyCount, mergeKeys); std::lock_guard lock(m_syncMutexs); - // 如果当前 tableName 不在准入范围之内,则不进行“特征准入”逻辑 - auto iter = m_recordsData.historyRecords.find(tableName); - if (iter == m_recordsData.historyRecords.end()) { // 之前tableName没出现过时,数据初始化 + auto iter = m_recordsData.historyRecords.find(tensorName); + if (iter == m_recordsData.historyRecords.end()) { // 之前tensorName没出现过时,数据初始化 absl::flat_hash_map records(m_recordsInitSize); - m_recordsData.historyRecords[tableName] = records; + m_recordsData.historyRecords[tensorName] = records; } VLOG(GLOG_DEBUG) << StringFormat( - "FeatureAdmitAndEvict PrintSize, name:[%s], history key:[%d] ...", tableName.c_str(), - m_recordsData.historyRecords[tableName].size()); + "FeatureAdmitAndEvict PrintSize, name:[%s], history key:[%d] ...", tensorName.c_str(), + m_recordsData.historyRecords[tensorName].size()); - if (batch->timestamp > m_recordsData.timestamps[tableName]) { - m_recordsData.timestamps[tableName] = batch->timestamp; + if (batch->timestamp > m_recordsData.timestamps[tensorName]) { + m_recordsData.timestamps[tensorName] = batch->timestamp; } absl::flat_hash_map visitedRecords; for (auto& key : splitKey) { @@ -77,7 +73,7 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, auto it = visitedRecords.find(key); if (it == visitedRecords.end()) { visitedRecords[key] = true; - if (FeatureAdmitHelper(channel, batch->name, key, mergeKeys[key]) == + if (FeatureAdmitHelper(channel, tensorName, key, mergeKeys[key]) == FeatureAdmitType::FEATURE_ADMIT_FAILED) { visitedRecords[key] = false; key = -1; // 被淘汰的Feature ID @@ -91,38 +87,32 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, } if (VLOG_IS_ON(GLOG_TRACE)) { VLOG(GLOG_TRACE) << StringFormat( - "FeatureAdmit, name:[%s], channel:[%d], after admit, splitKey:[%s] ...", tableName.c_str(), channel, + "FeatureAdmit, name:[%s], channel:[%d], after admit, splitKey:[%s] ...", tensorName.c_str(), channel, VectorToString(splitKey).c_str()); } return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; } -FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, const std::string& tableNameOrigin, +FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, const std::string& tensorName, const int64_t featureId, const uint32_t featureCnt) { // “特征准入”逻辑 uint32_t currKeyCount = 0; - std::string tableName = tableNameOrigin; - if (m_isCombine) { - tableName = COMBINE_HISTORY_NAME; - } - - absl::flat_hash_map& historyRecordInfos = m_recordsData.historyRecords[tableName]; + absl::flat_hash_map& historyRecordInfos = m_recordsData.historyRecords[tensorName]; auto innerIt = historyRecordInfos.find(featureId); if (channel == TRAIN_CHANNEL_ID) { if (innerIt == historyRecordInfos.end()) { // 维护 m_historyRecords - FeatureItemInfo info(featureCnt, m_recordsData.timestamps[tableName]); - info.count *= m_table2Threshold[tableNameOrigin].faaeCoefficient; + FeatureItemInfo info(featureCnt, m_recordsData.timestamps[tensorName]); historyRecordInfos[featureId] = info; - currKeyCount = info.count; + currKeyCount = featureCnt; } else { // 维护 m_historyRecords FeatureItemInfo &info = historyRecordInfos[featureId]; - info.count += m_table2Threshold[tableNameOrigin].faaeCoefficient * featureCnt; - info.lastTime = m_recordsData.timestamps[tableName]; + info.count += featureCnt; + info.lastTime = m_recordsData.timestamps[tensorName]; currKeyCount = info.count; } } else if (channel == EVAL_CHANNEL_ID) { // eval @@ -132,7 +122,7 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con } // 准入条件判断 - if (currKeyCount >= static_cast(m_table2Threshold[tableNameOrigin].countThreshold)) { + if (currKeyCount >= static_cast(m_tensor2Threshold[tensorName].countThreshold)) { return FeatureAdmitType::FEATURE_ADMIT_OK; } @@ -142,8 +132,8 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con // 特征淘汰接口 void FeatureAdmitAndEvict::FeatureEvict(map>& evictKeyMap) { - std::vector tableNames = GetAllNeedEvictTableNames(); - if (tableNames.empty()) { + std::vector tensorNames = GetAllNeedEvictTensorNames(); + if (tensorNames.empty()) { LOG(INFO) << "EmbNames is empty, no evict function ..."; return ; } @@ -153,9 +143,9 @@ void FeatureAdmitAndEvict::FeatureEvict(map> } std::lock_guard lock(m_syncMutexs); // 从 m_historyRecords 中淘汰删除 - size_t tableCnt = tableNames.size(); - for (size_t i = 0; i < tableCnt; ++i) { - FeatureEvictHelper(tableNames[i], evictKeyMap[tableNames[i]]); + size_t tensorCnt = tensorNames.size(); + for (size_t i = 0; i < tensorCnt; ++i) { + FeatureEvictHelper(tensorNames[i], evictKeyMap[tensorNames[i]]); } } @@ -163,7 +153,7 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v { // 从 m_historyRecords 中淘汰删除 time_t currTime = m_recordsData.timestamps[embName]; - // 从 m_table2SortedLastTime 获取当前要淘汰的featureId + // 从 m_tensor2SortedLastTime 获取当前要淘汰的featureId auto cmp = [](const auto& a, const auto& b) { return a.second.lastTime > b.second.lastTime; }; std::priority_queue, std::vector>, decltype(cmp)> lastTimePriority(cmp); @@ -171,7 +161,7 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v lastTimePriority.emplace(item); } while (!lastTimePriority.empty()) { - if (currTime - lastTimePriority.top().second.lastTime < m_table2Threshold[embName].timeThreshold) { + if (currTime - lastTimePriority.top().second.lastTime < m_tensor2Threshold[embName].timeThreshold) { break; } evictKey.emplace_back(lastTimePriority.top().first); @@ -180,11 +170,11 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v if (evictKey.size() == 0) { LOG(INFO) << StringFormat( - "table-name[%s]'s lastTime[%d], had no key to delete ...", embName.c_str(), currTime); + "tensor-name[%s]'s lastTime[%d], had no key to delete ...", embName.c_str(), currTime); return; } LOG(INFO) << StringFormat( - "table-name[%s]'s lastTime[%d], had size[%d] keys to delete ...", embName.c_str(), currTime, evictKey.size()); + "tensor-name[%s]'s lastTime[%d], had size[%d] keys to delete ...", embName.c_str(), currTime, evictKey.size()); // 真正从 m_historyRecords 中淘汰 absl::flat_hash_map& historyRecords = m_recordsData.historyRecords[embName]; @@ -206,12 +196,6 @@ void FeatureAdmitAndEvict::SetFunctionSwitch(bool isEnableEvict) } m_isEnableFunction = isEnableEvict; } - -void FeatureAdmitAndEvict::SetCombineSwitch() -{ - m_isCombine = GetCombineSwitch(); -} - bool FeatureAdmitAndEvict::GetFunctionSwitch() const { return m_isEnableFunction; @@ -238,10 +222,10 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t const std::vector& embNames, bool isTimestamp) { for (size_t i = 0; i < thresholds.size(); ++i) { - auto it = std::find(embNames.begin(), embNames.end(), thresholds[i].tableName); + auto it = std::find(embNames.begin(), embNames.end(), thresholds[i].tensorName); if (it == embNames.end()) { // 配置不存在于当前跑的模型,也要报错 LOG(ERROR) << StringFormat( - "embName[%s] is not exist at current model ...", thresholds[i].tableName.c_str()); + "embName[%s] is not exist at current model ...", thresholds[i].tensorName.c_str()); return false; } else { // 同时支持“准入&淘汰”,却没有传时间戳 @@ -258,10 +242,10 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t return true; } -auto FeatureAdmitAndEvict::GetTableThresholds() -> table_2_thresh_mem_t +auto FeatureAdmitAndEvict::GetTensorThresholds() -> tensor_2_thresh_mem_t { std::lock_guard lock(m_syncMutexs); - return m_table2Threshold; + return m_tensor2Threshold; } auto FeatureAdmitAndEvict::GetHistoryRecords() -> AdmitAndEvictData& @@ -270,10 +254,10 @@ auto FeatureAdmitAndEvict::GetHistoryRecords() -> AdmitAndEvictData& return m_recordsData; } -void FeatureAdmitAndEvict::LoadTableThresholds(table_2_thresh_mem_t& loadData) +void FeatureAdmitAndEvict::LoadTensorThresholds(tensor_2_thresh_mem_t& loadData) { std::lock_guard lock(m_syncMutexs); - m_table2Threshold = std::move(loadData); + m_tensor2Threshold = std::move(loadData); } void FeatureAdmitAndEvict::LoadHistoryRecords(AdmitAndEvictData& loadData) @@ -282,7 +266,7 @@ void FeatureAdmitAndEvict::LoadHistoryRecords(AdmitAndEvictData& loadData) m_recordsData = std::move(loadData); } -// 解析m_table2Threshold +// 解析m_tensor2Threshold bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& thresholdValues) { if (thresholdValues.empty()) { @@ -293,26 +277,23 @@ bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& m_cfgThresholds = thresholdValues; for (const auto& value : thresholdValues) { LOG(INFO) << StringFormat( - "embName[%s], count[%d], time[%d], coefficient[%d] ...", - value.tableName.c_str(), value.countThreshold, value.timeThreshold, value.faaeCoefficient); - auto it = m_table2Threshold.find(value.tableName); - if (it != m_table2Threshold.end()) { + "embName[%s], count[%d], time[%d] ...", + value.tensorName.c_str(), value.countThreshold, value.timeThreshold); + auto it = m_tensor2Threshold.find(value.tensorName); + if (it != m_tensor2Threshold.end()) { // train和eval同时开启,会出现表重复配置 - LOG(INFO) << StringFormat("[%s] is repeated configuration ...", value.tableName.c_str()); + LOG(INFO) << StringFormat("[%s] is repeated configuration ...", value.tensorName.c_str()); return true; } - m_table2Threshold[value.tableName] = value; + m_tensor2Threshold[value.tensorName] = value; - if (value.faaeCoefficient < 1) { - LOG(ERROR) << StringFormat("[%s] config error, coefficient smaller than 1 ...", value.tableName.c_str()); - } if (value.countThreshold != -1 && value.timeThreshold != -1) { - m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_BOTH; + m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_BOTH; } else if (value.countThreshold != -1 && value.timeThreshold == -1) { - m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_ONLY_ADMIT; + m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ONLY_ADMIT; } else { - LOG(ERROR) << StringFormat("[%s] config error, have evict but no admit ...", value.tableName.c_str()); - m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_ERROR; + LOG(ERROR) << StringFormat("[%s] config error, have evict but no admit ...", value.tensorName.c_str()); + m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ERROR; return false; } } @@ -320,7 +301,7 @@ bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& return true; } -std::vector FeatureAdmitAndEvict::GetAllNeedEvictTableNames() +std::vector FeatureAdmitAndEvict::GetAllNeedEvictTensorNames() { std::vector names; std::lock_guard lock(m_syncMutexs); diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 83e5a392..5dfcc3e0 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -58,13 +58,10 @@ namespace MxRec { // 特征淘汰接口 void FeatureEvict(map>& evictKeyMap); void ExecuteFeatureAdmit( - const string& tableName, int channel, keys_t& splitKey, absl::flat_hash_map& mergeKeys); + const string& tensorName, int channel, keys_t& splitKey, absl::flat_hash_map& mergeKeys); // 特征淘汰的使能接口 void SetFunctionSwitch(bool isEnableEvict); - // 特征准入合表统计 - void SetCombineSwitch(); - bool GetFunctionSwitch() const; void PreProcessKeys(const std::vector& splitKey, std::vector& keyCount, absl::flat_hash_map& mergeKeys); @@ -74,10 +71,10 @@ namespace MxRec { const std::vector& embNames, bool isTimestamp); // 与模型保存加载交互的接口 - auto GetTableThresholds() -> table_2_thresh_mem_t; + auto GetTensorThresholds() -> tensor_2_thresh_mem_t; auto GetHistoryRecords() -> AdmitAndEvictData&; - void LoadTableThresholds(table_2_thresh_mem_t& loadData); + void LoadTensorThresholds(tensor_2_thresh_mem_t& loadData); void LoadHistoryRecords(AdmitAndEvictData& loadData); static std::vector m_cfgThresholds; // 用于判断阈值配置的有效性 @@ -85,18 +82,17 @@ namespace MxRec { GTEST_PRIVATE : - // 解析m_table2Threshold + // 解析m_tensor2Threshold bool ParseThresholdCfg(const std::vector& thresholdValues); - std::vector GetAllNeedEvictTableNames(); - FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tableName, + std::vector GetAllNeedEvictTensorNames(); + FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tensorName, const int64_t featureId, const uint32_t featureCnt); void FeatureEvictHelper(const std::string& embName, std::vector& evictKey); void ResetAllRecords(); bool m_isEnableFunction { true }; // “特征淘汰”的使能开关 bool m_isExit { false }; // 淘汰线程退出的标识 - bool m_isCombine { true }; // 是否合并统计history - absl::flat_hash_map m_table2Threshold; // table-X ---> ThresholdValue 映射 + absl::flat_hash_map m_tensor2Threshold; // tensor-X ---> ThresholdValue 映射 AdmitAndEvictData m_recordsData; std::mutex m_syncMutexs; // 特征准入与特征淘汰竞争的同步锁 int m_recordsInitSize { DEFAULT_RECORDS_INIT_SIZE }; // m_historyRecords表初始容量 diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index d3ddbc2a..706c16ed 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -141,20 +141,4 @@ namespace MxRec { throw std::runtime_error("dsmi_get_chip_info failed, ret = " + to_string(ret)); } - - bool GetCombineSwitch() - { - const char* faaeMode = std::getenv("FAAE_MODE"); // 获取环境变量 - bool isCombine = true; - if (faaeMode != nullptr) { - try { - isCombine = (std::stoi(faaeMode) == 0); - LOG(INFO) << StringFormat("If combine history table: %d", isCombine); - } catch (const std::invalid_argument& e) { - LOG(ERROR) << "The value of FAAE_MODE is invalid!"; - throw std::invalid_argument("Invalid env value FAAE_MODE"); - } - } - return isCombine; - } } // end namespace MxRec diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 60d4a97b..f51d50c0 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -88,8 +88,6 @@ namespace MxRec { constexpr int HOT_EMB_UPDATE_STEP_DEFAULT = 1000; constexpr float HOT_EMB_CACHE_PCT = static_cast(1. / 3); // hot emb cache percent - const string COMBINE_HISTORY_NAME = "combine_table_history"; - using emb_key_t = int64_t; using emb_name_t = std::string; using keys_t = std::vector; @@ -104,7 +102,6 @@ namespace MxRec { }; string GetChipName(int devID); - bool GetCombineSwitch(); namespace UBSize { const int ASCEND910_PREMIUM_A = 262144; @@ -242,18 +239,16 @@ struct BatchTask { struct ThresholdValue { ThresholdValue() = default; - ThresholdValue(emb_name_t name, int countThre, int timeThre, int faaeCoef) + ThresholdValue(emb_name_t name, int countThre, int timeThre) { - tableName = name; + tensorName = name; countThreshold = countThre; timeThreshold = timeThre; - faaeCoefficient = faaeCoef; } - emb_name_t tableName { "" }; // embName + emb_name_t tensorName { "" }; // embName int countThreshold { -1 }; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 int timeThreshold { -1 }; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 - int faaeCoefficient { 1 }; // 配置后,该表在准入时,count计数会乘以该系数 }; struct FeatureItemInfo { @@ -497,7 +492,7 @@ struct BatchTask { using emb_hash_mem_t = absl::flat_hash_map; using offset_mem_t = std::map; using key_offset_mem_t = std::map>; - using table_2_thresh_mem_t = absl::flat_hash_map; + using tensor_2_thresh_mem_t = absl::flat_hash_map; using trans_serialize_t = uint8_t; using key_offset_map_t = std::map; using all_key_offset_map_t = std::map>; @@ -515,7 +510,7 @@ struct BatchTask { emb_hash_mem_t embHashMaps; offset_mem_t maxOffset; key_offset_mem_t keyOffsetMap; - table_2_thresh_mem_t table2Thresh; + tensor_2_thresh_mem_t tens2Thresh; AdmitAndEvictData histRec; }; @@ -537,7 +532,7 @@ struct BatchTask { EMB_CURR_STAT = 4, NDDR_OFFSET = 5, NDDR_FEATMAP = 6, - TABLE_2_THRESH = 7, + TENSOR_2_THRESH = 7, HIST_REC = 8, ATTRIBUTE = 9 }; diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 9c1b5fb2..b81e5c6d 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -169,9 +169,8 @@ void GetHybridMgmt(pybind11::module_& m) void GetThresholdValue(pybind11::module_& m) { pybind11::class_(m, "ThresholdValue") - .def(pybind11::init()) - .def_readwrite("table_name", &ThresholdValue::tableName) + .def(pybind11::init()) + .def_readwrite("tensor_name", &ThresholdValue::tensorName) .def_readwrite("count_threshold", &ThresholdValue::countThreshold) - .def_readwrite("time_threshold", &ThresholdValue::timeThreshold) - .def_readwrite("faae_coefficient", &ThresholdValue::faaeCoefficient); + .def_readwrite("time_threshold", &ThresholdValue::timeThreshold); } diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 1cf27378..33df7b23 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -175,18 +175,17 @@ protected: } } - void SetTable2Threshold(table_2_thresh_mem_t& testTable2Threshold) + void SetTens2Threshold(tensor_2_thresh_mem_t& testTens2Threshold) { for (const auto& testEmbInfo : testEmbInfos) { ThresholdValue val; - val.tableName = testEmbInfo.name; + val.tensorName = testEmbInfo.name; val.countThreshold = offsetMem; val.timeThreshold = offsetMem; - val.faaeCoefficient = 1; offsetMem++; - testTable2Threshold[testEmbInfo.name] = move(val); + testTens2Threshold[testEmbInfo.name] = move(val); } } @@ -215,30 +214,6 @@ protected: timeStamp++; } } - - void SetHistRecCombine(AdmitAndEvictData& histRec) - { - int64_t featureId { int64Min }; - int count { 1 }; - time_t lastTime { 1000 }; - time_t timeStamp { 10000 }; - - auto& historyRecords { histRec.historyRecords[COMBINE_HISTORY_NAME] }; - auto& timestamps { histRec.timestamps[COMBINE_HISTORY_NAME] }; - - timestamps = timeStamp; - - for (int i = 0; i < count; ++i) { - historyRecords[featureId].count = count; - historyRecords[featureId].lastTime = lastTime; - - featureId++; - } - - count++; - lastTime++; - timeStamp++; - } }; TEST_F(CheckpointTest, HostEmbs) @@ -428,31 +403,25 @@ TEST_F(CheckpointTest, AllMgmt) TEST_F(CheckpointTest, FeatAdmitNEvict) { - table_2_thresh_mem_t testTrens2Thresh; - table_2_thresh_mem_t validTrens2Thresh; + tensor_2_thresh_mem_t testTrens2Thresh; + tensor_2_thresh_mem_t validTrens2Thresh; AdmitAndEvictData testHistRec; AdmitAndEvictData validHistRec; SetEmbInfo(); - SetTable2Threshold(testTrens2Thresh); + SetTens2Threshold(testTrens2Thresh); validTrens2Thresh = testTrens2Thresh; - - if (GetCombineSwitch()) { - SetHistRecCombine(testHistRec); - } else { - SetHistRec(testHistRec); - } - + SetHistRec(testHistRec); validHistRec = testHistRec; CkptData testSaveData; CkptData validLoadData; CkptData testLoadData; - testSaveData.table2Thresh = testTrens2Thresh; + testSaveData.tens2Thresh = testTrens2Thresh; testSaveData.histRec.timestamps = testHistRec.timestamps; testSaveData.histRec.historyRecords = testHistRec.historyRecords; - validLoadData.table2Thresh = validTrens2Thresh; + validLoadData.tens2Thresh = validTrens2Thresh; validLoadData.histRec = validHistRec; validLoadData.histRec.timestamps = validHistRec.timestamps; validLoadData.histRec.historyRecords = validHistRec.historyRecords; @@ -461,16 +430,16 @@ TEST_F(CheckpointTest, FeatAdmitNEvict) testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::FEAT_ADMIT_N_EVICT }); - EXPECT_EQ(validLoadData.table2Thresh.size(), testLoadData.table2Thresh.size()); + EXPECT_EQ(validLoadData.tens2Thresh.size(), testLoadData.tens2Thresh.size()); EXPECT_EQ(validLoadData.histRec.historyRecords.size(), testLoadData.histRec.historyRecords.size()); - for (const auto& it : validLoadData.table2Thresh) { - EXPECT_EQ(1, testLoadData.table2Thresh.count(it.first)); + for (const auto& it : validLoadData.tens2Thresh) { + EXPECT_EQ(1, testLoadData.tens2Thresh.count(it.first)); - const auto& table2Thresh = testLoadData.table2Thresh.at(it.first); + const auto& tens2Thresh = testLoadData.tens2Thresh.at(it.first); - EXPECT_EQ(it.second.tableName, table2Thresh.tableName); - EXPECT_EQ(it.second.countThreshold, table2Thresh.countThreshold); - EXPECT_EQ(it.second.timeThreshold, table2Thresh.timeThreshold); + EXPECT_EQ(it.second.tensorName, tens2Thresh.tensorName); + EXPECT_EQ(it.second.countThreshold, tens2Thresh.countThreshold); + EXPECT_EQ(it.second.timeThreshold, tens2Thresh.timeThreshold); } for (const auto& it : validLoadData.histRec.timestamps) { diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 55babf89..6caace33 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -22,17 +22,6 @@ using valid_int64_t = absl::flat_hash_map>; using valie_dataset_t = absl::flat_hash_map>; using valid_attrib_t = absl::flat_hash_map>; -struct InputArgs { - vector& embNames; - CkptData& validData; - FeatAdmitNEvictCkpt& testCkpt; - valid_int_t& validTrens2ThreshArr; - valid_attrib_t& validTrens2ThreshAttrib; - valid_attrib_t& validHistRecAttrib; - valid_int64_t& validHistRecArr; - CkptData& testData; -}; - class CkptDataHandlerTest : public testing::Test { protected: int floatBytes { 4 }; @@ -79,7 +68,7 @@ protected: } } - void SetTable2Threshold(table_2_thresh_mem_t& testTable2Threshold, + void SetTens2Threshold(tensor_2_thresh_mem_t& testTens2Threshold, valid_int_t& validArr, valid_attrib_t& validAttrib) { @@ -88,7 +77,7 @@ protected: for (const auto& testEmbInfo : testEmbInfos) { ThresholdValue val; - val.tableName = testEmbInfo.name; + val.tensorName = testEmbInfo.name; val.countThreshold = countThreshold; val.timeThreshold = timeThreshold; @@ -97,7 +86,7 @@ protected: countThreshold++; timeThreshold++; - testTable2Threshold[testEmbInfo.name] = move(val); + testTens2Threshold[testEmbInfo.name] = move(val); validArr[testEmbInfo.name] = move(valid); validAttrib[testEmbInfo.name].push_back(2); // 2 is element num in one vector validAttrib[testEmbInfo.name].push_back(int32Bytes); @@ -139,114 +128,12 @@ protected: timeStamp++; } } - - void SetHistRecCombine(AdmitAndEvictData& histRec, valid_int64_t& validArr, valid_attrib_t& validAttrib) - { - int64_t featureId { int64Min }; - int count { 1 }; - time_t lastTime { 1000 }; - time_t timeStamp { 10000 }; - - auto& validA { validArr[COMBINE_HISTORY_NAME] }; - auto& historyRecords { histRec.historyRecords[COMBINE_HISTORY_NAME] }; - auto& timestamps { histRec.timestamps[COMBINE_HISTORY_NAME] }; - - timestamps = timeStamp; - validA.push_back(timeStamp); - - for (int i = 0; i < count; ++i) { - historyRecords[featureId].count = count; - historyRecords[featureId].lastTime = lastTime; - - validA.push_back(featureId); - validA.push_back(count); - validA.push_back(lastTime); - - featureId++; - } - - auto& attribute = validAttrib[COMBINE_HISTORY_NAME]; - attribute.push_back(count); - attribute.push_back(int64Bytes); - - count++; - lastTime++; - timeStamp++; - } - - void TestForSave(InputArgs& args) - { - for (const auto& embName : args.embNames) { - EXPECT_EQ(1, args.validData.table2Thresh.count(embName)); - - CkptTransData testSaveData = args.testCkpt.GetDataset(CkptDataType::TABLE_2_THRESH, embName); - EXPECT_EQ(args.validTrens2ThreshArr.at(embName), testSaveData.int32Arr); // need other test method - EXPECT_EQ(args.validTrens2ThreshAttrib.at(embName), testSaveData.attribute); - testSaveData = args.testCkpt.GetDataset(CkptDataType::HIST_REC, embName); - - if (!GetCombineSwitch()) { - EXPECT_EQ(1, args.validData.histRec.timestamps.count(embName)); - EXPECT_EQ(1, args.validData.histRec.historyRecords.count(embName)); - EXPECT_EQ(args.validHistRecAttrib.at(embName), testSaveData.attribute); - } else { - EXPECT_EQ(1, args.validData.histRec.timestamps.count(COMBINE_HISTORY_NAME)); - EXPECT_EQ(1, args.validData.histRec.historyRecords.count(COMBINE_HISTORY_NAME)); - EXPECT_EQ(args.validHistRecAttrib.at(COMBINE_HISTORY_NAME), testSaveData.attribute); - } - } - } - void TestForLoad(InputArgs& args) - { - CkptTransData testLoadData; - for (const auto& embName : args.embNames) { - testLoadData.int32Arr = args.validTrens2ThreshArr.at(embName); - testLoadData.attribute = args.validTrens2ThreshAttrib.at(embName); - args.testCkpt.SetDataset(CkptDataType::TABLE_2_THRESH, embName, testLoadData); - - if (!GetCombineSwitch()) { - testLoadData.int64Arr = args.validHistRecArr.at(embName); - testLoadData.attribute = args.validHistRecAttrib.at(embName); - } else { - testLoadData.int64Arr = args.validHistRecArr.at(COMBINE_HISTORY_NAME); - testLoadData.attribute = args.validHistRecAttrib.at(COMBINE_HISTORY_NAME); - } - args.testCkpt.SetDataset(CkptDataType::HIST_REC, embName, testLoadData); - } - args.testCkpt.GetProcessData(args.testData); - - EXPECT_EQ(args.validData.table2Thresh.size(), args.testData.table2Thresh.size()); - EXPECT_EQ(args.validData.histRec.historyRecords.size(), args.testData.histRec.historyRecords.size()); - for (const auto& it : args.validData.table2Thresh) { - EXPECT_EQ(1, args.testData.table2Thresh.count(it.first)); - - const auto& table2Thresh = args.testData.table2Thresh.at(it.first); - - EXPECT_EQ(it.second.tableName, table2Thresh.tableName); - EXPECT_EQ(it.second.countThreshold, table2Thresh.countThreshold); - EXPECT_EQ(it.second.timeThreshold, table2Thresh.timeThreshold); - } - - for (const auto& it : args.validData.histRec.timestamps) { - EXPECT_EQ(1, args.testData.histRec.timestamps.count(it.first)); - EXPECT_EQ(1, args.testData.histRec.historyRecords.count(it.first)); - - const auto& historyRecords = args.testData.histRec.historyRecords.at(it.first); - const auto& validHistRec = args.validData.histRec.historyRecords.at(it.first); - - for (const auto& validHR : validHistRec) { - const auto& testHR = historyRecords.at(validHR.first); - - EXPECT_EQ(validHR.second.count, testHR.count); - EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); - } - } - } }; TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) { - table_2_thresh_mem_t testTrens2Thresh; - table_2_thresh_mem_t validTrens2Thresh; + tensor_2_thresh_mem_t testTrens2Thresh; + tensor_2_thresh_mem_t validTrens2Thresh; AdmitAndEvictData testHistRec; AdmitAndEvictData validHistRec; @@ -256,37 +143,77 @@ TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) valid_attrib_t validHistRecAttrib; SetEmbInfo(); - SetTable2Threshold(testTrens2Thresh, validTrens2ThreshArr, validTrens2ThreshAttrib); + SetTens2Threshold(testTrens2Thresh, validTrens2ThreshArr, validTrens2ThreshAttrib); validTrens2Thresh = testTrens2Thresh; - - if (GetCombineSwitch()) { - SetHistRecCombine(testHistRec, validHistRecArr, validHistRecAttrib); - } else { - SetHistRec(testHistRec, validHistRecArr, validHistRecAttrib); - } + SetHistRec(testHistRec, validHistRecArr, validHistRecAttrib); validHistRec = testHistRec; CkptData testData; CkptData validData; FeatAdmitNEvictCkpt testCkpt; - testData.table2Thresh = testTrens2Thresh; + testData.tens2Thresh = testTrens2Thresh; testData.histRec.timestamps = testHistRec.timestamps; testData.histRec.historyRecords = testHistRec.historyRecords; - validData.table2Thresh = validTrens2Thresh; + validData.tens2Thresh = validTrens2Thresh; validData.histRec.timestamps = validHistRec.timestamps; validData.histRec.historyRecords = validHistRec.historyRecords; testCkpt.SetProcessData(testData); vector embNames { testCkpt.GetEmbNames() }; - EXPECT_EQ(validData.table2Thresh.size(), embNames.size()); + CkptTransData testSaveData; + EXPECT_EQ(validData.tens2Thresh.size(), embNames.size()); + + for (const auto& embName : embNames) { + EXPECT_EQ(1, validData.tens2Thresh.count(embName)); + + EXPECT_EQ(1, validData.histRec.timestamps.count(embName)); + EXPECT_EQ(1, validData.histRec.historyRecords.count(embName)); - InputArgs args = {embNames, validData, testCkpt, validTrens2ThreshArr, validTrens2ThreshAttrib, - validHistRecAttrib, validHistRecArr, testData}; - // 测试save - TestForSave(args); + testSaveData = testCkpt.GetDataset(CkptDataType::TENSOR_2_THRESH, embName); + EXPECT_EQ(validTrens2ThreshArr.at(embName), testSaveData.int32Arr); // need other test method + EXPECT_EQ(validTrens2ThreshAttrib.at(embName), testSaveData.attribute); + testSaveData = testCkpt.GetDataset(CkptDataType::HIST_REC, embName); + EXPECT_EQ(validHistRecAttrib.at(embName), testSaveData.attribute); + } + + CkptTransData testLoadData; + for (const auto& embName : embNames) { + testLoadData.int32Arr = validTrens2ThreshArr.at(embName); + testLoadData.attribute = validTrens2ThreshAttrib.at(embName); + testCkpt.SetDataset(CkptDataType::TENSOR_2_THRESH, embName, testLoadData); + + testLoadData.int64Arr = validHistRecArr.at(embName); + testLoadData.attribute = validHistRecAttrib.at(embName); + testCkpt.SetDataset(CkptDataType::HIST_REC, embName, testLoadData); + } + testCkpt.GetProcessData(testData); - // 测试load - TestForLoad(args); + EXPECT_EQ(validData.tens2Thresh.size(), testData.tens2Thresh.size()); + EXPECT_EQ(validData.histRec.historyRecords.size(), testData.histRec.historyRecords.size()); + for (const auto& it : validData.tens2Thresh) { + EXPECT_EQ(1, testData.tens2Thresh.count(it.first)); + + const auto& tens2Thresh = testData.tens2Thresh.at(it.first); + + EXPECT_EQ(it.second.tensorName, tens2Thresh.tensorName); + EXPECT_EQ(it.second.countThreshold, tens2Thresh.countThreshold); + EXPECT_EQ(it.second.timeThreshold, tens2Thresh.timeThreshold); + } + + for (const auto& it : validData.histRec.timestamps) { + EXPECT_EQ(1, testData.histRec.timestamps.count(it.first)); + EXPECT_EQ(1, testData.histRec.historyRecords.count(it.first)); + + const auto& historyRecords = testData.histRec.historyRecords.at(it.first); + const auto& validHistRec = validData.histRec.historyRecords.at(it.first); + + for (const auto& validHR : validHistRec) { + const auto& testHR = historyRecords.at(validHR.first); + + EXPECT_EQ(validHR.second.count, testHR.count); + EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); + } + } } \ No newline at end of file diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index a570915d..752db8e2 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -179,7 +179,7 @@ TEST_F(EmbMgmtTest, Initialize_HBM) embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1, 1); + thresholdValues.emplace_back(name, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; @@ -199,7 +199,7 @@ TEST_F(EmbMgmtTest, Evict) embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1, 1); + thresholdValues.emplace_back(name, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; @@ -222,7 +222,7 @@ TEST_F(EmbMgmtTest, Evict_HBM) embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1, 1); + thresholdValues.emplace_back(name, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index e94fe455..31bfdd1e 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -181,48 +181,48 @@ protected: printf("\t############# [%s] tid[%lu] ############# begin ...\n", thrName.c_str(), std::hash{}(std::this_thread::get_id())); /* - {"tableAAA", 2, 5} + {"tensorAAA", 2, 5} keys1 = {11, 11, 33, 44, 11, 55, 88, 55} cnt1 = 1 2 1 3 1 1 4 1 */ InputArgs args1 = {keys1, cnt1, {}, initHistory, {}}; // 每个表的第一次记录,要用initHistory追加 - FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tableName, args1); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args1); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); /* - {"tableAAA", 2, 5} + {"tensorAAA", 2, 5} keys2 = {11, 12, 33, 21, 11, 12} cnt2 = 1 2 1 1 2 3 */ InputArgs args2 = {keys2, cnt2, {}, args1.expectHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tableName, args2); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args2); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tableBBB", 3, 7} + {"tensorBBB", 3, 7} keys3 = {123, 121, 121, 212, 211} cnt3 = 1 2 1 1 2 */ InputArgs args3 = {keys3, cnt3, {}, initHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tableName, args3); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tensorName, args3); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_6)); /* - {"tableAAA", 2, 5} + {"tensorAAA", 2, 5} keys4 = {11, 11, 33, 44, 55, 88, 55} cnt4 = 1 2 3 2 1 2 1 */ InputArgs args4 = {keys4, cnt4, {}, args2.expectHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tableName, args4); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args4); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tableBBB", 3, 7} + {"tensorBBB", 3, 7} keys5 = {125, 121, 122, 212, 211} cnt5 = 1 2 1 3 1 */ InputArgs args5 = {keys5, cnt5, {}, args3.expectHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tableName, args5); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tensorName, args5); printf("\t############# [%s] tid[%lu] ############# end ...\n", thrName.c_str(), std::hash{}(std::this_thread::get_id())); @@ -263,59 +263,58 @@ protected: { faae.ResetAllRecords(); faae.ParseThresholdCfg(thresholds); - faae.SetCombineSwitch(); StartEvictThread(); printf("Current test single-thread is [%lu]\n", std::hash{}(std::this_thread::get_id())); /* - {"tableAAA", 2, 5} + {"tensorAAA", 2, 5} keys1 = {11, 11, 33, 44, 11, 55, 88, 55} cnt1 = 1 2 1 3 1 1 4 1 */ keys_t expectRet1 = {11, 11, -1, 44, 11, 55, 88, 55}; InputArgs args1 = {keys1, cnt1, expectRet1, initHistory, {}}; // 每个表的第一次记录,要用initHistory追加 - FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args1); + FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args1); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); /* - {"tableAAA", 2, 5} + {"tensorAAA", 2, 5} keys2 = {11, 12, 33, 21, 11, 12} cnt2 = 1 2 1 1 2 3 */ keys_t expectRet2 = {11, 12, 33, -1, 11, 12}; InputArgs args2 = {keys2, cnt2, expectRet2, args1.expectHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args2); + FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args2); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tableBBB", 3, 7} + {"tensorBBB", 3, 7} keys3 = {123, 121, 121, 212, 211} cnt3 = 1 2 1 1 2 */ keys_t expectRet3 = {-1, 121, 121, -1, -1}; InputArgs args3 = {keys3, cnt3, expectRet3, initHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[1].tableName, args3); + FeatureAdmitCommon(faae, 0, thresholds[1].tensorName, args3); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_6)); /* - {"tableAAA", 2, 5} + {"tensorAAA", 2, 5} keys4 = {11, 11, 33, 44, 55, 88, 55} cnt4 = 1 2 3 2 1 2 1 */ keys_t expectRet4 = {11, 11, 33, 44, 55, 88, 55}; InputArgs args4 = {keys4, cnt4, expectRet4, args2.expectHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args4); + FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args4); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tableBBB", 3, 7} + {"tensorBBB", 3, 7} keys5 = {125, 121, 122, 212, 211} cnt5 = 1 2 1 3 1 */ keys_t expectRet5 = {-1, 121, -1, 212, 211}; InputArgs args5 = {keys5, cnt5, expectRet5, args3.expectHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[1].tableName, args5); + FeatureAdmitCommon(faae, 0, thresholds[1].tensorName, args5); WaitEvictThread(); LOG(INFO) << "TestCase1: single thread test over ..."; @@ -332,7 +331,7 @@ protected: vector tmpCnt = {1, 2, 1, 3, 1, 1, 4}; std::unique_ptr batch = make_unique(); - batch->name = thresholds[0].tableName; + batch->name = thresholds[0].tensorName; batch->timestamp = time(nullptr); // 校验调接口,出错 @@ -376,7 +375,6 @@ protected: { faae.ResetAllRecords(); faae.ParseThresholdCfg(thresholds); - faae.SetCombineSwitch(); StartEvictThread(); std::thread thrs[PerfConfig::keyProcessThreadNum]; @@ -398,9 +396,9 @@ protected: { /* 如果没有淘汰功能 - tableAAA数据将会是 {11, 12, 21, 33, 44, 55, 88} + tensorAAA数据将会是 {11, 12, 21, 33, 44, 55, 88} 10 5 1 5 5 4 6 - tableBBB数据将会是 {121, 122, 123, 125, 211, 212}; + tensorBBB数据将会是 {121, 122, 123, 125, 211, 212}; 5 1 1 1 3 4 */ keys_t expectKeys1 = {11, 33, 44, 55, 88}; // 12,21被淘汰掉了 @@ -408,8 +406,8 @@ protected: keys_t expectKeys2 = {121, 122, 125, 211, 212}; // 123被淘汰掉了 vector expectCnt2 = {5, 1, 1, 3, 4}; std::lock_guard lock(faae.m_syncMutexs); // 与 evict-thread 竞争资源 - CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tableName, PerfConfig::keyProcessThreadNum); - CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tableName, PerfConfig::keyProcessThreadNum); + CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tensorName, PerfConfig::keyProcessThreadNum); + CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tensorName, PerfConfig::keyProcessThreadNum); } WaitEvictThread(); @@ -423,8 +421,8 @@ protected: faae.ParseThresholdCfg(thresholds); std::unique_ptr batch = make_unique(); - // 测试点:tableDDD表没有配置阈值,则不支持 - batch->name = std::string("tableDDD"); + // 测试点:tensorDDD表没有配置阈值,则不支持 + batch->name = std::string("tensorDDD"); batch->timestamp = time(nullptr); // 校验调接口,不支持 @@ -445,21 +443,11 @@ protected: vector cnt4 = {1, 2, 3, 2, 1, 2, 1}; keys_t keys5 = {125, 121, 122, 212, 211}; vector cnt5 = {1, 2, 1, 3, 1}; - std::vector thresholds = {{"tableAAA", 2, 5, 1}, {"tableBBB", 3, 7, 1}, {"tableCCC", 5, 9, 1}}; + std::vector thresholds = {{"tensorAAA", 2, 5}, {"tensorBBB", 3, 7}, {"tensorCCC", 5, 9}}; }; -void SetEnv() -{ - const char* name = "FAAE_MODE"; - const char* mode = "1"; - int overwrite = 1; - - ASSERT_EQ(setenv(name, mode, overwrite), 0); -} - TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict1) { - SetEnv(); TestCase1(); } TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict2) @@ -476,7 +464,6 @@ TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict4) } TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict5) { - SetEnv(); TestCase5(); } TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict6) -- Gitee From cabfd28606242d879bee5e7760a0427613a16ea7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 24 Jul 2023 19:31:49 +0800 Subject: [PATCH 229/551] Match-id-51c50b0a9c91726665dc4803de6b0845bb98eba8 --- tests/mx_rec/core/initializer_mock.py | 13 + tests/mx_rec/core/mxrec_pybind_mock.py | 11 + tests/mx_rec/core/test_build_graph.py | 292 ++++++++++++++++++++++ tests/mx_rec/validator/test_validators.py | 161 ++++++++++++ 4 files changed, 477 insertions(+) create mode 100644 tests/mx_rec/core/initializer_mock.py create mode 100644 tests/mx_rec/core/mxrec_pybind_mock.py create mode 100644 tests/mx_rec/core/test_build_graph.py create mode 100644 tests/mx_rec/validator/test_validators.py diff --git a/tests/mx_rec/core/initializer_mock.py b/tests/mx_rec/core/initializer_mock.py new file mode 100644 index 00000000..a5c840da --- /dev/null +++ b/tests/mx_rec/core/initializer_mock.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import os + + +class InitializerMock: + """ + initializer mock module + """ + @staticmethod + def get_use_static(): + return os.getenv("use_static", True) diff --git a/tests/mx_rec/core/mxrec_pybind_mock.py b/tests/mx_rec/core/mxrec_pybind_mock.py new file mode 100644 index 00000000..f65356c2 --- /dev/null +++ b/tests/mx_rec/core/mxrec_pybind_mock.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + + +class MxRecPybindMock: + """ + mxrec_pybind mock module + """ + def get_ub_hot_size(self): + return 21845 diff --git a/tests/mx_rec/core/test_build_graph.py b/tests/mx_rec/core/test_build_graph.py new file mode 100644 index 00000000..291d297e --- /dev/null +++ b/tests/mx_rec/core/test_build_graph.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. +import os +import sys +import unittest +from dataclasses import dataclass +from unittest import mock + +import tensorflow as tf + +from tests.mx_rec.core.mxrec_pybind_mock import MxRecPybindMock +from tests.mx_rec.core.initializer_mock import InitializerMock +from mx_rec.core.asc import build_graph +from mx_rec.util.tf_version_adapter import npu_ops + +sys.modules['mxrec_pybind'] = MxRecPybindMock +sys.modules['mx_rec.util.initialize'] = InitializerMock + +os.environ[ + "HOST_PIPELINE_OPS_LIB_PATH"] = f"{os.getenv('so_path')}/libasc/libasc_ops.so" + + +@dataclass +class InputConfig: + batch_size: int + feat_cnt: int + send_count: int + rank_size: int + channel_id: int + table_name: str + skip_emb_transfer: bool + ext_emb_size: int + emb_size: int + use_hot: bool + device_id: int + use_dynamic_expansion: bool + + +class TestBuildGraph(unittest.TestCase): + """ + Test Suite for Exception Checkpoint. + """ + + def setUp(self): + """ + 准备步骤 + :return:无 + """ + super().setUp() + + def tearDown(self): + """ + 销毁步骤 + :return: 无 + """ + super().tearDown() + + @staticmethod + def get_next_mock(): + return tf.constant(value=1, name="inference/asecnd_lookup_one_big_embedding/all2all/mul", shape=[8, 8], + dtype=tf.int64) + + @staticmethod + def get_id_offsets_mock(): + return tf.constant(value=1, shape=[270412, 8], dtype=tf.float32, name="inference/gather_for_id_offsets"), [], 0 + + @staticmethod + def get_all2all_mock(): + return tf.constant(value=1, shape=[8, ], dtype=tf.int64, name="mul") + + @staticmethod + def get_restore_vector_mock(): + return [tf.constant(value=1, shape=[2908800], dtype=tf.int32, name="aicpu_getnext_restore_vector/GetNext"), + None] + + @staticmethod + def get_input_config(input_config_init: InputConfig): + batch_size = input_config_init.batch_size + feat_cnt = input_config_init.feat_cnt + send_count = input_config_init.send_count + rank_size = input_config_init.rank_size + channel_id = input_config_init.channel_id + table_name = input_config_init.table_name + skip_emb_transfer = input_config_init.skip_emb_transfer + ext_emb_size = input_config_init.ext_emb_size + emb_size = input_config_init.emb_size + use_hot = input_config_init.use_hot + device_id = input_config_init.device_id + use_dynamic_expansion = input_config_init.use_dynamic_expansion + + input_config = {'batch_size': batch_size, + 'feat_cnt': feat_cnt, + 'send_count': send_count, + 'rank_size': rank_size, + 'channel_id': channel_id, + 'table_name': table_name, + 'skip_emb_transfer': skip_emb_transfer, + 'ext_emb_size': ext_emb_size, + 'emb_size': emb_size, + 'use_hot': use_hot, + 'device_id': device_id, + 'use_dynamic_expansion': use_dynamic_expansion} + return input_config + + @staticmethod + def get_input_table(): + input_table = tf.Variable(tf.zeros([875000, 8]), name="inference/one_ascend_hash_embedding:0", + dtype=tf.float32) + return input_table + + @mock.patch("npu_bridge.hccl.hccl_ops") + @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") + def test_get_restore_vector(self, tf1_hccl_ops_mock, + tf1_save_mock): + with mock.patch.object(npu_ops, "gen_npu_ops") as mock_npu_ops: + from mx_rec.core.asc.build_graph import get_restore_vector + restore_vector_mock = tf.constant(value=1, shape=[2908800], dtype=tf.int32, + name="aicpu_getnext_restore_vector/GetNext") + hot_pos_mock = tf.constant( + value=1, + name="restore_vector/one_ascend_hash_embedding/GetNext", + shape=[2730, ], + dtype=tf.int32) + mock_npu_ops.get_next.return_value = [restore_vector_mock, hot_pos_mock] + tf1_hccl_ops_mock.return_value = None + tf1_save_mock.return_value = None + input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, True, 6, + False) + input_config = self.get_input_config(input_config_instance) + res_restore_vector, res_hot_emb = get_restore_vector(input_config) + self.assertEqual(res_restore_vector, restore_vector_mock) + self.assertEqual(res_hot_emb, hot_pos_mock) + + @mock.patch("npu_bridge.hccl.hccl_ops") + @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") + def test_get_restore_vector_no_hot_embed(self, tf1_hccl_ops_mock, + tf1_save_mock): + with mock.patch.object(npu_ops, "gen_npu_ops") as mock_npu_ops: + from mx_rec.core.asc.build_graph import get_restore_vector + restore_vector_mock = tf.constant(value=1, shape=[2908800], dtype=tf.int32, + name="aicpu_getnext_restore_vector/GetNext") + mock_npu_ops.get_next.return_value = [restore_vector_mock] + tf1_hccl_ops_mock.return_value = None + tf1_save_mock.return_value = None + input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, + False, 6, False) + input_config = self.get_input_config(input_config_instance) + res_restore_vector, hot_pos_vector = get_restore_vector(input_config) + self.assertEqual(res_restore_vector, restore_vector_mock) + self.assertIsNone(hot_pos_vector) + + @mock.patch('npu_bridge.estimator.npu_ops') + @mock.patch("npu_bridge.hccl.hccl_ops") + @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") + def test_get_restore_vector_use_hot(self, tf1_npu_ops_mock, tf1_hccl_ops_mock, + tf1_save_mock): + from mx_rec.core.asc.build_graph import get_restore_vector + tf1_npu_ops_mock.return_value = None + tf1_hccl_ops_mock.return_value = None + tf1_save_mock.return_value = None + input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, "8", True, 6, + False) + input_config = self.get_input_config(input_config_instance) + try: + get_restore_vector(input_config) + except TypeError as exp: + self.assertEqual(type(exp), TypeError) + else: + self.fail("TypeError not raised.") + + @mock.patch("npu_bridge.hccl.hccl_ops") + @mock.patch('npu_bridge.estimator.npu_ops') + @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") + def test_get_restore_vector_emb_size_value_error(self, tf1_hccl_ops_mock, tf1_npu_ops_mock, + tf1_save_mock): + from mx_rec.core.asc.build_graph import get_restore_vector + tf1_save_mock.return_value = None + tf1_npu_ops_mock.return_value = None + tf1_hccl_ops_mock.return_value = None + input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, -1, True, 6, + False) + input_config = self.get_input_config(input_config_instance) + try: + get_restore_vector(input_config) + except TypeError as exp: + self.assertEqual(type(exp), TypeError) + else: + self.fail("ValueError not raised.") + + @mock.patch('npu_bridge.estimator.npu_ops') + @mock.patch("npu_bridge.hccl.hccl_ops") + @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") + def test_get_restore_vector_ext_emb_size_type_error(self, tf1_npu_ops_mock, tf1_hccl_ops_mock, + tf1_save_mock): + from mx_rec.core.asc.build_graph import get_restore_vector + tf1_npu_ops_mock.return_value = None + tf1_hccl_ops_mock.return_value = None + tf1_save_mock.return_value = None + input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", False, "8", 8, True, 6, + False) + input_config = self.get_input_config(input_config_instance) + try: + get_restore_vector(input_config) + except TypeError as exp: + self.assertEqual(type(exp), TypeError) + else: + self.fail("TypeError not raised.") + + @mock.patch('npu_bridge.estimator.npu_ops') + @mock.patch("npu_bridge.hccl.hccl_ops") + @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") + def test_get_restore_vector_ext_emb_size_value_error(self, tf1_npu_ops_mock, tf1_hccl_ops_mock, + tf1_save_mock): + from mx_rec.core.asc.build_graph import get_restore_vector + tf1_npu_ops_mock.return_value = None + tf1_hccl_ops_mock.return_value = None + tf1_save_mock.return_value = None + input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", False, -1, 8, True, 6, + False) + input_config = self.get_input_config(input_config_instance) + try: + get_restore_vector(input_config) + except TypeError as exp: + self.assertEqual(type(exp), TypeError) + else: + self.fail("ValueError not raised.") + + @mock.patch("npu_bridge.hccl.hccl_ops") + @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") + def test_get_id_offsets(self, tf1_hccl_ops_mock, + tf1_save_mock): + with mock.patch.object(npu_ops, "gen_npu_ops") as mock_npu_ops: + from mx_rec.core.asc.build_graph import get_id_offsets + id_offset_mock = tf.constant(value=1, shape=[270412, 8], dtype=tf.float32, + name="inference/gather_for_id_offsets") + mock_npu_ops.get_next.return_value = [id_offset_mock] + tf1_hccl_ops_mock.return_value = None + tf1_save_mock.return_value = None + input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, True, 6, + False) + input_config = self.get_input_config(input_config_instance) + max_lookup_vec_size = None + res_id_offsets = get_id_offsets(max_lookup_vec_size, input_config) + self.assertEqual(res_id_offsets[0], id_offset_mock) + + @mock.patch("npu_bridge.hccl.hccl_ops") + @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") + def test_get_all2all_args(self, tf1_hccl_ops_mock, + tf1_save_mock): + with mock.patch.object(npu_ops, "gen_npu_ops") as mock_npu_ops: + from mx_rec.core.asc.build_graph import get_all2all_args + all2all_mock = tf.constant( + value=1, + name='mul', + shape=[8, 8], dtype=tf.int64) + mock_npu_ops.get_next.return_value = all2all_mock + tf1_hccl_ops_mock.return_value = None + tf1_save_mock.return_value = None + input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, True, 6, + False) + input_config = self.get_input_config(input_config_instance) + use_static = False + res_all2all_args = get_all2all_args(use_static, input_config) + self.assertEqual(res_all2all_args.shape, tf.constant(value=1, shape=[8, ], dtype=tf.int64, + name="mul").shape) + self.assertEqual(res_all2all_args.dtype, tf.constant(value=1, shape=[8, 8], dtype=tf.int64, + name="mul").dtype) + + @mock.patch("npu_bridge.hccl.hccl_ops") + @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") + def test_get_preprocessed_tensor_for_asc(self, tf1_hccl_ops_mock, tf1_save_mock): + with mock.patch.object(npu_ops, "gen_npu_ops", return_value=self.get_next_mock()), \ + mock.patch.object(build_graph, "get_id_offsets", return_value=self.get_id_offsets_mock()), \ + mock.patch.object(build_graph, "get_all2all_args", return_value=self.get_all2all_mock()), \ + mock.patch.object(build_graph, "get_restore_vector", return_value=self.get_restore_vector_mock()): + from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc + os.environ["use_static"] = "False" + tf1_hccl_ops_mock.return_value = None + tf1_save_mock.return_value = None + + input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, True, 6, + False) + input_config = self.get_input_config(input_config_instance) + sparse_table = tf.Variable(tf.zeros([875000, 8]), name='inference/one_ascend_hash_embedding', + dtype=tf.float32) + res = get_preprocessed_tensor_for_asc(sparse_table, input_config) + self.assertEqual(res.get("hot_pos"), self.get_restore_vector_mock()[1]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/validator/test_validators.py b/tests/mx_rec/validator/test_validators.py new file mode 100644 index 00000000..1b44c84b --- /dev/null +++ b/tests/mx_rec/validator/test_validators.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. +import os +import sys +import tempfile +import unittest + +from mx_rec.validator.validator import Validator, StringValidator, DirectoryValidator + +sys.modules['mxrec_pybind'] = __import__('os') + +os.environ[ + "HOST_PIPELINE_OPS_LIB_PATH"] = f"{os.getenv('so_path')}/libasc/libasc_ops.so" + + +class ParameterCheckerTest(unittest.TestCase): + def setUp(self): + """ + 准备步骤 + :return:无 + """ + super().setUp() + + def tearDown(self): + """ + 销毁步骤 + :return: 无 + """ + super().tearDown() + + def test_validator_should_return_default_if_invalid(self): + validation = Validator('aa') + validation.register_checker(lambda x: len(x) < 5, 'length of string should be less than 5') + self.assertTrue(validation.is_valid()) + try: + validation = Validator('123456') + validation.register_checker(lambda x: len(x) < 5, 'length of string should be less than 5') + validation.is_valid() + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + + def test_string_validator_max_len_parameter(self): + try: + StringValidator('aa.1245', max_len=3).check_string_length().check().is_valid() + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + + self.assertTrue(StringValidator('aa.1245', max_len=30).check().is_valid()) + # default infinity + self.assertTrue(StringValidator('aa.124512132456').check().is_valid()) + + def test_string_validator_min_len_parameter(self): + try: + StringValidator('aa', min_len=3).check_string_length().check().is_valid() + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + + self.assertTrue(StringValidator('aa.', min_len=3).check().is_valid()) + # default 0 + self.assertTrue(StringValidator('1').check().is_valid()) + + def test_string_validator_can_be_transformed2int(self): + self.assertFalse(StringValidator('9' * 20).can_be_transformed2int().check().is_valid()) + self.assertFalse(StringValidator('1,2').can_be_transformed2int().check().is_valid()) + self.assertTrue(StringValidator('12').can_be_transformed2int().check().is_valid()) + self.assertFalse(StringValidator('12').can_be_transformed2int(min_value=100, max_value=200).check().is_valid()) + + def test_directory_black_list(self): + try: + DirectoryValidator('/abc/d/e').with_blacklist(lst=['/abc/d/e']).check().is_valid() + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + + self.assertTrue(DirectoryValidator('/abc/d/e').with_blacklist(['/abc/d/']).check().is_valid()) + self.assertTrue(DirectoryValidator('/abc/d/e').with_blacklist(['/abc/d/'], exact_compare=True).check() + .is_valid()) + # if not exact compare, the /abc/d/e is children path of /abc/d/, so it is invalid + try: + self.assertFalse(DirectoryValidator('/abc/d/e').with_blacklist(['/abc/d/'], exact_compare=False) + .check().is_valid()) + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + self.assertTrue(DirectoryValidator('/usr/bin/bash').with_blacklist().check().is_valid()) + + try: + DirectoryValidator('/usr/bin/bash').with_blacklist(exact_compare=False).check().is_valid() + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + + def test_remove_prefix(self): + self.assertEqual(DirectoryValidator.remove_prefix('/usr/bin', None)[1], '/usr/bin') + self.assertEqual(DirectoryValidator.remove_prefix('/usr/bin', '')[1], '/usr/bin') + self.assertIsNone(DirectoryValidator.remove_prefix(None, 'abc')[1]) + self.assertEqual(DirectoryValidator.remove_prefix('/usr/bin/python', '/usr/bin')[1], "/python") + + def test_directory_white_list(self): + self.assertTrue(DirectoryValidator.check_is_children_path('/abc/d', '/abc/d/e')) + self.assertTrue(DirectoryValidator.check_is_children_path('/abc/d', '/abc/d/')) + self.assertFalse(DirectoryValidator.check_is_children_path('/abc/d', '/abc/de')) + self.assertTrue(DirectoryValidator.check_is_children_path('/usr/bin/', '/usr/bin/bash')) + + def test_directory_soft_link(self): + tmp = tempfile.NamedTemporaryFile(delete=True) + temp_dir = tempfile.mkdtemp() + path = os.path.join(temp_dir, 'link.ink') + # make a soft link + os.symlink(tmp.name, path) + + try: + # do stuff with temp + tmp.write(b'stuff') + DirectoryValidator(path).check_not_soft_link().check().is_valid() + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + finally: + tmp.close() # close means remove + os.remove(path) + os.removedirs(temp_dir) + + def test_directory_check(self): + + try: + DirectoryValidator('a/b/.././c/a.txt').check_not_soft_link().check().is_valid() + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + + try: + DirectoryValidator("").check_is_not_none().check().is_valid() + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + + try: + DirectoryValidator(None).check_is_not_none().check().is_valid() + except ValueError as exp: + self.assertEqual(type(exp), ValueError) + else: + self.fail("ValueError not raised.") + self.assertTrue(DirectoryValidator("a/bc/d").check().is_valid()) + + +if __name__ == '__main__': + unittest.main() -- Gitee From 524d306c7a9c31c91e21b74e81e91efc5b12bde8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 4 Aug 2023 16:18:49 +0800 Subject: [PATCH 230/551] Match-id-08ba26c1d59f2f04346b883a5c62a716973f6083 --- mx_rec/saver/sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index b4edb2cd..fbc54d4c 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -46,7 +46,7 @@ class SparseProcessor: @staticmethod def _get_data(data_dir, dtype, data_shape): - with open(data_dir, "rb", encoding="utf-8") as file: + with open(data_dir, "rb") as file: # check whether data file is valid file_validator = FileValidator(data_dir) # 1.check whether data_dir is soft link -- Gitee From 8b6cd1c2f33333569b6773f9d7fde605a663144b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 4 Aug 2023 11:39:35 +0800 Subject: [PATCH 231/551] Match-id-23f338b74522337c6220ca88b8fb8c5138bdcb0b --- example/little_demo/main.py | 8 +++++--- example/little_demo/run_mode.py | 5 +++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index d2840d5f..35569f2f 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -198,11 +198,13 @@ if __name__ == "__main__": TRAIN_STEPS, EVAL_STEPS ) - if MODIFY_GRAPH_FLAG: + # start host pipeline + if not MODIFY_GRAPH_FLAG: + start_asc_pipeline() + # start modify graph + if MODIFY_GRAPH_FLAG and use_mode != UseMode.TRAIN: logging.info("start to modifying graph") modify_graph_and_start_emb_cache(dump_graph=True) - else: - start_asc_pipeline() if use_mode == UseMode.TRAIN: run_mode.train(EVAL_INTERVAL, SAVING_INTERVAL) diff --git a/example/little_demo/run_mode.py b/example/little_demo/run_mode.py index bd068b5c..04b6ecc2 100644 --- a/example/little_demo/run_mode.py +++ b/example/little_demo/run_mode.py @@ -12,6 +12,7 @@ from mx_rec.util.initialize import get_initializer, get_rank_id, get_rank_size, from mx_rec.util.variable import get_dense_and_sparse_variable from mx_rec.util.tf_version_adapter import hccl_ops from mx_rec.constants.constants import BaseEnum +from mx_rec.graph.modifier import modify_graph_and_start_emb_cache class UseMode(BaseEnum): @@ -85,7 +86,11 @@ class RunMode: def train(self, eval_interval: int, saving_interval: int): self.set_train_ops() + + # In train mode, graph modify needs to be performed after compute gradients if self.is_modify_graph: + logging.info("start to modifying graph") + modify_graph_and_start_emb_cache(dump_graph=True) self.session.run(get_initializer(True)) else: self.session.run(self.train_iterator.initializer) -- Gitee From 115a6abe9dc9f7bf3b170753e6ab7d6ce468c1a2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 4 Aug 2023 16:20:51 +0800 Subject: [PATCH 232/551] Match-id-b1320ed77b5fe55d544887914207206bb1cf7477 --- src/core/key_process/key_process.cpp | 7 +++++-- src/tests/key_process/key_process_test.cpp | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index b0f3edbf..36c6c054 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -9,6 +9,7 @@ #include #include +#include #include @@ -125,13 +126,15 @@ int KeyProcess::Start() int threadNum; for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); - if (threadNumEnv != nullptr) { - threadNum = static_cast(*threadNumEnv) - static_cast('0'); + if (threadNumEnv != nullptr && regex_match(threadNumEnv, regex("[0-9]+"))) { + threadNum = std::atoi(threadNumEnv); if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { throw runtime_error(StringFormat("%d is not valid", threadNum)); } } else { threadNum = KEY_PROCESS_THREAD; + LOG(WARNING) << StringFormat("error value of ENV $KEY_PROCESS_THREAD_NUM," + " use default PROCESS_THREAD %d", threadNum); } LOG(INFO) << StringFormat(KEY_PROCESS "key process thread num: %d", threadNum); for (int id = 0; id < threadNum; ++id) { diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 83cb9cee..82bfc048 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -259,6 +259,9 @@ TEST_F(KeyProcessTest, Start) { ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); + setenv("KEY_PROCESS_THREAD_NUM", "2", 1); + ASSERT_EQ(process.Start(), 0); + setenv("KEY_PROCESS_THREAD_NUM", "abc", 1); ASSERT_EQ(process.Start(), 0); process.Destroy(); } -- Gitee From f7ee352de4bd69c3ce25c9bbde04480b3aa134a2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 7 Aug 2023 09:26:04 +0800 Subject: [PATCH 233/551] Match-id-25c900e4e4567a8a0e6b563a5a62dca2dea4d05c --- mx_rec/core/asc/helper.py | 146 +--------------------------- mx_rec/core/asc/manager.py | 6 +- mx_rec/core/asc/merge_table.py | 171 +++++++++++++++++++++++++++++++++ 3 files changed, 179 insertions(+), 144 deletions(-) create mode 100644 mx_rec/core/asc/merge_table.py diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index ac957463..e030f01e 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -4,15 +4,11 @@ import logging from functools import reduce -from typing import List -from typing import Dict import tensorflow as tf -from tensorflow import Tensor -from tensorflow import Operation -from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static, \ - get_enable_table_merge, export_table_instances, insert_dangling_table +from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.asc.merge_table import find_dangling_table, should_skip def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, feature_numbers=None, @@ -48,136 +44,6 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names **kwargs) -def find_dangling_table(table_names: List[str]): - """ Find the tables which are disconenct with the forward training graph. And - these table will not be backward updated. - - :param table_names: list of all created tables' names - :return: a list of dangling table names. - """ - - def check_tensor(table_reachable_tensor: Tensor): - """Check whether the tensor op is optimizer op or backward gradient. - - Args: - table_reachable_tensor: tensor - Returns: - bool - """ - if table_reachable_tensor.op.type == 'ApplyAdam': - return True - - if 'gradients/' in table_reachable_tensor.name and table_reachable_tensor.op.type == 'Identity': - return True - - if 'SparseSoftmaxCrossEntropyWithLogits' in table_reachable_tensor.op.name \ - and table_reachable_tensor.op.type == 'SparseSoftmaxCrossEntropyWithLogits': - return True - - return False - - def find_table_op(table_name: str, - the_op: Operation, - table_lookup_op: Dict[str, List[Operation]], - table_reachable_tensor: Dict[str, List[Tensor]]): - """ find all the table lookup op. - :param table_name: tables' names - :param the_op: the op to be - :param table_lookup_op: list of the table lookup ops - :param table_reachable_tensor: the tensors which table lookup op can reach ( - here we just add the table lookup op's output tensors). - The data structure is map, key is table_name, value is the output tensors of table lookup op. - :return: None - """ - if table_name in the_op.name and the_op.type == "IdentityN": - if table_name not in table_lookup_op: - table_lookup_op[table_name] = [the_op] - table_reachable_tensor[table_name] = the_op.outputs - else: - table_lookup_op[table_name].append(the_op) - table_reachable_tensor[table_name].extend(the_op.outputs) - - op_list = tf.compat.v1.get_default_graph().get_operations() - - table_lookup_op = {} - table_reachable_tensor = {} - - for _, table_instance in export_table_instances().items(): - if table_instance.table_name not in table_names: - table_names.append(table_instance.table_name) - - for the_op in op_list: - for table_name in table_names: - find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) - - logging.info(f"*********** find tables: {table_lookup_op} ***********") - - dangling_table = [] - - for table_name in table_names: - if table_name not in table_lookup_op: - logging.info(f"*********** created table {table_name} but never look up***********") - dangling_table.append(table_name) - insert_dangling_table(table_name) - - - def extend(op_list: List[Operation], - tensor: Tensor, - spread_tensors: List[Tensor]): - """extend the tensors which table lookup op can reach - - :param op_list: all op in the graph - :param tensor: the tensor visited by bfs - :param spread_tensors: the list of tensors which table lookup op can reach - :return: - """ - for the_op in op_list: - if tensor in the_op.inputs: - spread_tensors.extend(the_op.outputs) - - def bfs_lookup(next_to_visit: List[Tensor]): - """find all the tensors which table lookup op can reach - - :param next_to_visit: the tensor list to be visited by bfs - :return: bool value indicate whether reached optimizer op or backward gradient op - """ - tensors_visited = set() - while next_to_visit: - spread_tensors = [] - for tensor in next_to_visit: - if tensor in tensors_visited: - continue - if check_tensor(tensor): - return True - tensors_visited.add(tensor) - extend(op_list, tensor, spread_tensors) - next_to_visit = spread_tensors - return False - - for table_name, table_op in table_reachable_tensor.items(): - found = bfs_lookup(table_op) - if not found: - dangling_table.append(table_name) - insert_dangling_table(table_name) - return dangling_table - - -def should_skip(table_name): - from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN - if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ - and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, str) \ - and ASCEND_TABLE_NAME_MUST_CONTAIN not in table_name: - return True - if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ - and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, list): - skip = True - for key_word in ASCEND_TABLE_NAME_MUST_CONTAIN: - if isinstance(key_word, str) and key_word in table_name: - skip = False - break - return skip - return False - def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_counts=None, table_names=None, **kwargs): @@ -221,9 +87,7 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ if feature_counts is None or table_names is None: raise ValueError("Please config 'args_index_list', 'feature_counts' and 'table_names' at the same time.") - dangling_tables = [] - if get_enable_table_merge(): - dangling_tables = find_dangling_table(table_names) + dangling_tables = find_dangling_table(table_names) logging.info(f"In insert found dangling table(s): {dangling_tables} " f"which does not need to be provided to the EmbInfo.") @@ -241,12 +105,12 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ new_insert_tensors, new_splits, new_table_names = [], [], [] for idx, table_name in enumerate(table_names): if table_name in dangling_tables: - logging.info(f"do_insert skip table : {table_name}") + logging.info(f"do_insert skip table by graph : {table_name}") continue skip = should_skip(table_name) if skip: - logging.info(f"do_insert skip table 2: {table_name}") + logging.info(f"do_insert skip table by keyword: {table_name}") continue new_insert_tensors.append(insert_tensors[idx]) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index dfc8eb7a..95b1b846 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -11,8 +11,8 @@ from mx_rec.constants.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_steps, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ - get_use_hot, get_use_dynamic_expansion, get_enable_table_merge, export_optimizer, export_dangling_table -from mx_rec.core.asc.helper import find_dangling_table, should_skip + get_use_hot, get_use_dynamic_expansion, export_optimizer, export_dangling_table +from mx_rec.core.asc.merge_table import find_dangling_table, should_skip def check_dangling_table(): @@ -21,7 +21,7 @@ def check_dangling_table(): :return: list of dangling_table """ dangling_table = export_dangling_table() - if not dangling_table and get_enable_table_merge(): + if not dangling_table: dangling_table = find_dangling_table([table_instance.table_name for _, table_instance in export_table_instances().items()]) return dangling_table diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py new file mode 100644 index 00000000..6fde2eff --- /dev/null +++ b/mx_rec/core/asc/merge_table.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import logging +from typing import List +from typing import Dict +import tensorflow as tf +from tensorflow import Tensor +from tensorflow import Operation + +from mx_rec.util.initialize import get_enable_table_merge, export_table_instances, insert_dangling_table, \ + get_bool_gauge_set + + +def check_op(table_reachable_op: Operation): + """Check whether the tensor op is optimizer op or backward gradient. + + Args: + table_reachable_tensor: tensor + Returns: + bool + """ + if table_reachable_op.type == 'ApplyAdam': + return True + + if 'gradients' in table_reachable_op.name and \ + table_reachable_op.type in ['UnsortedSegmentSum','TensorScatterUpdate']: + return True + + return False + +def find_dangling_table(table_names: List[str]): + """ Find the tables which are disconenct with the forward training graph. And + these table will not be backward updated. + + :param table_names: list of all created tables' names + :return: a list of dangling table names. + """ + if not is_train_task(): + logging.info(f"!!merge table only available in train task.") + return [] + if not get_enable_table_merge(): + return [] + + + def find_table_op(table_name: str, + the_op: Operation, + table_lookup_op: Dict[str, List[Operation]], + table_reachable_tensor: Dict[str, List[Tensor]]): + """ find all the table lookup op. + :param table_name: tables' names + :param the_op: the op to be + :param table_lookup_op: list of the table lookup ops + :param table_reachable_tensor: the tensors which table lookup op can reach ( + here we just add the table lookup op's output tensors). + The data structure is map, key is table_name, value is the output tensors of table lookup op. + :return: None + """ + if table_name in the_op.name and the_op.type == "IdentityN": + if table_name not in table_lookup_op: + table_lookup_op[table_name] = [the_op] + table_reachable_tensor[table_name] = [] + table_reachable_tensor[table_name].extend(the_op.outputs) + elif the_op not in table_lookup_op[table_name]: + table_lookup_op[table_name].append(the_op) + table_reachable_tensor[table_name].extend(the_op.outputs) + + op_list = tf.compat.v1.get_default_graph().get_operations() + + table_lookup_op = {} + table_reachable_tensor = {} + + for _, table_instance in export_table_instances().items(): + if table_instance.table_name not in table_names: + table_names.append(table_instance.table_name) + + for the_op in op_list: + for table_name in table_names: + find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) + + logging.debug(f"*********** find tables: {table_lookup_op} ***********") + + dangling_table = [] + + for table_name in table_names: + if table_name not in table_lookup_op: + logging.debug(f"*********** created table {table_name} but never look up***********") + dangling_table.append(table_name) + insert_dangling_table(table_name) + + + def extend(op_list: List[Operation], + tensor: Tensor, + spread_tensors: List[Tensor]): + """extend the tensors which table lookup op can reach + + :param op_list: all op in the graph + :param tensor: the tensor visited by bfs + :param spread_tensors: the list of tensors which table lookup op can reach + :return: + """ + for the_op in op_list: + if tensor in the_op.inputs: + spread_tensors.extend(the_op.outputs) + + def bfs_lookup(next_to_visit: List[Tensor]): + """find all the tensors which table lookup op can reach + + :param next_to_visit: the tensor list to be visited by bfs + :return: bool value indicate whether reached optimizer op or backward gradient op + """ + tensors_visited = set() + op_visited = set() + while next_to_visit: + spread_tensors = [] + for tensor in next_to_visit: + if tensor in tensors_visited: + continue + if check_op(tensor.op): + return op_visited, True + tensors_visited.add(tensor) + op_visited.add(tensor.op) + extend(op_list, tensor, spread_tensors) + next_to_visit = spread_tensors + return op_visited, False + + for table_name, table_op in table_reachable_tensor.items(): + reach_op,found = bfs_lookup(table_op) + affirm = False + if not found: + for node in reach_op: + if node.type not in ["IdentityN","Reshape"]: + break + else: + affirm = True + if affirm: + dangling_table.append(table_name) + insert_dangling_table(table_name) + return dangling_table + + +def should_skip(table_name): + from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ + and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, str) \ + and ASCEND_TABLE_NAME_MUST_CONTAIN not in table_name: + return True + if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ + and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, list): + skip = True + for key_word in ASCEND_TABLE_NAME_MUST_CONTAIN: + if isinstance(key_word, str) and key_word in table_name: + skip = False + break + return skip + return False + +def is_train_task(): + bool_gauge_set = get_bool_gauge_set() + if len(bool_gauge_set) > 0: + if 'train' in bool_gauge_set or 'train_and_evaluate' in bool_gauge_set: + return True + if 'predict' in bool_gauge_set: + return False + else: + op_list = tf.compat.v1.get_default_graph().get_operations() + for op in op_list: + if check_op(op): + return True + return False -- Gitee From 55f42104c2e871dfa0440f89fa5abea55955419f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 7 Aug 2023 16:04:56 +0800 Subject: [PATCH 234/551] Match-id-9a941402d35956e5edd4178d307492f097614ac5 --- example/little_demo/main.py | 18 +- example/little_demo/run.sh | 3 +- mx_rec/core/asc/feature_spec.py | 15 +- mx_rec/core/asc/manager.py | 7 +- mx_rec/optimizers/base.py | 4 +- src/core/checkpoint/checkpoint.cpp | 2 +- src/core/checkpoint/checkpoint.h | 2 +- .../ckpt_data_handler/ckpt_data_handler.h | 2 +- .../feat_admit_n_evict_ckpt.cpp | 69 ++++--- .../feat_admit_n_evict_ckpt.h | 12 +- src/core/emb_hashmap/emb_hashmap.cpp | 4 +- src/core/host_emb/host_emb.cpp | 14 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 +- .../key_process/feature_admit_and_evict.cpp | 109 ++++++---- .../key_process/feature_admit_and_evict.h | 18 +- src/core/key_process/key_process.cpp | 8 +- src/core/utils/common.cpp | 16 ++ src/core/utils/common.h | 17 +- src/pybind/module_main.cpp | 7 +- src/tests/checkpoint/checkpoint_test.cpp | 63 ++++-- .../ckpt_data_handler_test.cpp | 195 ++++++++++++------ src/tests/emb_mgmt/emb_mgmt_test.cpp | 6 +- .../feature_admit_and_evict_test.cpp | 69 ++++--- 23 files changed, 440 insertions(+), 224 deletions(-) diff --git a/example/little_demo/main.py b/example/little_demo/main.py index 35569f2f..44d867b5 100644 --- a/example/little_demo/main.py +++ b/example/little_demo/main.py @@ -93,17 +93,21 @@ def create_feature_spec_list(use_timestamp=False): eviction_threshold = cfg.eviction_threshold if use_timestamp else None feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold), + eviction_threshold=eviction_threshold, + faae_coefficient=1), FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold)] + eviction_threshold=eviction_threshold, + faae_coefficient=4)] if use_multi_lookup: feature_spec_list.extend([FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold), + eviction_threshold=eviction_threshold, + faae_coefficient=1), FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table", access_threshold=access_threshold, - eviction_threshold=eviction_threshold)]) + eviction_threshold=eviction_threshold, + faae_coefficient=4)]) if use_timestamp: feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) return feature_spec_list @@ -152,8 +156,10 @@ if __name__ == "__main__": # access_threshold unit counts; eviction_threshold unit seconds ACCESS_AND_EVICT = None if USE_TIMESTAMP: - config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) - config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) + config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, + faae_coefficient=1) + config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, + faae_coefficient=4) ACCESS_AND_EVICT = dict(user_table=config_for_user_table, item_table=config_for_item_table) train_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) eval_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 2cd84a80..1d5b8b57 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -83,7 +83,8 @@ export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 -export UpdateEmb_V2=0 # 0: UpdateEmb同步更新;1:UpdateEmb_V2异步更新 +export UpdateEmb_V2=1 # 0: UpdateEmb同步更新;1:UpdateEmb_V2异步更新 +export USE_COMBINE_FAAE=0 # 0: separate history when faae; 1: combine history when faae ################# 性能调优相关 #################### export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 export FAST_UNIQUE=0 #if use fast unique diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index dd1bff70..50eede22 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -31,6 +31,7 @@ class FeatureSpec: self._feat_cnt = kwargs.get("feat_count") self._access_threshold = kwargs.get("access_threshold") self._eviction_threshold = kwargs.get("eviction_threshold") + self._faae_coefficient = kwargs.get("faae_coefficient", 1) self._is_timestamp = kwargs.get("is_timestamp") self.feat_pos_train = None self.feat_pos_eval = None @@ -56,6 +57,10 @@ class FeatureSpec: def eviction_threshold(self): return self._eviction_threshold + @property + def faae_coefficient(self): + return self._faae_coefficient + @property def index_key(self): return self._index_key @@ -120,6 +125,11 @@ class FeatureSpec: if self._eviction_threshold > MAX_INT32: raise ValueError(f"Eviction_threshold is too big that exceed int32.") + if self._faae_coefficient is not None: + check_natural_number(self._faae_coefficient, "eviction_threshold") + if self._faae_coefficient > MAX_INT32: + raise ValueError(f"Eviction_threshold is too big that exceed int32.") + if self._is_timestamp is not None: check_bool(self._is_timestamp, "is_timestamp") @@ -197,10 +207,13 @@ class FeatureSpec: def get_feature_spec(table_name, access_and_evict_config): access_threshold = None eviction_threshold = None + faae_coefficient = None if access_and_evict_config: access_threshold = access_and_evict_config.get("access_threshold") eviction_threshold = access_and_evict_config.get("eviction_threshold") - return FeatureSpec(table_name, access_threshold=access_threshold, eviction_threshold=eviction_threshold) + faae_coefficient = access_and_evict_config.get("faae_coefficient", 1) + return FeatureSpec(table_name, access_threshold=access_threshold, eviction_threshold=eviction_threshold, + faae_coefficient=faae_coefficient) def set_temporary_feature_spec_attribute(mock_feature_spec: FeatureSpec, total_feature_count: Union[int, tf.Tensor]): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index dfc8eb7a..3892837e 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -162,16 +162,19 @@ def generate_threshold_list(): threshold_list = [] for _, feature_spec in export_feature_spec().items(): + coef = 1 if feature_spec.faae_coefficient is None else feature_spec.faae_coefficient if feature_spec.eviction_threshold: threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, - feature_spec.eviction_threshold) + feature_spec.eviction_threshold, + coef) threshold_list.append(threshold) continue if feature_spec.access_threshold: threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, - -1) + -1, + coef) threshold_list.append(threshold) return threshold_list diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index 9ad62ec6..632e8ba3 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -41,7 +41,7 @@ class CustomizedOptimizer: self.base_name = name -def my_update_op(self, opt, grad): +def custom_update_op(self, opt, grad): if isinstance(grad, ops.Tensor): update_op = opt._apply_sparse(grad, self._v) # pylint: disable=protected-access return update_op @@ -50,5 +50,5 @@ def my_update_op(self, opt, grad): def patch_for_optimizer(): - _TensorProcessor.update_op = my_update_op + _TensorProcessor.update_op = custom_update_op logging.debug("update_op in Class optimizer._TensorProcessor has been patched.") \ No newline at end of file diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 3c5e92df..dfedfe5c 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -69,7 +69,7 @@ void Checkpoint::SetDataHandler(CkptData& ckptData) if (!ckptData.keyOffsetMap.empty()) { dataHandlers.push_back(make_unique()); } - if (!ckptData.tens2Thresh.empty() && !ckptData.histRec.timestamps.empty() && + if (!ckptData.table2Thresh.empty() && !ckptData.histRec.timestamps.empty() && !ckptData.histRec.historyRecords.empty()) { dataHandlers.push_back(make_unique()); } diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 4baa6f0c..4c3abc53 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -42,7 +42,7 @@ namespace MxRec { CkptDataType::EMB_INFO, CkptDataType::EMB_CURR_STAT, CkptDataType::NDDR_OFFSET, - CkptDataType::TENSOR_2_THRESH + CkptDataType::TABLE_2_THRESH }; const set int64TransSet{ CkptDataType::EMB_HASHMAP, diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index ecc7907c..aea1d2b7 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -47,7 +47,7 @@ namespace MxRec { "embedding_current_status", "max_offset", "key_offset_map", - "tensor_2_threshold", + "table_2_threshold", "history_record" }; const vector dataElmtBytes { 4, 4, 8, 8, 4, 4, 8, 4, 8, 8 }; diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index 1e4e9d69..a2d6f96e 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -14,18 +14,18 @@ using namespace MxRec; void FeatAdmitNEvictCkpt::SetProcessData(CkptData& processData) { ClearData(); - if (processData.tens2Thresh.empty() || processData.histRec.timestamps.empty() || + if (processData.table2Thresh.empty() || processData.histRec.timestamps.empty() || processData.histRec.historyRecords.empty()) { LOG(ERROR) << "Missing Feature Admit and Evict data"; throw std::runtime_error("Missing Feature Admit and Evict data"); } - saveTens2Thresh = std::move(processData.tens2Thresh); + saveTable2Thresh = std::move(processData.table2Thresh); saveHistRec = std::move(processData.histRec); } void FeatAdmitNEvictCkpt::GetProcessData(CkptData& processData) { - processData.tens2Thresh = std::move(loadTens2Thresh); + processData.table2Thresh = std::move(loadTable2Thresh); processData.histRec = std::move(loadHistRec); ClearData(); } @@ -43,7 +43,7 @@ vector FeatAdmitNEvictCkpt::GetDirNames() vector FeatAdmitNEvictCkpt::GetEmbNames() { vector embNames; - for (const auto& item : saveTens2Thresh) { + for (const auto& item : saveTable2Thresh) { embNames.push_back(item.first); } return embNames; @@ -51,8 +51,8 @@ vector FeatAdmitNEvictCkpt::GetEmbNames() CkptTransData FeatAdmitNEvictCkpt::GetDataset(CkptDataType dataType, string embName) { - map> dataTransMap { { CkptDataType::TENSOR_2_THRESH, - [=] { SetTens2ThreshTrans(embName); } }, + map> dataTransMap { { CkptDataType::TABLE_2_THRESH, + [=] { SetTable2ThreshTrans(embName); } }, { CkptDataType::HIST_REC, [=] { SetHistRecTrans(embName); } } }; CleanTransfer(); @@ -62,8 +62,8 @@ CkptTransData FeatAdmitNEvictCkpt::GetDataset(CkptDataType dataType, string embN void FeatAdmitNEvictCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { - map> dataLoadMap { { CkptDataType::TENSOR_2_THRESH, - [=] { SetTens2Thresh(embName); } }, + map> dataLoadMap { { CkptDataType::TABLE_2_THRESH, + [=] { SetTable2Thresh(embName); } }, { CkptDataType::HIST_REC, [=] { SetHistRec(embName); } } }; CleanTransfer(); @@ -73,27 +73,31 @@ void FeatAdmitNEvictCkpt::SetDataset(CkptDataType dataType, string embName, Ckpt void FeatAdmitNEvictCkpt::ClearData() { - saveTens2Thresh.clear(); - loadTens2Thresh.clear(); + saveTable2Thresh.clear(); + loadTable2Thresh.clear(); saveHistRec.timestamps.clear(); saveHistRec.historyRecords.clear(); loadHistRec.timestamps.clear(); loadHistRec.historyRecords.clear(); } -void FeatAdmitNEvictCkpt::SetTens2ThreshTrans(string embName) +void FeatAdmitNEvictCkpt::SetTable2ThreshTrans(string embName) { - auto tens2ThreshSize = GetTens2ThreshSize(); + auto table2ThreshSize = GetTable2ThreshSize(); auto& transArr = transferData.int32Arr; - const auto& tens2Thresh = saveTens2Thresh.at(embName); + const auto& table2Thresh = saveTable2Thresh.at(embName); - transArr.reserve(tens2ThreshSize); - transArr.push_back(tens2Thresh.countThreshold); - transArr.push_back(tens2Thresh.timeThreshold); + transArr.reserve(table2ThreshSize); + transArr.push_back(table2Thresh.countThreshold); + transArr.push_back(table2Thresh.timeThreshold); } +// save void FeatAdmitNEvictCkpt::SetHistRecTrans(string embName) { + if (GetCombineSwitch()) { + embName = COMBINE_HISTORY_NAME; + } auto histRecSize = GetHistRecSize(embName); auto& transArr = transferData.int64Arr; const auto& timeStamp = saveHistRec.timestamps.at(embName); @@ -109,18 +113,22 @@ void FeatAdmitNEvictCkpt::SetHistRecTrans(string embName) } } -void FeatAdmitNEvictCkpt::SetTens2Thresh(string embName) +void FeatAdmitNEvictCkpt::SetTable2Thresh(string embName) { const auto& transArr = transferData.int32Arr; - auto& tens2Thresh = loadTens2Thresh[embName]; + auto& tens2Thresh = loadTable2Thresh[embName]; - tens2Thresh.tensorName = embName; + tens2Thresh.tableName = embName; tens2Thresh.countThreshold = transArr[countThresholdIdx]; tens2Thresh.timeThreshold = transArr[timeThresholdIdx]; } +// load void FeatAdmitNEvictCkpt::SetHistRec(string embName) { + if (GetCombineSwitch()) { + embName = COMBINE_HISTORY_NAME; + } const auto& transArr = transferData.int64Arr; const auto& attribute = transferData.attribute; auto& timestamp = loadHistRec.timestamps[embName]; @@ -129,17 +137,26 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) timestamp = transArr.front(); size_t featItemInfoTotalSize = attribute.front() * static_cast(featItemInfoSaveNum); - for (size_t i = featItemInfoOffset; i < featItemInfoTotalSize + featItemInfoOffset; i += featItemInfoSaveNum) { - const auto& featureId = transArr[i + featureIdIdxOffset]; - const auto& count = transArr[i + countIdxOffset]; - const auto& lastTime = transArr[i + lastTimeIdxOffset]; + VLOG(GLOG_DEBUG) << StringFormat("====Start SetHistRec, name: %s, featItemInfoTotalSize: %ld", embName.c_str(), + featItemInfoTotalSize); - histRecs[featureId].count = static_cast(count); - histRecs[featureId].lastTime = lastTime; + size_t process = 0; + size_t printPerStep = ((featItemInfoTotalSize / 100) > 0 ? (featItemInfoTotalSize / 100) : 1); + for (size_t i = featItemInfoOffset; i < featItemInfoTotalSize + featItemInfoOffset; i += featItemInfoSaveNum) { + process = i % printPerStep; + if (process == 1) { + VLOG(GLOG_DEBUG) << StringFormat("====in SetHistRec, process : %f", i/featItemInfoTotalSize); + } + auto featureId = transArr[i + featureIdIdxOffset]; + auto count = transArr[i + countIdxOffset]; + auto lastTime = transArr[i + lastTimeIdxOffset]; + + histRecs.emplace(featureId, FeatureItemInfo(static_cast(count), lastTime)); } + VLOG(GLOG_DEBUG) << StringFormat("====End SetHistRec, name: %s", embName.c_str()); } -int FeatAdmitNEvictCkpt::GetTens2ThreshSize() +int FeatAdmitNEvictCkpt::GetTable2ThreshSize() { auto& attribute = transferData.attribute; auto& attribSize = transferData.attributeSize; diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h index 2b12d315..268120c3 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h @@ -31,7 +31,7 @@ namespace MxRec { private: const vector fileDirNames { "HashTable", "DDR" }; - const vector saveDataTypes { CkptDataType::TENSOR_2_THRESH, CkptDataType::HIST_REC }; + const vector saveDataTypes { CkptDataType::TABLE_2_THRESH, CkptDataType::HIST_REC }; const int featItemInfoSaveNum { 3 }; const int threshValSaveNum { 2 }; @@ -45,21 +45,21 @@ namespace MxRec { const int countIdxOffset { 1 }; const int lastTimeIdxOffset { 2 }; - tensor_2_thresh_mem_t saveTens2Thresh; - tensor_2_thresh_mem_t loadTens2Thresh; + table_2_thresh_mem_t saveTable2Thresh; + table_2_thresh_mem_t loadTable2Thresh; AdmitAndEvictData saveHistRec; AdmitAndEvictData loadHistRec; void ClearData(); - void SetTens2ThreshTrans(string embName); + void SetTable2ThreshTrans(string embName); void SetHistRecTrans(string embName); - void SetTens2Thresh(string embName); + void SetTable2Thresh(string embName); void SetHistRec(string embName); - int GetTens2ThreshSize(); + int GetTable2ThreshSize(); size_t GetHistRecSize(string embName); }; } diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index d1918cde..5e3b3dbb 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -94,8 +94,8 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t } embHashMap.swapPos.clear(); embHashMap.lookUpVec.clear(); - LOG(INFO) << StringFormat("current dev emb usage:%d/[%d+%d]", embHashMap.maxOffset, embHashMap.devVocabSize, - embHashMap.hostVocabSize); + LOG(INFO) << StringFormat("current ddr emb:%s, usage:%d/[%d+%d]", embName.c_str(), embHashMap.maxOffset, + embHashMap.devVocabSize, embHashMap.hostVocabSize); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto swapLen = tmpDataOut.back().flat(); swapLen(0) = swapSize; diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 5b1e9bc6..b9a82787 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -142,12 +142,15 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, return; } const Tensor& d2hEmb = tensors[0]; - LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmb End missingkeys len = %d", missingKeysHostPos.size()); EASY_BLOCK("Update") const float* tensorPtr = d2hEmb.flat().data(); auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; auto& embData = hostEmbs[embName].embData; + VLOG(GLOG_DEBUG) << StringFormat(HOSTEMB + "embName:%s, UpdateEmb missingKeys len = %d, embeddingSize = %d, " + "embData.size = %d", embName.c_str(), missingKeysHostPos.size(), embeddingSize, + embData.size()); + #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ shared(missingKeysHostPos, tensorPtr, embData, embeddingSize) for (size_t i = 0; i < missingKeysHostPos.size(); i++) { @@ -176,7 +179,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI return; } TimeCost tc = TimeCost(); - LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmb End missingkeys len = %d", missingKeysHostPos.size()); + EASY_BLOCK("Update") auto& embData = hostEmbs[embName].embData; auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; @@ -185,6 +188,13 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI throw runtime_error("Acl get tensor data from dataset failed."); } float* ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); + + size_t elementSize = acltdtGetDataSizeFromItem(aclData); + size_t dimNum = acltdtGetDimNumFromItem(aclData); + VLOG(GLOG_DEBUG) << StringFormat(HOSTEMB + "embName:%s, UpdateEmb missingKeys len = %d, embeddingSize = %d," + " embData.size = %d, RecvAcl = %d, elementSize = %d, dimNum = %d", + embName.c_str(), missingKeysHostPos.size(), embeddingSize, embData.size(), + size, elementSize, dimNum); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ptr, embData, embeddingSize) for (size_t j = 0; j < missingKeysHostPos.size(); j++) { auto& dst = embData[missingKeysHostPos[j]]; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 0f8b72f0..cd400a52 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -150,7 +150,7 @@ bool HybridMgmt::Save(const string savePath) auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: feature admit and evict"); - saveData.tens2Thresh = featAdmitNEvict.GetTensorThresholds(); + saveData.table2Thresh = featAdmitNEvict.GetTableThresholds(); saveData.histRec.timestamps = featAdmitNEvict.GetHistoryRecords().timestamps; saveData.histRec.historyRecords = featAdmitNEvict.GetHistoryRecords().historyRecords; } @@ -201,7 +201,7 @@ bool HybridMgmt::Load(const string& loadPath) } if (featAdmitNEvict.GetFunctionSwitch()) { VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: feature admit and evict"); - featAdmitNEvict.LoadTensorThresholds(loadData.tens2Thresh); + featAdmitNEvict.LoadTableThresholds(loadData.table2Thresh); featAdmitNEvict.LoadHistoryRecords(loadData.histRec); } diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 29b2da1b..80db3003 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -31,6 +31,7 @@ bool FeatureAdmitAndEvict::Init(const std::vector& thresholdValu LOG(ERROR) << "Config is error, feature admin-and-evict function is not available ...\n"; return false; } + SetCombineSwitch(); return true; } @@ -44,24 +45,27 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR; } - // 如果当前 tensorName 不在准入范围之内,则不进行“特征准入”逻辑 - std::string tensorName = batch->name; + std::string tableName = batch->name; + if (m_isCombine) { + tableName = COMBINE_HISTORY_NAME; + } absl::flat_hash_map mergeKeys; mergeKeys.reserve(splitKey.size()); PreProcessKeys(splitKey, keyCount, mergeKeys); std::lock_guard lock(m_syncMutexs); - auto iter = m_recordsData.historyRecords.find(tensorName); - if (iter == m_recordsData.historyRecords.end()) { // 之前tensorName没出现过时,数据初始化 + // 如果当前 tableName 不在准入范围之内,则不进行“特征准入”逻辑 + auto iter = m_recordsData.historyRecords.find(tableName); + if (iter == m_recordsData.historyRecords.end()) { // 之前tableName没出现过时,数据初始化 absl::flat_hash_map records(m_recordsInitSize); - m_recordsData.historyRecords[tensorName] = records; + m_recordsData.historyRecords[tableName] = records; } VLOG(GLOG_DEBUG) << StringFormat( - "FeatureAdmitAndEvict PrintSize, name:[%s], history key:[%d] ...", tensorName.c_str(), - m_recordsData.historyRecords[tensorName].size()); + "FeatureAdmitAndEvict PrintSize, name:[%s], history key:[%d] ...", tableName.c_str(), + m_recordsData.historyRecords[tableName].size()); - if (batch->timestamp > m_recordsData.timestamps[tensorName]) { - m_recordsData.timestamps[tensorName] = batch->timestamp; + if (batch->timestamp > m_recordsData.timestamps[tableName]) { + m_recordsData.timestamps[tableName] = batch->timestamp; } absl::flat_hash_map visitedRecords; for (auto& key : splitKey) { @@ -73,7 +77,7 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, auto it = visitedRecords.find(key); if (it == visitedRecords.end()) { visitedRecords[key] = true; - if (FeatureAdmitHelper(channel, tensorName, key, mergeKeys[key]) == + if (FeatureAdmitHelper(channel, batch->name, key, mergeKeys[key]) == FeatureAdmitType::FEATURE_ADMIT_FAILED) { visitedRecords[key] = false; key = -1; // 被淘汰的Feature ID @@ -87,32 +91,38 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, } if (VLOG_IS_ON(GLOG_TRACE)) { VLOG(GLOG_TRACE) << StringFormat( - "FeatureAdmit, name:[%s], channel:[%d], after admit, splitKey:[%s] ...", tensorName.c_str(), channel, + "FeatureAdmit, name:[%s], channel:[%d], after admit, splitKey:[%s] ...", tableName.c_str(), channel, VectorToString(splitKey).c_str()); } return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; } -FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, const std::string& tensorName, +FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, const std::string& tableNameOrigin, const int64_t featureId, const uint32_t featureCnt) { // “特征准入”逻辑 uint32_t currKeyCount = 0; - absl::flat_hash_map& historyRecordInfos = m_recordsData.historyRecords[tensorName]; + std::string tableName = tableNameOrigin; + if (m_isCombine) { + tableName = COMBINE_HISTORY_NAME; + } + + absl::flat_hash_map& historyRecordInfos = m_recordsData.historyRecords[tableName]; auto innerIt = historyRecordInfos.find(featureId); if (channel == TRAIN_CHANNEL_ID) { if (innerIt == historyRecordInfos.end()) { // 维护 m_historyRecords - FeatureItemInfo info(featureCnt, m_recordsData.timestamps[tensorName]); + FeatureItemInfo info(featureCnt, m_recordsData.timestamps[tableName]); + info.count *= m_table2Threshold[tableNameOrigin].faaeCoefficient; historyRecordInfos[featureId] = info; - currKeyCount = featureCnt; + currKeyCount = info.count; } else { // 维护 m_historyRecords FeatureItemInfo &info = historyRecordInfos[featureId]; - info.count += featureCnt; - info.lastTime = m_recordsData.timestamps[tensorName]; + info.count += m_table2Threshold[tableNameOrigin].faaeCoefficient * featureCnt; + info.lastTime = m_recordsData.timestamps[tableName]; currKeyCount = info.count; } } else if (channel == EVAL_CHANNEL_ID) { // eval @@ -122,7 +132,7 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con } // 准入条件判断 - if (currKeyCount >= static_cast(m_tensor2Threshold[tensorName].countThreshold)) { + if (currKeyCount >= static_cast(m_table2Threshold[tableNameOrigin].countThreshold)) { return FeatureAdmitType::FEATURE_ADMIT_OK; } @@ -132,8 +142,8 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con // 特征淘汰接口 void FeatureAdmitAndEvict::FeatureEvict(map>& evictKeyMap) { - std::vector tensorNames = GetAllNeedEvictTensorNames(); - if (tensorNames.empty()) { + std::vector tableNames = GetAllNeedEvictTableNames(); + if (tableNames.empty()) { LOG(INFO) << "EmbNames is empty, no evict function ..."; return ; } @@ -143,9 +153,9 @@ void FeatureAdmitAndEvict::FeatureEvict(map> } std::lock_guard lock(m_syncMutexs); // 从 m_historyRecords 中淘汰删除 - size_t tensorCnt = tensorNames.size(); - for (size_t i = 0; i < tensorCnt; ++i) { - FeatureEvictHelper(tensorNames[i], evictKeyMap[tensorNames[i]]); + size_t tableCnt = tableNames.size(); + for (size_t i = 0; i < tableCnt; ++i) { + FeatureEvictHelper(tableNames[i], evictKeyMap[tableNames[i]]); } } @@ -153,7 +163,7 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v { // 从 m_historyRecords 中淘汰删除 time_t currTime = m_recordsData.timestamps[embName]; - // 从 m_tensor2SortedLastTime 获取当前要淘汰的featureId + // 从 m_table2SortedLastTime 获取当前要淘汰的featureId auto cmp = [](const auto& a, const auto& b) { return a.second.lastTime > b.second.lastTime; }; std::priority_queue, std::vector>, decltype(cmp)> lastTimePriority(cmp); @@ -161,7 +171,7 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v lastTimePriority.emplace(item); } while (!lastTimePriority.empty()) { - if (currTime - lastTimePriority.top().second.lastTime < m_tensor2Threshold[embName].timeThreshold) { + if (currTime - lastTimePriority.top().second.lastTime < m_table2Threshold[embName].timeThreshold) { break; } evictKey.emplace_back(lastTimePriority.top().first); @@ -170,11 +180,11 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v if (evictKey.size() == 0) { LOG(INFO) << StringFormat( - "tensor-name[%s]'s lastTime[%d], had no key to delete ...", embName.c_str(), currTime); + "table-name[%s]'s lastTime[%d], had no key to delete ...", embName.c_str(), currTime); return; } LOG(INFO) << StringFormat( - "tensor-name[%s]'s lastTime[%d], had size[%d] keys to delete ...", embName.c_str(), currTime, evictKey.size()); + "table-name[%s]'s lastTime[%d], had size[%d] keys to delete ...", embName.c_str(), currTime, evictKey.size()); // 真正从 m_historyRecords 中淘汰 absl::flat_hash_map& historyRecords = m_recordsData.historyRecords[embName]; @@ -196,6 +206,12 @@ void FeatureAdmitAndEvict::SetFunctionSwitch(bool isEnableEvict) } m_isEnableFunction = isEnableEvict; } + +void FeatureAdmitAndEvict::SetCombineSwitch() +{ + m_isCombine = GetCombineSwitch(); +} + bool FeatureAdmitAndEvict::GetFunctionSwitch() const { return m_isEnableFunction; @@ -222,10 +238,10 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t const std::vector& embNames, bool isTimestamp) { for (size_t i = 0; i < thresholds.size(); ++i) { - auto it = std::find(embNames.begin(), embNames.end(), thresholds[i].tensorName); + auto it = std::find(embNames.begin(), embNames.end(), thresholds[i].tableName); if (it == embNames.end()) { // 配置不存在于当前跑的模型,也要报错 LOG(ERROR) << StringFormat( - "embName[%s] is not exist at current model ...", thresholds[i].tensorName.c_str()); + "embName[%s] is not exist at current model ...", thresholds[i].tableName.c_str()); return false; } else { // 同时支持“准入&淘汰”,却没有传时间戳 @@ -242,10 +258,10 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t return true; } -auto FeatureAdmitAndEvict::GetTensorThresholds() -> tensor_2_thresh_mem_t +auto FeatureAdmitAndEvict::GetTableThresholds() -> table_2_thresh_mem_t { std::lock_guard lock(m_syncMutexs); - return m_tensor2Threshold; + return m_table2Threshold; } auto FeatureAdmitAndEvict::GetHistoryRecords() -> AdmitAndEvictData& @@ -254,10 +270,10 @@ auto FeatureAdmitAndEvict::GetHistoryRecords() -> AdmitAndEvictData& return m_recordsData; } -void FeatureAdmitAndEvict::LoadTensorThresholds(tensor_2_thresh_mem_t& loadData) +void FeatureAdmitAndEvict::LoadTableThresholds(table_2_thresh_mem_t& loadData) { std::lock_guard lock(m_syncMutexs); - m_tensor2Threshold = std::move(loadData); + m_table2Threshold = std::move(loadData); } void FeatureAdmitAndEvict::LoadHistoryRecords(AdmitAndEvictData& loadData) @@ -266,7 +282,7 @@ void FeatureAdmitAndEvict::LoadHistoryRecords(AdmitAndEvictData& loadData) m_recordsData = std::move(loadData); } -// 解析m_tensor2Threshold +// 解析m_table2Threshold bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& thresholdValues) { if (thresholdValues.empty()) { @@ -277,23 +293,26 @@ bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& m_cfgThresholds = thresholdValues; for (const auto& value : thresholdValues) { LOG(INFO) << StringFormat( - "embName[%s], count[%d], time[%d] ...", - value.tensorName.c_str(), value.countThreshold, value.timeThreshold); - auto it = m_tensor2Threshold.find(value.tensorName); - if (it != m_tensor2Threshold.end()) { + "embName[%s], count[%d], time[%d], coefficient[%d] ...", + value.tableName.c_str(), value.countThreshold, value.timeThreshold, value.faaeCoefficient); + auto it = m_table2Threshold.find(value.tableName); + if (it != m_table2Threshold.end()) { // train和eval同时开启,会出现表重复配置 - LOG(INFO) << StringFormat("[%s] is repeated configuration ...", value.tensorName.c_str()); + LOG(INFO) << StringFormat("[%s] is repeated configuration ...", value.tableName.c_str()); return true; } - m_tensor2Threshold[value.tensorName] = value; + m_table2Threshold[value.tableName] = value; + if (value.faaeCoefficient < 1) { + LOG(ERROR) << StringFormat("[%s] config error, coefficient smaller than 1 ...", value.tableName.c_str()); + } if (value.countThreshold != -1 && value.timeThreshold != -1) { - m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_BOTH; + m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_BOTH; } else if (value.countThreshold != -1 && value.timeThreshold == -1) { - m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ONLY_ADMIT; + m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_ONLY_ADMIT; } else { - LOG(ERROR) << StringFormat("[%s] config error, have evict but no admit ...", value.tensorName.c_str()); - m_embStatus[value.tensorName] = SingleEmbTableStatus::SETS_ERROR; + LOG(ERROR) << StringFormat("[%s] config error, have evict but no admit ...", value.tableName.c_str()); + m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_ERROR; return false; } } @@ -301,7 +320,7 @@ bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& return true; } -std::vector FeatureAdmitAndEvict::GetAllNeedEvictTensorNames() +std::vector FeatureAdmitAndEvict::GetAllNeedEvictTableNames() { std::vector names; std::lock_guard lock(m_syncMutexs); diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 5dfcc3e0..85219c8a 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -58,10 +58,13 @@ namespace MxRec { // 特征淘汰接口 void FeatureEvict(map>& evictKeyMap); void ExecuteFeatureAdmit( - const string& tensorName, int channel, keys_t& splitKey, absl::flat_hash_map& mergeKeys); + const string& tableName, int channel, keys_t& splitKey, absl::flat_hash_map& mergeKeys); // 特征淘汰的使能接口 void SetFunctionSwitch(bool isEnableEvict); + // 特征准入合表统计 + void SetCombineSwitch(); + bool GetFunctionSwitch() const; void PreProcessKeys(const std::vector& splitKey, std::vector& keyCount, absl::flat_hash_map& mergeKeys); @@ -71,10 +74,10 @@ namespace MxRec { const std::vector& embNames, bool isTimestamp); // 与模型保存加载交互的接口 - auto GetTensorThresholds() -> tensor_2_thresh_mem_t; + auto GetTableThresholds() -> table_2_thresh_mem_t; auto GetHistoryRecords() -> AdmitAndEvictData&; - void LoadTensorThresholds(tensor_2_thresh_mem_t& loadData); + void LoadTableThresholds(table_2_thresh_mem_t& loadData); void LoadHistoryRecords(AdmitAndEvictData& loadData); static std::vector m_cfgThresholds; // 用于判断阈值配置的有效性 @@ -82,17 +85,18 @@ namespace MxRec { GTEST_PRIVATE : - // 解析m_tensor2Threshold + // 解析m_table2Threshold bool ParseThresholdCfg(const std::vector& thresholdValues); - std::vector GetAllNeedEvictTensorNames(); - FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tensorName, + std::vector GetAllNeedEvictTableNames(); + FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tableName, const int64_t featureId, const uint32_t featureCnt); void FeatureEvictHelper(const std::string& embName, std::vector& evictKey); void ResetAllRecords(); bool m_isEnableFunction { true }; // “特征淘汰”的使能开关 bool m_isExit { false }; // 淘汰线程退出的标识 - absl::flat_hash_map m_tensor2Threshold; // tensor-X ---> ThresholdValue 映射 + bool m_isCombine { false }; // 是否合并统计history + absl::flat_hash_map m_table2Threshold; // table-X ---> ThresholdValue 映射 AdmitAndEvictData m_recordsData; std::mutex m_syncMutexs; // 特征准入与特征淘汰竞争的同步锁 int m_recordsInitSize { DEFAULT_RECORDS_INIT_SIZE }; // m_historyRecords表初始容量 diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index b0f3edbf..f24f58ff 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1040,7 +1040,8 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int cha LOG(ERROR) << StringFormat("dev cache overflow %d>%d", maxOffsetTmp, embInfos[embName].devVocabSize); throw std::runtime_error("dev cache overflow!"); } - VLOG(GLOG_DEBUG) << StringFormat("current dev emb usage:%d/%d", maxOffsetTmp, embInfos[embName].devVocabSize); + VLOG(GLOG_DEBUG) << StringFormat("current hbm emb:%s, usage:%d/%d", embName.c_str(), maxOffsetTmp, + embInfos[embName].devVocabSize); VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); } @@ -1074,7 +1075,8 @@ void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& s key = 0; } } - VLOG(GLOG_DEBUG) << StringFormat("current dev emb usage:%d/%d", maxOffsetTmp, embInfos[embName].devVocabSize); + VLOG(GLOG_DEBUG) << StringFormat("current expansion emb:%s, usage:%d/%d", embName.c_str(), maxOffsetTmp, + embInfos[embName].devVocabSize); VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); } @@ -1133,6 +1135,7 @@ T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, in return move(t); } +// DDR keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) { TimeCost tc = TimeCost(); @@ -1159,6 +1162,7 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) } } +// HBM unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) { TimeCost tc = TimeCost(); diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 706c16ed..822febea 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -141,4 +141,20 @@ namespace MxRec { throw std::runtime_error("dsmi_get_chip_info failed, ret = " + to_string(ret)); } + + bool GetCombineSwitch() + { + const char* faaeMode = std::getenv("USE_COMBINE_FAAE"); // 获取环境变量 + bool isCombine = false; + if (faaeMode != nullptr) { + try { + isCombine = (std::stoi(faaeMode) == 1); + LOG(INFO) << StringFormat("If combine history table: %d", isCombine); + } catch (const std::invalid_argument& e) { + LOG(ERROR) << "The value of USE_COMBINE_FAAE is invalid!"; + throw std::invalid_argument("Invalid env value USE_COMBINE_FAAE"); + } + } + return isCombine; + } } // end namespace MxRec diff --git a/src/core/utils/common.h b/src/core/utils/common.h index f51d50c0..60d4a97b 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -88,6 +88,8 @@ namespace MxRec { constexpr int HOT_EMB_UPDATE_STEP_DEFAULT = 1000; constexpr float HOT_EMB_CACHE_PCT = static_cast(1. / 3); // hot emb cache percent + const string COMBINE_HISTORY_NAME = "combine_table_history"; + using emb_key_t = int64_t; using emb_name_t = std::string; using keys_t = std::vector; @@ -102,6 +104,7 @@ namespace MxRec { }; string GetChipName(int devID); + bool GetCombineSwitch(); namespace UBSize { const int ASCEND910_PREMIUM_A = 262144; @@ -239,16 +242,18 @@ struct BatchTask { struct ThresholdValue { ThresholdValue() = default; - ThresholdValue(emb_name_t name, int countThre, int timeThre) + ThresholdValue(emb_name_t name, int countThre, int timeThre, int faaeCoef) { - tensorName = name; + tableName = name; countThreshold = countThre; timeThreshold = timeThre; + faaeCoefficient = faaeCoef; } - emb_name_t tensorName { "" }; // embName + emb_name_t tableName { "" }; // embName int countThreshold { -1 }; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 int timeThreshold { -1 }; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 + int faaeCoefficient { 1 }; // 配置后,该表在准入时,count计数会乘以该系数 }; struct FeatureItemInfo { @@ -492,7 +497,7 @@ struct BatchTask { using emb_hash_mem_t = absl::flat_hash_map; using offset_mem_t = std::map; using key_offset_mem_t = std::map>; - using tensor_2_thresh_mem_t = absl::flat_hash_map; + using table_2_thresh_mem_t = absl::flat_hash_map; using trans_serialize_t = uint8_t; using key_offset_map_t = std::map; using all_key_offset_map_t = std::map>; @@ -510,7 +515,7 @@ struct BatchTask { emb_hash_mem_t embHashMaps; offset_mem_t maxOffset; key_offset_mem_t keyOffsetMap; - tensor_2_thresh_mem_t tens2Thresh; + table_2_thresh_mem_t table2Thresh; AdmitAndEvictData histRec; }; @@ -532,7 +537,7 @@ struct BatchTask { EMB_CURR_STAT = 4, NDDR_OFFSET = 5, NDDR_FEATMAP = 6, - TENSOR_2_THRESH = 7, + TABLE_2_THRESH = 7, HIST_REC = 8, ATTRIBUTE = 9 }; diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index b81e5c6d..9c1b5fb2 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -169,8 +169,9 @@ void GetHybridMgmt(pybind11::module_& m) void GetThresholdValue(pybind11::module_& m) { pybind11::class_(m, "ThresholdValue") - .def(pybind11::init()) - .def_readwrite("tensor_name", &ThresholdValue::tensorName) + .def(pybind11::init()) + .def_readwrite("table_name", &ThresholdValue::tableName) .def_readwrite("count_threshold", &ThresholdValue::countThreshold) - .def_readwrite("time_threshold", &ThresholdValue::timeThreshold); + .def_readwrite("time_threshold", &ThresholdValue::timeThreshold) + .def_readwrite("faae_coefficient", &ThresholdValue::faaeCoefficient); } diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 33df7b23..1cf27378 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -175,17 +175,18 @@ protected: } } - void SetTens2Threshold(tensor_2_thresh_mem_t& testTens2Threshold) + void SetTable2Threshold(table_2_thresh_mem_t& testTable2Threshold) { for (const auto& testEmbInfo : testEmbInfos) { ThresholdValue val; - val.tensorName = testEmbInfo.name; + val.tableName = testEmbInfo.name; val.countThreshold = offsetMem; val.timeThreshold = offsetMem; + val.faaeCoefficient = 1; offsetMem++; - testTens2Threshold[testEmbInfo.name] = move(val); + testTable2Threshold[testEmbInfo.name] = move(val); } } @@ -214,6 +215,30 @@ protected: timeStamp++; } } + + void SetHistRecCombine(AdmitAndEvictData& histRec) + { + int64_t featureId { int64Min }; + int count { 1 }; + time_t lastTime { 1000 }; + time_t timeStamp { 10000 }; + + auto& historyRecords { histRec.historyRecords[COMBINE_HISTORY_NAME] }; + auto& timestamps { histRec.timestamps[COMBINE_HISTORY_NAME] }; + + timestamps = timeStamp; + + for (int i = 0; i < count; ++i) { + historyRecords[featureId].count = count; + historyRecords[featureId].lastTime = lastTime; + + featureId++; + } + + count++; + lastTime++; + timeStamp++; + } }; TEST_F(CheckpointTest, HostEmbs) @@ -403,25 +428,31 @@ TEST_F(CheckpointTest, AllMgmt) TEST_F(CheckpointTest, FeatAdmitNEvict) { - tensor_2_thresh_mem_t testTrens2Thresh; - tensor_2_thresh_mem_t validTrens2Thresh; + table_2_thresh_mem_t testTrens2Thresh; + table_2_thresh_mem_t validTrens2Thresh; AdmitAndEvictData testHistRec; AdmitAndEvictData validHistRec; SetEmbInfo(); - SetTens2Threshold(testTrens2Thresh); + SetTable2Threshold(testTrens2Thresh); validTrens2Thresh = testTrens2Thresh; - SetHistRec(testHistRec); + + if (GetCombineSwitch()) { + SetHistRecCombine(testHistRec); + } else { + SetHistRec(testHistRec); + } + validHistRec = testHistRec; CkptData testSaveData; CkptData validLoadData; CkptData testLoadData; - testSaveData.tens2Thresh = testTrens2Thresh; + testSaveData.table2Thresh = testTrens2Thresh; testSaveData.histRec.timestamps = testHistRec.timestamps; testSaveData.histRec.historyRecords = testHistRec.historyRecords; - validLoadData.tens2Thresh = validTrens2Thresh; + validLoadData.table2Thresh = validTrens2Thresh; validLoadData.histRec = validHistRec; validLoadData.histRec.timestamps = validHistRec.timestamps; validLoadData.histRec.historyRecords = validHistRec.historyRecords; @@ -430,16 +461,16 @@ TEST_F(CheckpointTest, FeatAdmitNEvict) testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::FEAT_ADMIT_N_EVICT }); - EXPECT_EQ(validLoadData.tens2Thresh.size(), testLoadData.tens2Thresh.size()); + EXPECT_EQ(validLoadData.table2Thresh.size(), testLoadData.table2Thresh.size()); EXPECT_EQ(validLoadData.histRec.historyRecords.size(), testLoadData.histRec.historyRecords.size()); - for (const auto& it : validLoadData.tens2Thresh) { - EXPECT_EQ(1, testLoadData.tens2Thresh.count(it.first)); + for (const auto& it : validLoadData.table2Thresh) { + EXPECT_EQ(1, testLoadData.table2Thresh.count(it.first)); - const auto& tens2Thresh = testLoadData.tens2Thresh.at(it.first); + const auto& table2Thresh = testLoadData.table2Thresh.at(it.first); - EXPECT_EQ(it.second.tensorName, tens2Thresh.tensorName); - EXPECT_EQ(it.second.countThreshold, tens2Thresh.countThreshold); - EXPECT_EQ(it.second.timeThreshold, tens2Thresh.timeThreshold); + EXPECT_EQ(it.second.tableName, table2Thresh.tableName); + EXPECT_EQ(it.second.countThreshold, table2Thresh.countThreshold); + EXPECT_EQ(it.second.timeThreshold, table2Thresh.timeThreshold); } for (const auto& it : validLoadData.histRec.timestamps) { diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 6caace33..55babf89 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -22,6 +22,17 @@ using valid_int64_t = absl::flat_hash_map>; using valie_dataset_t = absl::flat_hash_map>; using valid_attrib_t = absl::flat_hash_map>; +struct InputArgs { + vector& embNames; + CkptData& validData; + FeatAdmitNEvictCkpt& testCkpt; + valid_int_t& validTrens2ThreshArr; + valid_attrib_t& validTrens2ThreshAttrib; + valid_attrib_t& validHistRecAttrib; + valid_int64_t& validHistRecArr; + CkptData& testData; +}; + class CkptDataHandlerTest : public testing::Test { protected: int floatBytes { 4 }; @@ -68,7 +79,7 @@ protected: } } - void SetTens2Threshold(tensor_2_thresh_mem_t& testTens2Threshold, + void SetTable2Threshold(table_2_thresh_mem_t& testTable2Threshold, valid_int_t& validArr, valid_attrib_t& validAttrib) { @@ -77,7 +88,7 @@ protected: for (const auto& testEmbInfo : testEmbInfos) { ThresholdValue val; - val.tensorName = testEmbInfo.name; + val.tableName = testEmbInfo.name; val.countThreshold = countThreshold; val.timeThreshold = timeThreshold; @@ -86,7 +97,7 @@ protected: countThreshold++; timeThreshold++; - testTens2Threshold[testEmbInfo.name] = move(val); + testTable2Threshold[testEmbInfo.name] = move(val); validArr[testEmbInfo.name] = move(valid); validAttrib[testEmbInfo.name].push_back(2); // 2 is element num in one vector validAttrib[testEmbInfo.name].push_back(int32Bytes); @@ -128,12 +139,114 @@ protected: timeStamp++; } } + + void SetHistRecCombine(AdmitAndEvictData& histRec, valid_int64_t& validArr, valid_attrib_t& validAttrib) + { + int64_t featureId { int64Min }; + int count { 1 }; + time_t lastTime { 1000 }; + time_t timeStamp { 10000 }; + + auto& validA { validArr[COMBINE_HISTORY_NAME] }; + auto& historyRecords { histRec.historyRecords[COMBINE_HISTORY_NAME] }; + auto& timestamps { histRec.timestamps[COMBINE_HISTORY_NAME] }; + + timestamps = timeStamp; + validA.push_back(timeStamp); + + for (int i = 0; i < count; ++i) { + historyRecords[featureId].count = count; + historyRecords[featureId].lastTime = lastTime; + + validA.push_back(featureId); + validA.push_back(count); + validA.push_back(lastTime); + + featureId++; + } + + auto& attribute = validAttrib[COMBINE_HISTORY_NAME]; + attribute.push_back(count); + attribute.push_back(int64Bytes); + + count++; + lastTime++; + timeStamp++; + } + + void TestForSave(InputArgs& args) + { + for (const auto& embName : args.embNames) { + EXPECT_EQ(1, args.validData.table2Thresh.count(embName)); + + CkptTransData testSaveData = args.testCkpt.GetDataset(CkptDataType::TABLE_2_THRESH, embName); + EXPECT_EQ(args.validTrens2ThreshArr.at(embName), testSaveData.int32Arr); // need other test method + EXPECT_EQ(args.validTrens2ThreshAttrib.at(embName), testSaveData.attribute); + testSaveData = args.testCkpt.GetDataset(CkptDataType::HIST_REC, embName); + + if (!GetCombineSwitch()) { + EXPECT_EQ(1, args.validData.histRec.timestamps.count(embName)); + EXPECT_EQ(1, args.validData.histRec.historyRecords.count(embName)); + EXPECT_EQ(args.validHistRecAttrib.at(embName), testSaveData.attribute); + } else { + EXPECT_EQ(1, args.validData.histRec.timestamps.count(COMBINE_HISTORY_NAME)); + EXPECT_EQ(1, args.validData.histRec.historyRecords.count(COMBINE_HISTORY_NAME)); + EXPECT_EQ(args.validHistRecAttrib.at(COMBINE_HISTORY_NAME), testSaveData.attribute); + } + } + } + void TestForLoad(InputArgs& args) + { + CkptTransData testLoadData; + for (const auto& embName : args.embNames) { + testLoadData.int32Arr = args.validTrens2ThreshArr.at(embName); + testLoadData.attribute = args.validTrens2ThreshAttrib.at(embName); + args.testCkpt.SetDataset(CkptDataType::TABLE_2_THRESH, embName, testLoadData); + + if (!GetCombineSwitch()) { + testLoadData.int64Arr = args.validHistRecArr.at(embName); + testLoadData.attribute = args.validHistRecAttrib.at(embName); + } else { + testLoadData.int64Arr = args.validHistRecArr.at(COMBINE_HISTORY_NAME); + testLoadData.attribute = args.validHistRecAttrib.at(COMBINE_HISTORY_NAME); + } + args.testCkpt.SetDataset(CkptDataType::HIST_REC, embName, testLoadData); + } + args.testCkpt.GetProcessData(args.testData); + + EXPECT_EQ(args.validData.table2Thresh.size(), args.testData.table2Thresh.size()); + EXPECT_EQ(args.validData.histRec.historyRecords.size(), args.testData.histRec.historyRecords.size()); + for (const auto& it : args.validData.table2Thresh) { + EXPECT_EQ(1, args.testData.table2Thresh.count(it.first)); + + const auto& table2Thresh = args.testData.table2Thresh.at(it.first); + + EXPECT_EQ(it.second.tableName, table2Thresh.tableName); + EXPECT_EQ(it.second.countThreshold, table2Thresh.countThreshold); + EXPECT_EQ(it.second.timeThreshold, table2Thresh.timeThreshold); + } + + for (const auto& it : args.validData.histRec.timestamps) { + EXPECT_EQ(1, args.testData.histRec.timestamps.count(it.first)); + EXPECT_EQ(1, args.testData.histRec.historyRecords.count(it.first)); + + const auto& historyRecords = args.testData.histRec.historyRecords.at(it.first); + const auto& validHistRec = args.validData.histRec.historyRecords.at(it.first); + + for (const auto& validHR : validHistRec) { + const auto& testHR = historyRecords.at(validHR.first); + + EXPECT_EQ(validHR.second.count, testHR.count); + EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); + } + } + } }; TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) { - tensor_2_thresh_mem_t testTrens2Thresh; - tensor_2_thresh_mem_t validTrens2Thresh; + table_2_thresh_mem_t testTrens2Thresh; + table_2_thresh_mem_t validTrens2Thresh; AdmitAndEvictData testHistRec; AdmitAndEvictData validHistRec; @@ -143,77 +256,37 @@ TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) valid_attrib_t validHistRecAttrib; SetEmbInfo(); - SetTens2Threshold(testTrens2Thresh, validTrens2ThreshArr, validTrens2ThreshAttrib); + SetTable2Threshold(testTrens2Thresh, validTrens2ThreshArr, validTrens2ThreshAttrib); validTrens2Thresh = testTrens2Thresh; - SetHistRec(testHistRec, validHistRecArr, validHistRecAttrib); + + if (GetCombineSwitch()) { + SetHistRecCombine(testHistRec, validHistRecArr, validHistRecAttrib); + } else { + SetHistRec(testHistRec, validHistRecArr, validHistRecAttrib); + } validHistRec = testHistRec; CkptData testData; CkptData validData; FeatAdmitNEvictCkpt testCkpt; - testData.tens2Thresh = testTrens2Thresh; + testData.table2Thresh = testTrens2Thresh; testData.histRec.timestamps = testHistRec.timestamps; testData.histRec.historyRecords = testHistRec.historyRecords; - validData.tens2Thresh = validTrens2Thresh; + validData.table2Thresh = validTrens2Thresh; validData.histRec.timestamps = validHistRec.timestamps; validData.histRec.historyRecords = validHistRec.historyRecords; testCkpt.SetProcessData(testData); vector embNames { testCkpt.GetEmbNames() }; - CkptTransData testSaveData; - EXPECT_EQ(validData.tens2Thresh.size(), embNames.size()); - - for (const auto& embName : embNames) { - EXPECT_EQ(1, validData.tens2Thresh.count(embName)); - - EXPECT_EQ(1, validData.histRec.timestamps.count(embName)); - EXPECT_EQ(1, validData.histRec.historyRecords.count(embName)); + EXPECT_EQ(validData.table2Thresh.size(), embNames.size()); - testSaveData = testCkpt.GetDataset(CkptDataType::TENSOR_2_THRESH, embName); - EXPECT_EQ(validTrens2ThreshArr.at(embName), testSaveData.int32Arr); // need other test method - EXPECT_EQ(validTrens2ThreshAttrib.at(embName), testSaveData.attribute); - testSaveData = testCkpt.GetDataset(CkptDataType::HIST_REC, embName); - EXPECT_EQ(validHistRecAttrib.at(embName), testSaveData.attribute); - } - - CkptTransData testLoadData; - for (const auto& embName : embNames) { - testLoadData.int32Arr = validTrens2ThreshArr.at(embName); - testLoadData.attribute = validTrens2ThreshAttrib.at(embName); - testCkpt.SetDataset(CkptDataType::TENSOR_2_THRESH, embName, testLoadData); - - testLoadData.int64Arr = validHistRecArr.at(embName); - testLoadData.attribute = validHistRecAttrib.at(embName); - testCkpt.SetDataset(CkptDataType::HIST_REC, embName, testLoadData); - } - testCkpt.GetProcessData(testData); + InputArgs args = {embNames, validData, testCkpt, validTrens2ThreshArr, validTrens2ThreshAttrib, + validHistRecAttrib, validHistRecArr, testData}; + // 测试save + TestForSave(args); - EXPECT_EQ(validData.tens2Thresh.size(), testData.tens2Thresh.size()); - EXPECT_EQ(validData.histRec.historyRecords.size(), testData.histRec.historyRecords.size()); - for (const auto& it : validData.tens2Thresh) { - EXPECT_EQ(1, testData.tens2Thresh.count(it.first)); - - const auto& tens2Thresh = testData.tens2Thresh.at(it.first); - - EXPECT_EQ(it.second.tensorName, tens2Thresh.tensorName); - EXPECT_EQ(it.second.countThreshold, tens2Thresh.countThreshold); - EXPECT_EQ(it.second.timeThreshold, tens2Thresh.timeThreshold); - } - - for (const auto& it : validData.histRec.timestamps) { - EXPECT_EQ(1, testData.histRec.timestamps.count(it.first)); - EXPECT_EQ(1, testData.histRec.historyRecords.count(it.first)); - - const auto& historyRecords = testData.histRec.historyRecords.at(it.first); - const auto& validHistRec = validData.histRec.historyRecords.at(it.first); - - for (const auto& validHR : validHistRec) { - const auto& testHR = historyRecords.at(validHR.first); - - EXPECT_EQ(validHR.second.count, testHR.count); - EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); - } - } + // 测试load + TestForLoad(args); } \ No newline at end of file diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 752db8e2..a570915d 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -179,7 +179,7 @@ TEST_F(EmbMgmtTest, Initialize_HBM) embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1); + thresholdValues.emplace_back(name, 1, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; @@ -199,7 +199,7 @@ TEST_F(EmbMgmtTest, Evict) embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1); + thresholdValues.emplace_back(name, 1, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; @@ -222,7 +222,7 @@ TEST_F(EmbMgmtTest, Evict_HBM) embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1); + thresholdValues.emplace_back(name, 1, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index 31bfdd1e..f09eb3e2 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -181,48 +181,48 @@ protected: printf("\t############# [%s] tid[%lu] ############# begin ...\n", thrName.c_str(), std::hash{}(std::this_thread::get_id())); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys1 = {11, 11, 33, 44, 11, 55, 88, 55} cnt1 = 1 2 1 3 1 1 4 1 */ InputArgs args1 = {keys1, cnt1, {}, initHistory, {}}; // 每个表的第一次记录,要用initHistory追加 - FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args1); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tableName, args1); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys2 = {11, 12, 33, 21, 11, 12} cnt2 = 1 2 1 1 2 3 */ InputArgs args2 = {keys2, cnt2, {}, args1.expectHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args2); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tableName, args2); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tensorBBB", 3, 7} + {"tableBBB", 3, 7} keys3 = {123, 121, 121, 212, 211} cnt3 = 1 2 1 1 2 */ InputArgs args3 = {keys3, cnt3, {}, initHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tensorName, args3); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tableName, args3); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_6)); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys4 = {11, 11, 33, 44, 55, 88, 55} cnt4 = 1 2 3 2 1 2 1 */ InputArgs args4 = {keys4, cnt4, {}, args2.expectHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tensorName, args4); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[0].tableName, args4); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tensorBBB", 3, 7} + {"tableBBB", 3, 7} keys5 = {125, 121, 122, 212, 211} cnt5 = 1 2 1 3 1 */ InputArgs args5 = {keys5, cnt5, {}, args3.expectHistory, {}}; - FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tensorName, args5); + FeatureAdmitCommonMultiThr(faae, 0, thresholds[1].tableName, args5); printf("\t############# [%s] tid[%lu] ############# end ...\n", thrName.c_str(), std::hash{}(std::this_thread::get_id())); @@ -263,58 +263,59 @@ protected: { faae.ResetAllRecords(); faae.ParseThresholdCfg(thresholds); + faae.SetCombineSwitch(); StartEvictThread(); printf("Current test single-thread is [%lu]\n", std::hash{}(std::this_thread::get_id())); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys1 = {11, 11, 33, 44, 11, 55, 88, 55} cnt1 = 1 2 1 3 1 1 4 1 */ keys_t expectRet1 = {11, 11, -1, 44, 11, 55, 88, 55}; InputArgs args1 = {keys1, cnt1, expectRet1, initHistory, {}}; // 每个表的第一次记录,要用initHistory追加 - FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args1); + FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args1); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys2 = {11, 12, 33, 21, 11, 12} cnt2 = 1 2 1 1 2 3 */ keys_t expectRet2 = {11, 12, 33, -1, 11, 12}; InputArgs args2 = {keys2, cnt2, expectRet2, args1.expectHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args2); + FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args2); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tensorBBB", 3, 7} + {"tableBBB", 3, 7} keys3 = {123, 121, 121, 212, 211} cnt3 = 1 2 1 1 2 */ keys_t expectRet3 = {-1, 121, 121, -1, -1}; InputArgs args3 = {keys3, cnt3, expectRet3, initHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[1].tensorName, args3); + FeatureAdmitCommon(faae, 0, thresholds[1].tableName, args3); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_6)); /* - {"tensorAAA", 2, 5} + {"tableAAA", 2, 5} keys4 = {11, 11, 33, 44, 55, 88, 55} cnt4 = 1 2 3 2 1 2 1 */ keys_t expectRet4 = {11, 11, 33, 44, 55, 88, 55}; InputArgs args4 = {keys4, cnt4, expectRet4, args2.expectHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[0].tensorName, args4); + FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args4); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); /* - {"tensorBBB", 3, 7} + {"tableBBB", 3, 7} keys5 = {125, 121, 122, 212, 211} cnt5 = 1 2 1 3 1 */ keys_t expectRet5 = {-1, 121, -1, 212, 211}; InputArgs args5 = {keys5, cnt5, expectRet5, args3.expectHistory, {}}; - FeatureAdmitCommon(faae, 0, thresholds[1].tensorName, args5); + FeatureAdmitCommon(faae, 0, thresholds[1].tableName, args5); WaitEvictThread(); LOG(INFO) << "TestCase1: single thread test over ..."; @@ -331,7 +332,7 @@ protected: vector tmpCnt = {1, 2, 1, 3, 1, 1, 4}; std::unique_ptr batch = make_unique(); - batch->name = thresholds[0].tensorName; + batch->name = thresholds[0].tableName; batch->timestamp = time(nullptr); // 校验调接口,出错 @@ -375,6 +376,7 @@ protected: { faae.ResetAllRecords(); faae.ParseThresholdCfg(thresholds); + faae.SetCombineSwitch(); StartEvictThread(); std::thread thrs[PerfConfig::keyProcessThreadNum]; @@ -396,9 +398,9 @@ protected: { /* 如果没有淘汰功能 - tensorAAA数据将会是 {11, 12, 21, 33, 44, 55, 88} + tableAAA数据将会是 {11, 12, 21, 33, 44, 55, 88} 10 5 1 5 5 4 6 - tensorBBB数据将会是 {121, 122, 123, 125, 211, 212}; + tableBBB数据将会是 {121, 122, 123, 125, 211, 212}; 5 1 1 1 3 4 */ keys_t expectKeys1 = {11, 33, 44, 55, 88}; // 12,21被淘汰掉了 @@ -406,8 +408,8 @@ protected: keys_t expectKeys2 = {121, 122, 125, 211, 212}; // 123被淘汰掉了 vector expectCnt2 = {5, 1, 1, 3, 4}; std::lock_guard lock(faae.m_syncMutexs); // 与 evict-thread 竞争资源 - CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tensorName, PerfConfig::keyProcessThreadNum); - CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tensorName, PerfConfig::keyProcessThreadNum); + CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tableName, PerfConfig::keyProcessThreadNum); + CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tableName, PerfConfig::keyProcessThreadNum); } WaitEvictThread(); @@ -421,8 +423,8 @@ protected: faae.ParseThresholdCfg(thresholds); std::unique_ptr batch = make_unique(); - // 测试点:tensorDDD表没有配置阈值,则不支持 - batch->name = std::string("tensorDDD"); + // 测试点:tableDDD表没有配置阈值,则不支持 + batch->name = std::string("tableDDD"); batch->timestamp = time(nullptr); // 校验调接口,不支持 @@ -443,11 +445,21 @@ protected: vector cnt4 = {1, 2, 3, 2, 1, 2, 1}; keys_t keys5 = {125, 121, 122, 212, 211}; vector cnt5 = {1, 2, 1, 3, 1}; - std::vector thresholds = {{"tensorAAA", 2, 5}, {"tensorBBB", 3, 7}, {"tensorCCC", 5, 9}}; + std::vector thresholds = {{"tableAAA", 2, 5, 1}, {"tableBBB", 3, 7, 1}, {"tableCCC", 5, 9, 1}}; }; +void SetEnv() +{ + const char* name = "USE_COMBINE_FAAE"; + const char* mode = "0"; + int overwrite = 1; + + ASSERT_EQ(setenv(name, mode, overwrite), 0); +} + TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict1) { + SetEnv(); TestCase1(); } TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict2) @@ -464,6 +476,7 @@ TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict4) } TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict5) { + SetEnv(); TestCase5(); } TEST_F(FeatureAdmitAndEvictTest, TestAdmitAndEvict6) -- Gitee From 6c6d8b6c3fbc7d476d1749dd73da0ae3ae957105 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 7 Aug 2023 16:39:58 +0800 Subject: [PATCH 235/551] Match-id-c87aa8308206ca5214bab1bbadf297a3b8356237 --- mx_rec/constants/constants.py | 12 ++++++------ mx_rec/graph/utils.py | 5 +++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index dd46e7a4..544ea2a4 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -51,6 +51,8 @@ LOG_MAX_SIZE = 1024 * 1024 MAX_INT32 = np.iinfo(np.int32).max +DUMP_MIDIFY_GRAPH_FILE_MODE = 0o550 + class BaseEnum(Enum): @classmethod @@ -116,12 +118,10 @@ OPTIMIZER_STATE_META = {OptimizerType.LAZY_ADAM: ["momentum", "velocity"], class All2allGradientsOp(BaseEnum): - SUM_GRADIENTS = "sum_gradients" - SUM_GRADIENTS_AND_DIV_BY_RANKSIZE = "sum_gradients_and_div_by_ranksize" + SUM_GRADIENTS = "sum_gradients" + SUM_GRADIENTS_AND_DIV_BY_RANKSIZE = "sum_gradients_and_div_by_ranksize" class ApplyGradientsStrategy(BaseEnum): - DIRECT_APPLY = "direct_apply" - SUM_SAME_ID_GRADIENTS_AND_APPLY = "sum_same_id_gradients_and_apply" - - + DIRECT_APPLY = "direct_apply" + SUM_SAME_ID_GRADIENTS_AND_APPLY = "sum_same_id_gradients_and_apply" diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index 5399f5d6..3b5f8148 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -3,9 +3,12 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from collections import defaultdict +import os import tensorflow as tf +from mx_rec.constants.constants import DUMP_MIDIFY_GRAPH_FILE_MODE + def check_input_list(objs, obj_type): if isinstance(objs, obj_type): @@ -72,6 +75,8 @@ def export_pb_graph(file_name, dump_graph, graph_def=None, export_path="./export :return: None """ if dump_graph: + dir_path = os.path.dirname(os.path.join(export_path, file_name)) + os.makedirs(dir_path, mode=DUMP_MIDIFY_GRAPH_FILE_MODE, exist_ok=True) graph_def = graph_def if graph_def else tf.compat.v1.get_default_graph().as_graph_def() tf.io.write_graph(graph_def, export_path, file_name, as_text) -- Gitee From 1befc6c28d33cc341172789d87478dcba8777405 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 7 Aug 2023 17:31:15 +0800 Subject: [PATCH 236/551] Match-id-96e3c57d82a1b44133bf565958f023fb605c1b21 --- mx_rec/core/asc/helper.py | 1 + mx_rec/core/asc/merge_table.py | 28 ++++++++++++++++------------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index a810ce81..32641a75 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -43,6 +43,7 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names table_names=table_names, **kwargs) + def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_counts=None, table_names=None, **kwargs): both_none = tgt_key_specs is None and args_index_list is None diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index 6fde2eff..55924ec6 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -25,11 +25,13 @@ def check_op(table_reachable_op: Operation): return True if 'gradients' in table_reachable_op.name and \ - table_reachable_op.type in ['UnsortedSegmentSum','TensorScatterUpdate']: + table_reachable_op.type in ['UnsortedSegmentSum', 'TensorScatterUpdate']: return True return False + + def find_dangling_table(table_names: List[str]): """ Find the tables which are disconenct with the forward training graph. And these table will not be backward updated. @@ -79,13 +81,12 @@ def find_dangling_table(table_names: List[str]): for table_name in table_names: find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) - logging.debug(f"*********** find tables: {table_lookup_op} ***********") - + logging.debug("*********** find tables: %s ***********",table_lookup_op) dangling_table = [] for table_name in table_names: if table_name not in table_lookup_op: - logging.debug(f"*********** created table {table_name} but never look up***********") + logging.debug("*********** created table %s but never look up***********",table_name) dangling_table.append(table_name) insert_dangling_table(table_name) @@ -125,15 +126,17 @@ def find_dangling_table(table_names: List[str]): next_to_visit = spread_tensors return op_visited, False + def _affirm(reach_op:List[Operation]): + for node in reach_op: + if node.type not in ["IdentityN", "Reshape"]: + return False + return True + for table_name, table_op in table_reachable_tensor.items(): - reach_op,found = bfs_lookup(table_op) + reach_op, found = bfs_lookup(table_op) affirm = False if not found: - for node in reach_op: - if node.type not in ["IdentityN","Reshape"]: - break - else: - affirm = True + affirm = _affirm(reach_op) if affirm: dangling_table.append(table_name) insert_dangling_table(table_name) @@ -156,6 +159,7 @@ def should_skip(table_name): return skip return False + def is_train_task(): bool_gauge_set = get_bool_gauge_set() if len(bool_gauge_set) > 0: @@ -165,7 +169,7 @@ def is_train_task(): return False else: op_list = tf.compat.v1.get_default_graph().get_operations() - for op in op_list: - if check_op(op): + for t_op in op_list: + if check_op(t_op): return True return False -- Gitee From 35301188fff1f036acd5ff2faed2d0ca2a5ac4b0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 7 Aug 2023 17:40:18 +0800 Subject: [PATCH 237/551] Match-id-80afb6235b9bf8965b387af60ec0ae75a8609655 --- mx_rec/core/asc/merge_table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index 55924ec6..39cbfdf7 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -81,12 +81,12 @@ def find_dangling_table(table_names: List[str]): for table_name in table_names: find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) - logging.debug("*********** find tables: %s ***********",table_lookup_op) + logging.debug("*********** find tables: %s ***********", table_lookup_op) dangling_table = [] for table_name in table_names: if table_name not in table_lookup_op: - logging.debug("*********** created table %s but never look up***********",table_name) + logging.debug("*********** created table %s but never look up***********", table_name) dangling_table.append(table_name) insert_dangling_table(table_name) -- Gitee From 9854cbb07131ba7f93690e78ba7b991e4e393040 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 8 Aug 2023 09:34:58 +0800 Subject: [PATCH 238/551] Match-id-985989843bbabd332e885d31a7710f3e9e1cf6f6 --- src/core/key_process/key_process.cpp | 9 ++++++++- src/tests/key_process/key_process_test.cpp | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index f24f58ff..d2c7889d 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -126,12 +126,19 @@ int KeyProcess::Start() for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); if (threadNumEnv != nullptr) { - threadNum = static_cast(*threadNumEnv) - static_cast('0'); + try { + threadNum = std::stoi(threadNumEnv); + } catch (const std::invalid_argument& e) { + threadNum = KEY_PROCESS_THREAD; + LOG(WARNING) << StringFormat("error value of threadNum, use default KEY_PROCESS_THREAD: %d", + threadNum); + } if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { throw runtime_error(StringFormat("%d is not valid", threadNum)); } } else { threadNum = KEY_PROCESS_THREAD; + LOG(INFO) << StringFormat("use default KEY_PROCESS_THREAD: %d", threadNum); } LOG(INFO) << StringFormat(KEY_PROCESS "key process thread num: %d", threadNum); for (int id = 0; id < threadNum; ++id) { diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 83cb9cee..e7e97364 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -259,7 +259,11 @@ TEST_F(KeyProcessTest, Start) { ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); + setenv("KEY_PROCESS_THREAD_NUM", "2", 1); ASSERT_EQ(process.Start(), 0); + setenv("KEY_PROCESS_THREAD_NUM", "abc", 1); + ASSERT_EQ(process.Start(), 0); + CTRLog(0, "key process start successful"); process.Destroy(); } -- Gitee From 79f03ce66b63f6c5b8c0ac6ae54b580e1a8486da Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 8 Aug 2023 09:44:39 +0800 Subject: [PATCH 239/551] Match-id-5b96a136698f1b4bacc1bae1059d33c174c07fc4 --- mx_rec/constants/constants.py | 1 + mx_rec/util/communication/hccl_mgmt.py | 124 +++++++++++++++++++++++++ mx_rec/util/initialize.py | 107 +-------------------- 3 files changed, 130 insertions(+), 102 deletions(-) create mode 100644 mx_rec/util/communication/hccl_mgmt.py diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 544ea2a4..611457e5 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -52,6 +52,7 @@ LOG_MAX_SIZE = 1024 * 1024 MAX_INT32 = np.iinfo(np.int32).max DUMP_MIDIFY_GRAPH_FILE_MODE = 0o550 +MAX_DEVICE_ID = 15 class BaseEnum(Enum): diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py new file mode 100644 index 00000000..952671ca --- /dev/null +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import json +import os + +from mx_rec.constants.constants import VALID_DEVICE_ID_LIST, MIN_SIZE, MAX_CONFIG_SIZE, MAX_DEVICE_ID +from mx_rec.validator.validator import RankInfoValidator, FileValidator + + +def parse_hccl_json(): + rank_table_path = os.path.realpath(os.getenv("RANK_TABLE_FILE")) + if not os.path.exists(rank_table_path): + raise FileExistsError(f"Target_hccl_json_dir {rank_table_path} does not exist when reading.") + + with open(rank_table_path, "r", encoding="utf-8") as file: + # check whether json file is valid + file_validator = FileValidator(rank_table_path) + # 1.check whether rank_table_path is soft link + file_validator.check_not_soft_link() + # 2.check json file size + file_validator.check_file_size(file, MAX_CONFIG_SIZE, MIN_SIZE) + file_validator.check() + + table_hccl = json.load(file) + if "server_list" not in table_hccl: + raise AttributeError(f"Lack of attribute server_list.") + if not table_hccl.get("server_list"): + raise ValueError(f"Server_list is empty.") + if "device" not in table_hccl.get("server_list")[0]: + raise AttributeError(f"Lack of attribute device.") + + rank_to_device_dict = dict() + for server_list in table_hccl.get("server_list"): + devices = server_list.get("device") + if devices is None: + raise ValueError("device is empty") + + for device in devices: + if "rank_id" not in device or not device.get("rank_id").isdigit(): + raise ValueError(f"hccl_json rank_id wrong.") + rank_id = int(device.get("rank_id")) + if "device_id" not in device or not device.get("device_id").isdigit(): + raise ValueError(f"hccl_json device_id wrong.") + + import mxrec_pybind + device_id = mxrec_pybind.get_logic_id(int(device.get("device_id"))) + if device_id > MAX_DEVICE_ID: + raise ValueError(f"get logic id from physic id fail, the device id is invalid.") + rank_to_device_dict[rank_id] = device_id + + return rank_to_device_dict + + +def set_hccl_info_without_json(): + """ + Used for no rank table file configured training situation. + Now, only less than or equal 8p training job is supported. + :return: None + """ + RankInfoValidator().check_visible_devices() + ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") + device_list = get_device_list(ascend_visible_devices) + + chief_device = os.getenv("CM_CHIEF_DEVICE") + rank_size = os.getenv("CM_WORKER_SIZE") + sorted_device_list = sorted(device_list) + if int(rank_size) != len(sorted_device_list): + raise ValueError(f"Rank size {rank_size} is different from device num {len(sorted_device_list)}.") + rank_to_device_dict = dict() + try: + rank_to_device_dict[0] = int(chief_device) + except ValueError as err: + raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err + try: + sorted_device_list.pop(int(chief_device) % len(sorted_device_list)) + except IndexError as err: + raise IndexError( + f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {sorted_device_list}.") \ + from err + except ZeroDivisionError as err: + raise ZeroDivisionError("sorted_device_list length can not equal to 0.") from err + + for device_idx in sorted_device_list: + import mxrec_pybind + + try: + device_id = mxrec_pybind.get_logic_id(int(device_idx)) + if device_id > MAX_DEVICE_ID: + raise ValueError(f"get logic id from physic id fail.") + index = sorted_device_list.index(device_idx) + rank_to_device_dict[index + 1] = device_id + except RuntimeError as exp: + raise RuntimeError(f"get logic id from physic id fail. Possible reasons: 1) running user permission " + f"is not enough to call dsmi api 2) driver has been used by other process") from \ + exp + return rank_to_device_dict + + +def get_device_list(ascend_visible_devices): + device_list = [] + try: + if "-" in ascend_visible_devices: + split_devices = ascend_visible_devices.strip().split("-") + if split_devices: + rank_start = int(split_devices[0]) + device_list = list(range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]) + 1)) + elif "," in ascend_visible_devices: + device_list = list(map(int, ascend_visible_devices.strip().split(","))) + elif ascend_visible_devices in VALID_DEVICE_ID_LIST: + device_list = [int(ascend_visible_devices.strip())] + else: + raise ValueError("invalid env variable ascend_visible_devices.") + except ValueError as error: + raise ValueError("Invalid env variable ascend_visible_devices, no valid device id is configured. " + "Please refer to the document https://www.hiascend.com/document/detail/zh/" + "CANNCommunityEdition/63RC2alpha002/ptmoddevg/ptmigr/ptmigr_0151.html for " + "the correct configuration method.") from error + except IndexError as error: + raise IndexError( + f"Index of ascend_visible_devices {ascend_visible_devices.strip().split('-')[-1]} is out of range") \ + from error + return device_list diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index e7a180cf..c5971b8f 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -1,21 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - -import json import logging import os from collections import defaultdict -import mxrec_pybind import psutil import mx_rec.constants.constants from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST, LOCAL_RANK_SIZE, \ MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH,\ TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE +from mx_rec.util.communication.hccl_mgmt import parse_hccl_json, set_hccl_info_without_json from mx_rec.util.ops import import_host_pipeline_ops -from mx_rec.validator.validator import RankInfoValidator, StringValidator, FileValidator +from mx_rec.validator.validator import StringValidator, FileValidator from mx_rec.util.atomic import AtomicInteger @@ -64,7 +62,7 @@ class ConfigInitializer: self._rank_id = kwargs.get("rank_id") self._rank_size = kwargs.get("rank_size") - self.parse_hccl_json() if os.getenv("RANK_TABLE_FILE") else self.set_hccl_info_without_json() + self._rank_to_device_dict = parse_hccl_json() if os.getenv("RANK_TABLE_FILE") else set_hccl_info_without_json() self.train_steps = kwargs.get("train_steps", -1) self.eval_steps = kwargs.get("eval_steps", -1) self.check_parameters() @@ -191,16 +189,14 @@ class ConfigInitializer: ConfigInitializer._single_instance = ConfigInitializer(use_mpi, **kwargs) def terminate(self): + logging.info("python process run into terminate") if self._is_terminated: logging.warning("The initializer has already been released once, please do not release it again.") return if self._asc_manager is not None: self.del_asc_manager() - - if self._mpi: - self._mpi.Finalize() - logging.debug("MPI has been destroyed.") + logging.info("python process run terminate success") self._is_terminated = True @@ -213,99 +209,6 @@ class ConfigInitializer: def get_feature_spec(self, key): return self._feature_spec_dict.get(key) - def parse_hccl_json(self): - rank_table_path = os.path.realpath(os.getenv("RANK_TABLE_FILE")) - if not os.path.exists(rank_table_path): - raise FileExistsError(f"Target_hccl_json_dir {rank_table_path} does not exist when reading.") - - with open(rank_table_path, "r", encoding="utf-8") as file: - # check whether json file is valid - file_validator = FileValidator(rank_table_path) - # 1.check whether rank_table_path is soft link - file_validator.check_not_soft_link() - # 2.check json file size - file_validator.check_file_size(file, MAX_CONFIG_SIZE, MIN_SIZE) - file_validator.check() - - table_hccl = json.load(file) - if "server_list" not in table_hccl: - raise AttributeError(f"Lack of attribute server_list.") - if not table_hccl["server_list"]: - raise ValueError(f"Server_list is empty.") - if "device" not in table_hccl["server_list"][0]: - raise AttributeError(f"Lack of attribute device.") - - for server_list in table_hccl.get("server_list"): - devices = server_list.get("device") - if devices is None: - raise ValueError("device is empty") - for device in devices: - if "rank_id" not in device or not device["rank_id"].isdigit(): - raise ValueError(f"hccl_json rank_id wrong.") - rank_id = int(device["rank_id"]) - if "device_id" not in device or not device["device_id"].isdigit(): - raise ValueError(f"hccl_json device_id wrong.") - device_id = mxrec_pybind.get_logic_id(int(device["device_id"])) - if device_id > 16: - raise ValueError(f"get logic id from physic id fail.") - self._rank_to_device_dict[rank_id] = device_id - - def set_hccl_info_without_json(self): - """ - Used for no rank table file configured training situation. - Now, only less than or equal 8p training job is supported. - :return: None - """ - RankInfoValidator().check_visible_devices() - ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") - device_list = [] - try: - if "-" in ascend_visible_devices: - split_devices = ascend_visible_devices.strip().split("-") - if len(split_devices) >= 1: - rank_start = int(split_devices[0]) - device_list = list(range(rank_start, int(ascend_visible_devices.strip().split("-")[-1]) + 1)) - elif "," in ascend_visible_devices: - device_list = list(map(int, ascend_visible_devices.strip().split(","))) - elif ascend_visible_devices in VALID_DEVICE_ID_LIST: - device_list = [int(ascend_visible_devices.strip())] - else: - raise ValueError("invalid env variable ascend_visible_devices.") - except ValueError as error: - raise ValueError("Invalid env variable ascend_visible_devices, no valid device id is configured. " - "Please refer to the document https://www.hiascend.com/document/detail/zh/" - "CANNCommunityEdition/63RC2alpha002/ptmoddevg/ptmigr/ptmigr_0151.html for " - "the correct configuration method.") from error - except IndexError as error: - raise IndexError( - f"Index of ascend_visible_devices {ascend_visible_devices.strip().split('-')[-1]} is out of range") \ - from error - - chief_device = os.getenv("CM_CHIEF_DEVICE") - rank_size = os.getenv("CM_WORKER_SIZE") - sorted_device_list = sorted(device_list) - if int(rank_size) != len(sorted_device_list): - raise ValueError(f"Rank size {rank_size} is different from device num {len(sorted_device_list)}.") - try: - self._rank_to_device_dict[0] = int(chief_device) - except ValueError as err: - raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err - try: - sorted_device_list.pop(int(chief_device) % len(sorted_device_list)) - except IndexError as err: - raise IndexError( - f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {sorted_device_list}.") \ - from err - except ZeroDivisionError as err: - raise ZeroDivisionError("sorted_device_list length can not equal to 0.") from err - - for device_idx in sorted_device_list: - device_id = mxrec_pybind.get_logic_id(int(device_idx)) - if device_id > 16: - raise ValueError(f"get logic id from physic id fail.") - index = sorted_device_list.index(device_idx) - self._rank_to_device_dict[index + 1] = device_id - def insert_training_mode_channel_id(self, is_training): if is_training not in self._training_mode_channel_dict: # mx_rec has 2 channel for data input. -- Gitee From 8a6a62fb1f85d6fb1f02159d28667241e9507d08 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 8 Aug 2023 10:44:32 +0800 Subject: [PATCH 240/551] Match-id-adf49498454be2a3e0fc04ef3feb18c95e4d315b --- mx_rec/saver/saver.py | 10 ++++++++-- mx_rec/validator/validator.py | 5 +++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 5af58e98..d358dde8 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -16,6 +16,7 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_op get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ send_host_data, get_ascend_global_hashtable_collection from mx_rec.util.perf import performance +from mx_rec.validator.validator import DirectoryValidator class Saver(object): @@ -74,8 +75,13 @@ class Saver(object): else: ckpt_name = f"sparse-{base_name}" - integrated_path = os.path.join(directory, ckpt_name) - saving_path = integrated_path + saving_path = os.path.join(directory, ckpt_name) + try: + if save_path.find("://") == -1: + DirectoryValidator(saving_path).with_blacklist(exact_compare=False).check() + except ValueError as err: + raise ValueError(f"The saving path {saving_path} cannot be a system directory " + f"or a subdirectory of the system directory.") from err if tf.io.gfile.exists(saving_path): tf.io.gfile.rmtree(saving_path) diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index 1e2f45d0..a2eaede4 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -231,11 +231,11 @@ class DirectoryValidator(StringValidator): def with_blacklist(self, lst: List = None, exact_compare: bool = True, msg: str = None): if lst is None: - lst = ["/usr/bin", "/usr/sbin", "/etc", "/usr/lib", "/usr/lib64"] + lst = ["/usr/bin", "/usr/sbin", "/etc", "/usr/lib", "/usr/lib64", "/usr/local"] if len(lst) == 0: return self if msg is None: - msg = "path should is in blacklist" + msg = "path should not in blacklist" if exact_compare: self.register_checker(lambda path: path not in [os.path.realpath(each) for each in lst], msg) else: @@ -255,6 +255,7 @@ class FileValidator(StringValidator): """ Check if file is valid. """ + def __init__(self, value): """ @param value: the file path, should not be emtpy string, should not contain double dot(../) -- Gitee From becb0fbb25fe5dee20f9e5d61d0cbe363cf40add Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 8 Aug 2023 15:55:42 +0800 Subject: [PATCH 241/551] Match-id-71b22afa1379689de3821b3f0d2f6573d606a598 --- mx_rec/core/asc/helper.py | 1 + mx_rec/core/asc/manager.py | 1 + mx_rec/core/asc/merge_table.py | 122 ++++++++++++++++----------------- 3 files changed, 62 insertions(+), 62 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 32641a75..64b78f2c 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -4,6 +4,7 @@ import logging from functools import reduce + import tensorflow as tf from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index d1e6030c..832bf53d 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -3,6 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import logging + import tensorflow as tf from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index 39cbfdf7..5ebb0d87 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -3,17 +3,23 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import logging -from typing import List -from typing import Dict +from typing import Dict, List + import tensorflow as tf -from tensorflow import Tensor -from tensorflow import Operation +from tensorflow import Operation, Tensor from mx_rec.util.initialize import get_enable_table_merge, export_table_instances, insert_dangling_table, \ get_bool_gauge_set -def check_op(table_reachable_op: Operation): +def affirm(reach_op:List[Operation]) -> bool: + for node in reach_op: + if node.type not in ("IdentityN", "Reshape"): + return False + return True + + +def check_op(table_reachable_op: Operation) -> bool: """Check whether the tensor op is optimizer op or backward gradient. Args: @@ -31,25 +37,33 @@ def check_op(table_reachable_op: Operation): return False +def is_train_task(): + bool_gauge_set = get_bool_gauge_set() + if bool_gauge_set: + if 'train' in bool_gauge_set or 'train_and_evaluate' in bool_gauge_set: + return True + if 'predict' in bool_gauge_set: + return False + else: + op_list = tf.compat.v1.get_default_graph().get_operations() + for t_op in op_list: + if check_op(t_op): + return True + return False + -def find_dangling_table(table_names: List[str]): +def find_dangling_table(table_names: List[str]) -> List[str]: """ Find the tables which are disconenct with the forward training graph. And these table will not be backward updated. :param table_names: list of all created tables' names :return: a list of dangling table names. """ - if not is_train_task(): - logging.info(f"!!merge table only available in train task.") - return [] - if not get_enable_table_merge(): - return [] - def find_table_op(table_name: str, the_op: Operation, table_lookup_op: Dict[str, List[Operation]], - table_reachable_tensor: Dict[str, List[Tensor]]): + table_reachable_tensor: Dict[str, List[Tensor]]) -> None: """ find all the table lookup op. :param table_name: tables' names :param the_op: the op to be @@ -68,32 +82,10 @@ def find_dangling_table(table_names: List[str]): table_lookup_op[table_name].append(the_op) table_reachable_tensor[table_name].extend(the_op.outputs) - op_list = tf.compat.v1.get_default_graph().get_operations() - - table_lookup_op = {} - table_reachable_tensor = {} - - for _, table_instance in export_table_instances().items(): - if table_instance.table_name not in table_names: - table_names.append(table_instance.table_name) - - for the_op in op_list: - for table_name in table_names: - find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) - - logging.debug("*********** find tables: %s ***********", table_lookup_op) - dangling_table = [] - - for table_name in table_names: - if table_name not in table_lookup_op: - logging.debug("*********** created table %s but never look up***********", table_name) - dangling_table.append(table_name) - insert_dangling_table(table_name) - def extend(op_list: List[Operation], tensor: Tensor, - spread_tensors: List[Tensor]): + spread_tensors: List[Tensor]) -> None: """extend the tensors which table lookup op can reach :param op_list: all op in the graph @@ -105,7 +97,8 @@ def find_dangling_table(table_names: List[str]): if tensor in the_op.inputs: spread_tensors.extend(the_op.outputs) - def bfs_lookup(next_to_visit: List[Tensor]): + + def bfs_lookup(next_to_visit: List[Tensor]) -> (set, bool): """find all the tensors which table lookup op can reach :param next_to_visit: the tensor list to be visited by bfs @@ -126,24 +119,44 @@ def find_dangling_table(table_names: List[str]): next_to_visit = spread_tensors return op_visited, False - def _affirm(reach_op:List[Operation]): - for node in reach_op: - if node.type not in ["IdentityN", "Reshape"]: - return False - return True + + if not is_train_task(): + logging.info(f"!!merge table only available in train task.") + return [] + if not get_enable_table_merge(): + return [] + + op_list = tf.compat.v1.get_default_graph().get_operations() + + table_lookup_op = {} + table_reachable_tensor = {} + + for _, table_instance in export_table_instances().items(): + if table_instance.table_name not in table_names: + table_names.append(table_instance.table_name) + + for the_op in op_list: + for table_name in table_names: + find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) + + logging.debug("*********** find tables: %s ***********", table_lookup_op) + dangling_table = [] + + for table_name in table_names: + if table_name not in table_lookup_op: + logging.debug("*********** created table %s but never look up***********", table_name) + dangling_table.append(table_name) + insert_dangling_table(table_name) for table_name, table_op in table_reachable_tensor.items(): reach_op, found = bfs_lookup(table_op) - affirm = False - if not found: - affirm = _affirm(reach_op) - if affirm: + if not found and affirm(reach_op): dangling_table.append(table_name) insert_dangling_table(table_name) return dangling_table -def should_skip(table_name): +def should_skip(table_name) -> bool: from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, str) \ @@ -158,18 +171,3 @@ def should_skip(table_name): break return skip return False - - -def is_train_task(): - bool_gauge_set = get_bool_gauge_set() - if len(bool_gauge_set) > 0: - if 'train' in bool_gauge_set or 'train_and_evaluate' in bool_gauge_set: - return True - if 'predict' in bool_gauge_set: - return False - else: - op_list = tf.compat.v1.get_default_graph().get_operations() - for t_op in op_list: - if check_op(t_op): - return True - return False -- Gitee From aaff5bbc14df46870b8146706ec4bf3b9456c9a5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 8 Aug 2023 16:15:38 +0800 Subject: [PATCH 242/551] Match-id-ebf829c9e1e0457a67d6522b6f5a96ebea844404 --- mx_rec/util/communication/hccl_mgmt.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index 952671ca..ccb31eff 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -113,10 +113,7 @@ def get_device_list(ascend_visible_devices): else: raise ValueError("invalid env variable ascend_visible_devices.") except ValueError as error: - raise ValueError("Invalid env variable ascend_visible_devices, no valid device id is configured. " - "Please refer to the document https://www.hiascend.com/document/detail/zh/" - "CANNCommunityEdition/63RC2alpha002/ptmoddevg/ptmigr/ptmigr_0151.html for " - "the correct configuration method.") from error + raise ValueError("Invalid env variable ascend_visible_devices, no valid device id is configured.") from error except IndexError as error: raise IndexError( f"Index of ascend_visible_devices {ascend_visible_devices.strip().split('-')[-1]} is out of range") \ -- Gitee From 0558b84b1b85113923b7a09f01bb982491b445e3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 9 Aug 2023 09:22:11 +0800 Subject: [PATCH 243/551] Match-id-5b77a232857cd4714c8dfd254a397c3ddb692a6c --- mx_rec/core/asc/merge_table.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index 5ebb0d87..1d314e34 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -39,16 +39,15 @@ def check_op(table_reachable_op: Operation) -> bool: def is_train_task(): bool_gauge_set = get_bool_gauge_set() - if bool_gauge_set: - if 'train' in bool_gauge_set or 'train_and_evaluate' in bool_gauge_set: - return True - if 'predict' in bool_gauge_set: - return False - else: + if not bool_gauge_set: op_list = tf.compat.v1.get_default_graph().get_operations() for t_op in op_list: if check_op(t_op): return True + + if 'train' in bool_gauge_set or 'train_and_evaluate' in bool_gauge_set: + return True + return False -- Gitee From c4decb7d71d4528760386b7f649ed18f0d361156 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 9 Aug 2023 09:50:24 +0800 Subject: [PATCH 244/551] Match-id-15e550de2a91a3864deae53abe1bfdf0d9f231c2 --- mx_rec/util/communication/__init__.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 mx_rec/util/communication/__init__.py diff --git a/mx_rec/util/communication/__init__.py b/mx_rec/util/communication/__init__.py new file mode 100644 index 00000000..d9fc9564 --- /dev/null +++ b/mx_rec/util/communication/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. \ No newline at end of file -- Gitee From 0028cdec4dfa4a8bba4033a28a203bbb6ad46b9b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 9 Aug 2023 14:45:54 +0800 Subject: [PATCH 245/551] Match-id-c45caa380c7707efdff57244406e49a5adc7cbaa --- mx_rec/core/asc/helper.py | 148 ++-------------------------- mx_rec/core/asc/manager.py | 7 +- mx_rec/core/asc/merge_table.py | 172 +++++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 145 deletions(-) create mode 100644 mx_rec/core/asc/merge_table.py diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 3b24efa9..64b78f2c 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -4,15 +4,12 @@ import logging from functools import reduce -from typing import List -from typing import Dict + import tensorflow as tf -from tensorflow import Tensor -from tensorflow import Operation -from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static, \ - get_enable_table_merge, export_table_instances, insert_dangling_table +from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.asc.merge_table import find_dangling_table, should_skip def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, feature_numbers=None, @@ -48,137 +45,6 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names **kwargs) -def find_dangling_table(table_names: List[str]): - """ Find the tables which are disconenct with the forward training graph. And - these table will not be backward updated. - - :param table_names: list of all created tables' names - :return: a list of dangling table names. - """ - - def check_tensor(table_reachable_tensor: Tensor): - """Check whether the tensor op is optimizer op or backward gradient. - - Args: - table_reachable_tensor: tensor - Returns: - bool - """ - if table_reachable_tensor.op.type == 'ApplyAdam': - return True - - if 'gradients/' in table_reachable_tensor.name and table_reachable_tensor.op.type == 'Identity': - return True - - if 'SparseSoftmaxCrossEntropyWithLogits' in table_reachable_tensor.op.name \ - and table_reachable_tensor.op.type == 'SparseSoftmaxCrossEntropyWithLogits': - return True - - return False - - def find_table_op(table_name: str, - the_op: Operation, - table_lookup_op: Dict[str, List[Operation]], - table_reachable_tensor: Dict[str, List[Tensor]]): - """ find all the table lookup op. - :param table_name: tables' names - :param the_op: the op to be - :param table_lookup_op: list of the table lookup ops - :param table_reachable_tensor: the tensors which table lookup op can reach ( - here we just add the table lookup op's output tensors). - The data structure is map, key is table_name, value is the output tensors of table lookup op. - :return: None - """ - if table_name in the_op.name and the_op.type == "IdentityN": - if table_name not in table_lookup_op: - table_lookup_op[table_name] = [the_op] - table_reachable_tensor[table_name] = the_op.outputs - else: - table_lookup_op[table_name].append(the_op) - table_reachable_tensor[table_name].extend(the_op.outputs) - - op_list = tf.compat.v1.get_default_graph().get_operations() - - table_lookup_op = {} - table_reachable_tensor = {} - - for _, table_instance in export_table_instances().items(): - if table_instance.table_name not in table_names: - table_names.append(table_instance.table_name) - - for the_op in op_list: - for table_name in table_names: - find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) - - logging.info(f"*********** find tables: {table_lookup_op} ***********") - - dangling_table = [] - - for table_name in table_names: - if table_name not in table_lookup_op: - logging.info(f"*********** created table {table_name} but never look up***********") - dangling_table.append(table_name) - insert_dangling_table(table_name) - - - def extend(op_list: List[Operation], - tensor: Tensor, - spread_tensors: List[Tensor]): - """extend the tensors which table lookup op can reach - - :param op_list: all op in the graph - :param tensor: the tensor visited by bfs - :param spread_tensors: the list of tensors which table lookup op can reach - :return: - """ - for the_op in op_list: - if tensor in the_op.inputs: - spread_tensors.extend(the_op.outputs) - - def bfs_lookup(next_to_visit: List[Tensor]): - """find all the tensors which table lookup op can reach - - :param next_to_visit: the tensor list to be visited by bfs - :return: bool value indicate whether reached optimizer op or backward gradient op - """ - tensors_visited = set() - while next_to_visit: - spread_tensors = [] - for tensor in next_to_visit: - if tensor in tensors_visited: - continue - if check_tensor(tensor): - return True - tensors_visited.add(tensor) - extend(op_list, tensor, spread_tensors) - next_to_visit = spread_tensors - return False - - for table_name, table_op in table_reachable_tensor.items(): - found = bfs_lookup(table_op) - if not found: - dangling_table.append(table_name) - insert_dangling_table(table_name) - return dangling_table - - -def should_skip(table_name): - from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN - if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ - and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, str) \ - and ASCEND_TABLE_NAME_MUST_CONTAIN not in table_name: - return True - if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ - and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, list): - skip = True - for key_word in ASCEND_TABLE_NAME_MUST_CONTAIN: - if isinstance(key_word, str) and key_word in table_name: - skip = False - break - return skip - return False - - def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_counts=None, table_names=None, **kwargs): both_none = tgt_key_specs is None and args_index_list is None @@ -221,9 +87,7 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ if feature_counts is None or table_names is None: raise ValueError("Please config 'args_index_list', 'feature_counts' and 'table_names' at the same time.") - dangling_tables = [] - if get_enable_table_merge(): - dangling_tables = find_dangling_table(table_names) + dangling_tables = find_dangling_table(table_names) logging.info(f"In insert found dangling table(s): {dangling_tables} " f"which does not need to be provided to the EmbInfo.") @@ -241,12 +105,12 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ new_insert_tensors, new_splits, new_table_names = [], [], [] for idx, table_name in enumerate(table_names): if table_name in dangling_tables: - logging.info(f"do_insert skip table : {table_name}") + logging.info(f"do_insert skip table by graph : {table_name}") continue skip = should_skip(table_name) if skip: - logging.info(f"do_insert skip table 2: {table_name}") + logging.info(f"do_insert skip table by keyword: {table_name}") continue new_insert_tensors.append(insert_tensors[idx]) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 3892837e..832bf53d 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -3,6 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import logging + import tensorflow as tf from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo @@ -11,8 +12,8 @@ from mx_rec.constants.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_steps, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ - get_use_hot, get_use_dynamic_expansion, get_enable_table_merge, export_optimizer, export_dangling_table -from mx_rec.core.asc.helper import find_dangling_table, should_skip + get_use_hot, get_use_dynamic_expansion, export_optimizer, export_dangling_table +from mx_rec.core.asc.merge_table import find_dangling_table, should_skip def check_dangling_table(): @@ -21,7 +22,7 @@ def check_dangling_table(): :return: list of dangling_table """ dangling_table = export_dangling_table() - if not dangling_table and get_enable_table_merge(): + if not dangling_table: dangling_table = find_dangling_table([table_instance.table_name for _, table_instance in export_table_instances().items()]) return dangling_table diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py new file mode 100644 index 00000000..1d314e34 --- /dev/null +++ b/mx_rec/core/asc/merge_table.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import logging +from typing import Dict, List + +import tensorflow as tf +from tensorflow import Operation, Tensor + +from mx_rec.util.initialize import get_enable_table_merge, export_table_instances, insert_dangling_table, \ + get_bool_gauge_set + + +def affirm(reach_op:List[Operation]) -> bool: + for node in reach_op: + if node.type not in ("IdentityN", "Reshape"): + return False + return True + + +def check_op(table_reachable_op: Operation) -> bool: + """Check whether the tensor op is optimizer op or backward gradient. + + Args: + table_reachable_tensor: tensor + Returns: + bool + """ + if table_reachable_op.type == 'ApplyAdam': + return True + + if 'gradients' in table_reachable_op.name and \ + table_reachable_op.type in ['UnsortedSegmentSum', 'TensorScatterUpdate']: + return True + + return False + + +def is_train_task(): + bool_gauge_set = get_bool_gauge_set() + if not bool_gauge_set: + op_list = tf.compat.v1.get_default_graph().get_operations() + for t_op in op_list: + if check_op(t_op): + return True + + if 'train' in bool_gauge_set or 'train_and_evaluate' in bool_gauge_set: + return True + + return False + + +def find_dangling_table(table_names: List[str]) -> List[str]: + """ Find the tables which are disconenct with the forward training graph. And + these table will not be backward updated. + + :param table_names: list of all created tables' names + :return: a list of dangling table names. + """ + + def find_table_op(table_name: str, + the_op: Operation, + table_lookup_op: Dict[str, List[Operation]], + table_reachable_tensor: Dict[str, List[Tensor]]) -> None: + """ find all the table lookup op. + :param table_name: tables' names + :param the_op: the op to be + :param table_lookup_op: list of the table lookup ops + :param table_reachable_tensor: the tensors which table lookup op can reach ( + here we just add the table lookup op's output tensors). + The data structure is map, key is table_name, value is the output tensors of table lookup op. + :return: None + """ + if table_name in the_op.name and the_op.type == "IdentityN": + if table_name not in table_lookup_op: + table_lookup_op[table_name] = [the_op] + table_reachable_tensor[table_name] = [] + table_reachable_tensor[table_name].extend(the_op.outputs) + elif the_op not in table_lookup_op[table_name]: + table_lookup_op[table_name].append(the_op) + table_reachable_tensor[table_name].extend(the_op.outputs) + + + def extend(op_list: List[Operation], + tensor: Tensor, + spread_tensors: List[Tensor]) -> None: + """extend the tensors which table lookup op can reach + + :param op_list: all op in the graph + :param tensor: the tensor visited by bfs + :param spread_tensors: the list of tensors which table lookup op can reach + :return: + """ + for the_op in op_list: + if tensor in the_op.inputs: + spread_tensors.extend(the_op.outputs) + + + def bfs_lookup(next_to_visit: List[Tensor]) -> (set, bool): + """find all the tensors which table lookup op can reach + + :param next_to_visit: the tensor list to be visited by bfs + :return: bool value indicate whether reached optimizer op or backward gradient op + """ + tensors_visited = set() + op_visited = set() + while next_to_visit: + spread_tensors = [] + for tensor in next_to_visit: + if tensor in tensors_visited: + continue + if check_op(tensor.op): + return op_visited, True + tensors_visited.add(tensor) + op_visited.add(tensor.op) + extend(op_list, tensor, spread_tensors) + next_to_visit = spread_tensors + return op_visited, False + + + if not is_train_task(): + logging.info(f"!!merge table only available in train task.") + return [] + if not get_enable_table_merge(): + return [] + + op_list = tf.compat.v1.get_default_graph().get_operations() + + table_lookup_op = {} + table_reachable_tensor = {} + + for _, table_instance in export_table_instances().items(): + if table_instance.table_name not in table_names: + table_names.append(table_instance.table_name) + + for the_op in op_list: + for table_name in table_names: + find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) + + logging.debug("*********** find tables: %s ***********", table_lookup_op) + dangling_table = [] + + for table_name in table_names: + if table_name not in table_lookup_op: + logging.debug("*********** created table %s but never look up***********", table_name) + dangling_table.append(table_name) + insert_dangling_table(table_name) + + for table_name, table_op in table_reachable_tensor.items(): + reach_op, found = bfs_lookup(table_op) + if not found and affirm(reach_op): + dangling_table.append(table_name) + insert_dangling_table(table_name) + return dangling_table + + +def should_skip(table_name) -> bool: + from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN + if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ + and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, str) \ + and ASCEND_TABLE_NAME_MUST_CONTAIN not in table_name: + return True + if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ + and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, list): + skip = True + for key_word in ASCEND_TABLE_NAME_MUST_CONTAIN: + if isinstance(key_word, str) and key_word in table_name: + skip = False + break + return skip + return False -- Gitee From 5dc890663ff0ed8ff4d33c1d896078e47a5d2e22 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 9 Aug 2023 16:47:06 +0800 Subject: [PATCH 246/551] Match-id-243d4df11d76fed5e189720a516a04c6ce611cd8 --- mx_rec/__init__.py | 2 ++ mx_rec/constants/__init__.py | 2 ++ mx_rec/core/__init__.py | 2 ++ mx_rec/core/asc/__init__.py | 2 ++ mx_rec/graph/__init__.py | 2 ++ mx_rec/util/__init__.py | 3 ++- mx_rec/util/communication/__init__.py | 4 +++- 7 files changed, 15 insertions(+), 2 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 3b1de2b6..465a6833 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +__all__ = ["constants", "core", "graph", "util"] + from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook from mx_rec.saver.patch import patch_for_saver diff --git a/mx_rec/constants/__init__.py b/mx_rec/constants/__init__.py index 6924f767..0270daf3 100644 --- a/mx_rec/constants/__init__.py +++ b/mx_rec/constants/__init__.py @@ -1,3 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +__all__ = ["constants"] \ No newline at end of file diff --git a/mx_rec/core/__init__.py b/mx_rec/core/__init__.py index 6924f767..43336ebe 100644 --- a/mx_rec/core/__init__.py +++ b/mx_rec/core/__init__.py @@ -1,3 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +__all__ = ["asc", "embedding"] \ No newline at end of file diff --git a/mx_rec/core/asc/__init__.py b/mx_rec/core/asc/__init__.py index 6924f767..b9575dd5 100644 --- a/mx_rec/core/asc/__init__.py +++ b/mx_rec/core/asc/__init__.py @@ -1,3 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +__all__ = ["feature_spec", "helper", "manager"] \ No newline at end of file diff --git a/mx_rec/graph/__init__.py b/mx_rec/graph/__init__.py index 6924f767..ee63132c 100644 --- a/mx_rec/graph/__init__.py +++ b/mx_rec/graph/__init__.py @@ -1,3 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +__all__ = ["modifier"] \ No newline at end of file diff --git a/mx_rec/util/__init__.py b/mx_rec/util/__init__.py index 6b6497b8..f46049eb 100644 --- a/mx_rec/util/__init__.py +++ b/mx_rec/util/__init__.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -from mx_rec.util.log import get_log_level +__all__ = ["initialize", "variable"] +from mx_rec.util.log import get_log_level get_log_level() diff --git a/mx_rec/util/communication/__init__.py b/mx_rec/util/communication/__init__.py index d9fc9564..9da1fcf6 100644 --- a/mx_rec/util/communication/__init__.py +++ b/mx_rec/util/communication/__init__.py @@ -1,3 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. \ No newline at end of file +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +__all__ = ["hccl_mgmt"] \ No newline at end of file -- Gitee From 90372f32f2168a5b075037c3f6819a2f45c759e9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 9 Aug 2023 17:37:59 +0800 Subject: [PATCH 247/551] Match-id-a7aa6fbd134d66a734c2d47054d7ffdfc6ecc4a1 --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 4a169a7c..d1a975d7 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,6 @@ setup( name='mx_rec', version=VERSION, author='HUAWEI Inc', - url='https://www.hiascend.com/zh/software/mindx-sdk', description='MindX SDK Recommend', long_description=LONG_DESCRIPTION, # include mx_rec -- Gitee From 70a3881bea1cb62e2f430cccda366a549b03f283 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 9 Aug 2023 17:42:08 +0800 Subject: [PATCH 248/551] Match-id-8b73f49079aaf99bc39bede9196e1fa7d85e0f6f --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 4a169a7c..d1a975d7 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,6 @@ setup( name='mx_rec', version=VERSION, author='HUAWEI Inc', - url='https://www.hiascend.com/zh/software/mindx-sdk', description='MindX SDK Recommend', long_description=LONG_DESCRIPTION, # include mx_rec -- Gitee From 0e21ca38bfd676070cec064ee3f73673b62e33ee Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 9 Aug 2023 21:57:26 +0800 Subject: [PATCH 249/551] Match-id-e350be412c08f5f2408ce358daa669c9a557c870 --- mx_rec/__init__.py | 5 ++--- mx_rec/optimizers/__init__.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 465a6833..a5a58739 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - -__all__ = ["constants", "core", "graph", "util"] +__all__ = ["constants", "core", "graph", "util", + "version", "__version__"] from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook @@ -11,7 +11,6 @@ from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creato patch_for_end, patch_for_assert_eval_spec from mx_rec.optimizers.base import patch_for_optimizer - patch_for_saver() patch_for_dataset() patch_for_chief_session_creator() diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index 6924f767..22b891b6 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -1,3 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +__all__ = ["create_hash_optimizer", "create_ftrl_dense_optimizer", + "create_hash_optimizer_by_addr", "create_hash_optimizer_by_address"] + +from mx_rec.optimizers.adagrad import create_hash_optimizer +from mx_rec.optimizers.ftrl import create_hash_optimizer +from mx_rec.optimizers.ftrl_t import create_hash_optimizer +from mx_rec.optimizers.ftrl_t_dense import create_ftrl_dense_optimizer +from mx_rec.optimizers.gradient_descent import create_hash_optimizer +from mx_rec.optimizers.gradient_descent_by_addr import create_hash_optimizer_by_addr +from mx_rec.optimizers.lazy_adam import create_hash_optimizer +from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address +from mx_rec.optimizers.momentum import create_hash_optimizer \ No newline at end of file -- Gitee From 1fb8b89ce860f193e9f397a69808e4af6264a535 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 10 Aug 2023 12:40:33 +0800 Subject: [PATCH 250/551] Match-id-a55475559f76111a267beb0e76b12b1d4faae036 --- example/little_demo/run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh index 1d5b8b57..e03f3fe5 100644 --- a/example/little_demo/run.sh +++ b/example/little_demo/run.sh @@ -50,7 +50,7 @@ mx_rec_package_path="/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec" so_path=${mx_rec_package_path}/libasc # GLOG_stderrthreshold 0:INFO 1:WARNING 2:ERROR 3:FATAL # GLOG_v 1:DEBUG(print as INFO) 2:TRACE(print as INFO) -mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=0 -x GLOG_logtostderr=true -x GLOG_v=0 -bind-to none' +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=0 -x GLOG_logtostderr=true -x GLOG_v=0 -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0 ' interface="lo" local_rank_size=8 # 每个节点使用的NPU卡数 num_server=1 # 训练节点数 -- Gitee From 006b0e730ab44592cd699f1636256be68b080175 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 10 Aug 2023 20:54:23 +0800 Subject: [PATCH 251/551] Match-id-f76339a03811fcaf66b717cd94e81bd438a7fcd7 --- MANIFEST.in | 2 - build/build.sh | 1 - build/build_all.sh | 2 - build/build_tf1.sh | 2 - build/build_tf2.sh | 1 - example/__init__.py | 0 example/little_demo/config.py | 113 ------ example/little_demo/dataset.py | 60 ---- example/little_demo/main.py | 221 ------------ example/little_demo/model.py | 93 ----- example/little_demo/op_impl_mode.ini | 3 - example/little_demo/optimizer.py | 22 -- example/little_demo/random_data_generator.py | 52 --- example/little_demo/run.sh | 153 -------- example/little_demo/run_mode.py | 146 -------- tools/mx_rec_perf.sh | 71 ---- tools/parse_data/data_parser.py | 124 ------- tools/parse_data/run.sh | 11 - tools/perf/fast.sh | 346 ------------------- tools/perf/host_set.sh | 17 - tools/perf/msprof.sh | 24 -- tools/perf/mt_1207.sh | 60 ---- tools/perf/perf_flame_graph.sh | 37 -- tools/python/key_2_emb_formatter.py | 216 ------------ 24 files changed, 1777 deletions(-) delete mode 100755 MANIFEST.in delete mode 100644 example/__init__.py delete mode 100644 example/little_demo/config.py delete mode 100644 example/little_demo/dataset.py delete mode 100644 example/little_demo/main.py delete mode 100644 example/little_demo/model.py delete mode 100644 example/little_demo/op_impl_mode.ini delete mode 100644 example/little_demo/optimizer.py delete mode 100644 example/little_demo/random_data_generator.py delete mode 100644 example/little_demo/run.sh delete mode 100644 example/little_demo/run_mode.py delete mode 100644 tools/mx_rec_perf.sh delete mode 100755 tools/parse_data/data_parser.py delete mode 100755 tools/parse_data/run.sh delete mode 100755 tools/perf/fast.sh delete mode 100755 tools/perf/host_set.sh delete mode 100755 tools/perf/msprof.sh delete mode 100755 tools/perf/mt_1207.sh delete mode 100755 tools/perf/perf_flame_graph.sh delete mode 100644 tools/python/key_2_emb_formatter.py diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100755 index 3fc1dbb6..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -include mx_rec/tools/* -include mx_rec/tools/*/* \ No newline at end of file diff --git a/build/build.sh b/build/build.sh index 5eff8d3a..ab676ad2 100644 --- a/build/build.sh +++ b/build/build.sh @@ -58,7 +58,6 @@ gen_tar_file() cd "${src_path}" mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" - cp -r "${src_path}"/../example ../build/"${pkg_dir}" cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" cd ../build tar -zvcf "${release_tar}" "${pkg_dir}" || { diff --git a/build/build_all.sh b/build/build_all.sh index 926e6977..8492b989 100644 --- a/build/build_all.sh +++ b/build/build_all.sh @@ -158,7 +158,6 @@ gen_wheel_file() touch "${src_path}"/libasc/__init__.py remove "${ROOT_DIR}"/mx_rec/libasc mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec - cp -rf "${ROOT_DIR}"/tools "${ROOT_DIR}"/mx_rec python3 setup.py bdist_wheel --plat-name=linux_$(arch) mkdir -p "$1" mv dist/mx_rec*.whl "$1" @@ -170,7 +169,6 @@ gen_tar_file() cd "${src_path}" mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" - cp -r "${src_path}"/../example ../build/"${pkg_dir}" cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" cd ../build tar -zvcf "${release_tar}" "${pkg_dir}" || { diff --git a/build/build_tf1.sh b/build/build_tf1.sh index 255a6922..06317e52 100644 --- a/build/build_tf1.sh +++ b/build/build_tf1.sh @@ -106,7 +106,6 @@ gen_wheel_file() touch "${src_path}"/libasc/__init__.py remove "${ROOT_DIR}"/mx_rec/libasc mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec - cp -rf "${ROOT_DIR}"/tools "${ROOT_DIR}"/mx_rec python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) mkdir -p "$1" mv dist/mx_rec*.whl "$1" @@ -118,7 +117,6 @@ gen_tar_file() cd "${src_path}" mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" - cp -r "${src_path}"/../example ../build/"${pkg_dir}" cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" cd ../build tar -zvcf "${release_tar}" "${pkg_dir}" || { diff --git a/build/build_tf2.sh b/build/build_tf2.sh index b2397050..af8318e4 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -108,7 +108,6 @@ gen_wheel_file() touch "${src_path}"/libasc/__init__.py remove "${ROOT_DIR}"/mx_rec/libasc mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec - cp -rf "${ROOT_DIR}"/tools "${ROOT_DIR}"/mx_rec python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) mkdir -p "$1" mv dist/mx_rec*.whl "$1" diff --git a/example/__init__.py b/example/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/example/little_demo/config.py b/example/little_demo/config.py deleted file mode 100644 index 6dd5d183..00000000 --- a/example/little_demo/config.py +++ /dev/null @@ -1,113 +0,0 @@ -# coding: UTF-8 -import logging -import math -import tensorflow as tf - -from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig - -from mx_rec.util.initialize import get_rank_size - - -class Config: - def __init__(self, mode="simple", task_name="default"): - self.task_name = task_name - if mode == "simple": - self.generate_simple_config() - else: - self.generate_large_scale_config() - - def generate_simple_config(self): - self.batch_number = 8192 - self.batch_size = 4096 - - self.key_type = tf.int64 - self.label_type = tf.float32 - self.value_type = tf.float32 - - self.item_range = 80000 - self.user_range = 200000 - self.category_range = 5000 - self.item_feat_cnt = 16 - self.user_feat_cnt = 8 - self.category_feat_cnt = 3 - self.access_threshold = 100 - self.eviction_threshold = 60 - - rank_size = get_rank_size() - coefficient = 1.1 - if rank_size != 0: - max_ui_send_cnt = max(self.item_feat_cnt, self.user_feat_cnt) - max_ui_range = max(self.item_range, self.user_range) - self.item_send_cnt = min(int(self.batch_size * self.item_feat_cnt * coefficient), - math.ceil(self.item_range / rank_size)) - self.item_vocab_size = max(self.item_send_cnt * rank_size * rank_size, self.item_range) - self.user_send_cnt = min(int(self.batch_size * max_ui_send_cnt * coefficient), - math.ceil(max_ui_range / rank_size)) - self.user_vocab_size = max(self.user_send_cnt * rank_size * rank_size, self.user_range) - self.category_send_cnt = min(int(self.batch_size * self.category_feat_cnt * coefficient), - math.ceil(self.category_range / rank_size)) - else: - raise ZeroDivisionError("rank size must be an integer which is greater value zero.") - - self.user_hashtable_dim = 32 - self.user_hashtable_threshold = 1 - self.item_hashtable_dim = 8 - self.item_hashtable_threshold = 1 - - self.learning_rate = 0.01 - - def generate_large_scale_config(self): - self.lookup_count = 40 - self.tensor_name_list = ["sparse_tensor_%d" % i for i in range(self.lookup_count)] - self.hashtable_name_list = ["hashtable_%d" % i for i in range(self.lookup_count)] - self.batch_size = 9600 - - self.key_type = tf.int64 - self.label_type = tf.float32 - self.value_type = tf.float32 - - self.vocabulary_size = 500000 - self.feat_cnt = 1 - - rank_size = get_rank_size() - coefficient = 1.1 - if rank_size != 0: - self.send_cnt = min(int(self.batch_size * self.feat_cnt * coefficient), - math.ceil(self.vocabulary_size / rank_size)) - else: - raise ZeroDivisionError("rank size must be an integer which is greater value zero.") - - self.hashtable_dim = 8 - self.learning_rate = 0.01 - - -def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): - session_config = tf.compat.v1.ConfigProto(allow_soft_placement=False, - log_device_placement=False) - - session_config.gpu_options.allow_growth = True - custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() - custom_op.name = "NpuOptimizer" - custom_op.parameter_map["mix_compile_mode"].b = False - custom_op.parameter_map["use_off_line"].b = True - custom_op.parameter_map["min_group_size"].b = 1 - custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:pairwise;level1:pairwise") - custom_op.parameter_map["enable_data_pre_proc"].b = True - custom_op.parameter_map["iterations_per_loop"].i = 1 - custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") - custom_op.parameter_map["hcom_parallel"].b = False - custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_impl_mode.ini") - - if dump_data: - """ - To see the details, please refer to the descriptions at official web site - """ - custom_op.parameter_map["enable_dump"].b = True - custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(dump_path) - custom_op.parameter_map["dump_step"].s = tf.compat.as_bytes(dump_steps) - custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all") - - session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF - session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF - - return session_config diff --git a/example/little_demo/dataset.py b/example/little_demo/dataset.py deleted file mode 100644 index 7a9cf82c..00000000 --- a/example/little_demo/dataset.py +++ /dev/null @@ -1,60 +0,0 @@ -# coding: UTF-8 -import tensorflow as tf - -from random_data_generator import get_data_generator, get_large_scale_data_generator -from mx_rec.util.initialize import get_rank_size, get_rank_id, get_host_pipeline_ops - - -def generate_dataset(cfg, use_timestamp=False, batch_number=100): - dataset = tf.compat.v1.data.Dataset.from_generator( - generator=get_data_generator(cfg, batch_number=batch_number), - output_types={"item_ids": cfg.key_type, - "user_ids": cfg.key_type, - "category_ids": cfg.key_type, - "label_0": cfg.label_type, - "label_1": cfg.label_type}, - output_shapes={"item_ids": tf.TensorShape([cfg.batch_size, cfg.item_feat_cnt]), - "user_ids": tf.TensorShape([cfg.batch_size, cfg.user_feat_cnt]), - "category_ids": tf.TensorShape([cfg.batch_size, cfg.category_feat_cnt]), - "label_0": tf.TensorShape([cfg.batch_size]), - "label_1": tf.TensorShape([cfg.batch_size])}) - if use_timestamp: - dataset = dataset.map(add_timestamp_func) - - rank_size = get_rank_size() - rank_id = get_rank_id() - if rank_size > 1: - dataset = dataset.shard(rank_size, rank_id) - - return dataset - - -def add_timestamp_func(batch): - host_pipeline_ops = get_host_pipeline_ops() - timestamp = host_pipeline_ops.return_timestamp(tf.cast(batch['label_0'], tf.int64)) - batch["timestamp"] = timestamp - return batch - - -def generate_large_scale_data(cfg): - key_type_list = [cfg.key_type for _ in range(cfg.lookup_count)] - output_type_dict = dict(zip(cfg.tensor_name_list, key_type_list)) - output_type_dict["label_0"] = cfg.label_type - output_type_dict["label_1"] = cfg.label_type - - tensor_shape_list = [tf.TensorShape([cfg.batch_size]) for _ in range(cfg.lookup_count)] - output_shape_dict = dict(zip(cfg.tensor_name_list, tensor_shape_list)) - output_shape_dict["label_0"] = tf.TensorShape([cfg.batch_size]) - output_shape_dict["label_1"] = tf.TensorShape([cfg.batch_size]) - - dataset = tf.data.Dataset.from_generator(generator=get_large_scale_data_generator(cfg), - output_types=output_type_dict, - output_shapes=output_shape_dict) - rank_size = get_rank_size() - rank_id = get_rank_id() - if rank_size > 1: - dataset = dataset.shard(rank_size, rank_id) - - iterator = dataset.make_initializable_iterator() - batch = iterator.get_next() - return batch, iterator diff --git a/example/little_demo/main.py b/example/little_demo/main.py deleted file mode 100644 index 44d867b5..00000000 --- a/example/little_demo/main.py +++ /dev/null @@ -1,221 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - -import logging -import os -import warnings -from glob import glob - -import tensorflow as tf -from config import Config -from dataset import generate_dataset -from optimizer import create_dense_and_sparse_optimizer -from model import MyModel -from run_mode import RunMode, UseMode - -from mx_rec.core.asc.feature_spec import FeatureSpec -from mx_rec.core.asc.helper import get_asc_insert_func -from mx_rec.core.asc.manager import start_asc_pipeline -from mx_rec.core.embedding import create_table, sparse_lookup -from mx_rec.graph.modifier import modify_graph_and_start_emb_cache -from mx_rec.constants.constants import MxRecMode, ASCEND_TIMESTAMP -from mx_rec.util.initialize import get_rank_id, init, terminate_config_initializer, set_if_load, get_rank_size -from mx_rec.util.variable import get_dense_and_sparse_variable -from mx_rec.constants.constants import ApplyGradientsStrategy - -tf.compat.v1.disable_eager_execution() - - -def make_batch_and_iterator(is_training, feature_spec_list=None, - use_timestamp=False, dump_graph=False, batch_number=100): - dataset = generate_dataset(cfg, use_timestamp=use_timestamp, batch_number=batch_number) - if not MODIFY_GRAPH_FLAG: - insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph) - dataset = dataset.map(insert_fn) - dataset = dataset.prefetch(100) - iterator = dataset.make_initializable_iterator() - batch = iterator.get_next() - return batch, iterator - - -def model_forward(input_list, batch, is_train, modify_graph, config_dict=None): - embedding_list = [] - feature_list, hash_table_list, send_count_list = input_list - for feature, hash_table, send_count in zip(feature_list, hash_table_list, send_count_list): - access_and_evict_config = None - if isinstance(config_dict, dict): - access_and_evict_config = config_dict.get(hash_table.table_name) - embedding = sparse_lookup(hash_table, feature, send_count, dim=None, is_train=is_train, - access_and_evict_config=access_and_evict_config, - name=hash_table.table_name + "_lookup", modify_graph=modify_graph, batch=batch) - - reduced_embedding = tf.reduce_sum(embedding, axis=1, keepdims=False) - embedding_list.append(reduced_embedding) - - my_model = MyModel() - my_model(embedding_list, batch["label_0"], batch["label_1"]) - return my_model - - -def build_graph(hash_table_list, is_train, feature_spec_list=None, config_dict=None, batch_number=100): - batch, iterator = make_batch_and_iterator(is_train, feature_spec_list=feature_spec_list, - use_timestamp=USE_TIMESTAMP, dump_graph=is_train, - batch_number=batch_number) - if MODIFY_GRAPH_FLAG: - input_list = [[batch["user_ids"], batch["item_ids"]], - [hash_table_list[0], hash_table_list[1]], - [cfg.user_send_cnt, cfg.item_send_cnt]] - if use_multi_lookup: - input_list = [[batch["user_ids"], batch["item_ids"], batch["user_ids"], batch["item_ids"]], - [hash_table_list[0], hash_table_list[0], hash_table_list[0], hash_table_list[1]], - [cfg.user_send_cnt, cfg.item_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt]] - if USE_TIMESTAMP: - tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, batch["timestamp"]) - model = model_forward(input_list, batch, - is_train=is_train, modify_graph=True, config_dict=config_dict) - else: - input_list = [feature_spec_list, - [hash_table_list[0], hash_table_list[1]], - [cfg.user_send_cnt, cfg.item_send_cnt]] - if use_multi_lookup: - input_list = [feature_spec_list, - [hash_table_list[0], hash_table_list[1], hash_table_list[0], hash_table_list[0]], - [cfg.user_send_cnt, cfg.item_send_cnt, cfg.user_send_cnt, cfg.item_send_cnt]] - model = model_forward(input_list, batch, - is_train=is_train, modify_graph=False, config_dict=config_dict) - - return iterator, model - - -def create_feature_spec_list(use_timestamp=False): - access_threshold = cfg.access_threshold if use_timestamp else None - eviction_threshold = cfg.eviction_threshold if use_timestamp else None - feature_spec_list = [FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", - access_threshold=access_threshold, - eviction_threshold=eviction_threshold, - faae_coefficient=1), - FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="item_table", - access_threshold=access_threshold, - eviction_threshold=eviction_threshold, - faae_coefficient=4)] - if use_multi_lookup: - feature_spec_list.extend([FeatureSpec("user_ids", feat_count=cfg.user_feat_cnt, table_name="user_table", - access_threshold=access_threshold, - eviction_threshold=eviction_threshold, - faae_coefficient=1), - FeatureSpec("item_ids", feat_count=cfg.item_feat_cnt, table_name="user_table", - access_threshold=access_threshold, - eviction_threshold=eviction_threshold, - faae_coefficient=4)]) - if use_timestamp: - feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) - return feature_spec_list - - -if __name__ == "__main__": - tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) - warnings.filterwarnings("ignore") - - use_mode = UseMode.mapping(os.getenv("USE_MODE")) - mode = MxRecMode.mapping(os.getenv("MXREC_MODE")) - TRAIN_STEPS = 100 - EVAL_INTERVAL = 100 - EVAL_STEPS = 10 - SAVING_INTERVAL = 100 - - # get init configuration - try: - use_mpi = bool(int(os.getenv("USE_MPI", 1))) - use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) - use_hot = bool(int(os.getenv("USE_HOT", 0))) - use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) - use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 1))) - MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) - USE_TIMESTAMP = bool(int(os.getenv("USE_TIMESTAMP", 0))) - except ValueError as err: - raise ValueError(f"please correctly config USE_MPI or USE_DYNAMIC or USE_HOT or USE_DYNAMIC_EXPANSION or " - f"USE_MULTI_LOOKUP or USE_MODIFY_GRAPH or USE_TIMESTAMP only 0 or 1 is supported.") from err - - # nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0 - init(use_mpi=use_mpi, - train_steps=TRAIN_STEPS, - eval_steps=EVAL_STEPS, - prefetch_batch_number=1, - use_dynamic=use_dynamic, - use_hot=use_hot, - use_dynamic_expansion=use_dynamic_expansion) - IF_LOAD = False - rank_id = get_rank_id() - filelist = glob(f"./saved-model/sparse-model-{rank_id}-0") - if filelist: - IF_LOAD = True - set_if_load(IF_LOAD) - - cfg = Config() - # access_threshold unit counts; eviction_threshold unit seconds - ACCESS_AND_EVICT = None - if USE_TIMESTAMP: - config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, - faae_coefficient=1) - config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, - faae_coefficient=4) - ACCESS_AND_EVICT = dict(user_table=config_for_user_table, item_table=config_for_item_table) - train_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) - eval_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) - - optimizer_list = [create_dense_and_sparse_optimizer(cfg)] - sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] - - # 如需验证DDR模式,请按照key数量、batch unique数量合理设置device与host表大小。 - # 验证DDR的配置参考:数据集key总量大于device表,小于device+host;一个batch的unique key数量小于device表。 - user_hashtable = create_table(key_dtype=tf.int64, - dim=tf.TensorShape([cfg.user_hashtable_dim]), - name='user_table', - emb_initializer=tf.compat.v1.truncated_normal_initializer(), - device_vocabulary_size=cfg.user_vocab_size * 10, - host_vocabulary_size=0, - optimizer_list=sparse_optimizer_list, - mode=mode, - all2all_gradients_op="sum_gradients_and_div_by_ranksize", - apply_gradients_strategy = ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY) - - item_hashtable = create_table(key_dtype=tf.int64, - dim=tf.TensorShape([cfg.item_hashtable_dim]), - name='item_table', - emb_initializer=tf.compat.v1.truncated_normal_initializer(), - device_vocabulary_size=cfg.item_vocab_size * 10, - host_vocabulary_size=0, - optimizer_list=sparse_optimizer_list, - mode=mode) - - train_iterator, train_model = build_graph([user_hashtable, item_hashtable], is_train=True, - feature_spec_list=train_feature_spec_list, - config_dict=ACCESS_AND_EVICT, - batch_number=TRAIN_STEPS * get_rank_size()) - eval_iterator, eval_model = build_graph([user_hashtable, item_hashtable], is_train=False, - feature_spec_list=eval_feature_spec_list, - config_dict=ACCESS_AND_EVICT, - batch_number=EVAL_STEPS * get_rank_size()) - dense_variables, sparse_variables = get_dense_and_sparse_variable() - - run_mode = RunMode( - MODIFY_GRAPH_FLAG, optimizer_list, train_model, eval_model, train_iterator, eval_iterator, - TRAIN_STEPS, EVAL_STEPS - ) - - # start host pipeline - if not MODIFY_GRAPH_FLAG: - start_asc_pipeline() - # start modify graph - if MODIFY_GRAPH_FLAG and use_mode != UseMode.TRAIN: - logging.info("start to modifying graph") - modify_graph_and_start_emb_cache(dump_graph=True) - - if use_mode == UseMode.TRAIN: - run_mode.train(EVAL_INTERVAL, SAVING_INTERVAL) - elif use_mode == UseMode.PREDICT: - run_mode.predict() - - terminate_config_initializer() - logging.info("Demo done!") diff --git a/example/little_demo/model.py b/example/little_demo/model.py deleted file mode 100644 index 18ab98ca..00000000 --- a/example/little_demo/model.py +++ /dev/null @@ -1,93 +0,0 @@ -# coding: UTF-8 -from __future__ import print_function - -import tensorflow as tf - - -class MyModel: - def __init__(self): - self.layer_dims = [1024, 512, 256, 128] - self.act_func = 'relu' - self.keep_prob = 0.8 - self._lambda = 4.91e-7 - self.emb_dim = None - self.loss_list = [] - self.predict_list = [] - self.all_layer_dims = None - self.h_w, self.h_b = [], [] - self.h_w_head_0, self.h_w_head_1, self.h_b_head_0, self.h_b_head_1 = None, None, None, None - - def __call__(self, embedding_list, label_0, label_1, is_training=True): - with tf.compat.v1.variable_scope("mlp", reuse=tf.compat.v1.AUTO_REUSE): - embedding = tf.concat(embedding_list, axis=1) - self.emb_dim = embedding.shape.as_list()[-1] - self.all_layer_dims = [self.emb_dim] + self.layer_dims + [1] - - with tf.compat.v1.variable_scope("mlp", reuse=tf.compat.v1.AUTO_REUSE): - for i in range(len(self.all_layer_dims) - 2): - self.h_w.append(tf.compat.v1.get_variable('h%d_w' % (i + 1), shape=self.all_layer_dims[i: i + 2], - initializer=tf.random_uniform_initializer(-0.01, 0.01), - dtype=tf.float32, - collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"])) - self.h_b.append( - tf.compat.v1.get_variable('h%d_b' % (i + 1), shape=[self.all_layer_dims[i + 1]], - initializer=tf.zeros_initializer, - dtype=tf.float32, - collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"])) - i += 1 - self.h_w_head_0 = tf.compat.v1.get_variable('h_w_head_0', shape=self.all_layer_dims[i: i + 2], - initializer=tf.random_uniform_initializer(-0.01, 0.01), - dtype=tf.float32, - collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"]) - self.h_b_head_0 = tf.compat.v1.get_variable('h_b_head_0', shape=[self.all_layer_dims[i + 1]], - initializer=tf.zeros_initializer, - dtype=tf.float32, - collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"]) - self.h_w_head_1 = tf.compat.v1.get_variable('h_w_head_1', shape=self.all_layer_dims[i: i + 2], - initializer=tf.random_uniform_initializer(-0.01, 0.01), - dtype=tf.float32, - collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"]) - self.h_b_head_1 = tf.compat.v1.get_variable('h_b_head_1', shape=[self.all_layer_dims[i + 1]], - initializer=tf.zeros_initializer, - dtype=tf.float32, - collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"]) - - logit_list = self.forward(embedding, self.act_func, self.keep_prob, training=is_training) - - for logit, label in zip(logit_list, (label_0, label_1)): - train_preds = tf.sigmoid(logit) - - basic_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=label) - - deep_loss = tf.reduce_mean(basic_loss) # + _lambda * tf.nn.l2_loss(embedding) - self.predict_list.append(train_preds) - self.loss_list.append(deep_loss) - - - def forward(self, embedding, act_func, keep_prob, training): - hidden_output = tf.reshape(embedding, [-1, self.emb_dim]) # *512 - for i, h_w_var in enumerate(self.h_w): - hidden_output = tf.matmul(self.activate(act_func, hidden_output), h_w_var) - hidden_output = hidden_output + self.h_b[i] - - def output_head(hidden_output, h_w, h_b): - hidden_output_branch = tf.matmul(self.activate(act_func, hidden_output), h_w) - logit = hidden_output_branch + h_b - logit = tf.reshape(logit, [-1, ]) - - return logit - - logit_0 = output_head(hidden_output, self.h_w_head_0, self.h_b_head_0) - logit_1 = output_head(hidden_output, self.h_w_head_1, self.h_b_head_1) - logit_list = [logit_0, logit_1] - - return logit_list - - @staticmethod - def activate(act_func, input_x): - if act_func == 'tanh': - return tf.tanh(input_x) - elif act_func == 'relu': - return tf.nn.relu(input_x) - else: - return tf.sigmoid(input_x) diff --git a/example/little_demo/op_impl_mode.ini b/example/little_demo/op_impl_mode.ini deleted file mode 100644 index 4a744500..00000000 --- a/example/little_demo/op_impl_mode.ini +++ /dev/null @@ -1,3 +0,0 @@ -ScatterNdAdd=support_out_of_bound_index -GatherV2=high_performance -UnsortedSegmentSum=high_performance \ No newline at end of file diff --git a/example/little_demo/optimizer.py b/example/little_demo/optimizer.py deleted file mode 100644 index 43294764..00000000 --- a/example/little_demo/optimizer.py +++ /dev/null @@ -1,22 +0,0 @@ -# coding: UTF-8 - -import logging - -import tensorflow as tf - -from mx_rec.optimizers.lazy_adam import create_hash_optimizer -from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address -from mx_rec.util.initialize import get_use_dynamic_expansion - - -def create_dense_and_sparse_optimizer(cfg): - dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate) - use_dynamic_expansion = get_use_dynamic_expansion() - if use_dynamic_expansion: - sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate) - logging.info("optimizer lazy_adam_by_addr") - else: - sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate) - logging.info("optimizer lazy_adam") - - return dense_optimizer, sparse_optimizer diff --git a/example/little_demo/random_data_generator.py b/example/little_demo/random_data_generator.py deleted file mode 100644 index 473ac690..00000000 --- a/example/little_demo/random_data_generator.py +++ /dev/null @@ -1,52 +0,0 @@ -# coding: UTF-8 -import logging - -import numpy as np - -from mx_rec.util.initialize import get_rank_id - - -def get_data_generator(config, batch_number): - rank_id = get_rank_id() - - def data_generator(): - i = 0 - while i < batch_number: - item_ids = np.random.randint(0, config.item_range, (config.batch_size, config.item_feat_cnt)) - user_ids = np.random.randint(0, config.user_range, (config.batch_size, config.user_feat_cnt)) - category_ids = np.random.randint(0, config.category_range, (config.batch_size, config.category_feat_cnt)) - label_0 = np.random.randint(0, 2, (config.batch_size,)) - label_1 = np.random.randint(0, 2, (config.batch_size,)) - - yield {"item_ids": item_ids, - "user_ids": user_ids, - "category_ids": category_ids, - "label_0": label_0, - "label_1": label_1} - i += 1 - - logging.debug(f"================ end of data generator for {config.task_name} task | rank id {rank_id} " - f"================") - - return data_generator - - -def get_large_scale_data_generator(config): - def data_generator(): - i = 0 - while True: - id_list = [np.random.randint(0, config.vocabulary_size, (config.batch_size,)) - for _ in range(config.lookup_count)] - - data_block = dict(zip(config.tensor_name_list, id_list)) - - label_0 = np.random.randint(0, 2, (config.batch_size,)) - label_1 = np.random.randint(0, 2, (config.batch_size,)) - data_block["label_0"] = label_0 - data_block["label_1"] = label_1 - - logging.debug(f"================ generate NO.{i} step ================") - yield data_block - i += 1 - - return data_generator diff --git a/example/little_demo/run.sh b/example/little_demo/run.sh deleted file mode 100644 index e03f3fe5..00000000 --- a/example/little_demo/run.sh +++ /dev/null @@ -1,153 +0,0 @@ -kill -9 `ps -ef | grep python | grep -v grep | awk '{print $2}'` > /dev/null 2>&1 -rm -rf /root/ascend/log/* -rm -rf ./kernel* -rm -rf ./export_graph/* - -# 获取输入参数:py、ip -if [ $# -ge 1 ]; then - py=$1 - ip=$2 -else - echo "for example: bash run.sh main.py 10.10.10.10 or bash run.sh main.py" - exit 1 -fi - -# 检查输入的python文件是否合法 -if [[ $py =~ ^[a-z0-9_]+\.py$ ]]; then - echo "File $py is a valid Python file" -else - echo "File $py is not a Python file" - exit 1 -fi - -# 判断IP地址是否有效 -if [ -n "$ip" ]; then - if [[ $ip =~ ^([0-9]{1,3}\.){3}[0-9]{1,3}$ ]]; then - # 将IP地址拆分成四个数字 - ip_array=(${ip//./ }) - # 判断每个数字是否在0-255之间 - valid=true - for i in "${ip_array[@]}"; do - if ((i < 0 || i > 255)); then - valid=false - break - fi - done - if $valid; then - echo "ip: $ip is valid" - else - echo "ip: $ip is not valid" - exit 1 - fi - else - echo "ip: $ip is not valid." - exit 1 - fi -fi - -cur_path=`pwd` -mx_rec_package_path="/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec" # please config -so_path=${mx_rec_package_path}/libasc -# GLOG_stderrthreshold 0:INFO 1:WARNING 2:ERROR 3:FATAL -# GLOG_v 1:DEBUG(print as INFO) 2:TRACE(print as INFO) -mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=0 -x GLOG_logtostderr=true -x GLOG_v=0 -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0 ' -interface="lo" -local_rank_size=8 # 每个节点使用的NPU卡数 -num_server=1 # 训练节点数 -num_process=$((${num_server} * ${local_rank_size})) # 训练总的进程数,等于使用的NPU卡的总数 - -export HCCL_CONNECT_TIMEOUT=1200 # HCCL集合通信 建链超时时间,取值范围[120,7200] -export PYTHONPATH=${so_path}:$PYTHONPATH # 环境python安装路径 -#export LD_PRELOAD=/usr/lib64/libgomp.so.1 # GNU OpenMP动态库路径. 不应该使用LD_PRELOAD这种方式加载! -export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH -# 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 -export JOB_ID=10086 -# 训练任务使用的NPU卡数总数 -export MXREC_LOG_LEVEL="DEBUG" # 框架日志等级 -export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL -# 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 -export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL -export MXREC_MODE="ASC" -export USE_MPI=1 -export USE_MODE="train" # 支持[train, predict] - -if [ $USE_MODE = "train" ];then - echo "train mode: saved-model will be deleted" - rm -rf ./saved-model -fi - -################# 参数配置 ###################### -export USE_DYNAMIC=0 # 0:静态shape;1:动态shape -export USE_HOT=0 # 0:关闭hot emb;1: 开启hot emb -export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 -export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 -export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 -export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 -export UpdateEmb_V2=1 # 0: UpdateEmb同步更新;1:UpdateEmb_V2异步更新 -export USE_COMBINE_FAAE=0 # 0: separate history when faae; 1: combine history when faae -################# 性能调优相关 #################### -export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 -export FAST_UNIQUE=0 #if use fast unique -export MGMT_HBM_TASK_MODE=0 #if async h2d (get and send tensors) -################################################ - -# 帮助信息,不需要修改 -if [[ $1 == --help || $1 == -h ]];then - echo "Usage: ./run.sh [OPTION]... [IP]..." - echo " " - echo "parameter explain: - [OPTION] main.py - [IP] IP address of the host - -h/--help show help message - " - exit 1 -fi - -# 使用ranktable方案 -function rankTableSolution() { - echo "The ranktable solution" - export RANK_TABLE_FILE="${cur_path}/hccl_json_${local_rank_size}p.json" - export RANK_SIZE=$num_process - echo "RANK_TABLE_FILE=$RANK_TABLE_FILE" - if [ ! -f "$RANK_TABLE_FILE" ];then - echo "the rank table file does not exit. Please reference {hccl_json_8p.json} to correctly config rank table file" - exit 1 - fi -} - -if [ ! -n "$ip" ]; then - rankTableSolution -else - VALID_CHECK=$(echo $ip|awk -F. '$1<=255&&$2<=255&&$3<=255&&$4<=255{print "yes"}') - if echo $ip|grep -E "^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$">/dev/null; then - if [ "$VALID_CHECK" == "yes" ]; then - #################使用去除ranktable方案时开启###################### - echo "ip: $ip available." - echo "The ranktable solution is removed." - export CM_CHIEF_IP=$ip # 主节点ip - export CM_CHIEF_PORT=6000 # 主节点监听端口 - export CM_CHIEF_DEVICE=0 # 主节点device id - export CM_WORKER_IP=$ip # 当前节点ip - export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 - echo "CM_CHIEF_IP=$CM_CHIEF_IP" - echo "CM_CHIEF_PORT=$CM_CHIEF_PORT" - echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" - echo "CM_WORKER_IP=$CM_WORKER_IP" - echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" - echo "ASCEND_VISIBLE_DEVICES=$ASCEND_VISIBLE_DEVICES" - ######################################################### - else - echo "ip: $ip not available!" # 使用ranktable方案 - rankTableSolution - fi - else - echo "ip: $ip not available!" # 使用ranktable方案 - rankTableSolution - fi -fi - -echo "use horovod to start tasks" -DATE=$(date +%Y-%m-%d-%H-%M-%S) -horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ -python3.7 ${py} 2>&1 | tee "temp_${local_rank_size}p_${KEY_PROCESS_THREAD_NUM}t_${DATE}.log" - diff --git a/example/little_demo/run_mode.py b/example/little_demo/run_mode.py deleted file mode 100644 index 04b6ecc2..00000000 --- a/example/little_demo/run_mode.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - -import os -import logging - -import tensorflow as tf -from config import sess_config - -from mx_rec.util.initialize import get_initializer, get_rank_id, get_rank_size, clear_channel -from mx_rec.util.variable import get_dense_and_sparse_variable -from mx_rec.util.tf_version_adapter import hccl_ops -from mx_rec.constants.constants import BaseEnum -from mx_rec.graph.modifier import modify_graph_and_start_emb_cache - - -class UseMode(BaseEnum): - TRAIN = "train" - PREDICT = "predict" - - -class RunMode: - - def __init__( - self, is_modify_graph: bool, optimizer_list: list, train_model, eval_model, train_iterator, eval_iterator, - train_steps: int, infer_steps: int): - self.is_modify_graph = is_modify_graph - self.session = tf.compat.v1.Session(config=sess_config(dump_data=False)) - self.train_model = train_model - self.train_iterator = train_iterator - self.eval_model = eval_model - self.eval_iterator = eval_iterator - self.saver = tf.compat.v1.train.Saver() - self.rank_id = get_rank_id() - self.train_ops = [] - self.optimizer_list = optimizer_list - self.epoch = 1 - self.train_steps = train_steps - self.infer_steps = infer_steps - - def _infer(self): - if self.is_modify_graph: - self.session.run(get_initializer(False)) - else: - self.session.run(self.eval_iterator.initializer) - clear_channel(is_train_channel=False) - for i in range(1, self.infer_steps + 1): - logging.info("############### infer at step %d ################", i) - try: - self.session.run(self.eval_model.loss_list) - except tf.errors.OutOfRangeError: - logging.info(f"Encounter the end of Sequence for eval.") - break - - def set_train_ops(self): - dense_variables, sparse_variables = get_dense_and_sparse_variable() - - # multi task training - for loss, (dense_optimizer, sparse_optimizer) in zip(self.train_model.loss_list, self.optimizer_list): - # do dense optimization - grads = dense_optimizer.compute_gradients(loss, var_list=dense_variables) - avg_grads = [] - for grad, var in grads: - if get_rank_size() > 1: - grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None - if grad is not None: - avg_grads.append((grad, var)) - # apply gradients: update variables - self.train_ops.append(dense_optimizer.apply_gradients(avg_grads)) - - if bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))): - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET - - train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) - # do sparse optimization by addr - local_grads = tf.gradients(loss, train_emb_list) # local_embedding - grads_and_vars = [(grad, address) for grad, address in zip(local_grads, train_address_list)] - self.train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) - else: - # do sparse optimization - sparse_grads = tf.gradients(loss, sparse_variables) - grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] - self.train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) - - def train(self, eval_interval: int, saving_interval: int): - self.set_train_ops() - - # In train mode, graph modify needs to be performed after compute gradients - if self.is_modify_graph: - logging.info("start to modifying graph") - modify_graph_and_start_emb_cache(dump_graph=True) - self.session.run(get_initializer(True)) - else: - self.session.run(self.train_iterator.initializer) - self.session.run(tf.compat.v1.global_variables_initializer()) - - for i in range(1, self.train_steps + 1): - logging.info("################ training at step %d ################", i) - try: - self.session.run([self.train_ops, self.train_model.loss_list]) - except tf.errors.OutOfRangeError: - logging.info(f"Encounter the end of Sequence for training.") - break - else: - if i % eval_interval == 0: - self.evaluate() - - if i % saving_interval == 0: - self.saver.save(self.session, f"./saved-model/model-{self.rank_id}", global_step=i) - - self.saver.save(self.session, f"./saved-model/model-{self.rank_id}", global_step=i) - logging.info("################ training end ################") - - def evaluate(self): - logging.info("############### start evaluate, epoch:%d ################", self.epoch) - self._infer() - logging.info("############### evaluate end, epoch::%d ################", self.epoch) - self.epoch += 1 - - def predict(self): - logging.info(f"############### start predict ################") - import glob - import re - - model_file = glob.glob(f"./saved-model/sparse-model-{self.rank_id}-*") - if len(model_file) == 0: - raise ValueError("model file not exit") - - # get the latest model - pattern = f".*sparse-model-{self.rank_id}-([0-9]+).*" - latest_step = -1 - for file_path in model_file: - match = re.match(pattern, file_path) - if match and match.groups(): - step = int(match.groups()[0]) - - if step > latest_step: - latest_step = step - if latest_step == -1: - raise RuntimeError("latest model not found") - - self.saver.restore(self.session, f"./saved-model/model-{self.rank_id}-{latest_step}") - self._infer() - logging.info(f"############### predict end ################") diff --git a/tools/mx_rec_perf.sh b/tools/mx_rec_perf.sh deleted file mode 100644 index fe1ee706..00000000 --- a/tools/mx_rec_perf.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. -# Description MxRec性能分析脚本 V1.0 -set -e - -file="$1" #请输入spdlog文件 - -calculate_average() { - awk '{ - sum += $1; - count++ - } END { - average = sum / count; - print average - }' -} -perf() { - echo "read batch cost" - cat ${file} | grep 'read batch cost'|grep -v timeout|tail -n 20| awk 'NR%2==1' - echo "====================================" - echo "key process cost" - cat ${file} | grep 'key process cost'|tail - avg=$(cat ${file} | grep -Po '(?<=key process cost:)[^,:]+(?=,)'|tail -n +20 |calculate_average) - echo "Average: $avg" - echo "====================================" - echo "分析host和device流水,当 host key process 提前训练step时,host性能不为瓶颈" - echo "按输入训练step打印标志,(默认为step) Enter打开分析,按q退出" - read step - step="${step:-step}" - cat ${file} | grep -P "key process cost|${step}"|tail -n100|less -} -echo -e "\e[45m\e[1m =========MxRec分析脚本 V1.0========= \e[0m" -echo - -stuck_check() { - echo -e "\e[106m--------卡住、getnext超时问题定位----------\e[0m" - echo -n "超时通道为:" - cat ${file} | grep -Po "aicpu_getnext.*GetNext" - echo - echo "检查每张卡发送lookup数量:" - for i in {0..7} - do - line=$(cat ${file} | grep -P "send"|grep "h2d"|grep "1,${i}"|wc -l) - echo -n "$line " - done - echo - echo "检查每张卡发送h2d数量是否相同:" - for i in {0..7} - do - line=$(cat ${file} | grep "send"|grep "h2d"|grep "1,${i}"|wc -l) - echo -n "$line " - done - echo - echo "检查每张卡接收数量是否相同:" - for i in {0..7} - do - line=$(cat ${file} | grep "r recv"|grep "1,${i}"|wc -l) - echo -n "$line " - done - echo - echo "每张卡最后接收batch为:" - cat ${file}|grep "trans emb"|grep "info"|tail -} - -hot_check() { - # 查看hot emb去重率 - echo "表名及去重率(去重后/去重前)为:(应该要小于0.4)" - cat op_summary_*.csv |grep gather_for_restore_vector |awk -F "," '{print $6,$14,$15}'|sed 's/"//g'|sed 's/ [0-9]*;/\//' -} - -perf diff --git a/tools/parse_data/data_parser.py b/tools/parse_data/data_parser.py deleted file mode 100755 index 735c8faa..00000000 --- a/tools/parse_data/data_parser.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: -# Author: MindX SDK -# Create: 2023-06-28 - -# -----------------------------------------ReadMe Begin-------------------------------------------- -# 1. 功能描述 -# 本工具用于单测tensorflow数据解析阶段耗时,便于分析数据解析阶段是不是整个pipeline的瓶颈?堵塞了pipeline的流畅运行? -# 2. 注意事项 -# 数据解析逻辑主要包含在make_dataset函数中,本函数缺省使用criteo数据集。如果需要测试其他数据集的解析耗时,可根据需要重新定义make_dataset; -# 3. 绑核 -# 为了模拟真实场景,bind_cpu默认模拟了80核cpu、8worker平均分配核;如果worker数目不同、真实cpu核数不同,可根据需要重新定义bind_cpu函数; -# 4. 启动执行 -# 4.1 单worker执行: python3 data_parser.py -# 4.2 多worker执行: bash run.sh data_parser.py -# -----------------------------------------ReadMe End-------------------------------------------- - -import os -import sys -import time - -import logging -import psutil - -import tensorflow as tf - -logging.basicConfig(level=logging.DEBUG) - - -def make_dataset(data_path, batch_size=102400, line_per_sample=1024): - def extract_fn(data_record): - features = { - # Extract features using the keys set during creation - 'label': tf.FixedLenFeature(shape=(line_per_sample,), dtype=tf.int64), - 'sparse_feature': tf.FixedLenFeature(shape=(26 * line_per_sample,), dtype=tf.int64), - 'dense_feature': tf.FixedLenFeature(shape=(13 * line_per_sample,), dtype=tf.float32), - } - sample = tf.parse_single_example(data_record, features) - return sample - - def feat_cast(feat): - for name, tensor in feat.items(): - if tensor.dtype == tf.int64: - feat[name] = tf.cast(tensor, tf.int32) - return feat - - def reshape_fn(batch): - batch['label'] = tf.reshape(batch['label'], [-1, 1]) - batch['dense_feature'] = tf.reshape(batch['dense_feature'], [-1, 13]) - batch['dense_feature'] = tf.math.log(batch['dense_feature'] + 3.0) - batch['sparse_feature'] = tf.reshape(batch['sparse_feature'], [-1, 26]) - return batch - - file_list = sorted([os.path.join(data_path, file) for file in os.listdir(data_path)]) - dataset = tf.data.TFRecordDataset(file_list, num_parallel_reads=4) - - num_parallel = 8 - dataset = dataset.map(extract_fn, num_parallel_calls=num_parallel) - - line_cnt = batch_size // line_per_sample - dataset = dataset.batch(line_cnt, drop_remainder=True) - - dataset = dataset.map(feat_cast, num_parallel_calls=num_parallel) - dataset = dataset.map(reshape_fn, num_parallel_calls=num_parallel) - - dataset = dataset.prefetch(10) - return dataset - - -def bind_cpu(rank_id): - process = psutil.Process() - cpu_kernels = { - 0: 0, - 1: 10, - 2: 40, - 3: 50, - 4: 20, - 5: 30, - 6: 60, - 7: 70 - } - try: - process.cpu_affinity([cpu_kernels.get(rank_id) + x for x in range(10)]) - except IndexError: - logging.error("error cpu bind info, skipped.") - - -if __name__ == '__main__': - RANK_ID = 0 - if (len(sys.argv) > 1): - RANK_ID = int(sys.argv[1]) - bind_cpu(RANK_ID) - - DATA_PATH = "/media/mxRec/data/criteo_tfrecord_small/train" - train_dataset = make_dataset(DATA_PATH) - iterator = train_dataset.make_initializable_iterator() - next_batch = iterator.get_next() - - input_data = [] - for example in next_batch: - input_data.append(next_batch[example]) - - COUNT = 0 - TOTAL_TIME = 0.0 - - with tf.Session() as sess: - sess.run(iterator.initializer) - while True: - try: - start_time = time.time() - result = sess.run(input_data[0]) - end_time = time.time() - - COUNT += 1 - - if COUNT > 1: - TOTAL_TIME += end_time - start_time - logging.info("StepId:%d, StepTimeCost(ms):%f", COUNT, (end_time - start_time)) - except tf.errors.OutOfRangeError as e: - logging.error("End of Training Dataset") - break - logging.info("StepTimeCost avg(ms):%f", TOTAL_TIME / (COUNT - 1)) \ No newline at end of file diff --git a/tools/parse_data/run.sh b/tools/parse_data/run.sh deleted file mode 100755 index b3ab73bb..00000000 --- a/tools/parse_data/run.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: performace analysis tool -# Author: MindX SDK -# Create: 2023 -# History: NA - -for i in {0..7} -do - nohup python3 data_parser.py $i > rank_$i.log 2>&1 & -done \ No newline at end of file diff --git a/tools/perf/fast.sh b/tools/perf/fast.sh deleted file mode 100755 index aae6f9d4..00000000 --- a/tools/perf/fast.sh +++ /dev/null @@ -1,346 +0,0 @@ -#!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: performace analysis tool -# Author: MindX SDK -# Create: 2023 -# History: NA - -# -----------------------------------------ReadMe Begin-------------------------------------------- -# 1. 功能描述 -# 本工具用来分析模型执行过程中pipeline中各个pipe的耗时、以及各个pipe中的子模块(Step)的耗时,以便于发现系统瓶颈。 -# (pipeline的基本原理是:每个pipe的耗时近似相等,pipe之间的耗时能够互相掩盖起来,这样,才能减少堵塞和等待,提升吞吐。) -# -# 2. 使用方法 -# bash fast.sh your_log_file.log -# -# 3. 注意事项 -# 基于spdlog::debug,mxRec中添加了TimeCost打点日志,因此,在执行前务必确保run.sh中设置 -# SPDLOG_LEVEL=debug (如果没有设置,本工具会退出,并给予提示) -# -# 4. 解读结果 -# (1) Pipeline: 整个Pipeline由多个Pipe串行构成,性能分析结果分Pipe呈现,例如Pipe-1/Pipe-2/Pipe-3/Pipe-4等; -# (2) Pipe: 每个Pipe级都会有一个整个耗时。(我们希望每个Pipe的耗时近似相等,这样Pipe之间才能互相掩盖,流水线效率才最高) -# (3) 子模块(Step):一个Pipe可能有多个串行的子模块(Step)构成、子模块又可能包含下一级子模块(SubStep)。因此,在性能分级报告中, -# 下一级的子模块耗时用--开头,再下一级的子模块耗时用----开头;依次类推;(上一级的耗时中包含了下一级的耗时) -# -# 5. 性能调优 -# 通过分析报告,我们可能会发现: -# (1)耗时特别长的Pipe; -# (2)耗时特别长的子模块; -# 需要具体问题具体分析,针对性的调优或者开展深度优化。 -# 例如:如果发现Tensorflow数据解析慢(Pipe-1),导致供应不足,可以调节Tensorflow侧解析数据的num_parallel参数; -# 如果发现CPU打满而导致数据预处理阻塞(Pipe 2: Data Preprocess),则可以调低KEY_PROCESS_THREAD_NUM (默认为6); -# 如果发现H2D阻塞(Pipe 4: H2D Send Tensors (no DDR)),则可能需要排查NPU侧GetNext或者DNN训练是否堵塞。 -# 然而,对于一些深层的问题,可能涉及到需要开展深度优化:比如Pipe拆分、串行改并行、锁优化、执行逻辑调整。 -# 另外,本工具也可以作为性能优化的参考,例如优化了某个子模块,可以对比观察(优化前vs优化后)该子模块的耗时, -# 同时对比观察端到端耗时、吞吐变化等。 -# -# 6. 该工具也需要不断升级,和代码同步更新,欢迎大家修改、完善。Good Luck! -# -----------------------------------------ReadMe End-------------------------------------------- -#set -x - -LOG_INFO() { echo -e "\033[1;4;32m$1\033[0m" ; } -LOG_NOTICE() { echo -e "\033[1;4;45m$1\033[0m" ; } -LOG_WARN() { echo -e "\033[1;31m[WARN]$1\033[0m" ; } -LOG_ERROR() { echo -e "\033[1;31m[Error]$1\033[0m" ; } - -logfile=$1 - -validate_options() -{ - if [ $# -ne 1 ]; then - LOG_ERROR "NO log_file" - echo "[Usage]: bash $0 log_file" - exit 1 - fi -} - -check_spdlog_level() -{ - $(grep 'ReadEmbKeyV2Static' $logfile > /dev/null 2>&1) - if [ $? != 0 ]; then - $(grep 'ReadEmbKeyV2Dynamic' $logfile > /dev/null 2>&1) - if [ $? != 0 ]; then - LOG_ERROR "No timecost-related logs, please check 'mpi_args' in your run.sh, - make sure SPDLOG_LEEL=debug, and run again!" - exit 1 - fi - fi -} - -parse_pipe_1_data_parser() -{ - LOG_NOTICE "Pipe-1: Data Parser" - - $(grep 'ReadEmbKeyV2Dynamic' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - LOG_INFO "Step-1.x ReadEmbKeyV2 Dynamic" - else - LOG_INFO "Step-1.x ReadEmbKeyV2 Static" - fi - - grep 'read batch cost(ms)' $logfile | cut -d" " -f7| \ - awk -F "[:,]" '{sum+=$2} END {printf "read batch cost: avg=%0.1f\n", sum/NR}' - - grep 'enqueueTC(ms)' $logfile | grep -v 'timeout' | cut -d" " -f11 | \ - awk -F "[:,]" '{sum+=$2} END {printf "--|enqueueTC: avg=%0.1f\n", sum/NR}' - - grep 'elapsed from last(ms)' $logfile | grep -v 'timeout' | cut -d" " -f10 | \ - awk -F "[:,]" '{print $2}' | \ - awk 'BEGIN {sum=0; count=0} {if($1<1000) {sum+=$NF; count++} } END \ - {printf "elapsed from last: avg=%0.1f\n", sum/count}' -} - -parse_pipe_2_key_process() -{ - LOG_NOTICE "Pipe-2: Data Preprocess" - - grep 'getAndProcessTC(ms)' $logfile | cut -d" " -f5 | \ - awk -F"[:,]" '{print $2}' | \ - awk 'BEGIN{count=0; total=0;} {if ($1<2000) {total+=$NF; count++;}} END \ - {printf "getAndProcessTC(filter>2000ms): avg=%0.3f\n", total/count}' - - LOG_INFO "Step-2.1 GetBatchData" - - grep 'getBatchDataTC' $logfile | \ - awk -F":" 'BEGIN { max=0 } { sum+=$NF; if($NF>max) max=$NF } END \ - {printf "--|getBatchDataTC: total=%d, max=%0.1f, avg=%0.1f\n", NR, max, sum/NR}' - - grep 'getBatchDataTC' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ - {printf "--|getBatchDataTC(filter>2000ms): count=%d, avg=%0.1f\n", count, sum/count}' - - grep 'getBatchDataTC' $logfile | \ - awk -F":" 'BEGIN { total=0; none_zero_ms_num=0 } { total++; if($NF>0) none_zero_ms_num++ } END \ - {printf "--|getBatchDataTC: total=%d, none_zero_ms_num=%d, none_zero_ms_rate=%0.3f, zero_ms_rate=%0.3f\n", \ - total, none_zero_ms_num, none_zero_ms_num/total, (1-none_zero_ms_num/total)}' - - LOG_INFO "Step-2.2 KeyProcess" - - grep 'key process cost' $logfile | cut -d" " -f8 | cut -d ":" -f2 | cut -d"," -f1 | grep '^[0-9]' | grep '[0-9]$' | \ - awk 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ - {printf "--|key process cost(filter>2000): avg=%0.1f\n", sum/count}' - - # fast-unique related start - $(grep 'ProcessBatchWithFastUnique(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'ProcessBatchWithFastUnique(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "----|ProcessBatchWithFastUnique: avg=%0.1f\n", sum/NR}' - fi - - $(grep 'FastUniqueCompute(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'FastUniqueCompute(ms)' $logfile | cut -d' ' -f4 | \ - awk -F"[:,]" '{sum+=$2} END {printf "------|FastUniqueCompute: avg=%0.1f\n", sum/NR}' - fi - - $(grep 'GetScAll TimeCost(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'GetScAll TimeCost(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "------|GetScAll: avg=%0.1f\n", sum/NR}' - fi - - $(grep 'all2allTC TimeCost(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'all2allTC TimeCost(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "------|all2allTC: avg=%0.1f\n", sum/NR}' - fi - # fast-unique related end - - $(grep 'UniqueTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'UniqueTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "----|UniqueTC: avg=%0.1f\n", sum/NR}' - fi - - $(grep 'processSplitKeysTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'processSplitKeysTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "----|processSplitKeysTC: avg=%0.1f\n", sum/NR}' - fi - - $(grep 'getScAllTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'getScAllTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "------|getScAllTC(AllReduce-AllGather): avg=%0.1f\n", sum/NR}' - fi - - $(grep 'uniqueAll2AllTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'uniqueAll2AllTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "------|uniqueAll2AllTC(All2allv): avg=%0.1f\n", sum/NR}' - fi - - $(grep 'buildRestoreVecTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'buildRestoreVecTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "----|buildRestoreVecTC: avg=%0.1f\n", sum/NR}' - fi - - # common start - $(grep 'key2OffsetTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'key2OffsetTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "----|key2OffsetTC: avg=%0.1f\n", sum/NR}' - fi - - $(grep 'pushResultTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'pushResultTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {printf "----|pushResultTC, avg=%0.1f\n", sum/NR}' - fi - # common end -} - -parse_pipe_3_get_tensors_async_no_ddr() -{ - LOG_NOTICE "Pipe-3: Get Tensors async (no DDR)" - - $(grep 'getAllTensorTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'getAllTensorTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {print "getAllTensorTC, avg=", sum/NR}' - fi -} - -parse_pipe_4_send_tensors_async_no_ddr() -{ - LOG_NOTICE "Pipe-4: H2D Send Tensors async (no DDR)" - - $(grep 'sendTensorsTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'sendTensorsTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {print "sendTensorsTC, avg=", sum/NR}' - fi - - $(grep 'sendLookupTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'sendLookupTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {print "--|sendLookupTC, avg=", sum/NR}' - fi - - $(grep 'sendRestoreTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'sendRestoreTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {print "--|sendRestoreTC, avg=", sum/NR}' - fi -} - -parse_pipe_3_get_and_send_tensors_with_ddr() -{ - LOG_NOTICE "Pipe-3: Get and Send Tensors (with DDR)" - - grep 'parseKeyTC' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ - {printf "parseKeyTC(filter>1000ms): avg=%0.1f\n", sum/count}' - - - grep 'getAndSendTensorsTC' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ - {printf "--getAndSendTensorsTC(filter>1000ms): avg=%0.1f\n", sum/count}' - - grep 'getTensorsTC' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ - {printf "----getTensorsTC(filter>1000ms): avg=%0.1f\n", sum/count}' - - $(grep 'hostHashMapProcessTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'hostHashMapProcessTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {print "----hostHashMapProcessTC, avg=", sum/NR}' - fi - - $(grep 'sendTensorsTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'sendTensorsTC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {print "----sendTensorsTC, avg=", sum/NR}' - fi - - $(grep 'embHdTrans1TC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'embHdTrans1TC(ms)' $logfile | \ - awk -F":" '{sum+=$NF} END {print "--embHdTrans1TC, avg=", sum/NR}' - fi - - grep 'embHdTrans2TC' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ - {printf "--embHdTrans2TC(filter>1000ms): avg=%0.1f\n", sum/count}' - - grep 'hostEmbsTC' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ - {printf "----hostEmbsTC(filter>1000ms): avg=%0.1f\n", sum/count}' - - grep 'EmbHDTrans' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ - {printf "----EmbHDTrans(filter>1000ms): avg=%0.1f\n", sum/count}' - - grep 'h2dTC' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ - {printf "------h2dTC(filter>1000ms): avg=%0.1f\n", sum/count}' - - grep 'd2hTC' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<10000) {sum+=$NF; count++;}} END \ - {printf "------d2hTC(filter>1000ms): avg=%0.1f\n", sum/count}' -} - -parse_pipe_3_get_and_send_tensors_sync_without_ddr() -{ - LOG_NOTICE "Pipe-3: Get and Send Tensors sync (no DDR)" - - $(grep 'ParseKeysTC HBM mode (ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'ParseKeysTC HBM mode (ms)' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<2000) {sum+=$NF; count++;}} END \ - {printf "ParseKeysTC(filter>2000ms): avg=%0.1f\n", sum/count}' - fi - - grep 'getTensorsSyncTC(ms)' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<1000) {sum+=$NF; count++;}} END \ - {printf "--|getTensorsSyncTC(filter>1000ms): avg=%0.1f\n", sum/count}' - - grep 'sendTensorsSyncTC(ms)' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<1000) {sum+=$NF; count++;}} END \ - {printf "--|sendTensorsSyncTC(filter>1000ms): avg=%0.1f\n", sum/count}' - - $(grep 'sendAll2AllScSyncTC(ms)' $logfile > /dev/null 2>&1) - if [ $? == 0 ]; then - grep 'sendAll2AllScSyncTC(ms)' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<200) {sum+=$NF; count++;}} END \ - {printf "----|sendAll2AllScSyncTC(filter>200ms): avg=%0.1f\n", sum/count}' - fi - - grep 'sendLookupSyncTC(ms)' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<200) {sum+=$NF; count++;}} END \ - {printf "----|sendLookupSyncTC(filter>200ms): avg=%0.1f\n", sum/count}' - - grep 'sendRestoreSyncTC(ms)' $logfile | \ - awk -F":" 'BEGIN {sum=0; count=0;} {if($NF<200) {sum+=$NF; count++;}} END \ - {printf "----|sendRestoreSyncTC(filter>200ms): avg=%0.1f\n", sum/count}' -} - -main() -{ - validate_options $@ - check_spdlog_level - - echo "+----------------------------------------------------------------+" - echo "+ Profile Result +" - echo "+----------------------------------------------------------------+" - - parse_pipe_1_data_parser - parse_pipe_2_key_process - - $(grep 'DDR mode' $logfile > /dev/null 2>&1) - if [ $? -eq 0 ]; then - parse_pipe_3_get_and_send_tensors_with_ddr - else - $(grep 'ParseKeysTC HBM mode (ms)' $logfile > /dev/null 2>&1) - if [ $? -eq 0 ]; then - parse_pipe_3_get_and_send_tensors_sync_without_ddr - else - parse_pipe_3_get_tensors_async_no_ddr - parse_pipe_4_send_tensors_async_no_ddr - fi - fi -} - -main $@ \ No newline at end of file diff --git a/tools/perf/host_set.sh b/tools/perf/host_set.sh deleted file mode 100755 index 0120ebb9..00000000 --- a/tools/perf/host_set.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: performace analysis tool -# Author: MindX SDK -# Create: 2023 -# History: NA - -# cpu with high-performance -cpupower frequency-set -g performance -cat /proc/cpuinfo|grep MHz - -# clear cache -echo 3 > /proc/sys/vm/drop_caches -free -h - -# swap off -swapoff -a diff --git a/tools/perf/msprof.sh b/tools/perf/msprof.sh deleted file mode 100755 index c1821c83..00000000 --- a/tools/perf/msprof.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: performace analysis tool -# Author: MindX SDK -# Create: 2023 -# History: NA - - -curr_path=$(cd $(dirname $0); pwd) - -# ---------------config start--------------------- -model_run_path=/path/to/model/run -run_cmd="bash run.sh" -# ---------------config end--------------------- - -# ------------------------------+ -# msprof + -# ------------------------------+ -output_path="${model_run_path}"/msprof_out - -cd "${model_run_path}" -rm -rf "${output_path}" - -msprof --application="${run_cmd}" --output="${output_path}" diff --git a/tools/perf/mt_1207.sh b/tools/perf/mt_1207.sh deleted file mode 100755 index fc0af5db..00000000 --- a/tools/perf/mt_1207.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: performace analysis tool -# Author: MindX SDK -# Create: 2023 -# History: NA - -#set -x - -LOG_INFO() { echo -e "\033[1;4;32m$1\033[0m" ; } -LOG_NOTICE() { echo -e "\033[1;4;45m$1\033[0m" ; } -LOG_WARN() { echo -e "\033[1;31m[WARN]$1\033[0m" ; } -LOG_ERROR() { echo -e "\033[1;31m[Error]$1\033[0m" ; } - -logfile=$1 - -# ---------------config start--------------------- -batchsize=9600 -parallel=8 -nv_throughput=820000 -# ---------------config end--------------------- - -validate_options() -{ - if [ $# -ne 1 ]; then - LOG_ERROR "NO log_file" - echo "[Usage]: bash $0 your_file.log" - exit 1 - fi -} - -print_throughput() -{ - LOG_INFO "=========Throughput=====================" - nv_sps=$(awk 'BEGIN{printf "%.2f\n",('${nv_throughput}'/'$batchsize'/'$parallel')}') - LOG_NOTICE "batchsize:${batchsize}, parallel:${parallel}" - LOG_NOTICE "nv_throughput:${nv_throughput}, nv_sps:${nv_sps}" - - grep 'tensorflow:global_step/sec' $logfile | \ - awk -F" " '{sum+=$NF} END \ - {printf "Throughput: avg=%0.3f, xA100:%0.3f\n", \ - sum/NR, sum/NR/'${nv_sps}'}' - - grep 'tensorflow:global_step/sec' $logfile | \ - awk -F" " 'BEGIN {sum=0; count=0;} {if ($NF > 3) {sum+=$NF; count++;}} END \ - {printf "Throughput: after filter(<3), avg=%0.3f, xA100:%0.3f\n", \ - sum/count, sum/count/'${nv_sps}'}' - - grep 'tensorflow:global_step/sec' $logfile | \ - awk -F" " 'BEGIN {max=0} {if($2>max) max=$2} END \ - {printf "Throughput: max=%0.3f, xA100:%0.3f\n", max, max/'${nv_sps}'}' -} - -main() -{ - validate_options $@ - print_throughput -} - -main $@ diff --git a/tools/perf/perf_flame_graph.sh b/tools/perf/perf_flame_graph.sh deleted file mode 100755 index dce91600..00000000 --- a/tools/perf/perf_flame_graph.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: performace analysis tool -# Author: MindX SDK -# Create: 2023 -# History: NA - -#set -x - -curr_path=$(cd $(dirname $0); pwd) - -LOG_INFO() { echo -e "\033[1;4;32m$1\033[0m" ; } -LOG_NOTICE() { echo -e "\033[1;4;45m$1\033[0m" ; } -LOG_WARN() { echo -e "\033[1;31m[WARN]$1\033[0m" ; } -LOG_ERROR() { echo -e "\033[1;31m[Error]$1\033[0m" ; } - -# ---------------config start--------------------- -model_run_path=/path/to/model/run -run_cmd="bash run.sh" -flame_graph_path=/home/FlameGraph -# ---------------config end--------------------- - -cd "${model_run_path}" -rm -rf perf* - -#---- perf cpu-clock on all workers and build flame graph------------ -perf record -F 99 -a -g "${run_cmd}" -wait $! - -perf script -i perf.data | \ - "${flame_graph_path}"/stackcollapse-perf.pl | \ - "${flame_graph_path}"/flamegraph.pl > perf_mxRec.svg -wait $! - -LOG_INFO "perf_mxRec.svg is created, please check!" - - diff --git a/tools/python/key_2_emb_formatter.py b/tools/python/key_2_emb_formatter.py deleted file mode 100644 index 0c838b29..00000000 --- a/tools/python/key_2_emb_formatter.py +++ /dev/null @@ -1,216 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: -# Author: MindX SDK -# Create: 2023-01-29 - -import argparse -import json -import logging -import os -import re -import numpy as np - -MIN_SIZE = 1 -MAX_SIZE = 1024 * 1024 * 1024 * 1024 - - -parser = argparse.ArgumentParser() -parser.add_argument('--path', type=str, required=True, help='path of the root dir of saved file') -parser.add_argument('--name', type=str, default="key_2_embedding", help='name of output file') -parser.add_argument('--ddr', type=bool, default=False, help='if saved data was from ddr mode, default False') -parser.add_argument('--step', type=int, default=0, help='the step when the data was saved, default 0') - - -def get_verified_path(path): - real_path = os.path.realpath(path) - if os.path.exists(real_path): - return real_path - else: - raise NotADirectoryError(f"{path} is not a valid directory") - - -def get_valid_file_name(name): - invalid_symbols = r"[\/\\\:\*\?\"\<\>\|]" - valid_name = re.sub(invalid_symbols, "_", name) - return valid_name - - -class Formatter: - def __init__(self, saved_file_path, out_file_name, is_ddr_mode, step): - self._device_dir_list = ["HashTable", "HBM"] - self._host_dir_list = ["HashTable", "DDR"] - self._device_emb_dir = "embedding" - self._host_emb_dir = "embedding_data" - self._device_hashmap_dir = "key_offset_map" - self._host_hashmap_dir = "embedding_hashmap" - self._attrib_suffix = ".attribute" - self._data_suffix = ".data" - self._out_file_suffix = ".npy" - - self._saved_file_path = get_verified_path(saved_file_path) - self._out_file_name = get_valid_file_name(out_file_name) - self._sub_dirs = self._get_sub_dirs(step) - self._table_names = None - - self._json_attrib_dtype = "data_type" - self._json_attrib_shape = "shape" - self._host_attrib_dtype = np.uint64 - self._hashmap_dtype = np.uint64 - self._raw_key_dtype = np.uint64 - self._key_dtype = np.int64 - self._raw_key_offset = np.iinfo(np.uint32).max - self._data_dtype = None - - self._is_ddr_mode = is_ddr_mode - - def process(self): - dev_dir = self._set_upper_dir(self._sub_dirs[0], self._device_dir_list) - self._table_names = self._get_table_names(dev_dir) - - transformed_data = [] - for table_name in self._table_names: - combined_key = None - combined_emb = None - for sub_dir in self._sub_dirs: - dev_dir = self._set_upper_dir(sub_dir, self._device_dir_list) - host_dir = self._set_upper_dir(sub_dir, self._host_dir_list) - emb_data = self._data_process(dev_dir, host_dir, table_name) - key, offset = self._hashmap_process(dev_dir, host_dir, table_name) - emb_data = emb_data[offset] - - if combined_key is not None: - combined_key = np.append(combined_key, key, axis=0) - else: - combined_key = key - if combined_emb is not None: - combined_emb = np.append(combined_emb, emb_data, axis=0) - else: - combined_emb = emb_data - - logging.debug(f"{table_name} has combined key {combined_key.shape}" - f" and combined emb {combined_emb.shape}") - - transformed_data.append(table_name) - transformed_data.append(combined_key) - transformed_data.append(combined_emb) - - np.save("./" + self._out_file_name + self._out_file_suffix, transformed_data) - - def _data_process(self, dev_dir, host_dir, table_name): - dev_emb_dir = os.path.join(dev_dir, table_name, self._device_emb_dir) - host_emb_dir = os.path.join(host_dir, table_name, self._host_emb_dir) - - data_file, attribute_file = self._get_file_names(dev_emb_dir) - dev_attribute = self._get_attribute(dev_emb_dir, attribute_file, is_json=True) - if not self._data_dtype: - self._data_dtype = dev_attribute.pop(self._json_attrib_dtype) - - dev_data_shape = dev_attribute.pop(self._json_attrib_shape) - emb_data = self._get_data(dev_emb_dir, data_file, self._data_dtype, dev_data_shape) - - if self._is_ddr_mode: - data_file, attribute_file = self._get_file_names(host_emb_dir) - host_attribute = self._get_attribute(host_emb_dir, attribute_file, is_json=False) - host_data_shape = [host_attribute[0], host_attribute[1]] - host_data = self._get_data(host_emb_dir, data_file, self._data_dtype, host_data_shape) - host_data = host_data[:, :dev_data_shape[1]] - emb_data = np.append(emb_data, host_data, axis=0) - - return emb_data - - def _hashmap_process(self, dev_dir, host_dir, table_name): - dev_hashmap_dir = os.path.join(dev_dir, table_name, self._device_hashmap_dir) - host_hashmap_dir = os.path.join(host_dir, table_name, self._host_hashmap_dir) - - if self._is_ddr_mode: - data_file, attribute_file = self._get_file_names(host_hashmap_dir) - attribute = self._get_attribute(host_hashmap_dir, attribute_file, is_json=False) - data_shape = attribute[:2] - raw_hashmap = self._get_data(host_hashmap_dir, data_file, self._hashmap_dtype, data_shape) - else: - data_file, attribute_file = self._get_file_names(dev_hashmap_dir) - attribute = self._get_attribute(dev_hashmap_dir, attribute_file, is_json=False) - data_shape = attribute[:2] - raw_hashmap = self._get_data(dev_hashmap_dir, data_file, self._hashmap_dtype, data_shape) - - offset = raw_hashmap[:, -1] - raw_key = raw_hashmap[:, :2].astype(self._raw_key_dtype) - key = raw_key[:, 0] * self._raw_key_offset + raw_key[:, 1] - key = key.astype(self._key_dtype) - - return key, offset - - def _get_sub_dirs(self, step): - sub_dirs = [] - for _, sub_dir, _ in os.walk(self._saved_file_path): - sub_dirs.append(sub_dir) - - if not sub_dirs or not sub_dirs[0]: - raise FileNotFoundError(f"There is no sparse checkpoint for given root directory.") - - picked_sub_dirs = [] - for sub_dir in sub_dirs[0]: - if int(sub_dir.split("-")[-1]) == step: - picked_sub_dirs.append(sub_dir) - - if not picked_sub_dirs: - raise FileNotFoundError(f"There is no sparse checkpoint for given training step {step}.") - return picked_sub_dirs - - def _set_upper_dir(self, sub_dir, dir_list): - temp_dir = os.path.join(self._saved_file_path, sub_dir) - for directory in dir_list: - temp_dir = os.path.join(temp_dir, directory) - return temp_dir - - def _get_table_names(self, directory): - if os.path.exists(directory): - table_names = [] - for _, table_name, _ in os.walk(directory): - table_names.append(table_name) - return table_names[0] - else: - raise ValueError("given directory does not contain required subdirectories, cannot search for table names") - - def _get_file_names(self, directory): - files = [] - data_file = None - attribute_file = None - for _, _, file in os.walk(directory): - files.append(file) - for file in files[0]: - if file.find(self._data_suffix) != -1: - data_file = file - elif file.find(self._attrib_suffix) != -1: - attribute_file = file - return data_file, attribute_file - - def _get_attribute(self, directory, file_name, is_json): - file_dir = os.path.join(directory, file_name) - if is_json: - with open(file_dir, "r") as fin: - # check file whether is valid - file_info = os.stat(fin.fileno()) - if file_info.st_size < MIN_SIZE or file_info.st_size > MAX_SIZE: - raise ValueError(f"file size {file_info.st_size} is not in range[{MIN_SIZE}, {MAX_SIZE}]") - if os.path.islink(file_dir): - raise ValueError(f"file dir {file_dir} is soft link or relative path") - attributes = json.load(fin) - return attributes - else: - attributes = np.fromfile(file_dir, self._host_attrib_dtype) - return attributes - - def _get_data(self, directory, file_name, dtype, shape): - file_dir = os.path.join(directory, file_name) - data = np.fromfile(file_dir, dtype=dtype) - data = data.reshape(shape) - return data - - -if __name__ == "__main__": - args = parser.parse_args() - formatter = Formatter(saved_file_path=args.path, out_file_name=args.name, is_ddr_mode=args.ddr, step=args.step) - formatter.process() -- Gitee From af3a0737946358abf7faddcf2f791c0c4248ac1c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 11 Aug 2023 15:20:14 +0800 Subject: [PATCH 252/551] Match-id-dae96f37d7e4ad9472cbe4e8fbec76f8c311c18f --- build/build.sh | 11 +++++++++++ build/build_tf1.sh | 3 ++- build/build_tf2.sh | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/build/build.sh b/build/build.sh index ab676ad2..2276e117 100644 --- a/build/build.sh +++ b/build/build.sh @@ -90,6 +90,17 @@ then gen_tar_file echo "-----Build gen tar finished-----" + clean + echo "-----Done-----" +fi + +if [ "$(uname -m)" = "aarch64" ] +then + echo "-----Build gen tar -----" + bash ${ROOT_DIR}/build/build_tf2.sh + gen_tar_file + echo "-----Build gen tar finished-----" + clean echo "-----Done-----" fi \ No newline at end of file diff --git a/build/build_tf1.sh b/build/build_tf1.sh index 06317e52..3b8547b5 100644 --- a/build/build_tf1.sh +++ b/build/build_tf1.sh @@ -60,6 +60,8 @@ acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c cd "${ROOT_DIR}" +release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz + compile_securec() { if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then @@ -116,7 +118,6 @@ gen_tar_file() { cd "${src_path}" mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" - mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" cd ../build tar -zvcf "${release_tar}" "${pkg_dir}" || { diff --git a/build/build_tf2.sh b/build/build_tf2.sh index af8318e4..0765dfca 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -20,6 +20,13 @@ then deactivate tf2_env fi +if [ "$(uname -m)" = "aarch64" ] +then + source /opt/buildtools/tf2_env/bin/activate + tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow + deactivate tf2_env +fi + VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml get_version() { if [ -f "$VERSION_FILE" ]; then @@ -114,6 +121,20 @@ gen_wheel_file() remove "${ROOT_DIR}"/mx_rec/libasc } +gen_tar_file() +{ + cd "${src_path}" + mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" + cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" + cd ../build + tar -zvcf "${release_tar}" "${pkg_dir}" || { + warn "compression failed, packages might be broken" + } + + mv "${release_tar}" "${SCRIPT_DIR}"/../output/ + +} + if [ "$(uname -m)" = "x86_64" ] then compile_securec @@ -130,4 +151,23 @@ then deactivate tf2_env echo "-----Build tf2 finished -----" +fi + +if [ "$(uname -m)" = "aarch64" ] +then + compile_securec + + echo "-----Build AccCTR -----" + compile_acc_ctr_so_file + + echo "-----Build Start tf2 -----" + source /opt/buildtools/tf2_env/bin/activate + compile_so_file "${tf2_path}" + collect_so_file + gen_wheel_file "${ROOT_DIR}"/tf2_whl + + deactivate tf2_env + echo "-----Build tf2 finished -----" + gen_tar_file + echo "-----Build gen tar finished-----" fi \ No newline at end of file -- Gitee From accce890cb9034593b554902da74b3414f6201e4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 11 Aug 2023 18:07:11 +0800 Subject: [PATCH 253/551] Match-id-ebfa78c7d778857ef94111c778cbf50080a49b23 --- mx_rec/graph/patch.py | 1 - mx_rec/graph/utils.py | 1 - mx_rec/optimizers/base.py | 2 +- mx_rec/util/communication/hccl_mgmt.py | 8 ++++---- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 3c3bca3c..3379c361 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -24,7 +24,6 @@ def init_dataset(self, input_data): """ input_data: A DT_VARIANT tensor that represents the dataset. """ - # pylint: disable=W tf.compat.v1.add_to_collection("dataset_group", self) self._variant_tensor_attr = input_data # get obj diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index 3b5f8148..9d867fd8 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -54,7 +54,6 @@ def record_ops_to_replace(src_op): def replace_anchor(replacement_specs: defaultdict, new_tensor_list: list): - # pylint: disable=W0212 if len(replacement_specs) != len(new_tensor_list): raise ValueError("Given replacement_specs and new_tensor_list must have the same length.") diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index 632e8ba3..ad654cd9 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -43,7 +43,7 @@ class CustomizedOptimizer: def custom_update_op(self, opt, grad): if isinstance(grad, ops.Tensor): - update_op = opt._apply_sparse(grad, self._v) # pylint: disable=protected-access + update_op = opt._apply_sparse(grad, self._v) return update_op else: raise RuntimeError("Only support g with type Tensor.") diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index ccb31eff..8fc8f61f 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -87,14 +87,14 @@ def set_hccl_info_without_json(): try: device_id = mxrec_pybind.get_logic_id(int(device_idx)) - if device_id > MAX_DEVICE_ID: - raise ValueError(f"get logic id from physic id fail.") - index = sorted_device_list.index(device_idx) - rank_to_device_dict[index + 1] = device_id except RuntimeError as exp: raise RuntimeError(f"get logic id from physic id fail. Possible reasons: 1) running user permission " f"is not enough to call dsmi api 2) driver has been used by other process") from \ exp + if device_id > MAX_DEVICE_ID: + raise ValueError(f"get logic id from physic id fail.") + index = sorted_device_list.index(device_idx) + rank_to_device_dict[index + 1] = device_id return rank_to_device_dict -- Gitee From 6a0221bc6eab176a7efb0970a2ff625b4a599d66 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 11 Aug 2023 16:23:50 +0800 Subject: [PATCH 254/551] Match-id-69c42b65f7fe4a68b489adf0f43ea0e8fbb6fb44 --- mx_rec/core/asc/manager.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 832bf53d..50d69df4 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -203,11 +203,12 @@ def initialize_emb_cache(table_info_list, threshold_list): rank_info = RankInfo(rank_id, device_id, rank_size, option, n_batch_to_prefetch, [train_steps, eval_steps]) emb_cache = HybridMgmt() - if threshold_list: - emb_cache.initialize(rank_info=rank_info, emb_info=table_info_list, if_load=if_load, + + is_initialized = emb_cache.initialize(rank_info=rank_info, emb_info=table_info_list, if_load=if_load, threshold_values=threshold_list) - else: - emb_cache.initialize(rank_info=rank_info, emb_info=table_info_list, if_load=if_load) + if is_initialized is False: + logging.error("Failed to init emb_cache!") + raise RuntimeError("emb_cache has not been initialized successfully.") set_asc_manager(emb_cache) logging.info("Preprocessing has been sunk into the host pipeline.") -- Gitee From 1538ac925bbbc63552aa5d2ebea7b893cc94d21c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 14 Aug 2023 14:27:30 +0800 Subject: [PATCH 255/551] Match-id-955371774af2925e254df72eaffe021b1786f7b6 --- mx_rec/constants/constants.py | 3 + mx_rec/core/asc/build_graph.py | 41 ++++++++ mx_rec/core/embedding.py | 145 +++------------------------ mx_rec/graph/modifier.py | 7 +- src/core/hd_transfer/hd_transfer.cpp | 3 + src/core/hd_transfer/hd_transfer.h | 6 ++ src/core/hybrid_mgmt/hybrid_mgmt.cpp | 38 ++++++- src/core/key_process/key_process.cpp | 62 +++++++++--- src/core/key_process/key_process.h | 4 + src/core/utils/common.cpp | 1 + src/core/utils/common.h | 2 + src/platform/AccCTR | 2 +- 12 files changed, 162 insertions(+), 152 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 611457e5..92c55790 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -96,6 +96,9 @@ class ASCAnchorAttr(Enum): ALL2ALL_MATRIX = "all2all_matrix" HOT_POS = "hot_pos" LOOKUP_RESULT = "lookup_result" + RESTORE_VECTOR_SECOND = "restore_vector_second" + UNIQUE_KEYS = "unique_keys" + GRADIENTS_STRATEGY = "gradients_strategy" class MxRecMode(BaseEnum): diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index c25eb6be..4b85317a 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -9,6 +9,7 @@ import tensorflow as tf import mxrec_pybind from mx_rec.util.initialize import get_use_static from mx_rec.util.tf_version_adapter import npu_ops +from mx_rec.constants.constants import ApplyGradientsStrategy def get_restore_vector(config): @@ -76,6 +77,38 @@ def get_id_offsets(max_lookup_vec_size, config): return id_offsets, swap_pos, swap_len +def get_restore_vector_second(max_lookup_vec_size: int, config: dict) -> tf.Tensor: + """ + Get restore vector which is calculated after the second all2all + :param max_lookup_vec_size: the size of restore_vector_second + :param config: embedding config + :return: the restore vector calculated after the second all2all + """ + logging.debug(f'Channel {config.get("table_name")}_restore_second_{config.get("channel_id")} was built for getnext') + with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + restore_vector_second = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[[max_lookup_vec_size]], + channel_name=f'{config.get("table_name")}_restore_second_{config.get("channel_id")}')[0] + return restore_vector_second + + +def get_unique_keys(max_lookup_vec_size: int, config: dict) -> tf.Tensor: + """ + Get the global unique keys which is calculated after the second all2all + :param max_lookup_vec_size: the size of global unique keys + :param config: embedding config + :return: the global unique keys calculated after the second all2all + """ + logging.debug(f'Channel {config.get("table_name")}_uniquekeys_{config.get("channel_id")} was built for getnext') + with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + unique_keys = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[[max_lookup_vec_size]], + channel_name=f'{config.get("table_name")}_uniquekeys_{config.get("channel_id")}')[0] + return unique_keys + + def get_all2all_args(use_static: bool, config: dict) -> list: """ Get all2all parameters for dynamic condition @@ -139,6 +172,7 @@ def get_preprocessed_tensor_for_asc(table, config): h2d_emb_split = tf.split(h2d_emb, table_num, axis=1) swap_in = [tf.compat.v1.scatter_nd_update(table[i], nd_swap_pos, h2d_emb_split[i]) for i in range(len(table))] + result = { 'restore_vector': restore_vector, 'hot_pos': hot_pos, @@ -146,4 +180,11 @@ def get_preprocessed_tensor_for_asc(table, config): 'swap_in': swap_in, 'all2all_args': all2all_args, } + + if config.get("gradients_strategy") == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + with tf.compat.v1.variable_scope("restore_vector_second"): + restore_vector_second = get_restore_vector_second(max_lookup_vec_size, config) + with tf.compat.v1.variable_scope("unique_keys"): + unique_keys = get_unique_keys(max_lookup_vec_size, config) + result.update({'restore_vector_second': restore_vector_second, 'unique_keys': unique_keys}) return result diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index dc599fa8..115848c4 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -85,7 +85,6 @@ def create_table(**kwargs): def sparse_lookup(hashtable, ids, send_count, is_train, **kwargs): """ - Args: hashtable: SparseEmbedding instance to be looked up ids: Tensor to lookup from hashtable @@ -179,8 +178,7 @@ class SparseEmbedding: self.modify_graph = False self.init_param = config.get("init_param") self.all2all_gradients_op = All2allGradientsOp.mapping(config.get("all2all_gradients_op")) - self.apply_gradients_strategy = ApplyGradientsStrategy.mapping( - config.get("apply_gradients_strategy")) + self.apply_gradients_strategy = ApplyGradientsStrategy.mapping(config.get("apply_gradients_strategy")) self.set_slice_vocab_size() self.set_emb_size() @@ -358,6 +356,8 @@ class SparseEmbedding: SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.TABLE_INSTANCE] = self SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = kwargs.get("is_train") SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.FEATURE_SPEC] = feature_spec + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.GRADIENTS_STRATEGY] = \ + self.apply_gradients_strategy def check_mode(self, method_mode): if self.mode != method_mode: @@ -467,10 +467,6 @@ class SparseEmbedding: anchor_ids = tf.identity(ids, name="ids") tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, anchor_ids) self.register_anchor_attribute(anchor_ids, feature_spec, kwargs) - - use_dynamic_expansion = get_use_dynamic_expansion() - use_static = get_use_static() - use_hot = get_use_hot() eval_mode = not is_training and get_training_mode_channel_id(is_training) is None ids_lookup_name = feature_spec.name + "_lookup_ids" # set in train mode, train and eval mode, eval mode @@ -478,126 +474,9 @@ class SparseEmbedding: self.lookup_name_list.append(ids_lookup_name) self.modify_graph = kwargs.get("modify_graph", True) self.check_multi_lookup_times() - logging.debug(f"In lookup_for_asc function, table name: {self.table_name}, anchor_ids: {anchor_ids}, " - f"ids_lookup_name: {ids_lookup_name}, use_dynamic_expansion: {use_dynamic_expansion}, " - f"use_static: {use_static}, use_hot: {use_hot}") - - rank_size = get_rank_size() - id_offsets = tf.ones(shape=[send_count * rank_size if use_static else 1 * rank_size, ], - dtype=tf.int64 if get_use_dynamic_expansion() else tf.int32, name="id_offsets") - id_offsets = tf.identity(id_offsets, name=ASCAnchorAttr.ID_OFFSETS.value) - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.ID_OFFSETS] = id_offsets - local_embeddings = None - if use_dynamic_expansion: - local_embeddings = get_host_pipeline_ops().embedding_lookup_by_address(id_offsets, - embedding_dim=self.emb_size, - embedding_type=1) - - is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is None or \ - ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name - if is_training and use_dynamic_expansion and is_table_name_valid: - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) - logging.debug(f"modify graph, table_name: {self.table_name}, contain: {ASCEND_TABLE_NAME_MUST_CONTAIN}") - - @tf.custom_gradient - def sparse_forward(table, feat_ids): - logging.debug(f"fp rank size: {rank_size}") - if feat_ids.shape.as_list()[0] is not None: - restore_vector = tf.ones(shape=[np.prod(feat_ids.shape.as_list()), ], dtype=tf.int32, - name="restore_vector") - else: - restore_vector = tf.ones(shape=[tf.math.reduce_prod(array_ops.shape(feat_ids)[0]), ], dtype=tf.int32, - name="restore_vector") - - restore_vector = tf.identity(restore_vector, name=ASCAnchorAttr.RESTORE_VECTOR.value) - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, restore_vector) - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.RESTORE_VECTOR] = restore_vector - - all2all_matrix = None - if not use_static: - # In the case of multiple lookups of a table, the all2all_matrix does not run the 'getnext' op - # to obtain the actual value. Instead, the initial value is 1. So it needs to be multiplied by - # 'self.scalar_emb_size' to ensure the correctness of the 'Reshape' op in the _get_own_emb function. - all2all_matrix = tf.ones(shape=[rank_size, rank_size], - dtype=tf.int64, name="all2all_matrix") * self.scalar_emb_size - all2all_matrix = tf.identity(all2all_matrix, name=ASCAnchorAttr.ALL2ALL_MATRIX.value) - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, all2all_matrix) - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.ALL2ALL_MATRIX] = all2all_matrix - - hot_pos = None - if use_hot: - import mxrec_pybind - emb_size = self.scalar_emb_size if self.skip_emb_transfer else self.ext_emb_size - if emb_size == 0: - raise ValueError("emb_size is 0, please set a valid value.") - hot_size = int(mxrec_pybind.get_ub_hot_size(get_device_id()) / emb_size) - hot_pos = tf.ones(shape=[hot_size, ], dtype=tf.int32, name="hot_pos") - hot_pos = tf.identity(hot_pos, name=ASCAnchorAttr.HOT_POS.value) - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_HOT_POS, hot_pos) - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.HOT_POS] = hot_pos - - if not use_dynamic_expansion: - id_offsets_abs = tf.abs(id_offsets) - local_emb = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") - local_emb = set_zero_for_non_valid_key(id_offsets, local_emb, feature_spec.access_threshold) - else: - local_emb = tf.identity(table, name="identity_local_emb") - all2all_args = send_count if use_static else all2all_matrix - unique_embeddings = self._get_own_emb(local_emb, all2all_args, self.scalar_emb_size, use_static) - - if hot_pos is not None: - unique_embeddings = tf.concat([tf.gather(unique_embeddings, hot_pos, name="hot_pos"), - unique_embeddings], axis=0) - - embeddings = tf.gather(unique_embeddings, restore_vector, axis=0, name="gather_for_restore_vector") - if use_static: - lookup_result = tf.reshape(embeddings, feat_ids.shape.as_list() + [self.scalar_emb_size]) - else: - dest_shape = array_ops.concat([array_ops.shape(feat_ids), [self.scalar_emb_size]], 0) - lookup_result = array_ops.reshape(embeddings, dest_shape) - - # In the case of multiple lookups of a table, the lookup result node needs to be recorded and - # replaced during modify graph. - lookup_result = tf.identity(lookup_result, name=ASCAnchorAttr.LOOKUP_RESULT.value) - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT, lookup_result) - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.LOOKUP_RESULT] = lookup_result - - def grad(lookup_diff): - embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) - unique_embed_shape = unique_embeddings.shape.as_list() if use_static else tf.shape(unique_embeddings) - unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, - unique_embed_shape[0]) - bp_all2all_args = all2all_args if use_static else tf.transpose(all2all_args) - if hot_pos is not None: - hot, cold = tf.split(unique_grads, [tf.shape(hot_pos)[0], - tf.shape(unique_grads)[0] - tf.shape(hot_pos)[0]], axis=0) - unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) - local_grad = self._get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) - if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: - try: - local_grad = local_grad / get_rank_size() - except ZeroDivisionError as exp: - raise ZeroDivisionError("Rank size cannot be zero.") from exp - - if use_dynamic_expansion: - return local_grad, feat_ids - - if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: - unique_id_offsets, unique_id_offsets_position = array_ops.unique(id_offsets) - unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - unique_id_offsets_position, - array_ops.shape(unique_id_offsets)[0]) - return ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, - dense_shape=tf.shape(table)), feat_ids - - return ops.IndexedSlices(values=local_grad, indices=id_offsets, dense_shape=tf.shape(table)), feat_ids - - return lookup_result, grad + kwargs["ids"] = ids - if use_dynamic_expansion: - return sparse_forward(local_embeddings, ids) - return sparse_forward(self.variable, ids) + return self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) def lookup_for_asc_with_feature_spec(self, feature_spec: FeatureSpec, send_count: int, **kwargs): """ @@ -729,7 +608,7 @@ class SparseEmbedding: rank_size=rank_size, channel_id=channel_id, table_name=self.table_name, skip_emb_transfer=self.skip_emb_transfer, ext_emb_size=self.ext_emb_size, emb_size=self.emb_size, use_hot=use_hot, device_id=device_id, - use_dynamic_expansion=use_dynamic_expansion) + use_dynamic_expansion=use_dynamic_expansion, gradients_strategy=self.apply_gradients_strategy) if self.skip_emb_transfer: result = get_preprocessed_tensor_for_asc(self.variable, config) @@ -737,8 +616,10 @@ class SparseEmbedding: variable_list = [self.variable] + [slot_info.get("slot") for slot_info in self.optimizer_slot_info_list] result = get_preprocessed_tensor_for_asc(variable_list, config) restore_vector = result.get("restore_vector") + restore_vector_second = result.get("restore_vector_second") hot_pos = result.get("hot_pos") id_offsets = result.get("id_offsets") + unique_keys = result.get("unique_keys") swap_in = result.get("swap_in") all2all_matrix = result.get("all2all_args") control_ops = swap_in @@ -778,7 +659,8 @@ class SparseEmbedding: if kwargs.get("multi_lookup"): lookup_result = tf.reshape(embeddings, [-1, self.scalar_emb_size]) else: - tensor = kwargs.get("batch").get(feature_spec.index_key) + tensor = kwargs.get("batch").get(feature_spec.index_key) \ + if not self.modify_graph else kwargs.get("ids") if tensor is None: raise KeyError(f"index_key '{feature_spec.index_key}' does not exist in batch.") dest_shape = array_ops.concat([array_ops.shape(tensor), [self.scalar_emb_size]], 0) @@ -805,11 +687,10 @@ class SparseEmbedding: update_grad = local_grad else: if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: - unique_id_offsets, unique_id_offsets_position = array_ops.unique(id_offsets) unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - unique_id_offsets_position, - array_ops.shape(unique_id_offsets)[0]) - update_grad = ops.IndexedSlices(values=unique_local_grad, indices=unique_id_offsets, + restore_vector_second, + array_ops.shape(unique_keys)[0]) + update_grad = ops.IndexedSlices(values=unique_local_grad, indices=unique_keys, dense_shape=tf.shape(table)) else: diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 4095d750..49743f26 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -16,7 +16,7 @@ from mx_rec.core.asc.feature_spec import FeatureSpec, set_temporary_feature_spec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP + ASCAnchorAttr, ASCEND_TIMESTAMP, ApplyGradientsStrategy from mx_rec.util.initialize import get_rank_size, get_training_mode_channel_id, get_feature_spec, \ insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, get_use_dynamic_expansion, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, increase_run_times, \ @@ -348,7 +348,7 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): # multiple lookups of a same table lookup_for_same_table(sub_cutting_point_list, is_training) # replace the stub node for sparse lookup from the graph - replace_stub_node_with_asc_graph(sub_cutting_point_list, is_training) + logging.info("Graph has been revised.") export_pb_graph("new_graph.pb", dump_graph) @@ -437,7 +437,8 @@ def replace_stub_node_with_asc_graph(sub_cutting_point_list: list, is_training: send_count=table_instance.send_count, channel_id=channel_id, rank_size=get_rank_size(), table_name=table_instance.table_name, skip_emb_transfer=table_instance.skip_emb_transfer, ext_emb_size=table_instance.ext_emb_size, emb_size=table_instance.scalar_emb_size, - use_hot=get_use_hot(), device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion()) + use_hot=get_use_hot(), device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion(), + is_training=is_training) lookup_result = None if len(table_instance.lookup_name_list) > 1 and feature_spec.name in table_instance.lookup_result and \ diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index c748507e..8cc92597 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -59,6 +59,7 @@ void HDTransfer::Destroy() running = false; LOG(INFO) << (HD + "destroy channel start"); for (auto& c: transferChannels) { + LOG(INFO) << StringFormat(HD + "start destroy channel:%s", c.first.c_str()); tensorflow::StopRecvTensorByAcl(&c.second, c.first); LOG(INFO) << StringFormat(HD + "destroy channel:%s", c.first.c_str()); } @@ -101,6 +102,8 @@ void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName if (TransferChannel2Str(channel) == "all2all" || TransferChannel2Str(channel) == "restore" || TransferChannel2Str(channel) == "lookup" || + TransferChannel2Str(channel) == "restore_second" || + TransferChannel2Str(channel) == "uniquekeys" || TransferChannel2Str(channel) == "evict" /* for noDDR */ ) { transferChannels[sendName] = tdtCreateChannel(localRankId, sendName.c_str(), channelSize); diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index fea3c43e..d0ddaa73 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -30,7 +30,9 @@ namespace MxRec { enum class TransferChannel { D2H, RESTORE, + RESTORE_SECOND, ALL2ALL, + UNIQKEYS, LOOKUP, EVICT, H2D, @@ -41,12 +43,16 @@ namespace MxRec { inline string TransferChannel2Str(TransferChannel e) { switch (e) { + case TransferChannel::RESTORE_SECOND: + return "restore_second"; case TransferChannel::D2H: return "d2h"; case TransferChannel::RESTORE: return "restore"; case TransferChannel::ALL2ALL: return "all2all"; + case TransferChannel::UNIQKEYS: + return "uniquekeys"; case TransferChannel::LOOKUP: return "lookup"; case TransferChannel::EVICT: diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index cd400a52..0ba105ac 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -17,6 +17,12 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& const vector& thresholdValues, int seed) { #ifndef GTEST + if (getenv("APPLY_GRADIENTS_STRATEGY") != nullptr) { + bool strategy = (!strcmp(getenv("APPLY_GRADIENTS_STRATEGY"), SUM_SAME_ID)); + PerfConfig::gradientStrategy = strategy; + LOG(INFO) << StringFormat("config GRADIENTS_STRATEGY:%d", strategy); + } + if (getenv("KEY_PROCESS_THREAD_NUM") != nullptr) { int num = std::atoi(getenv("KEY_PROCESS_THREAD_NUM")); if (num < 1 || num > MAX_KEY_PROCESS_THREAD) { @@ -700,6 +706,18 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) infoVecs->pop_back(); VLOG(GLOG_DEBUG) << StringFormat("sendLookupSyncTC(ms):%d", sendLookupSyncTC.ElapsedMS()); + if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID) { + TimeCost sendUnikeysSyncTC; + hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embInfo.name); + infoVecs->pop_back(); + VLOG(GLOG_DEBUG) << StringFormat("sendUnikeysSyncTC(ms):%d", sendUnikeysSyncTC.ElapsedMS()); + + TimeCost sendRestoreVecSecSyncTC; + hdTransfer->Send(TransferChannel::RESTORE_SECOND, { infoVecs->back() }, channelId, embInfo.name); + infoVecs->pop_back(); + VLOG(GLOG_DEBUG) << StringFormat("sendRestoreVecSecSyncTC(ms):%d", sendRestoreVecSecSyncTC.ElapsedMS()); + } + TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); VLOG(GLOG_DEBUG) << StringFormat("sendRestoreSyncTC(ms):%d", sendRestoreSyncTC.ElapsedMS()); @@ -772,12 +790,26 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, remainBatchOut = false; } - auto restore = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); + auto infoVecs = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); VLOG(GLOG_DEBUG) << StringFormat("getTensorsTC(ms):%d", getTensorsTC.ElapsedMS()); - hdTransfer->Send(TransferChannel::RESTORE, *restore, channelId, embName); - vector tmpData; + if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { + TimeCost sendUnikeysSyncTC; + hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embName); + infoVecs->pop_back(); + VLOG(GLOG_DEBUG) << StringFormat("sendUnikeysSyncTC(ms):%d", sendUnikeysSyncTC.ElapsedMS()); + + TimeCost sendRestoreVecSecSyncTC; + hdTransfer->Send(TransferChannel::RESTORE_SECOND, { infoVecs->back() }, channelId, embName); + infoVecs->pop_back(); + VLOG(GLOG_DEBUG) << StringFormat("sendRestoreVecSecSyncTC(ms):%d", sendRestoreVecSecSyncTC.ElapsedMS()); + } + + TimeCost sendRestoreSyncTC; + hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embName); + VLOG(GLOG_DEBUG) << StringFormat("sendRestoreSyncTC(ms):%d", sendRestoreSyncTC.ElapsedMS()); + vector tmpData; TimeCost hostHashMapProcessTC; hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData, channelId); VLOG(GLOG_DEBUG) << StringFormat("hostHashMapProcessTC(ms):%d", hostHashMapProcessTC.ElapsedMS()); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index d2c7889d..a67f13d6 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -367,13 +367,9 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat uniqueInfo.hotPos.resize(hotEmbTotCount[batch->name], -1); tensors->push_back(Vec2TensorI32(uniqueInfo.hotPos)); } - if (rankInfo.noDDR) { - if (rankInfo.useDynamicExpansion) { - tensors->push_back(Vec2TensorI64(uniqueInfo.all2AllInfo.keyRecv)); - } else { - tensors->push_back(Vec2TensorI32(uniqueInfo.all2AllInfo.keyRecv)); - } - } + + PushGlobalUniqueTensors(move(tensors), uniqueInfo.all2AllInfo.keyRecv, channel); + TimeCost pushResultTC; PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv); VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); @@ -429,16 +425,56 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe hotPos.resize(hotEmbTotCount[batch->name], 0); tensors->push_back(Vec2TensorI32(hotPos)); } + + PushGlobalUniqueTensors(tensors, lookupKeys, channel); + + PushResult(batch, move(tensors), lookupKeys); + VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); + return true; +} + +void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tensors, keys_t& lookupKeys, int channel) +{ + if (PerfConfig::gradientStrategy && channel == TRAIN_CHANNEL_ID) { + keys_t uniqueKeys; + vector restoreVecSec; + GlobalUnique(lookupKeys, uniqueKeys, restoreVecSec); + tensors->push_back(Vec2TensorI32(restoreVecSec)); + tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : Vec2TensorI32(uniqueKeys)); + } + if (rankInfo.noDDR) { - if (rankInfo.useDynamicExpansion) { - tensors->push_back(Vec2TensorI64(lookupKeys)); + tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); + } +} + +void KeyProcess::GlobalUnique(const keys_t& lookupKeys, keys_t& uniqueKeys, vector& restoreVecSec) +{ + absl::flat_hash_map umap; + restoreVecSec.resize(lookupKeys.size(), -1); + int32_t length = 0; + + for (size_t i = 0; i < lookupKeys.size(); ++i) { + int64_t key = lookupKeys[i]; + if (key == -1) { + continue; + } + auto result = umap.find(key); + if (result == umap.end()) { + uniqueKeys.push_back(lookupKeys[i]); + umap[key] = length; + restoreVecSec[i] = length; + length++; } else { - tensors->push_back(Vec2TensorI32(lookupKeys)); + restoreVecSec[i] = result->second; } } - PushResult(batch, move(tensors), lookupKeys); - VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); - return true; + + if (rankInfo.useStatic) { + uniqueKeys.resize(lookupKeys.size(), -1); + } else { + restoreVecSec.erase(std::remove(restoreVecSec.begin(), restoreVecSec.end(), -1), restoreVecSec.end()); + } } vector KeyProcess::GetCountRecv(const unique_ptr& batch, int id, diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index b3170541..2b573831 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -132,6 +132,8 @@ namespace MxRec { void GetUniqueConfig(UniqueConf& uniqueConf); + void GlobalUnique(const keys_t& lookupKeys, keys_t& uniqueKeys, vector& restoreVecSec); + void InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, const unique_ptr & batch, UniquePtr& unique); @@ -180,6 +182,8 @@ namespace MxRec { void PushResult(unique_ptr& batch, unique_ptr> tensors, keys_t& lookupKeys); + void PushGlobalUniqueTensors(const unique_ptr>& tensors, keys_t& lookupKeys, int chanel); + void AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, const unique_ptr& batch); diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 822febea..385d3143 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -24,6 +24,7 @@ namespace MxRec { int PerfConfig::keyProcessThreadNum = DEFAULT_KEY_PROCESS_THREAD; int PerfConfig::maxUniqueThreadNum = DEFAULT_MAX_UNIQUE_THREAD_NUM; bool PerfConfig::fastUnique = false; + bool PerfConfig::gradientStrategy = false; string g_rankId; int g_glogLevel; bool g_isGlogInit = false; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 60d4a97b..a4f33452 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -58,6 +58,7 @@ namespace MxRec { constexpr int MAX_QUEUE_NUM = MAX_CHANNEL_NUM * MAX_KEY_PROCESS_THREAD; constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; constexpr int KEY_PROCESS_THREAD = 6; + constexpr char SUM_SAME_ID[] = "sum_same_id_gradients_and_apply"; // for GLOG extern int g_glogLevel; @@ -74,6 +75,7 @@ namespace MxRec { static int keyProcessThreadNum; static int maxUniqueThreadNum; static bool fastUnique; + static bool gradientStrategy; }; constexpr int KEY_PROCESS_TIMEOUT = 120; diff --git a/src/platform/AccCTR b/src/platform/AccCTR index 77af967d..62ab674f 160000 --- a/src/platform/AccCTR +++ b/src/platform/AccCTR @@ -1 +1 @@ -Subproject commit 77af967dcc81f4f8f7e2affd015cb760db5e4d9f +Subproject commit 62ab674f0a42d8de8398eafb4799e506fe99549d -- Gitee From 80b9957285147430484e9e2389e6170a19237c35 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 14 Aug 2023 19:07:15 +0800 Subject: [PATCH 256/551] Match-id-911db1c234340f18d563dc8fb13b9bf7a8a45fde --- mx_rec/saver/saver.py | 31 ++++++++++---- mx_rec/saver/sparse.py | 4 +- mx_rec/util/communication/hccl_mgmt.py | 2 +- mx_rec/util/initialize.py | 2 +- mx_rec/validator/validator.py | 10 +++-- src/core/checkpoint/checkpoint.cpp | 59 ++++++++++++++++++++++---- src/core/checkpoint/checkpoint.h | 4 +- src/core/utils/common.h | 4 ++ 8 files changed, 90 insertions(+), 26 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index d358dde8..9033b679 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -11,12 +11,12 @@ import numpy as np import tensorflow as tf from tensorflow.python.util import compat -from mx_rec.constants.constants import DataName, DataAttr +from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_SIZE from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ send_host_data, get_ascend_global_hashtable_collection from mx_rec.util.perf import performance -from mx_rec.validator.validator import DirectoryValidator +from mx_rec.validator.validator import DirectoryValidator, FileValidator class Saver(object): @@ -374,19 +374,21 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} does not exist when reading.") with tf.io.gfile.GFile(target_attribute_dir, "r") as fin: + validate_read_file(target_attribute_dir) attributes = json.load(fin) if DataAttr.DATATYPE.value not in attributes: raise AttributeError(f"Lack of attribute {DataAttr.DATATYPE.value}.") - if target_data_dir.find("://") != -1: - logging.debug(f"use hdfs path {target_data_dir} to restore sparse data.") - with tf.io.gfile.GFile(target_data_dir, "r") as file: + with tf.io.gfile.GFile(target_data_dir, "r") as file: + validate_read_file(target_data_dir) + if target_data_dir.find("://") != -1: + logging.debug("use hdfs path %s to restore sparse data.", target_data_dir) data_to_restore = file.read() data_to_restore = np.array(json.loads(data_to_restore)) - else: - logging.debug(f"use local file path {target_data_dir} to restore sparse data.") - data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) + else: + logging.debug("use local file path %s to restore sparse data.", target_data_dir) + data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) if DataAttr.SHAPE.value in attributes and data_name != DataName.KEY.value: data_shape = attributes.pop(DataAttr.SHAPE.value) @@ -403,6 +405,19 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: return data_dict +def validate_read_file(read_file_path): + """ + Validate file before reading,including validating soft link, file size + :param read_file_path: the file path to be validated + """ + file_validator = FileValidator(read_file_path) + file_validator.check_file_size(MAX_SIZE, MIN_SIZE) + # local file need to check soft link + if read_file_path.find("://") == -1: + file_validator.check_not_soft_link() + file_validator.check() + + def process_embedding_data(data_to_restore: np.ndarray, current_data_shape: list, data_shape: list) -> np.ndarray: """ Process embedding data when reading binary file diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index fbc54d4c..88f5316c 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -52,7 +52,7 @@ class SparseProcessor: # 1.check whether data_dir is soft link file_validator.check_not_soft_link() # 2.check data file size - file_validator.check_file_size(file) + file_validator.check_file_size() file_validator.check() try: @@ -72,7 +72,7 @@ class SparseProcessor: # 1.check whether attribute_dir is soft link file_validator.check_not_soft_link() # 2.check attribute file size - file_validator.check_file_size(fin) + file_validator.check_file_size() file_validator.check() attributes = json.load(fin) except FileNotFoundError as err: diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index 8fc8f61f..866c6682 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -20,7 +20,7 @@ def parse_hccl_json(): # 1.check whether rank_table_path is soft link file_validator.check_not_soft_link() # 2.check json file size - file_validator.check_file_size(file, MAX_CONFIG_SIZE, MIN_SIZE) + file_validator.check_file_size(MAX_CONFIG_SIZE, MIN_SIZE) file_validator.check() table_hccl = json.load(file) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index c5971b8f..048af669 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -719,7 +719,7 @@ def get_available_cpu_num_and_range(): # 1.check whether f_path is soft link file_validator.check_not_soft_link() # 2.check file size - file_validator.check_file_size(f_in, MAX_CONFIG_SIZE, MIN_SIZE) + file_validator.check_file_size(MAX_CONFIG_SIZE, MIN_SIZE) file_validator.check() pkg_id = f_in.readline().strip() pkg_id2cpu_list[pkg_id].append(cpu) diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index a2eaede4..5ce9a26a 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -8,6 +8,8 @@ import re from typing import Callable, Any from typing import List, Optional, Tuple +import tensorflow as tf + from mx_rec.constants.constants import MIN_SIZE from mx_rec.constants.constants import MAX_SIZE from mx_rec.constants.constants import MAX_DEVICE_NUM @@ -263,10 +265,10 @@ class FileValidator(StringValidator): super().__init__(value) self.register_checker(lambda x: isinstance(x, str), "parameter value's type is not str") - def check_file_size(self, file_obj, max_size=MAX_SIZE, min_size=MIN_SIZE): - file_info = os.stat(file_obj.fileno()) - self.register_checker(lambda path: min_size < file_info.st_size <= max_size, - f"file size: {file_info.st_size} is invalid, not in [{min_size}, {max_size}]") + def check_file_size(self, max_size=MAX_SIZE, min_size=MIN_SIZE): + file_stat = tf.io.gfile.stat(self.value) + self.register_checker(lambda path: min_size < file_stat.length <= max_size, + f"file size: {file_stat.length} is invalid, not in ({min_size}, {max_size}]") return self def check_not_soft_link(self): diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index dfedfe5c..73b54255 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -9,12 +9,14 @@ #include #include #include +#include #include "ckpt_data_handler//emb_hash_ckpt/emb_hash_ckpt.h" #include "ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h" #include "ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h" #include "ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h" #include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" +#include "utils/time_cost.h" #include "checkpoint.h" @@ -233,18 +235,25 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) { std::ifstream readFile; readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); + size_t datasetSize = static_cast(readFile.tellg()); + readFile.seekg(0, std::ios::beg); + try { + ValidateReadFile(dataDir, datasetSize); + } catch (const std::invalid_argument& e) { + readFile.close(); + throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + } #ifndef GTEST auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { LOG(ERROR) << StringFormat("Set device failed, device_id:%d", deviceId); + readFile.close(); throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); } auto &AttributeArr = transData.attribute; auto embHashMapSize = AttributeArr.at(0); - size_t datasetSize = readFile.tellg(); - readFile.seekg(0, std::ios::beg); auto embeddingSize = static_cast(datasetSize / sizeof(float) / embHashMapSize); aclError ret; @@ -252,6 +261,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); + readFile.close(); throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } @@ -266,6 +276,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); + readFile.close(); throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } @@ -424,17 +435,25 @@ void Checkpoint::ReadStream(CkptTransData& transData, LOG(WARNING) << "dataElmtBytes is 0, don't handle [/ %] operation"; return ; } + std::ifstream readFile; readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); - - size_t datasetSize = readFile.tellg(); + size_t datasetSize = static_cast(readFile.tellg()); readFile.seekg(0, std::ios::beg); + try { + ValidateReadFile(dataDir, datasetSize); + } catch (const std::invalid_argument& e) { + readFile.close(); + throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + } if (datasetSize % dataElmtBytes > 0) { VLOG(GLOG_DEBUG) << StringFormat("data is missing or incomplete in load file: %s", dataDir.c_str()); } + auto resizeSize { datasetSize / dataElmtBytes }; SetTransDataSize(transData, resizeSize, dataType); + if (readFile.is_open()) { size_t idx = 0; size_t readSize = 0; @@ -445,7 +464,6 @@ void Checkpoint::ReadStream(CkptTransData& transData, readSize = datasetSize; } ReadDataset(transData, readFile, readSize, dataType, idx); - datasetSize -= readSize; idx += readSize; } @@ -468,7 +486,6 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, } auto embDataOuterSize = transData.attribute.at(attribEmbDataOuterIdx); - auto loadHostEmbs = ckptData.hostEmbs; auto& dst = (*loadHostEmbs)[embName].embData; dst.reserve(embDataOuterSize); @@ -476,11 +493,19 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, std::ifstream readFile; readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); - size_t datasetSize = readFile.tellg(); + size_t datasetSize = static_cast(readFile.tellg()); readFile.seekg(0, std::ios::beg); + try { + ValidateReadFile(dataDir, datasetSize); + } catch (const std::invalid_argument& e) { + readFile.close(); + throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + } + if (datasetSize % embDataOuterSize > 0 || datasetSize % dataElmtBytes > 0) { LOG(ERROR) << StringFormat("data is missing or incomplete in load file: %s", dataDir.c_str()); + readFile.close(); throw runtime_error("unable to load EMB_DATA cause wrong-format saved emb data"); } auto onceReadByteSize { datasetSize / embDataOuterSize }; @@ -500,9 +525,7 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, } else { readSize = dataCol; } - readFile.read((char*)(dst[i].data()) + idx, readSize); - dataCol -= readSize; idx += readSize; } @@ -539,3 +562,21 @@ void Checkpoint::ReadDataset(CkptTransData& transData, readFile.read((char*)(transData.attribute.data()) + idx, readSize); } } + +void Checkpoint::ValidateReadFile(const string& dataDir, size_t datasetSize) +{ + // validate soft link + struct stat fileInfo; + if (lstat(dataDir.c_str(), &fileInfo) != -1) { + if (S_ISLNK(fileInfo.st_mode)) { + LOG(ERROR) << StringFormat("soft link %s should not in the path parameter", dataDir.c_str()); + throw invalid_argument(StringFormat("soft link should not be the path parameter")); + } + } + // validate file size + if (datasetSize <= FILE_MIN_SIZE || datasetSize > FILE_MAX_SIZE) { + LOG(ERROR) << StringFormat("the reading file size is invalid, " + "not in not in (%d,%d]", FILE_MIN_SIZE, FILE_MAX_SIZE); + throw invalid_argument(StringFormat("file size invalid")); + } +} diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 4c3abc53..17c22f17 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -31,7 +31,7 @@ namespace MxRec { const string dataFileType { ".data" }; const string attribFileType { ".attribute" }; const string dirSeparator { "/" }; - const mode_t dirMode { 0755 }; + const mode_t dirMode { 0400 }; const string currDir { "." }; const string prevDir { ".." }; @@ -97,6 +97,8 @@ namespace MxRec { void SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType); void ReadDataset(CkptTransData& transData, ifstream& readFile, size_t readSize, CkptDataType dataType, size_t idx); + + void ValidateReadFile(const string& dataDir, size_t datasetSize); }; } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index a4f33452..b6c64416 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -71,6 +71,10 @@ namespace MxRec { constexpr int MIN_UNIQUE_THREAD_NUM = 1; constexpr int DEFAULT_MAX_UNIQUE_THREAD_NUM = 8; + // validate file + constexpr long long FILE_MAX_SIZE = 1LL << 40; + constexpr int FILE_MIN_SIZE = 1; + struct PerfConfig { static int keyProcessThreadNum; static int maxUniqueThreadNum; -- Gitee From ddafdc5b45c024f00824ac7a5484cacca8a19fde Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 14 Aug 2023 19:54:41 +0800 Subject: [PATCH 257/551] Match-id-65e29a6e82ebd56dea41089503c3f59f53104e16 --- mx_rec/graph/modifier.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 49743f26..8e771426 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -136,12 +136,12 @@ def get_op_before_optimize_dataset(get_next_op): base_op = find_make_iterator_op(get_next_op.outputs[0]) # looking for the op which is the one before OptimizeDataset operator if tf.__version__.startswith("1"): - optimize_dataset_op = find_target_dataset_op(base_op, "OptimizeDataset") + optimize_dataset_op = find_target_dataset_op(base_op, "ModelDataset") target_op = find_parent_op(optimize_dataset_op) if not target_op: - raise RuntimeError(f"The parent op for 'OptimizeDataset' op was not found.") - if target_op[0].type != "PrefetchDataset": - raise TypeError(f"Op PrefetchDataset was not found.") + raise RuntimeError(f"The parent op for 'ModelDataset' op was not found.") + if target_op[0].type != "OptimizeDataset": + raise TypeError(f"Op OptimizeDataset was not found.") target_op = target_op[0] else: # 'OptimizeDataset' is not available in TensorFlow2.X -- Gitee From e031af37543edeed342db7a3cb4031572b1b1c37 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 14 Aug 2023 20:13:58 +0800 Subject: [PATCH 258/551] Match-id-d47f8cc1be8b97aaffd508930efd8d95654bd3fc --- build/build.sh | 2 +- build/build_all.sh | 2 +- build/build_tf1.sh | 2 +- build/build_tf2.sh | 2 +- setup.py | 2 +- src/CMakeLists.txt | 1 - src/core/utils/common.cpp | 4 ++-- src/core/utils/common.h | 2 +- 8 files changed, 8 insertions(+), 9 deletions(-) diff --git a/build/build.sh b/build/build.sh index 2276e117..5f77ec2c 100644 --- a/build/build.sh +++ b/build/build.sh @@ -23,7 +23,7 @@ get_version() { VERSION=${VERSION%.*} fi else - VERSION="5.0.T104" + VERSION="5.0.rc3" fi } diff --git a/build/build_all.sh b/build/build_all.sh index 8492b989..26719538 100644 --- a/build/build_all.sh +++ b/build/build_all.sh @@ -50,7 +50,7 @@ get_version() { VERSION=${VERSION%.*} fi else - VERSION="5.0.T104" + VERSION="5.0.rc3" fi } diff --git a/build/build_tf1.sh b/build/build_tf1.sh index 3b8547b5..481c4429 100644 --- a/build/build_tf1.sh +++ b/build/build_tf1.sh @@ -28,7 +28,7 @@ get_version() { VERSION=${VERSION%.*} fi else - VERSION="5.0.T104" + VERSION="5.0.rc3" fi } diff --git a/build/build_tf2.sh b/build/build_tf2.sh index 0765dfca..e9e7d05c 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -35,7 +35,7 @@ get_version() { VERSION=${VERSION%.*} fi else - VERSION="5.0.T104" + VERSION="5.0.rc3" fi } diff --git a/setup.py b/setup.py index d1a975d7..55573ac2 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,7 @@ except IOError: LONG_DESCRIPTION = "" env_version = os.getenv("VERSION") -VERSION = env_version if env_version is not None else '5.0.T104' +VERSION = env_version if env_version is not None else '5.0.rc3' INIT_FILE = "mx_rec/__init__.py" with open(INIT_FILE, 'r') as file: diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1a8ad26a..590d7e58 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -95,7 +95,6 @@ endif() if(IS_DIRECTORY ${OPENSOURCE_DIR}) add_subdirectory(${OPENSOURCE_DIR}/pybind11 pybind11.out) - option(WITH_CUSTOM_PREFIX "use for glog v0.5.0 to enable custom log format" ON) add_subdirectory(${OPENSOURCE_DIR}/glog glog.out) include_directories(glog.out) install(TARGETS glog LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 385d3143..5483b6eb 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -109,12 +109,12 @@ namespace MxRec { g_rankId = std::to_string(rank); } if (!g_isGlogInit) { - google::InitGoogleLogging("mxRec", &CustomGlogFormat); + InitGoogleLogging("mxRec", &CustomGlogFormat); g_isGlogInit = true; } } - void CustomGlogFormat(std::ostream &s, const LogMessageInfo &l, void*) + void CustomGlogFormat(std::ostream &s, const google::LogMessageInfo &l, void*) { s << "[" << setw(GLOG_TIME_WIDTH_2) << l.time.hour() << ':' diff --git a/src/core/utils/common.h b/src/core/utils/common.h index b6c64416..e1c7fc75 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -280,7 +280,7 @@ struct BatchTask { void SetLog(int rank); - void CustomGlogFormat(std::ostream &s, const LogMessageInfo &l, void*); + void CustomGlogFormat(std::ostream &s, const google::LogMessageInfo &l, void*); template string StringFormat(const string& format, Args ... args) -- Gitee From a4e51c612cbbe02eda6a7b0f9cd32925eea02ef4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 14 Aug 2023 21:01:53 +0800 Subject: [PATCH 259/551] Match-id-cede0472ec28977796e651175d00a8323b0499eb --- mx_rec/__init__.py | 3 +- mx_rec/constants/constants.py | 3 +- mx_rec/core/asc/helper.py | 11 +- mx_rec/core/embedding.py | 99 +++++++++++------- mx_rec/graph/merge_lookup.py | 81 +++++++++++++++ mx_rec/graph/modifier.py | 188 ++++------------------------------ mx_rec/graph/patch.py | 56 ++++++++-- mx_rec/graph/utils.py | 30 +++++- mx_rec/util/initialize.py | 78 +++++++++++++- 9 files changed, 321 insertions(+), 228 deletions(-) create mode 100644 mx_rec/graph/merge_lookup.py diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index a5a58739..f5d02cd0 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -8,11 +8,12 @@ from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator, patch_for_bool_gauge, \ - patch_for_end, patch_for_assert_eval_spec + patch_for_end, patch_for_assert_eval_spec, patch_for_scale_loss from mx_rec.optimizers.base import patch_for_optimizer patch_for_saver() patch_for_dataset() +patch_for_scale_loss() patch_for_chief_session_creator() patch_for_assert_eval_spec() patch_for_bool_gauge() diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 92c55790..222862d9 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -31,7 +31,7 @@ LOCAL_RANK_SIZE = "LOCAL_RANK_SIZE" # 训练时,当前服务器使用的NPU MAX_DEVICE_NUM_LOCAL_MACHINE = 16 # 单台服务器最大的卡数 DEFAULT_DEVICE_NUM_LOCAL_MACHINE = 8 # 单台服务器默认的卡数 -MULTI_LOOKUP_TIMES = 2048 +MULTI_LOOKUP_TIMES = 128 DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 @@ -96,6 +96,7 @@ class ASCAnchorAttr(Enum): ALL2ALL_MATRIX = "all2all_matrix" HOT_POS = "hot_pos" LOOKUP_RESULT = "lookup_result" + MOCK_LOOKUP_RESULT = "mock_lookup_result" RESTORE_VECTOR_SECOND = "restore_vector_second" UNIQUE_KEYS = "unique_keys" GRADIENTS_STRATEGY = "gradients_strategy" diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 64b78f2c..4963fec6 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -94,7 +94,6 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ def insert_fn_for_arg_indexes(*args): insert_tensors = get_target_tensors_with_args_indexes(args_index_list) - # config timestamp later logging.debug(f"do_insert without spec for {table_names}") splits = [] @@ -293,7 +292,6 @@ def get_valid_op_key(batch_dict: dict) -> str: def get_target_tensors_with_args_indexes(args_index_list): insert_tensors = [] graph = tf.compat.v1.get_default_graph() - for index in args_index_list: tensor = graph.get_tensor_by_name("args_%d:0" % index) if tensor.dtype != tf.int64: @@ -329,6 +327,9 @@ def get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, rea else: raise ValueError(f"Encounter a invalid batch.") + # Ensure that the sequence of the `read emb key` op input tensor is the same as that of the split result + # of the multi lookup in a same table. + reshape_name = "reshape_" + feature_spec.name if feature_spec.is_timestamp is None: result = feature_spec.set_feat_attribute(tensor, is_training) tensor = result.get("tensor") @@ -337,15 +338,15 @@ def get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, rea if tensor.dtype != tf.int64: tensor = tf.cast(tensor, dtype=tf.int64) - read_emb_key_inputs_dict["insert_tensors"].append(tf.reshape(tensor, [-1, ])) + read_emb_key_inputs_dict["insert_tensors"].append(tf.reshape(tensor, [-1, ], name=reshape_name)) read_emb_key_inputs_dict["table_names"].append(table_name) read_emb_key_inputs_dict["splits"].append(split) read_emb_key_inputs_dict["feature_spec_names"].append(feature_spec.name) elif feature_spec.is_timestamp: if len(tensor.shape.as_list()) != 0: raise ValueError(f"Given TimeStamp Tensor must be a scalar.") - read_emb_key_inputs_dict["insert_tensors"] = [tf.reshape(tensor, [-1, ])] + \ - read_emb_key_inputs_dict["insert_tensors"] + read_emb_key_inputs_dict["insert_tensors"] = [tf.reshape( + tensor, [-1, ], name=reshape_name)] + read_emb_key_inputs_dict.get("insert_tensors", []) feature_spec.include_timestamp(is_training) elif tensor is not None: raise ValueError(f"Spec timestamp should be true when batch contains timestamp.") diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 115848c4..a506143f 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -8,7 +8,6 @@ import re from collections import defaultdict from typing import Optional -import numpy as np import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -18,11 +17,9 @@ from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2 from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temporary_feature_spec_attribute from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_HOT_POS, \ - ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR, MxRecMode, \ - ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ - MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, \ - ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT, All2allGradientsOp, ApplyGradientsStrategy +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, MxRecMode, \ + ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, \ + MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ @@ -107,10 +104,13 @@ def sparse_lookup(hashtable, ids, send_count, is_train, **kwargs): kwargs["modify_graph"] = kwargs.get("modify_graph", False) if not isinstance(kwargs.get("modify_graph"), bool): - raise TypeError("Given name must be a boolean.") + raise TypeError("Given modify_graph must be a boolean.") if not isinstance(kwargs.get("is_train"), bool): - raise TypeError("Given name must be a boolean.") + raise TypeError("Given is_train must be a boolean.") + + if send_count is not None and not isinstance(send_count, int): + raise TypeError("Given send_count must be an int.") def check_table_legality_for_feature_spec(table, feature_spec): # check whether the name of the table exists with FeatureSpec. @@ -126,16 +126,18 @@ def sparse_lookup(hashtable, ids, send_count, is_train, **kwargs): check_lookup_kwargs() scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) with tf.compat.v1.variable_scope(scope_name): - if hashtable.mode == MxRecMode.ASC: - if isinstance(ids, FeatureSpec): - check_table_legality_for_feature_spec(hashtable, ids) - return hashtable.lookup_for_asc_with_feature_spec(ids, send_count, **kwargs) - else: - check_modify_graph() - set_modify_graph(True) - return hashtable.lookup_for_asc(ids, send_count, **kwargs) - else: - raise EnvironmentError(f"Invalid MxRec Mode.") + if hashtable.mode != MxRecMode.ASC: + raise EnvironmentError("Invalid MxRec Mode.") + if not isinstance(ids, (FeatureSpec, tf.Tensor)): + raise ValueError(f"Invalid ids type, it should be: `FeatureSpec` or `tf.Tensor`, but get `{type(ids)}`.") + + if isinstance(ids, FeatureSpec): + check_table_legality_for_feature_spec(hashtable, ids) + return hashtable.lookup_for_asc_with_feature_spec(ids, send_count, **kwargs) + + check_modify_graph() + set_modify_graph(True) + return hashtable.lookup_for_asc(ids, send_count, **kwargs) class SparseEmbedding: @@ -365,12 +367,14 @@ class SparseEmbedding: f"for {method_mode} was in use.") def check_multi_lookup_times(self): - if self.modify_graph: - self.lookup_result = dict() - if len(self.lookup_name_list) > MULTI_LOOKUP_TIMES or len(self.lookup_result) > MULTI_LOOKUP_TIMES: + lookup_times = len(self.lookup_name_list) if self.modify_graph else len(self.lookup_result) + if not self.modify_graph and get_training_mode_channel_id(True) is not None and \ + get_training_mode_channel_id(False) is not None: + lookup_times = int(lookup_times / 2) + if lookup_times > MULTI_LOOKUP_TIMES: run_mode = "Modify Graph" if self.modify_graph else "Feature Spec" raise RuntimeError(f"In '{run_mode}' mode, the number of multiple sparse lookup for a table" - f"({self.table_name}) is {MULTI_LOOKUP_TIMES}.") + f"({self.table_name}) is {MULTI_LOOKUP_TIMES}, and current times is {lookup_times}.") def check_and_format_lookup_params(self, feature, send_count, is_training): logging.debug(f"sparse lookup for table {self.table_name} with is_training {is_training}") @@ -450,33 +454,43 @@ class SparseEmbedding: """ logging.debug(f"Enter ASC Branch.") - + # check params self.check_mode(MxRecMode.ASC) is_training = kwargs.get("is_train") + self.check_and_format_lookup_params(ids, send_count, is_training) + self.same_table_send_count += send_count if send_count is not None and is_training else 0 if is_asc_frozen() and is_training: raise RuntimeError(f"Cannot build new sparse forward graph after emb cache management was built.") - self.same_table_send_count += send_count if send_count is not None else 0 + # create feature spec feature_spec = get_feature_spec(self.table_name, kwargs.get("access_and_evict_config")) feature_spec.set_feat_attribute(ids, is_training) # 'clear_channel()' function needs to be executed after 'set_feat_attribute()' function if is_asc_frozen() and not is_training: clear_channel(is_train_channel=False) - self.check_and_format_lookup_params(ids, send_count, is_training) + # record anchor ids anchor_ids = tf.identity(ids, name="ids") tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, anchor_ids) self.register_anchor_attribute(anchor_ids, feature_spec, kwargs) - eval_mode = not is_training and get_training_mode_channel_id(is_training) is None + + # record multi lookup info + eval_mode = not is_training and get_training_mode_channel_id(True) is None ids_lookup_name = feature_spec.name + "_lookup_ids" # set in train mode, train and eval mode, eval mode if is_training or eval_mode: self.lookup_name_list.append(ids_lookup_name) self.modify_graph = kwargs.get("modify_graph", True) self.check_multi_lookup_times() - kwargs["ids"] = ids - return self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) + # return the stub tensor of the lookup result + result_shape = ids.shape.as_list() + [self.scalar_emb_size] if get_use_static() else \ + array_ops.concat([array_ops.shape(ids), [self.scalar_emb_size]], 0) + mock_lookup_result = tf.ones(shape=result_shape, dtype=tf.float32, name="mock_lookup_result") + mock_lookup_result = tf.identity(mock_lookup_result, name=ASCAnchorAttr.MOCK_LOOKUP_RESULT.value) + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.MOCK_LOOKUP_RESULT] = mock_lookup_result + logging.debug("Return the stub tensor `%s` of the `%s` table.", mock_lookup_result, self.table_name) + return mock_lookup_result def lookup_for_asc_with_feature_spec(self, feature_spec: FeatureSpec, send_count: int, **kwargs): """ @@ -497,13 +511,13 @@ class SparseEmbedding: if spec_name in self.lookup_result and is_training in self.lookup_result.get(spec_name): return self.lookup_result.get(spec_name).get(is_training) - if not get_use_static() and kwargs.get("batch") is None: - raise RuntimeError(f"When the 'feature spec' mode and 'dynamic shape' are used, the 'batch' is required.") + if not get_use_static() and not self.modify_graph and kwargs.get("batch") is None: + raise RuntimeError("When the 'feature spec' mode and 'dynamic shape' are used, the 'batch' is required.") table_name = feature_spec.table_name same_table_feature_spec = ConfigInitializer.get_instance().table_name_to_feature_spec[table_name][is_training] same_table_spec_count = len(same_table_feature_spec) if same_table_spec_count == 0: - raise RuntimeError(f"spec_name {spec_name} not in table {table_name}") + raise RuntimeError(f"spec_name {spec_name} not in table {table_name}.") if same_table_spec_count == 1: lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) if spec_name not in self.lookup_result: @@ -517,12 +531,21 @@ class SparseEmbedding: """ same_table_tensor_list = [] for feat_spec in same_table_feature_spec: - tensor = kwargs.get("batch").get(feat_spec.index_key) + batch_tensor_dict = kwargs.get("batch") if not self.modify_graph else \ + kwargs.get("feature_spec_name_ids_dict") + if batch_tensor_dict is None: + raise KeyError(f"The tensor dict of batch does not exist in kwargs, " + f"and modify graph is `{self.modify_graph}`.") + tensor = batch_tensor_dict.get(feat_spec.index_key) if not self.modify_graph else \ + batch_tensor_dict.get(feat_spec.name) if tensor is None: - raise KeyError(f"index_key '{feat_spec.index_key}' does not exist in batch.") + tensor_key = feat_spec.index_key if not self.modify_graph else feat_spec.name + raise KeyError(f"Key `{tensor_key}` does not exist in batch_tensor_dict.") same_table_tensor_list.append(tensor) return same_table_tensor_list + # Ensure that tensors in the same table are sorted according to the lookup sequence (modify graph mode) or + # the sequence in which feature specs are created (feature spec mode). same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", feat_count=1, table_name=table_name) @@ -537,13 +560,14 @@ class SparseEmbedding: set_temporary_feature_spec_attribute(mock_feature_spec, total_feature_count) kwargs["multi_lookup"] = True - lookup_result = self.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, - send_count * same_table_spec_count, **kwargs) + total_send_count = self.same_table_send_count if self.modify_graph else send_count * same_table_spec_count + lookup_result = self.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, total_send_count, **kwargs) logging.debug(f"lookup table {table_name} via {tensor_split_list}") self.split_lookup_result(same_table_feature_spec, tensor_split_list, tensor_list, lookup_result, is_training) - self.check_multi_lookup_times() + if not self.modify_graph: + self.check_multi_lookup_times() return self.lookup_result.get(spec_name).get(is_training) def split_lookup_result(self, same_table_feature_spec: list, tensor_split_list: list, tensor_list: list, @@ -662,11 +686,12 @@ class SparseEmbedding: tensor = kwargs.get("batch").get(feature_spec.index_key) \ if not self.modify_graph else kwargs.get("ids") if tensor is None: - raise KeyError(f"index_key '{feature_spec.index_key}' does not exist in batch.") + raise KeyError(f"key or ids does not exist in batch, now modify graph is {self.modify_graph}.") dest_shape = array_ops.concat([array_ops.shape(tensor), [self.scalar_emb_size]], 0) lookup_result = array_ops.reshape(embeddings, dest_shape) def grad(lookup_diff): + logging.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py new file mode 100644 index 00000000..46dfcbd4 --- /dev/null +++ b/mx_rec/graph/merge_lookup.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import logging + +import tensorflow as tf + +from mx_rec.constants.constants import ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ENTRANCE +from mx_rec.core.embedding import SparseEmbedding +from mx_rec.graph.utils import check_cutting_points, replace_anchor_vec +from mx_rec.util.initialize import get_modify_graph, get_merged_multi_lookup, insert_merged_multi_lookup + + +def do_merge_lookup(is_train: bool = True): + """ + 自动改图一表一查/多查,添加前向和反向节点: + 1. 如果存在一表多查的情况,则对多查的表进行lookup合并操作,并用合并后的lookup result替换原来打桩的 mock lookup result. + 2. 若不存在一表多查,则无需合并,用sparse forward得到的lookup result替换原来打桩的 mock lookup result. + 3. 自动改图模式需要执行此函数,feature spec模式直接return. + 4. 此函数在Optimizer.compute_gradients()中利用patch执行,确保train时拥有正确的梯度和计算图;eval时在改图阶段执行. + + Args: + is_train: 当前是否为训练模式,训练模式为True,否则为False + + Returns: None + + """ + + if not get_modify_graph(): + logging.debug("The `do_merge_multi_lookup` function is called only for `modify graph` mode.") + return + if get_merged_multi_lookup(is_train): + logging.debug("The merge multi lookup has been executed once and does not need to be executed again.") + return + logging.info("start to merge multi lookup, mode(train: True, eval: False): %s.", is_train) + + # get anchor ids + cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) + if not cutting_point_list: + raise RuntimeError("The sparse table does not have sparse lookup.") + check_cutting_points(cutting_point_list) + + # get lookup info + sub_cutting_points_dict = dict() + feature_spec_name_ids_dict = dict() + for cutting_point in cutting_point_list: + is_training = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_TRAINING) + if is_training != is_train: + logging.debug("Skip! The current mode(train: True, eval: False) is %s, but the mode of %s is %s.", + is_train, cutting_point, is_training) + continue + + table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) + if len(table_instance.lookup_name_list) > 1: + feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) + feature_spec_name_ids_dict[feature_spec.name] = cutting_point + if sub_cutting_points_dict.get(is_training) is None: + sub_cutting_points_dict[is_training] = [] + sub_cutting_points_dict[is_training].append(cutting_point) + + # merge or restore lookup + sub_cutting_point_list = sub_cutting_points_dict.get(is_train) + if not sub_cutting_point_list: + raise RuntimeError(f"The current mode(train: True, eval: False) is {is_train}, and the sparse table does not " + f"have anchor ids.") + for cutting_point in sub_cutting_point_list: + feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) + table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) + send_count = table_instance.send_count + kwargs = dict(is_train=is_train, ids=cutting_point) + if len(table_instance.lookup_name_list) > 1: + kwargs["multi_lookup"] = True + kwargs["feature_spec_name_ids_dict"] = feature_spec_name_ids_dict + lookup_result = table_instance.lookup_for_asc_with_feature_spec(feature_spec, send_count, **kwargs) + replace_anchor_vec(cutting_point, ASCAnchorAttr.MOCK_LOOKUP_RESULT, lookup_result) + logging.debug("The mock lookup result of %s for %s was replaced.", feature_spec.name, table_instance.table_name) + + # records whether the current mode has been merged or restored lookup + insert_merged_multi_lookup(is_train, True) + logging.info("finish to merge multi lookup, mode(train: True, eval: False): %s.", is_train) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 8e771426..5eff82e8 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -4,26 +4,23 @@ import logging from collections import defaultdict -from functools import reduce import tensorflow as tf -from tensorflow.python.ops import array_ops from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter -from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.helper import get_asc_insert_func -from mx_rec.core.asc.feature_spec import FeatureSpec, set_temporary_feature_spec_attribute +from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ ASCAnchorAttr, ASCEND_TIMESTAMP, ApplyGradientsStrategy -from mx_rec.util.initialize import get_rank_size, get_training_mode_channel_id, get_feature_spec, \ - insert_feature_spec, set_initializer, get_use_static, get_use_hot, get_device_id, get_use_dynamic_expansion, \ +from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, increase_run_times, \ - get_is_last_round + get_is_last_round, insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch from mx_rec.util.perf import performance -from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, replace_anchor, \ - record_ops_to_replace, export_pb_graph, make_sorted_key_to_tensor_list +from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ + export_pb_graph, make_sorted_key_to_tensor_list +from mx_rec.graph.merge_lookup import do_merge_lookup def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tensor_names=None, @@ -337,6 +334,7 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): new_batch = new_iterator.get_next() tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, new_iterator.initializer) set_initializer(is_training, new_iterator.initializer) + set_target_batch(is_training, new_batch) try: one_tensor = [v for _, v in new_batch.items()][0] @@ -345,108 +343,22 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): new_get_next_op_name = find_target_dataset_op(one_tensor.op, "IteratorGetNext").name update_input_tensor_with_new_batch(records.get("replacement_specs"), new_get_next_op_name) - # multiple lookups of a same table - lookup_for_same_table(sub_cutting_point_list, is_training) - # replace the stub node for sparse lookup from the graph - + # In eval mode, backward is not required. In addition, compute gradients is not executed when + # only eval is used. Therefore, `do_merge_lookup` needs to be invoked during modify graph. + if not is_training: + do_merge_lookup(is_train=False) + if 'evaluate' in get_bool_gauge_set(): + logging.debug("In estimator mode, eval re-creates graph each time, so the flag needs to be cleared.") + insert_merged_multi_lookup(is_training, False) + # In training mode, `do_merge_lookup` should have been executed in compute gradients phase. + if is_training and not get_merged_multi_lookup(True): + raise RuntimeError("In training mode, `do_merge_lookup` should have been executed in compute gradients " + "phase. Please check whether compute gradients is performed.") logging.info("Graph has been revised.") export_pb_graph("new_graph.pb", dump_graph) -def lookup_for_same_table(sub_cutting_point_list: list, is_training: bool): - """ - Merge multiple lookups of a sparse table into one lookup. - - Args: - sub_cutting_point_list: the feature ids list passed in by sparse lookup - is_training: indicates whether the training mode is used - - Returns: None - - """ - same_table_feature_spec_dict = {} - feature_spec_ids_dict = {} - for cutting_point in sub_cutting_point_list: - feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) - table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) - if len(table_instance.lookup_name_list) > 1: - if same_table_feature_spec_dict.get(table_instance.table_name) is None: - same_table_feature_spec_dict[table_instance.table_name] = [] - same_table_feature_spec_dict[table_instance.table_name].append(feature_spec) - feature_spec_ids_dict[feature_spec.name] = cutting_point - - for table_name, same_feature_spec_list in same_table_feature_spec_dict.items(): - same_table_feature_spec = sorted(same_feature_spec_list, key=lambda x: x.name) - mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", feat_count=1, table_name=table_name) - - tensor_split_list = [] - tensor_list = [] - table_instance = None - for one_feature_spec in same_table_feature_spec: - feature_ids = feature_spec_ids_dict.get(one_feature_spec.name) - if feature_ids is None: - raise RuntimeError("In the case of multiple lookups of a table, feature ids cannot be None.") - tensor_list.append(feature_ids) - table_instance = SparseEmbedding.get_anchor_attribute(feature_ids, ASCAnchorAttr.TABLE_INSTANCE) - - # dynamic shape - if not get_use_static(): - tensor_split_list.append(tf.math.reduce_prod(array_ops.shape(feature_ids))) - continue - - # static shape - rank = feature_ids.shape.rank - if rank < 1: - raise ValueError(f"Given tensor rank cannot be smaller than 1, which is {rank} now.") - dims = feature_ids.shape.as_list() - feat_cnt = 1 if rank == 1 else reduce(lambda x, y: x * y, dims[1:]) - tensor_split_list.append(dims[0] * feat_cnt) - - total_feature_count = sum(tensor_split_list) if get_use_static() else tf.add_n(tensor_split_list) - set_temporary_feature_spec_attribute(mock_feature_spec, total_feature_count) - - kwargs = {"multi_lookup": True, "is_train": is_training} - if table_instance is None: - raise RuntimeError("In the case of multiple lookups of a table, table instance cannot be None.") - lookup_result = table_instance.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, - table_instance.same_table_send_count, - **kwargs) - table_instance.split_lookup_result(same_table_feature_spec, tensor_split_list, tensor_list, lookup_result, - is_training) - logging.info(f"Multiple lookups of a table for '{table_name}' have completed.") - - -def replace_stub_node_with_asc_graph(sub_cutting_point_list: list, is_training: bool): - """ - Replace the stub node for sparse lookup from the graph. e.g., id_offset, restore_vector, etc. - - Args: - sub_cutting_point_list: the feature ids list passed in by sparse lookup - is_training: indicates whether the training mode is used - - Returns: None - - """ - for _, cutting_point in enumerate(sub_cutting_point_list): - feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) - table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) - channel_id = get_training_mode_channel_id(is_training) - config = dict( - batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, - send_count=table_instance.send_count, channel_id=channel_id, rank_size=get_rank_size(), - table_name=table_instance.table_name, skip_emb_transfer=table_instance.skip_emb_transfer, - ext_emb_size=table_instance.ext_emb_size, emb_size=table_instance.scalar_emb_size, - use_hot=get_use_hot(), device_id=get_device_id(), use_dynamic_expansion=get_use_dynamic_expansion(), - is_training=is_training) - - lookup_result = None - if len(table_instance.lookup_name_list) > 1 and feature_spec.name in table_instance.lookup_result and \ - is_training in table_instance.lookup_result.get(feature_spec.name): - lookup_result = table_instance.lookup_result.get(feature_spec.name).get(is_training) - build_asc_graph(config, table_instance, cutting_point, lookup_result) - - def get_timestamp_index(get_next_op, is_training): timestamp_tensor_list = tf.compat.v1.get_collection(ASCEND_TIMESTAMP) timestamp_index = None @@ -468,70 +380,6 @@ def get_timestamp_index(get_next_op, is_training): return timestamp_index -def build_asc_graph(config: dict, table_instance: SparseEmbedding, cutting_point: tf.Tensor, lookup_result: tf.Tensor): - """ - Build the GetNext node in the graph and replace the stub node with this node. - - Args: - config: parameters required for GetNext - table_instance: sparse embedding table - cutting_point: the feature ids passed in by sparse lookup - lookup_result: results of the sparse lookup - - Returns: None - - """ - # returned results swap_pos and swap_len were not in used, will be applied for DDR mode - logging.debug(f"try to replace anchors for table {config.get('table_name')} on channel {config.get('channel_id')}") - - # In the case of multiple lookups of a table, replace the stub node of the lookup result in the graph - if len(table_instance.lookup_name_list) > 1: - if lookup_result is None: - raise RuntimeError("In the case of multiple lookups of a table, lookup result cannot be None.") - replace_anchor_vec(cutting_point, ASCAnchorAttr.LOOKUP_RESULT, lookup_result) - logging.info(f"The lookup result corresponding to feature ids '{cutting_point}' has been replaced by " - f"'{lookup_result}'.") - return - - skip_emb_transfer = config.get("skip_emb_transfer") - logging.info(f"modifier build_asc_graph skip_emb_transfer: {skip_emb_transfer}") - if skip_emb_transfer: - result = get_preprocessed_tensor_for_asc(table_instance.variable, config) - else: - variable_list = [table_instance.variable] \ - + [slot_info.get("slot") for slot_info in table_instance.optimizer_slot_info_list] - result = get_preprocessed_tensor_for_asc(variable_list, config) - restore_vector = result.get("restore_vector") - hot_pos = result.get("hot_pos") - id_offsets = result.get("id_offsets") - swap_in = result.get("swap_in") - all2all_matrix = result.get("all2all_args") - - with tf.control_dependencies(swap_in): - id_offsets = tf.identity(id_offsets) - - logging.info(f"build_asc_graph -> id_offsets: {id_offsets}") - replace_anchor_vec(cutting_point, ASCAnchorAttr.ID_OFFSETS, id_offsets) - logging.info(f"build_asc_graph -> restore_vector: {restore_vector}") - replace_anchor_vec(cutting_point, ASCAnchorAttr.RESTORE_VECTOR, restore_vector) - - logging.info(f"build_asc_graph -> all2all_matrix: {all2all_matrix}") - if not get_use_static(): - replace_anchor_vec(cutting_point, ASCAnchorAttr.ALL2ALL_MATRIX, all2all_matrix) - - logging.info(f"build_asc_graph -> hot_pos: {hot_pos}") - if get_use_hot(): - replace_anchor_vec(cutting_point, ASCAnchorAttr.HOT_POS, hot_pos) - - logging.debug(f"has replace anchors for table {config.get('table_name')} on channel {config.get('channel_id')}") - - -def replace_anchor_vec(cutting_point, attribute, anchor): - anchor_vec = SparseEmbedding.get_anchor_attribute(cutting_point, attribute) - replacement_specs_for_anchor_vec = record_ops_to_replace(anchor_vec.op) - replace_anchor(replacement_specs_for_anchor_vec, [anchor]) - - class GraphModifierHook(tf.estimator.SessionRunHook): def __init__(self, dump_graph=True, modify_graph=True): self.dump_graph = dump_graph diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 3379c361..65683f5b 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -14,10 +14,15 @@ from tensorflow.python.data.ops.dataset_ops import _VariantTracker from tensorflow.python.framework import ops from tensorflow_estimator.python.estimator.training import EvalSpec from tensorflow.python.eager.monitoring import BoolGauge, BoolGaugeCell +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx +from tensorflow.python.distribute import reduce_util as ds_reduce_util +from tensorflow.python.training.optimizer import Optimizer from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph, insert_bool_gauge, \ get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round from mx_rec.util.tf_version_adapter import NPUCheckpointSaverHook +from mx_rec.graph.merge_lookup import do_merge_lookup def init_dataset(self, input_data): @@ -87,8 +92,9 @@ def get_cell(self: BoolGauge, *labels: Any) -> Any: Returns: Obtains the cell value set by the user. """ - logging.debug(f"Enter patch 'BoolGauge.get_cell'.") + logging.debug("Enter patch 'BoolGauge.get_cell'.") if len(labels) > 0: + logging.debug("BoolGauge insert: %s.", labels[0]) insert_bool_gauge(labels[0]) return BoolGaugeCell(super(BoolGauge, self).get_cell(*labels)) @@ -97,7 +103,7 @@ def patch_for_bool_gauge(): """Patch for 'BoolGauge.get_cell'.""" BoolGauge.get_cell = get_cell - logging.debug(f"Function 'get_cell' in Class 'BoolGauge' has been patched.") + logging.debug("Function 'get_cell' in Class 'BoolGauge' has been patched.") def end(self: NPUCheckpointSaverHook, session: tf.compat.v1.Session): @@ -112,14 +118,14 @@ def end(self: NPUCheckpointSaverHook, session: tf.compat.v1.Session): """ - logging.debug(f"Enter patch 'NPUCheckpointSaverHook.end'.") + logging.debug("Enter patch 'NPUCheckpointSaverHook.end'.") logging.info("NPUCheckpointSaverHook end...") basic_session_run_hooks.CheckpointSaverHook.end(self, session) if 'train_and_evaluate' in get_bool_gauge_set() and get_run_times() == 1: set_is_last_round(True) return - logging.debug(f"NPUCheckpointSaverHook call 'terminate_config_initializer'...") + logging.debug("NPUCheckpointSaverHook call 'terminate_config_initializer'...") terminate_config_initializer() @@ -127,7 +133,7 @@ def patch_for_end(): """Patch for 'NPUCheckpointSaverHook.end'.""" NPUCheckpointSaverHook.end = end - logging.debug(f"Function 'end' in Class 'NPUCheckpointSaverHook' has been patched.") + logging.debug("Function 'end' in Class 'NPUCheckpointSaverHook' has been patched.") def assert_eval_spec(eval_spec: EvalSpec): @@ -141,7 +147,7 @@ def assert_eval_spec(eval_spec: EvalSpec): """ - logging.debug(f"Enter patch 'tensorflow_estimator.python.estimator.training._assert_eval_spec'.") + logging.debug("Enter patch 'tensorflow_estimator.python.estimator.training._assert_eval_spec'.") if not isinstance(eval_spec, EvalSpec): raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`. Got: {}'.format(type(eval_spec))) @@ -154,4 +160,40 @@ def patch_for_assert_eval_spec(): """Patch for 'tensorflow_estimator.python.estimator.training._assert_eval_spec'.""" tensorflow_estimator_lib.python.estimator.training._assert_eval_spec = assert_eval_spec - logging.debug(f"Function '_assert_eval_spec' in 'tensorflow_estimator.python.estimator.training' has been patched.") + logging.debug("Function '_assert_eval_spec' in 'tensorflow_estimator.python.estimator.training' has been patched.") + + +def scale_loss(self: Optimizer, loss_value: tf.Tensor) -> tf.Tensor: + """ + Multiply the loss value by a scalar factor. + + Args: + self: self: An `Optimizer` instance. + loss_value: A Tensor containing the value to minimize or a callable taking no arguments which returns the value + to minimize. When eager execution is enabled it must be a callable. + + Returns: loss_value + + """ + + logging.debug("Enter patch 'Optimizer._scale_loss'.") + # In train mode, merge lookup must be completed during compute gradients. + # Ensure that the backward of graph is constructed and the gradient calculation is correct. + do_merge_lookup(is_train=True) + + # origin code + ops.get_default_graph()._is_loss_scaled_by_optimizer = False + if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: + # origin name is num_replicas + loss_num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync + if loss_num_replicas > 1: + loss_value *= (1. / loss_num_replicas) + ops.get_default_graph()._is_loss_scaled_by_optimizer = True + return loss_value + + +def patch_for_scale_loss(): + """Patch for 'Optimizer._scale_loss'.""" + + Optimizer._scale_loss = scale_loss + logging.debug("Function '_scale_loss' in Class 'Optimizer' has been patched.") diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index 9d867fd8..e0e81b70 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -7,7 +7,8 @@ import os import tensorflow as tf -from mx_rec.constants.constants import DUMP_MIDIFY_GRAPH_FILE_MODE +from mx_rec.constants.constants import ASCAnchorAttr, DUMP_MIDIFY_GRAPH_FILE_MODE +from mx_rec.core.embedding import SparseEmbedding def check_input_list(objs, obj_type): @@ -55,7 +56,8 @@ def record_ops_to_replace(src_op): def replace_anchor(replacement_specs: defaultdict, new_tensor_list: list): if len(replacement_specs) != len(new_tensor_list): - raise ValueError("Given replacement_specs and new_tensor_list must have the same length.") + raise ValueError(f"Given replacement_specs and new_tensor_list must have the same length. " + f"replacement_specs: {replacement_specs}, new_tensor_list: {new_tensor_list}") for tensor_idx, (_, items) in enumerate(replacement_specs.items()): for input_idx, operator in items: @@ -103,3 +105,27 @@ def make_sorted_key_to_tensor_list(element_spec, sorted_keys, prefix=""): return sorted_keys raise TypeError(f"Given element_spec, whose type is {type(element_spec)}, is invalid.") + + +def replace_anchor_vec(cutting_point: tf.Tensor, attribute: ASCAnchorAttr, anchor: tf.Tensor): + """ + 根据打桩节点的名字找到以此为输入的op,并将该op的输入替换为入参anchor. + + Args: + cutting_point: sparse lookup查询的ids + attribute: 被替换的打桩节点的名字 + anchor: 用来替换打桩节点的tensor + + Returns: None + + """ + + # get stub node + anchor_vec = SparseEmbedding.get_anchor_attribute(cutting_point, attribute) + if anchor_vec is None: + raise RuntimeError(f"Node `{attribute.value}` does not exist. Check whether the sparse lookup interface " + f"is correctly invoked.") + # find the op with stub node as the input + replacement_specs_for_anchor_vec = record_ops_to_replace(anchor_vec.op) + # replace anchor_vec with anchor + replace_anchor(replacement_specs_for_anchor_vec, [anchor]) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 048af669..5b030c11 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -9,7 +9,7 @@ import psutil import mx_rec.constants.constants from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST, LOCAL_RANK_SIZE, \ - MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH,\ + MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH, \ TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE from mx_rec.util.communication.hccl_mgmt import parse_hccl_json, set_hccl_info_without_json from mx_rec.util.ops import import_host_pipeline_ops @@ -31,7 +31,6 @@ class ConfigInitializer: self._is_frozen = False self._train_steps = None self._eval_steps = None - self._prefetch_batch_number = None self._if_load = None self._table_instance_dict = dict() self._dangling_table = [] @@ -50,6 +49,8 @@ class ConfigInitializer: self._is_terminated = False self._is_last_round = False self._run_times = AtomicInteger() + self._merged_multi_lookup = dict() + self._target_batch = dict() if self._use_mpi: logging.debug(f"Using mpi to launch task.") @@ -66,11 +67,10 @@ class ConfigInitializer: self.train_steps = kwargs.get("train_steps", -1) self.eval_steps = kwargs.get("eval_steps", -1) self.check_parameters() - self.prefetch_batch_number = kwargs.get("prefetch_batch_number", 1) + self._prefetch_batch_number = kwargs.get("prefetch_batch_number", 1) self.if_load = kwargs.get("if_load", False) - if_dynamic = kwargs.get("use_dynamic", True) - self.use_static = not if_dynamic + self.use_static = not kwargs.get("use_dynamic", True) self.use_hot = kwargs.get("use_hot", True) self.use_dynamic_expansion = kwargs.get("use_dynamic_expansion", False) if kwargs.get("bind_cpu", True): @@ -80,6 +80,14 @@ class ConfigInitializer: def __del__(self): self.terminate() + @property + def merged_multi_lookup(self): + return self._merged_multi_lookup + + @property + def target_batch(self): + return self._target_batch + @property def is_last_round(self): return self._is_last_round @@ -365,6 +373,24 @@ class ConfigInitializer: self._initializer_dict[is_training] = initializer + def insert_merged_multi_lookup(self, is_training, value=True): + if not isinstance(is_training, bool): + raise TypeError(f"Given key must be a boolean, but got {is_training} for `merged_multi_lookup`.") + + self._merged_multi_lookup[is_training] = value + + def get_merged_multi_lookup(self, is_training): + return self._merged_multi_lookup.get(is_training) + + def set_target_batch(self, is_training, batch): + if not isinstance(is_training, bool): + raise TypeError(f"Given key must be a boolean, but got {is_training} for `target_batch`.") + + self._target_batch[is_training] = batch + + def get_target_batch(self, is_training): + return self._target_batch.get(is_training) + def delete_initializers(self): self._initializer_dict = {} @@ -660,6 +686,48 @@ def set_ascend_table_name_must_contain(name="merged"): mx_rec.constants.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name +def insert_merged_multi_lookup(is_training: bool, value: bool = True): + """ + 记录自动改图模式下是否调用了合并lookup的函数. + Args: + is_training: 当前是否为训练模式,训练模式为True,否则为False + value: 是否调用了合并lookup的函数, 调用了为True,否则为False + Returns: None + """ + ConfigInitializer.get_instance().insert_merged_multi_lookup(is_training, value) + + +def get_merged_multi_lookup(is_training: bool) -> bool: + """ + 返回自动改图模式下是否调用了合并lookup函数的记录. + Args: + is_training: 当前是否为训练模式,训练模式为True,否则为False + Returns: 调用记录,调用了为True,否则为False + """ + return ConfigInitializer.get_instance().get_merged_multi_lookup(is_training) + + +def set_target_batch(is_training: bool, batch: dict): + """ + 记录自动改图模式下生成新数据集中的batch. + Args: + is_training: 当前是否为训练模式,训练模式为True,否则为False + batch: 数据集中的batch + Returns: None + """ + ConfigInitializer.get_instance().set_target_batch(is_training, batch) + + +def get_target_batch(is_training: bool) -> dict: + """ + 返回自动改图模式下生成新数据集中batch的记录. + Args: + is_training: 当前是否为训练模式,训练模式为True,否则为False + Returns: 新数据集中的batch + """ + return ConfigInitializer.get_instance().get_target_batch(is_training) + + def set_ascend_env(): """ 配置昇腾相关的参数和环境变量,生成hccl配置 -- Gitee From 20d587e0c8993b7f922ffc0ae93977fef0bcece0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 15 Aug 2023 16:36:22 +0800 Subject: [PATCH 260/551] Match-id-ebe8beb29e65a2b0603ca0126355f547db01a645 --- mx_rec/core/embedding.py | 8 +++----- mx_rec/graph/merge_lookup.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index a506143f..a3c153c6 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -358,8 +358,6 @@ class SparseEmbedding: SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.TABLE_INSTANCE] = self SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = kwargs.get("is_train") SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.FEATURE_SPEC] = feature_spec - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.GRADIENTS_STRATEGY] = \ - self.apply_gradients_strategy def check_mode(self, method_mode): if self.mode != method_mode: @@ -484,9 +482,9 @@ class SparseEmbedding: self.check_multi_lookup_times() # return the stub tensor of the lookup result - result_shape = ids.shape.as_list() + [self.scalar_emb_size] if get_use_static() else \ - array_ops.concat([array_ops.shape(ids), [self.scalar_emb_size]], 0) - mock_lookup_result = tf.ones(shape=result_shape, dtype=tf.float32, name="mock_lookup_result") + if not get_use_static(): + kwargs["ids"] = ids + mock_lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) mock_lookup_result = tf.identity(mock_lookup_result, name=ASCAnchorAttr.MOCK_LOOKUP_RESULT.value) SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.MOCK_LOOKUP_RESULT] = mock_lookup_result logging.debug("Return the stub tensor `%s` of the `%s` table.", mock_lookup_result, self.table_name) diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py index 46dfcbd4..b0149587 100644 --- a/mx_rec/graph/merge_lookup.py +++ b/mx_rec/graph/merge_lookup.py @@ -9,7 +9,7 @@ import tensorflow as tf from mx_rec.constants.constants import ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ENTRANCE from mx_rec.core.embedding import SparseEmbedding from mx_rec.graph.utils import check_cutting_points, replace_anchor_vec -from mx_rec.util.initialize import get_modify_graph, get_merged_multi_lookup, insert_merged_multi_lookup +from mx_rec.util.initialize import get_modify_graph, get_merged_multi_lookup, insert_merged_multi_lookup, get_use_static def do_merge_lookup(is_train: bool = True): @@ -52,7 +52,7 @@ def do_merge_lookup(is_train: bool = True): continue table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) - if len(table_instance.lookup_name_list) > 1: + if not get_use_static() and len(table_instance.lookup_name_list) > 1: feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) feature_spec_name_ids_dict[feature_spec.name] = cutting_point if sub_cutting_points_dict.get(is_training) is None: @@ -65,12 +65,16 @@ def do_merge_lookup(is_train: bool = True): raise RuntimeError(f"The current mode(train: True, eval: False) is {is_train}, and the sparse table does not " f"have anchor ids.") for cutting_point in sub_cutting_point_list: - feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) + feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) + if len(table_instance.lookup_name_list) == 1: + logging.debug("The origin lookup result of %s for %s does not need to be replaced.", feature_spec.name, + table_instance.table_name) + continue + send_count = table_instance.send_count - kwargs = dict(is_train=is_train, ids=cutting_point) - if len(table_instance.lookup_name_list) > 1: - kwargs["multi_lookup"] = True + kwargs = dict(is_train=is_train, ids=cutting_point, multi_lookup=True) + if not get_use_static(): kwargs["feature_spec_name_ids_dict"] = feature_spec_name_ids_dict lookup_result = table_instance.lookup_for_asc_with_feature_spec(feature_spec, send_count, **kwargs) replace_anchor_vec(cutting_point, ASCAnchorAttr.MOCK_LOOKUP_RESULT, lookup_result) -- Gitee From c49c8254a2e79fc5b18b4a3c8f47a75910058de6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 17 Aug 2023 09:23:11 +0800 Subject: [PATCH 261/551] Match-id-1d9b8cb1686d75429e90e912a029ad87c95d050f --- mx_rec/constants/constants.py | 3 + mx_rec/graph/modifier.py | 222 +++++++++++++++++++++++++--------- mx_rec/util/initialize.py | 35 +++++- 3 files changed, 203 insertions(+), 57 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 222862d9..665f686b 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -21,6 +21,9 @@ CUSTOMIZED_OPS_LIB_PATH = "CUSTOMIZED_OPS_LIB_PATH" HOST_PIPELINE_OPS_LIB_PATH = "HOST_PIPELINE_OPS_LIB_PATH" ASCEND_SPARSE_LOOKUP_LOCAL_EMB = "ASCEND_SPARSE_LOOKUP_LOCAL_EMB" +# 自动改图模式下从计算图中寻找dataset的锚点名称 +ANCHOR_DATASET_NAME = "PrefetchDataset" + # the name of the embedding table merged by third party ASCEND_TABLE_NAME_MUST_CONTAIN = None diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 5eff82e8..8532c5ef 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -7,16 +7,18 @@ from collections import defaultdict import tensorflow as tf from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter +from tensorflow.python.framework.ops import Operation from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP, ApplyGradientsStrategy + ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, increase_run_times, \ - get_is_last_round, insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch + get_is_last_round, insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ + set_iterator_type from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ export_pb_graph, make_sorted_key_to_tensor_list @@ -125,7 +127,17 @@ def find_target_dataset_op(base_ops, op_type): raise ValueError(f"Op {op_type} was not found.") -def get_op_before_optimize_dataset(get_next_op): +def get_dataset_op(get_next_op: Operation) -> Operation: + """ + 根据`IteratorGetNext`算子从图中找到`OptimizeDataset`的dataset op. 注: TF2没有`OptimizeDataset`,则找的是dataset的默认锚点. + + Args: + get_next_op: `IteratorGetNext`算子 + + Returns: TF1返回`OptimizeDataset`算子,TF2返回dataset默认锚点的算子 + + """ + if get_next_op.type != "IteratorGetNext": raise TypeError("Op '{get_next_op}' must be one instance of IteratorGetNext.") @@ -142,7 +154,7 @@ def get_op_before_optimize_dataset(get_next_op): target_op = target_op[0] else: # 'OptimizeDataset' is not available in TensorFlow2.X - target_op = find_target_dataset_op(base_op, "PrefetchDataset") + target_op = find_target_dataset_op(base_op, ANCHOR_DATASET_NAME) return target_op @@ -232,29 +244,25 @@ def update_input_tensor_with_new_batch(replacement_specs, new_get_next_op_name): operator._update_input(idx, new_tensor) -def make_src_to_tgt_mapping(src_element_spec, tgt_element_spec): - # adding '_0' to the prefix - if not isinstance(src_element_spec, (list, tuple)): - src_element_spec = [src_element_spec] - src_sorted_keys = make_sorted_key_to_tensor_list(src_element_spec, []) - tgt_sorted_keys = make_sorted_key_to_tensor_list(tgt_element_spec, []) - index_to_src_key_mapping = dict([(idx, key) for idx, key in enumerate(src_sorted_keys)]) - tgt_key_to_index_mapping = dict([(key, idx) for idx, key in enumerate(tgt_sorted_keys)]) +def get_dataset_tensor_count(dataset: DatasetV1Adapter) -> int: + """ + 获取数据集中batch的tensor数量. + + Args: + dataset: 数据集实例 - original_tensor_count = len(src_sorted_keys) + Returns: 数据集batch中的tensor数量 - def mapping_func(src_idx): - key = index_to_src_key_mapping.get(src_idx) - if key is None: - raise ValueError("Given src_idx is out of range.") + """ - tgt_idx = tgt_key_to_index_mapping.get(key) - return tgt_idx + src_element_spec = dataset.element_spec + if not isinstance(src_element_spec, (list, tuple)): + src_element_spec = [src_element_spec] + src_sorted_keys = make_sorted_key_to_tensor_list(src_element_spec, []) - return mapping_func, original_tensor_count + return len(src_sorted_keys) -@performance("graph_modifier") def modify_graph_and_start_emb_cache(dump_graph=False): modify_graph_for_asc(dump_graph=dump_graph) start_asc_pipeline() @@ -287,18 +295,129 @@ def generate_get_next_op_specs(cutting_point_list, dump_graph): return get_next_op_map -def get_src_and_generate_tgt_dataset(get_next_op, records): - target_op = get_op_before_optimize_dataset(get_next_op) +def get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapter: + """ + 根据`IteratorGetNext`算子在计算图中找出原始dataset. + + Args: + get_next_op: `IteratorGetNext`算子 + is_training: 当前是否为训练模式,训练模式为True,否则为False + + Returns: 原始数据集 + + """ + + try: + target_op = get_dataset_op(get_next_op) + except (ValueError, TypeError, RuntimeError) as err: + logging.warning("The dataset op was not found, the error is `%s`. Start to traverse the operations.", err) + dataset_op_list = [op for op in tf.compat.v1.get_default_graph().get_operations() + if ANCHOR_DATASET_NAME in op.name] + logging.debug("In get_src_dataset function, current mode(train: True, eval: False): %s, dataset_op_list: %s.", + is_training, dataset_op_list) + + if len(dataset_op_list) == 1: + target_op = dataset_op_list[0] + elif is_training and len(dataset_op_list) == 2: + prefetch_dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) + target_op = prefetch_dataset_op_list[0] + elif not is_training and len(dataset_op_list) == 3: + prefetch_dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) + target_op = prefetch_dataset_op_list[1] + else: + raise RuntimeError(f"The `{ANCHOR_DATASET_NAME}` was not found from the operations, dataset_op_list: " + f"{dataset_op_list}.") from err + except Exception as err: + raise RuntimeError(f"The dataset was not found, the error is `{err}`.") from err + + if not target_op.outputs: + raise ValueError(f"The length of the outputs of target op `{target_op}` is 0.") + logging.debug("Find target op `%s`, and output is `%s`.", target_op.name, target_op.outputs) src_dataset = find_target_instance_dataset(target_op.outputs[0]) + return src_dataset + + +def get_tgt_dataset(src_dataset: DatasetV1Adapter, sub_cutting_point_list: list, records: dict, + dump_graph: bool = False, prefetch: int = 10) -> DatasetV1Adapter: + """ + 根据原始数据集生成新的数据集实例. + + Args: + src_dataset: 原始数据集实例 + sub_cutting_point_list: 打桩的lookup ids列表 + records: 记录被打桩ids对应输入/输出算子、子图关系等信息的字典 + dump_graph: 是否dump计算图,默认为False + prefetch: dataset预取数据量,默认为10 + + Returns: 新数据集实例 + + """ + tgt_dataset = src_dataset.map(get_preprocessing_map_func(records.get("sub_graph_def"), records.get("input_name_list"), records.get("output_name_list"), pipeline_input_indexes=records.get( "batch_tensor_index_list"))) - return src_dataset, tgt_dataset + feature_numbers = [SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).feat_cnt for + cutting_point in sub_cutting_point_list] + table_names = [SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).table_name for + cutting_point in sub_cutting_point_list] + tgt_dataset = tgt_dataset.map(get_asc_insert_func(feature_numbers=feature_numbers, + table_names=table_names, + args_index_list=records.get("input_index_list"), + is_training=records.get("is_training"), + dump_graph=dump_graph)) + + tgt_dataset = tgt_dataset.prefetch(prefetch) + return tgt_dataset + + +def update_iterator_getnext(get_next_op: Operation, tgt_dataset: DatasetV1Adapter, is_training: bool, records: dict): + """ + 用新数据集中的`IteratorGetNext`算子替换计算图中原始数据集的`IteratorGetNext`算子,即用新数据集的batch替换原始数据集的batch. + + Args: + get_next_op: `IteratorGetNext`算子 + tgt_dataset: 新数据集 + is_training: 当前是否为训练模式,训练模式为True,否则为False + records: 记录被打桩ids对应输入/输出算子、子图关系等信息的字典 + + Returns: None + + """ + + if not get_next_op.outputs: + raise RuntimeError("There is no tensor in the dataset. Please check the dataset and data processing.") + iterator_type = "" + if get_next_op.outputs[0].op.inputs: + iterator_type = get_next_op.outputs[0].op.inputs[0].op.type + if iterator_type == "IteratorV2": + iterator_type = find_make_iterator_op(get_next_op.outputs[0]).type + if iterator_type not in ("MakeIterator", "OneShotIterator"): + raise RuntimeError(f"Only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, " + f"but the current iterator is `{iterator_type}`.") + set_iterator_type(iterator_type) + logging.info("The iterator type of dataset is `%s`.", iterator_type) + + if iterator_type == "MakeIterator": + new_iterator = tgt_dataset.make_initializable_iterator() + tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, new_iterator.initializer) + set_initializer(is_training, new_iterator.initializer) + else: + new_iterator = tgt_dataset.make_one_shot_iterator() + new_batch = new_iterator.get_next() + set_target_batch(is_training, new_batch) + + try: + new_batch_tensor = list(new_batch.values())[0] + except IndexError as err: + raise IndexError("Cannot find a tensor from given batch.") from err + new_get_next_op_name = find_target_dataset_op(new_batch_tensor.op, "IteratorGetNext").name + update_input_tensor_with_new_batch(records.get("replacement_specs"), new_get_next_op_name) +@performance("graph_modifier") def modify_graph_for_asc(dump_graph=False, prefetch=10): cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) check_cutting_points(cutting_point_list) @@ -308,40 +427,29 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): export_pb_graph("old_graph.pb", dump_graph) get_next_op_map = generate_get_next_op_specs(cutting_point_list, dump_graph) + logging.debug("In modify_graph_for_asc function, get_next_op_map.len: %d, get_next_op_map.key: %s.", + len(get_next_op_map), get_next_op_map.keys()) for get_next_op, records in get_next_op_map.items(): is_training = records.get("is_training") + + # get source dataset + src_dataset = get_src_dataset(get_next_op, is_training) + + # generate target dataset timestamp_index = get_timestamp_index(get_next_op, is_training) - src_dataset, tgt_dataset = get_src_and_generate_tgt_dataset(get_next_op, records) - mapping_func, original_tensor_count = make_src_to_tgt_mapping(src_dataset.element_spec, - tgt_dataset.element_spec) + original_batch_tensor_count = get_dataset_tensor_count(src_dataset) sub_cutting_point_list = records.get("sub_cutting_point_list") input_index_list = get_input_index_list(sub_cutting_point_list, records.get("replacement_specs"), records.get("output_name_list"), - original_tensor_count, timestamp_index=timestamp_index) - feature_numbers = [SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).feat_cnt for - cutting_point in sub_cutting_point_list] - table_names = [SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).table_name for - cutting_point in sub_cutting_point_list] - - tgt_dataset = tgt_dataset.map( - get_asc_insert_func(feature_numbers=feature_numbers, table_names=table_names, - args_index_list=input_index_list, is_training=is_training, dump_graph=dump_graph)) + original_batch_tensor_count, timestamp_index=timestamp_index) + records["input_index_list"] = input_index_list + tgt_dataset = get_tgt_dataset(src_dataset, sub_cutting_point_list, records, + dump_graph=dump_graph, prefetch=prefetch) - tgt_dataset = tgt_dataset.prefetch(prefetch) - new_iterator = tgt_dataset.make_initializable_iterator() - new_batch = new_iterator.get_next() - tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, new_iterator.initializer) - set_initializer(is_training, new_iterator.initializer) - set_target_batch(is_training, new_batch) - - try: - one_tensor = [v for _, v in new_batch.items()][0] - except IndexError as err: - raise IndexError("Cannot find a tensor from given batch.") from err - new_get_next_op_name = find_target_dataset_op(one_tensor.op, "IteratorGetNext").name - update_input_tensor_with_new_batch(records.get("replacement_specs"), new_get_next_op_name) + # update the batch of dataset + update_iterator_getnext(get_next_op, tgt_dataset, is_training, records) # In eval mode, backward is not required. In addition, compute gradients is not executed when # only eval is used. Therefore, `do_merge_lookup` needs to be invoked during modify graph. @@ -382,18 +490,24 @@ def get_timestamp_index(get_next_op, is_training): class GraphModifierHook(tf.estimator.SessionRunHook): def __init__(self, dump_graph=True, modify_graph=True): - self.dump_graph = dump_graph - self.modify_graph = modify_graph + self._dump_graph = dump_graph + self._modify_graph = modify_graph + self._iterator_type = "" set_is_graph_modify_hook_running(True) def begin(self): - if self.modify_graph: - modify_graph_and_start_emb_cache(dump_graph=self.dump_graph) + if self._modify_graph: + modify_graph_and_start_emb_cache(dump_graph=self._dump_graph) else: start_asc_pipeline() + self._iterator_type = get_iterator_type() + if self._iterator_type not in ("MakeIterator", "OneShotIterator"): + raise ValueError("The value of iterator type should be like `MakeIterator` or `OneShotIterator`.") + logging.debug("In GraphModifierHook, iterator type is `%s`.", self._iterator_type) + def after_create_session(self, session, coord): - if self.modify_graph: + if self._modify_graph and self._iterator_type == "MakeIterator": session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) def end(self, session): diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 5b030c11..5da668c7 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -24,6 +24,8 @@ class ConfigInitializer: def __init__(self, use_mpi, **kwargs): self._use_mpi = use_mpi + self._rank_id = kwargs.get("rank_id", 0) + self._rank_size = kwargs.get("rank_size", 1) self._ascend_global_hashtable_collection = ASCEND_GLOBAL_HASHTABLE_COLLECTION self._comm = None self._asc_manager = None @@ -51,6 +53,7 @@ class ConfigInitializer: self._run_times = AtomicInteger() self._merged_multi_lookup = dict() self._target_batch = dict() + self._iterator_type = "" if self._use_mpi: logging.debug(f"Using mpi to launch task.") @@ -59,9 +62,6 @@ class ConfigInitializer: self._comm = MPI.COMM_WORLD self._rank_id = self._comm.Get_rank() self._rank_size = self._comm.Get_size() - else: - self._rank_id = kwargs.get("rank_id") - self._rank_size = kwargs.get("rank_size") self._rank_to_device_dict = parse_hccl_json() if os.getenv("RANK_TABLE_FILE") else set_hccl_info_without_json() self.train_steps = kwargs.get("train_steps", -1) @@ -80,6 +80,10 @@ class ConfigInitializer: def __del__(self): self.terminate() + @property + def iterator_type(self): + return self._iterator_type + @property def merged_multi_lookup(self): return self._merged_multi_lookup @@ -314,6 +318,13 @@ class ConfigInitializer: self.unfreeze() logging.debug("ASC manager has been destroyed.") + @iterator_type.setter + def iterator_type(self, iterator_type): + if not isinstance(iterator_type, str): + raise TypeError(f"iterator_type `{iterator_type}` should be str.") + + self._iterator_type = iterator_type + @train_steps.setter def train_steps(self, step: int): check_step(step) @@ -728,6 +739,24 @@ def get_target_batch(is_training: bool) -> dict: return ConfigInitializer.get_instance().get_target_batch(is_training) +def get_iterator_type() -> str: + """ + 返回数据集的迭代器类型. + Returns: 数据集的迭代器类型 + """ + return ConfigInitializer.get_instance().iterator_type + + +def set_iterator_type(iterator_type: str): + """ + 记录数据集的迭代器类型. + Args: + iterator_type: 数据集的迭代器类型 + Returns: None + """ + ConfigInitializer.get_instance().iterator_type = iterator_type + + def set_ascend_env(): """ 配置昇腾相关的参数和环境变量,生成hccl配置 -- Gitee From 0fc2f63e82caacd1226e86c274c89bd24e45ffbf Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 17 Aug 2023 17:06:10 +0800 Subject: [PATCH 262/551] Match-id-44b63974d2220fa0f599763a91b2c8bdb5fc508c --- mx_rec/graph/modifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 8532c5ef..efb06af2 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -502,7 +502,7 @@ class GraphModifierHook(tf.estimator.SessionRunHook): start_asc_pipeline() self._iterator_type = get_iterator_type() - if self._iterator_type not in ("MakeIterator", "OneShotIterator"): + if self._modify_graph and self._iterator_type not in ("MakeIterator", "OneShotIterator"): raise ValueError("The value of iterator type should be like `MakeIterator` or `OneShotIterator`.") logging.debug("In GraphModifierHook, iterator type is `%s`.", self._iterator_type) -- Gitee From 8cc998d60b039265304dc1ba8ccc4c7cc1d34868 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 17 Aug 2023 19:30:09 +0800 Subject: [PATCH 263/551] Match-id-59f9d1a26faa109c010654d61d5104aca7b1c688 --- src/core/checkpoint/checkpoint.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 73b54255..27d8783e 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -254,6 +254,9 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) auto &AttributeArr = transData.attribute; auto embHashMapSize = AttributeArr.at(0); + if (embHashMapSize <= 0) { + throw runtime_error(StringFormat("Invalid EmbHashMapSize:%d, must be greater than 0", embHashMapSize).c_str()); + } auto embeddingSize = static_cast(datasetSize / sizeof(float) / embHashMapSize); aclError ret; @@ -272,14 +275,12 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) readFile.read((char *) (row.data()), embeddingSize * sizeof(float)); aclError ret = aclrtMemcpy(floatPtr + i * embeddingSize, embeddingSize * sizeof(float), - row.data(), embeddingSize * sizeof(float), - ACL_MEMCPY_HOST_TO_DEVICE); + row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); readFile.close(); throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } - int64_t address = reinterpret_cast(floatPtr + i * embeddingSize); transArr.at(i + 1) = address; } @@ -484,18 +485,14 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, LOG(ERROR) << "dataElmtBytes is 0, don't handle [/ %] operation"; return ; } - auto embDataOuterSize = transData.attribute.at(attribEmbDataOuterIdx); - auto loadHostEmbs = ckptData.hostEmbs; - auto& dst = (*loadHostEmbs)[embName].embData; - dst.reserve(embDataOuterSize); - + if (embDataOuterSize <= 0) { + throw runtime_error(StringFormat("Invalid embDataOuterSize :%d", embDataOuterSize).c_str()); + } std::ifstream readFile; readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); - size_t datasetSize = static_cast(readFile.tellg()); readFile.seekg(0, std::ios::beg); - try { ValidateReadFile(dataDir, datasetSize); } catch (const std::invalid_argument& e) { @@ -508,6 +505,9 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, readFile.close(); throw runtime_error("unable to load EMB_DATA cause wrong-format saved emb data"); } + auto loadHostEmbs = ckptData.hostEmbs; + auto& dst = (*loadHostEmbs)[embName].embData; + dst.reserve(embDataOuterSize); auto onceReadByteSize { datasetSize / embDataOuterSize }; if (!readFile.is_open()) { -- Gitee From ecf5c4309faafa92324bf6617ce4fcbfd5e6e8da Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 17 Aug 2023 22:20:28 +0800 Subject: [PATCH 264/551] Match-id-11205680f8e0935ae102d36dd146af4869af0067 --- mx_rec/util/initialize.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 048af669..5e6c526d 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -199,6 +199,7 @@ class ConfigInitializer: logging.info("python process run terminate success") self._is_terminated = True + ConfigInitializer._single_instance = None def insert_feature_spec(self, feature, is_training): self._feature_spec_dict[feature.name] = feature -- Gitee From 97d5dcf1a70e3c0b87e1e92dae8acd1a42b78cf3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 19 Aug 2023 09:22:25 +0800 Subject: [PATCH 265/551] Match-id-e5f85a0e0dce3e534874d9dbb269263d6c714b1e --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 206 +-------------------------- src/core/hybrid_mgmt/hybrid_mgmt.h | 35 ++--- src/core/key_process/key_process.cpp | 7 - src/core/key_process/key_process.h | 12 +- src/core/utils/task_queue.h | 102 ------------- 5 files changed, 21 insertions(+), 341 deletions(-) delete mode 100644 src/core/utils/task_queue.h diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 0ba105ac..ff102852 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -6,9 +6,8 @@ */ #include "hybrid_mgmt.h" -#include "checkpoint/checkpoint.h" #include "utils/time_cost.h" -#include "utils/common.h" +#include "checkpoint/checkpoint.h" using namespace MxRec; using namespace std; @@ -104,12 +103,6 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, return false; } - lookUpKeysQueueForTrain = make_unique>>(); - restoreQueueForTrain = make_unique>>(); - lookUpKeysQueueForEval = make_unique>>(); - restoreQueueForEval = make_unique>>(); - a2aQueueForTrain = make_unique>>(); - a2aQueueForEval = make_unique>>(); isRunning = true; if (!rankInfo.noDDR) { @@ -343,21 +336,8 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) void HybridMgmt::Start() { #ifndef GTEST - int mode = 0; - const char* envTaskMode = std::getenv("MGMT_HBM_TASK_MODE"); // 获取环境变量 - if (envTaskMode != nullptr) { // 如果环境变量存在 - try { - mode = std::stoi(envTaskMode); // 将字符串转换为整数 - LOG(INFO) << StringFormat("The value of MGMT_HBM_TASK_MODE is an integer: %d", mode); - } catch (const std::invalid_argument& e) { // 如果转换失败 - LOG(ERROR) << "The value of MGMT_HBM_TASK_MODE is not an integer!"; - throw std::invalid_argument("Invalid env value MGMT_HBM_TASK_MODE"); - } - } else { // 如果环境变量不存在 - mode = 0; - } if (mgmtRankInfo.noDDR) { - InsertThreadForHBM(mode); + InsertThreadForHBM(); } if (!mgmtRankInfo.noDDR) { @@ -376,34 +356,9 @@ void HybridMgmt::Start() #endif } -void HybridMgmt::InsertThreadForHBM(int mode) +void HybridMgmt::InsertThreadForHBM() { #ifndef GTEST - if (mode == 1) { - auto getInfoTaskForTrain = [this]() { - TaskForTrain(TaskType::GETINFO); - LOG(INFO) << "getInfoTaskForTrain done"; - }; - procThreads.emplace_back(std::make_unique(getInfoTaskForTrain)); - - auto getInfoTaskForEval = [this]() { - TaskForEval(TaskType::GETINFO); - LOG(INFO) << "getInfoTaskForEval done"; - }; - procThreads.emplace_back(std::make_unique(getInfoTaskForEval)); - - auto sendInfoTaskForTrain = [this]() { - TaskForTrain(TaskType::SEND); - LOG(INFO) << "sendInfoTaskForTrain done"; - }; - procThreads.emplace_back(std::make_unique(sendInfoTaskForTrain)); - - auto sendInfoTaskForEval = [this]() { - TaskForEval(TaskType::SEND); - LOG(INFO) << "sendInfoTaskForEval done"; - }; - procThreads.emplace_back(std::make_unique(sendInfoTaskForEval)); - } else { auto parseKeysTaskForHBMTrain = [this]() { TaskForTrain(TaskType::HBM); LOG(INFO) << "parseKeysTaskForHBMTrain done"; @@ -415,7 +370,6 @@ void HybridMgmt::InsertThreadForHBM(int mode) LOG(INFO) << "parseKeysTaskForHBMEval done"; }; procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMEval)); - } #endif } @@ -464,27 +418,6 @@ bool HybridMgmt::TrainTask(TaskType type) bool status = false; switch (type) { - case TaskType::GETINFO: - status = GetLookupAndRestore(TRAIN_CHANNEL_ID, getInfoBatchId); - isContinue = getInfoBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; - LOG(INFO) << StringFormat(MGMT + "getInfoBatchId = %d", getInfoBatchId); - break; - case TaskType::SEND: - status = SendLookupAndRestore(TRAIN_CHANNEL_ID, sendBatchId); - isContinue = sendBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; - LOG(INFO) << StringFormat(MGMT + "sendBatchId = %d", sendBatchId); -#if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) - if (sendBatchId == PROFILING_START_BATCH_ID) { - EASY_PROFILER_ENABLE - } else if (sendBatchId == PROFILING_END_BATCH_ID) { - EASY_PROFILER_DISABLE - ::profiler::dumpBlocksToFile( - StringFormat("/home/MX_REC-mgmt-profile-%s.prof", mgmtRankInfo.rankId).c_str()); - } -#endif - break; case TaskType::HBM: status = ParseKeysHBM(TRAIN_CHANNEL_ID, trainBatchId); isContinue = trainBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || @@ -519,14 +452,6 @@ bool HybridMgmt::EvalTask(TaskType type) bool status = false; switch (type) { - case TaskType::GETINFO: - status = GetLookupAndRestore(EVAL_CHANNEL_ID, evalBatchId); - LOG(INFO) << StringFormat(MGMT + "GETINFO evalBatchId = %d", evalBatchId); - break; - case TaskType::SEND: - status = SendLookupAndRestore(EVAL_CHANNEL_ID, evalBatchId); - LOG(INFO) << StringFormat(MGMT + "SEND evalBatchId = %d", evalBatchId); - break; case TaskType::HBM: status = ParseKeysHBM(EVAL_CHANNEL_ID, evalBatchId); LOG(INFO) << StringFormat(MGMT + "HBM evalBatchId = %d", evalBatchId); @@ -548,131 +473,6 @@ bool HybridMgmt::EvalTask(TaskType type) return true; } -void HybridMgmt::GetAll2All(const int channelId, int &batchId, const string &name) -{ - auto all2all = preprocess->GetInfoVec(batchId, name, channelId, ProcessedInfo::ALL2ALL); - switch (channelId) { - case TRAIN_CHANNEL_ID: - a2aQueueForTrain->Pushv(*all2all); - break; - case EVAL_CHANNEL_ID: - a2aQueueForEval->Pushv(*all2all); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } -} - -bool HybridMgmt::GetLookupAndRestore(const int channelId, int &batchId) -{ - LOG(INFO) << StringFormat(MGMT + "start parse keys, nBatch:%d , [%d]:%d", mgmtRankInfo.nBatch, channelId, batchId); - for (const auto& embInfo: mgmtEmbInfo) { - TimeCost getAllTensorTC; - auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); - if (infoVecs == nullptr) { - LOG(INFO) << StringFormat( - MGMT + "ParseKeys infoVecs empty ! batchId:%d, channelId:%d", batchId, channelId); - return false; - } - switch (channelId) { - case TRAIN_CHANNEL_ID: - lookUpKeysQueueForTrain->Pushv({ infoVecs->back() }); - infoVecs->pop_back(); - restoreQueueForTrain->Pushv(*infoVecs); - break; - case EVAL_CHANNEL_ID: - lookUpKeysQueueForEval->Pushv({ infoVecs->back() }); - infoVecs->pop_back(); - restoreQueueForEval->Pushv(*infoVecs); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } - - if (!mgmtRankInfo.useStatic) { - GetAll2All(channelId, batchId, embInfo.name); - } - VLOG(GLOG_DEBUG) << StringFormat("getAllTensorTC(ms):%d", getAllTensorTC.ElapsedMS()); - } - batchId++; - return true; -} - -void HybridMgmt::All2AllKeys(const int channelId, const string &embName) -{ - TimeCost a2aKeysTC; - vector all2allKeys; - switch (channelId) { - case TRAIN_CHANNEL_ID: - all2allKeys = a2aQueueForTrain->WaitAndPop(); - break; - case EVAL_CHANNEL_ID: - all2allKeys = a2aQueueForEval->WaitAndPop(); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } - hdTransfer->Send(TransferChannel::ALL2ALL, all2allKeys, channelId, embName); - VLOG(GLOG_DEBUG) << StringFormat("All2AllKeysTC(ms):%d", a2aKeysTC.ElapsedMS()); -} - -void HybridMgmt::LookupKeys(const int channelId, const string &embName) -{ - TimeCost sendLookupTC; - vector lookUpKeys; - switch (channelId) { - case TRAIN_CHANNEL_ID: - lookUpKeys = lookUpKeysQueueForTrain->WaitAndPop(); - break; - case EVAL_CHANNEL_ID: - lookUpKeys = lookUpKeysQueueForEval->WaitAndPop(); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } - hdTransfer->Send(TransferChannel::LOOKUP, lookUpKeys, channelId, embName); - VLOG(GLOG_DEBUG) << StringFormat("sendLookupTC(ms):%d", sendLookupTC.ElapsedMS()); -} - -void HybridMgmt::RestoreKeys(const int channelId, const string &embName) -{ - TimeCost sendRestoreTC; - vector restore; - switch (channelId) { - case TRAIN_CHANNEL_ID: - restore = restoreQueueForTrain->WaitAndPop(); - break; - case EVAL_CHANNEL_ID: - restore = restoreQueueForEval->WaitAndPop(); - break; - default: - throw std::invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); - } - hdTransfer->Send(TransferChannel::RESTORE, restore, channelId, embName); - VLOG(GLOG_DEBUG) << StringFormat("sendRestoreTC(ms):%d", sendRestoreTC.ElapsedMS()); -} - -bool HybridMgmt::SendLookupAndRestore(const int channelId, int &batchId) -{ - for (const auto& embInfo: mgmtEmbInfo) { - TimeCost sendTensorsTC; - if (!mgmtRankInfo.useStatic) { - All2AllKeys(channelId, embInfo.name); - } - - LOG(INFO) << StringFormat( - "SendLookupAndRestore batchId:%d, name:%s, channelId:%d", - batchId, embInfo.name.c_str(), channelId - ); - - LookupKeys(channelId, embInfo.name); - RestoreKeys(channelId, embInfo.name); - VLOG(GLOG_DEBUG) << StringFormat("sendTensorsTC(ms):%d", sendTensorsTC.ElapsedMS()); - } - batchId++; - return true; -} - bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) { LOG(INFO) << StringFormat( diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 03ff37b2..70cd007c 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -8,18 +8,22 @@ #ifndef MX_REC_EMB_MGMT_H #define MX_REC_EMB_MGMT_H -#include -#include -#include #include + #include + +#include +#include +#include + #include "absl/container/flat_hash_map.h" + #include "utils/common.h" #include "utils/singleton.h" -#include "utils/task_queue.h" -#include "hd_transfer/hd_transfer.h" + #include "host_emb/host_emb.h" #include "emb_hashmap/emb_hashmap.h" +#include "hd_transfer/hd_transfer.h" #include "key_process/key_process.h" namespace MxRec { @@ -29,8 +33,6 @@ namespace MxRec { constexpr int SEND_TENSOR_TYPE_NUM = 2; enum class TaskType { - GETINFO, - SEND, HBM, DDR }; @@ -63,7 +65,7 @@ namespace MxRec { void Start(); - void InsertThreadForHBM(int mode); + void InsertThreadForHBM(); void Destroy() { @@ -72,10 +74,6 @@ namespace MxRec { } // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 isRunning = false; - restoreQueueForTrain->DestroyQueue(); - lookUpKeysQueueForTrain->DestroyQueue(); - restoreQueueForEval->DestroyQueue(); - lookUpKeysQueueForEval->DestroyQueue(); // 先发送停止信号给preprocess,用于停止查询中lookup卡住状态 preprocess->isRunning = false; @@ -127,12 +125,6 @@ namespace MxRec { unique_ptr hostEmbs {}; unique_ptr hostHashMaps {}; vector> procThreads {}; - unique_ptr>> lookUpKeysQueueForTrain; - unique_ptr>> restoreQueueForTrain; - unique_ptr>> lookUpKeysQueueForEval; - unique_ptr>> restoreQueueForEval; - unique_ptr>> a2aQueueForTrain; - unique_ptr>> a2aQueueForEval; map> evictKeyMap {}; KeyProcess *preprocess; HDTransfer *hdTransfer; @@ -145,13 +137,6 @@ namespace MxRec { bool TrainTask(TaskType type); bool EvalTask(TaskType type); - void All2AllKeys(const int channelId, const string &embName); - void LookupKeys(const int channelId, const string &embName); - void RestoreKeys(const int channelId, const string &embName); - bool GetLookupAndRestore(const int channelId, int &batchId); - void GetAll2All(const int channelId, int &batchId, const string &name); - bool SendLookupAndRestore(const int channelId, int &batchId); - void EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo); bool EndBatch(int batchId, int channelId) const; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a67f13d6..b0e9cc75 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -7,15 +7,8 @@ #include "key_process.h" -#include -#include - -#include - #include "checkpoint/checkpoint.h" #include "hd_transfer/hd_transfer.h" -#include "utils/common.h" -#include "utils/time_cost.h" using namespace std; using namespace chrono; diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 2b573831..7f8ac1df 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -15,19 +15,23 @@ #include #include #include +#include +#include -#include #include -#include +#include + +#include "ock_ctr_common/include/factory.h" #include "utils/common.h" +#include "utils/time_cost.h" #include "utils/safe_queue.h" -#include "utils/task_queue.h" #include "host_emb/host_emb.h" #include "emb_table/emb_table.h" + #include "feature_admit_and_evict.h" -#include "ock_ctr_common/include/factory.h" + namespace MxRec { using namespace std; diff --git a/src/core/utils/task_queue.h b/src/core/utils/task_queue.h deleted file mode 100644 index d44e2a11..00000000 --- a/src/core/utils/task_queue.h +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: task queue module - * Author: MindX SDK - * Create: 2022 - * History: NA - */ - -#ifndef TASK_QUEUE_H -#define TASK_QUEUE_H - -#include -#include -#include -#include - -namespace MxRec { -namespace Common { -template class TaskQueue { -public: - TaskQueue() = default; - - ~TaskQueue() = default; - - TaskQueue(TaskQueue const & other) - { - std::lock_guard lk(other.mut); - dataQueue = other.dataQueue; - } - - TaskQueue &operator = (TaskQueue const & other) - { - if (this == &other) { - return *this; - } - std::lock_guard lk(other.mut); - dataQueue = other.dataQueue; - return *this; - } - - void Pushv(T &t) - { - std::lock_guard lk(mut); - dataQueue.push_back(std::move(t)); - dataCond.notify_one(); - } - - void Pushv(T &&t) - { - std::lock_guard lk(mut); - dataQueue.emplace_back(t); - dataCond.notify_one(); - } - - T WaitAndPop() - { - std::unique_lock lk(mut); - dataCond.wait(lk, [this] { - if (!finished) { - return !dataQueue.empty(); - } else { - return true; - } - }); - T res; - if (finished) { - return res; - } - res = dataQueue.front(); - dataQueue.pop_front(); - return res; - } - - void DestroyQueue() - { - finished = true; - dataCond.notify_one(); - } - - bool Empty() const - { - std::lock_guard lk(mut); - return dataQueue.empty(); - } - - size_t Size() const - { - std::lock_guard lk(mut); - return dataQueue.size(); - } - -private: - mutable std::mutex mut; - std::list dataQueue; - std::condition_variable dataCond; - bool finished = false; -}; -} -} - - -#endif -- Gitee From e193b51ac8318c36cb8666c780f0ce5702e2032a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 19 Aug 2023 11:51:00 +0800 Subject: [PATCH 266/551] Match-id-1d10eaa31fbba7f69e995edff6af38415364a4ca --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 0ba105ac..43817194 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -194,6 +194,7 @@ bool HybridMgmt::Load(const string& loadPath) loadData.hostEmbs = hostEmbs->GetHostEmbs(); loadCkpt.LoadModel(loadPath, loadData, mgmtRankInfo, mgmtEmbInfo, loadFeatures); if (!mgmtRankInfo.noDDR && !LoadMatchesDDRSetup(loadData)) { + preprocess->LoadSaveUnlock(); return false; } @@ -219,6 +220,7 @@ bool HybridMgmt::Load(const string& loadPath) Start(); } #endif + preprocess->LoadSaveUnlock(); return true; } -- Gitee From 4553b49e510aa86aa5d71691a83ae5e0668c0a57 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 19 Aug 2023 16:52:42 +0800 Subject: [PATCH 267/551] Match-id-0257be2694ef9af58d11554ca06f1cd6a59a3839 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 43817194..48c8d53a 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -220,7 +220,6 @@ bool HybridMgmt::Load(const string& loadPath) Start(); } #endif - preprocess->LoadSaveUnlock(); return true; } -- Gitee From eccafae267b61fef8f703d8623f1f2b59dbbde6c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 21 Aug 2023 15:22:03 +0800 Subject: [PATCH 268/551] Match-id-51a115d4d895dce62785fe69a6355c95e0f21292 --- build/build.sh | 1 + build/build_tf1.sh | 23 +++++++++++++++++++++++ build/build_tf2.sh | 2 -- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/build/build.sh b/build/build.sh index 5f77ec2c..b9544a58 100644 --- a/build/build.sh +++ b/build/build.sh @@ -97,6 +97,7 @@ fi if [ "$(uname -m)" = "aarch64" ] then echo "-----Build gen tar -----" + bash ${ROOT_DIR}/build/build_tf1.sh bash ${ROOT_DIR}/build/build_tf2.sh gen_tar_file echo "-----Build gen tar finished-----" diff --git a/build/build_tf1.sh b/build/build_tf1.sh index 481c4429..6d7deabf 100644 --- a/build/build_tf1.sh +++ b/build/build_tf1.sh @@ -20,6 +20,13 @@ then deactivate tf1_env fi +if [ "$(uname -m)" = "aarch64" ] +then + source /opt/buildtools/tf1_env/bin/activate + tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core + deactivate tf1_env +fi + VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml get_version() { if [ -f "$VERSION_FILE" ]; then @@ -145,4 +152,20 @@ then gen_wheel_file "${ROOT_DIR}"/tf1_whl deactivate tf1_env echo "-----Build tf1 finished-----" +fi + +if [ "$(uname -m)" = "aarch64" ] +then + compile_securec + + echo "-----Build AccCTR -----" + compile_acc_ctr_so_file + + echo "-----Build Start tf1 -----" + source /opt/buildtools/tf1_env/bin/activate + compile_so_file "${tf1_path}" + collect_so_file + gen_wheel_file "${ROOT_DIR}"/tf1_whl + deactivate tf1_env + echo "-----Build tf1 finished-----" fi \ No newline at end of file diff --git a/build/build_tf2.sh b/build/build_tf2.sh index e9e7d05c..a42aeeb2 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -168,6 +168,4 @@ then deactivate tf2_env echo "-----Build tf2 finished -----" - gen_tar_file - echo "-----Build gen tar finished-----" fi \ No newline at end of file -- Gitee From eca2b080799fa4e98f8f5fc23d700a52cdff1fb0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 21 Aug 2023 16:38:55 +0800 Subject: [PATCH 269/551] Match-id-d804961cb7140fc9f3ed16470e1e7ef6fb19cfbb --- src/core/hd_transfer/hd_transfer.cpp | 23 +++++++++++++++-------- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 27 ++++++++++++++++++++++----- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 ++ src/core/key_process/key_process.cpp | 23 ++++++++++++++++------- 4 files changed, 55 insertions(+), 20 deletions(-) diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 8cc92597..8b889486 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -36,14 +36,21 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) } aclDatasets[embInfo.name] = acltdtCreateDataset(); } - const char* timeoutEnv = getenv("AclTimeout"); - if (timeoutEnv != nullptr) { - int32_t timeoutEnvCast = static_cast(std::atoi(timeoutEnv)); - VLOG(GLOG_DEBUG) << StringFormat("timeoutEnv:%d", timeoutEnvCast); - if (timeoutEnvCast > INT32_MAX || timeoutEnvCast < -1) { - LOG(WARNING) << StringFormat("AclTimeout=%d is not valid", timeoutEnvCast); - } else { - timeout = timeoutEnvCast; + const int defaultAclTimeout = -1; + this->timeout = defaultAclTimeout; + const char *envTimeout = getenv("AclTimeout"); + if (envTimeout != nullptr) { + try { + int32_t tmp = std::stoi(envTimeout); + if (tmp >= -1 && tmp <= INT32_MAX) { + this->timeout = tmp; + LOG(INFO) << StringFormat("Succeed to parse ${env:AclTimeout}: %d", tmp); + } else { + LOG(ERROR) << StringFormat("Failed to parse ${env:AclTimeout}: %d, expected in (0, INT32_MAX)", tmp); + } + } catch (const std::invalid_argument &e) { + LOG(ERROR) << StringFormat("Failed to parse ${env:AclTimeout}: %s, expected a integer, set to default: %d", + envTimeout, defaultAclTimeout); } } VLOG(GLOG_DEBUG) << StringFormat("hd transfer timeout:%d", timeout); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 6252cc8d..a7ffdb50 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -46,11 +46,10 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& LOG(INFO) << StringFormat("config MAX_UNIQUE_THREAD_NUM:%d", num); } - if (getenv("FAST_UNIQUE") != nullptr) { - bool isFastUnique = std::atoi(getenv("FAST_UNIQUE")); - PerfConfig::fastUnique = isFastUnique; - LOG(INFO) << StringFormat("config FAST_UNIQUE:%d", PerfConfig::fastUnique); - } + const int defaultFastUnique = false; + PerfConfig::fastUnique = defaultFastUnique; + const char* envFastUnique = getenv("FAST_UNIQUE"); + HybridMgmt::CheckFastUnique(envFastUnique); preprocess = Singleton::GetInstance(); preprocess->Initialize(rankInfo, embInfos, thresholdValues, seed); @@ -59,6 +58,24 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& return true; } +void HybridMgmt::CheckFastUnique(const char *envFastUnique) +{ + if (envFastUnique != nullptr) { + try { + int tmp = std::stoi(envFastUnique); + if (tmp == 0 || tmp == 1) { + PerfConfig::fastUnique = (tmp == 1) ? true : false; + LOG(INFO) << StringFormat("Succeed to parse ${env:FAST_UNIQUE}: %d.", PerfConfig::fastUnique); + } else { + LOG(ERROR) << StringFormat("Invalid ${env:FAST_UNIQUE}: %s, which should be an 0 or 1.", envFastUnique); + } + } catch (const std::invalid_argument &e) { + LOG(ERROR) << + StringFormat("Failed to parse ${env:FAST_UNIQUE}: %s, which should be an integer.", envFastUnique); + } + } +} + void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfos) { #ifndef GTEST diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 70cd007c..d084122a 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -112,6 +112,8 @@ namespace MxRec { private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed); + + void CheckFastUnique(const char* envFastUnique); void InitRankInfo(RankInfo& rankInfo, const vector& embInfos); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index b0e9cc75..0f404f81 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -29,13 +29,22 @@ inline vector Count2Start(const vector& count) void KeyProcess::SetupHotEmbUpdateStep() { - const char* env = getenv("HOT_EMB_UPDATE_STEP"); - if (env == nullptr) { - hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; - } else { - hotEmbUpdateStep = stoi(env); - if (hotEmbUpdateStep == 0) { - hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; + const auto maxUpdateStep = 1000; + this->hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; + const char *envUpdateStep = getenv("HOT_EMB_UPDATE_STEP"); + if (envUpdateStep != nullptr) { + try { + int tmp = std::stoi(envUpdateStep); + if (tmp >= 1 && tmp <= maxUpdateStep) { + this->hotEmbUpdateStep = tmp; + LOG(INFO) << StringFormat("Succeed to parse ${env:HOT_EMB_UPDATE_STEP}: %d.", this->hotEmbUpdateStep); + } else { + LOG(ERROR) << StringFormat("${env:HOT_EMB_UPDATE_STEP}: %d should be in [1, 1000], set default: %d.", + tmp, HOT_EMB_UPDATE_STEP_DEFAULT); + } + } catch (const std::invalid_argument &e) { + LOG(ERROR) << StringFormat("Failed to parse ${env:HOT_EMB_UPDATE_STEP}: %s, set default: %d.", + envUpdateStep, HOT_EMB_UPDATE_STEP_DEFAULT); } } } -- Gitee From 5bcc8f0c3721375b3245d66422deebd06bd8c58c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 21 Aug 2023 19:44:16 +0800 Subject: [PATCH 270/551] Match-id-342322cbf9026bd46b878c89b091176f3d9d2c7f --- mx_rec/core/embedding.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index a3c153c6..abbecfaf 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -23,8 +23,7 @@ from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPA from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ - get_host_pipeline_ops, get_use_dynamic_expansion, \ - set_modify_graph, insert_removing_var_list + get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set from mx_rec.validator.validator import ClassValidator, StringValidator @@ -456,10 +455,14 @@ class SparseEmbedding: self.check_mode(MxRecMode.ASC) is_training = kwargs.get("is_train") self.check_and_format_lookup_params(ids, send_count, is_training) - self.same_table_send_count += send_count if send_count is not None and is_training else 0 if is_asc_frozen() and is_training: raise RuntimeError(f"Cannot build new sparse forward graph after emb cache management was built.") + # record send count + eval_mode = not is_training and get_training_mode_channel_id(True) is None + if is_training or eval_mode or "train_and_evaluate" in get_bool_gauge_set(): + self.same_table_send_count += send_count if send_count is not None else 0 + # create feature spec feature_spec = get_feature_spec(self.table_name, kwargs.get("access_and_evict_config")) feature_spec.set_feat_attribute(ids, is_training) @@ -473,7 +476,6 @@ class SparseEmbedding: self.register_anchor_attribute(anchor_ids, feature_spec, kwargs) # record multi lookup info - eval_mode = not is_training and get_training_mode_channel_id(True) is None ids_lookup_name = feature_spec.name + "_lookup_ids" # set in train mode, train and eval mode, eval mode if is_training or eval_mode: -- Gitee From d26b3a6b79662cf801e37ffce6cea31da063dfec Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 21 Aug 2023 19:58:09 +0800 Subject: [PATCH 271/551] Match-id-d96f9db35e9190a191eb15d0f17573f33074aa02 --- src/core/key_process/key_process.cpp | 18 +----------------- src/core/utils/common.cpp | 22 ++++++++++++++++++++++ src/core/utils/common.h | 1 + src/ops_tf/hybrid_dataset_ops.cpp | 19 ++----------------- 4 files changed, 26 insertions(+), 34 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 0f404f81..af297fd4 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -124,24 +124,8 @@ int KeyProcess::Start() #endif KeyProcessTask(channel, id); }; // for clean code - int threadNum; + int threadNum = GetThreadNumEnv(); for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { - const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); - if (threadNumEnv != nullptr) { - try { - threadNum = std::stoi(threadNumEnv); - } catch (const std::invalid_argument& e) { - threadNum = KEY_PROCESS_THREAD; - LOG(WARNING) << StringFormat("error value of threadNum, use default KEY_PROCESS_THREAD: %d", - threadNum); - } - if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { - throw runtime_error(StringFormat("%d is not valid", threadNum)); - } - } else { - threadNum = KEY_PROCESS_THREAD; - LOG(INFO) << StringFormat("use default KEY_PROCESS_THREAD: %d", threadNum); - } LOG(INFO) << StringFormat(KEY_PROCESS "key process thread num: %d", threadNum); for (int id = 0; id < threadNum; ++id) { procThreads.emplace_back( diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 5483b6eb..dd15e8b2 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -158,4 +158,26 @@ namespace MxRec { } return isCombine; } + + int GetThreadNumEnv() + { + int threadNum = 0; + const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); + if (threadNumEnv != nullptr) { + try { + threadNum = std::stoi(threadNumEnv); + } catch (const std::invalid_argument& e) { + threadNum = KEY_PROCESS_THREAD; + LOG(INFO) << StringFormat("error value of threadNum, use default KEY_PROCESS_THREAD: %d", + threadNum); + } + if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { + throw runtime_error(StringFormat("%d is not valid", threadNum)); + } + } else { + threadNum = KEY_PROCESS_THREAD; + LOG(INFO) << StringFormat("use default KEY_PROCESS_THREAD: %d", threadNum); + } + return threadNum; + } } // end namespace MxRec diff --git a/src/core/utils/common.h b/src/core/utils/common.h index e1c7fc75..a38a797f 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -111,6 +111,7 @@ namespace MxRec { string GetChipName(int devID); bool GetCombineSwitch(); + int GetThreadNumEnv(); namespace UBSize { const int ASCEND910_PREMIUM_A = 262144; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 4891a126..06000536 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -153,16 +153,7 @@ public: } batchIdsInfo.at(channelId) = 0; - const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); - if (threadNumEnv != nullptr) { - threadNum = static_cast(*threadNumEnv) - static_cast('0'); - if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { - throw runtime_error(StringFormat("%d is not valid", threadNum)); - } - } else { - threadNum = KEY_PROCESS_THREAD; - } - + threadNum = GetThreadNumEnv(); auto keyProcess = Singleton::GetInstance(); if (!keyProcess->isRunning) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); @@ -367,13 +358,7 @@ public: } batchIdsInfo.at(channelId) = 0; - const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); - if (threadNumEnv != nullptr) { - threadNum = static_cast(*threadNumEnv) - static_cast('0'); - if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { - throw runtime_error(StringFormat("%d is not valid", threadNum)); - } - } + threadNum = GetThreadNumEnv(); auto keyProcess = Singleton::GetInstance(); if (!keyProcess->isRunning) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); -- Gitee From 2be338b3eebe315705c33d96f4238d1d4c0793ac Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 22 Aug 2023 18:05:51 +0800 Subject: [PATCH 272/551] Match-id-1676d9e88c632b637b1da16bf8d4ecd86961090e --- src/core/key_process/key_process.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a67f13d6..6a16e3df 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -456,7 +456,7 @@ void KeyProcess::GlobalUnique(const keys_t& lookupKeys, keys_t& uniqueKeys, vect for (size_t i = 0; i < lookupKeys.size(); ++i) { int64_t key = lookupKeys[i]; - if (key == -1) { + if (rankInfo.useStatic && key == -1) { continue; } auto result = umap.find(key); -- Gitee From 68ac2a1ac54cbbf658c6dc3f7aeb9a05b800f9a8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 23 Aug 2023 03:48:43 +0800 Subject: [PATCH 273/551] Match-id-e34d14aa6948cf6722fa2283c44b54460438109e --- mx_rec/core/embedding.py | 2 ++ src/core/emb_hashmap/emb_hashmap.cpp | 5 +++- src/core/emb_hashmap/emb_hashmap.h | 2 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 32 ++++++++++++++---------- src/core/key_process/key_process.cpp | 37 +++++++++++++++++++++++++--- src/core/key_process/key_process.h | 6 +++-- 6 files changed, 63 insertions(+), 21 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index a3c153c6..4e065ac1 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -691,6 +691,8 @@ class SparseEmbedding: def grad(lookup_diff): logging.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) + self.unique_id_offsets = unique_keys + self.unique_id_offsets_position = restore_vector_second unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, unique_embeddings_shape[0]) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 5e3b3dbb..6f626cea 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -46,7 +46,7 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, } void EmbHashMap::Process(const string& embName, vector& keys, size_t iBatch, - vector& tmpDataOut, int channelId) + vector& tmpDataOut, int channelId, vector& offsetsOut) { #ifndef GTEST EASY_FUNCTION(profiler::colors::Pink) @@ -69,6 +69,9 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t swapId++; EASY_BLOCK("hostHashMaps->tdt") +// std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), offsetsOut.begin()); + std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), std::back_inserter(offsetsOut)); + auto lookUpVecSize = static_cast(embHashMap.lookUpVec.size()); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index e2216e43..7c4a2d52 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -24,7 +24,7 @@ namespace MxRec { void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); void Process(const string& embName, std::vector& keys, size_t iBatch, - vector& tmpDataOut, int channelId); + vector& tmpDataOut, int channelId, vector& offsetsOut); void FindAndUpdateOffset(const string& embName, vector& keys, size_t currentBatchId, size_t keepBatchId, int channelId); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 0ba105ac..00761034 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -793,27 +793,33 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, auto infoVecs = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); VLOG(GLOG_DEBUG) << StringFormat("getTensorsTC(ms):%d", getTensorsTC.ElapsedMS()); - if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { - TimeCost sendUnikeysSyncTC; - hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embName); - infoVecs->pop_back(); - VLOG(GLOG_DEBUG) << StringFormat("sendUnikeysSyncTC(ms):%d", sendUnikeysSyncTC.ElapsedMS()); - - TimeCost sendRestoreVecSecSyncTC; - hdTransfer->Send(TransferChannel::RESTORE_SECOND, { infoVecs->back() }, channelId, embName); - infoVecs->pop_back(); - VLOG(GLOG_DEBUG) << StringFormat("sendRestoreVecSecSyncTC(ms):%d", sendRestoreVecSecSyncTC.ElapsedMS()); - } - TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embName); VLOG(GLOG_DEBUG) << StringFormat("sendRestoreSyncTC(ms):%d", sendRestoreSyncTC.ElapsedMS()); vector tmpData; + vector offsetsOut; TimeCost hostHashMapProcessTC; - hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData, channelId); + hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData, channelId, offsetsOut); + VLOG(GLOG_DEBUG) << StringFormat("hostHashMapProcessTC(ms):%d", hostHashMapProcessTC.ElapsedMS()); + if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { + vector uniqueKeys; + vector restoreVecSec; + preprocess->GlobalUnique(offsetsOut, uniqueKeys, restoreVecSec); + + TimeCost sendUnikeysSyncTC; + hdTransfer->Send(TransferChannel::UNIQKEYS, + { mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : Vec2TensorI32(uniqueKeys) }, + channelId, embName); + VLOG(GLOG_DEBUG) << StringFormat("sendUnikeysSyncTC(ms):%d", sendUnikeysSyncTC.ElapsedMS()); + + TimeCost sendRestoreVecSecSyncTC; + hdTransfer->Send(TransferChannel::RESTORE_SECOND, { Vec2TensorI32(restoreVecSec) }, channelId, embName); + VLOG(GLOG_DEBUG) << StringFormat("sendRestoreVecSecSyncTC(ms):%d", sendRestoreVecSecSyncTC.ElapsedMS()); + } + TimeCost sendTensorsTC; hdTransfer->Send(TransferChannel::LOOKUP, { tmpData.front() }, channelId, embName); tmpData.erase(tmpData.begin()); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 6a16e3df..eb7ce157 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -426,7 +426,10 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe tensors->push_back(Vec2TensorI32(hotPos)); } - PushGlobalUniqueTensors(tensors, lookupKeys, channel); + if (rankInfo.noDDR) { //HBM + PushGlobalUniqueTensors(tensors, lookupKeys, channel); + tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); + } PushResult(batch, move(tensors), lookupKeys); VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); @@ -442,13 +445,39 @@ void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tenso tensors->push_back(Vec2TensorI32(restoreVecSec)); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : Vec2TensorI32(uniqueKeys)); } +} - if (rankInfo.noDDR) { - tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); + +void KeyProcess::GlobalUnique(const keys_t& lookupKeys, keys_t& uniqueKeys, vector& restoreVecSec) +{ + absl::flat_hash_map umap; + restoreVecSec.resize(lookupKeys.size(), -1); + int32_t length = 0; + + for (size_t i = 0; i < lookupKeys.size(); ++i) { + int64_t key = lookupKeys[i]; + if (rankInfo.useStatic && key == -1) { + continue; + } + auto result = umap.find(key); + if (result == umap.end()) { + uniqueKeys.push_back(lookupKeys[i]); + umap[key] = length; + restoreVecSec[i] = length; + length++; + } else { + restoreVecSec[i] = result->second; + } + } + + if (rankInfo.useStatic) { + uniqueKeys.resize(lookupKeys.size(), -1); + } else { + restoreVecSec.erase(std::remove(restoreVecSec.begin(), restoreVecSec.end(), -1), restoreVecSec.end()); } } -void KeyProcess::GlobalUnique(const keys_t& lookupKeys, keys_t& uniqueKeys, vector& restoreVecSec) +void KeyProcess::GlobalUnique(const vector& lookupKeys, vector& uniqueKeys, vector& restoreVecSec) { absl::flat_hash_map umap; restoreVecSec.resize(lookupKeys.size(), -1); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 2b573831..75a280e6 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -86,6 +86,10 @@ namespace MxRec { void SetupHotEmbUpdateStep(); + void GlobalUnique(const keys_t& lookupKeys, keys_t& uniqueKeys, vector& restoreVecSec); + + void GlobalUnique(const vector& lookupKeys, vector& uniqueKeys, vector& restoreVecSec); + bool isRunning { false }; inline bool hasEmbName(const string &emb_name) @@ -132,8 +136,6 @@ namespace MxRec { void GetUniqueConfig(UniqueConf& uniqueConf); - void GlobalUnique(const keys_t& lookupKeys, keys_t& uniqueKeys, vector& restoreVecSec); - void InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, const unique_ptr & batch, UniquePtr& unique); -- Gitee From b96a9a2de3055bcb81b3338acadbebc6a7bb6a51 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 23 Aug 2023 15:32:44 +0800 Subject: [PATCH 274/551] Match-id-3cf3a4a068c99703da0e031a1413d64d4e16064f --- mx_rec/core/embedding.py | 20 ++++++++++++-------- mx_rec/graph/merge_lookup.py | 2 +- mx_rec/graph/modifier.py | 4 ++-- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index abbecfaf..16745e78 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -531,13 +531,16 @@ class SparseEmbedding: """ same_table_tensor_list = [] for feat_spec in same_table_feature_spec: - batch_tensor_dict = kwargs.get("batch") if not self.modify_graph else \ - kwargs.get("feature_spec_name_ids_dict") + feature_spec_tensor_dict = kwargs.get("batch") + modify_graph_tensor_dict = kwargs.get("feature_spec_name_ids_dict") + batch_tensor_dict = feature_spec_tensor_dict if not self.modify_graph else modify_graph_tensor_dict if batch_tensor_dict is None: - raise KeyError(f"The tensor dict of batch does not exist in kwargs, " - f"and modify graph is `{self.modify_graph}`.") - tensor = batch_tensor_dict.get(feat_spec.index_key) if not self.modify_graph else \ - batch_tensor_dict.get(feat_spec.name) + raise KeyError(f"The tensor dict of batch does not exist in kwargs, and modify graph " + f"is `{self.modify_graph}`.") + + feature_spec_tensor = batch_tensor_dict.get(feat_spec.index_key) + modify_graph_tensor = batch_tensor_dict.get(feat_spec.name) + tensor = feature_spec_tensor if not self.modify_graph else modify_graph_tensor if tensor is None: tensor_key = feat_spec.index_key if not self.modify_graph else feat_spec.name raise KeyError(f"Key `{tensor_key}` does not exist in batch_tensor_dict.") @@ -683,8 +686,9 @@ class SparseEmbedding: if kwargs.get("multi_lookup"): lookup_result = tf.reshape(embeddings, [-1, self.scalar_emb_size]) else: - tensor = kwargs.get("batch").get(feature_spec.index_key) \ - if not self.modify_graph else kwargs.get("ids") + feature_spec_tensor = kwargs.get("batch").get(feature_spec.index_key) + modify_graph_tensor = kwargs.get("ids") + tensor = feature_spec_tensor if not self.modify_graph else modify_graph_tensor if tensor is None: raise KeyError(f"key or ids does not exist in batch, now modify graph is {self.modify_graph}.") dest_shape = array_ops.concat([array_ops.shape(tensor), [self.scalar_emb_size]], 0) diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py index b0149587..11bc09b5 100644 --- a/mx_rec/graph/merge_lookup.py +++ b/mx_rec/graph/merge_lookup.py @@ -57,7 +57,7 @@ def do_merge_lookup(is_train: bool = True): feature_spec_name_ids_dict[feature_spec.name] = cutting_point if sub_cutting_points_dict.get(is_training) is None: sub_cutting_points_dict[is_training] = [] - sub_cutting_points_dict[is_training].append(cutting_point) + sub_cutting_points_dict.get(is_training).append(cutting_point) # merge or restore lookup sub_cutting_point_list = sub_cutting_points_dict.get(is_train) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index efb06af2..4d30a7e5 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -311,8 +311,8 @@ def get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapt target_op = get_dataset_op(get_next_op) except (ValueError, TypeError, RuntimeError) as err: logging.warning("The dataset op was not found, the error is `%s`. Start to traverse the operations.", err) - dataset_op_list = [op for op in tf.compat.v1.get_default_graph().get_operations() - if ANCHOR_DATASET_NAME in op.name] + graph = tf.compat.v1.get_default_graph() + dataset_op_list = [op for op in graph.get_operations() if ANCHOR_DATASET_NAME in op.name] logging.debug("In get_src_dataset function, current mode(train: True, eval: False): %s, dataset_op_list: %s.", is_training, dataset_op_list) -- Gitee From e2f9319055aa9588594b39b85491f176c7d2dfb0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 23 Aug 2023 17:36:21 +0800 Subject: [PATCH 275/551] Match-id-60ba2ef791f2fea2182050fedb97fd00534b9f17 --- mx_rec/constants/constants.py | 1 + mx_rec/core/embedding.py | 12 ++++- src/core/emb_hashmap/emb_hashmap.cpp | 1 - src/core/key_process/key_process.cpp | 67 +++------------------------- src/core/key_process/key_process.h | 34 ++++++++++++-- 5 files changed, 49 insertions(+), 66 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 222862d9..66a3a606 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -11,6 +11,7 @@ ASCEND_CUTTING_POINT = "ASCEND_CUTTING_POINT" ASCEND_SPARSE_LOOKUP_ENTRANCE = "ASCEND_SPARSE_LOOKUP_ENTRANCE" ASCEND_SPARSE_LOOKUP_ID_OFFSET = "ASCEND_SPARSE_LOOKUP_ID_OFFSET" ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR = "ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR" +ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS = "ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS" ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT = "ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT" # dynamic shape identity ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX = "ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX" diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 4e065ac1..7abca307 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -709,7 +709,12 @@ class SparseEmbedding: raise ZeroDivisionError("Rank size cannot be zero.") from exp if use_dynamic_expansion: - update_grad = local_grad + if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + update_grad = tf.compat.v1.unsorted_segment_sum(local_grad, + restore_vector_second, + array_ops.shape(unique_keys)[0]) + else: + update_grad = local_grad else: if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, @@ -738,6 +743,11 @@ class SparseEmbedding: if is_training and is_table_name_valid: tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) + if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, unique_keys) + else: + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) + logging.debug(f"feature spec mode, table_name: {self.table_name}, " f"ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 6f626cea..a13ae0f1 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -69,7 +69,6 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t swapId++; EASY_BLOCK("hostHashMaps->tdt") -// std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), offsetsOut.begin()); std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), std::back_inserter(offsetsOut)); auto lookUpVecSize = static_cast(embHashMap.lookUpVec.size()); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index eb7ce157..5cb730d1 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -368,7 +368,11 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat tensors->push_back(Vec2TensorI32(uniqueInfo.hotPos)); } - PushGlobalUniqueTensors(move(tensors), uniqueInfo.all2AllInfo.keyRecv, channel); + if (rankInfo.noDDR) { + PushGlobalUniqueTensors(move(tensors), uniqueInfo.all2AllInfo.keyRecv, channel); + tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueInfo.all2AllInfo.keyRecv) : + Vec2TensorI32(uniqueInfo.all2AllInfo.keyRecv)); + } TimeCost pushResultTC; PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv); @@ -426,7 +430,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe tensors->push_back(Vec2TensorI32(hotPos)); } - if (rankInfo.noDDR) { //HBM + if (rankInfo.noDDR) { PushGlobalUniqueTensors(tensors, lookupKeys, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); } @@ -447,65 +451,6 @@ void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tenso } } - -void KeyProcess::GlobalUnique(const keys_t& lookupKeys, keys_t& uniqueKeys, vector& restoreVecSec) -{ - absl::flat_hash_map umap; - restoreVecSec.resize(lookupKeys.size(), -1); - int32_t length = 0; - - for (size_t i = 0; i < lookupKeys.size(); ++i) { - int64_t key = lookupKeys[i]; - if (rankInfo.useStatic && key == -1) { - continue; - } - auto result = umap.find(key); - if (result == umap.end()) { - uniqueKeys.push_back(lookupKeys[i]); - umap[key] = length; - restoreVecSec[i] = length; - length++; - } else { - restoreVecSec[i] = result->second; - } - } - - if (rankInfo.useStatic) { - uniqueKeys.resize(lookupKeys.size(), -1); - } else { - restoreVecSec.erase(std::remove(restoreVecSec.begin(), restoreVecSec.end(), -1), restoreVecSec.end()); - } -} - -void KeyProcess::GlobalUnique(const vector& lookupKeys, vector& uniqueKeys, vector& restoreVecSec) -{ - absl::flat_hash_map umap; - restoreVecSec.resize(lookupKeys.size(), -1); - int32_t length = 0; - - for (size_t i = 0; i < lookupKeys.size(); ++i) { - int64_t key = lookupKeys[i]; - if (rankInfo.useStatic && key == -1) { - continue; - } - auto result = umap.find(key); - if (result == umap.end()) { - uniqueKeys.push_back(lookupKeys[i]); - umap[key] = length; - restoreVecSec[i] = length; - length++; - } else { - restoreVecSec[i] = result->second; - } - } - - if (rankInfo.useStatic) { - uniqueKeys.resize(lookupKeys.size(), -1); - } else { - restoreVecSec.erase(std::remove(restoreVecSec.begin(), restoreVecSec.end(), -1), restoreVecSec.end()); - } -} - vector KeyProcess::GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss) { diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 75a280e6..23dcbc64 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -86,9 +86,37 @@ namespace MxRec { void SetupHotEmbUpdateStep(); - void GlobalUnique(const keys_t& lookupKeys, keys_t& uniqueKeys, vector& restoreVecSec); - - void GlobalUnique(const vector& lookupKeys, vector& uniqueKeys, vector& restoreVecSec); + template + void GlobalUnique(T& lookupKeys, T& uniqueKeys, vector& restoreVecSec) + { + absl::flat_hash_map umap; + restoreVecSec.resize(lookupKeys.size(), -1); + int32_t length = 0; + + for (size_t i = 0; i < lookupKeys.size(); ++i) { + int64_t key = lookupKeys[i]; + if (rankInfo.useStatic && key == -1) { + continue; + } + auto result = umap.find(key); + if (result == umap.end()) { + uniqueKeys.push_back(lookupKeys[i]); + umap[key] = length; + restoreVecSec[i] = length; + length++; + } else { + restoreVecSec[i] = result->second; + } + } + + if (rankInfo.useStatic) { + if (rankInfo.useDynamicExpansion) { + uniqueKeys.resize(lookupKeys.size(), 0); + } else { + uniqueKeys.resize(lookupKeys.size(), -1); + } + } + } bool isRunning { false }; -- Gitee From 4d973189ab68a43e58e618e6d95d7e027b275fbf Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 24 Aug 2023 15:00:41 +0800 Subject: [PATCH 276/551] Match-id-5829773ae06f4bb074ddf7116130da32ef0370f6 --- mx_rec/core/embedding.py | 11 +++-------- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 15 +++++---------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 7abca307..014f6990 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -608,9 +608,6 @@ class SparseEmbedding: is_train: name: not in use modify_graph: if True, the original graph will be modified before building a Session instance - - Returns: Tensor for lookup result - """ logging.debug(f"Enter ASC Branch, looking up with FeatureSpec.") self.check_mode(MxRecMode.ASC) @@ -691,8 +688,7 @@ class SparseEmbedding: def grad(lookup_diff): logging.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) - self.unique_id_offsets = unique_keys - self.unique_id_offsets_position = restore_vector_second + unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, unique_embeddings_shape[0]) @@ -704,14 +700,13 @@ class SparseEmbedding: local_grad = self._get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: try: - local_grad = local_grad / get_rank_size() + local_grad = local_grad / get_rank_sizemin () except ZeroDivisionError as exp: raise ZeroDivisionError("Rank size cannot be zero.") from exp if use_dynamic_expansion: if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: - update_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - restore_vector_second, + update_grad = tf.compat.v1.unsorted_segment_sum(local_grad, restore_vector_second, array_ops.shape(unique_keys)[0]) else: update_grad = local_grad diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 00761034..b3a26728 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -779,8 +779,7 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut) { - TimeCost getAndSendTensorsTC; - TimeCost getTensorsTC; + TimeCost getAndSendTensorsTC, getTensorsTC; auto& embHashMap = hostHashMaps->embHashMaps.at(embName); if (iBatch == 0) { embHashMap.SetStartCount(); @@ -801,18 +800,15 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, vector offsetsOut; TimeCost hostHashMapProcessTC; hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData, channelId, offsetsOut); - VLOG(GLOG_DEBUG) << StringFormat("hostHashMapProcessTC(ms):%d", hostHashMapProcessTC.ElapsedMS()); if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { - vector uniqueKeys; - vector restoreVecSec; + vector uniqueKeys, restoreVecSec; preprocess->GlobalUnique(offsetsOut, uniqueKeys, restoreVecSec); TimeCost sendUnikeysSyncTC; - hdTransfer->Send(TransferChannel::UNIQKEYS, - { mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : Vec2TensorI32(uniqueKeys) }, - channelId, embName); + hdTransfer->Send(TransferChannel::UNIQKEYS, { mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : + Vec2TensorI32(uniqueKeys) }, channelId, embName); VLOG(GLOG_DEBUG) << StringFormat("sendUnikeysSyncTC(ms):%d", sendUnikeysSyncTC.ElapsedMS()); TimeCost sendRestoreVecSecSyncTC; @@ -836,8 +832,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch LOG(WARNING) << StringFormat( MGMT + "embName %s[%d]%d,iBatch:%d freeSize not enough, %d", embName.c_str(), channelId, - batchId, iBatch, lookupKeys.size() - ); + batchId, iBatch, lookupKeys.size()); return false; } return true; -- Gitee From aab1d0ad17788c2a66bd4aea3e66b6fdcb5c56bc Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 24 Aug 2023 15:22:31 +0800 Subject: [PATCH 277/551] Match-id-cb4665b5c15dfa67fe09b6ed789abf7d6a572baa --- src/core/emb_hashmap/emb_hashmap.cpp | 74 +++++++--- src/core/emb_hashmap/emb_hashmap.h | 5 + src/core/hd_transfer/hd_transfer.cpp | 39 +++-- src/core/hd_transfer/hd_transfer.h | 4 - src/core/host_emb/host_emb.cpp | 56 +++++--- src/core/host_emb/host_emb.h | 4 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 207 +++++++++++++++++++++------ src/core/hybrid_mgmt/hybrid_mgmt.h | 5 - src/core/key_process/key_process.cpp | 22 ++- src/core/utils/common.h | 5 +- 10 files changed, 312 insertions(+), 109 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 5e3b3dbb..00f72d34 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -45,6 +45,12 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, #endif } +/// DDR模型下处理特征的offset、swap信息等 +/// \param embName 表名 +/// \param keys 查询向量 +/// \param iBatch 预取数据处理计数 +/// \param tmpDataOut 临时向量 +/// \param channelId 通道索引(训练/推理) void EmbHashMap::Process(const string& embName, vector& keys, size_t iBatch, vector& tmpDataOut, int channelId) { @@ -55,10 +61,11 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t embHashMap.oldSwap.clear(); embHashMap.maxOffsetOld = embHashMap.maxOffset; - auto keepBatch = swapId - iBatch; + auto keepBatch = swapId - iBatch; // 处理batch的次数,多个预取一起处理算一次 bool findOffsetV2 = getenv("FIND_OFFSET_V2") != nullptr; VLOG(GLOG_DEBUG) << StringFormat("FindOffset version:%d", findOffsetV2); + // 找到所有key的偏移;dev和host需要交换的位置 if (findOffsetV2) { FindAndUpdateOffset(embName, keys, swapId, keepBatch, channelId); } else { @@ -69,6 +76,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t swapId++; EASY_BLOCK("hostHashMaps->tdt") + // 构造查询向量tensor auto lookUpVecSize = static_cast(embHashMap.lookUpVec.size()); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); @@ -79,6 +87,8 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t if (VLOG_IS_ON(GLOG_TRACE)) { VLOG(GLOG_TRACE) << StringFormat("lookupTensor, %s", VectorToString(embHashMap.lookUpVec).c_str()); } + + // 构造交换向量tensor auto swapSize = static_cast(embHashMap.swapPos.size()); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { swapSize })); @@ -92,10 +102,14 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t if (VLOG_IS_ON(GLOG_TRACE)) { VLOG(GLOG_TRACE) << StringFormat("swapTensor, %s", VectorToString(embHashMap.swapPos).c_str()); } + + // 清空本次记录的查询偏移和交换偏移 embHashMap.swapPos.clear(); embHashMap.lookUpVec.clear(); LOG(INFO) << StringFormat("current ddr emb:%s, usage:%d/[%d+%d]", embName.c_str(), embHashMap.maxOffset, embHashMap.devVocabSize, embHashMap.hostVocabSize); + + // 构造交换数量的tensor tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto swapLen = tmpDataOut.back().flat(); swapLen(0) = swapSize; @@ -258,12 +272,16 @@ void EmbHashMap::LoadHashMap(emb_hash_mem_t& loadData) embHashMaps = std::move(loadData); } +/// 对HBM剩余空间和更新位置进行初始化 void EmbHashMapInfo::SetStartCount() { currentUpdatePosStart = currentUpdatePos; freeSize = devVocabSize; } +/// 判断HBM是否有剩余空间 +/// \param i 查询向量的大小 +/// \return bool EmbHashMapInfo::HasFree(size_t i) { return freeSize < i; @@ -314,11 +332,12 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& } } -// old version -/* - * 从embHashMaps获取key对应的位置,并更新devOffset2Batch - */ - +/// 从embHashMaps获取key对应的位置,构造查询向量;更新devOffset2Batch;记录dev与host需要交换的偏移 +/// \param embName 表名 +/// \param keys 查询向量 +/// \param currentBatchId 已处理的batch数 +/// \param keepBatchId 处理batch的次数,多个预取一起处理算一次 +/// \param channelId 通道索引(训练/推理) void EmbHashMap::FindOffset(const string& embName, const vector& keys, size_t currentBatchId, size_t keepBatchId, int channelId) { @@ -340,10 +359,12 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys } if (offset < embHashMap.devVocabSize) { + // 偏移小于等于HBM容量:直接放入查询向量;更新偏移之前关联的key和当前关联的key embHashMap.lookUpVec.emplace_back(offset); embHashMap.devOffset2KeyOld.emplace_back(offset, static_cast(embHashMap.devOffset2Key[offset])); embHashMap.devOffset2Key[offset] = key; } else { + // 偏移大于HBM容量:记录在host emb上的偏移;找到需要交换的HBM偏移 embHashMap.missingKeysHostPos.emplace_back(offset - embHashMap.devVocabSize); FindSwapPosOld(embName, key, offset, currentBatchId, keepBatchId); } @@ -357,6 +378,12 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys } +/// 查找key对应的偏移;1. 已在hash map中,直接返回对应的offset;2. 开启淘汰的情况下,复用淘汰的位置;3. 没有则新分配 +/// \param key 输入特征 +/// \param embHashMap hash map实例 +/// \param channelId 通道索引(训练/推理) +/// \param offset 未初始化变量,用于记录 +/// \return bool EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset) { @@ -399,6 +426,11 @@ bool EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashM return true; } +/// 更新HBM中的key相应offset最近出现的batch步数,用于跟踪哪些offset是最近在使用的 +/// \param keys 查询向量 +/// \param currentBatchId 已处理的batch数 +/// \param keySize 查询向量长度 +/// \param embHashMap hash map实例 void EmbHashMap::UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, EmbHashMapInfo& embHashMap) const { @@ -456,33 +488,43 @@ int EmbHashMap::FindSwapPosV2(const string& embName, emb_key_t key, size_t hostO return newDevOffset; } -/* - * 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys - */ +/// 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys +/// \param embName 表名 +/// \param key 输入特征 +/// \param hostOffset 全局偏移 +/// \param currentBatchId 已处理的batch数 +/// \param keepBatchId 处理batch的次数,多个预取一起处理算一次 +/// \return 是否找到需要交换的位置 bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, size_t keepBatchId) { bool notFind = true; auto& embHashMap = embHashMaps.at(embName); while (notFind) { + // 找到本次预取之前的偏移(保证所有预取batch的key都在HBM中) if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] = static_cast(currentBatchId); - embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); - embHashMap.lookUpVec.emplace_back(embHashMap.currentUpdatePos); - embHashMap.hostHashMap[key] = embHashMap.currentUpdatePos; + embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); // 记录需要被换出的HBM偏移 + embHashMap.lookUpVec.emplace_back(embHashMap.currentUpdatePos); // 交换的位置就是该key查询的偏移 + embHashMap.hostHashMap[key] = embHashMap.currentUpdatePos; // 更新key对应的HBM偏移 + // 记录HBM偏移之前的key embHashMap.devOffset2KeyOld.emplace_back(embHashMap.currentUpdatePos, embHashMap.devOffset2Key[embHashMap.currentUpdatePos]); auto& oldKey = embHashMap.devOffset2Key[embHashMap.currentUpdatePos]; - embHashMap.oldSwap.emplace_back(oldKey, key); - embHashMap.hostHashMap[oldKey] = hostOffset; + embHashMap.oldSwap.emplace_back(oldKey, key); // 记录交换的两个key + embHashMap.hostHashMap[oldKey] = hostOffset; // 更新被替换的key的偏移 oldKey = key; notFind = false; } - embHashMap.currentUpdatePos++; - embHashMap.freeSize--; + embHashMap.currentUpdatePos++; // 查找位置+1 + embHashMap.freeSize--; // HBM可用空间-1 + + // 遍历完一遍整个HBM表后,从头开始遍历 if (embHashMap.currentUpdatePos == embHashMap.devVocabSize) { embHashMap.currentUpdatePos = 0; } + + // 已经找完了整个HBM空间,没有找到可用位置,表示HBM空间不足以放下整个batch(预取batch数)的key,无法正常执行训练,固运行时错误退出 if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { LOG(ERROR) << "devVocabSize is too small"; throw runtime_error("devVocabSize is too small"); diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index e2216e43..dd189723 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -66,6 +66,11 @@ namespace MxRec { bool FindSwapPosOld(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, size_t keepBatchId); + std::vector& GetEvictPos(const string& embName) + { + return embHashMaps.at(embName).evictPos; + } + private: RankInfo rankInfo; int swapId { 0 }; diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 8b889486..f8937dac 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -12,10 +12,15 @@ using namespace MxRec; using namespace std; +/// 1. acl初始化 2. 设置device 3. 为每张表创建数据传输通道 +/// \param embInfos 稀疏表元信息类的list +/// \param localRankId 设备逻辑ID +/// \return int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) { #ifndef GTEST LOG(INFO) << StringFormat(MGMT + "begin hd_transfer initialize, rank:%d", localRankId); + // 使用AscendCL接口开发应用时,必须先调用aclInit接口,否则可能会导致后续系统内部资源初始化出错,进而导致其它业务异常。 aclError retOk = aclInit(nullptr); LOG(INFO) << StringFormat(MGMT + "end aclInit, rank:%d", localRankId); if (retOk != ACL_SUCCESS) { @@ -23,6 +28,7 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) return false; } LOG(INFO) << StringFormat(MGMT + "start Set device, rank:%d", localRankId); + // 指定当前进程或线程中用于运算的Device,同时隐式创建默认Context auto ret = aclrtSetDevice(static_cast(localRankId)); if (ret != ACL_ERROR_NONE) { LOG(ERROR) << StringFormat("Set device failed, device_id:%d", localRankId); @@ -34,6 +40,7 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) for (int i = 0; i < MAX_CHANNEL_NUM; ++i) { CreateChannel(localRankId, embInfo.name, i); } + // 创建acltdtDataset类型的数据,对等一个Vector。同步接口。 aclDatasets[embInfo.name] = acltdtCreateDataset(); } const int defaultAclTimeout = -1; @@ -60,6 +67,7 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) return true; } +/// 删除所有通道和TDT dataset void HDTransfer::Destroy() { #ifndef GTEST @@ -79,6 +87,10 @@ void HDTransfer::Destroy() #endif } +/// 为每张表创建相应的数据传输通道(all2ll、restore、lookup等) +/// \param localRankId 设备逻辑ID +/// \param embName 表名 +/// \param channelNum 通道索引 void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum) { #ifndef GTEST @@ -124,6 +136,12 @@ void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName #endif } +/// 将tensor发送到channel +/// \param channel 通道实例 +/// \param tensors 待发送数据 +/// \param channelId 通道索引(训练/推理) +/// \param embName 表名 +/// \param batchId 已处理的batch数 void HDTransfer::Send(TransferChannel channel, const vector &tensors, int channelId, const string &embName, int batchId) { @@ -174,6 +192,11 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in #endif } +/// 接收从device发送过来的数据(D2H);使用tfa封装的接口 +/// \param channel 通道实例 +/// \param channelId 通道索引(训练/推理) +/// \param embName 表名 +/// \return vector HDTransfer::Recv(TransferChannel channel, int channelId, const string& embName) { EASY_FUNCTION() @@ -202,9 +225,13 @@ vector HDTransfer::Recv(TransferChannel channel, int channel } return tensors; #endif - return {}; } +/// 接收从device发送过来的数据(D2H), updateEmbV2函数使用;使用原生的aclTDT接口 +/// \param channel 通道实例 +/// \param channelId 通道索引(训练/推理) +/// \param embName 表名 +/// \return size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& embName) { EASY_FUNCTION() @@ -226,14 +253,4 @@ size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& LOG(INFO) << StringFormat("hd transfer recv:%s cost:%dms", recvName.c_str(), tc.ElapsedMS()); return acltdtGetDatasetSize(aclDatasets[embName]); #endif - return 0; -} - -size_t HDTransfer::QueryChannelSize(const string& channelName) -{ - size_t size = -1; -#ifndef GTEST - acltdtQueryChannelSize(transferChannels[channelName], &size); -#endif - return size; } diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index d0ddaa73..ac09e6e6 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -81,10 +81,6 @@ namespace MxRec { size_t RecvAcl(TransferChannel channel, int channelId, const string& embName); - size_t QueryChannelSize(const string& channelName); - - auto Vec2Tensor(const vector& tmpVec) const -> vector; - void Destroy(); private: diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index b9a82787..649fae66 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -16,7 +16,11 @@ using namespace MxRec; using namespace std; using namespace chrono; -bool HostEmb::Initialize(const vector& embInfos, int seed) +/// 初始化host emb +/// \param embInfos 表信息列表 +/// \param seed 随机种子 +/// \return +void HostEmb::Initialize(const vector& embInfos, int seed) { for (const auto& embInfo: embInfos) { HostEmbTable hostEmb; @@ -26,9 +30,14 @@ bool HostEmb::Initialize(const vector& embInfos, int seed) hostEmbs[embInfo.name] = move(hostEmb); LOG(INFO) << (HOSTEMB + "HostEmb Initialize End"); } - return true; } +/// 根据指定的初始化器对emb进行初始化 +/// \param initializeInfos emb初始化信息列表 +/// \param seed 随机种子 +/// \param vocabSize host表大小 +/// \param embeddingSize emb维度 +/// \param embData emb数据 void HostEmb::EmbDataGenerator(const vector &initializeInfos, int seed, int vocabSize, int embeddingSize, vector> &embData) { @@ -85,13 +94,8 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in #endif } -void HostEmb::LoadEmb(emb_mem_t& loadData) -{ -#ifndef GTEST - hostEmbs = std::move(loadData); -#endif -} - +/// 停止用于异步更新D2H emb的线程 +/// \param channelId 通道索引(训练/推理) void HostEmb::Join(int channelId) { TimeCost tc = TimeCost(); @@ -123,11 +127,12 @@ void HostEmb::Join(int channelId) } } -/* - * 从hdTransfer获取device侧返回的emb信息,并在host侧表的对应位置插入。 - * missingKeysHostPos为host侧需要发送的emb的位置,也就是淘汰的emb的插入位置 - */ #ifndef GTEST +/// 从hdTransfer获取device侧返回的emb信息,并在host侧表的对应位置插入。 +/// missingKeysHostPos为host侧需要发送的emb的位置,也就是淘汰的emb的插入位置 +/// \param missingKeysHostPos 当前batch在host上需要换出的偏移 +/// \param channelId 通道索引(训练/推理) +/// \param embName 表名 void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName) { LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmb, channelId:%d, embName:%s", channelId, embName.c_str()); @@ -164,6 +169,10 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, EASY_END_BLOCK } +/// 用从device获取的数据更新host的emb(使用aclTDT原生接口) +/// \param missingKeysHostPos 当前batch在host上需要换出的偏移 +/// \param channelId 通道索引(训练/推理) +/// \param embName 表名 void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName) { LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmbV2, channelId:%d, embName:%s", channelId, embName.c_str()); @@ -218,10 +227,10 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI } } -/* - * 找到host侧需要发送的emb,通过hdTransfer发送给device。 - * missingKeysHostPos为host侧需要发送的emb的位置 - */ +/// 查找host侧需要发送给device的emb数据。 +/// \param missingKeysHostPos 当前batch在host上需要换出的偏移 +/// \param embName +/// \param h2dEmbOut void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& embName, vector& h2dEmbOut) { @@ -246,12 +255,17 @@ void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& "GetH2DEmb end, missingKeys count:%d cost:%dms", missingKeysHostPos.size(), tc.ElapsedMS()); } - +/// 获取hostEmbs的指针 +/// \return auto HostEmb::GetHostEmbs() -> absl::flat_hash_map* { return &hostEmbs; } +/// 对指定offset的emb进行初始化 +/// \param initializeInfos emb初始化信息列表 +/// \param embData emb数据 +/// \param offset 偏移列表 void HostEmb::EmbPartGenerator(const vector &initializeInfos, vector> &embData, const vector& offset) { @@ -293,9 +307,9 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve } #endif -/* - * 利用initializer初始化emb淘汰的位置 - */ +/// 利用initializer初始化emb淘汰的位置 +/// \param embName 表名 +/// \param offset 淘汰的偏移列表 void HostEmb::EvictInitEmb(const string& embName, const vector& offset) { #ifndef GTEST diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index 9ffd9bd5..cfe9339a 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -28,9 +28,7 @@ namespace MxRec { ~HostEmb() {}; - bool Initialize(const vector& embInfos, int seed); - - void LoadEmb(absl::flat_hash_map& loadData); + void Initialize(const vector& embInfos, int seed); void Join(int channelId); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index a7ffdb50..b63f7744 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -12,16 +12,24 @@ using namespace MxRec; using namespace std; +/// 启动数据处理线程 +/// \param rankInfo 当前rank基本配置信息 +/// \param embInfos 表信息list +/// \param thresholdValues 准入淘汰相关配置 +/// \param seed 随机种子 +/// \return bool类型 启动成功/失败 bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed) { #ifndef GTEST + // 是否设置全局去重(相同key的梯度先累加),默认为false if (getenv("APPLY_GRADIENTS_STRATEGY") != nullptr) { bool strategy = (!strcmp(getenv("APPLY_GRADIENTS_STRATEGY"), SUM_SAME_ID)); PerfConfig::gradientStrategy = strategy; LOG(INFO) << StringFormat("config GRADIENTS_STRATEGY:%d", strategy); } + // 设置当前进程用于数据处理的线程数,默认为6,取值1-10;取值不在范围内,则数据处理线程启动失败退出 if (getenv("KEY_PROCESS_THREAD_NUM") != nullptr) { int num = std::atoi(getenv("KEY_PROCESS_THREAD_NUM")); if (num < 1 || num > MAX_KEY_PROCESS_THREAD) { @@ -34,6 +42,7 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& LOG(INFO) << StringFormat("config KEY_PROCESS_THREAD_NUM:%d", num); } + // 设置AccCTR去重线程数,默认为8,取值1-8;取值不在范围内,则数据处理线程启动失败退出 if (getenv("MAX_UNIQUE_THREAD_NUM") != nullptr) { int num = std::atoi(getenv("MAX_UNIQUE_THREAD_NUM")); if (num < 1 || num > DEFAULT_MAX_UNIQUE_THREAD_NUM) { @@ -46,11 +55,13 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& LOG(INFO) << StringFormat("config MAX_UNIQUE_THREAD_NUM:%d", num); } + // 设置是否使用AccCTR库提供的去重、分桶功能,默认关闭 const int defaultFastUnique = false; PerfConfig::fastUnique = defaultFastUnique; const char* envFastUnique = getenv("FAST_UNIQUE"); HybridMgmt::CheckFastUnique(envFastUnique); + // 初始化数据处理类,配置相关信息,启动处理线程 preprocess = Singleton::GetInstance(); preprocess->Initialize(rankInfo, embInfos, thresholdValues, seed); preprocess->Start(); @@ -76,31 +87,45 @@ void HybridMgmt::CheckFastUnique(const char *envFastUnique) } } +/// Openmpi通信域进程数设置、计算所有表host特征数量总数、设置训练模式(HBM/DDR) +/// \param rankInfo +/// \param embInfos void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfos) { #ifndef GTEST MPI_Comm_size(MPI_COMM_WORLD, &rankInfo.rankSize); rankInfo.localRankId = rankInfo.deviceId; + // 计算训练任务涉及的所有表在DDR中需要分配的key数量 size_t totHostVocabSize = 0; for (const auto& emb : embInfos) { totHostVocabSize += emb.hostVocabSize; } + + // 根据DDR的key数量,配置存储模式HBM/DDR if (totHostVocabSize == 0) { rankInfo.noDDR = true; } - rankInfo.useDataset = getenv("DATASET") != nullptr; #endif } +/// 处理进程初始化入口,由python侧调用 +/// \param rankInfo 当前rank基本配置信息 +/// \param embInfos 表信息list +/// \param seed 随机种子 +/// \param thresholdValues 准入淘汰相关配置 +/// \param ifLoad 是否断点续训 +/// \return bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, int seed, const vector& thresholdValues, bool ifLoad) { #ifndef GTEST + // 判断是否已经拉起特征处理线程(key process) if (isRunning) { return true; } + // 设置日志的级别,对日志格式进行配置 SetLog(rankInfo.rankId); InitRankInfo(rankInfo, embInfos); @@ -110,11 +135,12 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, mgmtRankInfo = rankInfo; mgmtEmbInfo = embInfos; - skipUpdate = getenv("SKIP_UPDATE") != nullptr; + // 进行acl资源初始化,设置当前训练进程的device,为每张表创建数据传输通道 hdTransfer = Singleton::GetInstance(); hdTransfer->Init(embInfos, rankInfo.deviceId); + // 启动数据处理线程 bool rc = InitKeyProcess(rankInfo, embInfos, thresholdValues, seed); if (!rc) { return false; @@ -122,14 +148,17 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, isRunning = true; + // DDR模式,初始化hashmap和host emb if (!rankInfo.noDDR) { hostEmbs = make_unique(); hostHashMaps = make_unique(); hostEmbs->Initialize(embInfos, seed); hostHashMaps->Init(rankInfo, embInfos, ifLoad); } + + // 非断点续训模式,启动数据传输 isLoad = ifLoad; - if (!rankInfo.useDataset && !isLoad) { + if (!isLoad) { Start(); } @@ -139,30 +168,36 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, info.name.c_str(), info.devVocabSize, info.hostVocabSize, info.sendCount); } LOG(INFO) << StringFormat( - MGMT + "end initialize, useDataset:%d, noDDR:%d, maxStep:[%d, %d], rank:%d", - rankInfo.useDataset, rankInfo.noDDR, + MGMT + "end initialize, noDDR:%d, maxStep:[%d, %d], rank:%d", rankInfo.noDDR, rankInfo.maxStep.at(TRAIN_CHANNEL_ID), rankInfo.maxStep.at(EVAL_CHANNEL_ID), rankInfo.rankId); #endif return true; } +/// 保存模型 +/// \param savePath 保存路径 +/// \return bool HybridMgmt::Save(const string savePath) { #ifndef GTEST + // 数据处理线程上锁 preprocess->LoadSaveLock(); CkptData saveData; Checkpoint saveCkpt; if (!mgmtRankInfo.noDDR) { + // DDR模式保存host的emb表以及hashmap VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: ddr mode hashmap"); saveData.hostEmbs = hostEmbs->GetHostEmbs(); saveData.embHashMaps = hostHashMaps->GetHashMaps(); } else { + // HBM模式保存最大偏移(真正使用了多少vocab容量),特征到偏移的映射 VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: no ddr mode hashmap"); saveData.maxOffset = preprocess->GetMaxOffset(); saveData.keyOffsetMap = preprocess->GetKeyOffsetMap(); } + // 保存特征准入淘汰相关的数据 auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: feature admit and evict"); @@ -171,16 +206,22 @@ bool HybridMgmt::Save(const string savePath) saveData.histRec.historyRecords = featAdmitNEvict.GetHistoryRecords().historyRecords; } + // 执行保存操作 saveCkpt.SaveModel(savePath, saveData, mgmtRankInfo, mgmtEmbInfo); + // 数据处理线程释放锁 preprocess->LoadSaveUnlock(); #endif return true; } +/// 加载模型 +/// \param loadPath +/// \return bool HybridMgmt::Load(const string& loadPath) { #ifndef GTEST + // 数据处理线程上锁 preprocess->LoadSaveLock(); VLOG(GLOG_DEBUG) << (MGMT + "Start host side load process"); @@ -189,33 +230,44 @@ bool HybridMgmt::Load(const string& loadPath) Checkpoint loadCkpt; vector loadFeatures; if (!mgmtRankInfo.noDDR) { + // DDR模式加载的类型为host的emb表以及hashmap loadFeatures.push_back(CkptFeatureType::HOST_EMB); loadFeatures.push_back(CkptFeatureType::EMB_HASHMAP); } else { + // HBM模式加载的类型为最大偏移(真正使用了多少vocab容量),特征到偏移的映射 loadFeatures.push_back(CkptFeatureType::MAX_OFFSET); loadFeatures.push_back(CkptFeatureType::KEY_OFFSET_MAP); } + // 添加特征准入淘汰相关的数据类型的加载 auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { loadFeatures.push_back(CkptFeatureType::FEAT_ADMIT_N_EVICT); } - loadData.hostEmbs = hostEmbs->GetHostEmbs(); + loadData.hostEmbs = hostEmbs->GetHostEmbs(); // 获取已经初始化好的host emb + + // 执行加载操作 loadCkpt.LoadModel(loadPath, loadData, mgmtRankInfo, mgmtEmbInfo, loadFeatures); + + // 检查DDR模式保存的模型和当前训练配置是否一致,不一致则退出 if (!mgmtRankInfo.noDDR && !LoadMatchesDDRSetup(loadData)) { preprocess->LoadSaveUnlock(); return false; } if (!mgmtRankInfo.noDDR) { + // DDR模式 将加载的hash map进行赋值 VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: ddr mode hashmap"); hostHashMaps->LoadHashMap(loadData.embHashMaps); } else { + // HBM模式 将加载的最大偏移(真正使用了多少vocab容量)、特征到偏移的映射,进行赋值 VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: no ddr mode hashmap"); preprocess->LoadMaxOffset(loadData.maxOffset); preprocess->LoadKeyOffsetMap(loadData.keyOffsetMap); } + + // 将加载的特征准入淘汰记录进行赋值 if (featAdmitNEvict.GetFunctionSwitch()) { VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: feature admit and evict"); featAdmitNEvict.LoadTableThresholds(loadData.table2Thresh); @@ -226,13 +278,17 @@ bool HybridMgmt::Load(const string& loadPath) preprocess->LoadSaveUnlock(); - if (!mgmtRankInfo.useDataset && isLoad) { + // 执行训练 + if (isLoad) { Start(); } #endif return true; } +/// 获取key对应的offset,python侧调用 +/// \param tableName 表名 +/// \return key_offset_map_t HybridMgmt::SendHostMap(const string tableName) { #ifndef GTEST @@ -258,6 +314,8 @@ key_offset_map_t HybridMgmt::SendHostMap(const string tableName) #endif } +/// 加载key对应的offset,python侧调用;启动数据处理线程 +/// \param ReceiveKeyOffsetMap void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) { #ifndef GTEST @@ -283,12 +341,17 @@ void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) } preprocess->LoadSaveUnlock(); - if (!mgmtRankInfo.useDataset && isLoad) { + if (isLoad) { Start(); } #endif } +/// 对加载的数据和训练配置进行一致性校验 +/// \param loadHostEmbs +/// \param setupHostEmbs +/// \param embTableCount +/// \return bool HybridMgmt::IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t& embTableCount) { bool loadDataMatches = { true }; @@ -333,6 +396,9 @@ bool HybridMgmt::IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEm return true; } +/// 对DDR模式保存的模型和训练配置进行一致性校验 +/// \param loadData +/// \return 是否一致 bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) { size_t embTableCount { 0 }; @@ -351,6 +417,7 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) return true; } +/// 根据HBM/DDR模式,启动数据处理线程 void HybridMgmt::Start() { #ifndef GTEST @@ -374,6 +441,7 @@ void HybridMgmt::Start() #endif } +/// 启动HBM模式数据处理线程 void HybridMgmt::InsertThreadForHBM() { #ifndef GTEST @@ -392,6 +460,8 @@ void HybridMgmt::InsertThreadForHBM() } #ifndef GTEST +/// 启动训练数据处理线程 +/// \param type 存储模式 void HybridMgmt::TaskForTrain(TaskType type) { bool isFirstIn = true; @@ -409,6 +479,8 @@ void HybridMgmt::TaskForTrain(TaskType type) } } +/// 启动推理数据处理线程 +/// \param type 存储模式 void HybridMgmt::TaskForEval(TaskType type) { bool isFirstIn = true; @@ -426,26 +498,27 @@ void HybridMgmt::TaskForEval(TaskType type) } } +/// 训练数据处理:数据处理状态正常,处理的batch数小于用户预设值或者设为-1时,会循环处理; +/// \param type 存储模式 +/// \return bool HybridMgmt::TrainTask(TaskType type) { - bool isContinue = false; + bool isContinue; + bool status; do { if (!isRunning) { return false; } - bool status = false; switch (type) { case TaskType::HBM: status = ParseKeysHBM(TRAIN_CHANNEL_ID, trainBatchId); - isContinue = trainBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; + isContinue = !EndBatch(trainBatchId, TRAIN_CHANNEL_ID); LOG(INFO) << StringFormat(MGMT + "ParseKeysHBMBatchId = %d", trainBatchId); break; case TaskType::DDR: status = ParseKeys(TRAIN_CHANNEL_ID, trainBatchId); - isContinue = trainBatchId % mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1; + isContinue = !EndBatch(trainBatchId, TRAIN_CHANNEL_ID); LOG(INFO) << StringFormat(MGMT + "parseKeysBatchId = %d", trainBatchId); break; default: @@ -460,6 +533,9 @@ bool HybridMgmt::TrainTask(TaskType type) return true; } +/// 推理数据处理:数据处理状态正常,处理的batch数小于用户预设值或者设为-1时,会循环处理; +/// \param type 存储模式 +/// \return bool HybridMgmt::EvalTask(TaskType type) { int evalBatchId = 0; // 0-99, 0-99 @@ -485,33 +561,42 @@ bool HybridMgmt::EvalTask(TaskType type) if (!status) { return false; } - } while (evalBatchId % mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] != 0 || - mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1); + } while (!EndBatch(evalBatchId, EVAL_CHANNEL_ID)); return true; } +/// HBM模式下,发送key process线程已处理好的各类型向量到指定通道中 +/// \param channelId 通道索引(训练/推理) +/// \param batchId 已处理的batch数 +/// \return bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) { LOG(INFO) << StringFormat( MGMT + "start parse keys HBM, nBatch:%d , [%d]:%d", mgmtRankInfo.nBatch, channelId, batchId); + + // 循环处理每个表的数据 for (const auto& embInfo: mgmtEmbInfo) { TimeCost ParseKeysTC; // get TimeCost getTensorsSyncTC; + + // 获取各类向量,如果为空指针,退出当前函数 auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); if (infoVecs == nullptr) { LOG(INFO) << StringFormat( MGMT + "ParseKeys infoVecs empty ! batchId:%d, channelId:%d", batchId, channelId); return false; } + + // 动态shape场景下,获取all2all向量(通信量矩阵) unique_ptr> all2all = nullptr; if (!mgmtRankInfo.useStatic) { all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); } VLOG(GLOG_DEBUG) << StringFormat("getTensorsSyncTC(ms):%d", getTensorsSyncTC.ElapsedMS()); - // send + // 动态shape场景下,发送all2all向量(通信量矩阵) TimeCost sendTensorsSyncTC; if (!mgmtRankInfo.useStatic) { TimeCost sendAll2AllScSyncTC; @@ -519,11 +604,13 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) VLOG(GLOG_DEBUG) << StringFormat("sendAll2AllScSyncTC(ms):%d", sendAll2AllScSyncTC.ElapsedMS()); } + // 发送查询向量 TimeCost sendLookupSyncTC; hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, embInfo.name); infoVecs->pop_back(); VLOG(GLOG_DEBUG) << StringFormat("sendLookupSyncTC(ms):%d", sendLookupSyncTC.ElapsedMS()); + // 训练时,使用全局去重聚合梯度,发送全局去重的key和对应的恢复向量 if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID) { TimeCost sendUnikeysSyncTC; hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embInfo.name); @@ -536,6 +623,7 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) VLOG(GLOG_DEBUG) << StringFormat("sendRestoreVecSecSyncTC(ms):%d", sendRestoreVecSecSyncTC.ElapsedMS()); } + // 发送恢复向量 TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); VLOG(GLOG_DEBUG) << StringFormat("sendRestoreSyncTC(ms):%d", sendRestoreSyncTC.ElapsedMS()); @@ -549,11 +637,19 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) } #endif +/// 当前处理的batch是否是最后一个batch +/// \param batchId 已处理的batch数 +/// \param channelId 通道索引(训练/推理) +/// \return bool HybridMgmt::EndBatch(int batchId, int channelId) const { return (batchId % mgmtRankInfo.maxStep[channelId] == 0 && mgmtRankInfo.maxStep[channelId] != -1); } +/// DDR模式下,发送key process线程已处理好的各类型向量到指定通道中 +/// \param channelId 通道索引(训练/推理) +/// \param batchId 已处理的batch数 +/// \return bool HybridMgmt::ParseKeys(int channelId, int& batchId) { #ifndef GTEST @@ -562,13 +658,15 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) mgmtRankInfo.nBatch, channelId, batchId); TimeCost parseKeyTC; int start = batchId; - int iBatch = 0; + int iBatch = 0; // 预取数据处理计数 bool ifHashmapFree = true; - bool remainBatch = true; + bool remainBatch = true; // 是否从通道获取了数据 while (true) { LOG(INFO) << StringFormat(MGMT + "parse keys, [%d]:%d", channelId, batchId); for (const auto& embInfo : mgmtEmbInfo) { ifHashmapFree = ProcessEmbInfo(embInfo.name, batchId, channelId, iBatch, remainBatch); + + // 通道数据已空 if (!remainBatch) { TimeCost embHdTrans1; EmbHDTransWrap(channelId, batchId, start, iBatch); @@ -594,23 +692,37 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) } #ifndef GTEST +/// 构造训练所需的各种向量数据 +/// \param embName 表名 +/// \param batchId 已处理的batch数 +/// \param channelId 通道索引(训练/推理) +/// \param iBatch 预取数据处理计数 +/// \param remainBatchOut 是否从通道获取了数据 +/// \return HBM是否还有剩余空间 bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut) { TimeCost getAndSendTensorsTC; TimeCost getTensorsTC; auto& embHashMap = hostHashMaps->embHashMaps.at(embName); + + // 进行新一批预取数据时,计数初始化 if (iBatch == 0) { embHashMap.SetStartCount(); } + + // 获取查询向量 auto lookupKeys = preprocess->GetLookupKeys(batchId, embName, channelId); if (lookupKeys.empty()) { remainBatchOut = false; } + // 获取各类向量,如果为空指针,退出当前函数 auto infoVecs = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); + if (infoVecs == nullptr) { return false; } VLOG(GLOG_DEBUG) << StringFormat("getTensorsTC(ms):%d", getTensorsTC.ElapsedMS()); + // 训练时,使用全局去重聚合梯度,发送全局去重的key和对应的恢复向量 if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { TimeCost sendUnikeysSyncTC; hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embName); @@ -623,19 +735,24 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, VLOG(GLOG_DEBUG) << StringFormat("sendRestoreVecSecSyncTC(ms):%d", sendRestoreVecSecSyncTC.ElapsedMS()); } + // 发送恢复向量 TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embName); VLOG(GLOG_DEBUG) << StringFormat("sendRestoreSyncTC(ms):%d", sendRestoreSyncTC.ElapsedMS()); + // 计算查询向量;记录需要被换出的HBM偏移 vector tmpData; TimeCost hostHashMapProcessTC; hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData, channelId); VLOG(GLOG_DEBUG) << StringFormat("hostHashMapProcessTC(ms):%d", hostHashMapProcessTC.ElapsedMS()); + // 发送查询、换出向量 TimeCost sendTensorsTC; hdTransfer->Send(TransferChannel::LOOKUP, { tmpData.front() }, channelId, embName); tmpData.erase(tmpData.begin()); hdTransfer->Send(TransferChannel::SWAP, tmpData, channelId, embName); + + // 动态shape场景下,获取与发送all2all向量(通信量矩阵) if (!mgmtRankInfo.useStatic) { auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); @@ -655,7 +772,11 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, return true; } -// send h2d & recv d2h emb +/// 发送H2D和接收D2H向量 +/// \param channelId 通道索引(训练/推理) +/// \param batchId 已处理的batch数 +/// \param start +/// \param iBatch 预取数据处理计数 void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, int iBatch) { if (iBatch == 0) { @@ -675,6 +796,9 @@ void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, in } } +/// 发送H2D和接收D2H向量,并更新host emb +/// \param channelId 通道索引(训练/推理) +/// \param batchId 已处理的batch数 void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) { EASY_FUNCTION(profiler::colors::Blue) @@ -682,8 +806,9 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) VLOG(GLOG_DEBUG) << StringFormat(MGMT + "trans emb, batchId:%d, channelId:%d", batchId, channelId); TimeCost tr; TimeCost h2dTC; + // 发送host需要换出的emb for (const auto& embInfo: mgmtEmbInfo) { - auto& missingKeys = hostHashMaps->embHashMaps.at(embInfo.name).missingKeysHostPos; + auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); vector h2dEmb; hostEmbs->GetH2DEmb(missingKeys, embInfo.name, h2dEmb); // order! hdTransfer->Send(TransferChannel::H2D, h2dEmb, channelId, embInfo.name, batchId); @@ -691,16 +816,15 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) VLOG(GLOG_DEBUG) << StringFormat("h2dTC(ms):%d", h2dTC.ElapsedMS()); TimeCost d2hTC; + // 接收device换出的emb,并更新到host上 for (const auto& embInfo: mgmtEmbInfo) { const auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); - if (!(skipUpdate && missingKeys.empty())) { - auto updateEmbV2 = getenv("UpdateEmb_V2"); - if (updateEmbV2 != nullptr and atoi(updateEmbV2) == 1) { - hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! - } else { - hostEmbs->UpdateEmb(missingKeys, channelId, embInfo.name); // order! - } - } // skip when skip update and empty missing keys + auto updateEmbV2 = getenv("UpdateEmb_V2"); + if (updateEmbV2 != nullptr and atoi(updateEmbV2) == 1) { + hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! + } else { + hostEmbs->UpdateEmb(missingKeys, channelId, embInfo.name); // order! + } hostHashMaps->ClearMissingKeys(embInfo.name); } VLOG(GLOG_DEBUG) << StringFormat("d2hTC(ms):%d", d2hTC.ElapsedMS()); @@ -709,23 +833,14 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) "EmbHDTrans TimeCost(ms):%d batchId: %d channelId:%d", tr.ElapsedMS(), batchId, channelId ); } - -void HybridMgmt::EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo) -{ - EASY_FUNCTION(profiler::colors::Blue) - EASY_VALUE("mgmtProcess", batchId) - LOG(INFO) << StringFormat(MGMT + "trans emb dummy, batchId:%d, channelId:%d", batchId, channelId); - auto transferName = TransferChannel::D2H; - auto d2hEmb = hdTransfer->Recv(transferName, channelId, embInfo.name)[0]; - hdTransfer->Send(TransferChannel::H2D, {}, channelId, embInfo.name); -} #endif -/* -* hook通过时间或者step数触发淘汰 -*/ + +/// hook通过时间或者step数触发淘汰 +/// \return bool HybridMgmt::Evict() { #ifndef GTEST + // 配置了淘汰选项,则触发 auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { featAdmitNEvict.FeatureEvict(evictKeyMap); @@ -734,6 +849,8 @@ bool HybridMgmt::Evict() return false; } VLOG(GLOG_DEBUG) << StringFormat(MGMT + "evict triggered by hook, evict TableNum %d ", evictKeyMap.size()); + + // 表为空,淘汰触发失败 if (evictKeyMap.size() == 0) { LOG(WARNING) << (MGMT + "evict triggered by hook before dataset in injected"); return false; @@ -752,7 +869,9 @@ bool HybridMgmt::Evict() #endif } -// ddr模式淘汰->删除映射表、初始化host表、发送dev淘汰位置 +/// DDR模式下的淘汰:删除映射表、初始化host表、发送dev淘汰位置 +/// \param embName +/// \param keys void HybridMgmt::EvictKeys(const string& embName, const vector& keys) { #ifndef GTEST @@ -765,7 +884,7 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) } // 初始化host侧的emb - auto& evictOffset = hostHashMaps->embHashMaps.at(embName).evictPos; + auto& evictOffset = hostHashMaps->GetEvictPos(embName); if (evictOffset.size() != 0) { VLOG(GLOG_DEBUG) << StringFormat( MGMT + "ddr mode, delete emb: [%s]! evict size on host:%d", embName.c_str(), evictOffset.size() diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index d084122a..4c77d6d9 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -30,8 +30,6 @@ namespace MxRec { using namespace std; using namespace tensorflow; - constexpr int SEND_TENSOR_TYPE_NUM = 2; - enum class TaskType { HBM, DDR @@ -131,7 +129,6 @@ namespace MxRec { KeyProcess *preprocess; HDTransfer *hdTransfer; bool isRunning; - bool skipUpdate; bool isLoad { false }; void TaskForTrain(TaskType type); @@ -139,8 +136,6 @@ namespace MxRec { bool TrainTask(TaskType type); bool EvalTask(TaskType type); - void EmbHDTransDummy(int channelId, int batchId, const EmbInfo& embInfo); - bool EndBatch(int batchId, int channelId) const; void EmbHDTransWrap(int channelId, const int& batchId, int start, int iBatch); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index af297fd4..14d420de 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -165,6 +165,8 @@ void KeyProcess::LoadMaxOffset(offset_mem_t& loadData) maxOffset = std::move(loadData); } +/// 加载每张表key到offset的映射 +/// \param loadData void KeyProcess::LoadKeyOffsetMap(key_offset_mem_t& loadData) { keyOffsetMap = std::move(loadData); @@ -182,6 +184,7 @@ void KeyProcess::Destroy() LOG(INFO) << StringFormat(KEY_PROCESS "rank %d destroy success.", rankInfo.rankId); } +/// 每个数据通道的所有数据处理线程上锁 void KeyProcess::LoadSaveLock() { for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { @@ -191,6 +194,7 @@ void KeyProcess::LoadSaveLock() } } +/// 每个数据通道的所有数据处理线程释放锁 void KeyProcess::LoadSaveUnlock() { for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { @@ -1164,10 +1168,15 @@ T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, in return move(t); } -// DDR +/// DDR模式下,从list中获取查询tensor向量 +/// \param batch 已处理的batch数 +/// \param embName 表名 +/// \param channel 通道索引(训练/推理) +/// \return keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) { TimeCost tc = TimeCost(); + // 循环尝试获取list中的数据;如果key process线程退出或者处理数据超时,返回空vector while (true) { if (!isRunning) { return {}; @@ -1191,11 +1200,18 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) } } -// HBM +/// HBM模式下,从list中获取指定类型的tensor向量 +/// \param batch 已处理的batch数 +/// \param embName 表名 +/// \param channel 通道索引(训练/推理) +/// \param type 数据类型 +/// \return unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) { TimeCost tc = TimeCost(); info_list_t* list; + + // 根据数据类型,选择对应的list switch (type) { case ProcessedInfo::ALL2ALL: list = &all2AllList; @@ -1206,6 +1222,8 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa default: throw std::invalid_argument("Invalid ProcessedInfo Type."); } + + // 循环尝试获取list中的数据;如果key process线程退出或者处理数据超时,返回空指针 while (true) { if (!isRunning) { return nullptr; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index a38a797f..9c9cc2f3 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -206,7 +206,6 @@ struct BatchTask { bool useHot {}; uint32_t option {}; int nBatch {}; - bool useDataset { false }; // deprecated bool noDDR { false }; bool useDynamicExpansion {false}; std::vector maxStep; @@ -458,7 +457,7 @@ struct BatchTask { }; struct EmbHashMapInfo { - absl::flat_hash_map hostHashMap; + absl::flat_hash_map hostHashMap; // key在HBM中的偏移 std::vector devOffset2Batch; // has -1 std::vector devOffset2Key; size_t currentUpdatePos; @@ -467,7 +466,7 @@ struct BatchTask { size_t devVocabSize; size_t freeSize; std::vector lookUpVec; - std::vector missingKeysHostPos; + std::vector missingKeysHostPos; // 用于记录当前batch在host上需要换出的偏移 std::vector swapPos; size_t maxOffset { 0 }; std::vector evictPos; -- Gitee From 1b6ae984495796b8c71e0edad4f377dc72e8c11f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 24 Aug 2023 16:08:51 +0800 Subject: [PATCH 278/551] Match-id-1e6a04010ce85c35ab956d0ccfe125c761cd42f8 --- src/core/CMakeLists.txt | 2 + src/core/checkpoint/checkpoint.cpp | 20 +- src/core/checkpoint/checkpoint.h | 2 - src/core/ssd_engine/file.cpp | 282 +++++++++++++++++++++++++++++ src/core/ssd_engine/file.h | 69 +++++++ src/core/utils/common.cpp | 19 ++ src/core/utils/common.h | 3 + src/tests/CMakeLists.txt | 2 + src/tests/ssd_engine/file_test.cpp | 133 ++++++++++++++ 9 files changed, 511 insertions(+), 21 deletions(-) create mode 100644 src/core/ssd_engine/file.cpp create mode 100644 src/core/ssd_engine/file.h create mode 100644 src/tests/ssd_engine/file_test.cpp diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 27fbc4be..725d1b14 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -21,6 +21,8 @@ else() include_directories(${PYTHON_PATH}/lib/python3.7/site-packages/tensorflow/include) endif() +link_libraries(stdc++fs) + file(GLOB_RECURSE MXREC_SRC ./*.cpp) add_library(ASC SHARED ${MXREC_SRC}) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 27d8783e..4349331f 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -7,7 +7,6 @@ #include #include -#include #include #include @@ -17,6 +16,7 @@ #include "ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h" #include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" #include "utils/time_cost.h" +#include "utils/common.h" #include "checkpoint.h" @@ -562,21 +562,3 @@ void Checkpoint::ReadDataset(CkptTransData& transData, readFile.read((char*)(transData.attribute.data()) + idx, readSize); } } - -void Checkpoint::ValidateReadFile(const string& dataDir, size_t datasetSize) -{ - // validate soft link - struct stat fileInfo; - if (lstat(dataDir.c_str(), &fileInfo) != -1) { - if (S_ISLNK(fileInfo.st_mode)) { - LOG(ERROR) << StringFormat("soft link %s should not in the path parameter", dataDir.c_str()); - throw invalid_argument(StringFormat("soft link should not be the path parameter")); - } - } - // validate file size - if (datasetSize <= FILE_MIN_SIZE || datasetSize > FILE_MAX_SIZE) { - LOG(ERROR) << StringFormat("the reading file size is invalid, " - "not in not in (%d,%d]", FILE_MIN_SIZE, FILE_MAX_SIZE); - throw invalid_argument(StringFormat("file size invalid")); - } -} diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 17c22f17..481a8380 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -97,8 +97,6 @@ namespace MxRec { void SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType); void ReadDataset(CkptTransData& transData, ifstream& readFile, size_t readSize, CkptDataType dataType, size_t idx); - - void ValidateReadFile(const string& dataDir, size_t datasetSize); }; } diff --git a/src/core/ssd_engine/file.cpp b/src/core/ssd_engine/file.cpp new file mode 100644 index 00000000..ff1c67a6 --- /dev/null +++ b/src/core/ssd_engine/file.cpp @@ -0,0 +1,282 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ +#include "file.h" + +#include + +#include "utils/common.h" + +using namespace MxRec; + +/// 创建新文件实例,包含元数据文件、数据文件 +/// \param fileID 文件ID +/// \param saveDir 保存文件夹的路径 +File::File(uint64_t fileID, string &saveDir) : fileID(fileID), saveDir(saveDir) +{ + VLOG(GLOG_DEBUG) << StringFormat("start init file, fileID:%llu", fileID); + + if (!fs::exists(fs::absolute(saveDir))) { + if (!fs::create_directories(fs::absolute(saveDir))) { + throw runtime_error("fail to create Save directory"); + } + } + + metaFilePath = fs::absolute(saveDir + "/" + to_string(fileID) + ".meta.latest"); + dataFilePath = fs::absolute(saveDir + "/" + to_string(fileID) + ".data.latest"); + localFileMeta.open(metaFilePath, ios::out | ios::trunc | ios::binary); + if (!localFileMeta.is_open()) { + throw runtime_error("fail to create meta file"); + } + fs::permissions(metaFilePath, fs::perms::owner_read | fs::perms::owner_write); + localFileData.open(dataFilePath, ios::out | ios::in | ios::trunc | ios::binary); + if (!localFileData.is_open()) { + throw runtime_error("fail to create data file"); + } + fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write); + + VLOG(GLOG_DEBUG) << StringFormat("end init file, fileID:%llu", fileID); +} + +/// 创建文件实例并加载,从保存路径中读取元数据文件、数据文件 +/// \param fileID 文件ID +/// \param saveDir 保存文件夹的路径 +/// \param step 加载的步数 +File::File(uint64_t fileID, string &saveDir, int step) : fileID(fileID), saveDir(saveDir) +{ + VLOG(GLOG_DEBUG) << StringFormat("start init file with load, fileID:%llu", fileID); + + fs::path metaFileToLoad = fs::absolute(saveDir + "/" + to_string(fileID) + ".meta." + to_string(step)); + fs::path dataFileToLoad = fs::absolute(saveDir + "/" + to_string(fileID) + ".data." + to_string(step)); + if (!fs::exists(metaFileToLoad)) { + throw invalid_argument("meta file not found while loading"); + } + if (!fs::exists(dataFileToLoad)) { + throw invalid_argument("data file not found while loading"); + } + + ValidateReadFile(metaFileToLoad, fs::file_size(metaFileToLoad)); + ValidateReadFile(dataFileToLoad, fs::file_size(dataFileToLoad)); + + metaFilePath = fs::absolute(saveDir + "/" + to_string(fileID) + ".meta.latest"); + dataFilePath = fs::absolute(saveDir + "/" + to_string(fileID) + ".data.latest"); + fs::remove(metaFilePath); + fs::remove(dataFilePath); + + if (!fs::copy_file(metaFileToLoad, metaFilePath)) { + throw runtime_error("fail to create latest meta file"); + } + if (!fs::copy_file(dataFileToLoad, dataFilePath)) { + throw runtime_error("fail to create latest data file"); + } + + localFileMeta.open(metaFilePath, ios::in | ios::binary); + if (!localFileMeta.is_open()) { + throw runtime_error("fail to Load latest meta file"); + } + fs::permissions(metaFilePath, fs::perms::owner_read | fs::perms::owner_write); + localFileData.open(dataFilePath, ios::out | ios::in | ios::binary); + if (!localFileData.is_open()) { + throw runtime_error("fail to Load latest data file"); + } + fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write); + Load(); + + VLOG(GLOG_DEBUG) << StringFormat("end init file with load, fileID:%llu", fileID); +} + +File::~File() +{ + localFileMeta.close(); + localFileData.close(); + + // should call Save before exit program, temporary data will be removed. + fs::remove(metaFilePath); + fs::remove(dataFilePath); +} + +bool File::IsKeyExist(emb_key_t key) +{ + auto it = keyToOffset.find(key); + return !(it == keyToOffset.end()); +} + +void File::InsertEmbeddings(vector &keys, vector> &embeddings) +{ + if (keys.size() != embeddings.size()) { + throw invalid_argument("keys' length not equal to embeddings' length"); + } + + localFileData.seekp(lastWriteOffset); // always set pointer to buffer end in case reading happened before + + size_t dLen = keys.size(); + for (size_t i = 0; i < dLen; ++i) { + if (IsKeyExist(keys[i])) { + staleDataCnt++; + } + keyToOffset[keys[i]] = lastWriteOffset; + + uint64_t embSize = embeddings[i].size(); + localFileData.write(reinterpret_cast(&embSize), sizeof(embSize)); + localFileData.write(reinterpret_cast(embeddings[i].data()), + embeddings[i].size() * sizeof(float)); + + auto pos = localFileData.tellp(); + if (pos == -1) { + throw runtime_error("can't get file position pointer"); + } + lastWriteOffset = offset_t(pos); + } + dataCnt += dLen; +} + +vector> File::FetchEmbeddings(vector &keys) +{ + vector> ret; + for (emb_key_t k: keys) { + auto it = keyToOffset.find(k); + if (it == keyToOffset.end()) { + throw invalid_argument("key not exist"); + } + localFileData.seekg(it->second); // for fstream, this moves the file position pointer (both put and get) + if (localFileData.fail()) { + throw runtime_error("can't move file position pointer"); + } + + uint64_t embSize; + localFileData.read(reinterpret_cast(&embSize), sizeof(embSize)); + + vector tmp; + tmp.resize(embSize); + localFileData.read(reinterpret_cast(tmp.data()), tmp.size() * sizeof(float)); + ret.emplace_back(tmp); + } + return ret; +} + +void File::DeleteEmbedding(emb_key_t key) +{ + if (!IsKeyExist(key)) { + return; + } + keyToOffset.erase(key); + staleDataCnt += 1; +} + +void File::Save(int step) +{ + VLOG(GLOG_DEBUG) << StringFormat("start save file at step:%d, fileID:%llu", step, fileID); + + // write current meta into meta file + for (auto [key, offset]: keyToOffset) { + localFileMeta.write(reinterpret_cast(&key), sizeof(key)); + localFileMeta.write(reinterpret_cast(&offset), sizeof(offset)); + } + // flush not guarantee data already written into disk, must call close to force flush and wait + localFileMeta.flush(); + if (localFileMeta.fail()) { + throw runtime_error("fail to save latest meta"); + } + localFileMeta.close(); + + fs::path metaFileToSave = fs::absolute(saveDir + "/" + to_string(fileID) + ".meta." + to_string(step)); + if (fs::exists(metaFileToSave)) { + throw invalid_argument("fail to save latest meta, file already exist"); + } + + VLOG(GLOG_DEBUG) << StringFormat("save latest meta file at step:%d, fileID:%llu", step, fileID); + if (!fs::copy_file(metaFilePath, metaFileToSave)) { + throw runtime_error("fail to Save latest meta"); + } + + // re-open new meta file for next saving + localFileMeta.open(metaFilePath, ios::out | ios::trunc | ios::binary); + if (!localFileMeta.is_open()) { + throw runtime_error("fail to re-open meta file"); + } + + // Save data + VLOG(GLOG_DEBUG) << StringFormat("save latest data file at step:%d", step); + localFileData.flush(); + if (localFileData.fail()) { + throw runtime_error("fail to Save data"); + } + localFileData.close(); + + fs::path dataFileToSave = fs::absolute(saveDir + "/" + to_string(fileID) + ".data." + to_string(step)); + if (fs::exists(dataFileToSave)) { + throw invalid_argument("fail to save latest data, file already exist"); + } + if (!fs::copy_file(dataFilePath, dataFileToSave)) { + throw runtime_error("fail to Save latest data"); + } + + // re-open data file for other operation + localFileData.open(dataFilePath, ios::out | ios::in | ios::app | ios::binary); + if (!localFileData.is_open()) { + throw runtime_error("fail to re-open data file"); + } + + VLOG(GLOG_DEBUG) << StringFormat("end save file at step:%d, fileID:%llu", step, fileID); +} + +void File::Load() +{ + // file already validate and open in instantiation + VLOG(GLOG_DEBUG) << StringFormat("start reading meta file, fileID:%llu", fileID); + emb_key_t key; + offset_t offset; + do { + localFileMeta.read(reinterpret_cast(&key), keyDataLen); + if (!localFileMeta.eof() && localFileMeta.fail()) { + throw invalid_argument("file broken while reading key"); + } + + localFileMeta.read(reinterpret_cast(&offset), offsetDataLen); + if (!localFileMeta.eof() && localFileMeta.fail()) { + throw invalid_argument("file broken while reading offset"); + } + keyToOffset[key] = offset; + dataCnt += 1; + + // try reading one byte to see if pointer reach end + localFileMeta.get(); + if (localFileMeta.eof()) { + break; + } + localFileMeta.unget(); + } while (!localFileMeta.eof()); + localFileMeta.close(); + + // re-open new meta file for next saving + localFileMeta.open(metaFilePath, ios::out | ios::trunc | ios::binary); + if (!localFileMeta.is_open()) { + throw runtime_error("fail to re-open meta file"); + } + + VLOG(GLOG_DEBUG) << StringFormat("end reading meta file, fileID:%llu", fileID); +} + +vector File::GetKeys() +{ + vector ret; + for (auto item: keyToOffset) { + ret.push_back(item.first); + } + return ret; +} + +uint64_t File::GetDataCnt() +{ + return dataCnt; +} + +uint64_t File::GetFileID() +{ + return fileID; +} + +uint64_t File::GetStaleDataCnt() +{ + return staleDataCnt; +} diff --git a/src/core/ssd_engine/file.h b/src/core/ssd_engine/file.h new file mode 100644 index 00000000..2ccf1246 --- /dev/null +++ b/src/core/ssd_engine/file.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ +#ifndef MXREC_FILE_H +#define MXREC_FILE_H + +#include +#include +#include +#include +#include +#include +#include + +#include "utils/common.h" + +namespace MxRec { + using namespace std; + namespace fs = std::experimental::filesystem; + + using offset_t = uint32_t; + + class File { + static const uint64_t keyDataLen = sizeof(emb_key_t); + static const uint64_t offsetDataLen = sizeof(offset_t); + + public: + File(uint64_t fileID, string &saveDir); + + File(uint64_t fileID, string &saveDir, int step); // initialize with loading specific step data + + ~File(); + + bool IsKeyExist(emb_key_t key); + + void InsertEmbeddings(vector &keys, vector> &embeddings); + + vector> FetchEmbeddings(vector &keys); + + void DeleteEmbedding(emb_key_t key); + + void Save(int step); + + vector GetKeys(); + + uint64_t GetDataCnt(); + + uint64_t GetFileID(); + + uint64_t GetStaleDataCnt(); + + private: + uint64_t fileID; // init by constructor + string saveDir; // init by constructor + fs::path dataFilePath = ""; + fs::path metaFilePath = ""; + fstream localFileData{}; + fstream localFileMeta{}; + + uint64_t dataCnt = 0; + uint64_t staleDataCnt = 0; + unordered_map keyToOffset{}; // offset_t >> maxDataNumInFile * embDataSize + offset_t lastWriteOffset = 0; + + void Load(); + }; +} + +#endif // MXREC_FILE_H diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index dd15e8b2..74245cc1 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -180,4 +180,23 @@ namespace MxRec { } return threadNum; } + + void ValidateReadFile(const string& dataDir, size_t datasetSize) + { + // validate soft link + struct stat fileInfo; + if (lstat(dataDir.c_str(), &fileInfo) != -1) { + if (S_ISLNK(fileInfo.st_mode)) { + LOG(ERROR) << StringFormat("soft link %s should not in the path parameter", dataDir.c_str()); + throw invalid_argument(StringFormat("soft link should not be the path parameter")); + } + } + // validate file size + if (datasetSize <= FILE_MIN_SIZE || datasetSize > FILE_MAX_SIZE) { + LOG(ERROR) << StringFormat("the reading file size is invalid, " + "not in range (%d,%d]", FILE_MIN_SIZE, FILE_MAX_SIZE); + throw invalid_argument(StringFormat("file size invalid")); + } + } + } // end namespace MxRec diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 9c9cc2f3..5b02248a 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -62,6 +62,7 @@ namespace MxRec { // for GLOG extern int g_glogLevel; + extern string g_rankId; constexpr int GLOG_MAX_BUF_SIZE = 1024; constexpr int GLOG_TIME_WIDTH_2 = 2; constexpr int GLOG_TIME_WIDTH_6 = 6; @@ -345,6 +346,8 @@ struct BatchTask { return ss.str(); } + void ValidateReadFile(const string& dataDir, size_t datasetSize); + inline void GenerateRandomValue(std::vector& vecData, std::default_random_engine& generator, RandomInfo& randomInfo) diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 1b672523..302d7bc9 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -13,6 +13,8 @@ message("MXREC_TEST_SRC: " ${MXREC_TEST_SRC}) set(CMAKE_CXX_FLAGS "--coverage") +link_libraries(stdc++fs) + add_executable(test_main ${MXREC_SRC} ${MXREC_TEST_SRC}) if(NOT SECUREC_PATH) diff --git a/src/tests/ssd_engine/file_test.cpp b/src/tests/ssd_engine/file_test.cpp new file mode 100644 index 00000000..7eaac073 --- /dev/null +++ b/src/tests/ssd_engine/file_test.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + + +#include +#include + +#include "utils/common.h" +#include "ssd_engine/file.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +TEST(File, CreateEmptyFile) +{ + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + g_rankId = to_string(rankId); + + string savePath = g_rankId; + bool isExceptionThrown = false; + try { + auto f = make_shared(0, savePath); + } catch (runtime_error &e) { + isExceptionThrown = true; + LOG(ERROR) << e.what(); + } + ASSERT_EQ(isExceptionThrown, false); + fs::remove_all(savePath); +} + +TEST(File, LoadFromFile) +{ + // prepare + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + g_rankId = to_string(rankId); + + string savePath = g_rankId; + if (!fs::exists(fs::absolute(savePath))) { + if (!fs::create_directories(fs::absolute(savePath))) { + throw runtime_error("fail to create Save directory"); + } + } + + emb_key_t key = 0; + offset_t offset = 0; + vector val = {1.0}; + + fstream localFileMeta; + localFileMeta.open(savePath + "/0.meta.0", ios::out | ios::trunc | ios::binary); + localFileMeta.write(reinterpret_cast(&key), sizeof(key)); + localFileMeta.write(reinterpret_cast(&offset), sizeof(offset)); + localFileMeta.flush(); + if (localFileMeta.fail()) { + throw runtime_error("fail to prepare meta file"); + } + localFileMeta.close(); + + fstream localFileData; + localFileData.open(savePath + "/0.data.0", ios::out | ios::trunc | ios::binary); + uint64_t embSize = val.size(); + localFileData.write(reinterpret_cast(&embSize), sizeof(embSize)); + localFileData.write(reinterpret_cast(val.data()), val.size() * sizeof(float)); + localFileData.flush(); + if (localFileData.fail()) { + throw runtime_error("fail to prepare data file"); + } + localFileData.close(); + + // start test + bool isExceptionThrown = false; + try { + auto f = make_shared(0, savePath, 0); + } catch (runtime_error &e) { + LOG(ERROR) << e.what(); + isExceptionThrown = true; + } + ASSERT_EQ(isExceptionThrown, false); + fs::remove_all(savePath); +} + +TEST(File, WriteAndRead) +{ + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + g_rankId = to_string(rankId); + + string savePath = g_rankId; + auto f = make_shared(0, savePath); + + vector keys; + vector> embeddings; + for (emb_key_t k = 0; k < 10; k++) { + keys.emplace_back(k); + vector emb = {static_cast(k + 0.1), static_cast(k + 0.2)}; + embeddings.emplace_back(emb); + } + + f->InsertEmbeddings(keys, embeddings); + auto ret = f->FetchEmbeddings(keys); + ASSERT_EQ(embeddings, ret); + + + f->DeleteEmbedding(0); + ASSERT_EQ(f->IsKeyExist(0), false); + + fs::remove_all(savePath); +} + +TEST(File, SaveAndLoad) +{ + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + g_rankId = to_string(rankId); + + int saveStep = 0; + string savePath = g_rankId; + auto fTmp = make_shared(0, savePath); + + vector key = {0}; + vector> expect = {{1.0, 1.1}}; + fTmp->InsertEmbeddings(key, expect); + fTmp->Save(saveStep); + + auto fLoad = make_shared(0, savePath, saveStep); + auto actual = fLoad->FetchEmbeddings(key); + ASSERT_EQ(expect, actual); + + fs::remove_all(savePath); +} -- Gitee From 69fedc5462e4f9b25667bcfea7d113b8d149bb67 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 24 Aug 2023 16:49:57 +0800 Subject: [PATCH 279/551] Match-id-081469d442649adb3ed1a4b7e9374d27652dcfb8 --- mx_rec/core/asc/merge_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index 1d314e34..34e3d9ff 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -14,7 +14,7 @@ from mx_rec.util.initialize import get_enable_table_merge, export_table_instance def affirm(reach_op:List[Operation]) -> bool: for node in reach_op: - if node.type not in ("IdentityN", "Reshape"): + if node.type not in ("IdentityN", "Reshape", "Identity"): return False return True -- Gitee From 2caee34bb078e9055e020c2f80b0793bdd9cf5de Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 24 Aug 2023 17:32:15 +0800 Subject: [PATCH 280/551] Match-id-2a59c5aab1281baf5a32921c89f850a872eb3975 --- mx_rec/core/embedding.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 014f6990..efb47985 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -17,8 +17,8 @@ from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2 from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temporary_feature_spec_attribute from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, MxRecMode, \ - ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, \ +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS,\ + MxRecMode, ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, \ MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ @@ -655,8 +655,7 @@ class SparseEmbedding: if not use_dynamic_expansion: id_offsets_abs = tf.abs(id_offsets) local_embeddings = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") - local_embeddings = set_zero_for_non_valid_key(id_offsets, local_embeddings, - feature_spec.access_threshold) + local_embeddings = set_zero_for_non_valid_key(id_offsets, local_embeddings, feature_spec.access_threshold) else: local_embeddings = tf.identity(table, name="identity_local_emb") @@ -689,8 +688,7 @@ class SparseEmbedding: logging.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) - unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, - restore_vector, + unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, unique_embeddings_shape[0]) bp_all2all_args = all2all_args if use_static else tf.transpose(all2all_args) if hot_pos is not None: @@ -712,13 +710,11 @@ class SparseEmbedding: update_grad = local_grad else: if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: - unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - restore_vector_second, + unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, restore_vector_second, array_ops.shape(unique_keys)[0]) update_grad = ops.IndexedSlices(values=unique_local_grad, indices=unique_keys, dense_shape=tf.shape(table)) else: - update_grad = ops.IndexedSlices(values=local_grad, indices=id_offsets, dense_shape=tf.shape(table)) return update_grad @@ -736,7 +732,6 @@ class SparseEmbedding: is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is None or \ ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name if is_training and is_table_name_valid: - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, unique_keys) -- Gitee From 7a7b0c3f687ddb536de12d63762537e37ce5ba10 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 10:42:41 +0800 Subject: [PATCH 281/551] Match-id-efb98e4d304255efe93e2f5f5c1a893e58052cb0 --- src/core/ssd_engine/table.cpp | 351 ++++++++++++++++++++++++++++ src/core/ssd_engine/table.h | 81 +++++++ src/tests/ssd_engine/table_test.cpp | 129 ++++++++++ 3 files changed, 561 insertions(+) create mode 100644 src/core/ssd_engine/table.cpp create mode 100644 src/core/ssd_engine/table.h create mode 100644 src/tests/ssd_engine/table_test.cpp diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp new file mode 100644 index 00000000..9d5163b4 --- /dev/null +++ b/src/core/ssd_engine/table.cpp @@ -0,0 +1,351 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + + +#include "table.h" +#include "utils/common.h" + +using namespace MxRec; + +/// 创建新表 +/// \param name 表名 +/// \param savePaths 表的存储路径 +/// \param maxTableSize 表的最大空间,按key数量计 +/// \param compactThreshold 表的压缩阈值,当无效数据占比超阈值时,文件会被清理 +Table::Table(const string &name, vector &savePaths, uint64_t maxTableSize, double compactThreshold) + : name(name), + savePaths(savePaths), + maxTableSize(maxTableSize), + compactThreshold(compactThreshold) +{ + curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + g_rankId + "/" + name).string(); + if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { + throw runtime_error("fail to create table directory"); + } + LOG(INFO) << StringFormat("create table:%s at path:%s", name.c_str(), curTablePath.c_str()); +} + +/// 加载表 +/// \param name 表名 +/// \param savePaths 表的存储路径 +/// \param maxTableSize 表的最大空间,按key数量计 +/// \param compactThreshold 表的压缩阈值,当无效数据占比超阈值时,文件会被清理 +/// \param step 加载的步数 +Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize, double compactThreshold, int step) + : name(name), + savePaths(saveDirs), + maxTableSize(maxTableSize), + compactThreshold(compactThreshold) +{ + bool isMetaFileFound = false; + for (const string &dirPath: saveDirs) { + auto metaFilePath = fs::absolute( + dirPath + "/" + g_rankId + "/" + name + "/" + name + ".meta" + "." + to_string(step)).string(); + if (!fs::exists(metaFilePath)) { + continue; + } + Load(metaFilePath, step); + isMetaFileFound = true; + break; + } + if (!isMetaFileFound) { + throw invalid_argument("table meta file not found"); + } + + // always use first path to save until it's full + curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + g_rankId + "/" + name).string(); + LOG(INFO) << StringFormat("load table:%s done. try store at path:%s", name.c_str(), curTablePath.c_str()); +} + +bool Table::IsKeyExist(emb_key_t key) +{ + lock_guard guard(rwLock); + auto it = keyToFile.find(key); + return !(it == keyToFile.end()); +} + +void Table::InsertEmbeddings(vector &keys, vector> &embeddings) +{ + lock_guard guard(rwLock); + InsertEmbeddingsInner(keys, embeddings); +} + +vector> Table::FetchEmbeddings(vector &keys) +{ + lock_guard guard(rwLock); + return FetchEmbeddingsInner(keys); +} + + +void Table::DeleteEmbeddings(vector &keys) +{ + lock_guard guard(rwLock); + DeleteEmbeddingsInner(keys); +} + +void Table::Save(int step) +{ + LOG(INFO) << StringFormat("start save table:%s, at step:%d", name.c_str(), step); + Compact(true); + + lock_guard guard(rwLock); + auto metaFilePath = fs::absolute(curTablePath + "/" + name + ".meta" + "." + to_string(step)); + if (fs::exists(metaFilePath)) { + throw invalid_argument("fail to save table meta, file already exist"); + } + + fstream metaFile; + metaFile.open(metaFilePath, ios::out | ios::trunc | ios::binary); + if (!metaFile.is_open()) { + throw runtime_error("fail to create table meta file"); + } + + // dump table name + uint32_t nameSize = static_cast(name.size()); + metaFile.write(reinterpret_cast(&nameSize), sizeof(nameSize)); + metaFile.write(name.c_str(), nameSize); + + // dump file ID + uint64_t fileCnt = fileSet.size(); + metaFile.write(reinterpret_cast(&fileCnt), sizeof(fileCnt)); + for (const auto &f: fileSet) { + uint64_t fid = f->GetFileID(); + metaFile.write(reinterpret_cast(&fid), sizeof(fid)); + f->Save(step); + } + + metaFile.flush(); + if (metaFile.fail()) { + throw runtime_error("fail to Save table meta file"); + } + + metaFile.close(); + LOG(INFO) << StringFormat("end save table:%s, at step:%d", name.c_str(), step); +} + +/// 根据元数据加载data文件 +/// \param metaFile 元数据文件 +/// \param step 加载的步数 +void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) +{ + LOG(INFO) << StringFormat("table:%s, start load data file", name.c_str()); + uint64_t fileCnt; + metaFile->read(reinterpret_cast(&fileCnt), sizeof(fileCnt)); + uint64_t fileID; + uint64_t fidSize = sizeof(fileID); + for (uint64_t i = 0; i < fileCnt; ++i) { + metaFile->read(reinterpret_cast(&fileID), fidSize); + if (fileID > curMaxFileID) { + curMaxFileID = fileID; + } + + bool isFileFound = false; + shared_ptr tmp; + for (const string &p: savePaths) { + // try to find data file from each path + string dataPath = p + "/" + g_rankId + "/" + name; + try { + tmp = make_shared(fileID, dataPath, step); + fileSet.insert(tmp); + isFileFound = true; + break; + } catch (invalid_argument &e) { + // do nothing because file may in other path + } + } + if (!isFileFound) { + throw invalid_argument("data file not found"); + } + + auto keys = tmp->GetKeys(); + totalKeyCnt += keys.size(); + if (totalKeyCnt > maxTableSize) { + throw invalid_argument("table size too small, key quantity exceed while loading data"); + } + + for (emb_key_t k: keys) { + if (keyToFile.find(k) != keyToFile.end()) { + throw invalid_argument( + "find duplicate key in files, compaction already done before saving, file may broken or modified"); + } + keyToFile[k] = tmp; + } + } + curMaxFileID += 1; +} + + +void Table::Load(const string &metaFilePath, int step) +{ + ValidateReadFile(metaFilePath, fs::file_size(metaFilePath)); + + shared_ptr metaFile = make_shared(); + metaFile->open(metaFilePath, ios::in | ios::binary); + LOG(INFO) << StringFormat("table:%s, load meta file from path:%s", name.c_str(), metaFilePath.c_str()); + if (!metaFile->is_open()) { + throw invalid_argument("fail to open meta"); + } + + // Load table name and validate + uint32_t nameSize; + metaFile->read(reinterpret_cast(&nameSize), sizeof(nameSize)); + char *tmpArr = new char[nameSize + 1]; + metaFile->read(tmpArr, static_cast(nameSize)); + tmpArr[nameSize] = '\0'; + string tbNameInFile = tmpArr; + if (name != tbNameInFile) { + throw invalid_argument("table name not match"); + } + + // construct file set + LoadDataFileSet(metaFile, step); + metaFile->close(); + if (metaFile->fail()) { + throw runtime_error("fail to load table"); + } + LOG(INFO) << StringFormat("table:%s, end load data file", name.c_str()); +} + +void Table::InsertEmbeddingsInner(vector &keys, vector> &embeddings) +{ + if (totalKeyCnt > maxTableSize) { + throw invalid_argument("table size too small, key quantity exceed while loading data"); + } + + if (curFile == nullptr || (curFile != nullptr && curFile->GetDataCnt() >= maxDataNumInFile)) { + // leave diskFreeSpaceThreshold % space for each disk + while (true) { + fs::space_info si = fs::space((curTablePath)); + if ((double(si.free) / double(si.capacity)) > diskFreeSpaceThreshold) { + break; + } + + curSavePathIdx += 1; + if (curSavePathIdx >= savePaths.size()) { + throw runtime_error("all disk's space not enough"); + } + curTablePath = savePaths[curSavePathIdx]; + LOG(INFO) << StringFormat( + "current data path's free space less than %f, try next path:%s", + diskFreeSpaceThreshold, curTablePath.c_str() + ); + } + + curFile = make_shared(curMaxFileID, curTablePath); + fileSet.insert(curFile); + curMaxFileID++; + } + + for (emb_key_t k: keys) { + auto it = keyToFile.find(k); + if (it != keyToFile.end()) { + it->second->DeleteEmbedding(k); + staleDataFileSet.insert(it->second); + totalKeyCnt -= 1; + } + keyToFile[k] = curFile; + } + curFile->InsertEmbeddings(keys, embeddings); + totalKeyCnt += keys.size(); +} + +vector> Table::FetchEmbeddingsInner(vector &keys) +{ + // build mini batch for each file, first element for keys, second for index + size_t dLen = keys.size(); + unordered_map, shared_ptr, vector>>> miniBatch; + for (size_t i = 0; i < dLen; ++i) { + auto it = keyToFile.find(keys[i]); + if (miniBatch[it->second] == nullptr) { + miniBatch[it->second] = make_shared, vector>>(); + } + miniBatch[it->second]->first.emplace_back(keys[i]); + miniBatch[it->second]->second.emplace_back(i); + } + + // must convert map to list to perform parallel query, omp not support to iterate map + vector, vector, vector>> queryList; + queryList.reserve(miniBatch.size()); + for (auto [f, info]: miniBatch) { + queryList.emplace_back(f, info->first, info->second); + } + + // read in parallel + vector> ret; + ret.resize(dLen); + size_t queryLen = queryList.size(); +#pragma omp parallel for num_threads(readThreadNum) default(none) shared(ret, queryLen, queryList) + for (size_t i = 0; i < queryLen; ++i) { + tuple item = queryList[i]; + shared_ptr f; + vector batchKeys; + vector batchIdx; + tie(f, batchKeys, batchIdx) = item; + vector> batchRet = f->FetchEmbeddings(batchKeys); + size_t batchLen = batchRet.size(); + for (size_t j = 0; j < batchLen; ++j) { + ret[batchIdx[j]] = batchRet[j]; + } + } + return ret; +} + +/// 整理数据,将有效数据转移至新文件后,含无效数据的文件将被删除 +/// \param fullCompact 是否执行全量数据清理 +void Table::Compact(bool fullCompact) +{ + lock_guard guard(rwLock); + + if (staleDataFileSet.empty()) { + return; + } + + VLOG(GLOG_DEBUG) << StringFormat("table:%s, start compact", name.c_str()); + + vector> compactFileList; + for (const auto &f: staleDataFileSet) { + if (fullCompact) { + compactFileList.emplace_back(f); + continue; + } + if (double(f->GetDataCnt()) * compactThreshold < double(f->GetStaleDataCnt())) { + compactFileList.emplace_back(f); + } + } + + // always move valid data to new file to avoid repeated compaction + if (curFile->GetStaleDataCnt() > 0) { + curFile = make_shared(curMaxFileID, curTablePath); + fileSet.insert(curFile); + curMaxFileID++; + } + + for (const auto &f: compactFileList) { + staleDataFileSet.erase(f); + fileSet.erase(f); + vector validKeys = f->GetKeys(); + vector> validEmbs = f->FetchEmbeddings(validKeys); + InsertEmbeddingsInner(validKeys, validEmbs); + } + VLOG(GLOG_DEBUG) << StringFormat("table:%s, end compact", name.c_str()); +} + +uint64_t Table::GetTableAvailableSpace() +{ + lock_guard guard(rwLock); + return maxTableSize - totalKeyCnt; +} + +void Table::DeleteEmbeddingsInner(vector &keys) +{ + for (emb_key_t k: keys) { + auto it = keyToFile.find(k); + if (it != keyToFile.end()) { + it->second->DeleteEmbedding(k); + staleDataFileSet.insert(it->second); + keyToFile.erase(k); + totalKeyCnt -= 1; + } + } +} diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h new file mode 100644 index 00000000..de5327e5 --- /dev/null +++ b/src/core/ssd_engine/table.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ +#ifndef MXREC_TABLE_H +#define MXREC_TABLE_H + +#include +#include +#include +#include +#include +#include +#include + +#include "file.h" +#include "utils/common.h" + +namespace MxRec { + using namespace std; + + class Table { + public: + Table(const string &name, vector &savePaths, uint64_t maxTableSize, double compactThreshold); + + // initialize with loading specific step data + Table(const string &name, vector &saveDirs, uint64_t maxTableSize, double compactThreshold, int step); + + bool IsKeyExist(emb_key_t key); + + void InsertEmbeddings(vector &keys, vector> &embeddings); + + vector> FetchEmbeddings(vector &keys); + + void DeleteEmbeddings(vector &keys); + + void Save(int step); + + uint64_t GetTableAvailableSpace(); + + void Compact(bool fullCompact); + + private: + void Load(const string& metaFilePath, int step); + + void InsertEmbeddingsInner(vector &keys, vector> &embeddings); + + void DeleteEmbeddingsInner(vector &keys); + + vector> FetchEmbeddingsInner(vector &keys); + + void LoadDataFileSet(const shared_ptr& metaFile, int step); + + string name; // init by constructor + vector savePaths; // init by constructor, support Save and Load from multiple path + uint64_t maxTableSize; // init by constructor, maximum key-value volume + uint64_t totalKeyCnt = 0; + unordered_map> keyToFile{}; // max mem cost 1.5G*2 for 100m keys + set> staleDataFileSet{}; + string curTablePath = ""; + uint32_t curSavePathIdx = 0; + set> fileSet{}; + mutex rwLock{}; + shared_ptr curFile = nullptr; + uint64_t curMaxFileID = 0; // no concurrent writing, always atomic increase + + /* args for performance + * 2 read thread is optimal when: + * embedding's dimension=240, maxDataNumInFile=10000 + * fetch 1000000 keys at a time + * QPS(get n embedding per second) reach 109685 + * when maxDataNumInFile=10000: + * QPS(write n embedding per second) reach 194212 + */ + int readThreadNum = 2; + uint32_t maxDataNumInFile = 10000; // relax constrain for performance, need tuning + double compactThreshold = 0.5; + double diskFreeSpaceThreshold = 0.05; + }; +} + +#endif // MXREC_TABLE_H diff --git a/src/tests/ssd_engine/table_test.cpp b/src/tests/ssd_engine/table_test.cpp new file mode 100644 index 00000000..af6f58a1 --- /dev/null +++ b/src/tests/ssd_engine/table_test.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +#include +#include + +#include "utils/common.h" +#include "ssd_engine/table.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +TEST(Table, WriteAndReadAndDeleteAndCompact) +{ + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + g_rankId = to_string(rankId); + + string tbName = "test"; + vector savePath = {g_rankId}; + uint64_t maxTableSize = 1000000; + uint64_t embDim = 240; + double compactThreshold = 0.5; + + // create + auto tb = make_shared(tbName, savePath, maxTableSize, compactThreshold); + + // write + emb_key_t nData = 1000000; + emb_key_t batchSize = 10000; + vector allKeys; + vector> allEmbs; + vector batchKeys; + vector> batchEmbs; + + chrono::milliseconds writeCost = 0ms; + for (emb_key_t k = 0; k < nData; k++) { + vector emb; + emb.resize(embDim); + for (uint64_t i = 0; i < embDim; ++i) { + emb[i] = static_cast(k + float(i) / float(10)); + } + allKeys.emplace_back(k); + allEmbs.emplace_back(emb); + batchKeys.emplace_back(k); + batchEmbs.emplace_back(emb); + if ((k + 1) % batchSize == 0) { + auto start = chrono::high_resolution_clock::now(); + tb->InsertEmbeddings(batchKeys, batchEmbs); + auto end = chrono::high_resolution_clock::now(); + writeCost += chrono::duration_cast(end - start); + batchKeys.clear(); + batchEmbs.clear(); + } + } + + LOG(INFO) << "n data:" << nData << " ,batch size:" << batchSize << " ,write cost(ms):" << writeCost.count() + << " ,QPS:" << float(nData) * 1000 / writeCost.count(); + + // read + auto start = chrono::high_resolution_clock::now(); + auto ret = tb->FetchEmbeddings(allKeys); + auto end = chrono::high_resolution_clock::now(); + auto readCost = chrono::duration_cast(end - start); + LOG(INFO) << "n data:" << nData << " ,batch size:" << batchSize << " ,read cost(ms):" << readCost.count() + << " ,QPS:" << float(nData) * 1000 / readCost.count(); + ASSERT_EQ(allEmbs, ret); + + // check space + auto availSpace = tb->GetTableAvailableSpace(); + ASSERT_EQ(availSpace, maxTableSize - allKeys.size()); + + // delete + tb->DeleteEmbeddings(allKeys); + for (emb_key_t k: allKeys) { + ASSERT_EQ(tb->IsKeyExist(k), false); + } + + // full compact, old file will delete, valid data will move to new file + tb->Compact(true); + string oldDataFilePath = g_rankId + "/" + tbName + "/" + "0.data.latest"; + string oldMetaFilePath = g_rankId + "/" + tbName + "/" + "0.meta.latest"; + ASSERT_EQ(fs::exists(oldDataFilePath), false); + ASSERT_EQ(fs::exists(oldMetaFilePath), false); + + for (string p: savePath) { + fs::remove_all(p); + } +} + +TEST(Table, SaveAndLoad) +{ + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + g_rankId = to_string(rankId); + + string tbName = "test"; + vector savePath = {g_rankId}; + uint64_t maxTableSize = 100; + double compactThreshold = 0.5; + int saveStep = 0; + + // create + auto tbSave = make_shared
(tbName, savePath, maxTableSize, compactThreshold); + + // write and save + emb_key_t nData = 10; + vector keys; + vector> embs; + for (emb_key_t k = 0; k < nData; k++) { + vector emb = {static_cast(k + 0.1), static_cast(k + 0.2)}; + keys.emplace_back(k); + embs.emplace_back(emb); + } + tbSave->InsertEmbeddings(keys, embs); + tbSave->Save(saveStep); + + // load + auto tbLoad = make_shared
(tbName, savePath, maxTableSize, compactThreshold, saveStep); + auto ret = tbLoad->FetchEmbeddings(keys); + + ASSERT_EQ(embs, ret); + + for (const string &p: savePath) { + fs::remove_all(p); + } +} \ No newline at end of file -- Gitee From 53bb59f19ccaceb843262e8057e811b5dc68a884 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 11:02:10 +0800 Subject: [PATCH 282/551] Match-id-6070461f6ef5b0da7c742f79686bd89746a3902a --- src/core/ssd_engine/ssd_engine.cpp | 186 +++++++++++++++++++++++++++ src/core/ssd_engine/ssd_engine.h | 56 ++++++++ src/tests/ssd_engine/engine_test.cpp | 135 +++++++++++++++++++ 3 files changed, 377 insertions(+) create mode 100644 src/core/ssd_engine/ssd_engine.cpp create mode 100644 src/core/ssd_engine/ssd_engine.h create mode 100644 src/tests/ssd_engine/engine_test.cpp diff --git a/src/core/ssd_engine/ssd_engine.cpp b/src/core/ssd_engine/ssd_engine.cpp new file mode 100644 index 00000000..169e785b --- /dev/null +++ b/src/core/ssd_engine/ssd_engine.cpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ +#include "ssd_engine.h" + +using namespace MxRec; +using namespace std; + +bool SSDEngine::IsTableExist(const string &tableName) +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + auto it = tableMap.find(tableName); + return !(it == tableMap.end()); +} + +bool SSDEngine::IsKeyExist(const string &tableName, emb_key_t key) +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + auto it = tableMap.find(tableName); + if (it == tableMap.end()) { + throw invalid_argument("table not found"); + } + return it->second->IsKeyExist(key); +} + +void SSDEngine::CreateTable(const string &tableName, vector savePaths, uint64_t maxTableSize) +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + if (savePaths.empty()) { + throw invalid_argument("SSDEngine input savePaths is empty"); + } + auto it = tableMap.find(tableName); + if (it != tableMap.end()) { + throw invalid_argument("table already exist"); + } + tableMap[tableName] = make_shared
(tableName, savePaths, maxTableSize, compactThreshold); +} + +void SSDEngine::InsertEmbeddings(const string &tableName, vector &keys, vector> &embeddings) +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + auto it = tableMap.find(tableName); + if (it == tableMap.end()) { + throw invalid_argument("table not found"); + } + + if (keys.size() != embeddings.size()) { + throw invalid_argument("keys' length not equal to embeddings' length"); + } + + it->second->InsertEmbeddings(keys, embeddings); +} + +void SSDEngine::DeleteEmbeddings(const string &tableName, vector &keys) +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + auto it = tableMap.find(tableName); + if (it == tableMap.end()) { + throw invalid_argument("table not found"); + } + + it->second->DeleteEmbeddings(keys); +} + +int64_t SSDEngine::GetTableAvailableSpace(const string &tableName) +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + auto it = tableMap.find(tableName); + if (it == tableMap.end()) { + throw invalid_argument("table not found"); + } + + return it->second->GetTableAvailableSpace(); +} + +void SSDEngine::Save(int step) +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + for (auto item: tableMap) { + item.second->Save(step); + } +} + +void SSDEngine::Load(const string &tableName, vector savePaths, uint64_t maxTableSize, int step) +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + auto it = tableMap.find(tableName); + if (it != tableMap.end()) { + throw invalid_argument("table already exist"); + } + + tableMap[tableName] = make_shared
(tableName, savePaths, maxTableSize, compactThreshold, step); +} + +void SSDEngine::Start() +{ + if (isRunning) { + return; + } + isRunning = true; + compactThread = make_shared([this] { CompactMonitor(); }); +} + +/// 压缩监控方法,达到检查周期时调用表的压缩接口 +void SSDEngine::CompactMonitor() +{ + VLOG(GLOG_DEBUG) << "SSDEngine start CompactMonitor"; + auto start = chrono::high_resolution_clock::now(); + auto end = chrono::high_resolution_clock::now(); + chrono::microseconds loopDuration = 100ms; + chrono::seconds duration; + while (isRunning) { + duration = chrono::duration_cast(end - start); + if (duration >= compactPeriod) { + VLOG(GLOG_DEBUG) << "SSDEngine CompactMonitor start compact"; + for (const auto &item: tableMap) { + item.second->Compact(false); + } + VLOG(GLOG_DEBUG) << "SSDEngine CompactMonitor end compact"; + start = chrono::high_resolution_clock::now(); + } + this_thread::sleep_for(loopDuration); + end = chrono::high_resolution_clock::now(); + } + VLOG(GLOG_DEBUG) << "SSDEngine end CompactMonitor"; +} + +vector> SSDEngine::FetchEmbeddings(const string &tableName, vector &keys) +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + auto it = tableMap.find(tableName); + if (it == tableMap.end()) { + throw invalid_argument("table not found"); + } + + return it->second->FetchEmbeddings(keys); +} + +void SSDEngine::Stop() +{ + if (!isRunning) { + throw invalid_argument("SSDEngine not running"); + } + isRunning = false; + compactThread->join(); + tableMap.clear(); + compactThread = nullptr; + + LOG(INFO) << "SSDEngine stop"; +} + +/// 设置文件压缩的周期 +/// \param seconds 文件压缩的周期 +void SSDEngine::SetCompactPeriod(chrono::seconds seconds) +{ + compactPeriod = seconds; +} + +/// 设置文件压缩的阈值 +/// \param threshold 无效数据占比阈值 +void SSDEngine::SetCompactThreshold(double threshold) +{ + if (threshold >= 0 && threshold <= 1) { + compactThreshold = threshold; + return; + } + throw invalid_argument("compact threshold should in range [0, 1]"); +} \ No newline at end of file diff --git a/src/core/ssd_engine/ssd_engine.h b/src/core/ssd_engine/ssd_engine.h new file mode 100644 index 00000000..7a66c24f --- /dev/null +++ b/src/core/ssd_engine/ssd_engine.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ +#include "table.h" + +#include +#include +#include + +#include "utils/common.h" + + +namespace MxRec { + + class SSDEngine { + public: + bool IsTableExist(const string &tableName); + + bool IsKeyExist(const string &tableName, emb_key_t key); + + void CreateTable(const string &tableName, vector savePaths, uint64_t maxTableSize); + + int64_t GetTableAvailableSpace(const string &tableName); + + void InsertEmbeddings(const string &tableName, vector &keys, vector> &embeddings); + + void DeleteEmbeddings(const string &tableName, vector &keys); + + vector> FetchEmbeddings(const string &tableName, vector &keys); + + void Save(int step); + + void Load(const string &tableName, vector savePaths, uint64_t maxTableSize, int step); + + void Start(); + + void Stop(); + + void SetCompactPeriod(chrono::seconds seconds); + + void SetCompactThreshold(double threshold); + + private: + bool isRunning = false; + + // leave 50% space for stale data to avoid modification in file + double compactThreshold = 0.5; + chrono::seconds compactPeriod = chrono::seconds(60); + + map> tableMap{}; + shared_ptr compactThread = nullptr; + + void CompactMonitor(); + }; +} + diff --git a/src/tests/ssd_engine/engine_test.cpp b/src/tests/ssd_engine/engine_test.cpp new file mode 100644 index 00000000..b32aa626 --- /dev/null +++ b/src/tests/ssd_engine/engine_test.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +#include +#include + +#include "utils/common.h" +#include "ssd_engine/ssd_engine.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +TEST(SSDEngine, CreateAndWriteAndReadAndAutoCompactAndSave) +{ + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + g_rankId = to_string(rankId); + + string tbName = "test"; + vector savePath = {g_rankId}; + uint64_t maxTableSize = 100; + double compactThreshold = 0.5; + chrono::seconds compactPeriod = chrono::seconds(5); + int saveStep = 0; + + // create and start + SSDEngine *eng = new SSDEngine(); + eng->SetCompactThreshold(compactThreshold); + eng->SetCompactPeriod(compactPeriod); + eng->Start(); + eng->CreateTable(tbName, savePath, maxTableSize); + + // check table + ASSERT_EQ(eng->IsTableExist(tbName), true); + + // write + vector keys; + vector> embeddings; + for (emb_key_t k = 0; k < 10; k++) { + keys.emplace_back(k); + vector emb = {static_cast(k + 0.1), static_cast(k + 0.2)}; + embeddings.emplace_back(emb); + } + eng->InsertEmbeddings(tbName, keys, embeddings); + + // read data + auto ret = eng->FetchEmbeddings(tbName, keys); + ASSERT_EQ(embeddings, ret); + + // check space + ASSERT_EQ(eng->GetTableAvailableSpace(tbName), maxTableSize - keys.size()); + + // delete and wait auto compact + vector deleteKeys = {0}; + eng->DeleteEmbeddings(tbName, deleteKeys); + this_thread::sleep_for(compactPeriod); + + // check space to see if stale data space released + ASSERT_EQ(eng->GetTableAvailableSpace(tbName), maxTableSize - keys.size() + deleteKeys.size()); + + // save + eng->Save(saveStep); + + eng->Stop(); + delete eng; + + // after saving, full compact will perform, old file will be deleted + string oldDataFilePath = g_rankId + "/" + tbName + "/" + "0.data.latest"; + string oldMetaFilePath = g_rankId + "/" + tbName + "/" + "0.meta.latest"; + ASSERT_EQ(fs::exists(oldDataFilePath), false); + ASSERT_EQ(fs::exists(oldMetaFilePath), false); + + // check saved data existence + string newDataFilePath = savePath.front() + "/" + g_rankId + "/" + tbName + "/" + "1.data." + to_string(saveStep); + string newMetaFilePath = savePath.front() + "/" + g_rankId + "/" + tbName + "/" + "1.meta." + to_string(saveStep); + string newTableMetaFilePath = + savePath.front() + "/" + g_rankId + "/" + tbName + "/" + tbName + ".meta." + to_string(saveStep); + ASSERT_EQ(fs::exists(newDataFilePath), true); + ASSERT_EQ(fs::exists(newMetaFilePath), true); + ASSERT_EQ(fs::exists(newTableMetaFilePath), true); + + for (string p: savePath) { + fs::remove_all(p); + } +} + +TEST(SSDEngine, LoadAndRead) +{ + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + g_rankId = to_string(rankId); + + string tbName = "test"; + vector savePath = {g_rankId}; + uint64_t maxTableSize = 100; + int saveStep = 0; + + // create and start + shared_ptr engSave = make_shared(); + chrono::seconds compactPeriod = chrono::seconds(5); + engSave->SetCompactPeriod(compactPeriod); + engSave->Start(); + engSave->CreateTable(tbName, savePath, maxTableSize); + + // write + vector keys; + vector> embeddings; + for (emb_key_t k = 0; k < 10; k++) { + keys.emplace_back(k); + vector emb = {static_cast(k + 0.1), static_cast(k + 0.2)}; + embeddings.emplace_back(emb); + } + engSave->InsertEmbeddings(tbName, keys, embeddings); + + // save + engSave->Save(saveStep); + engSave->Stop(); + + // load + shared_ptr engLoad = make_shared(); + engLoad->Start(); + engLoad->Load(tbName, savePath, maxTableSize, saveStep); + for (emb_key_t k: keys) { + ASSERT_EQ(engLoad->IsKeyExist(tbName, k), true); + } + auto ret = engLoad->FetchEmbeddings(tbName, keys); + ASSERT_EQ(embeddings, ret); + engLoad->Stop(); + + for (string p: savePath) { + fs::remove_all(p); + } +} \ No newline at end of file -- Gitee From 1466051b67b21dec4f1c8c1c43b88db5d7c29693 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 11:10:32 +0800 Subject: [PATCH 283/551] Match-id-d6d1677a8f80c4062b3bdb2c7f0da4d03d17122e --- src/core/ssd_engine/ssd_engine.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/core/ssd_engine/ssd_engine.h b/src/core/ssd_engine/ssd_engine.h index 7a66c24f..d0d4ee59 100644 --- a/src/core/ssd_engine/ssd_engine.h +++ b/src/core/ssd_engine/ssd_engine.h @@ -1,6 +1,9 @@ /* * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. */ +#ifndef MXREC_ENGINE_H +#define MXREC_ENGINE_H + #include "table.h" #include @@ -54,3 +57,4 @@ namespace MxRec { }; } +#endif // MXREC_ENGINE_H -- Gitee From 398b23e12ffd9c3181a9670de7315510935254ec Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 11:12:17 +0800 Subject: [PATCH 284/551] Match-id-c1d4bca58edc3e4c941160cd74d924e2bc411bbb --- build/build_tf2.sh | 24 ++++++++++----------- mx_rec/core/asc/build_graph.py | 9 +++++++- mx_rec/core/embedding.py | 24 +++++++++++++-------- mx_rec/graph/modifier.py | 2 +- src/core/emb_hashmap/emb_hashmap.cpp | 16 +++++++------- src/core/emb_hashmap/emb_hashmap.h | 3 ++- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 10 ++++----- src/core/hybrid_mgmt/hybrid_mgmt.h | 1 - src/core/utils/common.h | 31 ++++++++++++++++++---------- 9 files changed, 71 insertions(+), 49 deletions(-) diff --git a/build/build_tf2.sh b/build/build_tf2.sh index e9e7d05c..8f486ac5 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -12,13 +12,13 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")") ROOT_DIR=$(dirname "${SCRIPT_DIR}") cd "$SCRIPT_DIR" -if [ "$(uname -m)" = "x86_64" ] -then - virtualenv -p "$(which python3.7)" tf2_env - source /opt/buildtools/tf2_env/bin/activate - tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow - deactivate tf2_env -fi +#if [ "$(uname -m)" = "x86_64" ] +#then +# virtualenv -p "$(which python3.7)" tf2_env +# source /opt/buildtools/tf2_env/bin/activate +# tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow +# deactivate tf2_env +#fi if [ "$(uname -m)" = "aarch64" ] then @@ -143,13 +143,13 @@ then compile_acc_ctr_so_file echo "-----Build Start tf2 -----" - virtualenv -p "$(which python3.7)" tf2_env - source /opt/buildtools/tf2_env/bin/activate +# virtualenv -p "$(which python3.7)" tf2_env +# source /opt/buildtools/tf2_env/bin/activate compile_so_file "${tf2_path}" collect_so_file gen_wheel_file "${ROOT_DIR}"/tf2_whl - deactivate tf2_env +# deactivate tf2_env echo "-----Build tf2 finished -----" fi @@ -161,12 +161,12 @@ then compile_acc_ctr_so_file echo "-----Build Start tf2 -----" - source /opt/buildtools/tf2_env/bin/activate +# source /opt/buildtools/tf2_env/bin/activate compile_so_file "${tf2_path}" collect_so_file gen_wheel_file "${ROOT_DIR}"/tf2_whl - deactivate tf2_env +# deactivate tf2_env echo "-----Build tf2 finished -----" gen_tar_file echo "-----Build gen tar finished-----" diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 4b85317a..d2540082 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -102,11 +102,18 @@ def get_unique_keys(max_lookup_vec_size: int, config: dict) -> tf.Tensor: """ logging.debug(f'Channel {config.get("table_name")}_uniquekeys_{config.get("channel_id")} was built for getnext') with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + if config.get("use_dynamic_expansion"): + unique_keys = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int64], + output_shapes=[[max_lookup_vec_size]], + channel_name=f'{config.get("table_name")}_uniquekeys_{config.get("channel_id")}')[0] + return unique_keys + unique_keys = npu_ops.gen_npu_ops.get_next( output_types=[tf.int32], output_shapes=[[max_lookup_vec_size]], channel_name=f'{config.get("table_name")}_uniquekeys_{config.get("channel_id")}')[0] - return unique_keys + return unique_keys def get_all2all_args(use_static: bool, config: dict) -> list: diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index efb47985..0c831585 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -688,7 +688,8 @@ class SparseEmbedding: logging.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) - unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, + unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, + restore_vector, unique_embeddings_shape[0]) bp_all2all_args = all2all_args if use_static else tf.transpose(all2all_args) if hot_pos is not None: @@ -698,24 +699,28 @@ class SparseEmbedding: local_grad = self._get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: try: - local_grad = local_grad / get_rank_sizemin () + local_grad = local_grad / get_rank_size() except ZeroDivisionError as exp: raise ZeroDivisionError("Rank size cannot be zero.") from exp if use_dynamic_expansion: if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: - update_grad = tf.compat.v1.unsorted_segment_sum(local_grad, restore_vector_second, + update_grad = tf.compat.v1.unsorted_segment_sum(local_grad, + restore_vector_second, array_ops.shape(unique_keys)[0]) else: update_grad = local_grad else: if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: - unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, restore_vector_second, + unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, + restore_vector_second, array_ops.shape(unique_keys)[0]) - update_grad = ops.IndexedSlices(values=unique_local_grad, indices=unique_keys, + update_grad = ops.IndexedSlices(values=unique_local_grad, + indices=unique_keys, dense_shape=tf.shape(table)) else: - update_grad = ops.IndexedSlices(values=local_grad, indices=id_offsets, + update_grad = ops.IndexedSlices(values=local_grad, + indices=id_offsets, dense_shape=tf.shape(table)) return update_grad @@ -726,12 +731,11 @@ class SparseEmbedding: return sparse_forward(self.variable) local_embeddings = \ - host_pipeline_ops.embedding_lookup_by_address(id_offsets, embedding_dim=self.emb_size, - embedding_type=1) + host_pipeline_ops.embedding_lookup_by_address(id_offsets, embedding_dim=self.emb_size, embedding_type=1) is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is None or \ ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name - if is_training and is_table_name_valid: + def add_to_collection(): tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, unique_keys) @@ -740,6 +744,8 @@ class SparseEmbedding: logging.debug(f"feature spec mode, table_name: {self.table_name}, " f"ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") + if is_training and is_table_name_valid: + add_to_collection() return sparse_forward(local_embeddings) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 5eff82e8..3855162b 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -13,7 +13,7 @@ from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP, ApplyGradientsStrategy + ASCAnchorAttr, ASCEND_TIMESTAMP from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, increase_run_times, \ get_is_last_round, insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index a13ae0f1..1f8bbb4b 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -46,7 +46,7 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, } void EmbHashMap::Process(const string& embName, vector& keys, size_t iBatch, - vector& tmpDataOut, int channelId, vector& offsetsOut) + DDRParam& ddrParam, int channelId) { #ifndef GTEST EASY_FUNCTION(profiler::colors::Pink) @@ -69,12 +69,12 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t swapId++; EASY_BLOCK("hostHashMaps->tdt") - std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), std::back_inserter(offsetsOut)); + std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), std::back_inserter(ddrParam.offsetsOut)); auto lookUpVecSize = static_cast(embHashMap.lookUpVec.size()); - tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); + ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); - auto lookupTensorData = tmpDataOut.back().flat(); + auto lookupTensorData = ddrParam.tmpDataOut.back().flat(); for (int i = 0; i < lookUpVecSize; i++) { lookupTensorData(i) = static_cast(embHashMap.lookUpVec[i]); } @@ -82,9 +82,9 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t VLOG(GLOG_TRACE) << StringFormat("lookupTensor, %s", VectorToString(embHashMap.lookUpVec).c_str()); } auto swapSize = static_cast(embHashMap.swapPos.size()); - tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { swapSize })); + ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { swapSize })); - auto swapTensorData = tmpDataOut.back().flat(); + auto swapTensorData = ddrParam.tmpDataOut.back().flat(); for (int i = 0; i < swapSize; i++) { swapTensorData(i) = static_cast(embHashMap.swapPos[i]); } @@ -98,8 +98,8 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t embHashMap.lookUpVec.clear(); LOG(INFO) << StringFormat("current ddr emb:%s, usage:%d/[%d+%d]", embName.c_str(), embHashMap.maxOffset, embHashMap.devVocabSize, embHashMap.hostVocabSize); - tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); - auto swapLen = tmpDataOut.back().flat(); + ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); + auto swapLen = ddrParam.tmpDataOut.back().flat(); swapLen(0) = swapSize; EASY_END_BLOCK #endif diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 7c4a2d52..f9407415 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -13,6 +13,7 @@ #include #include "absl/container/flat_hash_map.h" #include "host_emb/host_emb.h" +#include "utils/common.h" namespace MxRec { using namespace std; @@ -24,7 +25,7 @@ namespace MxRec { void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); void Process(const string& embName, std::vector& keys, size_t iBatch, - vector& tmpDataOut, int channelId, vector& offsetsOut); + DDRParam& ddrParam, int channelId); void FindAndUpdateOffset(const string& embName, vector& keys, size_t currentBatchId, size_t keepBatchId, int channelId); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index b3a26728..64e7e030 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -8,7 +8,6 @@ #include "checkpoint/checkpoint.h" #include "utils/time_cost.h" -#include "utils/common.h" using namespace MxRec; using namespace std; @@ -799,7 +798,8 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, vector tmpData; vector offsetsOut; TimeCost hostHashMapProcessTC; - hostHashMaps->Process(embName, lookupKeys, iBatch, tmpData, channelId, offsetsOut); + DDRParam ddrParam(tmpData, offsetsOut); + hostHashMaps->Process(embName, lookupKeys, iBatch, ddrParam, channelId); VLOG(GLOG_DEBUG) << StringFormat("hostHashMapProcessTC(ms):%d", hostHashMapProcessTC.ElapsedMS()); if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { @@ -817,9 +817,9 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, } TimeCost sendTensorsTC; - hdTransfer->Send(TransferChannel::LOOKUP, { tmpData.front() }, channelId, embName); - tmpData.erase(tmpData.begin()); - hdTransfer->Send(TransferChannel::SWAP, tmpData, channelId, embName); + hdTransfer->Send(TransferChannel::LOOKUP, { ddrParam.tmpDataOut.front() }, channelId, embName); + ddrParam.tmpDataOut.erase(ddrParam.tmpDataOut.begin()); + hdTransfer->Send(TransferChannel::SWAP, ddrParam.tmpDataOut, channelId, embName); if (!mgmtRankInfo.useStatic) { auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 03ff37b2..c4ad2f55 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -14,7 +14,6 @@ #include #include #include "absl/container/flat_hash_map.h" -#include "utils/common.h" #include "utils/singleton.h" #include "utils/task_queue.h" #include "hd_transfer/hd_transfer.h" diff --git a/src/core/utils/common.h b/src/core/utils/common.h index e1c7fc75..7d38b064 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -173,21 +173,30 @@ namespace MxRec { time_t timestamp { -1 }; }; -struct BatchTask { - vector splits; - vector embNames; - size_t batchSize; - int batchQueueId; - int batchId; - int channelId; - time_t timestamp { -1 }; - bool flag; // true int64 false int32 - const void *tensor; -}; + struct BatchTask { + vector splits; + vector embNames; + size_t batchSize; + int batchQueueId; + int batchId; + int channelId; + time_t timestamp { -1 }; + bool flag; // true int64 false int32 + const void *tensor; + }; using emb_batch_t = Batch; using batch_task_t = BatchTask; + struct DDRParam { + vector tmpDataOut; + vector offsetsOut; + DDRParam(vector tmpData, vector offset) + { + tmpDataOut = tmpData; + offsetsOut = offset; + } + }; struct RankInfo { RankInfo() = default; -- Gitee From 9033b64d88c7d2b80aa2a0ad06cfa7bdab8b600e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 11:30:54 +0800 Subject: [PATCH 285/551] Match-id-6980490f0c6f373fd659a46faf7679bac32d36e2 --- build/build_tf2.sh | 22 +++++++++++----------- src/core/emb_hashmap/emb_hashmap.cpp | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/build/build_tf2.sh b/build/build_tf2.sh index 8f486ac5..593628f3 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -12,13 +12,13 @@ SCRIPT_DIR=$(dirname "$(readlink -f "$0")") ROOT_DIR=$(dirname "${SCRIPT_DIR}") cd "$SCRIPT_DIR" -#if [ "$(uname -m)" = "x86_64" ] -#then -# virtualenv -p "$(which python3.7)" tf2_env -# source /opt/buildtools/tf2_env/bin/activate -# tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow -# deactivate tf2_env -#fi +if [ "$(uname -m)" = "x86_64" ] +then + virtualenv -p "$(which python3.7)" tf2_env + source /opt/buildtools/tf2_env/bin/activate + tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow + deactivate tf2_env +fi if [ "$(uname -m)" = "aarch64" ] then @@ -143,8 +143,8 @@ then compile_acc_ctr_so_file echo "-----Build Start tf2 -----" -# virtualenv -p "$(which python3.7)" tf2_env -# source /opt/buildtools/tf2_env/bin/activate + virtualenv -p "$(which python3.7)" tf2_env + source /opt/buildtools/tf2_env/bin/activate compile_so_file "${tf2_path}" collect_so_file gen_wheel_file "${ROOT_DIR}"/tf2_whl @@ -161,12 +161,12 @@ then compile_acc_ctr_so_file echo "-----Build Start tf2 -----" -# source /opt/buildtools/tf2_env/bin/activate + source /opt/buildtools/tf2_env/bin/activate compile_so_file "${tf2_path}" collect_so_file gen_wheel_file "${ROOT_DIR}"/tf2_whl -# deactivate tf2_env + deactivate tf2_env echo "-----Build tf2 finished -----" gen_tar_file echo "-----Build gen tar finished-----" diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 1f8bbb4b..eca09417 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -64,10 +64,10 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t } else { FindOffset(embName, keys, swapId, keepBatch, channelId); } - VLOG(GLOG_DEBUG) << "FindOffset end"; +// VLOG(GLOG_DEBUG) << "FindOffset end"; swapId++; - EASY_BLOCK("hostHashMaps->tdt") +// EASY_BLOCK("hostHashMaps->tdt") std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), std::back_inserter(ddrParam.offsetsOut)); -- Gitee From 5e7246ed7991d2fee5fbc3295ae957ae9e9fdcbd Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 14:13:53 +0800 Subject: [PATCH 286/551] Match-id-c2613262be0eff38e5ba86bc65f718cc9c97527c --- mx_rec/graph/modifier.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 5b64f6c2..9c6e5a7c 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -14,8 +14,7 @@ from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP - ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME + ASCAnchorAttr, ASCEND_TIMESTAMP,ANCHOR_DATASET_NAME from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, increase_run_times, \ get_is_last_round, insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ -- Gitee From 7fd1fde16f5a01228e64b44973563a6b1396db96 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 14:53:01 +0800 Subject: [PATCH 287/551] Match-id-f201b5a05e1102b8403cabbb461d7f6cfb722f53 --- mx_rec/core/embedding.py | 10 ++++++---- mx_rec/graph/modifier.py | 2 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 11 ++++------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index e4c3e64c..de96e018 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -17,9 +17,9 @@ from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2 from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temporary_feature_spec_attribute from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS,\ - MxRecMode, ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, \ - MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy +from mx_rec.constants.constants import (ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET,\ + ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, MxRecMode, ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ + MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy) from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ @@ -660,7 +660,8 @@ class SparseEmbedding: if not use_dynamic_expansion: id_offsets_abs = tf.abs(id_offsets) local_embeddings = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") - local_embeddings = set_zero_for_non_valid_key(id_offsets, local_embeddings, feature_spec.access_threshold) + local_embeddings = set_zero_for_non_valid_key(id_offsets, local_embeddings, + feature_spec.access_threshold) else: local_embeddings = tf.identity(table, name="identity_local_emb") @@ -741,6 +742,7 @@ class SparseEmbedding: is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is None or \ ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name + def add_to_collection(): tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 9c6e5a7c..4d30a7e5 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -14,7 +14,7 @@ from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP,ANCHOR_DATASET_NAME + ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, increase_run_times, \ get_is_last_round, insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index d0cc9551..5c84308e 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -712,9 +712,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, // 获取查询向量 auto lookupKeys = preprocess->GetLookupKeys(batchId, embName, channelId); - if (lookupKeys.empty()) { - remainBatchOut = false; - } + if (lookupKeys.empty()) { remainBatchOut = false; } // 获取各类向量,如果为空指针,退出当前函数 auto infoVecs = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); @@ -728,8 +726,8 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, // 计算查询向量;记录需要被换出的HBM偏移 vector tmpData; vector offsetsOut; - TimeCost hostHashMapProcessTC; DDRParam ddrParam(tmpData, offsetsOut); + TimeCost hostHashMapProcessTC; hostHashMaps->Process(embName, lookupKeys, iBatch, ddrParam, channelId); VLOG(GLOG_DEBUG) << StringFormat("hostHashMapProcessTC(ms):%d", hostHashMapProcessTC.ElapsedMS()); @@ -761,9 +759,8 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, "getAndSendTensorsTC(ms):%d, channelId:%d", getAndSendTensorsTC.ElapsedMS(), channelId); if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch - LOG(WARNING) << StringFormat( - MGMT + "embName %s[%d]%d,iBatch:%d freeSize not enough, %d", embName.c_str(), channelId, - batchId, iBatch, lookupKeys.size()); + LOG(WARNING) << StringFormat(MGMT + "embName %s[%d]%d,iBatch:%d freeSize not enough, %d", + embName.c_str(), channelId, batchId, iBatch, lookupKeys.size()); return false; } return true; -- Gitee From d08c17eef227bb3f75ef02b4dd13acb75a4b52c3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 15:14:44 +0800 Subject: [PATCH 288/551] Match-id-b116b2b6fee0d3c1a355edf20fea504ade7bd7be --- mx_rec/core/embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index de96e018..e092a06e 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -613,6 +613,8 @@ class SparseEmbedding: is_train: name: not in use modify_graph: if True, the original graph will be modified before building a Session instance + + Returns: Tensors for lookup result """ logging.debug(f"Enter ASC Branch, looking up with FeatureSpec.") self.check_mode(MxRecMode.ASC) -- Gitee From eb7bb28a857ca74eaea0aca1f3bff486aba79ce9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 15:24:08 +0800 Subject: [PATCH 289/551] Match-id-47794915695daabd086cec6c4bb8da14ae562a69 --- build/build_tf2.sh | 2 +- src/core/emb_hashmap/emb_hashmap.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build/build_tf2.sh b/build/build_tf2.sh index 799929dc..a42aeeb2 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -149,7 +149,7 @@ then collect_so_file gen_wheel_file "${ROOT_DIR}"/tf2_whl -# deactivate tf2_env + deactivate tf2_env echo "-----Build tf2 finished -----" fi diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 56f04bfb..639522dc 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -71,10 +71,10 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t } else { FindOffset(embName, keys, swapId, keepBatch, channelId); } -// VLOG(GLOG_DEBUG) << "FindOffset end"; + VLOG(GLOG_DEBUG) << "FindOffset end"; swapId++; -// EASY_BLOCK("hostHashMaps->tdt") + EASY_BLOCK("hostHashMaps->tdt") std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), std::back_inserter(ddrParam.offsetsOut)); -- Gitee From f3a0678e87d86c85a1533632d2dd03e967b29ea6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 15:42:01 +0800 Subject: [PATCH 290/551] Match-id-c4e47f0d88daa8f9ac9c58ff6bf4fe0f79878ab7 --- mx_rec/constants/constants.py | 1 + mx_rec/core/embedding.py | 5 ++++- src/core/emb_hashmap/emb_hashmap.cpp | 5 +++-- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 24 +----------------------- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 -- src/core/utils/common.cpp | 19 +++++++++++++++++++ src/core/utils/common.h | 2 ++ 7 files changed, 30 insertions(+), 28 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 665f686b..b97326a1 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -39,6 +39,7 @@ DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 HASHTABLE_COLLECTION_NAME_LENGTH = 30 +MAX_HOST_VOCABULARY_SIZE = 10**10 # RANK INFO VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"] diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 16745e78..2f34e02b 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -19,7 +19,7 @@ from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temp from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, MxRecMode, \ ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, ASCEND_TABLE_NAME_MUST_CONTAIN, \ - MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy + MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_HOST_VOCABULARY_SIZE from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ @@ -152,6 +152,9 @@ class SparseEmbedding: self.embedding_size = tf.TensorShape([self.embedding_size]) self.device_vocabulary_size = config.get("device_vocabulary_size") self.host_vocabulary_size = config.get("host_vocabulary_size") + if self.host_vocabulary_size > MAX_HOST_VOCABULARY_SIZE: + raise ValueError(f"host_vocabulary_size is larger than {MAX_HOST_VOCABULARY_SIZE}.") + self.table_name = config.get("table_name") self.key_dtype = config.get("key_dtype") self._optimizer_instance_list = config.get("optimizer_list") diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 00f72d34..d0192a32 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -62,7 +62,8 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t embHashMap.maxOffsetOld = embHashMap.maxOffset; auto keepBatch = swapId - iBatch; // 处理batch的次数,多个预取一起处理算一次 - bool findOffsetV2 = getenv("FIND_OFFSET_V2") != nullptr; + bool findOffsetV2 = GetEnv("FIND_OFFSET_V2"); + VLOG(GLOG_DEBUG) << StringFormat("FindOffset version:%d", findOffsetV2); // 找到所有key的偏移;dev和host需要交换的位置 @@ -203,7 +204,7 @@ void EmbHashMap::FindAndUpdateBatchId(vector& keys, size_t currentBat EmbHashMapInfo& embHashMap) const { EASY_FUNCTION() - bool findOffsetV3 = getenv("FIND_OFFSET_V3") != nullptr; + bool findOffsetV3 = GetEnv("FIND_OFFSET_V3"); for (size_t i = 0; i < keySize; i++) { int offset; auto& key = keys[i]; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index b63f7744..f8d13ddb 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -56,11 +56,7 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& } // 设置是否使用AccCTR库提供的去重、分桶功能,默认关闭 - const int defaultFastUnique = false; - PerfConfig::fastUnique = defaultFastUnique; - const char* envFastUnique = getenv("FAST_UNIQUE"); - HybridMgmt::CheckFastUnique(envFastUnique); - + PerfConfig::fastUnique = GetEnv("FAST_UNIQUE"); // 初始化数据处理类,配置相关信息,启动处理线程 preprocess = Singleton::GetInstance(); preprocess->Initialize(rankInfo, embInfos, thresholdValues, seed); @@ -69,24 +65,6 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& return true; } -void HybridMgmt::CheckFastUnique(const char *envFastUnique) -{ - if (envFastUnique != nullptr) { - try { - int tmp = std::stoi(envFastUnique); - if (tmp == 0 || tmp == 1) { - PerfConfig::fastUnique = (tmp == 1) ? true : false; - LOG(INFO) << StringFormat("Succeed to parse ${env:FAST_UNIQUE}: %d.", PerfConfig::fastUnique); - } else { - LOG(ERROR) << StringFormat("Invalid ${env:FAST_UNIQUE}: %s, which should be an 0 or 1.", envFastUnique); - } - } catch (const std::invalid_argument &e) { - LOG(ERROR) << - StringFormat("Failed to parse ${env:FAST_UNIQUE}: %s, which should be an integer.", envFastUnique); - } - } -} - /// Openmpi通信域进程数设置、计算所有表host特征数量总数、设置训练模式(HBM/DDR) /// \param rankInfo /// \param embInfos diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 4c77d6d9..19dec09c 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -111,8 +111,6 @@ namespace MxRec { bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed); - void CheckFastUnique(const char* envFastUnique); - void InitRankInfo(RankInfo& rankInfo, const vector& embInfos); private: diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 74245cc1..b21d6fb7 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -125,6 +125,25 @@ namespace MxRec { << " [" << l.severity << "] "; } + bool GetEnv(const char *envName) + { + const char* envString = getenv(envName); + int tmp = 0; + if (envString != nullptr) { + try { + tmp = std::stoi(envString); + if (tmp == 0 || tmp == 1) { + LOG(INFO) << StringFormat("Succeed to parse ${env:%s}: %d.", envName, tmp); + } else { + LOG(ERROR) << StringFormat("Invalid ${env:%s}: %d, which should be an 0 or 1.", envName, tmp); + } + } catch (const std::invalid_argument &e) { + LOG(ERROR) << + StringFormat("Failed to parse ${env:%s}, which should be an integer.", envName); + } + } + return (tmp == 1) ? true : false; + } string GetChipName(int devID) { int ret = 0; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 5b02248a..9f9fa5fd 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -283,6 +283,8 @@ struct BatchTask { void CustomGlogFormat(std::ostream &s, const google::LogMessageInfo &l, void*); + bool GetEnv(const char *envName); + template string StringFormat(const string& format, Args ... args) { -- Gitee From bbe20c340ffac1c2802be2431c5bd524ee4fb890 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 16:01:07 +0800 Subject: [PATCH 291/551] Match-id-e63b8fc241f1092a5dd1a5f90ee4fecb95be140e --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 5c84308e..70082a06 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -702,7 +702,8 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut) { - TimeCost getAndSendTensorsTC, getTensorsTC; + TimeCost getAndSendTensorsTC; + TimeCost getTensorsTC; auto& embHashMap = hostHashMaps->embHashMaps.at(embName); // 进行新一批预取数据时,计数初始化 -- Gitee From d66ef485cacbcb13f1a5817648a947209810066a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 26 Aug 2023 01:28:34 +0800 Subject: [PATCH 292/551] Match-id-54e51df050d013dd5e7e4bc0a2490d771a13fbd1 --- mx_rec/constants/constants.py | 1 + mx_rec/saver/saver.py | 4 ++-- src/core/checkpoint/checkpoint.cpp | 2 +- src/core/utils/common.h | 1 + 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 79178523..248b15c9 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -47,6 +47,7 @@ VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", MIN_SIZE = 1 MAX_CONFIG_SIZE = 10 * 1024 * 1024 MAX_SIZE = 1024 * 1024 * 1024 * 1024 +MAX_FILE_SIZE = 500 * 1024 * 1024 * 1024 MAX_DEVICE_NUM = 16 MAX_RANK_SIZE = 4095 MIN_DEVICE_NUM = 1 diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 9033b679..df219914 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -11,7 +11,7 @@ import numpy as np import tensorflow as tf from tensorflow.python.util import compat -from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_SIZE +from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ send_host_data, get_ascend_global_hashtable_collection @@ -411,7 +411,7 @@ def validate_read_file(read_file_path): :param read_file_path: the file path to be validated """ file_validator = FileValidator(read_file_path) - file_validator.check_file_size(MAX_SIZE, MIN_SIZE) + file_validator.check_file_size(MAX_FILE_SIZE, MIN_SIZE) # local file need to check soft link if read_file_path.find("://") == -1: file_validator.check_not_soft_link() diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 4349331f..92216402 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -486,7 +486,7 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, return ; } auto embDataOuterSize = transData.attribute.at(attribEmbDataOuterIdx); - if (embDataOuterSize <= 0) { + if (embDataOuterSize <= 0 || embDataOuterSize > MAX_VOCABULARY_SIZE) { throw runtime_error(StringFormat("Invalid embDataOuterSize :%d", embDataOuterSize).c_str()); } std::ifstream readFile; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 82f90d83..c30d2b2e 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -59,6 +59,7 @@ namespace MxRec { constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; constexpr int KEY_PROCESS_THREAD = 6; constexpr char SUM_SAME_ID[] = "sum_same_id_gradients_and_apply"; + constexpr int MAX_VOCABULARY_SIZE = 1e9; // for GLOG extern int g_glogLevel; -- Gitee From a5031d29277d316186bd9dbcce844a1aa6ffc0a9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 25 Aug 2023 16:31:50 +0800 Subject: [PATCH 293/551] Match-id-a42b87482eb73f441c620a3940438206832f5606 --- src/core/ssd_cache/lfu_cache.cpp | 154 +++++++++++++++++++++++++ src/core/ssd_cache/lfu_cache.h | 60 ++++++++++ src/tests/ssd_cache/lfu_cache_test.cpp | 109 +++++++++++++++++ 3 files changed, 323 insertions(+) create mode 100644 src/core/ssd_cache/lfu_cache.cpp create mode 100644 src/core/ssd_cache/lfu_cache.h create mode 100644 src/tests/ssd_cache/lfu_cache_test.cpp diff --git a/src/core/ssd_cache/lfu_cache.cpp b/src/core/ssd_cache/lfu_cache.cpp new file mode 100644 index 00000000..5e0c2a04 --- /dev/null +++ b/src/core/ssd_cache/lfu_cache.cpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: lfu cache module + * Author: MindX SDK + * Date: 2023/8/10 + */ +#include "lfu_cache.h" + +#include +#include +#include + +using namespace std; +using namespace MxRec; + +/// 仅获取当前key的频次,不增加频次;key不存在时返回-1 +/// \param key key +/// \return key的频次 +freq_num_t LFUCache::Get(emb_key_t key) +{ + auto it = keyTable.find(key); + if (it == keyTable.end()) { return -1; } + return it->second->freq; +} + +/// 返回num个不在keys中的最低频次的key和对应次数;并清理相应的key数据 +/// \param num 要返回的最低频次key info的个数 +/// \param keys 要返回的最低频次key不能在该列表内 +/// \param ddrSwapOutKeys 记录最低频次key +/// \param ddrSwapOutCounts 记录最低频次key对应次数 +void LFUCache::GetAndDeleteLeastFreqKeyInfo(int64_t num, const vector& keys, + vector& ddrSwapOutKeys, vector& ddrSwapOutCounts) +{ + freq_num_t tempMinFreq = minFreq; + unordered_set retainedKeySet(keys.begin(), keys.end()); + int64_t counter = 0; + const size_t freqSize = freqTable.size(); + // 遍历freqTable<次数,keyList>时,次数可能不连续,要实际使用了1个keyList后才自增,手动增加计数器 + for (size_t i = 0; i < freqSize;) { + auto nodesIter = freqTable.find(tempMinFreq); + if (nodesIter == freqTable.end()) { + tempMinFreq++; + continue; + } + auto nodeIt = freqTable[tempMinFreq].begin(); + while (nodeIt != freqTable[tempMinFreq].end() && !freqTable[tempMinFreq].empty() && counter < num) { + emb_key_t currentKey = nodeIt->key; + if (retainedKeySet.find(currentKey) != retainedKeySet.end()) { + // 当前key在指定的集合中,不满足 + nodeIt++; + continue; + } + ddrSwapOutKeys.emplace_back(currentKey); + ddrSwapOutCounts.emplace_back(nodeIt->freq); + keyTable.erase(currentKey); + nodeIt = freqTable[tempMinFreq].erase(nodeIt); + counter++; + } + if (freqTable[tempMinFreq].empty()) { + freqTable.erase(tempMinFreq); + // 删除频次列表时,若最小频次和临时最小频次相等,则最小频次+1 + minFreq = tempMinFreq == minFreq ? minFreq + 1 : minFreq; + } + if (counter == num || freqTable.empty()) { + break; + } + tempMinFreq++; + i++; + } +} + +/// 放入key,新增/更新(次数+1)次数 +/// \param key key +void LFUCache::Put(emb_key_t key) +{ + auto it = keyTable.find(key); + if (it == keyTable.end()) { + freqTable[1].emplace_front(key, 1); + keyTable[key] = freqTable[1].begin(); + minFreq = 1; + return; + } + auto& node = it->second; + freq_num_t freq = node->freq; + freqTable[freq].erase(node); + if (freqTable[freq].empty()) { + freqTable.erase(freq); + } + if (minFreq == freq) { minFreq += 1; } + freqTable[freq + 1].emplace_front(key, freq + 1); + keyTable[key] = freqTable[freq + 1].begin(); +} + +void LFUCache::PutKeys(vector& keys) +{ + for (auto key : keys) { + Put(key); + } +} + +/// 直接放入指定次数;用于初始化场景 +/// \param key key +/// \param freq 频次 +void LFUCache::PutWithInit(emb_key_t key, freq_num_t freq) +{ + if (keyTable.find(key) != keyTable.end()) { + // 一般初始化时,key应该不存在已经被插入的情况;此处替换就的key频次信息 + LOG(WARNING) << StringFormat("key has exist when init process, key:%d", key); + Pop(key); + } + freqTable[freq].emplace_front(key, freq); + keyTable[key] = freqTable[freq].begin(); + if (minFreq == 0) { + minFreq = freq; + } else { + minFreq = freq < minFreq ? freq : minFreq; + } +} + +/// 删除指定key +bool LFUCache::Pop(emb_key_t key) +{ + auto it = keyTable.find(key); + if (it == keyTable.end()) { + return false; + } + auto& node = it->second; + freq_num_t oldFreq = node->freq; + freqTable[oldFreq].erase(node); + if (freqTable[oldFreq].empty()) { + freqTable.erase(oldFreq); + if (minFreq == oldFreq) { minFreq += 1; } + } + keyTable.erase(it); + return true; +} + +/// 获取所有的key和次数信息 +/// \return 频次数据map +std::unordered_map LFUCache::GetFreqTable() +{ + unordered_map freqMap(keyTable.size()); + for (const auto& it :keyTable) { + freqMap[it.first] = it.second->freq; + } + return freqMap; +} + +LFUCache::LFUCache() +{ + minFreq = 0; + keyTable.clear(); + freqTable.clear(); +} diff --git a/src/core/ssd_cache/lfu_cache.h b/src/core/ssd_cache/lfu_cache.h new file mode 100644 index 00000000..170e0fc7 --- /dev/null +++ b/src/core/ssd_cache/lfu_cache.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: lfu cache module + * Author: MindX SDK + * Date: 2023/8/10 + */ +#ifndef MXREC_LFU_CACHE_H +#define MXREC_LFU_CACHE_H + +#include +#include +#include +#include +#include + +#include "utils/common.h" + +namespace MxRec { + using namespace std; + + using freq_num_t = int_fast32_t; + + // 记录key和次数信息 + struct LFUCacheNode { + emb_key_t key; + freq_num_t freq; + + LFUCacheNode(emb_key_t key, freq_num_t freq) : key(key), freq(freq) + {} + }; + + class LFUCache { + public: + LFUCache(); + + freq_num_t Get(emb_key_t key); + + void GetAndDeleteLeastFreqKeyInfo(int64_t num, const vector& keys, + vector& ddrSwapOutKeys, + vector& ddrSwapOutCounts); + + void Put(emb_key_t key); + + void PutKeys(vector& keys); + + bool Pop(emb_key_t key); + + void PutWithInit(emb_key_t key, freq_num_t freq); + + std::unordered_map GetFreqTable(); + // 最小频次 + freq_num_t minFreq = 0; + // 次数, 该次数对应的key列表(key, freq) + std::unordered_map> freqTable; + // key, key所属node在freqTable的节点列表中的存储位置地址 + std::unordered_map::iterator> keyTable; + }; +} + +#endif // MXREC_LFU_CACHE_H diff --git a/src/tests/ssd_cache/lfu_cache_test.cpp b/src/tests/ssd_cache/lfu_cache_test.cpp new file mode 100644 index 00000000..6e4bd487 --- /dev/null +++ b/src/tests/ssd_cache/lfu_cache_test.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: lfu cache module + * Author: MindX SDK + * Date: 2023/8/10 + */ + +#include +#include + +#include "ssd_cache/lfu_cache.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +/* + * 要放入的key, 频次对应key列表中元素顺序和放入顺序相反; 如下列表key逐个放入的结果: + * 频次-对应key列表 + * 1 - 9,8 + * 2 - 6,4 + * 3 - 3,2,1 + */ +vector INPUT_KEYS = {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 6, 6, 8, 9}; + +inline void CompareHandleRet(vector& leastFreqKeys, vector& leastFreq, + vector& expectKeys, + vector& expectFreq) +{ + ASSERT_EQ(leastFreqKeys.size(), expectKeys.size()); + ASSERT_EQ(leastFreq.size(), expectFreq.size()); + for (size_t i = 0; i < leastFreqKeys.size(); i++) { + ASSERT_EQ(leastFreqKeys[i], expectKeys[i]); + ASSERT_EQ(leastFreq[i], expectFreq[i]); + } +} + +TEST(LFUCache, TestGetFreqTable) +{ + LFUCache cache; + cache.PutKeys(INPUT_KEYS); + auto ret = cache.GetFreqTable(); + ASSERT_EQ(ret[9], 1); + ASSERT_EQ(ret[6], 2); + ASSERT_EQ(ret[3], 3); +} + +TEST(LFUCache, PopTest) +{ + LFUCache cache; + cache.PutKeys(INPUT_KEYS); + cache.Pop(8); + cache.Pop(9); + ASSERT_EQ(cache.minFreq, 2); + ASSERT_EQ(cache.Get(8), -1); + ASSERT_EQ(cache.Get(9), -1); +} + +TEST(LFUCache, PutInitTest) +{ + LFUCache cache; + cache.PutWithInit(1, 3); + cache.PutWithInit(2, 3); + cache.PutWithInit(3, 3); + cache.PutWithInit(4, 2); + cache.PutWithInit(6, 2); + cache.PutWithInit(8, 1); + cache.PutWithInit(9, 1); + vector retainedKeys = {4, 6}; + vector leastFreqKeys; + vector leastFreq; + cache.GetAndDeleteLeastFreqKeyInfo(2, retainedKeys, leastFreqKeys, leastFreq); + vector expectKeys = {9, 8}; + vector expectFreq = {1, 1}; + CompareHandleRet(leastFreqKeys, leastFreq, expectKeys, expectFreq); + ASSERT_EQ(cache.minFreq, 2); +} + +TEST(LFUCache, LFUDeleteTotalFreqListTest) +{ + LFUCache cache; + cache.PutKeys(INPUT_KEYS); + vector retainedKeys = {4, 6, 8, 9}; + vector leastFreqKeys; + vector leastFreq; + cache.GetAndDeleteLeastFreqKeyInfo(2, retainedKeys, leastFreqKeys, leastFreq); + vector expectKeys = {3, 2}; + vector expectFreq = {3, 3}; + CompareHandleRet(leastFreqKeys, leastFreq, expectKeys, expectFreq); +} + +TEST(LFUCache, BaseCacheTest) +{ + LFUCache cache; + cache.PutKeys(INPUT_KEYS); + vector retainedKeys = {8, 4, 6, 2}; + vector leastFreqKeys; + vector leastFreq; + cache.GetAndDeleteLeastFreqKeyInfo(2, retainedKeys, leastFreqKeys, leastFreq); + vector expectKeys = {9, 3}; + vector expectFreq = {1, 3}; + CompareHandleRet(leastFreqKeys, leastFreq, expectKeys, expectFreq); + ASSERT_EQ(cache.minFreq, 1); + ASSERT_EQ(cache.Get(9), -1); + cache.Put(9); + ASSERT_EQ(cache.Get(9), 1); + cache.Put(9); + ASSERT_EQ(cache.minFreq, 2); +} -- Gitee From c2ca5a58850796517aefd43f28530d2e14747812 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 26 Aug 2023 11:15:01 +0800 Subject: [PATCH 294/551] Match-id-8c5536b9ad401025a7ef6f1398af26729f138d15 --- mx_rec/core/asc/manager.py | 5 +- mx_rec/core/embedding.py | 32 ++ src/core/ssd_cache/cache_manager.cpp | 448 +++++++++++++++++++++ src/core/ssd_cache/cache_manager.h | 121 ++++++ src/core/utils/common.h | 9 +- src/pybind/module_main.cpp | 8 +- src/tests/ssd_cache/cache_manager_test.cpp | 352 ++++++++++++++++ 7 files changed, 968 insertions(+), 7 deletions(-) create mode 100644 src/core/ssd_cache/cache_manager.cpp create mode 100644 src/core/ssd_cache/cache_manager.h create mode 100644 src/tests/ssd_cache/cache_manager_test.cpp diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 50d69df4..b7f27337 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -62,12 +62,13 @@ def generate_table_info_list(): if static_shape_rec_flag or dynamic_shape_rec_flag: logging.debug(f"table_instance.slice_device_vocabulary_size: {table_instance.slice_device_vocabulary_size}") logging.debug(f"table_instance.slice_host_vocabulary_size: {table_instance.slice_host_vocabulary_size}") + logging.debug(f"table_instance.slice_ssd_vocabulary_size: {table_instance.slice_ssd_vocabulary_size}") table_info = EmbInfo(table_instance.table_name, table_instance.send_count, table_instance.scalar_emb_size, table_instance.ext_emb_size, table_instance.is_save, [table_instance.slice_device_vocabulary_size, - table_instance.slice_host_vocabulary_size], + table_instance.slice_host_vocabulary_size, table_instance.slice_ssd_vocabulary_size], [matched_emb_initializer(table_instance)] + - matched_opt_slot_initializers(table_instance)) + matched_opt_slot_initializers(table_instance), table_instance.ssd_data_path) table_info_list.append(table_info) return table_info_list diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 16745e78..7d132d8d 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -4,6 +4,7 @@ import logging import math +import os import re from collections import defaultdict from typing import Optional @@ -27,6 +28,24 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is from mx_rec.validator.validator import ClassValidator, StringValidator +def check_ssd_vocab_param(host_vocabulary_size, ssd_vocabulary_size, ssd_data_path): + h_size = 0 + s_size = 0 + try: + h_size = int(host_vocabulary_size) + s_size = int(ssd_vocabulary_size) + except ValueError: + raise ValueError("exist invalid value in host_vocabulary_size or ssd_vocabulary_size or both.") + if h_size == 0 and s_size != 0: + raise ValueError("ssd_vocabulary_size value is invalid, it need host_vocabulary_size value not equals 0.") + invalid_ssd_data_path = [] + for tmpPath in ssd_data_path: + if not os.path.exists(tmpPath) or not os.path.isdir(tmpPath) or ".." in tmpPath: + invalid_ssd_data_path.append(tmpPath) + if invalid_ssd_data_path: + raise ValueError("ssd_data_path value is invalid, detail:{}.".format(", ".join(invalid_ssd_data_path))) + + def create_table(**kwargs): """ Args: @@ -36,6 +55,8 @@ def create_table(**kwargs): emb_initializer: the initializer for embedding values device_vocabulary_size: embedding vector numbers on device host_vocabulary_size: embedding vector numbers on ddr + ssd_vocabulary_size: embedding vector numbers on ssd + ssd_data_path: ssd embedding data save and load path relation from feature to variable offset will be built optimizer_list: specify the optimizers to use for current hash table mode: specify which mode to run for current sparse table @@ -55,6 +76,8 @@ def create_table(**kwargs): emb_initializer = kwargs.get("emb_initializer") device_vocabulary_size = kwargs.get("device_vocabulary_size", 1) host_vocabulary_size = kwargs.get("host_vocabulary_size", 0) + ssd_vocabulary_size = kwargs.get("ssd_vocabulary_size", 0) + ssd_data_path = kwargs.get("ssd_data_path", [os.getcwd()]) optimizer_list = kwargs.get("optimizer_list") mode = kwargs.get("mode", MxRecMode.ASC) value_dtype = kwargs.get("value_dtype", tf.float32) @@ -68,9 +91,11 @@ def create_table(**kwargs): name = fix_invalid_table_name(name) check_create_table_params(key_dtype, dim, name, emb_initializer) + check_ssd_vocab_param(host_vocabulary_size, ssd_vocabulary_size, ssd_data_path) config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, + ssd_vocabulary_size=ssd_vocabulary_size, ssd_data_path=ssd_data_path, optimizer_list=optimizer_list, mode=mode, value_dtype=value_dtype, shard_num=shard_num, fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold, init_param=init_param, is_save=is_save, all2all_gradients_op=all2all_gradients_op, @@ -152,6 +177,8 @@ class SparseEmbedding: self.embedding_size = tf.TensorShape([self.embedding_size]) self.device_vocabulary_size = config.get("device_vocabulary_size") self.host_vocabulary_size = config.get("host_vocabulary_size") + self.ssd_vocabulary_size = config.get("ssd_vocabulary_size") + self.ssd_data_path = config.get("ssd_data_path") self.table_name = config.get("table_name") self.key_dtype = config.get("key_dtype") self._optimizer_instance_list = config.get("optimizer_list") @@ -171,6 +198,7 @@ class SparseEmbedding: self._optimizer = dict() self.slice_device_vocabulary_size = 0 self.slice_host_vocabulary_size = 0 + self.slice_ssd_vocabulary_size = 0 self.variable = None self.lookup_info = set() self.lookup_result = dict() @@ -352,6 +380,7 @@ class SparseEmbedding: else: self.slice_device_vocabulary_size = math.ceil(self.device_vocabulary_size / rank_size) self.slice_host_vocabulary_size = math.ceil(self.host_vocabulary_size / rank_size) + self.slice_ssd_vocabulary_size = math.ceil(self.ssd_vocabulary_size / rank_size) def register_anchor_attribute(self, anchor_ids, feature_spec, kwargs): SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.TABLE_INSTANCE] = self @@ -755,6 +784,9 @@ class SparseEmbedding: logging.debug(f"Host vocabulary size for table {self.table_name} is {self.host_vocabulary_size}.") logging.debug(f"Slice host vocabulary_size for table {self.table_name} is" f" {self.slice_host_vocabulary_size}.") + logging.debug(f"SSD vocabulary size for table {self.table_name} is {self.ssd_vocabulary_size}.") + logging.debug("Slice ssd vocabulary_size for table {self.table_name} is" + f" {self.slice_ssd_vocabulary_size}.") def _initialize_variables(self): initialized_tensor = \ diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp new file mode 100644 index 00000000..735e734e --- /dev/null +++ b/src/core/ssd_cache/cache_manager.cpp @@ -0,0 +1,448 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: ssd cache module + * Author: MindX SDK + * Date: 2023/8/10 + */ +#include "cache_manager.h" + +#include +#include +#include + +#include "utils/common.h" +#include "utils/time_cost.h" + +using namespace MxRec; + +inline TransferRet TransferSuccess() +{ + return TransferRet::TRANSFER_OK; +} + +inline TransferRet TransferError() +{ + return TransferRet::TRANSFER_ERROR; +} + +inline TransferRet TransferSpaceWarning() +{ + return TransferRet::SSD_SPACE_NOT_ENOUGH; +} + +inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& externalKeys, const vector& keys) +{ + auto& hostHashMap = embHashMap.hostHashMap; + for (auto key : keys) { + if (hostHashMap.find(key) == hostHashMap.end()) { + externalKeys.emplace_back(key); + } + } +} + +void AddDebugAndTraceLog(vector& externalKeys, vector& externalSSDKeys) +{ + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: externalKeys size:%d, externalSSDKeys size:%d", + externalKeys.size(), externalSSDKeys.size()); + if (VLOG_IS_ON(GLOG_TRACE)) { + VLOG(GLOG_TRACE) << StringFormat("TransferDDREmbWithSSD: externalKeys:%s, externalSSDKeys:%s", + VectorToString(externalKeys).c_str(), VectorToString(externalSSDKeys).c_str()); + } +} + +/// DDR与SSD数据转移,使DDR内剩余空间能放置当前批次key +/// \param embTableName emb表名 +/// \param embHashMap emb表 +/// \param keys 当前批次key +/// \param channelId 通道id +/// \return 转移结果枚举 +TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, + const vector& keys, int channelId) +{ + // 区分HBM+DDR内key,和HBM+DDR外的key(新key或保存在SSD中的key) + vector externalKeys; + GetExternalKeys(embHashMap, externalKeys, keys); + if (externalKeys.empty()) { + return TransferSuccess(); + } + + // 判断剩余内存空间是否足够; 可用内存空间计算:HBM+DDR-已占用; 若是训练,再加DDR已淘汰; + // SSD仅与DDR交互,不考虑HBM淘汰位置 + size_t ddrAvailableSize = embHashMap.devVocabSize + embHashMap.hostVocabSize - embHashMap.maxOffset; + if (channelId == TRAIN_CHANNEL_ID) { + ddrAvailableSize += embHashMap.evictPos.size(); + } + + CreateSSDTableIfNotExist(embTableName); + + // 调用ssdEngine查询当前批次key中保存在SSD中的key + vector externalSSDKeys; + GetSSDKeys(embTableName, externalKeys, externalSSDKeys); + // 后续判断maxOffset是否超出范围时,maxOffset=devVocabSize+hostVocabSize时可用,此处包含等于 + bool isDDRSpaceEnough = ddrAvailableSize >= externalKeys.size(); + bool ddrSpaceEnoughOrEval = channelId != TRAIN_CHANNEL_ID || isDDRSpaceEnough; + if (ddrSpaceEnoughOrEval && externalSSDKeys.empty()) { + // 部分场景后续不用处理,在此处返回 + return TransferSuccess(); + } + + AddDebugAndTraceLog(externalKeys, externalSSDKeys); + /* + * 前面 externalSSDKeys = 0 ,评估场景的 ddr空间可用、不可用已返回; 训练的可用已返回; + * 剩下的情况如下: + * 评估: + * externalSSDKeys > 0, 可用 & 不可用操作一样; + * 可选:Ddr->ssd, 腾出 externalSSDKeys 大小空间; + * Ssd->ddr, 需要移动 externalSSDKeys ; + * externalSSDKeys = 0 --已返回 + * 训练: + * externalSSDKeys > 0 + * 可用: + * 可选:Ddr->ssd, 腾出 externalSSDKeys 大小空间; + * Ssd->ddr, 需要移动 externalSSDKeys ; + * 不可用: + * 必选:Ddr->ssd, 腾出 externalKeys 大小空间; + * 需要计算ssd剩余空间:externalKeys - externalSSDKeys + * (注: 当前策略均转移externalKeys) + * Ssd->ddr, 需要移动 externalSSDKeys ; + * externalSSDKeys = 0 + * 可用: --已返回 + * 不可用: + * Ddr->ssd, 腾出 externalKeys 大小的空间; + * 需要计算ssd剩余空间: externalKeys + * 因cache每次只转移DDR最小空间,上述可选动作也需执行,避免SSD移入DDR时空间不足 + */ + // 训练场景检查SSD剩余空间 评估不考虑新key + if (channelId == TRAIN_CHANNEL_ID) { + size_t needSSDSize = externalKeys.size() - externalSSDKeys.size() - ddrAvailableSize; + const int64_t ssdAvailableSize = ssdEngine->GetTableAvailableSpace(embTableName); + if (int64_t(needSSDSize) > ssdAvailableSize) { + LOG(ERROR) << "TransferDDREmbWithSSD: ssd available space is not enough to transfer DDR emb data." + " needSSDSize:" << needSSDSize << ", ssdAvailableSize:" << ssdAvailableSize; + return TransferSpaceWarning(); + } + } + + // 从SSD获取emb数据并从SSD删除; 避免DDR->SSD时空间不够 + vector> ssdEmbData; + if (!externalSSDKeys.empty()) { + ssdEmbData = ssdEngine->FetchEmbeddings(embTableName, externalSSDKeys); + ssdEngine->DeleteEmbeddings(embTableName, externalSSDKeys); + } + + // 从ddr转移到ssd的key个数 + size_t ddrSwapOutSizeTmp = ddrSpaceEnoughOrEval ? externalSSDKeys.size() : externalKeys.size(); + auto ddrSwapOutSize = static_cast(ddrSwapOutSizeTmp - ddrAvailableSize); + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: ddrSwapOutSize:%d, ddrAvailableSize:%lld", ddrSwapOutSize, + ddrAvailableSize); + + /* + * 转移DDR中数据到SSD + */ + // 记录要从DDR转移到SSD的key对应的offset(相对值,需减去devVocabSize) + vector ddrTransferPos; + TransferRet ddr2SsdRet = TransferDDREmb2SSD(embTableName, embHashMap, ddrSwapOutSize, keys, ddrTransferPos); + if (ddr2SsdRet == TransferRet::DDR_SPACE_NOT_ENOUGH) { + ssdEngine->InsertEmbeddings(embTableName, externalSSDKeys, ssdEmbData); + return ddr2SsdRet; + } + + HandleDDRTransferPos(ddrTransferPos, externalSSDKeys, embHashMap); + + /* + * 转移SSD中保存的当前批次key的emb数据到DDR + */ + return TransferSSDEmb2DDR(embTableName, embHashMap, externalSSDKeys, ddrTransferPos, ssdEmbData); +} + +/// SSD数据转移到DDR中后刷新映射和频次信息 +/// \param embTableName emb表名 +/// \param embHashMap emb hash表 +/// \param externalSSDKeys 存储在SSD中的key列表 +/// \param ddrTransferPos +void CacheManager::RefreshRelateInfoWithSSD2DDR(const std::string& embTableName, EmbHashMapInfo& embHashMap, + vector& externalSSDKeys, vector& ddrTransferPos) +{ + for (size_t i = 0; i < externalSSDKeys.size(); i++) { + // 映射关系 ddrTransferPos是在ddrEmbHash中的位置,记录映射时需加上devVocabSize + auto& key = externalSSDKeys[i]; + embHashMap.hostHashMap[externalSSDKeys[i]] = ddrTransferPos[i] + embHashMap.devVocabSize; + // 频次 + ddrKeyFreqMap[embTableName].PutWithInit(key, excludeDDRKeyCountMap[embTableName][key]); + excludeDDRKeyCountMap[embTableName].erase(key); + } +} + +void CacheManager::GetDDREmbInfo(vector& keys, const std::string& embTableName, EmbHashMapInfo& embHashMap, + vector& ddrTransferPos, vector>& ddrEmbData) +{ + // 根据offset 获取对应Emb数据 + for (auto& key : keys) { + ddrTransferPos.emplace_back(embHashMap.hostHashMap[key] - embHashMap.devVocabSize); + } + + if (VLOG_IS_ON(GLOG_TRACE)) { + VLOG(GLOG_TRACE) << "DDR keys:" << VectorToString(keys); + VLOG(GLOG_TRACE) << "DDR key positions:" << VectorToString(ddrTransferPos); + } + + ddrEmbData.resize(keys.size()); + const auto& emb = hostEmbs->GetEmb(embTableName); +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ddrTransferPos, emb, ddrEmbData) + for (size_t i = 0; i < ddrTransferPos.size(); i++) { + auto& missingKeyPo = ddrTransferPos[i]; + const auto& src = emb.embData[missingKeyPo]; + ddrEmbData[i] = src; + } +} + +/// 使用ssdEmbData更新DDR内emb数据 +/// \param embTableName emb表名 +/// \param ddrTransferPos 需要更新的DDR内的offset +/// \param ssdEmbData SSD对应的emb数据 +void CacheManager::UpdateDDREmbInfo(const std::string& embTableName, + vector& ddrTransferPos, + vector>& ssdEmbData) +{ + auto& emb = hostEmbs->GetEmb(embTableName); + auto& embData = emb.embData; +#pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ddrTransferPos, embData, ssdEmbData) + for (size_t i = 0; i < ddrTransferPos.size(); i++) { + embData[ddrTransferPos[i]] = ssdEmbData[i]; + } +} + +/// DDR_2_SSD场景数据刷新: 仅刷新映射和频次,ddr转移出去的offset信息后续统一处理 +/// \param embTableName emb表名 +/// \param embHashMap emb map +/// \param ddrSwapOutKeys 从DDR中转移到SSD中key列表 +/// \param ddrSwapOutCounts 从DDR中转移到SSD中key频次数据 +void CacheManager::RefreshRelateInfoWithDDR2SSD(const string& embTableName, EmbHashMapInfo& embHashMap, + vector& ddrSwapOutKeys, + vector& ddrSwapOutCounts) +{ + auto& excludeFreqMap = excludeDDRKeyCountMap[embTableName]; + for (size_t i = 0; i < ddrSwapOutKeys.size(); i++) { + auto& key = ddrSwapOutKeys[i]; + embHashMap.hostHashMap.erase(key); + excludeFreqMap[key] = ddrSwapOutCounts[i]; + } +} + +/// key从DDR移入、移出、HBM淘汰时刷新频次信息;仅刷新频次信息 +/// \param embTableName emb表名 +/// \param keys 操作的key集合 +/// \param type TransferType +void CacheManager::RefreshFreqInfoCommon(const string& embTableName, vector& keys, + TransferType type) +{ + if (type == TransferType::DDR_2_HBM) { + for (auto& key : keys) { + // 频次数据记录到 excludeDDRKeyCountMap,并删除ddrKeyFreqMap中频次数据 + excludeDDRKeyCountMap[embTableName][key] = ddrKeyFreqMap[embTableName].Get(key); + ddrKeyFreqMap[embTableName].Pop(key); + } + } else if (type == TransferType::HBM_2_DDR) { + for (auto& key : keys) { + // excludeDDRKeyCountMap 中次数转移到 ddrKeyFreqMap, 并删除原记录 + ddrKeyFreqMap[embTableName].PutWithInit(key, excludeDDRKeyCountMap[embTableName][key]); + excludeDDRKeyCountMap[embTableName].erase(key); + } + } else if (type == TransferType::DDR_2_EVICT) { + for (auto& key : keys) { + ddrKeyFreqMap[embTableName].Pop(key); + } + } else { + // TransferType::HBM_2_EVICT + for (auto& key : keys) { + excludeDDRKeyCountMap[embTableName].erase(key); + } + } +} + +void CacheManager::Init(HostEmb* hostEmbPtr, vector& mgmtEmbInfo) +{ + this->hostEmbs = hostEmbPtr; + for (auto& emb : mgmtEmbInfo) { + EmbBaseInfo baseInfo {emb.ssdVocabSize, emb.ssdDataPath, false}; + embBaseInfos.emplace(emb.name, baseInfo); + } + ssdEngine->Start(); + LOG(INFO) << "CacheManager Init method end."; +} + +bool CacheManager::IsKeyInSSD(const string& embTableName, emb_key_t key) +{ + return ssdEngine->IsKeyExist(embTableName, key); +} + +/// 淘汰SSD中Emb信息 +/// \param embTableName emb表名 +/// \param keys 淘汰key列表 +void CacheManager::EvictSSDEmbedding(const string& embTableName, vector& keys) +{ + // 1 删除缓存中记录的key的次数 2 删除SSD中保存的Emb数据 + for (auto& key : keys) { + excludeDDRKeyCountMap[embTableName].erase(key); + } + ssdEngine->DeleteEmbeddings(embTableName, keys); +} + +/// 放入key,新增/更新(次数+1)次数 +/// \param embTableName emb表名 +/// \param key key +/// \param type 记录类型 +void CacheManager::PutKey(const string& embTableName, const emb_key_t& key, RecordType type) +{ + if (type == RecordType::DDR) { + ddrKeyFreqMap[embTableName].Put(key); + return; + } + auto& hashMap = excludeDDRKeyCountMap[embTableName]; + const auto& it = hashMap.find(key); + freq_num_t count = it == hashMap.end() ? 1 : it->second + 1; + hashMap[key] = count; +} + +/// DDR->SSD与SSD->DDR的key个数可能不一致,手动补齐/截取 +/// \param ddrTransferPos DDR->SSD的offset列表(hostEmb表内的偏移值) +/// \param externalSSDKeys SSD->DDR的key列表 +/// \param embHashMap emb hash表 +void CacheManager::HandleDDRTransferPos(vector& ddrTransferPos, vector& externalSSDKeys, + EmbHashMapInfo& embHashMap) +{ + if (ddrTransferPos.size() == externalSSDKeys.size()) { + return; + } + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: operate length is not equal, will padding or clipping, " + "pos len:%d, keys len:%d", ddrTransferPos.size(), externalSSDKeys.size()); + // ddrTransferPos中是DDR内偏移位置,存入evictPos时,需加上devVocabSize;取出时需减去 + if (ddrTransferPos.size() > externalSSDKeys.size()) { + while (ddrTransferPos.size() > externalSSDKeys.size()) { + embHashMap.evictPos.emplace_back(ddrTransferPos.back() + embHashMap.devVocabSize); + ddrTransferPos.pop_back(); + } + return; + } + // 补齐offset + while (ddrTransferPos.size() < externalSSDKeys.size() && !embHashMap.evictPos.empty()) { + ddrTransferPos.emplace_back(embHashMap.evictPos.back() - embHashMap.devVocabSize); + embHashMap.evictPos.pop_back(); + } + auto allSize = embHashMap.devVocabSize + embHashMap.hostVocabSize; + // 还不够继续使用maxOffset + while (ddrTransferPos.size() < externalSSDKeys.size() && embHashMap.maxOffset < allSize) { + auto nextPos = embHashMap.maxOffset++; + ddrTransferPos.emplace_back(nextPos - embHashMap.devVocabSize); + } + VLOG(GLOG_DEBUG) << StringFormat("HandleDDRTransferPos: handle end, pos len:%d, keys len:%d", + ddrTransferPos.size(), externalSSDKeys.size()); +} + +void CacheManager::GetSSDKeys(const std::string& embTableName, vector& externalKeys, + vector& externalSSDKeys) +{ + for (auto& key : externalKeys) { + if (ssdEngine->IsKeyExist(embTableName, key)) { + externalSSDKeys.emplace_back(key); + } + } +} + +TransferRet CacheManager::TransferDDREmb2SSD(const string& embTableName, EmbHashMapInfo& embHashMap, + int64_t ddrSwapOutSize, + const vector& keys, vector& ddrTransferPos) +{ + if (ddrSwapOutSize <= 0) { + // 此时不需要转移数据 + return TransferRet::TRANSFER_OK; + } + + TimeCost ddr2SsdTc; + VLOG(GLOG_DEBUG) + << StringFormat("TransferDDREmbWithSSD: get ddr least freq keys, ddrSwapOutSize:%lld", ddrSwapOutSize); + // 获取DDR中指定数量的最低频次key,并获取相应emb数据,执行DDR换出到SSD + vector ddrSwapOutKeys; + vector ddrSwapOutCounts; + ddrKeyFreqMap[embTableName].GetAndDeleteLeastFreqKeyInfo(ddrSwapOutSize, keys, ddrSwapOutKeys, + ddrSwapOutCounts); + if (ddrSwapOutKeys.size() != ddrSwapOutSize) { + // 获取的最低频次key数量和预期不一致,DDR空间不足,不能放置当前批次数据 + LOG(ERROR) << StringFormat( + "TransferDDREmbWithSSD, vector length is not equal, ddrSwapOutKeys size:%d, ddrSwapOutSize:%lld", + ddrSwapOutKeys.size(), ddrSwapOutSize); + RestoreLeastFreqInfo(embTableName, ddrSwapOutKeys, ddrSwapOutCounts); + return TransferRet::DDR_SPACE_NOT_ENOUGH; + } + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: get DDR embeddings and save to SSD, size:%d", + ddrSwapOutKeys.size()); + // 获取DDR中emb数据 + vector> ddrEmbData; + GetDDREmbInfo(ddrSwapOutKeys, embTableName, embHashMap, ddrTransferPos, ddrEmbData); + // 调用SSDEngine接口,将DDR Emb数据保存到SSD + ssdEngine->InsertEmbeddings(embTableName, ddrSwapOutKeys, ddrEmbData); + + // 初始化DDR内被转移出去的位置 + hostEmbs->EvictInitEmb(embTableName, ddrTransferPos); + + // 更新记录的DDR中key频次信息 + RefreshRelateInfoWithDDR2SSD(embTableName, embHashMap, ddrSwapOutKeys, ddrSwapOutCounts); + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: ddr2SsdTc TimeCost(ms):%d", ddr2SsdTc.ElapsedMS()); + return TransferRet::TRANSFER_OK; +} + +TransferRet CacheManager::TransferSSDEmb2DDR(const string& embTableName, EmbHashMapInfo& embHashMap, + vector& externalSSDKeys, vector& ddrTransferPos, + vector>& ssdEmbData) +{ + if (externalSSDKeys.empty()) { + return TransferSuccess(); + } + TimeCost ssd2DdrTc; + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: get SSD embeddings and save to DDR, size:%d", + externalSSDKeys.size()); + if (ddrTransferPos.size() != externalSSDKeys.size()) { + LOG(ERROR) << StringFormat( + "TransferDDREmbWithSSD, vector length is not equal, ddrTransferPos len:%d, externalSSDKeys len:%d", + ddrTransferPos.size(), externalSSDKeys.size()); + return TransferError(); + } + // 将SSD emb存储到DDR中 刷新频次信息 + UpdateDDREmbInfo(embTableName, ddrTransferPos, ssdEmbData); + RefreshRelateInfoWithSSD2DDR(embTableName, embHashMap, externalSSDKeys, ddrTransferPos); + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: ssd2DdrTc TimeCost(ms):%d", ssd2DdrTc.ElapsedMS()); + return TransferSuccess(); +} + +void CacheManager::CreateSSDTableIfNotExist(const std::string& embTableName) +{ + if (embBaseInfos[embTableName].isExist) { + return; + } + if (!ssdEngine->IsTableExist(embTableName)) { + ssdEngine->CreateTable(embTableName, embBaseInfos[embTableName].savePath, + embBaseInfos[embTableName].maxTableSize); + embBaseInfos[embTableName].isExist = true; + LOG(INFO) << ("create ssd table end, embTableName:" + embTableName); + return; + } + embBaseInfos[embTableName].isExist = true; + LOG(INFO) << ("ssd table is exist, embTableName:" + embTableName); +} + +void CacheManager::RestoreLeastFreqInfo(const std::string& embTableName, vector& ddrSwapOutKeys, + vector& ddrSwapOutCounts) +{ + auto& lfuCache = ddrKeyFreqMap[embTableName]; + for (size_t i = 0; i < ddrSwapOutKeys.size(); i++) { + lfuCache.PutWithInit(ddrSwapOutKeys[i], ddrSwapOutCounts[i]); + } +} + +CacheManager::~CacheManager() +{ + hostEmbs = nullptr; + ssdEngine->Stop(); + ddrKeyFreqMap.clear(); + excludeDDRKeyCountMap.clear(); +} diff --git a/src/core/ssd_cache/cache_manager.h b/src/core/ssd_cache/cache_manager.h new file mode 100644 index 00000000..c1145584 --- /dev/null +++ b/src/core/ssd_cache/cache_manager.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: ssd cache module head + * Author: MindX SDK + * Date: 2023/8/10 + */ + +#ifndef MXREC_CACHE_MANAGER_H +#define MXREC_CACHE_MANAGER_H + +#include +#include +#include +#include +#include + +#include "hd_transfer/hd_transfer.h" +#include "host_emb/host_emb.h" +#include "lfu_cache.h" +#include "ssd_engine/ssd_engine.h" +#include "utils/common.h" + +namespace MxRec { + enum class TransferRet { + TRANSFER_OK = 0, // 转移成功或无需处理 + TRANSFER_ERROR, + SSD_SPACE_NOT_ENOUGH, + DDR_SPACE_NOT_ENOUGH, + }; + + enum class TransferType { + DDR_2_HBM = 0, + DDR_2_EVICT, + HBM_2_DDR, + HBM_2_EVICT, + }; + + enum class RecordType { + DDR = 0, + NOT_DDR, + }; + + class CacheManager { + public: + CacheManager() = default; + + ~CacheManager(); + + void Init(HostEmb* hostEmbPtr, vector& mgmtEmbInfo); + + // 保存/初始化模块相关数据 + bool SaveCacheManagerData(); + + // 转换DDR和SSD数据 + TransferRet TransferDDREmbWithSSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, + const vector& keys, int channelId); + + /* HBM与DDR换入换出时刷新频次信息 */ + void RefreshFreqInfoCommon(const string& embTableName, vector& keys, + TransferType type); + + bool IsKeyInSSD(const string& embTableName, emb_key_t key); + + void EvictSSDEmbedding(const string& embTableName, vector& keys); + + void PutKey(const string& embTableName, const emb_key_t& key, RecordType type); + + // DDR内每个表中emb数据频次缓存;map + unordered_map ddrKeyFreqMap; + // 每张表中非DDR内key的出现次数 + unordered_map> excludeDDRKeyCountMap; + + private: + struct EmbBaseInfo { + uint64_t maxTableSize; + vector savePath; + bool isExist; + }; + + void GetDDREmbInfo(vector& keys, + const std::string& embTableName, EmbHashMapInfo& embHashMap, + vector& ddrTransferPos, vector>& ddrEmbData); + + void UpdateDDREmbInfo(const std::string& embTableName, + vector& ddrTransferPos, + vector>& ssdEmbData); + + void RefreshRelateInfoWithDDR2SSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, + vector& ddrSwapOutKeys, vector& ddrSwapOutCounts); + + void RefreshRelateInfoWithSSD2DDR(const std::string& embTableName, EmbHashMapInfo& embHashMap, + vector& externalSSDKeys, vector& ddrTransferPos); + + void GetSSDKeys(const std::string& embTableName, vector& externalKeys, + vector& externalSSDKeys); + + TransferRet TransferDDREmb2SSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, + int64_t ddrSwapOutSize, const vector& keys, + vector& ddrTransferPos); + + TransferRet TransferSSDEmb2DDR(const std::string& embTableName, EmbHashMapInfo& embHashMap, + vector& externalSSDKeys, vector& ddrTransferPos, + vector>& ssdEmbData); + + void CreateSSDTableIfNotExist(const std::string& embTableName); + + void RestoreLeastFreqInfo(const std::string& embTableName, vector& ddrSwapOutKeys, + vector& ddrSwapOutCounts); + + static void HandleDDRTransferPos(vector& ddrTransferPos, vector& externalSSDKeys, + EmbHashMapInfo& embHashMap); + + unordered_map embBaseInfos; + + GTEST_PRIVATE: + shared_ptr ssdEngine = std::make_shared(); + HostEmb* hostEmbs {}; + }; +} + +#endif // MXREC_CACHE_MANAGER_H diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 5b02248a..7732f75f 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -208,6 +208,7 @@ struct BatchTask { uint32_t option {}; int nBatch {}; bool noDDR { false }; + bool isSSDEnabled { false }; bool useDynamicExpansion {false}; std::vector maxStep; }; @@ -436,12 +437,14 @@ struct BatchTask { int extEmbeddingSize, bool isSave, std::vector vocabsize, - std::vector initializeInfos) + std::vector initializeInfos, + std::vector ssdDataPath) : name(name), sendCount(sendCount), embeddingSize(embeddingSize), extEmbeddingSize(extEmbeddingSize), - isSave(isSave), initializeInfos(initializeInfos) + isSave(isSave), initializeInfos(initializeInfos), ssdDataPath(std::move(ssdDataPath)) { devVocabSize = vocabsize[0]; hostVocabSize = vocabsize[1]; + ssdVocabSize = vocabsize[2]; } std::string name; @@ -451,7 +454,9 @@ struct BatchTask { bool isSave; size_t devVocabSize; size_t hostVocabSize; + size_t ssdVocabSize; std::vector initializeInfos; + std::vector ssdDataPath; }; struct HostEmbTable { diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 9c1b5fb2..de01f850 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -94,9 +94,10 @@ void GetEmbInfo(pybind11::module_& m) { pybind11::class_(m, "EmbInfo") .def(pybind11::init, - std::vector&>(), + std::vector&, std::vector&>(), py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), py::arg("ext_embedding_size"), - py::arg("is_save"), py::arg("vocab_size"), py::arg("initialize_infos")) + py::arg("is_save"), py::arg("vocab_size"), py::arg("initialize_infos"), + py::arg("ssd_data_path")) .def_readwrite("name", &EmbInfo::name) .def_readwrite("send_count", &EmbInfo::sendCount) .def_readwrite("embedding_size", &EmbInfo::embeddingSize) @@ -104,7 +105,8 @@ void GetEmbInfo(pybind11::module_& m) .def_readwrite("is_save", &EmbInfo::isSave) .def_readwrite("dev_vocab_size", &EmbInfo::devVocabSize) .def_readwrite("host_vocab_size", &EmbInfo::hostVocabSize) - .def_readwrite("initialize_infos", &EmbInfo::initializeInfos); + .def_readwrite("initialize_infos", &EmbInfo::initializeInfos) + .def_readwrite("ssd_data_path", &EmbInfo::ssdDataPath); } void GetRandomInfo(pybind11::module_& m) diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp new file mode 100644 index 00000000..3c4147dc --- /dev/null +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -0,0 +1,352 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: ssd cache module head + * Author: MindX SDK + * Date: 2023/8/19 + */ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "host_emb/host_emb.h" +#include "ssd_cache/lfu_cache.h" +#include "ssd_cache/cache_manager.h" +#include "utils/common.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +static const string SSD_SAVE_PATH = "savePath1"; + +static const float EPSILON = 1e-6f; + +void InitSSDEngine(CacheManager& manager, string embTableName, uint64_t ssdSize) +{ + // Init ssd engine data + chrono::seconds period = chrono::seconds(120); + manager.ssdEngine->SetCompactPeriod(period); + manager.ssdEngine->SetCompactThreshold(1); + manager.ssdEngine->CreateTable(embTableName, {SSD_SAVE_PATH}, ssdSize); + vector ssdKeys = {15, 25}; // 预设15, 25存储在SSD + std::vector> ssdEmbData = {{15.0f}, + {25.0f}}; + auto& excludeMap = manager.excludeDDRKeyCountMap[embTableName]; + excludeMap[15] = 3; // 初始化次数 + excludeMap[25] = 5; + manager.ssdEngine->InsertEmbeddings(embTableName, ssdKeys, ssdEmbData); +} + +void InitDDREmbData(absl::flat_hash_map& loadData, string& embTableName, + vector& mgmtEmbInfos) +{ + // 构造 HostEmb 对象 + EmbInfo embInfo; + embInfo.name = embTableName; + embInfo.hostVocabSize = 20; + embInfo.devVocabSize = 100; + embInfo.ssdVocabSize = 100; + embInfo.ssdDataPath = {SSD_SAVE_PATH}; + mgmtEmbInfos.emplace_back(embInfo); + + std::vector> t_embData; // 以DDR vocabSize=100设置 + t_embData.assign(100, {}); + t_embData[0] = {1.0f}; + t_embData[1] = {2.0f}; + t_embData[91] = {3.0f}; + t_embData[92] = {4.0f}; + t_embData[94] = {6.0f}; + t_embData[96] = {8.0f}; + t_embData[97] = {9.0f}; + HostEmbTable hEmbTable = {embInfo, t_embData}; + loadData[embTableName] = hEmbTable; +} + +class CacheManagerTest : public testing::Test { +protected: + void SetUp() + { + cacheManager.ddrKeyFreqMap[embTableName] = cache; + cacheManager.ddrKeyFreqMap[embTableName].PutKeys(input_keys); + LFUCache cache2; + cacheManager.ddrKeyFreqMap[embTableName2] = cache2; + cacheManager.ddrKeyFreqMap[embTableName2].PutKeys(input_keys); + unordered_map excludeDDRKeyFreq; + excludeDDRKeyFreq[27] = 10; + excludeDDRKeyFreq[30] = 10; + cacheManager.excludeDDRKeyCountMap[embTableName] = excludeDDRKeyFreq; + + // init cache manager + vector mgmtEmbInfos; + absl::flat_hash_map loadData = {}; + InitDDREmbData(loadData, embTableName, mgmtEmbInfos); + InitDDREmbData(loadData, embTableName2, mgmtEmbInfos); + + cacheManager.Init(hEmb, mgmtEmbInfos); + + InitSSDEngine(cacheManager, embTableName, 5); + InitSSDEngine(cacheManager, embTableName2, 10); + // load ddr emb data + cacheManager.hostEmbs->hostEmbs = loadData; + + auto& embMap = cacheManager.hostEmbs->hostEmbs; + + // 设置全局rankId,ssdEngine保存时会使用 + int workRankId; + MPI_Comm_rank(MPI_COMM_WORLD, &workRankId); + g_rankId = to_string(workRankId); + } + + CacheManager cacheManager; + LFUCache cache; + /* + * 频次-对应key列表 + * 1 - 9,8 + * 2 - 6,4 + * 3 - 3,2,1 + */ + vector input_keys = {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 6, 6, 8, 9}; + string embTableName = "table1"; + string embTableName2 = "table2"; + HostEmb* hEmb = Singleton::GetInstance(); + + void TearDown() + { + } +}; + +TEST_F(CacheManagerTest, RefreshFreqInfo) +{ + vector ddr2HbmKeys = {8, 9}; + cacheManager.RefreshFreqInfoCommon(embTableName, ddr2HbmKeys, TransferType::DDR_2_HBM); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 2); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].keyTable.size(), 5); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].freqTable.size(), 2); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(8), -1); + ASSERT_EQ(cacheManager.excludeDDRKeyCountMap[embTableName].size(), 6); + + // HBM转移到DDR 频次数据设置构造 + cacheManager.excludeDDRKeyCountMap[embTableName][150] = 4; + cacheManager.excludeDDRKeyCountMap[embTableName][151] = 1; + vector hbm2DdrKeys = {150, 151}; + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(151), -1); + cacheManager.RefreshFreqInfoCommon(embTableName, hbm2DdrKeys, TransferType::HBM_2_DDR); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(150), 4); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(151), 1); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 1); + ASSERT_EQ(cacheManager.excludeDDRKeyCountMap[embTableName].size(), 6); + + vector ddr2EvictKeys = {151}; + cacheManager.RefreshFreqInfoCommon(embTableName, ddr2EvictKeys, TransferType::DDR_2_EVICT); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(151), -1); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].freqTable.size(), 3); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 2); + + // HBM2Evict + cacheManager.excludeDDRKeyCountMap[embTableName][160] = 1; + vector hbm2EvictKeys = {160}; + cacheManager.RefreshFreqInfoCommon(embTableName, hbm2EvictKeys, TransferType::HBM_2_EVICT); + const auto it = cacheManager.excludeDDRKeyCountMap[embTableName].find(160); + ASSERT_EQ(it, cacheManager.excludeDDRKeyCountMap[embTableName].end()); + LOG(INFO) << "test RefreshFreqInfo end."; +} + +TEST_F(CacheManagerTest, PutKey) +{ + vector putDDRKeys = {1, 9, 8, 15}; + for (auto& key : putDDRKeys) { + cacheManager.PutKey(embTableName, key, RecordType::DDR); + } + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 1); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].freqTable[1].size(), 1); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(15), 1); + LOG(INFO) << "test PutKey end."; +} + +TEST_F(CacheManagerTest, IsKeyInSSD) +{ + vector checkKeys = {1, 2, 15, 25}; + ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, checkKeys[0])); + ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, checkKeys[1])); + ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, checkKeys[2])); + ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, checkKeys[3])); + LOG(INFO) << "test IsKeyInSSD end."; +} + +TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalKey) +{ + EmbHashMapInfo embHashMapInfo; + vector currentKeys = {55, 65, 75}; + embHashMapInfo.hostHashMap[55] = 119; + embHashMapInfo.hostHashMap[65] = 118; + embHashMapInfo.hostHashMap[75] = 116; + auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); + ASSERT_EQ(ret, TransferRet::TRANSFER_OK); + LOG(INFO) << "test TransferDDREmbWithSSDByEmptyExternalKey end."; +} + +TEST_F(CacheManagerTest, TransferDDREmbWithSSDByAllProcess) +{ + vector ssdKeys = {15, 25}; + vector> ssdKeyEmbInfo = {{1.5f}, {2.5f}}; + + // init EmbHashMapInfo + EmbHashMapInfo embHashMapInfo; + embHashMapInfo.devVocabSize = 20; + embHashMapInfo.hostVocabSize = 100; + embHashMapInfo.maxOffset = 118; // 剩余2个可用空间(DDR剩余, 相对位置:98 99) + embHashMapInfo.evictPos.emplace_back(110); // 淘汰列表 + + // 构造已经存储早DDR中key和offset对应关系; DDR的offset在映射表中范围是 20~119 + embHashMapInfo.hostHashMap[9] = 117; // DDR中相对位置: 97 + embHashMapInfo.hostHashMap[8] = 116; // DDR中相对位置: 96 + embHashMapInfo.hostHashMap[6] = 114; // DDR中相对位置: 94 + embHashMapInfo.hostHashMap[4] = 112; // DDR中相对位置: 92 + embHashMapInfo.hostHashMap[3] = 111; // DDR中相对位置: 91 + embHashMapInfo.hostHashMap[2] = 21; // DDR中相对位置: 1 + embHashMapInfo.hostHashMap[1] = 20; // DDR中相对位置: 0 + + // 检查构造数据正确性 + auto& embMap = cacheManager.hostEmbs->hostEmbs; + const auto& it = embMap.find(embTableName); + auto& hostData = it->second.embData; + ASSERT_TRUE(fabs(hostData[0][0] - 1.0f) < EPSILON); + ASSERT_TRUE(fabs(hostData[1][0] - 2.0f) < EPSILON); + ASSERT_TRUE(fabs(hostData[94][0] - 6.0f) < EPSILON); + ASSERT_TRUE(fabs(hostData[97][0] - 9.0f) < EPSILON); + auto& excludeKeyCountMap = cacheManager.excludeDDRKeyCountMap[embTableName]; + ASSERT_EQ(excludeKeyCountMap[15], 3); + ASSERT_EQ(excludeKeyCountMap[25], 5); + ASSERT_FALSE(cacheManager.ssdEngine->IsKeyExist(embTableName, 9)); + ASSERT_FALSE(cacheManager.ssdEngine->IsKeyExist(embTableName, 8)); + ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 15)); + + LOG(INFO) << "check detail data before transfer ok."; + + // externalKeys: SSD(15, 25) + newKey(55, 65, 75) + // 训练场景,构造结果:offsetAvailableSize=20+100-118+evictPos.size()=3 + // cacheManager中的频次数据(低-高): 9 8 6 4 3 2 1 + // 构造空间超出SSD可用上限 + vector exceedKeys = {15, 25, 6, 4, 55, 65, 75, 85, 95, 105, 115}; + auto spaceError1 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, exceedKeys, TRAIN_CHANNEL_ID); + ASSERT_EQ(spaceError1, TransferRet::SSD_SPACE_NOT_ENOUGH); + + // 构造训练+超SSD可用+当前批次中不包含报错在SSD的key + vector keys2 = {6, 4, 55, 65, 75, 85, 95, 105, 115, 125, 135}; + auto spaceError2 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, exceedKeys, TRAIN_CHANNEL_ID); + ASSERT_EQ(spaceError2, TransferRet::SSD_SPACE_NOT_ENOUGH); + + // 构造当前批次key 存储位置: SSD(15, 25) DDR(6, 4) newKey(55, 65, 75) + vector currentKeys = {15, 25, 6, 4, 55, 65, 75}; + // 需要从ddr转移4个key到ssd, 低频数据中6 4在当前批次key中,不会被转移,构造的数据转移key:9, 8, 3, 2 + auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); + + // 检查处理后数据正确性 + ASSERT_EQ(ret, TransferRet::TRANSFER_OK); + ASSERT_TRUE(fabs(hostData[94][0] - 6.0f) < EPSILON); // DDR内未移动的数据 + ASSERT_TRUE(fabs(hostData[96][0] - 25.0f) < EPSILON); // SSD转移到DDR的数据 + ASSERT_TRUE(fabs(hostData[97][0] - 15.0f) < EPSILON); // SSD转移到DDR的数据 + ASSERT_EQ(embHashMapInfo.evictPos.size(), 1); + ASSERT_EQ(embHashMapInfo.evictPos.back(), 110); + + // 原DDR中最小频次key(9,8)次数(1)被转移到SSD,SSD转移到DDR的key(15,25)次数(3,5), DDR内频次索引应变为2 + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 2); + ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 9)); + ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 8)); + ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 15)); + LOG(INFO) << "test TransferDDREmbWithSSDByAllProcess end."; +} + +TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalSSDKey) +{ + // 训练+评估:构造DDR剩余空间足够,externalSSDKeys为空 + EmbHashMapInfo embHashMapInfo; + embHashMapInfo.devVocabSize = 20; + embHashMapInfo.hostVocabSize = 100; + embHashMapInfo.hostHashMap[6] = 114; // DDR中相对位置: 94 + embHashMapInfo.hostHashMap[4] = 112; // DDR中相对位置: 92 + // 剩余3个可用空间(DDR剩余2个, 相对位置:98 99; DDR淘汰列表1个) + embHashMapInfo.maxOffset = 118; + embHashMapInfo.evictPos.emplace_back(110); + vector currentKeys = {6, 4, 55, 65, 75}; + auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); + ASSERT_EQ(ret, TransferRet::TRANSFER_OK); + auto retByEval = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, EVAL_CHANNEL_ID); + ASSERT_EQ(retByEval, TransferRet::TRANSFER_OK); + + // 评估场景, DDR剩余空间不足, externalSSDKeys为空 + vector currentKeys2 = {6, 4, 55, 65, 75, 85, 95, 105, 115}; + auto ret2 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys2, EVAL_CHANNEL_ID); + ASSERT_EQ(ret2, TransferRet::TRANSFER_OK); + // 训练场景,返回ssd空间不足 + auto ret3 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys2, TRAIN_CHANNEL_ID); + ASSERT_EQ(ret3, TransferRet::SSD_SPACE_NOT_ENOUGH); + LOG(INFO) << "test TransferDDREmbWithSSDByEmptyExternalSSDKey end."; +} + +TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEval) +{ + // 评估+DDR剩余空间足够+externalSSDKeys为空 + EmbHashMapInfo embHashMapInfo; + embHashMapInfo.devVocabSize = 20; + embHashMapInfo.hostVocabSize = 100; + embHashMapInfo.hostHashMap[9] = 117; // DDR中相对位置: 97 + embHashMapInfo.hostHashMap[8] = 116; // DDR中相对位置: 96 + embHashMapInfo.hostHashMap[6] = 114; // DDR中相对位置: 94 + embHashMapInfo.hostHashMap[4] = 112; // DDR中相对位置: 92 + // 剩余3个可用空间(DDR剩余2个, 相对位置:98 99; DDR淘汰列表1个) + embHashMapInfo.maxOffset = 118; + embHashMapInfo.evictPos.emplace_back(110); // 淘汰列表 + vector currentKeys = {6, 4, 55, 65, 75}; + auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, EVAL_CHANNEL_ID); + ASSERT_EQ(ret, TransferRet::TRANSFER_OK); + LOG(INFO) << "test eval+space enough+externalSSDKeysEmpty ok."; + + // 评估+DDR剩余空间足够+externalSSDKeys非空 + vector currentKeys2 = {15, 25, 6, 4, 55, 65, 75, 85, 95, 105, 115}; + auto ret2 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys2, EVAL_CHANNEL_ID); + ASSERT_EQ(ret2, TransferRet::TRANSFER_OK); + // 检查处理后数据正确性 + const auto& it = cacheManager.hostEmbs->hostEmbs.find(embTableName); + auto& hostData = it->second.embData; + ASSERT_TRUE(fabs(hostData[94][0] - 6.0f) < EPSILON); // DDR内未移动的数据 + ASSERT_TRUE(fabs(hostData[98][0] - 25.0f) < EPSILON); // SSD转移到DDR的数据 + ASSERT_TRUE(fabs(hostData[90][0] - 15.0f) < EPSILON); // SSD转移到DDR的数据 + ASSERT_EQ(embHashMapInfo.evictPos.size(), 0); + // 原DDR中最小频次key(9,8)次数(1)被转移到SSD,SSD转移到DDR的key(15,25)次数(3,5), DDR内频次索引应变为2 + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 1); + ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 9)); + ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 8)); + ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 15)); + LOG(INFO) << "test eval+space enough+externalSSDKeysNotEmpty ok."; +} + +TEST_F(CacheManagerTest, TransferDDREmbWithSSDByDDRSpaceNotEnough) +{ + // 构造DDR所有空间不满足存放当前批次数据 + EmbHashMapInfo embHashMapInfo; + embHashMapInfo.devVocabSize = 20; + embHashMapInfo.hostVocabSize = 10; + embHashMapInfo.maxOffset = 30; + embHashMapInfo.hostHashMap[6] = 9; + embHashMapInfo.hostHashMap[4] = 8; + // keys size:10, ddr keys:2 externalKeys:8 externalSSDKeys:0 + vector currentKeys = {6, 4, 101, 102, 103, 104, 105, 106, 107, 108}; + auto ret = cacheManager.TransferDDREmbWithSSD(embTableName2, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); + ASSERT_EQ(ret, TransferRet::DDR_SPACE_NOT_ENOUGH); + LOG(INFO) << "test train+ddr space enough+externalSSDKeysEmpty ok."; +} + +TEST_F(CacheManagerTest, EvictSSDEmbedding) +{ + // 构造时ssd中已存在的key: 15 25 + emb_key_t key = 15; + vector ssdKeys = {key}; + cacheManager.EvictSSDEmbedding(embTableName, ssdKeys); + ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, key)); + const auto it = cacheManager.excludeDDRKeyCountMap[embTableName].find(key); + ASSERT_EQ(it, cacheManager.excludeDDRKeyCountMap[embTableName].end()); + LOG(INFO) << "test EvictSSDEmbedding end."; +} \ No newline at end of file -- Gitee From 9e778f3b186094ec126c34fe635036ddbc9bd301 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 26 Aug 2023 16:02:30 +0800 Subject: [PATCH 295/551] Match-id-662ed2143e45e4514a05092b2ac0d330da993cfc --- src/core/emb_hashmap/emb_hashmap.cpp | 55 ++++++++++++++++-- src/core/emb_hashmap/emb_hashmap.h | 15 ++++- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 65 ++++++++++++++++++++-- src/core/hybrid_mgmt/hybrid_mgmt.h | 13 ++++- src/core/ssd_cache/cache_manager.cpp | 22 ++++++++ src/core/ssd_cache/cache_manager.h | 4 +- src/tests/ssd_cache/cache_manager_test.cpp | 33 +++++++++++ 7 files changed, 193 insertions(+), 14 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index b91f31b4..67d0da80 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -45,6 +45,12 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, #endif } +inline void ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) +{ + embHashMap.swapPos.clear(); + embHashMap.lookUpVec.clear(); +} + /// DDR模型下处理特征的offset、swap信息等 /// \param embName 表名 /// \param keys 查询向量 @@ -74,6 +80,9 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t } VLOG(GLOG_DEBUG) << "FindOffset end"; + // 调用刷新频次数据方法 + RefreshFreqInfoWithSwap(embName, embHashMap.oldSwap); + swapId++; EASY_BLOCK("hostHashMaps->tdt") @@ -107,8 +116,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t } // 清空本次记录的查询偏移和交换偏移 - embHashMap.swapPos.clear(); - embHashMap.lookUpVec.clear(); + ClearLookupAndSwapOffset(embHashMap); LOG(INFO) << StringFormat("current ddr emb:%s, usage:%d/[%d+%d]", embName.c_str(), embHashMap.maxOffset, embHashMap.devVocabSize, embHashMap.hostVocabSize); ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); @@ -296,7 +304,8 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& EASY_FUNCTION() size_t keySize = keys.size(); auto& embHashMap = embHashMaps.at(embName); - + vector evictHBMRKeys; + vector evictDDRKeys; for (size_t i = 0; i < keySize; i++) { size_t offset; auto key = keys[i]; @@ -319,10 +328,14 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& embHashMap.devOffset2KeyOld.emplace_back(offset, embHashMap.devOffset2Key[offset]); embHashMap.devOffset2Key[offset] = -1; embHashMap.evictDevPos.emplace_back(offset); + evictHBMRKeys.emplace_back(key); } else { - embHashMap.evictPos.emplace_back(offset - embHashMap.devVocabSize); + embHashMap.evictPos.emplace_back(offset); + evictDDRKeys.emplace_back(key); } } + cacheManager->RefreshFreqInfoCommon(embName, evictHBMRKeys, TransferType::HBM_2_EVICT); + cacheManager->RefreshFreqInfoCommon(embName, evictDDRKeys, TransferType::DDR_2_EVICT); LOG(INFO) << StringFormat( "ddr EvictDeleteEmb, emb: [%s], hostEvictSize: %d, devEvictSize: %d ", @@ -364,10 +377,12 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys embHashMap.lookUpVec.emplace_back(offset); embHashMap.devOffset2KeyOld.emplace_back(offset, static_cast(embHashMap.devOffset2Key[offset])); embHashMap.devOffset2Key[offset] = key; + AddKeyFreqInfo(embName, key, RecordType::NOT_DDR); } else { // 偏移大于HBM容量:记录在host emb上的偏移;找到需要交换的HBM偏移 embHashMap.missingKeysHostPos.emplace_back(offset - embHashMap.devVocabSize); FindSwapPosOld(embName, key, offset, currentBatchId, keepBatchId); + AddKeyFreqInfo(embName, key, RecordType::DDR); } } if (currentBatchId == 0) { @@ -533,4 +548,36 @@ bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hos } return true; } + +/// HBM-DDR换入换出时刷新频次信息 +/// \param embName emb表名 +/// \param oldSwap 换入换出key列表,元素为pair: pair oldKey为从HBM移出的key, key为从DDR移出的key +void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, + const std::vector>& oldSwap) +{ + if (!isSSDEnabled) { + return; + } + vector enterDDRKeys; + vector leaveDDRKeys; + for (auto keyPair : oldSwap) { + enterDDRKeys.emplace_back(keyPair.first); + leaveDDRKeys.emplace_back(keyPair.second); + } + cacheManager->RefreshFreqInfoCommon(embName, enterDDRKeys, TransferType::HBM_2_DDR); + cacheManager->RefreshFreqInfoCommon(embName, leaveDDRKeys, TransferType::DDR_2_HBM); +} + +/// 记录key频次数据 +/// \param embTableName emb表名 +/// \param key key +/// \param type 记录类型枚举 +void EmbHashMap::AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type) +{ + if (!isSSDEnabled) { + return; + } + cacheManager->PutKey(embTableName, key, type); +} + #endif diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index c8565a31..6a1fa414 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -13,6 +13,7 @@ #include #include "absl/container/flat_hash_map.h" #include "host_emb/host_emb.h" +#include "ssd_cache/cache_manager.h" #include "utils/common.h" namespace MxRec { @@ -72,14 +73,22 @@ namespace MxRec { return embHashMaps.at(embName).evictPos; } - private: - RankInfo rankInfo; - int swapId { 0 }; + bool isSSDEnabled { false }; + CacheManager* cacheManager; + private: void FindAndUpdateBatchId(vector& keys, size_t currentBatchId, size_t keySize, EmbHashMapInfo& embHashMap) const; int32_t FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap); + + void RefreshFreqInfoWithSwap(const string& embName, + const std::vector>& oldSwap); + + void AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type); + + RankInfo rankInfo; + int swapId { 0 }; }; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 1d5444ae..db6ed66a 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -76,14 +76,19 @@ void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfo // 计算训练任务涉及的所有表在DDR中需要分配的key数量 size_t totHostVocabSize = 0; + size_t totalSsdVocabSize = 0; for (const auto& emb : embInfos) { totHostVocabSize += emb.hostVocabSize; + totalSsdVocabSize += emb.ssdVocabSize; } // 根据DDR的key数量,配置存储模式HBM/DDR if (totHostVocabSize == 0) { rankInfo.noDDR = true; } + if (totalSsdVocabSize != 0) { + rankInfo.isSSDEnabled = true; + } #endif } @@ -128,13 +133,20 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, // DDR模式,初始化hashmap和host emb if (!rankInfo.noDDR) { - hostEmbs = make_unique(); + hostEmbs = Singleton::GetInstance(); hostHashMaps = make_unique(); hostEmbs->Initialize(embInfos, seed); hostHashMaps->Init(rankInfo, embInfos, ifLoad); } // 非断点续训模式,启动数据传输 + isSSDEnabled = rankInfo.isSSDEnabled; + if (isSSDEnabled) { + cacheManager = Singleton::GetInstance(); + cacheManager->Init(hostEmbs, mgmtEmbInfo); + hostHashMaps->isSSDEnabled = this->isSSDEnabled; + hostHashMaps->cacheManager = this->cacheManager; + } isLoad = ifLoad; if (!isLoad) { Start(); @@ -669,6 +681,20 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) return true; } +inline void HandlePrepareDDRDataRet(TransferRet prepareSSDRet) +{ + LOG(ERROR) << "Transfer embedding with DDR and SSD error."; + if (prepareSSDRet == TransferRet::SSD_SPACE_NOT_ENOUGH) { + LOG(ERROR) << "PrepareDDRData: SSD available space is not enough."; + throw runtime_error("ssdVocabSize too small"); + } + if (prepareSSDRet == TransferRet::DDR_SPACE_NOT_ENOUGH) { + LOG(ERROR) << "PrepareDDRData: DDR available space is not enough."; + throw runtime_error("ddrVocabSize too small"); + } + throw runtime_error("Transfer embedding with DDR and SSD error."); +} + #ifndef GTEST /// 构造训练所需的各种向量数据 /// \param embName 表名 @@ -685,9 +711,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, auto& embHashMap = hostHashMaps->embHashMaps.at(embName); // 进行新一批预取数据时,计数初始化 - if (iBatch == 0) { - embHashMap.SetStartCount(); - } + if (iBatch == 0) { embHashMap.SetStartCount(); } // 获取查询向量 auto lookupKeys = preprocess->GetLookupKeys(batchId, embName, channelId); @@ -702,6 +726,9 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embName); VLOG(GLOG_DEBUG) << StringFormat("sendRestoreSyncTC(ms):%d", sendRestoreSyncTC.ElapsedMS()); + // 调用SSD cache缓存处理流程 + PrepareDDRData(embName, embHashMap, lookupKeys, channelId); + // 计算查询向量;记录需要被换出的HBM偏移 vector tmpData; vector offsetsOut; @@ -836,6 +863,7 @@ bool HybridMgmt::Evict() } else { for (auto evict : evictKeyMap) { EvictKeys(evict.first, evict.second); + EvictSSDKeys(evict.first, evict.second); } } return true; @@ -885,3 +913,32 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) hdTransfer->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); #endif } + +inline void HybridMgmt::PrepareDDRData(const string& embTableName, EmbHashMapInfo& embHashMap, + const vector& keys, int channelId) +{ + if (!isSSDEnabled) { + return; + } + VLOG(GLOG_DEBUG) << "PrepareDDRData start."; + TimeCost prepareDDRDataTc; + TransferRet ret = cacheManager->TransferDDREmbWithSSD(embTableName, embHashMap, keys, channelId); + if (ret != TransferRet::TRANSFER_OK) { + HandlePrepareDDRDataRet(ret); + } + VLOG(GLOG_DEBUG) << StringFormat("PrepareDDRData end, TimeCost(ms):%d", prepareDDRDataTc.ElapsedMS()); +} + +void HybridMgmt::EvictSSDKeys(const string& embName, const vector& keys) +{ + if (!isSSDEnabled) { + return; + } + vector ssdKeys; + for (auto& key : keys) { + if (cacheManager->IsKeyInSSD(embName, key)) { + ssdKeys.emplace_back(key); + } + } + cacheManager->EvictSSDEmbedding(embName, ssdKeys); +} diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 19dec09c..3338ec15 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -25,6 +25,7 @@ #include "emb_hashmap/emb_hashmap.h" #include "hd_transfer/hd_transfer.h" #include "key_process/key_process.h" +#include "ssd_cache/cache_manager.h" namespace MxRec { using namespace std; @@ -80,6 +81,9 @@ namespace MxRec { for (auto& t : procThreads) { t->join(); } + if (cacheManager != nullptr) { + cacheManager = nullptr; + } if (hostEmbs != nullptr) { hostEmbs->Join(TRAIN_CHANNEL_ID); hostEmbs->Join(EVAL_CHANNEL_ID); @@ -113,6 +117,11 @@ namespace MxRec { void InitRankInfo(RankInfo& rankInfo, const vector& embInfos); + void EvictSSDKeys(const string& embName, const vector& keys); + + void PrepareDDRData(const std::string& embTableName, EmbHashMapInfo& embHashMap, + const vector& keys, int channelId); + private: int currentBatchId; int trainBatchId = 0; // 0-199, 200- @@ -120,12 +129,14 @@ namespace MxRec { int sendBatchId; vector mgmtEmbInfo; RankInfo mgmtRankInfo; - unique_ptr hostEmbs {}; + CacheManager* cacheManager; + HostEmb* hostEmbs {}; unique_ptr hostHashMaps {}; vector> procThreads {}; map> evictKeyMap {}; KeyProcess *preprocess; HDTransfer *hdTransfer; + bool isSSDEnabled { false }; bool isRunning; bool isLoad { false }; diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index 51ff1e06..bc521e85 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -447,3 +447,25 @@ CacheManager::~CacheManager() ddrKeyFreqMap.clear(); excludeDDRKeyCountMap.clear(); } + +/// 加载数据到CacheManager +/// \param ddrFreqInitMap ddr内key频次数据 +/// \param excludeDdrFreqInitMap 非DDR key频次数据 +void CacheManager::Load(unordered_map>& ddrFreqInitMap, + unordered_map>& excludeDdrFreqInitMap) +{ + for (auto& it : ddrFreqInitMap) { + auto& embTableName = it.first; + auto& freqMap = it.second; + for (auto& freqIt : freqMap) { + ddrKeyFreqMap[embTableName].PutWithInit(freqIt.first, freqIt.second); + } + } + for (auto& it : excludeDdrFreqInitMap) { + auto& embTableName = it.first; + auto& freqMap = it.second; + for (auto& freqIt : freqMap) { + excludeDDRKeyCountMap[embTableName].emplace(freqIt.first, freqIt.second); + } + } +} diff --git a/src/core/ssd_cache/cache_manager.h b/src/core/ssd_cache/cache_manager.h index c1145584..ac62e30c 100644 --- a/src/core/ssd_cache/cache_manager.h +++ b/src/core/ssd_cache/cache_manager.h @@ -48,8 +48,8 @@ namespace MxRec { void Init(HostEmb* hostEmbPtr, vector& mgmtEmbInfo); - // 保存/初始化模块相关数据 - bool SaveCacheManagerData(); + void Load(unordered_map>& ddrFreqInitMap, + unordered_map>& excludeDdrFreqInitMap); // 转换DDR和SSD数据 TransferRet TransferDDREmbWithSSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 3c4147dc..08aee0d3 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -9,6 +9,8 @@ #include #include +#define GTEST + #include "absl/container/flat_hash_map.h" #include "host_emb/host_emb.h" #include "ssd_cache/lfu_cache.h" @@ -349,4 +351,35 @@ TEST_F(CacheManagerTest, EvictSSDEmbedding) const auto it = cacheManager.excludeDDRKeyCountMap[embTableName].find(key); ASSERT_EQ(it, cacheManager.excludeDDRKeyCountMap[embTableName].end()); LOG(INFO) << "test EvictSSDEmbedding end."; +} + +TEST_F(CacheManagerTest, LoadTest) +{ + cacheManager.ddrKeyFreqMap.clear(); + cacheManager.excludeDDRKeyCountMap.clear(); + unordered_map> ddrMap; + string embTableName = "table1"; + unordered_map ddrTableMap; + ddrTableMap.emplace(1, 3); + ddrTableMap.emplace(2, 3); + ddrTableMap.emplace(3, 3); + ddrTableMap.emplace(4, 2); + ddrTableMap.emplace(6, 2); + ddrTableMap.emplace(8, 1); + ddrTableMap.emplace(9, 1); + ddrMap.emplace(embTableName, ddrTableMap); + unordered_map> excludeDdrMap; + unordered_map excludeDdrTableMap; + excludeDdrTableMap.emplace(15, 1); + excludeDdrTableMap.emplace(25, 5); + excludeDdrMap.emplace(embTableName, excludeDdrTableMap); + cacheManager.Load(ddrMap, excludeDdrMap); + // 数据检查 + auto& ddrKeyFreqMap = cacheManager.ddrKeyFreqMap; + auto& excludeDDRKeyCountMap = cacheManager.excludeDDRKeyCountMap; + ASSERT_EQ(ddrKeyFreqMap[embTableName].minFreq, 1); + ASSERT_EQ(ddrKeyFreqMap[embTableName].freqTable.size(), 3); + ASSERT_EQ(ddrKeyFreqMap[embTableName].Get(2), 3); + ASSERT_EQ(ddrKeyFreqMap[embTableName].Get(12), -1); + ASSERT_EQ(excludeDDRKeyCountMap[embTableName][25], 5); } \ No newline at end of file -- Gitee From 527dc2361a8e358168347d071fb73ff27e57d9aa Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 28 Aug 2023 10:20:06 +0800 Subject: [PATCH 296/551] Match-id-9e9fe556c2ebc2313dd3af415848cad7ae0ee0d0 --- src/core/ssd_engine/table.cpp | 5 ++++- src/core/ssd_engine/table.h | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index 9d5163b4..d9575428 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -21,7 +21,7 @@ Table::Table(const string &name, vector &savePaths, uint64_t maxTableSiz { curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + g_rankId + "/" + name).string(); if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { - throw runtime_error("fail to create table directory"); + throw runtime_error("fail to create table directory"); } LOG(INFO) << StringFormat("create table:%s at path:%s", name.c_str(), curTablePath.c_str()); } @@ -190,6 +190,9 @@ void Table::Load(const string &metaFilePath, int step) // Load table name and validate uint32_t nameSize; metaFile->read(reinterpret_cast(&nameSize), sizeof(nameSize)); + if (nameSize > maxNameSize) { + throw invalid_argument("table name too large, file may broken"); + } char *tmpArr = new char[nameSize + 1]; metaFile->read(tmpArr, static_cast(nameSize)); tmpArr[nameSize] = '\0'; diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h index de5327e5..4d454224 100644 --- a/src/core/ssd_engine/table.h +++ b/src/core/ssd_engine/table.h @@ -62,8 +62,9 @@ namespace MxRec { mutex rwLock{}; shared_ptr curFile = nullptr; uint64_t curMaxFileID = 0; // no concurrent writing, always atomic increase + const uint32_t maxNameSize = 1024; - /* args for performance + /* args for performance(not expose to user yet) * 2 read thread is optimal when: * embedding's dimension=240, maxDataNumInFile=10000 * fetch 1000000 keys at a time @@ -74,7 +75,7 @@ namespace MxRec { int readThreadNum = 2; uint32_t maxDataNumInFile = 10000; // relax constrain for performance, need tuning double compactThreshold = 0.5; - double diskFreeSpaceThreshold = 0.05; + double diskFreeSpaceThreshold = 0.05; // in range [0, 1), leave diskFreeSpaceThreshold*100 % for disk space }; } -- Gitee From c8375ffd5d5e11d71e70cc91ece0f54e57f1bcca Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 28 Aug 2023 20:40:44 +0800 Subject: [PATCH 297/551] Match-id-474ec8550ed56cc5ad039632cdd0da364d913bff --- src/core/checkpoint/checkpoint.cpp | 31 ++++++++++++++++++++---------- src/core/checkpoint/checkpoint.h | 8 ++++++-- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 92216402..b943f642 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -143,14 +143,17 @@ void Checkpoint::MakeSaveDir(const string& dirName) } } -int Checkpoint::GetEmbeddingSize(const string& embName) const +Checkpoint::EmbSizeInfo Checkpoint::GetEmbeddingSize(const string& embName) { + EmbSizeInfo embSizeInfo; for (const auto &embInfo: mgmtEmbInfo) { if (embInfo.name == embName) { - return embInfo.extEmbeddingSize; + embSizeInfo.embSize = embInfo.embeddingSize; + embSizeInfo.extEmbSize = embInfo.extEmbeddingSize; + return embSizeInfo; } } - return 0; + return embSizeInfo; } bool Checkpoint::CheckEmbNames(const string& embName) @@ -185,10 +188,10 @@ void Checkpoint::SaveDataset(const vector& embNames, if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { auto embedPath { dataDir + dirSeparator + "key_embedding" }; auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; - auto embeddingSize = GetEmbeddingSize(embName); + auto embeddingSizeInfo = GetEmbeddingSize(embName); MakeSaveDir(embedPath); VLOG(GLOG_DEBUG) << StringFormat("====Start saving embedding data to: %s", datasetDir.c_str()); - WriteEmbedding(transData, embedDatasetDir, embeddingSize); + WriteEmbedding(transData, embedDatasetDir, embeddingSizeInfo.extEmbSize); } VLOG(GLOG_DEBUG) << StringFormat("====Start saving data to: %s", datasetDir.c_str()); @@ -231,7 +234,7 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da writeFile.close(); } -void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) +void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, const string& embName) { std::ifstream readFile; readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); @@ -270,18 +273,26 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir) float *floatPtr = static_cast(newBlock); auto &transArr = transData.int64Arr; - for (size_t i{0}; i < transArr.size(); i += embHashNum) { + EmbSizeInfo embSizeInfo = GetEmbeddingSize(embName); + if (embSizeInfo.embSize == 0) { + throw runtime_error(StringFormat("embsize is 0").c_str()); + } + auto keyAddrElem = embSizeInfo.extEmbSize / embSizeInfo.embSize - 1; + if (keyAddrElem < 0) { + throw runtime_error(StringFormat("keyAddrElem: %d is less than 0", keyAddrElem).c_str()); + } + for (size_t i{0}, j{0}; i < transArr.size(); i += keyAddrElem, ++j) { vector row(embeddingSize); readFile.read((char *) (row.data()), embeddingSize * sizeof(float)); - aclError ret = aclrtMemcpy(floatPtr + i * embeddingSize, embeddingSize * sizeof(float), + aclError ret = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); readFile.close(); throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } - int64_t address = reinterpret_cast(floatPtr + i * embeddingSize); + int64_t address = reinterpret_cast(floatPtr + j * embeddingSize); transArr.at(i + 1) = address; } #endif @@ -412,7 +423,7 @@ void Checkpoint::LoadDataset(const vector& embNames, auto embedPath { dataDir + dirSeparator + "key_embedding" }; auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; VLOG(GLOG_DEBUG) << StringFormat("====Start loading embedding data from: %s", datasetDir.c_str()); - ReadEmbedding(transData, embedDatasetDir); + ReadEmbedding(transData, embedDatasetDir, embName); } VLOG(GLOG_DEBUG) << StringFormat( diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 481a8380..c860c29e 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -81,9 +81,13 @@ namespace MxRec { void WriteDataset(CkptTransData& transData, ofstream& writeFile, size_t writeSize, CkptDataType dataType, size_t idx); void WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize); - void ReadEmbedding(CkptTransData& transData, const string& dataDir); + void ReadEmbedding(CkptTransData& transData, const string& dataDir, const string& embName); - int GetEmbeddingSize(const string& embName) const; + struct EmbSizeInfo { + int embSize; + int extEmbSize; // embSize + (optimizer's slot) * embSize + }; + EmbSizeInfo GetEmbeddingSize(const string& embName); bool CheckEmbNames(const string& embNames); void LoadProcess(CkptData& ckptData); -- Gitee From 5fe793ffb0fdb9aa5579d130344a418d735cbcb8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 28 Aug 2023 22:19:44 +0800 Subject: [PATCH 298/551] Match-id-ef79bb890fac7f02628dfe26334a27587fd0091a --- src/core/emb_hashmap/emb_hashmap.cpp | 7 ++- src/core/ssd_cache/cache_manager.cpp | 66 ++++++++++++++-------- src/tests/ssd_cache/cache_manager_test.cpp | 2 - 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 67d0da80..0d20278d 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -304,7 +304,7 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& EASY_FUNCTION() size_t keySize = keys.size(); auto& embHashMap = embHashMaps.at(embName); - vector evictHBMRKeys; + vector evictHBMKeys; vector evictDDRKeys; for (size_t i = 0; i < keySize; i++) { size_t offset; @@ -328,13 +328,13 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& embHashMap.devOffset2KeyOld.emplace_back(offset, embHashMap.devOffset2Key[offset]); embHashMap.devOffset2Key[offset] = -1; embHashMap.evictDevPos.emplace_back(offset); - evictHBMRKeys.emplace_back(key); + evictHBMKeys.emplace_back(key); } else { embHashMap.evictPos.emplace_back(offset); evictDDRKeys.emplace_back(key); } } - cacheManager->RefreshFreqInfoCommon(embName, evictHBMRKeys, TransferType::HBM_2_EVICT); + cacheManager->RefreshFreqInfoCommon(embName, evictHBMKeys, TransferType::HBM_2_EVICT); cacheManager->RefreshFreqInfoCommon(embName, evictDDRKeys, TransferType::DDR_2_EVICT); LOG(INFO) << StringFormat( @@ -558,6 +558,7 @@ void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, if (!isSSDEnabled) { return; } + VLOG(GLOG_DEBUG) << StringFormat("RefreshFreqInfoWithSwap:oldSwap Size:%lld", oldSwap.size()); vector enterDDRKeys; vector leaveDDRKeys; for (auto keyPair : oldSwap) { diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index bc521e85..4da9d210 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -30,41 +30,59 @@ inline TransferRet TransferSpaceWarning() return TransferRet::SSD_SPACE_NOT_ENOUGH; } -inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& externalKeys, const vector& keys) +inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& externalKeys, + vector& internalKeys, const vector& keys) { auto& hostHashMap = embHashMap.hostHashMap; for (auto key : keys) { if (hostHashMap.find(key) == hostHashMap.end()) { externalKeys.emplace_back(key); + } else { + internalKeys.emplace_back(key); } } } -void AddDebugAndTraceLog(vector& externalKeys, vector& externalSSDKeys) +void AddDebugAndTraceLog(size_t batchKeySize, vector& externalKeys, vector& externalSSDKeys) { - VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: externalKeys size:%d, externalSSDKeys size:%d", - externalKeys.size(), externalSSDKeys.size()); + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: batchKeySize:%d, externalKeys size:%d," + " externalSSDKeys size:%d", + batchKeySize, externalKeys.size(), externalSSDKeys.size()); if (VLOG_IS_ON(GLOG_TRACE)) { - VLOG(GLOG_TRACE) << StringFormat("TransferDDREmbWithSSD: externalKeys:%s, externalSSDKeys:%s", - VectorToString(externalKeys).c_str(), VectorToString(externalSSDKeys).c_str()); + VLOG(GLOG_TRACE) << "TransferDDREmbWithSSD: externalKeys:" << VectorToString(externalKeys).c_str() + << ", externalSSDKeys:%s" << VectorToString(externalSSDKeys).c_str(); } } +inline vector DeleteRepeatKey(const vector& originalKeys) +{ + // 去重并保持原key的顺序 结果可测试 + unordered_set keySet; + vector keys; + for (auto& key : originalKeys) { + if (keySet.find(key) == keySet.end()) { + keySet.emplace(key); + keys.emplace_back(key); + } + } + return keys; +} + /// DDR与SSD数据转移,使DDR内剩余空间能放置当前批次key /// \param embTableName emb表名 /// \param embHashMap emb表 -/// \param keys 当前批次key +/// \param originalKeys 当前批次key /// \param channelId 通道id /// \return 转移结果枚举 TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, - const vector& keys, int channelId) + const vector& originalKeys, int channelId) { + vector keys = DeleteRepeatKey(originalKeys); // 去重 // 区分HBM+DDR内key,和HBM+DDR外的key(新key或保存在SSD中的key) vector externalKeys; - GetExternalKeys(embHashMap, externalKeys, keys); - if (externalKeys.empty()) { - return TransferSuccess(); - } + vector internalKeys; + GetExternalKeys(embHashMap, externalKeys, internalKeys, keys); + if (externalKeys.empty()) { return TransferSuccess(); } // 判断剩余内存空间是否足够; 可用内存空间计算:HBM+DDR-已占用; 若是训练,再加DDR已淘汰; // SSD仅与DDR交互,不考虑HBM淘汰位置 @@ -86,7 +104,7 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, return TransferSuccess(); } - AddDebugAndTraceLog(externalKeys, externalSSDKeys); + AddDebugAndTraceLog(keys.size(), externalKeys, externalSSDKeys); /* * 前面 externalSSDKeys = 0 ,评估场景的 ddr空间可用、不可用已返回; 训练的可用已返回; * 剩下的情况如下: @@ -141,7 +159,7 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, */ // 记录要从DDR转移到SSD的key对应的offset(相对值,需减去devVocabSize) vector ddrTransferPos; - TransferRet ddr2SsdRet = TransferDDREmb2SSD(embTableName, embHashMap, ddrSwapOutSize, keys, ddrTransferPos); + TransferRet ddr2SsdRet = TransferDDREmb2SSD(embTableName, embHashMap, ddrSwapOutSize, internalKeys, ddrTransferPos); if (ddr2SsdRet == TransferRet::DDR_SPACE_NOT_ENOUGH) { ssdEngine->InsertEmbeddings(embTableName, externalSSDKeys, ssdEmbData); return ddr2SsdRet; @@ -231,11 +249,12 @@ void CacheManager::RefreshRelateInfoWithDDR2SSD(const string& embTableName, EmbH /// key从DDR移入、移出、HBM淘汰时刷新频次信息;仅刷新频次信息 /// \param embTableName emb表名 -/// \param keys 操作的key集合 +/// \param originalKeys 操作的key集合 /// \param type TransferType -void CacheManager::RefreshFreqInfoCommon(const string& embTableName, vector& keys, +void CacheManager::RefreshFreqInfoCommon(const string& embTableName, vector& originalKeys, TransferType type) { + vector keys = DeleteRepeatKey(originalKeys); if (type == TransferType::DDR_2_HBM) { for (auto& key : keys) { // 频次数据记录到 excludeDDRKeyCountMap,并删除ddrKeyFreqMap中频次数据 @@ -278,14 +297,16 @@ bool CacheManager::IsKeyInSSD(const string& embTableName, emb_key_t key) /// 淘汰SSD中Emb信息 /// \param embTableName emb表名 -/// \param keys 淘汰key列表 -void CacheManager::EvictSSDEmbedding(const string& embTableName, vector& keys) +/// \param originalKeys 淘汰key列表 +void CacheManager::EvictSSDEmbedding(const string& embTableName, vector& originalKeys) { + vector keys = DeleteRepeatKey(originalKeys); // 1 删除缓存中记录的key的次数 2 删除SSD中保存的Emb数据 for (auto& key : keys) { excludeDDRKeyCountMap[embTableName].erase(key); } - ssdEngine->DeleteEmbeddings(embTableName, keys); + vector currentKeys(keys.begin(), keys.end()); + ssdEngine->DeleteEmbeddings(embTableName, currentKeys); } /// 放入key,新增/更新(次数+1)次数 @@ -367,10 +388,11 @@ TransferRet CacheManager::TransferDDREmb2SSD(const string& embTableName, EmbHash ddrKeyFreqMap[embTableName].GetAndDeleteLeastFreqKeyInfo(ddrSwapOutSize, keys, ddrSwapOutKeys, ddrSwapOutCounts); if (static_cast(ddrSwapOutKeys.size()) != ddrSwapOutSize) { + auto keyTableSize = ddrKeyFreqMap[embTableName].keyTable.size(); // 获取的最低频次key数量和预期不一致,DDR空间不足,不能放置当前批次数据 - LOG(ERROR) << StringFormat( - "TransferDDREmbWithSSD, vector length is not equal, ddrSwapOutKeys size:%d, ddrSwapOutSize:%lld", - ddrSwapOutKeys.size(), ddrSwapOutSize); + LOG(ERROR) << StringFormat("TransferDDREmbWithSSD, vector length is not equal, ddrSwapOutKeys size:%d, " + "ddrSwapOutSize:%lld, ddr lfu keyTable size:%lld", ddrSwapOutKeys.size(), + ddrSwapOutSize, keyTableSize); RestoreLeastFreqInfo(embTableName, ddrSwapOutKeys, ddrSwapOutCounts); return TransferRet::DDR_SPACE_NOT_ENOUGH; } diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 08aee0d3..37a4037d 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -9,8 +9,6 @@ #include #include -#define GTEST - #include "absl/container/flat_hash_map.h" #include "host_emb/host_emb.h" #include "ssd_cache/lfu_cache.h" -- Gitee From 9b0114a264e714b303b5d213c616877aa678d6a9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 29 Aug 2023 20:40:50 +0800 Subject: [PATCH 299/551] Match-id-f223cf141b0053dbe3b41b2f7d180ed5740ef98b --- src/core/key_process/key_process.cpp | 96 ++++++++++++++------ src/core/key_process/key_process.h | 9 +- src/core/ock_ctr_common/include/error_code.h | 15 +++ src/tests/key_process/key_process_test.cpp | 8 +- src/tests/ssd_cache/cache_manager_test.cpp | 10 +- 5 files changed, 99 insertions(+), 39 deletions(-) create mode 100644 src/core/ock_ctr_common/include/error_code.h diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 89fa5e92..b6748736 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -114,7 +114,7 @@ int KeyProcess::Start() // | rank0 | | rank1 | // each rank creates KEY_PROCESS_THREAD threads, each thread process one batchdata LOG(INFO) << StringFormat("CPU Core Num: %d", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 - auto fn = [this](int channel, int id) { + auto fn = [this](int channel, int threadId) { #ifndef GTEST auto ret = aclrtSetDevice(static_cast(rankInfo.deviceId)); if (ret != ACL_ERROR_NONE) { @@ -122,7 +122,11 @@ int KeyProcess::Start() return; } #endif - KeyProcessTask(channel, id); + if (PerfConfig::fastUnique) { + KeyProcessTaskWithFastUnique(channel, threadId); + } else { + KeyProcessTask(channel, threadId); + } }; // for clean code int threadNum = GetThreadNumEnv(); for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { @@ -241,12 +245,15 @@ void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, uniqueConf.maxIdVal = INT64_MAX; uniqueConf.dataType = ock::ctr::DataType::INT64; - unique->Initialize(uniqueConf); + auto ret = unique->Initialize(uniqueConf); + if (ret != ock::ctr::H_OK) { + throw runtime_error(StringFormat("fast unique init failed, code:%d", ret)); + } uniqueInitialize = true; } } -void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCESS_THREAD-1] +void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) { unique_ptr batch; UniquePtr unique = nullptr; @@ -254,17 +261,18 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES size_t preBatchSize = 0; bool uniqueInitialize = false; - if (PerfConfig::fastUnique) { - factory->CreateUnique(unique); - GetUniqueConfig(uniqueConf); + auto ret = factory->CreateUnique(unique); + if (ret != ock::ctr::H_OK) { + throw runtime_error(StringFormat("create fast unique failed, error code:%d", ret)); } + GetUniqueConfig(uniqueConf); TimeCost tc = TimeCost(); try { while (true) { TimeCost getAndProcessTC; TimeCost getBatchDataTC; - batch = GetBatchData(channel, id); // get batch data from SingletonQueue + batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue VLOG(GLOG_DEBUG) << StringFormat("getBatchDataTC(ms):%d", getBatchDataTC.ElapsedMS()); if (batch == nullptr) { break; @@ -272,33 +280,62 @@ void KeyProcess::KeyProcessTask(int channel, int id) // thread id [0, KEY_PROCES auto getBatchTime = tc.ElapsedMS(); tc = TimeCost(); - bool ret = false; - if (PerfConfig::fastUnique) { - InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); - ret = KeyProcessTaskHelperWithFastUnique(batch, unique, channel, id); - } else { - ret = KeyProcessTaskHelper(batch, channel, id); + InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); + if (!KeyProcessTaskHelperWithFastUnique(batch, unique, channel, threadId)) { + break; + } + LOG(INFO) << StringFormat( + KEY_PROCESS "getAndProcessTC(ms):%d, key process with fast unique cost:%d," + " get data time(ms):%d, batch name:%s, channel:%d, batchID:%d", + getAndProcessTC.ElapsedMS(), tc.ElapsedMS(), getBatchTime, + batch->name.c_str(), batch->channel, batch->batchId + ); + auto batchQueue = SingletonQueue::getInstances(threadId + KEY_PROCESS_THREAD * batch->channel); + batchQueue->PutDirty(move(batch)); + } + unique->UnInitialize(); + } catch (const EndRunError &e) { + VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "abort run: %s", e.what()); + } + LOG(INFO) << StringFormat( + KEY_PROCESS "KeyProcessTaskWithFastUnique exit. rank:%d thread:%d, channel:%d", + rankInfo.rankId, threadId, channel); +} + + +void KeyProcess::KeyProcessTask(int channel, int threadId) +{ + unique_ptr batch; + TimeCost tc = TimeCost(); + try { + while (true) { + TimeCost getAndProcessTC; + TimeCost getBatchDataTC; + batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue + VLOG(GLOG_DEBUG) << StringFormat("getBatchDataTC(ms):%d", getBatchDataTC.ElapsedMS()); + if (batch == nullptr) { + break; } + auto getBatchTime = tc.ElapsedMS(); + tc = TimeCost(); - if (!ret) { + if (!KeyProcessTaskHelper(batch, channel, threadId)) { break; } LOG(INFO) << StringFormat( - KEY_PROCESS "getAndProcessTC(ms):%d, key process cost:%d, get data time:%d batch %s[%d]:%d", + KEY_PROCESS "getAndProcessTC(ms):%d, key process cost:%d," + " get data time(ms):%d, batch name:%s, channel:%d, batchID:%d", getAndProcessTC.ElapsedMS(), tc.ElapsedMS(), getBatchTime, batch->name.c_str(), batch->channel, batch->batchId ); - auto batchQueue = SingletonQueue::getInstances(id + KEY_PROCESS_THREAD * batch->channel); + auto batchQueue = SingletonQueue::getInstances(threadId + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } - if (PerfConfig::fastUnique) { - unique->UnInitialize(); - } } catch (const EndRunError &e) { VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "abort run: %s", e.what()); } LOG(INFO) << StringFormat( - KEY_PROCESS "KeyProcessTask exit. rank:%d thread:%d, channel:%d", rankInfo.rankId, id, channel); + KEY_PROCESS "KeyProcessTask exit. rank:%d thread:%d, channel:%d", rankInfo.rankId, threadId, channel); } void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, @@ -320,14 +357,14 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < } bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, UniquePtr& unique, - int channel, int id) + int channel, int threadId) { // tuple for keyRec restore hotPos scAll countRecv isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; TimeCost fastUniqueTC; UniqueInfo uniqueInfo; - ProcessBatchWithFastUnique(batch, unique, id, uniqueInfo); + ProcessBatchWithFastUnique(batch, unique, threadId, uniqueInfo); VLOG(GLOG_DEBUG) << StringFormat("ProcessBatchWithFastUnique(ms):%d", fastUniqueTC.ElapsedMS()); // 特征准入&淘汰 @@ -336,7 +373,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat uniqueInfo.all2AllInfo.countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { LOG(ERROR) << StringFormat(KEY_PROCESS "rank:%d thread:%d, channel:%d, Feature-admit-and-evict error ...", - rankInfo.rankId, id, channel); + rankInfo.rankId, threadId, channel); return false; } @@ -370,7 +407,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat return true; } -bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int id) +bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) { vector splitKeys; vector restore; @@ -378,12 +415,12 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe vector> keyCount; HashSplitHelper(batch, splitKeys, restore, hotPos, keyCount); - auto [lookupKeys, scAll, ss] = ProcessSplitKeys(batch, id, splitKeys); + auto [lookupKeys, scAll, ss] = ProcessSplitKeys(batch, threadId, splitKeys); vector countRecv; if (m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { - countRecv = GetCountRecv(batch, id, keyCount, scAll, ss); + countRecv = GetCountRecv(batch, threadId, keyCount, scAll, ss); } BuildRestoreVec(batch, ss, restore, static_cast(hotPos.size())); @@ -394,7 +431,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, lookupKeys, countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { LOG(ERROR) << StringFormat(KEY_PROCESS "rank:%d thread:%d, channel:%d, Feature-admit-and-evict error ...", - rankInfo.rankId, id, channel); + rankInfo.rankId, threadId, channel); return false; } @@ -582,6 +619,9 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch uniqueOut.uniqueIdCnt = 0; int ret = unique->DoEnhancedUnique(uniqueIn, uniqueOut); + if (ret != ock::ctr::H_OK) { + throw runtime_error(StringFormat("fast unique DoEnhancedUnique failed, code:%d", ret)); + } EASY_END_BLOCK VLOG(GLOG_DEBUG) << StringFormat("FastUniqueCompute(ms):%d, ret:%d", uniqueTC.ElapsedMS(), ret); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 5214de53..d6b9df3d 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -22,6 +22,7 @@ #include #include "ock_ctr_common/include/factory.h" +#include "ock_ctr_common/include/error_code.h" #include "utils/common.h" #include "utils/time_cost.h" @@ -156,12 +157,14 @@ namespace MxRec { void InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo); - void KeyProcessTask(int channel, int id); + void KeyProcessTask(int channel, int threadId); - bool KeyProcessTaskHelper(unique_ptr& batch, int channel, int id); + void KeyProcessTaskWithFastUnique(int channel, int threadId); + + bool KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId); bool KeyProcessTaskHelperWithFastUnique(unique_ptr &batch, UniquePtr& unique, - int channel, int id); + int channel, int threadId); auto ProcessSplitKeys(const unique_ptr& batch, int id, vector& splitKeys) -> tuple, vector>; diff --git a/src/core/ock_ctr_common/include/error_code.h b/src/core/ock_ctr_common/include/error_code.h new file mode 100644 index 00000000..0905502a --- /dev/null +++ b/src/core/ock_ctr_common/include/error_code.h @@ -0,0 +1,15 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +#ifndef OCK_CTR_ERROR_CODE_H +#define OCK_CTR_ERROR_CODE_H +namespace ock { + namespace ctr { + using CTRCode = enum : int { + H_OK = 0, + }; + } +} + +#endif //OCK_CTR_ERROR_CODE_H diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index e7e97364..e2a66fbe 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -554,11 +554,13 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) UniqueInfo uniqueInfo; batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue - ASSERT_EQ(factory->CreateUnique(unique), 0); + ASSERT_EQ(factory->CreateUnique(unique), ock::ctr::H_OK); UniqueConf uniqueConf; + size_t preBatchSize = 0; + bool uniqueInitialize = false; process.GetUniqueConfig(uniqueConf); - unique->Initialize(uniqueConf); - + process.InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); + LOG(INFO) << StringFormat("rankid :%d,batchid: %d", rankInfo.rankId, batch->batchId); process.KeyProcessTaskHelperWithFastUnique(batch, unique, channel, id); LOG(INFO) << StringFormat( diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 37a4037d..86b9d16b 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -68,6 +68,11 @@ class CacheManagerTest : public testing::Test { protected: void SetUp() { + // 设置全局rankId,ssdEngine保存时会使用 + int workRankId; + MPI_Comm_rank(MPI_COMM_WORLD, &workRankId); + g_rankId = to_string(workRankId); + cacheManager.ddrKeyFreqMap[embTableName] = cache; cacheManager.ddrKeyFreqMap[embTableName].PutKeys(input_keys); LFUCache cache2; @@ -92,11 +97,6 @@ protected: cacheManager.hostEmbs->hostEmbs = loadData; auto& embMap = cacheManager.hostEmbs->hostEmbs; - - // 设置全局rankId,ssdEngine保存时会使用 - int workRankId; - MPI_Comm_rank(MPI_COMM_WORLD, &workRankId); - g_rankId = to_string(workRankId); } CacheManager cacheManager; -- Gitee From a40074dcc731ebbf3d4acd4139e763b639f93265 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 30 Aug 2023 10:17:12 +0800 Subject: [PATCH 300/551] Match-id-fb7f2242ab3d1bff5c9131b5fa0b89eb61490ab6 --- src/core/ock_ctr_common/include/error_code.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/ock_ctr_common/include/error_code.h b/src/core/ock_ctr_common/include/error_code.h index 0905502a..b8616b46 100644 --- a/src/core/ock_ctr_common/include/error_code.h +++ b/src/core/ock_ctr_common/include/error_code.h @@ -12,4 +12,4 @@ namespace ock { } } -#endif //OCK_CTR_ERROR_CODE_H +#endif // OCK_CTR_ERROR_CODE_H -- Gitee From 04b0b0030da7b0f624f5b7585c0e6d2e1684aa7b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 30 Aug 2023 14:47:05 +0800 Subject: [PATCH 301/551] Match-id-24f6a8c08e5d46681119620b9617c7c8bfba4a78 --- mx_rec/core/embedding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 0a7bd9c2..6c7b3a77 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -722,7 +722,9 @@ class SparseEmbedding: if kwargs.get("multi_lookup"): lookup_result = tf.reshape(embeddings, [-1, self.scalar_emb_size]) else: - feature_spec_tensor = kwargs.get("batch").get(feature_spec.index_key) + feature_spec_tensor = None + if not self.modify_graph: + feature_spec_tensor = kwargs.get("batch").get(feature_spec.index_key) modify_graph_tensor = kwargs.get("ids") tensor = feature_spec_tensor if not self.modify_graph else modify_graph_tensor if tensor is None: -- Gitee From ff56871084a294eb33cb97098d206dc3818215c4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 30 Aug 2023 21:11:26 +0800 Subject: [PATCH 302/551] Match-id-03d202a5a46c3e0820f682539185dcfbba921418 --- src/core/hd_transfer/hd_transfer.cpp | 18 ++++----- src/core/key_process/key_process.cpp | 14 +++---- src/core/utils/common.h | 5 ++- src/tests/emb_mgmt/emb_mgmt_test.cpp | 10 ++--- src/tests/key_process/key_process_test.cpp | 44 ++++++++-------------- 5 files changed, 36 insertions(+), 55 deletions(-) diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index f8937dac..48ebf562 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -156,12 +156,10 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in } string sendName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat( - HD + "hd transfer send %s, send count is %d, size list:%s", - sendName.c_str(), sizes.size(), VectorToString(sizes).c_str() - ); - } + REC_LOG(INFO) << StringFormat( + HD + "hd transfer send %s, send count is %d, size list:%s", + sendName.c_str(), sizes.size(), VectorToString(sizes).c_str() + ); if (sizes.size() == 0) { LOG(WARNING) << "tensors num can not be zero"; @@ -218,11 +216,9 @@ vector HDTransfer::Recv(TransferChannel channel, int channel for (auto& t: tensors) { sizes.push_back(t.NumElements()); } - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat( - "hd transfer recv:%s, size:%d cost:%dms", recvName.c_str(), VectorToString(sizes).c_str(), tc.ElapsedMS() - ); - } + REC_LOG(INFO) << StringFormat( + "hd transfer recv:%s, size:%d cost:%dms", recvName.c_str(), VectorToString(sizes).c_str(), tc.ElapsedMS() + ); return tensors; #endif } diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index b6748736..e49066f4 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -72,9 +72,7 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos } } - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat(KEY_PROCESS "hot emb count info:%s", MapToString(hotEmbTotCount).c_str()); - } + REC_LOG(INFO) << StringFormat(KEY_PROCESS "hot emb count info:%s", MapToString(hotEmbTotCount).c_str()); MPI_Group worldGroup; MPI_Comm_group(MPI_COMM_WORLD, &worldGroup); for (auto& i: comm) { @@ -97,12 +95,10 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos Factory::Create(factory); } - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat( - KEY_PROCESS "scInfo:%s, localRankSize:%d, rankSize:%d, useStatic:%d, useHot:%d", - MapToString(scInfo).c_str(), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot - ); - } + REC_LOG(INFO) << StringFormat( + KEY_PROCESS "scInfo:%s, localRankSize:%d, rankSize:%d, useStatic:%d, useHot:%d", + MapToString(scInfo).c_str(), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot + ); return true; } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 29b25c4e..fd48a660 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -45,6 +45,8 @@ #define EASY_PROFILER_DISABLE #endif +#define REC_LOG(severity) if (g_glogLevel >= severity) LOG(severity) + namespace MxRec { #define INFO_PTR shared_ptr #define MGMT_CPY_THREADS 4 @@ -313,7 +315,8 @@ namespace MxRec { // use environment variable GLOG_v to decide if showing debug log. // default 0, debug message will not display. // 1 for debug, 2 for trace - const int GLOG_DEBUG = 1, GLOG_TRACE = 2; + constexpr int GLOG_DEBUG = 1; + constexpr int GLOG_TRACE = 2; template std::string VectorToString(const std::vector& vec) diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index a570915d..32c2d89d 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -72,12 +72,10 @@ protected: tensorPtr = tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.extEmbeddingSize; } for (size_t i = 0; i < hostEmb->GetEmb(embName).embData.size(); ++i) { - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat( - "hostEmb: embName %s, %d is: %s", embName.c_str(), i, - VectorToString(hostEmb->GetEmb(embName).embData[i]).c_str() - ); - } + REC_LOG(INFO) << StringFormat( + "hostEmb: embName %s, %d is: %s", embName.c_str(), i, + VectorToString(hostEmb->GetEmb(embName).embData[i]).c_str() + ); } LOG(INFO) << (HD + "update emb end"); d2h_emb.clear(); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index e2a66fbe..882b9c4e 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -178,18 +178,12 @@ protected: { for (int i = 0; i < rankSize; ++i) { std::cout << "splitKeys dev" << i << std::endl; - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat("%d", VectorToString(splitKeys[i]).c_str()); - } + REC_LOG(INFO) << StringFormat("%d", VectorToString(splitKeys[i]).c_str()); } std::cout << "restore" << std::endl; - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat("%d", VectorToString(restore).c_str()); - } + REC_LOG(INFO) << StringFormat("%d", VectorToString(restore).c_str()); std::cout << "hotPos" << std::endl; - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat("%d", VectorToString(hotPos).c_str()); - } + REC_LOG(INFO) << StringFormat("%d", VectorToString(hotPos).c_str()); } void GetExpectRestore(keys_t& sample, vector& blockOffset, vector& restoreVec) @@ -343,12 +337,10 @@ TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { 5, 0, 6, 2, 1, 3, 1, 7, 4, 8 }, { 6, 3, 7, 4, 3, 0, 1, 2, 5, 8 } }; batch->sample = std::move(allBatchKeys[worldRank]); - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat( - KEY_PROCESS "test BuildRestoreVec: rank %d, batchKeys %s", - worldRank, VectorToString(batch->sample).c_str() - ); - } + LOG(INFO) << StringFormat( + KEY_PROCESS "test BuildRestoreVec: rank %d, batchKeys %s", + worldRank, VectorToString(batch->sample).c_str() + ); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); auto [splitKeys, restore] = process.HashSplit(batch); @@ -379,11 +371,9 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue LOG(INFO) << StringFormat("rankid :%d,batchid: %d", rankInfo.rankId, batch->batchId); tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat( - "rankid :%d,batchid: %d, hotPos %s", rankInfo.rankId, batch->batchId, VectorToString(hotPos).c_str() - ); - } + LOG(INFO) << StringFormat( + "rankid :%d,batchid: %d, hotPos %s", rankInfo.rankId, batch->batchId, VectorToString(hotPos).c_str() + ); }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < 1; ++id) { @@ -412,13 +402,11 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); auto[lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); process.BuildRestoreVec(batch, ss, restore, hotPos.size()); - if (g_glogLevel >= INFO) { - LOG(INFO) << StringFormat( - "rankid :%d,batchid: %d, lookupKeys: %s, scAll: %s, restore after build %s", - rankInfo.rankId, batch->batchId, VectorToString(lookupKeys).c_str(), - VectorToString(scAll).c_str(), VectorToString(restore).c_str() - ); - } + REC_LOG(INFO) << StringFormat( + "rankid :%d,batchid: %d, lookupKeys: %s, scAll: %s, restore after build %s", + rankInfo.rankId, batch->batchId, VectorToString(lookupKeys).c_str(), + VectorToString(scAll).c_str(), VectorToString(restore).c_str() + ); }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { @@ -591,4 +579,4 @@ TEST(KeyProcess, SetupHotEmbUpdateStep) putenv("HOT_EMB_UPDATE_STEP=0"); kp.SetupHotEmbUpdateStep(); ASSERT_EQ(kp.hotEmbUpdateStep, HOT_EMB_UPDATE_STEP_DEFAULT); -} \ No newline at end of file +} -- Gitee From 34c83f54dab58674db188e5aa5dce199e9569a2e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 29 Aug 2023 22:14:11 +0800 Subject: [PATCH 303/551] Match-id-b2764dc575579a332a64f1e1260dbd8e785c648a --- src/core/emb_hashmap/emb_hashmap.cpp | 63 ++++++++++++++++++++-- src/core/emb_hashmap/emb_hashmap.h | 5 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 41 ++++++++++++-- src/core/hybrid_mgmt/hybrid_mgmt.h | 1 + src/core/ssd_cache/cache_manager.cpp | 55 +++++++++++++------ src/core/ssd_cache/cache_manager.h | 4 +- src/core/utils/common.h | 13 ++++- src/tests/ssd_cache/cache_manager_test.cpp | 2 +- 8 files changed, 156 insertions(+), 28 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 0d20278d..c87cd26f 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -81,7 +81,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t VLOG(GLOG_DEBUG) << "FindOffset end"; // 调用刷新频次数据方法 - RefreshFreqInfoWithSwap(embName, embHashMap.oldSwap); + RefreshFreqInfoWithSwap(embName, embHashMap); swapId++; EASY_BLOCK("hostHashMaps->tdt") @@ -260,18 +260,28 @@ auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map { auto embHashMapsOld = embHashMaps; for (auto& temp: embHashMapsOld) { + auto& embTableName = temp.first; auto& embHashMap = temp.second; + vector hbm2DdrKeys; + vector ddr2HbmKeys; for (auto& swapKeys: embHashMap.oldSwap) { emb_key_t oldKey = swapKeys.first; emb_key_t key = swapKeys.second; int tempOffset = static_cast(embHashMap.hostHashMap[key]); embHashMap.hostHashMap[key] = embHashMap.hostHashMap[oldKey]; embHashMap.hostHashMap[oldKey] = static_cast(tempOffset); + hbm2DdrKeys.emplace_back(key); + ddr2HbmKeys.emplace_back(oldKey); } embHashMap.maxOffset = embHashMap.maxOffsetOld; for (auto& Offset2Key: embHashMap.devOffset2KeyOld) { embHashMap.devOffset2Key[Offset2Key.first] = Offset2Key.second; } + if (isSSDEnabled) { + // 恢复CacheManager中频次数据 + cacheManager->RefreshFreqInfoCommon(embTableName, hbm2DdrKeys, TransferType::HBM_2_DDR); + cacheManager->RefreshFreqInfoCommon(embTableName, ddr2HbmKeys, TransferType::DDR_2_HBM); + } } return embHashMapsOld; } @@ -551,13 +561,14 @@ bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hos /// HBM-DDR换入换出时刷新频次信息 /// \param embName emb表名 -/// \param oldSwap 换入换出key列表,元素为pair: pair oldKey为从HBM移出的key, key为从DDR移出的key -void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, - const std::vector>& oldSwap) +/// \param embHashMap emb hash map +void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap) { if (!isSSDEnabled) { return; } + // 换入换出key列表,元素为pair: pair oldKey为从HBM移出的key, key为从DDR移出的key + auto& oldSwap = embHashMap.oldSwap; VLOG(GLOG_DEBUG) << StringFormat("RefreshFreqInfoWithSwap:oldSwap Size:%lld", oldSwap.size()); vector enterDDRKeys; vector leaveDDRKeys; @@ -567,6 +578,50 @@ void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, } cacheManager->RefreshFreqInfoCommon(embName, enterDDRKeys, TransferType::HBM_2_DDR); cacheManager->RefreshFreqInfoCommon(embName, leaveDDRKeys, TransferType::DDR_2_HBM); + + AddCacheManagerTraceLog(embName, embHashMap); +} + +/// 记录日志:HBM和DDR换入换出后,比较hostHashMap中DDR内key和表对应的lfuCache对象中的key内容 +void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const +{ + if (!VLOG_IS_ON(GLOG_TRACE)) { + return; + } + auto& hostMap = embHashMap.hostHashMap; + auto& devSize = embHashMap.devVocabSize; + auto& lfu = cacheManager->ddrKeyFreqMap[embTableName]; + const auto& lfuTab = lfu.GetFreqTable(); + if (lfuTab.empty()) { + return; + } + size_t tableKeyInDdr = 0; + vector ddrKeys; // 获取hostHashMap中保存在DDR的key + for (const auto& item : hostMap) { + if (item.second < devSize) { + continue; + } + ddrKeys.emplace_back(item.first); + ++tableKeyInDdr; + } + vector lfuKeys; + for (const auto& it : lfuTab) { + lfuKeys.emplace_back(it.first); + } + std::sort(ddrKeys.begin(), ddrKeys.end()); + std::sort(lfuKeys.begin(), lfuKeys.end()); + std::string ddrKeysString = VectorToString(ddrKeys); + std::string lfuKeysString = VectorToString(lfuKeys); + if (ddrKeysString != lfuKeysString) { + LOG(ERROR) << "ERROR STRING not equal, ddrKeysString:" << ddrKeysString << ", lfuKeysString:" << lfuKeysString; + } else { + LOG(INFO) << "After HBM swap with DDR, table:" << embTableName << + ", ddrKeysString is equals with lfuKeysString, string length:" << lfuKeysString.length(); + } + + LOG(INFO) + << "After HBM swap with DDR, table:" << embTableName << ", tableKeyInDdr:" << tableKeyInDdr << + ", tableKeyInLfu:" << lfu.keyTable.size(); } /// 记录key频次数据 diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 6a1fa414..5d6e9204 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -82,8 +82,9 @@ namespace MxRec { int32_t FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap); - void RefreshFreqInfoWithSwap(const string& embName, - const std::vector>& oldSwap); + void RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap); + + void AddCacheManagerTraceLog(const string& embName, const EmbHashMapInfo& embHashMap) const; void AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index db6ed66a..3e8dc2ce 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -164,6 +164,35 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, return true; } +// 比较hostHashMap和cacheManager的数据是否一致 +void HybridMgmt::AddCacheManagerTraceLog(absl::flat_hash_map, EmbHashMapInfo>& embHashMaps) +{ + if (!isSSDEnabled || !VLOG_IS_ON(GLOG_TRACE)) { + return; + } + for (auto& it : embHashMaps) { + string embTableName = it.first; + auto& hostMap = it.second.hostHashMap; + auto& devSize = it.second.devVocabSize; + auto& lfu = cacheManager->ddrKeyFreqMap[embTableName]; + size_t tableKeyInDdr = 0; + for (const auto& item : hostMap) { + if (item.second < devSize) { + continue; + } + ++tableKeyInDdr; + auto cuKey = item.first; + auto lfuKeyCount = lfu.Get(cuKey); + if (lfuKeyCount == -1) { + LOG(ERROR) << "ERROR, SAVE Step, ddr key:" << cuKey << ", lfu count by key:" << + lfuKeyCount << ", hostHashMap offset:" << item.second; + } + } + LOG(INFO) << "SAVE Step, table:" << embTableName << ", tableKeyInDdr:" << tableKeyInDdr << + ", tableKeyInLfu:" << lfu.keyTable.size(); + } +} + /// 保存模型 /// \param savePath 保存路径 /// \return @@ -180,6 +209,7 @@ bool HybridMgmt::Save(const string savePath) VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: ddr mode hashmap"); saveData.hostEmbs = hostEmbs->GetHostEmbs(); saveData.embHashMaps = hostHashMaps->GetHashMaps(); + AddCacheManagerTraceLog(saveData.embHashMaps); } else { // HBM模式保存最大偏移(真正使用了多少vocab容量),特征到偏移的映射 VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: no ddr mode hashmap"); @@ -886,11 +916,16 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) // 初始化host侧的emb auto& evictOffset = hostHashMaps->GetEvictPos(embName); - if (evictOffset.size() != 0) { + vector evictOffset4Ddr; + auto devVocabSize = hostHashMaps->embHashMaps.at(embName).devVocabSize; + for (auto& offsetInHostHashMap : evictOffset) { + evictOffset4Ddr.emplace_back(offsetInHostHashMap - devVocabSize); + } + if (!evictOffset4Ddr.empty()) { VLOG(GLOG_DEBUG) << StringFormat( - MGMT + "ddr mode, delete emb: [%s]! evict size on host:%d", embName.c_str(), evictOffset.size() + MGMT + "ddr mode, delete emb: [%s]! evict size on host:%d", embName.c_str(), evictOffset4Ddr.size() ); - hostEmbs->EvictInitEmb(embName, evictOffset); + hostEmbs->EvictInitEmb(embName, evictOffset4Ddr); } else { LOG(INFO) << StringFormat(MGMT + "ddr mode, evict size on host is empty"); } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 3338ec15..c131e19a 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -121,6 +121,7 @@ namespace MxRec { void PrepareDDRData(const std::string& embTableName, EmbHashMapInfo& embHashMap, const vector& keys, int channelId); + void AddCacheManagerTraceLog(absl::flat_hash_map, EmbHashMapInfo>& embHashMaps); private: int currentBatchId; diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index 4da9d210..e71d72d8 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -33,9 +33,8 @@ inline TransferRet TransferSpaceWarning() inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& externalKeys, vector& internalKeys, const vector& keys) { - auto& hostHashMap = embHashMap.hostHashMap; - for (auto key : keys) { - if (hostHashMap.find(key) == hostHashMap.end()) { + for (const emb_key_t key : keys) { + if (embHashMap.hostHashMap.find(key) == embHashMap.hostHashMap.end()) { externalKeys.emplace_back(key); } else { internalKeys.emplace_back(key); @@ -54,18 +53,22 @@ void AddDebugAndTraceLog(size_t batchKeySize, vector& externalKeys, v } } -inline vector DeleteRepeatKey(const vector& originalKeys) +/// 去重和过滤无效key +/// \param originalKeys 原有keys +/// \param keys 处理后的keys +void HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys) { // 去重并保持原key的顺序 结果可测试 unordered_set keySet; - vector keys; for (auto& key : originalKeys) { + if (key == INVALID_KEY_VALUE) { + continue; + } if (keySet.find(key) == keySet.end()) { keySet.emplace(key); keys.emplace_back(key); } } - return keys; } /// DDR与SSD数据转移,使DDR内剩余空间能放置当前批次key @@ -77,7 +80,8 @@ inline vector DeleteRepeatKey(const vector& originalKeys) TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, const vector& originalKeys, int channelId) { - vector keys = DeleteRepeatKey(originalKeys); // 去重 + vector keys; // 去重和删除无效key + HandleRepeatAndInvalidKey(originalKeys, keys); // 区分HBM+DDR内key,和HBM+DDR外的key(新key或保存在SSD中的key) vector externalKeys; vector internalKeys; @@ -85,12 +89,13 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, if (externalKeys.empty()) { return TransferSuccess(); } // 判断剩余内存空间是否足够; 可用内存空间计算:HBM+DDR-已占用; 若是训练,再加DDR已淘汰; - // SSD仅与DDR交互,不考虑HBM淘汰位置 + // SSD仅与DDR交互,不考虑HBM淘汰位置;由于maxOffset比实际使用大1,所以虽然从0开始也不用再减1 size_t ddrAvailableSize = embHashMap.devVocabSize + embHashMap.hostVocabSize - embHashMap.maxOffset; if (channelId == TRAIN_CHANNEL_ID) { ddrAvailableSize += embHashMap.evictPos.size(); } - + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD, maxOffset:%d, evictPos size:%d, ddrAvailableSize:%d", + embHashMap.maxOffset, embHashMap.evictPos.size(), ddrAvailableSize); CreateSSDTableIfNotExist(embTableName); // 调用ssdEngine查询当前批次key中保存在SSD中的key @@ -151,8 +156,7 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, // 从ddr转移到ssd的key个数 size_t ddrSwapOutSizeTmp = ddrSpaceEnoughOrEval ? externalSSDKeys.size() : externalKeys.size(); auto ddrSwapOutSize = static_cast(ddrSwapOutSizeTmp - ddrAvailableSize); - VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: ddrSwapOutSize:%d, ddrAvailableSize:%lld", ddrSwapOutSize, - ddrAvailableSize); + VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: ddrSwapOutSize:%d", ddrSwapOutSize); /* * 转移DDR中数据到SSD @@ -254,7 +258,8 @@ void CacheManager::RefreshRelateInfoWithDDR2SSD(const string& embTableName, EmbH void CacheManager::RefreshFreqInfoCommon(const string& embTableName, vector& originalKeys, TransferType type) { - vector keys = DeleteRepeatKey(originalKeys); + vector keys; + HandleRepeatAndInvalidKey(originalKeys, keys); if (type == TransferType::DDR_2_HBM) { for (auto& key : keys) { // 频次数据记录到 excludeDDRKeyCountMap,并删除ddrKeyFreqMap中频次数据 @@ -300,7 +305,8 @@ bool CacheManager::IsKeyInSSD(const string& embTableName, emb_key_t key) /// \param originalKeys 淘汰key列表 void CacheManager::EvictSSDEmbedding(const string& embTableName, vector& originalKeys) { - vector keys = DeleteRepeatKey(originalKeys); + vector keys; + HandleRepeatAndInvalidKey(originalKeys, keys); // 1 删除缓存中记录的key的次数 2 删除SSD中保存的Emb数据 for (auto& key : keys) { excludeDDRKeyCountMap[embTableName].erase(key); @@ -336,7 +342,8 @@ void CacheManager::HandleDDRTransferPos(vector& ddrTransferPos, vector externalSSDKeys.size()) { while (ddrTransferPos.size() > externalSSDKeys.size()) { @@ -473,9 +480,12 @@ CacheManager::~CacheManager() /// 加载数据到CacheManager /// \param ddrFreqInitMap ddr内key频次数据 /// \param excludeDdrFreqInitMap 非DDR key频次数据 +/// \param step 加载SSDEngine传入步数 void CacheManager::Load(unordered_map>& ddrFreqInitMap, - unordered_map>& excludeDdrFreqInitMap) + unordered_map>& excludeDdrFreqInitMap, + int step) { + // 加载CacheManager数据 for (auto& it : ddrFreqInitMap) { auto& embTableName = it.first; auto& freqMap = it.second; @@ -490,4 +500,19 @@ void CacheManager::Load(unordered_mapLoad(embTableName, embBase.savePath, embBase.maxTableSize, step); + } +#endif +} + +void CacheManager::SaveSSDEngine(int step) +{ +#ifndef GTEST + ssdEngine->Save(step); +#endif } diff --git a/src/core/ssd_cache/cache_manager.h b/src/core/ssd_cache/cache_manager.h index ac62e30c..8e3fd26d 100644 --- a/src/core/ssd_cache/cache_manager.h +++ b/src/core/ssd_cache/cache_manager.h @@ -49,7 +49,9 @@ namespace MxRec { void Init(HostEmb* hostEmbPtr, vector& mgmtEmbInfo); void Load(unordered_map>& ddrFreqInitMap, - unordered_map>& excludeDdrFreqInitMap); + unordered_map>& excludeDdrFreqInitMap, int step); + + void SaveSSDEngine(int step); // 转换DDR和SSD数据 TransferRet TransferDDREmbWithSSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 29b25c4e..fa5ba6a2 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -488,10 +488,19 @@ namespace MxRec { size_t freeSize; std::vector lookUpVec; std::vector missingKeysHostPos; // 用于记录当前batch在host上需要换出的偏移 - std::vector swapPos; + std::vector swapPos; // 记录从HBM换出到DDR的offset + /* + * 取值范围:[0,devVocabSize+hostVocabSize); + * [0,devVocabSize-1]时存储在HBM, [devVocabSize,devVocabSize+hostVocabSize)存储在DDR + */ size_t maxOffset { 0 }; + /* + * 记录DDR内淘汰列表,其值为相对HBM+DDR大表的;hostHashMap可直接使用;操作ddr内emb时需减掉devVocabSize + * 例如:HBM表大小20(offset:0~19),DDR表大小为100(offset:0~99); + * 若DDR内0位置被淘汰,记录到evictPos的值为0+20=20 + */ std::vector evictPos; - std::vector evictDevPos; + std::vector evictDevPos; // 记录HBM内淘汰列表 size_t maxOffsetOld { 0 }; std::vector evictPosChange; std::vector evictDevPosChange; diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 37a4037d..5087cba3 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -371,7 +371,7 @@ TEST_F(CacheManagerTest, LoadTest) excludeDdrTableMap.emplace(15, 1); excludeDdrTableMap.emplace(25, 5); excludeDdrMap.emplace(embTableName, excludeDdrTableMap); - cacheManager.Load(ddrMap, excludeDdrMap); + cacheManager.Load(ddrMap, excludeDdrMap, 0); // 数据检查 auto& ddrKeyFreqMap = cacheManager.ddrKeyFreqMap; auto& excludeDDRKeyCountMap = cacheManager.excludeDDRKeyCountMap; -- Gitee From 41a06a2393fc1b0a328c7194681ab030c3a67705 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 1 Sep 2023 00:53:04 +0800 Subject: [PATCH 304/551] Match-id-fcf97349ea3586c88d3de2aa7e8580c42aaf7f82 --- .../op_host/embedding_lookup_by_address.cpp | 32 +++++++++++-- .../op_host/embedding_update_by_address.cpp | 39 +++++++++------ .../op_kernel/embedding_lookup_by_address.cpp | 46 +++++++++--------- .../op_kernel/embedding_update_by_address.cpp | 48 +++++++++---------- src/ops_tf/hybrid_dataset_ops.cpp | 10 ---- 5 files changed, 97 insertions(+), 78 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 117efaa3..a881d58f 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -10,6 +10,11 @@ namespace optiling size_t usrSize = 256; size_t sysWorkspaceSize = 16 * 1024 * 1024; + if (context == nullptr) { + printf("Tiling context nullptr\n"); + return ge::GRAPH_FAILED; + } + size_t *currentWorkspace = context->GetWorkspaceSizes(1); currentWorkspace[0] = sysWorkspaceSize + usrSize; @@ -17,9 +22,21 @@ namespace optiling int32_t ub_limit = 175 * 1024; auto *attrs = context->GetAttrs(); const auto *attr0_value = attrs->GetAttrPointer(0); - int32_t embbeding_dim = *attr0_value; + if (attr0_value == nullptr) { + printf(" Lookup embbeding_type attr0_value nullptr\n"); + } + else { + int32_t embbeding_dim = *attr0_value; + } + const auto *attr1_value = attrs->GetAttrPointer(1); - int32_t embbeding_type = *attr1_value; + if (attr1_value == nullptr) { + printf(" Lookup embbeding_type attr1_value nullptr\n"); + } + else { + int32_t embbeding_type = *attr1_value; + } + int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); tiling.set_embbeding_type(embbeding_type); @@ -43,7 +60,13 @@ namespace ge gert::Shape *y_shape = context->GetOutputShape(0); auto *attrs = context->GetAttrs(); const auto *attr0_value = attrs->GetAttrPointer(0); - int64_t update_dim = *attr0_value; + if (attr0_value == nullptr) { + printf(" Lookup embbeding_type attr0_value nullptr\n"); + } + else { + int64_t update_dim = *attr0_value; + } + int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); y_shape->SetDimNum(2); y_shape->SetDim(0, input_shape); @@ -56,8 +79,7 @@ namespace ge int64_t embbeding_type; auto *attrs = context->GetAttrs(); const auto *attr1_value = attrs->GetAttrPointer(1); - if (attr1_value == nullptr) - { + if (attr1_value == nullptr) { printf(" Lookup embbeding_type nullptr\n"); } else diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index ca43db22..ea7a3730 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -11,6 +11,11 @@ namespace optiling size_t usrSize = 256; size_t sysWorkspaceSize = 16 * 1024 * 1024; + if (context == nullptr) { + printf("Update embbeding_type context nullptr\n"); + return ge::GRAPH_FAILED; + } + size_t *currentWorkspace = context->GetWorkspaceSizes(1); currentWorkspace[0] = sysWorkspaceSize + usrSize; @@ -21,24 +26,25 @@ namespace optiling int32_t embbeding_type; int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); + if (input_shape <= 0) { + printf("input_shape must larger than 0\n"); + return ge::GRAPH_FAILED; + } + int32_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; int32_t update_type = *(context->GetAttrs()->GetAttrPointer(0)); ge::DataType input_datatype = context->GetInputTensor(1)->GetDataType(); - switch (input_datatype) - { - case ge::DT_FLOAT16: - embbeding_type = 2; - break; - case ge::DT_FLOAT: - embbeding_type = 1; - break; - case ge::DT_INT32: - embbeding_type = 0; - break; - default: - embbeding_type = 1; - break; + switch (input_datatype) { + case ge::DT_FLOAT16: + embbeding_type = 2; + break; + case ge::DT_INT32: + embbeding_type = 0; + break; + default: + embbeding_type = 1; + break; } update_dim = input_dim; @@ -61,6 +67,11 @@ namespace ge { gert::Shape *y_shape = context->GetOutputShape(0); int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); + if (input_shape <= 0) { + printf("input_shape must larger than 0\n"); + return GRAPH_FAILED; + } + int64_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; y_shape->SetDimNum(2); y_shape->SetDim(0, input_shape); diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 437e5b25..718f7c64 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -32,6 +32,7 @@ public: { GET_TILING_DATA(constData, tiling); // TODO: user kernel impl + // 数据的维度数 int32_t update_dim = constData.update_dim; int32_t embbeding_type = constData.embbeding_type; int32_t block_total_nums = block_num; @@ -39,24 +40,29 @@ public: addr_nums = constData.addr_nums; if (embbeding_type == 2) { - single_data_size = 2; + singleDataSize = 2; } else { - single_data_size = 4; + singleDataSize = 4; } + // 缓冲区数量 PingpongNum = 1; - int min_move_num = 32 / single_data_size; - once_move_nums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); - - int addr_max_num = ((int)((int)(ub_limit / (sizeof(int64_t) + single_data_size * (once_move_nums * ((int32_t)(update_dim - 1 + once_move_nums) / once_move_nums)) * PingpongNum * 2)) / 4)) * 4; + int min_move_num = 32 / singleDataSize; + // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + min_move_num) / min_move_num表示除以min_move_num向下取整 + int onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); + int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums + // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 + int occupy_address_bytes_num = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2 + // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 + int addr_max_num = ((int)((int)(ub_limit / occupy_address_bytes_num) / 4)) * 4; int singlenum = (int)(addr_nums / block_total_nums); if (singlenum % 4) { singlenum -= singlenum % 4; } - roundSize = addr_max_num; // addr_max_num; - Veclen = roundSize * single_data_size * once_move_nums; + roundSize = addr_max_num; + Veclen = roundSize * singleDataSize * onceMoveNums; SingleCoreAddrLen = singlenum * sizeof(int64_t); cache = roundSize; dim = update_dim; @@ -100,7 +106,7 @@ private: bool isFull = true; int nums = 0; int out_index = 0; - int times = once_move_nums / 8; + int times = onceMoveNums / 8; int tmp_cache = cache - 1; for (int i = 0; i < sizes; i++) @@ -112,14 +118,14 @@ private: if (address != 0) { srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(address)); - DataCopy(dataLocal[once_move_nums * nums], srcDataBufferGm, once_move_nums); + DataCopy(dataLocal[onceMoveNums * nums], srcDataBufferGm, onceMoveNums); } else { for (int j = 0; j < times; j++) { - Duplicate(dataLocal[once_move_nums * nums + j * 8], (T)0, 8); + Duplicate(dataLocal[onceMoveNums * nums + j * 8], (T)0, 8); } } @@ -146,7 +152,7 @@ private: DataCopyParams copyparams; copyparams.blockCount = 1; - copyparams.blockLen = once_move_nums * sizeof(T) * nums / 32; + copyparams.blockLen = onceMoveNums * sizeof(T) * nums / 32; DataCopy(dstLocal, srcLocal, copyparams); outQueue.EnQue(dstLocal); @@ -159,19 +165,19 @@ private: int offset = block_idx * dim * SingleCoreAddrLen / sizeof(int64_t) + (turns * roundSize * dim) + dim * index; #if defined(__DAV_C220_VEC__) - if (single_data_size == 4) + if (singleDataSize == 4) { copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm[offset].GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, nums, dim * sizeof(T), 0, 0, 0, 0); } - else if (single_data_size == 2) + else if (singleDataSize == 2) { copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm[offset].GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, nums, dim * sizeof(T), 0, 0, 0, 0); } #else - DataCopy(dstDataGm[offset], dstLocal, once_move_nums * nums); + DataCopy(dstDataGm[offset], dstLocal, onceMoveNums * nums); #endif outQueue.FreeTensor(dstLocal); } @@ -179,7 +185,7 @@ private: public: int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, cache, Veclen, dim, PingpongNum; int32_t addr_nums; - int32_t once_move_nums, single_data_size, update_type; + int32_t onceMoveNums, singleDataSize, update_type; private: TPipe pipe; @@ -208,14 +214,6 @@ extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR addres op.Process(); } break; - case 1: - { - KernelEimtable op; - op.Init_param(tiling); - op.Init(address, y); - op.Process(); - } - break; case 2: { KernelEimtable op; diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index 4e45d2ac..8a6c537a 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -29,6 +29,7 @@ public: { GET_TILING_DATA(constData, tiling); // TODO: user kernel impl + // 数据的维度数 int32_t update_dim = constData.update_dim; int32_t embbeding_type = constData.embbeding_type; int32_t block_total_nums = block_num; @@ -37,24 +38,29 @@ public: addr_nums = constData.addr_nums; if (embbeding_type == 2) { - single_data_size = 2; + singleDataSize = 2; } else { - single_data_size = 4; + singleDataSize = 4; } + // 缓冲区数量 PingpongNum = 1; - int min_move_num = 32 / single_data_size; - once_move_nums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); - - int addr_max_num = ((int)((int)(ub_limit / (sizeof(int64_t) + single_data_size * (once_move_nums * ((int32_t)(update_dim - 1 + once_move_nums) / once_move_nums)) * PingpongNum * 2)) / 4)) * 4; + int min_move_num = 32 / singleDataSize; + // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + min_move_num) / min_move_num表示除以min_move_num向下取整 + int onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); + int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums + // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 + int occupy_address_bytes_num = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2 + // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 + int addr_max_num = ((int)((int)(ub_limit / occupy_address_bytes_num) / 4)) * 4; int singlenum = (int)(addr_nums / block_total_nums); if (singlenum % 4) { singlenum -= singlenum % 4; } - roundSize = addr_max_num; // addr_max_num; - Veclen = roundSize * single_data_size * once_move_nums; + roundSize = addr_max_num; + Veclen = roundSize * singleDataSize * onceMoveNums; SingleCoreAddrLen = singlenum * sizeof(int64_t); cache = roundSize; dim = update_dim; @@ -98,10 +104,10 @@ private: int out_index = 0; int offset = 0; int64_t address = 0; - if (dim == once_move_nums) + if (dim == onceMoveNums) { dataLocal = inQueue.AllocTensor(); - DataCopy(dataLocal, srcDataBufferGm[turns * roundSize], sizes * once_move_nums); + DataCopy(dataLocal, srcDataBufferGm[turns * roundSize], sizes * onceMoveNums); inQueue.EnQue(dataLocal); Compute(sizes); LocalTensor dstLocal = outQueue.DeQue(); @@ -115,7 +121,7 @@ private: if (address != 0) { dstDataGm.SetGlobalBuffer((__gm__ T*)(address)); - DataCopy(dstDataGm, dstLocal[i*once_move_nums], once_move_nums); + DataCopy(dstDataGm, dstLocal[i*onceMoveNums], onceMoveNums); } } if (update_type == 0) @@ -129,7 +135,7 @@ private: for (int i = 0; i < sizes; i++) { dataLocal = inQueue.AllocTensor(); - DataCopy(dataLocal, srcDataBufferGm[i * dim + turns * roundSize], once_move_nums); + DataCopy(dataLocal, srcDataBufferGm[i * dim + turns * roundSize], onceMoveNums); inQueue.EnQue(dataLocal); Compute(1); address = srcAddrLocal.GetValue(i); @@ -145,7 +151,7 @@ private: LocalTensor dstLocal = outQueue.AllocTensor(); DataCopyParams copyparams; copyparams.blockCount = 1; - copyparams.blockLen = once_move_nums * sizeof(T) * nums / 32; + copyparams.blockLen = onceMoveNums * sizeof(T) * nums / 32; DataCopy(dstLocal, srcLocal, copyparams); outQueue.EnQue(dstLocal); inQueue.FreeTensor(srcLocal); @@ -167,19 +173,19 @@ private: } #if defined(__DAV_C220_VEC__) - if (single_data_size == 4) + if (singleDataSize == 4) { copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, 1, dim * sizeof(T), 0, 0, 0, 0); } - else if (single_data_size == 2) + else if (singleDataSize == 2) { copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, 1, dim * sizeof(T), 0, 0, 0, 0); } #else - DataCopy(dstDataGm, dstLocal, once_move_nums); + DataCopy(dstDataGm, dstLocal, onceMoveNums); #endif } if (update_type == 0) @@ -191,7 +197,7 @@ private: public: int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, addr_nums, cache, Veclen, dim, PingpongNum; - int32_t once_move_nums, single_data_size, update_type; + int32_t onceMoveNums, singleDataSize, update_type; private: TPipe pipe; @@ -219,14 +225,6 @@ extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR addres op.Process(); } break; - case 1: - { - KernelEimtable_update op; - op.Init_param(tiling); - op.Init(address, embedding, y); - op.Process(); - } - break; case 2: { KernelEimtable_update op; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 06000536..e1381d29 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -38,16 +38,6 @@ TimeCost staticSw {}; TimeCost staticReadRaw {}; array batchIdsInfo {}; -size_t GetBatchSize(OpKernelContextPtr context, const size_t dataSize, const size_t fieldNum) -{ - if (fieldNum == 0 || dataSize / fieldNum <= 0) { - context->SetStatus( - errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat("batchSize error. %d/%d", dataSize, fieldNum))); - return 0; - } - return dataSize / fieldNum; -} - REGISTER_OP("ClearChannel").Attr("channel_id : int"); class ClearChannel : public OpKernel { -- Gitee From 89db1fbb83e13b7d15023919a7e28e6e9b7f4fd7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 30 Aug 2023 17:24:26 +0800 Subject: [PATCH 305/551] Match-id-3452ea937b45522bce546336dec1e8ce75d424ae --- mx_rec/logger/log.py | 35 +-------- src/core/emb_hashmap/emb_hashmap.cpp | 51 +++---------- src/core/emb_hashmap/emb_hashmap.h | 5 -- src/core/emb_table/emb_table.cpp | 103 +++------------------------ src/core/emb_table/emb_table.h | 13 ---- 5 files changed, 20 insertions(+), 187 deletions(-) diff --git a/mx_rec/logger/log.py b/mx_rec/logger/log.py index e24a7efe..d00ecb2d 100644 --- a/mx_rec/logger/log.py +++ b/mx_rec/logger/log.py @@ -6,9 +6,8 @@ import logging.config import os import yaml -from mx_rec.constants.constants import MAX_SIZE, LOG_MAX_SIZE +from mx_rec.constants.constants import LOG_MAX_SIZE from mx_rec.validator.validator import FileValidator -from mx_rec.validator.validator import DirectoryValidator def init_sys_log(): @@ -33,38 +32,6 @@ def init_sys_log(): logging.config.dictConfig(log_cfg) -def init_log_dir_for_dt(log_cfg): - """Create log directory for local environment dt test. - - :param log_cfg: log configuration dictionary from yml file. - :return: None - """ - handlers = log_cfg.get('handlers') - if not handlers: - return - - for handler_name in handlers: - handler_dict = handlers.get(handler_name) - log_file = handler_dict.get('filename') - - if not log_file: - continue - - log_file_standard = os.path.realpath(log_file) - if log_file_standard != log_file: - continue - - log_dir = os.path.dirname(log_file_standard) - if not DirectoryValidator(log_dir) \ - .check_is_not_none() \ - .check_dir_name() \ - .should_not_contains_sensitive_words() \ - .with_blacklist() \ - .check() \ - .is_valid(): - continue - - init_sys_log() srv_stream_log = logging.getLogger("logStream") env_log_level = os.getenv("MXREC_LOG_LEVEL") diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index c87cd26f..2739cc26 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -20,17 +20,17 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, #ifndef GTEST this->rankInfo = rankInfo; if (!ifLoad) { - EmbHashMapInfo embHashMap; + EmbHashMapInfo embHashMapInfo; LOG(INFO) << "init emb hash map from scratch"; for (const auto& embInfo: embInfos) { - embHashMap.devOffset2Batch.resize(embInfo.devVocabSize); - embHashMap.devOffset2Key.resize(embInfo.devVocabSize); - embHashMap.hostVocabSize = embInfo.hostVocabSize; - embHashMap.devVocabSize = embInfo.devVocabSize; - embHashMap.currentUpdatePos = 0; - fill(embHashMap.devOffset2Batch.begin(), embHashMap.devOffset2Batch.end(), -1); - fill(embHashMap.devOffset2Key.begin(), embHashMap.devOffset2Key.end(), -1); - embHashMaps[embInfo.name] = embHashMap; + embHashMapInfo.devOffset2Batch.resize(embInfo.devVocabSize); + embHashMapInfo.devOffset2Key.resize(embInfo.devVocabSize); + embHashMapInfo.hostVocabSize = embInfo.hostVocabSize; + embHashMapInfo.devVocabSize = embInfo.devVocabSize; + embHashMapInfo.currentUpdatePos = 0; + fill(embHashMapInfo.devOffset2Batch.begin(), embHashMapInfo.devOffset2Batch.end(), -1); + fill(embHashMapInfo.devOffset2Key.begin(), embHashMapInfo.devOffset2Key.end(), -1); + embHashMaps[embInfo.name] = embHashMapInfo; if (VLOG_IS_ON(GLOG_TRACE)) { VLOG(GLOG_TRACE) << StringFormat( @@ -159,19 +159,6 @@ void EmbHashMap::FindAndUpdateOffset(const string& embName, vector& k } } -void EmbHashMap::ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_t hostOffset, size_t currentBatchId, - int pos) -{ - embHashMap.devOffset2Batch[pos] = static_cast(currentBatchId); - embHashMap.hostHashMap[key] = pos; - auto& oldKey = embHashMap.devOffset2Key[pos]; - if (oldKey != -1) { - embHashMap.oldSwap.emplace_back(oldKey, key); - embHashMap.hostHashMap[oldKey] = hostOffset; - } - oldKey = key; -} - int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap) { int32_t offset; @@ -236,26 +223,6 @@ void EmbHashMap::FindAndUpdateBatchId(vector& keys, size_t currentBat } } -void EmbHashMap::FindPos(EmbHashMapInfo& embHashMap, int num, size_t keepBatchId) -{ - while (num != 0) { - if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { - embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); - num -= 1; - } - embHashMap.currentUpdatePos++; - embHashMap.freeSize--; - if (embHashMap.currentUpdatePos == embHashMap.devVocabSize) { - embHashMap.currentUpdatePos = 0; - } - if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { - LOG(ERROR) << "devVocabSize is too small"; - throw runtime_error("devVocabSize is too small"); - } - } -} - - auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map { auto embHashMapsOld = embHashMaps; diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 5d6e9204..00559509 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -31,11 +31,6 @@ namespace MxRec { void FindAndUpdateOffset(const string& embName, vector& keys, size_t currentBatchId, size_t keepBatchId, int channelId); - void ChangeSwapInfo(EmbHashMapInfo& embHashMap, emb_key_t key, size_t hostOffset, size_t currentBatchId, - int pos); - - void FindPos(EmbHashMapInfo& embHashMap, int num, size_t keepBatchId); - auto GetHashMaps() -> absl::flat_hash_map; void LoadHashMap(absl::flat_hash_map& loadData); diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 8bde2888..a13b1eed 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -41,16 +41,11 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); throw AclError(); } - if (newBlock == nullptr) { - // 内存不足,抛出异常 - throw OutOfMemoryError(); - } else { - // 申请内存初始化 - RandomInit(newBlock, embInfo.initializeInfos, seed); - // 将新的内存块加入内存链表 - memoryList.push_back(newBlock); - SplitMemoryBlock(newBlock); - } + // 申请内存初始化 + RandomInit(newBlock, embInfo.initializeInfos, seed); + // 将新的内存块加入内存链表 + memoryList.push_back(newBlock); + SplitMemoryBlock(newBlock); } totalCapacity = static_cast(memoryList.size()); LOG(INFO) << StringFormat( @@ -85,16 +80,11 @@ int64_t EmbTable::GetEmbAddress() LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); throw AclError(); } - if (addBlock == nullptr) { - // 内存不足,抛出异常 - throw OutOfMemoryError(); - } else { - RandomInit(addBlock, embInfo.initializeInfos, seed); - // 将新的内存块加入内存list - memoryList.push_back(addBlock); - SplitMemoryBlock(addBlock); - totalCapacity++; - } + RandomInit(addBlock, embInfo.initializeInfos, seed); + // 将新的内存块加入内存list + memoryList.push_back(addBlock); + SplitMemoryBlock(addBlock); + totalCapacity++; } float *embAddr = embeddingList.front(); embeddingList.pop_front(); @@ -103,15 +93,6 @@ int64_t EmbTable::GetEmbAddress() #endif } -// 将一个emb地址放入embeddingList中 -void EmbTable::PutEmbAddress(int64_t curAddress) -{ -#ifndef GTEST - embeddingList.push_back(reinterpret_cast(curAddress)); - usedCapacity--; -#endif -} - void EmbTable::RandomInit(void* newBlock, const vector& initializeInfos, int seed) { #ifndef GTEST @@ -195,67 +176,3 @@ void EmbTable::PrintStatus() // 输出embedding table的未使用的使用容量 LOG(INFO) << StringFormat("Unused capacity:%d", totalCapacity * blockSize - usedCapacity * embSize); } - -// 用于保存 -map> EmbTable::SaveEmb() -{ -#ifndef GTEST - if (embSize == 0) { - throw std::runtime_error("SaveEmb Divided by Zero!"); - } - map> savedEmb; - for (auto ptr : memoryList) { - float* floatPtr = static_cast(ptr); - for (int i = 0; i < BLOCK_EMB_COUNT; ++i) { - // 访问 aclmemcpy - vector row(embSize); - aclError ret = aclrtMemcpy(row.data(), embSize * sizeof(float), - floatPtr + i * embSize, embSize * sizeof(float), - ACL_MEMCPY_HOST_TO_DEVICE); - if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); - throw AclError(); - } - savedEmb[reinterpret_cast(floatPtr + i * embSize)] = move(row); - } - } - return savedEmb; -#endif -} - -// 用于加载 输入一个vector,申请内存,存储输入信息 , list返回全部地址 -list EmbTable::LoadEmb(const vector> &savedEmb) -{ -#ifndef GTEST - list addressList; - int embCapacity = static_cast(savedEmb.size()); - if (savedEmb.size() == 0 || savedEmb[0].size() == 0) { - LOG(ERROR) << "Load invalid savedEmb"; - return addressList; - } - embSize = static_cast(savedEmb[0].size()); - void *newBlock = nullptr; - aclError ret = aclrtMalloc(&newBlock, embCapacity * embSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); - if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); - throw AclError(); - } - if (newBlock == nullptr) { - // 内存不足,抛出异常 - throw OutOfMemoryError(); - } - float *floatPtr = static_cast(newBlock); - for (int i = 0; i < embCapacity; i++) { - aclError ret = aclrtMemcpy(floatPtr + i * embSize, embSize * sizeof(float), - savedEmb[i].data(), embSize * sizeof(float), - ACL_MEMCPY_HOST_TO_DEVICE); - if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); - throw AclError(); - } - addressList.push_back(floatPtr + i * embSize); - } - memoryList.push_back(newBlock); - return addressList; -#endif -} diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h index 500a324c..cc83a4e2 100644 --- a/src/core/emb_table/emb_table.h +++ b/src/core/emb_table/emb_table.h @@ -31,16 +31,9 @@ namespace MxRec { // 从embeddingList获取获取一个可用的emb地址 int64_t GetEmbAddress(); - // 将一个emb地址放入embeddingList中 - void PutEmbAddress(int64_t curAddress); - // 打印emb表使用情况 void PrintStatus(); - int GetTotalCap(); - - int GetUsedCap(); - EmbTable(const EmbTable&) = delete; EmbTable(EmbTable&&) = delete; @@ -51,12 +44,6 @@ namespace MxRec { void ExecuteAclMemcpy(void* newBlock, vector devEmb); - // 用于保存 - map> SaveEmb(); - - // 用于加载 输入一个vector,创建一个embeddingtable类,申请内存,存储输入信息 , list返回全部地址 - list LoadEmb(const vector> &savedEmb); - GTEST_PRIVATE: constexpr static int BLOCK_EMB_COUNT = 100000; constexpr static int INIT_BLOCK_COUNT = 5; -- Gitee From 97c0b30d9234df7b33f3f30472d30185eb16fb77 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 1 Sep 2023 14:52:15 +0800 Subject: [PATCH 306/551] Match-id-1fb358261ecb190c7647e2fd2c3f98a9d0d5a898 --- mx_rec/core/embedding.py | 6 ++++-- src/core/emb_hashmap/emb_hashmap.cpp | 6 ++++-- src/core/ssd_cache/cache_manager.cpp | 19 ++++++++----------- src/core/ssd_cache/lfu_cache.cpp | 2 +- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 0a7bd9c2..74663444 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -35,9 +35,11 @@ def check_ssd_relate_param(host_vocabulary_size, ssd_vocabulary_size, ssd_data_p h_size = int(host_vocabulary_size) s_size = int(ssd_vocabulary_size) except ValueError: - raise ValueError("exist invalid value in host_vocabulary_size or ssd_vocabulary_size or both.") + raise ValueError("host_vocabulary_size and ssd_vocabulary_size should be integer") if h_size == 0 and s_size != 0: - raise ValueError("ssd_vocabulary_size value is invalid, it need host_vocabulary_size value not equals 0.") + raise ValueError("ssd_vocabulary_size value is invalid, it effected by host_vocabulary_size not zero") + if h_size != 0 and s_size < 0: + raise ValueError("ssd_vocabulary_size value is invalid, it need be greater than 0") invalid_ssd_data_path = [] for tmp_path in ssd_data_path: if is_invalid_path(tmp_path): diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index c87cd26f..c8823eb7 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -344,8 +344,10 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& evictDDRKeys.emplace_back(key); } } - cacheManager->RefreshFreqInfoCommon(embName, evictHBMKeys, TransferType::HBM_2_EVICT); - cacheManager->RefreshFreqInfoCommon(embName, evictDDRKeys, TransferType::DDR_2_EVICT); + if (isSSDEnabled) { + cacheManager->RefreshFreqInfoCommon(embName, evictHBMKeys, TransferType::HBM_2_EVICT); + cacheManager->RefreshFreqInfoCommon(embName, evictDDRKeys, TransferType::DDR_2_EVICT); + } LOG(INFO) << StringFormat( "ddr EvictDeleteEmb, emb: [%s], hostEvictSize: %d, devEvictSize: %d ", diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index e71d72d8..2108c915 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -253,13 +253,10 @@ void CacheManager::RefreshRelateInfoWithDDR2SSD(const string& embTableName, EmbH /// key从DDR移入、移出、HBM淘汰时刷新频次信息;仅刷新频次信息 /// \param embTableName emb表名 -/// \param originalKeys 操作的key集合 +/// \param keys 操作的key集合 /// \param type TransferType -void CacheManager::RefreshFreqInfoCommon(const string& embTableName, vector& originalKeys, - TransferType type) +void CacheManager::RefreshFreqInfoCommon(const string& embTableName, vector& keys, TransferType type) { - vector keys; - HandleRepeatAndInvalidKey(originalKeys, keys); if (type == TransferType::DDR_2_HBM) { for (auto& key : keys) { // 频次数据记录到 excludeDDRKeyCountMap,并删除ddrKeyFreqMap中频次数据 @@ -302,17 +299,17 @@ bool CacheManager::IsKeyInSSD(const string& embTableName, emb_key_t key) /// 淘汰SSD中Emb信息 /// \param embTableName emb表名 -/// \param originalKeys 淘汰key列表 -void CacheManager::EvictSSDEmbedding(const string& embTableName, vector& originalKeys) +/// \param keys 淘汰key列表 +void CacheManager::EvictSSDEmbedding(const string& embTableName, vector& keys) { - vector keys; - HandleRepeatAndInvalidKey(originalKeys, keys); + if (keys.empty()) { + return; + } // 1 删除缓存中记录的key的次数 2 删除SSD中保存的Emb数据 for (auto& key : keys) { excludeDDRKeyCountMap[embTableName].erase(key); } - vector currentKeys(keys.begin(), keys.end()); - ssdEngine->DeleteEmbeddings(embTableName, currentKeys); + ssdEngine->DeleteEmbeddings(embTableName, keys); } /// 放入key,新增/更新(次数+1)次数 diff --git a/src/core/ssd_cache/lfu_cache.cpp b/src/core/ssd_cache/lfu_cache.cpp index 5e0c2a04..8c68ecca 100644 --- a/src/core/ssd_cache/lfu_cache.cpp +++ b/src/core/ssd_cache/lfu_cache.cpp @@ -105,7 +105,7 @@ void LFUCache::PutWithInit(emb_key_t key, freq_num_t freq) { if (keyTable.find(key) != keyTable.end()) { // 一般初始化时,key应该不存在已经被插入的情况;此处替换就的key频次信息 - LOG(WARNING) << StringFormat("key has exist when init process, key:%d", key); + VLOG(GLOG_DEBUG) << StringFormat("key has exist when init process, key:%d", key); Pop(key); } freqTable[freq].emplace_front(key, freq); -- Gitee From d44bf8113344690238e5fb3c86670ab24e5dc6e2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 4 Sep 2023 09:12:58 +0800 Subject: [PATCH 307/551] Match-id-6f8752a1a254340ad2ceb24bad9deea11e62ccce --- mx_rec/core/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 1a65d74e..d16ccae0 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -744,7 +744,7 @@ class SparseEmbedding: if hot_pos is not None: hot, cold = tf.split(unique_grads, [tf.shape(hot_pos)[0], tf.shape(unique_grads)[0] - tf.shape(hot_pos)[0]], axis=0) - unique_grads = tf.tensor_scatter_nd_update(cold, tf.expand_dims(hot_pos, 1), hot) + unique_grads = tf.tensor_scatter_nd_add(cold, tf.expand_dims(hot_pos, 1), hot) local_grad = self._get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: try: -- Gitee From 2f3f7392a9ba370a5bdbcf1ed292c704e4b40a0f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 4 Sep 2023 09:55:50 +0800 Subject: [PATCH 308/551] Match-id-76837a4522d5ca21d5fdb14d89fc53693fecd413 --- mx_rec/saver/saver.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index df219914..a9c6630d 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -149,6 +149,12 @@ class Saver(object): def _save(self, sess, root_dir): result = sess.run(self.save_op_dict) for table_name, dump_data_dict in result.items(): + table_instance = get_table_instance_by_name(table_name) + if table_instance.host_vocabulary_size > 0: + table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) + else: + table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) + tf.io.gfile.makedirs(table_dir) if is_asc_manager_initialized() and self.save_easy_mode: host_data = get_host_data(table_name) key = np.array(list(host_data.keys())) -- Gitee From 8ff812a6fb169f50131ca313e1d9626e3ad0232e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 4 Sep 2023 10:27:01 +0800 Subject: [PATCH 309/551] Match-id-961e1654f3f82f102c8b9d6e6ae90780e82800ba --- mx_rec/saver/saver.py | 8 +- src/core/checkpoint/checkpoint.cpp | 36 ++++- src/core/checkpoint/checkpoint.h | 8 +- .../ckpt_data_handler/ckpt_data_handler.h | 8 +- .../emb_hash_ckpt/emb_hash_ckpt.cpp | 37 ++++- .../emb_hash_ckpt/emb_hash_ckpt.h | 9 +- .../key_freq_map_ckpt/key_freq_map_ckpt.cpp | 137 ++++++++++++++++++ .../key_freq_map_ckpt/key_freq_map_ckpt.h | 51 +++++++ src/core/hybrid_mgmt/hybrid_mgmt.cpp | 64 ++++++-- src/core/hybrid_mgmt/hybrid_mgmt.h | 4 + src/core/utils/common.cpp | 4 +- src/core/utils/common.h | 15 +- src/test_ut.sh | 2 +- src/tests/checkpoint/checkpoint_test.cpp | 72 ++++++++- 14 files changed, 415 insertions(+), 40 deletions(-) create mode 100644 src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp create mode 100644 src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index df219914..57a59956 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -148,6 +148,10 @@ class Saver(object): def _save(self, sess, root_dir): result = sess.run(self.save_op_dict) + if is_asc_manager_initialized() and not self.save_easy_mode: + save_host_data(root_dir) + logging.debug(f"host data was saved.") + for table_name, dump_data_dict in result.items(): if is_asc_manager_initialized() and self.save_easy_mode: host_data = get_host_data(table_name) @@ -155,9 +159,7 @@ class Saver(object): offset = list(host_data.values()) get_valid_dict_data(dump_data_dict, offset) save_key_data(root_dir, table_name, key, self.rank_id) - if is_asc_manager_initialized() and not self.save_easy_mode: - save_host_data(root_dir) - logging.debug(f"host data was saved.") + save_embedding_data(root_dir, table_name, dump_data_dict, self.rank_id) table_instance = get_table_instance_by_name(table_name) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index b943f642..1c18b456 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -15,6 +15,7 @@ #include "ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h" #include "ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h" #include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" +#include "ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h" #include "utils/time_cost.h" #include "utils/common.h" @@ -75,6 +76,9 @@ void Checkpoint::SetDataHandler(CkptData& ckptData) !ckptData.histRec.historyRecords.empty()) { dataHandlers.push_back(make_unique()); } + if (!ckptData.ddrKeyFreqMaps.empty() && !ckptData.excludeDDRKeyFreqMaps.empty()) { + dataHandlers.push_back(make_unique()); + } } void Checkpoint::SetDataHandler(const vector& featureTypes) @@ -84,7 +88,8 @@ void Checkpoint::SetDataHandler(const vector& featureTypes) { CkptFeatureType::EMB_HASHMAP, [&] { dataHandlers.push_back(make_unique()); } }, { CkptFeatureType::MAX_OFFSET, [&] { dataHandlers.push_back(make_unique()); } }, { CkptFeatureType::KEY_OFFSET_MAP, [&] { dataHandlers.push_back(make_unique()); } }, - { CkptFeatureType::FEAT_ADMIT_N_EVICT, [&] { dataHandlers.push_back(make_unique()); } } }; + { CkptFeatureType::FEAT_ADMIT_N_EVICT, [&] { dataHandlers.push_back(make_unique()); } }, + { CkptFeatureType::DDR_KEY_FREQ_MAP, [&] { dataHandlers.push_back(make_unique()); } } }; for (const auto& featureType : featureTypes) { setCkptMap.at(featureType)(); @@ -359,12 +364,13 @@ void Checkpoint::LoadProcess(CkptData& ckptData) vector embNames {}; vector dirNames { dataHandler->GetDirNames() }; vector saveDataTypes { dataHandler->GetDataTypes() }; - GetUpperLayerLoadDir(dirNames); - embNames = GetEmbedTableNames(); - + if (find(dirNames.begin(), dirNames.end(), ssdSymbol) != dirNames.end()) { + embNames = GetTableLayerLoadDir(); + } else { + embNames = GetEmbedTableNames(); + } LoadDataset(embNames, saveDataTypes, dataHandler, ckptData); - dataHandler->GetProcessData(ckptData); } } @@ -390,6 +396,25 @@ vector Checkpoint::GetEmbedTableNames() return loadTableNames; } +vector Checkpoint::GetTableLayerLoadDir() +{ + vector loadTableDir; + auto dir { opendir(innerDirPath.c_str()) }; + struct dirent* en; + if (dir != nullptr) { + while ((en = readdir(dir)) != nullptr) { + if (strcmp(en->d_name, currDir.c_str()) != 0 && + strcmp(en->d_name, prevDir.c_str()) != 0) { + loadTableDir.emplace_back(en->d_name); + } + } + closedir(dir); + } else { + LOG(WARNING) << "when loading data in ssd, there are no table files."; + } + return loadTableDir; +} + void Checkpoint::LoadDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler, @@ -465,7 +490,6 @@ void Checkpoint::ReadStream(CkptTransData& transData, auto resizeSize { datasetSize / dataElmtBytes }; SetTransDataSize(transData, resizeSize, dataType); - if (readFile.is_open()) { size_t idx = 0; size_t readSize = 0; diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index c860c29e..b2c1124d 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -31,7 +31,8 @@ namespace MxRec { const string dataFileType { ".data" }; const string attribFileType { ".attribute" }; const string dirSeparator { "/" }; - const mode_t dirMode { 0400 }; + const string ssdSymbol {"SSD"}; + const mode_t dirMode { 0500 }; const string currDir { "." }; const string prevDir { ".." }; @@ -48,7 +49,9 @@ namespace MxRec { CkptDataType::EMB_HASHMAP, CkptDataType::DEV_OFFSET, CkptDataType::HIST_REC, - CkptDataType::NDDR_FEATMAP + CkptDataType::NDDR_FEATMAP, + CkptDataType::DDR_FREQ_MAP, + CkptDataType::EXCLUDE_FREQ_MAP }; const set floatTransSet{ CkptDataType::EMB_DATA @@ -93,6 +96,7 @@ namespace MxRec { void LoadProcess(CkptData& ckptData); void GetUpperLayerLoadDir(const vector& dirNames); vector GetEmbedTableNames(); + vector GetTableLayerLoadDir(); void LoadDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler, CkptData& ckptData); void ReadStream(CkptTransData& transData, const string& dataDir, CkptDataType dataType, uint32_t dataElmtBytes); diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index aea1d2b7..82ee2f3a 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -48,9 +48,13 @@ namespace MxRec { "max_offset", "key_offset_map", "table_2_threshold", - "history_record" + "history_record", + "attribute", + "ddr_key_freq_map", + "exclude_ddr_key_freq_map", + "evict_pos" }; - const vector dataElmtBytes { 4, 4, 8, 8, 4, 4, 8, 4, 8, 8 }; + const vector dataElmtBytes { 4, 4, 8, 8, 4, 4, 8, 4, 8, 8, 8, 8, 8}; const uint32_t eightBytes { 8 }; const uint32_t fourBytes { 4 }; diff --git a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp index a8de3aed..d8bf2f07 100644 --- a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp +++ b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp @@ -51,7 +51,8 @@ CkptTransData EmbHashCkpt::GetDataset(CkptDataType dataType, string embName) map> dataTransMap { { CkptDataType::EMB_HASHMAP, [=] { SetEmbHashMapTrans(embName); } }, { CkptDataType::DEV_OFFSET, [=] { SetDevOffsetTrans(embName); } }, - { CkptDataType::EMB_CURR_STAT, [=] { SetEmbCurrStatTrans(embName); } } }; + { CkptDataType::EMB_CURR_STAT, [=] { SetEmbCurrStatTrans(embName); } }, + { CkptDataType::EVICT_POS, [=] { SetEvictPosTrans(embName); } } }; CleanTransfer(); dataTransMap.at(dataType)(); @@ -62,7 +63,8 @@ void EmbHashCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransDat { map> dataLoadMap { { CkptDataType::EMB_HASHMAP, [=] { SetEmbHashMap(embName); } }, { CkptDataType::DEV_OFFSET, [=] { SetDevOffset(embName); } }, - { CkptDataType::EMB_CURR_STAT, [=] { SetEmbCurrStat(embName); } } }; + { CkptDataType::EMB_CURR_STAT, [=] { SetEmbCurrStat(embName); } }, + { CkptDataType::EVICT_POS, [=] { SetEvictPos(embName); } } }; CleanTransfer(); transferData = move(loadedData); @@ -124,6 +126,23 @@ void EmbHashCkpt::SetEmbCurrStatTrans(string embName) transArr.push_back(static_cast(saveEmbHashMaps.at(embName).currentUpdatePos)); transArr.push_back(static_cast(saveEmbHashMaps.at(embName).hostVocabSize)); transArr.push_back(static_cast(saveEmbHashMaps.at(embName).devVocabSize)); + transArr.push_back(static_cast(saveEmbHashMaps.at(embName).maxOffset)); +} + +void EmbHashCkpt::SetEvictPosTrans(string embName) +{ + const auto& evictPos = saveEmbHashMaps.at(embName).evictPos; + auto& transArr = transferData.int64Arr; + auto& attribute = transferData.attribute; + auto evictPosSize = evictPos.size(); + + attribute.push_back(evictPosSize); + attribute.push_back(eightBytes); + transferData.datasetSize = evictPosSize * eightBytes; + transferData.attributeSize = attribute.size() * eightBytes; + + transArr.reserve(evictPosSize); + transArr.insert(transArr.end(), evictPos.begin(), evictPos.end()); } void EmbHashCkpt::SetEmbHashMap(string embName) @@ -134,7 +153,6 @@ void EmbHashCkpt::SetEmbHashMap(string embName) if (i + embHashElmtNum > transArr.size()) { // this is an error, need to log this } - hostHashMap[transArr.at(i)] = static_cast(transArr.at(i + 1)); } } @@ -153,6 +171,18 @@ void EmbHashCkpt::SetDevOffset(string embName) dev2Key.insert(dev2Key.begin(), transArr.begin() + attribute.at(attrbDev2BatchIdx), transArr.end()); } +void EmbHashCkpt::SetEvictPos(string embName) +{ + const auto& transArr = transferData.int64Arr; + const auto& attribute = transferData.attribute; + auto& evictPos = loadEmbHashMaps[embName].evictPos; + + evictPos.resize(attribute.at(attrEvictPosIdx)); + fill(evictPos.begin(), evictPos.end(), -1); + evictPos.insert(evictPos.begin(), transArr.begin() + attribute.at(attrEvictPosIdx), transArr.end()); +} + + void EmbHashCkpt::SetEmbCurrStat(string embName) { auto& embCurrStat = loadEmbHashMaps[embName]; @@ -161,4 +191,5 @@ void EmbHashCkpt::SetEmbCurrStat(string embName) embCurrStat.currentUpdatePos = static_cast(transArr.at(currUpdataPosIdx)); embCurrStat.hostVocabSize = static_cast(transArr.at(hostVocabIdx)); embCurrStat.devVocabSize = static_cast(transArr.at(devVocabIdx)); + embCurrStat.maxOffset = static_cast(transArr.at(maxOffsetIdx)); } \ No newline at end of file diff --git a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h index 8596f74a..4fd413af 100644 --- a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h +++ b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h @@ -32,27 +32,32 @@ namespace MxRec { private: const vector fileDirNames { "HashTable", "DDR" }; const vector saveDataTypes { CkptDataType::EMB_HASHMAP, CkptDataType::DEV_OFFSET, - CkptDataType::EMB_CURR_STAT }; + CkptDataType::EMB_CURR_STAT, CkptDataType::EVICT_POS }; const int currUpdataPosIdx { 0 }; const int hostVocabIdx { 1 }; const int devVocabIdx { 2 }; + const int maxOffsetIdx { 3 }; const int attrbDev2BatchIdx { 0 }; const int attrbDev2KeyIdx { 1 }; + const int attrEvictPosIdx {0}; + const int embHashElmtNum { 2 }; - const int embCurrStatNum { 3 }; + const int embCurrStatNum { 4 }; emb_hash_mem_t saveEmbHashMaps; emb_hash_mem_t loadEmbHashMaps; void SetEmbHashMapTrans(string embName); void SetDevOffsetTrans(string embName); void SetEmbCurrStatTrans(string embName); + void SetEvictPosTrans(string embName); void SetEmbHashMap(string embName); void SetDevOffset(string embName); void SetEmbCurrStat(string embName); + void SetEvictPos(string embName); }; } diff --git a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp new file mode 100644 index 00000000..f9377af7 --- /dev/null +++ b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-08-17 + */ + +#include "key_freq_map_ckpt.h" + +using namespace std; +using namespace MxRec; + +void KeyFreqMapCkpt::SetProcessData(CkptData& processData) +{ + saveDDRKeyFreqMaps.clear(); + loadDDRKeyFreqMaps.clear(); + saveExcludeDDRKeyFreqMaps.clear(); + loadExcludeDDRKeyFreqMaps.clear(); + + saveDDRKeyFreqMaps = std::move(processData.ddrKeyFreqMaps); + saveExcludeDDRKeyFreqMaps = std::move(processData.excludeDDRKeyFreqMaps); +} + +void KeyFreqMapCkpt::GetProcessData(CkptData& processData) +{ + processData.ddrKeyFreqMaps = std::move(loadDDRKeyFreqMaps); + processData.excludeDDRKeyFreqMaps = std::move(loadExcludeDDRKeyFreqMaps); + + saveDDRKeyFreqMaps.clear(); + loadDDRKeyFreqMaps.clear(); + saveExcludeDDRKeyFreqMaps.clear(); + loadExcludeDDRKeyFreqMaps.clear(); +} + +vector KeyFreqMapCkpt::GetDataTypes() +{ + return saveDataTypes; +} + +vector KeyFreqMapCkpt::GetDirNames() +{ + return fileDirNames; +} + +vector KeyFreqMapCkpt::GetEmbNames() +{ + vector embNames; + for (const auto& item :saveDDRKeyFreqMaps) { + embNames.push_back(item.first); + } + return embNames; +} + +CkptTransData KeyFreqMapCkpt::GetDataset(CkptDataType dataType, string embName) +{ + map> dataTransMap { + { CkptDataType::DDR_FREQ_MAP, [=] { SetDDRFreqMapTrans(embName); } }, + { CkptDataType::EXCLUDE_FREQ_MAP, [=] { SetExcludeDDRFreqMapTrans(embName); } } }; + + CleanTransfer(); + dataTransMap.at(dataType)(); + return move(transferData); +} + +void KeyFreqMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) +{ + map> dataLoadMap { + { CkptDataType::DDR_FREQ_MAP, [=] { SetDDRFreqMaps(embName); } }, + { CkptDataType::EXCLUDE_FREQ_MAP, [=] { SetExcludeDDRFreqMaps(embName); } } }; + + CleanTransfer(); + transferData = move(loadedData); + dataLoadMap.at(dataType)(); +} + +// set DDRFreqMapTrans for save +void KeyFreqMapCkpt::SetDDRFreqMapTrans(string embName) +{ + auto& transArr = transferData.int64Arr; + auto& attribute = transferData.attribute; + auto ddrFreqMapSize = saveDDRKeyFreqMaps.at(embName).size(); + + attribute.push_back(ddrFreqMapSize); + ddrFreqMapSize = ddrFreqMapSize * freqMapElmtNum; + + attribute.push_back(freqMapElmtNum); + attribute.push_back(eightBytes); + + transferData.datasetSize = ddrFreqMapSize * eightBytes; + transferData.attributeSize = attribute.size() * eightBytes; + + transArr.reserve(ddrFreqMapSize); + for (const auto& it : saveDDRKeyFreqMaps.at(embName)) { + transArr.push_back(it.first); + transArr.push_back(static_cast(it.second)); + } +} + +void KeyFreqMapCkpt::SetExcludeDDRFreqMapTrans(string embName) +{ + auto& transArr = transferData.int64Arr; + auto& attribute = transferData.attribute; + auto excludeDDRFreqMapSize = saveExcludeDDRKeyFreqMaps.at(embName).size(); + + attribute.push_back(excludeDDRFreqMapSize); + excludeDDRFreqMapSize = excludeDDRFreqMapSize * freqMapElmtNum; + + attribute.push_back(freqMapElmtNum); + attribute.push_back(eightBytes); + + transferData.datasetSize = excludeDDRFreqMapSize * eightBytes; + transferData.attributeSize = attribute.size() * eightBytes; + + transArr.reserve(excludeDDRFreqMapSize); + for (const auto& it : saveExcludeDDRKeyFreqMaps.at(embName)) { + transArr.push_back(it.first); + transArr.push_back(static_cast(it.second)); + } +} + +void KeyFreqMapCkpt::SetDDRFreqMaps(string embName) +{ + auto& ddrKeyFreqMap = loadDDRKeyFreqMaps[embName]; + const auto& transArr = transferData.int64Arr; + for (size_t i = 0; i < transArr.size(); i += freqMapElmtNum) { + ddrKeyFreqMap[transArr.at(i)] = static_cast(transArr.at(i + 1)); + } +} + +void KeyFreqMapCkpt::SetExcludeDDRFreqMaps(string embName) +{ + auto& excludeDDRKeyFreqMap = loadExcludeDDRKeyFreqMaps[embName]; + const auto& transArr = transferData.int64Arr; + for (size_t i = 0; i < transArr.size(); i += freqMapElmtNum) { + excludeDDRKeyFreqMap[transArr.at(i)] = static_cast(transArr.at(i + 1)); + } +} diff --git a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h new file mode 100644 index 00000000..2db7091c --- /dev/null +++ b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-08-17 + */ + +#ifndef MX_REC_KEY_FREQ_MAP_CKPT_H +#define MX_REC_KEY_FREQ_MAP_CKPT_H + +#include "ckpt_data_handler/ckpt_data_handler.h" + +namespace MxRec { + using namespace std; + + class KeyFreqMapCkpt : public CkptDataHandler { + public: + KeyFreqMapCkpt() = default; + ~KeyFreqMapCkpt() override {} + + void SetProcessData(CkptData& processData) override; + void GetProcessData(CkptData& processData) override; + + vector GetDataTypes() override; + + vector GetDirNames() override; + vector GetEmbNames() override; + CkptTransData GetDataset(CkptDataType dataType, string embName) override; + + void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) override; + + private: + const vector fileDirNames { "HashTable", "SSD" }; + const vector saveDataTypes { CkptDataType::DDR_FREQ_MAP, CkptDataType::EXCLUDE_FREQ_MAP}; + + const int freqMapElmtNum { 2 }; // Number of element types in the keyFreqMap during saving + + key_freq_mem_t saveDDRKeyFreqMaps; + key_freq_mem_t loadDDRKeyFreqMaps; + key_freq_mem_t saveExcludeDDRKeyFreqMaps; + key_freq_mem_t loadExcludeDDRKeyFreqMaps; + + void SetDDRFreqMapTrans(string embName); + void SetExcludeDDRFreqMapTrans(string embName); + + void SetDDRFreqMaps(string embName); + void SetExcludeDDRFreqMaps(string embName); + }; +} + +#endif // MX_REC_KEY_FREQ_MAP_CKPT_H diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 3e8dc2ce..f47569eb 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -9,6 +9,7 @@ #include "utils/time_cost.h" #include "checkpoint/checkpoint.h" + using namespace MxRec; using namespace std; @@ -217,6 +218,15 @@ bool HybridMgmt::Save(const string savePath) saveData.keyOffsetMap = preprocess->GetKeyOffsetMap(); } + if (isSSDEnabled) { + for (auto& it : cacheManager->ddrKeyFreqMap) { + saveData.ddrKeyFreqMaps[it.first] = it.second.GetFreqTable(); + } + saveData.excludeDDRKeyFreqMaps = cacheManager->excludeDDRKeyCountMap; + auto step = GetStepFromPath(savePath); + cacheManager->SaveSSDEngine(step); + } + // 保存特征准入淘汰相关的数据 auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { @@ -249,24 +259,11 @@ bool HybridMgmt::Load(const string& loadPath) CkptData loadData; Checkpoint loadCkpt; vector loadFeatures; - if (!mgmtRankInfo.noDDR) { - // DDR模式加载的类型为host的emb表以及hashmap - loadFeatures.push_back(CkptFeatureType::HOST_EMB); - loadFeatures.push_back(CkptFeatureType::EMB_HASHMAP); - } else { - // HBM模式加载的类型为最大偏移(真正使用了多少vocab容量),特征到偏移的映射 - loadFeatures.push_back(CkptFeatureType::MAX_OFFSET); - loadFeatures.push_back(CkptFeatureType::KEY_OFFSET_MAP); - } - // 添加特征准入淘汰相关的数据类型的加载 auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); - if (featAdmitNEvict.GetFunctionSwitch()) { - loadFeatures.push_back(CkptFeatureType::FEAT_ADMIT_N_EVICT); - } + SetFeatureTypeForLoad(loadFeatures, featAdmitNEvict); loadData.hostEmbs = hostEmbs->GetHostEmbs(); // 获取已经初始化好的host emb - // 执行加载操作 loadCkpt.LoadModel(loadPath, loadData, mgmtRankInfo, mgmtEmbInfo, loadFeatures); @@ -294,6 +291,12 @@ bool HybridMgmt::Load(const string& loadPath) featAdmitNEvict.LoadHistoryRecords(loadData.histRec); } + if (isSSDEnabled) { + VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: ssd key freq map"); + auto step = GetStepFromPath(loadPath); + cacheManager->Load(loadData.ddrKeyFreqMaps, loadData.excludeDDRKeyFreqMaps, step); + } + VLOG(GLOG_DEBUG) << (MGMT + "Finish host side load process"); preprocess->LoadSaveUnlock(); @@ -306,6 +309,29 @@ bool HybridMgmt::Load(const string& loadPath) return true; } +void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, + const FeatureAdmitAndEvict& featAdmitNEvict) +{ + if (!mgmtRankInfo.noDDR) { + // DDR模式加载的类型为host的emb表以及hashmap + loadFeatures.push_back(CkptFeatureType::HOST_EMB); + loadFeatures.push_back(CkptFeatureType::EMB_HASHMAP); + } else { + // HBM模式加载的类型为最大偏移(真正使用了多少vocab容量),特征到偏移的映射 + loadFeatures.push_back(CkptFeatureType::MAX_OFFSET); + loadFeatures.push_back(CkptFeatureType::KEY_OFFSET_MAP); + } + + // 添加特征准入淘汰相关的数据类型的加载 + if (featAdmitNEvict.GetFunctionSwitch()) { + loadFeatures.push_back(CkptFeatureType::FEAT_ADMIT_N_EVICT); + } + + if (isSSDEnabled) { + loadFeatures.push_back(CkptFeatureType::DDR_KEY_FREQ_MAP); + } +} + /// 获取key对应的offset,python侧调用 /// \param tableName 表名 /// \return @@ -977,3 +1003,13 @@ void HybridMgmt::EvictSSDKeys(const string& embName, const vector& ke } cacheManager->EvictSSDEmbedding(embName, ssdKeys); } + +int HybridMgmt::GetStepFromPath(const string& loadPath) +{ + regex pattern("sparse-model-\\d+-(\\d+)"); + smatch match; + if (regex_search(loadPath, match, pattern)) { + return stoi(match[1]); + } + return 0; +} \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index c131e19a..5889c564 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -58,6 +58,9 @@ namespace MxRec { bool Load(const string& loadPath); + void SetFeatureTypeForLoad(vector& loadFeatures, + const FeatureAdmitAndEvict& featAdmitNEvict); + key_offset_map_t SendHostMap(const string tableName); void ReceiveHostMap(all_key_offset_map_t keyOffsetMap); @@ -121,6 +124,7 @@ namespace MxRec { void PrepareDDRData(const std::string& embTableName, EmbHashMapInfo& embHashMap, const vector& keys, int channelId); + int GetStepFromPath(const string& loadPath); void AddCacheManagerTraceLog(absl::flat_hash_map, EmbHashMapInfo>& embHashMaps); private: diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index b21d6fb7..451d384a 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -211,9 +211,9 @@ namespace MxRec { } } // validate file size - if (datasetSize <= FILE_MIN_SIZE || datasetSize > FILE_MAX_SIZE) { + if (datasetSize > FILE_MAX_SIZE) { LOG(ERROR) << StringFormat("the reading file size is invalid, " - "not in range (%d,%d]", FILE_MIN_SIZE, FILE_MAX_SIZE); + "not in range [%d,%lld]", FILE_MIN_SIZE, FILE_MAX_SIZE); throw invalid_argument(StringFormat("file size invalid")); } } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index c03ad828..6f0b6e42 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -78,7 +78,7 @@ namespace MxRec { // validate file constexpr long long FILE_MAX_SIZE = 1LL << 40; - constexpr int FILE_MIN_SIZE = 1; + constexpr int FILE_MIN_SIZE = 0; struct PerfConfig { static int keyProcessThreadNum; @@ -102,6 +102,7 @@ namespace MxRec { const string COMBINE_HISTORY_NAME = "combine_table_history"; using emb_key_t = int64_t; + using freq_num_t = int64_t; using emb_name_t = std::string; using keys_t = std::vector; using lookup_key_t = std::tuple; // batch_id quarry_lable keys_vector @@ -540,13 +541,16 @@ namespace MxRec { using trans_serialize_t = uint8_t; using key_offset_map_t = std::map; using all_key_offset_map_t = std::map>; + using key_freq_mem_t = unordered_map>; enum class CkptFeatureType { HOST_EMB = 0, EMB_HASHMAP = 1, MAX_OFFSET = 2, KEY_OFFSET_MAP = 3, - FEAT_ADMIT_N_EVICT = 4 + FEAT_ADMIT_N_EVICT = 4, + DDR_KEY_FREQ_MAP = 5, + EXCLUDE_DDR_KEY_FREQ_MAP = 6 }; struct CkptData { @@ -556,6 +560,8 @@ namespace MxRec { key_offset_mem_t keyOffsetMap; table_2_thresh_mem_t table2Thresh; AdmitAndEvictData histRec; + key_freq_mem_t ddrKeyFreqMaps; + key_freq_mem_t excludeDDRKeyFreqMaps; }; struct CkptTransData { @@ -578,7 +584,10 @@ namespace MxRec { NDDR_FEATMAP = 6, TABLE_2_THRESH = 7, HIST_REC = 8, - ATTRIBUTE = 9 + ATTRIBUTE = 9, + DDR_FREQ_MAP = 10, + EXCLUDE_FREQ_MAP = 11, + EVICT_POS = 12 }; } // end namespace MxRec diff --git a/src/test_ut.sh b/src/test_ut.sh index aa8edf3c..98cdcf39 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -81,7 +81,7 @@ cd "$(dirname "${PWD}")" COVERAGE_FILE=coverage.info REPORT_FOLDER=coverage_report lcov --rc lcov_branch_coverage=1 -c -d build -o "${COVERAGE_FILE}"_tmp -lcov -r "${COVERAGE_FILE}"_tmp 'ut/*' '/usr1/mxRec/src/core/key_process*' '/usr1/mxRec/src/core/hybrid_mgmt*' '/usr1/mxRec/src/core/host_emb*' '/usr1/mxRec/src/core/emb_table*' '7/ext*' 'platform/*' '/usr/local/*' '/usr/include/*' '/opt/buildtools/python-3.7.5/lib/python3.7/site-packages/tensorflow*' 'tests/*' '/usr1/mxRec/src/core/ock_ctr_common/include*' --rc lcov_branch_coverage=1 -o "${COVERAGE_FILE}" +lcov -r "${COVERAGE_FILE}"_tmp 'ut/*' '/usr1/mxRec/src/core/key_process*' '/usr1/mxRec/src/core/hybrid_mgmt*' '/usr1/mxRec/src/core/host_emb*' '/usr1/mxRec/src/core/emb_table*' '7/ext*' '*7/bits*' 'platform/*' '/usr/local/*' '/usr/include/*' '/opt/buildtools/python-3.7.5/lib/python3.7/site-packages/tensorflow*' 'tests/*' '/usr1/mxRec/src/core/ock_ctr_common/include*' --rc lcov_branch_coverage=1 -o "${COVERAGE_FILE}" genhtml --rc genhtml_branch_coverage=1 "${COVERAGE_FILE}" -o "${REPORT_FOLDER}" [ -d "${COVERAGE_FILE}"_tmp ] && rm -rf "${COVERAGE_FILE}"_tmp [ -d "${COVERAGE_FILE}" ] && rm -rf "${COVERAGE_FILE}" diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 1cf27378..623c001b 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -161,7 +161,6 @@ protected: { for (int64_t i { 0 }; i < hostVocabSize; ++i) { testKeyOffsetMap[featMem] = i; - featMem++; } } @@ -175,6 +174,40 @@ protected: } } + void SetDDRKeyFreqMap(unordered_map& testDDRKeyFreqMap) + { + for (int64_t i { 0 }; i < hostVocabSize; ++i) { + testDDRKeyFreqMap[featMem] = i; + featMem++; + } + } + + void SetExcludeDDRKeyFreqMap(unordered_map& testExcludeDDRKeyFreqMap) + { + for (int64_t i { 0 }; i < hostVocabSize; ++i) { + testExcludeDDRKeyFreqMap[featMem] = i; + featMem++; + } + } + + void SetDDRKeyFreqMaps(key_freq_mem_t& testDDRKeyFreqMaps) + { + unordered_map testDDRKeyFreqMap; + for (const auto& testEmbInfo : testEmbInfos) { + SetDDRKeyFreqMap(testDDRKeyFreqMap); + testDDRKeyFreqMaps[testEmbInfo.name] = std::move(testDDRKeyFreqMap); + } + } + + void SetExcludeDDRKeyFreqMaps(key_freq_mem_t& testExcludeDDRKeyFreqMaps) + { + unordered_map testExcludeDDRKeyFreqMap; + for (const auto& testEmbInfo : testEmbInfos) { + SetExcludeDDRKeyFreqMap(testExcludeDDRKeyFreqMap); + testExcludeDDRKeyFreqMaps[testEmbInfo.name] = std::move(testExcludeDDRKeyFreqMap); + } + } + void SetTable2Threshold(table_2_thresh_mem_t& testTable2Threshold) { for (const auto& testEmbInfo : testEmbInfos) { @@ -489,4 +522,39 @@ TEST_F(CheckpointTest, FeatAdmitNEvict) EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); } } -} \ No newline at end of file +} + + +TEST_F(CheckpointTest, KeyFreqMaps) +{ + key_freq_mem_t testDDRKeyFreqMaps; + key_freq_mem_t validDDRKeyFreqMaps; + key_freq_mem_t testExcludeDDRKeyFreqMaps; + key_freq_mem_t validExcludeDDRKeyFreqMaps; + + SetEmbInfo(); + SetDDRKeyFreqMaps(testDDRKeyFreqMaps); + SetExcludeDDRKeyFreqMaps(testExcludeDDRKeyFreqMaps); + validDDRKeyFreqMaps = testDDRKeyFreqMaps; + validExcludeDDRKeyFreqMaps = testExcludeDDRKeyFreqMaps; + + CkptData testSaveData; + CkptData validLoadData; + CkptData testLoadData; + + testSaveData.ddrKeyFreqMaps = std::move(testDDRKeyFreqMaps); + testSaveData.excludeDDRKeyFreqMaps = std::move(testExcludeDDRKeyFreqMaps); + validLoadData.ddrKeyFreqMaps = std::move(validDDRKeyFreqMaps); + + Checkpoint testCkpt; + testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); + testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::DDR_KEY_FREQ_MAP }); + EXPECT_EQ(validLoadData.ddrKeyFreqMaps.size(), testLoadData.ddrKeyFreqMaps.size()); + + for (const auto& it : validLoadData.ddrKeyFreqMaps) { + EXPECT_EQ(1, testLoadData.ddrKeyFreqMaps.count(it.first)); + const auto& ddrKeyFreqMap = testLoadData.ddrKeyFreqMaps.at(it.first); + EXPECT_EQ(it.second, ddrKeyFreqMap); + } +} + -- Gitee From 7df7f72a2e296c2aed0f398289e0364a7d79bb68 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 4 Sep 2023 14:52:51 +0800 Subject: [PATCH 310/551] Match-id-60c765ad4db72b0009f113f82e84221d0f8b6deb --- mx_rec/saver/saver.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 15f11a7d..9e4bbab4 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -57,7 +57,13 @@ class Saver(object): @performance("Save") def save(self, sess, save_path="model", global_step=None): """ - Save sparse tables + Save sparse tables. For local save, both save_easy mode and normal mode is supported. For HDFS save, + only save_easy mode is supported. + For easy_save mode, checkpoint is saved in under format: + ./rank_id/HashTable/HBM/embed_table_name/key/xxx.data + ./rank_id/HashTable/HBM/embed_table_name/key/xxx.attribute + ./rank_id/HashTable/HBM/embed_table_name/embedding/xxx.data + ./rank_id/HashTable/HBM/embed_table_name/embedding/xxx.attribute :param sess: A Session to use to save the sparse table variables :param save_path: Only absolute path supported :param global_step: If provided the global step number is appended to save_path to create @@ -85,13 +91,13 @@ class Saver(object): if tf.io.gfile.exists(saving_path): tf.io.gfile.rmtree(saving_path) - logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been deleted.") + logging.info(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been deleted.") tf.io.gfile.makedirs(saving_path) - logging.debug(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been made.") + logging.info(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been made.") self._save(sess, saving_path) logging.info(f"sparse model was saved in dir '{saving_path}' .") - logging.debug(f"======== Saving finished for rank id {self.rank_id} ========") + logging.info(f"======== Saving finished for rank id {self.rank_id} ========") @performance("Restore") def restore(self, sess, reading_path): @@ -109,6 +115,9 @@ class Saver(object): def _build_save(self): for var in self.var_list: + if os.getenv("TF_DEVICE", " ") == "NPU" and "merged" not in var.name: + continue + table_instance = get_table_instance(var) table_name = table_instance.table_name with tf.compat.v1.variable_scope(table_name): @@ -119,6 +128,8 @@ class Saver(object): def _build_restore(self): for var in self.var_list: + if os.getenv("TF_DEVICE", " ") == "NPU" and "merged" not in var.name: + continue table_instance = get_table_instance(var) sub_placeholder_dict = self.placeholder_dict[table_instance.table_name] with tf.compat.v1.variable_scope(table_instance.table_name): @@ -207,10 +218,10 @@ class Saver(object): if is_asc_manager_initialized() and self.save_easy_mode: send_host_data(key_offset_dict) - logging.debug(f"host data was sent to the host pipeline.") + logging.info("host data was sent to the host pipeline.") if is_asc_manager_initialized() and not self.save_easy_mode: restore_host_data(reading_path) - logging.debug(f"host data was restored.") + logging.info("host data was restored.") sess.run(self.restore_fetch_list, feed_dict=restore_feed_dict) -- Gitee From d2a2d506672e6675ff4dc900fa7e6136f9f6586d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 4 Sep 2023 17:03:18 +0800 Subject: [PATCH 311/551] Match-id-41e2288f683b026f395e1ad04c5c548e983643c9 --- mx_rec/util/communication/hccl_mgmt.py | 23 ++++++++++++----------- src/core/key_process/key_process.cpp | 3 +++ src/pybind/module_main.cpp | 5 ++--- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index 866c6682..0a9af6a0 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -45,10 +45,13 @@ def parse_hccl_json(): raise ValueError(f"hccl_json device_id wrong.") import mxrec_pybind - device_id = mxrec_pybind.get_logic_id(int(device.get("device_id"))) - if device_id > MAX_DEVICE_ID: + res = mxrec_pybind.get_logic_id(int(device.get("device_id"))) + if res < 0: + raise RuntimeError( + f"get logic id from physic id fail, error code is {res}, please check if dsmi api is functional.") + if res > MAX_DEVICE_ID: raise ValueError(f"get logic id from physic id fail, the device id is invalid.") - rank_to_device_dict[rank_id] = device_id + rank_to_device_dict[rank_id] = res return rank_to_device_dict @@ -84,17 +87,15 @@ def set_hccl_info_without_json(): for device_idx in sorted_device_list: import mxrec_pybind + res = mxrec_pybind.get_logic_id(int(device_idx)) + if res < 0: + raise RuntimeError( + f"get logic id from physic id fail, error code is {res}, please check if dsmi api is functional.") - try: - device_id = mxrec_pybind.get_logic_id(int(device_idx)) - except RuntimeError as exp: - raise RuntimeError(f"get logic id from physic id fail. Possible reasons: 1) running user permission " - f"is not enough to call dsmi api 2) driver has been used by other process") from \ - exp - if device_id > MAX_DEVICE_ID: + if res > MAX_DEVICE_ID: raise ValueError(f"get logic id from physic id fail.") index = sorted_device_list.index(device_idx) - rank_to_device_dict[index + 1] = device_id + rank_to_device_dict[index + 1] = res return rank_to_device_dict diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index e49066f4..a2005bdd 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -468,7 +468,10 @@ void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tenso if (PerfConfig::gradientStrategy && channel == TRAIN_CHANNEL_ID) { keys_t uniqueKeys; vector restoreVecSec; + + TimeCost globalUniqueSyncTC; GlobalUnique(lookupKeys, uniqueKeys, restoreVecSec); + VLOG(GLOG_DEBUG) << StringFormat("globalUniqueSyncTC(ms):%d", globalUniqueSyncTC.ElapsedMS()); tensors->push_back(Vec2TensorI32(restoreVecSec)); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : Vec2TensorI32(uniqueKeys)); } diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index de01f850..70559c83 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -35,11 +35,10 @@ int GetUBHotSize(int devID) return static_cast(static_cast(MxRec::GetUBSize(devID)) / sizeof(float) * HOT_EMB_CACHE_PCT) ; } -uint32_t GetLogicID(uint32_t phyid) +int32_t GetLogicID(uint32_t phyid) { - int32_t ret = 0; uint32_t logicId; - ret = dsmi_get_logicid_from_phyid(phyid, &logicId); + int32_t ret = dsmi_get_logicid_from_phyid(phyid, &logicId); if (ret != 0) { return ret; } -- Gitee From afb30eb244276d8d7b81210534ff928f0ddaba9c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 5 Sep 2023 19:59:24 +0800 Subject: [PATCH 312/551] Match-id-62eb51829bbb374faa01227ac9ad3aaea5dc64b7 --- src/core/emb_table/emb_table.cpp | 38 +-------- src/core/host_emb/host_emb.cpp | 77 ++----------------- .../constant_initializer.cpp | 2 + .../random_normal_initializer.cpp | 3 + .../truncated_normal_initializer.cpp | 3 + src/core/utils/common.cpp | 6 +- src/core/utils/common.h | 4 +- 7 files changed, 23 insertions(+), 110 deletions(-) diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index a13b1eed..fea2f521 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -100,42 +100,10 @@ void EmbTable::RandomInit(void* newBlock, const vector& initiali "Device GenerateEmbData Start, seed:%d, initializer num: %d", seed, initializeInfos.size()); vector devEmb(blockSize); for (auto initializeInfo: initializeInfos) { - Initializer* initializer; - switch (initializeInfo.initializerType) { - case InitializerType::CONSTANT: { - LOG(INFO) << StringFormat( - "Device GenerateEmbData ing using Constant Initializer by value %f. name %s, start %d, len %d.", - initializeInfo.constantInitializerInfo.constantValue, - initializeInfo.name.c_str(), initializeInfo.start, initializeInfo.len); - initializer = &initializeInfo.constantInitializer; - break; - } - case InitializerType::TRUNCATED_NORMAL: { - LOG(INFO) << StringFormat( - "Device GenerateEmbData ing using Truncated Normal Initializer by mean: %f stddev: %f. " - "name %s, start %d, len %d.", initializeInfo.normalInitializerInfo.mean, - initializeInfo.normalInitializerInfo.stddev, initializeInfo.name.c_str(), - initializeInfo.start, initializeInfo.len); - initializer = &initializeInfo.truncatedNormalInitializer; - break; - } - case InitializerType::RANDOM_NORMAL: { - LOG(INFO) << StringFormat( - "Device GenerateEmbData ing using Random Normal Initializer by mean: %f stddev: %f. " - "name %s, start %d, len %d.", initializeInfo.normalInitializerInfo.mean, - initializeInfo.normalInitializerInfo.stddev, initializeInfo.name.c_str(), - initializeInfo.start, initializeInfo.len); - initializer = &initializeInfo.randomNormalInitializer; - break; - } - default: { - LOG(WARNING) << "Device Invalid Initializer Type. Using default Constant Initializer with value 0."; - ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0, 1); - initializer = &defaultInitializer; - } - } + LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", + initializeInfo.name.c_str()); for (int i = 0; i < BLOCK_EMB_COUNT; i++) { - initializer->GenerateData(&devEmb[i * embSize], embSize); + initializeInfo.initializer->GenerateData(&devEmb[i * embSize], embSize); } } LOG(INFO) << StringFormat("Device GenerateEmbData End, seed:%d", seed); diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 649fae66..e5ba9d6a 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -49,45 +49,11 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in embData.resize(vocabSize, vector(embeddingSize)); for (auto initializeInfo: initializeInfos) { - Initializer* initializer; - - switch (initializeInfo.initializerType) { - case InitializerType::CONSTANT: { - LOG(INFO) << StringFormat( - HOSTEMB + "GenerateEmbData ing using Constant Initializer by value %f. name %s, " - "start %d, len %d.", initializeInfo.constantInitializerInfo.constantValue, - initializeInfo.name.c_str(), initializeInfo.start, initializeInfo.len); - initializer = &initializeInfo.constantInitializer; - break; - } - case InitializerType::TRUNCATED_NORMAL: { - LOG(INFO) << StringFormat( - HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: %f stddev: %f. " - "name %s, start %d, len %d.", initializeInfo.normalInitializerInfo.mean, - initializeInfo.normalInitializerInfo.stddev, initializeInfo.name.c_str(), - initializeInfo.start, initializeInfo.len); - initializer = &initializeInfo.truncatedNormalInitializer; - break; - } - case InitializerType::RANDOM_NORMAL: { - LOG(INFO) << StringFormat( - HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: %f stddev: %f. " - "name %s, start %d, len %d.", initializeInfo.normalInitializerInfo.mean, - initializeInfo.normalInitializerInfo.stddev, initializeInfo.name.c_str(), - initializeInfo.start, initializeInfo.len); - initializer = &initializeInfo.randomNormalInitializer; - break; - } - default: { - LOG(WARNING) << ( - HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); - ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0, 1); - initializer = &defaultInitializer; - } - } - + LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", + initializeInfo.name.c_str()); for (int i = 0; i < vocabSize; i++) { - initializer->GenerateData(embData.at(i).data(), embeddingSize); + initializeInfo.initializer->GenerateData(embData.at(i).data(), + embeddingSize); } } LOG(INFO) << StringFormat(HOSTEMB + "GenerateEmbData End, seed:%d", seed); @@ -270,38 +236,11 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve const vector& offset) { for (auto initializeInfo: initializeInfos) { - Initializer* initializer; - - switch (initializeInfo.initializerType) { - case InitializerType::CONSTANT: { - LOG(INFO) << StringFormat(HOSTEMB + "GenerateEmbData ing using Constant Initializer by value %d.", - initializeInfo.constantInitializerInfo.constantValue); - initializer = &initializeInfo.constantInitializer; - break; - } - case InitializerType::TRUNCATED_NORMAL: { - LOG(INFO) << StringFormat( - HOSTEMB + "GenerateEmbData ing using Truncated Normal Initializer by mean: %f stddev: %f.", - initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); - initializer = &initializeInfo.truncatedNormalInitializer; - break; - } - case InitializerType::RANDOM_NORMAL: { - LOG(INFO) << StringFormat( - HOSTEMB + "GenerateEmbData ing using Random Normal Initializer by mean: %f stddev: %f.", - initializeInfo.normalInitializerInfo.mean, initializeInfo.normalInitializerInfo.stddev); - initializer = &initializeInfo.randomNormalInitializer; - break; - } - default: { - LOG(ERROR) << (HOSTEMB + "Invalid Initializer Type. Using default Constant Initializer with value 0."); - ConstantInitializer defaultInitializer(initializeInfo.start, initializeInfo.len, 0, 1); - initializer = &defaultInitializer; - } - } - + LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", + initializeInfo.name.c_str()); for (size_t i = 0; i < offset.size(); i++) { - initializer->GenerateData(embData.at(offset.at(i)).data(), static_cast(embData[0].size())); + initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), + static_cast(embData[0].size())); } } } diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index 4327b638..e14d9ccb 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -19,6 +19,8 @@ ConstantInitializer::ConstantInitializer(int start, int len, float value, float void ConstantInitializer::GenerateData(float* const emb, const int embSize) { + LOG(INFO) << StringFormat("Device GenerateEmbData ing using Constant Initializer by value %f., start %d, len %d.", + value, start, len); if (len == 0) { return; } diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index 0cff333d..54cd813f 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -21,6 +21,9 @@ RandomNormalInitializer::RandomNormalInitializer(int start, int len, std::tuple< void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) { + LOG(INFO) << StringFormat( + "Device GenerateEmbData ing using Random Normal Initializer by mean: %f stddev: %f. " + "start %d, len %d.", mean, stddev, start, len); if (len == 0) { return; } diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index 85fb4a45..4223b00c 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -25,6 +25,9 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, std:: void TruncatedNormalInitializer::GenerateData(float* const emb, const int embSize) { + LOG(INFO) << StringFormat( + "Device GenerateEmbData ing using Truncated Normal Initializer by mean: %f stddev: %f. " + "start %d, len %d.", mean, stddev, start, len); if (len == 0) { return; } diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 451d384a..e972891e 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -73,7 +73,7 @@ namespace MxRec { { if (name == "constant_initializer") { initializerType = InitializerType::CONSTANT; - constantInitializer = ConstantInitializer(start, len, constantInitializerInfo.constantValue, + initializer = make_shared(start, len, constantInitializerInfo.constantValue, constantInitializerInfo.initK); } else { throw std::invalid_argument("Invalid Initializer Type."); @@ -88,10 +88,10 @@ namespace MxRec { if (name == "truncated_normal_initializer") { initializerType = InitializerType::TRUNCATED_NORMAL; - truncatedNormalInitializer = TruncatedNormalInitializer(start, len, ret); + initializer = make_shared(start, len, ret); } else if (name == "random_normal_initializer") { initializerType = InitializerType::RANDOM_NORMAL; - randomNormalInitializer = RandomNormalInitializer(start, len, ret); + initializer = make_shared(start, len, ret); } else { throw std::invalid_argument("Invalid Initializer Type."); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 6f0b6e42..e58b6900 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -418,9 +418,7 @@ namespace MxRec { ConstantInitializerInfo constantInitializerInfo; NormalInitializerInfo normalInitializerInfo; - ConstantInitializer constantInitializer; - TruncatedNormalInitializer truncatedNormalInitializer; - RandomNormalInitializer randomNormalInitializer; + shared_ptr initializer; }; template -- Gitee From a5b3249040b337502f1dd196a03ead488a8d53eb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 5 Sep 2023 20:07:55 +0800 Subject: [PATCH 313/551] Match-id-eb70253e79597b83ed1ace35e8ec2733ff8d6e2b --- src/core/emb_table/emb_table.cpp | 3 +-- src/core/host_emb/host_emb.cpp | 12 ++++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index fea2f521..74195ad2 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -100,8 +100,7 @@ void EmbTable::RandomInit(void* newBlock, const vector& initiali "Device GenerateEmbData Start, seed:%d, initializer num: %d", seed, initializeInfos.size()); vector devEmb(blockSize); for (auto initializeInfo: initializeInfos) { - LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", - initializeInfo.name.c_str()); + LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", initializeInfo.name.c_str()); for (int i = 0; i < BLOCK_EMB_COUNT; i++) { initializeInfo.initializer->GenerateData(&devEmb[i * embSize], embSize); } diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index e5ba9d6a..4ce6d6b4 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -49,11 +49,9 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in embData.resize(vocabSize, vector(embeddingSize)); for (auto initializeInfo: initializeInfos) { - LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", - initializeInfo.name.c_str()); + LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", initializeInfo.name.c_str()); for (int i = 0; i < vocabSize; i++) { - initializeInfo.initializer->GenerateData(embData.at(i).data(), - embeddingSize); + initializeInfo.initializer->GenerateData(embData.at(i).data(), embeddingSize); } } LOG(INFO) << StringFormat(HOSTEMB + "GenerateEmbData End, seed:%d", seed); @@ -236,11 +234,9 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve const vector& offset) { for (auto initializeInfo: initializeInfos) { - LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", - initializeInfo.name.c_str()); + LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", initializeInfo.name.c_str()); for (size_t i = 0; i < offset.size(); i++) { - initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), - static_cast(embData[0].size())); + initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), static_cast(embData[0].size())); } } } -- Gitee From 880f41bd547814cb9957e05948d0de8781fc7473 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 5 Sep 2023 20:15:55 +0800 Subject: [PATCH 314/551] Match-id-9b78f5e2b56d5c5535c5f3a5da3d86acba48fc06 --- src/core/host_emb/host_emb.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 4ce6d6b4..113bacdb 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -236,7 +236,8 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve for (auto initializeInfo: initializeInfos) { LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", initializeInfo.name.c_str()); for (size_t i = 0; i < offset.size(); i++) { - initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), static_cast(embData[0].size())); + initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), + static_cast(embData[0].size())); } } } -- Gitee From 4642ae341f5275b8b5a2befc4163805cabd16a12 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 5 Sep 2023 20:25:57 +0800 Subject: [PATCH 315/551] Match-id-70fdfef9262e0e0a59423c60e05531761481c7d3 --- src/core/host_emb/host_emb.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 113bacdb..e008a781 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -237,7 +237,7 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", initializeInfo.name.c_str()); for (size_t i = 0; i < offset.size(); i++) { initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), - static_cast(embData[0].size())); + static_cast(embData[0].size())); } } } -- Gitee From 8c60c29f859ac894dbc9e416b5369ae22b7c7089 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 6 Sep 2023 10:36:07 +0800 Subject: [PATCH 316/551] Match-id-39909da0bd930d560d714beab3241fab0bd9607a --- mx_rec/saver/saver.py | 4 +- mx_rec/saver/sparse.py | 61 ++++++++----------- mx_rec/util/initialize.py | 22 ++++++- .../constant_initializer.cpp | 2 - .../random_normal_initializer.cpp | 3 - .../truncated_normal_initializer.cpp | 3 - 6 files changed, 48 insertions(+), 47 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 9e4bbab4..b6b053b2 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -14,7 +14,7 @@ from tensorflow.python.util import compat from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ - send_host_data, get_ascend_global_hashtable_collection + send_host_data, get_ascend_global_hashtable_collection, set_sparse_dir from mx_rec.util.perf import performance from mx_rec.validator.validator import DirectoryValidator, FileValidator @@ -82,6 +82,8 @@ class Saver(object): ckpt_name = f"sparse-{base_name}" saving_path = os.path.join(directory, ckpt_name) + set_sparse_dir(saving_path) + try: if save_path.find("://") == -1: DirectoryValidator(saving_path).with_blacklist(exact_compare=False).check() diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index 88f5316c..2a1f9307 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -8,14 +8,14 @@ import json import numpy as np -from mx_rec.util.initialize import get_table_instance_by_name, export_table_name_set +from mx_rec.util.initialize import get_table_instance_by_name, export_table_name_set, get_sparse_dir from mx_rec.validator.validator import FileValidator class SparseProcessor: single_instance = None - def __init__(self, model_dir, **kwargs): + def __init__(self, **kwargs): self.sep = "/" self.export_name = "key-emb" self.device_dir_list = ["HashTable", "HBM"] @@ -28,9 +28,6 @@ class SparseProcessor: self.attrib_suffix = ".attribute" self.json_attrib_dtype = "data_type" self.json_attrib_shape = "shape" - self.model_dir = model_dir - if not os.path.exists(model_dir): - raise FileExistsError(f"the model_dir supported {model_dir} does not exist.") self.table_list = kwargs.get("table_list") self.default_table_list = list(export_table_name_set()) @@ -41,8 +38,8 @@ class SparseProcessor: self.table_list = check_table_param(self.table_list, self.default_table_list) @staticmethod - def set_instance(model_dir, **kwargs): - SparseProcessor.single_instance = SparseProcessor(model_dir, **kwargs) + def set_instance(**kwargs): + SparseProcessor.single_instance = SparseProcessor(**kwargs) @staticmethod def _get_data(data_dir, dtype, data_shape): @@ -87,26 +84,24 @@ class SparseProcessor: def export_sparse_data(self): logging.info("table list to be exported is %s", self.table_list) - sparse_dirs = self._get_sparse_dirs() - for sparse_dir in sparse_dirs: - ddr = False - sparse_dir = os.path.join(self.model_dir, sparse_dir) - dev_dir = set_upper_dir(sparse_dir, self.device_dir_list) - host_dir = set_upper_dir(sparse_dir, self.host_dir_list) - for table in self.table_list: - table_instance = get_table_instance_by_name(table) - device_table_dir = os.path.join(dev_dir, table) - host_table_dir = os.path.join(host_dir, table) - if table_instance.host_vocabulary_size != 0: - ddr = True - out_dir = host_table_dir - else: - out_dir = device_table_dir - key, offset = self._get_hashmap(out_dir, ddr) - emb_data = self.get_embedding(device_table_dir, host_table_dir, ddr) - emb_data = emb_data[offset] - transformed_data = dict(zip(key[:], emb_data[:])) - np.save(out_dir + self.sep + self.export_name + ".npy", transformed_data) + sparse_dir = get_sparse_dir() + ddr = False + dev_dir = set_upper_dir(sparse_dir, self.device_dir_list) + host_dir = set_upper_dir(sparse_dir, self.host_dir_list) + for table in self.table_list: + table_instance = get_table_instance_by_name(table) + device_table_dir = os.path.join(dev_dir, table) + host_table_dir = os.path.join(host_dir, table) + if table_instance.host_vocabulary_size != 0: + ddr = True + out_dir = host_table_dir + else: + out_dir = device_table_dir + key, offset = self._get_hashmap(out_dir, ddr) + emb_data = self.get_embedding(device_table_dir, host_table_dir, ddr) + emb_data = emb_data[offset] + transformed_data = dict(zip(key[:], emb_data[:])) + np.save(out_dir + self.sep + self.export_name + ".npy", transformed_data) def get_embedding(self, device_table_dir, host_table_dir, ddr): emb_dir = os.path.join(device_table_dir, self.device_emb_dir) @@ -151,14 +146,6 @@ class SparseProcessor: key = raw_hashmap[:, 0] return key, offset - def _get_sparse_dirs(self): - sub_dirs = [] - for _, sub_dir, _ in os.walk(self.model_dir): - sub_dirs.append(sub_dir) - if not sub_dirs: - raise FileExistsError("There is no sparse folder in the model ") - return sub_dirs[0] - def _get_file_names(self, directory): files = [] data_file = None @@ -177,9 +164,9 @@ class SparseProcessor: return data_file, attribute_file -def export(model_dir, **kwargs): +def export(**kwargs): empty_value = 0 - SparseProcessor.set_instance(model_dir, **kwargs) + SparseProcessor.set_instance(**kwargs) if SparseProcessor.single_instance.table_list: return SparseProcessor.single_instance.export_sparse_data() else: diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 4d4612f2..d7e70fa5 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -54,6 +54,7 @@ class ConfigInitializer: self._merged_multi_lookup = dict() self._target_batch = dict() self._iterator_type = "" + self._sparse_dir = "" if self._use_mpi: logging.debug(f"Using mpi to launch task.") @@ -112,6 +113,10 @@ class ConfigInitializer: def modify_graph(self): return self._modify_graph + @property + def sparse_dir(self): + return self._sparse_dir + @property def feature_spec_dict(self): return self._feature_spec_dict @@ -362,6 +367,13 @@ class ConfigInitializer: self._modify_graph = is_modify_graph + @sparse_dir.setter + def sparse_dir(self, sparse_dir): + if not isinstance(sparse_dir, str): + raise TypeError(f"sparse_dir should be str.") + + self._sparse_dir = sparse_dir + @is_last_round.setter def is_last_round(self, last_round): if not isinstance(last_round, bool): @@ -468,6 +480,14 @@ def set_modify_graph(is_modify_graph): ConfigInitializer.get_instance().modify_graph = is_modify_graph +def set_sparse_dir(sparse_dir): + ConfigInitializer.get_instance().sparse_dir = sparse_dir + + +def get_sparse_dir(): + return ConfigInitializer.get_instance().sparse_dir + + def is_mpi_in_use(): return ConfigInitializer.get_instance().use_mpi @@ -894,4 +914,4 @@ def bind_cpu(rank_id: int, rank_size: int = None): process.cpu_affinity(cpu_list) except IndexError: logging.error(f"failed to bind cpu for rank {rank_id}: {cpu_list}") - logging.info(f"bind cpu for rank {rank_id}: {cpu_list}") \ No newline at end of file + logging.info(f"bind cpu for rank {rank_id}: {cpu_list}") diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index e14d9ccb..4327b638 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -19,8 +19,6 @@ ConstantInitializer::ConstantInitializer(int start, int len, float value, float void ConstantInitializer::GenerateData(float* const emb, const int embSize) { - LOG(INFO) << StringFormat("Device GenerateEmbData ing using Constant Initializer by value %f., start %d, len %d.", - value, start, len); if (len == 0) { return; } diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index 54cd813f..0cff333d 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -21,9 +21,6 @@ RandomNormalInitializer::RandomNormalInitializer(int start, int len, std::tuple< void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) { - LOG(INFO) << StringFormat( - "Device GenerateEmbData ing using Random Normal Initializer by mean: %f stddev: %f. " - "start %d, len %d.", mean, stddev, start, len); if (len == 0) { return; } diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index 4223b00c..85fb4a45 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -25,9 +25,6 @@ TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, std:: void TruncatedNormalInitializer::GenerateData(float* const emb, const int embSize) { - LOG(INFO) << StringFormat( - "Device GenerateEmbData ing using Truncated Normal Initializer by mean: %f stddev: %f. " - "start %d, len %d.", mean, stddev, start, len); if (len == 0) { return; } -- Gitee From 671006df64225c854ad632747c1ffb24ababda7a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 6 Sep 2023 15:52:22 +0800 Subject: [PATCH 317/551] Match-id-69256a8a1006ca264cb00f050dcfe9b0017f0844 --- mx_rec/graph/modifier.py | 92 ++++++++++++++++++++++++++++++++-------- 1 file changed, 75 insertions(+), 17 deletions(-) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 4d30a7e5..90fa8b62 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -4,10 +4,12 @@ import logging from collections import defaultdict +from typing import Any import tensorflow as tf from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter from tensorflow.python.framework.ops import Operation +from tensorflow.python.framework.errors_impl import InvalidArgumentError from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.feature_spec import FeatureSpec @@ -38,19 +40,57 @@ def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tenso "'pipeline_input_indexes' was given.") def map_func(*args): - def print_tensors(batch_id, tracker=None): - if tracker is None: - tracker = [] - if isinstance(batch_id, dict): - for key, item in batch_id.items(): - print_tensors(item, tracker + [key]) - if isinstance(batch_id, tf.Tensor): - logging.debug(f"######## tracker: {tracker}, tensor: {batch_id} ########") - - for batch in args: - print_tensors(batch) + def parse_batch(data_args: Any, data_batch: dict, key: str = None): + """ + 解析原始数据集中的batch,并将非dict格式的batch转为dict格式. + Args: + data_args: 待解析的batch + data_batch: 解析后的batch + key: batch中的key + + Returns: None + + """ + + def parse_tensor(data_tensor: tf.Tensor, data_batch: dict, key: str = None): + """ + 将待解析batch中的tensor写入解析后的batch中,如果key存在则使用原key,不存在则生成batch中字典序最小的key. + Args: + data_tensor: 待解析batch中的tensor + data_batch: 解析后的batch + key: batch中的key + + Returns: None + + """ + + if key is not None: + data_batch[key] = data_tensor + return + + last_key = f"{sorted(data_batch)[-1]}_last_key" + data_batch[last_key] = data_tensor + + # 开始解析old batch + if isinstance(data_args, dict): + for key, data_tensor in data_args.items(): + parse_batch(data_tensor, data_batch, key) + return + elif isinstance(data_args, (list, tuple)): + for data_arg in data_args: + parse_batch(data_arg, data_batch, key) + return + elif isinstance(data_args, tf.Tensor): + # 将old batch中的tensor加入到dict中 + parse_tensor(data_args, data_batch, key) + return + else: + raise ValueError("Encounter a invalid batch.") - batch = args[0] + logging.debug("In get_preprocessing_map_func, the old batch is: %s.", args) + batch = dict() + parse_batch(args, batch, key=None) + logging.debug("In get_preprocessing_map_func, the parse batch is: %s.", batch) input_tensors = [] if batch_tensor_names is not None: @@ -67,11 +107,12 @@ def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tenso tensor = graph.get_tensor_by_name("args_%d:0" % index) input_tensors.append(tensor) + # 以tf.import_graph_def()作为read emb key的输入,保证数据读取到传入lookup的ids过程中的特征处理关系能够保留在子图中。 output_list = tf.import_graph_def(graph_def, input_map=dict(zip(input_names, input_tensors)), return_elements=output_names) - output_batch = list(args) - output_batch.append(tuple(output_list)) + output_batch = [batch, tuple(output_list)] + logging.debug("In get_preprocessing_map_func, the output batch is: %s.", output_batch) return tuple(output_batch) return map_func @@ -233,7 +274,19 @@ def get_sub_graph(input_tensors, output_tensors): return sub_graph_def, input_name_list, output_name_list -def update_input_tensor_with_new_batch(replacement_specs, new_get_next_op_name): +def update_input_tensor_with_new_batch(replacement_specs: dict, new_get_next_op_name: str, new_batch: dict): + """ + 用新batch中的IteratorGetNext替换计算图中老batch的IteratorGetNext. + + Args: + replacement_specs: 记录待替换算子的dict,key为老batch的IteratorGetNext,value为以老batch作为输入的算子 + new_get_next_op_name: 新数据集的get_next算子名称 + new_batch: 新数据集的batch + + Returns: None + + """ + graph = tf.compat.v1.get_default_graph() for old_tensor, item in replacement_specs.items(): for idx, operator in item: @@ -241,7 +294,12 @@ def update_input_tensor_with_new_batch(replacement_specs, new_get_next_op_name): output_index = old_tensor_name.split(":")[-1] new_tensor_name = f"{new_get_next_op_name}:{output_index}" new_tensor = graph.get_tensor_by_name(new_tensor_name) - operator._update_input(idx, new_tensor) + try: + operator._update_input(idx, new_tensor) + except InvalidArgumentError as err: + logging.info("The replacement specs keys (old batch) is: %s. \n\t\t" + "The new batch is: %s.", replacement_specs.keys(), new_batch) + raise RuntimeError(f"Cannot update edge, old tensor: {old_tensor}, new tensor: {new_tensor}.") from err def get_dataset_tensor_count(dataset: DatasetV1Adapter) -> int: @@ -414,7 +472,7 @@ def update_iterator_getnext(get_next_op: Operation, tgt_dataset: DatasetV1Adapte except IndexError as err: raise IndexError("Cannot find a tensor from given batch.") from err new_get_next_op_name = find_target_dataset_op(new_batch_tensor.op, "IteratorGetNext").name - update_input_tensor_with_new_batch(records.get("replacement_specs"), new_get_next_op_name) + update_input_tensor_with_new_batch(records.get("replacement_specs"), new_get_next_op_name, new_batch) @performance("graph_modifier") -- Gitee From d905c376067eab96882c45a934f3acab5853cd86 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 6 Sep 2023 20:49:02 +0800 Subject: [PATCH 318/551] Match-id-a8a9b840b7072146066f6c7a271da0dbd4d1c1c9 --- mx_rec/util/ops.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py index 90bea48e..920c8e70 100644 --- a/mx_rec/util/ops.py +++ b/mx_rec/util/ops.py @@ -11,11 +11,7 @@ from mx_rec.constants.constants import HOST_PIPELINE_OPS_LIB_PATH def import_host_pipeline_ops(): - host_pipeline_ops_lib_path = os.getenv(HOST_PIPELINE_OPS_LIB_PATH) - if host_pipeline_ops_lib_path and os.path.exists(host_pipeline_ops_lib_path): - logging.debug(f"Using the HOST_PIPELINE_OPS_LIB_PATH '{host_pipeline_ops_lib_path}' to get ops lib.") - return tf.load_op_library(host_pipeline_ops_lib_path) - elif os.path.exists( + if os.path.exists( os.path.join(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), 'mx_rec/libasc/libasc_ops.so')): default_so_path = os.path.join( @@ -24,5 +20,4 @@ def import_host_pipeline_ops(): logging.debug(f"Using the DEFAULT PATH '{default_so_path}' to get ops lib.") return tf.load_op_library(default_so_path) else: - raise ValueError("Invalid host pipeline ops lib path. Please check if libasc_ops.so exists or corrected " - "configured") + raise ValueError("Please check if libasc_ops.so exists (mxRec correctly installed)") -- Gitee From 653c2514be1a796756a4da2fea0e92c8225e91e5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 6 Sep 2023 21:13:07 +0800 Subject: [PATCH 319/551] Match-id-d83e0399b8bce947d65e9c4076d61ab9a8c7fa19 --- mx_rec/core/asc/build_graph.py | 30 +- mx_rec/core/embedding.py | 25 +- mx_rec/util/initialize.py | 2 + src/core/emb_hashmap/emb_hashmap.cpp | 63 ++-- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 124 +++----- src/core/hybrid_mgmt/hybrid_mgmt.h | 10 +- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 268 ++++++++++++++++++ src/core/hybrid_mgmt/hybrid_mgmt_block.h | 81 ++++++ src/core/key_process/key_process.cpp | 16 ++ src/core/key_process/key_process.h | 2 +- src/ops_tf/hybrid_dataset_ops.cpp | 21 +- .../hybrid_mgmt/hybrid_mgmt_block_test.cpp | 137 +++++++++ 12 files changed, 647 insertions(+), 132 deletions(-) create mode 100644 src/core/hybrid_mgmt/hybrid_mgmt_block.cpp create mode 100644 src/core/hybrid_mgmt/hybrid_mgmt_block.h create mode 100644 src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index d2540082..916e4eb2 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -34,21 +34,21 @@ def get_restore_vector(config): restore_size = config.get("batch_size") * config.get("feat_cnt") else: restore_size = None - - with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): - if use_hot and emb_size: - device_id = int(config.get("device_id")) - hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) - restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32, tf.int32], - output_shapes=[restore_size, [hot_size]], - channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}' - ) - else: - restore_vector = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32], - output_shapes=[restore_size], - channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}')[0] + with tf.control_dependencies([config.get("notify_hybridmgmt_op")]): + with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + if use_hot and emb_size: + device_id = int(config.get("device_id")) + hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) + restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[restore_size, [hot_size]], + channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}' + ) + else: + restore_vector = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[restore_size], + channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}')[0] return restore_vector, hot_pos diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index d16ccae0..08db37ab 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -26,6 +26,7 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set from mx_rec.validator.validator import ClassValidator, StringValidator +from mx_rec.util.tf_version_adapter import npu_ops def check_ssd_relate_param(host_vocabulary_size, ssd_vocabulary_size, ssd_data_path): @@ -156,6 +157,7 @@ def sparse_lookup(hashtable, ids, send_count, is_train, **kwargs): kwargs["is_train"] = is_train check_lookup_kwargs() scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) + with tf.compat.v1.variable_scope(scope_name): if hashtable.mode != MxRecMode.ASC: raise EnvironmentError("Invalid MxRec Mode.") @@ -641,6 +643,22 @@ class SparseEmbedding: dest_shape = array_ops.concat([array_ops.shape(tensor_list[idx]), [self.scalar_emb_size]], 0) self.lookup_result[one_feature_spec.name][is_training] = array_ops.reshape(one_result, dest_shape) + def generate_lookup_id_notify_hybrid(self, channel_id: int): + + """ + Args: + channel_id: channel id 0 for train,1 for eval + Returns: npu_ops.outfeed_enqueue_op notify preprocess step + """ + sparse_lookup_id = ConfigInitializer.get_instance().notify_hybrid_channel_sparse_id[channel_id] + notify_message = tf.constant([sparse_lookup_id], dtype=tf.int32) + ConfigInitializer.get_instance().notify_hybrid_channel_sparse_id[channel_id] += 1 + channel_name = "d2h_notify_hybridmgmt_{}".format(channel_id) + logging.debug("%s was built for op outfeed sparse id : %s.", channel_name, sparse_lookup_id) + notify_hybridmgmt_op = npu_ops.outfeed_enqueue_op( + channel_name=channel_name, inputs=[notify_message]) + return notify_hybridmgmt_op + def lookup_for_asc_with_feature_spec_inner(self, feature_spec: FeatureSpec, send_count: int, **kwargs): """ Args: @@ -668,12 +686,15 @@ class SparseEmbedding: channel_id = get_training_mode_channel_id(is_training=is_training) logging.debug(f"get preprocessed tensor for asc for table {self.table_name} with skip emb transfer " f"{self.skip_emb_transfer} is_training: {is_training}, channel_id: {channel_id} .") - + # 通知c++侧此处开始执行sparse look up的逻辑,注意此处每一个tablename做一次 + notify_hybridmgmt_op = self.generate_lookup_id_notify_hybrid(channel_id) + # 将notify_hybridmgmt_op加入到config中,在restore的get next 算子中做控制依赖 config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, rank_size=rank_size, channel_id=channel_id, table_name=self.table_name, skip_emb_transfer=self.skip_emb_transfer, ext_emb_size=self.ext_emb_size, emb_size=self.emb_size, use_hot=use_hot, device_id=device_id, - use_dynamic_expansion=use_dynamic_expansion, gradients_strategy=self.apply_gradients_strategy) + use_dynamic_expansion=use_dynamic_expansion, gradients_strategy=self.apply_gradients_strategy, + notify_hybridmgmt_op=notify_hybridmgmt_op) if self.skip_emb_transfer: result = get_preprocessed_tensor_for_asc(self.variable, config) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 4d4612f2..53b8cb40 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -76,6 +76,8 @@ class ConfigInitializer: if kwargs.get("bind_cpu", True): bind_cpu(self._rank_id, self._rank_size) self.enable_table_merge = True if os.getenv("TF_DEVICE") == "NPU" else False + # 两个通道的sparse look id,用于通讯的标识 + self.notify_hybrid_channel_sparse_id = [0, 0] def __del__(self): self.terminate() diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 9e0e3d52..1b68536d 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -9,8 +9,10 @@ #include #include #include -#include "hd_transfer/hd_transfer.h" + #include "checkpoint/checkpoint.h" +#include "hd_transfer/hd_transfer.h" +#include "hybrid_mgmt/hybrid_mgmt_block.h" #include "utils/common.h" using namespace MxRec; @@ -225,30 +227,45 @@ void EmbHashMap::FindAndUpdateBatchId(vector& keys, size_t currentBat auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map { + VLOG(GLOG_DEBUG) << (HYBRID_BLOCKING + " start GetHashMaps"); + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); auto embHashMapsOld = embHashMaps; - for (auto& temp: embHashMapsOld) { - auto& embTableName = temp.first; - auto& embHashMap = temp.second; - vector hbm2DdrKeys; - vector ddr2HbmKeys; - for (auto& swapKeys: embHashMap.oldSwap) { - emb_key_t oldKey = swapKeys.first; - emb_key_t key = swapKeys.second; - int tempOffset = static_cast(embHashMap.hostHashMap[key]); - embHashMap.hostHashMap[key] = embHashMap.hostHashMap[oldKey]; - embHashMap.hostHashMap[oldKey] = static_cast(tempOffset); - hbm2DdrKeys.emplace_back(key); - ddr2HbmKeys.emplace_back(oldKey); - } - embHashMap.maxOffset = embHashMap.maxOffsetOld; - for (auto& Offset2Key: embHashMap.devOffset2KeyOld) { - embHashMap.devOffset2Key[Offset2Key.first] = Offset2Key.second; - } - if (isSSDEnabled) { - // 恢复CacheManager中频次数据 - cacheManager->RefreshFreqInfoCommon(embTableName, hbm2DdrKeys, TransferType::HBM_2_DDR); - cacheManager->RefreshFreqInfoCommon(embTableName, ddr2HbmKeys, TransferType::DDR_2_HBM); + int checkResult = hybridMgmtBlock->CheckSaveEmbdMapValid(); + if (checkResult == 0) { + // 检查是否需要回退 + return embHashMapsOld; + } + if (checkResult == 1) { + // 回退一步 + for (auto& temp: embHashMapsOld) { + auto &embTableName = temp.first; + auto &embHashMap = temp.second; + vector hbm2DdrKeys; + vector ddr2HbmKeys; + for (auto &swapKeys: embHashMap.oldSwap) { + emb_key_t oldKey = swapKeys.first; + emb_key_t key = swapKeys.second; + int tempOffset = static_cast(embHashMap.hostHashMap[key]); + embHashMap.hostHashMap[key] = embHashMap.hostHashMap[oldKey]; + embHashMap.hostHashMap[oldKey] = static_cast(tempOffset); + hbm2DdrKeys.emplace_back(key); + ddr2HbmKeys.emplace_back(oldKey); + } + embHashMap.maxOffset = embHashMap.maxOffsetOld; + for (auto &Offset2Key: embHashMap.devOffset2KeyOld) { + embHashMap.devOffset2Key[Offset2Key.first] = Offset2Key.second; + } + if (isSSDEnabled) { + // 恢复CacheManager中频次数据 + cacheManager->RefreshFreqInfoCommon(embTableName, hbm2DdrKeys, TransferType::HBM_2_DDR); + cacheManager->RefreshFreqInfoCommon(embTableName, ddr2HbmKeys, TransferType::DDR_2_HBM); + } } + return embHashMapsOld; + } + // 此时需要回退2步,无法满足此条件,保存的东西错误,直接回退 + if (not rankInfo.noDDR) { + throw HybridMgmtBlockingException("EmbHashMap::GetHashMaps() "); } return embHashMapsOld; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index f47569eb..f8ed8d58 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -124,6 +124,9 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, hdTransfer = Singleton::GetInstance(); hdTransfer->Init(embInfos, rankInfo.deviceId); + hybridMgmtBlock = Singleton::GetInstance(); + hybridMgmtBlock->SetRankInfo(rankInfo); + hybridMgmtBlock->StartNotifySignalMonitor(); // 启动数据处理线程 bool rc = InitKeyProcess(rankInfo, embInfos, thresholdValues, seed); if (!rc) { @@ -473,13 +476,13 @@ void HybridMgmt::Start() if (!mgmtRankInfo.noDDR) { auto parseKeysTaskForTrain = [this]() { - TaskForTrain(TaskType::DDR); + TrainTask(TaskType::DDR); LOG(INFO) << StringFormat("parseKeysTaskForTrain done"); }; procThreads.emplace_back(std::make_unique(parseKeysTaskForTrain)); auto parseKeysTaskForEval = [this]() { - TaskForEval(TaskType::DDR); + EvalTask(TaskType::DDR); LOG(INFO) << StringFormat("parseKeysTaskForEval done"); }; procThreads.emplace_back(std::make_unique(parseKeysTaskForEval)); @@ -492,13 +495,13 @@ void HybridMgmt::InsertThreadForHBM() { #ifndef GTEST auto parseKeysTaskForHBMTrain = [this]() { - TaskForTrain(TaskType::HBM); + TrainTask(TaskType::HBM); LOG(INFO) << "parseKeysTaskForHBMTrain done"; }; procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMTrain)); auto parseKeysTaskForHBMEval = [this]() { - TaskForEval(TaskType::HBM); + EvalTask(TaskType::HBM); LOG(INFO) << "parseKeysTaskForHBMEval done"; }; procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMEval)); @@ -506,110 +509,71 @@ void HybridMgmt::InsertThreadForHBM() } #ifndef GTEST -/// 启动训练数据处理线程 -/// \param type 存储模式 -void HybridMgmt::TaskForTrain(TaskType type) -{ - bool isFirstIn = true; - while (isRunning) { - if (isFirstIn) { - LOG(INFO) << StringFormat(MGMT + "Start Train Task: %d", type); - isFirstIn = false; - } - if (mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[TRAIN_CHANNEL_ID] > 0) { - if (!TrainTask(type)) { - return; - } - } - this_thread::sleep_for(1ms); - } -} - -/// 启动推理数据处理线程 -/// \param type 存储模式 -void HybridMgmt::TaskForEval(TaskType type) +/// 启动hybrid处理任务 +/// \param type +void HybridMgmt::TrainTask(TaskType type) { - bool isFirstIn = true; - while (isRunning) { - if (isFirstIn) { - LOG(INFO) << StringFormat(MGMT + "Start Eval Task: %d", type); - isFirstIn = false; - } - if (mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] == -1 || mgmtRankInfo.maxStep[EVAL_CHANNEL_ID] > 0) { - if (!EvalTask(type)) { - return; - } - } - this_thread::sleep_for(1ms); - } -} - -/// 训练数据处理:数据处理状态正常,处理的batch数小于用户预设值或者设为-1时,会循环处理; -/// \param type 存储模式 -/// \return -bool HybridMgmt::TrainTask(TaskType type) -{ - bool isContinue; - bool status; + int channelId = TRAIN_CHANNEL_ID; + int& theTrainBatchId = hybridMgmtBlock->hybridBatchId[channelId]; do { + hybridMgmtBlock->CheckAndSetBlock(channelId); + if (hybridMgmtBlock->GetBlockStatus(channelId)) { + hybridMgmtBlock->DoBlock(channelId); + } if (!isRunning) { - return false; + return; } + LOG(INFO) << StringFormat(HYBRID_BLOCKING + + "hybrid start task channel %d batch %d", channelId, theTrainBatchId); switch (type) { case TaskType::HBM: - status = ParseKeysHBM(TRAIN_CHANNEL_ID, trainBatchId); - isContinue = !EndBatch(trainBatchId, TRAIN_CHANNEL_ID); - LOG(INFO) << StringFormat(MGMT + "ParseKeysHBMBatchId = %d", trainBatchId); + ParseKeysHBM(TRAIN_CHANNEL_ID, theTrainBatchId); + + LOG(INFO) << StringFormat(MGMT + "ParseKeysHBMBatchId = %d", theTrainBatchId); break; case TaskType::DDR: - status = ParseKeys(TRAIN_CHANNEL_ID, trainBatchId); - isContinue = !EndBatch(trainBatchId, TRAIN_CHANNEL_ID); - LOG(INFO) << StringFormat(MGMT + "parseKeysBatchId = %d", trainBatchId); + ParseKeys(TRAIN_CHANNEL_ID, theTrainBatchId); + + LOG(INFO) << StringFormat(MGMT + "parseKeysBatchId = %d", theTrainBatchId); break; default: throw std::invalid_argument("Invalid TaskType Type."); } - - if (!status) { - return false; - } - } while (isContinue); - - return true; + } while (true); } /// 推理数据处理:数据处理状态正常,处理的batch数小于用户预设值或者设为-1时,会循环处理; /// \param type 存储模式 /// \return -bool HybridMgmt::EvalTask(TaskType type) +void HybridMgmt::EvalTask(TaskType type) { - int evalBatchId = 0; // 0-99, 0-99 + int channelId = EVAL_CHANNEL_ID; + int& evalBatchId = hybridMgmtBlock->hybridBatchId[channelId]; do { + hybridMgmtBlock->CheckAndSetBlock(channelId); + if (hybridMgmtBlock->GetBlockStatus(channelId)) { + hybridMgmtBlock->DoBlock(channelId); + } if (!isRunning) { - return false; + return; } - bool status = false; + LOG(INFO) << StringFormat(HYBRID_BLOCKING + + "hybrid start task channel %d batch %d", channelId, evalBatchId); switch (type) { case TaskType::HBM: - status = ParseKeysHBM(EVAL_CHANNEL_ID, evalBatchId); + ParseKeysHBM(EVAL_CHANNEL_ID, evalBatchId); LOG(INFO) << StringFormat(MGMT + "HBM evalBatchId = %d", evalBatchId); break; case TaskType::DDR: - status = ParseKeys(EVAL_CHANNEL_ID, evalBatchId); + ParseKeys(EVAL_CHANNEL_ID, evalBatchId); LOG(INFO) << StringFormat(MGMT + "DDR evalBatchId = %d", evalBatchId); break; default: throw std::invalid_argument("Invalid TaskType Type."); } - - if (!status) { - return false; - } - } while (!EndBatch(evalBatchId, EVAL_CHANNEL_ID)); - - return true; + } while (true); } /// HBM模式下,发送key process线程已处理好的各类型向量到指定通道中 @@ -714,9 +678,7 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) // 通道数据已空 if (!remainBatch) { - TimeCost embHdTrans1; - EmbHDTransWrap(channelId, batchId, start, iBatch); - VLOG(GLOG_DEBUG) << StringFormat("embHdTrans1TC TimeCost(ms):%d", embHdTrans1.ElapsedMS()); + VLOG(GLOG_DEBUG) << StringFormat("last batch ending"); return false; } } @@ -752,6 +714,7 @@ inline void HandlePrepareDDRDataRet(TransferRet prepareSSDRet) } #ifndef GTEST + /// 构造训练所需的各种向量数据 /// \param embName 表名 /// \param batchId 已处理的batch数 @@ -771,7 +734,10 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, // 获取查询向量 auto lookupKeys = preprocess->GetLookupKeys(batchId, embName, channelId); - if (lookupKeys.empty()) { remainBatchOut = false; } + if (lookupKeys.empty()) { + remainBatchOut = false; + return false; + } // 获取各类向量,如果为空指针,退出当前函数 auto infoVecs = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 5889c564..df5d60de 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -75,12 +75,13 @@ namespace MxRec { return; } // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 - isRunning = false; + isRunning = false; // 先发送停止信号给preprocess,用于停止查询中lookup卡住状态 preprocess->isRunning = false; // 停止hdTransfer,用于停止mgmt的recv中卡住状态 hdTransfer->Destroy(); + hybridMgmtBlock->Destroy(); for (auto& t : procThreads) { t->join(); } @@ -132,6 +133,7 @@ namespace MxRec { int trainBatchId = 0; // 0-199, 200- int getInfoBatchId; // 0-199, 200- int sendBatchId; + HybridMgmtBlock* hybridMgmtBlock; vector mgmtEmbInfo; RankInfo mgmtRankInfo; CacheManager* cacheManager; @@ -145,10 +147,8 @@ namespace MxRec { bool isRunning; bool isLoad { false }; - void TaskForTrain(TaskType type); - void TaskForEval(TaskType type); - bool TrainTask(TaskType type); - bool EvalTask(TaskType type); + void TrainTask(TaskType type); + void EvalTask(TaskType type); bool EndBatch(int batchId, int channelId) const; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp new file mode 100644 index 00000000..910f8176 --- /dev/null +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: hybrid mgmt module,Record the number of program running steps, + * manage blocking and wakeup + * Author: MindX SDK + * Date: 2023/08/15 + */ +#include + +#include "utils/common.h" +#include "hybrid_mgmt_block.h" + +/// 检查当前hybrid是否运行到了应该阻塞的位置 +/// \param channelId train 0 eval 1 +void HybridMgmtBlock::CheckAndSetBlock(int channelId) +{ + if (stepsInterval[channelId] == -1) { + return; + } + if (stepsInterval[channelId] == 0) { + // 为0应该阻塞,并且避免下面除0的逻辑 + isBlock[channelId] = true; + return; + } + if (hybridBatchId[channelId] % stepsInterval[channelId] == 0) { + isBlock[channelId] = true; + } +} + +/// 检查当前是否进行了数据通道切换,如果进行了数据通道切换则进行参数校验 +/// 通过python侧的batchId和hybrid的batchId当前的步数是否到了唤醒阻塞线程的步数 +/// \param channelId train 0 eval 1 +void HybridMgmtBlock::CheckAndNotifyWake(int channelId) +{ + CheckValid(channelId); + if (pythonBatchId[channelId] >= hybridBatchId[channelId]) { + isBlock[channelId] = false; + } +} + +/// 如果检查参数不合理,涉及到抛出异常,需要先等待,有可能是数据传输未完成。 +/// \param channelId train 0 eval 1 +bool HybridMgmtBlock::WaitValid(int channelId) +{ + // 等待hybrid处理完成 + int reTryNumber = 100; + VLOG(INFO) << StringFormat(HYBRID_BLOCKING + + "check step invalid, wait", channelId, hybridBatchId[channelId]); + // 等待hybrid处理完成后再一次唤醒 + while (pythonBatchId[lastRunChannelId] != hybridBatchId[lastRunChannelId] and isRunning) { + std::this_thread::sleep_for(std::chrono::milliseconds(10ms)); + reTryNumber--; + if (reTryNumber <= 0) { + break; + } + } + + if (pythonBatchId[channelId] == hybridBatchId[channelId]) { + return true; + } else { + // 如果等待python侧处理较长时间后hybrid依旧无法追赶上python则异常 + return false; + } +} + +void HybridMgmtBlock::CountPythonStep(int channelId) +{ + // 相应的通知计数 + pythonBatchId[channelId]++; +} + +/// 检查是否进行了通道切换,检查当前的step是否合理 +/// \param channelId +void HybridMgmtBlock::CheckValid(int channelId) +{ + // 通道没有切换,不用处理 + if (lastRunChannelId == channelId) { + return; + } + // 当python侧第一次调用时,此时跳过参数检查 + if (lastRunChannelId == -1) { + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + + "The data channel was called for the first time, and the parameters were " + "checked to be normal channelId %d hybridBatchId %d", channelId, hybridBatchId[channelId]); + + lastRunChannelId = channelId; + return; + } + // 在通道切换时,hybrid预处理的batch与python的一致。 + if (pythonBatchId[lastRunChannelId] == hybridBatchId[lastRunChannelId]) { + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + + "HybridMgmt is switching data channels and checking for normal parameters. he number of steps " + "in the previous round is lastRunChannelId %d pythonBatchId %d hybridBatchId %d", + lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); + } else if (pythonBatchId[lastRunChannelId] < hybridBatchId[lastRunChannelId]) { + // 在通道切换时,上一个通道处理的数据超出了python侧的调用 + if (!WaitValid(lastRunChannelId)) { + throw HybridMgmtBlockingException("when channel switch"); + } + } else { + // 在通道切换时,hybrid处理的数据还没有赶上python侧,此时需要等待hybrid处理完成 + VLOG(INFO) << StringFormat(HYBRID_BLOCKING + + "When switching data channels, it was found that HybridMgmt processed less data than the " + "Python side.In this case, after reading the dataset, the Python side called it again, but it was " + "interrupted midway,which did not affect the subsequent calls lastRunChannelId %d hybridBatchId %d", + lastRunChannelId, hybridBatchId[lastRunChannelId]); + } + lastRunChannelId = channelId; + return; +} + +/// 进行阻塞操作 +/// \param channelId train 0 eval 1 +void HybridMgmtBlock::DoBlock(int channelId) +{ + // 通道没有切换,不用处理 + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + + "HybridMgmt starts blocking channelId %d hybridBatchId %d", channelId, hybridBatchId[channelId]); + + while (isBlock[channelId]) { + std::this_thread::sleep_for(SLEEP_MS); + if (!isRunning) { + return; + } + } + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + + "HybridMgmt is starting to wake up channelId %d hybridBatchId %d", channelId, hybridBatchId[channelId]); + return; +} + +/// 重置所有的步数,主要用于图重构的情况,readembedkey算子重建 +/// \param channelId channelId train 0 eval 1 +void HybridMgmtBlock::ResetAll(int channelId) +{ + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + + "Hybridmgmt is resetting data channelId %d hybridBatchId %d", channelId, hybridBatchId[channelId]); + + readEmbedBatchId[channelId] = 0; + pythonBatchId[channelId] = 0; + hybridBatchId[channelId] = 0; + isBlock[channelId] = true; + // eval train通道的sparse 同时进行重置,以防出现sparse id失效的问题 + uniqueSparseLookID[EVAL_CHANNEL_ID] = -1; + uniqueSparseLookID[TRAIN_CHANNEL_ID] = -1; +} + +/// 检查当前的步数是否可以进行save +/// \return 0 is legal, 1 需要回退一步, -1 表示错误 +int HybridMgmtBlock::CheckSaveEmbdMapValid() +{ + // 检查数据通道此时的HashMap是否被提前处理了 + if (pythonBatchId[lastRunChannelId] >= hybridBatchId[lastRunChannelId]) { + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + + "HybridMgmt is checking the step and checking that the parameters are normal. " + "The number of steps in the previous round is " + "lastRunChannelId %d pythonBatchId %d hybridBatchId %d", + lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); + return 0; + } else if (pythonBatchId[lastRunChannelId] + 1 == hybridBatchId[lastRunChannelId]) { + // 在通道切换时,上一个通道处理的数据超出了python侧的调用 + VLOG(INFO) << StringFormat(HYBRID_BLOCKING + + "HybridMgmt is checking the step, and the parameters have been processed one step " + "in advance. The number of steps in the previous round was " + "lastRunChannelId %d pythonBatchId %d hybridBatchId %d", + lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); + + return 1; + } else { + // 在通道切换时,hybrid处理的数据还没有赶上python侧,此时需要等待hybrid处理完成 + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + "ERROR FLAG lastRunChannelId %d hybridBatchId %d", + lastRunChannelId, hybridBatchId[lastRunChannelId]); + return -1; + } +} + +bool HybridMgmtBlock::GetBlockStatus(int channelId) +{ + return isBlock[channelId]; +} + +void HybridMgmtBlock::SetBlockStatus(int channelId, bool block) +{ + isBlock[channelId] = block; +} + +/// python侧调用的npu.outfeed_enqueue_op 发送的消息。用来判断当前python执行的步数 +void HybridMgmtBlock::StartNotifySignalMonitor() +{ +#ifndef GTEST + auto fn = [this](int channelId) { + while (isRunning) { + std::vector tensors; + tensorflow::Status status = tensorflow::RecvTensorByAcl(aclHandles[channelId], tensors); + if (!isRunning) { + break; + } + if (status != tensorflow::Status::OK()) { + LOG(ERROR) << StringFormat(HYBRID_BLOCKING + + "%s hd recv error '%s'", d2hChannelName[channelId].c_str(), status.error_message().c_str()); + throw runtime_error("rev error"); + } + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + + "send message to hybrid channelId %d pythonBatchId %d hybridBatchId %d", + channelId, pythonBatchId[channelId], hybridBatchId[channelId]); + + int sparseLookupId = *tensors[0].flat().data(); + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + + "send sparse_lookup_id channel %d sparse id %d unique id %d", + channelId, sparseLookupId, uniqueSparseLookID[channelId]); + + if (uniqueSparseLookID[channelId] == -1) { + // 初始化,只有第一个sparse loop id能进行计数和唤 + uniqueSparseLookID[channelId] = sparseLookupId; + } + // 只被计数一次 + if (sparseLookupId == uniqueSparseLookID[channelId]) { + // 只有最先来的id才能进行唤醒和计数 + CheckAndNotifyWake(channelId); + CountPythonStep(channelId); + } + } + LOG(INFO) << StringFormat(HYBRID_BLOCKING + "BLOCKING thread stop"); + }; + uint32_t localRankId = rankInfo.deviceId; + for (int channelId = 0; channelId < MAX_CHANNEL_NUM; ++channelId) { + d2hChannelName[channelId] = StringFormat(D2H_CHANNEL_NAME_PRE + "%d", channelId); + auto aclChannelHandle = tdtCreateChannel(localRankId, d2hChannelName[channelId].c_str(), PING_PONG_SIZE); + LOG(INFO) << StringFormat(HYBRID_BLOCKING + " %d %s", localRankId, d2hChannelName[channelId].c_str()); + aclHandles[channelId] = aclChannelHandle; + procThreads.emplace_back(std::make_unique(fn, channelId)); + } +#endif +} + +void HybridMgmtBlock::Destroy() +{ + if (!isRunning) { + // 已经销毁过了,不用再次销毁会报错 + return; + } + isRunning = false; +#ifndef GTEST + for (int channelId = 0; channelId < MAX_CHANNEL_NUM; ++channelId) { + tensorflow::StopRecvTensorByAcl(&aclHandles[channelId], d2hChannelName[channelId]); + procThreads[channelId]->join(); + } + LOG(INFO) << StringFormat(HYBRID_BLOCKING + "BLOCKING stop"); +#endif +} + + +void HybridMgmtBlock::SetRankInfo(RankInfo rankInfo) +{ + this->stepsInterval[0] = rankInfo.maxStep[0]; + this->stepsInterval[1] = rankInfo.maxStep[1]; + this->rankInfo = rankInfo; +}; + +void HybridMgmtBlock::SetStepInterval(int trainStep, int evalStep) +{ + this->stepsInterval[0] = trainStep; + this->stepsInterval[1] = evalStep; +}; + +HybridMgmtBlock::~HybridMgmtBlock() +{ + Destroy(); +} \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h new file mode 100644 index 00000000..e72a8523 --- /dev/null +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: hybrid mgmt module,Record the number of program running steps, + * manage blocking and wakeup + * Author: MindX SDK + * Date: 2023/08/15 + */ +#ifndef MX_REC_HYBRID_BLOCKING_H +#define MX_REC_HYBRID_BLOCKING_H +#include +#include + +#include "hd_transfer/hd_transfer.h" +#include "utils/common.h" +#include "utils/singleton.h" +using namespace MxRec; +const std::string HYBRID_BLOCKING = "[HYBRID_BLOCKING] "; +const std::string D2H_CHANNEL_NAME_PRE = "d2h_notify_hybridmgmt_"; +const std::chrono::milliseconds SLEEP_MS = 20ms; +class HybridMgmtBlock { +public: + // 上一次运行的通道ID + int lastRunChannelId = -1; + // hybrid将要处理的batch id + int hybridBatchId[2] = {0, 0}; + // python侧将要处理的batch id + int pythonBatchId[2] = {0, 0}; + // readEmbed算子侧将要处理的batch id + int readEmbedBatchId[2] = {0, 0}; + bool isRunning = true; + // 每个sparse lookup都会生成一个唯一的id,保证每次运行只有一个id在进行计数 + int uniqueSparseLookID[2]{-1, -1}; + + ~HybridMgmtBlock(); + void CheckAndNotifyWake(int channelId); + void CountPythonStep(int channelId); + void CheckAndSetBlock(int channelId); + void CheckValid(int channelId); + void DoBlock(int channelId); + void ResetAll(int channelId); + int CheckSaveEmbdMapValid(); + bool GetBlockStatus(int channelId); + void SetBlockStatus(int channelId, bool block); + void SetRankInfo(RankInfo rankInfo); + void SetStepInterval(int trainStep, int evalStep); + void StartNotifySignalMonitor(); + bool WaitValid(int channelId); + void Destroy(); +private: + // 通道i运行多少步后切换为通道j + int stepsInterval[2] = {0, 0}; + // 控制通道阻塞的变量 + bool isBlock[2] = {true, true}; + string d2hChannelName[2]; + RankInfo rankInfo; + acltdtChannelHandle* aclHandles[2]; + std::vector> procThreads {}; +}; + +class HybridMgmtBlockingException : public std::exception { +public: + explicit HybridMgmtBlockingException(const string scene) + { + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + // int channelId, int preprocessBatchNumber, int currentBatchNumber + int channelId = hybridMgmtBlock->lastRunChannelId; + int preprocessBatchNumber = hybridMgmtBlock->hybridBatchId[channelId]; + int currentBatchNumber = hybridMgmtBlock->pythonBatchId[channelId]; + str = StringFormat("Error happened at HyBridmgmt Blocking, it finds that " + "preprocess batch number not match current using batch number " + "%s , last use channel id is %d, preprocessBatchNumber is %d ," + "currentBatchNumber is %d. please check your setting of train " + "steps and eval steps", scene.c_str(), channelId, preprocessBatchNumber, + currentBatchNumber); + LOG(ERROR) << str; + } + +private: + string str; +}; +#endif \ No newline at end of file diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a2005bdd..20fca8ff 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1194,6 +1194,14 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) if (!isRunning) { return {}; } + // 判断此时的batch id是否已经过期,即通道已经刷新 + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + if (batch != hybridMgmtBlock->hybridBatchId[channel]) { + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "Detected that the batch has expired at this time, exiting the loop! %s[%d]:%d", + embName.c_str(), channel, batch); + return {}; + } if (batch != 0 && channel != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { LOG(WARNING) << StringFormat( KEY_PROCESS "getting lookup keys timeout! %s[%d]:%d", embName.c_str(), channel, batch); @@ -1241,6 +1249,14 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa if (!isRunning) { return nullptr; } + // 判断此时的batch id是否已经过期,即通道已经刷新 + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + if (batch != hybridMgmtBlock->hybridBatchId[channel]) { + VLOG(GLOG_DEBUG) << StringFormat( + KEY_PROCESS "Detected that the batch has expired at this time, exiting the loop! %s[%d]:%d", + embName.c_str(), channel, batch); + return nullptr; + } if (batch != 0 && channel != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { LOG(WARNING) << StringFormat( KEY_PROCESS "getting lookup keys timeout! %s[%d]:%d", embName.c_str(), channel, batch); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index d6b9df3d..65fbaae9 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -32,7 +32,7 @@ #include "emb_table/emb_table.h" #include "feature_admit_and_evict.h" - +#include "hybrid_mgmt/hybrid_mgmt_block.h" namespace MxRec { using namespace std; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index e1381d29..533a4442 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -64,7 +64,8 @@ public: void Compute(OpKernelContextPtr context) override { LOG(INFO) << StringFormat("clear channel %d, context %d", channelId, context->step_id()); - batchIdsInfo.at(channelId) = 0; + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + hybridMgmtBlock->ResetAll(channelId); } private: @@ -122,7 +123,7 @@ public: OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); - + hybridMgmtBlock = Singleton::GetInstance(); // 特征准入&淘汰功能 相关校验 // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 @@ -141,7 +142,8 @@ public: MAX_CHANNEL_NUM))); return; } - batchIdsInfo.at(channelId) = 0; + VLOG(INFO) << StringFormat(HYBRID_BLOCKING + " reset channel %d", channelId); + hybridMgmtBlock->ResetAll(channelId); threadNum = GetThreadNumEnv(); auto keyProcess = Singleton::GetInstance(); @@ -158,7 +160,7 @@ public: EASY_FUNCTION(); VLOG(GLOG_DEBUG) << "enter ReadEmbKeyV2Dynamic"; TimeCost tc = TimeCost(); - int batchId = batchIdsInfo.at(channelId)++; + int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { LOG(WARNING) << StringFormat("skip excess batch after %d/%d", batchId, maxStep); @@ -296,6 +298,7 @@ public: int maxStep = 0; bool isTimestamp { false }; int threadNum = 0; + HybridMgmtBlock* hybridMgmtBlock; }; REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2Dynamic").Device(DEVICE_CPU), ReadEmbKeyV2Dynamic); @@ -325,6 +328,7 @@ public: OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); fieldNum = accumulate(splits.begin(), splits.end(), 0); + hybridMgmtBlock = Singleton::GetInstance(); // 特征准入&淘汰功能 相关校验 // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 @@ -346,7 +350,9 @@ public: "ReadEmbKeyV2 channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", MAX_CHANNEL_NUM))); return; } - batchIdsInfo.at(channelId) = 0; + VLOG(INFO) << StringFormat(HYBRID_BLOCKING + " reset channel %d", channelId); + // 重置此数据通道中所有的步数 + hybridMgmtBlock->ResetAll(channelId); threadNum = GetThreadNumEnv(); auto keyProcess = Singleton::GetInstance(); @@ -364,7 +370,7 @@ public: EASY_FUNCTION(); VLOG(GLOG_DEBUG) << "enter ReadEmbKeyV2"; TimeCost tc = TimeCost(); - int batchId = batchIdsInfo.at(channelId)++; + int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; Tensor* output = nullptr; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { @@ -499,6 +505,7 @@ public: int maxStep = 0; bool isTimestamp { false }; int threadNum = KEY_PROCESS_THREAD; + HybridMgmtBlock* hybridMgmtBlock; }; REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), ReadEmbKeyV2); @@ -616,4 +623,4 @@ REGISTER_OP("EmbeddingUpdateByAddress") return Status::OK(); }); -REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), CustOps); +REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), CustOps); \ No newline at end of file diff --git a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp new file mode 100644 index 00000000..380ed5f7 --- /dev/null +++ b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: key process test + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + +#include +#include +#include +#include +#include + +#include "hybrid_mgmt/hybrid_mgmt_block.h" +#include "utils/common.h" +using namespace MxRec; +using namespace std::chrono_literals; + +class HybridMgmtBlockTest : public testing::Test { +public: + std::unique_ptr hybridMgmtBlock; + std::vector> procThreads {}; + bool isRunning = true; +protected: + void SetUp() + { + VLOG(GLOG_DEBUG) << StringFormat("%s", "start initialize") ; + } +}; + +TEST_F(HybridMgmtBlockTest, CheckAndDoBlock) +{ + int steps[] = {-1, 1}; + hybridMgmtBlock = std::make_unique(); + hybridMgmtBlock->SetStepInterval(1, 1); + hybridMgmtBlock->CheckAndSetBlock(0); + hybridMgmtBlock->CheckAndSetBlock(1); + ASSERT_EQ(hybridMgmtBlock->GetBlockStatus(0), true); +} + +TEST_F(HybridMgmtBlockTest, CountAndNotifyWake) +{ + hybridMgmtBlock = std::make_unique(); + hybridMgmtBlock->SetStepInterval(1, 1); + hybridMgmtBlock->CheckAndNotifyWake(0); + hybridMgmtBlock->CountPythonStep(0); + hybridMgmtBlock->pythonBatchId[0] = 1; + hybridMgmtBlock->hybridBatchId[0] = 0; + auto fn = [this](int channelId) { + hybridMgmtBlock->CheckAndNotifyWake(channelId); + hybridMgmtBlock->CountPythonStep(0); + return 0; + }; + procThreads.emplace_back(std::make_unique(fn, 0)); + std::this_thread::sleep_for(std::chrono::milliseconds(2ms)); + hybridMgmtBlock->hybridBatchId[0] = 1; + for (auto p = procThreads.begin(); p != procThreads.end(); p++) { + (*p)->join(); + } +} + +TEST_F(HybridMgmtBlockTest, CheckValid) +{ + hybridMgmtBlock = std::make_unique(); + hybridMgmtBlock->SetStepInterval(1, 1); + hybridMgmtBlock->pythonBatchId[0] = 0; + hybridMgmtBlock->hybridBatchId[0] = 0; + hybridMgmtBlock->CheckValid(0); + hybridMgmtBlock->CheckValid(0); + + int step2 = 2; + hybridMgmtBlock->pythonBatchId[0] = 0; + hybridMgmtBlock->hybridBatchId[0] = step2; + hybridMgmtBlock->lastRunChannelId = 0; + try { + hybridMgmtBlock->CheckValid(1); + ASSERT_EQ(-1, 0); + } catch (HybridMgmtBlockingException e) { + VLOG(INFO) << StringFormat(HYBRID_BLOCKING + "sucess"); + ASSERT_EQ(0, 0); + } + hybridMgmtBlock->pythonBatchId[0] = 0; + hybridMgmtBlock->hybridBatchId[0] = 1; + hybridMgmtBlock->CheckValid(0); +} + +TEST_F(HybridMgmtBlockTest, DoBlock) +{ + hybridMgmtBlock = std::make_unique(); + hybridMgmtBlock->SetStepInterval(1, 1); + hybridMgmtBlock->pythonBatchId[0] = 1; + hybridMgmtBlock->hybridBatchId[0] = 1; + auto fn = [this](int channelId) { + hybridMgmtBlock->DoBlock(channelId); + return 0; + }; + procThreads.emplace_back(std::make_unique(fn, 0)); + std::this_thread::sleep_for(std::chrono::milliseconds(2ms)); + hybridMgmtBlock->SetBlockStatus(0, false); + for (auto p = procThreads.begin(); p != procThreads.end(); p++) { + (*p)->join(); + } +} + +TEST_F(HybridMgmtBlockTest, ResetAll) +{ + hybridMgmtBlock = std::make_unique(); + hybridMgmtBlock->SetStepInterval(1, 1); + hybridMgmtBlock->ResetAll(0); + ASSERT_EQ(hybridMgmtBlock->hybridBatchId[0], 0); +} + +TEST_F(HybridMgmtBlockTest, CheckSaveEmbdMapValid) +{ + hybridMgmtBlock = std::make_unique(); + hybridMgmtBlock->SetStepInterval(1, 1); + hybridMgmtBlock->lastRunChannelId = 0; + + hybridMgmtBlock->pythonBatchId[0] = 0; + hybridMgmtBlock->hybridBatchId[0] = 0; + hybridMgmtBlock->CheckSaveEmbdMapValid(); + int status0 = hybridMgmtBlock->CheckSaveEmbdMapValid(); + + hybridMgmtBlock->pythonBatchId[0] = 0; + hybridMgmtBlock->hybridBatchId[0] = 1; + hybridMgmtBlock->CheckSaveEmbdMapValid(); + int status1 = hybridMgmtBlock->CheckSaveEmbdMapValid(); + + int step2 = 2; + hybridMgmtBlock->pythonBatchId[0] = 0; + hybridMgmtBlock->hybridBatchId[0] = step2; + int status2 = hybridMgmtBlock->CheckSaveEmbdMapValid(); + ASSERT_EQ(status0, 0); + ASSERT_EQ(status1, 1); + ASSERT_EQ(status2, -1); +} \ No newline at end of file -- Gitee From 9ec79ad9b188032f83b6f3abc3ea440d2e41ced5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 4 Sep 2023 11:02:56 +0800 Subject: [PATCH 320/551] Match-id-54c2e64295b58e93969e118c875c9ccf7f67880a --- mx_rec/core/asc/manager.py | 6 ++- mx_rec/util/initialize.py | 22 ++++++++- src/core/emb_hashmap/emb_hashmap.cpp | 10 +++- src/core/emb_hashmap/emb_hashmap.h | 1 + src/core/hybrid_mgmt/hybrid_mgmt.cpp | 1 + src/core/key_process/key_process.cpp | 71 +++++++++++++++++++++++----- src/core/utils/common.cpp | 3 +- src/core/utils/common.h | 4 ++ 8 files changed, 100 insertions(+), 18 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index b7f27337..064fce38 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -12,7 +12,7 @@ from mx_rec.constants.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_steps, get_eval_steps, get_prefetch_batch_number, \ export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ - get_use_hot, get_use_dynamic_expansion, export_optimizer, export_dangling_table + get_use_hot, get_stat_on, get_use_dynamic_expansion, export_optimizer, export_dangling_table, export_table_num from mx_rec.core.asc.merge_table import find_dangling_table, should_skip @@ -227,5 +227,9 @@ def start_asc_pipeline(): if not table_info_list: logging.error("table_info_list is empty!") raise RuntimeError("table_info_list is empty!") + if get_stat_on(): + logging.info(f"[StatInfo] current_table_num {export_table_num()}") if not is_asc_manager_initialized() and table_info_list: initialize_emb_cache(table_info_list, threshold_list) + + diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 53b8cb40..3a8e85f5 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -78,6 +78,7 @@ class ConfigInitializer: self.enable_table_merge = True if os.getenv("TF_DEVICE") == "NPU" else False # 两个通道的sparse look id,用于通讯的标识 self.notify_hybrid_channel_sparse_id = [0, 0] + self.stat_on = set_stat_flag() if os.getenv("STAT_ON") else False def __del__(self): self.terminate() @@ -255,6 +256,8 @@ class ConfigInitializer: self._table_name_to_feature_spec[name] = {True: [], False: []} self._name_to_var_dict[name] = key self._table_instance_dict[key] = instance + if self.stat_on: + logging.info(f"[StatInfo] current_table_num {len(self._table_instance_dict)}") def insert_bool_gauge(self, name): if not isinstance(name, str): @@ -620,6 +623,10 @@ def export_table_instances(): return ConfigInitializer.get_instance().table_instance_dict +def export_table_num(): + return len(ConfigInitializer.get_instance().table_instance_dict) + + def export_dangling_table(): return ConfigInitializer.get_instance().dangling_table @@ -668,6 +675,10 @@ def get_use_static(): return ConfigInitializer.get_instance().use_static +def get_stat_on(): + return ConfigInitializer.get_instance().stat_on + + def get_use_hot(): return ConfigInitializer.get_instance().use_hot @@ -896,4 +907,13 @@ def bind_cpu(rank_id: int, rank_size: int = None): process.cpu_affinity(cpu_list) except IndexError: logging.error(f"failed to bind cpu for rank {rank_id}: {cpu_list}") - logging.info(f"bind cpu for rank {rank_id}: {cpu_list}") \ No newline at end of file + logging.info(f"bind cpu for rank {rank_id}: {cpu_list}") + + +def set_stat_flag(): + if os.getenv("STAT_ON") == "1": + return True + elif os.getenv("STAT_ON") == "0": + return False + else: + raise ValueError(f"STAT_ON can only be 0 or 1.") \ No newline at end of file diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 1b68536d..ba778464 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -64,6 +64,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t { #ifndef GTEST EASY_FUNCTION(profiler::colors::Pink) + TimeCost swapTimeCost; auto& embHashMap = embHashMaps.at(embName); embHashMap.devOffset2KeyOld.clear(); embHashMap.oldSwap.clear(); @@ -85,7 +86,6 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t // 调用刷新频次数据方法 RefreshFreqInfoWithSwap(embName, embHashMap); - swapId++; EASY_BLOCK("hostHashMaps->tdt") std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), std::back_inserter(ddrParam.offsetsOut)); @@ -116,7 +116,6 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t if (VLOG_IS_ON(GLOG_TRACE)) { VLOG(GLOG_TRACE) << StringFormat("swapTensor, %s", VectorToString(embHashMap.swapPos).c_str()); } - // 清空本次记录的查询偏移和交换偏移 ClearLookupAndSwapOffset(embHashMap); LOG(INFO) << StringFormat("current ddr emb:%s, usage:%d/[%d+%d]", embName.c_str(), embHashMap.maxOffset, @@ -124,6 +123,13 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto swapLen = ddrParam.tmpDataOut.back().flat(); swapLen(0) = swapSize; + + if (g_statOn) { + LOG(INFO) << StringFormat(STAT_INFO "channel_id %d batch_id %d rank_id %d swap_key_size %d swap_time_cost %d", + channelId, swapId, rankInfo.rankId, swapSize, swapTimeCost.ElapsedMS()); + } + + swapId++; EASY_END_BLOCK #endif } diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 00559509..1b6c8b25 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -15,6 +15,7 @@ #include "host_emb/host_emb.h" #include "ssd_cache/cache_manager.h" #include "utils/common.h" +#include "utils/time_cost.h" namespace MxRec { using namespace std; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index f8ed8d58..2c8f1e80 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -112,6 +112,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, // 设置日志的级别,对日志格式进行配置 SetLog(rankInfo.rankId); InitRankInfo(rankInfo, embInfos); + g_statOn = GetEnv("STAT_ON"); LOG(INFO) << StringFormat( MGMT + "begin initialize, localRankSize:%d, localRankId:%d, rank:%d", diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 20fca8ff..46e6497d 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -263,7 +263,6 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) } GetUniqueConfig(uniqueConf); - TimeCost tc = TimeCost(); try { while (true) { TimeCost getAndProcessTC; @@ -273,8 +272,8 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) if (batch == nullptr) { break; } - auto getBatchTime = tc.ElapsedMS(); - tc = TimeCost(); + auto getBatchTime = getBatchDataTC.ElapsedMS(); + TimeCost processDataTime = TimeCost(); InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); if (!KeyProcessTaskHelperWithFastUnique(batch, unique, channel, threadId)) { @@ -283,7 +282,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) LOG(INFO) << StringFormat( KEY_PROCESS "getAndProcessTC(ms):%d, key process with fast unique cost:%d," " get data time(ms):%d, batch name:%s, channel:%d, batchID:%d", - getAndProcessTC.ElapsedMS(), tc.ElapsedMS(), getBatchTime, + getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, batch->name.c_str(), batch->channel, batch->batchId ); auto batchQueue = SingletonQueue::getInstances(threadId + KEY_PROCESS_THREAD * batch->channel); @@ -293,16 +292,16 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) } catch (const EndRunError &e) { VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "abort run: %s", e.what()); } + LOG(INFO) << StringFormat( - KEY_PROCESS "KeyProcessTaskWithFastUnique exit. rank:%d thread:%d, channel:%d", - rankInfo.rankId, threadId, channel); + KEY_PROCESS "KeyProcessTaskWithFastUnique exit. rank:%d thread:%d, channel:%d", + rankInfo.rankId, threadId, channel); } void KeyProcess::KeyProcessTask(int channel, int threadId) { unique_ptr batch; - TimeCost tc = TimeCost(); try { while (true) { TimeCost getAndProcessTC; @@ -312,8 +311,8 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) if (batch == nullptr) { break; } - auto getBatchTime = tc.ElapsedMS(); - tc = TimeCost(); + auto getBatchTime = getBatchDataTC.ElapsedMS(); + TimeCost processDataTime = TimeCost(); if (!KeyProcessTaskHelper(batch, channel, threadId)) { break; @@ -321,7 +320,7 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) LOG(INFO) << StringFormat( KEY_PROCESS "getAndProcessTC(ms):%d, key process cost:%d," " get data time(ms):%d, batch name:%s, channel:%d, batchID:%d", - getAndProcessTC.ElapsedMS(), tc.ElapsedMS(), getBatchTime, + getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, batch->name.c_str(), batch->channel, batch->batchId ); auto batchQueue = SingletonQueue::getInstances(threadId + KEY_PROCESS_THREAD * batch->channel); @@ -330,6 +329,7 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) } catch (const EndRunError &e) { VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "abort run: %s", e.what()); } + LOG(INFO) << StringFormat( KEY_PROCESS "KeyProcessTask exit. rank:%d thread:%d, channel:%d", rankInfo.rankId, threadId, channel); } @@ -358,6 +358,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat // tuple for keyRec restore hotPos scAll countRecv isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; + TimeCost totalTimeCost = TimeCost(); TimeCost fastUniqueTC; UniqueInfo uniqueInfo; ProcessBatchWithFastUnique(batch, unique, threadId, uniqueInfo); @@ -399,6 +400,11 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat TimeCost pushResultTC; PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv); + if (g_statOn) { + LOG(INFO) << StringFormat(STAT_INFO "channel_id %d batch_id %d rank_id %d " + "key_process_time_cost_with_fast_unique %d", + channel, batch->batchId, rankInfo.rankId, totalTimeCost.ElapsedMS()); + } VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); return true; } @@ -409,7 +415,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe vector restore; vector hotPos; vector> keyCount; - + TimeCost totalTimeCost = TimeCost(); HashSplitHelper(batch, splitKeys, restore, hotPos, keyCount); auto [lookupKeys, scAll, ss] = ProcessSplitKeys(batch, threadId, splitKeys); @@ -460,6 +466,11 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe PushResult(batch, move(tensors), lookupKeys); VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); + if (g_statOn) { + LOG(INFO) << StringFormat(STAT_INFO "channel_id %d batch_id %d rank_id %d " + "key_process_time_cost %d", + channel, batch->batchId, rankInfo.rankId, totalTimeCost.ElapsedMS()); + } return true; } @@ -635,6 +646,13 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch batch->batchId, batch->Size(), batch->channel, batch->name.c_str(), uniqueInfoOut.restore.size(), keySendInfo.keyCount.size() ); + + if (g_statOn) { + LOG(INFO) << StringFormat( + STAT_INFO "channel_id %d batch_id %d rank_id %d " + "batch_key_num_with_fast_unique %d unique_key_num_with_fast_unique %d", + batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), uniqueOut.uniqueIdCnt); + } } void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, @@ -816,6 +834,16 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< } VLOG(GLOG_TRACE) << "dump splitKeys " << ssTrace.str(); } + + if (g_statOn) { + size_t UniqueKeyNum = 0; + for (int devId = 0; devId < rankInfo.rankSize; ++devId) { + UniqueKeyNum += splitKeys[devId].size(); + } + LOG(INFO) << StringFormat( + STAT_INFO "channel_id %d batch_id %d rank_id %d batch_key_num %d unique_key_num %ld", + batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), UniqueKeyNum); + } return { splitKeys, restore }; } @@ -868,6 +896,15 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const VLOG(GLOG_TRACE) << "dump splitKeys " << ssTrace.str(); } + if (g_statOn) { + size_t UniqueKeyNum = 0; + for (int devId = 0; devId < rankInfo.rankSize; ++devId) { + UniqueKeyNum += splitKeys[devId].size(); + } + LOG(INFO) << StringFormat( + STAT_INFO "channel_id %d batch_id %d rank_id %d batch_key_num %d faae_unique_key_num %ld", + batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), UniqueKeyNum); + } return { splitKeys, restore, keyCount }; } @@ -886,7 +923,6 @@ tuple, vector, vector> lock.unlock(); vector hotPos(hotEmbTotCount[batch->name]); vector hotPosDev(hotEmbTotCount[batch->name]); - int hotCount = 0; int hotOffset = hotEmbTotCount[batch->name]; for (size_t i = 0; i < miniBs; i++) { // for mini batch @@ -919,10 +955,19 @@ tuple, vector, vector> } uKey[key] = restore[i]; } + + if (g_statOn) { + size_t UniqueKeyNum = 0; + for (int devId = 0; devId < rankInfo.rankSize; ++devId) { + UniqueKeyNum += splitKeys[devId].size(); + } + LOG(INFO) << StringFormat( + STAT_INFO "channel_id %d batch_id %d rank_id %d batch_key_num %d hot_unique_key_num %ld", + batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), UniqueKeyNum); + } UpdateHotMap(keyCountMap, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, batch->name); AddCountStartToHotPos(splitKeys, hotPos, hotPosDev, batch); - return { splitKeys, restore, hotPos }; } diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index e972891e..00b5700e 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -28,7 +28,7 @@ namespace MxRec { string g_rankId; int g_glogLevel; bool g_isGlogInit = false; - + bool g_statOn = false; RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, const vector& maxStep) : rankId(rankId), deviceId(deviceId), localRankSize(localRankSize), option(option), @@ -144,6 +144,7 @@ namespace MxRec { } return (tmp == 1) ? true : false; } + string GetChipName(int devID) { int ret = 0; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index e58b6900..786c708c 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -65,11 +65,14 @@ namespace MxRec { constexpr int SSD_SIZE_INDEX = 2; // for GLOG + extern bool g_statOn; extern int g_glogLevel; extern string g_rankId; constexpr int GLOG_MAX_BUF_SIZE = 1024; constexpr int GLOG_TIME_WIDTH_2 = 2; constexpr int GLOG_TIME_WIDTH_6 = 6; + constexpr char GLOG_STAT_FLAG[] = "STAT_ON"; + // unique related config constexpr int UNIQUE_BUCKET = 6; @@ -590,6 +593,7 @@ namespace MxRec { } // end namespace MxRec #define KEY_PROCESS "\033[45m[KeyProcess]\033[0m " +#define STAT_INFO "[StatInfo] " #ifdef GTEST #define GTEST_PRIVATE public #else -- Gitee From 3089792bf80988db7c1cfd8ddd7d7ef8e26c86ae Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Sep 2023 11:16:37 +0800 Subject: [PATCH 321/551] Match-id-2743c26284b5c3fdfb3db1655d244b9da9a0ae71 --- mx_rec/core/embedding.py | 2 +- .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp | 4 ++-- .../ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 08db37ab..25784d0d 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -830,7 +830,7 @@ class SparseEmbedding: logging.debug(f"Slice host vocabulary_size for table {self.table_name} is" f" {self.slice_host_vocabulary_size}.") logging.debug(f"SSD vocabulary size for table {self.table_name} is {self.ssd_vocabulary_size}.") - logging.debug("Slice ssd vocabulary_size for table {self.table_name} is" + logging.debug(f"Slice ssd vocabulary_size for table {self.table_name} is" f" {self.slice_ssd_vocabulary_size}.") def _initialize_variables(self): diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp index 6379ed18..c5406366 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp @@ -64,7 +64,7 @@ CkptTransData NddrFeatMapCkpt::GetDataset(CkptDataType dataType, string embName) transArr.push_back(it.first); transArr.push_back(it.second); } - LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType%d is", CkptDataType::EMB_INFO, dataType); + LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d", CkptDataType::EMB_INFO, dataType); return move(transferData); } @@ -83,5 +83,5 @@ void NddrFeatMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTran int64_t key { transArr.at(i) }; hostHashMap[key] = transArr.at(i + 1); } - LOG(INFO) << StringFormat("dataType%d is", dataType); + LOG(INFO) << StringFormat("dataType:%d", dataType); } diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp index 356fb80d..62ce257c 100644 --- a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp @@ -52,7 +52,7 @@ CkptTransData NddrOffsetCkpt::GetDataset(CkptDataType dataType, string embName) transferData.attribute.push_back(1); transferData.attribute.push_back(fourBytes); transferData.attributeSize = transferData.attribute.size() * eightBytes; - LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d is", CkptDataType::EMB_INFO, dataType); + LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d", CkptDataType::EMB_INFO, dataType); return move(transferData); } @@ -61,5 +61,5 @@ void NddrOffsetCkpt::SetDataset(CkptDataType dataType, string embName, CkptTrans CleanTransfer(); transferData = move(loadedData); loadMaxOffset[embName] = transferData.int32Arr.front(); - LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d is", CkptDataType::EMB_INFO, dataType); + LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d", CkptDataType::EMB_INFO, dataType); } -- Gitee From ab3a0c94a1bd8b19280211a3efd4e07dde53a2ec Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Sep 2023 11:19:27 +0800 Subject: [PATCH 322/551] Match-id-27fcf729bcf92206e83255fbe74eb5e3e8355df0 --- src/core/utils/common.h | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 786c708c..780fddf4 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -12,7 +12,6 @@ #include #include -#include #include #include @@ -369,20 +368,6 @@ namespace MxRec { void ValidateReadFile(const string& dataDir, size_t datasetSize); - inline void GenerateRandomValue(std::vector& vecData, - std::default_random_engine& generator, - RandomInfo& randomInfo) - { - float min = ((!randomInfo.randomMin) ? -0.1f : randomInfo.randomMin); - float max = ((!randomInfo.randomMax) ? 0.1f : randomInfo.randomMax); - if (randomInfo.len == 0) { - return; - } - assert(static_cast(vecData.size()) >= randomInfo.len + randomInfo.start); - std::uniform_real_distribution distribution(min, max); - std::generate_n(vecData.begin() + randomInfo.start, randomInfo.len, [&]() { return distribution(generator); }); - } - enum class InitializerType { CONSTANT, TRUNCATED_NORMAL, -- Gitee From ca34f89bb108df5c2e9cca88c3cf122d5dc36110 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Sep 2023 14:51:52 +0800 Subject: [PATCH 323/551] Match-id-e8682b16a202aab35f0edb4ec9f841c2b2adaa8c --- mx_rec/core/asc/build_graph.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 916e4eb2..6e3258fb 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -123,9 +123,12 @@ def get_all2all_args(use_static: bool, config: dict) -> list: :param config: embedding config :return: all2all parametrs """ + all2all_args = None - with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): - if not use_static: + if use_static: + return all2all_args + with tf.control_dependencies([config.get("notify_hybridmgmt_op")]): + with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): with tf.compat.v1.variable_scope("all2all"): logging.debug( f'Channel {config.get("table_name")}_a2a_{config.get("channel_id")} was built for getnext') -- Gitee From 599a11312b6feaef4ef2485678323404630faa2c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Sep 2023 15:42:19 +0800 Subject: [PATCH 324/551] Match-id-97f6709e7b422afe5e0ddee31349882d6f9c38a4 --- .../op_host/embedding_lookup_by_address.cpp | 20 +++++++++++-------- .../op_kernel/embedding_lookup_by_address.cpp | 8 ++++---- .../op_kernel/embedding_update_by_address.cpp | 8 ++++---- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index a881d58f..401d7148 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -24,18 +24,22 @@ namespace optiling const auto *attr0_value = attrs->GetAttrPointer(0); if (attr0_value == nullptr) { printf(" Lookup embbeding_type attr0_value nullptr\n"); + return ge::GRAPH_FAILED; } - else { - int32_t embbeding_dim = *attr0_value; + + int32_t embbeding_dim = *attr0_value; + if (embbeding_dim <= 0) { + printf("embbeding_dim must larger than 0\n"); + return ge::GRAPH_FAILED; } const auto *attr1_value = attrs->GetAttrPointer(1); if (attr1_value == nullptr) { printf(" Lookup embbeding_type attr1_value nullptr\n"); + return ge::GRAPH_FAILED; } - else { - int32_t embbeding_type = *attr1_value; - } + + int32_t embbeding_type = *attr1_value; int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); @@ -62,10 +66,10 @@ namespace ge const auto *attr0_value = attrs->GetAttrPointer(0); if (attr0_value == nullptr) { printf(" Lookup embbeding_type attr0_value nullptr\n"); + return GRAPH_FAILED; } - else { - int64_t update_dim = *attr0_value; - } + + int64_t update_dim = *attr0_value; int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); y_shape->SetDimNum(2); diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 718f7c64..bd43a9b5 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -51,17 +51,17 @@ public: int min_move_num = 32 / singleDataSize; // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + min_move_num) / min_move_num表示除以min_move_num向下取整 int onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); - int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums + int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums; // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 - int occupy_address_bytes_num = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2 + int occupyAddressBytesNum = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2; // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 - int addr_max_num = ((int)((int)(ub_limit / occupy_address_bytes_num) / 4)) * 4; + int addrMaxNum = ((int)((int)(ub_limit / occupyAddressBytesNum) / 4)) * 4; int singlenum = (int)(addr_nums / block_total_nums); if (singlenum % 4) { singlenum -= singlenum % 4; } - roundSize = addr_max_num; + roundSize = addrMaxNum; Veclen = roundSize * singleDataSize * onceMoveNums; SingleCoreAddrLen = singlenum * sizeof(int64_t); cache = roundSize; diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index 8a6c537a..6b1c52cc 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -49,17 +49,17 @@ public: int min_move_num = 32 / singleDataSize; // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + min_move_num) / min_move_num表示除以min_move_num向下取整 int onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); - int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums + int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums; // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 - int occupy_address_bytes_num = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2 + int occupyAddressBytesNum = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2; // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 - int addr_max_num = ((int)((int)(ub_limit / occupy_address_bytes_num) / 4)) * 4; + int addrMaxNum = ((int)((int)(ub_limit / occupyAddressBytesNum) / 4)) * 4; int singlenum = (int)(addr_nums / block_total_nums); if (singlenum % 4) { singlenum -= singlenum % 4; } - roundSize = addr_max_num; + roundSize = addrMaxNum; Veclen = roundSize * singleDataSize * onceMoveNums; SingleCoreAddrLen = singlenum * sizeof(int64_t); cache = roundSize; -- Gitee From cc39f688a04e2655087b8fa0faee8aee668b62b0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Sep 2023 15:52:09 +0800 Subject: [PATCH 325/551] Match-id-9f77932fa97788344a78babc3672ef5039f17e9b --- src/core/ssd_engine/table.cpp | 8 ++++---- src/core/ssd_engine/table.h | 1 + src/tests/ssd_engine/engine_test.cpp | 21 +++++++++++++-------- src/tests/ssd_engine/table_test.cpp | 6 ++++-- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index d9575428..89880989 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -19,7 +19,7 @@ Table::Table(const string &name, vector &savePaths, uint64_t maxTableSiz maxTableSize(maxTableSize), compactThreshold(compactThreshold) { - curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + g_rankId + "/" + name).string(); + curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + g_rankId + "/" + name).string(); if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { throw runtime_error("fail to create table directory"); } @@ -41,7 +41,7 @@ Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize bool isMetaFileFound = false; for (const string &dirPath: saveDirs) { auto metaFilePath = fs::absolute( - dirPath + "/" + g_rankId + "/" + name + "/" + name + ".meta" + "." + to_string(step)).string(); + dirPath + "/" + saveDirPrefix + g_rankId + "/" + name + "/" + name + ".meta." + to_string(step)).string(); if (!fs::exists(metaFilePath)) { continue; } @@ -54,7 +54,7 @@ Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize } // always use first path to save until it's full - curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + g_rankId + "/" + name).string(); + curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + g_rankId + "/" + name).string(); LOG(INFO) << StringFormat("load table:%s done. try store at path:%s", name.c_str(), curTablePath.c_str()); } @@ -144,7 +144,7 @@ void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) shared_ptr tmp; for (const string &p: savePaths) { // try to find data file from each path - string dataPath = p + "/" + g_rankId + "/" + name; + string dataPath = p + "/" + saveDirPrefix + g_rankId + "/" + name; try { tmp = make_shared(fileID, dataPath, step); fileSet.insert(tmp); diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h index 4d454224..679f702d 100644 --- a/src/core/ssd_engine/table.h +++ b/src/core/ssd_engine/table.h @@ -63,6 +63,7 @@ namespace MxRec { shared_ptr curFile = nullptr; uint64_t curMaxFileID = 0; // no concurrent writing, always atomic increase const uint32_t maxNameSize = 1024; + const string saveDirPrefix = "ssd_sparse_model_rank_"; /* args for performance(not expose to user yet) * 2 read thread is optimal when: diff --git a/src/tests/ssd_engine/engine_test.cpp b/src/tests/ssd_engine/engine_test.cpp index b32aa626..d6805f46 100644 --- a/src/tests/ssd_engine/engine_test.cpp +++ b/src/tests/ssd_engine/engine_test.cpp @@ -19,7 +19,7 @@ TEST(SSDEngine, CreateAndWriteAndReadAndAutoCompactAndSave) g_rankId = to_string(rankId); string tbName = "test"; - vector savePath = {g_rankId}; + vector savePath = {"."}; uint64_t maxTableSize = 100; double compactThreshold = 0.5; chrono::seconds compactPeriod = chrono::seconds(5); @@ -67,22 +67,27 @@ TEST(SSDEngine, CreateAndWriteAndReadAndAutoCompactAndSave) delete eng; // after saving, full compact will perform, old file will be deleted - string oldDataFilePath = g_rankId + "/" + tbName + "/" + "0.data.latest"; - string oldMetaFilePath = g_rankId + "/" + tbName + "/" + "0.meta.latest"; + string oldDataFilePath = + savePath.front() + "ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "0.data.latest"; + string oldMetaFilePath = + savePath.front() + "ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "0.meta.latest"; ASSERT_EQ(fs::exists(oldDataFilePath), false); ASSERT_EQ(fs::exists(oldMetaFilePath), false); // check saved data existence - string newDataFilePath = savePath.front() + "/" + g_rankId + "/" + tbName + "/" + "1.data." + to_string(saveStep); - string newMetaFilePath = savePath.front() + "/" + g_rankId + "/" + tbName + "/" + "1.meta." + to_string(saveStep); + string newDataFilePath = + savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "1.data." + to_string(saveStep); + string newMetaFilePath = + savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "1.meta." + to_string(saveStep); string newTableMetaFilePath = - savePath.front() + "/" + g_rankId + "/" + tbName + "/" + tbName + ".meta." + to_string(saveStep); + savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + tbName + ".meta." + + to_string(saveStep); ASSERT_EQ(fs::exists(newDataFilePath), true); ASSERT_EQ(fs::exists(newMetaFilePath), true); ASSERT_EQ(fs::exists(newTableMetaFilePath), true); - for (string p: savePath) { - fs::remove_all(p); + for (const string& p: savePath) { + fs::remove_all(p + "/ssd_sparse_model_rank_" + g_rankId); } } diff --git a/src/tests/ssd_engine/table_test.cpp b/src/tests/ssd_engine/table_test.cpp index af6f58a1..48c6a6d1 100644 --- a/src/tests/ssd_engine/table_test.cpp +++ b/src/tests/ssd_engine/table_test.cpp @@ -80,8 +80,10 @@ TEST(Table, WriteAndReadAndDeleteAndCompact) // full compact, old file will delete, valid data will move to new file tb->Compact(true); - string oldDataFilePath = g_rankId + "/" + tbName + "/" + "0.data.latest"; - string oldMetaFilePath = g_rankId + "/" + tbName + "/" + "0.meta.latest"; + string oldDataFilePath = + savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "0.data.latest"; + string oldMetaFilePath = + savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "0.meta.latest"; ASSERT_EQ(fs::exists(oldDataFilePath), false); ASSERT_EQ(fs::exists(oldMetaFilePath), false); -- Gitee From fb5a098b4d580a908c28223dd6b403d7f90ecf65 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Sep 2023 16:13:07 +0800 Subject: [PATCH 326/551] Match-id-2d85b547a2bd866d683d6bf347d4eaa3adfd35d8 --- src/core/emb_hashmap/emb_hashmap.cpp | 20 ++------ src/core/hybrid_mgmt/hybrid_mgmt.cpp | 72 +++++++++++++++++++++++----- src/core/hybrid_mgmt/hybrid_mgmt.h | 3 +- 3 files changed, 68 insertions(+), 27 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index ba778464..2c79ed62 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -244,28 +244,18 @@ auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map if (checkResult == 1) { // 回退一步 for (auto& temp: embHashMapsOld) { - auto &embTableName = temp.first; auto &embHashMap = temp.second; - vector hbm2DdrKeys; - vector ddr2HbmKeys; for (auto &swapKeys: embHashMap.oldSwap) { emb_key_t oldKey = swapKeys.first; emb_key_t key = swapKeys.second; int tempOffset = static_cast(embHashMap.hostHashMap[key]); embHashMap.hostHashMap[key] = embHashMap.hostHashMap[oldKey]; embHashMap.hostHashMap[oldKey] = static_cast(tempOffset); - hbm2DdrKeys.emplace_back(key); - ddr2HbmKeys.emplace_back(oldKey); } embHashMap.maxOffset = embHashMap.maxOffsetOld; for (auto &Offset2Key: embHashMap.devOffset2KeyOld) { embHashMap.devOffset2Key[Offset2Key.first] = Offset2Key.second; } - if (isSSDEnabled) { - // 恢复CacheManager中频次数据 - cacheManager->RefreshFreqInfoCommon(embTableName, hbm2DdrKeys, TransferType::HBM_2_DDR); - cacheManager->RefreshFreqInfoCommon(embTableName, ddr2HbmKeys, TransferType::DDR_2_HBM); - } } return embHashMapsOld; } @@ -605,15 +595,15 @@ void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHa std::string ddrKeysString = VectorToString(ddrKeys); std::string lfuKeysString = VectorToString(lfuKeys); if (ddrKeysString != lfuKeysString) { - LOG(ERROR) << "ERROR STRING not equal, ddrKeysString:" << ddrKeysString << ", lfuKeysString:" << lfuKeysString; + LOG(ERROR) << "swap HBM with DDR step error, key string not equal, ddrKeysString:" << ddrKeysString + << ", lfuKeysString:" << lfuKeysString; } else { - LOG(INFO) << "After HBM swap with DDR, table:" << embTableName << + LOG(INFO) << "swap HBM with DDR step OK, table:" << embTableName << ", ddrKeysString is equals with lfuKeysString, string length:" << lfuKeysString.length(); } - LOG(INFO) - << "After HBM swap with DDR, table:" << embTableName << ", tableKeyInDdr:" << tableKeyInDdr << - ", tableKeyInLfu:" << lfu.keyTable.size(); + LOG(INFO) << "swap HBM with DDR step end, table:" << embTableName << ", tableKeyInDdr:" << tableKeyInDdr + << ", tableKeyInLfu:" << lfu.keyTable.size(); } /// 记录key频次数据 diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 2c8f1e80..28262f8c 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -170,16 +170,18 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, } // 比较hostHashMap和cacheManager的数据是否一致 -void HybridMgmt::AddCacheManagerTraceLog(absl::flat_hash_map, EmbHashMapInfo>& embHashMaps) +void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) { - if (!isSSDEnabled || !VLOG_IS_ON(GLOG_TRACE)) { + if (!VLOG_IS_ON(GLOG_TRACE)) { return; } + auto& embHashMaps = saveData.embHashMaps; + auto& ddrKeyFreqMap = saveData.ddrKeyFreqMaps; for (auto& it : embHashMaps) { string embTableName = it.first; auto& hostMap = it.second.hostHashMap; auto& devSize = it.second.devVocabSize; - auto& lfu = cacheManager->ddrKeyFreqMap[embTableName]; + auto& lfu = ddrKeyFreqMap[embTableName]; size_t tableKeyInDdr = 0; for (const auto& item : hostMap) { if (item.second < devSize) { @@ -187,14 +189,61 @@ void HybridMgmt::AddCacheManagerTraceLog(absl::flat_hash_map, } ++tableKeyInDdr; auto cuKey = item.first; - auto lfuKeyCount = lfu.Get(cuKey); - if (lfuKeyCount == -1) { - LOG(ERROR) << "ERROR, SAVE Step, ddr key:" << cuKey << ", lfu count by key:" << - lfuKeyCount << ", hostHashMap offset:" << item.second; + if (lfu.find(cuKey) == lfu.end()) { + LOG(ERROR) << "save step error, ddr key:" << cuKey << ", not exist in lfu, hostHashMap offset:" + << item.second; } } - LOG(INFO) << "SAVE Step, table:" << embTableName << ", tableKeyInDdr:" << tableKeyInDdr << - ", tableKeyInLfu:" << lfu.keyTable.size(); + LOG(INFO) << "save step end, table:" << embTableName << ", tableKeyInDdr:" << tableKeyInDdr << + ", tableKeyInLfu:" << lfu.size(); + } +} + +/// 保存CacheManager时恢复数据(与恢复hostHashMap类似,仅恢复保存数据,不修改源数据) +/// \param saveData 保存数据 +void HybridMgmt::RestoreFreq4Save(CkptData& saveData) +{ + // 仅在差异1步时执行恢复操作 + int checkResult = hybridMgmtBlock->CheckSaveEmbdMapValid(); + if (checkResult != 1) { + return; + } + auto& ddrKeyFreqMaps = saveData.ddrKeyFreqMaps; + auto& excludeDDRKeyFreqMaps = saveData.excludeDDRKeyFreqMaps; + + for (const auto& it : saveData.embHashMaps) { + auto& embTableName = it.first; + auto& embHashMap = it.second; + vector hbm2DdrKeys; + vector ddr2HbmKeys; + LOG(INFO) << "restore freq info for save step, table:" << embTableName << ", embHashMap.oldSwap size:" + << embHashMap.oldSwap.size(); + LOG(INFO) << "before, ddr key table size:" << ddrKeyFreqMaps[embTableName].size() + << ", exclude ddr key table size:" << excludeDDRKeyFreqMaps[embTableName].size(); + for (const auto& swapKeys : embHashMap.oldSwap) { + hbm2DdrKeys.emplace_back(swapKeys.second); + ddr2HbmKeys.emplace_back(swapKeys.first); + } + int hbm2DdrKeysNotInExcludeMapCount = 0; + int ddr2HbmKeysNotInDDRMapCount = 0; + for (auto& key : hbm2DdrKeys) { + if (excludeDDRKeyFreqMaps[embTableName].find(key) == excludeDDRKeyFreqMaps[embTableName].end()) { + ++hbm2DdrKeysNotInExcludeMapCount; + } + ddrKeyFreqMaps[embTableName][key] = excludeDDRKeyFreqMaps[embTableName][key]; + excludeDDRKeyFreqMaps[embTableName].erase(key); + } + for (auto& key : ddr2HbmKeys) { + if (ddrKeyFreqMaps[embTableName].find(key) == ddrKeyFreqMaps[embTableName].end()) { + ++ddr2HbmKeysNotInDDRMapCount; + } + excludeDDRKeyFreqMaps[embTableName][key] = ddrKeyFreqMaps[embTableName][key]; + ddrKeyFreqMaps[embTableName].erase(key); + } + LOG(INFO) << "hbm2DdrKeysNotInExcludeMapCount:" << hbm2DdrKeysNotInExcludeMapCount + << ", ddr2HbmKeysNotInDDRMapCount:" << ddr2HbmKeysNotInDDRMapCount; + LOG(INFO) << "after, ddr key table size:" << ddrKeyFreqMaps[embTableName].size() + << ", exclude ddr key table size:" << excludeDDRKeyFreqMaps[embTableName].size(); } } @@ -214,7 +263,6 @@ bool HybridMgmt::Save(const string savePath) VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: ddr mode hashmap"); saveData.hostEmbs = hostEmbs->GetHostEmbs(); saveData.embHashMaps = hostHashMaps->GetHashMaps(); - AddCacheManagerTraceLog(saveData.embHashMaps); } else { // HBM模式保存最大偏移(真正使用了多少vocab容量),特征到偏移的映射 VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: no ddr mode hashmap"); @@ -227,6 +275,8 @@ bool HybridMgmt::Save(const string savePath) saveData.ddrKeyFreqMaps[it.first] = it.second.GetFreqTable(); } saveData.excludeDDRKeyFreqMaps = cacheManager->excludeDDRKeyCountMap; + RestoreFreq4Save(saveData); + AddCacheManagerTraceLog(saveData); auto step = GetStepFromPath(savePath); cacheManager->SaveSSDEngine(step); } @@ -787,7 +837,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, VLOG(GLOG_DEBUG) << StringFormat( "getAndSendTensorsTC(ms):%d, channelId:%d", getAndSendTensorsTC.ElapsedMS(), channelId); - if (embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch + if (!isSSDEnabled && embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch LOG(WARNING) << StringFormat(MGMT + "embName %s[%d]%d,iBatch:%d freeSize not enough, %d", embName.c_str(), channelId, batchId, iBatch, lookupKeys.size()); return false; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index df5d60de..f7d24c80 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -126,7 +126,8 @@ namespace MxRec { void PrepareDDRData(const std::string& embTableName, EmbHashMapInfo& embHashMap, const vector& keys, int channelId); int GetStepFromPath(const string& loadPath); - void AddCacheManagerTraceLog(absl::flat_hash_map, EmbHashMapInfo>& embHashMaps); + static void AddCacheManagerTraceLog(CkptData& saveData); + void RestoreFreq4Save(CkptData& saveData); private: int currentBatchId; -- Gitee From 7cee34c8b3c1626c90a0e19d343f00c19f6452b7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Sep 2023 16:35:11 +0800 Subject: [PATCH 327/551] Match-id-c8f1c3a1aaa95e2d4fd87b06c7db82a92e92cb73 --- mx_rec/core/asc/feature_spec.py | 8 ++------ mx_rec/core/embedding.py | 16 +--------------- mx_rec/util/initialize.py | 2 ++ mx_rec/util/normalization.py | 24 ++++++++++++++++++++++++ 4 files changed, 29 insertions(+), 21 deletions(-) create mode 100644 mx_rec/util/normalization.py diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 50eede22..4b29a7fe 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -11,6 +11,7 @@ import tensorflow as tf from mx_rec.util.atomic import AtomicInteger from mx_rec.util.initialize import insert_feature_spec, insert_training_mode_channel_id, get_use_static +from mx_rec.util.normalization import fix_invalid_table_name from mx_rec.constants.constants import MAX_INT32 feature_spec_global_id = AtomicInteger() @@ -27,7 +28,7 @@ class FeatureSpec: spec_name = name + f"_{feature_spec_global_id}" self.name = spec_name self._index_key = kwargs.get("index_key") if kwargs.get("index_key") else name - self._table_name = kwargs.get("table_name") if kwargs.get("table_name") else name + self._table_name = fix_invalid_table_name(kwargs.get("table_name") if kwargs.get("table_name") else name) self._feat_cnt = kwargs.get("feat_count") self._access_threshold = kwargs.get("access_threshold") self._eviction_threshold = kwargs.get("eviction_threshold") @@ -42,7 +43,6 @@ class FeatureSpec: self.initialized = False self._pipeline_mode = set() - self.fix_invalid_table_name() self.check_params() @property @@ -133,10 +133,6 @@ class FeatureSpec: if self._is_timestamp is not None: check_bool(self._is_timestamp, "is_timestamp") - def fix_invalid_table_name(self): - if not re.match("^[0-9A-Za-z_]+$", self._table_name): - self._table_name = re.sub(r'\W+', '_', self._table_name) - def set_feat_pos(self, is_training): if is_training: self.feat_pos_train = FeatureSpec.instance_count_train diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 25784d0d..db7f3076 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -27,6 +27,7 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set from mx_rec.validator.validator import ClassValidator, StringValidator from mx_rec.util.tf_version_adapter import npu_ops +from mx_rec.util.normalization import fix_invalid_table_name def check_ssd_relate_param(host_vocabulary_size, ssd_vocabulary_size, ssd_data_path): @@ -904,18 +905,3 @@ def check_create_table_params(key_dtype, dim, name, emb_initializer): emb_initializer_validator = ClassValidator(value=emb_initializer, classes=(InitializerV1, InitializerV2)) emb_initializer_validator.check_isinstance() emb_initializer_validator.check() - - -def fix_invalid_table_name(name): - """ - 校验table name字符串中是否含有特殊字符,如有,替换为下划线 - :param name: table name - :return : the fixed table name - """ - if re.match("^[0-9A-Za-z_]+$", name): - return name - fix_name = re.sub(r'\W+', '_', name) - logging.warning(f"The table name {name} contains invalid characters." - f"The system automatically replaces invalid characters with underscores (_). " - f"The table name was changed to {fix_name}") - return fix_name diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 8b9e984d..e9804c43 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -63,6 +63,8 @@ class ConfigInitializer: self._comm = MPI.COMM_WORLD self._rank_id = self._comm.Get_rank() self._rank_size = self._comm.Get_size() + else: + raise ValueError("only mpi is supported for launching task.") self._rank_to_device_dict = parse_hccl_json() if os.getenv("RANK_TABLE_FILE") else set_hccl_info_without_json() self.train_steps = kwargs.get("train_steps", -1) diff --git a/mx_rec/util/normalization.py b/mx_rec/util/normalization.py new file mode 100644 index 00000000..1a515954 --- /dev/null +++ b/mx_rec/util/normalization.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import re +import logging + + +def fix_invalid_table_name(name): + """ + 校验table name字符串中是否含有特殊字符,如有,替换为下划线 + :param name: table name + :return : the fixed table name + """ + pattern = "^[0-9A-Za-z_]+$" + if re.match(pattern, name): + return name + fix_name = re.sub(r'\W+', '', name) + if not fix_name: + raise ValueError(f"The table name '{name}' doesn't contain valid character, " + f"according to the rule '{pattern}'") + logging.warning(f"The table name '%s' contains invalid characters. " + f"The system automatically remove invalid characters. " + f"The table name was changed to '%s'", name, fix_name) + return fix_name -- Gitee From 380a028fc8daa6264d9b65e2e76c9ccd879bcd29 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Sep 2023 21:18:30 +0800 Subject: [PATCH 328/551] Match-id-e4dabac6b3969d5b6feea4f4605a0a3b567be60c --- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index 910f8176..ea95aec9 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -138,7 +138,7 @@ void HybridMgmtBlock::ResetAll(int channelId) readEmbedBatchId[channelId] = 0; pythonBatchId[channelId] = 0; hybridBatchId[channelId] = 0; - isBlock[channelId] = true; + isBlock[channelId] = false; // eval train通道的sparse 同时进行重置,以防出现sparse id失效的问题 uniqueSparseLookID[EVAL_CHANNEL_ID] = -1; uniqueSparseLookID[TRAIN_CHANNEL_ID] = -1; -- Gitee From 083512f736321c71d9ca7cd120797befa301c176 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Sep 2023 09:52:48 +0800 Subject: [PATCH 329/551] Match-id-4d9335cc6d16e3bca37ffae46257399355cc185c --- src/core/emb_hashmap/emb_hashmap.cpp | 88 +++---- .../key_process/feature_admit_and_evict.cpp | 30 +-- src/core/key_process/key_process.cpp | 237 ++++++++---------- src/core/ssd_cache/cache_manager.cpp | 52 ++-- src/core/ssd_engine/ssd_engine.cpp | 10 +- src/core/ssd_engine/table.cpp | 20 +- src/core/utils/common.cpp | 44 ++-- src/core/utils/common.h | 4 +- src/core/utils/log.cpp | 56 +++++ src/core/utils/log.h | 125 +++++++++ .../feature_admit_and_evict_test.cpp | 18 +- src/tests/key_process/key_process_test.cpp | 95 +++---- src/tests/ssd_cache/cache_manager_test.cpp | 10 +- src/tests/ssd_engine/table_test.cpp | 10 +- src/tests/utils/log_test.cpp | 135 ++++++++++ 15 files changed, 574 insertions(+), 360 deletions(-) create mode 100644 src/core/utils/log.cpp create mode 100644 src/core/utils/log.h create mode 100644 src/tests/utils/log_test.cpp diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 2c79ed62..3566f621 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -73,7 +73,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t auto keepBatch = swapId - iBatch; // 处理batch的次数,多个预取一起处理算一次 bool findOffsetV2 = GetEnv("FIND_OFFSET_V2"); - VLOG(GLOG_DEBUG) << StringFormat("FindOffset version:%d", findOffsetV2); + LOG_DEBUG("FindOffset version:{}", findOffsetV2); // 找到所有key的偏移;dev和host需要交换的位置 if (findOffsetV2) { @@ -81,7 +81,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t } else { FindOffset(embName, keys, swapId, keepBatch, channelId); } - VLOG(GLOG_DEBUG) << "FindOffset end"; + LOG_DEBUG("FindOffset end"); // 调用刷新频次数据方法 RefreshFreqInfoWithSwap(embName, embHashMap); @@ -98,9 +98,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t for (int i = 0; i < lookUpVecSize; i++) { lookupTensorData(i) = static_cast(embHashMap.lookUpVec[i]); } - if (VLOG_IS_ON(GLOG_TRACE)) { - VLOG(GLOG_TRACE) << StringFormat("lookupTensor, %s", VectorToString(embHashMap.lookUpVec).c_str()); - } + LOG_TRACE("lookupTensor, {}", VectorToString(embHashMap.lookUpVec)); // 构造交换向量tensor auto swapSize = static_cast(embHashMap.swapPos.size()); @@ -111,21 +109,19 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t swapTensorData(i) = static_cast(embHashMap.swapPos[i]); } if (swapSize > 0) { - VLOG(GLOG_DEBUG) << StringFormat("swap num: %d", swapSize); - } - if (VLOG_IS_ON(GLOG_TRACE)) { - VLOG(GLOG_TRACE) << StringFormat("swapTensor, %s", VectorToString(embHashMap.swapPos).c_str()); + LOG_DEBUG("swap num: {}", swapSize); } + LOG_TRACE("swapTensor, {}", VectorToString(embHashMap.swapPos)); // 清空本次记录的查询偏移和交换偏移 ClearLookupAndSwapOffset(embHashMap); - LOG(INFO) << StringFormat("current ddr emb:%s, usage:%d/[%d+%d]", embName.c_str(), embHashMap.maxOffset, - embHashMap.devVocabSize, embHashMap.hostVocabSize); + LOG_INFO("current ddr emb:{}, usage:{}/[{}+{}]", embName, embHashMap.maxOffset, + embHashMap.devVocabSize, embHashMap.hostVocabSize); ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto swapLen = ddrParam.tmpDataOut.back().flat(); swapLen(0) = swapSize; if (g_statOn) { - LOG(INFO) << StringFormat(STAT_INFO "channel_id %d batch_id %d rank_id %d swap_key_size %d swap_time_cost %d", + LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} swap_key_size {} swap_time_cost {}", channelId, swapId, rankInfo.rankId, swapSize, swapTimeCost.ElapsedMS()); } @@ -176,15 +172,13 @@ int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashM } else if (embHashMap.evictDevPos.size() != 0) { // 优先复用hbm表 offset = static_cast(embHashMap.evictDevPos.back()); embHashMap.hostHashMap[key] = offset; - VLOG(GLOG_TRACE) << StringFormat( - "ddr mode, dev evictPos is not null, key [%d] reuse offset [%d], evictSize [%d]", + LOG_TRACE("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", key, offset, embHashMap.evictDevPos.size()); embHashMap.evictDevPos.pop_back(); } else if (embHashMap.evictPos.size() != 0) { // hbm不足,再复用ddr表 offset = static_cast(embHashMap.evictPos.back()); embHashMap.hostHashMap[key] = offset; - VLOG(GLOG_TRACE) << StringFormat( - "ddr mode, host evictPos is not null, key [%d] reuse offset [%d], evictSize [%d]", + LOG_TRACE("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", key, offset, embHashMap.evictPos.size()); embHashMap.evictPos.pop_back(); } else { @@ -192,10 +186,10 @@ int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashM offset = static_cast(embHashMap.maxOffset); embHashMap.maxOffset++; if (embHashMap.maxOffset == embHashMap.devVocabSize) { - LOG(INFO) << "start using host vocab!"; + LOG_INFO("start using host vocab!"); } if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { - LOG(ERROR) << StringFormat("hostVocabSize too small! dev:%d host:%d", embHashMap.devVocabSize, + LOG_ERROR("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, embHashMap.hostVocabSize); throw runtime_error("hostVocabSize too small"); } @@ -233,7 +227,7 @@ void EmbHashMap::FindAndUpdateBatchId(vector& keys, size_t currentBat auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map { - VLOG(GLOG_DEBUG) << (HYBRID_BLOCKING + " start GetHashMaps"); + LOG_DEBUG(HYBRID_BLOCKING + " start GetHashMaps"); HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); auto embHashMapsOld = embHashMaps; int checkResult = hybridMgmtBlock->CheckSaveEmbdMapValid(); @@ -300,14 +294,14 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& size_t offset; auto key = keys[i]; if (key == -1) { - LOG(WARNING) << "evict key equal -1!"; + LOG_WARN("evict key equal -1!"); continue; } const auto& iter = embHashMap.hostHashMap.find(key); if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; embHashMap.hostHashMap.erase(iter); - VLOG(GLOG_TRACE) << StringFormat("evict embName %s , offset , %d", embName.c_str(), offset); + LOG_TRACE("evict embName %s, offset %d", embName, offset); } else { // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 continue; @@ -329,13 +323,9 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& cacheManager->RefreshFreqInfoCommon(embName, evictDDRKeys, TransferType::DDR_2_EVICT); } - LOG(INFO) << StringFormat( - "ddr EvictDeleteEmb, emb: [%s], hostEvictSize: %d, devEvictSize: %d ", - embName.c_str(), embHashMap.evictPos.size(), embHashMap.evictDevPos.size() - ); - if (VLOG_IS_ON(GLOG_TRACE)) { - VLOG(GLOG_TRACE) << StringFormat("hostHashMap, %s", MapToString(embHashMaps[embName].hostHashMap).c_str()); - } + LOG_INFO("ddr EvictDeleteEmb, emb: [{}], hostEvictSize: {}, devEvictSize: {}", + embName, embHashMap.evictPos.size(), embHashMap.evictDevPos.size()); + LOG_TRACE("hostHashMap, {}", MapToString(embHashMaps[embName].hostHashMap)); } /// 从embHashMaps获取key对应的位置,构造查询向量;更新devOffset2Batch;记录dev与host需要交换的偏移 @@ -378,11 +368,9 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys } } if (currentBatchId == 0) { - LOG(INFO) << StringFormat("max offset %d", embHashMap.maxOffset); - } - if (VLOG_IS_ON(GLOG_TRACE)) { - VLOG(GLOG_TRACE) << StringFormat("hostHashMap, %s", MapToString(embHashMaps[embName].hostHashMap).c_str()); + LOG_INFO("max offset {}", embHashMap.maxOffset); } + LOG_TRACE("hostHashMap, {}", MapToString(embHashMaps[embName].hostHashMap)); } @@ -398,20 +386,17 @@ bool EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashM const auto& iter = embHashMap.hostHashMap.find(key); if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; - VLOG(GLOG_TRACE) << StringFormat("devVocabSize, %d , offset , %d", embHashMap.devVocabSize, offset); + LOG_TRACE("devVocabSize, {} , offset , {}", embHashMap.devVocabSize, offset); } else if (embHashMap.evictDevPos.size() != 0 && channelId == TRAIN_CHANNEL_ID) { // 优先复用hbm表 offset = embHashMap.evictDevPos.back(); embHashMap.hostHashMap[key] = offset; - VLOG(GLOG_TRACE) << StringFormat( - "ddr mode, dev evictPos is not null, key [%d] reuse offset [%d], evictSize [%d]", - key, offset, embHashMap.evictDevPos.size() - ); + LOG_TRACE("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", + key, offset, embHashMap.evictDevPos.size()); embHashMap.evictDevPos.pop_back(); } else if (embHashMap.evictPos.size() != 0 && channelId == TRAIN_CHANNEL_ID) { // hbm不足,再复用ddr表 offset = embHashMap.evictPos.back(); embHashMap.hostHashMap[key] = offset; - VLOG(GLOG_TRACE) << StringFormat( - "ddr mode, host evictPos is not null, key [%d] reuse offset [%d], evictSize [%d]", + LOG_TRACE("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", key, offset, embHashMap.evictPos.size()); embHashMap.evictPos.pop_back(); } else { @@ -420,11 +405,10 @@ bool EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashM offset = embHashMap.maxOffset; embHashMap.maxOffset++; if (embHashMap.maxOffset == embHashMap.devVocabSize) { - LOG(INFO) << ("start using host vocab!"); + LOG_INFO("start using host vocab!"); } if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { - LOG(ERROR) << StringFormat( - "hostVocabSize too small! dev:%d host:%d", embHashMap.devVocabSize, embHashMap.hostVocabSize); + LOG_ERROR("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, embHashMap.hostVocabSize); throw runtime_error("hostVocabSize too small"); } } else { @@ -452,7 +436,7 @@ void EmbHashMap::UpdateBatchId(const vector& keys, size_t currentBatc if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; - VLOG(GLOG_TRACE) << StringFormat("key will be used, %d , offset , %d", key, offset); + LOG_TRACE("key will be used, {} , offset , {}", key, offset); if (offset < embHashMap.devVocabSize) { embHashMap.devOffset2Batch[offset] = static_cast(currentBatchId); } @@ -489,7 +473,7 @@ int EmbHashMap::FindSwapPosV2(const string& embName, emb_key_t key, size_t hostO embHashMap.currentUpdatePos = 0; } if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { - LOG(ERROR) << "devVocabSize is too small"; + LOG_ERROR("devVocabSize is too small"); throw runtime_error("devVocabSize is too small"); } } @@ -534,7 +518,7 @@ bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hos // 已经找完了整个HBM空间,没有找到可用位置,表示HBM空间不足以放下整个batch(预取batch数)的key,无法正常执行训练,固运行时错误退出 if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { - LOG(ERROR) << "devVocabSize is too small"; + LOG_ERROR("devVocabSize is too small"); throw runtime_error("devVocabSize is too small"); } } @@ -551,7 +535,7 @@ void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& } // 换入换出key列表,元素为pair: pair oldKey为从HBM移出的key, key为从DDR移出的key auto& oldSwap = embHashMap.oldSwap; - VLOG(GLOG_DEBUG) << StringFormat("RefreshFreqInfoWithSwap:oldSwap Size:%lld", oldSwap.size()); + LOG_DEBUG("RefreshFreqInfoWithSwap:oldSwap Size:{}", oldSwap.size()); vector enterDDRKeys; vector leaveDDRKeys; for (auto keyPair : oldSwap) { @@ -595,15 +579,15 @@ void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHa std::string ddrKeysString = VectorToString(ddrKeys); std::string lfuKeysString = VectorToString(lfuKeys); if (ddrKeysString != lfuKeysString) { - LOG(ERROR) << "swap HBM with DDR step error, key string not equal, ddrKeysString:" << ddrKeysString - << ", lfuKeysString:" << lfuKeysString; + LOG_ERROR("swap HBM with DDR step error, key string not equal, ddrKeysString:{}, lfuKeysString:{}", + ddrKeysString, lfuKeysString); } else { - LOG(INFO) << "swap HBM with DDR step OK, table:" << embTableName << - ", ddrKeysString is equals with lfuKeysString, string length:" << lfuKeysString.length(); + LOG_INFO("swap HBM with DDR step OK, table:{}, ddrKeysString == lfuKeysString, string length:{}", + embTableName, lfuKeysString.length()); } - LOG(INFO) << "swap HBM with DDR step end, table:" << embTableName << ", tableKeyInDdr:" << tableKeyInDdr - << ", tableKeyInLfu:" << lfu.keyTable.size(); + LOG_INFO("swap HBM with DDR step end, table:{}, tableKeyInDdr:{}, tableKeyInLfu:{}", + embTableName, tableKeyInDdr, lfu.keyTable.size()); } /// 记录key频次数据 diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 80db3003..f521ff5e 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -179,12 +179,10 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v } if (evictKey.size() == 0) { - LOG(INFO) << StringFormat( - "table-name[%s]'s lastTime[%d], had no key to delete ...", embName.c_str(), currTime); + LOG_INFO("table-name[{}]'s lastTime[{}], had no key to delete ...", embName, currTime); return; } - LOG(INFO) << StringFormat( - "table-name[%s]'s lastTime[%d], had size[%d] keys to delete ...", embName.c_str(), currTime, evictKey.size()); + LOG_INFO("table-name[{}]'s lastTime[{}], had size[{}] keys to delete ...", embName, currTime, evictKey.size()); // 真正从 m_historyRecords 中淘汰 absl::flat_hash_map& historyRecords = m_recordsData.historyRecords[embName]; @@ -200,9 +198,9 @@ void FeatureAdmitAndEvict::FeatureEvictHelper(const std::string& embName, std::v void FeatureAdmitAndEvict::SetFunctionSwitch(bool isEnableEvict) { if (isEnableEvict) { - LOG(INFO) << "feature admit-and-evict switch is opened ..."; + LOG_INFO("feature admit-and-evict switch is opened ..."); } else { - LOG(INFO) << "feature admit-and-evict switch is closed ..."; + LOG_INFO("feature admit-and-evict switch is closed ..."); } m_isEnableFunction = isEnableEvict; } @@ -240,16 +238,15 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t for (size_t i = 0; i < thresholds.size(); ++i) { auto it = std::find(embNames.begin(), embNames.end(), thresholds[i].tableName); if (it == embNames.end()) { // 配置不存在于当前跑的模型,也要报错 - LOG(ERROR) << StringFormat( - "embName[%s] is not exist at current model ...", thresholds[i].tableName.c_str()); + LOG_ERROR("embName[{}] is not exist at current model ...", thresholds[i].tableName); return false; } else { // 同时支持“准入&淘汰”,却没有传时间戳 if (m_embStatus[*it] == SingleEmbTableStatus::SETS_ERROR) { - LOG(ERROR) << StringFormat("embName[%s] config error, please check ...", embNames[i].c_str()); + LOG_ERROR("embName[{}] config error, please check ...", embNames[i]); return false; } else if (m_embStatus[*it] == SingleEmbTableStatus::SETS_BOTH && !isTimestamp) { - LOG(ERROR) << StringFormat("embName[%s] admit and evict, but no timestamp", embNames[i].c_str()); + LOG_ERROR("embName[{}] admit and evict, but no timestamp", embNames[i]); return false; } } @@ -286,32 +283,31 @@ void FeatureAdmitAndEvict::LoadHistoryRecords(AdmitAndEvictData& loadData) bool FeatureAdmitAndEvict::ParseThresholdCfg(const std::vector& thresholdValues) { if (thresholdValues.empty()) { - LOG(ERROR) << "thresholdValues is empty ..."; + LOG_ERROR("thresholdValues is empty ..."); return false; } m_cfgThresholds = thresholdValues; for (const auto& value : thresholdValues) { - LOG(INFO) << StringFormat( - "embName[%s], count[%d], time[%d], coefficient[%d] ...", - value.tableName.c_str(), value.countThreshold, value.timeThreshold, value.faaeCoefficient); + LOG_INFO("embName[{}], count[{}], time[{}], coefficient[{}] ...", + value.tableName, value.countThreshold, value.timeThreshold, value.faaeCoefficient); auto it = m_table2Threshold.find(value.tableName); if (it != m_table2Threshold.end()) { // train和eval同时开启,会出现表重复配置 - LOG(INFO) << StringFormat("[%s] is repeated configuration ...", value.tableName.c_str()); + LOG_INFO("[{}] is repeated configuration ...", value.tableName); return true; } m_table2Threshold[value.tableName] = value; if (value.faaeCoefficient < 1) { - LOG(ERROR) << StringFormat("[%s] config error, coefficient smaller than 1 ...", value.tableName.c_str()); + LOG_ERROR("[{}] config error, coefficient smaller than 1 ...", value.tableName); } if (value.countThreshold != -1 && value.timeThreshold != -1) { m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_BOTH; } else if (value.countThreshold != -1 && value.timeThreshold == -1) { m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_ONLY_ADMIT; } else { - LOG(ERROR) << StringFormat("[%s] config error, have evict but no admit ...", value.tableName.c_str()); + LOG_ERROR("[{}] config error, have evict but no admit ...", value.tableName); m_embStatus[value.tableName] = SingleEmbTableStatus::SETS_ERROR; return false; } diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 46e6497d..10e2889a 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -37,13 +37,13 @@ void KeyProcess::SetupHotEmbUpdateStep() int tmp = std::stoi(envUpdateStep); if (tmp >= 1 && tmp <= maxUpdateStep) { this->hotEmbUpdateStep = tmp; - LOG(INFO) << StringFormat("Succeed to parse ${env:HOT_EMB_UPDATE_STEP}: %d.", this->hotEmbUpdateStep); + LOG_INFO("Succeed to parse ${env:HOT_EMB_UPDATE_STEP}: {}.", this->hotEmbUpdateStep); } else { - LOG(ERROR) << StringFormat("${env:HOT_EMB_UPDATE_STEP}: %d should be in [1, 1000], set default: %d.", + LOG_ERROR("${env:HOT_EMB_UPDATE_STEP}: {} should be in [1, 1000], set default: {}.", tmp, HOT_EMB_UPDATE_STEP_DEFAULT); } } catch (const std::invalid_argument &e) { - LOG(ERROR) << StringFormat("Failed to parse ${env:HOT_EMB_UPDATE_STEP}: %s, set default: %d.", + LOG_ERROR("Failed to parse ${env:HOT_EMB_UPDATE_STEP}: {}, set default: {}.", envUpdateStep, HOT_EMB_UPDATE_STEP_DEFAULT); } } @@ -68,11 +68,11 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos if (rankInfo.useDynamicExpansion) { // 动态扩容 embeddingTableMap[info.name].Init(info, rInfo, seed); - LOG(INFO) << StringFormat(KEY_PROCESS "EmbeddingTableMap:%s init success", info.name.c_str()); + LOG_INFO(KEY_PROCESS "EmbeddingTableMap:{} init success", info.name); } } - REC_LOG(INFO) << StringFormat(KEY_PROCESS "hot emb count info:%s", MapToString(hotEmbTotCount).c_str()); + LOG_INFO(KEY_PROCESS "hot emb count info:{}", MapToString(hotEmbTotCount)); MPI_Group worldGroup; MPI_Comm_group(MPI_COMM_WORLD, &worldGroup); for (auto& i: comm) { @@ -88,17 +88,15 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos m_featureAdmitAndEvict.Init(thresholdValues); } else { m_featureAdmitAndEvict.SetFunctionSwitch(false); - LOG(WARNING) << KEY_PROCESS "Feature admit-and-evict function is unavailable ..."; + LOG_WARN(KEY_PROCESS "Feature admit-and-evict function is unavailable ..."); } if (PerfConfig::fastUnique) { Factory::Create(factory); } - REC_LOG(INFO) << StringFormat( - KEY_PROCESS "scInfo:%s, localRankSize:%d, rankSize:%d, useStatic:%d, useHot:%d", - MapToString(scInfo).c_str(), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot - ); + LOG_INFO(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", + MapToString(scInfo), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot); return true; } @@ -109,12 +107,12 @@ int KeyProcess::Start() // 0 1 2 3 4 5 0 1 2 3 4 5 // | rank0 | | rank1 | // each rank creates KEY_PROCESS_THREAD threads, each thread process one batchdata - LOG(INFO) << StringFormat("CPU Core Num: %d", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + LOG_INFO("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int threadId) { #ifndef GTEST auto ret = aclrtSetDevice(static_cast(rankInfo.deviceId)); if (ret != ACL_ERROR_NONE) { - LOG(ERROR) << StringFormat("Set device failed, device_id:%d", rankInfo.deviceId); + LOG_ERROR("Set device failed, device_id:{}", rankInfo.deviceId); return; } #endif @@ -126,10 +124,10 @@ int KeyProcess::Start() }; // for clean code int threadNum = GetThreadNumEnv(); for (int channel = 0; channel < MAX_CHANNEL_NUM; ++channel) { - LOG(INFO) << StringFormat(KEY_PROCESS "key process thread num: %d", threadNum); + LOG_INFO(KEY_PROCESS "key process thread num: {}", threadNum); for (int id = 0; id < threadNum; ++id) { - procThreads.emplace_back( - std::make_unique(fn, channel, id)); // use lambda expression initialize thread + // use lambda expression initialize thread + procThreads.emplace_back(std::make_unique(fn, channel, id)); } } return 0; @@ -176,12 +174,12 @@ void KeyProcess::LoadKeyOffsetMap(key_offset_mem_t& loadData) void KeyProcess::Destroy() { isRunning = false; - LOG(INFO) << StringFormat(KEY_PROCESS "rank %d begin destroy.", rankInfo.rankId); + LOG_INFO(KEY_PROCESS "rank {} begin destroy.", rankInfo.rankId); for (auto& i: procThreads) { i->join(); } procThreads.clear(); - LOG(INFO) << StringFormat(KEY_PROCESS "rank %d destroy success.", rankInfo.rankId); + LOG_INFO(KEY_PROCESS "rank {} destroy success.", rankInfo.rankId); } /// 每个数据通道的所有数据处理线程上锁 @@ -243,7 +241,7 @@ void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, auto ret = unique->Initialize(uniqueConf); if (ret != ock::ctr::H_OK) { - throw runtime_error(StringFormat("fast unique init failed, code:%d", ret)); + throw runtime_error(Log::Format("fast unique init failed, code:{}", ret)); } uniqueInitialize = true; } @@ -259,7 +257,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) auto ret = factory->CreateUnique(unique); if (ret != ock::ctr::H_OK) { - throw runtime_error(StringFormat("create fast unique failed, error code:%d", ret)); + throw runtime_error(Log::Format("create fast unique failed, error code:{}", ret)); } GetUniqueConfig(uniqueConf); @@ -268,7 +266,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) TimeCost getAndProcessTC; TimeCost getBatchDataTC; batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue - VLOG(GLOG_DEBUG) << StringFormat("getBatchDataTC(ms):%d", getBatchDataTC.ElapsedMS()); + LOG_DEBUG("getBatchDataTC(ms):{}", getBatchDataTC.ElapsedMS()); if (batch == nullptr) { break; } @@ -279,23 +277,19 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) if (!KeyProcessTaskHelperWithFastUnique(batch, unique, channel, threadId)) { break; } - LOG(INFO) << StringFormat( - KEY_PROCESS "getAndProcessTC(ms):%d, key process with fast unique cost:%d," - " get data time(ms):%d, batch name:%s, channel:%d, batchID:%d", + LOG_INFO(KEY_PROCESS "getAndProcessTC(ms):{}, key process with fast unique cost:{}," + " get data time(ms):{}, batch name:{}, channel:{}, batchID:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, - batch->name.c_str(), batch->channel, batch->batchId - ); + batch->name, batch->channel, batch->batchId); auto batchQueue = SingletonQueue::getInstances(threadId + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } unique->UnInitialize(); } catch (const EndRunError &e) { - VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "abort run: %s", e.what()); + LOG_ERROR(KEY_PROCESS "abort run: {}", e.what()); } - - LOG(INFO) << StringFormat( - KEY_PROCESS "KeyProcessTaskWithFastUnique exit. rank:%d thread:%d, channel:%d", - rankInfo.rankId, threadId, channel); + LOG_INFO(KEY_PROCESS "KeyProcessTaskWithFastUnique exit. rank:{} thread:{}, channel:{}", + rankInfo.rankId, threadId, channel); } @@ -307,7 +301,7 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) TimeCost getAndProcessTC; TimeCost getBatchDataTC; batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue - VLOG(GLOG_DEBUG) << StringFormat("getBatchDataTC(ms):%d", getBatchDataTC.ElapsedMS()); + LOG_DEBUG("getBatchDataTC(ms):{}", getBatchDataTC.ElapsedMS()); if (batch == nullptr) { break; } @@ -317,21 +311,17 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) if (!KeyProcessTaskHelper(batch, channel, threadId)) { break; } - LOG(INFO) << StringFormat( - KEY_PROCESS "getAndProcessTC(ms):%d, key process cost:%d," - " get data time(ms):%d, batch name:%s, channel:%d, batchID:%d", + LOG_INFO(KEY_PROCESS "getAndProcessTC(ms):{}, key process cost:{}," + " get data time(ms):{}, batch name:{}, channel:{}, batchID:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, - batch->name.c_str(), batch->channel, batch->batchId - ); + batch->name, batch->channel, batch->batchId); auto batchQueue = SingletonQueue::getInstances(threadId + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } } catch (const EndRunError &e) { - VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "abort run: %s", e.what()); + LOG_ERROR(KEY_PROCESS "abort run: {}", e.what()); } - - LOG(INFO) << StringFormat( - KEY_PROCESS "KeyProcessTask exit. rank:%d thread:%d, channel:%d", rankInfo.rankId, threadId, channel); + LOG_INFO(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, threadId, channel); } void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, @@ -349,7 +339,7 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < tie(splitKeys, restore) = HashSplit(batch); // 按存储dev id切分并去重 } } - VLOG(GLOG_DEBUG) << StringFormat("UniqueTC(ms):%d", UniqueTC.ElapsedMS()); + LOG_DEBUG("UniqueTC(ms):{}", UniqueTC.ElapsedMS()); } bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, UniquePtr& unique, @@ -362,15 +352,15 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat TimeCost fastUniqueTC; UniqueInfo uniqueInfo; ProcessBatchWithFastUnique(batch, unique, threadId, uniqueInfo); - VLOG(GLOG_DEBUG) << StringFormat("ProcessBatchWithFastUnique(ms):%d", fastUniqueTC.ElapsedMS()); + LOG_DEBUG("ProcessBatchWithFastUnique(ms):{}", fastUniqueTC.ElapsedMS()); // 特征准入&淘汰 if (isWithFAAE && (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, uniqueInfo.all2AllInfo.keyRecv, uniqueInfo.all2AllInfo.countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { - LOG(ERROR) << StringFormat(KEY_PROCESS "rank:%d thread:%d, channel:%d, Feature-admit-and-evict error ...", - rankInfo.rankId, threadId, channel); + LOG_ERROR(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", + rankInfo.rankId, threadId, channel); return false; } @@ -379,7 +369,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat if (rankInfo.noDDR) { TimeCost key2OffsetTC; Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv, channel); - VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); + LOG_DEBUG("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); } if (!rankInfo.useStatic) { // Static all2all,need send count SendA2A(uniqueInfo.all2AllInfo.scAll, batch->name, batch->channel, batch->batchId); @@ -401,11 +391,10 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat TimeCost pushResultTC; PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv); if (g_statOn) { - LOG(INFO) << StringFormat(STAT_INFO "channel_id %d batch_id %d rank_id %d " - "key_process_time_cost_with_fast_unique %d", + LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} key_process_time_cost_with_fast_unique {}", channel, batch->batchId, rankInfo.rankId, totalTimeCost.ElapsedMS()); } - VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); + LOG_DEBUG("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); return true; } @@ -432,8 +421,8 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE && (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, lookupKeys, countRecv) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { - LOG(ERROR) << StringFormat(KEY_PROCESS "rank:%d thread:%d, channel:%d, Feature-admit-and-evict error ...", - rankInfo.rankId, threadId, channel); + LOG_ERROR(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", + rankInfo.rankId, threadId, channel); return false; } @@ -465,10 +454,9 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe } PushResult(batch, move(tensors), lookupKeys); - VLOG(GLOG_DEBUG) << StringFormat("pushResultTC(ms):%d", pushResultTC.ElapsedMS()); + LOG_DEBUG("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); if (g_statOn) { - LOG(INFO) << StringFormat(STAT_INFO "channel_id %d batch_id %d rank_id %d " - "key_process_time_cost %d", + LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} key_process_time_cost {}", channel, batch->batchId, rankInfo.rankId, totalTimeCost.ElapsedMS()); } return true; @@ -482,7 +470,7 @@ void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tenso TimeCost globalUniqueSyncTC; GlobalUnique(lookupKeys, uniqueKeys, restoreVecSec); - VLOG(GLOG_DEBUG) << StringFormat("globalUniqueSyncTC(ms):%d", globalUniqueSyncTC.ElapsedMS()); + LOG_DEBUG("globalUniqueSyncTC(ms):{}", globalUniqueSyncTC.ElapsedMS()); tensors->push_back(Vec2TensorI32(restoreVecSec)); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : Vec2TensorI32(uniqueKeys)); } @@ -514,7 +502,7 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, countRecv.resize(rs.back() + rc.back()); MPI_Alltoallv(countSend.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), rc.data(), rs.data(), MPI_UINT32_T, comm[batch->channel][id]); - VLOG(GLOG_DEBUG) << StringFormat("getCountRecvTC(ms)(with-all2all):%d", getCountRecvTC.ElapsedMS()); + LOG_DEBUG("getCountRecvTC(ms)(with-all2all):{}", getCountRecvTC.ElapsedMS()); return countRecv; } @@ -569,11 +557,8 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) } } EASY_END_BLOCK - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "rank %d thread %d get batch %s[%d]:%d done. bs:%d sample:[%s]", - rankInfo.rankId, commId, - batch->name.c_str(), batch->channel, batch->batchId, batch->Size(), batch->UnParse().c_str() - ); + LOG_DEBUG(KEY_PROCESS "rank {} thread {} get batch {}[{}]:{} done. bs:{} sample:[{}]", + rankInfo.rankId, commId, batch->name, batch->channel, batch->batchId, batch->Size(), batch->UnParse()); #if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) if (batch->batchId == PROFILING_START_BATCH_ID) { EASY_PROFILER_ENABLE @@ -633,7 +618,7 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch throw runtime_error(StringFormat("fast unique DoEnhancedUnique failed, code:%d", ret)); } EASY_END_BLOCK - VLOG(GLOG_DEBUG) << StringFormat("FastUniqueCompute(ms):%d, ret:%d", uniqueTC.ElapsedMS(), ret); + LOG_DEBUG("FastUniqueCompute(ms):{}, ret:{}", uniqueTC.ElapsedMS(), ret); vector sc; HandleHotAndSendCount(batch, uniqueInfoOut, keySendInfo, sc, splitSize); @@ -669,7 +654,7 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uni TimeCost ComputeHotTc; ComputeHotPos(batch, hotMap, uniqueInfoOut.hotPos, uniqueInfoOut.restore, hotOffset); - VLOG(GLOG_DEBUG) << StringFormat("ComputeHot TimeCost(ms):%d", ComputeHotTc.ElapsedMS()); + LOG_DEBUG("ComputeHot TimeCost(ms):{}", ComputeHotTc.ElapsedMS()); UpdateHotMapForUnique(keySendInfo.keySend, keySendInfo.keyCount, hotOffset, batch->batchId % hotEmbUpdateStep == 0, batch->name); } @@ -713,7 +698,7 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS { TimeCost getScAllTC; GetScAllForUnique(sc, id, channel, all2AllInfoOut.scAll); // Allgather通信获取所有(不同rank相同thread id的) - VLOG(GLOG_DEBUG) << StringFormat("GetScAll TimeCost(ms):%d", getScAllTC.ElapsedMS()); + LOG_DEBUG("GetScAll TimeCost(ms):{}", getScAllTC.ElapsedMS()); TimeCost all2allTC; auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 @@ -733,7 +718,7 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfoOut.countRecv.data(), rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); } - VLOG(GLOG_DEBUG) << StringFormat("all2allTC TimeCost(ms):%d", all2allTC.ElapsedMS()); + LOG_DEBUG("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); EASY_END_BLOCK } @@ -743,15 +728,14 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, TimeCost processSplitKeysTC; EASY_FUNCTION(profiler::colors::Purple) EASY_VALUE("batchId", batch->batchId) - LOG(INFO) << StringFormat( - KEY_PROCESS "ProcessSplitKeys start batchId:%d, channel:%d", batch->batchId, batch->channel); + LOG_INFO(KEY_PROCESS "ProcessSplitKeys start batchId:{}, channel:{}", batch->batchId, batch->channel); // 使用静态all2all通信:发送或接受量为预置固定值 scInfo[batch->name] = 65536 / rankSize 经验值 if (rankInfo.useStatic) { // maybe move after all2all for (auto& i: splitKeys) { if (static_cast(i.size()) > embInfos[batch->name].sendCount) { - LOG(ERROR) << StringFormat("%s[%d]:%d overflow! set send count bigger than %d", - batch->name.c_str(), batch->channel, batch->batchId, i.size()); + LOG_ERROR("{}[{}]:{} overflow! set send count bigger than {}", + batch->name, batch->channel, batch->batchId, i.size()); throw runtime_error( StringFormat("%s[%d]:%d overflow! set send count bigger than %d", batch->name.c_str(), batch->channel, batch->batchId, i.size()).c_str()); @@ -769,7 +753,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, TimeCost getScAllTC; auto scAll = GetScAll(sc, id, batch->channel); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 - VLOG(GLOG_DEBUG) << StringFormat("getScAllTC(ms)(AllReduce-AllGather):%d", getScAllTC.ElapsedMS()); + LOG_DEBUG("getScAllTC(ms)(AllReduce-AllGather):{}", getScAllTC.ElapsedMS()); auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 vector rc; // receive count @@ -779,19 +763,19 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, } auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 keyRecv.resize(rs.back() + rc.back()); - VLOG(GLOG_TRACE) << StringFormat(KEY_PROCESS "MPI_Alltoallv begin. rank %d thread %d batch %d %s", - rankInfo.rankId, id, batch->batchId, batch->name.c_str()); + LOG_TRACE(KEY_PROCESS "MPI_Alltoallv begin. rank {} thread {} batch {} {}", + rankInfo.rankId, id, batch->batchId, batch->name); EASY_BLOCK("all2all") TimeCost uniqueAll2AllTC; MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[batch->channel][id]); - VLOG(GLOG_DEBUG) << StringFormat("uniqueAll2AllTC(ms):%d", uniqueAll2AllTC.ElapsedMS()); + LOG_DEBUG("uniqueAll2AllTC(ms):{}", uniqueAll2AllTC.ElapsedMS()); EASY_END_BLOCK - VLOG(GLOG_TRACE) << StringFormat(KEY_PROCESS "MPI_Alltoallv finish. rank %d thread %d batch %d %s", - rankInfo.rankId, id, batch->batchId, batch->name.c_str()); - VLOG(GLOG_DEBUG) << StringFormat("processSplitKeysTC(ms):%d", processSplitKeysTC.ElapsedMS()); + LOG_TRACE(KEY_PROCESS "MPI_Alltoallv finish. rank {} thread {} batch {} {}", + rankInfo.rankId, id, batch->batchId, batch->name); + LOG_DEBUG("processSplitKeysTC(ms):{}", processSplitKeysTC.ElapsedMS()); return { keyRecv, scAll, ss }; } @@ -823,7 +807,8 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< } } EASY_END_BLOCK - if (VLOG_IS_ON(GLOG_TRACE)) { + + LOG_TRACE("dump splitKeys {}", [&] { stringstream ssTrace; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { ssTrace << '|' << devId << ":"; @@ -832,16 +817,15 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< } ssTrace << '|'; } - VLOG(GLOG_TRACE) << "dump splitKeys " << ssTrace.str(); - } + return ssTrace.str(); + }()); if (g_statOn) { size_t UniqueKeyNum = 0; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { UniqueKeyNum += splitKeys[devId].size(); } - LOG(INFO) << StringFormat( - STAT_INFO "channel_id %d batch_id %d rank_id %d batch_key_num %d unique_key_num %ld", + LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} batch_key_num {} unique_key_num {}", batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), UniqueKeyNum); } return { splitKeys, restore }; @@ -884,7 +868,7 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const } EASY_END_BLOCK - if (VLOG_IS_ON(GLOG_TRACE)) { + LOG_TRACE("dump splitKeys {}", [&] { stringstream ssTrace; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { ssTrace << '|' << devId << ":"; @@ -893,8 +877,8 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const } ssTrace << '|'; } - VLOG(GLOG_TRACE) << "dump splitKeys " << ssTrace.str(); - } + return ssTrace.str(); + }()); if (g_statOn) { size_t UniqueKeyNum = 0; @@ -1055,15 +1039,11 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int throw EndRunError("GetScAll end run."); } EASY_END_BLOCK; - VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "barrier time:%d", tc.ElapsedMS()); + LOG_DEBUG(KEY_PROCESS "barrier time:{}", tc.ElapsedMS()); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); - if (VLOG_IS_ON(GLOG_DEBUG)) { - VLOG(GLOG_DEBUG) << StringFormat( - "rank %d key scAll matrix:\n%s", rankInfo.rankId, VectorToString(scAll).c_str() - ); - } + LOG_DEBUG("rank {} key scAll matrix:\n{}", rankInfo.rankId, VectorToString(scAll)); return scAll; } @@ -1080,15 +1060,11 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, in throw EndRunError("GetScAll end run."); } EASY_END_BLOCK; - VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "barrier time:%d", tc.ElapsedMS()); + LOG_DEBUG(KEY_PROCESS "barrier time:{}", tc.ElapsedMS()); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); - if (VLOG_IS_ON(GLOG_DEBUG)) { - VLOG(GLOG_DEBUG) << StringFormat( - "rank %d key scAllOut matrix:\n%s", rankInfo.rankId, VectorToString(scAllOut).c_str() - ); - } + LOG_DEBUG("rank {} key scAllOut matrix:\n{}", rankInfo.rankId, VectorToString(scAllOut)); } void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int channel) @@ -1110,10 +1086,8 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int cha size_t offset; // 新值, emb有pos可复用 offset = evictPos.back(); - VLOG(GLOG_TRACE) << StringFormat( - "HBM mode, evictPos is not null, name[%s] key [%d] reuse offset [%d], evictSize [%d]!!!", - embName.c_str(), key, offset, evictPos.size() - ); + LOG_TRACE("HBM mode, evictPos is not null, name[{}] key [{}] reuse offset [{}], evictSize [{}]!!!", + embName, key, offset, evictPos.size()); key2Offset[key] = offset; key = offset; evictPos.pop_back(); @@ -1131,9 +1105,8 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int cha LOG(ERROR) << StringFormat("dev cache overflow %d>%d", maxOffsetTmp, embInfos[embName].devVocabSize); throw std::runtime_error("dev cache overflow!"); } - VLOG(GLOG_DEBUG) << StringFormat("current hbm emb:%s, usage:%d/%d", embName.c_str(), maxOffsetTmp, - embInfos[embName].devVocabSize); - VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); + LOG_DEBUG("current hbm emb:{}, usage:{}/{} key2OffsetTC({} ms)", + embName, maxOffsetTmp, embInfos[embName].devVocabSize, key2OffsetTC.ElapsedMS()); } void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& splitKey, int channel) @@ -1166,9 +1139,8 @@ void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& s key = 0; } } - VLOG(GLOG_DEBUG) << StringFormat("current expansion emb:%s, usage:%d/%d", embName.c_str(), maxOffsetTmp, - embInfos[embName].devVocabSize); - VLOG(GLOG_DEBUG) << StringFormat("key2OffsetTC(ms):%d", key2OffsetTC.ElapsedMS()); + LOG_DEBUG("current expansion emb:{}, usage:{}/{}, key2OffsetTC({} ms)", + embName, maxOffsetTmp, embInfos[embName].devVocabSize, key2OffsetTC.ElapsedMS()); } /* @@ -1208,17 +1180,16 @@ T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, in { std::lock_guard lockGuard(mut); if (list[embName][channel].empty()) { - VLOG(GLOG_TRACE) << "get info list is empty."; + LOG_TRACE("get info list is empty."); throw EmptyList(); } auto topBatch = get(list[embName][channel].top()); if (topBatch < batch) { - LOG(ERROR) << StringFormat( - "wrong batch id, top:%d getting:%d, channel:%d, may not clear channel", topBatch, batch, channel); + LOG_ERROR("wrong batch id, top:{} getting:{}, channel:{}, may not clear channel", topBatch, batch, channel); this_thread::sleep_for(1s); } if (topBatch != batch) { - VLOG(GLOG_TRACE) << StringFormat("topBatch(%d) is not equal batch(%d).", topBatch, batch); + LOG_TRACE("topBatch({}) is not equal batch({}).", topBatch, batch); throw WrongListTop(); } auto t = list[embName][channel].top(); @@ -1242,25 +1213,22 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) // 判断此时的batch id是否已经过期,即通道已经刷新 HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); if (batch != hybridMgmtBlock->hybridBatchId[channel]) { - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "Detected that the batch has expired at this time, exiting the loop! %s[%d]:%d", - embName.c_str(), channel, batch); + LOG_DEBUG(KEY_PROCESS "Detected that the batch has expired at this time, exiting the loop! {}[{}]:{}", + embName, channel, batch); return {}; } if (batch != 0 && channel != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { - LOG(WARNING) << StringFormat( - KEY_PROCESS "getting lookup keys timeout! %s[%d]:%d", embName.c_str(), channel, batch); + LOG_WARN(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", embName, channel, batch); return {}; } try { auto ret = GetInfo(lookupKeysList, batch, embName, channel); return get(ret); } catch (EmptyList&) { - VLOG(GLOG_TRACE) << StringFormat("getting info failed %s[%d]:%d", embName.c_str(), channel, batch); + LOG_TRACE("getting info failed {}[{}]:{}", embName, channel, batch); this_thread::sleep_for(1ms); } catch (WrongListTop&) { - VLOG(GLOG_TRACE) << StringFormat( - "getting info failed %s[%d]:%d wrong top", embName.c_str(), channel, batch); + LOG_TRACE("getting info failed {}[{}]:{} wrong top", embName, channel, batch); this_thread::sleep_for(1ms); } } @@ -1297,14 +1265,12 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa // 判断此时的batch id是否已经过期,即通道已经刷新 HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); if (batch != hybridMgmtBlock->hybridBatchId[channel]) { - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "Detected that the batch has expired at this time, exiting the loop! %s[%d]:%d", - embName.c_str(), channel, batch); + LOG_DEBUG(KEY_PROCESS "Detected that the batch has expired at this time, exiting the loop! {}[{}]:{}", + embName, channel, batch); return nullptr; } if (batch != 0 && channel != 0 && tc.ElapsedSec() > KEY_PROCESS_TIMEOUT) { - LOG(WARNING) << StringFormat( - KEY_PROCESS "getting lookup keys timeout! %s[%d]:%d", embName.c_str(), channel, batch); + LOG_WARN(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", embName, channel, batch); return nullptr; } try { @@ -1315,11 +1281,10 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa storage.erase(it); return uTensor; } catch (EmptyList&) { - VLOG(GLOG_TRACE) << StringFormat("getting info failed %s[%d]:%d", embName.c_str(), channel, batch); + LOG_TRACE("getting info failed {}[{}]:{}", embName, channel, batch); this_thread::sleep_for(1ms); } catch (WrongListTop&) { - VLOG(GLOG_TRACE) << StringFormat( - "getting info failed %s[%d]:%d wrong top", embName.c_str(), channel, batch); + LOG_TRACE("getting info failed {}[{}]:{} wrong top", embName, channel, batch); this_thread::sleep_for(1ms); } } @@ -1351,7 +1316,7 @@ int KeyProcess::GetMaxStep(int channelId) const void KeyProcess::EvictKeys(const string& embName, const vector& keys) // hbm { - LOG(INFO) << StringFormat(KEY_PROCESS "hbm funEvictCall: [%s]! keySize:%d", embName.c_str(), keys.size()); + LOG_INFO(KEY_PROCESS "hbm funEvictCall: [{}]! keySize:{}", embName, keys.size()); // 删除映射关系 if (keys.size() != 0) { @@ -1375,7 +1340,7 @@ void KeyProcess::EvictDeleteDeviceEmb(const string& embName, const vectorsecond; devHashMap.erase(iter); evictPos.emplace_back(offset); - VLOG(GLOG_TRACE) << StringFormat("evict embName:%s, offset:%d", embName.c_str(), offset); + LOG_TRACE("evict embName:{}, offset:{}", embName, offset); } - LOG(INFO) << StringFormat( - KEY_PROCESS "hbm EvictDeleteDeviceEmb: [%s]! evict size on dev:%d", embName.c_str(), evictPos.size()); + LOG_INFO(KEY_PROCESS "hbm EvictDeleteDeviceEmb: [{}]! evict size on dev:{}", embName, evictPos.size()); } void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset) { if (offset.size() > embInfos[embName].devVocabSize) { - LOG(ERROR) << StringFormat( - "%s overflow! init evict dev, evictOffset size %s bigger than dev vocabSize %d", - embName.c_str(), offset.size(), embInfos[embName].devVocabSize); + LOG_ERROR("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", + embName, offset.size(), embInfos[embName].devVocabSize); throw runtime_error( - StringFormat( - "%s overflow! init evict dev, evictOffset size %d bigger than dev vocabSize %d", - embName.c_str(), offset.size(), embInfos[embName].devVocabSize + Log::Format("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", + embName, offset.size(), embInfos[embName].devVocabSize ).c_str()); } @@ -1417,6 +1379,5 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset auto trans = Singleton::GetInstance(); trans->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); - LOG(INFO) << StringFormat( - KEY_PROCESS "hbm EvictInitDeviceEmb: [%s]! send offsetSize:%d", embName.c_str(), offset.size()); + LOG_INFO(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); } diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index 2108c915..5c4391f2 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -140,8 +140,8 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, size_t needSSDSize = externalKeys.size() - externalSSDKeys.size() - ddrAvailableSize; const int64_t ssdAvailableSize = ssdEngine->GetTableAvailableSpace(embTableName); if (int64_t(needSSDSize) > ssdAvailableSize) { - LOG(ERROR) << "TransferDDREmbWithSSD: ssd available space is not enough to transfer DDR emb data." - " needSSDSize:" << needSSDSize << ", ssdAvailableSize:" << ssdAvailableSize; + LOG_ERROR("TransferDDREmbWithSSD: ssd available space is not enough to transfer DDR emb data. " + "needSSDSize:{}, ssdAvailableSize:{}", needSSDSize, ssdAvailableSize); return TransferSpaceWarning(); } } @@ -156,7 +156,7 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, // 从ddr转移到ssd的key个数 size_t ddrSwapOutSizeTmp = ddrSpaceEnoughOrEval ? externalSSDKeys.size() : externalKeys.size(); auto ddrSwapOutSize = static_cast(ddrSwapOutSizeTmp - ddrAvailableSize); - VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: ddrSwapOutSize:%d", ddrSwapOutSize); + LOG_DEBUG("TransferDDREmbWithSSD: ddrSwapOutSize:{}", ddrSwapOutSize); /* * 转移DDR中数据到SSD @@ -203,10 +203,8 @@ void CacheManager::GetDDREmbInfo(vector& keys, const std::string& emb ddrTransferPos.emplace_back(embHashMap.hostHashMap[key] - embHashMap.devVocabSize); } - if (VLOG_IS_ON(GLOG_TRACE)) { - VLOG(GLOG_TRACE) << "DDR keys:" << VectorToString(keys); - VLOG(GLOG_TRACE) << "DDR key positions:" << VectorToString(ddrTransferPos); - } + LOG_TRACE("DDR keys:{}", VectorToString(keys)); + LOG_TRACE("DDR key positions:{}", VectorToString(ddrTransferPos)); ddrEmbData.resize(keys.size()); const auto& emb = hostEmbs->GetEmb(embTableName); @@ -289,7 +287,7 @@ void CacheManager::Init(HostEmb* hostEmbPtr, vector& mgmtEmbInfo) embBaseInfos.emplace(emb.name, baseInfo); } ssdEngine->Start(); - LOG(INFO) << "CacheManager Init method end."; + LOG_INFO("CacheManager Init method end."); } bool CacheManager::IsKeyInSSD(const string& embTableName, emb_key_t key) @@ -338,9 +336,9 @@ void CacheManager::HandleDDRTransferPos(vector& ddrTransferPos, vector externalSSDKeys.size()) { while (ddrTransferPos.size() > externalSSDKeys.size()) { @@ -360,8 +358,8 @@ void CacheManager::HandleDDRTransferPos(vector& ddrTransferPos, vector& externalKeys, @@ -384,8 +382,7 @@ TransferRet CacheManager::TransferDDREmb2SSD(const string& embTableName, EmbHash } TimeCost ddr2SsdTc; - VLOG(GLOG_DEBUG) - << StringFormat("TransferDDREmbWithSSD: get ddr least freq keys, ddrSwapOutSize:%lld", ddrSwapOutSize); + LOG_DEBUG("TransferDDREmbWithSSD: get ddr least freq keys, ddrSwapOutSize:{}", ddrSwapOutSize); // 获取DDR中指定数量的最低频次key,并获取相应emb数据,执行DDR换出到SSD vector ddrSwapOutKeys; vector ddrSwapOutCounts; @@ -394,14 +391,13 @@ TransferRet CacheManager::TransferDDREmb2SSD(const string& embTableName, EmbHash if (static_cast(ddrSwapOutKeys.size()) != ddrSwapOutSize) { auto keyTableSize = ddrKeyFreqMap[embTableName].keyTable.size(); // 获取的最低频次key数量和预期不一致,DDR空间不足,不能放置当前批次数据 - LOG(ERROR) << StringFormat("TransferDDREmbWithSSD, vector length is not equal, ddrSwapOutKeys size:%d, " - "ddrSwapOutSize:%lld, ddr lfu keyTable size:%lld", ddrSwapOutKeys.size(), - ddrSwapOutSize, keyTableSize); + LOG_ERROR("TransferDDREmbWithSSD, vector length is not equal, ddrSwapOutKeys size:{}, " + "ddrSwapOutSize:{}, ddr lfu keyTable size:{}", + ddrSwapOutKeys.size(), ddrSwapOutSize, keyTableSize); RestoreLeastFreqInfo(embTableName, ddrSwapOutKeys, ddrSwapOutCounts); return TransferRet::DDR_SPACE_NOT_ENOUGH; } - VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: get DDR embeddings and save to SSD, size:%d", - ddrSwapOutKeys.size()); + LOG_DEBUG("TransferDDREmbWithSSD: get DDR embeddings and save to SSD, size:{}", ddrSwapOutKeys.size()); // 获取DDR中emb数据 vector> ddrEmbData; GetDDREmbInfo(ddrSwapOutKeys, embTableName, embHashMap, ddrTransferPos, ddrEmbData); @@ -413,7 +409,7 @@ TransferRet CacheManager::TransferDDREmb2SSD(const string& embTableName, EmbHash // 更新记录的DDR中key频次信息 RefreshRelateInfoWithDDR2SSD(embTableName, embHashMap, ddrSwapOutKeys, ddrSwapOutCounts); - VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: ddr2SsdTc TimeCost(ms):%d", ddr2SsdTc.ElapsedMS()); + LOG_DEBUG("TransferDDREmbWithSSD: ddr2SsdTc TimeCost(ms):{}", ddr2SsdTc.ElapsedMS()); return TransferRet::TRANSFER_OK; } @@ -425,18 +421,16 @@ TransferRet CacheManager::TransferSSDEmb2DDR(const string& embTableName, EmbHash return TransferSuccess(); } TimeCost ssd2DdrTc; - VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: get SSD embeddings and save to DDR, size:%d", - externalSSDKeys.size()); + LOG_DEBUG("TransferDDREmbWithSSD: get SSD embeddings and save to DDR, size:{}", externalSSDKeys.size()); if (ddrTransferPos.size() != externalSSDKeys.size() || externalSSDKeys.size() != ssdEmbData.size()) { - LOG(ERROR) << StringFormat( - "TransferDDREmbWithSSD, vector length is not equal, ddrTransferPos len:%d, externalSSDKeys len:%d, " - "ssdEmbData len:%d", ddrTransferPos.size(), externalSSDKeys.size(), ssdEmbData.size()); + LOG_ERROR("TransferDDREmbWithSSD, vector length is not equal, ddrTransferPos len:{}, externalSSDKeys len:{}, " + "ssdEmbData len:{}", ddrTransferPos.size(), externalSSDKeys.size(), ssdEmbData.size()); return TransferError(); } // 将SSD emb存储到DDR中 刷新频次信息 UpdateDDREmbInfo(embTableName, ddrTransferPos, ssdEmbData); RefreshRelateInfoWithSSD2DDR(embTableName, embHashMap, externalSSDKeys, ddrTransferPos); - VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: ssd2DdrTc TimeCost(ms):%d", ssd2DdrTc.ElapsedMS()); + LOG_DEBUG("TransferDDREmbWithSSD: ssd2DdrTc TimeCost(ms):{}", ssd2DdrTc.ElapsedMS()); return TransferSuccess(); } @@ -449,12 +443,12 @@ void CacheManager::CreateSSDTableIfNotExist(const std::string& embTableName) ssdEngine->CreateTable(embTableName, embBaseInfos[embTableName].savePath, embBaseInfos[embTableName].maxTableSize); embBaseInfos[embTableName].isExist = true; - LOG(INFO) << ("create ssd table end, embTableName:" + embTableName); + LOG_INFO("create ssd table end, embTableName:" + embTableName); return; } // 续训场景:embBaseInfos 没有保存,不会初始化;SSD表会初始化,此时表已存在 embBaseInfos[embTableName].isExist = true; - LOG(INFO) << ("ssd table is exist, embTableName:" + embTableName); + LOG_INFO("ssd table is exist, embTableName:" + embTableName); } void CacheManager::RestoreLeastFreqInfo(const std::string& embTableName, vector& ddrSwapOutKeys, diff --git a/src/core/ssd_engine/ssd_engine.cpp b/src/core/ssd_engine/ssd_engine.cpp index 169e785b..df226a56 100644 --- a/src/core/ssd_engine/ssd_engine.cpp +++ b/src/core/ssd_engine/ssd_engine.cpp @@ -120,7 +120,7 @@ void SSDEngine::Start() /// 压缩监控方法,达到检查周期时调用表的压缩接口 void SSDEngine::CompactMonitor() { - VLOG(GLOG_DEBUG) << "SSDEngine start CompactMonitor"; + LOG_DEBUG("SSDEngine start CompactMonitor"); auto start = chrono::high_resolution_clock::now(); auto end = chrono::high_resolution_clock::now(); chrono::microseconds loopDuration = 100ms; @@ -128,17 +128,17 @@ void SSDEngine::CompactMonitor() while (isRunning) { duration = chrono::duration_cast(end - start); if (duration >= compactPeriod) { - VLOG(GLOG_DEBUG) << "SSDEngine CompactMonitor start compact"; + LOG_DEBUG("SSDEngine CompactMonitor start compact"); for (const auto &item: tableMap) { item.second->Compact(false); } - VLOG(GLOG_DEBUG) << "SSDEngine CompactMonitor end compact"; + LOG_DEBUG("SSDEngine CompactMonitor end compact"); start = chrono::high_resolution_clock::now(); } this_thread::sleep_for(loopDuration); end = chrono::high_resolution_clock::now(); } - VLOG(GLOG_DEBUG) << "SSDEngine end CompactMonitor"; + LOG_DEBUG("SSDEngine end CompactMonitor"); } vector> SSDEngine::FetchEmbeddings(const string &tableName, vector &keys) @@ -164,7 +164,7 @@ void SSDEngine::Stop() tableMap.clear(); compactThread = nullptr; - LOG(INFO) << "SSDEngine stop"; + LOG_INFO("SSDEngine stop"); } /// 设置文件压缩的周期 diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index 89880989..212de4e7 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -23,7 +23,7 @@ Table::Table(const string &name, vector &savePaths, uint64_t maxTableSiz if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { throw runtime_error("fail to create table directory"); } - LOG(INFO) << StringFormat("create table:%s at path:%s", name.c_str(), curTablePath.c_str()); + LOG_INFO("create table:{} at path:{}", name, curTablePath); } /// 加载表 @@ -55,7 +55,7 @@ Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize // always use first path to save until it's full curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + g_rankId + "/" + name).string(); - LOG(INFO) << StringFormat("load table:%s done. try store at path:%s", name.c_str(), curTablePath.c_str()); + LOG_INFO("load table:{} done. try store at path:{}", name, curTablePath); } bool Table::IsKeyExist(emb_key_t key) @@ -86,7 +86,7 @@ void Table::DeleteEmbeddings(vector &keys) void Table::Save(int step) { - LOG(INFO) << StringFormat("start save table:%s, at step:%d", name.c_str(), step); + LOG_INFO("start save table:{}, at step:{}", name, step); Compact(true); lock_guard guard(rwLock); @@ -121,7 +121,7 @@ void Table::Save(int step) } metaFile.close(); - LOG(INFO) << StringFormat("end save table:%s, at step:%d", name.c_str(), step); + LOG_INFO("end save table:{}, at step:{}", name, step); } /// 根据元数据加载data文件 @@ -129,7 +129,7 @@ void Table::Save(int step) /// \param step 加载的步数 void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) { - LOG(INFO) << StringFormat("table:%s, start load data file", name.c_str()); + LOG_INFO("table:{}, start load data file", name); uint64_t fileCnt; metaFile->read(reinterpret_cast(&fileCnt), sizeof(fileCnt)); uint64_t fileID; @@ -182,7 +182,7 @@ void Table::Load(const string &metaFilePath, int step) shared_ptr metaFile = make_shared(); metaFile->open(metaFilePath, ios::in | ios::binary); - LOG(INFO) << StringFormat("table:%s, load meta file from path:%s", name.c_str(), metaFilePath.c_str()); + LOG_INFO("table:{}, load meta file from path:{}", name, metaFilePath); if (!metaFile->is_open()) { throw invalid_argument("fail to open meta"); } @@ -207,7 +207,7 @@ void Table::Load(const string &metaFilePath, int step) if (metaFile->fail()) { throw runtime_error("fail to load table"); } - LOG(INFO) << StringFormat("table:%s, end load data file", name.c_str()); + LOG_INFO("table:{}, end load data file", name); } void Table::InsertEmbeddingsInner(vector &keys, vector> &embeddings) @@ -229,10 +229,8 @@ void Table::InsertEmbeddingsInner(vector &keys, vector> throw runtime_error("all disk's space not enough"); } curTablePath = savePaths[curSavePathIdx]; - LOG(INFO) << StringFormat( - "current data path's free space less than %f, try next path:%s", - diskFreeSpaceThreshold, curTablePath.c_str() - ); + LOG_INFO("current data path's free space less than {}, try next path:{}", + diskFreeSpaceThreshold, curTablePath); } curFile = make_shared(curMaxFileID, curTablePath); diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 00b5700e..037e54d7 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -109,22 +109,12 @@ namespace MxRec { g_rankId = std::to_string(rank); } if (!g_isGlogInit) { - InitGoogleLogging("mxRec", &CustomGlogFormat); + Log::SetLevel(g_glogLevel); + Log::SetRank(rank); g_isGlogInit = true; } } - void CustomGlogFormat(std::ostream &s, const google::LogMessageInfo &l, void*) - { - s << "[" - << setw(GLOG_TIME_WIDTH_2) << l.time.hour() << ':' - << setw(GLOG_TIME_WIDTH_2) << l.time.min() << ':' - << setw(GLOG_TIME_WIDTH_2) << l.time.sec() << "." - << setw(GLOG_TIME_WIDTH_6) << l.time.usec() << "]" - << " [" + g_rankId + "]" - << " [" << l.severity << "] "; - } - bool GetEnv(const char *envName) { const char* envString = getenv(envName); @@ -133,13 +123,12 @@ namespace MxRec { try { tmp = std::stoi(envString); if (tmp == 0 || tmp == 1) { - LOG(INFO) << StringFormat("Succeed to parse ${env:%s}: %d.", envName, tmp); + LOG_INFO("Succeed to parse ${env:{}}: {}.", envName, tmp); } else { - LOG(ERROR) << StringFormat("Invalid ${env:%s}: %d, which should be an 0 or 1.", envName, tmp); + LOG_ERROR("Invalid ${env:{}}: {}, which should be an 0 or 1.", envName, tmp); } } catch (const std::invalid_argument &e) { - LOG(ERROR) << - StringFormat("Failed to parse ${env:%s}, which should be an integer.", envName); + LOG_ERROR("Failed to parse ${env:{}}, which should be an integer.", envName); } } return (tmp == 1) ? true : false; @@ -153,11 +142,10 @@ namespace MxRec { { 0 }}; ret = dsmi_get_chip_info(devID, &info); if (ret == 0) { - VLOG(GLOG_DEBUG) << StringFormat( - "dsmi_get_chip_info successful, ret = %d, chip_name = %s", ret, - reinterpret_cast(info.chip_name) - ); - return reinterpret_cast(info.chip_name); + stringstream ss; + ss << info.chip_name; + LOG_DEBUG("dsmi_get_chip_info successful, ret = {}, chip_name = {}", ret, ss.str()); + return ss.str(); } throw std::runtime_error("dsmi_get_chip_info failed, ret = " + to_string(ret)); @@ -170,9 +158,9 @@ namespace MxRec { if (faaeMode != nullptr) { try { isCombine = (std::stoi(faaeMode) == 1); - LOG(INFO) << StringFormat("If combine history table: %d", isCombine); + LOG_INFO("If combine history table: {}", isCombine); } catch (const std::invalid_argument& e) { - LOG(ERROR) << "The value of USE_COMBINE_FAAE is invalid!"; + LOG_ERROR("The value of USE_COMBINE_FAAE is invalid!"); throw std::invalid_argument("Invalid env value USE_COMBINE_FAAE"); } } @@ -188,15 +176,14 @@ namespace MxRec { threadNum = std::stoi(threadNumEnv); } catch (const std::invalid_argument& e) { threadNum = KEY_PROCESS_THREAD; - LOG(INFO) << StringFormat("error value of threadNum, use default KEY_PROCESS_THREAD: %d", - threadNum); + LOG_INFO("error value of threadNum, use default KEY_PROCESS_THREAD: {}", threadNum); } if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { throw runtime_error(StringFormat("%d is not valid", threadNum)); } } else { threadNum = KEY_PROCESS_THREAD; - LOG(INFO) << StringFormat("use default KEY_PROCESS_THREAD: %d", threadNum); + LOG_INFO("use default KEY_PROCESS_THREAD: {}", threadNum); } return threadNum; } @@ -207,14 +194,13 @@ namespace MxRec { struct stat fileInfo; if (lstat(dataDir.c_str(), &fileInfo) != -1) { if (S_ISLNK(fileInfo.st_mode)) { - LOG(ERROR) << StringFormat("soft link %s should not in the path parameter", dataDir.c_str()); + LOG_ERROR("soft link {} should not in the path parameter", dataDir); throw invalid_argument(StringFormat("soft link should not be the path parameter")); } } // validate file size if (datasetSize > FILE_MAX_SIZE) { - LOG(ERROR) << StringFormat("the reading file size is invalid, " - "not in range [%d,%lld]", FILE_MIN_SIZE, FILE_MAX_SIZE); + LOG_ERROR("the reading file size is invalid, not in range [{},{}]", FILE_MIN_SIZE, FILE_MAX_SIZE); throw invalid_argument(StringFormat("file size invalid")); } } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 780fddf4..55f859ba 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -23,9 +23,9 @@ #include #include #include "tensorflow/core/framework/tensor.h" -#include "glog/logging.h" // note: must set behind any tensorflow reference, otherwise will overwrite logging.h #include "absl/container/flat_hash_map.h" #include "securec.h" +#include "utils/log.h" #include "initializer/initializer.h" #include "initializer/constant_initializer/constant_initializer.h" @@ -298,8 +298,6 @@ namespace MxRec { void SetLog(int rank); - void CustomGlogFormat(std::ostream &s, const google::LogMessageInfo &l, void*); - bool GetEnv(const char *envName); template diff --git a/src/core/utils/log.cpp b/src/core/utils/log.cpp new file mode 100644 index 00000000..c75545ee --- /dev/null +++ b/src/core/utils/log.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + + +#include "utils/log.h" + +namespace MxRec { + +int MxRec::Log::_level = MxRec::Log::INFO; +int MxRec::Log::_rank = 0; + +void Log::SetRank(int rank) +{ + Log::_rank = rank; +} + +void Log::SetLevel(int level) +{ + Log::_level = level; +} + +int Log::GetLevel() +{ + return Log::_level; +} + +const char* Log::LevelToStr(int level) +{ + if (level < TRACE || level > ERROR) { + return "INVALID LEVEL"; + } + static const char* msg[] = { + "TRACE", + "DEBUG", + "INFO", + "WARN", + "ERROR", + }; + return msg[level]; +} + +void Log::LogUnpack(queue& fmt, stringstream &ss) +{ + while (!fmt.empty()) { + ss << fmt.front(); + fmt.pop(); + } + return; +} + +} \ No newline at end of file diff --git a/src/core/utils/log.h b/src/core/utils/log.h new file mode 100644 index 00000000..c1078510 --- /dev/null +++ b/src/core/utils/log.h @@ -0,0 +1,125 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + +#ifndef MXREC_LOG_H_ +#define MXREC_LOG_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace MxRec { + +constexpr int YEAR_BASE = 1900; +constexpr size_t DELIM_LEN = 2; + +class Log { +public: + + static constexpr int TRACE = 0; + static constexpr int DEBUG = 1; + static constexpr int INFO = 2; + static constexpr int WARN = 3; + static constexpr int ERROR = 4; + + static void SetRank(int rank); + + static void SetLevel(int level); + + static int GetLevel(); + + template + static void Format(stringstream& ss, const char* fmt, Args &&...args) + { + queue formats; + string tmp(fmt); + for (size_t pos = tmp.find_first_of("{}"); pos != string::npos; pos = tmp.find_first_of("{}")) { + string x = tmp.substr(0, pos); + formats.push(x); + tmp = tmp.substr(pos + DELIM_LEN); + } + formats.push(tmp); + LogUnpack(formats, ss, args...); + } + + template + static string Format(const char* fmt, Args &&...args) + { + stringstream ss; + Log::Format(ss, fmt, args...); + return ss.str(); + } + + template + static void log(const char* file, int line, int level, const char* fmt, Args &&...args) + { + stringstream ss; + struct tm t; + struct timeval tv; + gettimeofday(&tv, NULL); + localtime_r(&tv.tv_sec, &t); + ss << "[MxRec][" << YEAR_BASE + t.tm_year << "/" << t.tm_mon << "/" << t.tm_mday<< " " + << t.tm_hour << ":" << t.tm_min << ":" << t.tm_sec << "." << tv.tv_usec << "] [" + << Log::_rank << "] ["<< Log::LevelToStr(level) << "] [" + << (strrchr(file, '/') ? strrchr(file, '/') + 1 : file) << ":" << line << "] "; + Log::Format(ss, fmt, args...); + ss << std::endl; + std::cout << ss.str(); + } + + template + static void log(const char* file, int line, int level, const string& fmt, Args &&...args) + { + Log::log(file, line, level, fmt.c_str(), args...); + } + +private: + static const char* LevelToStr(int level); + + static void LogUnpack(queue& fmt, stringstream &ss); + + template + static void LogUnpack(queue& fmt, stringstream &ss, head &h, tail &&...tails) + { + if (!fmt.empty()) { + ss << fmt.front(); + fmt.pop(); + } + ss << h; + LogUnpack(fmt, ss, tails...); + }; + static int _level; + static int _rank; +}; + + +#define LOG_TRACE(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::TRACE) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::TRACE, args) + +#define LOG_DEBUG(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::DEBUG) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::DEBUG, args) + +#define LOG_INFO(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::INFO) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::INFO, args) + +#define LOG_WARN(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::WARN) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::WARN, args) + +#define LOG_ERROR(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::ERROR) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::ERROR, args) + +} + +#endif // MXREC_LOG_H_ \ No newline at end of file diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index f09eb3e2..463798e8 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -147,7 +147,7 @@ protected: batch->name = embName; batch->timestamp = ts; printf("\n"); - LOG(INFO) << StringFormat("current admit embName[%s] at time[%d] ...", embName.c_str(), ts); + LOG_INFO("current admit embName[{}] at time[{}] ...", embName, ts); // 校验调接口不出错 ASSERT_EQ(faae.FeatureAdmit(channel, batch, args.keys, args.cnt) != @@ -170,7 +170,7 @@ protected: batch->name = embName; batch->timestamp = ts; printf("\n"); - LOG(INFO) << StringFormat("current admit embName[%s] at time[%d] ...", embName.c_str(), ts); + LOG_INFO("current admit embName[{}] at time[{}] ...", embName, ts); // 校验调接口不出错 faae.FeatureAdmit(channel, batch, args.keys, args.cnt); @@ -231,7 +231,7 @@ protected: void StartEvictThread() { evictThr = std::thread([&]() { - LOG(INFO) << "Evict-thread start ..."; + LOG_INFO("Evict-thread start ..."); time_t currTime = 0; time_t lastTime = 0; @@ -239,13 +239,13 @@ protected: std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); currTime = time(nullptr); if (currTime - lastTime >= SleepTime::SLEEP_SECOND_4) { - LOG(INFO) << StringFormat("Evict-thread doing at currTime[%d] ...", currTime); + LOG_INFO("Evict-thread doing at currTime[{}] ...", currTime); map> evictPosMap {}; faae.FeatureEvict(evictPosMap); lastTime = currTime; } } - LOG(INFO) << "Evict-thread exit ..."; + LOG_INFO("Evict-thread exit ..."); }); } void WaitEvictThread() @@ -318,7 +318,7 @@ protected: FeatureAdmitCommon(faae, 0, thresholds[1].tableName, args5); WaitEvictThread(); - LOG(INFO) << "TestCase1: single thread test over ..."; + LOG_INFO("TestCase1: single thread test over ..."); } // 进行“准入”逻辑时,若(splitKey.size() != keyCount.size()),则业务报错退出;(说明是前面all2all通信数据错误) @@ -338,7 +338,7 @@ protected: // 校验调接口,出错 ASSERT_EQ(faae.FeatureAdmit(0, batch, tmpKeys, tmpCnt) == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR, true); - LOG(INFO) << "TestCase2 over ..."; + LOG_INFO("TestCase2 over ..."); } // 准入、淘汰阈值可单独配置;只配置“准入”阈值、却不配置“淘汰”阈值,功能正常; @@ -413,7 +413,7 @@ protected: } WaitEvictThread(); - LOG(INFO) << "TestCase5: multi thread test over ..."; + LOG_INFO("TestCase5: multi thread test over ..."); } // 同时不配置“准入、淘汰”阈值,特征准入&淘汰功能“不支持”; @@ -428,7 +428,7 @@ protected: batch->timestamp = time(nullptr); // 校验调接口,不支持 - LOG(INFO) << "TestCase6 over ..."; + LOG_INFO("TestCase6 over ..."); } bool isExitFlag { false }; diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 882b9c4e..7ee83494 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -41,7 +41,7 @@ static void CTRLog(int level, const char *msg) { switch (level) { case 0: - VLOG(GLOG_DEBUG) << StringFormat("%s", msg); + LOG_DEBUG(msg); break; default: break; @@ -57,7 +57,7 @@ protected: ASSERT_EQ(claimed, MPI_THREAD_MULTIPLE); MPI_Comm_rank(MPI_COMM_WORLD, &worldRank); MPI_Comm_size(MPI_COMM_WORLD, &worldSize); - LOG(INFO) << StringFormat(KEY_PROCESS "wordRank: %d, worldSize: %d", worldRank, worldSize); + LOG_INFO(KEY_PROCESS "wordRank: {}, worldSize: {}", worldRank, worldSize); // 初始化rank信息 rankInfo.rankId = worldRank; rankInfo.rankSize = worldSize; @@ -91,10 +91,9 @@ protected: batch->name = embInfos[i].name; batch->batchId = batchId; batch->channel = channel; - VLOG(GLOG_DEBUG) << StringFormat( - "[%d/%d]" KEY_PROCESS "PrepareBatch: batchQueueId: %d, %s[%d]%d, sampleSize:%d", + LOG_DEBUG("[{}/{}]" KEY_PROCESS "PrepareBatch: batchQueueId: {}, {}[{}]{}, sampleSize:{}", worldRank, worldSize, - batchQueueId, batch->name.c_str(), batch->channel, batch->batchId, batch->sample.size() + batchQueueId, batch->name, batch->channel, batch->batchId, batch->sample.size() ); emb_batch_t temp; temp.sample = batch->sample; @@ -178,12 +177,12 @@ protected: { for (int i = 0; i < rankSize; ++i) { std::cout << "splitKeys dev" << i << std::endl; - REC_LOG(INFO) << StringFormat("%d", VectorToString(splitKeys[i]).c_str()); + LOG_INFO(VectorToString(splitKeys[i])); } std::cout << "restore" << std::endl; - REC_LOG(INFO) << StringFormat("%d", VectorToString(restore).c_str()); + LOG_INFO(VectorToString(restore)); std::cout << "hotPos" << std::endl; - REC_LOG(INFO) << StringFormat("%d", VectorToString(hotPos).c_str()); + LOG_INFO(VectorToString(hotPos)); } void GetExpectRestore(keys_t& sample, vector& blockOffset, vector& restoreVec) @@ -270,9 +269,7 @@ TEST_F(KeyProcessTest, HashSplit) vector expectRestore = { 0, 0, 0, 0, 1, 1, 1, 1, 1, 2 }; vector> expectSplitKeys = { { 4, 16 }, { 1, 21, 29 }, { 14, 2 }, { 23, 7 } }; batch->sample = std::move(batchKeys); - if (VLOG_IS_ON(GLOG_DEBUG)) { - VLOG(GLOG_DEBUG) << StringFormat(KEY_PROCESS "batch sample: %s", VectorToString(batch->sample).c_str()); - } + LOG_DEBUG(KEY_PROCESS "batch sample: {}", VectorToString(batch->sample)); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.rankInfo.rankSize = rankSize; @@ -287,11 +284,7 @@ TEST_F(KeyProcessTest, HashSplit) TEST_F(KeyProcessTest, GetScAll) { vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 - if (VLOG_IS_ON(GLOG_DEBUG)) { - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "rank %d keyScLocal: %s", worldRank, VectorToString(keyScLocal).c_str() - ); - } + LOG_DEBUG(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, VectorToString(keyScLocal)); vector expectScAll(worldSize * worldSize); for (unsigned int i = 0; i < expectScAll.size(); ++i) { expectScAll[i] = floor(i / worldSize) + 1; @@ -307,11 +300,7 @@ TEST_F(KeyProcessTest, GetScAll) TEST_F(KeyProcessTest, GetScAllForUnique) { vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 - if (VLOG_IS_ON(GLOG_DEBUG)) { - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "rank %d keyScLocal: %s", worldRank, VectorToString(keyScLocal).c_str() - ); - } + LOG_DEBUG(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, VectorToString(keyScLocal)); vector expectScAll(worldSize * worldSize); for (unsigned int i = 0; i < expectScAll.size(); ++i) { expectScAll[i] = floor(i / worldSize) + 1; @@ -337,20 +326,20 @@ TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { 5, 0, 6, 2, 1, 3, 1, 7, 4, 8 }, { 6, 3, 7, 4, 3, 0, 1, 2, 5, 8 } }; batch->sample = std::move(allBatchKeys[worldRank]); - LOG(INFO) << StringFormat( - KEY_PROCESS "test BuildRestoreVec: rank %d, batchKeys %s", - worldRank, VectorToString(batch->sample).c_str() - ); + LOG_INFO(KEY_PROCESS "test BuildRestoreVec: rank {}, batchKeys {}", + worldRank, VectorToString(batch->sample)); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); auto [splitKeys, restore] = process.HashSplit(batch); - if (VLOG_IS_ON(GLOG_DEBUG)) { + + LOG_DEBUG("rank: {} splitKeys: {}", worldRank, [&] { vector tmp; for (const auto& i : splitKeys) { tmp.emplace_back(VectorToString(i)); } - VLOG(GLOG_DEBUG) << StringFormat("rank: %d splitKeys: %s", worldRank, VectorToString(tmp).c_str()); - } + return VectorToString(tmp); + }()); + process.BuildRestoreVec(batch, allExpectSs[worldRank], restore); ASSERT_THAT(restore, ElementsAreArray(allExpectRestore[worldRank])); } @@ -359,7 +348,7 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) { PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - LOG(INFO) << StringFormat("CPU Core Num: %d", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + LOG_INFO("CPU Core Num: %{}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { auto embName = embInfos[0].name; @@ -369,11 +358,9 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) vector hotPos; unique_ptr batch; batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue - LOG(INFO) << StringFormat("rankid :%d,batchid: %d", rankInfo.rankId, batch->batchId); + LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); - LOG(INFO) << StringFormat( - "rankid :%d,batchid: %d, hotPos %s", rankInfo.rankId, batch->batchId, VectorToString(hotPos).c_str() - ); + LOG_INFO("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, VectorToString(hotPos)); }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < 1; ++id) { @@ -389,7 +376,7 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) { PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - LOG(INFO) << StringFormat("CPU Core Num: %d", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + LOG_INFO("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { auto embName = embInfos[0].name; @@ -398,15 +385,13 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) vector hotPos; unique_ptr batch; batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue - LOG(INFO) << StringFormat("rankid :%d,batchid: %d", rankInfo.rankId, batch->batchId); + LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); auto[lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); process.BuildRestoreVec(batch, ss, restore, hotPos.size()); - REC_LOG(INFO) << StringFormat( - "rankid :%d,batchid: %d, lookupKeys: %s, scAll: %s, restore after build %s", - rankInfo.rankId, batch->batchId, VectorToString(lookupKeys).c_str(), - VectorToString(scAll).c_str(), VectorToString(restore).c_str() - ); + LOG_INFO("rankid :{},batchid: {}, lookupKeys: {}, scAll: {}, restore after build {}", + rankInfo.rankId, batch->batchId, VectorToString(lookupKeys), + VectorToString(scAll), VectorToString(restore)); }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { @@ -430,9 +415,8 @@ TEST_F(KeyProcessTest, Key2Offset) tmp.insert(pair(it->first, MapToString(it->second).c_str())); } - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "test Key2Offset: lookupKeys: %s, keyOffsetMap: %s", - VectorToString(lookupKeys).c_str(), MapToString(tmp).c_str()); + LOG_DEBUG(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", + VectorToString(lookupKeys), MapToString(tmp)); ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); keys_t lookupKeys2 = { 5, 17, 29, 5, 25, 5, 21, 25 }; @@ -442,9 +426,8 @@ TEST_F(KeyProcessTest, Key2Offset) for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { tmp.insert(pair(it->first, MapToString(it->second).c_str())); } - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "test Key2Offset: lookupKeys: %s, keyOffsetMap: %s", VectorToString(lookupKeys2).c_str(), - MapToString(tmp2).c_str()); + LOG_DEBUG(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", + VectorToString(lookupKeys2), MapToString(tmp2).c_str()); ASSERT_THAT(lookupKeys2, ElementsAreArray(expectOffset2)); } @@ -455,17 +438,15 @@ TEST_F(KeyProcessTest, Key2OffsetDynamicExpansion) ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.Key2OffsetDynamicExpansion("emb0", lookupKeys, EVAL_CHANNEL_ID); - if (VLOG_IS_ON(GLOG_DEBUG)) { + + LOG_DEBUG(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", VectorToString(lookupKeys), [&] { map tmp; for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { tmp.insert(pair(it->first, MapToString(it->second).c_str())); } + return MapToString(tmp); + }()); - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "test Key2Offset: lookupKeys: %s, keyOffsetMap: %s", - VectorToString(lookupKeys).c_str(), MapToString(tmp).c_str() - ); - } ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); } @@ -491,7 +472,7 @@ TEST_F(KeyProcessTest, ProcessPrefetchTask) ASSERT_EQ(process.Start(), 0); // 所有线程处理完(训练结束)后调用 this_thread::sleep_for(5s); - LOG(INFO) << "wait 20s for thread running"; + LOG_INFO("wait 20s for thread running"); this_thread::sleep_for(20s); process.Destroy(); } @@ -528,7 +509,7 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - LOG(INFO) << StringFormat("CPU Core Num: %d", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + LOG_INFO("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { UniquePtr unique; @@ -549,11 +530,9 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) process.GetUniqueConfig(uniqueConf); process.InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); - LOG(INFO) << StringFormat("rankid :%d,batchid: %d", rankInfo.rankId, batch->batchId); + LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); process.KeyProcessTaskHelperWithFastUnique(batch, unique, channel, id); - LOG(INFO) << StringFormat( - "rankid :%d,batchid: %d, hotPos %s", rankInfo.rankId, batch->batchId, VectorToString(hotPos).c_str() - ); + LOG_INFO("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, VectorToString(hotPos)); }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < 1; ++id) { diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 81752460..8d1ca883 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -283,7 +283,7 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalSSDKey) // 训练场景,返回ssd空间不足 auto ret3 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys2, TRAIN_CHANNEL_ID); ASSERT_EQ(ret3, TransferRet::SSD_SPACE_NOT_ENOUGH); - LOG(INFO) << "test TransferDDREmbWithSSDByEmptyExternalSSDKey end."; + LOG_INFO("test TransferDDREmbWithSSDByEmptyExternalSSDKey end."); } TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEval) @@ -302,7 +302,7 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEval) vector currentKeys = {6, 4, 55, 65, 75}; auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, EVAL_CHANNEL_ID); ASSERT_EQ(ret, TransferRet::TRANSFER_OK); - LOG(INFO) << "test eval+space enough+externalSSDKeysEmpty ok."; + LOG_INFO("test eval+space enough+externalSSDKeysEmpty ok."); // 评估+DDR剩余空间足够+externalSSDKeys非空 vector currentKeys2 = {15, 25, 6, 4, 55, 65, 75, 85, 95, 105, 115}; @@ -320,7 +320,7 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEval) ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 9)); ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 8)); ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 15)); - LOG(INFO) << "test eval+space enough+externalSSDKeysNotEmpty ok."; + LOG_INFO("test eval+space enough+externalSSDKeysNotEmpty ok."); } TEST_F(CacheManagerTest, TransferDDREmbWithSSDByDDRSpaceNotEnough) @@ -336,7 +336,7 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByDDRSpaceNotEnough) vector currentKeys = {6, 4, 101, 102, 103, 104, 105, 106, 107, 108}; auto ret = cacheManager.TransferDDREmbWithSSD(embTableName2, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); ASSERT_EQ(ret, TransferRet::DDR_SPACE_NOT_ENOUGH); - LOG(INFO) << "test train+ddr space enough+externalSSDKeysEmpty ok."; + LOG_INFO("test train+ddr space enough+externalSSDKeysEmpty ok."); } TEST_F(CacheManagerTest, EvictSSDEmbedding) @@ -348,7 +348,7 @@ TEST_F(CacheManagerTest, EvictSSDEmbedding) ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, key)); const auto it = cacheManager.excludeDDRKeyCountMap[embTableName].find(key); ASSERT_EQ(it, cacheManager.excludeDDRKeyCountMap[embTableName].end()); - LOG(INFO) << "test EvictSSDEmbedding end."; + LOG_INFO("test EvictSSDEmbedding end."); } TEST_F(CacheManagerTest, LoadTest) diff --git a/src/tests/ssd_engine/table_test.cpp b/src/tests/ssd_engine/table_test.cpp index 48c6a6d1..ed5b045d 100644 --- a/src/tests/ssd_engine/table_test.cpp +++ b/src/tests/ssd_engine/table_test.cpp @@ -56,16 +56,18 @@ TEST(Table, WriteAndReadAndDeleteAndCompact) } } - LOG(INFO) << "n data:" << nData << " ,batch size:" << batchSize << " ,write cost(ms):" << writeCost.count() - << " ,QPS:" << float(nData) * 1000 / writeCost.count(); + LOG_INFO("n data:{} ,batch size:{} ,write cost(ms): {} ,QPS:{}", + nData, batchSize, writeCost.count(), float(nData) * 1000 / writeCost.count()); // read auto start = chrono::high_resolution_clock::now(); auto ret = tb->FetchEmbeddings(allKeys); auto end = chrono::high_resolution_clock::now(); auto readCost = chrono::duration_cast(end - start); - LOG(INFO) << "n data:" << nData << " ,batch size:" << batchSize << " ,read cost(ms):" << readCost.count() - << " ,QPS:" << float(nData) * 1000 / readCost.count(); + + LOG_INFO("n data:{} ,batch size:{} ,read cost(ms):{} ,QPS:{}", + nData, batchSize, readCost.count(), float(nData) * 1000 / readCost.count()); + ASSERT_EQ(allEmbs, ret); // check space diff --git a/src/tests/utils/log_test.cpp b/src/tests/utils/log_test.cpp new file mode 100644 index 00000000..44dfd67a --- /dev/null +++ b/src/tests/utils/log_test.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. + * Description: common module + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + +#include +#include "utils/common.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +TEST(Log, Format) +{ + string test = Log::Format("{}{}{}", 1, 2, 3); + EXPECT_STREQ(test.c_str(), "123"); +} + +TEST(Log, LogLevel) +{ + MxRec::Log::SetLevel(Log::DEBUG); + testing::internal::CaptureStdout(); + LOG_DEBUG("debug log {}", "hellow"); + LOG_INFO("info log {}", "hellow"); + LOG_WARN("warn log {}", "hellow"); + LOG_ERROR("error log {}", "hellow"); + std::string output = testing::internal::GetCapturedStdout(); + EXPECT_NE(output.find("debug log hellow"), string::npos); + EXPECT_NE(output.find("info log hellow"), string::npos); + EXPECT_NE(output.find("warn log hellow"), string::npos); + EXPECT_NE(output.find("error log hellow"), string::npos); + + MxRec::Log::SetLevel(Log::INFO); + testing::internal::CaptureStdout(); + LOG_DEBUG("debug log {}", "hellow"); + LOG_INFO("info log {}", "hellow"); + LOG_WARN("warn log {}", "hellow"); + LOG_ERROR("error log {}", "hellow"); + output = testing::internal::GetCapturedStdout(); + EXPECT_EQ(output.find("debug log hellow"), string::npos); + EXPECT_NE(output.find("info log hellow"), string::npos); + EXPECT_NE(output.find("warn log hellow"), string::npos); + EXPECT_NE(output.find("error log hellow"), string::npos); + + MxRec::Log::SetLevel(Log::WARN); + testing::internal::CaptureStdout(); + LOG_DEBUG("debug log {}", "hellow"); + LOG_INFO("info log {}", "hellow"); + LOG_WARN("warn log {}", "hellow"); + LOG_ERROR("error log {}", "hellow"); + output = testing::internal::GetCapturedStdout(); + EXPECT_EQ(output.find("debug log hellow"), string::npos); + EXPECT_EQ(output.find("info log hellow"), string::npos); + EXPECT_NE(output.find("warn log hellow"), string::npos); + EXPECT_NE(output.find("error log hellow"), string::npos); + + MxRec::Log::SetLevel(Log::ERROR); + testing::internal::CaptureStdout(); + LOG_DEBUG("debug log {}", "hellow"); + LOG_INFO("info log {}", "hellow"); + LOG_WARN("warn log {}", "hellow"); + LOG_ERROR("error log {}", "hellow"); + output = testing::internal::GetCapturedStdout(); + EXPECT_EQ(output.find("debug log hellow"), string::npos); + EXPECT_EQ(output.find("info log hellow"), string::npos); + EXPECT_EQ(output.find("warn log hellow"), string::npos); + EXPECT_NE(output.find("error log hellow"), string::npos); +} + +TEST(Log, LayzEvalution) +{ + MxRec::Log::SetLevel(Log::WARN); + testing::internal::CaptureStdout(); + int flag1 = 0; + int flag2 = 0; + LOG_INFO("info log {} {}", "hellow", [&] { + flag1 = 1; + return "hellow"; + }()); + LOG_WARN("warn log {} {}", "hellow", [&] { + flag2 = 1; + return "hellow"; + }()); + LOG_ERROR("error log {}", "hellow"); + LOG_DEBUG("debug log {}", "hellow"); + std::string output = testing::internal::GetCapturedStdout(); + EXPECT_EQ(output.find("debug log hellow"), string::npos); + EXPECT_EQ(output.find("info log hellow hellow"), string::npos); + EXPECT_NE(output.find("warn log hellow hellow"), string::npos); + EXPECT_NE(output.find("error log hellow"), string::npos); + EXPECT_EQ(flag1, 0); + EXPECT_EQ(flag2, 1); +} + +TEST(Log, Basic) +{ + MxRec::Log::SetLevel(Log::INFO); + testing::internal::CaptureStdout(); + LOG_INFO("basictest"); + std::string output = testing::internal::GetCapturedStdout(); + EXPECT_NE(output.find("basictest"), string::npos); +} + +TEST(Log, TooManyArgs1) +{ + MxRec::Log::SetLevel(Log::INFO); + testing::internal::CaptureStdout(); + LOG_INFO("{} {} {}", 0.1f, 'h', 'e', "llow"); + std::string output = testing::internal::GetCapturedStdout(); + cout << output << endl; + EXPECT_NE(output.find("0.1 h ellow"), string::npos); +} + +TEST(Log, TooManyArgs2) +{ + MxRec::Log::SetLevel(Log::INFO); + testing::internal::CaptureStdout(); + LOG_INFO("{}", "h", "h", "h", "h", "h", "h", "h"); + std::string output = testing::internal::GetCapturedStdout(); + cout << output << endl; + EXPECT_NE(output.find("hhhhhhh"), string::npos); +} + +TEST(Log, FewArgs) +{ + MxRec::Log::SetLevel(Log::INFO); + testing::internal::CaptureStdout(); + LOG_INFO("{} {} {} {} {} {}", "hellow", "hellow"); + std::string output = testing::internal::GetCapturedStdout(); + cout << output << endl; + EXPECT_NE(output.find("hellow hellow"), string::npos); +} \ No newline at end of file -- Gitee From fe34840bebee974d42c9e8b59d39ddf40cc08cc2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 8 Sep 2023 18:57:57 +0800 Subject: [PATCH 330/551] Match-id-bc98fa432f6db288acc28d49bf9b99a321058e74 --- .../op_host/embedding_lookup_by_address.cpp | 5 +++ .../op_host/embedding_update_by_address.cpp | 40 +++++++++++-------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 401d7148..18e3f295 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -16,6 +16,11 @@ namespace optiling } size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (currentWorkspace == nullptr) { + printf("currentWorkspace nullptr\n"); + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = sysWorkspaceSize + usrSize; int32_t block_total_nums = 48; diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index ea7a3730..21caebf1 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -9,22 +9,23 @@ namespace optiling { TilingData2 tiling; - size_t usrSize = 256; - size_t sysWorkspaceSize = 16 * 1024 * 1024; + size_t usrSize = 256, sysWorkspaceSize = 16 * 1024 * 1024; if (context == nullptr) { printf("Update embbeding_type context nullptr\n"); return ge::GRAPH_FAILED; } size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (currentWorkspace == nullptr) { + printf("currentWorkspace nullptr\n"); + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = sysWorkspaceSize + usrSize; int32_t block_total_nums = 48; int32_t ub_limit = 175 * 1024; - - int32_t update_dim; - int32_t embbeding_type; - + int32_t update_dim, embbeding_type; int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); if (input_shape <= 0) { printf("input_shape must larger than 0\n"); @@ -32,22 +33,27 @@ namespace optiling } int32_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; + if (context->GetAttrs()->GetAttrPointer(0) == nullptr) { + printf("context GetAttrs GetAttrPointer nullptr\n"); + return ge::GRAPH_FAILED; + } + int32_t update_type = *(context->GetAttrs()->GetAttrPointer(0)); ge::DataType input_datatype = context->GetInputTensor(1)->GetDataType(); - - switch (input_datatype) { - case ge::DT_FLOAT16: - embbeding_type = 2; - break; - case ge::DT_INT32: - embbeding_type = 0; - break; - default: - embbeding_type = 1; - break; + if (input_datatype == ge::DT_FLOAT16) { + embbeding_type = 2; + } else if (input_datatype == ge::DT_INT32) { + embbeding_type = 0; + } else { + embbeding_type = 1; } update_dim = input_dim; + if (update_dim <= 0) { + printf("update_dim must larger than 0\n"); + return ge::GRAPH_FAILED; + } + tiling.set_update_type(update_type); tiling.set_embbeding_type(embbeding_type); tiling.set_update_dim(update_dim); -- Gitee From e65428066fd5a556a23ea6c09ab134cd3e83e76e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 11 Sep 2023 09:42:33 +0800 Subject: [PATCH 331/551] Match-id-d3e518ee7a7d6e82f7ca000fdf7838a38606b536 --- mx_rec/core/embedding.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index db7f3076..64a6365c 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -24,7 +24,8 @@ from mx_rec.constants.constants import (ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SP from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ - get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set + get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set, \ + get_table_instance_by_name from mx_rec.validator.validator import ClassValidator, StringValidator from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.util.normalization import fix_invalid_table_name @@ -317,6 +318,27 @@ class SparseEmbedding: return tf.reshape(src_emb, reshape_info) + @staticmethod + def get_emb_table_size(table_name: str) -> int: + """ + For HBM or DDR mode, return the size of sparse embedding table + :param table_name: the name of sparse embedding table + :return: the size of the sparse embedding table + """ + table_instance = get_table_instance_by_name(table_name) + host_vocabulary_size = table_instance.host_vocabulary_size() + device_vocabulary_size = table_instance.device_vocabulary_size + if not host_vocabulary_size and not get_use_dynamic_expansion(): + embed_dim = table_instance.emb_size + size = embed_dim * device_vocabulary_size + elif not host_vocabulary_size and get_use_dynamic_expansion(): + embed_dim = table_instance.ext_emb_size + size = embed_dim * device_vocabulary_size + else: + embed_dim = table_instance.ext_emb_size + size = (device_vocabulary_size + host_vocabulary_size) * embed_dim + return size + def check_optimizer_instance(self): for optimizer_instance in self._optimizer_instance_list: if tf.__version__.startswith("1"): -- Gitee From bdc8c09a7c8c856fd96f6eb10c9b8cb05d8c3be2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 11 Sep 2023 10:49:21 +0800 Subject: [PATCH 332/551] Match-id-107ef14199873b0dd9b8d05578f6e0a0f6e7af49 --- src/core/utils/log.cpp | 3 ++- src/core/utils/log.h | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/core/utils/log.cpp b/src/core/utils/log.cpp index c75545ee..2314a581 100644 --- a/src/core/utils/log.cpp +++ b/src/core/utils/log.cpp @@ -41,7 +41,8 @@ const char* Log::LevelToStr(int level) "WARN", "ERROR", }; - return msg[level]; + constexpr int LEVEL_OFFSET = 2; + return msg[level + LEVEL_OFFSET]; } void Log::LogUnpack(queue& fmt, stringstream &ss) diff --git a/src/core/utils/log.h b/src/core/utils/log.h index c1078510..3f3bf7bf 100644 --- a/src/core/utils/log.h +++ b/src/core/utils/log.h @@ -28,11 +28,11 @@ constexpr size_t DELIM_LEN = 2; class Log { public: - static constexpr int TRACE = 0; - static constexpr int DEBUG = 1; - static constexpr int INFO = 2; - static constexpr int WARN = 3; - static constexpr int ERROR = 4; + static constexpr int TRACE = -2; + static constexpr int DEBUG = -1; + static constexpr int INFO = 0; + static constexpr int WARN = 1; + static constexpr int ERROR = 2; static void SetRank(int rank); -- Gitee From 6cf7ddef4fb9de600025c7de84f5dfe6c1cefc94 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 5 Sep 2023 15:08:25 +0800 Subject: [PATCH 333/551] Match-id-09aba9f6f38e5e904d2fb5f7066c4720570b433c --- mx_rec/core/asc/build_graph.py | 56 ++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index d2540082..1841f2d9 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -9,7 +9,7 @@ import tensorflow as tf import mxrec_pybind from mx_rec.util.initialize import get_use_static from mx_rec.util.tf_version_adapter import npu_ops -from mx_rec.constants.constants import ApplyGradientsStrategy +from mx_rec.constants.constants import ApplyGradientsStrategy, TRAIN_CHANNEL_ID def get_restore_vector(config): @@ -138,20 +138,20 @@ def get_all2all_args(use_static: bool, config: dict) -> list: return all2all_args -def get_preprocessed_tensor_for_asc(table, config): +def get_swap_info(config: dict, swap_len: int, swap_pos: list, table: tf.Variable) -> list: + """ + Get swap info if threshold is configured. + :param config: training job config + :param swap_len: swap length + :param swap_pos: swap position + :param table: the instance to do swap + :return: swap info + """ use_static = get_use_static() max_lookup_vec_size = None if use_static: max_lookup_vec_size = config.get("send_count") * config.get("rank_size") - with tf.compat.v1.variable_scope("restore_vector"): - restore_vector, hot_pos = get_restore_vector(config) - - with tf.compat.v1.variable_scope("id_offsets"): - id_offsets, swap_pos, swap_len = get_id_offsets(max_lookup_vec_size, config) - - all2all_args = get_all2all_args(use_static, config) - if config.get("skip_emb_transfer"): swap_in = [tf.no_op()] else: @@ -179,6 +179,24 @@ def get_preprocessed_tensor_for_asc(table, config): h2d_emb_split = tf.split(h2d_emb, table_num, axis=1) swap_in = [tf.compat.v1.scatter_nd_update(table[i], nd_swap_pos, h2d_emb_split[i]) for i in range(len(table))] + return swap_in + + +def get_preprocessed_tensor_for_asc(table, config): + use_static = get_use_static() + max_lookup_vec_size = None + if use_static: + max_lookup_vec_size = config.get("send_count") * config.get("rank_size") + + with tf.compat.v1.variable_scope("restore_vector"): + restore_vector, hot_pos = get_restore_vector(config) + + with tf.compat.v1.variable_scope("id_offsets"): + id_offsets, swap_pos, swap_len = get_id_offsets(max_lookup_vec_size, config) + + all2all_args = get_all2all_args(use_static, config) + + swap_in = get_swap_info(config, swap_len, swap_pos, table) result = { 'restore_vector': restore_vector, @@ -188,10 +206,16 @@ def get_preprocessed_tensor_for_asc(table, config): 'all2all_args': all2all_args, } - if config.get("gradients_strategy") == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: - with tf.compat.v1.variable_scope("restore_vector_second"): - restore_vector_second = get_restore_vector_second(max_lookup_vec_size, config) - with tf.compat.v1.variable_scope("unique_keys"): - unique_keys = get_unique_keys(max_lookup_vec_size, config) - result.update({'restore_vector_second': restore_vector_second, 'unique_keys': unique_keys}) + if config.get("gradients_strategy") != ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + return result + + if config.get("channel_id") != TRAIN_CHANNEL_ID: + return result + + with tf.compat.v1.variable_scope("restore_vector_second"): + restore_vector_second = get_restore_vector_second(max_lookup_vec_size, config) + + with tf.compat.v1.variable_scope("unique_keys"): + unique_keys = get_unique_keys(max_lookup_vec_size, config) + result.update({'restore_vector_second': restore_vector_second, 'unique_keys': unique_keys}) return result -- Gitee From 699e409aaa2a07cd727b01aa2f991998350e3ac5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 11 Sep 2023 15:20:27 +0800 Subject: [PATCH 334/551] Match-id-e1af324093b685706ebbd8a80b53daa62722453b --- build/build.sh | 4 ++-- build/build_tf2.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build/build.sh b/build/build.sh index b9544a58..8a2dd5d4 100644 --- a/build/build.sh +++ b/build/build.sh @@ -90,7 +90,7 @@ then gen_tar_file echo "-----Build gen tar finished-----" - clean + # clean echo "-----Done-----" fi @@ -102,6 +102,6 @@ then gen_tar_file echo "-----Build gen tar finished-----" - clean + # clean echo "-----Done-----" fi \ No newline at end of file diff --git a/build/build_tf2.sh b/build/build_tf2.sh index a42aeeb2..e3d01417 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -118,7 +118,7 @@ gen_wheel_file() python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) mkdir -p "$1" mv dist/mx_rec*.whl "$1" - remove "${ROOT_DIR}"/mx_rec/libasc + # remove "${ROOT_DIR}"/mx_rec/libasc } gen_tar_file() -- Gitee From b1cf9e38dac273900cec717ee0ac115afda906b5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 11 Sep 2023 17:01:38 +0800 Subject: [PATCH 335/551] Match-id-7be77becb9fdfdd09fe1ee9eaca6063959a3de5a --- mx_rec/__init__.py | 3 +- mx_rec/core/asc/build_graph.py | 53 ++++---- mx_rec/core/embedding.py | 29 ++--- mx_rec/graph/patch.py | 123 +++++++++++++++++- mx_rec/util/initialize.py | 4 + src/core/hybrid_mgmt/hybrid_mgmt.cpp | 11 +- src/core/hybrid_mgmt/hybrid_mgmt.h | 3 +- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 67 +--------- src/core/hybrid_mgmt/hybrid_mgmt_block.h | 9 +- src/core/key_process/key_process.cpp | 12 +- src/core/utils/common.h | 6 + src/pybind/module_main.cpp | 3 +- .../hybrid_mgmt/hybrid_mgmt_block_test.cpp | 4 +- 13 files changed, 202 insertions(+), 125 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index f5d02cd0..9c1157fd 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -8,7 +8,7 @@ from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator, patch_for_bool_gauge, \ - patch_for_end, patch_for_assert_eval_spec, patch_for_scale_loss + patch_for_end, patch_for_assert_eval_spec, patch_for_scale_loss, patch_for_session from mx_rec.optimizers.base import patch_for_optimizer patch_for_saver() @@ -19,6 +19,7 @@ patch_for_assert_eval_spec() patch_for_bool_gauge() patch_for_end() patch_for_optimizer() +patch_for_session() __version__ = "5.0.RC2" diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 1fd17219..9ecd5108 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -34,21 +34,21 @@ def get_restore_vector(config): restore_size = config.get("batch_size") * config.get("feat_cnt") else: restore_size = None - with tf.control_dependencies([config.get("notify_hybridmgmt_op")]): - with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): - if use_hot and emb_size: - device_id = int(config.get("device_id")) - hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) - restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32, tf.int32], - output_shapes=[restore_size, [hot_size]], - channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}' - ) - else: - restore_vector = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32], - output_shapes=[restore_size], - channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}')[0] + + with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + if use_hot and emb_size: + device_id = int(config.get("device_id")) + hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) + restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[restore_size, [hot_size]], + channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}' + ) + else: + restore_vector = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32], + output_shapes=[restore_size], + channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}')[0] return restore_vector, hot_pos @@ -123,20 +123,19 @@ def get_all2all_args(use_static: bool, config: dict) -> list: :param config: embedding config :return: all2all parametrs """ - all2all_args = None if use_static: return all2all_args - with tf.control_dependencies([config.get("notify_hybridmgmt_op")]): - with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): - with tf.compat.v1.variable_scope("all2all"): - logging.debug( - f'Channel {config.get("table_name")}_a2a_{config.get("channel_id")} was built for getnext') - all2all_args = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int64], - output_shapes=[[config.get("rank_size"), config.get("rank_size")]], - channel_name=f'{config.get("table_name")}_all2all_{config.get("channel_id")}', - name="a2a_get_next")[0] * config.get("emb_size") + + with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): + with tf.compat.v1.variable_scope("all2all"): + logging.debug( + f'Channel {config.get("table_name")}_a2a_{config.get("channel_id")} was built for getnext') + all2all_args = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int64], + output_shapes=[[config.get("rank_size"), config.get("rank_size")]], + channel_name=f'{config.get("table_name")}_all2all_{config.get("channel_id")}', + name="a2a_get_next")[0] * config.get("emb_size") return all2all_args @@ -221,4 +220,4 @@ def get_preprocessed_tensor_for_asc(table, config): with tf.compat.v1.variable_scope("unique_keys"): unique_keys = get_unique_keys(max_lookup_vec_size, config) result.update({'restore_vector_second': restore_vector_second, 'unique_keys': unique_keys}) - return result + return result \ No newline at end of file diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 64a6365c..4c94d675 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -673,13 +673,8 @@ class SparseEmbedding: channel_id: channel id 0 for train,1 for eval Returns: npu_ops.outfeed_enqueue_op notify preprocess step """ - sparse_lookup_id = ConfigInitializer.get_instance().notify_hybrid_channel_sparse_id[channel_id] - notify_message = tf.constant([sparse_lookup_id], dtype=tf.int32) - ConfigInitializer.get_instance().notify_hybrid_channel_sparse_id[channel_id] += 1 channel_name = "d2h_notify_hybridmgmt_{}".format(channel_id) - logging.debug("%s was built for op outfeed sparse id : %s.", channel_name, sparse_lookup_id) - notify_hybridmgmt_op = npu_ops.outfeed_enqueue_op( - channel_name=channel_name, inputs=[notify_message]) + notify_hybridmgmt_op = tf.no_op(channel_name) return notify_hybridmgmt_op def lookup_for_asc_with_feature_spec_inner(self, feature_spec: FeatureSpec, send_count: int, **kwargs): @@ -709,21 +704,23 @@ class SparseEmbedding: channel_id = get_training_mode_channel_id(is_training=is_training) logging.debug(f"get preprocessed tensor for asc for table {self.table_name} with skip emb transfer " f"{self.skip_emb_transfer} is_training: {is_training}, channel_id: {channel_id} .") - # 通知c++侧此处开始执行sparse look up的逻辑,注意此处每一个tablename做一次 - notify_hybridmgmt_op = self.generate_lookup_id_notify_hybrid(channel_id) - # 将notify_hybridmgmt_op加入到config中,在restore的get next 算子中做控制依赖 + config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, rank_size=rank_size, channel_id=channel_id, table_name=self.table_name, skip_emb_transfer=self.skip_emb_transfer, ext_emb_size=self.ext_emb_size, emb_size=self.emb_size, use_hot=use_hot, device_id=device_id, - use_dynamic_expansion=use_dynamic_expansion, gradients_strategy=self.apply_gradients_strategy, - notify_hybridmgmt_op=notify_hybridmgmt_op) + use_dynamic_expansion=use_dynamic_expansion, gradients_strategy=self.apply_gradients_strategy) - if self.skip_emb_transfer: - result = get_preprocessed_tensor_for_asc(self.variable, config) - else: - variable_list = [self.variable] + [slot_info.get("slot") for slot_info in self.optimizer_slot_info_list] - result = get_preprocessed_tensor_for_asc(variable_list, config) + # 用于打桩的op节点,它的name用于标识此次的sparse lookup是train还是eval + # 后续在session run的时候,通过图反向查找该子图中查找到此op + # 最后通过名称判断session run是调用的哪个通道,并通知c++侧进行计数和唤醒操作 + notify_hybridmgmt_op = self.generate_lookup_id_notify_hybrid(channel_id) + with tf.control_dependencies([notify_hybridmgmt_op]): + if self.skip_emb_transfer: + result = get_preprocessed_tensor_for_asc(self.variable, config) + else: + variable_list = [self.variable] + [slot_info.get("slot") for slot_info in self.optimizer_slot_info_list] + result = get_preprocessed_tensor_for_asc(variable_list, config) restore_vector = result.get("restore_vector") restore_vector_second = result.get("restore_vector_second") hot_pos = result.get("hot_pos") diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 65683f5b..1ca73423 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -18,9 +18,10 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.training.optimizer import Optimizer +from tensorflow.python.client.session import BaseSession from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph, insert_bool_gauge, \ - get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round + get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round, get_asc_manager from mx_rec.util.tf_version_adapter import NPUCheckpointSaverHook from mx_rec.graph.merge_lookup import do_merge_lookup @@ -38,10 +39,130 @@ def init_dataset(self, input_data): self._graph_attr = ops.get_default_graph() +def run(self, fetches, feed_dict=None, options=None, run_metadata=None): + + """ + Replace tensorflow's session run method with this method, this method will + notify the hybridMgmt side to wake up and count each time sess run is called. + + Args: + fetches: A single graph element, a list of graph elements, or a dictionary + whose values are graph elements or lists of graph elements (described + above). + feed_dict: A dictionary that maps graph elements to values (described + above). + options: A [`RunOptions`] protocol buffer + run_metadata: A [`RunMetadata`] protocol buffer + + Returns: + Either a single value if `fetches` is a single graph element, or + a list of values if `fetches` is a list, or a dictionary with the + same keys as `fetches` if that is a dictionary (described above). + Order in which `fetches` operations are evaluated inside the call + is undefined. + + Raises: + RuntimeError: If this `Session` is in an invalid state (e.g. has been + closed). + TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. + ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a + `Tensor` that doesn't exist. + Returns:None + """ + + all_op = [] + + def get_all_tensor(tensor_or_tensorlist): + # 把所有的tensor和Operation取出来 + if isinstance(tensor_or_tensorlist, (list, tuple)) : + for i in tensor_or_tensorlist: + get_all_tensor(i) + elif isinstance(tensor_or_tensorlist, dict): + for k in tensor_or_tensorlist.keys(): + get_all_tensor(tensor_or_tensorlist.get(k)) + elif isinstance(tensor_or_tensorlist, (tf.Tensor, tf.Operation)): + name = tensor_or_tensorlist.name + if ":" in name: + name = name[:name.find(":")] + all_op.append(name) + + def get_channel_id_by_sub_graph(input_tensors, name2channel_cache): + # 通过fetches需要运行的节点来找到 spase look up中的打桩tensor + # 从而判断该session run运行的是train还是eval + name_list_str_key = "_".join(input_tensors) + if name_list_str_key in name2channel_cache.keys(): + return name2channel_cache.get(name_list_str_key) + this_channel_id = -1 + graph_def = self.graph_def + cut_graph_input = tf.compat.v1.graph_util.extract_sub_graph(graph_def, input_tensors) + node_list_input = cut_graph_input.node + for node in node_list_input: + if "d2h_notify_hybridmgmt_" in node.name: + this_channel_id = int(node.name[-1]) + break + name2channel_cache[name_list_str_key] = this_channel_id + return this_channel_id + + # patch的方式为session增加步数属性 + step = self.get_mxrec_steps() + # 进行缓存,避免每次都进行图查询 + if not step: + step = 1 + for custom_optimizer in self.get_config().graph_options.rewrite_options.custom_optimizers: + if custom_optimizer.name == "NpuOptimizer": + step = custom_optimizer.parameter_map["iterations_per_loop"].i + break + self.steps = step + + # patch的方式为图增加缓存属性 + name2channel_cache = self.get_mxrec_name2channel_cache() + + # 查找相应的channel_id + try: + get_all_tensor(fetches) + channel_id = get_channel_id_by_sub_graph(all_op, name2channel_cache) + except AssertionError: + channel_id = -1 + + if channel_id != -1: + get_asc_manager().send_message_to_hybrid(channel_id, self.steps) + + #调用tensorflow原生的方法 + return self.old_run_method(fetches, feed_dict, options, run_metadata) + + def patch_for_dataset(): DatasetV2.__init__ = init_dataset +def patch_for_session(): + + def get_mxrec_steps(self): + try: + # 不能在未调用非__init__函数之前调用非__init__中定义的实例化属性 + return self.steps + except AttributeError: + self.steps = None + return self.steps + + def get_mxrec_name2channel_cache(self): + try: + # 不能在未调用非__init__函数之前调用非__init__中定义的实例化属性 + return self.name2channel_cache + except AttributeError: + self.name2channel_cache = {} + return self.name2channel_cache + + def get_config(self): + return getattr(self, '_config') + + BaseSession.old_run_method = BaseSession.run + BaseSession.run = run + BaseSession.get_mxrec_name2channel_cache = get_mxrec_name2channel_cache + BaseSession.get_mxrec_steps = get_mxrec_steps + BaseSession.get_config = get_config + + def chief_session_creator_init(self, scaffold=None, master='', config=None, checkpoint_dir=None, checkpoint_filename_with_path=None): """ diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index e9804c43..47093085 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -515,6 +515,10 @@ def set_asc_manager(manager): ConfigInitializer.get_instance().set_asc_manager(manager) +def get_asc_manager(): + return ConfigInitializer.get_instance().get_asc_manager() + + def trigger_evict(): if not is_asc_manager_initialized(): raise RuntimeError("ASC manager does not exist.") diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 28262f8c..ea560ad1 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -127,7 +127,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, hybridMgmtBlock = Singleton::GetInstance(); hybridMgmtBlock->SetRankInfo(rankInfo); - hybridMgmtBlock->StartNotifySignalMonitor(); + // 启动数据处理线程 bool rc = InitKeyProcess(rankInfo, embInfos, thresholdValues, seed); if (!rc) { @@ -1029,4 +1029,13 @@ int HybridMgmt::GetStepFromPath(const string& loadPath) return stoi(match[1]); } return 0; +} + +/// 通过pyBind在python侧调用,通知hybridMgmt上层即将进行图的执行 +/// \param channelID 通道id +/// \param steps 运行的步数,由于可能存在循环下沉,所以1个session run 对应N步 +void HybridMgmt::CallBySessionRun(int channelID, int steps) +{ + hybridMgmtBlock->CheckAndNotifyWake(channelID); + hybridMgmtBlock->CountPythonStep(channelID, steps); } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index f7d24c80..93d34009 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -115,6 +115,8 @@ namespace MxRec { bool IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t& embTableCount); + void CallBySessionRun(int channelID, int steps); + private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed); @@ -128,7 +130,6 @@ namespace MxRec { int GetStepFromPath(const string& loadPath); static void AddCacheManagerTraceLog(CkptData& saveData); void RestoreFreq4Save(CkptData& saveData); - private: int currentBatchId; int trainBatchId = 0; // 0-199, 200- diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index ea95aec9..4d39c571 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -32,6 +32,10 @@ void HybridMgmtBlock::CheckAndSetBlock(int channelId) /// \param channelId train 0 eval 1 void HybridMgmtBlock::CheckAndNotifyWake(int channelId) { + VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + + "start notify channelId %d pythonBatchId %d hybridBatchId %d", + channelId, pythonBatchId[lastRunChannelId], hybridBatchId[channelId]); + CheckValid(channelId); if (pythonBatchId[channelId] >= hybridBatchId[channelId]) { isBlock[channelId] = false; @@ -63,10 +67,10 @@ bool HybridMgmtBlock::WaitValid(int channelId) } } -void HybridMgmtBlock::CountPythonStep(int channelId) +void HybridMgmtBlock::CountPythonStep(int channelId, int steps) { // 相应的通知计数 - pythonBatchId[channelId]++; + pythonBatchId[channelId] += steps; } /// 检查是否进行了通道切换,检查当前的step是否合理 @@ -139,9 +143,6 @@ void HybridMgmtBlock::ResetAll(int channelId) pythonBatchId[channelId] = 0; hybridBatchId[channelId] = 0; isBlock[channelId] = false; - // eval train通道的sparse 同时进行重置,以防出现sparse id失效的问题 - uniqueSparseLookID[EVAL_CHANNEL_ID] = -1; - uniqueSparseLookID[TRAIN_CHANNEL_ID] = -1; } /// 检查当前的步数是否可以进行save @@ -183,55 +184,6 @@ void HybridMgmtBlock::SetBlockStatus(int channelId, bool block) isBlock[channelId] = block; } -/// python侧调用的npu.outfeed_enqueue_op 发送的消息。用来判断当前python执行的步数 -void HybridMgmtBlock::StartNotifySignalMonitor() -{ -#ifndef GTEST - auto fn = [this](int channelId) { - while (isRunning) { - std::vector tensors; - tensorflow::Status status = tensorflow::RecvTensorByAcl(aclHandles[channelId], tensors); - if (!isRunning) { - break; - } - if (status != tensorflow::Status::OK()) { - LOG(ERROR) << StringFormat(HYBRID_BLOCKING + - "%s hd recv error '%s'", d2hChannelName[channelId].c_str(), status.error_message().c_str()); - throw runtime_error("rev error"); - } - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + - "send message to hybrid channelId %d pythonBatchId %d hybridBatchId %d", - channelId, pythonBatchId[channelId], hybridBatchId[channelId]); - - int sparseLookupId = *tensors[0].flat().data(); - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + - "send sparse_lookup_id channel %d sparse id %d unique id %d", - channelId, sparseLookupId, uniqueSparseLookID[channelId]); - - if (uniqueSparseLookID[channelId] == -1) { - // 初始化,只有第一个sparse loop id能进行计数和唤 - uniqueSparseLookID[channelId] = sparseLookupId; - } - // 只被计数一次 - if (sparseLookupId == uniqueSparseLookID[channelId]) { - // 只有最先来的id才能进行唤醒和计数 - CheckAndNotifyWake(channelId); - CountPythonStep(channelId); - } - } - LOG(INFO) << StringFormat(HYBRID_BLOCKING + "BLOCKING thread stop"); - }; - uint32_t localRankId = rankInfo.deviceId; - for (int channelId = 0; channelId < MAX_CHANNEL_NUM; ++channelId) { - d2hChannelName[channelId] = StringFormat(D2H_CHANNEL_NAME_PRE + "%d", channelId); - auto aclChannelHandle = tdtCreateChannel(localRankId, d2hChannelName[channelId].c_str(), PING_PONG_SIZE); - LOG(INFO) << StringFormat(HYBRID_BLOCKING + " %d %s", localRankId, d2hChannelName[channelId].c_str()); - aclHandles[channelId] = aclChannelHandle; - procThreads.emplace_back(std::make_unique(fn, channelId)); - } -#endif -} - void HybridMgmtBlock::Destroy() { if (!isRunning) { @@ -239,13 +191,6 @@ void HybridMgmtBlock::Destroy() return; } isRunning = false; -#ifndef GTEST - for (int channelId = 0; channelId < MAX_CHANNEL_NUM; ++channelId) { - tensorflow::StopRecvTensorByAcl(&aclHandles[channelId], d2hChannelName[channelId]); - procThreads[channelId]->join(); - } - LOG(INFO) << StringFormat(HYBRID_BLOCKING + "BLOCKING stop"); -#endif } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h index e72a8523..5b0821f9 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -28,12 +28,10 @@ public: // readEmbed算子侧将要处理的batch id int readEmbedBatchId[2] = {0, 0}; bool isRunning = true; - // 每个sparse lookup都会生成一个唯一的id,保证每次运行只有一个id在进行计数 - int uniqueSparseLookID[2]{-1, -1}; ~HybridMgmtBlock(); void CheckAndNotifyWake(int channelId); - void CountPythonStep(int channelId); + void CountPythonStep(int channelId, int steps); void CheckAndSetBlock(int channelId); void CheckValid(int channelId); void DoBlock(int channelId); @@ -43,7 +41,6 @@ public: void SetBlockStatus(int channelId, bool block); void SetRankInfo(RankInfo rankInfo); void SetStepInterval(int trainStep, int evalStep); - void StartNotifySignalMonitor(); bool WaitValid(int channelId); void Destroy(); private: @@ -51,10 +48,7 @@ private: int stepsInterval[2] = {0, 0}; // 控制通道阻塞的变量 bool isBlock[2] = {true, true}; - string d2hChannelName[2]; RankInfo rankInfo; - acltdtChannelHandle* aclHandles[2]; - std::vector> procThreads {}; }; class HybridMgmtBlockingException : public std::exception { @@ -62,7 +56,6 @@ public: explicit HybridMgmtBlockingException(const string scene) { HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - // int channelId, int preprocessBatchNumber, int currentBatchNumber int channelId = hybridMgmtBlock->lastRunChannelId; int preprocessBatchNumber = hybridMgmtBlock->hybridBatchId[channelId]; int currentBatchNumber = hybridMgmtBlock->pythonBatchId[channelId]; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 10e2889a..69f2278b 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -345,6 +345,7 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, UniquePtr& unique, int channel, int threadId) { + std::lock_guard lock(loadSaveMut[channel][threadId]); // tuple for keyRec restore hotPos scAll countRecv isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; @@ -371,9 +372,8 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv, channel); LOG_DEBUG("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); } - if (!rankInfo.useStatic) { // Static all2all,need send count - SendA2A(uniqueInfo.all2AllInfo.scAll, batch->name, batch->channel, batch->batchId); - } + // Static all2all,need send count + if (!rankInfo.useStatic) { SendA2A(uniqueInfo.all2AllInfo.scAll, batch->name, batch->channel, batch->batchId); } auto tensors = make_unique>(); tensors->push_back(Vec2TensorI32(uniqueInfo.restore)); @@ -400,6 +400,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) { + std::lock_guard lock(loadSaveMut[channel][threadId]); vector splitKeys; vector restore; vector hotPos; @@ -436,9 +437,8 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe } } - if (!rankInfo.useStatic) { // Static all2all,need send count - SendA2A(scAll, batch->name, batch->channel, batch->batchId); - } + // Static all2all,need send count + if (!rankInfo.useStatic) { SendA2A(scAll, batch->name, batch->channel, batch->batchId); } TimeCost pushResultTC; auto tensors = make_unique>(); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 55f859ba..beae701e 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -504,12 +504,18 @@ namespace MxRec { keys_t keyRecv; vector scAll; vector countRecv; + All2AllInfo() = default; + All2AllInfo(keys_t keyRecv, vector scAll, vector countRecv) + : keyRecv(keyRecv), scAll(scAll), countRecv(countRecv) {} }; struct UniqueInfo { vector restore; vector hotPos; All2AllInfo all2AllInfo; + UniqueInfo() = default; + UniqueInfo(vector restore, vector hotPos, All2AllInfo all2AllInfo) + : restore(restore), hotPos(hotPos), all2AllInfo(all2AllInfo) {} }; struct KeySendInfo { diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 70559c83..403db042 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -164,7 +164,8 @@ void GetHybridMgmt(pybind11::module_& m) .def("destroy", &MxRec::HybridMgmt::Destroy) .def("evict", &MxRec::HybridMgmt::Evict) .def("send", &MxRec::HybridMgmt::SendHostMap, py::arg("table_name") = "") - .def("receive", &MxRec::HybridMgmt::ReceiveHostMap, py::arg("key_offset_map")); + .def("receive", &MxRec::HybridMgmt::ReceiveHostMap, py::arg("key_offset_map")) + .def("send_message_to_hybrid", &MxRec::HybridMgmt::CallBySessionRun, py::arg("channel_id"), py::arg("steps")=1); } void GetThresholdValue(pybind11::module_& m) diff --git a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp index 380ed5f7..b3b0213d 100644 --- a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp +++ b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp @@ -44,12 +44,12 @@ TEST_F(HybridMgmtBlockTest, CountAndNotifyWake) hybridMgmtBlock = std::make_unique(); hybridMgmtBlock->SetStepInterval(1, 1); hybridMgmtBlock->CheckAndNotifyWake(0); - hybridMgmtBlock->CountPythonStep(0); + hybridMgmtBlock->CountPythonStep(0, 1); hybridMgmtBlock->pythonBatchId[0] = 1; hybridMgmtBlock->hybridBatchId[0] = 0; auto fn = [this](int channelId) { hybridMgmtBlock->CheckAndNotifyWake(channelId); - hybridMgmtBlock->CountPythonStep(0); + hybridMgmtBlock->CountPythonStep(0, 1); return 0; }; procThreads.emplace_back(std::make_unique(fn, 0)); -- Gitee From 52b7b302c86911b4b5bae7a297c949a805c66d97 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 8 Sep 2023 09:57:46 +0800 Subject: [PATCH 336/551] Match-id-62dd33dc710b3667bd19879faa5e5b568b7f4665 --- src/core/ssd_engine/table.cpp | 35 +++++++++++++++++++---------------- src/core/ssd_engine/table.h | 5 ++++- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index 212de4e7..6aca62f6 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -217,22 +217,7 @@ void Table::InsertEmbeddingsInner(vector &keys, vector> } if (curFile == nullptr || (curFile != nullptr && curFile->GetDataCnt() >= maxDataNumInFile)) { - // leave diskFreeSpaceThreshold % space for each disk - while (true) { - fs::space_info si = fs::space((curTablePath)); - if ((double(si.free) / double(si.capacity)) > diskFreeSpaceThreshold) { - break; - } - - curSavePathIdx += 1; - if (curSavePathIdx >= savePaths.size()) { - throw runtime_error("all disk's space not enough"); - } - curTablePath = savePaths[curSavePathIdx]; - LOG_INFO("current data path's free space less than {}, try next path:{}", - diskFreeSpaceThreshold, curTablePath); - } - + SetValidPath(); curFile = make_shared(curMaxFileID, curTablePath); fileSet.insert(curFile); curMaxFileID++; @@ -350,3 +335,21 @@ void Table::DeleteEmbeddingsInner(vector &keys) } } } + +void Table::SetValidPath() +{ + while (true) { + fs::space_info si = fs::space((curTablePath)); + if ((double(si.available) / double(si.capacity)) > diskAvailSpaceThreshold) { + break; + } + + curSavePathIdx += 1; + if (curSavePathIdx >= savePaths.size()) { + throw runtime_error("all disk's space not enough"); + } + curTablePath = savePaths[curSavePathIdx]; + LOG_INFO("current data path's free space less than {}%, try next path:{}", + diskAvailSpaceThreshold * convertToPercentage, curTablePath); + } +} diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h index 679f702d..f82be2df 100644 --- a/src/core/ssd_engine/table.h +++ b/src/core/ssd_engine/table.h @@ -50,6 +50,8 @@ namespace MxRec { void LoadDataFileSet(const shared_ptr& metaFile, int step); + void SetValidPath(); + string name; // init by constructor vector savePaths; // init by constructor, support Save and Load from multiple path uint64_t maxTableSize; // init by constructor, maximum key-value volume @@ -64,6 +66,7 @@ namespace MxRec { uint64_t curMaxFileID = 0; // no concurrent writing, always atomic increase const uint32_t maxNameSize = 1024; const string saveDirPrefix = "ssd_sparse_model_rank_"; + const int convertToPercentage = 100; /* args for performance(not expose to user yet) * 2 read thread is optimal when: @@ -76,7 +79,7 @@ namespace MxRec { int readThreadNum = 2; uint32_t maxDataNumInFile = 10000; // relax constrain for performance, need tuning double compactThreshold = 0.5; - double diskFreeSpaceThreshold = 0.05; // in range [0, 1), leave diskFreeSpaceThreshold*100 % for disk space + double diskAvailSpaceThreshold = 0.05; // in range [0, 1), leave diskAvailSpaceThreshold*100 % for disk space }; } -- Gitee From 7c88a000b803b266aab8d1aefa685531f2d6faa5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 11 Sep 2023 20:57:52 +0800 Subject: [PATCH 337/551] Match-id-b5b57c7f308070e028de69f0e4315d5425aaa9cd --- .../op_host/embedding_update_by_address.cpp | 11 ----------- .../op_kernel/embedding_lookup_by_address.cpp | 2 +- .../op_kernel/embedding_update_by_address.cpp | 2 +- cust_op/cust_op_by_addr/run.sh | 4 ++-- 4 files changed, 4 insertions(+), 15 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index 21caebf1..b75a5912 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -71,17 +71,6 @@ namespace ge { static ge::graphStatus InferShape(gert::InferShapeContext *context) { - gert::Shape *y_shape = context->GetOutputShape(0); - int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); - if (input_shape <= 0) { - printf("input_shape must larger than 0\n"); - return GRAPH_FAILED; - } - - int64_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; - y_shape->SetDimNum(2); - y_shape->SetDim(0, input_shape); - y_shape->SetDim(1, input_dim); return GRAPH_SUCCESS; } static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index bd43a9b5..84548729 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -50,7 +50,7 @@ public: PingpongNum = 1; int min_move_num = 32 / singleDataSize; // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + min_move_num) / min_move_num表示除以min_move_num向下取整 - int onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); + onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums; // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 int occupyAddressBytesNum = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2; diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index 6b1c52cc..b35a56c2 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -48,7 +48,7 @@ public: PingpongNum = 1; int min_move_num = 32 / singleDataSize; // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + min_move_num) / min_move_num表示除以min_move_num向下取整 - int onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); + onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums; // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 int occupyAddressBytesNum = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2; diff --git a/cust_op/cust_op_by_addr/run.sh b/cust_op/cust_op_by_addr/run.sh index e994c1b2..86ca88cf 100644 --- a/cust_op/cust_op_by_addr/run.sh +++ b/cust_op/cust_op_by_addr/run.sh @@ -14,8 +14,8 @@ export PATH=$parent_dir:$PATH # 利用msopgen生成可编译文件 rm -rf ./custom_op -msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b -lan cpp -out ./custom_op -m 0 -op EmbeddingLookupByAddress -msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b -lan cpp -out ./custom_op -m 1 -op EmbeddingUpdateByAddress +msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b1 -lan cpp -out ./custom_op -m 0 -op EmbeddingLookupByAddress +msopgen gen -i emb_custom.json -f tf -c ai_core-ascend910b1 -lan cpp -out ./custom_op -m 1 -op EmbeddingUpdateByAddress cp -rf op_kernel custom_op/ cp -rf op_host custom_op/ -- Gitee From fee9c1d879c6fec549eface832004e4d5a3a86ae Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 11 Sep 2023 18:13:05 +0800 Subject: [PATCH 338/551] Match-id-99ea283c93c9011e0d10c46db1088fa2d055a248 --- src/core/checkpoint/checkpoint.cpp | 68 +++--- .../ckpt_data_handler/ckpt_data_handler.cpp | 6 +- .../feat_admit_n_evict_ckpt.cpp | 9 +- .../host_emb_ckpt/host_emb_ckpt.cpp | 9 +- .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp | 4 +- .../nddr_offset_ckpt/nddr_offset_ckpt.cpp | 4 +- src/core/emb_hashmap/emb_hashmap.cpp | 16 +- src/core/emb_table/emb_table.cpp | 33 ++- src/core/hd_transfer/hd_transfer.cpp | 66 +++-- src/core/host_emb/host_emb.cpp | 58 ++--- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 229 ++++++++---------- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 69 +++--- .../constant_initializer.cpp | 3 +- .../random_normal_initializer.cpp | 3 +- .../truncated_normal_initializer.cpp | 3 +- .../key_process/feature_admit_and_evict.cpp | 20 +- src/core/key_process/key_process.cpp | 35 ++- src/core/ssd_cache/cache_manager.cpp | 15 +- src/core/ssd_cache/lfu_cache.cpp | 2 +- src/core/ssd_engine/file.cpp | 20 +- src/core/ssd_engine/table.cpp | 4 +- src/core/utils/common.cpp | 6 + src/core/utils/common.h | 4 +- src/ops_tf/hybrid_dataset_ops.cpp | 53 ++-- src/tests/emb_mgmt/emb_mgmt_test.cpp | 18 +- src/tests/emb_table/emb_table_test.cpp | 11 +- .../hybrid_mgmt/hybrid_mgmt_block_test.cpp | 4 +- src/tests/ssd_cache/cache_manager_test.cpp | 12 +- src/tests/ssd_engine/file_test.cpp | 4 +- src/tests/utils/log_test.cpp | 14 ++ 30 files changed, 364 insertions(+), 438 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 1c18b456..4bbcaad7 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -32,12 +32,12 @@ void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRa useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; mgmtEmbInfo = EmbInfo; - LOG(INFO) << "Start host side saving data."; - VLOG(GLOG_DEBUG) << "==Start to create save data handler."; + LOG_INFO("Start host side saving data."); + LOG_DEBUG("==Start to create save data handler."); SetDataHandler(ckptData); - VLOG(GLOG_DEBUG) << "==Start save data process."; + LOG_DEBUG("==Start save data process."); SaveProcess(ckptData); - LOG(INFO) << "Finish host side saving data."; + LOG_INFO("Finish host side saving data."); } void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo, @@ -49,12 +49,12 @@ void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRa useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; mgmtEmbInfo = EmbInfo; - LOG(INFO) << "Start host side loading data."; - VLOG(GLOG_DEBUG) << "==Start to create load data handler."; + LOG_INFO("Start host side loading data."); + LOG_DEBUG("==Start to create load data handler."); SetDataHandler(featureTypes); - VLOG(GLOG_DEBUG) << "==Start load data process."; + LOG_DEBUG("==Start load data process."); LoadProcess(ckptData); - LOG(INFO) << "Finish host side loading data."; + LOG_INFO("Finish host side loading data."); } void Checkpoint::SetDataHandler(CkptData& ckptData) @@ -143,7 +143,7 @@ void Checkpoint::MakeSaveDir(const string& dirName) { if (access(dirName.c_str(), F_OK) == -1) { if (mkdir(dirName.c_str(), dirMode) == -1) { - VLOG(GLOG_DEBUG) << StringFormat("Unable to create directory: %s", dirName.c_str()); + LOG_DEBUG("Unable to create directory: {}", dirName); } } } @@ -186,7 +186,7 @@ void Checkpoint::SaveDataset(const vector& embNames, auto datasetDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; auto attributeDir { datasetPath + dirSeparator + datasetName + to_string(rankId) + attribFileType }; - VLOG(GLOG_DEBUG) << StringFormat("====Start getting data from handler to: %s", datasetDir.c_str()); + LOG_DEBUG("====Start getting data from handler to: {}", datasetDir); auto transData { dataHandler->GetDataset(saveDataType, embName) }; // save embedding when dynamic expansion is open @@ -195,13 +195,13 @@ void Checkpoint::SaveDataset(const vector& embNames, auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; auto embeddingSizeInfo = GetEmbeddingSize(embName); MakeSaveDir(embedPath); - VLOG(GLOG_DEBUG) << StringFormat("====Start saving embedding data to: %s", datasetDir.c_str()); + LOG_DEBUG("====Start saving embedding data to: {}", datasetDir); WriteEmbedding(transData, embedDatasetDir, embeddingSizeInfo.extEmbSize); } - VLOG(GLOG_DEBUG) << StringFormat("====Start saving data to: %s", datasetDir.c_str()); + LOG_DEBUG("====Start saving data to: {}", datasetDir); WriteStream(transData, datasetDir, transData.datasetSize, saveDataType); - VLOG(GLOG_DEBUG) << StringFormat("====Start saving data to: %s", attributeDir.c_str()); + LOG_DEBUG("====Start saving data to: {}", attributeDir); WriteStream(transData, attributeDir, transData.attributeSize, CkptDataType::ATTRIBUTE); } } @@ -215,8 +215,8 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da #ifndef GTEST auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { - LOG(ERROR) << StringFormat("Set device failed, device_id:%d", deviceId); - throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + LOG_ERROR("Set device failed, device_id:{}", deviceId); + throw runtime_error(Log::Format("Set device failed, device_id:{}", deviceId).c_str()); } auto &transArr = transData.int64Arr; @@ -229,8 +229,8 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da floatPtr, embeddingSize * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); - throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); + LOG_ERROR("aclrtMemcpy failed, ret={}", ret); + throw runtime_error(Log::Format("aclrtMemcpy failed, ret={}", ret).c_str()); } writeFile.write((const char *) (row.data()), embeddingSize * sizeof(float)); @@ -255,7 +255,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, #ifndef GTEST auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { - LOG(ERROR) << StringFormat("Set device failed, device_id:%d", deviceId); + LOG_ERROR("Set device failed, device_id:{}", deviceId); readFile.close(); throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); } @@ -271,7 +271,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, void *newBlock = nullptr; ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); + LOG_ERROR("aclrtMalloc failed, ret={}", ret); readFile.close(); throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } @@ -293,7 +293,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, aclError ret = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); + LOG_ERROR("aclrtMemcpy failed, ret={}", ret); readFile.close(); throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } @@ -310,7 +310,7 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); if (!writeFile.is_open()) { - VLOG(GLOG_DEBUG) << StringFormat("unable to open save file: %s", dataDir.c_str()); + LOG_DEBUG("unable to open save file: {}", dataDir); writeFile.close(); return; } @@ -410,7 +410,7 @@ vector Checkpoint::GetTableLayerLoadDir() } closedir(dir); } else { - LOG(WARNING) << "when loading data in ssd, there are no table files."; + LOG_WARN("when loading data in ssd, there are no table files."); } return loadTableDir; } @@ -430,7 +430,7 @@ void Checkpoint::LoadDataset(const vector& embNames, CkptTransData transData; - VLOG(GLOG_DEBUG) << StringFormat("====Start reading data from: %s", attributeDir.c_str()); + LOG_DEBUG("====Start reading data from: {}", attributeDir); auto dataElmtBytes { dataHandler->GetDataElmtBytes(CkptDataType::ATTRIBUTE) }; ReadStream(transData, attributeDir, CkptDataType::ATTRIBUTE, dataElmtBytes); @@ -439,7 +439,7 @@ void Checkpoint::LoadDataset(const vector& embNames, ReadStreamForEmbData(transData, datasetDir, dataElmtBytes, ckptData, embName); continue; } else { - VLOG(GLOG_DEBUG) << StringFormat("====Start reading data from: %s", datasetDir.c_str()); + LOG_DEBUG("====Start reading data from: {}", datasetDir); ReadStream(transData, datasetDir, saveDataType, dataElmtBytes); } @@ -447,13 +447,11 @@ void Checkpoint::LoadDataset(const vector& embNames, if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { auto embedPath { dataDir + dirSeparator + "key_embedding" }; auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; - VLOG(GLOG_DEBUG) << StringFormat("====Start loading embedding data from: %s", datasetDir.c_str()); + LOG_DEBUG("====Start loading embedding data from: {}", datasetDir); ReadEmbedding(transData, embedDatasetDir, embName); } - VLOG(GLOG_DEBUG) << StringFormat( - "====Start loading data from: %s to data handler.", attributeDir.c_str() - ); + LOG_DEBUG("====Start loading data from: {} to data handler.", attributeDir); if ((saveDataType == CkptDataType::EMB_INFO)) { dataHandler->SetDatasetForLoadEmb(saveDataType, embName, transData, ckptData); } else { @@ -469,7 +467,7 @@ void Checkpoint::ReadStream(CkptTransData& transData, uint32_t dataElmtBytes) { if (dataElmtBytes == 0) { - LOG(WARNING) << "dataElmtBytes is 0, don't handle [/ %] operation"; + LOG_WARN("dataElmtBytes is 0, don't handle [/ %] operation"); return ; } @@ -481,11 +479,11 @@ void Checkpoint::ReadStream(CkptTransData& transData, ValidateReadFile(dataDir, datasetSize); } catch (const std::invalid_argument& e) { readFile.close(); - throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + throw runtime_error(Log::Format("Invalid read file path: {}", e.what())); } if (datasetSize % dataElmtBytes > 0) { - VLOG(GLOG_DEBUG) << StringFormat("data is missing or incomplete in load file: %s", dataDir.c_str()); + LOG_DEBUG("data is missing or incomplete in load file: {}", dataDir); } auto resizeSize { datasetSize / dataElmtBytes }; @@ -504,7 +502,7 @@ void Checkpoint::ReadStream(CkptTransData& transData, idx += readSize; } } else { - VLOG(GLOG_DEBUG) << StringFormat("unable to open load file: %s", dataDir.c_str()); + LOG_DEBUG("unable to open load file: {}", dataDir); } readFile.close(); @@ -517,7 +515,7 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, string embName) { if (dataElmtBytes == 0) { - LOG(ERROR) << "dataElmtBytes is 0, don't handle [/ %] operation"; + LOG_ERROR("dataElmtBytes is 0, don't handle [/ %] operation"); return ; } auto embDataOuterSize = transData.attribute.at(attribEmbDataOuterIdx); @@ -536,7 +534,7 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, } if (datasetSize % embDataOuterSize > 0 || datasetSize % dataElmtBytes > 0) { - LOG(ERROR) << StringFormat("data is missing or incomplete in load file: %s", dataDir.c_str()); + LOG_ERROR("data is missing or incomplete in load file: {}", dataDir); readFile.close(); throw runtime_error("unable to load EMB_DATA cause wrong-format saved emb data"); } @@ -546,7 +544,7 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, auto onceReadByteSize { datasetSize / embDataOuterSize }; if (!readFile.is_open()) { - VLOG(GLOG_DEBUG) << StringFormat("unable to open load file: %s", dataDir.c_str()); + LOG_DEBUG("unable to open load file: {}", dataDir); readFile.close(); return; } diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.cpp b/src/core/ckpt_data_handler/ckpt_data_handler.cpp index 9f5a5522..8faec17b 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.cpp +++ b/src/core/ckpt_data_handler/ckpt_data_handler.cpp @@ -33,9 +33,7 @@ void CkptDataHandler::CleanTransfer() void CkptDataHandler::SetDatasetForLoadEmb(CkptDataType dataType, string embName, CkptTransData& loadedData, CkptData& ckptData) { - LOG(ERROR) << StringFormat( - "Load host emb failed. dataType:%d, embName:%s, loadedData:%d, ckptData:%d", - dataType, embName.c_str(), loadedData.datasetSize, ckptData.embHashMaps.empty() - ); + LOG_ERROR("Load host emb failed. dataType:{}, embName:{}, loadedData:{}, ckptData:{}", + dataType, embName, loadedData.datasetSize, ckptData.embHashMaps.empty()); throw runtime_error("only EMB_INFO and EMB_DATA supported for load host emb"); } \ No newline at end of file diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index a2d6f96e..48687bca 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -16,7 +16,7 @@ void FeatAdmitNEvictCkpt::SetProcessData(CkptData& processData) ClearData(); if (processData.table2Thresh.empty() || processData.histRec.timestamps.empty() || processData.histRec.historyRecords.empty()) { - LOG(ERROR) << "Missing Feature Admit and Evict data"; + LOG_ERROR("Missing Feature Admit and Evict data"); throw std::runtime_error("Missing Feature Admit and Evict data"); } saveTable2Thresh = std::move(processData.table2Thresh); @@ -137,15 +137,14 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) timestamp = transArr.front(); size_t featItemInfoTotalSize = attribute.front() * static_cast(featItemInfoSaveNum); - VLOG(GLOG_DEBUG) << StringFormat("====Start SetHistRec, name: %s, featItemInfoTotalSize: %ld", embName.c_str(), - featItemInfoTotalSize); + LOG_DEBUG("====Start SetHistRec, name: {}, featItemInfoTotalSize: {}", embName, featItemInfoTotalSize); size_t process = 0; size_t printPerStep = ((featItemInfoTotalSize / 100) > 0 ? (featItemInfoTotalSize / 100) : 1); for (size_t i = featItemInfoOffset; i < featItemInfoTotalSize + featItemInfoOffset; i += featItemInfoSaveNum) { process = i % printPerStep; if (process == 1) { - VLOG(GLOG_DEBUG) << StringFormat("====in SetHistRec, process : %f", i/featItemInfoTotalSize); + LOG_DEBUG("====in SetHistRec, process : %f", i/featItemInfoTotalSize); } auto featureId = transArr[i + featureIdIdxOffset]; auto count = transArr[i + countIdxOffset]; @@ -153,7 +152,7 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) histRecs.emplace(featureId, FeatureItemInfo(static_cast(count), lastTime)); } - VLOG(GLOG_DEBUG) << StringFormat("====End SetHistRec, name: %s", embName.c_str()); + LOG_DEBUG("====End SetHistRec, name: {}", embName); } int FeatAdmitNEvictCkpt::GetTable2ThreshSize() diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp index 9b9ffd6f..d0fa9499 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp @@ -24,7 +24,7 @@ void HostEmbCkpt::GetProcessData(CkptData& processData) { saveHostEmbs = nullptr; loadHostEmbs = nullptr; - LOG(INFO) << StringFormat("processData.embHashMaps.empty():%d", processData.embHashMaps.empty()); + LOG_INFO("processData.embHashMaps.empty():{}", processData.embHashMaps.empty()); } vector HostEmbCkpt::GetDataTypes() @@ -59,9 +59,8 @@ CkptTransData HostEmbCkpt::GetDataset(CkptDataType dataType, string embName) void HostEmbCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { - LOG(INFO) << StringFormat( - "Parameter dataType:%d, embName:%s, loadedData:%d", dataType, embName.c_str(), loadedData.datasetSize - ); + LOG_INFO("Parameter dataType:{}, embName:{}, loadedData:{}", + dataType, embName, loadedData.datasetSize); return; } @@ -119,7 +118,7 @@ void HostEmbCkpt::SetEmbInfo(string embName, CkptData& ckptData) // load Emb data void HostEmbCkpt::SetEmbData(string embName, CkptData& ckptData) { - LOG(INFO) << StringFormat("Parameter embName:%s, ckptData:%d", embName.c_str(), ckptData.embHashMaps.empty()); + LOG_INFO("Parameter embName:{}, ckptData:{}", embName, ckptData.embHashMaps.empty()); return; } diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp index c5406366..a3f36f4d 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp @@ -64,7 +64,7 @@ CkptTransData NddrFeatMapCkpt::GetDataset(CkptDataType dataType, string embName) transArr.push_back(it.first); transArr.push_back(it.second); } - LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d", CkptDataType::EMB_INFO, dataType); + LOG_INFO("CkptDataType::EMB_INFO:{}, dataType:{}", CkptDataType::EMB_INFO, dataType); return move(transferData); } @@ -83,5 +83,5 @@ void NddrFeatMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTran int64_t key { transArr.at(i) }; hostHashMap[key] = transArr.at(i + 1); } - LOG(INFO) << StringFormat("dataType:%d", dataType); + LOG_INFO("dataType:{}", dataType); } diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp index 62ce257c..261911b2 100644 --- a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp @@ -52,7 +52,7 @@ CkptTransData NddrOffsetCkpt::GetDataset(CkptDataType dataType, string embName) transferData.attribute.push_back(1); transferData.attribute.push_back(fourBytes); transferData.attributeSize = transferData.attribute.size() * eightBytes; - LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d", CkptDataType::EMB_INFO, dataType); + LOG_INFO("CkptDataType::EMB_INFO:{}, dataType:{}", CkptDataType::EMB_INFO, dataType); return move(transferData); } @@ -61,5 +61,5 @@ void NddrOffsetCkpt::SetDataset(CkptDataType dataType, string embName, CkptTrans CleanTransfer(); transferData = move(loadedData); loadMaxOffset[embName] = transferData.int32Arr.front(); - LOG(INFO) << StringFormat("CkptDataType::EMB_INFO:%d, dataType:%d", CkptDataType::EMB_INFO, dataType); + LOG_INFO("CkptDataType::EMB_INFO:{}, dataType:{}", CkptDataType::EMB_INFO, dataType); } diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 3566f621..2109f1e8 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -23,7 +23,7 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, this->rankInfo = rankInfo; if (!ifLoad) { EmbHashMapInfo embHashMapInfo; - LOG(INFO) << "init emb hash map from scratch"; + LOG_INFO("init emb hash map from scratch"); for (const auto& embInfo: embInfos) { embHashMapInfo.devOffset2Batch.resize(embInfo.devVocabSize); embHashMapInfo.devOffset2Key.resize(embInfo.devVocabSize); @@ -34,14 +34,8 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, fill(embHashMapInfo.devOffset2Key.begin(), embHashMapInfo.devOffset2Key.end(), -1); embHashMaps[embInfo.name] = embHashMapInfo; - if (VLOG_IS_ON(GLOG_TRACE)) { - VLOG(GLOG_TRACE) << StringFormat( - "devOffset2Key, %s", VectorToString(embHashMaps.at(embInfo.name).devOffset2Key).c_str() - ); - VLOG(GLOG_TRACE) << StringFormat( - "devOffset2Batch, %s", VectorToString(embHashMaps.at(embInfo.name).devOffset2Batch).c_str() - ); - } + LOG_TRACE("devOffset2Key, {}", VectorToString(embHashMaps.at(embInfo.name).devOffset2Key)); + LOG_TRACE("devOffset2Batch, {}", VectorToString(embHashMaps.at(embInfo.name).devOffset2Batch)); } } #endif @@ -301,7 +295,7 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; embHashMap.hostHashMap.erase(iter); - LOG_TRACE("evict embName %s, offset %d", embName, offset); + LOG_TRACE("evict embName {}, offset {}", embName, offset); } else { // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 continue; @@ -551,7 +545,7 @@ void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& /// 记录日志:HBM和DDR换入换出后,比较hostHashMap中DDR内key和表对应的lfuCache对象中的key内容 void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const { - if (!VLOG_IS_ON(GLOG_TRACE)) { + if (Log::GetLevel() != Log::TRACE) { return; } auto& hostMap = embHashMap.hostHashMap; diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 74195ad2..f879f4d4 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -23,12 +23,11 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) #ifndef GTEST this->rankInfo = rInfo; this->seed = seed; - LOG(INFO) << StringFormat( - "EmbTable init, deviceID %d, embSize %d running", rInfo.deviceId, embInfo.extEmbeddingSize); + LOG_INFO("EmbTable init, deviceID {}, embSize {} running", rInfo.deviceId, embInfo.extEmbeddingSize); // 计算embedding table需要分配的内存块数 auto ret = aclrtSetDevice(static_cast(rInfo.deviceId)); if (ret != ACL_ERROR_NONE) { - LOG(ERROR) << StringFormat("Set device failed, device_id:%d, ret=%d", rInfo.deviceId, ret); + LOG_ERROR("Set device failed, device_id:{}, ret={}", rInfo.deviceId, ret); throw AclError(); } embSize = embInfo.extEmbeddingSize; @@ -38,7 +37,7 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) void *newBlock = nullptr; aclError ret = aclrtMalloc(&newBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); + LOG_ERROR("aclrtMalloc failed, ret={}", ret); throw AclError(); } // 申请内存初始化 @@ -48,9 +47,7 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) SplitMemoryBlock(newBlock); } totalCapacity = static_cast(memoryList.size()); - LOG(INFO) << StringFormat( - "aclrtMalloc success, emb name:%s, total capacity:%d", embInfo.name.c_str(), totalCapacity - ); + LOG_INFO("aclrtMalloc success, emb name:{}, total capacity:{}", embInfo.name, totalCapacity); #endif } @@ -61,7 +58,7 @@ EmbTable::~EmbTable() // 释放内存块 aclError ret = aclrtFree(block); if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtFree failed, ret=%d", ret); + LOG_ERROR("aclrtFree failed, ret={}", ret); } } #endif @@ -73,11 +70,11 @@ int64_t EmbTable::GetEmbAddress() #ifndef GTEST if (embeddingList.empty()) { PrintStatus(); - VLOG(GLOG_DEBUG) << "GetEmbAddress, embedding_list size: empty! Add block!"; + LOG_DEBUG("GetEmbAddress, embedding_list size: empty! Add block!"); void *addBlock = nullptr; aclError ret = aclrtMalloc(&addBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtMalloc failed, ret=%d", ret); + LOG_ERROR("aclrtMalloc failed, ret={}", ret); throw AclError(); } RandomInit(addBlock, embInfo.initializeInfos, seed); @@ -96,16 +93,15 @@ int64_t EmbTable::GetEmbAddress() void EmbTable::RandomInit(void* newBlock, const vector& initializeInfos, int seed) { #ifndef GTEST - LOG(INFO) << StringFormat( - "Device GenerateEmbData Start, seed:%d, initializer num: %d", seed, initializeInfos.size()); + LOG_INFO("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); vector devEmb(blockSize); for (auto initializeInfo: initializeInfos) { - LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", initializeInfo.name.c_str()); + LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name.c_str()); for (int i = 0; i < BLOCK_EMB_COUNT; i++) { initializeInfo.initializer->GenerateData(&devEmb[i * embSize], embSize); } } - LOG(INFO) << StringFormat("Device GenerateEmbData End, seed:%d", seed); + LOG_INFO("Device GenerateEmbData End, seed:{}", seed); ExecuteAclMemcpy(newBlock, devEmb); #endif } @@ -116,7 +112,7 @@ void EmbTable::ExecuteAclMemcpy(void* newBlock, vector devEmb) aclError ret = aclrtMemcpy( newBlock, blockSize * sizeof(float), devEmb.data(), blockSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { - LOG(ERROR) << StringFormat("aclrtMemcpy failed, ret=%d", ret); + LOG_ERROR("aclrtMemcpy failed, ret={}", ret); throw AclError(); } #endif @@ -138,8 +134,7 @@ void EmbTable::SplitMemoryBlock(void *newBlock) void EmbTable::PrintStatus() { - // 输出embedding table的总容量 - LOG(INFO) << StringFormat("Total capacity:%d", totalCapacity * blockSize); - // 输出embedding table的未使用的使用容量 - LOG(INFO) << StringFormat("Unused capacity:%d", totalCapacity * blockSize - usedCapacity * embSize); + // 输出embedding table的总容量和未使用的使用容量 + LOG_INFO("Total capacity:{}, Unused capacity:{}", + totalCapacity * blockSize, totalCapacity * blockSize - usedCapacity * embSize); } diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 48ebf562..16d6d6d9 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -19,22 +19,22 @@ using namespace std; int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) { #ifndef GTEST - LOG(INFO) << StringFormat(MGMT + "begin hd_transfer initialize, rank:%d", localRankId); + LOG_INFO(MGMT + "begin hd_transfer initialize, rank:{}", localRankId); // 使用AscendCL接口开发应用时,必须先调用aclInit接口,否则可能会导致后续系统内部资源初始化出错,进而导致其它业务异常。 aclError retOk = aclInit(nullptr); - LOG(INFO) << StringFormat(MGMT + "end aclInit, rank:%d", localRankId); + LOG_INFO(MGMT + "end aclInit, rank:{}", localRankId); if (retOk != ACL_SUCCESS) { - LOG(ERROR) << StringFormat(MGMT + "aclInit fail, rank:%d, errno:%d", localRankId, retOk); + LOG_ERROR(MGMT + "aclInit fail, rank:{}, errno:{}", localRankId, retOk); return false; } - LOG(INFO) << StringFormat(MGMT + "start Set device, rank:%d", localRankId); + LOG_INFO(MGMT + "start Set device, rank:{}", localRankId); // 指定当前进程或线程中用于运算的Device,同时隐式创建默认Context auto ret = aclrtSetDevice(static_cast(localRankId)); if (ret != ACL_ERROR_NONE) { - LOG(ERROR) << StringFormat("Set device failed, device_id:%d", localRankId); + LOG_ERROR("Set device failed, device_id:{}", localRankId); return false; } - LOG(INFO) << StringFormat(MGMT + "end Set device, rank:%d", localRankId); + LOG_INFO(MGMT + "end Set device, rank:{}", localRankId); for (const auto& embInfo: embInfos) { auto embName = embInfo.name; for (int i = 0; i < MAX_CHANNEL_NUM; ++i) { @@ -51,18 +51,18 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) int32_t tmp = std::stoi(envTimeout); if (tmp >= -1 && tmp <= INT32_MAX) { this->timeout = tmp; - LOG(INFO) << StringFormat("Succeed to parse ${env:AclTimeout}: %d", tmp); + LOG_INFO("Succeed to parse ${env:AclTimeout}: {}", tmp); } else { - LOG(ERROR) << StringFormat("Failed to parse ${env:AclTimeout}: %d, expected in (0, INT32_MAX)", tmp); + LOG_ERROR("Failed to parse ${env:AclTimeout}: {}, expected in (0, INT32_MAX)", tmp); } } catch (const std::invalid_argument &e) { - LOG(ERROR) << StringFormat("Failed to parse ${env:AclTimeout}: %s, expected a integer, set to default: %d", + LOG_ERROR("Failed to parse ${env:AclTimeout}: {}, expected a integer, set to default: {}", envTimeout, defaultAclTimeout); } } - VLOG(GLOG_DEBUG) << StringFormat("hd transfer timeout:%d", timeout); + LOG_DEBUG("hd transfer timeout:{}", timeout); running = true; - LOG(INFO) << "hd_transfer init"; + LOG_INFO("hd_transfer init"); #endif return true; } @@ -72,11 +72,11 @@ void HDTransfer::Destroy() { #ifndef GTEST running = false; - LOG(INFO) << (HD + "destroy channel start"); + LOG_INFO(HD + "destroy channel start"); for (auto& c: transferChannels) { - LOG(INFO) << StringFormat(HD + "start destroy channel:%s", c.first.c_str()); + LOG_INFO(HD + "start destroy channel:{}", c.first); tensorflow::StopRecvTensorByAcl(&c.second, c.first); - LOG(INFO) << StringFormat(HD + "destroy channel:%s", c.first.c_str()); + LOG_INFO(HD + "destroy channel:{}", c.first); } for (auto& d: aclDatasets) { if (acltdtDestroyDataset(d.second) != ACL_ERROR_NONE) { @@ -102,17 +102,17 @@ void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName try { channelSize = stoi(env); } catch (const std::invalid_argument& e) { - LOG(WARNING) << StringFormat("wrong HD_CHANNEL_SIZE env %s", e.what()); + LOG_WARN("wrong HD_CHANNEL_SIZE env {}", e.what()); channelSize = LARGE_CHANNEL_SIZE; } catch (const std::out_of_range& e) { - LOG(WARNING) << StringFormat("wrong HD_CHANNEL_SIZE env %s", e.what()); + LOG_WARN("wrong HD_CHANNEL_SIZE env {}", e.what()); channelSize = LARGE_CHANNEL_SIZE; } if (channelSize <= 0) { channelSize = LARGE_CHANNEL_SIZE; } } - LOG(INFO) << StringFormat("user config all2all restore lookup channel size:%d", channelSize); + LOG_INFO("user config all2all restore lookup channel size:{}", channelSize); for (int c = static_cast(TransferChannel::D2H); c != static_cast(TransferChannel::INVALID); c++) { auto channel = static_cast(c); string sendName = StringFormat( @@ -129,9 +129,7 @@ void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName } else { transferChannels[sendName] = tdtCreateChannel(localRankId, sendName.c_str(), PING_PONG_SIZE); } - LOG(INFO) << StringFormat( - "create channel:%s %d", sendName.c_str(), static_cast(transferChannels[sendName]) - ); + LOG_INFO("create channel:{} {}", sendName, static_cast(transferChannels[sendName])); } #endif } @@ -156,13 +154,11 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in } string sendName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); - REC_LOG(INFO) << StringFormat( - HD + "hd transfer send %s, send count is %d, size list:%s", - sendName.c_str(), sizes.size(), VectorToString(sizes).c_str() - ); + LOG_INFO(HD + "hd transfer send {}, send count is {}, size list:{}", + sendName, sizes.size(), VectorToString(sizes)); if (sizes.size() == 0) { - LOG(WARNING) << "tensors num can not be zero"; + LOG_WARN("tensors num can not be zero"); return; } bool isNeedResend = false; @@ -175,15 +171,11 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in return; } if (status != tensorflow::Status::OK()) { - LOG(ERROR) << StringFormat( - MGMT + "hd send %s error '%s'", sendName.c_str(), status.error_message().c_str() - ); + LOG_ERROR(MGMT + "hd send {} error '{}'", sendName, status.error_message()); throw runtime_error("hd send error"); } if (batchId != -1 && resendTime != 0) { - LOG(WARNING) << StringFormat( - MGMT + "hd send %s batch: %d failed, retry: %d ", sendName.c_str(), batchId, resendTime - ); + LOG_WARN(MGMT + "hd send {} batch: {} failed, retry: {} ", sendName, batchId, resendTime); } resendTime++; } while (isNeedResend); @@ -201,14 +193,14 @@ vector HDTransfer::Recv(TransferChannel channel, int channel #ifndef GTEST std::vector tensors; string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); - VLOG(GLOG_DEBUG) << StringFormat("hd transfer try recv:%s", recvName.c_str()); + LOG_DEBUG("hd transfer try recv:{}", recvName); TimeCost tc = TimeCost(); tensorflow::Status status = tensorflow::RecvTensorByAcl(transferChannels[recvName], tensors); if (!running) { return {}; } if (status != tensorflow::Status::OK()) { - LOG(ERROR) << StringFormat(MGMT + "%s hd recv error '%s'", recvName.c_str(), status.error_message().c_str()); + LOG_ERROR(MGMT + "{} hd recv error '{}'", recvName, status.error_message()); throw runtime_error("hd recv error"); } @@ -216,9 +208,7 @@ vector HDTransfer::Recv(TransferChannel channel, int channel for (auto& t: tensors) { sizes.push_back(t.NumElements()); } - REC_LOG(INFO) << StringFormat( - "hd transfer recv:%s, size:%d cost:%dms", recvName.c_str(), VectorToString(sizes).c_str(), tc.ElapsedMS() - ); + LOG_INFO("hd transfer recv:{}, size:{} cost:{}ms", recvName, VectorToString(sizes), tc.ElapsedMS()); return tensors; #endif } @@ -234,7 +224,7 @@ size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& #ifndef GTEST std::vector tensors; string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); - VLOG(GLOG_DEBUG) << StringFormat("hd transfer try recv:%s", recvName.c_str()); + LOG_DEBUG("hd transfer try recv:{}", recvName); TimeCost tc = TimeCost(); if (aclDatasets[embName] == nullptr) { throw runtime_error(StringFormat("Failed recv:%s.", recvName.c_str()).c_str()); @@ -246,7 +236,7 @@ size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { throw runtime_error(StringFormat("Failed receive data from acl channel, acl status:%d", aclStatus).c_str()); } - LOG(INFO) << StringFormat("hd transfer recv:%s cost:%dms", recvName.c_str(), tc.ElapsedMS()); + LOG_INFO("hd transfer recv:{} cost:{}ms", recvName, tc.ElapsedMS()); return acltdtGetDatasetSize(aclDatasets[embName]); #endif } diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index e008a781..521468a7 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -28,7 +28,7 @@ void HostEmb::Initialize(const vector& embInfos, int seed) EmbDataGenerator(embInfo.initializeInfos, seed, static_cast(embInfo.hostVocabSize), embInfo.extEmbeddingSize, hostEmb.embData); hostEmbs[embInfo.name] = move(hostEmb); - LOG(INFO) << (HOSTEMB + "HostEmb Initialize End"); + LOG_INFO(HOSTEMB + "HostEmb Initialize End"); } } @@ -42,19 +42,17 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in int embeddingSize, vector> &embData) { #ifndef GTEST - LOG(INFO) << StringFormat( - HOSTEMB + "GenerateEmbData Start, seed:%d, initializer num: %d", seed, initializeInfos.size() - ); + LOG_INFO(HOSTEMB + "GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); embData.clear(); embData.resize(vocabSize, vector(embeddingSize)); for (auto initializeInfo: initializeInfos) { - LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", initializeInfo.name.c_str()); + LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name); for (int i = 0; i < vocabSize; i++) { initializeInfo.initializer->GenerateData(embData.at(i).data(), embeddingSize); } } - LOG(INFO) << StringFormat(HOSTEMB + "GenerateEmbData End, seed:%d", seed); + LOG_INFO(HOSTEMB + "GenerateEmbData End, seed:{}", seed); #endif } @@ -65,26 +63,22 @@ void HostEmb::Join(int channelId) TimeCost tc = TimeCost(); switch (channelId) { case TRAIN_CHANNEL_ID: - VLOG(GLOG_DEBUG) << StringFormat( - HOSTEMB + "start join, channelId:%d, procThreadsForTrain num:%d", + LOG_DEBUG(HOSTEMB + "start join, channelId:{}, procThreadsForTrain num:{}", channelId, procThreadsForTrain.size()); for (auto& t: procThreadsForTrain) { t->join(); } procThreadsForTrain.clear(); - VLOG(GLOG_DEBUG) << StringFormat( - HOSTEMB + "end join, channelId:%d, cost:%dms", channelId, tc.ElapsedMS()); + LOG_DEBUG(HOSTEMB + "end join, channelId:{}, cost:{}ms", channelId, tc.ElapsedMS()); break; case EVAL_CHANNEL_ID: - VLOG(GLOG_DEBUG) << StringFormat( - HOSTEMB + "start join, channelId:%d, procThreadsForEval num:%d", + LOG_DEBUG(HOSTEMB + "start join, channelId:{}, procThreadsForEval num:{}", channelId, procThreadsForEval.size()); for (auto& t: procThreadsForEval) { t->join(); } procThreadsForEval.clear(); - VLOG(GLOG_DEBUG) << StringFormat( - HOSTEMB + "end join, channelId:%d, cost:%dms", channelId, tc.ElapsedMS()); + LOG_DEBUG(HOSTEMB + "end join, channelId:{}, cost:{}ms", channelId, tc.ElapsedMS()); break; default: throw invalid_argument("channelId not in [TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID]"); @@ -99,15 +93,15 @@ void HostEmb::Join(int channelId) /// \param embName 表名 void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, const string& embName) { - LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmb, channelId:%d, embName:%s", channelId, embName.c_str()); + LOG_INFO(HOSTEMB + "UpdateEmb, channelId:{}, embName:{}", channelId, embName); EASY_FUNCTION(profiler::colors::Purple); TimeCost tc = TimeCost(); auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; - LOG(INFO) << StringFormat(HOSTEMB + "wait D2H embs, channelId:%d", channelId); + LOG_INFO(HOSTEMB + "wait D2H embs, channelId:{}", channelId); const auto tensors = hdTransfer->Recv(transferName, channelId, embName); if (tensors.empty()) { - LOG(WARNING) << (HOSTEMB + "recv empty data"); + LOG_WARN(HOSTEMB + "recv empty data"); return; } const Tensor& d2hEmb = tensors[0]; @@ -116,9 +110,8 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; auto& embData = hostEmbs[embName].embData; - VLOG(GLOG_DEBUG) << StringFormat(HOSTEMB + "embName:%s, UpdateEmb missingKeys len = %d, embeddingSize = %d, " - "embData.size = %d", embName.c_str(), missingKeysHostPos.size(), embeddingSize, - embData.size()); + LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}, " + "embData.size = {}", embName, missingKeysHostPos.size(), embeddingSize, embData.size()); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ shared(missingKeysHostPos, tensorPtr, embData, embeddingSize) @@ -129,7 +122,7 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, dst[j] = tensorPtr[j + embeddingSize * i]; } } - LOG(INFO) << StringFormat(HOSTEMB + "update emb end cost: %dms", tc.ElapsedMS()); + LOG_INFO(HOSTEMB + "update emb end cost: {}ms", tc.ElapsedMS()); EASY_END_BLOCK } @@ -139,16 +132,16 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, /// \param embName 表名 void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelId, const string& embName) { - LOG(INFO) << StringFormat(HOSTEMB + "UpdateEmbV2, channelId:%d, embName:%s", channelId, embName.c_str()); + LOG_INFO(HOSTEMB + "UpdateEmbV2, channelId:{}, embName:{}", channelId, embName); EASY_FUNCTION(profiler::colors::Purple) auto updateThread = [&, missingKeysHostPos, channelId, embName] { auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; - LOG(INFO) << StringFormat(HOSTEMB + "wait D2H embs, channelId:%d", channelId); + LOG_INFO(HOSTEMB + "wait D2H embs, channelId:{}", channelId); auto size = hdTransfer->RecvAcl(transferName, channelId, embName); if (size == 0) { - LOG(WARNING) << (HOSTEMB + "recv empty data"); + LOG_WARN(HOSTEMB + "recv empty data"); return; } TimeCost tc = TimeCost(); @@ -164,10 +157,9 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI size_t elementSize = acltdtGetDataSizeFromItem(aclData); size_t dimNum = acltdtGetDimNumFromItem(aclData); - VLOG(GLOG_DEBUG) << StringFormat(HOSTEMB + "embName:%s, UpdateEmb missingKeys len = %d, embeddingSize = %d," - " embData.size = %d, RecvAcl = %d, elementSize = %d, dimNum = %d", - embName.c_str(), missingKeysHostPos.size(), embeddingSize, embData.size(), - size, elementSize, dimNum); + LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}," + " embData.size = {}, RecvAcl = {}, elementSize = {}, dimNum = {}", + embName, missingKeysHostPos.size(), embeddingSize, embData.size(), size, elementSize, dimNum); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ptr, embData, embeddingSize) for (size_t j = 0; j < missingKeysHostPos.size(); j++) { auto& dst = embData[missingKeysHostPos[j]]; @@ -176,7 +168,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI dst[k] = ptr[k + embeddingSize * j]; } } - LOG(INFO) << StringFormat(HOSTEMB + "update emb end cost: %dms", tc.ElapsedMS()); + LOG_INFO(HOSTEMB + "update emb end cost: {}ms", tc.ElapsedMS()); }; switch (channelId) { @@ -215,8 +207,7 @@ void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& tmpData(j + i * embeddingSize) = src[j]; } } - LOG(INFO) << StringFormat( - "GetH2DEmb end, missingKeys count:%d cost:%dms", missingKeysHostPos.size(), tc.ElapsedMS()); + LOG_INFO("GetH2DEmb end, missingKeys count:{} cost:{}ms", missingKeysHostPos.size(), tc.ElapsedMS()); } /// 获取hostEmbs的指针 @@ -234,7 +225,7 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve const vector& offset) { for (auto initializeInfo: initializeInfos) { - LOG(INFO) << StringFormat("Device GenerateEmbData ing. name %s", initializeInfo.name.c_str()); + LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name); for (size_t i = 0; i < offset.size(); i++) { initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), static_cast(embData[0].size())); @@ -251,7 +242,6 @@ void HostEmb::EvictInitEmb(const string& embName, const vector& offset) #ifndef GTEST auto& hostEmb = GetEmb(embName); EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); - LOG(INFO) << StringFormat( - HOSTEMB + "ddr EvictInitEmb!host embName %s, init offsets size: %d", embName.c_str(), offset.size()); + LOG_INFO(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); #endif } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index ea560ad1..d77f053c 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -7,6 +7,7 @@ #include "hybrid_mgmt.h" #include "utils/time_cost.h" +#include "utils/log.h" #include "checkpoint/checkpoint.h" @@ -27,33 +28,31 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& if (getenv("APPLY_GRADIENTS_STRATEGY") != nullptr) { bool strategy = (!strcmp(getenv("APPLY_GRADIENTS_STRATEGY"), SUM_SAME_ID)); PerfConfig::gradientStrategy = strategy; - LOG(INFO) << StringFormat("config GRADIENTS_STRATEGY:%d", strategy); + LOG_INFO("config GRADIENTS_STRATEGY:{}", strategy); } // 设置当前进程用于数据处理的线程数,默认为6,取值1-10;取值不在范围内,则数据处理线程启动失败退出 if (getenv("KEY_PROCESS_THREAD_NUM") != nullptr) { int num = std::atoi(getenv("KEY_PROCESS_THREAD_NUM")); if (num < 1 || num > MAX_KEY_PROCESS_THREAD) { - LOG(ERROR) << StringFormat( - "[HybridMgmt::InitKeyProcess] KEY_PROCESS_THREAD_NUM:%d, should in range [1, %d]", + LOG_ERROR("[HybridMgmt::InitKeyProcess] KEY_PROCESS_THREAD_NUM:{}, should in range [1, {}]", num, MAX_KEY_PROCESS_THREAD); return false; } PerfConfig::keyProcessThreadNum = num; - LOG(INFO) << StringFormat("config KEY_PROCESS_THREAD_NUM:%d", num); + LOG_INFO("config KEY_PROCESS_THREAD_NUM:{}", num); } // 设置AccCTR去重线程数,默认为8,取值1-8;取值不在范围内,则数据处理线程启动失败退出 if (getenv("MAX_UNIQUE_THREAD_NUM") != nullptr) { int num = std::atoi(getenv("MAX_UNIQUE_THREAD_NUM")); if (num < 1 || num > DEFAULT_MAX_UNIQUE_THREAD_NUM) { - LOG(ERROR) << StringFormat( - "[HybridMgmt::InitKeyProcess] MAX_UNIQUE_THREAD_NUM:%d, should in range [1, %d]", + LOG_ERROR("[HybridMgmt::InitKeyProcess] MAX_UNIQUE_THREAD_NUM:{}, should in range [1, {}]", num, DEFAULT_MAX_UNIQUE_THREAD_NUM); return false; } PerfConfig::maxUniqueThreadNum = num; - LOG(INFO) << StringFormat("config MAX_UNIQUE_THREAD_NUM:%d", num); + LOG_INFO("config MAX_UNIQUE_THREAD_NUM:{}", num); } // 设置是否使用AccCTR库提供的去重、分桶功能,默认关闭 @@ -114,8 +113,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, InitRankInfo(rankInfo, embInfos); g_statOn = GetEnv("STAT_ON"); - LOG(INFO) << StringFormat( - MGMT + "begin initialize, localRankSize:%d, localRankId:%d, rank:%d", + LOG_INFO(MGMT + "begin initialize, localRankSize:{}, localRankId:{}, rank:{}", rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); mgmtRankInfo = rankInfo; @@ -158,12 +156,10 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, } for (const auto& info: embInfos) { - LOG(INFO) << StringFormat( - MGMT + "emb[%s] vocab size %d+%d sc:%d", - info.name.c_str(), info.devVocabSize, info.hostVocabSize, info.sendCount); + LOG_INFO(MGMT + "emb[{}] vocab size {}+{} sc:{}", + info.name, info.devVocabSize, info.hostVocabSize, info.sendCount); } - LOG(INFO) << StringFormat( - MGMT + "end initialize, noDDR:%d, maxStep:[%d, %d], rank:%d", rankInfo.noDDR, + LOG_INFO(MGMT + "end initialize, noDDR:{}, maxStep:[{}, {}], rank:{}", rankInfo.noDDR, rankInfo.maxStep.at(TRAIN_CHANNEL_ID), rankInfo.maxStep.at(EVAL_CHANNEL_ID), rankInfo.rankId); #endif return true; @@ -172,7 +168,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, // 比较hostHashMap和cacheManager的数据是否一致 void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) { - if (!VLOG_IS_ON(GLOG_TRACE)) { + if (Log::GetLevel() != Log::TRACE) { return; } auto& embHashMaps = saveData.embHashMaps; @@ -190,12 +186,12 @@ void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) ++tableKeyInDdr; auto cuKey = item.first; if (lfu.find(cuKey) == lfu.end()) { - LOG(ERROR) << "save step error, ddr key:" << cuKey << ", not exist in lfu, hostHashMap offset:" - << item.second; + LOG_ERROR("save step error, ddr key:{}, not exist in lfu, hostHashMap offset:", + cuKey, item.second); } } - LOG(INFO) << "save step end, table:" << embTableName << ", tableKeyInDdr:" << tableKeyInDdr << - ", tableKeyInLfu:" << lfu.size(); + LOG_INFO("save step end, table:{}, tableKeyInDdr:{}, tableKeyInLfu:{}", + embTableName, tableKeyInDdr, lfu.size()); } } @@ -216,10 +212,10 @@ void HybridMgmt::RestoreFreq4Save(CkptData& saveData) auto& embHashMap = it.second; vector hbm2DdrKeys; vector ddr2HbmKeys; - LOG(INFO) << "restore freq info for save step, table:" << embTableName << ", embHashMap.oldSwap size:" - << embHashMap.oldSwap.size(); - LOG(INFO) << "before, ddr key table size:" << ddrKeyFreqMaps[embTableName].size() - << ", exclude ddr key table size:" << excludeDDRKeyFreqMaps[embTableName].size(); + LOG_INFO("restore freq info for save step, table:{}, embHashMap.oldSwap size:{}", + embTableName, embHashMap.oldSwap.size()); + LOG_INFO("before, ddr key table size:{}, exclude ddr key table size:{}", + ddrKeyFreqMaps[embTableName].size(), excludeDDRKeyFreqMaps[embTableName].size()); for (const auto& swapKeys : embHashMap.oldSwap) { hbm2DdrKeys.emplace_back(swapKeys.second); ddr2HbmKeys.emplace_back(swapKeys.first); @@ -240,10 +236,10 @@ void HybridMgmt::RestoreFreq4Save(CkptData& saveData) excludeDDRKeyFreqMaps[embTableName][key] = ddrKeyFreqMaps[embTableName][key]; ddrKeyFreqMaps[embTableName].erase(key); } - LOG(INFO) << "hbm2DdrKeysNotInExcludeMapCount:" << hbm2DdrKeysNotInExcludeMapCount - << ", ddr2HbmKeysNotInDDRMapCount:" << ddr2HbmKeysNotInDDRMapCount; - LOG(INFO) << "after, ddr key table size:" << ddrKeyFreqMaps[embTableName].size() - << ", exclude ddr key table size:" << excludeDDRKeyFreqMaps[embTableName].size(); + LOG_INFO("hbm2DdrKeysNotInExcludeMapCount:{}, ddr2HbmKeysNotInDDRMapCount:{}", + hbm2DdrKeysNotInExcludeMapCount, ddr2HbmKeysNotInDDRMapCount); + LOG_INFO("after, ddr key table size:{}, exclude ddr key table size:{}", + ddrKeyFreqMaps[embTableName].size(), excludeDDRKeyFreqMaps[embTableName].size()); } } @@ -260,12 +256,12 @@ bool HybridMgmt::Save(const string savePath) Checkpoint saveCkpt; if (!mgmtRankInfo.noDDR) { // DDR模式保存host的emb表以及hashmap - VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: ddr mode hashmap"); + LOG_DEBUG(MGMT + "Start host side save: ddr mode hashmap"); saveData.hostEmbs = hostEmbs->GetHostEmbs(); saveData.embHashMaps = hostHashMaps->GetHashMaps(); } else { // HBM模式保存最大偏移(真正使用了多少vocab容量),特征到偏移的映射 - VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: no ddr mode hashmap"); + LOG_DEBUG(MGMT + "Start host side save: no ddr mode hashmap"); saveData.maxOffset = preprocess->GetMaxOffset(); saveData.keyOffsetMap = preprocess->GetKeyOffsetMap(); } @@ -284,7 +280,7 @@ bool HybridMgmt::Save(const string savePath) // 保存特征准入淘汰相关的数据 auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { - VLOG(GLOG_DEBUG) << (MGMT + "Start host side save: feature admit and evict"); + LOG_DEBUG(MGMT + "Start host side save: feature admit and evict"); saveData.table2Thresh = featAdmitNEvict.GetTableThresholds(); saveData.histRec.timestamps = featAdmitNEvict.GetHistoryRecords().timestamps; saveData.histRec.historyRecords = featAdmitNEvict.GetHistoryRecords().historyRecords; @@ -308,7 +304,7 @@ bool HybridMgmt::Load(const string& loadPath) // 数据处理线程上锁 preprocess->LoadSaveLock(); - VLOG(GLOG_DEBUG) << (MGMT + "Start host side load process"); + LOG_DEBUG(MGMT + "Start host side load process"); CkptData loadData; Checkpoint loadCkpt; @@ -329,29 +325,29 @@ bool HybridMgmt::Load(const string& loadPath) if (!mgmtRankInfo.noDDR) { // DDR模式 将加载的hash map进行赋值 - VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: ddr mode hashmap"); + LOG_DEBUG(MGMT + "Start host side load: ddr mode hashmap"); hostHashMaps->LoadHashMap(loadData.embHashMaps); } else { // HBM模式 将加载的最大偏移(真正使用了多少vocab容量)、特征到偏移的映射,进行赋值 - VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: no ddr mode hashmap"); + LOG_DEBUG(MGMT + "Start host side load: no ddr mode hashmap"); preprocess->LoadMaxOffset(loadData.maxOffset); preprocess->LoadKeyOffsetMap(loadData.keyOffsetMap); } // 将加载的特征准入淘汰记录进行赋值 if (featAdmitNEvict.GetFunctionSwitch()) { - VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: feature admit and evict"); + LOG_DEBUG(MGMT + "Start host side load: feature admit and evict"); featAdmitNEvict.LoadTableThresholds(loadData.table2Thresh); featAdmitNEvict.LoadHistoryRecords(loadData.histRec); } if (isSSDEnabled) { - VLOG(GLOG_DEBUG) << (MGMT + "Start host side load: ssd key freq map"); + LOG_DEBUG(MGMT + "Start host side load: ssd key freq map"); auto step = GetStepFromPath(loadPath); cacheManager->Load(loadData.ddrKeyFreqMaps, loadData.excludeDDRKeyFreqMaps, step); } - VLOG(GLOG_DEBUG) << (MGMT + "Finish host side load process"); + LOG_DEBUG(MGMT + "Finish host side load process"); preprocess->LoadSaveUnlock(); @@ -397,9 +393,9 @@ key_offset_map_t HybridMgmt::SendHostMap(const string tableName) key_offset_map_t sendKeyOffsetMap; if (!mgmtRankInfo.noDDR) { - VLOG(GLOG_DEBUG) << (MGMT + "Start send sparse data: ddr mode hashmap"); + LOG_DEBUG(MGMT + "Start send sparse data: ddr mode hashmap"); } else { - VLOG(GLOG_DEBUG) << (MGMT + "Start send sparse data: no ddr mode hashmap"); + LOG_DEBUG(MGMT + "Start send sparse data: no ddr mode hashmap"); keyOffsetMap = preprocess->GetKeyOffsetMap(); } @@ -433,9 +429,9 @@ void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) } } if (!mgmtRankInfo.noDDR) { - VLOG(GLOG_DEBUG) << (MGMT + "Start receive sparse data: ddr mode hashmap"); + LOG_DEBUG(MGMT + "Start receive sparse data: ddr mode hashmap"); } else { - VLOG(GLOG_DEBUG) << (MGMT + "Start receive sparse data: no ddr mode hashmap"); + LOG_DEBUG(MGMT + "Start receive sparse data: no ddr mode hashmap"); preprocess->LoadKeyOffsetMap(loadKeyOffsetMap); preprocess->LoadMaxOffset(loadMaxOffset); } @@ -461,36 +457,30 @@ bool HybridMgmt::IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEm const auto& loadEmbInfo { loadEmbTable->second.hostEmbInfo }; if (setupHostEmbs->sendCount != loadEmbInfo.sendCount) { - LOG(ERROR) << StringFormat( - MGMT + "Load data sendCount %d for table %s does not match setup sendCount %d", - setupHostEmbs->sendCount, setupHostEmbs->name.c_str(), loadEmbInfo.sendCount); + LOG_ERROR(MGMT + "Load data sendCount {} for table {} does not match setup sendCount {}", + setupHostEmbs->sendCount, setupHostEmbs->name, loadEmbInfo.sendCount); loadDataMatches = false; } if (setupHostEmbs->extEmbeddingSize != loadEmbInfo.extEmbeddingSize) { - LOG(ERROR) << StringFormat( - MGMT + "Load data extEmbeddingSize %d for table %s does not match setup extEmbeddingSize %d", - setupHostEmbs->extEmbeddingSize, setupHostEmbs->name.c_str(), loadEmbInfo.extEmbeddingSize); + LOG_ERROR(MGMT + "Load data extEmbeddingSize {} for table {} does not match setup extEmbeddingSize {}", + setupHostEmbs->extEmbeddingSize, setupHostEmbs->name, loadEmbInfo.extEmbeddingSize); loadDataMatches = false; } if (setupHostEmbs->devVocabSize != loadEmbInfo.devVocabSize) { - LOG(ERROR) << StringFormat( - MGMT + "Load data devVocabSize %d for table %s does not match setup devVocabSize %d", - setupHostEmbs->devVocabSize, setupHostEmbs->name.c_str(), loadEmbInfo.devVocabSize); + LOG_ERROR(MGMT + "Load data devVocabSize {} for table {} does not match setup devVocabSize {}", + setupHostEmbs->devVocabSize, setupHostEmbs->name, loadEmbInfo.devVocabSize); loadDataMatches = false; } if (setupHostEmbs->hostVocabSize != loadEmbInfo.hostVocabSize) { - LOG(ERROR) << StringFormat( - MGMT + "Load data hostVocabSize %d for table %s does not match setup hostVocabSize %d", - setupHostEmbs->hostVocabSize, setupHostEmbs->name.c_str(), loadEmbInfo.hostVocabSize); + LOG_ERROR(MGMT + "Load data hostVocabSize {} for table {} does not match setup hostVocabSize {}", + setupHostEmbs->hostVocabSize, setupHostEmbs->name, loadEmbInfo.hostVocabSize); loadDataMatches = false; } if (!loadDataMatches) { return false; } } else { - LOG(ERROR) << StringFormat( - MGMT + "Load data does not contain table with table name: %s", setupHostEmbs->name.c_str() - ); + LOG_ERROR(MGMT + "Load data does not contain table with table name: {}", setupHostEmbs->name); return false; } return true; @@ -510,7 +500,7 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) } if (embTableCount < loadHostEmbs->size()) { - LOG(ERROR) << StringFormat(MGMT + "Load data has %d tables more than setup table num %d", + LOG_ERROR(MGMT + "Load data has {} tables more than setup table num {}", loadHostEmbs->size(), embTableCount); return false; } @@ -528,13 +518,13 @@ void HybridMgmt::Start() if (!mgmtRankInfo.noDDR) { auto parseKeysTaskForTrain = [this]() { TrainTask(TaskType::DDR); - LOG(INFO) << StringFormat("parseKeysTaskForTrain done"); + LOG_INFO("parseKeysTaskForTrain done"); }; procThreads.emplace_back(std::make_unique(parseKeysTaskForTrain)); auto parseKeysTaskForEval = [this]() { EvalTask(TaskType::DDR); - LOG(INFO) << StringFormat("parseKeysTaskForEval done"); + LOG_INFO("parseKeysTaskForEval done"); }; procThreads.emplace_back(std::make_unique(parseKeysTaskForEval)); } @@ -547,13 +537,13 @@ void HybridMgmt::InsertThreadForHBM() #ifndef GTEST auto parseKeysTaskForHBMTrain = [this]() { TrainTask(TaskType::HBM); - LOG(INFO) << "parseKeysTaskForHBMTrain done"; + LOG_INFO("parseKeysTaskForHBMTrain done"); }; procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMTrain)); auto parseKeysTaskForHBMEval = [this]() { EvalTask(TaskType::HBM); - LOG(INFO) << "parseKeysTaskForHBMEval done"; + LOG_INFO("parseKeysTaskForHBMEval done"); }; procThreads.emplace_back(std::make_unique(parseKeysTaskForHBMEval)); #endif @@ -574,19 +564,16 @@ void HybridMgmt::TrainTask(TaskType type) if (!isRunning) { return; } - LOG(INFO) << StringFormat(HYBRID_BLOCKING + - "hybrid start task channel %d batch %d", channelId, theTrainBatchId); + LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, theTrainBatchId); switch (type) { case TaskType::HBM: ParseKeysHBM(TRAIN_CHANNEL_ID, theTrainBatchId); - - LOG(INFO) << StringFormat(MGMT + "ParseKeysHBMBatchId = %d", theTrainBatchId); + LOG_INFO(MGMT + "ParseKeysHBMBatchId = {}", theTrainBatchId); break; case TaskType::DDR: ParseKeys(TRAIN_CHANNEL_ID, theTrainBatchId); - - LOG(INFO) << StringFormat(MGMT + "parseKeysBatchId = %d", theTrainBatchId); + LOG_INFO(MGMT + "parseKeysBatchId = {}", theTrainBatchId); break; default: throw std::invalid_argument("Invalid TaskType Type."); @@ -609,17 +596,16 @@ void HybridMgmt::EvalTask(TaskType type) if (!isRunning) { return; } - LOG(INFO) << StringFormat(HYBRID_BLOCKING + - "hybrid start task channel %d batch %d", channelId, evalBatchId); + LOG_INFO(HYBRID_BLOCKING + "hybrid start task channel {} batch {}", channelId, evalBatchId); switch (type) { case TaskType::HBM: ParseKeysHBM(EVAL_CHANNEL_ID, evalBatchId); - LOG(INFO) << StringFormat(MGMT + "HBM evalBatchId = %d", evalBatchId); + LOG_INFO(MGMT + "HBM evalBatchId = {}", evalBatchId); break; case TaskType::DDR: ParseKeys(EVAL_CHANNEL_ID, evalBatchId); - LOG(INFO) << StringFormat(MGMT + "DDR evalBatchId = %d", evalBatchId); + LOG_INFO(MGMT + "DDR evalBatchId = {}", evalBatchId); break; default: throw std::invalid_argument("Invalid TaskType Type."); @@ -633,8 +619,7 @@ void HybridMgmt::EvalTask(TaskType type) /// \return bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) { - LOG(INFO) << StringFormat( - MGMT + "start parse keys HBM, nBatch:%d , [%d]:%d", mgmtRankInfo.nBatch, channelId, batchId); + LOG_INFO(MGMT + "start parse keys HBM, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); // 循环处理每个表的数据 for (const auto& embInfo: mgmtEmbInfo) { @@ -645,8 +630,7 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) // 获取各类向量,如果为空指针,退出当前函数 auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); if (infoVecs == nullptr) { - LOG(INFO) << StringFormat( - MGMT + "ParseKeys infoVecs empty ! batchId:%d, channelId:%d", batchId, channelId); + LOG_INFO(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); return false; } @@ -655,43 +639,41 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) if (!mgmtRankInfo.useStatic) { all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); } - VLOG(GLOG_DEBUG) << StringFormat("getTensorsSyncTC(ms):%d", getTensorsSyncTC.ElapsedMS()); + LOG_DEBUG("getTensorsSyncTC(ms):{}", getTensorsSyncTC.ElapsedMS()); // 动态shape场景下,发送all2all向量(通信量矩阵) TimeCost sendTensorsSyncTC; if (!mgmtRankInfo.useStatic) { TimeCost sendAll2AllScSyncTC; hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embInfo.name); - VLOG(GLOG_DEBUG) << StringFormat("sendAll2AllScSyncTC(ms):%d", sendAll2AllScSyncTC.ElapsedMS()); + LOG_DEBUG("sendAll2AllScSyncTC(ms):{}", sendAll2AllScSyncTC.ElapsedMS()); } // 发送查询向量 TimeCost sendLookupSyncTC; hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, embInfo.name); infoVecs->pop_back(); - VLOG(GLOG_DEBUG) << StringFormat("sendLookupSyncTC(ms):%d", sendLookupSyncTC.ElapsedMS()); + LOG_DEBUG("sendLookupSyncTC(ms):{}", sendLookupSyncTC.ElapsedMS()); // 训练时,使用全局去重聚合梯度,发送全局去重的key和对应的恢复向量 if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID) { TimeCost sendUnikeysSyncTC; hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embInfo.name); infoVecs->pop_back(); - VLOG(GLOG_DEBUG) << StringFormat("sendUnikeysSyncTC(ms):%d", sendUnikeysSyncTC.ElapsedMS()); + LOG_DEBUG("sendUnikeysSyncTC(ms):{}", sendUnikeysSyncTC.ElapsedMS()); TimeCost sendRestoreVecSecSyncTC; hdTransfer->Send(TransferChannel::RESTORE_SECOND, { infoVecs->back() }, channelId, embInfo.name); infoVecs->pop_back(); - VLOG(GLOG_DEBUG) << StringFormat("sendRestoreVecSecSyncTC(ms):%d", sendRestoreVecSecSyncTC.ElapsedMS()); + LOG_DEBUG("sendRestoreVecSecSyncTC(ms):{}", sendRestoreVecSecSyncTC.ElapsedMS()); } // 发送恢复向量 TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); - VLOG(GLOG_DEBUG) << StringFormat("sendRestoreSyncTC(ms):%d", sendRestoreSyncTC.ElapsedMS()); - - VLOG(GLOG_DEBUG) << StringFormat("sendTensorsSyncTC(ms):%d", sendTensorsSyncTC.ElapsedMS()); - VLOG(GLOG_DEBUG) << StringFormat("ParseKeysTC HBM mode (ms):%d", ParseKeysTC.ElapsedMS()); + LOG_DEBUG("sendRestoreSyncTC(ms):{}, sendTensorsSyncTC(ms):{}, ParseKeysTC HBM mode (ms):{}", + sendRestoreSyncTC.ElapsedMS(), sendTensorsSyncTC.ElapsedMS(), ParseKeysTC.ElapsedMS()); } batchId++; return true; @@ -714,8 +696,7 @@ bool HybridMgmt::EndBatch(int batchId, int channelId) const bool HybridMgmt::ParseKeys(int channelId, int& batchId) { #ifndef GTEST - LOG(INFO) << StringFormat( - MGMT + "DDR mode, start parse keys, nBatch:%d , [%d]:%d", + LOG_INFO(MGMT + "DDR mode, start parse keys, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); TimeCost parseKeyTC; int start = batchId; @@ -723,13 +704,13 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) bool ifHashmapFree = true; bool remainBatch = true; // 是否从通道获取了数据 while (true) { - LOG(INFO) << StringFormat(MGMT + "parse keys, [%d]:%d", channelId, batchId); + LOG_INFO(MGMT + "parse keys, [{}]:{}", channelId, batchId); for (const auto& embInfo : mgmtEmbInfo) { ifHashmapFree = ProcessEmbInfo(embInfo.name, batchId, channelId, iBatch, remainBatch); // 通道数据已空 if (!remainBatch) { - VLOG(GLOG_DEBUG) << StringFormat("last batch ending"); + LOG_DEBUG("last batch ending"); return false; } } @@ -744,21 +725,21 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) } TimeCost embHdTrans2TC; EmbHDTransWrap(channelId, batchId - 1, start, iBatch); - VLOG(GLOG_DEBUG) << StringFormat("embHdTrans2TC TimeCost(ms):%d", embHdTrans2TC.ElapsedMS()); - VLOG(GLOG_DEBUG) << StringFormat("[%d]-%d, parseKeyTC TimeCost(ms):%d", channelId, batchId, parseKeyTC.ElapsedMS()); + LOG_DEBUG("embHdTrans2TC TimeCost(ms):{}", embHdTrans2TC.ElapsedMS()); + LOG_DEBUG("[{}]-{}, parseKeyTC TimeCost(ms):{}", channelId, batchId, parseKeyTC.ElapsedMS()); #endif return true; } inline void HandlePrepareDDRDataRet(TransferRet prepareSSDRet) { - LOG(ERROR) << "Transfer embedding with DDR and SSD error."; + LOG_ERROR("Transfer embedding with DDR and SSD error."); if (prepareSSDRet == TransferRet::SSD_SPACE_NOT_ENOUGH) { - LOG(ERROR) << "PrepareDDRData: SSD available space is not enough."; + LOG_ERROR("PrepareDDRData: SSD available space is not enough."); throw runtime_error("ssdVocabSize too small"); } if (prepareSSDRet == TransferRet::DDR_SPACE_NOT_ENOUGH) { - LOG(ERROR) << "PrepareDDRData: DDR available space is not enough."; + LOG_ERROR("PrepareDDRData: DDR available space is not enough."); throw runtime_error("ddrVocabSize too small"); } throw runtime_error("Transfer embedding with DDR and SSD error."); @@ -793,11 +774,11 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, // 获取各类向量,如果为空指针,退出当前函数 auto infoVecs = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); if (infoVecs == nullptr) { return false; } - VLOG(GLOG_DEBUG) << StringFormat("getTensorsTC(ms):%d", getTensorsTC.ElapsedMS()); + LOG_DEBUG("getTensorsTC(ms):{}", getTensorsTC.ElapsedMS()); TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embName); - VLOG(GLOG_DEBUG) << StringFormat("sendRestoreSyncTC(ms):%d", sendRestoreSyncTC.ElapsedMS()); + LOG_DEBUG("sendRestoreSyncTC(ms):{}", sendRestoreSyncTC.ElapsedMS()); // 调用SSD cache缓存处理流程 PrepareDDRData(embName, embHashMap, lookupKeys, channelId); @@ -808,20 +789,21 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, DDRParam ddrParam(tmpData, offsetsOut); TimeCost hostHashMapProcessTC; hostHashMaps->Process(embName, lookupKeys, iBatch, ddrParam, channelId); - VLOG(GLOG_DEBUG) << StringFormat("hostHashMapProcessTC(ms):%d", hostHashMapProcessTC.ElapsedMS()); + LOG_DEBUG("hostHashMapProcessTC(ms):{}", hostHashMapProcessTC.ElapsedMS()); if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { - vector uniqueKeys, restoreVecSec; + vector uniqueKeys; + vector restoreVecSec; preprocess->GlobalUnique(offsetsOut, uniqueKeys, restoreVecSec); TimeCost sendUnikeysSyncTC; hdTransfer->Send(TransferChannel::UNIQKEYS, { mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : Vec2TensorI32(uniqueKeys) }, channelId, embName); - VLOG(GLOG_DEBUG) << StringFormat("sendUnikeysSyncTC(ms):%d", sendUnikeysSyncTC.ElapsedMS()); TimeCost sendRestoreVecSecSyncTC; hdTransfer->Send(TransferChannel::RESTORE_SECOND, { Vec2TensorI32(restoreVecSec) }, channelId, embName); - VLOG(GLOG_DEBUG) << StringFormat("sendRestoreVecSecSyncTC(ms):%d", sendRestoreVecSecSyncTC.ElapsedMS()); + LOG_DEBUG("sendUnikeysSyncTC(ms):{}sendRestoreVecSecSyncTC(ms):{}", + sendUnikeysSyncTC.ElapsedMS(), sendRestoreVecSecSyncTC.ElapsedMS()); } TimeCost sendTensorsTC; @@ -832,14 +814,12 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } - VLOG(GLOG_DEBUG) << StringFormat("sendTensorsTC(ms):%d", sendTensorsTC.ElapsedMS()); - - VLOG(GLOG_DEBUG) << StringFormat( - "getAndSendTensorsTC(ms):%d, channelId:%d", getAndSendTensorsTC.ElapsedMS(), channelId); + LOG_DEBUG("sendTensorsTC(ms):{} getAndSendTensorsTC(ms):{}, channelId:{}", + sendTensorsTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS(), channelId); if (!isSSDEnabled && embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch - LOG(WARNING) << StringFormat(MGMT + "embName %s[%d]%d,iBatch:%d freeSize not enough, %d", - embName.c_str(), channelId, batchId, iBatch, lookupKeys.size()); + LOG_WARN(MGMT + "embName {}[{}]{}, iBatch:{} freeSize not enough, {}", + embName, channelId, batchId, iBatch, lookupKeys.size()); return false; } return true; @@ -855,16 +835,16 @@ void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, in if (iBatch == 0) { return; } - LOG(INFO) << StringFormat(MGMT + "trans emb, batchId:[%d-%d], channelId:%d", start, batchId, channelId); + LOG_INFO(MGMT + "trans emb, batchId:[{}-{}], channelId:{}", start, batchId, channelId); TimeCost hostEmbsTC; hostEmbs->Join(channelId); - VLOG(GLOG_DEBUG) << StringFormat("hostEmbsTC(ms):%d", hostEmbsTC.ElapsedMS()); + LOG_DEBUG("hostEmbsTC(ms):{}", hostEmbsTC.ElapsedMS()); EmbHDTrans(channelId, batchId); for (int i = 0; i < iBatch - 1; ++i) { // need send empty - LOG(INFO) << StringFormat(MGMT + "trans emb dummy, batchId:%d, ", start + 1 + i); + LOG_INFO(MGMT + "trans emb dummy, batchId:{}, ", start + 1 + i); EmbHDTrans(channelId, batchId); } } @@ -876,7 +856,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) { EASY_FUNCTION(profiler::colors::Blue) EASY_VALUE("mgmtProcess", batchId) - VLOG(GLOG_DEBUG) << StringFormat(MGMT + "trans emb, batchId:%d, channelId:%d", batchId, channelId); + LOG_DEBUG(MGMT + "trans emb, batchId:{}, channelId:{}", batchId, channelId); TimeCost tr; TimeCost h2dTC; // 发送host需要换出的emb @@ -886,7 +866,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) hostEmbs->GetH2DEmb(missingKeys, embInfo.name, h2dEmb); // order! hdTransfer->Send(TransferChannel::H2D, h2dEmb, channelId, embInfo.name, batchId); } - VLOG(GLOG_DEBUG) << StringFormat("h2dTC(ms):%d", h2dTC.ElapsedMS()); + LOG_DEBUG("h2dTC(ms):{}", h2dTC.ElapsedMS()); TimeCost d2hTC; // 接收device换出的emb,并更新到host上 @@ -900,11 +880,8 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) } hostHashMaps->ClearMissingKeys(embInfo.name); } - VLOG(GLOG_DEBUG) << StringFormat("d2hTC(ms):%d", d2hTC.ElapsedMS()); - - VLOG(GLOG_DEBUG) << StringFormat( - "EmbHDTrans TimeCost(ms):%d batchId: %d channelId:%d", tr.ElapsedMS(), batchId, channelId - ); + LOG_DEBUG("D2HTC(ms):{} EmbHDTrans TimeCost(ms):{} batchId: {} channelId:{}", + d2hTC.ElapsedMS(), tr.ElapsedMS(), batchId, channelId); } #endif @@ -918,14 +895,14 @@ bool HybridMgmt::Evict() if (featAdmitNEvict.GetFunctionSwitch()) { featAdmitNEvict.FeatureEvict(evictKeyMap); } else { - LOG(WARNING) << (MGMT + "Hook can not trigger evict, cause AdmitNEvict is not open"); + LOG_WARN(MGMT + "Hook can not trigger evict, cause AdmitNEvict is not open"); return false; } - VLOG(GLOG_DEBUG) << StringFormat(MGMT + "evict triggered by hook, evict TableNum %d ", evictKeyMap.size()); + LOG_DEBUG(MGMT + "evict triggered by hook, evict TableNum {}", evictKeyMap.size()); // 表为空,淘汰触发失败 if (evictKeyMap.size() == 0) { - LOG(WARNING) << (MGMT + "evict triggered by hook before dataset in injected"); + LOG_WARN(MGMT + "evict triggered by hook before dataset in injected"); return false; } @@ -949,9 +926,7 @@ bool HybridMgmt::Evict() void HybridMgmt::EvictKeys(const string& embName, const vector& keys) { #ifndef GTEST - VLOG(GLOG_DEBUG) << StringFormat( - MGMT + "ddr mode, delete emb: [%s]! evict keySize:%d", embName.c_str(), keys.size() - ); + LOG_DEBUG(MGMT + "ddr mode, delete emb: [{}]! evict keySize:{}", embName.c_str(), keys.size()); // 删除映射关系 if (keys.size() != 0) { hostHashMaps->EvictDeleteEmb(embName, keys); @@ -965,19 +940,15 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) evictOffset4Ddr.emplace_back(offsetInHostHashMap - devVocabSize); } if (!evictOffset4Ddr.empty()) { - VLOG(GLOG_DEBUG) << StringFormat( - MGMT + "ddr mode, delete emb: [%s]! evict size on host:%d", embName.c_str(), evictOffset4Ddr.size() - ); + LOG_DEBUG(MGMT + "ddr mode, delete emb: [{}]! evict size on host:{}", embName, evictOffset4Ddr.size()); hostEmbs->EvictInitEmb(embName, evictOffset4Ddr); } else { - LOG(INFO) << StringFormat(MGMT + "ddr mode, evict size on host is empty"); + LOG_INFO(MGMT + "ddr mode, evict size on host is empty"); } // 发送dev侧的淘汰pos,以便dev侧初始化emb auto evictDevOffset = hostHashMaps->embHashMaps.at(embName).evictDevPos; - VLOG(GLOG_DEBUG) << StringFormat( - MGMT + "ddr mode, init dev emb: [%s]! evict size on dev :%d", embName.c_str(), evictDevOffset.size() - ); + LOG_DEBUG(MGMT + "ddr mode, init dev emb: [{}]! evict size on dev :{}", embName, evictDevOffset.size()); vector tmpDataOut; Tensor tmpData = Vec2TensorI32(evictDevOffset); @@ -998,13 +969,13 @@ inline void HybridMgmt::PrepareDDRData(const string& embTableName, EmbHashMapInf if (!isSSDEnabled) { return; } - VLOG(GLOG_DEBUG) << "PrepareDDRData start."; + LOG_DEBUG("PrepareDDRData start."); TimeCost prepareDDRDataTc; TransferRet ret = cacheManager->TransferDDREmbWithSSD(embTableName, embHashMap, keys, channelId); if (ret != TransferRet::TRANSFER_OK) { HandlePrepareDDRDataRet(ret); } - VLOG(GLOG_DEBUG) << StringFormat("PrepareDDRData end, TimeCost(ms):%d", prepareDDRDataTc.ElapsedMS()); + LOG_DEBUG("PrepareDDRData end, TimeCost(ms):{}", prepareDDRDataTc.ElapsedMS()); } void HybridMgmt::EvictSSDKeys(const string& embName, const vector& keys) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index 4d39c571..662582fc 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -32,9 +32,8 @@ void HybridMgmtBlock::CheckAndSetBlock(int channelId) /// \param channelId train 0 eval 1 void HybridMgmtBlock::CheckAndNotifyWake(int channelId) { - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + - "start notify channelId %d pythonBatchId %d hybridBatchId %d", - channelId, pythonBatchId[lastRunChannelId], hybridBatchId[channelId]); + LOG_DEBUG(HYBRID_BLOCKING + "start notify channelId {} pythonBatchId {} hybridBatchId {}", + channelId, pythonBatchId[lastRunChannelId], hybridBatchId[channelId]); CheckValid(channelId); if (pythonBatchId[channelId] >= hybridBatchId[channelId]) { @@ -48,8 +47,7 @@ bool HybridMgmtBlock::WaitValid(int channelId) { // 等待hybrid处理完成 int reTryNumber = 100; - VLOG(INFO) << StringFormat(HYBRID_BLOCKING + - "check step invalid, wait", channelId, hybridBatchId[channelId]); + LOG_INFO(HYBRID_BLOCKING + "check step invalid, wait {} {}", channelId, hybridBatchId[channelId]); // 等待hybrid处理完成后再一次唤醒 while (pythonBatchId[lastRunChannelId] != hybridBatchId[lastRunChannelId] and isRunning) { std::this_thread::sleep_for(std::chrono::milliseconds(10ms)); @@ -83,19 +81,18 @@ void HybridMgmtBlock::CheckValid(int channelId) } // 当python侧第一次调用时,此时跳过参数检查 if (lastRunChannelId == -1) { - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + - "The data channel was called for the first time, and the parameters were " - "checked to be normal channelId %d hybridBatchId %d", channelId, hybridBatchId[channelId]); + LOG_DEBUG(HYBRID_BLOCKING + "The data channel was called for the first time, and the parameters were " + "checked to be normal channelId {} hybridBatchId {}", channelId, hybridBatchId[channelId]); lastRunChannelId = channelId; return; } // 在通道切换时,hybrid预处理的batch与python的一致。 if (pythonBatchId[lastRunChannelId] == hybridBatchId[lastRunChannelId]) { - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + - "HybridMgmt is switching data channels and checking for normal parameters. he number of steps " - "in the previous round is lastRunChannelId %d pythonBatchId %d hybridBatchId %d", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); + LOG_DEBUG(HYBRID_BLOCKING + + "HybridMgmt is switching data channels and checking for normal parameters. he number of steps " + "in the previous round is lastRunChannelId {} pythonBatchId {} hybridBatchId {}", + lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); } else if (pythonBatchId[lastRunChannelId] < hybridBatchId[lastRunChannelId]) { // 在通道切换时,上一个通道处理的数据超出了python侧的调用 if (!WaitValid(lastRunChannelId)) { @@ -103,11 +100,11 @@ void HybridMgmtBlock::CheckValid(int channelId) } } else { // 在通道切换时,hybrid处理的数据还没有赶上python侧,此时需要等待hybrid处理完成 - VLOG(INFO) << StringFormat(HYBRID_BLOCKING + - "When switching data channels, it was found that HybridMgmt processed less data than the " - "Python side.In this case, after reading the dataset, the Python side called it again, but it was " - "interrupted midway,which did not affect the subsequent calls lastRunChannelId %d hybridBatchId %d", - lastRunChannelId, hybridBatchId[lastRunChannelId]); + LOG_INFO(HYBRID_BLOCKING + + "When switching data channels, it was found that HybridMgmt processed less data than the " + "Python side.In this case, after reading the dataset, the Python side called it again, but it was " + "interrupted midway,which did not affect the subsequent calls lastRunChannelId {} hybridBatchId {}", + lastRunChannelId, hybridBatchId[lastRunChannelId]); } lastRunChannelId = channelId; return; @@ -118,8 +115,8 @@ void HybridMgmtBlock::CheckValid(int channelId) void HybridMgmtBlock::DoBlock(int channelId) { // 通道没有切换,不用处理 - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + - "HybridMgmt starts blocking channelId %d hybridBatchId %d", channelId, hybridBatchId[channelId]); + LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt starts blocking channelId {} hybridBatchId {}", + channelId, hybridBatchId[channelId]); while (isBlock[channelId]) { std::this_thread::sleep_for(SLEEP_MS); @@ -127,8 +124,8 @@ void HybridMgmtBlock::DoBlock(int channelId) return; } } - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + - "HybridMgmt is starting to wake up channelId %d hybridBatchId %d", channelId, hybridBatchId[channelId]); + LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt is starting to wake up channelId {} hybridBatchId {}", + channelId, hybridBatchId[channelId]); return; } @@ -136,8 +133,8 @@ void HybridMgmtBlock::DoBlock(int channelId) /// \param channelId channelId train 0 eval 1 void HybridMgmtBlock::ResetAll(int channelId) { - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + - "Hybridmgmt is resetting data channelId %d hybridBatchId %d", channelId, hybridBatchId[channelId]); + LOG_DEBUG(HYBRID_BLOCKING + "Hybridmgmt is resetting data channelId {} hybridBatchId {}", + channelId, hybridBatchId[channelId]); readEmbedBatchId[channelId] = 0; pythonBatchId[channelId] = 0; @@ -151,25 +148,25 @@ int HybridMgmtBlock::CheckSaveEmbdMapValid() { // 检查数据通道此时的HashMap是否被提前处理了 if (pythonBatchId[lastRunChannelId] >= hybridBatchId[lastRunChannelId]) { - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + - "HybridMgmt is checking the step and checking that the parameters are normal. " - "The number of steps in the previous round is " - "lastRunChannelId %d pythonBatchId %d hybridBatchId %d", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); + LOG_DEBUG(HYBRID_BLOCKING + + "HybridMgmt is checking the step and checking that the parameters are normal. " + "The number of steps in the previous round is " + "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", + lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); return 0; } else if (pythonBatchId[lastRunChannelId] + 1 == hybridBatchId[lastRunChannelId]) { // 在通道切换时,上一个通道处理的数据超出了python侧的调用 - VLOG(INFO) << StringFormat(HYBRID_BLOCKING + - "HybridMgmt is checking the step, and the parameters have been processed one step " - "in advance. The number of steps in the previous round was " - "lastRunChannelId %d pythonBatchId %d hybridBatchId %d", - lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); + LOG_DEBUG(HYBRID_BLOCKING + + "HybridMgmt is checking the step, and the parameters have been processed one step " + "in advance. The number of steps in the previous round was " + "lastRunChannelId {} pythonBatchId {} hybridBatchId {}", + lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); return 1; } else { // 在通道切换时,hybrid处理的数据还没有赶上python侧,此时需要等待hybrid处理完成 - VLOG(GLOG_DEBUG) << StringFormat(HYBRID_BLOCKING + "ERROR FLAG lastRunChannelId %d hybridBatchId %d", - lastRunChannelId, hybridBatchId[lastRunChannelId]); + LOG_DEBUG(HYBRID_BLOCKING + "ERROR FLAG lastRunChannelId {} hybridBatchId {}", + lastRunChannelId, hybridBatchId[lastRunChannelId]); return -1; } } @@ -184,7 +181,7 @@ void HybridMgmtBlock::SetBlockStatus(int channelId, bool block) isBlock[channelId] = block; } -void HybridMgmtBlock::Destroy() +void HybridMgmtBlock::Destroy() { if (!isRunning) { // 已经销毁过了,不用再次销毁会报错 diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index 4327b638..2c2e1489 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -23,8 +23,7 @@ void ConstantInitializer::GenerateData(float* const emb, const int embSize) return; } if (embSize < (start + len)) { - LOG(WARNING) << StringFormat( - "InitializeInfo start %d + len %d is larger than embedding size %d.", start, len, embSize); + LOG_WARN("InitializeInfo start {} + len {} is larger than embedding size {}.", start, len, embSize); return; } std::fill_n(emb + start, len, initParam * value); diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index 0cff333d..3b066bba 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -25,8 +25,7 @@ void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) return; } if (embSize < (start + len)) { - LOG(WARNING) << StringFormat( - "InitializeInfo start %d + len %d is larger than embedding size %d.", start, len, embSize); + LOG_WARN("InitializeInfo start {} + len {} is larger than embedding size {}.", start, len, embSize); return; } std::generate_n(emb + start, len, [&]() { return initParam * distribution(generator); }); diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index 85fb4a45..7379a871 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -29,8 +29,7 @@ void TruncatedNormalInitializer::GenerateData(float* const emb, const int embSiz return; } if (embSize < (start + len)) { - LOG(WARNING) << StringFormat( - "InitializeInfo start %d + len %d is larger than embedding size %d.", start, len, embSize); + LOG_WARN("InitializeInfo start {} + len {} is larger than embedding size {}.", start, len, embSize); return; } std::generate_n(emb + start, len, [&]() { diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index f521ff5e..8465e37b 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -28,7 +28,7 @@ bool FeatureAdmitAndEvict::Init(const std::vector& thresholdValu { if (!ParseThresholdCfg(thresholdValues)) { m_isEnableFunction = false; - LOG(ERROR) << "Config is error, feature admin-and-evict function is not available ...\n"; + LOG_ERROR("Config is error, feature admin-and-evict function is not available ..."); return false; } SetCombineSwitch(); @@ -41,7 +41,7 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, const std::unique_ptr& batch, keys_t& splitKey, std::vector& keyCount) { if (splitKey.size() != keyCount.size()) { - LOG(ERROR) << StringFormat("splitKey.size %d != keyCount.size %d", splitKey.size(), keyCount.size()); + LOG_ERROR("splitKey.size {} != keyCount.size {}", splitKey.size(), keyCount.size()); return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR; } @@ -60,9 +60,8 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, absl::flat_hash_map records(m_recordsInitSize); m_recordsData.historyRecords[tableName] = records; } - VLOG(GLOG_DEBUG) << StringFormat( - "FeatureAdmitAndEvict PrintSize, name:[%s], history key:[%d] ...", tableName.c_str(), - m_recordsData.historyRecords[tableName].size()); + LOG_DEBUG("FeatureAdmitAndEvict PrintSize, name:[{}], history key:[{}] ...", + tableName.c_str(), m_recordsData.historyRecords[tableName].size()); if (batch->timestamp > m_recordsData.timestamps[tableName]) { m_recordsData.timestamps[tableName] = batch->timestamp; @@ -89,11 +88,8 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, key = -1; } } - if (VLOG_IS_ON(GLOG_TRACE)) { - VLOG(GLOG_TRACE) << StringFormat( - "FeatureAdmit, name:[%s], channel:[%d], after admit, splitKey:[%s] ...", tableName.c_str(), channel, - VectorToString(splitKey).c_str()); - } + LOG_TRACE("FeatureAdmit, name:[{}], channel:[{}], after admit, splitKey:[{}] ...", + tableName, channel, VectorToString(splitKey)); return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; } @@ -144,11 +140,11 @@ void FeatureAdmitAndEvict::FeatureEvict(map> { std::vector tableNames = GetAllNeedEvictTableNames(); if (tableNames.empty()) { - LOG(INFO) << "EmbNames is empty, no evict function ..."; + LOG_INFO("EmbNames is empty, no evict function ..."); return ; } if (!m_isEnableFunction) { - LOG(WARNING) << "m_isEnableFunction switch is false, no evict function ..."; + LOG_WARN("m_isEnableFunction switch is false, no evict function ..."); return ; } std::lock_guard lock(m_syncMutexs); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 69f2278b..4acde948 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -541,10 +541,8 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) } if (tc.ElapsedSec() > GET_BATCH_TIMEOUT) { if (commId == 0) { - LOG(WARNING) << StringFormat( - KEY_PROCESS "getting batch timeout! 1. check last 'read batch cost' print. channel[%d] commId[%d]", - channel, commId - ); + LOG_WARN(KEY_PROCESS "getting batch timeout! 1. check last 'read batch cost' print. " + "channel[{}] commId[{}]", channel, commId); } this_thread::sleep_for(seconds(1)); tc = TimeCost(); @@ -625,17 +623,14 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch All2All(sc, id, batch->channel, keySendInfo, uniqueInfoOut.all2AllInfo); - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "ProcessBatchWithFastUnique get batchId:%d, batchSize:%d," - " channel:%d, name:%s, restore:%d, keyCount:%d", - batch->batchId, batch->Size(), batch->channel, batch->name.c_str(), - uniqueInfoOut.restore.size(), keySendInfo.keyCount.size() - ); + LOG_DEBUG(KEY_PROCESS "ProcessBatchWithFastUnique get batchId:{}, batchSize:{}," + " channel:{}, name:{}, restore:{}, keyCount:{}", + batch->batchId, batch->Size(), batch->channel, batch->name, + uniqueInfoOut.restore.size(), keySendInfo.keyCount.size()); if (g_statOn) { - LOG(INFO) << StringFormat( - STAT_INFO "channel_id %d batch_id %d rank_id %d " - "batch_key_num_with_fast_unique %d unique_key_num_with_fast_unique %d", + LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} " + "batch_key_num_with_fast_unique {} unique_key_num_with_fast_unique {}", batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), uniqueOut.uniqueIdCnt); } } @@ -885,8 +880,7 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const for (int devId = 0; devId < rankInfo.rankSize; ++devId) { UniqueKeyNum += splitKeys[devId].size(); } - LOG(INFO) << StringFormat( - STAT_INFO "channel_id %d batch_id %d rank_id %d batch_key_num %d faae_unique_key_num %ld", + LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} batch_key_num {} faae_unique_key_num {}", batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), UniqueKeyNum); } return { splitKeys, restore, keyCount }; @@ -945,8 +939,7 @@ tuple, vector, vector> for (int devId = 0; devId < rankInfo.rankSize; ++devId) { UniqueKeyNum += splitKeys[devId].size(); } - LOG(INFO) << StringFormat( - STAT_INFO "channel_id %d batch_id %d rank_id %d batch_key_num %d hot_unique_key_num %ld", + LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} batch_key_num {} hot_unique_key_num {}", batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), UniqueKeyNum); } @@ -1102,7 +1095,7 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int cha } } if (maxOffsetTmp > embInfos[embName].devVocabSize) { - LOG(ERROR) << StringFormat("dev cache overflow %d>%d", maxOffsetTmp, embInfos[embName].devVocabSize); + LOG_ERROR("dev cache overflow {} > {}", maxOffsetTmp, embInfos[embName].devVocabSize); throw std::runtime_error("dev cache overflow!"); } LOG_DEBUG("current hbm emb:{}, usage:{}/{} key2OffsetTC({} ms)", @@ -1161,12 +1154,12 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vec int devId = static_cast(d) & (rankInfo.rankSize - 1); if (restoreVec[i] >= hotPosSize) { restoreVec[i] += blockOffset[devId]; - } else if (VLOG_IS_ON(GLOG_DEBUG)) { + } else if (Log::GetLevel() >= Log::DEBUG) { hotNum += 1; } } - VLOG(GLOG_DEBUG) << StringFormat("hot num in all:%d/%d", hotNum, batch->Size()); - VLOG(GLOG_DEBUG) << StringFormat("buildRestoreVecTC(ms):%d", buildRestoreVecTC.ElapsedMS()); + LOG_DEBUG("hot num in all:{}/{} buildRestoreVecTC(ms):{}", + hotNum, batch->Size(), buildRestoreVecTC.ElapsedMS()); } class EmptyList : public std::exception { diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index 5c4391f2..5b7c359f 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -44,13 +44,10 @@ inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& exter void AddDebugAndTraceLog(size_t batchKeySize, vector& externalKeys, vector& externalSSDKeys) { - VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD: batchKeySize:%d, externalKeys size:%d," - " externalSSDKeys size:%d", - batchKeySize, externalKeys.size(), externalSSDKeys.size()); - if (VLOG_IS_ON(GLOG_TRACE)) { - VLOG(GLOG_TRACE) << "TransferDDREmbWithSSD: externalKeys:" << VectorToString(externalKeys).c_str() - << ", externalSSDKeys:%s" << VectorToString(externalSSDKeys).c_str(); - } + LOG_DEBUG("TransferDDREmbWithSSD: batchKeySize:{}, externalKeys size:{}, externalSSDKeys size:{}", + batchKeySize, externalKeys.size(), externalSSDKeys.size()); + LOG_TRACE("TransferDDREmbWithSSD: externalKeys:{}, externalSSDKeys:{}", + VectorToString(externalKeys), VectorToString(externalSSDKeys)); } /// 去重和过滤无效key @@ -94,8 +91,8 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, if (channelId == TRAIN_CHANNEL_ID) { ddrAvailableSize += embHashMap.evictPos.size(); } - VLOG(GLOG_DEBUG) << StringFormat("TransferDDREmbWithSSD, maxOffset:%d, evictPos size:%d, ddrAvailableSize:%d", - embHashMap.maxOffset, embHashMap.evictPos.size(), ddrAvailableSize); + LOG_DEBUG("TransferDDREmbWithSSD, maxOffset:{}, evictPos size:{}, ddrAvailableSize:{}", + embHashMap.maxOffset, embHashMap.evictPos.size(), ddrAvailableSize); CreateSSDTableIfNotExist(embTableName); // 调用ssdEngine查询当前批次key中保存在SSD中的key diff --git a/src/core/ssd_cache/lfu_cache.cpp b/src/core/ssd_cache/lfu_cache.cpp index 8c68ecca..f6f6001f 100644 --- a/src/core/ssd_cache/lfu_cache.cpp +++ b/src/core/ssd_cache/lfu_cache.cpp @@ -105,7 +105,7 @@ void LFUCache::PutWithInit(emb_key_t key, freq_num_t freq) { if (keyTable.find(key) != keyTable.end()) { // 一般初始化时,key应该不存在已经被插入的情况;此处替换就的key频次信息 - VLOG(GLOG_DEBUG) << StringFormat("key has exist when init process, key:%d", key); + LOG_DEBUG("key has exist when init process, key:{}", key); Pop(key); } freqTable[freq].emplace_front(key, freq); diff --git a/src/core/ssd_engine/file.cpp b/src/core/ssd_engine/file.cpp index ff1c67a6..1834f6df 100644 --- a/src/core/ssd_engine/file.cpp +++ b/src/core/ssd_engine/file.cpp @@ -14,7 +14,7 @@ using namespace MxRec; /// \param saveDir 保存文件夹的路径 File::File(uint64_t fileID, string &saveDir) : fileID(fileID), saveDir(saveDir) { - VLOG(GLOG_DEBUG) << StringFormat("start init file, fileID:%llu", fileID); + LOG_DEBUG("start init file, fileID:{}", fileID); if (!fs::exists(fs::absolute(saveDir))) { if (!fs::create_directories(fs::absolute(saveDir))) { @@ -35,7 +35,7 @@ File::File(uint64_t fileID, string &saveDir) : fileID(fileID), saveDir(saveDir) } fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write); - VLOG(GLOG_DEBUG) << StringFormat("end init file, fileID:%llu", fileID); + LOG_DEBUG("end init file, fileID:{}", fileID); } /// 创建文件实例并加载,从保存路径中读取元数据文件、数据文件 @@ -44,7 +44,7 @@ File::File(uint64_t fileID, string &saveDir) : fileID(fileID), saveDir(saveDir) /// \param step 加载的步数 File::File(uint64_t fileID, string &saveDir, int step) : fileID(fileID), saveDir(saveDir) { - VLOG(GLOG_DEBUG) << StringFormat("start init file with load, fileID:%llu", fileID); + LOG_DEBUG("start init file with load, fileID:{}", fileID); fs::path metaFileToLoad = fs::absolute(saveDir + "/" + to_string(fileID) + ".meta." + to_string(step)); fs::path dataFileToLoad = fs::absolute(saveDir + "/" + to_string(fileID) + ".data." + to_string(step)); @@ -82,7 +82,7 @@ File::File(uint64_t fileID, string &saveDir, int step) : fileID(fileID), saveDir fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write); Load(); - VLOG(GLOG_DEBUG) << StringFormat("end init file with load, fileID:%llu", fileID); + LOG_DEBUG("end init file with load, fileID:{}", fileID); } File::~File() @@ -165,7 +165,7 @@ void File::DeleteEmbedding(emb_key_t key) void File::Save(int step) { - VLOG(GLOG_DEBUG) << StringFormat("start save file at step:%d, fileID:%llu", step, fileID); + LOG_DEBUG("start save file at step:{}, fileID:{}", step, fileID); // write current meta into meta file for (auto [key, offset]: keyToOffset) { @@ -184,7 +184,7 @@ void File::Save(int step) throw invalid_argument("fail to save latest meta, file already exist"); } - VLOG(GLOG_DEBUG) << StringFormat("save latest meta file at step:%d, fileID:%llu", step, fileID); + LOG_DEBUG("save latest meta file at step:{}, fileID:{}", step, fileID); if (!fs::copy_file(metaFilePath, metaFileToSave)) { throw runtime_error("fail to Save latest meta"); } @@ -196,7 +196,7 @@ void File::Save(int step) } // Save data - VLOG(GLOG_DEBUG) << StringFormat("save latest data file at step:%d", step); + LOG_DEBUG("save latest data file at step:{}", step); localFileData.flush(); if (localFileData.fail()) { throw runtime_error("fail to Save data"); @@ -217,13 +217,13 @@ void File::Save(int step) throw runtime_error("fail to re-open data file"); } - VLOG(GLOG_DEBUG) << StringFormat("end save file at step:%d, fileID:%llu", step, fileID); + LOG_DEBUG("end save file at step:{}, fileID:{}", step, fileID); } void File::Load() { // file already validate and open in instantiation - VLOG(GLOG_DEBUG) << StringFormat("start reading meta file, fileID:%llu", fileID); + LOG_DEBUG("start reading meta file, fileID:{}", fileID); emb_key_t key; offset_t offset; do { @@ -254,7 +254,7 @@ void File::Load() throw runtime_error("fail to re-open meta file"); } - VLOG(GLOG_DEBUG) << StringFormat("end reading meta file, fileID:%llu", fileID); + LOG_DEBUG("end reading meta file, fileID:{}", fileID); } vector File::GetKeys() diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index 212de4e7..958694c1 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -302,7 +302,7 @@ void Table::Compact(bool fullCompact) return; } - VLOG(GLOG_DEBUG) << StringFormat("table:%s, start compact", name.c_str()); + LOG_DEBUG("table:{}, start compact", name); vector> compactFileList; for (const auto &f: staleDataFileSet) { @@ -329,7 +329,7 @@ void Table::Compact(bool fullCompact) vector> validEmbs = f->FetchEmbeddings(validKeys); InsertEmbeddingsInner(validKeys, validEmbs); } - VLOG(GLOG_DEBUG) << StringFormat("table:%s, end compact", name.c_str()); + LOG_DEBUG("table:{}, end compact", name); } uint64_t Table::GetTableAvailableSpace() diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 037e54d7..d84a1119 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -205,4 +205,10 @@ namespace MxRec { } } + ostream& operator<<(ostream& ss, MxRec::CkptDataType type) + { + ss << static_cast(type); + return ss; + } + } // end namespace MxRec diff --git a/src/core/utils/common.h b/src/core/utils/common.h index beae701e..7a29286d 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -44,8 +44,6 @@ #define EASY_PROFILER_DISABLE #endif -#define REC_LOG(severity) if (g_glogLevel >= severity) LOG(severity) - namespace MxRec { #define INFO_PTR shared_ptr #define MGMT_CPY_THREADS 4 @@ -579,6 +577,8 @@ namespace MxRec { EXCLUDE_FREQ_MAP = 11, EVICT_POS = 12 }; + + ostream& operator<<(ostream& s, MxRec::CkptDataType type); } // end namespace MxRec #define KEY_PROCESS "\033[45m[KeyProcess]\033[0m " diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 533a4442..ad7fbf68 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -63,7 +63,7 @@ public: void Compute(OpKernelContextPtr context) override { - LOG(INFO) << StringFormat("clear channel %d, context %d", channelId, context->step_id()); + LOG_INFO("clear channel {}, context {}", channelId, context->step_id()); HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); hybridMgmtBlock->ResetAll(channelId); } @@ -119,7 +119,7 @@ class ReadEmbKeyV2Dynamic : public OpKernel { public: explicit ReadEmbKeyV2Dynamic(OpKernelConstructionPtr context) : OpKernel(context) { - VLOG(GLOG_DEBUG) << "ReadEmbKeyV2Dynamic init"; + LOG_DEBUG("ReadEmbKeyV2Dynamic init"); OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); @@ -142,7 +142,7 @@ public: MAX_CHANNEL_NUM))); return; } - VLOG(INFO) << StringFormat(HYBRID_BLOCKING + " reset channel %d", channelId); + LOG_INFO(HYBRID_BLOCKING + " reset channel {}", channelId); hybridMgmtBlock->ResetAll(channelId); threadNum = GetThreadNumEnv(); @@ -158,12 +158,12 @@ public: void Compute(OpKernelContextPtr context) override { EASY_FUNCTION(); - VLOG(GLOG_DEBUG) << "enter ReadEmbKeyV2Dynamic"; + LOG_DEBUG("enter ReadEmbKeyV2Dynamic"); TimeCost tc = TimeCost(); int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { - LOG(WARNING) << StringFormat("skip excess batch after %d/%d", batchId, maxStep); + LOG_WARN("skip excess batch after {}/{}", batchId, maxStep); return; } } @@ -195,11 +195,9 @@ public: TimeCost enqueueTC; EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "ReadEmbKeyV2Dynamic read batch cost(ms):%d, elapsed from last(ms):%d," - " enqueueTC(ms):%d, batch[%d]:%d", - tc.ElapsedMS(), staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId - ); + LOG_DEBUG(KEY_PROCESS "ReadEmbKeyV2Dynamic read batch cost(ms):{}, elapsed from last(ms):{}," + " enqueueTC(ms):{}, batch[{}]:{}", + tc.ElapsedMS(), staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId); staticSw = TimeCost(); } @@ -208,9 +206,7 @@ public: auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < embNames.size(); ++i) { if (!keyProcess->hasEmbName(embNames.at(i))) { - LOG(INFO) << StringFormat( - "ReadEmbKeyV2Dynamic not found emb_name:%d %s", i, embNames.at(i).c_str() - ); + LOG_INFO("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i)); tableUsed.push_back(false); } else { tableUsed.push_back(true); @@ -260,18 +256,18 @@ public: size_t& dataSize) { if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 - LOG(ERROR) << StringFormat("dataSize[%d], fieldNum[%d] ...", dataSize, fieldNumTmp); + LOG_ERROR("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); return false; } // 前面8个字节、即占一个featureId位,是unix时间戳 auto src = (const time_t*)inputTensor.tensor_data().data(); std::copy(src, src + 1, ×tamp); - LOG(INFO) << StringFormat("current batchId[%d] timestamp[%d]", batchId, timestamp); + LOG_INFO("current batchId[{}] timestamp[{}]", batchId, timestamp); dataSize -= 1; if (timestamp <= 0) { - LOG(ERROR) << StringFormat("timestamp[%d] <= 0 ", timestamp); + LOG_ERROR("timestamp[{}] <= 0 ", timestamp); return false; } @@ -321,7 +317,7 @@ class ReadEmbKeyV2 : public OpKernel { public: explicit ReadEmbKeyV2(OpKernelConstructionPtr context) : OpKernel(context) { - VLOG(GLOG_DEBUG) << "ReadEmbKeyV2 init"; + LOG_DEBUG("ReadEmbKeyV2 init"); OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); OP_REQUIRES_OK(context, context->GetAttr("splits", &splits)); // 每个表的field Number @@ -350,7 +346,7 @@ public: "ReadEmbKeyV2 channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", MAX_CHANNEL_NUM))); return; } - VLOG(INFO) << StringFormat(HYBRID_BLOCKING + " reset channel %d", channelId); + LOG_INFO(HYBRID_BLOCKING + " reset channel {}", channelId); // 重置此数据通道中所有的步数 hybridMgmtBlock->ResetAll(channelId); @@ -368,13 +364,13 @@ public: void Compute(OpKernelContextPtr context) override { EASY_FUNCTION(); - VLOG(GLOG_DEBUG) << "enter ReadEmbKeyV2"; + LOG_DEBUG("enter ReadEmbKeyV2"); TimeCost tc = TimeCost(); int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; Tensor* output = nullptr; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { - LOG(WARNING) << StringFormat("skip excess batch after %d/%d", batchId, maxStep); + LOG_WARN(StringFormat("skip excess batch after {}/{}", batchId, maxStep)); OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); out(0) = batchId; @@ -404,9 +400,8 @@ public: TimeCost enqueueTC; EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); - VLOG(GLOG_DEBUG) << StringFormat( - KEY_PROCESS "ReadEmbKeyV2Static read batch cost(ms):%d, elapsed from last(ms):%d," - " enqueueTC(ms):%d, batch[%d]:%d", + LOG_DEBUG(KEY_PROCESS "ReadEmbKeyV2Static read batch cost(ms):{}, elapsed from last(ms):{}," + " enqueueTC(ms):{}, batch[{}]:{}", tc.ElapsedMS(), staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId); staticSw = TimeCost(); } @@ -416,7 +411,7 @@ public: auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < splits.size(); ++i) { if (!keyProcess->hasEmbName(embNames.at(i))) { - LOG(INFO) << StringFormat("ReadEmbKeyV2 not found emb_name:%d %s", i, embNames.at(i).c_str()); + LOG_INFO("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i)); tableUsed.push_back(false); } else { tableUsed.push_back(true); @@ -466,18 +461,18 @@ public: size_t& dataSize) { if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 - LOG(ERROR) << StringFormat("dataSize[%d], fieldNum[%d] ...", dataSize, fieldNumTmp); + LOG_ERROR("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); return false; } // 前面8个字节、即占一个featureId位,是unix时间戳 auto src = (const time_t*)inputTensor.tensor_data().data(); std::copy(src, src + 1, ×tamp); - LOG(INFO) << StringFormat("current batchId[%d] timestamp[%d]", batchId, timestamp); + LOG_INFO("current batchId[{}] timestamp[{}]", batchId, timestamp); dataSize -= 1; if (timestamp <= 0) { - LOG(ERROR) << StringFormat("timestamp[%d] <= 0 ", timestamp); + LOG_ERROR("timestamp[{}] <= 0 ", timestamp); return false; } @@ -563,7 +558,7 @@ public: for (int i { 0 }; i < restoreLen; ++i) { r(i) = i % lookupLen; } - LOG(WARNING) << StringFormat("dummy read batch cost: %d,elapsed from last %d", + LOG_WARN("dummy read batch cost: {},elapsed from last {}", tc.ElapsedMS(), staticSw.ElapsedMS()); tc = TimeCost(); } @@ -581,7 +576,7 @@ public: void Compute(OpKernelContextPtr context) override { - LOG(INFO) << StringFormat("context %d", context->step_id()); + LOG_INFO("context {}", context->step_id()); std::cout << " Cust opp not installed!!" << std::endl; } diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 32c2d89d..8b08443f 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -58,9 +58,9 @@ protected: void UpdateEmb(vector &missingKeysHostPos, int channelId, const string &embName, std::unique_ptr &hostEmb, vector &d2h_emb) { - LOG(INFO) << (HD + "update emb start"); + LOG_INFO(HD + "update emb start"); if (d2h_emb.size() == 0) { - LOG(INFO) << StringFormat(HD + "emb is none channelId:%d", channelId); + LOG_INFO(HD + "emb is none channelId:{}", channelId); return; } @@ -72,12 +72,10 @@ protected: tensorPtr = tensorPtr + hostEmb->GetEmb(embName).hostEmbInfo.extEmbeddingSize; } for (size_t i = 0; i < hostEmb->GetEmb(embName).embData.size(); ++i) { - REC_LOG(INFO) << StringFormat( - "hostEmb: embName %s, %d is: %s", embName.c_str(), i, - VectorToString(hostEmb->GetEmb(embName).embData[i]).c_str() - ); + LOG_INFO("hostEmb: embName {}, {} is: {}", embName, i, + VectorToString(hostEmb->GetEmb(embName).embData[i])); } - LOG(INFO) << (HD + "update emb end"); + LOG_INFO(HD + "update emb end"); d2h_emb.clear(); } @@ -136,7 +134,7 @@ TEST_F(EmbMgmtTest, Initialize) vector tmpData; hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); auto missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; - LOG(INFO) << StringFormat("missingKeys %d", missingKeys); + LOG_INFO("missingKeys {}", missingKeys); hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); auto status = Float2TensorVec(tmpDatas, d2h_emb); ASSERT_EQ(status, true); @@ -146,7 +144,7 @@ TEST_F(EmbMgmtTest, Initialize) lookupKeys = { 2, 3, 5, 6 }; hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; - LOG(INFO) << StringFormat("missingKeys %d", missingKeys); + LOG_INFO("missingKeys {}", missingKeys); hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); status = Float2TensorVec(tmpDatas, d2h_emb); ASSERT_EQ(status, true); @@ -156,7 +154,7 @@ TEST_F(EmbMgmtTest, Initialize) lookupKeys = { 1, 7, 9, 10 }; hostHashMaps->Process(embInfo.name, lookupKeys, currentBatchId, tmpData); missingKeys = hostHashMaps->embHashMaps[embInfo.name].missingKeysHostPos; - LOG(INFO) << StringFormat("missingKeys %d", missingKeys); + LOG_INFO("missingKeys {}", missingKeys); hostEmbs->EmbDataGenerator(initializeInfos, seed, missingKeys.size(), embeddingSize, tmpDatas); Float2TensorVec(tmpDatas, d2h_emb); status = Float2TensorVec(tmpDatas, d2h_emb); diff --git a/src/tests/emb_table/emb_table_test.cpp b/src/tests/emb_table/emb_table_test.cpp index 92a71c77..ecc8711d 100644 --- a/src/tests/emb_table/emb_table_test.cpp +++ b/src/tests/emb_table/emb_table_test.cpp @@ -26,8 +26,8 @@ protected: { // 设置测试用的EmbInfo embInfo.extEmbeddingSize = embTable.TEST_EMB_SIZE; - LOG(INFO) << StringFormat( - "EmbTable BLOCK_EMB_COUNT %d INIT_BLOCK_COUNT %d", embTable.BLOCK_EMB_COUNT, embTable.INIT_BLOCK_COUNT); + LOG_INFO("EmbTable BLOCK_EMB_COUNT {} INIT_BLOCK_COUNT {}", + embTable.BLOCK_EMB_COUNT, embTable.INIT_BLOCK_COUNT); rankInfo.rankId = 0; rankInfo.rankSize = 1; rankInfo.localRankSize = 1; @@ -38,7 +38,7 @@ protected: rankInfo.deviceId = 0; // 初始化EmbeddingTable #ifndef GTEST - LOG(INFO) << StringFormat("rank %d running", rankInfo.deviceId); + LOG_INFO("rank {} running", rankInfo.deviceId); aclInit(nullptr); #endif } @@ -58,15 +58,14 @@ TEST_F(EmbTableTest, Init) #ifndef GTEST // 测试初始化是否出现异常 EXPECT_NO_THROW(embTable.Init(embInfo, rankInfo, 0)); - LOG(INFO) << "embTable Init succeed!"; + LOG_INFO("embTable Init succeed!"); ASSERT_EQ(embTable.rankInfo.g_rankId, rankInfo.g_rankId); ASSERT_EQ(embTable.rankInfo.rankSize, rankInfo.rankSize); ASSERT_EQ(embTable.rankInfo.localRankSize, rankInfo.localRankSize); ASSERT_EQ(embTable.rankInfo.useStatic, rankInfo.useStatic); ASSERT_EQ(embTable.rankInfo.localRankId, rankInfo.localRankId); // 测试容量是否正常 - LOG(INFO) << StringFormat( - "totalCapacity %d, INIT_BLOCK_COUNT %d", embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); + LOG_INFO("totalCapacity {}, INIT_BLOCK_COUNT {}", embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); EXPECT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); #endif } diff --git a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp index b3b0213d..261a1890 100644 --- a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp +++ b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp @@ -25,7 +25,7 @@ public: protected: void SetUp() { - VLOG(GLOG_DEBUG) << StringFormat("%s", "start initialize") ; + LOG_DEBUG("start initialize") ; } }; @@ -77,7 +77,7 @@ TEST_F(HybridMgmtBlockTest, CheckValid) hybridMgmtBlock->CheckValid(1); ASSERT_EQ(-1, 0); } catch (HybridMgmtBlockingException e) { - VLOG(INFO) << StringFormat(HYBRID_BLOCKING + "sucess"); + LOG_INFO(HYBRID_BLOCKING + "sucess"); ASSERT_EQ(0, 0); } hybridMgmtBlock->pythonBatchId[0] = 0; diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 8d1ca883..651b13c7 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -150,7 +150,7 @@ TEST_F(CacheManagerTest, RefreshFreqInfo) cacheManager.RefreshFreqInfoCommon(embTableName, hbm2EvictKeys, TransferType::HBM_2_EVICT); const auto it = cacheManager.excludeDDRKeyCountMap[embTableName].find(160); ASSERT_EQ(it, cacheManager.excludeDDRKeyCountMap[embTableName].end()); - LOG(INFO) << "test RefreshFreqInfo end."; + LOG_INFO("test RefreshFreqInfo end."); } TEST_F(CacheManagerTest, PutKey) @@ -162,7 +162,7 @@ TEST_F(CacheManagerTest, PutKey) ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 1); ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].freqTable[1].size(), 1); ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(15), 1); - LOG(INFO) << "test PutKey end."; + LOG_INFO("test PutKey end."); } TEST_F(CacheManagerTest, IsKeyInSSD) @@ -172,7 +172,7 @@ TEST_F(CacheManagerTest, IsKeyInSSD) ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, checkKeys[1])); ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, checkKeys[2])); ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, checkKeys[3])); - LOG(INFO) << "test IsKeyInSSD end."; + LOG_INFO("test IsKeyInSSD end."); } TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalKey) @@ -184,7 +184,7 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalKey) embHashMapInfo.hostHashMap[75] = 116; auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); ASSERT_EQ(ret, TransferRet::TRANSFER_OK); - LOG(INFO) << "test TransferDDREmbWithSSDByEmptyExternalKey end."; + LOG_INFO("test TransferDDREmbWithSSDByEmptyExternalKey end."); } TEST_F(CacheManagerTest, TransferDDREmbWithSSDByAllProcess) @@ -223,7 +223,7 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByAllProcess) ASSERT_FALSE(cacheManager.ssdEngine->IsKeyExist(embTableName, 8)); ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 15)); - LOG(INFO) << "check detail data before transfer ok."; + LOG_INFO("check detail data before transfer ok."); // externalKeys: SSD(15, 25) + newKey(55, 65, 75) // 训练场景,构造结果:offsetAvailableSize=20+100-118+evictPos.size()=3 @@ -256,7 +256,7 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByAllProcess) ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 9)); ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 8)); ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 15)); - LOG(INFO) << "test TransferDDREmbWithSSDByAllProcess end."; + LOG_INFO("test TransferDDREmbWithSSDByAllProcess end."); } TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalSSDKey) diff --git a/src/tests/ssd_engine/file_test.cpp b/src/tests/ssd_engine/file_test.cpp index 7eaac073..cd95b5f1 100644 --- a/src/tests/ssd_engine/file_test.cpp +++ b/src/tests/ssd_engine/file_test.cpp @@ -25,7 +25,7 @@ TEST(File, CreateEmptyFile) auto f = make_shared(0, savePath); } catch (runtime_error &e) { isExceptionThrown = true; - LOG(ERROR) << e.what(); + LOG_ERROR(e.what()); } ASSERT_EQ(isExceptionThrown, false); fs::remove_all(savePath); @@ -75,7 +75,7 @@ TEST(File, LoadFromFile) try { auto f = make_shared(0, savePath, 0); } catch (runtime_error &e) { - LOG(ERROR) << e.what(); + LOG_ERROR(e.what()); isExceptionThrown = true; } ASSERT_EQ(isExceptionThrown, false); diff --git a/src/tests/utils/log_test.cpp b/src/tests/utils/log_test.cpp index 44dfd67a..09ed911a 100644 --- a/src/tests/utils/log_test.cpp +++ b/src/tests/utils/log_test.cpp @@ -132,4 +132,18 @@ TEST(Log, FewArgs) std::string output = testing::internal::GetCapturedStdout(); cout << output << endl; EXPECT_NE(output.find("hellow hellow"), string::npos); +} + +TEST(Log, CkptType) +{ + MxRec::Log::SetLevel(Log::INFO); + testing::internal::CaptureStdout(); + LOG_INFO("ckpt type={}", CkptDataType::EMB_DATA); + std::string output = testing::internal::GetCapturedStdout(); + EXPECT_NE(output.find("ckpt type=1"), string::npos); + + testing::internal::CaptureStdout(); + LOG_INFO("ckpt type={}", CkptDataType::NDDR_OFFSET); + output = testing::internal::GetCapturedStdout(); + EXPECT_NE(output.find("ckpt type=5"), string::npos); } \ No newline at end of file -- Gitee From 3a48927439314eff71137a896da3c32365147929 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 12 Sep 2023 11:12:34 +0800 Subject: [PATCH 339/551] Match-id-193c6c4f0c0b97cc4494882ce8b6b7d4cec005d5 --- src/core/key_process/key_process.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 4acde948..03681180 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -286,7 +286,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) } unique->UnInitialize(); } catch (const EndRunError &e) { - LOG_ERROR(KEY_PROCESS "abort run: {}", e.what()); + LOG_INFO(KEY_PROCESS "abort run: {}", e.what()); } LOG_INFO(KEY_PROCESS "KeyProcessTaskWithFastUnique exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, threadId, channel); @@ -319,7 +319,7 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) batchQueue->PutDirty(move(batch)); } } catch (const EndRunError &e) { - LOG_ERROR(KEY_PROCESS "abort run: {}", e.what()); + LOG_INFO(KEY_PROCESS "abort run: {}", e.what()); } LOG_INFO(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, threadId, channel); } -- Gitee From 7e6fd4eb92b6da98436f14e637480641a488f0cb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 12 Sep 2023 17:08:08 +0800 Subject: [PATCH 340/551] Match-id-f5e7ac7b21e3f17eec140d8ea0164e7c28be499d --- mx_rec/saver/saver.py | 97 ++++++++++++++++++++++++++++++------------- 1 file changed, 67 insertions(+), 30 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index b6b053b2..4049dd5a 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -5,6 +5,7 @@ import json import os import logging +import threading from collections import defaultdict import numpy as np @@ -19,6 +20,19 @@ from mx_rec.util.perf import performance from mx_rec.validator.validator import DirectoryValidator, FileValidator +# define save model thread +class SaveModelThread(threading.Thread): + def __init__(self, sess, result, root_dir, table_name): + super().__init__() + self.result = result + self.root_dir = root_dir + self.table_name = table_name + self.sess = sess + + def run(self): + Saver().save_table_name_data(self.sess, self.result, self.root_dir, self.table_name) + + class Saver(object): customized_ops = get_customized_ops() @@ -34,7 +48,7 @@ class Saver(object): self.restore_fetch_list = [] self.placeholder_dict = defaultdict(dict) # save_easy_mode : only save the embedding and key data of sparse tables - self.save_easy_mode = os.getenv("SAVE_EASY", 0) + self.save_easy_mode = int(os.getenv("SAVE_EASY", 0)) self.build() def build(self): @@ -108,6 +122,7 @@ class Saver(object): ckpt_name = f"sparse-{base_name}" reading_path = os.path.join(directory, ckpt_name) + set_sparse_dir(reading_path) if not tf.io.gfile.exists(reading_path): raise FileExistsError(f"Given dir {reading_path} does not exist, please double check.") @@ -115,6 +130,26 @@ class Saver(object): logging.info(f"sparse model was restored from dir '{reading_path}' .") logging.debug("======== Restoring finished ========") + @performance("save_table_name_data") + def save_table_name_data(self, sess, result, root_dir, table_name): + dump_data_dict = sess.run(result.get(table_name)) + + table_instance = get_table_instance_by_name(table_name) + self._make_table_name_dir(root_dir, table_instance, table_name) + # save key + if is_asc_manager_initialized() and self.save_easy_mode: + self._save_easy_mode_save_key_data(dump_data_dict, root_dir, table_name) + # save embedding + save_embedding_data(root_dir, table_name, dump_data_dict, self.rank_id) + if table_instance.use_feature_mapping: + save_feature_mapping_data(root_dir, table_name, dump_data_dict, self.rank_id) + save_offset_data(root_dir, table_name, dump_data_dict, self.rank_id) + if "optimizer" in dump_data_dict: + dump_optimizer_data_dict = dump_data_dict.get("optimizer") + for optimizer_name, dump_optimizer_data in dump_optimizer_data_dict.items(): + save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, + self.rank_id) + def _build_save(self): for var in self.var_list: if os.getenv("TF_DEVICE", " ") == "NPU" and "merged" not in var.name: @@ -159,37 +194,38 @@ class Saver(object): assign_op = state.assign(sub_optimizer_placeholder_dict.get(key_state)) self.restore_fetch_list.append(assign_op) + @performance("_save") def _save(self, sess, root_dir): - result = sess.run(self.save_op_dict) + result = self.save_op_dict if is_asc_manager_initialized() and not self.save_easy_mode: save_host_data(root_dir) logging.debug(f"host data was saved.") - for table_name, dump_data_dict in result.items(): - table_instance = get_table_instance_by_name(table_name) - if table_instance.host_vocabulary_size > 0: - table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) - else: - table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) - tf.io.gfile.makedirs(table_dir) - if is_asc_manager_initialized() and self.save_easy_mode: - host_data = get_host_data(table_name) - key = np.array(list(host_data.keys())) - offset = list(host_data.values()) - get_valid_dict_data(dump_data_dict, offset) - save_key_data(root_dir, table_name, key, self.rank_id) - - save_embedding_data(root_dir, table_name, dump_data_dict, self.rank_id) - table_instance = get_table_instance_by_name(table_name) + threads = [] + for table_name in result.keys(): + thread = SaveModelThread(sess, result, root_dir, table_name) + threads.append(thread) - if table_instance.use_feature_mapping: - save_feature_mapping_data(root_dir, table_name, dump_data_dict, self.rank_id) - save_offset_data(root_dir, table_name, dump_data_dict, self.rank_id) - if "optimizer" in dump_data_dict: - dump_optimizer_data_dict = dump_data_dict.get("optimizer") - for optimizer_name, dump_optimizer_data in dump_optimizer_data_dict.items(): - save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, - self.rank_id) + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + @staticmethod + def _make_table_name_dir(root_dir, table_instance, table_name): + if table_instance.host_vocabulary_size > 0: + table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) + else: + table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) + tf.io.gfile.makedirs(table_dir) + + def _save_easy_mode_save_key_data(self, dump_data_dict, root_dir, table_name): + host_data = get_host_data(table_name) + key = np.array(list(host_data.keys())) + offset = list(host_data.values()) + get_valid_dict_data(dump_data_dict, offset) + save_key_data(root_dir, table_name, key, self.rank_id) def _restore(self, sess, reading_path): restore_feed_dict = defaultdict(dict) @@ -283,6 +319,7 @@ def fill_placeholder(reading_path, placeholder_dict, feed_dict, suffix, name_des feed_dict[embedding_placeholder] = data +@performance("save_embedding_data") def save_embedding_data(root_dir, table_name, dump_data_dict, suffix): target_path = generate_path(root_dir, "HashTable", "HBM", table_name, DataName.EMBEDDING.value) data_to_write = dump_data_dict.get(DataName.EMBEDDING.value) @@ -362,8 +399,8 @@ def write_binary_data(writing_path, suffix, data, attributes=None): if target_data_dir.find("://") != -1: logging.debug(f"use hdfs path {target_data_dir} to save sparse data.") - with tf.io.gfile.GFile(target_data_dir, "w") as file: - data = json.dumps(data.flatten().tolist()) + with tf.io.gfile.GFile(target_data_dir, "wb") as file: + data = data.tostring() file.write(data) else: logging.debug(f"use local file path {target_data_dir} to save sparse data.") @@ -401,12 +438,12 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: if DataAttr.DATATYPE.value not in attributes: raise AttributeError(f"Lack of attribute {DataAttr.DATATYPE.value}.") - with tf.io.gfile.GFile(target_data_dir, "r") as file: + with tf.io.gfile.GFile(target_data_dir, "rb") as file: validate_read_file(target_data_dir) if target_data_dir.find("://") != -1: logging.debug("use hdfs path %s to restore sparse data.", target_data_dir) data_to_restore = file.read() - data_to_restore = np.array(json.loads(data_to_restore)) + data_to_restore = np.fromstring(data_to_restore, dtype=attributes.pop(DataAttr.DATATYPE.value)) else: logging.debug("use local file path %s to restore sparse data.", target_data_dir) data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) -- Gitee From 4097166f7e6b1a17292fe21b45fd937119126142 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 8 Sep 2023 17:54:54 +0800 Subject: [PATCH 341/551] Match-id-01deb669d1ef5d4c35ea31802a4d41f92e38b8ef --- src/core/ssd_engine/file.cpp | 51 ++++++++++++----------- src/core/ssd_engine/file.h | 8 ++-- src/core/ssd_engine/ssd_engine.cpp | 1 + src/core/ssd_engine/table.cpp | 63 +++++++++++++++++++---------- src/core/ssd_engine/table.h | 2 +- src/tests/ssd_engine/file_test.cpp | 35 ++++++++-------- src/tests/ssd_engine/table_test.cpp | 2 +- 7 files changed, 91 insertions(+), 71 deletions(-) diff --git a/src/core/ssd_engine/file.cpp b/src/core/ssd_engine/file.cpp index ff1c67a6..995fdc7e 100644 --- a/src/core/ssd_engine/file.cpp +++ b/src/core/ssd_engine/file.cpp @@ -11,19 +11,17 @@ using namespace MxRec; /// 创建新文件实例,包含元数据文件、数据文件 /// \param fileID 文件ID -/// \param saveDir 保存文件夹的路径 -File::File(uint64_t fileID, string &saveDir) : fileID(fileID), saveDir(saveDir) +/// \param fileDir 当前文件目录 +File::File(uint64_t fileID, string &fileDir) : fileID(fileID), fileDir(fileDir) { - VLOG(GLOG_DEBUG) << StringFormat("start init file, fileID:%llu", fileID); + LOG_DEBUG("start init file, fileID:{}", fileID); - if (!fs::exists(fs::absolute(saveDir))) { - if (!fs::create_directories(fs::absolute(saveDir))) { - throw runtime_error("fail to create Save directory"); - } + if (!fs::exists(fs::absolute(fileDir)) && (!fs::create_directories(fs::absolute(fileDir)))) { + throw runtime_error("fail to create Save directory"); } - metaFilePath = fs::absolute(saveDir + "/" + to_string(fileID) + ".meta.latest"); - dataFilePath = fs::absolute(saveDir + "/" + to_string(fileID) + ".data.latest"); + metaFilePath = fs::absolute(fileDir + "/" + to_string(fileID) + ".meta.latest"); + dataFilePath = fs::absolute(fileDir + "/" + to_string(fileID) + ".data.latest"); localFileMeta.open(metaFilePath, ios::out | ios::trunc | ios::binary); if (!localFileMeta.is_open()) { throw runtime_error("fail to create meta file"); @@ -35,19 +33,20 @@ File::File(uint64_t fileID, string &saveDir) : fileID(fileID), saveDir(saveDir) } fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write); - VLOG(GLOG_DEBUG) << StringFormat("end init file, fileID:%llu", fileID); + LOG_DEBUG("end init file, fileID:{}", fileID); } -/// 创建文件实例并加载,从保存路径中读取元数据文件、数据文件 +/// 创建文件实例并加载,从加载路径中读取元数据文件、数据文件,生成临时文件到当前文件目录下 /// \param fileID 文件ID -/// \param saveDir 保存文件夹的路径 +/// \param loadDir 加载文件的目录 +/// \param fileDir 当前文件目录 /// \param step 加载的步数 -File::File(uint64_t fileID, string &saveDir, int step) : fileID(fileID), saveDir(saveDir) +File::File(uint64_t fileID, string &fileDir, string &loadDir, int step) : fileID(fileID), fileDir(fileDir) { - VLOG(GLOG_DEBUG) << StringFormat("start init file with load, fileID:%llu", fileID); + LOG_DEBUG("start init file with load, fileID:{}", fileID); - fs::path metaFileToLoad = fs::absolute(saveDir + "/" + to_string(fileID) + ".meta." + to_string(step)); - fs::path dataFileToLoad = fs::absolute(saveDir + "/" + to_string(fileID) + ".data." + to_string(step)); + fs::path metaFileToLoad = fs::absolute(loadDir + "/" + to_string(fileID) + ".meta." + to_string(step)); + fs::path dataFileToLoad = fs::absolute(loadDir + "/" + to_string(fileID) + ".data." + to_string(step)); if (!fs::exists(metaFileToLoad)) { throw invalid_argument("meta file not found while loading"); } @@ -58,8 +57,8 @@ File::File(uint64_t fileID, string &saveDir, int step) : fileID(fileID), saveDir ValidateReadFile(metaFileToLoad, fs::file_size(metaFileToLoad)); ValidateReadFile(dataFileToLoad, fs::file_size(dataFileToLoad)); - metaFilePath = fs::absolute(saveDir + "/" + to_string(fileID) + ".meta.latest"); - dataFilePath = fs::absolute(saveDir + "/" + to_string(fileID) + ".data.latest"); + metaFilePath = fs::absolute(fileDir + "/" + to_string(fileID) + ".meta.latest"); + dataFilePath = fs::absolute(fileDir + "/" + to_string(fileID) + ".data.latest"); fs::remove(metaFilePath); fs::remove(dataFilePath); @@ -82,7 +81,7 @@ File::File(uint64_t fileID, string &saveDir, int step) : fileID(fileID), saveDir fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write); Load(); - VLOG(GLOG_DEBUG) << StringFormat("end init file with load, fileID:%llu", fileID); + LOG_DEBUG("end init file with load, fileID:{}", fileID); } File::~File() @@ -163,9 +162,9 @@ void File::DeleteEmbedding(emb_key_t key) staleDataCnt += 1; } -void File::Save(int step) +void File::Save(const string &saveDir, int step) { - VLOG(GLOG_DEBUG) << StringFormat("start save file at step:%d, fileID:%llu", step, fileID); + LOG_DEBUG("start save file at step:{}, fileID:{}", step, fileID); // write current meta into meta file for (auto [key, offset]: keyToOffset) { @@ -184,7 +183,7 @@ void File::Save(int step) throw invalid_argument("fail to save latest meta, file already exist"); } - VLOG(GLOG_DEBUG) << StringFormat("save latest meta file at step:%d, fileID:%llu", step, fileID); + LOG_DEBUG("save latest meta file at step:{}, fileID:{}", step, fileID); if (!fs::copy_file(metaFilePath, metaFileToSave)) { throw runtime_error("fail to Save latest meta"); } @@ -196,7 +195,7 @@ void File::Save(int step) } // Save data - VLOG(GLOG_DEBUG) << StringFormat("save latest data file at step:%d", step); + LOG_DEBUG("save latest data file at step:{}", step); localFileData.flush(); if (localFileData.fail()) { throw runtime_error("fail to Save data"); @@ -217,13 +216,13 @@ void File::Save(int step) throw runtime_error("fail to re-open data file"); } - VLOG(GLOG_DEBUG) << StringFormat("end save file at step:%d, fileID:%llu", step, fileID); + LOG_DEBUG("end save file at step:{}, fileID:{}", step, fileID); } void File::Load() { // file already validate and open in instantiation - VLOG(GLOG_DEBUG) << StringFormat("start reading meta file, fileID:%llu", fileID); + LOG_DEBUG("start reading meta file, fileID:{}", fileID); emb_key_t key; offset_t offset; do { @@ -254,7 +253,7 @@ void File::Load() throw runtime_error("fail to re-open meta file"); } - VLOG(GLOG_DEBUG) << StringFormat("end reading meta file, fileID:%llu", fileID); + LOG_DEBUG("end reading meta file, fileID:{}", fileID); } vector File::GetKeys() diff --git a/src/core/ssd_engine/file.h b/src/core/ssd_engine/file.h index 2ccf1246..5233265d 100644 --- a/src/core/ssd_engine/file.h +++ b/src/core/ssd_engine/file.h @@ -25,9 +25,9 @@ namespace MxRec { static const uint64_t offsetDataLen = sizeof(offset_t); public: - File(uint64_t fileID, string &saveDir); + File(uint64_t fileID, string &fileDir); - File(uint64_t fileID, string &saveDir, int step); // initialize with loading specific step data + File(uint64_t fileID, string &fileDir, string &loadDir, int step); // initialize with loading specific step data ~File(); @@ -39,7 +39,7 @@ namespace MxRec { void DeleteEmbedding(emb_key_t key); - void Save(int step); + void Save(const string &saveDir, int step); vector GetKeys(); @@ -51,7 +51,7 @@ namespace MxRec { private: uint64_t fileID; // init by constructor - string saveDir; // init by constructor + string fileDir; // init by constructor fs::path dataFilePath = ""; fs::path metaFilePath = ""; fstream localFileData{}; diff --git a/src/core/ssd_engine/ssd_engine.cpp b/src/core/ssd_engine/ssd_engine.cpp index df226a56..ba4b9ba7 100644 --- a/src/core/ssd_engine/ssd_engine.cpp +++ b/src/core/ssd_engine/ssd_engine.cpp @@ -115,6 +115,7 @@ void SSDEngine::Start() } isRunning = true; compactThread = make_shared([this] { CompactMonitor(); }); + LOG_INFO("SSDEngine start"); } /// 压缩监控方法,达到检查周期时调用表的压缩接口 diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index 6aca62f6..a064ee61 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -38,6 +38,12 @@ Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize maxTableSize(maxTableSize), compactThreshold(compactThreshold) { + // always use first path to save until it's full + curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + g_rankId + "/" + name).string(); + if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { + throw runtime_error("fail to create table directory"); + } + bool isMetaFileFound = false; for (const string &dirPath: saveDirs) { auto metaFilePath = fs::absolute( @@ -50,11 +56,9 @@ Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize break; } if (!isMetaFileFound) { - throw invalid_argument("table meta file not found"); + throw invalid_argument(StringFormat("table:%s meta file not found", name.c_str())); } - // always use first path to save until it's full - curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + g_rankId + "/" + name).string(); LOG_INFO("load table:{} done. try store at path:{}", name, curTablePath); } @@ -112,7 +116,11 @@ void Table::Save(int step) for (const auto &f: fileSet) { uint64_t fid = f->GetFileID(); metaFile.write(reinterpret_cast(&fid), sizeof(fid)); - f->Save(step); + SetTablePathToDiskWithSpace(); + if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { + throw runtime_error("fail to create table directory"); + } + f->Save(curTablePath, step); } metaFile.flush(); @@ -140,25 +148,27 @@ void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) curMaxFileID = fileID; } - bool isFileFound = false; - shared_ptr tmp; + shared_ptr loadedFile = nullptr; for (const string &p: savePaths) { // try to find data file from each path - string dataPath = p + "/" + saveDirPrefix + g_rankId + "/" + name; + string loadPath = p + "/" + saveDirPrefix + g_rankId + "/" + name; + SetTablePathToDiskWithSpace(); + if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { + throw runtime_error("fail to create table directory"); + } try { - tmp = make_shared(fileID, dataPath, step); - fileSet.insert(tmp); - isFileFound = true; + loadedFile = make_shared(fileID, curTablePath, loadPath, step); + fileSet.insert(loadedFile); break; } catch (invalid_argument &e) { // do nothing because file may in other path } } - if (!isFileFound) { + if (loadedFile == nullptr) { throw invalid_argument("data file not found"); } - auto keys = tmp->GetKeys(); + auto keys = loadedFile->GetKeys(); totalKeyCnt += keys.size(); if (totalKeyCnt > maxTableSize) { throw invalid_argument("table size too small, key quantity exceed while loading data"); @@ -169,7 +179,7 @@ void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) throw invalid_argument( "find duplicate key in files, compaction already done before saving, file may broken or modified"); } - keyToFile[k] = tmp; + keyToFile[k] = loadedFile; } } curMaxFileID += 1; @@ -217,7 +227,10 @@ void Table::InsertEmbeddingsInner(vector &keys, vector> } if (curFile == nullptr || (curFile != nullptr && curFile->GetDataCnt() >= maxDataNumInFile)) { - SetValidPath(); + SetTablePathToDiskWithSpace(); + if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { + throw runtime_error("fail to create table directory"); + } curFile = make_shared(curMaxFileID, curTablePath); fileSet.insert(curFile); curMaxFileID++; @@ -287,7 +300,7 @@ void Table::Compact(bool fullCompact) return; } - VLOG(GLOG_DEBUG) << StringFormat("table:%s, start compact", name.c_str()); + LOG_DEBUG("table:{}, start compact", name); vector> compactFileList; for (const auto &f: staleDataFileSet) { @@ -314,7 +327,7 @@ void Table::Compact(bool fullCompact) vector> validEmbs = f->FetchEmbeddings(validKeys); InsertEmbeddingsInner(validKeys, validEmbs); } - VLOG(GLOG_DEBUG) << StringFormat("table:%s, end compact", name.c_str()); + LOG_DEBUG("table:{}, end compact", name); } uint64_t Table::GetTableAvailableSpace() @@ -336,10 +349,12 @@ void Table::DeleteEmbeddingsInner(vector &keys) } } -void Table::SetValidPath() +void Table::SetTablePathToDiskWithSpace() { - while (true) { - fs::space_info si = fs::space((curTablePath)); + constexpr int nMaxLoop = 1024; + int loopCnt = 0; + while (loopCnt < nMaxLoop) { + fs::space_info si = fs::space(savePaths.at(curSavePathIdx)); if ((double(si.available) / double(si.capacity)) > diskAvailSpaceThreshold) { break; } @@ -348,8 +363,12 @@ void Table::SetValidPath() if (curSavePathIdx >= savePaths.size()) { throw runtime_error("all disk's space not enough"); } - curTablePath = savePaths[curSavePathIdx]; - LOG_INFO("current data path's free space less than {}%, try next path:{}", - diskAvailSpaceThreshold * convertToPercentage, curTablePath); + curTablePath = fs::absolute( + savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + g_rankId + "/" + name).string(); + + LOG_INFO("current data path's available space less than {}%, try next path:{}", + diskAvailSpaceThreshold * convertToPercentage, curTablePath); + loopCnt += 1; } } + diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h index f82be2df..64044126 100644 --- a/src/core/ssd_engine/table.h +++ b/src/core/ssd_engine/table.h @@ -50,7 +50,7 @@ namespace MxRec { void LoadDataFileSet(const shared_ptr& metaFile, int step); - void SetValidPath(); + void SetTablePathToDiskWithSpace(); string name; // init by constructor vector savePaths; // init by constructor, support Save and Load from multiple path diff --git a/src/tests/ssd_engine/file_test.cpp b/src/tests/ssd_engine/file_test.cpp index 7eaac073..8be2f0ff 100644 --- a/src/tests/ssd_engine/file_test.cpp +++ b/src/tests/ssd_engine/file_test.cpp @@ -19,16 +19,16 @@ TEST(File, CreateEmptyFile) MPI_Comm_rank(MPI_COMM_WORLD, &rankId); g_rankId = to_string(rankId); - string savePath = g_rankId; + string fileDir = g_rankId; bool isExceptionThrown = false; try { - auto f = make_shared(0, savePath); + auto f = make_shared(0, fileDir); } catch (runtime_error &e) { isExceptionThrown = true; LOG(ERROR) << e.what(); } ASSERT_EQ(isExceptionThrown, false); - fs::remove_all(savePath); + fs::remove_all(fileDir); } TEST(File, LoadFromFile) @@ -38,11 +38,9 @@ TEST(File, LoadFromFile) MPI_Comm_rank(MPI_COMM_WORLD, &rankId); g_rankId = to_string(rankId); - string savePath = g_rankId; - if (!fs::exists(fs::absolute(savePath))) { - if (!fs::create_directories(fs::absolute(savePath))) { - throw runtime_error("fail to create Save directory"); - } + string fileDir = g_rankId; + if (!fs::exists(fs::absolute(fileDir)) && !fs::create_directories(fs::absolute(fileDir))) { + throw runtime_error("fail to create Save directory"); } emb_key_t key = 0; @@ -50,7 +48,7 @@ TEST(File, LoadFromFile) vector val = {1.0}; fstream localFileMeta; - localFileMeta.open(savePath + "/0.meta.0", ios::out | ios::trunc | ios::binary); + localFileMeta.open(fileDir + "/0.meta.0", ios::out | ios::trunc | ios::binary); localFileMeta.write(reinterpret_cast(&key), sizeof(key)); localFileMeta.write(reinterpret_cast(&offset), sizeof(offset)); localFileMeta.flush(); @@ -60,7 +58,7 @@ TEST(File, LoadFromFile) localFileMeta.close(); fstream localFileData; - localFileData.open(savePath + "/0.data.0", ios::out | ios::trunc | ios::binary); + localFileData.open(fileDir + "/0.data.0", ios::out | ios::trunc | ios::binary); uint64_t embSize = val.size(); localFileData.write(reinterpret_cast(&embSize), sizeof(embSize)); localFileData.write(reinterpret_cast(val.data()), val.size() * sizeof(float)); @@ -72,14 +70,15 @@ TEST(File, LoadFromFile) // start test bool isExceptionThrown = false; + string loadDir = fileDir; // for test convenience try { - auto f = make_shared(0, savePath, 0); + auto f = make_shared(0, fileDir, loadDir, 0); } catch (runtime_error &e) { LOG(ERROR) << e.what(); isExceptionThrown = true; } ASSERT_EQ(isExceptionThrown, false); - fs::remove_all(savePath); + fs::remove_all(fileDir); } TEST(File, WriteAndRead) @@ -117,17 +116,19 @@ TEST(File, SaveAndLoad) g_rankId = to_string(rankId); int saveStep = 0; - string savePath = g_rankId; - auto fTmp = make_shared(0, savePath); + string fileDir = g_rankId; + auto fTmp = make_shared(0, fileDir); vector key = {0}; vector> expect = {{1.0, 1.1}}; fTmp->InsertEmbeddings(key, expect); - fTmp->Save(saveStep); + string saveDir = fileDir; // for test convenience + fTmp->Save(saveDir, saveStep); - auto fLoad = make_shared(0, savePath, saveStep); + string loadDir = fileDir; // for test convenience + auto fLoad = make_shared(0, fileDir, loadDir, saveStep); auto actual = fLoad->FetchEmbeddings(key); ASSERT_EQ(expect, actual); - fs::remove_all(savePath); + fs::remove_all(fileDir); } diff --git a/src/tests/ssd_engine/table_test.cpp b/src/tests/ssd_engine/table_test.cpp index ed5b045d..44b1704c 100644 --- a/src/tests/ssd_engine/table_test.cpp +++ b/src/tests/ssd_engine/table_test.cpp @@ -89,7 +89,7 @@ TEST(Table, WriteAndReadAndDeleteAndCompact) ASSERT_EQ(fs::exists(oldDataFilePath), false); ASSERT_EQ(fs::exists(oldMetaFilePath), false); - for (string p: savePath) { + for (const string& p: savePath) { fs::remove_all(p); } } -- Gitee From 1b3dfc850749ec1ff1591c284c1da86fd8997d12 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 12 Sep 2023 20:22:58 +0800 Subject: [PATCH 342/551] Match-id-ca81286b48f18d3e5b2ba02caa37b77ae21cd8f3 --- mx_rec/core/feature_process.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py index 489a2946..9ffc0324 100644 --- a/mx_rec/core/feature_process.py +++ b/mx_rec/core/feature_process.py @@ -28,7 +28,10 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): self._global_step_tensor = None self.check_evict_init_params() - logging.info(f"_EvictHook - > evict_time_interval: %d, evict_step_interval: %d", self._evict_time_interval, + if evict_step_interval is None: + logging.info(f"_EvictHook - > evict_time_interval: %d", self._evict_time_interval) + else: + logging.info(f"_EvictHook - > evict_time_interval: %d, evict_step_interval: %d", self._evict_time_interval, self._evict_step_interval) def begin(self): -- Gitee From fbde137ba809002427ab66fa845e0234ba6ca02c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Sep 2023 09:08:45 +0800 Subject: [PATCH 343/551] Match-id-cd3ef4b144876a50c9bc374efc722ad792039c08 --- .../op_kernel/embedding_lookup_by_address.cpp | 2 - .../op_kernel/embedding_update_by_address.cpp | 1 - mx_rec/constants/constants.py | 81 +++- mx_rec/core/asc/build_graph.py | 25 +- mx_rec/core/asc/feature_spec.py | 93 ++-- mx_rec/core/asc/helper.py | 46 +- mx_rec/core/asc/manager.py | 89 ++-- mx_rec/core/asc/merge_table.py | 13 +- mx_rec/core/embedding.py | 370 +++++++--------- mx_rec/core/feature_process.py | 43 +- mx_rec/graph/merge_lookup.py | 21 +- mx_rec/graph/modifier.py | 60 +-- mx_rec/graph/patch.py | 30 +- mx_rec/logger/log.py | 16 +- mx_rec/optimizers/adagrad.py | 18 +- mx_rec/optimizers/base.py | 5 +- mx_rec/optimizers/ftrl.py | 42 +- mx_rec/optimizers/ftrl_t.py | 2 - mx_rec/optimizers/ftrl_t_dense.py | 8 +- mx_rec/optimizers/gradient_descent.py | 13 +- mx_rec/optimizers/gradient_descent_by_addr.py | 9 +- mx_rec/optimizers/lazy_adam.py | 36 +- mx_rec/optimizers/lazy_adam_by_addr.py | 32 +- mx_rec/optimizers/momentum.py | 41 +- mx_rec/saver/patch.py | 24 +- mx_rec/saver/saver.py | 57 +-- mx_rec/saver/sparse.py | 18 +- mx_rec/util/__init__.py | 4 - mx_rec/util/communication/hccl_mgmt.py | 58 ++- mx_rec/util/global_env_conf.py | 99 +++++ mx_rec/util/initialize.py | 179 ++++---- mx_rec/util/log.py | 25 +- mx_rec/util/normalization.py | 8 +- mx_rec/util/ops.py | 5 +- mx_rec/util/perf.py | 5 +- mx_rec/validator/validator.py | 409 +++++++++++++----- .../feat_admit_n_evict_ckpt.cpp | 4 +- .../feat_admit_n_evict_ckpt.h | 1 + src/core/emb_hashmap/emb_hashmap.cpp | 16 +- src/core/emb_hashmap/emb_hashmap.h | 3 +- src/core/hd_transfer/hd_transfer.cpp | 43 +- src/core/hd_transfer/hd_transfer.h | 3 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 154 +++---- src/core/hybrid_mgmt/hybrid_mgmt.h | 5 +- .../key_process/feature_admit_and_evict.cpp | 2 +- src/core/key_process/key_process.cpp | 28 +- src/core/key_process/key_process.h | 1 + src/core/utils/common.cpp | 54 +-- src/core/utils/common.h | 17 +- src/core/utils/config.cpp | 131 ++++++ src/core/utils/config.h | 57 +++ src/pybind/module_main.cpp | 5 +- src/tests/checkpoint/checkpoint_test.cpp | 5 +- .../ckpt_data_handler_test.cpp | 11 +- .../feature_admit_and_evict_test.cpp | 13 +- src/tests/key_process/key_process_test.cpp | 20 +- src/tests/utils/common_test.cpp | 11 - tests/mx_rec/validator/test_validators.py | 71 ++- 58 files changed, 1466 insertions(+), 1176 deletions(-) create mode 100644 mx_rec/util/global_env_conf.py create mode 100644 src/core/utils/config.cpp create mode 100644 src/core/utils/config.h diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 84548729..235a9561 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -31,7 +31,6 @@ public: __aicore__ inline void Init_param(GM_ADDR tiling) { GET_TILING_DATA(constData, tiling); - // TODO: user kernel impl // 数据的维度数 int32_t update_dim = constData.update_dim; int32_t embbeding_type = constData.embbeding_type; @@ -200,7 +199,6 @@ extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR addres GM_ADDR tiling) { GET_TILING_DATA(constData, tiling); - // // TODO: user kernel impl int32_t embbeding_type = constData.embbeding_type; diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index b35a56c2..6fc875a6 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -28,7 +28,6 @@ public: __aicore__ inline void Init_param(GM_ADDR tiling) { GET_TILING_DATA(constData, tiling); - // TODO: user kernel impl // 数据的维度数 int32_t update_dim = constData.update_dim; int32_t embbeding_type = constData.embbeding_type; diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 248b15c9..d5eb61b7 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -19,8 +19,8 @@ ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX = "ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX" ASCEND_SPARSE_LOOKUP_HOT_POS = "ASCEND_SPARSE_LOOKUP_HOT_POS" ASCEND_TIMESTAMP = "ASCEND_TIMESTAMP" CUSTOMIZED_OPS_LIB_PATH = "CUSTOMIZED_OPS_LIB_PATH" -HOST_PIPELINE_OPS_LIB_PATH = "HOST_PIPELINE_OPS_LIB_PATH" ASCEND_SPARSE_LOOKUP_LOCAL_EMB = "ASCEND_SPARSE_LOOKUP_LOCAL_EMB" +EMPTY_STR = "" # 自动改图模式下从计算图中寻找dataset的锚点名称 ANCHOR_DATASET_NAME = "PrefetchDataset" @@ -31,16 +31,33 @@ ASCEND_TABLE_NAME_MUST_CONTAIN = None # this number is a temp plan to solve a problem # to avoid op "scatter_nd_update" may get a None tensor for input AVOID_TENSOR_POS = 439999 -LOCAL_RANK_SIZE = "LOCAL_RANK_SIZE" # 训练时,当前服务器使用的NPU卡数 -MAX_DEVICE_NUM_LOCAL_MACHINE = 16 # 单台服务器最大的卡数 -DEFAULT_DEVICE_NUM_LOCAL_MACHINE = 8 # 单台服务器默认的卡数 + +# acl通道数据深度 +DEFAULT_HD_CHANNEL_SIZE = 40 +MAX_HD_CHANNEL_SIZE = 8192 +MIN_HD_CHANNEL_SIZE = 2 + +# key process线程数 +DEFAULT_KP_THREAD_NUM = 6 +MIN_KP_THREAD_NUM = 1 +MAX_KP_THREAD_NUM = 32 + +# Fast unique去重最大线程数 +DEFAULT_FAST_UNIQUE_THREAD_NUM = 8 +MIN_FAST_UNIQUE_THREAD_NUM = 1 +MAX_FAST_UNIQUE_THREAD_NUM = 8 + +# Hot Embedding更新步数 +DEFAULT_HOT_EMB_UPDATE_STEP = 1000 +MIN_HOT_EMB_UPDATE_STEP = 1 +MAX_HOT_EMB_UPDATE_STEP = 1000 MULTI_LOOKUP_TIMES = 128 DEFAULT_EVICT_TIME_INTERVAL = 60 * 60 * 24 TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 HASHTABLE_COLLECTION_NAME_LENGTH = 30 -MAX_HOST_VOCABULARY_SIZE = 10**10 +MAX_VOCABULARY_SIZE = 10**10 # RANK INFO VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"] @@ -78,6 +95,29 @@ class BaseEnum(Enum): f"'{list(map(lambda c: c.value, cls))}'.") +class EnvOption(Enum): + MXREC_LOG_LEVEL = "MXREC_LOG_LEVEL" + SAVE_EASY = "SAVE_EASY" + RANK_TABLE_FILE = "RANK_TABLE_FILE" + ASCEND_VISIBLE_DEVICES = "ASCEND_VISIBLE_DEVICES" + CM_CHIEF_DEVICE = "CM_CHIEF_DEVICE" + CM_WORKER_SIZE = "CM_WORKER_SIZE" + TF_DEVICE = "TF_DEVICE" + APPLY_GRADIENTS_STRATEGY = "APPLY_GRADIENTS_STRATEGY" + ACL_TIMEOUT = "AclTimeout" + HD_CHANNEL_SIZE = "HD_CHANNEL_SIZE" + FIND_OFFSET_V2 = "FIND_OFFSET_V2" + FIND_OFFSET_V3 = "FIND_OFFSET_V3" + KEY_PROCESS_THREAD_NUM = "KEY_PROCESS_THREAD_NUM" + MAX_UNIQUE_THREAD_NUM = "MAX_UNIQUE_THREAD_NUM" + FAST_UNIQUE = "FAST_UNIQUE" + UPDATEEMB_V2 = "UpdateEmb_V2" + HOT_EMB_UPDATE_STEP = "HOT_EMB_UPDATE_STEP" + GLOG_STDERRTHREAHOLD = "GLOG_stderrthreshold" + USE_COMBINE_FAAE = "USE_COMBINE_FAAE" + STAT_ON = "STAT_ON" + + class DataName(Enum): KEY = "key" EMBEDDING = "embedding" @@ -108,10 +148,6 @@ class ASCAnchorAttr(Enum): GRADIENTS_STRATEGY = "gradients_strategy" -class MxRecMode(BaseEnum): - ASC = "ASC" # Ascend Sparse with Cpu-hashtable - - class OptimizerType(Enum): LAZY_ADAM = "LazyAdam" SGD = "SGD" @@ -136,3 +172,30 @@ class All2allGradientsOp(BaseEnum): class ApplyGradientsStrategy(BaseEnum): DIRECT_APPLY = "direct_apply" SUM_SAME_ID_GRADIENTS_AND_APPLY = "sum_same_id_gradients_and_apply" + + +class RecPyLogLevel(Enum): + DEBUG = "DEBUG" + INFO = "INFO" + ERROR = "ERROR" + + +class RecCPPLogLevel(Enum): + TRACE = "-2" + DEBUG = "-1" + INFO = "0" + WARN = "1" + ERROR = "2" + + +class TFDevice(Enum): + CPU = "CPU" + NPU = "NPU" + GPU = "GPU" + + +class Flag(Enum): + TRUE = "1" + FALSE = "0" + + diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index 9ecd5108..e9d54115 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -2,18 +2,18 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging - import tensorflow as tf import mxrec_pybind from mx_rec.util.initialize import get_use_static from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.constants.constants import ApplyGradientsStrategy, TRAIN_CHANNEL_ID +from mx_rec.util.global_env_conf import global_env +from mx_rec.util.log import logger def get_restore_vector(config): - logging.debug(f'Channel {config.get("table_name")}_restore_{config.get("channel_id")} was built for getnext') + logger.debug('Channel %s_restore_%s was built for getnext', config.get("table_name"), config.get("channel_id")) if config.get("skip_emb_transfer"): if not isinstance(config.get("emb_size"), int) or config.get("emb_size") < 1: raise TypeError(f"emb_size must be a int") @@ -54,7 +54,7 @@ def get_restore_vector(config): def get_id_offsets(max_lookup_vec_size, config): - logging.debug(f'Channel {config.get("table_name")}_lookup_{config.get("channel_id")} was built for getnext') + logger.debug('Channel %s_lookup_%s was built for getnext', config.get("table_name"), config.get("channel_id")) # 自动扩容当前只支持HBM模式,默认没有换入换出 with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): if config.get("use_dynamic_expansion"): @@ -84,7 +84,8 @@ def get_restore_vector_second(max_lookup_vec_size: int, config: dict) -> tf.Tens :param config: embedding config :return: the restore vector calculated after the second all2all """ - logging.debug(f'Channel {config.get("table_name")}_restore_second_{config.get("channel_id")} was built for getnext') + logger.debug('Channel %s_restore_second_%s was built for getnext', + config.get("table_name"), config.get("channel_id")) with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): restore_vector_second = npu_ops.gen_npu_ops.get_next( output_types=[tf.int32], @@ -100,7 +101,7 @@ def get_unique_keys(max_lookup_vec_size: int, config: dict) -> tf.Tensor: :param config: embedding config :return: the global unique keys calculated after the second all2all """ - logging.debug(f'Channel {config.get("table_name")}_uniquekeys_{config.get("channel_id")} was built for getnext') + logger.debug('Channel %s_uniquekeys_%s was built for getnext', config.get("table_name"), config.get("channel_id")) with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): if config.get("use_dynamic_expansion"): unique_keys = npu_ops.gen_npu_ops.get_next( @@ -129,8 +130,7 @@ def get_all2all_args(use_static: bool, config: dict) -> list: with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): with tf.compat.v1.variable_scope("all2all"): - logging.debug( - f'Channel {config.get("table_name")}_a2a_{config.get("channel_id")} was built for getnext') + logger.debug('Channel %s_a2a_%s was built for getnext', config.get("table_name"), config.get("channel_id")) all2all_args = npu_ops.gen_npu_ops.get_next( output_types=[tf.int64], output_shapes=[[config.get("rank_size"), config.get("rank_size")]], @@ -158,12 +158,12 @@ def get_swap_info(config: dict, swap_len: int, swap_pos: list, table: tf.Variabl swap_in = [tf.no_op()] else: with tf.compat.v1.variable_scope("h2d_emb"): - logging.debug(f'Channel {config.get("table_name")}_h2d_{config.get("channel_id")} was built for getnext') + logger.debug('Channel %s_h2d_%s was built for getnext', config.get("table_name"), config.get("channel_id")) h2d_emb = npu_ops.gen_npu_ops.get_next( output_types=[tf.float32], output_shapes=[[max_lookup_vec_size, config.get("ext_emb_size")]], channel_name=f'{config.get("table_name")}_h2d_{config.get("channel_id")}')[0] - logging.debug(f"h2d_emb shape: {h2d_emb}") + logger.debug("h2d_emb shape: %s", h2d_emb) if not isinstance(table, list): raise RuntimeError("When enable emb_transfer, optimizer should have slots") if use_static: @@ -171,8 +171,7 @@ def get_swap_info(config: dict, swap_len: int, swap_pos: list, table: tf.Variabl h2d_emb = h2d_emb[0:swap_len, :] swap_outs = [tf.gather(one_table, swap_pos) for one_table in table] swap_out = tf.concat(swap_outs, axis=1) - logging.debug( - f'Channel {config.get("table_name")}_d2h_{config.get("channel_id")} was built for op outfeed.') + logger.debug('Channel %s_d2h_%s was built for op outfeed.', config.get("table_name"), config.get("channel_id")) swap_out_op = npu_ops.outfeed_enqueue_op( channel_name=f'{config.get("table_name")}_d2h_{config.get("channel_id")}', inputs=[swap_out]) with tf.control_dependencies([swap_out_op]): @@ -208,7 +207,7 @@ def get_preprocessed_tensor_for_asc(table, config): 'all2all_args': all2all_args, } - if config.get("gradients_strategy") != ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + if global_env.apply_gradients_strategy != ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY.value: return result if config.get("channel_id") != TRAIN_CHANNEL_ID: diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 4b29a7fe..d1e8d7fe 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -2,9 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging -import re -from typing import Union +from typing import Union, Optional from functools import reduce import tensorflow as tf @@ -13,6 +11,9 @@ from mx_rec.util.atomic import AtomicInteger from mx_rec.util.initialize import insert_feature_spec, insert_training_mode_channel_id, get_use_static from mx_rec.util.normalization import fix_invalid_table_name from mx_rec.constants.constants import MAX_INT32 +from mx_rec.validator.validator import ClassValidator, StringValidator, para_checker_decorator, \ + OptionalStringValidator, OptionalIntValidator +from mx_rec.util.log import logger feature_spec_global_id = AtomicInteger() @@ -23,27 +24,42 @@ class FeatureSpec: use_timestamp_train = False use_timestamp_eval = False - def __init__(self, name, **kwargs): + @para_checker_decorator(check_option_list=[ + ("name", StringValidator, {"max_len": 255}, ["check_string_length"]), + ("table_name", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]), + ("index_key", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]), + ("access_threshold", OptionalIntValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), + ("eviction_threshold", OptionalIntValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), + ("is_timestamp", ClassValidator, {"classes": (bool, type(None))}), + ("batch_size", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), + ("faae_coefficient", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]) + ]) + def __init__(self, name: str, table_name: str, + index_key: Optional[str] = None, + access_threshold: Optional[int] = None, + eviction_threshold: Optional[int] = None, is_timestamp: Optional[bool] = None, + batch_size: Optional[int] = None, faae_coefficient: int = 1): feature_spec_global_id.increase() spec_name = name + f"_{feature_spec_global_id}" self.name = spec_name - self._index_key = kwargs.get("index_key") if kwargs.get("index_key") else name - self._table_name = fix_invalid_table_name(kwargs.get("table_name") if kwargs.get("table_name") else name) - self._feat_cnt = kwargs.get("feat_count") - self._access_threshold = kwargs.get("access_threshold") - self._eviction_threshold = kwargs.get("eviction_threshold") - self._faae_coefficient = kwargs.get("faae_coefficient", 1) - self._is_timestamp = kwargs.get("is_timestamp") + self._index_key = index_key if index_key else name + self._table_name = fix_invalid_table_name(table_name if table_name else name) + self._feat_cnt = None + self._access_threshold = access_threshold + self._eviction_threshold = eviction_threshold + self._faae_coefficient = faae_coefficient + self._is_timestamp = is_timestamp self.feat_pos_train = None self.feat_pos_eval = None self.dims = None self.rank = None - self.batch_size = kwargs.get("batch_size") + self.batch_size = batch_size self.split = None # usually split == batch_size * feature_count self.initialized = False self._pipeline_mode = set() - self.check_params() + if self._access_threshold is None and self._eviction_threshold is not None: + raise ValueError(f"Access_threshold should be configured before eviction_threshold.") @property def is_timestamp(self): @@ -90,49 +106,6 @@ class FeatureSpec: def use_timestamp(is_training): return FeatureSpec.use_timestamp_train if is_training else FeatureSpec.use_timestamp_eval - def check_params(self): - def check_str(arg, param_name): - if not isinstance(arg, str): - raise TypeError(f"{param_name} should be a string, whose value is {arg} with type '{type(arg)}' " - f"in fact.") - - def check_natural_number(arg, param_name): - if not isinstance(arg, int) or arg < 1: - raise TypeError(f"{param_name} should be a natural number, whose value is {arg} with type " - f"'{type(arg)}' in fact.") - - def check_bool(arg, param_name): - if not isinstance(arg, bool): - raise TypeError(f"{param_name} should be a bool, whose value is {arg} with type " - f"'{type(arg)}' in fact.") - - check_str(self.name, "name") - check_str(self._table_name, "table_name") - - if self._feat_cnt is not None: - check_natural_number(self._feat_cnt, "feat_count") - - if self._access_threshold is not None: - check_natural_number(self._access_threshold, "access_threshold") - if self._access_threshold > MAX_INT32: - raise ValueError(f"Access_threshold is too big that exceed int32.") - - elif self._eviction_threshold is not None: - raise ValueError(f"Access_threshold should be configured before eviction_threshold.") - - if self._eviction_threshold is not None: - check_natural_number(self._eviction_threshold, "eviction_threshold") - if self._eviction_threshold > MAX_INT32: - raise ValueError(f"Eviction_threshold is too big that exceed int32.") - - if self._faae_coefficient is not None: - check_natural_number(self._faae_coefficient, "eviction_threshold") - if self._faae_coefficient > MAX_INT32: - raise ValueError(f"Eviction_threshold is too big that exceed int32.") - - if self._is_timestamp is not None: - check_bool(self._is_timestamp, "is_timestamp") - def set_feat_pos(self, is_training): if is_training: self.feat_pos_train = FeatureSpec.instance_count_train @@ -146,7 +119,7 @@ class FeatureSpec: raise TypeError("Is training mode must be a boolean.") if mode and mode in self._pipeline_mode: - logging.info(f"FeatureSpec{self.name}. Is training mode [{mode}] has been set.") + logger.info("FeatureSpec%s. Is training mode [%s] has been set.", self.name, mode) return insert_training_mode_channel_id(is_training=mode) @@ -166,8 +139,8 @@ class FeatureSpec: raise ValueError(f"Given tensor rank cannot be smaller than 1, which is {self.rank} now.") inferred_feat_cnt = 1 if self.rank == 1 else reduce(lambda x, y: x * y, self.dims[1:]) - logging.debug(f"update feature_spec[{self.name}] feature_count " - f"from {self._feat_cnt} to {inferred_feat_cnt} via {self.dims}") + logger.debug("update feature_spec[%s] feature_count to %s via %s", self.name, inferred_feat_cnt, + self.dims) self.batch_size = self.dims[0] self._feat_cnt = inferred_feat_cnt self.split = self.batch_size * self._feat_cnt @@ -180,7 +153,7 @@ class FeatureSpec: self._feat_cnt = 1 else: - logging.debug(f"The initialized Feature Spec was set once again.") + logger.debug("The initialized Feature Spec was set once again.") if get_use_static(): if self.dims != tensor.shape.as_list(): raise ValueError(f"Given static Tensor shape mismatches with the last one, whose is_training mode " diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 4963fec6..64bffc04 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging from functools import reduce import tensorflow as tf @@ -10,8 +9,16 @@ import tensorflow as tf from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.merge_table import find_dangling_table, should_skip +from mx_rec.validator.validator import para_checker_decorator, ValueCompareValidator +from mx_rec.util.log import logger +@para_checker_decorator(check_option_list=[ + (["tgt_key_specs", "args_index_list"], ValueCompareValidator, {"target": None}, + ["check_at_least_one_not_equal_to_target"]), + (["tgt_key_specs", "args_index_list"], ValueCompareValidator, {"target": None}, + ["check_at_least_one_equal_to_target"]), +]) def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, feature_numbers=None, table_names=None, **kwargs): ''' @@ -35,6 +42,10 @@ def create_asc_insert_func_with_specs(tgt_key_specs, **kwargs): return get_asc_insert_func_inner(tgt_key_specs=tgt_key_specs, **kwargs) +@para_checker_decorator(check_option_list=[ + (["args_index_list", "feature_counts", "table_names"], ValueCompareValidator, {"target": None}, + ["check_all_not_equal_to_target"]), +]) def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names, **kwargs): ''' 自动改图模式 auto change graph @@ -47,10 +58,6 @@ def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_counts=None, table_names=None, **kwargs): - both_none = tgt_key_specs is None and args_index_list is None - both_no_none = tgt_key_specs is not None and args_index_list is not None - if both_none or both_no_none: - raise ValueError("Args tgt_key_specs and args_index_list should and only can choice one to get insert tensors.") is_training = kwargs.get("is_training", True) dump_graph = kwargs.get("dump_graph", False) @@ -71,7 +78,7 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ "splits": [] } get_target_tensors_with_feature_specs(tgt_key_specs, data_src, is_training, read_emb_key_inputs_dict) - logging.debug(f"do_insert with spec for {read_emb_key_inputs_dict.get('table_names')}") + logger.debug("do_insert with spec for %s", read_emb_key_inputs_dict.get('table_names')) return do_insert(args, insert_tensors=read_emb_key_inputs_dict.get("insert_tensors"), splits=read_emb_key_inputs_dict.get("splits"), @@ -84,18 +91,15 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ insert_fn = insert_fn_for_feature_specs else: - if feature_counts is None or table_names is None: - raise ValueError("Please config 'args_index_list', 'feature_counts' and 'table_names' at the same time.") - dangling_tables = find_dangling_table(table_names) - logging.info(f"In insert found dangling table(s): {dangling_tables} " - f"which does not need to be provided to the EmbInfo.") + logger.info("In insert found dangling table(s): %s which does not need to be provided to the EmbInfo.", + dangling_tables) def insert_fn_for_arg_indexes(*args): insert_tensors = get_target_tensors_with_args_indexes(args_index_list) - logging.debug(f"do_insert without spec for {table_names}") + logger.debug("do_insert without spec for %s", table_names) splits = [] for insert_tensor in insert_tensors: split = reduce(lambda x, y: x * y, insert_tensor.shape.as_list()) @@ -104,12 +108,12 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_ new_insert_tensors, new_splits, new_table_names = [], [], [] for idx, table_name in enumerate(table_names): if table_name in dangling_tables: - logging.info(f"do_insert skip table by graph : {table_name}") + logger.info("do_insert skip table by graph : %s", table_name) continue skip = should_skip(table_name) if skip: - logging.info(f"do_insert skip table by keyword: {table_name}") + logger.info("do_insert skip table by keyword: %s", table_name) continue new_insert_tensors.append(insert_tensors[idx]) @@ -144,7 +148,7 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list): f"len(table_name_list): {len(table_name_list)}") feature_id_requests = zip(feature_id_list, split_list, table_name_list) feature_id_requests = sorted(feature_id_requests, key=lambda x: (x[2], x[0].name)) - logging.debug(f" features to merge: {feature_id_requests}") + logger.debug("features to merge: %s", feature_id_requests) last_table_name = None last_split = 0 @@ -170,8 +174,8 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list): output_table_name_list.append(last_table_name) output_split_list.append(last_split) output_tensorshape_split_list.append(last_tensorshape_split) - logging.debug(f"merge request from {table_name_list} {split_list} " - f" to {output_table_name_list} {output_split_list}") + logger.debug("merge request from %s %s to %s %s", table_name_list, split_list, + output_table_name_list, output_split_list) list_set = { 'output_feature_id_list': output_feature_id_list, @@ -210,12 +214,12 @@ def send_feature_id_request_async(feature_id_list, split_list, table_name_list, raise RuntimeError(f"The length of split list can not be 0.") if use_static: - logging.info(f"read_emb_key_v2(static), table_name_list: {table_name_list}, split_list: {split_list}") + logger.info("read_emb_key_v2(static), table_name_list: %s, split_list: %s", table_name_list, split_list) return host_pipeline_ops.read_emb_key_v2(concat_tensor, channel_id=channel_id, splits=split_list, emb_name=table_name_list, timestamp=timestamp) - logging.info(f"read_emb_key_v2(dynamic), table_name_list: {table_name_list}, " - f"tensorshape_split_list: {tensorshape_split_list}") + logger.info("read_emb_key_v2(dynamic), table_name_list: %s, tensorshape_split_list: %s", + table_name_list, tensorshape_split_list) return host_pipeline_ops.read_emb_key_v2_dynamic(concat_tensor, tensorshape_split_list, channel_id=channel_id, emb_name=table_name_list, timestamp=timestamp) @@ -295,7 +299,7 @@ def get_target_tensors_with_args_indexes(args_index_list): for index in args_index_list: tensor = graph.get_tensor_by_name("args_%d:0" % index) if tensor.dtype != tf.int64: - logging.debug(f"Input tensor dtype is {tensor.dtype}, which will be transferred to tf.int64.") + logger.debug("Input tensor dtype is %s, which will be transferred to tf.int64.", tensor.dtype) tensor = tf.cast(tensor, tf.int64) insert_tensors.append(tf.reshape(tensor, [-1, ])) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 064fce38..eaf3faf1 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -2,18 +2,16 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging - import tensorflow as tf from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo -from mx_rec.constants.constants import MxRecMode from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ - is_asc_manager_initialized, get_train_steps, get_eval_steps, get_prefetch_batch_number, \ - export_table_instances, export_feature_spec, get_if_load, get_training_mode_channel_id, get_use_static, \ + is_asc_manager_initialized, get_train_steps, get_eval_steps, \ + export_table_instances, export_feature_spec, get_if_load, get_use_static, \ get_use_hot, get_stat_on, get_use_dynamic_expansion, export_optimizer, export_dangling_table, export_table_num from mx_rec.core.asc.merge_table import find_dangling_table, should_skip +from mx_rec.util.log import logger def check_dangling_table(): @@ -49,20 +47,20 @@ def generate_table_info_list(): # When dynamic expansion mode, ext_emb_size is set by optimizer if optimizer is not None: table_instance.ext_emb_size = table_instance.scalar_emb_size * (1 + optimizer.slot_num) - logging.debug(f"ext_emb_size is reset to be {table_instance.ext_emb_size} for EmbInfo") + logger.debug("ext_emb_size is reset to be %s for EmbInfo", table_instance.ext_emb_size) skip = should_skip(table_instance.table_name) if table_instance.table_name in dangling_table or skip: - logging.info(f"skip table {skip}: {table_instance.table_name} " - f"which does not need to be provided to the EmbInfo.") + logger.info("skip table %s: %s which does not need to be provided to the EmbInfo.", + skip, table_instance.table_name) continue - rec_mode_asc_flag = table_instance.mode == MxRecMode.ASC - static_shape_rec_flag = rec_mode_asc_flag and get_use_static() and table_instance.send_count > 0 - dynamic_shape_rec_flag = rec_mode_asc_flag and not get_use_static() + static_shape_rec_flag = get_use_static() and table_instance.send_count > 0 + dynamic_shape_rec_flag = not get_use_static() if static_shape_rec_flag or dynamic_shape_rec_flag: - logging.debug(f"table_instance.slice_device_vocabulary_size: {table_instance.slice_device_vocabulary_size}") - logging.debug(f"table_instance.slice_host_vocabulary_size: {table_instance.slice_host_vocabulary_size}") - logging.debug(f"table_instance.slice_ssd_vocabulary_size: {table_instance.slice_ssd_vocabulary_size}") + logger.debug("table_instance.slice_device_vocabulary_size: %s", + table_instance.slice_device_vocabulary_size) + logger.debug("table_instance.slice_host_vocabulary_size: %s", table_instance.slice_host_vocabulary_size) + logger.debug("table_instance.slice_ssd_vocabulary_size: %s", table_instance.slice_ssd_vocabulary_size) table_info = EmbInfo(table_instance.table_name, table_instance.send_count, table_instance.scalar_emb_size, table_instance.ext_emb_size, table_instance.is_save, [table_instance.slice_device_vocabulary_size, @@ -76,36 +74,36 @@ def generate_table_info_list(): def matched_constant_initializer(tabel_info): init_param = tabel_info.init_param - logging.debug(f"constant_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") + logger.debug("constant_initializer, tabel: %s, initK is %s.", tabel_info.table_name, init_param) return InitializeInfo(name="constant_initializer", start=0, len=tabel_info.scalar_emb_size, - constant_initializer_info=ConstantInitializerInfo( - constant_val=tabel_info.emb_initializer.value, initK=init_param)) + constant_initializer_info=ConstantInitializerInfo( + constant_val=tabel_info.emb_initializer.value, initK=init_param)) def matched_random_normal_initializer(tabel_info): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed init_param = tabel_info.init_param - logging.debug(f"random_normal_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") + logger.debug("random_normal_initializer, tabel: %s, initK is %s.", tabel_info.table_name, init_param) return InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.scalar_emb_size, - normal_initializer_info=NormalInitializerInfo( - mean=tabel_info.emb_initializer.mean, - stddev=tabel_info.emb_initializer.stddev, - seed=random_seed, - initK=init_param - )) + normal_initializer_info=NormalInitializerInfo( + mean=tabel_info.emb_initializer.mean, + stddev=tabel_info.emb_initializer.stddev, + seed=random_seed, + initK=init_param + )) def matched_truncated_normal_initializer(tabel_info): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed init_param = tabel_info.init_param - logging.debug(f"truncated_normal_initializer, tabel: {tabel_info.table_name}, initK is {init_param}.") + logger.debug("truncated_normal_initializer, tabel: %s, initK is %s.", tabel_info.table_name, init_param) return InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, - normal_initializer_info=NormalInitializerInfo( - mean=tabel_info.emb_initializer.mean, - stddev=tabel_info.emb_initializer.stddev, - seed=random_seed, - initK=init_param - )) + normal_initializer_info=NormalInitializerInfo( + mean=tabel_info.emb_initializer.mean, + stddev=tabel_info.emb_initializer.stddev, + seed=random_seed, + initK=init_param + )) def matched_emb_initializer(tabel_info): @@ -143,8 +141,8 @@ def matched_emb_initializer(tabel_info): def matched_opt_slot_initializers(table_instance): start_index = table_instance.scalar_emb_size slot_initializers = [] - logging.debug(f"matched_opt_slot_initializers, scalar emb size:{table_instance.ext_emb_size}, " - f"optimizer_instance_list size:{len(table_instance.optimizer_instance_list)}") + logger.debug("matched_opt_slot_initializers, scalar emb size:%s, optimizer_instance_list size:%s", + table_instance.ext_emb_size, len(table_instance.optimizer_instance_list)) for optimizer in table_instance.optimizer_instance_list: for slot_init_value in optimizer.get_slot_init_values(): slot_initializer = InitializeInfo(name="constant_initializer", @@ -190,7 +188,6 @@ def initialize_emb_cache(table_info_list, threshold_list): rank_size = get_rank_size() train_steps = get_train_steps() eval_steps = get_eval_steps() - n_batch_to_prefetch = get_prefetch_batch_number() if_load = get_if_load() option = 0 if get_use_static(): @@ -201,23 +198,23 @@ def initialize_emb_cache(table_info_list, threshold_list): option = option | USE_DYNAMIC_EXPANSION # [train_steps, eval_steps] pass step information to HybridMgmt for data process loop - rank_info = RankInfo(rank_id, device_id, rank_size, option, n_batch_to_prefetch, [train_steps, eval_steps]) + rank_info = RankInfo(rank_id, device_id, rank_size, option, [train_steps, eval_steps]) emb_cache = HybridMgmt() is_initialized = emb_cache.initialize(rank_info=rank_info, emb_info=table_info_list, if_load=if_load, - threshold_values=threshold_list) + threshold_values=threshold_list) + if is_initialized is False: - logging.error("Failed to init emb_cache!") + logger.error("Failed to init emb_cache!") raise RuntimeError("emb_cache has not been initialized successfully.") set_asc_manager(emb_cache) - logging.info("Preprocessing has been sunk into the host pipeline.") - logging.debug(f"Flag if load is {if_load}.") - logging.debug(f"n_batch_to_prefetch is {n_batch_to_prefetch}.") - logging.debug(f"train_steps is {train_steps}.") - logging.debug(f"eval_steps is {eval_steps}.") - logging.debug(f"threshold_values are {threshold_list}.") + logger.info("Preprocessing has been sunk into the host pipeline.") + logger.debug("Flag if load is %s.", if_load) + logger.debug("train_steps is %s.", train_steps) + logger.debug("eval_steps is %s.", eval_steps) + logger.debug("threshold_values are %s.", threshold_list) def start_asc_pipeline(): @@ -225,11 +222,9 @@ def start_asc_pipeline(): threshold_list = generate_threshold_list() if not table_info_list: - logging.error("table_info_list is empty!") + logger.error("table_info_list is empty!") raise RuntimeError("table_info_list is empty!") if get_stat_on(): - logging.info(f"[StatInfo] current_table_num {export_table_num()}") + logger.info("[StatInfo] current_table_num %s", export_table_num()) if not is_asc_manager_initialized() and table_info_list: initialize_emb_cache(table_info_list, threshold_list) - - diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index 34e3d9ff..38ba8c71 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging from typing import Dict, List import tensorflow as tf @@ -10,9 +9,10 @@ from tensorflow import Operation, Tensor from mx_rec.util.initialize import get_enable_table_merge, export_table_instances, insert_dangling_table, \ get_bool_gauge_set +from mx_rec.util.log import logger -def affirm(reach_op:List[Operation]) -> bool: +def affirm(reach_op: List[Operation]) -> bool: for node in reach_op: if node.type not in ("IdentityN", "Reshape", "Identity"): return False @@ -81,7 +81,6 @@ def find_dangling_table(table_names: List[str]) -> List[str]: table_lookup_op[table_name].append(the_op) table_reachable_tensor[table_name].extend(the_op.outputs) - def extend(op_list: List[Operation], tensor: Tensor, spread_tensors: List[Tensor]) -> None: @@ -96,7 +95,6 @@ def find_dangling_table(table_names: List[str]) -> List[str]: if tensor in the_op.inputs: spread_tensors.extend(the_op.outputs) - def bfs_lookup(next_to_visit: List[Tensor]) -> (set, bool): """find all the tensors which table lookup op can reach @@ -118,9 +116,8 @@ def find_dangling_table(table_names: List[str]) -> List[str]: next_to_visit = spread_tensors return op_visited, False - if not is_train_task(): - logging.info(f"!!merge table only available in train task.") + logger.info("!!merge table only available in train task.") return [] if not get_enable_table_merge(): return [] @@ -138,12 +135,12 @@ def find_dangling_table(table_names: List[str]) -> List[str]: for table_name in table_names: find_table_op(table_name, the_op, table_lookup_op, table_reachable_tensor) - logging.debug("*********** find tables: %s ***********", table_lookup_op) + logger.debug("*********** find tables: %s ***********", table_lookup_op) dangling_table = [] for table_name in table_names: if table_name not in table_lookup_op: - logging.debug("*********** created table %s but never look up***********", table_name) + logger.debug("*********** created table %s but never look up***********", table_name) dangling_table.append(table_name) insert_dangling_table(table_name) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 4c94d675..1045bb56 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -2,12 +2,10 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging import math import os -import re from collections import defaultdict -from typing import Optional +from typing import Optional, Union import tensorflow as tf from tensorflow.python.framework import ops @@ -18,163 +16,85 @@ from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2 from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temporary_feature_spec_attribute from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.constants.constants import (ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET,\ - ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, MxRecMode, ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES,\ - ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_HOST_VOCABULARY_SIZE) -from mx_rec.util.initialize import get_rank_id, get_rank_size, is_mpi_in_use, is_asc_frozen, get_customized_ops, \ +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ + ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, \ + ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_VOCABULARY_SIZE +from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set, \ get_table_instance_by_name -from mx_rec.validator.validator import ClassValidator, StringValidator +from mx_rec.validator.validator import ClassValidator, StringValidator, SSDFeatureValidator, \ + para_checker_decorator, IntValidator, NumValidator, OptionValidator from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.util.normalization import fix_invalid_table_name - - -def check_ssd_relate_param(host_vocabulary_size, ssd_vocabulary_size, ssd_data_path): - h_size = 0 - s_size = 0 - try: - h_size = int(host_vocabulary_size) - s_size = int(ssd_vocabulary_size) - except ValueError: - raise ValueError("host_vocabulary_size and ssd_vocabulary_size should be integer") - if h_size == 0 and s_size != 0: - raise ValueError("ssd_vocabulary_size value is invalid, it effected by host_vocabulary_size not zero") - if h_size != 0 and s_size < 0: - raise ValueError("ssd_vocabulary_size value is invalid, it need be greater than 0") - invalid_ssd_data_path = [] - for tmp_path in ssd_data_path: - if is_invalid_path(tmp_path): - invalid_ssd_data_path.append(tmp_path) - if invalid_ssd_data_path: - raise ValueError("ssd_data_path value is invalid, detail:{}, the path need exist and is real path" - .format(", ".join(invalid_ssd_data_path))) - - -def is_invalid_path(tmp_path): - return not os.path.exists(tmp_path) or not os.path.isdir(tmp_path) or os.path.islink(tmp_path) or ".." in tmp_path - - -def create_table(**kwargs): +from mx_rec.util.global_env_conf import global_env +from mx_rec.util.log import logger + + +@para_checker_decorator(check_option_list=[ + ("key_dtype", OptionValidator, {"options": (tf.int64, tf.int32, tf.string)}), + ("dim", ClassValidator, {"classes": (int, tf.TensorShape)}), + ("dim", NumValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), + ("name", StringValidator, {"max_len": 255}, ["check_string_length", "check_whitelist"]), + ("emb_initializer", ClassValidator, {"classes": (InitializerV1, InitializerV2)}), + ("optimizer_list", ClassValidator, {"classes": (list, type(None))}), + (["ssd_vocabulary_size", "ssd_data_path", "host_vocabulary_size"], SSDFeatureValidator), + ("device_vocabulary_size", IntValidator, {"min_value": 1, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), + ("host_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), + ("ssd_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), + ("ssd_data_path", ClassValidator, {"classes": (list, tuple)}), + ("is_save", ClassValidator, {"classes": (bool, )}), + ("init_param", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("all2all_gradients_op", OptionValidator, {"options": [i.value for i in list(All2allGradientsOp)]}), + ("value_dtype", OptionValidator, {"options": [tf.float32]}), + ("shard_num", NumValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), + ("fusion_optimizer_var", ClassValidator, {"classes": (bool, )}), + ("hashtable_threshold", NumValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]) +]) +def create_table(key_dtype, dim, name, emb_initializer, + optimizer_list=Optional[list], + device_vocabulary_size=1, + host_vocabulary_size=0, + ssd_vocabulary_size=0, + ssd_data_path=(os.getcwd(), ), + is_save=True, + init_param=1., + all2all_gradients_op=All2allGradientsOp.SUM_GRADIENTS.value, + value_dtype=tf.float32, + shard_num=1, + fusion_optimizer_var=True, + hashtable_threshold=0): """ Args: key_dtype: data type for feature id dim: embedding vector size name: hash table name emb_initializer: the initializer for embedding values + optimizer_list: specify the optimizers to use for current hash table device_vocabulary_size: embedding vector numbers on device host_vocabulary_size: embedding vector numbers on ddr ssd_vocabulary_size: embedding vector numbers on ssd - ssd_data_path: ssd embedding data save and load path - relation from feature to variable offset will be built - optimizer_list: specify the optimizers to use for current hash table - mode: specify which mode to run for current sparse table - value_dtype: the type of the value tensors. - shard_num: embedding partition number - fusion_optimizer_var: fusion optimizer variable with embedding - hashtable_threshold: choose to implement based on hash table or linear layer + ssd_data_path: ssd embedding data save and load path relation from feature to variable offset will be built is_save: switch whether to store sparse table data. init_param: embedding init param-coefficient all2all_gradients_op: sum_grads (default) or sum_gradients_and_div_by_ranksize. - apply_gradients_strategy: direct_apply (default) or sum_same_id_gradients_and_apply. - + value_dtype: the type of the value tensors. only tf.float32 if supported for now. + shard_num: embedding partition number + fusion_optimizer_var: fusion optimizer variable with embedding + hashtable_threshold: choose to implement based on hash table or linear layer """ - key_dtype = kwargs.get("key_dtype") - dim = kwargs.get("dim") - name = kwargs.get("name") - emb_initializer = kwargs.get("emb_initializer") - device_vocabulary_size = kwargs.get("device_vocabulary_size", 1) - host_vocabulary_size = kwargs.get("host_vocabulary_size", 0) - ssd_vocabulary_size = kwargs.get("ssd_vocabulary_size", 0) - ssd_data_path = kwargs.get("ssd_data_path", [os.getcwd()]) - optimizer_list = kwargs.get("optimizer_list") - mode = kwargs.get("mode", MxRecMode.ASC) - value_dtype = kwargs.get("value_dtype", tf.float32) - shard_num = kwargs.get("shard_num", 1) - fusion_optimizer_var = kwargs.get("fusion_optimizer_var", True) - hashtable_threshold = kwargs.get("hashtable_threshold", 0) - is_save = kwargs.get("is_save", True) - init_param = kwargs.get("init_param", 1.0) - all2all_gradients_op = kwargs.get("all2all_gradients_op", All2allGradientsOp.SUM_GRADIENTS) - apply_gradients_strategy = kwargs.get("apply_gradients_strategy", ApplyGradientsStrategy.DIRECT_APPLY) - name = fix_invalid_table_name(name) - check_create_table_params(key_dtype, dim, name, emb_initializer) - check_ssd_relate_param(host_vocabulary_size, ssd_vocabulary_size, ssd_data_path) config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, ssd_vocabulary_size=ssd_vocabulary_size, ssd_data_path=ssd_data_path, - optimizer_list=optimizer_list, mode=mode, value_dtype=value_dtype, shard_num=shard_num, - fusion_optimizer_var=fusion_optimizer_var, hashtable_threshold=hashtable_threshold, - init_param=init_param, is_save=is_save, all2all_gradients_op=all2all_gradients_op, - apply_gradients_strategy=apply_gradients_strategy) + optimizer_list=optimizer_list, init_param=init_param, is_save=is_save, + all2all_gradients_op=all2all_gradients_op) embedding = SparseEmbedding(config) return embedding -def sparse_lookup(hashtable, ids, send_count, is_train, **kwargs): - """ - Args: - hashtable: SparseEmbedding instance to be looked up - ids: Tensor to lookup from hashtable - send_count: used to config all2all communication parameters - is_train: indicates whether the mode is train. - kwargs: - dim: not in use - is_train: not in use - name: will be used to build scope_name together with hashtable name - modify_graph: if True, the original graph will be modified before building a Session instance - - Returns: Tensor for lookup result - - """ - - def check_lookup_kwargs(): - kwargs["name"] = kwargs.get("name", hashtable.get_default_lookup_name()) - if not isinstance(kwargs.get("name"), str): - raise TypeError("Given name must be a string.") - - kwargs["modify_graph"] = kwargs.get("modify_graph", False) - if not isinstance(kwargs.get("modify_graph"), bool): - raise TypeError("Given modify_graph must be a boolean.") - - if not isinstance(kwargs.get("is_train"), bool): - raise TypeError("Given is_train must be a boolean.") - - if send_count is not None and not isinstance(send_count, int): - raise TypeError("Given send_count must be an int.") - - def check_table_legality_for_feature_spec(table, feature_spec): - # check whether the name of the table exists with FeatureSpec. - if table.table_name != feature_spec.table_name: - raise ValueError(f"The table name '{feature_spec.table_name}' specified by FeatureSpec is inconsistent with" - f" the SparseEmbedding table name '{table.table_name}'.") - - def check_modify_graph(): - if not kwargs.get("modify_graph"): - raise ValueError(f"modify_graph must be turn-on when lookup by ids(Tensor, not FeatureSpec).") - - kwargs["is_train"] = is_train - check_lookup_kwargs() - scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) - - with tf.compat.v1.variable_scope(scope_name): - if hashtable.mode != MxRecMode.ASC: - raise EnvironmentError("Invalid MxRec Mode.") - if not isinstance(ids, (FeatureSpec, tf.Tensor)): - raise ValueError(f"Invalid ids type, it should be: `FeatureSpec` or `tf.Tensor`, but get `{type(ids)}`.") - - if isinstance(ids, FeatureSpec): - check_table_legality_for_feature_spec(hashtable, ids) - return hashtable.lookup_for_asc_with_feature_spec(ids, send_count, **kwargs) - - check_modify_graph() - set_modify_graph(True) - return hashtable.lookup_for_asc(ids, send_count, **kwargs) - - class SparseEmbedding: """ each feat_name has its own sparse_embedding_layer. @@ -189,14 +109,11 @@ class SparseEmbedding: self.device_vocabulary_size = config.get("device_vocabulary_size") self.host_vocabulary_size = config.get("host_vocabulary_size") self.ssd_vocabulary_size = config.get("ssd_vocabulary_size") - self.ssd_data_path = config.get("ssd_data_path") - if self.host_vocabulary_size > MAX_HOST_VOCABULARY_SIZE: - raise ValueError(f"host_vocabulary_size is larger than {MAX_HOST_VOCABULARY_SIZE}.") + self.ssd_data_path = list(config.get("ssd_data_path")) self.table_name = config.get("table_name") self.key_dtype = config.get("key_dtype") self._optimizer_instance_list = config.get("optimizer_list") self.emb_initializer = config.get("emb_initializer") - self._mode = config.get("mode") self.is_save = config.get("is_save") self.optimizer_slot_info_list = [] self._slot_num = dict() @@ -220,11 +137,10 @@ class SparseEmbedding: self.modify_graph = False self.init_param = config.get("init_param") self.all2all_gradients_op = All2allGradientsOp.mapping(config.get("all2all_gradients_op")) - self.apply_gradients_strategy = ApplyGradientsStrategy.mapping(config.get("apply_gradients_strategy")) self.set_slice_vocab_size() self.set_emb_size() - if self._mode == MxRecMode.ASC and is_asc_frozen() and self.table_name in get_name_to_var_dict(): + if is_asc_frozen() and self.table_name in get_name_to_var_dict(): self.variable = tf.compat.v1.get_variable(self.table_name, shape=(self.slice_device_vocabulary_size, self.emb_size)) if not self.skip_emb_transfer: @@ -243,10 +159,6 @@ class SparseEmbedding: def scalar_emb_size(self): return self.emb_size - @property - def mode(self): - return self._mode - @property def send_count(self): return self._send_count @@ -354,24 +266,16 @@ class SparseEmbedding: raise ValueError(f"args optimizer list must be a list or an instance of CustomizedOptimizer.") def check_and_format_init_params(self): - if not isinstance(self.embedding_size, tf.TensorShape): - raise TypeError("Parameter 'embedding_size' must be a tf.TensorShape instance.") - if self.embedding_size.ndims != 1: raise ValueError("Parameter 'embedding_size' can only be one dim shape.") - if self.mode == MxRecMode.ASC and is_asc_frozen(): + if is_asc_frozen(): raise EnvironmentError(f"Emb cache management has been established, you cannot build new ASC hash table.") - if self.mode != MxRecMode.ASC and self.host_vocabulary_size > 0: - raise ValueError(f"Only ASC mode can use host_vocabulary_size > 0.") - - if self.mode == MxRecMode.ASC and not is_mpi_in_use(): - raise EnvironmentError(f"Hash table with ASC mode must use mpi to start task.") - if not self.skip_emb_transfer and not self._optimizer_instance_list: raise ValueError("ASC with DDR mode should config optimizers before instantiating sparse table, " "but nothing was configured.") + if not self.skip_emb_transfer and self.use_dynamic_expansion: raise ValueError("DDR mode do not support embedding dynamic_expansion for now.") @@ -388,7 +292,7 @@ class SparseEmbedding: def get_default_lookup_name(self): self._default_name_count += 1 default_name = "sparse_lookup_%d" % self._default_name_count - logging.debug(f"getting one default lookup name {default_name}") + logger.debug("getting one default lookup name %s", default_name) return default_name def set_using_feature_mapping(self): @@ -402,12 +306,10 @@ class SparseEmbedding: if self.use_dynamic_expansion and len(self._optimizer_instance_list) != 0: self.ext_coefficient += self._slot_num.get(self.table_name) self.ext_emb_size = self.emb_size * self.ext_coefficient - logging.debug(f"init table, ext_emb_size is set to be {self.ext_emb_size}") + logger.debug("init table, ext_emb_size is set to be %s", self.ext_emb_size) def set_slice_vocab_size(self): rank_size = get_rank_size() - if rank_size == 0: - raise ZeroDivisionError("Rank size cannot be zero.") if self.use_dynamic_expansion: self.slice_device_vocabulary_size = 1 # 动态扩容模式下,保留device侧variable,大小设置为 1 self.slice_host_vocabulary_size = 0 @@ -421,11 +323,6 @@ class SparseEmbedding: SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = kwargs.get("is_train") SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.FEATURE_SPEC] = feature_spec - def check_mode(self, method_mode): - if self.mode != method_mode: - raise RuntimeError(f"Current sparse table was config in {self.mode.value} mode, but sparse lookup method " - f"for {method_mode} was in use.") - def check_multi_lookup_times(self): lookup_times = len(self.lookup_name_list) if self.modify_graph else len(self.lookup_result) if not self.modify_graph and get_training_mode_channel_id(True) is not None and \ @@ -437,7 +334,7 @@ class SparseEmbedding: f"({self.table_name}) is {MULTI_LOOKUP_TIMES}, and current times is {lookup_times}.") def check_and_format_lookup_params(self, feature, send_count, is_training): - logging.debug(f"sparse lookup for table {self.table_name} with is_training {is_training}") + logger.debug("sparse lookup for table %s with is_training %s", self.table_name, is_training) def check_params(): if not isinstance(is_training, bool): @@ -451,7 +348,7 @@ class SparseEmbedding: f"feature with func sparse_lookup at first.") elif isinstance(feature, tf.Tensor): - logging.debug("Input feature is a Tensor.") + logger.debug("Input feature is a Tensor.") else: raise TypeError(f"Given feature must be a FeatureSpec or tf.Tensor.") @@ -465,8 +362,8 @@ class SparseEmbedding: if get_use_static(): if isinstance(send_count, int) and send_count > 0: if self._send_count and self._send_count != send_count: - logging.warning(f"A new send count {send_count} will be used to replace the old one" - f"({self._send_count}).") + logger.warning("A new send count %s will be used to replace the old one (%s).", + send_count, self._send_count) self._send_count = send_count else: @@ -479,7 +376,7 @@ class SparseEmbedding: f"{self.slice_device_vocabulary_size} and slice_host_vocabulary_size was " f"{self.slice_host_vocabulary_size} ") - is_check_mode = self.mode == MxRecMode.ASC and not self.skip_emb_transfer and not self.use_dynamic_expansion + is_check_mode = not self.skip_emb_transfer and not self.use_dynamic_expansion if is_check_mode and self.slice_device_vocabulary_size < self.send_count * get_rank_size(): raise ValueError(f"Given device_vocabulary_size was too small for table '{self.table_name}', in which " f"slice_device_vocabulary_size was {self.slice_device_vocabulary_size} and " @@ -513,9 +410,7 @@ class SparseEmbedding: Returns: Tensor for lookup result """ - logging.debug(f"Enter ASC Branch.") - # check params - self.check_mode(MxRecMode.ASC) + logger.debug(f"Enter ASC Branch.") is_training = kwargs.get("is_train") self.check_and_format_lookup_params(ids, send_count, is_training) if is_asc_frozen() and is_training: @@ -552,7 +447,7 @@ class SparseEmbedding: mock_lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) mock_lookup_result = tf.identity(mock_lookup_result, name=ASCAnchorAttr.MOCK_LOOKUP_RESULT.value) SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.MOCK_LOOKUP_RESULT] = mock_lookup_result - logging.debug("Return the stub tensor `%s` of the `%s` table.", mock_lookup_result, self.table_name) + logger.debug("Return the stub tensor `%s` of the `%s` table.", mock_lookup_result, self.table_name) return mock_lookup_result def lookup_for_asc_with_feature_spec(self, feature_spec: FeatureSpec, send_count: int, **kwargs): @@ -613,7 +508,7 @@ class SparseEmbedding: # Ensure that tensors in the same table are sorted according to the lookup sequence (modify graph mode) or # the sequence in which feature specs are created (feature spec mode). same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) - mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", feat_count=1, table_name=table_name) + mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", table_name=table_name) if get_use_static(): tensor_list = [] @@ -628,7 +523,7 @@ class SparseEmbedding: kwargs["multi_lookup"] = True total_send_count = self.same_table_send_count if self.modify_graph else send_count * same_table_spec_count lookup_result = self.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, total_send_count, **kwargs) - logging.debug(f"lookup table {table_name} via {tensor_split_list}") + logger.debug("lookup table %s via %s", table_name, tensor_split_list) self.split_lookup_result(same_table_feature_spec, tensor_split_list, tensor_list, lookup_result, is_training) @@ -691,8 +586,7 @@ class SparseEmbedding: Returns: Tensor for lookup result """ - logging.debug(f"Enter ASC Branch, looking up with FeatureSpec.") - self.check_mode(MxRecMode.ASC) + logger.debug(f"Enter ASC Branch, looking up with FeatureSpec.") is_training = kwargs.get("is_train") self.check_and_format_lookup_params(feature_spec, send_count, is_training) rank_size = get_rank_size() @@ -702,14 +596,13 @@ class SparseEmbedding: # check training mode order and ensure channel id channel_id = get_training_mode_channel_id(is_training=is_training) - logging.debug(f"get preprocessed tensor for asc for table {self.table_name} with skip emb transfer " - f"{self.skip_emb_transfer} is_training: {is_training}, channel_id: {channel_id} .") - + logger.debug("get preprocessed tensor for asc for table %s with skip emb transfer %s is_training: %s, " + "channel_id: %s .", self.table_name, self.skip_emb_transfer, is_training, channel_id) config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, rank_size=rank_size, channel_id=channel_id, table_name=self.table_name, skip_emb_transfer=self.skip_emb_transfer, ext_emb_size=self.ext_emb_size, emb_size=self.emb_size, use_hot=use_hot, device_id=device_id, - use_dynamic_expansion=use_dynamic_expansion, gradients_strategy=self.apply_gradients_strategy) + use_dynamic_expansion=use_dynamic_expansion) # 用于打桩的op节点,它的name用于标识此次的sparse lookup是train还是eval # 后续在session run的时候,通过图反向查找该子图中查找到此op @@ -738,7 +631,7 @@ class SparseEmbedding: @tf.custom_gradient def sparse_forward(table): - logging.debug(f"fp rank size: {rank_size}") + logger.debug("fp rank size: %s", rank_size) if not use_dynamic_expansion: id_offsets_abs = tf.abs(id_offsets) local_embeddings = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") @@ -776,7 +669,7 @@ class SparseEmbedding: lookup_result = array_ops.reshape(embeddings, dest_shape) def grad(lookup_diff): - logging.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) + logger.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, restore_vector, @@ -794,14 +687,16 @@ class SparseEmbedding: raise ZeroDivisionError("Rank size cannot be zero.") from exp if use_dynamic_expansion: - if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + if global_env.apply_gradients_strategy == \ + ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY.value: update_grad = tf.compat.v1.unsorted_segment_sum(local_grad, restore_vector_second, array_ops.shape(unique_keys)[0]) else: update_grad = local_grad else: - if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + if global_env.apply_gradients_strategy == \ + ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY.value: unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, restore_vector_second, array_ops.shape(unique_keys)[0]) @@ -829,13 +724,13 @@ class SparseEmbedding: def add_to_collection(): tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) - if self.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY: + if global_env.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY.value: tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, unique_keys) else: tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) - logging.debug(f"feature spec mode, table_name: {self.table_name}, " - f"ASCEND_TABLE_NAME_MUST_CONTAIN: {ASCEND_TABLE_NAME_MUST_CONTAIN}") + logger.debug("feature spec mode, table_name: %s, ASCEND_TABLE_NAME_MUST_CONTAIN: %s", + self.table_name, ASCEND_TABLE_NAME_MUST_CONTAIN) if is_training and is_table_name_valid: add_to_collection() @@ -843,15 +738,14 @@ class SparseEmbedding: def _record(self): insert_table_instance(self.table_name, self.variable, self) - logging.debug(f"Device vocabulary_size for table {self.table_name} is {self.device_vocabulary_size}.") - logging.debug(f"Slice_device_vocabulary_size for table {self.table_name} is" - f" {self.slice_device_vocabulary_size}.") - logging.debug(f"Host vocabulary size for table {self.table_name} is {self.host_vocabulary_size}.") - logging.debug(f"Slice host vocabulary_size for table {self.table_name} is" - f" {self.slice_host_vocabulary_size}.") - logging.debug(f"SSD vocabulary size for table {self.table_name} is {self.ssd_vocabulary_size}.") - logging.debug(f"Slice ssd vocabulary_size for table {self.table_name} is" - f" {self.slice_ssd_vocabulary_size}.") + logger.debug("Device vocabulary_size for table %s is %s.", self.table_name, self.device_vocabulary_size) + logger.debug("Slice_device_vocabulary_size for table %s is %s.", + self.table_name, self.slice_device_vocabulary_size) + logger.debug(f"Host vocabulary size for table %s is %s.", self.table_name, self.host_vocabulary_size) + logger.debug(f"Slice host vocabulary_size for table %s is %s.", + self.table_name, self.slice_host_vocabulary_size) + logger.debug(f"SSD vocabulary size for table %s is %s.", self.table_name, self.ssd_vocabulary_size) + logger.debug(f"Slice ssd vocabulary_size for table %s is %s.", self.table_name, self.slice_ssd_vocabulary_size) def _initialize_variables(self): initialized_tensor = \ @@ -865,9 +759,10 @@ class SparseEmbedding: if self.use_dynamic_expansion: for sparse_optimizer_instance in self._optimizer_instance_list: self._slot_num[self.table_name] = sparse_optimizer_instance.slot_num - logging.info(f"init emb, table name: {self.table_name}, slot_num: {sparse_optimizer_instance.slot_num}") + logger.info("init emb, table name: %s, slot_num: %s", + self.table_name, sparse_optimizer_instance.slot_num) - if self.mode == MxRecMode.ASC and not self.skip_emb_transfer: + if not self.skip_emb_transfer: # build optimizer states for sparse_optimizer_instance in self._optimizer_instance_list: slot_info_list = sparse_optimizer_instance.initialize_slots(self.variable, self) @@ -877,6 +772,62 @@ class SparseEmbedding: self.set_optimizer_slot(slot_info) +@para_checker_decorator(check_option_list=[ + ("hashtable", ClassValidator, {"classes": (SparseEmbedding, )}), + ("ids", ClassValidator, {"classes": (FeatureSpec, tf.Tensor)}), + ("is_train", ClassValidator, {"classes": (bool, )}), + ("send_count", ClassValidator, {"classes": (int, type(None))}), + ("name", ClassValidator, {"classes": (str, type(None))}), + ("modify_graph", ClassValidator, {"classes": (bool, type(None))}), + ("batch", ClassValidator, {"classes": (dict, type(None))}), + ("access_and_evict_config", ClassValidator, {"classes": (dict, type(None))}), +]) +def sparse_lookup(hashtable: SparseEmbedding, + ids: Union[FeatureSpec, tf.Tensor], + send_count: Optional[int] = None, + is_train: bool = True, + name: Optional[str] = None, + modify_graph: bool = False, + batch: Optional[dict] = None, + access_and_evict_config: Optional[dict] = None, + **kwargs): + """ + Args: + hashtable: SparseEmbedding instance to be looked up + ids: Tensor to lookup from hashtable + send_count: used to config all2all communication parameters + is_train: indicates whether the mode is train. + name: identity for lookup ops, it will be used to build scope_name together with hashtable name + modify_graph: if True, the original graph will be modified before building a Session instance + batch: the value returned by the get_next() method of TF Dataset + access_and_evict_config: the configuration for the feature of feature filtering and eviction + + Returns: Tensor for lookup result + + """ + kwargs["is_train"] = is_train + kwargs["name"] = name if name is not None else hashtable.get_default_lookup_name() + kwargs["modify_graph"] = modify_graph + kwargs["batch"] = batch + kwargs["access_and_evict_config"] = access_and_evict_config + scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) + + with tf.compat.v1.variable_scope(scope_name): + if isinstance(ids, FeatureSpec): + # check whether the name of the table exists with FeatureSpec. + if hashtable.table_name != ids.table_name: + raise ValueError(f"The table name '{ids.table_name}' specified by FeatureSpec is inconsistent with" + f" the SparseEmbedding table name '{hashtable.table_name}'.") + + return hashtable.lookup_for_asc_with_feature_spec(ids, send_count, **kwargs) + + if not modify_graph: + raise ValueError("'ids' is type of tf.Tensor, 'modify_graph' should be set to True") + + set_modify_graph(modify_graph) + return hashtable.lookup_for_asc(ids, send_count, **kwargs) + + def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Optional[tf.Tensor], access_threshold: bool): """ @@ -897,30 +848,3 @@ def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Opti id_offsets_expand = tf.compat.v1.expand_dims(id_offsets >= 0, axis=-1) embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like(embeddings)) return embeddings - - -def check_create_table_params(key_dtype, dim, name, emb_initializer): - """ - 校验create_table接口必选参数:key_dtype, dim, name, emb_initializer和optimizer_list(已有校验) - :param key_dtype: data type for feature id, tf.int64 or tf.int32 or tf.string - :param dim: embedding vector size, dim's type: int or tf.TensorShape - :param name: hash table name, name's type: str - :param emb_initializer: the initializer for embedding values - :return: - """ - # check key_dtype - if key_dtype not in [tf.int64, tf.int32, tf.string]: - raise ValueError(f"key_dtype: {key_dtype} not in [tf.int64, tf.int32, tf.string]") - # check dim - dim_validator = ClassValidator(value=dim, classes=(int, tf.TensorShape)) - dim_validator.check_isinstance() - dim_validator.check() - # check name - name_validator = StringValidator(value=name, max_len=255) - name_validator.check_string_length() - name_validator.check_whitelist() - name_validator.check() - # check emb_initializer - emb_initializer_validator = ClassValidator(value=emb_initializer, classes=(InitializerV1, InitializerV2)) - emb_initializer_validator.check_isinstance() - emb_initializer_validator.check() diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py index 9ffc0324..85b00b59 100644 --- a/mx_rec/core/feature_process.py +++ b/mx_rec/core/feature_process.py @@ -2,18 +2,25 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging import time import tensorflow as tf -from mx_rec.constants.constants import DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID +from mx_rec.constants.constants import DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, MAX_INT32 from mx_rec.util.initialize import trigger_evict, get_table_instance_by_name, export_feature_spec +from mx_rec.validator.validator import para_checker_decorator, ClassValidator, IntValidator, OptionalIntValidator +from mx_rec.util.log import logger class _EvictHook(tf.compat.v1.train.SessionRunHook): """Sets evict based on global step or time.""" - + @para_checker_decorator( + check_option_list=[ + ("evict_enable", ClassValidator, {"classes": (bool,)}), + ("evict_time_interval", IntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), + ("evict_step_interval", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), + ] + ) def __init__(self, evict_enable=False, evict_time_interval=DEFAULT_EVICT_TIME_INTERVAL, @@ -27,12 +34,11 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): self._evict_op = dict() self._global_step_tensor = None - self.check_evict_init_params() if evict_step_interval is None: - logging.info(f"_EvictHook - > evict_time_interval: %d", self._evict_time_interval) + logger.info(f"_EvictHook - > evict_time_interval: %d", self._evict_time_interval) else: - logging.info(f"_EvictHook - > evict_time_interval: %d, evict_step_interval: %d", self._evict_time_interval, - self._evict_step_interval) + logger.info(f"_EvictHook - > evict_time_interval: %d, evict_step_interval: %d", + self._evict_time_interval, self._evict_step_interval) def begin(self): self._global_step_tensor = tf.compat.v1.train.get_or_create_global_step() @@ -42,7 +48,7 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): for name, instance in self._hash_table_instance.items(): scope_name = f"{instance.table_name}//evict" with tf.compat.v1.variable_scope(scope_name): - logging.debug('Channel %s_evict_%d was built for op getnext', instance.table_name, TRAIN_CHANNEL_ID) + logger.debug('Channel %s_evict_%d was built for op getnext', instance.table_name, TRAIN_CHANNEL_ID) from mx_rec.util.tf_version_adapter import npu_ops evict_pos, evict_len = npu_ops.gen_npu_ops.get_next( @@ -55,7 +61,7 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): initialized_tensor = initialized_tensor[0:evict_len, :] - logging.debug( + logger.debug( 'evict_pos output shape %r, and slice_device_vocabulary_size %d, initialized_tensor shape: %r', evict_pos, instance.slice_device_vocabulary_size, initialized_tensor) @@ -65,7 +71,7 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): def after_create_session(self, session, coord): self._global_step = session.run(self._global_step_tensor) - logging.debug("_EvictHook - > after_create_session, step: %d", self._global_step) + logger.debug("_EvictHook - > after_create_session, step: %d", self._global_step) def after_run(self, run_context, run_values): if not self._evict_enable: @@ -75,7 +81,7 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): cur_time = time.time() if cur_time - self._start_time > self._evict_time_interval or \ (self._evict_step_interval is not None and self._global_step % self._evict_step_interval == 0): - logging.info("_EvictHook - > evict switch on!!! after_run step: %d", self._global_step) + logger.info("_EvictHook - > evict switch on!!! after_run step: %d", self._global_step) if not trigger_evict(): return self._start_time = cur_time @@ -85,19 +91,6 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): def check_name_and_get_hashtable(self): for _, feature_spec in export_feature_spec().items(): if feature_spec.eviction_threshold: - logging.debug("_EvictHook - > check and get instance: table_names %s", feature_spec.table_name) + logger.debug("_EvictHook - > check and get instance: table_names %s", feature_spec.table_name) self._hash_table_instance[feature_spec.table_name] = get_table_instance_by_name(feature_spec.table_name) - def check_evict_init_params(self): - def check_type(arg, n_type, param_name): - if not isinstance(arg, n_type): - raise TypeError(f"{param_name} should be type '{n_type}', whose value is {arg} with type " - f"'{type(arg)}' in fact.") - if type(arg) == int and arg < 1: - raise ValueError(f"{param_name} should be bigger than 0, whose value is {arg} in fact") - - check_type(self._evict_enable, bool, "evict_enable") - if self._evict_time_interval is not None: - check_type(self._evict_time_interval, int, "evict_time_interval") - if self._evict_step_interval is not None: - check_type(self._evict_step_interval, int, "evict_time_interval") diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py index 11bc09b5..7a88e085 100644 --- a/mx_rec/graph/merge_lookup.py +++ b/mx_rec/graph/merge_lookup.py @@ -2,14 +2,13 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging - import tensorflow as tf from mx_rec.constants.constants import ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ENTRANCE from mx_rec.core.embedding import SparseEmbedding from mx_rec.graph.utils import check_cutting_points, replace_anchor_vec from mx_rec.util.initialize import get_modify_graph, get_merged_multi_lookup, insert_merged_multi_lookup, get_use_static +from mx_rec.util.log import logger def do_merge_lookup(is_train: bool = True): @@ -28,12 +27,12 @@ def do_merge_lookup(is_train: bool = True): """ if not get_modify_graph(): - logging.debug("The `do_merge_multi_lookup` function is called only for `modify graph` mode.") + logger.debug("The `do_merge_multi_lookup` function is called only for `modify graph` mode.") return if get_merged_multi_lookup(is_train): - logging.debug("The merge multi lookup has been executed once and does not need to be executed again.") + logger.debug("The merge multi lookup has been executed once and does not need to be executed again.") return - logging.info("start to merge multi lookup, mode(train: True, eval: False): %s.", is_train) + logger.info("start to merge multi lookup, mode(train: True, eval: False): %s.", is_train) # get anchor ids cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) @@ -47,8 +46,8 @@ def do_merge_lookup(is_train: bool = True): for cutting_point in cutting_point_list: is_training = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_TRAINING) if is_training != is_train: - logging.debug("Skip! The current mode(train: True, eval: False) is %s, but the mode of %s is %s.", - is_train, cutting_point, is_training) + logger.debug("Skip! The current mode(train: True, eval: False) is %s, but the mode of %s is %s.", + is_train, cutting_point, is_training) continue table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) @@ -68,8 +67,8 @@ def do_merge_lookup(is_train: bool = True): table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) if len(table_instance.lookup_name_list) == 1: - logging.debug("The origin lookup result of %s for %s does not need to be replaced.", feature_spec.name, - table_instance.table_name) + logger.debug("The origin lookup result of %s for %s does not need to be replaced.", + feature_spec.name, table_instance.table_name) continue send_count = table_instance.send_count @@ -78,8 +77,8 @@ def do_merge_lookup(is_train: bool = True): kwargs["feature_spec_name_ids_dict"] = feature_spec_name_ids_dict lookup_result = table_instance.lookup_for_asc_with_feature_spec(feature_spec, send_count, **kwargs) replace_anchor_vec(cutting_point, ASCAnchorAttr.MOCK_LOOKUP_RESULT, lookup_result) - logging.debug("The mock lookup result of %s for %s was replaced.", feature_spec.name, table_instance.table_name) + logger.debug("The mock lookup result of %s for %s was replaced.", feature_spec.name, table_instance.table_name) # records whether the current mode has been merged or restored lookup insert_merged_multi_lookup(is_train, True) - logging.info("finish to merge multi lookup, mode(train: True, eval: False): %s.", is_train) + logger.info("finish to merge multi lookup, mode(train: True, eval: False): %s.", is_train) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 90fa8b62..d9c2855a 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging from collections import defaultdict from typing import Any @@ -25,6 +24,8 @@ from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ export_pb_graph, make_sorted_key_to_tensor_list from mx_rec.graph.merge_lookup import do_merge_lookup +from mx_rec.validator.validator import para_checker_decorator, ClassValidator +from mx_rec.util.log import logger def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tensor_names=None, @@ -87,10 +88,10 @@ def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tenso else: raise ValueError("Encounter a invalid batch.") - logging.debug("In get_preprocessing_map_func, the old batch is: %s.", args) + logger.debug("In get_preprocessing_map_func, the old batch is: %s.", args) batch = dict() parse_batch(args, batch, key=None) - logging.debug("In get_preprocessing_map_func, the parse batch is: %s.", batch) + logger.debug("In get_preprocessing_map_func, the parse batch is: %s.", batch) input_tensors = [] if batch_tensor_names is not None: @@ -112,7 +113,7 @@ def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tenso return_elements=output_names) output_batch = [batch, tuple(output_list)] - logging.debug("In get_preprocessing_map_func, the output batch is: %s.", output_batch) + logger.debug("In get_preprocessing_map_func, the output batch is: %s.", output_batch) return tuple(output_batch) return map_func @@ -143,7 +144,7 @@ def find_make_iterator_op(batch_tensor): for input_tensor in batch_tensor.op.inputs: if input_tensor.op.outputs and input_tensor.op.outputs[0] in list( each_op.inputs) and each_op.type == "MakeIterator": - logging.debug(f"Op MakeIterator '{each_op.name}' was found.") + logger.debug("Op MakeIterator '%s' was found.", each_op.name) return each_op raise ValueError(f"Op MakeIterator was not found.") @@ -221,10 +222,10 @@ def get_passing_tensor_list(src_tensors, target_op): if passing_tensor not in passing_tensor_list: passing_tensor_list.append(passing_tensor) if len(passing_tensors) != 0: - logging.info(f"passing_tensors: {passing_tensors}") + logger.info("passing_tensors: %s", passing_tensors) sub_src_tensors.append(tensor) else: - logging.info(f"Cannot find passing tensor for given tensor '{tensor}'.") + logger.info("Cannot find passing tensor for given tensor '%s'.", tensor) output_index_list = [int(tensor.name.split(":")[1]) for tensor in passing_tensor_list] @@ -237,7 +238,7 @@ def find_target_instance_dataset(variant_tensor): if ins._variant_tensor == variant_tensor: if not isinstance(ins, DatasetV1Adapter): ins = ins._input_dataset - logging.debug(f"Find target instance '{ins}', whose variant_tensor is '{variant_tensor}'.") + logger.debug("Find target instance '%s', whose variant_tensor is '%s'.", ins, variant_tensor) if not isinstance(ins.element_spec, dict) and not ( isinstance(ins.element_spec, (list, tuple)) and len(ins.element_spec) == 2 and isinstance( ins.element_spec[0], dict)): @@ -297,8 +298,8 @@ def update_input_tensor_with_new_batch(replacement_specs: dict, new_get_next_op_ try: operator._update_input(idx, new_tensor) except InvalidArgumentError as err: - logging.info("The replacement specs keys (old batch) is: %s. \n\t\t" - "The new batch is: %s.", replacement_specs.keys(), new_batch) + logger.info("The replacement specs keys (old batch) is: %s. \n\t\t The new batch is: %s.", + replacement_specs.keys(), new_batch) raise RuntimeError(f"Cannot update edge, old tensor: {old_tensor}, new tensor: {new_tensor}.") from err @@ -321,6 +322,9 @@ def get_dataset_tensor_count(dataset: DatasetV1Adapter) -> int: return len(src_sorted_keys) +@para_checker_decorator( + check_option_list=[("dump_graph", ClassValidator, {"classes": (bool,)})] +) def modify_graph_and_start_emb_cache(dump_graph=False): modify_graph_for_asc(dump_graph=dump_graph) start_asc_pipeline() @@ -331,7 +335,7 @@ def generate_get_next_op_specs(cutting_point_list, dump_graph): for input_tensor in cutting_point_list: get_next_op = find_target_dataset_op(input_tensor.op, "IteratorGetNext") if get_next_op not in get_next_op_map: - logging.debug(f"find a new get_next_op named '{get_next_op.name}'") + logger.debug("find a new get_next_op named '%s'", get_next_op.name) replacement_specs = record_ops_to_replace(get_next_op) get_next_op_map[get_next_op]["replacement_specs"] = replacement_specs passing_tensor_list, batch_tensor_index_list, sub_cutting_point_list = \ @@ -368,11 +372,11 @@ def get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapt try: target_op = get_dataset_op(get_next_op) except (ValueError, TypeError, RuntimeError) as err: - logging.warning("The dataset op was not found, the error is `%s`. Start to traverse the operations.", err) + logger.warning("The dataset op was not found, the error is `%s`. Start to traverse the operations.", err) graph = tf.compat.v1.get_default_graph() dataset_op_list = [op for op in graph.get_operations() if ANCHOR_DATASET_NAME in op.name] - logging.debug("In get_src_dataset function, current mode(train: True, eval: False): %s, dataset_op_list: %s.", - is_training, dataset_op_list) + logger.debug("In get_src_dataset function, current mode(train: True, eval: False): %s, dataset_op_list: %s.", + is_training, dataset_op_list) if len(dataset_op_list) == 1: target_op = dataset_op_list[0] @@ -390,7 +394,7 @@ def get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapt if not target_op.outputs: raise ValueError(f"The length of the outputs of target op `{target_op}` is 0.") - logging.debug("Find target op `%s`, and output is `%s`.", target_op.name, target_op.outputs) + logger.debug("Find target op `%s`, and output is `%s`.", target_op.name, target_op.outputs) src_dataset = find_target_instance_dataset(target_op.outputs[0]) return src_dataset @@ -456,7 +460,7 @@ def update_iterator_getnext(get_next_op: Operation, tgt_dataset: DatasetV1Adapte raise RuntimeError(f"Only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, " f"but the current iterator is `{iterator_type}`.") set_iterator_type(iterator_type) - logging.info("The iterator type of dataset is `%s`.", iterator_type) + logger.info("The iterator type of dataset is `%s`.", iterator_type) if iterator_type == "MakeIterator": new_iterator = tgt_dataset.make_initializable_iterator() @@ -480,13 +484,13 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) check_cutting_points(cutting_point_list) if not cutting_point_list: - logging.warning("Nothing to revise.") + logger.warning("Nothing to revise.") return export_pb_graph("old_graph.pb", dump_graph) get_next_op_map = generate_get_next_op_specs(cutting_point_list, dump_graph) - logging.debug("In modify_graph_for_asc function, get_next_op_map.len: %d, get_next_op_map.key: %s.", - len(get_next_op_map), get_next_op_map.keys()) + logger.debug("In modify_graph_for_asc function, get_next_op_map.len: %d, get_next_op_map.key: %s.", + len(get_next_op_map), get_next_op_map.keys()) for get_next_op, records in get_next_op_map.items(): is_training = records.get("is_training") @@ -514,14 +518,14 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): if not is_training: do_merge_lookup(is_train=False) if 'evaluate' in get_bool_gauge_set(): - logging.debug("In estimator mode, eval re-creates graph each time, so the flag needs to be cleared.") + logger.debug("In estimator mode, eval re-creates graph each time, so the flag needs to be cleared.") insert_merged_multi_lookup(is_training, False) # In training mode, `do_merge_lookup` should have been executed in compute gradients phase. if is_training and not get_merged_multi_lookup(True): raise RuntimeError("In training mode, `do_merge_lookup` should have been executed in compute gradients " "phase. Please check whether compute gradients is performed.") - logging.info("Graph has been revised.") + logger.info("Graph has been revised.") export_pb_graph("new_graph.pb", dump_graph) @@ -547,6 +551,12 @@ def get_timestamp_index(get_next_op, is_training): class GraphModifierHook(tf.estimator.SessionRunHook): + @para_checker_decorator( + check_option_list=[ + ("dump_graph", ClassValidator, {"classes": (bool,)}), + ("modify_graph", ClassValidator, {"classes": (bool,)}) + ] + ) def __init__(self, dump_graph=True, modify_graph=True): self._dump_graph = dump_graph self._modify_graph = modify_graph @@ -562,7 +572,7 @@ class GraphModifierHook(tf.estimator.SessionRunHook): self._iterator_type = get_iterator_type() if self._modify_graph and self._iterator_type not in ("MakeIterator", "OneShotIterator"): raise ValueError("The value of iterator type should be like `MakeIterator` or `OneShotIterator`.") - logging.debug("In GraphModifierHook, iterator type is `%s`.", self._iterator_type) + logger.debug("In GraphModifierHook, iterator type is `%s`.", self._iterator_type) def after_create_session(self, session, coord): if self._modify_graph and self._iterator_type == "MakeIterator": @@ -570,11 +580,11 @@ class GraphModifierHook(tf.estimator.SessionRunHook): def end(self, session): bool_gauge_set = get_bool_gauge_set() - logging.debug(f"GraphModifierHook, bool_gauge_set: {bool_gauge_set}") + logger.debug("GraphModifierHook, bool_gauge_set: %s", bool_gauge_set) # In eval or predict mode, the initializer can be directly terminated. if 'train' not in bool_gauge_set: - logging.debug(f"In evaluate or predict case, GraphModifierHook call 'terminate_config_initializer'...") + logger.debug("In evaluate or predict case, GraphModifierHook call 'terminate_config_initializer'...") terminate_config_initializer() return @@ -582,5 +592,5 @@ class GraphModifierHook(tf.estimator.SessionRunHook): increase_run_times() # In 'train_and_evaluate' mode, the terminate function should be executed last. if get_is_last_round(): - logging.debug(f"In train_and_evaluate case, GraphModifierHook call 'terminate_config_initializer'...") + logger.debug("In train_and_evaluate case, GraphModifierHook call 'terminate_config_initializer'...") terminate_config_initializer() diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 1ca73423..4f0845c1 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -3,7 +3,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import weakref -import logging from typing import Any import tensorflow as tf @@ -24,6 +23,7 @@ from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_ get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round, get_asc_manager from mx_rec.util.tf_version_adapter import NPUCheckpointSaverHook from mx_rec.graph.merge_lookup import do_merge_lookup +from mx_rec.util.log import logger def init_dataset(self, input_data): @@ -178,7 +178,7 @@ def chief_session_creator_init(self, scaffold=None, master='', config=None, chec checkpoint_filename_with_path: Full file name path to the checkpoint file. Returns:None """ - logging.debug("Enter the mxrec init function of Class 'monitored_session.ChiefSessionCreator'.") + logger.debug("Enter the mxrec init function of Class 'monitored_session.ChiefSessionCreator'.") if get_modify_graph() and not get_is_graph_modify_hook_running(): raise RuntimeError( f"When 'modify_graph' is True, 'GraphModifierHook' must be configured. Example: \n" @@ -200,7 +200,7 @@ def patch_for_chief_session_creator(): Returns:None """ tf.compat.v1.train.ChiefSessionCreator.__init__ = chief_session_creator_init - logging.debug("__init__ in Class 'monitored_session.ChiefSessionCreator' has been patched.") + logger.debug("__init__ in Class 'monitored_session.ChiefSessionCreator' has been patched.") def get_cell(self: BoolGauge, *labels: Any) -> Any: @@ -213,9 +213,9 @@ def get_cell(self: BoolGauge, *labels: Any) -> Any: Returns: Obtains the cell value set by the user. """ - logging.debug("Enter patch 'BoolGauge.get_cell'.") + logger.debug("Enter patch 'BoolGauge.get_cell'.") if len(labels) > 0: - logging.debug("BoolGauge insert: %s.", labels[0]) + logger.debug("BoolGauge insert: %s.", labels[0]) insert_bool_gauge(labels[0]) return BoolGaugeCell(super(BoolGauge, self).get_cell(*labels)) @@ -224,7 +224,7 @@ def patch_for_bool_gauge(): """Patch for 'BoolGauge.get_cell'.""" BoolGauge.get_cell = get_cell - logging.debug("Function 'get_cell' in Class 'BoolGauge' has been patched.") + logger.debug("Function 'get_cell' in Class 'BoolGauge' has been patched.") def end(self: NPUCheckpointSaverHook, session: tf.compat.v1.Session): @@ -239,14 +239,14 @@ def end(self: NPUCheckpointSaverHook, session: tf.compat.v1.Session): """ - logging.debug("Enter patch 'NPUCheckpointSaverHook.end'.") - logging.info("NPUCheckpointSaverHook end...") + logger.debug("Enter patch 'NPUCheckpointSaverHook.end'.") + logger.info("NPUCheckpointSaverHook end...") basic_session_run_hooks.CheckpointSaverHook.end(self, session) if 'train_and_evaluate' in get_bool_gauge_set() and get_run_times() == 1: set_is_last_round(True) return - logging.debug("NPUCheckpointSaverHook call 'terminate_config_initializer'...") + logger.debug("NPUCheckpointSaverHook call 'terminate_config_initializer'...") terminate_config_initializer() @@ -254,7 +254,7 @@ def patch_for_end(): """Patch for 'NPUCheckpointSaverHook.end'.""" NPUCheckpointSaverHook.end = end - logging.debug("Function 'end' in Class 'NPUCheckpointSaverHook' has been patched.") + logger.debug("Function 'end' in Class 'NPUCheckpointSaverHook' has been patched.") def assert_eval_spec(eval_spec: EvalSpec): @@ -268,20 +268,20 @@ def assert_eval_spec(eval_spec: EvalSpec): """ - logging.debug("Enter patch 'tensorflow_estimator.python.estimator.training._assert_eval_spec'.") + logger.debug("Enter patch 'tensorflow_estimator.python.estimator.training._assert_eval_spec'.") if not isinstance(eval_spec, EvalSpec): raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`. Got: {}'.format(type(eval_spec))) if 'train_and_evaluate' not in get_bool_gauge_set(): insert_bool_gauge('train_and_evaluate') - logging.debug("assert_eval_spec: add 'train_and_evaluate' to BoolGaugeCell.") + logger.debug("assert_eval_spec: add 'train_and_evaluate' to BoolGaugeCell.") def patch_for_assert_eval_spec(): """Patch for 'tensorflow_estimator.python.estimator.training._assert_eval_spec'.""" tensorflow_estimator_lib.python.estimator.training._assert_eval_spec = assert_eval_spec - logging.debug("Function '_assert_eval_spec' in 'tensorflow_estimator.python.estimator.training' has been patched.") + logger.debug("Function '_assert_eval_spec' in 'tensorflow_estimator.python.estimator.training' has been patched.") def scale_loss(self: Optimizer, loss_value: tf.Tensor) -> tf.Tensor: @@ -297,7 +297,7 @@ def scale_loss(self: Optimizer, loss_value: tf.Tensor) -> tf.Tensor: """ - logging.debug("Enter patch 'Optimizer._scale_loss'.") + logger.debug("Enter patch 'Optimizer._scale_loss'.") # In train mode, merge lookup must be completed during compute gradients. # Ensure that the backward of graph is constructed and the gradient calculation is correct. do_merge_lookup(is_train=True) @@ -317,4 +317,4 @@ def patch_for_scale_loss(): """Patch for 'Optimizer._scale_loss'.""" Optimizer._scale_loss = scale_loss - logging.debug("Function '_scale_loss' in Class 'Optimizer' has been patched.") + logger.debug("Function '_scale_loss' in Class 'Optimizer' has been patched.") diff --git a/mx_rec/logger/log.py b/mx_rec/logger/log.py index d00ecb2d..fd630a8a 100644 --- a/mx_rec/logger/log.py +++ b/mx_rec/logger/log.py @@ -8,6 +8,7 @@ import yaml from mx_rec.constants.constants import LOG_MAX_SIZE from mx_rec.validator.validator import FileValidator +from mx_rec.util.global_env_conf import global_env def init_sys_log(): @@ -15,14 +16,14 @@ def init_sys_log(): log_cfg_file = os.path.join(work_dir, "logger.yaml") real_config_path = os.path.realpath(log_cfg_file) - if not FileValidator(log_cfg_file).check_file_size(real_config_path).check().is_valid(): + if not FileValidator("log_cfg_file", log_cfg_file).check_file_size(real_config_path).check().is_valid(): raise ValueError("Config file size is not valid.") with open(real_config_path, 'r', encoding='utf-8') as open_file: - if not FileValidator(real_config_path).\ - check_file_size(LOG_MAX_SIZE).\ - check_not_soft_link().\ - check_user_group().\ + if not FileValidator("log_cfg_file", real_config_path). \ + check_file_size(LOG_MAX_SIZE). \ + check_not_soft_link(). \ + check_user_group(). \ is_valid(): raise ValueError("Log config file is not valid.") @@ -34,8 +35,5 @@ def init_sys_log(): init_sys_log() srv_stream_log = logging.getLogger("logStream") -env_log_level = os.getenv("MXREC_LOG_LEVEL") srv_log = srv_stream_log -if env_log_level: - srv_log.setLevel(env_log_level) - +srv_log.setLevel(global_env.mxrec_log_level) diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index 5430373a..4c396d7f 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -6,10 +6,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict -import tensorflow as tf from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import adagrad, training_ops @@ -17,9 +15,16 @@ from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_table_instance, insert_removing_var_list -from mx_rec.util.variable import check_param_type +from mx_rec.constants.constants import MAX_INT32 +from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, NumValidator +@para_checker_decorator(check_option_list=[ + ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("initial_accumulator_value", NumValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), + ("use_locking", ClassValidator, {"classes": (bool, )}), + ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) +]) def create_hash_optimizer(learning_rate=0.001, initial_accumulator_value=0.9, use_locking=False, @@ -53,8 +58,6 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): use_locking=use_locking, name=self.unique_name) - self._check_input_param() - def initialize_slots(self, var, table_instance): # Create slots for the first and second moments. def creat_one_single_slot(var, op_name): @@ -85,11 +88,6 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): initial_accumulator_value = 0.0 return [initial_accumulator_value] - def _check_input_param(self): - check_param_type("learning_rate", self._learning_rate, (tf.Tensor, float)) - check_param_type("initial_accumulator_value", self._initial_accumulator_value, (tf.Tensor, float)) - check_param_type("use_locking", self._use_locking, bool) - def _create_slots(self, var_list): for var in var_list: dtype = var.dtype.base_dtype diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index ad654cd9..313fded8 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -6,12 +6,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict from tensorflow.python.framework import ops from tensorflow.python.training.optimizer import _TensorProcessor +from mx_rec.util.log import logger + class CustomizedOptimizer: @@ -51,4 +52,4 @@ def custom_update_op(self, opt, grad): def patch_for_optimizer(): _TensorProcessor.update_op = custom_update_op - logging.debug("update_op in Class optimizer._TensorProcessor has been patched.") \ No newline at end of file + logger.debug("update_op in Class optimizer._TensorProcessor has been patched.") \ No newline at end of file diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 0b4efb66..1cd342ea 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -6,7 +6,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict import tensorflow as tf @@ -21,11 +20,25 @@ from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_table_instance, insert_removing_var_list -from mx_rec.util.variable import check_and_get_config_via_var, check_param_type, check_param_range - - +from mx_rec.util.variable import check_and_get_config_via_var +from mx_rec.constants.constants import MAX_INT32 +from mx_rec.validator.validator import para_checker_decorator, OptionalStringValidator, ClassValidator, NumValidator, \ + StringValidator + + +@para_checker_decorator(check_option_list=[ + ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("initial_accumulator_value", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), + ("learning_rate_power", NumValidator, {"min_value": -MAX_INT32, "max_value": 0}, ["check_value"]), + ("l1_regularization_strength", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), + ("l2_regularization_strength", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), + ("l2_shrinkage_regularization_strength", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), + ("use_locking", ClassValidator, {"classes": (bool,)}), + ("name", StringValidator, {"max_len": 255}, ["check_string_length"]), + ("accum_name", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]), + ("linear_name", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]) +]) def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl", **kwargs): - return CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) @@ -48,25 +61,6 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): l2_shrinkage_regularization_strength=kwargs.get("l2_shrinkage_regularization_strength", 0.0) ) - param_name_list = ["initial_accumulator_value", "l1_regularization_strength", - "l2_regularization_strength", "l2_shrinkage_regularization_strength"] - - def _check_param_type_range(param_name_list): - for name in param_name_list: - if kwargs.get(name, None): - check_param_type(name, kwargs.get(name), (int, float)) - check_param_range(name, kwargs.get(name), 0, 1e4) - - if kwargs.get("accum_name", None): - check_param_type("accum_name", kwargs.get("accum_name"), str) - - if kwargs.get("linear_name", None): - check_param_type("linear_name", kwargs.get("linear_name"), str) - - check_param_type("use_locking", use_locking, bool) - - _check_param_type_range(param_name_list) - def initialize_slots(self, var, table_instance): val = constant_op.constant( self._initial_accumulator_value, dtype=var.dtype, shape=var.get_shape()) diff --git a/mx_rec/optimizers/ftrl_t.py b/mx_rec/optimizers/ftrl_t.py index 12710c57..ec5c38ce 100644 --- a/mx_rec/optimizers/ftrl_t.py +++ b/mx_rec/optimizers/ftrl_t.py @@ -6,12 +6,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict import tensorflow as tf -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops diff --git a/mx_rec/optimizers/ftrl_t_dense.py b/mx_rec/optimizers/ftrl_t_dense.py index 271da078..a59e793c 100644 --- a/mx_rec/optimizers/ftrl_t_dense.py +++ b/mx_rec/optimizers/ftrl_t_dense.py @@ -6,22 +6,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict import tensorflow as tf -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.training import optimizer -from tensorflow.python.training import slot_creator -from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance, insert_removing_var_list -from mx_rec.util.variable import check_and_get_config_via_var +from mx_rec.util.initialize import insert_removing_var_list def create_ftrl_dense_optimizer(learning_rate, use_locking=False, name="Ftrl_t_dense", **kwargs): @@ -51,7 +46,6 @@ class CustomizedFtrlTZ(optimizer.Optimizer): self._epsilon_tensor = None self._grad_factor_tensor = None super(CustomizedFtrlTZ, self).__init__(use_locking, name) - logging.debug("CustomizedFtrlTZ __init__ ok") def _prepare(self): self._learning_rate_tensor = ops.convert_to_tensor( diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 25747aed..5c9b1d1a 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -6,21 +6,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict import tensorflow as tf -from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import math_ops from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.variable import check_param_type +from mx_rec.constants.constants import MAX_INT32 +from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, NumValidator +@para_checker_decorator(check_option_list=[ + ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("use_locking", ClassValidator, {"classes": (bool,)}), + ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) +]) def create_hash_optimizer(learning_rate, use_locking=False, name="GradientDescent"): - return CustomizedGradientDescent(learning_rate=learning_rate, use_locking=use_locking, name=name) @@ -33,8 +36,6 @@ class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, Custo super(CustomizedGradientDescent, self).__init__(learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name) - check_param_type("use_locking", use_locking, bool) - def initialize_slots(self, var, table_instance): return [] diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index 0fa3cb6e..6af75e3d 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -6,7 +6,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict from tensorflow.python.ops import math_ops @@ -14,8 +13,16 @@ from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer +from mx_rec.constants.constants import MAX_INT32 +from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, NumValidator +@para_checker_decorator(check_option_list=[ + ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("weight_decay", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("use_locking", ClassValidator, {"classes": (bool,)}), + ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) +]) def create_hash_optimizer_by_addr(learning_rate, weight_decay=0.0001, use_locking=False, name="GradientDescentByAddr"): optimizer_by_addr = CustomizedGradientDescentByAddr(learning_rate=learning_rate, weight_decay=weight_decay, diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 4f1ee9e9..2ec61c64 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -6,7 +6,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict import tensorflow as tf @@ -20,9 +19,18 @@ from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_table_instance, insert_removing_var_list -from mx_rec.util.variable import check_and_get_config_via_var, check_param_type, check_param_range - - +from mx_rec.util.variable import check_and_get_config_via_var +from mx_rec.constants.constants import MAX_INT32 +from mx_rec.validator.validator import para_checker_decorator, StringValidator, NumValidator + + +@para_checker_decorator(check_option_list=[ + ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("beta1", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("beta2", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("epsilon", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) +]) def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdam"): """ Args: @@ -46,17 +54,6 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): super(CustomizedLazyAdam, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, use_locking=use_locking, name=self.unique_name) - check_param_type("beta1", beta1, (int, float)) - check_param_range("beta1", beta1, 0, 1) - - check_param_type("beta2", beta2, (int, float)) - check_param_range("beta2", beta2, 0, 1) - - check_param_type("epsilon", epsilon, (int, float)) - check_param_range("epsilon", epsilon, 0, 1) - - check_param_type("use_locking", use_locking, bool) - def initialize_slots(self, var, table_instance): # Create slots for the first and second moments. def creat_one_single_slot(var, op_name): @@ -91,7 +88,6 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): initial_velocity_value = 0.0 return [initial_momentum_value, initial_velocity_value] - def _apply_sparse_duplicate_indices(self, grad, var): # _apply_sparse_duplicate_indices method include tf.unique and unsorted_segment_sum operations which may # introduce dynamic shape problem, if encounter that, please de-annotation the method below. @@ -110,10 +106,10 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): temp_b2 = math_ops.cast(self._beta2_t, var_type) temp_epsilon = math_ops.cast(self._epsilon_t, var_type) temp = { - 'temp_lr' : temp_lr, - 'temp_b1' : temp_b1, - 'temp_b2' : temp_b2, - 'temp_epsilon' : temp_epsilon, + 'temp_lr': temp_lr, + 'temp_b1': temp_b1, + 'temp_b2': temp_b2, + 'temp_epsilon': temp_epsilon, } return temp diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 76f6f72d..d7cdda4b 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -6,7 +6,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict import tensorflow as tf @@ -15,9 +14,17 @@ from tensorflow.python.training import adam from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.variable import check_param_type, check_param_range +from mx_rec.constants.constants import MAX_INT32 +from mx_rec.validator.validator import para_checker_decorator, StringValidator, NumValidator +@para_checker_decorator(check_option_list=[ + ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("beta1", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("beta2", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("epsilon", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) +]) def create_hash_optimizer_by_address(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdamByAddress"): """ @@ -49,7 +56,6 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): name=self.unique_name) self._slot_num = 2 - self._check_input_param() @property def slot_num(self): @@ -61,18 +67,6 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): initial_velocity_value = 0.0 return [initial_momentum_value, initial_velocity_value] - def _check_input_param(self): - check_param_type("beta1", self._beta1, (int, float)) - check_param_range("beta1", self._beta1, 0, 1) - - check_param_type("beta2", self._beta2, (int, float)) - check_param_range("beta2", self._beta2, 0, 1) - - check_param_type("epsilon", self._epsilon, (int, float)) - check_param_range("epsilon", self._epsilon, 0, 1) - - check_param_type("use_locking", self._use_locking, bool) - def _create_slots(self, addr_list): first_addr = addr_list[0] self._create_non_slot_variable( @@ -90,10 +84,10 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): temp_b2 = math_ops.cast(self._beta2_t, var_type) temp_epsilon = math_ops.cast(self._epsilon_t, var_type) temp = { - 'temp_lr' : temp_lr, - 'temp_b1' : temp_b1, - 'temp_b2' : temp_b2, - 'temp_epsilon' : temp_epsilon, + 'temp_lr': temp_lr, + 'temp_b1': temp_b1, + 'temp_b2': temp_b2, + 'temp_epsilon': temp_epsilon, } return temp diff --git a/mx_rec/optimizers/momentum.py b/mx_rec/optimizers/momentum.py index daae8119..7e585f54 100644 --- a/mx_rec/optimizers/momentum.py +++ b/mx_rec/optimizers/momentum.py @@ -6,10 +6,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import logging from collections import defaultdict -import tensorflow as tf from tensorflow.python.ops import math_ops from tensorflow.python.training import training_ops from tensorflow.python.training import momentum @@ -17,17 +15,26 @@ from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_table_instance, insert_removing_var_list -from mx_rec.util.variable import check_and_get_config_via_var, check_param_type, check_param_range - - -def create_hash_optimizer(learning_rate_input=0.001, mom=0.9, enable_locking=False, optimizer_name="momentum", +from mx_rec.util.variable import check_and_get_config_via_var +from mx_rec.constants.constants import MAX_INT32 +from mx_rec.validator.validator import para_checker_decorator, StringValidator, NumValidator, ClassValidator + + +@para_checker_decorator(check_option_list=[ + ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("mom", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("use_locking", ClassValidator, {"classes": (bool,)}), + ("name", StringValidator, {"max_len": 255}, ["check_string_length"]), + ("enable_nesterov", ClassValidator, {"classes": (bool,)}), +]) +def create_hash_optimizer(learning_rate=0.001, mom=0.9, use_locking=False, name="momentum", enable_nesterov=False): """ Create an instance of hash optimizer - :param learning_rate_input: A `Tensor` or a floating point value. The learning rate. + :param learning_rate: A `Tensor` or a floating point value. The learning rate. :param mom: A `Tensor` or a floating point value. The momentum. - :param enable_locking: If `True` use locks for update operations. - :param optimizer_name: Optional name prefix for the operations created when applying gradients. + :param use_locking: If `True` use locks for update operations. + :param name: Optional name prefix for the operations created when applying gradients. Defaults to "Momentum". :param enable_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al., 2013). This implementation always computes gradients at the value of the variable(s) passed to the optimizer. Using Nesterov Momentum makes the @@ -37,10 +44,10 @@ def create_hash_optimizer(learning_rate_input=0.001, mom=0.9, enable_locking=Fal the change in the average gradient. :return: momentum hash optimizer instance """ - return CustomizedMomentum(learning_rate=learning_rate_input, + return CustomizedMomentum(learning_rate=learning_rate, momentum_var=mom, - use_locking=enable_locking, - name=optimizer_name, + use_locking=use_locking, + name=name, use_nesterov=enable_nesterov) @@ -61,8 +68,6 @@ class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): name=self.unique_name, use_nesterov=use_nesterov) - self._check_input_param() - def initialize_slots(self, var, table_instance): # Create slots for the first and second moments. def creat_one_single_slot(var, op_name): @@ -93,14 +98,6 @@ class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): initial_momentum_value = 0.0 return [initial_momentum_value] - def _check_input_param(self): - check_param_type("learning_rate", self._learning_rate, (tf.Tensor, float)) - check_param_type("momentum", self._momentum, (tf.Tensor, float)) - check_param_type("use_locking", self._use_locking, bool) - check_param_type("use_nesterov", self._use_nesterov, bool) - - check_param_range("momentum", self._momentum, 0.0, 1.0) - def _create_slots(self, var_list): m_state_name = self._name + "/" + "momentum" for var in var_list: diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index f60a0c36..3a53ba30 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -4,7 +4,6 @@ import os import time -import logging import tensorflow as tf from tensorflow.core.protobuf import saver_pb2 @@ -27,6 +26,8 @@ from tensorflow.python.training.saving import saveable_object_util from mx_rec.saver.saver import Saver as SparseSaver from mx_rec.util.initialize import get_ascend_global_hashtable_collection, export_removing_var_list +from mx_rec.validator.validator import para_checker_decorator, ClassValidator +from mx_rec.util.log import logger def get_sparse_vars(var_list): @@ -115,7 +116,7 @@ def get_model_checkpoint_path(self, checkpoint_file, sess): if self.sparse_saver: self.sparse_saver.save(sess, save_path=checkpoint_file) - logging.info("Save model into dir %s", checkpoint_file) + logger.info("Save model into dir %s", checkpoint_file) else: self._build_eager(checkpoint_file, build_save=True, build_restore=False) model_checkpoint_path = self.saver_def.save_tensor_name @@ -169,6 +170,17 @@ def build(self): self._build(self._filename, build_save=True, build_restore=True) +@para_checker_decorator(check_option_list=[ + ("sess", ClassValidator, {"classes": (tf.compat.v1.Session, tf.compat.v1.train.MonitoredSession)}), + ("save_path", ClassValidator, {"classes": (str, )}), + ("global_step", ClassValidator, {"classes": (int, type(None))}), + ("latest_filename", ClassValidator, {"classes": (str, type(None))}), + ("meta_graph_suffix", ClassValidator, {"classes": (str, type(None))}), + ("write_meta_graph", ClassValidator, {"classes": (bool, type(None))}), + ("write_state", ClassValidator, {"classes": (bool, type(None))}), + ("strip_default_attrs", ClassValidator, {"classes": (bool, type(None))}), + ("save_debug_info", ClassValidator, {"classes": (bool, type(None))}) +]) def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False): # since tf 2.6.0, tf needs tensorflow_io to support hdfs path @@ -207,6 +219,10 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra return model_checkpoint_path +@para_checker_decorator(check_option_list=[ + ("sess", ClassValidator, {"classes": (tf.compat.v1.Session, tf.compat.v1.train.MonitoredSession)}), + ("save_path", ClassValidator, {"classes": (str, )}) +]) def restore(self, sess, save_path): if save_path is None: raise ValueError("Can't load save_path when it is None.") @@ -231,7 +247,7 @@ def restore(self, sess, save_path): sess.run(self.saver_def.restore_op_name, {self.saver_def.filename_tensor_name: save_path}) - logging.info("Restore from dir %s", save_path) + logger.info("Restore from dir %s", save_path) else: self._build_eager(save_path, build_save=False, build_restore=True) @@ -382,6 +398,6 @@ def patch_for_saver(): dense_saver.save = save dense_saver.restore = restore dense_saver.build = build - logging.debug("Class tf.train.Saver has been patched.") + logger.debug("Class tf.train.Saver has been patched.") diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 4049dd5a..c031648a 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -4,7 +4,6 @@ import json import os -import logging import threading from collections import defaultdict @@ -12,12 +11,14 @@ import numpy as np import tensorflow as tf from tensorflow.python.util import compat -from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE +from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE, Flag, TFDevice from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ - send_host_data, get_ascend_global_hashtable_collection, set_sparse_dir + send_host_data, get_ascend_global_hashtable_collection, set_sparse_dir, get_local_rank_size from mx_rec.util.perf import performance from mx_rec.validator.validator import DirectoryValidator, FileValidator +from mx_rec.util.global_env_conf import global_env +from mx_rec.util.log import logger # define save model thread @@ -41,20 +42,20 @@ class Saver(object): self._prefix_name = prefix_name self.var_list = var_list self.rank_id = get_rank_id() - self.local_rank_id = self.rank_id % 8 + self.local_rank_size = get_local_rank_size() + self.local_rank_id = self.rank_id % self.local_rank_size self.rank_size = get_rank_size() - self.local_rank_size = min(self.rank_size, 8) self.save_op_dict = defaultdict(dict) self.restore_fetch_list = [] self.placeholder_dict = defaultdict(dict) # save_easy_mode : only save the embedding and key data of sparse tables - self.save_easy_mode = int(os.getenv("SAVE_EASY", 0)) + self.save_easy_mode = (global_env.save_easy == Flag.TRUE.value) self.build() def build(self): if self.var_list is None: self.var_list = [] - logging.debug(f"optimizer collection name: {get_ascend_global_hashtable_collection()}") + logger.debug("optimizer collection name: %s", get_ascend_global_hashtable_collection()) temp_var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) for var in temp_var_list: table_instance = get_table_instance(var) @@ -66,7 +67,7 @@ class Saver(object): with tf.compat.v1.variable_scope("mx_rec_restore"): self._build_restore() - logging.debug("Save & Restore graph was built.") + logger.debug("Save & Restore graph was built.") @performance("Save") def save(self, sess, save_path="model", global_step=None): @@ -84,7 +85,7 @@ class Saver(object): the checkpoint filenames. The optional argument can be a Tensor, a Tensor name or an integer. :return: None """ - logging.debug(f"======== Start saving for rank id {self.rank_id} ========") + logger.debug("======== Start saving for rank id %s ========", self.rank_id) save_path = save_path if save_path else self._prefix_name directory, base_name = os.path.split(save_path) @@ -100,24 +101,24 @@ class Saver(object): try: if save_path.find("://") == -1: - DirectoryValidator(saving_path).with_blacklist(exact_compare=False).check() + DirectoryValidator("saving_path", saving_path).with_blacklist(exact_compare=False).check() except ValueError as err: raise ValueError(f"The saving path {saving_path} cannot be a system directory " f"or a subdirectory of the system directory.") from err if tf.io.gfile.exists(saving_path): tf.io.gfile.rmtree(saving_path) - logging.info(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been deleted.") + logger.info("rank id %s | Saving_path '%s' has been deleted.", self.rank_id, saving_path) tf.io.gfile.makedirs(saving_path) - logging.info(f"rank id {self.rank_id} | Saving_path '{saving_path}' has been made.") + logger.info("rank id %s | Saving_path '%s' has been made.", self.rank_id, saving_path) self._save(sess, saving_path) - logging.info(f"sparse model was saved in dir '{saving_path}' .") - logging.info(f"======== Saving finished for rank id {self.rank_id} ========") + logger.info("sparse model was saved in dir '%s' .", saving_path) + logger.info("======== Saving finished for rank id %s ========", self.rank_id) @performance("Restore") def restore(self, sess, reading_path): - logging.debug("======== Start restoring ========") + logger.debug("======== Start restoring ========") directory, base_name = os.path.split(reading_path) ckpt_name = f"sparse-{base_name}" @@ -127,8 +128,8 @@ class Saver(object): raise FileExistsError(f"Given dir {reading_path} does not exist, please double check.") self._restore(sess, reading_path) - logging.info(f"sparse model was restored from dir '{reading_path}' .") - logging.debug("======== Restoring finished ========") + logger.info("sparse model was restored from dir '%s' .", reading_path) + logger.debug("======== Restoring finished ========") @performance("save_table_name_data") def save_table_name_data(self, sess, result, root_dir, table_name): @@ -165,7 +166,7 @@ class Saver(object): def _build_restore(self): for var in self.var_list: - if os.getenv("TF_DEVICE", " ") == "NPU" and "merged" not in var.name: + if global_env.tf_device == TFDevice.NPU.value and "merged" not in var.name: continue table_instance = get_table_instance(var) sub_placeholder_dict = self.placeholder_dict[table_instance.table_name] @@ -199,7 +200,7 @@ class Saver(object): result = self.save_op_dict if is_asc_manager_initialized() and not self.save_easy_mode: save_host_data(root_dir) - logging.debug(f"host data was saved.") + logger.debug(f"host data was saved.") threads = [] for table_name in result.keys(): @@ -256,10 +257,10 @@ class Saver(object): if is_asc_manager_initialized() and self.save_easy_mode: send_host_data(key_offset_dict) - logging.info("host data was sent to the host pipeline.") + logger.info("host data was sent to the host pipeline.") if is_asc_manager_initialized() and not self.save_easy_mode: restore_host_data(reading_path) - logging.info("host data was restored.") + logger.info("host data was restored.") sess.run(self.restore_fetch_list, feed_dict=restore_feed_dict) @@ -398,12 +399,12 @@ def write_binary_data(writing_path, suffix, data, attributes=None): raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} exists before writing.") if target_data_dir.find("://") != -1: - logging.debug(f"use hdfs path {target_data_dir} to save sparse data.") + logger.debug("use hdfs path %s to save sparse data.", target_data_dir) with tf.io.gfile.GFile(target_data_dir, "wb") as file: data = data.tostring() file.write(data) else: - logging.debug(f"use local file path {target_data_dir} to save sparse data.") + logger.debug("use local file path %s to save sparse data.", target_data_dir) data.tofile(target_data_dir) if attributes is not None: @@ -441,11 +442,11 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: with tf.io.gfile.GFile(target_data_dir, "rb") as file: validate_read_file(target_data_dir) if target_data_dir.find("://") != -1: - logging.debug("use hdfs path %s to restore sparse data.", target_data_dir) + logger.debug("use hdfs path %s to restore sparse data.", target_data_dir) data_to_restore = file.read() data_to_restore = np.fromstring(data_to_restore, dtype=attributes.pop(DataAttr.DATATYPE.value)) else: - logging.debug("use local file path %s to restore sparse data.", target_data_dir) + logger.debug("use local file path %s to restore sparse data.", target_data_dir) data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) if DataAttr.SHAPE.value in attributes and data_name != DataName.KEY.value: @@ -457,8 +458,8 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: data_to_restore = process_embedding_data(data_to_restore, current_data_shape, data_shape) data_dict = {data_name: data_to_restore} - logging.debug(f"Attribute: '{target_attribute_dir}' and data file: '{target_data_dir}' have been read.") - logging.debug(f"Reading shape is {data_to_restore.shape}.") + logger.debug("Attribute: '%s' and data file: '%s' have been read.", target_attribute_dir, target_data_dir) + logger.debug("Reading shape is %s.", data_to_restore.shape) return data_dict @@ -468,7 +469,7 @@ def validate_read_file(read_file_path): Validate file before reading,including validating soft link, file size :param read_file_path: the file path to be validated """ - file_validator = FileValidator(read_file_path) + file_validator = FileValidator("read_file_path", read_file_path) file_validator.check_file_size(MAX_FILE_SIZE, MIN_SIZE) # local file need to check soft link if read_file_path.find("://") == -1: diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index 675726a3..581acb93 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -2,7 +2,6 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -import logging import os import json @@ -10,6 +9,8 @@ import numpy as np from mx_rec.util.initialize import get_table_instance_by_name, export_table_name_set, get_sparse_dir from mx_rec.validator.validator import FileValidator +from mx_rec.validator.validator import para_checker_decorator, ClassValidator +from mx_rec.util.log import logger class SparseProcessor: @@ -31,7 +32,7 @@ class SparseProcessor: self.default_table_list = list(export_table_name_set()) if not self.table_list: - logging.debug("table list not be set, use default value : all table created ") + logger.debug("table list not be set, use default value : all table created ") self.table_list = self.default_table_list else: self.table_list = check_table_param(self.table_list, self.default_table_list) @@ -44,7 +45,7 @@ class SparseProcessor: def _get_data(data_dir, dtype, data_shape): with open(data_dir, "rb") as file: # check whether data file is valid - file_validator = FileValidator(data_dir) + file_validator = FileValidator("data_dir", data_dir) # 1.check whether data_dir is soft link file_validator.check_not_soft_link() # 2.check data file size @@ -64,7 +65,7 @@ class SparseProcessor: try: with open(attribute_dir, "r") as fin: # check whether attribute file is valid - file_validator = FileValidator(attribute_dir) + file_validator = FileValidator("attribute_dir", attribute_dir) # 1.check whether attribute_dir is soft link file_validator.check_not_soft_link() # 2.check attribute file size @@ -82,7 +83,7 @@ class SparseProcessor: return attributes def export_sparse_data(self): - logging.info("table list to be exported is %s", self.table_list) + logger.info("table list to be exported is %s", self.table_list) sparse_dir = get_sparse_dir() ddr = False dev_dir = set_upper_dir(sparse_dir, self.device_dir_list) @@ -164,13 +165,16 @@ class SparseProcessor: return data_file, attribute_file +@para_checker_decorator(check_option_list=[ + ("table_list", ClassValidator, {"classes": (list, )}) +]) def export(**kwargs): empty_value = 0 SparseProcessor.set_instance(**kwargs) if SparseProcessor.single_instance.table_list: return SparseProcessor.single_instance.export_sparse_data() else: - logging.warning("no table can be exported ,please check if you have saved or created tables") + logger.warning("no table can be exported ,please check if you have saved or created tables") return empty_value @@ -178,7 +182,7 @@ def check_table_param(table_list, default_table_list): out_list = [] for table in table_list: if table not in default_table_list: - logging.warning(f"{table} not be created , please check your table name.") + logger.warning("%s not be created , please check your table name.", table) out_list.append(table) return out_list diff --git a/mx_rec/util/__init__.py b/mx_rec/util/__init__.py index f46049eb..4c2bd953 100644 --- a/mx_rec/util/__init__.py +++ b/mx_rec/util/__init__.py @@ -3,7 +3,3 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. __all__ = ["initialize", "variable"] - -from mx_rec.util.log import get_log_level - -get_log_level() diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index 0a9af6a0..3a2db57c 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -5,18 +5,19 @@ import json import os -from mx_rec.constants.constants import VALID_DEVICE_ID_LIST, MIN_SIZE, MAX_CONFIG_SIZE, MAX_DEVICE_ID -from mx_rec.validator.validator import RankInfoValidator, FileValidator +from mx_rec.constants.constants import VALID_DEVICE_ID_LIST, MIN_SIZE, MAX_CONFIG_SIZE, MAX_DEVICE_ID, \ + MIN_RANK_SIZE, MAX_RANK_SIZE +from mx_rec.validator.validator import FileValidator, para_checker_decorator, StringValidator, \ + Convert2intValidator +from mx_rec.util.global_env_conf import global_env def parse_hccl_json(): - rank_table_path = os.path.realpath(os.getenv("RANK_TABLE_FILE")) - if not os.path.exists(rank_table_path): - raise FileExistsError(f"Target_hccl_json_dir {rank_table_path} does not exist when reading.") + rank_table_path = os.path.realpath(global_env.rank_table_file) with open(rank_table_path, "r", encoding="utf-8") as file: # check whether json file is valid - file_validator = FileValidator(rank_table_path) + file_validator = FileValidator("RANK_TABLE_FILE", rank_table_path) # 1.check whether rank_table_path is soft link file_validator.check_not_soft_link() # 2.check json file size @@ -32,11 +33,13 @@ def parse_hccl_json(): raise AttributeError(f"Lack of attribute device.") rank_to_device_dict = dict() + local_rank_size = -1 for server_list in table_hccl.get("server_list"): devices = server_list.get("device") if devices is None: raise ValueError("device is empty") + local_rank_size = len(devices) for device in devices: if "rank_id" not in device or not device.get("rank_id").isdigit(): raise ValueError(f"hccl_json rank_id wrong.") @@ -53,31 +56,40 @@ def parse_hccl_json(): raise ValueError(f"get logic id from physic id fail, the device id is invalid.") rank_to_device_dict[rank_id] = res - return rank_to_device_dict + return rank_to_device_dict, local_rank_size -def set_hccl_info_without_json(): +@para_checker_decorator(check_option_list=[ + ("visible_devices", StringValidator, {"msg": "please config ASCEND_VISIBLE_DEVICES in docker container start"}), + ("rank_size", StringValidator, {"msg": "please config CM_WORKER_SIZE in docker container start"}), + ("chief_device", StringValidator, {"msg": "please config CM_CHIEF_DEVICE in docker container start"}), + ("rank_size", Convert2intValidator, {"min_value": MIN_RANK_SIZE, "max_value": MAX_RANK_SIZE, + "constrained_options": [1, 2, 4, 8, 16]}, ["check_value"]), + ("chief_device", Convert2intValidator, {"min_value": 0, "max_value": 15}, ["check_value"]), +]) +def set_hccl_info_without_json(visible_devices: str, rank_size: str, chief_device: str): """ Used for no rank table file configured training situation. Now, only less than or equal 8p training job is supported. - :return: None + :param visible_devices: 昇腾处理器可见的设备,来指定程序只使用其中的部分设备。 + :param rank_size: 参与集群训练的device数量。 + :param chief_device: 主节点device id。 + :return: """ - RankInfoValidator().check_visible_devices() - ascend_visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") - device_list = get_device_list(ascend_visible_devices) + device_list = get_device_list(visible_devices) + chief_device = int(chief_device) + rank_size = int(rank_size) - chief_device = os.getenv("CM_CHIEF_DEVICE") - rank_size = os.getenv("CM_WORKER_SIZE") sorted_device_list = sorted(device_list) - if int(rank_size) != len(sorted_device_list): - raise ValueError(f"Rank size {rank_size} is different from device num {len(sorted_device_list)}.") - rank_to_device_dict = dict() - try: - rank_to_device_dict[0] = int(chief_device) - except ValueError as err: - raise ValueError("CM_WORKER_SIZE or CM_CHIEF_DEVICE uncorrected configured.") from err + local_rank_size = len(sorted_device_list) + + if rank_size < local_rank_size: + raise ValueError(f"Rank size {rank_size} is less than devices: {local_rank_size}.") + + rank_to_device_dict = {0: chief_device} + try: - sorted_device_list.pop(int(chief_device) % len(sorted_device_list)) + sorted_device_list.pop(chief_device % local_rank_size) except IndexError as err: raise IndexError( f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {sorted_device_list}.") \ @@ -96,7 +108,7 @@ def set_hccl_info_without_json(): raise ValueError(f"get logic id from physic id fail.") index = sorted_device_list.index(device_idx) rank_to_device_dict[index + 1] = res - return rank_to_device_dict + return rank_to_device_dict, local_rank_size def get_device_list(ascend_visible_devices): diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py new file mode 100644 index 00000000..b617fd62 --- /dev/null +++ b/mx_rec/util/global_env_conf.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import os +import dataclasses +from dataclasses import dataclass + +from mx_rec.constants.constants import EnvOption, RecPyLogLevel, Flag, EMPTY_STR, ApplyGradientsStrategy, \ + DEFAULT_HD_CHANNEL_SIZE, DEFAULT_KP_THREAD_NUM, DEFAULT_FAST_UNIQUE_THREAD_NUM, RecCPPLogLevel, MAX_INT32, \ + MIN_HD_CHANNEL_SIZE, MAX_HD_CHANNEL_SIZE, MIN_KP_THREAD_NUM, MAX_KP_THREAD_NUM, \ + MIN_FAST_UNIQUE_THREAD_NUM, MAX_FAST_UNIQUE_THREAD_NUM, DEFAULT_HOT_EMB_UPDATE_STEP, MIN_HOT_EMB_UPDATE_STEP, \ + MAX_HOT_EMB_UPDATE_STEP +from mx_rec.validator.validator import para_checker_decorator, OptionValidator, DirectoryValidator, Convert2intValidator + + +@dataclass +class RecEnv: + mxrec_log_level: str + save_easy: str + rank_table_file: str + ascend_visible_devices: str + cm_chief_device: str + cm_worker_size: str + tf_device: str + apply_gradients_strategy: str + acl_timeout: str + hd_channel_size: str + find_offset_v2: str + find_offset_v3: str + key_process_thread_num: str + max_unique_thread_num: str + fast_unique: str + updateemb_v2: str + hot_emb_update_step: str + glog_stderrthreahold: str + use_combine_faae: str + stat_on: str + + +def get_global_env_conf() -> RecEnv: + """ + 获取mxRec全局环境变量,并做校验 + :return: + """ + rec_env = RecEnv( + mxrec_log_level=os.getenv(EnvOption.MXREC_LOG_LEVEL.value, RecPyLogLevel.INFO.value), + save_easy=os.getenv(EnvOption.SAVE_EASY.value, Flag.FALSE.value), + rank_table_file=os.getenv(EnvOption.RANK_TABLE_FILE.value, EMPTY_STR), + ascend_visible_devices=os.getenv(EnvOption.ASCEND_VISIBLE_DEVICES.value), + cm_chief_device=os.getenv(EnvOption.CM_CHIEF_DEVICE.value), + cm_worker_size=os.getenv(EnvOption.CM_WORKER_SIZE.value), + tf_device=os.getenv(EnvOption.TF_DEVICE.value), + apply_gradients_strategy=os.getenv(EnvOption.APPLY_GRADIENTS_STRATEGY.value, + ApplyGradientsStrategy.DIRECT_APPLY.value), + acl_timeout=os.getenv(EnvOption.ACL_TIMEOUT.value, "-1"), + hd_channel_size=os.getenv(EnvOption.HD_CHANNEL_SIZE.value, DEFAULT_HD_CHANNEL_SIZE), + find_offset_v2=os.getenv(EnvOption.FIND_OFFSET_V2.value, Flag.FALSE.value), + find_offset_v3=os.getenv(EnvOption.FIND_OFFSET_V3.value, Flag.FALSE.value), + key_process_thread_num=os.getenv(EnvOption.KEY_PROCESS_THREAD_NUM.value, DEFAULT_KP_THREAD_NUM), + max_unique_thread_num=os.getenv(EnvOption.MAX_UNIQUE_THREAD_NUM.value, DEFAULT_FAST_UNIQUE_THREAD_NUM), + fast_unique=os.getenv(EnvOption.FAST_UNIQUE.value, Flag.FALSE.value), + updateemb_v2=os.getenv(EnvOption.UPDATEEMB_V2.value, Flag.FALSE.value), + hot_emb_update_step=os.getenv(EnvOption.HOT_EMB_UPDATE_STEP.value, DEFAULT_HOT_EMB_UPDATE_STEP), + glog_stderrthreahold=os.getenv(EnvOption.GLOG_STDERRTHREAHOLD.value, RecCPPLogLevel.INFO.value), + use_combine_faae=os.getenv(EnvOption.USE_COMBINE_FAAE.value, Flag.FALSE.value), + stat_on=os.getenv(EnvOption.STAT_ON.value, Flag.FALSE.value) + ) + + return rec_env + + +@para_checker_decorator(check_option_list=[ + ("mxrec_log_level", OptionValidator, {"options": [i.value for i in list(RecPyLogLevel)]}), + ("save_easy", OptionValidator, {"options": [i.value for i in list(Flag)]}), + ("rank_table_file", DirectoryValidator, {}, ["check_exists_if_not_empty"]), + ("apply_gradients_strategy", OptionValidator, {"options": [i.value for i in list(ApplyGradientsStrategy)]}), + ("acl_timeout", Convert2intValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), + ("hd_channel_size", Convert2intValidator, + {"min_value": MIN_HD_CHANNEL_SIZE, "max_value": MAX_HD_CHANNEL_SIZE}, ["check_value"]), + ("find_offset_v2", OptionValidator, {"options": [i.value for i in list(Flag)]}), + ("find_offset_v3", OptionValidator, {"options": [i.value for i in list(Flag)]}), + ("key_process_thread_num", Convert2intValidator, + {"min_value": MIN_KP_THREAD_NUM, "max_value": MAX_KP_THREAD_NUM}, ["check_value"]), + ("max_unique_thread_num", Convert2intValidator, + {"min_value": MIN_FAST_UNIQUE_THREAD_NUM, "max_value": MAX_FAST_UNIQUE_THREAD_NUM}, ["check_value"]), + ("fast_unique", OptionValidator, {"options": [i.value for i in list(Flag)]}), + ("updateemb_v2", OptionValidator, {"options": [i.value for i in list(Flag)]}), + ("hot_emb_update_step", Convert2intValidator, + {"min_value": MIN_HOT_EMB_UPDATE_STEP, "max_value": MAX_HOT_EMB_UPDATE_STEP}, ["check_value"]), + ("glog_stderrthreahold", OptionValidator, {"options": [i.value for i in list(RecCPPLogLevel)]}), + ("use_combine_faae", OptionValidator, {"options": [i.value for i in list(Flag)]}), + ("stat_on", OptionValidator, {"options": [i.value for i in list(Flag)]}) +]) +def check_env(**kwargs): + pass + + +global_env = get_global_env_conf() + +check_env(**dataclasses.asdict(global_env)) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 47093085..1cea5677 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -1,20 +1,24 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging + import os from collections import defaultdict +import dataclasses +import json import psutil import mx_rec.constants.constants -from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, VALID_DEVICE_ID_LIST, LOCAL_RANK_SIZE, \ - MAX_DEVICE_NUM_LOCAL_MACHINE, DEFAULT_DEVICE_NUM_LOCAL_MACHINE, HASHTABLE_COLLECTION_NAME_LENGTH, \ - TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE +from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, HASHTABLE_COLLECTION_NAME_LENGTH, \ + TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE, MAX_RANK_SIZE, MAX_INT32, TFDevice, Flag from mx_rec.util.communication.hccl_mgmt import parse_hccl_json, set_hccl_info_without_json from mx_rec.util.ops import import_host_pipeline_ops -from mx_rec.validator.validator import StringValidator, FileValidator +from mx_rec.validator.validator import StringValidator, FileValidator, para_checker_decorator, ClassValidator, \ + IntValidator, ValueCompareValidator from mx_rec.util.atomic import AtomicInteger +from mx_rec.util.global_env_conf import global_env +from mx_rec.util.log import logger class ConfigInitializer: @@ -22,7 +26,19 @@ class ConfigInitializer: customized_ops = None host_pipeline_ops = import_host_pipeline_ops() - def __init__(self, use_mpi, **kwargs): + @para_checker_decorator(check_option_list=[ + ("use_mpi", ClassValidator, {"classes": (bool, )}), + ("train_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), + ("eval_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), + (["train_steps", "eval_steps"], ValueCompareValidator, {"target": 0}, + ["check_at_least_one_not_equal_to_target"]), + ("if_load", ClassValidator, {"classes": (bool, )}), + ("use_dynamic", ClassValidator, {"classes": (bool, )}), + ("use_hot", ClassValidator, {"classes": (bool, )}), + ("use_dynamic_expansion", ClassValidator, {"classes": (bool, )}), + ("bind_cpu", ClassValidator, {"classes": (bool, )}), + ]) + def __init__(self, use_mpi=True, **kwargs): self._use_mpi = use_mpi self._rank_id = kwargs.get("rank_id", 0) self._rank_size = kwargs.get("rank_size", 1) @@ -57,7 +73,7 @@ class ConfigInitializer: self._sparse_dir = "" if self._use_mpi: - logging.debug(f"Using mpi to launch task.") + logger.debug(f"Using mpi to launch task.") from mpi4py import MPI self._mpi = MPI self._comm = MPI.COMM_WORLD @@ -66,22 +82,23 @@ class ConfigInitializer: else: raise ValueError("only mpi is supported for launching task.") - self._rank_to_device_dict = parse_hccl_json() if os.getenv("RANK_TABLE_FILE") else set_hccl_info_without_json() + self._rank_to_device_dict, self._local_rank_size = parse_hccl_json() if global_env.rank_table_file else \ + set_hccl_info_without_json(visible_devices=global_env.ascend_visible_devices, + rank_size=global_env.cm_worker_size, + chief_device=global_env.cm_chief_device) self.train_steps = kwargs.get("train_steps", -1) self.eval_steps = kwargs.get("eval_steps", -1) - self.check_parameters() - self._prefetch_batch_number = kwargs.get("prefetch_batch_number", 1) self.if_load = kwargs.get("if_load", False) self.use_static = not kwargs.get("use_dynamic", True) self.use_hot = kwargs.get("use_hot", True) self.use_dynamic_expansion = kwargs.get("use_dynamic_expansion", False) if kwargs.get("bind_cpu", True): - bind_cpu(self._rank_id, self._rank_size) - self.enable_table_merge = True if os.getenv("TF_DEVICE") == "NPU" else False + bind_cpu(self._rank_id, self._local_rank_size) + self.enable_table_merge = True if global_env.tf_device == TFDevice.NPU.value else False # 两个通道的sparse look id,用于通讯的标识 self.notify_hybrid_channel_sparse_id = [0, 0] - self.stat_on = set_stat_flag() if os.getenv("STAT_ON") else False + self.stat_on = (global_env.stat_on == Flag.TRUE.value) def __del__(self): self.terminate() @@ -90,6 +107,10 @@ class ConfigInitializer: def iterator_type(self): return self._iterator_type + @property + def local_rank_size(self): + return self._local_rank_size + @property def merged_multi_lookup(self): return self._merged_multi_lookup @@ -176,10 +197,6 @@ class ConfigInitializer: def eval_steps(self): return self._eval_steps - @property - def prefetch_batch_number(self): - return self._prefetch_batch_number - @property def if_load(self): return self._if_load @@ -211,14 +228,14 @@ class ConfigInitializer: ConfigInitializer._single_instance = ConfigInitializer(use_mpi, **kwargs) def terminate(self): - logging.info("python process run into terminate") + logger.info("python process run into terminate") if self._is_terminated: - logging.warning("The initializer has already been released once, please do not release it again.") + logger.warning("The initializer has already been released once, please do not release it again.") return if self._asc_manager is not None: self.del_asc_manager() - logging.info("python process run terminate success") + logger.info("python process run terminate success") self._is_terminated = True ConfigInitializer._single_instance = None @@ -257,14 +274,14 @@ class ConfigInitializer: if name in self._table_name_set: raise ValueError(f"Duplicated hashtable name '{name}' was used.") - logging.debug(f"Record one hash table, with name: {name}, key: {key}.") + logger.debug("Record one hash table, with name: %s, key: %s.", name, key) self._table_name_set.add(name) if name not in self._table_name_to_feature_spec: self._table_name_to_feature_spec[name] = {True: [], False: []} self._name_to_var_dict[name] = key self._table_instance_dict[key] = instance if self.stat_on: - logging.info(f"[StatInfo] current_table_num {len(self._table_instance_dict)}") + logger.info("[StatInfo] current_table_num %s", len(self._table_instance_dict)) def insert_bool_gauge(self, name): if not isinstance(name, str): @@ -288,25 +305,6 @@ class ConfigInitializer: def insert_optimizer(self, optimizer): self._optimizer_instance = optimizer - def check_parameters(self): - if not isinstance(self._use_mpi, bool): - raise ValueError(f"Arg use_mpi must be a boolean.") - - if not isinstance(self.rank_id, int) or not isinstance(self.rank_size, int): - raise ValueError(f"Args rank_size and rank_id must be integers. {self.rank_id} {self.rank_size}") - - if self.rank_id < 0: - raise ValueError(f"Arg rank_id must be larger than 0, which is {self.rank_id} now.") - - if self.rank_size < 1: - raise ValueError(f"Arg rank_size must be larger than 1, which is {self.rank_size} now.") - - if self.rank_id >= self.rank_size: - raise ValueError(f"Rank_id must be within the range from 0 to rank_size.") - - if self._train_steps == 0 and self._eval_steps == 0: - raise ValueError(f"Train steps and eval steps could not both equal 0.") - def freeze(self): self._is_frozen = True @@ -329,7 +327,7 @@ class ConfigInitializer: self._asc_manager.destroy() self._asc_manager = None self.unfreeze() - logging.debug("ASC manager has been destroyed.") + logger.debug("ASC manager has been destroyed.") @iterator_type.setter def iterator_type(self, iterator_type): @@ -348,11 +346,6 @@ class ConfigInitializer: check_step(steps) self._eval_steps = steps - @prefetch_batch_number.setter - def prefetch_batch_number(self, number): - check_step(number, 1) - self._prefetch_batch_number = number - @if_load.setter def if_load(self, flag): if not isinstance(flag, bool): @@ -390,7 +383,8 @@ class ConfigInitializer: @ascend_global_hashtable_collection.setter def ascend_global_hashtable_collection(self, name): - string_validator = StringValidator(name, max_len=HASHTABLE_COLLECTION_NAME_LENGTH, min_len=1) + string_validator = StringValidator(name="hashtable_collection", value=name, + max_len=HASHTABLE_COLLECTION_NAME_LENGTH, min_len=1) if not string_validator.check_string_length().check_whitelist().is_valid(): raise ValueError(string_validator.msg) self._ascend_global_hashtable_collection = name @@ -426,6 +420,9 @@ class ConfigInitializer: self._initializer_dict = {} +@para_checker_decorator(check_option_list=[ + ("name", ClassValidator, {"classes": (str, type(None))}) +]) def set_ascend_global_hashtable_collection(name=ASCEND_GLOBAL_HASHTABLE_COLLECTION): ConfigInitializer.get_instance().ascend_global_hashtable_collection = name @@ -443,6 +440,8 @@ def check_step(param, min_value=-1): def init(use_mpi, **kwargs): + logger.info("The environment variables set for mxRec is: %s", + json.dumps(dataclasses.asdict(global_env), ensure_ascii=False)) ConfigInitializer.set_instance(use_mpi, **kwargs) set_ascend_env() @@ -495,10 +494,6 @@ def get_sparse_dir(): return ConfigInitializer.get_instance().sparse_dir -def is_mpi_in_use(): - return ConfigInitializer.get_instance().use_mpi - - def get_rank_size(): return ConfigInitializer.get_instance().rank_size @@ -524,9 +519,9 @@ def trigger_evict(): raise RuntimeError("ASC manager does not exist.") if ConfigInitializer.get_instance().get_asc_manager().evict(): - logging.debug("Feature evict is triggered by ops.") + logger.debug("Feature evict is triggered by ops.") return True - logging.warning("Feature evict not success, skip this time!") + logger.warning("Feature evict not success, skip this time!") return False @@ -534,7 +529,7 @@ def clear_channel(is_train_channel=False): if not isinstance(is_train_channel, bool): raise ValueError("Arg is_train_channel should be a boolean.") channel_id = get_training_mode_channel_id(is_train_channel) - logging.info(f"clear channel: {channel_id}") + logger.info("clear channel: %s", channel_id) return ConfigInitializer.get_instance().host_pipeline_ops.clear_channel(channel_id) @@ -546,7 +541,7 @@ def is_asc_manager_initialized(): def get_host_data(table_name): if not is_asc_manager_initialized(): raise RuntimeError("ASC manager does not exist.") - logging.debug("start to get host data.") + logger.debug("start to get host data.") return ConfigInitializer.get_instance().get_asc_manager().send(table_name) @@ -554,7 +549,7 @@ def send_host_data(key_offset_map): if not is_asc_manager_initialized(): raise RuntimeError("ASC manager does not exist.") ConfigInitializer.get_instance().get_asc_manager().receive(key_offset_map) - logging.debug("Data has been send to the host pipeline.") + logger.debug("Data has been send to the host pipeline.") def save_host_data(root_dir): @@ -562,7 +557,7 @@ def save_host_data(root_dir): raise RuntimeError("ASC manager does not exist.") ConfigInitializer.get_instance().get_asc_manager().save(root_dir) - logging.debug("Data from host pipeline has been saved.") + logger.debug("Data from host pipeline has been saved.") def restore_host_data(root_dir): @@ -573,16 +568,16 @@ def restore_host_data(root_dir): terminate_config_initializer() raise TypeError("Asc load data does not match usr setups, \ please re-consider if you want to restore from this dir") - logging.debug("Data from host pipeline has been restored.") + logger.debug("Data from host pipeline has been restored.") def destroy_asc_manager(): initializer = ConfigInitializer.get_instance() if initializer.get_asc_manager() is not None: - logging.debug("start destroy asc manager...") + logger.debug("start destroy asc manager...") initializer.del_asc_manager() else: - logging.warning("ASC manager does not exist, please check your code.") + logger.warning("ASC manager does not exist, please check your code.") def is_asc_frozen(): @@ -617,14 +612,6 @@ def set_eval_steps(steps: int): ConfigInitializer.get_instance().eval_steps = steps -def get_prefetch_batch_number(): - return ConfigInitializer.get_instance().prefetch_batch_number - - -def set_prefetch_batch_number(number): - ConfigInitializer.get_instance().prefetch_batch_number = number - - def get_table_instance(key): return ConfigInitializer.get_instance().get_table_instance(key) @@ -689,6 +676,9 @@ def export_feature_spec(): return ConfigInitializer.get_instance().feature_spec_dict +@para_checker_decorator(check_option_list=[ + ("if_load", ClassValidator, {"classes": (bool, )}) +]) def set_if_load(if_load): ConfigInitializer.get_instance().if_load = if_load @@ -725,6 +715,9 @@ def get_name_to_var_dict(): return ConfigInitializer.get_instance().name_to_var_dict +@para_checker_decorator(check_option_list=[ + ("is_training", ClassValidator, {"classes": (bool, )}) +]) def get_initializer(is_training): return ConfigInitializer.get_instance().get_initializer(is_training) @@ -787,6 +780,14 @@ def get_iterator_type() -> str: return ConfigInitializer.get_instance().iterator_type +def get_local_rank_size() -> int: + """ + 获取当前worker参与任务的进程数 + Returns: + """ + return ConfigInitializer.get_instance().local_rank_size + + def set_iterator_type(iterator_type: str): """ 记录数据集的迭代器类型. @@ -818,7 +819,7 @@ def set_ascend_env(): os.environ["ASCEND_DEVICE_ID"] = device_id os.environ["DEVICE_INDEX"] = device_id - if os.getenv("RANK_TABLE_FILE"): + if global_env.rank_table_file: os.environ["RANK_SIZE"] = str(rank_size) os.environ["HCCL_CONNECT_TIMEOUT"] = "1200" @@ -829,7 +830,7 @@ def set_ascend_env(): os.environ["EXPERIMENTAL_DYNAMIC_PARTITION"] = "1" os.environ["ENABLE_FORCE_V2_CONTROL"] = "1" - logging.debug(f"Ascend env has been set.") + logger.debug(f"Ascend env has been set.") def get_available_cpu_num_and_range(): @@ -846,13 +847,13 @@ def get_available_cpu_num_and_range(): for cpu in cpu_available: f_path = cpu_pkg_id_file.format(cpu) if not os.path.exists(f_path): - logging.warning(f"failed to get numa node of cpu: {cpu}") + logger.warning("failed to get numa node of cpu: %s", cpu) is_ok = False break with open(f_path, "r", encoding="utf-8") as f_in: # check whether file is valid - file_validator = FileValidator(f_path) + file_validator = FileValidator("cpu_topology_file", f_path) # 1.check whether f_path is soft link file_validator.check_not_soft_link() # 2.check file size @@ -877,7 +878,7 @@ def get_available_cpu_num_and_range(): valid_cpu_range_list = [] if is_ok: - logging.info(f"available numa node num: {len(pkg_id2cpu_list)}") + logger.info("available numa node num: %s", len(pkg_id2cpu_list)) for _, part_cpu_list in pkg_id2cpu_list.items(): parse_range(part_cpu_list, valid_cpu_range_list) else: @@ -885,31 +886,20 @@ def get_available_cpu_num_and_range(): return len(cpu_available), valid_cpu_range_list -def bind_cpu(rank_id: int, rank_size: int = None): +def bind_cpu(rank_id: int, local_rank_size: int): """ 以均衡的方式为每个进程绑定CPU :param rank_id:当前进程的rank_id - :param rank_size: 进程数 + :param local_rank_size: 当前worker进程数 :return: """ import math - try: - local_rank_size = int(os.getenv(LOCAL_RANK_SIZE)) if rank_size is None else rank_size - except (ValueError, TypeError): - logging.warning(f"no valid LOCAL_RANK_SIZE was set. {DEFAULT_DEVICE_NUM_LOCAL_MACHINE} is set as default value") - local_rank_size = DEFAULT_DEVICE_NUM_LOCAL_MACHINE - - if not (1 <= local_rank_size <= MAX_DEVICE_NUM_LOCAL_MACHINE): - logging.warning(f"LOCAL_RANK_SIZE should be between 1 and {MAX_DEVICE_NUM_LOCAL_MACHINE}. " - f"{DEFAULT_DEVICE_NUM_LOCAL_MACHINE} is set as default value") - local_rank_size = DEFAULT_DEVICE_NUM_LOCAL_MACHINE - total_cpu, cpu_range_list = get_available_cpu_num_and_range() avg_count = math.ceil(total_cpu / local_rank_size) while True: if avg_count == 0: - logging.warning(f"not enough cpu to bind. cpu num: {total_cpu}, range: {cpu_range_list}") + logger.warning(f"not enough cpu to bind. cpu num: %s, range: %s", total_cpu, cpu_range_list) return max_split = 0 @@ -932,14 +922,5 @@ def bind_cpu(rank_id: int, rank_size: int = None): try: process.cpu_affinity(cpu_list) except IndexError: - logging.error(f"failed to bind cpu for rank {rank_id}: {cpu_list}") - logging.info(f"bind cpu for rank {rank_id}: {cpu_list}") - - -def set_stat_flag(): - if os.getenv("STAT_ON") == "1": - return True - elif os.getenv("STAT_ON") == "0": - return False - else: - raise ValueError(f"STAT_ON can only be 0 or 1.") \ No newline at end of file + logger.error("failed to bind cpu for rank %s: %s", rank_id, cpu_list) + logger.info("bind cpu for rank %s: %s", rank_id, cpu_list) \ No newline at end of file diff --git a/mx_rec/util/log.py b/mx_rec/util/log.py index 0093704b..9fb1c678 100644 --- a/mx_rec/util/log.py +++ b/mx_rec/util/log.py @@ -5,17 +5,22 @@ import os import logging +from mx_rec.constants.constants import RecPyLogLevel, EnvOption -def get_log_level(): - env_log_level = os.getenv("MXREC_LOG_LEVEL") - if env_log_level is None: - env_log_level = "INFO" - log_level = logging.getLevelName(env_log_level) - if not isinstance(log_level, int): - raise EnvironmentError("A wrong log level string was given.") +def get_logger(log_level: str): + options = [i.value for i in list(RecPyLogLevel)] + if log_level not in options: + raise ValueError(f"log level set for mxRec is not valid, only {options} are allowed, but got {log_level}") - log_format = "%(asctime)s\t%(levelname)s\t%(message)s" - date_format = "%m/%d/%Y %H:%M:%S %p" + rec_logger = logging.getLogger("MxRec") + formatter = logging.Formatter(fmt="[MxRec][%(asctime)s] [%(levelname)s] %(message)s", + datefmt="%m/%d/%Y %H:%M:%S %p") + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + rec_logger.addHandler(stream_handler) + rec_logger.setLevel(log_level) + return rec_logger - logging.basicConfig(level=log_level, format=log_format, datefmt=date_format) + +logger = get_logger(log_level=os.getenv(EnvOption.MXREC_LOG_LEVEL.value, RecPyLogLevel.INFO.value)) diff --git a/mx_rec/util/normalization.py b/mx_rec/util/normalization.py index 1a515954..45f9e2fd 100644 --- a/mx_rec/util/normalization.py +++ b/mx_rec/util/normalization.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. import re -import logging + +from mx_rec.util.log import logger def fix_invalid_table_name(name): @@ -18,7 +19,6 @@ def fix_invalid_table_name(name): if not fix_name: raise ValueError(f"The table name '{name}' doesn't contain valid character, " f"according to the rule '{pattern}'") - logging.warning(f"The table name '%s' contains invalid characters. " - f"The system automatically remove invalid characters. " - f"The table name was changed to '%s'", name, fix_name) + logger.warning(f"The table name '%s' contains invalid characters. The system automatically " + f"remove invalid characters. The table name was changed to '%s'", name, fix_name) return fix_name diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py index 920c8e70..121e65f3 100644 --- a/mx_rec/util/ops.py +++ b/mx_rec/util/ops.py @@ -2,12 +2,11 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -import logging import os import tensorflow as tf -from mx_rec.constants.constants import HOST_PIPELINE_OPS_LIB_PATH +from mx_rec.util.log import logger def import_host_pipeline_ops(): @@ -17,7 +16,7 @@ def import_host_pipeline_ops(): default_so_path = os.path.join( os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), 'mx_rec/libasc/libasc_ops.so') - logging.debug(f"Using the DEFAULT PATH '{default_so_path}' to get ops lib.") + logger.debug("Using the DEFAULT PATH '%s' to get ops lib.", default_so_path) return tf.load_op_library(default_so_path) else: raise ValueError("Please check if libasc_ops.so exists (mxRec correctly installed)") diff --git a/mx_rec/util/perf.py b/mx_rec/util/perf.py index 66501b5f..5070773e 100644 --- a/mx_rec/util/perf.py +++ b/mx_rec/util/perf.py @@ -3,7 +3,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import time -import logging + +from mx_rec.util.log import logger def performance(method_name): @@ -12,7 +13,7 @@ def performance(method_name): start = time.perf_counter() result = func(*args, **kwargs) span = time.perf_counter() - start - logging.debug(f"{method_name} method consume {span:.6f}s.") + logger.debug(f"%s method consume %s (s).", method_name, round(span, 6)) return result return wrapper return decorator diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index 5ce9a26a..ce68162a 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -3,19 +3,17 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. -import os +from typing import List, Tuple, Any, Callable, Dict, Optional, Union, Type import re -from typing import Callable, Any -from typing import List, Optional, Tuple + +import os +import inspect +import functools import tensorflow as tf -from mx_rec.constants.constants import MIN_SIZE -from mx_rec.constants.constants import MAX_SIZE -from mx_rec.constants.constants import MAX_DEVICE_NUM -from mx_rec.constants.constants import MAX_RANK_SIZE -from mx_rec.constants.constants import MIN_DEVICE_NUM -from mx_rec.constants.constants import MIN_RANK_SIZE +from mx_rec.constants.constants import MIN_SIZE, MAX_SIZE, MAX_RANK_SIZE, MIN_RANK_SIZE +from mx_rec.util.log import logger class Validator: @@ -23,11 +21,11 @@ class Validator: A validator to check the input parameters """ - def __init__(self, value, msg="value is invalid"): + def __init__(self, name: Union[List[str], str], value: Union[List[Any], Any], msg="value is invalid"): """ - :param value: the value for validation :param msg: default error msg """ + self.name = name self.value = value self.msg = msg self.checkers = [] @@ -40,7 +38,7 @@ class Validator: if self.is_valid_state is None: self.is_valid_state = True for checker, msg in self.checkers: - if not checker(self.value): + if not checker(): self.msg = msg raise ValueError(self.msg) if self.is_valid_state: @@ -52,8 +50,87 @@ class Validator: self.check() return self.is_valid_state - def get_value(self, default=None): - return self.value if self.is_valid() else default + +def para_checker_decorator(check_option_list: List[Tuple[Union[List[str], str], + Type[Validator], + Optional[Dict], + Optional[List[str]]]]): + """ + 函数参数校验装饰器 + :param check_option_list: + 需要校验的参数及其相关校验器[“需要检验的参数或参数组合”, "使用的校验器", "校验器的参数", "校验器需要执行的方法(添加指定校验)"] + :return: + """ + + def para_checker(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + func_spec = inspect.getfullargspec(func) + # 将函数有默认值的参数加入kwargs + args_with_default = set() + if func_spec.defaults is not None: + arg_with_default_num = len(func_spec.defaults) + for arg, default in zip(func_spec.args[-arg_with_default_num:], func_spec.defaults): + if arg in kwargs: + continue + args_with_default.add(arg) + kwargs.update({arg: default}) + logger.debug("[checker wrapper]func %s args: %s, kwargs: %s", func.__name__, args, kwargs) + # 执行每一个检查项 + for option in check_option_list: + optional_check_list = None + validator_kwargs = {} + + # 解包每个检查项的:待检查参数名,检查器,检查器参数,特定的检查方法 + option_num = len(option) + if option_num == 2: + para_list_to_be_check, validator = option + elif option_num == 3: + para_list_to_be_check, validator, validator_kwargs = option + else: + para_list_to_be_check, validator, validator_kwargs, optional_check_list = option + + if not isinstance(para_list_to_be_check, list): + para_list_to_be_check = [para_list_to_be_check] + + # 确认当前检查项需要检查的参数是否在函数参数中 + paras = [] + for para_to_be_check in para_list_to_be_check: + if para_to_be_check not in kwargs: + logger.debug("[checker wrapper]invalid para '%s' to be checked, " + "not passed to the function '%s'", para_to_be_check, func.__name__) + continue + paras.append(kwargs.get(para_to_be_check)) + + # 如果检查的参数不在传参中,跳过该检查项 + if not paras: + continue + + # 更新检查器的参数 + validator_kwargs.update( + { + "name": para_list_to_be_check[0] if len(para_list_to_be_check) == 1 else para_list_to_be_check, + "value": paras[0] if len(paras) == 1 else paras + } + ) + + validator_instance = validator(**validator_kwargs) + + # 添加检查器特定的检查方法 + if optional_check_list and len(optional_check_list) != 0: + for optional_check in optional_check_list: + getattr(validator_instance, optional_check)() + + # 执行检查 + validator_instance.check() + + for arg in args_with_default: + del kwargs[arg] + return func(*args, **kwargs) + + return wrapper + + return para_checker class ClassValidator(Validator): @@ -61,14 +138,75 @@ class ClassValidator(Validator): Check class validator. """ - def __init__(self, value, classes): - super().__init__(value) + def __init__(self, name, value, classes): + super(ClassValidator, self).__init__(name, value) self.classes = classes + self.register() - def check_isinstance(self): + def register(self): """Check arg isinstance of classes""" - self.register_checker(lambda path: isinstance(self.value, self.classes), f"Invalid parameter type, not " - f"in {self.classes}") + self.register_checker(lambda: isinstance(self.value, self.classes), + f"Invalid parameter type of para '{self.name}', " + f"not in {self.classes}, but: '{type(self.value)}'") + return self + + +class OptionValidator(Validator): + """ + Check class validator. + """ + + def __init__(self, name, value, options): + super(OptionValidator, self).__init__(name, value) + self.options = options + self.register() + + def register(self): + """Check arg isinstance of classes""" + self.register_checker(lambda: self.value in self.options, + f"Invalid option of '{self.name}', " + f"should be one of '{self.options}', but: '{self.value}'") + return self + + +class ValueCompareValidator(Validator): + """ + Check value validator. Whether value equals to target value. + """ + def __init__(self, name: Union[List[str], str], value: Union[List[Any], Any], target: Any): + super(ValueCompareValidator, self).__init__(name, value) + self.name = name if isinstance(name, list) else [name] + self.value = value if isinstance(value, list) else [value] + self.target = target + + def check_at_least_one_not_equal_to_target(self): + """ + 至少一个值不为目标值 + Returns: + + """ + self.register_checker(lambda: not all([v == self.target for v in self.value]), + f"at least one of '{','.join(self.name)}' should not be equal to {self.target}") + return self + + def check_at_least_one_equal_to_target(self): + """ + 至少一个值为目标值 + Returns: + + """ + self.register_checker(lambda: any([v == self.target for v in self.value]), + f"at least one of '{','.join(self.name)}' should be equal to {self.target}") + return self + + def check_all_not_equal_to_target(self): + """ + 所有值都不为目标值 + Returns: + + """ + self.register_checker(lambda: all([v != self.target for v in self.value]), + f" all of '{','.join(self.name)}' should not be equal to {self.target}") return self @@ -77,28 +215,35 @@ class StringValidator(Validator): String type validator. """ - def __init__(self, value, max_len=None, min_len=0): - super().__init__(value) + def __init__(self, name, value, max_len: Optional[int] = None, min_len: Optional[int] = 0, + element: Optional[str] = None, msg=""): + super(StringValidator, self).__init__(name, value) self.max_len = max_len self.min_len = min_len self.whitelist = "^[0-9A-Za-z_]+$" - self.register_checker(lambda x: isinstance(x, str), "type is not str") + self.element = element + msg = msg if msg else f"type of '{name}' is not str, '{value}' is '{type(value)}'" + self.register_checker(lambda: isinstance(value, str), msg) def check_string_length(self): if self.min_len is not None: - self.register_checker(lambda x: len(x) >= self.min_len, f"length is less than {self.min_len}") + self.register_checker(lambda: len(self.value) >= self.min_len, + f"'{self.name}' length is less than {self.min_len}") if self.max_len is not None: - self.register_checker(lambda x: len(x) <= self.max_len, f"length is bigger than {self.max_len}") + self.register_checker(lambda: len(self.value) <= self.max_len, + f"'{self.name}' length is bigger than {self.max_len}") return self - def check_not_contain_black_element(self, element): - self.register_checker(lambda x: x is not None and element is not None and x.find(element) == -1) + def check_not_contain_black_element(self): + if self.value is not None and self.element is not None and self.element != "": + self.register_checker(lambda: self.value.find(self.element) == -1, + f"'{self.name}' contain black element '{self.element}'") return self def check_whitelist(self): """Perform whitelist verification on the input string""" - self.register_checker(lambda x: x is not None and re.match(self.whitelist, x) is not None, - "The string is invalid, please check the input string. " + self.register_checker(lambda: self.value is not None and re.match(self.whitelist, self.value) is not None, + f"The string '{self.name}' is invalid, please check the input string. " "Note: It should be a string consisting of numbers, letters, and underscores.") return self @@ -122,50 +267,145 @@ class StringValidator(Validator): return self -class IntValidator(Validator): +class OptionalStringValidator(StringValidator): """ - Int type validator + String type validator if value is not None """ - def __init__(self, value: int, min_value: int = None, max_value: int = None): - super().__init__(value) + def __init__(self, name, value, max_len=None, min_len=0, element: Optional[str] = None, msg=""): + if value is None: + super(OptionalStringValidator, self).__init__(name, "", None, None, None, msg) + else: + super(OptionalStringValidator, self).__init__(name, value, max_len, min_len, element, msg) + + +class SSDFeatureValidator(Validator): + """ + Check SSD related parameters + """ + + def __init__(self, name, value): + super(SSDFeatureValidator, self).__init__(name, value) + self.register() + + def register(self): + """Check ssd related parameters""" + s_size, ssd_data_path, h_size = self.value + self.register_checker(lambda: isinstance(s_size, int), + f"'{self.name[0]}', not int, but '{type(s_size)}'") + self.register_checker(lambda: isinstance(h_size, int), + f"'{self.name[2]}', not int, but '{type(h_size)}'") + + if s_size != 0: + self.register_checker(lambda: not (h_size == 0 and s_size > 0), + f"'{self.name[2]}' should be greater than 0 when enabling ssd feature") + + self.register_checker(lambda: not (h_size != 0 and s_size < 0), + f"'{self.name[0]}' should be greater than 0 when enabling ssd feature") + + self.register_checker(lambda: isinstance(ssd_data_path, (list, tuple)) and len(ssd_data_path) != 0, + f"'{self.name[1]}' should be type of list and not empty") + + self.register_checker(lambda: len([p for p in ssd_data_path if self._is_invalid_path(p)]) == 0, + f"'{self.name[1]}' contains invalid path") + + return self + + def _is_invalid_path(self, path: str): + return not os.path.exists(path) or not os.path.isdir(path) or os.path.islink(path) or ".." in path + + +class NumValidator(Validator): + """ + number validator float or int + """ + + def __init__(self, name: str, value: int, min_value: int = None, max_value: int = None, + invalid_options: List = None, constrained_options: List = None, msg: str = ""): + if isinstance(value, tf.TensorShape) and value.ndims == 1: + value = value.as_list()[0] + super(NumValidator, self).__init__(name, value) + self.min_value = min_value self.max_value = max_value - self.register_checker(lambda x: isinstance(x, int), "type is not int") + self.invalid_options = invalid_options + self.constrained_options = constrained_options + self.register_checker(lambda: isinstance(self.value, (int, float)), + msg if msg else f"type of '{name}' is not int or float") def check_value(self): if self.min_value is not None: - self.register_checker(lambda x: x >= self.min_value, f"value is less than {self.min_value}") + self.register_checker(lambda: self.value >= self.min_value, f"'{self.name}' is less than {self.min_value}") if self.max_value is not None: - self.register_checker(lambda x: x <= self.max_value, f"value is bigger than {self.max_value}") + self.register_checker(lambda: self.value <= self.max_value, + f"'{self.name}' is bigger than {self.max_value}") + if self.invalid_options is not None: + self.register_checker(lambda: self.value not in self.invalid_options, + f"'{self.name}' is invalid, num in '{self.invalid_options}' is forbidden") + + if self.constrained_options is not None: + self.register_checker(lambda: self.value in self.constrained_options, + f"'{self.name}' is invalid, only num in '{self.constrained_options}' is allowed") + return self -class RankSizeValidator(IntValidator): +class IntValidator(NumValidator): """ - Distributed training job size validator + Int type validator """ - def check_rank_size_valid(self): - super().__init__(self.value) - self.register_checker(lambda x: MIN_RANK_SIZE <= self.value <= MAX_RANK_SIZE, - "Invalid rank size") - return self + def __init__(self, name: str, value: int, min_value: int = None, max_value: int = None, + invalid_options: List = None, constrained_options: List = None, msg: str = ""): + super(IntValidator, self).__init__(name, value, min_value, max_value, invalid_options, constrained_options, msg) + self.register_checker(lambda: isinstance(self.value, int), msg if msg else f"type of '{name}' is not int") - def check_device_num_valid(self): - super().__init__(self.value) - self.register_checker(lambda x: MIN_DEVICE_NUM <= self.value <= MAX_DEVICE_NUM, - "Invalid device num") - return self + +class OptionalIntValidator(IntValidator): + """ + Int type validator if value is not None + """ + + def __init__(self, name: str, value: int, min_value: int = None, max_value: int = None, + invalid_options: List = None, constrained_options: List = None, msg: str = ""): + if value is None: + super(OptionalIntValidator, self).__init__(name, 0, None, None, None, None, msg) + else: + super(OptionalIntValidator, self).__init__(name, value, min_value, max_value, + invalid_options, constrained_options, msg) + + +class Convert2intValidator(IntValidator): + """ + check whether a variable can be converted to int or not. + """ + def __init__(self, name: str, value: int, min_value: int = None, max_value: int = None, + invalid_options: List = None, constrained_options: List = None, msg: str = ""): + convertable = True + int_value = None + try: + int_value = int(value) + except TypeError: + convertable = False + if convertable: + super(Convert2intValidator, self).__init__(name, int_value, min_value, max_value, invalid_options, + constrained_options, msg) + else: + super(Convert2intValidator, self).__init__(name, + value, + min_value, + max_value, + invalid_options, + constrained_options, f"'{name}' cannot be converted to int") class DirectoryValidator(StringValidator): - def __init__(self, value, max_len=None, min_len=1): + def __init__(self, name, value, max_len=None, min_len=1): """ @param value: the path, should not be emtpy string, should not contain double dot(../) """ - super().__init__(value, max_len, min_len) - self.register_checker(lambda x: isinstance(x, str), "type is not str") + super(DirectoryValidator, self).__init__(name, value, max_len, min_len) + self.register_checker(lambda: isinstance(value, str), "type is not str") @staticmethod def remove_prefix(string: Optional[str], prefix: Optional[str]) -> Tuple[bool, Optional[str]]: @@ -210,25 +450,29 @@ class DirectoryValidator(StringValidator): return True def check_is_not_none(self): - self.register_checker(lambda path: self.value is not None and len(self.value) > 0, + self.register_checker(lambda: self.value is not None and len(self.value) > 0, "Invalid directory parameter") return self def check_not_soft_link(self): - self.register_checker(lambda path: os.path.realpath(self.value) == os.path.normpath(self.value), + self.register_checker(lambda: os.path.realpath(self.value) == os.path.normpath(self.value), "soft link or relative path should not be in the path parameter") return self def path_should_exist(self, is_file=True, msg=None): - self.register_checker(lambda path: os.path.exists(self.value), + self.register_checker(lambda: os.path.exists(self.value), msg if msg else "path parameter does not exist") if is_file: - self.register_checker(lambda path: os.path.isfile(self.value), + self.register_checker(lambda: os.path.isfile(self.value), msg if msg else "path parameter is not a file") return self + def check_exists_if_not_empty(self): + if self.value: + self.register_checker(lambda: os.path.exists(os.path.realpath(self.value)), f"'{self.value}' not exists") + def path_should_not_exist(self): - self.register_checker(lambda path: not os.path.exists(self.value), "path parameter does not exist") + self.register_checker(lambda: not os.path.exists(self.value), "path parameter does not exist") return self def with_blacklist(self, lst: List = None, exact_compare: bool = True, msg: str = None): @@ -239,17 +483,17 @@ class DirectoryValidator(StringValidator): if msg is None: msg = "path should not in blacklist" if exact_compare: - self.register_checker(lambda path: path not in [os.path.realpath(each) for each in lst], msg) + self.register_checker(lambda: self.value not in [os.path.realpath(each) for each in lst], msg) else: self.register_checker( - lambda path: not any([DirectoryValidator.check_is_children_path(each, path) for each in lst]), msg + lambda: not any([DirectoryValidator.check_is_children_path(each, self.value) for each in lst]), msg ) return self def should_not_contains_sensitive_words(self, words: List = None, msg=None): if words is None: words = ["Key", "password", "privatekey"] - self.register_checker(lambda path: DirectoryValidator.__check_with_sensitive_words(path, words), msg) + self.register_checker(lambda: DirectoryValidator.__check_with_sensitive_words(self.value, words), msg) return self @@ -258,21 +502,21 @@ class FileValidator(StringValidator): Check if file is valid. """ - def __init__(self, value): + def __init__(self, name, value): """ @param value: the file path, should not be emtpy string, should not contain double dot(../) """ - super().__init__(value) - self.register_checker(lambda x: isinstance(x, str), "parameter value's type is not str") + super(FileValidator, self).__init__(name, value) + self.register_checker(lambda: isinstance(self.value, str), "parameter value's type is not str") def check_file_size(self, max_size=MAX_SIZE, min_size=MIN_SIZE): file_stat = tf.io.gfile.stat(self.value) - self.register_checker(lambda path: min_size < file_stat.length <= max_size, + self.register_checker(lambda: min_size < file_stat.length <= max_size, f"file size: {file_stat.length} is invalid, not in ({min_size}, {max_size}]") return self def check_not_soft_link(self): - self.register_checker(lambda path: not os.path.islink(self.value), + self.register_checker(lambda: not os.path.islink(self.value), f"soft link or relative path: {self.value} should not be in the path parameter") return self @@ -282,41 +526,6 @@ class FileValidator(StringValidator): stat_info = os.stat(self.value) file_uid = stat_info.st_uid file_gid = stat_info.st_gid - self.register_checker( - lambda path: process_uid == file_uid or process_gid == file_gid, "Invalid log file user or group.") + self.register_checker(lambda: process_uid == file_uid or process_gid == file_gid, + "Invalid log file user or group.") return self - - -class RankInfoValidator: - """ - Check replace rank table system environment configuration. - """ - - @staticmethod - def check_visible_devices(): - visible_devices = os.getenv("ASCEND_VISIBLE_DEVICES") - device_res = StringValidator(visible_devices).check() - if not device_res: - raise TypeError("env variable ascend_visible_devices is null, please config ASCEND_VISIBLE_DEVICES in " - "docker container start.") - - rank_size = os.getenv("CM_WORKER_SIZE") - rank_size_res = StringValidator(rank_size).check() - if not rank_size_res: - raise TypeError("env variable CM_WORKER_SIZE is null, please config CM_WORKER_SIZE. For example, " - "CM_WORKER_SIZE=1") - - try: - rank_size_value = int(rank_size) - except ValueError as err: - raise ValueError("Invalid rank size, rank size is a valid integer.") from err - - res = RankSizeValidator(rank_size_value, 1, 16).check_rank_size_valid() - if not res and rank_size_value not in [1, 2, 4, 8, 16]: - raise ValueError("Invalid rank size, rank size must between 0 and 15 in recommendation training.") - - chief_device = os.getenv("CM_CHIEF_DEVICE") - chief_device_res = StringValidator(chief_device).check() - if not chief_device_res: - raise TypeError("env variable CM_CHIEF_DEVICE is null, please config CM_CHIEF_DEVICE. For example, " - "CM_CHIEF_DEVICE=0") diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index 48687bca..f9179635 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -95,7 +95,7 @@ void FeatAdmitNEvictCkpt::SetTable2ThreshTrans(string embName) // save void FeatAdmitNEvictCkpt::SetHistRecTrans(string embName) { - if (GetCombineSwitch()) { + if (GlobalEnv::useCombineFaae) { embName = COMBINE_HISTORY_NAME; } auto histRecSize = GetHistRecSize(embName); @@ -126,7 +126,7 @@ void FeatAdmitNEvictCkpt::SetTable2Thresh(string embName) // load void FeatAdmitNEvictCkpt::SetHistRec(string embName) { - if (GetCombineSwitch()) { + if (GlobalEnv::useCombineFaae) { embName = COMBINE_HISTORY_NAME; } const auto& transArr = transferData.int64Arr; diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h index 268120c3..a3fe3be5 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h @@ -9,6 +9,7 @@ #define MXREC_FEAT_ADMIT_N_EVICT_CKPT_H #include "ckpt_data_handler/ckpt_data_handler.h" +#include "utils/config.h" namespace MxRec { using namespace std; diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 2109f1e8..e965cc56 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -50,11 +50,9 @@ inline void ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) /// DDR模型下处理特征的offset、swap信息等 /// \param embName 表名 /// \param keys 查询向量 -/// \param iBatch 预取数据处理计数 -/// \param tmpDataOut 临时向量 +/// \param DDRParam 临时向量 /// \param channelId 通道索引(训练/推理) -void EmbHashMap::Process(const string& embName, vector& keys, size_t iBatch, - DDRParam& ddrParam, int channelId) +void EmbHashMap::Process(const string& embName, vector& keys, DDRParam& ddrParam, int channelId) { #ifndef GTEST EASY_FUNCTION(profiler::colors::Pink) @@ -64,13 +62,12 @@ void EmbHashMap::Process(const string& embName, vector& keys, size_t embHashMap.oldSwap.clear(); embHashMap.maxOffsetOld = embHashMap.maxOffset; - auto keepBatch = swapId - iBatch; // 处理batch的次数,多个预取一起处理算一次 - bool findOffsetV2 = GetEnv("FIND_OFFSET_V2"); + auto keepBatch = swapId; // 处理batch的次数,多个预取一起处理算一次 - LOG_DEBUG("FindOffset version:{}", findOffsetV2); + LOG_DEBUG("FindOffset version:{}", GlobalEnv::findOffsetV2); // 找到所有key的偏移;dev和host需要交换的位置 - if (findOffsetV2) { + if (GlobalEnv::findOffsetV2) { FindAndUpdateOffset(embName, keys, swapId, keepBatch, channelId); } else { FindOffset(embName, keys, swapId, keepBatch, channelId); @@ -195,7 +192,6 @@ void EmbHashMap::FindAndUpdateBatchId(vector& keys, size_t currentBat EmbHashMapInfo& embHashMap) const { EASY_FUNCTION() - bool findOffsetV3 = GetEnv("FIND_OFFSET_V3"); for (size_t i = 0; i < keySize; i++) { int offset; auto& key = keys[i]; @@ -204,7 +200,7 @@ void EmbHashMap::FindAndUpdateBatchId(vector& keys, size_t currentBat } const auto& iter = embHashMap.hostHashMap.find(key); if (iter != embHashMap.hostHashMap.end()) { // found - if (findOffsetV3) { + if (GlobalEnv::findOffsetV3) { key = -1; } offset = static_cast(iter->second); diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 1b6c8b25..2eecc7f8 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -26,8 +26,7 @@ namespace MxRec { void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); - void Process(const string& embName, std::vector& keys, size_t iBatch, - DDRParam& ddrParam, int channelId); + void Process(const string& embName, std::vector& keys, DDRParam& ddrParam, int channelId); void FindAndUpdateOffset(const string& embName, vector& keys, size_t currentBatchId, size_t keepBatchId, int channelId); diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 16d6d6d9..5f32b457 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -43,26 +43,8 @@ int HDTransfer::Init(const vector& embInfos, uint32_t localRankId) // 创建acltdtDataset类型的数据,对等一个Vector。同步接口。 aclDatasets[embInfo.name] = acltdtCreateDataset(); } - const int defaultAclTimeout = -1; - this->timeout = defaultAclTimeout; - const char *envTimeout = getenv("AclTimeout"); - if (envTimeout != nullptr) { - try { - int32_t tmp = std::stoi(envTimeout); - if (tmp >= -1 && tmp <= INT32_MAX) { - this->timeout = tmp; - LOG_INFO("Succeed to parse ${env:AclTimeout}: {}", tmp); - } else { - LOG_ERROR("Failed to parse ${env:AclTimeout}: {}, expected in (0, INT32_MAX)", tmp); - } - } catch (const std::invalid_argument &e) { - LOG_ERROR("Failed to parse ${env:AclTimeout}: {}, expected a integer, set to default: {}", - envTimeout, defaultAclTimeout); - } - } - LOG_DEBUG("hd transfer timeout:{}", timeout); running = true; - LOG_INFO("hd_transfer init"); + LOG(INFO) << "hd_transfer init"; #endif return true; } @@ -94,24 +76,7 @@ void HDTransfer::Destroy() void HDTransfer::CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum) { #ifndef GTEST - int channelSize; - const char* env = getenv("HD_CHANNEL_SIZE"); - if (env == nullptr) { - channelSize = LARGE_CHANNEL_SIZE; - } else { - try { - channelSize = stoi(env); - } catch (const std::invalid_argument& e) { - LOG_WARN("wrong HD_CHANNEL_SIZE env {}", e.what()); - channelSize = LARGE_CHANNEL_SIZE; - } catch (const std::out_of_range& e) { - LOG_WARN("wrong HD_CHANNEL_SIZE env {}", e.what()); - channelSize = LARGE_CHANNEL_SIZE; - } - if (channelSize <= 0) { - channelSize = LARGE_CHANNEL_SIZE; - } - } + int channelSize = GlobalEnv::hdChannelSize; LOG_INFO("user config all2all restore lookup channel size:{}", channelSize); for (int c = static_cast(TransferChannel::D2H); c != static_cast(TransferChannel::INVALID); c++) { auto channel = static_cast(c); @@ -155,7 +120,7 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in string sendName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); LOG_INFO(HD + "hd transfer send {}, send count is {}, size list:{}", - sendName, sizes.size(), VectorToString(sizes)); + sendName, sizes.size(), VectorToString(sizes)); if (sizes.size() == 0) { LOG_WARN("tensors num can not be zero"); @@ -229,7 +194,7 @@ size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& if (aclDatasets[embName] == nullptr) { throw runtime_error(StringFormat("Failed recv:%s.", recvName.c_str()).c_str()); } - auto aclStatus = acltdtReceiveTensor(transferChannels[recvName], aclDatasets[embName], timeout /*-1 no timeout */); + auto aclStatus = acltdtReceiveTensor(transferChannels[recvName], aclDatasets[embName], GlobalEnv::aclTimeout); if (!running) { return 0; } diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index ac09e6e6..b7f9ea38 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -14,6 +14,7 @@ #include "acl/acl_tdt_queue.h" #include "acl_channel.h" #include "utils/common.h" +#include "utils/config.h" #ifndef tdtCreateChannel #define tdtCreateChannel acltdtCreateChannelWithCapacity @@ -25,7 +26,6 @@ namespace MxRec { const std::string HD = "\033[32m[HD]\033[0m "; const std::string HOSTEMB = "\033[32m[HostEmb]\033[0m "; const int PING_PONG_SIZE = 6; - const int LARGE_CHANNEL_SIZE = 40; enum class TransferChannel { D2H, @@ -88,7 +88,6 @@ namespace MxRec { std::unordered_map transferChannels; #endif bool running; - int32_t timeout{-1}; void CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum); }; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index d77f053c..28298ee6 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -24,39 +24,6 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& const vector& thresholdValues, int seed) { #ifndef GTEST - // 是否设置全局去重(相同key的梯度先累加),默认为false - if (getenv("APPLY_GRADIENTS_STRATEGY") != nullptr) { - bool strategy = (!strcmp(getenv("APPLY_GRADIENTS_STRATEGY"), SUM_SAME_ID)); - PerfConfig::gradientStrategy = strategy; - LOG_INFO("config GRADIENTS_STRATEGY:{}", strategy); - } - - // 设置当前进程用于数据处理的线程数,默认为6,取值1-10;取值不在范围内,则数据处理线程启动失败退出 - if (getenv("KEY_PROCESS_THREAD_NUM") != nullptr) { - int num = std::atoi(getenv("KEY_PROCESS_THREAD_NUM")); - if (num < 1 || num > MAX_KEY_PROCESS_THREAD) { - LOG_ERROR("[HybridMgmt::InitKeyProcess] KEY_PROCESS_THREAD_NUM:{}, should in range [1, {}]", - num, MAX_KEY_PROCESS_THREAD); - return false; - } - PerfConfig::keyProcessThreadNum = num; - LOG_INFO("config KEY_PROCESS_THREAD_NUM:{}", num); - } - - // 设置AccCTR去重线程数,默认为8,取值1-8;取值不在范围内,则数据处理线程启动失败退出 - if (getenv("MAX_UNIQUE_THREAD_NUM") != nullptr) { - int num = std::atoi(getenv("MAX_UNIQUE_THREAD_NUM")); - if (num < 1 || num > DEFAULT_MAX_UNIQUE_THREAD_NUM) { - LOG_ERROR("[HybridMgmt::InitKeyProcess] MAX_UNIQUE_THREAD_NUM:{}, should in range [1, {}]", - num, DEFAULT_MAX_UNIQUE_THREAD_NUM); - return false; - } - PerfConfig::maxUniqueThreadNum = num; - LOG_INFO("config MAX_UNIQUE_THREAD_NUM:{}", num); - } - - // 设置是否使用AccCTR库提供的去重、分桶功能,默认关闭 - PerfConfig::fastUnique = GetEnv("FAST_UNIQUE"); // 初始化数据处理类,配置相关信息,启动处理线程 preprocess = Singleton::GetInstance(); preprocess->Initialize(rankInfo, embInfos, thresholdValues, seed); @@ -103,18 +70,25 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, const vector& thresholdValues, bool ifLoad) { #ifndef GTEST + // 环境变量初始化 + ConfigGlobalEnv(); + + // 设置日志的级别,对日志格式进行配置 + SetLog(rankInfo.rankId); + + // 打印环境变量 + LogGlobalEnv(); + // 判断是否已经拉起特征处理线程(key process) if (isRunning) { return true; } - // 设置日志的级别,对日志格式进行配置 - SetLog(rankInfo.rankId); InitRankInfo(rankInfo, embInfos); - g_statOn = GetEnv("STAT_ON"); + g_statOn = GlobalEnv::statOn; LOG_INFO(MGMT + "begin initialize, localRankSize:{}, localRankId:{}, rank:{}", - rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); + rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); mgmtRankInfo = rankInfo; mgmtEmbInfo = embInfos; @@ -157,10 +131,10 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, for (const auto& info: embInfos) { LOG_INFO(MGMT + "emb[{}] vocab size {}+{} sc:{}", - info.name, info.devVocabSize, info.hostVocabSize, info.sendCount); + info.name, info.devVocabSize, info.hostVocabSize, info.sendCount); } LOG_INFO(MGMT + "end initialize, noDDR:{}, maxStep:[{}, {}], rank:{}", rankInfo.noDDR, - rankInfo.maxStep.at(TRAIN_CHANNEL_ID), rankInfo.maxStep.at(EVAL_CHANNEL_ID), rankInfo.rankId); + rankInfo.maxStep.at(TRAIN_CHANNEL_ID), rankInfo.maxStep.at(EVAL_CHANNEL_ID), rankInfo.rankId); #endif return true; } @@ -187,11 +161,11 @@ void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) auto cuKey = item.first; if (lfu.find(cuKey) == lfu.end()) { LOG_ERROR("save step error, ddr key:{}, not exist in lfu, hostHashMap offset:", - cuKey, item.second); + cuKey, item.second); } } LOG_INFO("save step end, table:{}, tableKeyInDdr:{}, tableKeyInLfu:{}", - embTableName, tableKeyInDdr, lfu.size()); + embTableName, tableKeyInDdr, lfu.size()); } } @@ -213,9 +187,9 @@ void HybridMgmt::RestoreFreq4Save(CkptData& saveData) vector hbm2DdrKeys; vector ddr2HbmKeys; LOG_INFO("restore freq info for save step, table:{}, embHashMap.oldSwap size:{}", - embTableName, embHashMap.oldSwap.size()); + embTableName, embHashMap.oldSwap.size()); LOG_INFO("before, ddr key table size:{}, exclude ddr key table size:{}", - ddrKeyFreqMaps[embTableName].size(), excludeDDRKeyFreqMaps[embTableName].size()); + ddrKeyFreqMaps[embTableName].size(), excludeDDRKeyFreqMaps[embTableName].size()); for (const auto& swapKeys : embHashMap.oldSwap) { hbm2DdrKeys.emplace_back(swapKeys.second); ddr2HbmKeys.emplace_back(swapKeys.first); @@ -237,9 +211,9 @@ void HybridMgmt::RestoreFreq4Save(CkptData& saveData) ddrKeyFreqMaps[embTableName].erase(key); } LOG_INFO("hbm2DdrKeysNotInExcludeMapCount:{}, ddr2HbmKeysNotInDDRMapCount:{}", - hbm2DdrKeysNotInExcludeMapCount, ddr2HbmKeysNotInDDRMapCount); + hbm2DdrKeysNotInExcludeMapCount, ddr2HbmKeysNotInDDRMapCount); LOG_INFO("after, ddr key table size:{}, exclude ddr key table size:{}", - ddrKeyFreqMaps[embTableName].size(), excludeDDRKeyFreqMaps[embTableName].size()); + ddrKeyFreqMaps[embTableName].size(), excludeDDRKeyFreqMaps[embTableName].size()); } } @@ -458,22 +432,22 @@ bool HybridMgmt::IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEm const auto& loadEmbInfo { loadEmbTable->second.hostEmbInfo }; if (setupHostEmbs->sendCount != loadEmbInfo.sendCount) { LOG_ERROR(MGMT + "Load data sendCount {} for table {} does not match setup sendCount {}", - setupHostEmbs->sendCount, setupHostEmbs->name, loadEmbInfo.sendCount); + setupHostEmbs->sendCount, setupHostEmbs->name, loadEmbInfo.sendCount); loadDataMatches = false; } if (setupHostEmbs->extEmbeddingSize != loadEmbInfo.extEmbeddingSize) { LOG_ERROR(MGMT + "Load data extEmbeddingSize {} for table {} does not match setup extEmbeddingSize {}", - setupHostEmbs->extEmbeddingSize, setupHostEmbs->name, loadEmbInfo.extEmbeddingSize); + setupHostEmbs->extEmbeddingSize, setupHostEmbs->name, loadEmbInfo.extEmbeddingSize); loadDataMatches = false; } if (setupHostEmbs->devVocabSize != loadEmbInfo.devVocabSize) { LOG_ERROR(MGMT + "Load data devVocabSize {} for table {} does not match setup devVocabSize {}", - setupHostEmbs->devVocabSize, setupHostEmbs->name, loadEmbInfo.devVocabSize); + setupHostEmbs->devVocabSize, setupHostEmbs->name, loadEmbInfo.devVocabSize); loadDataMatches = false; } if (setupHostEmbs->hostVocabSize != loadEmbInfo.hostVocabSize) { LOG_ERROR(MGMT + "Load data hostVocabSize {} for table {} does not match setup hostVocabSize {}", - setupHostEmbs->hostVocabSize, setupHostEmbs->name, loadEmbInfo.hostVocabSize); + setupHostEmbs->hostVocabSize, setupHostEmbs->name, loadEmbInfo.hostVocabSize); loadDataMatches = false; } if (!loadDataMatches) { @@ -501,7 +475,7 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) if (embTableCount < loadHostEmbs->size()) { LOG_ERROR(MGMT + "Load data has {} tables more than setup table num {}", - loadHostEmbs->size(), embTableCount); + loadHostEmbs->size(), embTableCount); return false; } return true; @@ -656,7 +630,8 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) LOG_DEBUG("sendLookupSyncTC(ms):{}", sendLookupSyncTC.ElapsedMS()); // 训练时,使用全局去重聚合梯度,发送全局去重的key和对应的恢复向量 - if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID) { + if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && + channelId == TRAIN_CHANNEL_ID) { TimeCost sendUnikeysSyncTC; hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embInfo.name); infoVecs->pop_back(); @@ -671,9 +646,8 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) // 发送恢复向量 TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); - LOG_DEBUG("sendRestoreSyncTC(ms):{}, sendTensorsSyncTC(ms):{}, ParseKeysTC HBM mode (ms):{}", - sendRestoreSyncTC.ElapsedMS(), sendTensorsSyncTC.ElapsedMS(), ParseKeysTC.ElapsedMS()); + sendRestoreSyncTC.ElapsedMS(), sendTensorsSyncTC.ElapsedMS(), ParseKeysTC.ElapsedMS()); } batchId++; return true; @@ -696,35 +670,27 @@ bool HybridMgmt::EndBatch(int batchId, int channelId) const bool HybridMgmt::ParseKeys(int channelId, int& batchId) { #ifndef GTEST - LOG_INFO(MGMT + "DDR mode, start parse keys, nBatch:{} , [{}]:{}", - mgmtRankInfo.nBatch, channelId, batchId); + LOG_INFO(MGMT + "DDR mode, start parse keys, [{}]:{}", channelId, batchId); TimeCost parseKeyTC; int start = batchId; - int iBatch = 0; // 预取数据处理计数 - bool ifHashmapFree = true; bool remainBatch = true; // 是否从通道获取了数据 - while (true) { - LOG_INFO(MGMT + "parse keys, [{}]:{}", channelId, batchId); - for (const auto& embInfo : mgmtEmbInfo) { - ifHashmapFree = ProcessEmbInfo(embInfo.name, batchId, channelId, iBatch, remainBatch); - - // 通道数据已空 - if (!remainBatch) { - LOG_DEBUG("last batch ending"); - return false; - } - } - batchId++; - iBatch++; - if (EndBatch(batchId, channelId) || iBatch == mgmtRankInfo.nBatch || !ifHashmapFree || !isRunning) { - break; + + LOG_INFO(MGMT + "parse keys, [{}]:{}", channelId, batchId); + for (const auto& embInfo : mgmtEmbInfo) { + ProcessEmbInfo(embInfo.name, batchId, channelId, remainBatch); + // 通道数据已空 + if (!remainBatch) { + LOG_DEBUG("last batch ending"); + return false; } } + batchId++; + if (!isRunning) { return false; } TimeCost embHdTrans2TC; - EmbHDTransWrap(channelId, batchId - 1, start, iBatch); + EmbHDTransWrap(channelId, batchId - 1, start); LOG_DEBUG("embHdTrans2TC TimeCost(ms):{}", embHdTrans2TC.ElapsedMS()); LOG_DEBUG("[{}]-{}, parseKeyTC TimeCost(ms):{}", channelId, batchId, parseKeyTC.ElapsedMS()); #endif @@ -751,18 +717,16 @@ inline void HandlePrepareDDRDataRet(TransferRet prepareSSDRet) /// \param embName 表名 /// \param batchId 已处理的batch数 /// \param channelId 通道索引(训练/推理) -/// \param iBatch 预取数据处理计数 /// \param remainBatchOut 是否从通道获取了数据 /// \return HBM是否还有剩余空间 -bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, - int channelId, int iBatch, bool& remainBatchOut) +bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int channelId, bool& remainBatchOut) { TimeCost getAndSendTensorsTC; TimeCost getTensorsTC; auto& embHashMap = hostHashMaps->embHashMaps.at(embName); - // 进行新一批预取数据时,计数初始化 - if (iBatch == 0) { embHashMap.SetStartCount(); } + // 计数初始化 + embHashMap.SetStartCount(); // 获取查询向量 auto lookupKeys = preprocess->GetLookupKeys(batchId, embName, channelId); @@ -788,22 +752,22 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, vector offsetsOut; DDRParam ddrParam(tmpData, offsetsOut); TimeCost hostHashMapProcessTC; - hostHashMaps->Process(embName, lookupKeys, iBatch, ddrParam, channelId); + hostHashMaps->Process(embName, lookupKeys, ddrParam, channelId); LOG_DEBUG("hostHashMapProcessTC(ms):{}", hostHashMapProcessTC.ElapsedMS()); - if (PerfConfig::gradientStrategy && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { - vector uniqueKeys; - vector restoreVecSec; + if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && + channelId == TRAIN_CHANNEL_ID && remainBatchOut) { + vector uniqueKeys, restoreVecSec; preprocess->GlobalUnique(offsetsOut, uniqueKeys, restoreVecSec); TimeCost sendUnikeysSyncTC; hdTransfer->Send(TransferChannel::UNIQKEYS, { mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : - Vec2TensorI32(uniqueKeys) }, channelId, embName); + Vec2TensorI32(uniqueKeys) }, channelId, embName); TimeCost sendRestoreVecSecSyncTC; hdTransfer->Send(TransferChannel::RESTORE_SECOND, { Vec2TensorI32(restoreVecSec) }, channelId, embName); LOG_DEBUG("sendUnikeysSyncTC(ms):{}sendRestoreVecSecSyncTC(ms):{}", - sendUnikeysSyncTC.ElapsedMS(), sendRestoreVecSecSyncTC.ElapsedMS()); + sendUnikeysSyncTC.ElapsedMS(), sendRestoreVecSecSyncTC.ElapsedMS()); } TimeCost sendTensorsTC; @@ -815,11 +779,10 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } LOG_DEBUG("sendTensorsTC(ms):{} getAndSendTensorsTC(ms):{}, channelId:{}", - sendTensorsTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS(), channelId); + sendTensorsTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS(), channelId); if (!isSSDEnabled && embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch - LOG_WARN(MGMT + "embName {}[{}]{}, iBatch:{} freeSize not enough, {}", - embName, channelId, batchId, iBatch, lookupKeys.size()); + LOG_WARN(MGMT + "embName {}[{}]{}, freeSize not enough, {}", embName, channelId, batchId, lookupKeys.size()); return false; } return true; @@ -829,24 +792,14 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, /// \param channelId 通道索引(训练/推理) /// \param batchId 已处理的batch数 /// \param start -/// \param iBatch 预取数据处理计数 -void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start, int iBatch) +void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start) { - if (iBatch == 0) { - return; - } LOG_INFO(MGMT + "trans emb, batchId:[{}-{}], channelId:{}", start, batchId, channelId); TimeCost hostEmbsTC; hostEmbs->Join(channelId); LOG_DEBUG("hostEmbsTC(ms):{}", hostEmbsTC.ElapsedMS()); EmbHDTrans(channelId, batchId); - - for (int i = 0; i < iBatch - 1; ++i) { - // need send empty - LOG_INFO(MGMT + "trans emb dummy, batchId:{}, ", start + 1 + i); - EmbHDTrans(channelId, batchId); - } } /// 发送H2D和接收D2H向量,并更新host emb @@ -872,8 +825,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) // 接收device换出的emb,并更新到host上 for (const auto& embInfo: mgmtEmbInfo) { const auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); - auto updateEmbV2 = getenv("UpdateEmb_V2"); - if (updateEmbV2 != nullptr and atoi(updateEmbV2) == 1) { + if (GlobalEnv::updateEmbV2) { hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! } else { hostEmbs->UpdateEmb(missingKeys, channelId, embInfo.name); // order! @@ -881,7 +833,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) hostHashMaps->ClearMissingKeys(embInfo.name); } LOG_DEBUG("D2HTC(ms):{} EmbHDTrans TimeCost(ms):{} batchId: {} channelId:{}", - d2hTC.ElapsedMS(), tr.ElapsedMS(), batchId, channelId); + d2hTC.ElapsedMS(), tr.ElapsedMS(), batchId, channelId); } #endif diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 93d34009..ebba695c 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -19,6 +19,7 @@ #include "absl/container/flat_hash_map.h" #include "utils/common.h" +#include "utils/config.h" #include "utils/singleton.h" #include "host_emb/host_emb.h" @@ -105,7 +106,7 @@ namespace MxRec { bool ParseKeysHBM(int channelId, int& batchId); - bool ProcessEmbInfo(const std::string& embName, int batchId, int channelId, int iBatch, bool& remainBatchOut); + bool ProcessEmbInfo(const std::string& embName, int batchId, int channelId, bool& remainBatchOut); void EmbHDTrans(const int channelId, const int batchId); @@ -154,7 +155,7 @@ namespace MxRec { bool EndBatch(int batchId, int channelId) const; - void EmbHDTransWrap(int channelId, const int& batchId, int start, int iBatch); + void EmbHDTransWrap(int channelId, const int& batchId, int start); bool LoadMatchesDDRSetup(const CkptData& loadData); }; diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 8465e37b..a80ff826 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -203,7 +203,7 @@ void FeatureAdmitAndEvict::SetFunctionSwitch(bool isEnableEvict) void FeatureAdmitAndEvict::SetCombineSwitch() { - m_isCombine = GetCombineSwitch(); + m_isCombine = GlobalEnv::useCombineFaae; } bool FeatureAdmitAndEvict::GetFunctionSwitch() const diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 03681180..003c8464 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -29,24 +29,7 @@ inline vector Count2Start(const vector& count) void KeyProcess::SetupHotEmbUpdateStep() { - const auto maxUpdateStep = 1000; - this->hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; - const char *envUpdateStep = getenv("HOT_EMB_UPDATE_STEP"); - if (envUpdateStep != nullptr) { - try { - int tmp = std::stoi(envUpdateStep); - if (tmp >= 1 && tmp <= maxUpdateStep) { - this->hotEmbUpdateStep = tmp; - LOG_INFO("Succeed to parse ${env:HOT_EMB_UPDATE_STEP}: {}.", this->hotEmbUpdateStep); - } else { - LOG_ERROR("${env:HOT_EMB_UPDATE_STEP}: {} should be in [1, 1000], set default: {}.", - tmp, HOT_EMB_UPDATE_STEP_DEFAULT); - } - } catch (const std::invalid_argument &e) { - LOG_ERROR("Failed to parse ${env:HOT_EMB_UPDATE_STEP}: {}, set default: {}.", - envUpdateStep, HOT_EMB_UPDATE_STEP_DEFAULT); - } - } + this->hotEmbUpdateStep = GlobalEnv::hotEmbUpdateStep; } bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos, @@ -91,7 +74,7 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos LOG_WARN(KEY_PROCESS "Feature admit-and-evict function is unavailable ..."); } - if (PerfConfig::fastUnique) { + if (GlobalEnv::fastUnique) { Factory::Create(factory); } @@ -116,7 +99,7 @@ int KeyProcess::Start() return; } #endif - if (PerfConfig::fastUnique) { + if (GlobalEnv::fastUnique) { KeyProcessTaskWithFastUnique(channel, threadId); } else { KeyProcessTask(channel, threadId); @@ -219,7 +202,7 @@ void KeyProcess::GetUniqueConfig(UniqueConf& uniqueConf) uniqueConf.useIdCount = true; uniqueConf.outputType = OutputType::ENHANCED; uniqueConf.minThreadNum = MIN_UNIQUE_THREAD_NUM; - uniqueConf.maxThreadNum = PerfConfig::maxUniqueThreadNum; + uniqueConf.maxThreadNum = GlobalEnv::maxUniqueThreadNum; } void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, @@ -464,7 +447,8 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tensors, keys_t& lookupKeys, int channel) { - if (PerfConfig::gradientStrategy && channel == TRAIN_CHANNEL_ID) { + if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && + channel == TRAIN_CHANNEL_ID) { keys_t uniqueKeys; vector restoreVecSec; diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 65fbaae9..c4e109cc 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -25,6 +25,7 @@ #include "ock_ctr_common/include/error_code.h" #include "utils/common.h" +#include "utils/config.h" #include "utils/time_cost.h" #include "utils/safe_queue.h" diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index d84a1119..31787b75 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -21,18 +21,13 @@ using namespace std; using std::chrono::system_clock; namespace MxRec { - int PerfConfig::keyProcessThreadNum = DEFAULT_KEY_PROCESS_THREAD; - int PerfConfig::maxUniqueThreadNum = DEFAULT_MAX_UNIQUE_THREAD_NUM; - bool PerfConfig::fastUnique = false; - bool PerfConfig::gradientStrategy = false; string g_rankId; int g_glogLevel; bool g_isGlogInit = false; bool g_statOn = false; - RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, - const vector& maxStep) : rankId(rankId), deviceId(deviceId), localRankSize(localRankSize), option(option), - nBatch(nBatch), maxStep(maxStep) + RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, const vector& maxStep) + : rankId(rankId), deviceId(deviceId), localRankSize(localRankSize), option(option), maxStep(maxStep) { MPI_Comm_size(MPI_COMM_WORLD, &rankSize); if (localRankSize != 0) { @@ -43,8 +38,8 @@ namespace MxRec { useDynamicExpansion = option bitand HybridOption::USE_DYNAMIC_EXPANSION; } - RankInfo::RankInfo(int localRankSize, int option, int nBatch, const vector& maxStep) - : localRankSize(localRankSize), option(option), nBatch(nBatch), maxStep(maxStep) + RankInfo::RankInfo(int localRankSize, int option, const vector& maxStep) + : localRankSize(localRankSize), option(option), maxStep(maxStep) { MPI_Comm_rank(MPI_COMM_WORLD, &rankId); MPI_Comm_size(MPI_COMM_WORLD, &rankSize); @@ -99,12 +94,7 @@ namespace MxRec { void SetLog(int rank) { - auto logLevel = getenv("GLOG_stderrthreshold"); - if (logLevel == nullptr) { - g_glogLevel = 0; // default as INFO - } else { - g_glogLevel = atoi(logLevel); - } + g_glogLevel = GlobalEnv::glogStderrthreshold; if (g_rankId.empty()) { g_rankId = std::to_string(rank); } @@ -151,41 +141,9 @@ namespace MxRec { throw std::runtime_error("dsmi_get_chip_info failed, ret = " + to_string(ret)); } - bool GetCombineSwitch() - { - const char* faaeMode = std::getenv("USE_COMBINE_FAAE"); // 获取环境变量 - bool isCombine = false; - if (faaeMode != nullptr) { - try { - isCombine = (std::stoi(faaeMode) == 1); - LOG_INFO("If combine history table: {}", isCombine); - } catch (const std::invalid_argument& e) { - LOG_ERROR("The value of USE_COMBINE_FAAE is invalid!"); - throw std::invalid_argument("Invalid env value USE_COMBINE_FAAE"); - } - } - return isCombine; - } - int GetThreadNumEnv() { - int threadNum = 0; - const char* threadNumEnv = getenv("KEY_PROCESS_THREAD_NUM"); - if (threadNumEnv != nullptr) { - try { - threadNum = std::stoi(threadNumEnv); - } catch (const std::invalid_argument& e) { - threadNum = KEY_PROCESS_THREAD; - LOG_INFO("error value of threadNum, use default KEY_PROCESS_THREAD: {}", threadNum); - } - if (threadNum > KEY_PROCESS_THREAD || threadNum < 0) { - throw runtime_error(StringFormat("%d is not valid", threadNum)); - } - } else { - threadNum = KEY_PROCESS_THREAD; - LOG_INFO("use default KEY_PROCESS_THREAD: {}", threadNum); - } - return threadNum; + return GlobalEnv::keyProcessThreadNum; } void ValidateReadFile(const string& dataDir, size_t datasetSize) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 7a29286d..2348594a 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -26,6 +26,7 @@ #include "absl/container/flat_hash_map.h" #include "securec.h" #include "utils/log.h" +#include "utils/config.h" #include "initializer/initializer.h" #include "initializer/constant_initializer/constant_initializer.h" @@ -68,25 +69,17 @@ namespace MxRec { constexpr int GLOG_MAX_BUF_SIZE = 1024; constexpr int GLOG_TIME_WIDTH_2 = 2; constexpr int GLOG_TIME_WIDTH_6 = 6; - constexpr char GLOG_STAT_FLAG[] = "STAT_ON"; + constexpr char GLOG_STAT_FLAG[] = "statOn"; // unique related config constexpr int UNIQUE_BUCKET = 6; constexpr int MIN_UNIQUE_THREAD_NUM = 1; - constexpr int DEFAULT_MAX_UNIQUE_THREAD_NUM = 8; // validate file constexpr long long FILE_MAX_SIZE = 1LL << 40; constexpr int FILE_MIN_SIZE = 0; - struct PerfConfig { - static int keyProcessThreadNum; - static int maxUniqueThreadNum; - static bool fastUnique; - static bool gradientStrategy; - }; - constexpr int KEY_PROCESS_TIMEOUT = 120; constexpr int GET_BATCH_TIMEOUT = 300; @@ -116,7 +109,6 @@ namespace MxRec { }; string GetChipName(int devID); - bool GetCombineSwitch(); int GetThreadNumEnv(); namespace UBSize { @@ -208,9 +200,8 @@ namespace MxRec { struct RankInfo { RankInfo() = default; - RankInfo(int rankId, int deviceId, int localRankSize, int option, int nBatch, - const std::vector& maxStep); - RankInfo(int localRankSize, int option, int nBatch, const std::vector& maxStep); + RankInfo(int rankId, int deviceId, int localRankSize, int option, const std::vector& maxStep); + RankInfo(int localRankSize, int option, const std::vector& maxStep); int rankId {}; int deviceId {}; diff --git a/src/core/utils/config.cpp b/src/core/utils/config.cpp new file mode 100644 index 00000000..0e482427 --- /dev/null +++ b/src/core/utils/config.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: config module + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + + +#include "config.h" +#include "log.h" + +using namespace std; + +namespace MxRec { + // 设置环境变量默认值 + string GlobalEnv::applyGradientsStrategy = ApplyGradientsStrategyOptions::DIRECT_APPLY; + int GlobalEnv::aclTimeout = -1; // 默认阻塞方式,一直等待直到数据接收完成。 + int GlobalEnv::hdChannelSize = 40; // 默认通道深度40 + bool GlobalEnv::findOffsetV2 = false; + bool GlobalEnv::findOffsetV3 = false; + int GlobalEnv::keyProcessThreadNum = 6; // 默认6个线程 + int GlobalEnv::maxUniqueThreadNum = 8; // 默认最大8个线程 + bool GlobalEnv::fastUnique = false; + bool GlobalEnv::updateEmbV2 = false; + int GlobalEnv::hotEmbUpdateStep = 1000; // 默认1000步更新 + int GlobalEnv::glogStderrthreshold = 0; // 默认info级别 + bool GlobalEnv::useCombineFaae = false; + bool GlobalEnv::statOn = false; + + /// 配置环境变量,Python侧已经做了变量值校验,CPP侧直接使用即可;bool类型,1代表true,0代表false + void ConfigGlobalEnv() + { + // 设置梯度策略 + const char *envStrategy = getenv(RecEnvNames::APPLY_GRADIENTS_STRATEGY); + if (envStrategy != nullptr) { + GlobalEnv::applyGradientsStrategy = envStrategy; + } + + // 设置ACL超时时间 + const char *envAclTimeout = getenv(RecEnvNames::ACL_TIMEOUT); + if (envAclTimeout != nullptr) { + GlobalEnv::aclTimeout = std::stoi(envAclTimeout); + } + + // 设置ACL通道深度 + const char *envHDChannelSize = getenv(RecEnvNames::HD_CHANNEL_SIZE); + if (envHDChannelSize != nullptr) { + GlobalEnv::hdChannelSize = std::stoi(envHDChannelSize); + } + + // 设置偏移查找策略V2 + const char *envFindOffsetV2 = getenv(RecEnvNames::FIND_OFFSET_V2); + if (envFindOffsetV2 != nullptr) { + GlobalEnv::findOffsetV2 = (std::stoi(envFindOffsetV2) == 1); + } + + // 设置偏移查找策略V3 + const char *envFindOffsetV3 = getenv(RecEnvNames::FIND_OFFSET_V3); + if (envFindOffsetV3 != nullptr) { + GlobalEnv::findOffsetV3 = (std::stoi(envFindOffsetV3) == 1); + } + + // 设置数据处理线程数 + const char *envKPNum = getenv(RecEnvNames::KEY_PROCESS_THREAD_NUM); + if (envKPNum != nullptr) { + GlobalEnv::keyProcessThreadNum = std::stoi(envKPNum); + } + + // 设置去重处理线程数 + const char *envUniqNum = getenv(RecEnvNames::MAX_UNIQUE_THREAD_NUM); + if (envUniqNum != nullptr) { + GlobalEnv::maxUniqueThreadNum = std::stoi(envUniqNum); + } + + // 设置是否使用fast unique库进行去重 + const char *envFastUnique = getenv(RecEnvNames::FAST_UNIQUE); + if (envFastUnique != nullptr) { + GlobalEnv::fastUnique = (std::stoi(envFastUnique) == 1); + } + + // 设置是否使用异步更新d2h的host emb + const char *envUpdateEmbV2 = getenv(RecEnvNames::UPDATE_EMB_V2); + if (envUpdateEmbV2 != nullptr) { + GlobalEnv::updateEmbV2 = (std::stoi(envUpdateEmbV2) == 1); + } + + // 设置hot emb更新步数 + const char *envHotEmbStep = getenv(RecEnvNames::HOT_EMB_UPDATE_STEP); + if (envHotEmbStep != nullptr) { + GlobalEnv::hotEmbUpdateStep = std::stoi(envHotEmbStep); + } + + // 设置日志级别 + const char *envLogLevel = getenv(RecEnvNames::GLOG_STDERR_THRESHOLD); + if (envLogLevel != nullptr) { + GlobalEnv::glogStderrthreshold = std::stoi(envLogLevel); + } + + // 设置特征准入统计模式 + const char *envFAAEMode = getenv(RecEnvNames::USE_COMBINE_FAAE); + if (envFAAEMode != nullptr) { + GlobalEnv::useCombineFaae = (std::stoi(envFAAEMode) == 1); + } + + // 设置打开维测信息 + const char *envStat = getenv(RecEnvNames::STAT_ON); + if (envStat != nullptr) { + GlobalEnv::statOn = (std::stoi(envStat) == 1); + } + } + + void LogGlobalEnv() + { + LOG_DEBUG("Environment variables are: [{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}], " + "[{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}], [{}: {}]", + RecEnvNames::APPLY_GRADIENTS_STRATEGY, GlobalEnv::applyGradientsStrategy, + RecEnvNames::ACL_TIMEOUT, GlobalEnv::aclTimeout, + RecEnvNames::HD_CHANNEL_SIZE, GlobalEnv::hdChannelSize, + RecEnvNames::FIND_OFFSET_V2, GlobalEnv::findOffsetV2, + RecEnvNames::FIND_OFFSET_V3, GlobalEnv::findOffsetV3, + RecEnvNames::KEY_PROCESS_THREAD_NUM, GlobalEnv::keyProcessThreadNum, + RecEnvNames::MAX_UNIQUE_THREAD_NUM, GlobalEnv::maxUniqueThreadNum, + RecEnvNames::FAST_UNIQUE, GlobalEnv::fastUnique, + RecEnvNames::UPDATE_EMB_V2, GlobalEnv::updateEmbV2, + RecEnvNames::HOT_EMB_UPDATE_STEP, GlobalEnv::hotEmbUpdateStep, + RecEnvNames::GLOG_STDERR_THRESHOLD, GlobalEnv::glogStderrthreshold, + RecEnvNames::USE_COMBINE_FAAE, GlobalEnv::useCombineFaae, + RecEnvNames::STAT_ON, GlobalEnv::statOn); + } +} \ No newline at end of file diff --git a/src/core/utils/config.h b/src/core/utils/config.h new file mode 100644 index 00000000..58040cb3 --- /dev/null +++ b/src/core/utils/config.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: config module + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + +#ifndef MXREC_CONFIG_H +#define MXREC_CONFIG_H + +#include + +namespace MxRec { + namespace RecEnvNames { + const char *const APPLY_GRADIENTS_STRATEGY = "APPLY_GRADIENTS_STRATEGY"; + const char *const ACL_TIMEOUT = "AclTimeout"; + const char *const HD_CHANNEL_SIZE = "HD_CHANNEL_SIZE"; + const char *const FIND_OFFSET_V2 = "FIND_OFFSET_V2"; + const char *const FIND_OFFSET_V3 = "FIND_OFFSET_V3"; + const char *const KEY_PROCESS_THREAD_NUM = "KEY_PROCESS_THREAD_NUM"; + const char *const MAX_UNIQUE_THREAD_NUM = "MAX_UNIQUE_THREAD_NUM"; + const char *const FAST_UNIQUE = "FAST_UNIQUE"; + const char *const UPDATE_EMB_V2 = "UpdateEmb_V2"; + const char *const HOT_EMB_UPDATE_STEP = "HOT_EMB_UPDATE_STEP"; + const char *const GLOG_STDERR_THRESHOLD = "GLOG_stderrthreshold"; + const char *const USE_COMBINE_FAAE = "USE_COMBINE_FAAE"; + const char *const STAT_ON = "STAT_ON"; + }; + + namespace ApplyGradientsStrategyOptions { + const std::string DIRECT_APPLY = "direct_apply"; + const std::string SUM_SAME_ID_GRADIENTS_AND_APPLY = "sum_same_id_gradients_and_apply"; + }; + + struct GlobalEnv { + static std::string applyGradientsStrategy; + static int aclTimeout; + static int hdChannelSize; + static bool findOffsetV2; + static bool findOffsetV3; + static int keyProcessThreadNum; + static int maxUniqueThreadNum; + static bool fastUnique; + static bool updateEmbV2; + static int hotEmbUpdateStep; + static int glogStderrthreshold; + static bool useCombineFaae; + static bool statOn; + }; + + void ConfigGlobalEnv(); + void LogGlobalEnv(); +} + +#endif + diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 403db042..7116efcf 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -77,15 +77,14 @@ PYBIND11_MODULE(mxrec_pybind, m) void GetRankInfo(pybind11::module_& m) { pybind11::class_(m, "RankInfo") - .def(py::init>(), py::arg("rank_id"), py::arg("device_id"), - py::arg("local_rank_size"), py::arg("option"), py::arg("num_batch") = 1, + .def(py::init>(), py::arg("rank_id"), py::arg("device_id"), + py::arg("local_rank_size"), py::arg("option"), py::arg("max_step") = vector { -1, -1 }) .def_readwrite("rank_id", &RankInfo::rankId) .def_readwrite("device_id", &RankInfo::deviceId) .def_readwrite("rank_size", &RankInfo::rankSize) .def_readwrite("local_rank_size", &RankInfo::localRankSize) .def_readwrite("option", &RankInfo::option) - .def_readwrite("num_batch", &RankInfo::nBatch) .def_readwrite("max_step", &RankInfo::maxStep); } diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 623c001b..820fe57f 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -29,7 +29,7 @@ protected: int64_t int64Min { static_cast(UINT32_MAX) }; int maxChannelNum = MAX_CHANNEL_NUM; - int keyProcessThread = PerfConfig::keyProcessThreadNum; + int keyProcessThread = 1; int embInfoNum { 10 }; @@ -469,8 +469,9 @@ TEST_F(CheckpointTest, FeatAdmitNEvict) SetEmbInfo(); SetTable2Threshold(testTrens2Thresh); validTrens2Thresh = testTrens2Thresh; + bool isCombine = false; - if (GetCombineSwitch()) { + if (isCombine) { SetHistRecCombine(testHistRec); } else { SetHistRec(testHistRec); diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 55babf89..39dcc36f 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -42,7 +42,7 @@ protected: int64_t int64Min { static_cast(UINT32_MAX) }; int maxChannelNum { MAX_CHANNEL_NUM }; - int keyProcessThread { PerfConfig::keyProcessThreadNum }; + int keyProcessThread { 6 }; vector testEmbInfos; valid_int_t validEmbInfo; @@ -183,8 +183,9 @@ protected: EXPECT_EQ(args.validTrens2ThreshArr.at(embName), testSaveData.int32Arr); // need other test method EXPECT_EQ(args.validTrens2ThreshAttrib.at(embName), testSaveData.attribute); testSaveData = args.testCkpt.GetDataset(CkptDataType::HIST_REC, embName); + bool isCombine = false; - if (!GetCombineSwitch()) { + if (!isCombine) { EXPECT_EQ(1, args.validData.histRec.timestamps.count(embName)); EXPECT_EQ(1, args.validData.histRec.historyRecords.count(embName)); EXPECT_EQ(args.validHistRecAttrib.at(embName), testSaveData.attribute); @@ -202,8 +203,9 @@ protected: testLoadData.int32Arr = args.validTrens2ThreshArr.at(embName); testLoadData.attribute = args.validTrens2ThreshAttrib.at(embName); args.testCkpt.SetDataset(CkptDataType::TABLE_2_THRESH, embName, testLoadData); + bool isCombine = false; - if (!GetCombineSwitch()) { + if (!isCombine) { testLoadData.int64Arr = args.validHistRecArr.at(embName); testLoadData.attribute = args.validHistRecAttrib.at(embName); } else { @@ -258,8 +260,9 @@ TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) SetEmbInfo(); SetTable2Threshold(testTrens2Thresh, validTrens2ThreshArr, validTrens2ThreshAttrib); validTrens2Thresh = testTrens2Thresh; + bool isCombine = false; - if (GetCombineSwitch()) { + if (isCombine) { SetHistRecCombine(testHistRec, validHistRecArr, validHistRecAttrib); } else { SetHistRec(testHistRec, validHistRecArr, validHistRecAttrib); diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index 463798e8..384a0c37 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -379,16 +379,17 @@ protected: faae.SetCombineSwitch(); StartEvictThread(); - std::thread thrs[PerfConfig::keyProcessThreadNum]; + std::thread thrs[6]; + int threadNum = 1; // 测试多线程的 - for (int i = 0; i < PerfConfig::keyProcessThreadNum; ++i) { + for (int i = 0; i < threadNum; ++i) { std::string name("thread-"); name += std::to_string(i); thrs[i] = std::thread(TestMultiThread, this, std::ref(name)); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); } - for (int i = 0; i < PerfConfig::keyProcessThreadNum; ++i) { + for (int i = 0; i < threadNum; ++i) { if (thrs[i].joinable()) { thrs[i].join(); } @@ -408,8 +409,8 @@ protected: keys_t expectKeys2 = {121, 122, 125, 211, 212}; // 123被淘汰掉了 vector expectCnt2 = {5, 1, 1, 3, 4}; std::lock_guard lock(faae.m_syncMutexs); // 与 evict-thread 竞争资源 - CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tableName, PerfConfig::keyProcessThreadNum); - CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tableName, PerfConfig::keyProcessThreadNum); + CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tableName, threadNum); + CheckMultiThreadRet(expectKeys2, expectCnt2, thresholds[1].tableName, threadNum); } WaitEvictThread(); @@ -450,7 +451,7 @@ protected: void SetEnv() { - const char* name = "USE_COMBINE_FAAE"; + const char* name = "useCombineFaae"; const char* mode = "0"; int overwrite = 1; diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 7ee83494..be8f6c58 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -252,9 +252,9 @@ TEST_F(KeyProcessTest, Start) { ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); - setenv("KEY_PROCESS_THREAD_NUM", "2", 1); + setenv("keyProcessThreadNum", "2", 1); ASSERT_EQ(process.Start(), 0); - setenv("KEY_PROCESS_THREAD_NUM", "abc", 1); + setenv("keyProcessThreadNum", "abc", 1); ASSERT_EQ(process.Start(), 0); CTRLog(0, "key process start successful"); process.Destroy(); @@ -543,19 +543,3 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) this_thread::sleep_for(20s); process.Destroy(); } - -TEST(KeyProcess, SetupHotEmbUpdateStep) -{ - KeyProcess kp; - - kp.SetupHotEmbUpdateStep(); - ASSERT_EQ(kp.hotEmbUpdateStep, HOT_EMB_UPDATE_STEP_DEFAULT); - - putenv("HOT_EMB_UPDATE_STEP=1"); - kp.SetupHotEmbUpdateStep(); - ASSERT_EQ(kp.hotEmbUpdateStep, 1); - - putenv("HOT_EMB_UPDATE_STEP=0"); - kp.SetupHotEmbUpdateStep(); - ASSERT_EQ(kp.hotEmbUpdateStep, HOT_EMB_UPDATE_STEP_DEFAULT); -} diff --git a/src/tests/utils/common_test.cpp b/src/tests/utils/common_test.cpp index b9e835b9..20813619 100644 --- a/src/tests/utils/common_test.cpp +++ b/src/tests/utils/common_test.cpp @@ -16,17 +16,6 @@ using namespace std; using namespace MxRec; using namespace testing; -TEST(common, SetLog) -{ - int rankId = 0; - SetLog(rankId); - ASSERT_EQ(g_glogLevel, 0); - - putenv("GLOG_stderrthreshold=1"); - SetLog(rankId); - ASSERT_EQ(g_glogLevel, 1); -} - TEST(common, InitializeInfo) { NormalInitializerInfo nInfoTruncatedNormal; diff --git a/tests/mx_rec/validator/test_validators.py b/tests/mx_rec/validator/test_validators.py index 1b44c84b..72f180f1 100644 --- a/tests/mx_rec/validator/test_validators.py +++ b/tests/mx_rec/validator/test_validators.py @@ -6,7 +6,9 @@ import sys import tempfile import unittest -from mx_rec.validator.validator import Validator, StringValidator, DirectoryValidator +from mx_rec.validator.validator import Validator, StringValidator, DirectoryValidator, para_checker_decorator, \ + ClassValidator, OptionValidator, ValueCompareValidator, OptionalStringValidator, \ + OptionalIntValidator, NumValidator, IntValidator, Convert2intValidator sys.modules['mxrec_pybind'] = __import__('os') @@ -30,11 +32,11 @@ class ParameterCheckerTest(unittest.TestCase): super().tearDown() def test_validator_should_return_default_if_invalid(self): - validation = Validator('aa') + validation = Validator("v1", 'aa') validation.register_checker(lambda x: len(x) < 5, 'length of string should be less than 5') self.assertTrue(validation.is_valid()) try: - validation = Validator('123456') + validation = Validator("v2", '123456') validation.register_checker(lambda x: len(x) < 5, 'length of string should be less than 5') validation.is_valid() except ValueError as exp: @@ -44,57 +46,58 @@ class ParameterCheckerTest(unittest.TestCase): def test_string_validator_max_len_parameter(self): try: - StringValidator('aa.1245', max_len=3).check_string_length().check().is_valid() + StringValidator("val", 'aa.1245', max_len=3).check_string_length().check().is_valid() except ValueError as exp: self.assertEqual(type(exp), ValueError) else: self.fail("ValueError not raised.") - self.assertTrue(StringValidator('aa.1245', max_len=30).check().is_valid()) + self.assertTrue(StringValidator("val", 'aa.1245', max_len=30).check().is_valid()) # default infinity - self.assertTrue(StringValidator('aa.124512132456').check().is_valid()) + self.assertTrue(StringValidator("val", 'aa.124512132456').check().is_valid()) def test_string_validator_min_len_parameter(self): try: - StringValidator('aa', min_len=3).check_string_length().check().is_valid() + StringValidator("val", 'aa', min_len=3).check_string_length().check().is_valid() except ValueError as exp: self.assertEqual(type(exp), ValueError) else: self.fail("ValueError not raised.") - self.assertTrue(StringValidator('aa.', min_len=3).check().is_valid()) + self.assertTrue(StringValidator("val", 'aa.', min_len=3).check().is_valid()) # default 0 - self.assertTrue(StringValidator('1').check().is_valid()) + self.assertTrue(StringValidator("val", '1').check().is_valid()) def test_string_validator_can_be_transformed2int(self): - self.assertFalse(StringValidator('9' * 20).can_be_transformed2int().check().is_valid()) - self.assertFalse(StringValidator('1,2').can_be_transformed2int().check().is_valid()) - self.assertTrue(StringValidator('12').can_be_transformed2int().check().is_valid()) - self.assertFalse(StringValidator('12').can_be_transformed2int(min_value=100, max_value=200).check().is_valid()) + self.assertFalse(StringValidator("val", '9' * 20).can_be_transformed2int().check().is_valid()) + self.assertFalse(StringValidator("val", '1,2').can_be_transformed2int().check().is_valid()) + self.assertTrue(StringValidator("val", '12').can_be_transformed2int().check().is_valid()) + self.assertFalse( + StringValidator("val", '12').can_be_transformed2int(min_value=100, max_value=200).check().is_valid()) def test_directory_black_list(self): try: - DirectoryValidator('/abc/d/e').with_blacklist(lst=['/abc/d/e']).check().is_valid() + DirectoryValidator("val", '/abc/d/e').with_blacklist(lst=['/abc/d/e']).check().is_valid() except ValueError as exp: self.assertEqual(type(exp), ValueError) else: self.fail("ValueError not raised.") - self.assertTrue(DirectoryValidator('/abc/d/e').with_blacklist(['/abc/d/']).check().is_valid()) - self.assertTrue(DirectoryValidator('/abc/d/e').with_blacklist(['/abc/d/'], exact_compare=True).check() + self.assertTrue(DirectoryValidator("val", '/abc/d/e').with_blacklist(['/abc/d/']).check().is_valid()) + self.assertTrue(DirectoryValidator("val", '/abc/d/e').with_blacklist(['/abc/d/'], exact_compare=True).check() .is_valid()) # if not exact compare, the /abc/d/e is children path of /abc/d/, so it is invalid try: - self.assertFalse(DirectoryValidator('/abc/d/e').with_blacklist(['/abc/d/'], exact_compare=False) + self.assertFalse(DirectoryValidator("val", '/abc/d/e').with_blacklist(['/abc/d/'], exact_compare=False) .check().is_valid()) except ValueError as exp: self.assertEqual(type(exp), ValueError) else: self.fail("ValueError not raised.") - self.assertTrue(DirectoryValidator('/usr/bin/bash').with_blacklist().check().is_valid()) + self.assertTrue(DirectoryValidator("val", '/usr/bin/bash').with_blacklist().check().is_valid()) try: - DirectoryValidator('/usr/bin/bash').with_blacklist(exact_compare=False).check().is_valid() + DirectoryValidator("val", '/usr/bin/bash').with_blacklist(exact_compare=False).check().is_valid() except ValueError as exp: self.assertEqual(type(exp), ValueError) else: @@ -122,7 +125,7 @@ class ParameterCheckerTest(unittest.TestCase): try: # do stuff with temp tmp.write(b'stuff') - DirectoryValidator(path).check_not_soft_link().check().is_valid() + DirectoryValidator("val", path).check_not_soft_link().check().is_valid() except ValueError as exp: self.assertEqual(type(exp), ValueError) else: @@ -135,26 +138,46 @@ class ParameterCheckerTest(unittest.TestCase): def test_directory_check(self): try: - DirectoryValidator('a/b/.././c/a.txt').check_not_soft_link().check().is_valid() + DirectoryValidator("val", 'a/b/.././c/a.txt').check_not_soft_link().check().is_valid() except ValueError as exp: self.assertEqual(type(exp), ValueError) else: self.fail("ValueError not raised.") try: - DirectoryValidator("").check_is_not_none().check().is_valid() + DirectoryValidator("val", "").check_is_not_none().check().is_valid() except ValueError as exp: self.assertEqual(type(exp), ValueError) else: self.fail("ValueError not raised.") try: - DirectoryValidator(None).check_is_not_none().check().is_valid() + DirectoryValidator("val", None).check_is_not_none().check().is_valid() except ValueError as exp: self.assertEqual(type(exp), ValueError) else: self.fail("ValueError not raised.") - self.assertTrue(DirectoryValidator("a/bc/d").check().is_valid()) + self.assertTrue(DirectoryValidator("val", "a/bc/d").check().is_valid()) + + def test_decorator(self): + @para_checker_decorator(check_option_list=[ + ("class_arg", ClassValidator, {"classed": (bool,)}), + ("options_arg", OptionValidator, {"options": (1, 2, 3)}), + (["options_arg", "int_arg"], ValueCompareValidator, {"target": -1}, ["check_all_not_equal_to_target"]), + ("string_arg", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]), + ("int_arg", IntValidator, {"min_value": 1, "max_value": 100}, ["check_value"]), + ("int_arg", OptionalIntValidator, {"min_value": 1, "max_value": 100}, ["check_value"]), + ("int_arg", NumValidator, {"min_value": 1, "max_value": 100}, ["check_value"]), + ("string_arg", Convert2intValidator, {"min_value": 1, "max_value": 100}, ["check_value"]), + ]) + def demo_func(class_arg: bool, options_arg: int, string_arg: str, int_arg: int): + return True + + try: + result = demo_func(class_arg=True, options_arg=1, string_arg="72", int_arg=10) + except ValueError: + result = False + self.assertTrue(result) if __name__ == '__main__': -- Gitee From e5a6c0b974bbc17f51d2fdf9c1b20de8eaaf7b39 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Sep 2023 15:38:01 +0800 Subject: [PATCH 344/551] Match-id-8062dcd196058f1306b1388e39984e96b6b3a410 --- src/core/ssd_engine/file.cpp | 16 ++++++++++++---- src/core/ssd_engine/file.h | 3 +++ src/core/ssd_engine/table.cpp | 15 ++++++++++++++- src/core/ssd_engine/table.h | 4 +++- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/core/ssd_engine/file.cpp b/src/core/ssd_engine/file.cpp index 995fdc7e..82d56c78 100644 --- a/src/core/ssd_engine/file.cpp +++ b/src/core/ssd_engine/file.cpp @@ -116,13 +116,15 @@ void File::InsertEmbeddings(vector &keys, vector> &embe keyToOffset[keys[i]] = lastWriteOffset; uint64_t embSize = embeddings[i].size(); + if (embSize > maxEmbSize) { + throw invalid_argument("embedding size too large"); + } localFileData.write(reinterpret_cast(&embSize), sizeof(embSize)); - localFileData.write(reinterpret_cast(embeddings[i].data()), - embeddings[i].size() * sizeof(float)); + localFileData.write(reinterpret_cast(embeddings[i].data()), embSize * sizeof(float)); auto pos = localFileData.tellp(); if (pos == -1) { - throw runtime_error("can't get file position pointer"); + throw runtime_error("can't get file position pointer, write data failed"); } lastWriteOffset = offset_t(pos); } @@ -139,11 +141,17 @@ vector> File::FetchEmbeddings(vector &keys) } localFileData.seekg(it->second); // for fstream, this moves the file position pointer (both put and get) if (localFileData.fail()) { - throw runtime_error("can't move file position pointer"); + throw runtime_error("can't move file position pointer, read data failed"); } uint64_t embSize; localFileData.read(reinterpret_cast(&embSize), sizeof(embSize)); + if (localFileData.fail()) { + throw invalid_argument("read embedding size failed, file may broken"); + } + if (embSize > maxEmbSize) { + throw invalid_argument("embedding size too large, file may broken"); + } vector tmp; tmp.resize(embSize); diff --git a/src/core/ssd_engine/file.h b/src/core/ssd_engine/file.h index 5233265d..8ae9ac26 100644 --- a/src/core/ssd_engine/file.h +++ b/src/core/ssd_engine/file.h @@ -57,6 +57,9 @@ namespace MxRec { fstream localFileData{}; fstream localFileMeta{}; + // for safety validation + const uint64_t maxEmbSize = 8192 * 10; // x10 for optimizer state data + uint64_t dataCnt = 0; uint64_t staleDataCnt = 0; unordered_map keyToOffset{}; // offset_t >> maxDataNumInFile * embDataSize diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index a064ee61..7fe23f2f 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -140,10 +140,16 @@ void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) LOG_INFO("table:{}, start load data file", name); uint64_t fileCnt; metaFile->read(reinterpret_cast(&fileCnt), sizeof(fileCnt)); + if (metaFile->fail()) { + throw invalid_argument("fail to read nFile, meta file broken"); + } uint64_t fileID; uint64_t fidSize = sizeof(fileID); for (uint64_t i = 0; i < fileCnt; ++i) { metaFile->read(reinterpret_cast(&fileID), fidSize); + if (metaFile->fail()) { + throw invalid_argument("fail to read fileID, meta file broken"); + } if (fileID > curMaxFileID) { curMaxFileID = fileID; } @@ -200,10 +206,17 @@ void Table::Load(const string &metaFilePath, int step) // Load table name and validate uint32_t nameSize; metaFile->read(reinterpret_cast(&nameSize), sizeof(nameSize)); + if (metaFile->fail()) { + throw invalid_argument("fail to read table name size"); + } if (nameSize > maxNameSize) { throw invalid_argument("table name too large, file may broken"); } - char *tmpArr = new char[nameSize + 1]; + char tmpArr[nameSize + 1]; + auto ec = memset_s(tmpArr, nameSize + 1, '\0', nameSize + 1); + if (ec != EOK) { + throw runtime_error("fail to init table name array"); + } metaFile->read(tmpArr, static_cast(nameSize)); tmpArr[nameSize] = '\0'; string tbNameInFile = tmpArr; diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h index 64044126..3218b871 100644 --- a/src/core/ssd_engine/table.h +++ b/src/core/ssd_engine/table.h @@ -64,10 +64,12 @@ namespace MxRec { mutex rwLock{}; shared_ptr curFile = nullptr; uint64_t curMaxFileID = 0; // no concurrent writing, always atomic increase - const uint32_t maxNameSize = 1024; const string saveDirPrefix = "ssd_sparse_model_rank_"; const int convertToPercentage = 100; + // for safety validation + const uint32_t maxNameSize = 1024; + /* args for performance(not expose to user yet) * 2 read thread is optimal when: * embedding's dimension=240, maxDataNumInFile=10000 -- Gitee From 12271c64a3e0b672c118cd49b9f4c4ef4a4f3cad Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Sep 2023 15:48:16 +0800 Subject: [PATCH 345/551] Match-id-5448b6c2373002f0d18b6155044cf1664c5b2ff6 --- mx_rec/core/asc/feature_spec.py | 8 +++++++- mx_rec/core/embedding.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index d1e8d7fe..fb8fcc78 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -34,7 +34,8 @@ class FeatureSpec: ("batch_size", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), ("faae_coefficient", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]) ]) - def __init__(self, name: str, table_name: str, + def __init__(self, name: str, + table_name: Optional[str] = None, index_key: Optional[str] = None, access_threshold: Optional[int] = None, eviction_threshold: Optional[int] = None, is_timestamp: Optional[bool] = None, @@ -93,6 +94,10 @@ class FeatureSpec: def pipeline_mode(self): return self._pipeline_mode + @feat_cnt.setter + def feat_cnt(self, feat_cnt: int): + self._feat_cnt = feat_cnt + @staticmethod def include_timestamp(is_training): if is_training: @@ -197,6 +202,7 @@ def set_temporary_feature_spec_attribute(mock_feature_spec: FeatureSpec, total_f """ mock_feature_spec.batch_size = total_feature_count + mock_feature_spec.feat_cnt = 1 mock_feature_spec.dims = [total_feature_count, 1] mock_feature_spec.initialized = True mock_feature_spec.pipeline_mode.add(True) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 1045bb56..7fb95af3 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -53,7 +53,7 @@ from mx_rec.util.log import logger ("hashtable_threshold", NumValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]) ]) def create_table(key_dtype, dim, name, emb_initializer, - optimizer_list=Optional[list], + optimizer_list: Optional[list] = None, device_vocabulary_size=1, host_vocabulary_size=0, ssd_vocabulary_size=0, -- Gitee From 1cd5094424af3ebda92f3c403a1d30cdae30f4a2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Sep 2023 16:08:26 +0800 Subject: [PATCH 346/551] Match-id-022ce310d2438db3089d14e848c8799dacb2507c --- mx_rec/core/asc/manager.py | 9 +++++--- mx_rec/graph/patch.py | 25 +++++++++++----------- mx_rec/util/initialize.py | 20 +++++++++++++++++ src/core/hybrid_mgmt/hybrid_mgmt.cpp | 11 ++++++++-- src/core/hybrid_mgmt/hybrid_mgmt.h | 4 +++- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 15 ++++++++++--- src/core/hybrid_mgmt/hybrid_mgmt_block.h | 6 ++++-- src/core/key_process/key_process.cpp | 6 ++---- src/pybind/module_main.cpp | 3 ++- 9 files changed, 70 insertions(+), 29 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index eaf3faf1..faf63a8d 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -7,7 +7,7 @@ import tensorflow as tf from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ - is_asc_manager_initialized, get_train_steps, get_eval_steps, \ + is_asc_manager_initialized, get_train_steps, get_eval_steps, get_save_steps, \ export_table_instances, export_feature_spec, get_if_load, get_use_static, \ get_use_hot, get_stat_on, get_use_dynamic_expansion, export_optimizer, export_dangling_table, export_table_num from mx_rec.core.asc.merge_table import find_dangling_table, should_skip @@ -188,6 +188,8 @@ def initialize_emb_cache(table_info_list, threshold_list): rank_size = get_rank_size() train_steps = get_train_steps() eval_steps = get_eval_steps() + save_steps = get_save_steps() + if_load = get_if_load() option = 0 if get_use_static(): @@ -197,8 +199,9 @@ def initialize_emb_cache(table_info_list, threshold_list): if get_use_dynamic_expansion(): option = option | USE_DYNAMIC_EXPANSION - # [train_steps, eval_steps] pass step information to HybridMgmt for data process loop - rank_info = RankInfo(rank_id, device_id, rank_size, option, [train_steps, eval_steps]) + # [train_steps, eval_steps, save_steps] pass step information to HybridMgmt for data process loop + rank_info = RankInfo(rank_id, device_id, rank_size, option, + n_batch_to_prefetch, [train_steps, eval_steps, save_steps]) emb_cache = HybridMgmt() diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 4f0845c1..6c107293 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -104,16 +104,7 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): return this_channel_id # patch的方式为session增加步数属性 - step = self.get_mxrec_steps() - # 进行缓存,避免每次都进行图查询 - if not step: - step = 1 - for custom_optimizer in self.get_config().graph_options.rewrite_options.custom_optimizers: - if custom_optimizer.name == "NpuOptimizer": - step = custom_optimizer.parameter_map["iterations_per_loop"].i - break - self.steps = step - + steps = self.get_mxrec_steps() # patch的方式为图增加缓存属性 name2channel_cache = self.get_mxrec_name2channel_cache() @@ -125,10 +116,13 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): channel_id = -1 if channel_id != -1: - get_asc_manager().send_message_to_hybrid(channel_id, self.steps) + get_asc_manager().block_notify_wake(channel_id) #调用tensorflow原生的方法 - return self.old_run_method(fetches, feed_dict, options, run_metadata) + result = self.old_run_method(fetches, feed_dict, options, run_metadata) + if channel_id != -1: + get_asc_manager().block_count_steps(channel_id, steps) + return result def patch_for_dataset(): @@ -142,7 +136,12 @@ def patch_for_session(): # 不能在未调用非__init__函数之前调用非__init__中定义的实例化属性 return self.steps except AttributeError: - self.steps = None + self.steps = 1 + for custom_optimizer in self.get_config().graph_options.rewrite_options.custom_optimizers: + if custom_optimizer.name == "NpuOptimizer" \ + and custom_optimizer.parameter_map["iterations_per_loop"].i != 0: + self.steps = custom_optimizer.parameter_map["iterations_per_loop"].i + break return self.steps def get_mxrec_name2channel_cache(self): diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 1cea5677..e0d8854b 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -49,6 +49,7 @@ class ConfigInitializer: self._is_frozen = False self._train_steps = None self._eval_steps = None + self._save_steps = None self._if_load = None self._table_instance_dict = dict() self._dangling_table = [] @@ -88,6 +89,8 @@ class ConfigInitializer: chief_device=global_env.cm_chief_device) self.train_steps = kwargs.get("train_steps", -1) self.eval_steps = kwargs.get("eval_steps", -1) + self.save_steps = kwargs.get("save_steps", -1) + self.if_load = kwargs.get("if_load", False) self.use_static = not kwargs.get("use_dynamic", True) @@ -197,6 +200,10 @@ class ConfigInitializer: def eval_steps(self): return self._eval_steps + @property + def save_steps(self): + return self._save_steps + @property def if_load(self): return self._if_load @@ -346,6 +353,11 @@ class ConfigInitializer: check_step(steps) self._eval_steps = steps + @save_steps.setter + def save_steps(self, steps): + check_step(steps) + self._save_steps = steps + @if_load.setter def if_load(self, flag): if not isinstance(flag, bool): @@ -604,6 +616,10 @@ def get_eval_steps(): return ConfigInitializer.get_instance().eval_steps +def get_save_steps(): + return ConfigInitializer.get_instance().save_steps + + def set_train_steps(steps: int): ConfigInitializer.get_instance().train_steps = steps @@ -612,6 +628,10 @@ def set_eval_steps(steps: int): ConfigInitializer.get_instance().eval_steps = steps +def set_save_steps(steps: int): + ConfigInitializer.get_instance().save_steps = steps + + def get_table_instance(key): return ConfigInitializer.get_instance().get_table_instance(key) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 28298ee6..2573c4b3 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -954,11 +954,18 @@ int HybridMgmt::GetStepFromPath(const string& loadPath) return 0; } -/// 通过pyBind在python侧调用,通知hybridMgmt上层即将进行图的执行 +/// 通过pyBind在python侧调用,通知hybridMgmt上层即将进行图的执行,需要进行唤醒 /// \param channelID 通道id /// \param steps 运行的步数,由于可能存在循环下沉,所以1个session run 对应N步 -void HybridMgmt::CallBySessionRun(int channelID, int steps) +void HybridMgmt::NotifyBySessionRun(int channelID) { hybridMgmtBlock->CheckAndNotifyWake(channelID); +} + +/// 通过pyBind在python侧调用,通知hybridMgmt上层即将进行图的执行 +/// \param channelID 通道id +/// \param steps 运行的步数,由于可能存在循环下沉,所以1个session run 对应N步 +void HybridMgmt::CountStepBySessionRun(int channelID, int steps) +{ hybridMgmtBlock->CountPythonStep(channelID, steps); } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index ebba695c..563296d1 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -116,7 +116,9 @@ namespace MxRec { bool IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t& embTableCount); - void CallBySessionRun(int channelID, int steps); + void NotifyBySessionRun(int channelID); + + void CountStepBySessionRun(int channelID, int steps); private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index 662582fc..550564d4 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -14,6 +14,14 @@ /// \param channelId train 0 eval 1 void HybridMgmtBlock::CheckAndSetBlock(int channelId) { + // 判断save时候的阻塞情况 + // 当在进行训练通道,且save interval不为0和-1(不需要阻塞),且运行到了需要阻塞的步骤 + if (channelId==TRAIN_CHANNEL_ID && saveInterval!=0 + && saveInterval!=-1 && hybridBatchId[TRAIN_CHANNEL_ID]%saveInterval==0) { + LOG_DEBUG(HYBRID_BLOCKING + "blocking by save saveInterval {} pythonBatchId {} hybridBatchId {}", + saveInterval, pythonBatchId[channelId], hybridBatchId[channelId]); + isBlock[TRAIN_CHANNEL_ID] = true; + } if (stepsInterval[channelId] == -1) { return; } @@ -95,7 +103,7 @@ void HybridMgmtBlock::CheckValid(int channelId) lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); } else if (pythonBatchId[lastRunChannelId] < hybridBatchId[lastRunChannelId]) { // 在通道切换时,上一个通道处理的数据超出了python侧的调用 - if (!WaitValid(lastRunChannelId)) { + if (!rankInfo.noDDR and !WaitValid(lastRunChannelId)) { throw HybridMgmtBlockingException("when channel switch"); } } else { @@ -193,8 +201,9 @@ void HybridMgmtBlock::Destroy() void HybridMgmtBlock::SetRankInfo(RankInfo rankInfo) { - this->stepsInterval[0] = rankInfo.maxStep[0]; - this->stepsInterval[1] = rankInfo.maxStep[1]; + this->stepsInterval[TRAIN_CHANNEL_ID] = rankInfo.maxStep[TRAIN_CHANNEL_ID]; + this->stepsInterval[EVAL_CHANNEL_ID] = rankInfo.maxStep[EVAL_CHANNEL_ID]; + this->saveInterval = rankInfo.maxStep[SAVE_STEP_INDEX]; this->rankInfo = rankInfo; }; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h index 5b0821f9..257de095 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -15,7 +15,7 @@ #include "utils/singleton.h" using namespace MxRec; const std::string HYBRID_BLOCKING = "[HYBRID_BLOCKING] "; -const std::string D2H_CHANNEL_NAME_PRE = "d2h_notify_hybridmgmt_"; +const int SAVE_STEP_INDEX = 2; const std::chrono::milliseconds SLEEP_MS = 20ms; class HybridMgmtBlock { public: @@ -48,6 +48,8 @@ private: int stepsInterval[2] = {0, 0}; // 控制通道阻塞的变量 bool isBlock[2] = {true, true}; + // 控制训练了多少步进行保存的步数 + int saveInterval = 0; RankInfo rankInfo; }; @@ -65,7 +67,7 @@ public: "currentBatchNumber is %d. please check your setting of train " "steps and eval steps", scene.c_str(), channelId, preprocessBatchNumber, currentBatchNumber); - LOG(ERROR) << str; + LOG_ERROR(str); } private: diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 003c8464..fede6f71 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -328,7 +328,6 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, UniquePtr& unique, int channel, int threadId) { - std::lock_guard lock(loadSaveMut[channel][threadId]); // tuple for keyRec restore hotPos scAll countRecv isWithFAAE = m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE; @@ -347,7 +346,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat rankInfo.rankId, threadId, channel); return false; } - + std::lock_guard lock(loadSaveMut[channel][threadId]); // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) if (rankInfo.noDDR) { @@ -383,7 +382,6 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) { - std::lock_guard lock(loadSaveMut[channel][threadId]); vector splitKeys; vector restore; vector hotPos; @@ -397,7 +395,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { countRecv = GetCountRecv(batch, threadId, keyCount, scAll, ss); } - + std::lock_guard lock(loadSaveMut[channel][threadId]); BuildRestoreVec(batch, ss, restore, static_cast(hotPos.size())); // 特征准入&淘汰 diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 7116efcf..18277c3b 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -164,7 +164,8 @@ void GetHybridMgmt(pybind11::module_& m) .def("evict", &MxRec::HybridMgmt::Evict) .def("send", &MxRec::HybridMgmt::SendHostMap, py::arg("table_name") = "") .def("receive", &MxRec::HybridMgmt::ReceiveHostMap, py::arg("key_offset_map")) - .def("send_message_to_hybrid", &MxRec::HybridMgmt::CallBySessionRun, py::arg("channel_id"), py::arg("steps")=1); + .def("block_notify_wake", &MxRec::HybridMgmt::NotifyBySessionRun, py::arg("channel_id")) + .def("block_count_steps", &MxRec::HybridMgmt::CountStepBySessionRun, py::arg("channel_id"), py::arg("steps")=1); } void GetThresholdValue(pybind11::module_& m) -- Gitee From 5fd3e6386d516041a2f27dfa5ce917522512d806 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Sep 2023 19:26:02 +0800 Subject: [PATCH 347/551] Match-id-96d5502ebc696ba866d55ad24f1fe67a16930ed3 --- mx_rec/core/asc/manager.py | 3 +-- mx_rec/util/initialize.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index faf63a8d..dcac0ec8 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -200,8 +200,7 @@ def initialize_emb_cache(table_info_list, threshold_list): option = option | USE_DYNAMIC_EXPANSION # [train_steps, eval_steps, save_steps] pass step information to HybridMgmt for data process loop - rank_info = RankInfo(rank_id, device_id, rank_size, option, - n_batch_to_prefetch, [train_steps, eval_steps, save_steps]) + rank_info = RankInfo(rank_id, device_id, rank_size, option, [train_steps, eval_steps, save_steps]) emb_cache = HybridMgmt() diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index e0d8854b..f6895702 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -30,6 +30,7 @@ class ConfigInitializer: ("use_mpi", ClassValidator, {"classes": (bool, )}), ("train_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), ("eval_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), + ("save_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), (["train_steps", "eval_steps"], ValueCompareValidator, {"target": 0}, ["check_at_least_one_not_equal_to_target"]), ("if_load", ClassValidator, {"classes": (bool, )}), -- Gitee From ea80f6b81ebdf60efc436cdbf4ec814fc5707478 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Sep 2023 22:16:49 +0800 Subject: [PATCH 348/551] Match-id-5c099b2c3755b21a2f4d64d15116119e978bcab7 --- mx_rec/core/asc/feature_spec.py | 12 +++++++++--- mx_rec/core/asc/helper.py | 11 ++++++++++- mx_rec/core/embedding.py | 7 ++++--- mx_rec/core/feature_process.py | 3 ++- mx_rec/optimizers/adagrad.py | 2 +- mx_rec/optimizers/ftrl.py | 6 +++--- mx_rec/optimizers/gradient_descent.py | 2 +- mx_rec/optimizers/gradient_descent_by_addr.py | 2 +- mx_rec/optimizers/lazy_adam.py | 2 +- mx_rec/optimizers/lazy_adam_by_addr.py | 2 +- mx_rec/optimizers/momentum.py | 2 +- mx_rec/saver/patch.py | 14 ++++++++++---- mx_rec/validator/validator.py | 6 +++--- 13 files changed, 47 insertions(+), 24 deletions(-) diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index fb8fcc78..1eb0c35a 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -25,13 +25,19 @@ class FeatureSpec: use_timestamp_eval = False @para_checker_decorator(check_option_list=[ - ("name", StringValidator, {"max_len": 255}, ["check_string_length"]), - ("table_name", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]), - ("index_key", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]), + ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("table_name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("table_name", ClassValidator, {"classes": (str, type(None))}), + ("index_key", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("index_key", OptionalIntValidator, {"min_value": 0, "max_value": 255}, ["check_value"]), + ("index_key", ClassValidator, {"classes": (str, int, type(None))}), ("access_threshold", OptionalIntValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), + ("access_threshold", ClassValidator, {"classes": (int, type(None))}), ("eviction_threshold", OptionalIntValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), + ("eviction_threshold", ClassValidator, {"classes": (int, type(None))}), ("is_timestamp", ClassValidator, {"classes": (bool, type(None))}), ("batch_size", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), + ("batch_size", ClassValidator, {"classes": (int, type(None))}), ("faae_coefficient", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]) ]) def __init__(self, name: str, diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 64bffc04..3ec042f5 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -9,8 +9,10 @@ import tensorflow as tf from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.merge_table import find_dangling_table, should_skip -from mx_rec.validator.validator import para_checker_decorator, ValueCompareValidator +from mx_rec.validator.validator import para_checker_decorator, ValueCompareValidator, ClassValidator, \ + OptionalIntValidator from mx_rec.util.log import logger +from mx_rec.constants.constants import MAX_INT32 @para_checker_decorator(check_option_list=[ @@ -18,6 +20,13 @@ from mx_rec.util.log import logger ["check_at_least_one_not_equal_to_target"]), (["tgt_key_specs", "args_index_list"], ValueCompareValidator, {"target": None}, ["check_at_least_one_equal_to_target"]), + ("tgt_key_specs", ClassValidator, {"classes": (FeatureSpec, type(None))}), + ("args_index_list", ClassValidator, {"classes": (list, type(None))}), + ("feature_numbers", ClassValidator, {"classes": (int, type(None))}), + ("feature_numbers", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), + ("table_names", ClassValidator, {"classes": (list, type(None))}), + ("is_training", ClassValidator, {"classes": (bool, type(None))}), + ("dump_graph", ClassValidator, {"classes": (bool, type(None))}), ]) def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, feature_numbers=None, table_names=None, **kwargs): diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 7fb95af3..597167b9 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -25,7 +25,7 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, ge get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set, \ get_table_instance_by_name from mx_rec.validator.validator import ClassValidator, StringValidator, SSDFeatureValidator, \ - para_checker_decorator, IntValidator, NumValidator, OptionValidator + para_checker_decorator, IntValidator, NumValidator, OptionValidator, OptionalIntValidator from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.util.normalization import fix_invalid_table_name from mx_rec.util.global_env_conf import global_env @@ -36,7 +36,7 @@ from mx_rec.util.log import logger ("key_dtype", OptionValidator, {"options": (tf.int64, tf.int32, tf.string)}), ("dim", ClassValidator, {"classes": (int, tf.TensorShape)}), ("dim", NumValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), - ("name", StringValidator, {"max_len": 255}, ["check_string_length", "check_whitelist"]), + ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length", "check_whitelist"]), ("emb_initializer", ClassValidator, {"classes": (InitializerV1, InitializerV2)}), ("optimizer_list", ClassValidator, {"classes": (list, type(None))}), (["ssd_vocabulary_size", "ssd_data_path", "host_vocabulary_size"], SSDFeatureValidator), @@ -777,9 +777,10 @@ class SparseEmbedding: ("ids", ClassValidator, {"classes": (FeatureSpec, tf.Tensor)}), ("is_train", ClassValidator, {"classes": (bool, )}), ("send_count", ClassValidator, {"classes": (int, type(None))}), + ("send_count", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), ("name", ClassValidator, {"classes": (str, type(None))}), ("modify_graph", ClassValidator, {"classes": (bool, type(None))}), - ("batch", ClassValidator, {"classes": (dict, type(None))}), + ("batch", ClassValidator, {"classes": (dict, list, tuple, type(None))}), ("access_and_evict_config", ClassValidator, {"classes": (dict, type(None))}), ]) def sparse_lookup(hashtable: SparseEmbedding, diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py index 85b00b59..c2c16f11 100644 --- a/mx_rec/core/feature_process.py +++ b/mx_rec/core/feature_process.py @@ -16,9 +16,10 @@ class _EvictHook(tf.compat.v1.train.SessionRunHook): """Sets evict based on global step or time.""" @para_checker_decorator( check_option_list=[ - ("evict_enable", ClassValidator, {"classes": (bool,)}), + ("evict_enable", ClassValidator, {"classes": (bool, )}), ("evict_time_interval", IntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), ("evict_step_interval", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), + ("evict_step_interval", ClassValidator, {"classes": (int, type(None))}), ] ) def __init__(self, diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index 4c396d7f..c4b23cdd 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -23,7 +23,7 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), ("initial_accumulator_value", NumValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool, )}), - ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) + ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate=0.001, initial_accumulator_value=0.9, diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 1cd342ea..5cbc83bd 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -34,9 +34,9 @@ from mx_rec.validator.validator import para_checker_decorator, OptionalStringVal ("l2_regularization_strength", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), ("l2_shrinkage_regularization_strength", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool,)}), - ("name", StringValidator, {"max_len": 255}, ["check_string_length"]), - ("accum_name", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]), - ("linear_name", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]) + ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("accum_name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("linear_name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl", **kwargs): return CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 5c9b1d1a..39f5725b 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -21,7 +21,7 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, @para_checker_decorator(check_option_list=[ ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool,)}), - ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) + ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate, use_locking=False, name="GradientDescent"): return CustomizedGradientDescent(learning_rate=learning_rate, use_locking=use_locking, name=name) diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index 6af75e3d..7542121d 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -21,7 +21,7 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), ("weight_decay", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool,)}), - ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) + ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer_by_addr(learning_rate, weight_decay=0.0001, use_locking=False, name="GradientDescentByAddr"): optimizer_by_addr = CustomizedGradientDescentByAddr(learning_rate=learning_rate, diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 2ec61c64..8b1670f3 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -29,7 +29,7 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("beta1", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), ("beta2", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), ("epsilon", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), - ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) + ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdam"): """ diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index d7cdda4b..1e601221 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -23,7 +23,7 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("beta1", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), ("beta2", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), ("epsilon", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), - ("name", StringValidator, {"max_len": 255}, ["check_string_length"]) + ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer_by_address(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdamByAddress"): diff --git a/mx_rec/optimizers/momentum.py b/mx_rec/optimizers/momentum.py index 7e585f54..dff1a14f 100644 --- a/mx_rec/optimizers/momentum.py +++ b/mx_rec/optimizers/momentum.py @@ -24,7 +24,7 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), ("mom", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool,)}), - ("name", StringValidator, {"max_len": 255}, ["check_string_length"]), + ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), ("enable_nesterov", ClassValidator, {"classes": (bool,)}), ]) def create_hash_optimizer(learning_rate=0.001, mom=0.9, use_locking=False, name="momentum", diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 3a53ba30..47141c7f 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -23,11 +23,14 @@ from tensorflow.python.util import compat from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util +import numpy as np from mx_rec.saver.saver import Saver as SparseSaver from mx_rec.util.initialize import get_ascend_global_hashtable_collection, export_removing_var_list -from mx_rec.validator.validator import para_checker_decorator, ClassValidator +from mx_rec.validator.validator import para_checker_decorator, ClassValidator, StringValidator, OptionalIntValidator, \ + OptionalStringValidator from mx_rec.util.log import logger +from mx_rec.constants.constants import MAX_INT32 def get_sparse_vars(var_list): @@ -172,10 +175,13 @@ def build(self): @para_checker_decorator(check_option_list=[ ("sess", ClassValidator, {"classes": (tf.compat.v1.Session, tf.compat.v1.train.MonitoredSession)}), - ("save_path", ClassValidator, {"classes": (str, )}), - ("global_step", ClassValidator, {"classes": (int, type(None))}), + ("save_path", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("global_step", ClassValidator, {"classes": (int, np.int64, type(None))}), + ("global_step", OptionalIntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), ("latest_filename", ClassValidator, {"classes": (str, type(None))}), + ("latest_filename", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), ("meta_graph_suffix", ClassValidator, {"classes": (str, type(None))}), + ("meta_graph_suffix", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), ("write_meta_graph", ClassValidator, {"classes": (bool, type(None))}), ("write_state", ClassValidator, {"classes": (bool, type(None))}), ("strip_default_attrs", ClassValidator, {"classes": (bool, type(None))}), @@ -221,7 +227,7 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra @para_checker_decorator(check_option_list=[ ("sess", ClassValidator, {"classes": (tf.compat.v1.Session, tf.compat.v1.train.MonitoredSession)}), - ("save_path", ClassValidator, {"classes": (str, )}) + ("save_path", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), ]) def restore(self, sess, save_path): if save_path is None: diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index ce68162a..e26a5b68 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -273,9 +273,9 @@ class OptionalStringValidator(StringValidator): """ def __init__(self, name, value, max_len=None, min_len=0, element: Optional[str] = None, msg=""): - if value is None: + if not isinstance(value, str): super(OptionalStringValidator, self).__init__(name, "", None, None, None, msg) - else: + elif isinstance(value, str): super(OptionalStringValidator, self).__init__(name, value, max_len, min_len, element, msg) @@ -368,7 +368,7 @@ class OptionalIntValidator(IntValidator): def __init__(self, name: str, value: int, min_value: int = None, max_value: int = None, invalid_options: List = None, constrained_options: List = None, msg: str = ""): - if value is None: + if not isinstance(value, int): super(OptionalIntValidator, self).__init__(name, 0, None, None, None, None, msg) else: super(OptionalIntValidator, self).__init__(name, value, min_value, max_value, -- Gitee From 2e162c4392adc4a502fcb022f9934633aa24a92e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 15 Sep 2023 09:48:29 +0800 Subject: [PATCH 349/551] Match-id-f875be55e759aed37cd8d9584b24e667facafb43 --- mx_rec/core/asc/helper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 3ec042f5..6c88344d 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -20,9 +20,9 @@ from mx_rec.constants.constants import MAX_INT32 ["check_at_least_one_not_equal_to_target"]), (["tgt_key_specs", "args_index_list"], ValueCompareValidator, {"target": None}, ["check_at_least_one_equal_to_target"]), - ("tgt_key_specs", ClassValidator, {"classes": (FeatureSpec, type(None))}), + ("tgt_key_specs", ClassValidator, {"classes": (FeatureSpec, list, tuple, type(None))}), ("args_index_list", ClassValidator, {"classes": (list, type(None))}), - ("feature_numbers", ClassValidator, {"classes": (int, type(None))}), + ("feature_numbers", ClassValidator, {"classes": (int, list, type(None))}), ("feature_numbers", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), ("table_names", ClassValidator, {"classes": (list, type(None))}), ("is_training", ClassValidator, {"classes": (bool, type(None))}), -- Gitee From ae7d62065fb812c83d6ee5ec68238484055d42b3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 15 Sep 2023 10:34:15 +0800 Subject: [PATCH 350/551] Match-id-dbcebd2d8ae50cbf939b4f7099204defe2c41805 --- mx_rec/core/embedding.py | 24 ++++++++++++---------- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 597167b9..7ca5e723 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -604,16 +604,12 @@ class SparseEmbedding: emb_size=self.emb_size, use_hot=use_hot, device_id=device_id, use_dynamic_expansion=use_dynamic_expansion) - # 用于打桩的op节点,它的name用于标识此次的sparse lookup是train还是eval - # 后续在session run的时候,通过图反向查找该子图中查找到此op - # 最后通过名称判断session run是调用的哪个通道,并通知c++侧进行计数和唤醒操作 - notify_hybridmgmt_op = self.generate_lookup_id_notify_hybrid(channel_id) - with tf.control_dependencies([notify_hybridmgmt_op]): - if self.skip_emb_transfer: - result = get_preprocessed_tensor_for_asc(self.variable, config) - else: - variable_list = [self.variable] + [slot_info.get("slot") for slot_info in self.optimizer_slot_info_list] - result = get_preprocessed_tensor_for_asc(variable_list, config) + + if self.skip_emb_transfer: + result = get_preprocessed_tensor_for_asc(self.variable, config) + else: + variable_list = [self.variable] + [slot_info.get("slot") for slot_info in self.optimizer_slot_info_list] + result = get_preprocessed_tensor_for_asc(variable_list, config) restore_vector = result.get("restore_vector") restore_vector_second = result.get("restore_vector_second") hot_pos = result.get("hot_pos") @@ -650,7 +646,13 @@ class SparseEmbedding: unique_embeddings_shape = unique_embeddings.shape.as_list() else: unique_embeddings_shape = tf.shape(unique_embeddings) - embeddings = tf.gather(unique_embeddings, restore_vector, axis=0, name="gather_for_restore_vector") + + # 用于打桩的op节点,它的name用于标识此次的sparse lookup是train还是eval + # 后续在session run的时候,通过图反向查找该子图中查找到此op + # 最后通过名称判断session run是调用的哪个通道,并通知c++侧进行计数和唤醒操作 + notify_hybridmgmt_op = self.generate_lookup_id_notify_hybrid(channel_id) + with tf.control_dependencies([notify_hybridmgmt_op]): + embeddings = tf.gather(unique_embeddings, restore_vector, axis=0, name="gather_for_restore_vector") if use_static: lookup_result = tf.reshape(embeddings, feature_spec.dims + [self.scalar_emb_size]) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index 550564d4..cd0649e2 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -41,7 +41,7 @@ void HybridMgmtBlock::CheckAndSetBlock(int channelId) void HybridMgmtBlock::CheckAndNotifyWake(int channelId) { LOG_DEBUG(HYBRID_BLOCKING + "start notify channelId {} pythonBatchId {} hybridBatchId {}", - channelId, pythonBatchId[lastRunChannelId], hybridBatchId[channelId]); + channelId, pythonBatchId[channelId], hybridBatchId[channelId]); CheckValid(channelId); if (pythonBatchId[channelId] >= hybridBatchId[channelId]) { -- Gitee From 1f05babd8e067546ab6dfbfcad271ee9a34e9f36 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 15 Sep 2023 15:43:58 +0800 Subject: [PATCH 351/551] Match-id-8e3e984b0bd01cebbd631237c5850aff0552ac03 --- src/core/key_process/key_process.cpp | 23 ++++++++++++++++++----- src/core/utils/common.h | 1 - 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index fede6f71..0880d54f 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -27,6 +27,19 @@ inline vector Count2Start(const vector& count) return start; } +class EndRunExit : public std::exception { +public: + explicit EndRunExit(const char* message) : errorMessage(message) {} + + const char* what() const noexcept override + { + return errorMessage; + } + +private: + const char* errorMessage; +}; + void KeyProcess::SetupHotEmbUpdateStep() { this->hotEmbUpdateStep = GlobalEnv::hotEmbUpdateStep; @@ -268,7 +281,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) batchQueue->PutDirty(move(batch)); } unique->UnInitialize(); - } catch (const EndRunError &e) { + } catch (const EndRunExit &e) { LOG_INFO(KEY_PROCESS "abort run: {}", e.what()); } LOG_INFO(KEY_PROCESS "KeyProcessTaskWithFastUnique exit. rank:{} thread:{}, channel:{}", @@ -301,7 +314,7 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) auto batchQueue = SingletonQueue::getInstances(threadId + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } - } catch (const EndRunError &e) { + } catch (const EndRunExit &e) { LOG_INFO(KEY_PROCESS "abort run: {}", e.what()); } LOG_INFO(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, threadId, channel); @@ -533,7 +546,7 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) // 通信终止信号,同步退出,防止线程卡住 int exitFlag = isRunning; MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); - throw EndRunError("GetBatchData end run."); + throw EndRunExit("GetBatchData end run."); } } EASY_END_BLOCK @@ -1011,7 +1024,7 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int int exitFlag = isRunning; MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); if (exitFlag < rankInfo.rankSize) { - throw EndRunError("GetScAll end run."); + throw EndRunExit("GetScAll end run."); } EASY_END_BLOCK; LOG_DEBUG(KEY_PROCESS "barrier time:{}", tc.ElapsedMS()); @@ -1032,7 +1045,7 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, in int exitFlag = isRunning; MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); if (exitFlag < rankInfo.rankSize) { - throw EndRunError("GetScAll end run."); + throw EndRunExit("GetScAll end run."); } EASY_END_BLOCK; LOG_DEBUG(KEY_PROCESS "barrier time:{}", tc.ElapsedMS()); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 2348594a..981c93e8 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -100,7 +100,6 @@ namespace MxRec { using keys_t = std::vector; using lookup_key_t = std::tuple; // batch_id quarry_lable keys_vector using tensor_info_t = std::tuple>>::iterator>; - using EndRunError = std::runtime_error; namespace HybridOption { const int USE_STATIC = 0x001; -- Gitee From b8c6b9b696d28260bd528f6cedcd755e3d7df0b2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 16 Sep 2023 10:58:13 +0800 Subject: [PATCH 352/551] Match-id-1783e291bc98cbecf51caecda8a3dc2556a51424 --- mx_rec/core/embedding.py | 65 ++++++++++++++++++++------------------- mx_rec/graph/patch.py | 4 +-- mx_rec/saver/saver.py | 16 +++++----- mx_rec/util/initialize.py | 34 +++++++++++--------- 4 files changed, 62 insertions(+), 57 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 597167b9..6e440acc 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -171,6 +171,18 @@ class SparseEmbedding: def optimizer_instance_list(self): return self._optimizer_instance_list + @staticmethod + def generate_lookup_id_notify_hybrid(channel_id: int): + + """ + Args: + channel_id: channel id 0 for train,1 for eval + Returns: npu_ops.outfeed_enqueue_op notify preprocess step + """ + channel_name = "d2h_notify_hybridmgmt_{}".format(channel_id) + notify_hybridmgmt_op = tf.no_op(channel_name) + return notify_hybridmgmt_op + @staticmethod def get_anchor_attribute(anchor, attr): if not isinstance(anchor, tf.Tensor): @@ -194,6 +206,27 @@ class SparseEmbedding: optimizer.insert_slot(slot, named_slot_key, slot_name) + @staticmethod + def get_emb_table_size(table_name: str) -> int: + """ + For HBM or DDR mode, return the size of sparse embedding table + :param table_name: the name of sparse embedding table + :return: the size of the sparse embedding table + """ + table_instance = get_table_instance_by_name(table_name) + host_vocabulary_size = table_instance.host_vocabulary_size() + device_vocabulary_size = table_instance.device_vocabulary_size + if not host_vocabulary_size and not get_use_dynamic_expansion(): + embed_dim = table_instance.emb_size + size = embed_dim * device_vocabulary_size + elif not host_vocabulary_size and get_use_dynamic_expansion(): + embed_dim = table_instance.ext_emb_size + size = embed_dim * device_vocabulary_size + else: + embed_dim = table_instance.ext_emb_size + size = (device_vocabulary_size + host_vocabulary_size) * embed_dim + return size + @staticmethod def _get_own_emb(emb, all2all_args, emb_size, use_static): """ @@ -230,27 +263,6 @@ class SparseEmbedding: return tf.reshape(src_emb, reshape_info) - @staticmethod - def get_emb_table_size(table_name: str) -> int: - """ - For HBM or DDR mode, return the size of sparse embedding table - :param table_name: the name of sparse embedding table - :return: the size of the sparse embedding table - """ - table_instance = get_table_instance_by_name(table_name) - host_vocabulary_size = table_instance.host_vocabulary_size() - device_vocabulary_size = table_instance.device_vocabulary_size - if not host_vocabulary_size and not get_use_dynamic_expansion(): - embed_dim = table_instance.emb_size - size = embed_dim * device_vocabulary_size - elif not host_vocabulary_size and get_use_dynamic_expansion(): - embed_dim = table_instance.ext_emb_size - size = embed_dim * device_vocabulary_size - else: - embed_dim = table_instance.ext_emb_size - size = (device_vocabulary_size + host_vocabulary_size) * embed_dim - return size - def check_optimizer_instance(self): for optimizer_instance in self._optimizer_instance_list: if tf.__version__.startswith("1"): @@ -561,17 +573,6 @@ class SparseEmbedding: dest_shape = array_ops.concat([array_ops.shape(tensor_list[idx]), [self.scalar_emb_size]], 0) self.lookup_result[one_feature_spec.name][is_training] = array_ops.reshape(one_result, dest_shape) - def generate_lookup_id_notify_hybrid(self, channel_id: int): - - """ - Args: - channel_id: channel id 0 for train,1 for eval - Returns: npu_ops.outfeed_enqueue_op notify preprocess step - """ - channel_name = "d2h_notify_hybridmgmt_{}".format(channel_id) - notify_hybridmgmt_op = tf.no_op(channel_name) - return notify_hybridmgmt_op - def lookup_for_asc_with_feature_spec_inner(self, feature_spec: FeatureSpec, send_count: int, **kwargs): """ Args: diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 6c107293..5fab4a8c 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -109,8 +109,8 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): name2channel_cache = self.get_mxrec_name2channel_cache() # 查找相应的channel_id + get_all_tensor(fetches) try: - get_all_tensor(fetches) channel_id = get_channel_id_by_sub_graph(all_op, name2channel_cache) except AssertionError: channel_id = -1 @@ -118,7 +118,7 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): if channel_id != -1: get_asc_manager().block_notify_wake(channel_id) - #调用tensorflow原生的方法 + # 调用tensorflow原生的方法 result = self.old_run_method(fetches, feed_dict, options, run_metadata) if channel_id != -1: get_asc_manager().block_count_steps(channel_id, steps) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index c031648a..5d7feb47 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -52,6 +52,14 @@ class Saver(object): self.save_easy_mode = (global_env.save_easy == Flag.TRUE.value) self.build() + @staticmethod + def _make_table_name_dir(root_dir, table_instance, table_name): + if table_instance.host_vocabulary_size > 0: + table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) + else: + table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) + tf.io.gfile.makedirs(table_dir) + def build(self): if self.var_list is None: self.var_list = [] @@ -213,14 +221,6 @@ class Saver(object): for thread in threads: thread.join() - @staticmethod - def _make_table_name_dir(root_dir, table_instance, table_name): - if table_instance.host_vocabulary_size > 0: - table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) - else: - table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) - tf.io.gfile.makedirs(table_dir) - def _save_easy_mode_save_key_data(self, dump_data_dict, root_dir, table_name): host_data = get_host_data(table_name) key = np.array(list(host_data.keys())) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index f6895702..3a4233c0 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -11,7 +11,7 @@ import psutil import mx_rec.constants.constants from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, HASHTABLE_COLLECTION_NAME_LENGTH, \ - TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE, MAX_RANK_SIZE, MAX_INT32, TFDevice, Flag + TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE, MAX_INT32, TFDevice, Flag from mx_rec.util.communication.hccl_mgmt import parse_hccl_json, set_hccl_info_without_json from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.validator.validator import StringValidator, FileValidator, para_checker_decorator, ClassValidator, \ @@ -27,17 +27,17 @@ class ConfigInitializer: host_pipeline_ops = import_host_pipeline_ops() @para_checker_decorator(check_option_list=[ - ("use_mpi", ClassValidator, {"classes": (bool, )}), + ("use_mpi", ClassValidator, {"classes": (bool,)}), ("train_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), ("eval_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), ("save_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), (["train_steps", "eval_steps"], ValueCompareValidator, {"target": 0}, ["check_at_least_one_not_equal_to_target"]), - ("if_load", ClassValidator, {"classes": (bool, )}), - ("use_dynamic", ClassValidator, {"classes": (bool, )}), - ("use_hot", ClassValidator, {"classes": (bool, )}), - ("use_dynamic_expansion", ClassValidator, {"classes": (bool, )}), - ("bind_cpu", ClassValidator, {"classes": (bool, )}), + ("if_load", ClassValidator, {"classes": (bool,)}), + ("use_dynamic", ClassValidator, {"classes": (bool,)}), + ("use_hot", ClassValidator, {"classes": (bool,)}), + ("use_dynamic_expansion", ClassValidator, {"classes": (bool,)}), + ("bind_cpu", ClassValidator, {"classes": (bool,)}), ]) def __init__(self, use_mpi=True, **kwargs): self._use_mpi = use_mpi @@ -84,10 +84,14 @@ class ConfigInitializer: else: raise ValueError("only mpi is supported for launching task.") - self._rank_to_device_dict, self._local_rank_size = parse_hccl_json() if global_env.rank_table_file else \ - set_hccl_info_without_json(visible_devices=global_env.ascend_visible_devices, - rank_size=global_env.cm_worker_size, - chief_device=global_env.cm_chief_device) + if global_env.rank_table_file: + self._rank_to_device_dict, self._local_rank_size = parse_hccl_json() + else: + self._rank_to_device_dict, self._local_rank_size = set_hccl_info_without_json( + visible_devices=global_env.ascend_visible_devices, + rank_size=global_env.cm_worker_size, + chief_device=global_env.cm_chief_device) + self.train_steps = kwargs.get("train_steps", -1) self.eval_steps = kwargs.get("eval_steps", -1) self.save_steps = kwargs.get("save_steps", -1) @@ -698,7 +702,7 @@ def export_feature_spec(): @para_checker_decorator(check_option_list=[ - ("if_load", ClassValidator, {"classes": (bool, )}) + ("if_load", ClassValidator, {"classes": (bool,)}) ]) def set_if_load(if_load): ConfigInitializer.get_instance().if_load = if_load @@ -714,7 +718,7 @@ def get_use_static(): def get_stat_on(): return ConfigInitializer.get_instance().stat_on - + def get_use_hot(): return ConfigInitializer.get_instance().use_hot @@ -737,7 +741,7 @@ def get_name_to_var_dict(): @para_checker_decorator(check_option_list=[ - ("is_training", ClassValidator, {"classes": (bool, )}) + ("is_training", ClassValidator, {"classes": (bool,)}) ]) def get_initializer(is_training): return ConfigInitializer.get_instance().get_initializer(is_training) @@ -944,4 +948,4 @@ def bind_cpu(rank_id: int, local_rank_size: int): process.cpu_affinity(cpu_list) except IndexError: logger.error("failed to bind cpu for rank %s: %s", rank_id, cpu_list) - logger.info("bind cpu for rank %s: %s", rank_id, cpu_list) \ No newline at end of file + logger.info("bind cpu for rank %s: %s", rank_id, cpu_list) -- Gitee From bb971d0113489e1e1826a7db89e3fde6700b17fb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 16 Sep 2023 14:32:05 +0800 Subject: [PATCH 353/551] Match-id-4eb473be80c5d0da46bfa1d19bcf9ad0b93c042e --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 14 +++++++++++++- src/core/utils/common.cpp | 19 ------------------- src/core/utils/common.h | 2 -- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 2573c4b3..e63fa5a0 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -949,7 +949,19 @@ int HybridMgmt::GetStepFromPath(const string& loadPath) regex pattern("sparse-model-\\d+-(\\d+)"); smatch match; if (regex_search(loadPath, match, pattern)) { - return stoi(match[1]); + int res = 0; + unsigned int minSize = 2; + if (match.size() < minSize) { + return res; + } + try { + res = stoi(match[1]); + } catch (const std::invalid_argument& e) { + LOG_ERROR(e.what()); + } catch (const std::out_of_range& e) { + LOG_ERROR(e.what()); + } + return res; } return 0; } diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 31787b75..c0ab83f9 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -105,25 +105,6 @@ namespace MxRec { } } - bool GetEnv(const char *envName) - { - const char* envString = getenv(envName); - int tmp = 0; - if (envString != nullptr) { - try { - tmp = std::stoi(envString); - if (tmp == 0 || tmp == 1) { - LOG_INFO("Succeed to parse ${env:{}}: {}.", envName, tmp); - } else { - LOG_ERROR("Invalid ${env:{}}: {}, which should be an 0 or 1.", envName, tmp); - } - } catch (const std::invalid_argument &e) { - LOG_ERROR("Failed to parse ${env:{}}, which should be an integer.", envName); - } - } - return (tmp == 1) ? true : false; - } - string GetChipName(int devID) { int ret = 0; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 981c93e8..515a2426 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -286,8 +286,6 @@ namespace MxRec { void SetLog(int rank); - bool GetEnv(const char *envName); - template string StringFormat(const string& format, Args ... args) { -- Gitee From 712006be9f4c7aedf6f5afb529364afa1cf40105 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 18 Sep 2023 17:20:49 +0800 Subject: [PATCH 354/551] Match-id-2d014c70516a2d6044680a769956402d238a2fcb --- src/core/emb_hashmap/emb_hashmap.cpp | 18 +++++++++--------- src/core/emb_hashmap/emb_hashmap.h | 5 +++-- src/core/ssd_cache/cache_manager.cpp | 4 +++- src/core/ssd_cache/lfu_cache.cpp | 2 +- src/core/utils/common.h | 6 +++++- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index e965cc56..43cdcb48 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -45,6 +45,7 @@ inline void ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) { embHashMap.swapPos.clear(); embHashMap.lookUpVec.clear(); + embHashMap.ddr2HbmKeys.clear(); } /// DDR模型下处理特征的offset、swap信息等 @@ -343,18 +344,16 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys embHashMap.lookUpVec.emplace_back(INVALID_KEY_VALUE); continue; } - + AddKeyFreqInfo(embName, key, RecordType::NOT_DDR); if (offset < embHashMap.devVocabSize) { // 偏移小于等于HBM容量:直接放入查询向量;更新偏移之前关联的key和当前关联的key embHashMap.lookUpVec.emplace_back(offset); embHashMap.devOffset2KeyOld.emplace_back(offset, static_cast(embHashMap.devOffset2Key[offset])); embHashMap.devOffset2Key[offset] = key; - AddKeyFreqInfo(embName, key, RecordType::NOT_DDR); } else { // 偏移大于HBM容量:记录在host emb上的偏移;找到需要交换的HBM偏移 embHashMap.missingKeysHostPos.emplace_back(offset - embHashMap.devVocabSize); FindSwapPosOld(embName, key, offset, currentBatchId, keepBatchId); - AddKeyFreqInfo(embName, key, RecordType::DDR); } } if (currentBatchId == 0) { @@ -377,6 +376,9 @@ bool EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashM if (iter != embHashMap.hostHashMap.end()) { offset = iter->second; LOG_TRACE("devVocabSize, {} , offset , {}", embHashMap.devVocabSize, offset); + if (isSSDEnabled && offset >= embHashMap.devVocabSize) { + embHashMap.ddr2HbmKeys.emplace_back(key); + } } else if (embHashMap.evictDevPos.size() != 0 && channelId == TRAIN_CHANNEL_ID) { // 优先复用hbm表 offset = embHashMap.evictDevPos.back(); embHashMap.hostHashMap[key] = offset; @@ -493,7 +495,7 @@ bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hos embHashMap.devOffset2KeyOld.emplace_back(embHashMap.currentUpdatePos, embHashMap.devOffset2Key[embHashMap.currentUpdatePos]); auto& oldKey = embHashMap.devOffset2Key[embHashMap.currentUpdatePos]; - embHashMap.oldSwap.emplace_back(oldKey, key); // 记录交换的两个key + embHashMap.oldSwap.emplace_back(oldKey, key); // 记录交换的两个key oldKey:HBM->DDR key:DDR->HBM embHashMap.hostHashMap[oldKey] = hostOffset; // 更新被替换的key的偏移 oldKey = key; notFind = false; @@ -506,8 +508,8 @@ bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hos embHashMap.currentUpdatePos = 0; } - // 已经找完了整个HBM空间,没有找到可用位置,表示HBM空间不足以放下整个batch(预取batch数)的key,无法正常执行训练,固运行时错误退出 - if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { + // 已经找完整个HBM空间,且没找到可用位置,表示HBM空间不足以放下整个batch(预取batch数)的key,无法正常执行训练,故运行时错误退出 + if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart && notFind) { LOG_ERROR("devVocabSize is too small"); throw runtime_error("devVocabSize is too small"); } @@ -527,13 +529,11 @@ void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& auto& oldSwap = embHashMap.oldSwap; LOG_DEBUG("RefreshFreqInfoWithSwap:oldSwap Size:{}", oldSwap.size()); vector enterDDRKeys; - vector leaveDDRKeys; for (auto keyPair : oldSwap) { enterDDRKeys.emplace_back(keyPair.first); - leaveDDRKeys.emplace_back(keyPair.second); } cacheManager->RefreshFreqInfoCommon(embName, enterDDRKeys, TransferType::HBM_2_DDR); - cacheManager->RefreshFreqInfoCommon(embName, leaveDDRKeys, TransferType::DDR_2_HBM); + cacheManager->RefreshFreqInfoCommon(embName, embHashMap.ddr2HbmKeys, TransferType::DDR_2_HBM); AddCacheManagerTraceLog(embName, embHashMap); } diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 2eecc7f8..654e52ed 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -77,14 +77,15 @@ namespace MxRec { int32_t FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap); - void RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap); - void AddCacheManagerTraceLog(const string& embName, const EmbHashMapInfo& embHashMap) const; void AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type); RankInfo rankInfo; int swapId { 0 }; + + GTEST_PRIVATE: + void RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap); }; } diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index 5b7c359f..46ddd2f5 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -255,7 +255,9 @@ void CacheManager::RefreshFreqInfoCommon(const string& embTableName, vector evictDevPosChange; std::vector> devOffset2KeyOld; std::vector> oldSwap; // (old on dev, old on host) - + /* + * HBM与DDR换入换出时,已存在于DDR且要转移到HBM的key(不包含新key); 用于SSD模式 + * (区别于oldSwap: pair.second为已存在于DDR key + 换入换出前映射到DDR的新key) + */ + std::vector ddr2HbmKeys; void SetStartCount(); bool HasFree(size_t i); -- Gitee From 6697132c256f01e7704a6f7e2e4bbc041e195caf Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 19 Sep 2023 10:53:56 +0800 Subject: [PATCH 355/551] Match-id-05e35bc1f89bc0cbef82af807ac7fb0f25f74725 --- src/core/emb_hashmap/emb_hashmap.cpp | 77 +++++++------- src/tests/emb_hashmap/emb_hashmap_test.cpp | 111 +++++++++++++++++++++ 2 files changed, 148 insertions(+), 40 deletions(-) create mode 100644 src/tests/emb_hashmap/emb_hashmap_test.cpp diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 43cdcb48..855f2187 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -19,7 +19,6 @@ using namespace MxRec; void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad) { -#ifndef GTEST this->rankInfo = rankInfo; if (!ifLoad) { EmbHashMapInfo embHashMapInfo; @@ -38,7 +37,6 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, LOG_TRACE("devOffset2Batch, {}", VectorToString(embHashMaps.at(embInfo.name).devOffset2Batch)); } } -#endif } inline void ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) @@ -319,6 +317,43 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& LOG_TRACE("hostHashMap, {}", MapToString(embHashMaps[embName].hostHashMap)); } +/* + * 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys + */ +int EmbHashMap::FindSwapPosV2(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, + size_t keepBatchId) +{ + bool notFind = true; + auto& embHashMap = embHashMaps.at(embName); + int newDevOffset; + while (notFind) { + if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { + embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] = static_cast(currentBatchId); + embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); + newDevOffset = static_cast(embHashMap.currentUpdatePos); + embHashMap.hostHashMap[key] = embHashMap.currentUpdatePos; + embHashMap.devOffset2KeyOld.emplace_back(embHashMap.currentUpdatePos, + embHashMap.devOffset2Key[embHashMap.currentUpdatePos]); + auto& oldKey = embHashMap.devOffset2Key[embHashMap.currentUpdatePos]; + embHashMap.oldSwap.emplace_back(oldKey, key); + embHashMap.hostHashMap[oldKey] = hostOffset; + oldKey = key; + notFind = false; + } + embHashMap.currentUpdatePos++; + embHashMap.freeSize--; + if (embHashMap.currentUpdatePos == embHashMap.devVocabSize) { + embHashMap.currentUpdatePos = 0; + } + if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { + LOG_ERROR("devVocabSize is too small"); + throw runtime_error("devVocabSize is too small"); + } + } + return newDevOffset; +} +#endif + /// 从embHashMaps获取key对应的位置,构造查询向量;更新devOffset2Batch;记录dev与host需要交换的偏移 /// \param embName 表名 /// \param keys 查询向量 @@ -436,42 +471,6 @@ void EmbHashMap::UpdateBatchId(const vector& keys, size_t currentBatc } } -/* - * 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys - */ -int EmbHashMap::FindSwapPosV2(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, - size_t keepBatchId) -{ - bool notFind = true; - auto& embHashMap = embHashMaps.at(embName); - int newDevOffset; - while (notFind) { - if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { - embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] = static_cast(currentBatchId); - embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); - newDevOffset = static_cast(embHashMap.currentUpdatePos); - embHashMap.hostHashMap[key] = embHashMap.currentUpdatePos; - embHashMap.devOffset2KeyOld.emplace_back(embHashMap.currentUpdatePos, - embHashMap.devOffset2Key[embHashMap.currentUpdatePos]); - auto& oldKey = embHashMap.devOffset2Key[embHashMap.currentUpdatePos]; - embHashMap.oldSwap.emplace_back(oldKey, key); - embHashMap.hostHashMap[oldKey] = hostOffset; - oldKey = key; - notFind = false; - } - embHashMap.currentUpdatePos++; - embHashMap.freeSize--; - if (embHashMap.currentUpdatePos == embHashMap.devVocabSize) { - embHashMap.currentUpdatePos = 0; - } - if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { - LOG_ERROR("devVocabSize is too small"); - throw runtime_error("devVocabSize is too small"); - } - } - return newDevOffset; -} - /// 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys /// \param embName 表名 /// \param key 输入特征 @@ -591,5 +590,3 @@ void EmbHashMap::AddKeyFreqInfo(const string& embTableName, const emb_key_t& key } cacheManager->PutKey(embTableName, key, type); } - -#endif diff --git a/src/tests/emb_hashmap/emb_hashmap_test.cpp b/src/tests/emb_hashmap/emb_hashmap_test.cpp new file mode 100644 index 00000000..80814d6f --- /dev/null +++ b/src/tests/emb_hashmap/emb_hashmap_test.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: emb_hashmap test + * Author: MindX SDK + * Date: 2023/9/18 + */ + +#include +#include + +#include "emb_hashmap/emb_hashmap.h" + +#include "ssd_cache/cache_manager.h" +#include "utils/common.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +const int HBM_VOCAB_SIZE = 10; +const int DDR_VOCAB_SIZE = 100; +const int SSD_VOCAB_SIZE = 100; +const int INT_2 = 2; +const int INT_4 = 4; +const int INT_21 = 21; +const int INT_42 = 42; +const int NEGATIVE_INT_1 = -1; + +// 刷新换入换出频次和打印信息 +void RefreshSwapFreqInfoAndPrint(EmbHashMap& hostHashMaps, string embTableName, int opTimes) +{ + auto& embHashMap = hostHashMaps.embHashMaps[embTableName]; + hostHashMaps.RefreshFreqInfoWithSwap(embTableName, embHashMap); + vector hbm2DdrKeyList; + vector ddr2HbmKeyList; + for (auto it : embHashMap.oldSwap) { + hbm2DdrKeyList.emplace_back(it.first); + ddr2HbmKeyList.emplace_back(it.second); + } + LOG_INFO("embHashMap hbm2DdrKeyList: {}", VectorToString(hbm2DdrKeyList)); + LOG_INFO("embHashMap ddr2HbmKeyList: {}", VectorToString(ddr2HbmKeyList)); + embHashMap.oldSwap.clear(); + LOG_INFO("RefreshSwapFreqInfoAndPrint end, opTimes:{}", opTimes); +} + +vector GetEmbInfoList() +{ + EmbInfo embInfo; + embInfo.name = "table1"; + embInfo.devVocabSize = HBM_VOCAB_SIZE; + embInfo.hostVocabSize = DDR_VOCAB_SIZE; + embInfo.ssdVocabSize = SSD_VOCAB_SIZE; + embInfo.ssdDataPath = {"ssd_data"}; + vector embInfos; + embInfos.emplace_back(embInfo); + return embInfos; +} + +// 测试HBM与DDR换入换出时CacheManager模块频次刷新 +TEST(EmbHashMap, TestFindOffset) +{ + LOG_INFO("start TestFindOffset"); + string embTableName = "table1"; + EmbHashMap hostHashMaps; + RankInfo rankInfo; + auto embInfo = GetEmbInfoList(); + hostHashMaps.Init(rankInfo, embInfo, false); + CacheManager cacheManager; + cacheManager.Init(nullptr, embInfo); + bool isSSDEnabled = true; + hostHashMaps.isSSDEnabled = isSSDEnabled; + hostHashMaps.cacheManager = &cacheManager; + int channelId = 0; + size_t currentBatchId = 0; + size_t keepBatchId = 0; + int opTimes = 0; + + vector keys = {1, 2, 3, 4, 5}; + hostHashMaps.FindOffset(embTableName, keys, currentBatchId++, keepBatchId++, channelId); + RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); + + vector keys2 = {6, 7, 8, 9, 10}; + hostHashMaps.FindOffset(embTableName, keys2, currentBatchId++, keepBatchId++, channelId); + RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); + + auto& excludeKeyMap = cacheManager.excludeDDRKeyCountMap[embTableName]; + auto& ddrKeyMap = cacheManager.ddrKeyFreqMap[embTableName]; + + auto logLevelTemp = Log::GetLevel(); + Log::SetLevel(Log::TRACE); + vector keys4 = {21, 21, 21, 21}; // 新key重复值, 且需要换入换出 + hostHashMaps.FindOffset(embTableName, keys4, currentBatchId++, keepBatchId++, channelId); + RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); + ASSERT_EQ(excludeKeyMap[INT_21], INT_4); + ASSERT_EQ(ddrKeyMap.Get(1), 1); + + keys4 = {41, 42, 43, 44, 45, 46, 47, 48, 49, 50}; // 整个hbm大小key换入换出 + hostHashMaps.FindOffset(embTableName, keys4, currentBatchId++, keepBatchId++, channelId); + RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); + ASSERT_EQ(ddrKeyMap.Get(INT_21), INT_4); + + keys4 = {51, 52, 53, 1, 2, 21, 41, 42, 43, 44}; // 3个新key, 3个在ddr, 4个在hbm + hostHashMaps.FindOffset(embTableName, keys4, currentBatchId, keepBatchId, channelId); + RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes); + ASSERT_EQ(excludeKeyMap[1], INT_2); + ASSERT_EQ(excludeKeyMap[INT_42], INT_2); + ASSERT_EQ(ddrKeyMap.Get(INT_21), NEGATIVE_INT_1); + ASSERT_EQ(ddrKeyMap.Get(1), NEGATIVE_INT_1); + Log::SetLevel(logLevelTemp); // 恢复日志级别 + LOG_INFO("test TestFindOffset end."); +} \ No newline at end of file -- Gitee From d8ea2bc5e95b1a037baa50a4b04bcbca42d49cd8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 20 Sep 2023 17:34:31 +0800 Subject: [PATCH 356/551] Match-id-194a1eb3a31329950e8c21ad794e556b1e47ab86 --- mx_rec/validator/validator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index e26a5b68..6d1bfbbd 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -324,6 +324,9 @@ class NumValidator(Validator): invalid_options: List = None, constrained_options: List = None, msg: str = ""): if isinstance(value, tf.TensorShape) and value.ndims == 1: value = value.as_list()[0] + if isinstance(value, tf.Tensor): + sess = tf.Session() if tf.__version__.startswith("1.") else tf.compat.v1.Session() + value = sess.run(value).item() super(NumValidator, self).__init__(name, value) self.min_value = min_value -- Gitee From 6437f9114a2745eac90a456fa8ec6c74a2259d85 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 21 Sep 2023 10:15:03 +0800 Subject: [PATCH 357/551] Match-id-963748da9e0612963293eb18b4dae791452c1cb2 --- src/CMakeLists.txt | 13 ++++++++----- src/core/CMakeLists.txt | 1 - src/tests/CMakeLists.txt | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 590d7e58..6d1aa6de 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,6 +23,13 @@ set(CMAKE_PREFIX_PATH ${OMPI_PATH} ${HDF5_PATH}) find_package(OpenMP REQUIRED) find_package(MPI REQUIRED) find_package(PythonLibs 3.7 REQUIRED) + +find_program(CCACHE_FOUND ccache) +if(CCACHE_FOUND) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache) + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ccache) +endif(CCACHE_FOUND) + if (CMAKE_BUILD_TYPE MATCHES "Debug") find_package(easy_profiler) else () @@ -35,7 +42,7 @@ else () ADD_DEFINITIONS(-DBUILD_WITH_EASY_PROFILER) endif () set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -ffunction-sections -O0 -Wall -g2 -ggdb") -set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -ffunction-sections -O3 -Wall -DNDEBUG -fPIC -fstack-protector-all -Wextra -Wconversion -D_FORTIFY_SOURCE=2 -s") +set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -ffunction-sections -O3 -Wfatal-errors -DNDEBUG -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -s") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack") option(ENABLE_DEBUG "use debug mode" OFF) @@ -94,10 +101,6 @@ endif() if(IS_DIRECTORY ${OPENSOURCE_DIR}) add_subdirectory(${OPENSOURCE_DIR}/pybind11 pybind11.out) - - add_subdirectory(${OPENSOURCE_DIR}/glog glog.out) - include_directories(glog.out) - install(TARGETS glog LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) else() message(FATAL_ERROR "INVALID FOLDER, ${OPENSOURCE_DIR}") endif() diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 725d1b14..e8eda15c 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -48,7 +48,6 @@ target_link_libraries(ASC PUBLIC -l:_tf_adapter.so OpenMP::OpenMP_CXX ${MPI_CXX_LIBRARIES} ${PYTHON_LIBRARY} - glog::glog ) find_package(easy_profiler PATHS ${EASY_PROFILER_PATH} NO_DEFAULT_PATH) if (easy_profiler_FOUND) diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 302d7bc9..9854324a 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -62,4 +62,4 @@ target_link_libraries(test_main PUBLIC MPI::MPI_CXX) target_link_libraries(test_main PUBLIC ascendcl msprofiler ge_executor gert runtime ge_common register graph ascend_protobuf - profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host glog::glog) \ No newline at end of file + profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host) -- Gitee From fc9faa84c188b1e177dd0f1a817fc989e2285433 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 21 Sep 2023 11:07:36 +0800 Subject: [PATCH 358/551] Match-id-394d4dd51973d6c3cb2d8938e0d0be22aba3f181 --- mx_rec/util/initialize.py | 5 +++-- mx_rec/validator/validator.py | 18 +++++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 3a4233c0..6903f65d 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -15,7 +15,7 @@ from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, HASHT from mx_rec.util.communication.hccl_mgmt import parse_hccl_json, set_hccl_info_without_json from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.validator.validator import StringValidator, FileValidator, para_checker_decorator, ClassValidator, \ - IntValidator, ValueCompareValidator + IntValidator, ValueCompareValidator, OptionalStringValidator from mx_rec.util.atomic import AtomicInteger from mx_rec.util.global_env_conf import global_env from mx_rec.util.log import logger @@ -438,7 +438,8 @@ class ConfigInitializer: @para_checker_decorator(check_option_list=[ - ("name", ClassValidator, {"classes": (str, type(None))}) + ("name", ClassValidator, {"classes": (str, type(None))}), + ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), ]) def set_ascend_global_hashtable_collection(name=ASCEND_GLOBAL_HASHTABLE_COLLECTION): ConfigInitializer.get_instance().ascend_global_hashtable_collection = name diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index e26a5b68..f314e810 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -65,6 +65,18 @@ def para_checker_decorator(check_option_list: List[Tuple[Union[List[str], str], def para_checker(func): @functools.wraps(func) def wrapper(*args, **kwargs): + signature = inspect.signature(func) + # 获取实际传入的参数,包括默认值参数 + bound_args = signature.bind(*args, **kwargs) + bound_args.apply_defaults() + actual_args = dict(bound_args.arguments) + + temp_kwargs = dict() + if "kwargs" in actual_args: + temp_kwargs = actual_args["kwargs"] + del actual_args["kwargs"] + actual_args.update(temp_kwargs) + func_spec = inspect.getfullargspec(func) # 将函数有默认值的参数加入kwargs args_with_default = set() @@ -75,7 +87,7 @@ def para_checker_decorator(check_option_list: List[Tuple[Union[List[str], str], continue args_with_default.add(arg) kwargs.update({arg: default}) - logger.debug("[checker wrapper]func %s args: %s, kwargs: %s", func.__name__, args, kwargs) + logger.debug("[checker wrapper]func %s kwargs: %s", func.__name__, actual_args) # 执行每一个检查项 for option in check_option_list: optional_check_list = None @@ -96,11 +108,11 @@ def para_checker_decorator(check_option_list: List[Tuple[Union[List[str], str], # 确认当前检查项需要检查的参数是否在函数参数中 paras = [] for para_to_be_check in para_list_to_be_check: - if para_to_be_check not in kwargs: + if para_to_be_check not in actual_args: logger.debug("[checker wrapper]invalid para '%s' to be checked, " "not passed to the function '%s'", para_to_be_check, func.__name__) continue - paras.append(kwargs.get(para_to_be_check)) + paras.append(actual_args.get(para_to_be_check)) # 如果检查的参数不在传参中,跳过该检查项 if not paras: -- Gitee From 7e7d2559786779af197bef5618a324946e7afc99 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 20 Sep 2023 16:28:54 +0800 Subject: [PATCH 359/551] Match-id-77bb456d19f68d9fc1d8732d9f91dc3d62afce37 --- src/core/checkpoint/checkpoint.cpp | 69 ++++----- src/core/checkpoint/checkpoint.h | 14 +- .../emb_hash_ckpt/emb_hash_ckpt.cpp | 29 ++-- .../feat_admit_n_evict_ckpt.cpp | 14 +- .../host_emb_ckpt/host_emb_ckpt.cpp | 15 +- .../host_emb_ckpt/host_emb_ckpt.h | 2 +- .../key_freq_map_ckpt/key_freq_map_ckpt.cpp | 8 +- .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp | 2 +- .../nddr_offset_ckpt/nddr_offset_ckpt.cpp | 2 +- src/core/emb_hashmap/emb_hashmap.cpp | 22 +-- src/core/emb_hashmap/emb_hashmap.h | 14 +- src/core/emb_table/emb_table.cpp | 20 +-- src/core/emb_table/emb_table.h | 11 +- src/core/hd_transfer/acl_channel.h | 8 +- src/core/host_emb/host_emb.cpp | 6 +- src/core/host_emb/host_emb.h | 4 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 75 +++++----- src/core/hybrid_mgmt/hybrid_mgmt.h | 24 ++- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 23 ++- src/core/hybrid_mgmt/hybrid_mgmt_block.h | 138 ++++++++++-------- .../constant_initializer.h | 2 - src/core/initializer/initializer.cpp | 46 +++++- src/core/initializer/initializer.h | 53 ++++++- .../random_normal_initializer.cpp | 8 +- .../random_normal_initializer.h | 2 +- .../truncated_normal_initializer.cpp | 8 +- .../truncated_normal_initializer.h | 2 +- .../key_process/feature_admit_and_evict.cpp | 2 +- .../key_process/feature_admit_and_evict.h | 4 +- src/core/key_process/key_process.cpp | 125 ++++++---------- src/core/key_process/key_process.h | 46 +++++- src/core/ssd_cache/cache_manager.cpp | 40 ++--- src/core/ssd_cache/cache_manager.h | 14 +- src/core/ssd_engine/file.cpp | 6 +- src/core/ssd_engine/file.h | 6 +- src/core/ssd_engine/ssd_engine.cpp | 20 +-- src/core/utils/common.cpp | 38 ----- src/core/utils/common.h | 43 +----- src/ops_tf/hybrid_dataset_ops.cpp | 4 +- .../hybrid_mgmt/hybrid_mgmt_block_test.cpp | 12 +- src/tests/initializer/initializer_test.cpp | 8 +- src/tests/key_process/key_process_test.cpp | 1 + 42 files changed, 508 insertions(+), 482 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 4bbcaad7..195be601 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -24,13 +24,13 @@ using namespace std; using namespace MxRec; -void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo) +void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& embInfo) { processPath = savePath; rankId = mgmtRankInfo.rankId; deviceId = mgmtRankInfo.deviceId; useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; - mgmtEmbInfo = EmbInfo; + mgmtEmbInfo = embInfo; LOG_INFO("Start host side saving data."); LOG_DEBUG("==Start to create save data handler."); @@ -40,14 +40,14 @@ void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRa LOG_INFO("Finish host side saving data."); } -void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo, +void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& embInfo, const vector& featureTypes) { processPath = loadPath; rankId = mgmtRankInfo.rankId; deviceId = mgmtRankInfo.deviceId; useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; - mgmtEmbInfo = EmbInfo; + mgmtEmbInfo = embInfo; LOG_INFO("Start host side loading data."); LOG_DEBUG("==Start to create load data handler."); @@ -83,13 +83,14 @@ void Checkpoint::SetDataHandler(CkptData& ckptData) void Checkpoint::SetDataHandler(const vector& featureTypes) { - map> setCkptMap { { CkptFeatureType::HOST_EMB, - [&] { dataHandlers.push_back(make_unique()); } }, - { CkptFeatureType::EMB_HASHMAP, [&] { dataHandlers.push_back(make_unique()); } }, - { CkptFeatureType::MAX_OFFSET, [&] { dataHandlers.push_back(make_unique()); } }, - { CkptFeatureType::KEY_OFFSET_MAP, [&] { dataHandlers.push_back(make_unique()); } }, - { CkptFeatureType::FEAT_ADMIT_N_EVICT, [&] { dataHandlers.push_back(make_unique()); } }, - { CkptFeatureType::DDR_KEY_FREQ_MAP, [&] { dataHandlers.push_back(make_unique()); } } }; + map> setCkptMap{ + {CkptFeatureType::HOST_EMB, [this] { dataHandlers.push_back(make_unique()); }}, + {CkptFeatureType::EMB_HASHMAP, [this] { dataHandlers.push_back(make_unique()); }}, + {CkptFeatureType::MAX_OFFSET, [this] { dataHandlers.push_back(make_unique()); }}, + {CkptFeatureType::KEY_OFFSET_MAP, [this] { dataHandlers.push_back(make_unique()); }}, + {CkptFeatureType::FEAT_ADMIT_N_EVICT, [this] { dataHandlers.push_back(make_unique()); }}, + {CkptFeatureType::DDR_KEY_FREQ_MAP, [this] { dataHandlers.push_back(make_unique()); }} + }; for (const auto& featureType : featureTypes) { setCkptMap.at(featureType)(); @@ -139,7 +140,7 @@ void Checkpoint::MakeDataLayerSaveDir(const vector& embNames, } } -void Checkpoint::MakeSaveDir(const string& dirName) +void Checkpoint::MakeSaveDir(const string& dirName) const { if (access(dirName.c_str(), F_OK) == -1) { if (mkdir(dirName.c_str(), dirMode) == -1) { @@ -164,7 +165,7 @@ Checkpoint::EmbSizeInfo Checkpoint::GetEmbeddingSize(const string& embName) bool Checkpoint::CheckEmbNames(const string& embName) { for (const auto &embInfo: mgmtEmbInfo) { - if (embInfo.name == embName && embInfo.isSave == true) { + if (embInfo.name == embName && embInfo.isSave) { return true; } } @@ -233,7 +234,7 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da throw runtime_error(Log::Format("aclrtMemcpy failed, ret={}", ret).c_str()); } - writeFile.write((const char *) (row.data()), embeddingSize * sizeof(float)); + writeFile.write(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); } #endif writeFile.close(); @@ -260,8 +261,8 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); } - auto &AttributeArr = transData.attribute; - auto embHashMapSize = AttributeArr.at(0); + auto &attributeArr = transData.attribute; + auto embHashMapSize = attributeArr.at(0); if (embHashMapSize <= 0) { throw runtime_error(StringFormat("Invalid EmbHashMapSize:%d, must be greater than 0", embHashMapSize).c_str()); } @@ -286,16 +287,16 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, if (keyAddrElem < 0) { throw runtime_error(StringFormat("keyAddrElem: %d is less than 0", keyAddrElem).c_str()); } - for (size_t i{0}, j{0}; i < transArr.size(); i += keyAddrElem, ++j) { + for (size_t i = 0, j = 0; i < transArr.size(); i += keyAddrElem, ++j) { vector row(embeddingSize); - readFile.read((char *) (row.data()), embeddingSize * sizeof(float)); + readFile.read(reinterpret_cast (row.data()), embeddingSize * sizeof(float)); - aclError ret = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), - row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtMemcpy failed, ret={}", ret); + aclError ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), + row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + if (ec != ACL_SUCCESS) { + LOG_ERROR("aclrtMemcpy failed, ret={}", ec); readFile.close(); - throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ec).c_str()); } int64_t address = reinterpret_cast(floatPtr + j * embeddingSize); transArr.at(i + 1) = address; @@ -330,7 +331,7 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si writeSize = dataCol; } if (floatTransSet.find(dataType) != floatTransSet.end()) { - writeFile.write((const char*)(transData.floatArr[i]) + idx, writeSize); + writeFile.write(reinterpret_cast(transData.floatArr[i]) + idx, writeSize); } else { WriteDataset(transData, writeFile, writeSize, dataType, idx); } @@ -350,11 +351,11 @@ void Checkpoint::WriteDataset(CkptTransData& transData, size_t idx) { if (int32TransSet.find(dataType) != int32TransSet.end()) { - writeFile.write((const char*)(transData.int32Arr.data()) + idx, writeSize); + writeFile.write(reinterpret_cast(transData.int32Arr.data()) + idx, writeSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - writeFile.write((const char*)(transData.int64Arr.data()) + idx, writeSize); + writeFile.write(reinterpret_cast(transData.int64Arr.data()) + idx, writeSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - writeFile.write((const char*)(transData.attribute.data()) + idx, writeSize); + writeFile.write(reinterpret_cast(transData.attribute.data()) + idx, writeSize); } } @@ -388,7 +389,7 @@ vector Checkpoint::GetEmbedTableNames() { vector loadTableNames; for (const auto& embInfo : mgmtEmbInfo) { - if (embInfo.isSave == true) { + if (embInfo.isSave) { loadTableNames.push_back(embInfo.name); } } @@ -512,7 +513,7 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, uint32_t dataElmtBytes, CkptData& ckptData, - string embName) + string embName) const { if (dataElmtBytes == 0) { LOG_ERROR("dataElmtBytes is 0, don't handle [/ %] operation"); @@ -558,7 +559,7 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, } else { readSize = dataCol; } - readFile.read((char*)(dst[i].data()) + idx, readSize); + readFile.read(reinterpret_cast(dst[i].data()) + idx, readSize); dataCol -= readSize; idx += readSize; } @@ -586,12 +587,12 @@ void Checkpoint::ReadDataset(CkptTransData& transData, size_t idx) { if (int32TransSet.find(dataType) != int32TransSet.end()) { - readFile.read((char*)(transData.int32Arr.data()) + idx, readSize); + readFile.read(reinterpret_cast(transData.int32Arr.data()) + idx, readSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - readFile.read((char*)(transData.int64Arr.data()) + idx, readSize); + readFile.read(reinterpret_cast(transData.int64Arr.data()) + idx, readSize); } else if (floatTransSet.find(dataType) != floatTransSet.end()) { - readFile.read((char*)(transData.floatArr.data()) + idx, readSize); + readFile.read(reinterpret_cast(transData.floatArr.data()) + idx, readSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - readFile.read((char*)(transData.attribute.data()) + idx, readSize); + readFile.read(reinterpret_cast(transData.attribute.data()) + idx, readSize); } } diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index b2c1124d..f7ad4ccc 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -22,8 +22,8 @@ namespace MxRec { Checkpoint() = default; ~Checkpoint() {}; - void SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo); - void LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& EmbInfo, + void SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& embInfo); + void LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRankInfo, const vector& embInfo, const vector& featureTypes); private: @@ -77,7 +77,7 @@ namespace MxRec { void MakeUpperLayerSaveDir(const vector& dirNames); void MakeDataLayerSaveDir(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler); - void MakeSaveDir(const string& dirName); + void MakeSaveDir(const string& dirName) const; void SaveDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler); void WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType); @@ -87,11 +87,11 @@ namespace MxRec { void ReadEmbedding(CkptTransData& transData, const string& dataDir, const string& embName); struct EmbSizeInfo { - int embSize; - int extEmbSize; // embSize + (optimizer's slot) * embSize + int embSize = 0; + int extEmbSize = 0; // embSize + (optimizer's slot) * embSize }; EmbSizeInfo GetEmbeddingSize(const string& embName); - bool CheckEmbNames(const string& embNames); + bool CheckEmbNames(const string& embName); void LoadProcess(CkptData& ckptData); void GetUpperLayerLoadDir(const vector& dirNames); @@ -101,7 +101,7 @@ namespace MxRec { const unique_ptr& dataHandler, CkptData& ckptData); void ReadStream(CkptTransData& transData, const string& dataDir, CkptDataType dataType, uint32_t dataElmtBytes); void ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, uint32_t dataElmtBytes, - CkptData& ckptData, string embName); + CkptData& ckptData, string embName) const; void SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType); void ReadDataset(CkptTransData& transData, ifstream& readFile, size_t readSize, CkptDataType dataType, size_t idx); diff --git a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp index d8bf2f07..f334e49c 100644 --- a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp +++ b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp @@ -48,11 +48,11 @@ vector EmbHashCkpt::GetEmbNames() CkptTransData EmbHashCkpt::GetDataset(CkptDataType dataType, string embName) { - map> dataTransMap { { CkptDataType::EMB_HASHMAP, - [=] { SetEmbHashMapTrans(embName); } }, - { CkptDataType::DEV_OFFSET, [=] { SetDevOffsetTrans(embName); } }, - { CkptDataType::EMB_CURR_STAT, [=] { SetEmbCurrStatTrans(embName); } }, - { CkptDataType::EVICT_POS, [=] { SetEvictPosTrans(embName); } } }; + map> dataTransMap { + {CkptDataType::EMB_HASHMAP, [this, embName] { SetEmbHashMapTrans(embName); }}, + {CkptDataType::DEV_OFFSET, [this, embName] { SetDevOffsetTrans(embName); }}, + {CkptDataType::EMB_CURR_STAT, [this, embName] { SetEmbCurrStatTrans(embName); }}, + {CkptDataType::EVICT_POS, [this, embName] { SetEvictPosTrans(embName); }}}; CleanTransfer(); dataTransMap.at(dataType)(); @@ -61,10 +61,11 @@ CkptTransData EmbHashCkpt::GetDataset(CkptDataType dataType, string embName) void EmbHashCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { - map> dataLoadMap { { CkptDataType::EMB_HASHMAP, [=] { SetEmbHashMap(embName); } }, - { CkptDataType::DEV_OFFSET, [=] { SetDevOffset(embName); } }, - { CkptDataType::EMB_CURR_STAT, [=] { SetEmbCurrStat(embName); } }, - { CkptDataType::EVICT_POS, [=] { SetEvictPos(embName); } } }; + map> dataLoadMap { + {CkptDataType::EMB_HASHMAP, [this, embName] { SetEmbHashMap(embName); }}, + {CkptDataType::DEV_OFFSET, [this, embName] { SetDevOffset(embName); }}, + {CkptDataType::EMB_CURR_STAT, [this, embName] { SetEmbCurrStat(embName); }}, + {CkptDataType::EVICT_POS, [this, embName] { SetEvictPos(embName); }}}; CleanTransfer(); transferData = move(loadedData); @@ -107,8 +108,8 @@ void EmbHashCkpt::SetDevOffsetTrans(string embName) transferData.attributeSize = attribute.size() * eightBytes; transArr.reserve(embDevOffsetSize); - transArr.insert(transArr.end(), devOffset2Batch.begin(), devOffset2Batch.end()); - transArr.insert(transArr.end(), devOffset2Key.begin(), devOffset2Key.end()); + transArr.insert(transArr.cend(), devOffset2Batch.cbegin(), devOffset2Batch.cend()); + transArr.insert(transArr.cend(), devOffset2Key.cbegin(), devOffset2Key.cend()); } void EmbHashCkpt::SetEmbCurrStatTrans(string embName) @@ -142,7 +143,7 @@ void EmbHashCkpt::SetEvictPosTrans(string embName) transferData.attributeSize = attribute.size() * eightBytes; transArr.reserve(evictPosSize); - transArr.insert(transArr.end(), evictPos.begin(), evictPos.end()); + transArr.insert(transArr.cend(), evictPos.cbegin(), evictPos.cend()); } void EmbHashCkpt::SetEmbHashMap(string embName) @@ -168,7 +169,7 @@ void EmbHashCkpt::SetDevOffset(string embName) dev2Key.reserve(attribute.at(attrbDev2KeyIdx)); fill(dev2Batch.begin(), dev2Batch.end(), -1); - dev2Key.insert(dev2Key.begin(), transArr.begin() + attribute.at(attrbDev2BatchIdx), transArr.end()); + dev2Key.insert(dev2Key.cbegin(), transArr.cbegin() + attribute.at(attrbDev2BatchIdx), transArr.cend()); } void EmbHashCkpt::SetEvictPos(string embName) @@ -179,7 +180,7 @@ void EmbHashCkpt::SetEvictPos(string embName) evictPos.resize(attribute.at(attrEvictPosIdx)); fill(evictPos.begin(), evictPos.end(), -1); - evictPos.insert(evictPos.begin(), transArr.begin() + attribute.at(attrEvictPosIdx), transArr.end()); + evictPos.insert(evictPos.cbegin(), transArr.cbegin() + attribute.at(attrEvictPosIdx), transArr.cend()); } diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index f9179635..4d3c7444 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -51,9 +51,9 @@ vector FeatAdmitNEvictCkpt::GetEmbNames() CkptTransData FeatAdmitNEvictCkpt::GetDataset(CkptDataType dataType, string embName) { - map> dataTransMap { { CkptDataType::TABLE_2_THRESH, - [=] { SetTable2ThreshTrans(embName); } }, - { CkptDataType::HIST_REC, [=] { SetHistRecTrans(embName); } } }; + map> dataTransMap{ + {CkptDataType::TABLE_2_THRESH, [this, embName] { SetTable2ThreshTrans(embName); }}, + {CkptDataType::HIST_REC, [this, embName] { SetHistRecTrans(embName); }}}; CleanTransfer(); dataTransMap.at(dataType)(); @@ -62,9 +62,9 @@ CkptTransData FeatAdmitNEvictCkpt::GetDataset(CkptDataType dataType, string embN void FeatAdmitNEvictCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { - map> dataLoadMap { { CkptDataType::TABLE_2_THRESH, - [=] { SetTable2Thresh(embName); } }, - { CkptDataType::HIST_REC, [=] { SetHistRec(embName); } } }; + map> dataLoadMap{ + {CkptDataType::TABLE_2_THRESH, [this, embName] { SetTable2Thresh(embName); }}, + {CkptDataType::HIST_REC, [this, embName] { SetHistRec(embName); }}}; CleanTransfer(); transferData = move(loadedData); @@ -144,7 +144,7 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) for (size_t i = featItemInfoOffset; i < featItemInfoTotalSize + featItemInfoOffset; i += featItemInfoSaveNum) { process = i % printPerStep; if (process == 1) { - LOG_DEBUG("====in SetHistRec, process : %f", i/featItemInfoTotalSize); + LOG_DEBUG("====in SetHistRec, process : %f", i / featItemInfoTotalSize); } auto featureId = transArr[i + featureIdIdxOffset]; auto count = transArr[i + countIdxOffset]; diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp index d0fa9499..9eb8d2b7 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp @@ -49,8 +49,9 @@ vector HostEmbCkpt::GetEmbNames() // save info and data CkptTransData HostEmbCkpt::GetDataset(CkptDataType dataType, string embName) { - map> dataTransMap { { CkptDataType::EMB_INFO, [=] { SetEmbInfoTrans(embName); } }, - { CkptDataType::EMB_DATA, [=] { SetEmbDataTrans(embName); } } }; + map> dataTransMap{ + {CkptDataType::EMB_INFO, [this, embName] { SetEmbInfoTrans(embName); }}, + {CkptDataType::EMB_DATA, [this, embName] { SetEmbDataTrans(embName); }}}; CleanTransfer(); dataTransMap.at(dataType)(); @@ -61,16 +62,15 @@ void HostEmbCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransDat { LOG_INFO("Parameter dataType:{}, embName:{}, loadedData:{}", dataType, embName, loadedData.datasetSize); - return; } // load info and data void HostEmbCkpt::SetDatasetForLoadEmb(CkptDataType dataType, string embName, CkptTransData& loadedData, CkptData& ckptData) { - map> dataLoadMap { - { CkptDataType::EMB_INFO, [&] { SetEmbInfo(embName, ckptData); } }, - { CkptDataType::EMB_DATA, [&] { SetEmbData(embName, ckptData); } } }; + map> dataLoadMap{ + {CkptDataType::EMB_INFO, [this, embName, &ckptData] { SetEmbInfo(embName, ckptData); }}, + {CkptDataType::EMB_DATA, [this, embName, &ckptData] { SetEmbData(embName, ckptData); }}}; CleanTransfer(); transferData = move(loadedData); @@ -116,10 +116,9 @@ void HostEmbCkpt::SetEmbInfo(string embName, CkptData& ckptData) } // load Emb data -void HostEmbCkpt::SetEmbData(string embName, CkptData& ckptData) +void HostEmbCkpt::SetEmbData(string embName, CkptData& ckptData) const { LOG_INFO("Parameter embName:{}, ckptData:{}", embName, ckptData.embHashMaps.empty()); - return; } int HostEmbCkpt::GetEmbInfoSize() diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h index c1b95ff8..1dc34691 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h @@ -51,7 +51,7 @@ namespace MxRec { void SetEmbDataTrans(string embName); void SetEmbInfo(string embName, CkptData& ckptData); - void SetEmbData(string embName, CkptData& ckptData); + void SetEmbData(string embName, CkptData& ckptData) const; int GetEmbInfoSize(); size_t GetEmbDataSize(string embName); diff --git a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp index f9377af7..1609faaf 100644 --- a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp @@ -54,8 +54,8 @@ vector KeyFreqMapCkpt::GetEmbNames() CkptTransData KeyFreqMapCkpt::GetDataset(CkptDataType dataType, string embName) { map> dataTransMap { - { CkptDataType::DDR_FREQ_MAP, [=] { SetDDRFreqMapTrans(embName); } }, - { CkptDataType::EXCLUDE_FREQ_MAP, [=] { SetExcludeDDRFreqMapTrans(embName); } } }; + {CkptDataType::DDR_FREQ_MAP, [this, embName] { SetDDRFreqMapTrans(embName); }}, + {CkptDataType::EXCLUDE_FREQ_MAP, [this, embName] { SetExcludeDDRFreqMapTrans(embName); }}}; CleanTransfer(); dataTransMap.at(dataType)(); @@ -65,8 +65,8 @@ CkptTransData KeyFreqMapCkpt::GetDataset(CkptDataType dataType, string embName) void KeyFreqMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) { map> dataLoadMap { - { CkptDataType::DDR_FREQ_MAP, [=] { SetDDRFreqMaps(embName); } }, - { CkptDataType::EXCLUDE_FREQ_MAP, [=] { SetExcludeDDRFreqMaps(embName); } } }; + {CkptDataType::DDR_FREQ_MAP, [this, embName] { SetDDRFreqMaps(embName); }}, + {CkptDataType::EXCLUDE_FREQ_MAP, [this, embName] { SetExcludeDDRFreqMaps(embName); }}}; CleanTransfer(); transferData = move(loadedData); diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp index a3f36f4d..46b8e8b2 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp @@ -37,7 +37,7 @@ vector NddrFeatMapCkpt::GetDirNames() vector NddrFeatMapCkpt::GetEmbNames() { vector embNames; - for (const auto& item : saveKeyOffsetMap) { + for (const auto& item : as_const(saveKeyOffsetMap)) { embNames.push_back(item.first); } return embNames; diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp index 261911b2..41e81a0c 100644 --- a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp @@ -38,7 +38,7 @@ vector NddrOffsetCkpt::GetDirNames() vector NddrOffsetCkpt::GetEmbNames() { vector embNames; - for (const auto& item : saveMaxOffset) { + for (const auto& item : as_const(saveMaxOffset)) { embNames.push_back(item.first); } return embNames; diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 855f2187..863e1f6a 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -17,9 +17,9 @@ using namespace MxRec; -void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad) +void EmbHashMap::Init(const RankInfo& ri, const vector& embInfos, bool ifLoad) { - this->rankInfo = rankInfo; + this->rankInfo = ri; if (!ifLoad) { EmbHashMapInfo embHashMapInfo; LOG_INFO("init emb hash map from scratch"); @@ -39,7 +39,7 @@ void EmbHashMap::Init(const RankInfo& rankInfo, const vector& embInfos, } } -inline void ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) +void EmbHashMap::ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) const { embHashMap.swapPos.clear(); embHashMap.lookUpVec.clear(); @@ -153,7 +153,7 @@ void EmbHashMap::FindAndUpdateOffset(const string& embName, vector& k } } -int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap) +int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap) const { int32_t offset; const auto& iter = embHashMap.hostHashMap.find(key); @@ -219,7 +219,7 @@ auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map LOG_DEBUG(HYBRID_BLOCKING + " start GetHashMaps"); HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); auto embHashMapsOld = embHashMaps; - int checkResult = hybridMgmtBlock->CheckSaveEmbdMapValid(); + int checkResult = hybridMgmtBlock->CheckSaveEmbMapValid(); if (checkResult == 0) { // 检查是否需要回退 return embHashMapsOld; @@ -236,8 +236,8 @@ auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map embHashMap.hostHashMap[oldKey] = static_cast(tempOffset); } embHashMap.maxOffset = embHashMap.maxOffsetOld; - for (auto &Offset2Key: embHashMap.devOffset2KeyOld) { - embHashMap.devOffset2Key[Offset2Key.first] = Offset2Key.second; + for (auto &offset2Key: embHashMap.devOffset2KeyOld) { + embHashMap.devOffset2Key[offset2Key.first] = offset2Key.second; } } return embHashMapsOld; @@ -264,7 +264,7 @@ void EmbHashMapInfo::SetStartCount() /// 判断HBM是否有剩余空间 /// \param i 查询向量的大小 /// \return -bool EmbHashMapInfo::HasFree(size_t i) +bool EmbHashMapInfo::HasFree(size_t i) const { return freeSize < i; } @@ -404,7 +404,7 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys /// \param channelId 通道索引(训练/推理) /// \param offset 未初始化变量,用于记录 /// \return -bool EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset) +bool EmbHashMap::FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset) const { const auto& iter = embHashMap.hostHashMap.find(key); @@ -519,7 +519,7 @@ bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hos /// HBM-DDR换入换出时刷新频次信息 /// \param embName emb表名 /// \param embHashMap emb hash map -void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap) +void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap) const { if (!isSSDEnabled) { return; @@ -583,7 +583,7 @@ void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHa /// \param embTableName emb表名 /// \param key key /// \param type 记录类型枚举 -void EmbHashMap::AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type) +void EmbHashMap::AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type) const { if (!isSSDEnabled) { return; diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 654e52ed..df44e74c 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -24,7 +24,7 @@ namespace MxRec { public: EmbHashMap() = default; - void Init(const RankInfo& rankInfo, const vector& embInfos, bool ifLoad = false); + void Init(const RankInfo& ri, const vector& embInfos, bool ifLoad = false); void Process(const string& embName, std::vector& keys, DDRParam& ddrParam, int channelId); @@ -52,7 +52,7 @@ namespace MxRec { void FindOffset(const string& embName, const vector& keys, size_t currentBatchId, size_t keepBatchId, int channelId); - bool FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset); + bool FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset) const; void UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, EmbHashMapInfo& embHashMap) const; @@ -75,17 +75,19 @@ namespace MxRec { void FindAndUpdateBatchId(vector& keys, size_t currentBatchId, size_t keySize, EmbHashMapInfo& embHashMap) const; - int32_t FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap); + int32_t FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap) const; - void AddCacheManagerTraceLog(const string& embName, const EmbHashMapInfo& embHashMap) const; + void AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const; - void AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type); + void AddKeyFreqInfo(const string& embTableName, const emb_key_t& key, RecordType type) const; + + void ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) const; RankInfo rankInfo; int swapId { 0 }; GTEST_PRIVATE: - void RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap); + void RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap) const; }; } diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index f879f4d4..ce6b7eed 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -18,11 +18,11 @@ using namespace std; using namespace MxRec; using namespace tensorflow; -void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) +void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int initSeed) { #ifndef GTEST this->rankInfo = rInfo; - this->seed = seed; + this->seed = initSeed; LOG_INFO("EmbTable init, deviceID {}, embSize {} running", rInfo.deviceId, embInfo.extEmbeddingSize); // 计算embedding table需要分配的内存块数 auto ret = aclrtSetDevice(static_cast(rInfo.deviceId)); @@ -35,13 +35,13 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed) for (int i = 0; i < INIT_BLOCK_COUNT; ++i) { // 申请新的内存块 void *newBlock = nullptr; - aclError ret = aclrtMalloc(&newBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtMalloc failed, ret={}", ret); + aclError ec = aclrtMalloc(&newBlock, blockSize * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); + if (ec != ACL_SUCCESS) { + LOG_ERROR("aclrtMalloc failed, ret={}", ec); throw AclError(); } // 申请内存初始化 - RandomInit(newBlock, embInfo.initializeInfos, seed); + RandomInit(newBlock, embInfo.initializeInfos); // 将新的内存块加入内存链表 memoryList.push_back(newBlock); SplitMemoryBlock(newBlock); @@ -77,7 +77,7 @@ int64_t EmbTable::GetEmbAddress() LOG_ERROR("aclrtMalloc failed, ret={}", ret); throw AclError(); } - RandomInit(addBlock, embInfo.initializeInfos, seed); + RandomInit(addBlock, embInfo.initializeInfos); // 将新的内存块加入内存list memoryList.push_back(addBlock); SplitMemoryBlock(addBlock); @@ -90,7 +90,7 @@ int64_t EmbTable::GetEmbAddress() #endif } -void EmbTable::RandomInit(void* newBlock, const vector& initializeInfos, int seed) +void EmbTable::RandomInit(void* newBlock, const vector& initializeInfos) { #ifndef GTEST LOG_INFO("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); @@ -106,7 +106,7 @@ void EmbTable::RandomInit(void* newBlock, const vector& initiali #endif } -void EmbTable::ExecuteAclMemcpy(void* newBlock, vector devEmb) +void EmbTable::ExecuteAclMemcpy(void* newBlock, vector devEmb) const { #ifndef GTEST aclError ret = aclrtMemcpy( @@ -132,7 +132,7 @@ void EmbTable::SplitMemoryBlock(void *newBlock) #endif } -void EmbTable::PrintStatus() +void EmbTable::PrintStatus() const { // 输出embedding table的总容量和未使用的使用容量 LOG_INFO("Total capacity:{}, Unused capacity:{}", diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h index cc83a4e2..7f3eb326 100644 --- a/src/core/emb_table/emb_table.h +++ b/src/core/emb_table/emb_table.h @@ -12,9 +12,8 @@ #include #include #include + #include "utils/common.h" -#include -#include namespace MxRec { @@ -32,7 +31,7 @@ namespace MxRec { int64_t GetEmbAddress(); // 打印emb表使用情况 - void PrintStatus(); + void PrintStatus() const; EmbTable(const EmbTable&) = delete; @@ -42,7 +41,7 @@ namespace MxRec { EmbTable& operator=(EmbTable&&) = delete; - void ExecuteAclMemcpy(void* newBlock, vector devEmb); + void ExecuteAclMemcpy(void* newBlock, vector devEmb) const; GTEST_PRIVATE: constexpr static int BLOCK_EMB_COUNT = 100000; @@ -55,14 +54,12 @@ namespace MxRec { int totalCapacity = 1; int usedCapacity = 0; int seed = 0; - float mean = 0; - float stddev = 1; // embedding地址的列表 list embeddingList; // 内存块列表 vector memoryList; - void RandomInit(void* newBlock, const vector &initializeInfos, int seed); + void RandomInit(void* newBlock, const vector &initializeInfos); // embSize由embInfo得出 void SplitMemoryBlock(void* newBlock); diff --git a/src/core/hd_transfer/acl_channel.h b/src/core/hd_transfer/acl_channel.h index ce8da921..08da510f 100644 --- a/src/core/hd_transfer/acl_channel.h +++ b/src/core/hd_transfer/acl_channel.h @@ -19,14 +19,14 @@ namespace tensorflow { Status RecvTensorByAcl(acltdtChannelHandle *acl_handle, std::vector &tensors); #else - Status RecvTensorByAcl(const acltdtChannelHandle* acl_handle, std::vector& tensors); + Status RecvTensorByAcl(const acltdtChannelHandle* aclHandle, std::vector& tensors); - Status StopRecvTensorByAcl(acltdtChannelHandle **handle, const std::string &channel_name); + Status StopRecvTensorByAcl(acltdtChannelHandle **handle, const std::string &channelName); #endif - Status SendTensorsByAcl(const acltdtChannelHandle* acl_handle, acltdtTensorType acl_type, - const std::vector& tensors, bool& is_need_resend); + Status SendTensorsByAcl(const acltdtChannelHandle* aclHandle, acltdtTensorType aclType, + const std::vector& tensors, bool& isNeedResend); } // namespace tensorflow diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 521468a7..b189e8b3 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -39,7 +39,7 @@ void HostEmb::Initialize(const vector& embInfos, int seed) /// \param embeddingSize emb维度 /// \param embData emb数据 void HostEmb::EmbDataGenerator(const vector &initializeInfos, int seed, int vocabSize, - int embeddingSize, vector> &embData) + int embeddingSize, vector> &embData) const { #ifndef GTEST LOG_INFO(HOSTEMB + "GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); @@ -135,7 +135,7 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI LOG_INFO(HOSTEMB + "UpdateEmbV2, channelId:{}, embName:{}", channelId, embName); EASY_FUNCTION(profiler::colors::Purple) auto updateThread = - [&, missingKeysHostPos, channelId, embName] { + [this, missingKeysHostPos, channelId, embName] { auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; LOG_INFO(HOSTEMB + "wait D2H embs, channelId:{}", channelId); @@ -222,7 +222,7 @@ auto HostEmb::GetHostEmbs() -> absl::flat_hash_map* /// \param embData emb数据 /// \param offset 偏移列表 void HostEmb::EmbPartGenerator(const vector &initializeInfos, vector> &embData, - const vector& offset) + const vector& offset) const { for (auto initializeInfo: initializeInfos) { LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name); diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index cfe9339a..d61b7cad 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -54,9 +54,9 @@ namespace MxRec { std::vector> procThreadsForEval; void EmbDataGenerator(const vector& initializeInfos, int seed, int vocabSize, int embeddingSize, - vector>& embData); + vector>& embData) const; void EmbPartGenerator(const vector &initializeInfos, vector> &embData, - const vector& offset); + const vector& offset) const; }; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index e63fa5a0..75d65b3e 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -35,7 +35,7 @@ bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& /// Openmpi通信域进程数设置、计算所有表host特征数量总数、设置训练模式(HBM/DDR) /// \param rankInfo /// \param embInfos -void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfos) +void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfos) const { #ifndef GTEST MPI_Comm_size(MPI_COMM_WORLD, &rankInfo.rankSize); @@ -171,10 +171,10 @@ void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) /// 保存CacheManager时恢复数据(与恢复hostHashMap类似,仅恢复保存数据,不修改源数据) /// \param saveData 保存数据 -void HybridMgmt::RestoreFreq4Save(CkptData& saveData) +void HybridMgmt::RestoreFreq4Save(CkptData& saveData) const { // 仅在差异1步时执行恢复操作 - int checkResult = hybridMgmtBlock->CheckSaveEmbdMapValid(); + int checkResult = hybridMgmtBlock->CheckSaveEmbMapValid(); if (checkResult != 1) { return; } @@ -373,7 +373,7 @@ key_offset_map_t HybridMgmt::SendHostMap(const string tableName) keyOffsetMap = preprocess->GetKeyOffsetMap(); } - if ((!keyOffsetMap.empty()) && keyOffsetMap.count(tableName)) { + if ((!keyOffsetMap.empty()) && keyOffsetMap.count(tableName) > 0) { for (const auto& it : keyOffsetMap.at(tableName)) { sendKeyOffsetMap[it.first] = it.second; } @@ -393,13 +393,13 @@ void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) key_offset_mem_t loadKeyOffsetMap; offset_mem_t loadMaxOffset; if (!ReceiveKeyOffsetMap.empty()) { - for (const auto& KeyOffsetMap : ReceiveKeyOffsetMap) { - auto& SingleHashMap = loadKeyOffsetMap[KeyOffsetMap.first]; - auto& MaxOffset = loadMaxOffset[KeyOffsetMap.first]; - for (const auto& it : KeyOffsetMap.second) { - SingleHashMap[it.first] = it.second; + for (const auto& keyOffsetMap : as_const(ReceiveKeyOffsetMap)) { + auto& singleHashMap = loadKeyOffsetMap[keyOffsetMap.first]; + auto& maxOffset = loadMaxOffset[keyOffsetMap.first]; + for (const auto& it : keyOffsetMap.second) { + singleHashMap[it.first] = it.second; } - MaxOffset = KeyOffsetMap.second.size(); + maxOffset = keyOffsetMap.second.size(); } } if (!mgmtRankInfo.noDDR) { @@ -422,39 +422,39 @@ void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) /// \param setupHostEmbs /// \param embTableCount /// \return -bool HybridMgmt::IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t& embTableCount) +bool HybridMgmt::IsLoadDataMatches(emb_mem_t& loadHostEmbs, EmbInfo& setupHostEmbs, size_t& embTableCount) const { bool loadDataMatches = { true }; - const auto& loadEmbTable { loadHostEmbs->find(setupHostEmbs->name) }; - if (loadEmbTable != loadHostEmbs->end()) { + const auto& loadEmbTable { loadHostEmbs.find(setupHostEmbs.name) }; + if (loadEmbTable != loadHostEmbs.end()) { embTableCount++; const auto& loadEmbInfo { loadEmbTable->second.hostEmbInfo }; - if (setupHostEmbs->sendCount != loadEmbInfo.sendCount) { + if (setupHostEmbs.sendCount != loadEmbInfo.sendCount) { LOG_ERROR(MGMT + "Load data sendCount {} for table {} does not match setup sendCount {}", - setupHostEmbs->sendCount, setupHostEmbs->name, loadEmbInfo.sendCount); + setupHostEmbs.sendCount, setupHostEmbs.name, loadEmbInfo.sendCount); loadDataMatches = false; } - if (setupHostEmbs->extEmbeddingSize != loadEmbInfo.extEmbeddingSize) { + if (setupHostEmbs.extEmbeddingSize != loadEmbInfo.extEmbeddingSize) { LOG_ERROR(MGMT + "Load data extEmbeddingSize {} for table {} does not match setup extEmbeddingSize {}", - setupHostEmbs->extEmbeddingSize, setupHostEmbs->name, loadEmbInfo.extEmbeddingSize); + setupHostEmbs.extEmbeddingSize, setupHostEmbs.name, loadEmbInfo.extEmbeddingSize); loadDataMatches = false; } - if (setupHostEmbs->devVocabSize != loadEmbInfo.devVocabSize) { + if (setupHostEmbs.devVocabSize != loadEmbInfo.devVocabSize) { LOG_ERROR(MGMT + "Load data devVocabSize {} for table {} does not match setup devVocabSize {}", - setupHostEmbs->devVocabSize, setupHostEmbs->name, loadEmbInfo.devVocabSize); + setupHostEmbs.devVocabSize, setupHostEmbs.name, loadEmbInfo.devVocabSize); loadDataMatches = false; } - if (setupHostEmbs->hostVocabSize != loadEmbInfo.hostVocabSize) { + if (setupHostEmbs.hostVocabSize != loadEmbInfo.hostVocabSize) { LOG_ERROR(MGMT + "Load data hostVocabSize {} for table {} does not match setup hostVocabSize {}", - setupHostEmbs->hostVocabSize, setupHostEmbs->name, loadEmbInfo.hostVocabSize); + setupHostEmbs.hostVocabSize, setupHostEmbs.name, loadEmbInfo.hostVocabSize); loadDataMatches = false; } if (!loadDataMatches) { return false; } } else { - LOG_ERROR(MGMT + "Load data does not contain table with table name: {}", setupHostEmbs->name); + LOG_ERROR(MGMT + "Load data does not contain table with table name: {}", setupHostEmbs.name); return false; } return true; @@ -468,7 +468,7 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) size_t embTableCount { 0 }; auto loadHostEmbs { loadData.hostEmbs }; for (EmbInfo setupHostEmbs : mgmtEmbInfo) { - if (!IsLoadDataMatches(loadHostEmbs, &setupHostEmbs, embTableCount)) { + if (!IsLoadDataMatches(*loadHostEmbs, setupHostEmbs, embTableCount)) { return false; } } @@ -597,7 +597,7 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) // 循环处理每个表的数据 for (const auto& embInfo: mgmtEmbInfo) { - TimeCost ParseKeysTC; + TimeCost parseKeysTc; // get TimeCost getTensorsSyncTC; @@ -646,8 +646,8 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) // 发送恢复向量 TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); - LOG_DEBUG("sendRestoreSyncTC(ms):{}, sendTensorsSyncTC(ms):{}, ParseKeysTC HBM mode (ms):{}", - sendRestoreSyncTC.ElapsedMS(), sendTensorsSyncTC.ElapsedMS(), ParseKeysTC.ElapsedMS()); + LOG_DEBUG("sendRestoreSyncTC(ms):{}, sendTensorsSyncTC(ms):{}, parseKeysTc HBM mode (ms):{}", + sendRestoreSyncTC.ElapsedMS(), sendTensorsSyncTC.ElapsedMS(), parseKeysTc.ElapsedMS()); } batchId++; return true; @@ -697,7 +697,7 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) return true; } -inline void HandlePrepareDDRDataRet(TransferRet prepareSSDRet) +void HybridMgmt::HandlePrepareDDRDataRet(TransferRet prepareSSDRet) const { LOG_ERROR("Transfer embedding with DDR and SSD error."); if (prepareSSDRet == TransferRet::SSD_SPACE_NOT_ENOUGH) { @@ -757,7 +757,8 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { - vector uniqueKeys, restoreVecSec; + vector uniqueKeys; + vector restoreVecSec; preprocess->GlobalUnique(offsetsOut, uniqueKeys, restoreVecSec); TimeCost sendUnikeysSyncTC; @@ -772,7 +773,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha TimeCost sendTensorsTC; hdTransfer->Send(TransferChannel::LOOKUP, { ddrParam.tmpDataOut.front() }, channelId, embName); - ddrParam.tmpDataOut.erase(ddrParam.tmpDataOut.begin()); + ddrParam.tmpDataOut.erase(ddrParam.tmpDataOut.cbegin()); hdTransfer->Send(TransferChannel::SWAP, ddrParam.tmpDataOut, channelId, embName); if (!mgmtRankInfo.useStatic) { auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); @@ -853,17 +854,17 @@ bool HybridMgmt::Evict() LOG_DEBUG(MGMT + "evict triggered by hook, evict TableNum {}", evictKeyMap.size()); // 表为空,淘汰触发失败 - if (evictKeyMap.size() == 0) { + if (evictKeyMap.empty()) { LOG_WARN(MGMT + "evict triggered by hook before dataset in injected"); return false; } if (mgmtRankInfo.noDDR) { - for (auto evict : evictKeyMap) { + for (const auto& evict : as_const(evictKeyMap)) { preprocess->EvictKeys(evict.first, evict.second); } } else { - for (auto evict : evictKeyMap) { + for (const auto& evict : as_const(evictKeyMap)) { EvictKeys(evict.first, evict.second); EvictSSDKeys(evict.first, evict.second); } @@ -916,7 +917,7 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) } inline void HybridMgmt::PrepareDDRData(const string& embTableName, EmbHashMapInfo& embHashMap, - const vector& keys, int channelId) + const vector& keys, int channelId) const { if (!isSSDEnabled) { return; @@ -930,7 +931,7 @@ inline void HybridMgmt::PrepareDDRData(const string& embTableName, EmbHashMapInf LOG_DEBUG("PrepareDDRData end, TimeCost(ms):{}", prepareDDRDataTc.ElapsedMS()); } -void HybridMgmt::EvictSSDKeys(const string& embName, const vector& keys) +void HybridMgmt::EvictSSDKeys(const string& embName, const vector& keys) const { if (!isSSDEnabled) { return; @@ -944,7 +945,7 @@ void HybridMgmt::EvictSSDKeys(const string& embName, const vector& ke cacheManager->EvictSSDEmbedding(embName, ssdKeys); } -int HybridMgmt::GetStepFromPath(const string& loadPath) +int HybridMgmt::GetStepFromPath(const string& loadPath) const { regex pattern("sparse-model-\\d+-(\\d+)"); smatch match; @@ -969,7 +970,7 @@ int HybridMgmt::GetStepFromPath(const string& loadPath) /// 通过pyBind在python侧调用,通知hybridMgmt上层即将进行图的执行,需要进行唤醒 /// \param channelID 通道id /// \param steps 运行的步数,由于可能存在循环下沉,所以1个session run 对应N步 -void HybridMgmt::NotifyBySessionRun(int channelID) +void HybridMgmt::NotifyBySessionRun(int channelID) const { hybridMgmtBlock->CheckAndNotifyWake(channelID); } @@ -977,7 +978,7 @@ void HybridMgmt::NotifyBySessionRun(int channelID) /// 通过pyBind在python侧调用,通知hybridMgmt上层即将进行图的执行 /// \param channelID 通道id /// \param steps 运行的步数,由于可能存在循环下沉,所以1个session run 对应N步 -void HybridMgmt::CountStepBySessionRun(int channelID, int steps) +void HybridMgmt::CountStepBySessionRun(int channelID, int steps) const { hybridMgmtBlock->CountPythonStep(channelID, steps); } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 563296d1..fd902589 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -27,6 +27,7 @@ #include "hd_transfer/hd_transfer.h" #include "key_process/key_process.h" #include "ssd_cache/cache_manager.h" +#include "hybrid_mgmt_block.h" namespace MxRec { using namespace std; @@ -114,25 +115,29 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); - bool IsLoadDataMatches(emb_mem_t* loadHostEmbs, EmbInfo* setupHostEmbs, size_t& embTableCount); + bool IsLoadDataMatches(emb_mem_t& loadHostEmbs, EmbInfo& setupHostEmbs, size_t& embTableCount) const; - void NotifyBySessionRun(int channelID); + void NotifyBySessionRun(int channelID) const; - void CountStepBySessionRun(int channelID, int steps); + void CountStepBySessionRun(int channelID, int steps) const; private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed); - void InitRankInfo(RankInfo& rankInfo, const vector& embInfos); + void InitRankInfo(RankInfo& rankInfo, const vector& embInfos) const; - void EvictSSDKeys(const string& embName, const vector& keys); + void EvictSSDKeys(const string& embName, const vector& keys) const; void PrepareDDRData(const std::string& embTableName, EmbHashMapInfo& embHashMap, - const vector& keys, int channelId); - int GetStepFromPath(const string& loadPath); + const vector &keys, int channelId) const; + + int GetStepFromPath(const string& loadPath) const; + static void AddCacheManagerTraceLog(CkptData& saveData); - void RestoreFreq4Save(CkptData& saveData); + + void RestoreFreq4Save(CkptData& saveData) const; + private: int currentBatchId; int trainBatchId = 0; // 0-199, 200- @@ -153,6 +158,7 @@ namespace MxRec { bool isLoad { false }; void TrainTask(TaskType type); + void EvalTask(TaskType type); bool EndBatch(int batchId, int channelId) const; @@ -160,6 +166,8 @@ namespace MxRec { void EmbHDTransWrap(int channelId, const int& batchId, int start); bool LoadMatchesDDRSetup(const CkptData& loadData); + + void HandlePrepareDDRDataRet(TransferRet prepareSSDRet) const; }; } #endif // MX_REC_EMB_MGMT_H diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index cd0649e2..7f530a64 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -10,14 +10,16 @@ #include "utils/common.h" #include "hybrid_mgmt_block.h" +using namespace MxRec; + /// 检查当前hybrid是否运行到了应该阻塞的位置 /// \param channelId train 0 eval 1 void HybridMgmtBlock::CheckAndSetBlock(int channelId) { // 判断save时候的阻塞情况 // 当在进行训练通道,且save interval不为0和-1(不需要阻塞),且运行到了需要阻塞的步骤 - if (channelId==TRAIN_CHANNEL_ID && saveInterval!=0 - && saveInterval!=-1 && hybridBatchId[TRAIN_CHANNEL_ID]%saveInterval==0) { + if (channelId == TRAIN_CHANNEL_ID && saveInterval != 0 && + saveInterval != -1 && hybridBatchId[TRAIN_CHANNEL_ID] % saveInterval == 0) { LOG_DEBUG(HYBRID_BLOCKING + "blocking by save saveInterval {} pythonBatchId {} hybridBatchId {}", saveInterval, pythonBatchId[channelId], hybridBatchId[channelId]); isBlock[TRAIN_CHANNEL_ID] = true; @@ -41,7 +43,7 @@ void HybridMgmtBlock::CheckAndSetBlock(int channelId) void HybridMgmtBlock::CheckAndNotifyWake(int channelId) { LOG_DEBUG(HYBRID_BLOCKING + "start notify channelId {} pythonBatchId {} hybridBatchId {}", - channelId, pythonBatchId[channelId], hybridBatchId[channelId]); + channelId, pythonBatchId[channelId], hybridBatchId[channelId]); CheckValid(channelId); if (pythonBatchId[channelId] >= hybridBatchId[channelId]) { @@ -115,7 +117,6 @@ void HybridMgmtBlock::CheckValid(int channelId) lastRunChannelId, hybridBatchId[lastRunChannelId]); } lastRunChannelId = channelId; - return; } /// 进行阻塞操作 @@ -134,7 +135,6 @@ void HybridMgmtBlock::DoBlock(int channelId) } LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt is starting to wake up channelId {} hybridBatchId {}", channelId, hybridBatchId[channelId]); - return; } /// 重置所有的步数,主要用于图重构的情况,readembedkey算子重建 @@ -152,7 +152,7 @@ void HybridMgmtBlock::ResetAll(int channelId) /// 检查当前的步数是否可以进行save /// \return 0 is legal, 1 需要回退一步, -1 表示错误 -int HybridMgmtBlock::CheckSaveEmbdMapValid() +int HybridMgmtBlock::CheckSaveEmbMapValid() { // 检查数据通道此时的HashMap是否被提前处理了 if (pythonBatchId[lastRunChannelId] >= hybridBatchId[lastRunChannelId]) { @@ -198,13 +198,12 @@ void HybridMgmtBlock::Destroy() isRunning = false; } - -void HybridMgmtBlock::SetRankInfo(RankInfo rankInfo) +void HybridMgmtBlock::SetRankInfo(RankInfo ri) { - this->stepsInterval[TRAIN_CHANNEL_ID] = rankInfo.maxStep[TRAIN_CHANNEL_ID]; - this->stepsInterval[EVAL_CHANNEL_ID] = rankInfo.maxStep[EVAL_CHANNEL_ID]; - this->saveInterval = rankInfo.maxStep[SAVE_STEP_INDEX]; - this->rankInfo = rankInfo; + this->stepsInterval[TRAIN_CHANNEL_ID] = ri.maxStep[TRAIN_CHANNEL_ID]; + this->stepsInterval[EVAL_CHANNEL_ID] = ri.maxStep[EVAL_CHANNEL_ID]; + this->saveInterval = ri.maxStep[SAVE_STEP_INDEX]; + this->rankInfo = ri; }; void HybridMgmtBlock::SetStepInterval(int trainStep, int evalStep) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h index 257de095..02d4a070 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -7,70 +7,88 @@ */ #ifndef MX_REC_HYBRID_BLOCKING_H #define MX_REC_HYBRID_BLOCKING_H + #include #include #include "hd_transfer/hd_transfer.h" #include "utils/common.h" #include "utils/singleton.h" -using namespace MxRec; -const std::string HYBRID_BLOCKING = "[HYBRID_BLOCKING] "; -const int SAVE_STEP_INDEX = 2; -const std::chrono::milliseconds SLEEP_MS = 20ms; -class HybridMgmtBlock { -public: - // 上一次运行的通道ID - int lastRunChannelId = -1; - // hybrid将要处理的batch id - int hybridBatchId[2] = {0, 0}; - // python侧将要处理的batch id - int pythonBatchId[2] = {0, 0}; - // readEmbed算子侧将要处理的batch id - int readEmbedBatchId[2] = {0, 0}; - bool isRunning = true; - - ~HybridMgmtBlock(); - void CheckAndNotifyWake(int channelId); - void CountPythonStep(int channelId, int steps); - void CheckAndSetBlock(int channelId); - void CheckValid(int channelId); - void DoBlock(int channelId); - void ResetAll(int channelId); - int CheckSaveEmbdMapValid(); - bool GetBlockStatus(int channelId); - void SetBlockStatus(int channelId, bool block); - void SetRankInfo(RankInfo rankInfo); - void SetStepInterval(int trainStep, int evalStep); - bool WaitValid(int channelId); - void Destroy(); -private: - // 通道i运行多少步后切换为通道j - int stepsInterval[2] = {0, 0}; - // 控制通道阻塞的变量 - bool isBlock[2] = {true, true}; - // 控制训练了多少步进行保存的步数 - int saveInterval = 0; - RankInfo rankInfo; -}; - -class HybridMgmtBlockingException : public std::exception { -public: - explicit HybridMgmtBlockingException(const string scene) - { - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - int channelId = hybridMgmtBlock->lastRunChannelId; - int preprocessBatchNumber = hybridMgmtBlock->hybridBatchId[channelId]; - int currentBatchNumber = hybridMgmtBlock->pythonBatchId[channelId]; - str = StringFormat("Error happened at HyBridmgmt Blocking, it finds that " - "preprocess batch number not match current using batch number " - "%s , last use channel id is %d, preprocessBatchNumber is %d ," - "currentBatchNumber is %d. please check your setting of train " - "steps and eval steps", scene.c_str(), channelId, preprocessBatchNumber, - currentBatchNumber); - LOG_ERROR(str); - } - -private: - string str; -}; + +namespace MxRec { + const std::string HYBRID_BLOCKING = "[HYBRID_BLOCKING] "; + const int SAVE_STEP_INDEX = 2; + const std::chrono::milliseconds SLEEP_MS = 20ms; + + class HybridMgmtBlock { + public: + // 上一次运行的通道ID + int lastRunChannelId = -1; + // hybrid将要处理的batch id + int hybridBatchId[2] = {0, 0}; + // python侧将要处理的batch id + int pythonBatchId[2] = {0, 0}; + // readEmbed算子侧将要处理的batch id + int readEmbedBatchId[2] = {0, 0}; + bool isRunning = true; + + ~HybridMgmtBlock(); + + void CheckAndNotifyWake(int channelId); + + void CountPythonStep(int channelId, int steps); + + void CheckAndSetBlock(int channelId); + + void CheckValid(int channelId); + + void DoBlock(int channelId); + + void ResetAll(int channelId); + + int CheckSaveEmbMapValid(); + + bool GetBlockStatus(int channelId); + + void SetBlockStatus(int channelId, bool block); + + void SetRankInfo(RankInfo ri); + + void SetStepInterval(int trainStep, int evalStep); + + bool WaitValid(int channelId); + + void Destroy(); + + private: + // 通道i运行多少步后切换为通道j + int stepsInterval[2] = {0, 0}; + // 控制通道阻塞的变量 + bool isBlock[2] = {true, true}; + // 控制训练了多少步进行保存的步数 + int saveInterval = 0; + RankInfo rankInfo; + }; + + class HybridMgmtBlockingException : public std::exception { + public: + explicit HybridMgmtBlockingException(const string scene) + { + HybridMgmtBlock *hybridMgmtBlock = Singleton::GetInstance(); + int channelId = hybridMgmtBlock->lastRunChannelId; + int preprocessBatchNumber = hybridMgmtBlock->hybridBatchId[channelId]; + int currentBatchNumber = hybridMgmtBlock->pythonBatchId[channelId]; + str = StringFormat("Error happened at HyBridmgmt Blocking, it finds that " + "preprocess batch number not match current using batch number " + "%s , last use channel id is %d, preprocessBatchNumber is %d ," + "currentBatchNumber is %d. please check your setting of train " + "steps and eval steps", scene.c_str(), channelId, preprocessBatchNumber, + currentBatchNumber); + LOG_ERROR(str); + } + + private: + string str; + }; +} #endif \ No newline at end of file diff --git a/src/core/initializer/constant_initializer/constant_initializer.h b/src/core/initializer/constant_initializer/constant_initializer.h index 8af23170..6ca1c3e5 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.h +++ b/src/core/initializer/constant_initializer/constant_initializer.h @@ -8,11 +8,9 @@ #ifndef MX_REC_CONSTANT_INITIALIZER_H #define MX_REC_CONSTANT_INITIALIZER_H -#include #include "initializer/initializer.h" namespace MxRec { - using std::vector; class ConstantInitializer : public Initializer { public: diff --git a/src/core/initializer/initializer.cpp b/src/core/initializer/initializer.cpp index afb44d20..51586eac 100644 --- a/src/core/initializer/initializer.cpp +++ b/src/core/initializer/initializer.cpp @@ -3,4 +3,48 @@ * Description: initializer module * Author: MindX SDK * Date: 2022/12/22 -*/ \ No newline at end of file +*/ +#include "initializer.h" + +#include + +#include "constant_initializer/constant_initializer.h" +#include "random_normal_initializer/random_normal_initializer.h" +#include "truncated_normal_initializer/truncated_normal_initializer.h" + +using namespace MxRec; + +ConstantInitializerInfo::ConstantInitializerInfo(float constantValue, float initK) + : constantValue(constantValue), initK(initK) +{} + +NormalInitializerInfo::NormalInitializerInfo(float mean, float stddev, int seed, float initK) + : mean(mean), stddev(stddev), seed(seed), initK(initK) +{} + +InitializeInfo::InitializeInfo(string &name, int start, int len, + ConstantInitializerInfo constantInitializerInfo) + : name(name), start(start), len(len), constantInitializerInfo(constantInitializerInfo) +{ + if (name == "constant_initializer") { + initializerType = InitializerType::CONSTANT; + initializer = make_shared(start, len, constantInitializerInfo.constantValue, + constantInitializerInfo.initK); + } else { + throw invalid_argument("Invalid Initializer Type."); + } +} + +InitializeInfo::InitializeInfo(string &name, int start, int len, NormalInitializerInfo normalInitializerInfo) + : name(name), start(start), len(len), normalInitializerInfo(normalInitializerInfo) +{ + if (name == "truncated_normal_initializer") { + initializerType = InitializerType::TRUNCATED_NORMAL; + initializer = make_shared(start, len, normalInitializerInfo); + } else if (name == "random_normal_initializer") { + initializerType = InitializerType::RANDOM_NORMAL; + initializer = make_shared(start, len, normalInitializerInfo); + } else { + throw invalid_argument("Invalid Initializer Type."); + } +} diff --git a/src/core/initializer/initializer.h b/src/core/initializer/initializer.h index feec8729..dbe59ac9 100644 --- a/src/core/initializer/initializer.h +++ b/src/core/initializer/initializer.h @@ -8,20 +8,67 @@ #ifndef MX_REC_INITIALIZER_H #define MX_REC_INITIALIZER_H -#include +#include +#include + namespace MxRec { - using std::vector; + using namespace std; class Initializer { public: Initializer() = default; virtual ~Initializer() {}; - virtual void GenerateData(float* emb, int embSize)= 0; + virtual void GenerateData(float *emb, int embSize) = 0; int start; int len; float initParam = 1.0; }; + + enum class InitializerType { + INVALID, + CONSTANT, + TRUNCATED_NORMAL, + RANDOM_NORMAL + }; + + struct ConstantInitializerInfo { + ConstantInitializerInfo() = default; + + explicit ConstantInitializerInfo(float constantValue, float initK); + + float constantValue; + float initK = 1.0; + }; + + struct NormalInitializerInfo { + NormalInitializerInfo() = default; + + NormalInitializerInfo(float mean, float stddev, int seed, float initK); + + float mean; + float stddev; + int seed; + float initK = 1.0; + }; + + struct InitializeInfo { + InitializeInfo() = default; + + InitializeInfo(string &name, int start, int len, ConstantInitializerInfo constantInitializerInfo); + + InitializeInfo(string &name, int start, int len, NormalInitializerInfo normalInitializerInfo); + + string name; + int start; + int len; + InitializerType initializerType = InitializerType::INVALID; + + ConstantInitializerInfo constantInitializerInfo; + NormalInitializerInfo normalInitializerInfo; + + std::shared_ptr initializer; + }; } #endif // MX_REC_INITIALIZER_H \ No newline at end of file diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index 3b066bba..c10b7e46 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -11,10 +11,10 @@ using namespace MxRec; -RandomNormalInitializer::RandomNormalInitializer(int start, int len, std::tuple ret) - : start(start), len(len) +RandomNormalInitializer::RandomNormalInitializer(int start, int len, NormalInitializerInfo& initInfo) + : start(start), len(len), mean(initInfo.mean), stddev(initInfo.stddev), seed(initInfo.seed) { - std::tie(mean, stddev, seed, initParam) = ret; + initParam = initInfo.initK; generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); } @@ -28,5 +28,5 @@ void RandomNormalInitializer::GenerateData(float* const emb, const int embSize) LOG_WARN("InitializeInfo start {} + len {} is larger than embedding size {}.", start, len, embSize); return; } - std::generate_n(emb + start, len, [&]() { return initParam * distribution(generator); }); + std::generate_n(emb + start, len, [this]() { return initParam * distribution(generator); }); } \ No newline at end of file diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index e020bd64..fdedb190 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -20,7 +20,7 @@ namespace MxRec { class RandomNormalInitializer : public Initializer { public: RandomNormalInitializer() = default; - RandomNormalInitializer(int start, int len, std::tuple); + RandomNormalInitializer(int start, int len, NormalInitializerInfo& initInfo); ~RandomNormalInitializer() override {}; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index 7379a871..d3ea48fb 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -11,10 +11,10 @@ using namespace MxRec; -TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, std::tuple ret) - : start(start), len(len) +TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, NormalInitializerInfo& initInfo) + : start(start), len(len), mean(initInfo.mean), stddev(initInfo.stddev), seed(initInfo.seed) { - std::tie(mean, stddev, seed, initParam) = ret; + initParam = initInfo.initK; generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); @@ -32,7 +32,7 @@ void TruncatedNormalInitializer::GenerateData(float* const emb, const int embSiz LOG_WARN("InitializeInfo start {} + len {} is larger than embedding size {}.", start, len, embSize); return; } - std::generate_n(emb + start, len, [&]() { + std::generate_n(emb + start, len, [this]() { float tmp = initParam * distribution(generator); while (tmp < minBound || tmp > maxBound) { tmp = initParam * distribution(generator); diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index 67285072..e68ad02a 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -20,7 +20,7 @@ namespace MxRec { class TruncatedNormalInitializer : public Initializer { public: TruncatedNormalInitializer() = default; - TruncatedNormalInitializer(int start, int len, std::tuple); + TruncatedNormalInitializer(int start, int len, NormalInitializerInfo& initInfo); ~TruncatedNormalInitializer() override {}; diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index a80ff826..c7a97daa 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -212,7 +212,7 @@ bool FeatureAdmitAndEvict::GetFunctionSwitch() const } void FeatureAdmitAndEvict::PreProcessKeys(const std::vector& splitKey, std::vector& keyCount, - absl::flat_hash_map& mergeKeys) + absl::flat_hash_map& mergeKeys) const { for (size_t i = 0; i < splitKey.size(); ++i) { if (splitKey[i] == -1) { diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 85219c8a..4f56f5b5 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -67,7 +67,7 @@ namespace MxRec { bool GetFunctionSwitch() const; void PreProcessKeys(const std::vector& splitKey, std::vector& keyCount, - absl::flat_hash_map& mergeKeys); + absl::flat_hash_map& mergeKeys) const; // 判断配置是否正确的接口 static bool IsThresholdCfgOK(const std::vector& thresholds, @@ -88,7 +88,7 @@ namespace MxRec { // 解析m_table2Threshold bool ParseThresholdCfg(const std::vector& thresholdValues); std::vector GetAllNeedEvictTableNames(); - FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tableName, + FeatureAdmitType FeatureAdmitHelper(const int channel, const std::string& tableNameOrigin, const int64_t featureId, const uint32_t featureCnt); void FeatureEvictHelper(const std::string& embName, std::vector& evictKey); void ResetAllRecords(); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 0880d54f..db142930 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -9,6 +9,7 @@ #include "checkpoint/checkpoint.h" #include "hd_transfer/hd_transfer.h" +#include "ock_ctr_common/include/error_code.h" using namespace std; using namespace chrono; @@ -17,29 +18,6 @@ using namespace ock::ctr; static shared_mutex g_smut; -template -inline vector Count2Start(const vector& count) -{ - vector start = { 0 }; - for (size_t i = 0; i < count.size() - 1; ++i) { - start.push_back(count[i] + start.back()); - } - return start; -} - -class EndRunExit : public std::exception { -public: - explicit EndRunExit(const char* message) : errorMessage(message) {} - - const char* what() const noexcept override - { - return errorMessage; - } - -private: - const char* errorMessage; -}; - void KeyProcess::SetupHotEmbUpdateStep() { this->hotEmbUpdateStep = GlobalEnv::hotEmbUpdateStep; @@ -221,7 +199,7 @@ void KeyProcess::GetUniqueConfig(UniqueConf& uniqueConf) void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, const unique_ptr & batch, UniquePtr& unique) { - uniqueConf.desiredSize = (uint32_t)batch->Size(); + uniqueConf.desiredSize = static_cast(batch->Size()); if (preBatchSize != batch->Size()) { uniqueInitialize = false; preBatchSize = batch->Size(); @@ -324,10 +302,10 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < vector & restore, vector & hotPos, vector >& keyCount) { - TimeCost UniqueTC; + TimeCost uniqueTc; if (m_featureAdmitAndEvict.GetFunctionSwitch() && FeatureAdmitAndEvict::m_embStatus[batch->name] != SingleEmbTableStatus::SETS_NONE) { - tie(splitKeys, restore, keyCount) = HashSplit_withFAAE(batch); // 按存储dev id切分并去重 + tie(splitKeys, restore, keyCount) = HashSplitWithFAAE(batch); // 按存储dev id切分并去重 } else { if (rankInfo.useHot) { tie(splitKeys, restore, hotPos) = HotHashSplit(batch); // 按存储dev id切分并去重 @@ -335,7 +313,7 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < tie(splitKeys, restore) = HashSplit(batch); // 按存储dev id切分并去重 } } - LOG_DEBUG("UniqueTC(ms):{}", UniqueTC.ElapsedMS()); + LOG_DEBUG("uniqueTc(ms):{}", uniqueTc.ElapsedMS()); } bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, UniquePtr& unique, @@ -352,9 +330,9 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat // 特征准入&淘汰 if (isWithFAAE && - (m_featureAdmitAndEvict.FeatureAdmit(channel, batch, uniqueInfo.all2AllInfo.keyRecv, - uniqueInfo.all2AllInfo.countRecv) - == FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { + (m_featureAdmitAndEvict.FeatureAdmit( + channel, batch, uniqueInfo.all2AllInfo.keyRecv, uniqueInfo.all2AllInfo.countRecv) == + FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR)) { LOG_ERROR(KEY_PROCESS "rank:{} thread:{}, channel:{}, Feature-admit-and-evict error ...", rankInfo.rankId, threadId, channel); return false; @@ -482,7 +460,7 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, } vector countSend; for (auto& cnt: keyCount) { - countSend.insert(countSend.end(), cnt.begin(), cnt.end()); + countSend.insert(countSend.cend(), cnt.cbegin(), cnt.cend()); } vector sc; for (int i = 0; i < rankInfo.rankSize; ++i) { @@ -590,12 +568,12 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch keySendInfo.keyCount.resize(size); UniqueIn uniqueIn; - uniqueIn.inputIdCnt = (uint32_t)batch->Size(); + uniqueIn.inputIdCnt = static_cast(batch->Size()); uniqueIn.inputId = reinterpret_cast(batch->sample.data()); EnhancedUniqueOut uniqueOut; uniqueOut.uniqueId = reinterpret_cast(keySendInfo.keySend.data()); - uniqueOut.index = (uint32_t*)uniqueInfoOut.restore.data(); + uniqueOut.index = reinterpret_cast(uniqueInfoOut.restore.data()); if (rankInfo.useStatic) { uniqueOut.idCnt = idCount.data(); uniqueOut.idCntFill = keySendInfo.keyCount.data(); @@ -642,9 +620,9 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uni uniqueInfoOut.hotPos.resize(hotEmbTotCount[batch->name]); hotOffset = hotEmbTotCount[batch->name]; - TimeCost ComputeHotTc; + TimeCost computeHotTc; ComputeHotPos(batch, hotMap, uniqueInfoOut.hotPos, uniqueInfoOut.restore, hotOffset); - LOG_DEBUG("ComputeHot TimeCost(ms):{}", ComputeHotTc.ElapsedMS()); + LOG_DEBUG("ComputeHot TimeCost(ms):{}", computeHotTc.ElapsedMS()); UpdateHotMapForUnique(keySendInfo.keySend, keySendInfo.keyCount, hotOffset, batch->batchId % hotEmbUpdateStep == 0, batch->name); } @@ -660,7 +638,7 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uni } void KeyProcess::ComputeHotPos(const unique_ptr &batch, absl::flat_hash_map &hotMap, - vector &hotPos, vector &restore, const int hotOffset) + vector &hotPos, vector &restore, const int hotOffset) const { auto* inputData = batch->sample.data(); size_t miniBs = batch->Size(); @@ -737,7 +715,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, vector sc; // send count for (const auto& i: splitKeys) { sc.push_back(static_cast(i.size())); - keySend.insert(keySend.end(), i.begin(), i.end()); + keySend.insert(keySend.cend(), i.cbegin(), i.cend()); } keys_t keyRecv; @@ -786,7 +764,7 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< EASY_BLOCK("split push back") for (size_t i = 0; i < miniBs; i++) { const emb_key_t& key = batchData[i]; - int devId = static_cast(key) & (rankInfo.rankSize - 1); // 数据所在的设备devID = key % dev总数 support -1 + int devId = key % static_cast(rankInfo.rankSize); auto result = uKey.find(key); if (result == uKey.end()) { splitKeys[devId].push_back(key); @@ -798,30 +776,20 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< } EASY_END_BLOCK - LOG_TRACE("dump splitKeys {}", [&] { - stringstream ssTrace; - for (int devId = 0; devId < rankInfo.rankSize; ++devId) { - ssTrace << '|' << devId << ":"; - for (auto x: splitKeys[devId]) { - ssTrace << x << ','; - } - ssTrace << '|'; - } - return ssTrace.str(); - }()); + LOG_TRACE("dump splitKeys {}", DumpSplitKeys(splitKeys)); if (g_statOn) { - size_t UniqueKeyNum = 0; + size_t uniqueKeyNum = 0; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { - UniqueKeyNum += splitKeys[devId].size(); + uniqueKeyNum += splitKeys[devId].size(); } LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} batch_key_num {} unique_key_num {}", - batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), UniqueKeyNum); + batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), uniqueKeyNum); } return { splitKeys, restore }; } -auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const +auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const -> tuple, vector, vector>> { EASY_FUNCTION(profiler::colors::Gold) @@ -835,7 +803,7 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const EASY_BLOCK("split push back") for (size_t i = 0; i < miniBs; i++) { const emb_key_t& key = batchData[i]; - int devId = static_cast(key) & (rankInfo.rankSize - 1); // 数据所在的设备devID = key % dev总数 support -1 + int devId = key % static_cast(rankInfo.rankSize); auto result = uKey.find(key); if (result == uKey.end()) { splitKeys[devId].push_back(key); @@ -858,25 +826,15 @@ auto KeyProcess::HashSplit_withFAAE(const unique_ptr& batch) const } EASY_END_BLOCK - LOG_TRACE("dump splitKeys {}", [&] { - stringstream ssTrace; - for (int devId = 0; devId < rankInfo.rankSize; ++devId) { - ssTrace << '|' << devId << ":"; - for (auto x : splitKeys[devId]) { - ssTrace << x << ','; - } - ssTrace << '|'; - } - return ssTrace.str(); - }()); + LOG_TRACE("dump splitKeys {}", DumpSplitKeys(splitKeys)); if (g_statOn) { - size_t UniqueKeyNum = 0; + size_t uniqueKeyNum = 0; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { - UniqueKeyNum += splitKeys[devId].size(); + uniqueKeyNum += splitKeys[devId].size(); } LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} batch_key_num {} faae_unique_key_num {}", - batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), UniqueKeyNum); + batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), uniqueKeyNum); } return { splitKeys, restore, keyCount }; } @@ -903,7 +861,7 @@ tuple, vector, vector> if (batch->batchId % hotEmbUpdateStep == 0) { keyCountMap[key]++; } - int devId = static_cast(key) & (rankInfo.rankSize - 1); // 数据所在的设备devID = key % dev总数 support -1 + int devId = key % static_cast(rankInfo.rankSize); // 数据所在的设备devID = key % dev总数 support -1 auto result = uKey.find(key); if (result != uKey.end()) { // // already in splitKeys restore[i] = result->second; @@ -930,12 +888,12 @@ tuple, vector, vector> } if (g_statOn) { - size_t UniqueKeyNum = 0; + size_t uniqueKeyNum = 0; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { - UniqueKeyNum += splitKeys[devId].size(); + uniqueKeyNum += splitKeys[devId].size(); } LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} batch_key_num {} hot_unique_key_num {}", - batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), UniqueKeyNum); + batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), uniqueKeyNum); } UpdateHotMap(keyCountMap, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, batch->name); @@ -1145,8 +1103,8 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vec EASY_FUNCTION() int hotNum = 0; for (size_t i = 0; i < batch->Size(); ++i) { - const emb_key_t d = batch->sample[i]; - int devId = static_cast(d) & (rankInfo.rankSize - 1); + const emb_key_t key = batch->sample[i]; + int devId = key % static_cast(rankInfo.rankSize); if (restoreVec[i] >= hotPosSize) { restoreVec[i] += blockOffset[devId]; } else if (Log::GetLevel() >= Log::DEBUG) { @@ -1157,12 +1115,6 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vec hotNum, batch->Size(), buildRestoreVecTC.ElapsedMS()); } -class EmptyList : public std::exception { -}; - -class WrongListTop : public std::exception { -}; - template T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, int channel) { @@ -1369,3 +1321,16 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset LOG_INFO(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); } + +string KeyProcess::DumpSplitKeys(vector> &splitKeys) const +{ + stringstream ssTrace; + for (int devId = 0; devId < rankInfo.rankSize; ++devId) { + ssTrace << '|' << devId << ":"; + for (auto key: splitKeys[devId]) { + ssTrace << key << ','; + } + ssTrace << '|'; + } + return ssTrace.str(); +} diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index c4e109cc..ce885d0c 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -22,7 +22,6 @@ #include #include "ock_ctr_common/include/factory.h" -#include "ock_ctr_common/include/error_code.h" #include "utils/common.h" #include "utils/config.h" @@ -41,7 +40,7 @@ namespace MxRec { template struct Cmp { - bool operator()(const T& a, const T& b) + bool operator()(const T& a, const T& b) const { return get(a) > get(b); // batch id order } @@ -59,6 +58,25 @@ namespace MxRec { INVALID }; + class EndRunExit : public std::exception { + public: + explicit EndRunExit(const char* message) : errorMessage(message) {} + + const char* what() const noexcept override + { + return errorMessage; + } + + private: + const char* errorMessage; + }; + + class EmptyList : public std::exception { + }; + + class WrongListTop : public std::exception { + }; + class KeyProcess { public: bool Initialize(const RankInfo& rInfo, const vector& eInfos, @@ -126,9 +144,9 @@ namespace MxRec { bool isRunning { false }; - inline bool hasEmbName(const string &emb_name) + inline bool HasEmbName(const string& embName) { - return embInfos.find(emb_name) != embInfos.end(); + return embInfos.find(embName) != embInfos.end(); }; GTEST_PRIVATE: template @@ -187,7 +205,7 @@ namespace MxRec { auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector>; - auto HashSplit_withFAAE(const unique_ptr& batch) const + auto HashSplitWithFAAE(const unique_ptr& batch) const -> tuple, vector, vector>>; vector GetScAll(const vector& keyScLocal, int commId, int channel) const; @@ -200,7 +218,7 @@ namespace MxRec { unique_ptr GetBatchData(int channel, int commId); - void BuildRestoreVec(const unique_ptr& batch, const vector& rs, + void BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, vector& restoreVec, int hotPosSize = 0) const; void SendA2A(const vector& a2aInfo, const string& embName, int channel, int batch); @@ -220,13 +238,13 @@ namespace MxRec { void PushResult(unique_ptr& batch, unique_ptr> tensors, keys_t& lookupKeys); - void PushGlobalUniqueTensors(const unique_ptr>& tensors, keys_t& lookupKeys, int chanel); + void PushGlobalUniqueTensors(const unique_ptr>& tensors, keys_t& lookupKeys, int channel); void AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, const unique_ptr& batch); void ComputeHotPos(const unique_ptr &batch, absl::flat_hash_map &hotMap, - vector &hotPos, vector &restore, const int hotOffset); + vector &hotPos, vector &restore, const int hotOffset) const; vector GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss); @@ -234,6 +252,18 @@ namespace MxRec { void HashSplitHelper(const unique_ptr & batch, vector & splitKeys, vector & restore, vector & hotPos, vector >& keyCount); + + template + inline vector Count2Start(const vector& count) + { + vector start = { 0 }; + for (size_t i = 0; i < count.size() - 1; ++i) { + start.push_back(count[i] + start.back()); + } + return start; + } + + string DumpSplitKeys(vector>& splitKeys) const; }; } // end namespace MxRec #endif // MX_REC_KEY_PROCESS_H diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index 46ddd2f5..de25abf2 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -15,23 +15,8 @@ using namespace MxRec; -inline TransferRet TransferSuccess() -{ - return TransferRet::TRANSFER_OK; -} - -inline TransferRet TransferError() -{ - return TransferRet::TRANSFER_ERROR; -} - -inline TransferRet TransferSpaceWarning() -{ - return TransferRet::SSD_SPACE_NOT_ENOUGH; -} - -inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& externalKeys, - vector& internalKeys, const vector& keys) +inline void CacheManager::GetExternalKeys(EmbHashMapInfo &embHashMap, vector &externalKeys, + vector &internalKeys, const vector &keys) { for (const emb_key_t key : keys) { if (embHashMap.hostHashMap.find(key) == embHashMap.hostHashMap.end()) { @@ -42,7 +27,8 @@ inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& exter } } -void AddDebugAndTraceLog(size_t batchKeySize, vector& externalKeys, vector& externalSSDKeys) +void CacheManager::AddDebugAndTraceLog(size_t batchKeySize, vector &externalKeys, + vector &externalSSDKeys) { LOG_DEBUG("TransferDDREmbWithSSD: batchKeySize:{}, externalKeys size:{}, externalSSDKeys size:{}", batchKeySize, externalKeys.size(), externalSSDKeys.size()); @@ -53,7 +39,7 @@ void AddDebugAndTraceLog(size_t batchKeySize, vector& externalKeys, v /// 去重和过滤无效key /// \param originalKeys 原有keys /// \param keys 处理后的keys -void HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys) +void CacheManager::HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys) { // 去重并保持原key的顺序 结果可测试 unordered_set keySet; @@ -83,7 +69,7 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, vector externalKeys; vector internalKeys; GetExternalKeys(embHashMap, externalKeys, internalKeys, keys); - if (externalKeys.empty()) { return TransferSuccess(); } + if (externalKeys.empty()) { return TransferRet::TRANSFER_OK; } // 判断剩余内存空间是否足够; 可用内存空间计算:HBM+DDR-已占用; 若是训练,再加DDR已淘汰; // SSD仅与DDR交互,不考虑HBM淘汰位置;由于maxOffset比实际使用大1,所以虽然从0开始也不用再减1 @@ -103,7 +89,7 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, bool ddrSpaceEnoughOrEval = channelId != TRAIN_CHANNEL_ID || isDDRSpaceEnough; if (ddrSpaceEnoughOrEval && externalSSDKeys.empty()) { // 部分场景后续不用处理,在此处返回 - return TransferSuccess(); + return TransferRet::TRANSFER_OK; } AddDebugAndTraceLog(keys.size(), externalKeys, externalSSDKeys); @@ -139,7 +125,7 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, if (int64_t(needSSDSize) > ssdAvailableSize) { LOG_ERROR("TransferDDREmbWithSSD: ssd available space is not enough to transfer DDR emb data. " "needSSDSize:{}, ssdAvailableSize:{}", needSSDSize, ssdAvailableSize); - return TransferSpaceWarning(); + return TransferRet::SSD_SPACE_NOT_ENOUGH; } } @@ -193,7 +179,7 @@ void CacheManager::RefreshRelateInfoWithSSD2DDR(const std::string& embTableName, } void CacheManager::GetDDREmbInfo(vector& keys, const std::string& embTableName, EmbHashMapInfo& embHashMap, - vector& ddrTransferPos, vector>& ddrEmbData) + vector& ddrTransferPos, vector>& ddrEmbData) const { // 根据offset 获取对应Emb数据 for (auto& key : keys) { @@ -219,7 +205,7 @@ void CacheManager::GetDDREmbInfo(vector& keys, const std::string& emb /// \param ssdEmbData SSD对应的emb数据 void CacheManager::UpdateDDREmbInfo(const std::string& embTableName, vector& ddrTransferPos, - vector>& ssdEmbData) + vector>& ssdEmbData) const { auto& emb = hostEmbs->GetEmb(embTableName); auto& embData = emb.embData; @@ -417,20 +403,20 @@ TransferRet CacheManager::TransferSSDEmb2DDR(const string& embTableName, EmbHash vector>& ssdEmbData) { if (externalSSDKeys.empty()) { - return TransferSuccess(); + return TransferRet::TRANSFER_OK; } TimeCost ssd2DdrTc; LOG_DEBUG("TransferDDREmbWithSSD: get SSD embeddings and save to DDR, size:{}", externalSSDKeys.size()); if (ddrTransferPos.size() != externalSSDKeys.size() || externalSSDKeys.size() != ssdEmbData.size()) { LOG_ERROR("TransferDDREmbWithSSD, vector length is not equal, ddrTransferPos len:{}, externalSSDKeys len:{}, " "ssdEmbData len:{}", ddrTransferPos.size(), externalSSDKeys.size(), ssdEmbData.size()); - return TransferError(); + return TransferRet::TRANSFER_ERROR; } // 将SSD emb存储到DDR中 刷新频次信息 UpdateDDREmbInfo(embTableName, ddrTransferPos, ssdEmbData); RefreshRelateInfoWithSSD2DDR(embTableName, embHashMap, externalSSDKeys, ddrTransferPos); LOG_DEBUG("TransferDDREmbWithSSD: ssd2DdrTc TimeCost(ms):{}", ssd2DdrTc.ElapsedMS()); - return TransferSuccess(); + return TransferRet::TRANSFER_OK; } void CacheManager::CreateSSDTableIfNotExist(const std::string& embTableName) diff --git a/src/core/ssd_cache/cache_manager.h b/src/core/ssd_cache/cache_manager.h index 8e3fd26d..c240d67c 100644 --- a/src/core/ssd_cache/cache_manager.h +++ b/src/core/ssd_cache/cache_manager.h @@ -55,7 +55,7 @@ namespace MxRec { // 转换DDR和SSD数据 TransferRet TransferDDREmbWithSSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, - const vector& keys, int channelId); + const vector& originalKeys, int channelId); /* HBM与DDR换入换出时刷新频次信息 */ void RefreshFreqInfoCommon(const string& embTableName, vector& keys, @@ -81,11 +81,11 @@ namespace MxRec { void GetDDREmbInfo(vector& keys, const std::string& embTableName, EmbHashMapInfo& embHashMap, - vector& ddrTransferPos, vector>& ddrEmbData); + vector& ddrTransferPos, vector>& ddrEmbData) const; void UpdateDDREmbInfo(const std::string& embTableName, vector& ddrTransferPos, - vector>& ssdEmbData); + vector>& ssdEmbData) const; void RefreshRelateInfoWithDDR2SSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, vector& ddrSwapOutKeys, vector& ddrSwapOutCounts); @@ -112,6 +112,14 @@ namespace MxRec { static void HandleDDRTransferPos(vector& ddrTransferPos, vector& externalSSDKeys, EmbHashMapInfo& embHashMap); + inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& externalKeys, + vector& internalKeys, const vector& keys); + + void AddDebugAndTraceLog(size_t batchKeySize, vector& externalKeys, + vector& externalSSDKeys); + + void HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys); + unordered_map embBaseInfos; GTEST_PRIVATE: diff --git a/src/core/ssd_engine/file.cpp b/src/core/ssd_engine/file.cpp index 82d56c78..fe6d9724 100644 --- a/src/core/ssd_engine/file.cpp +++ b/src/core/ssd_engine/file.cpp @@ -273,17 +273,17 @@ vector File::GetKeys() return ret; } -uint64_t File::GetDataCnt() +uint64_t File::GetDataCnt() const { return dataCnt; } -uint64_t File::GetFileID() +uint64_t File::GetFileID() const { return fileID; } -uint64_t File::GetStaleDataCnt() +uint64_t File::GetStaleDataCnt() const { return staleDataCnt; } diff --git a/src/core/ssd_engine/file.h b/src/core/ssd_engine/file.h index 8ae9ac26..1c234d11 100644 --- a/src/core/ssd_engine/file.h +++ b/src/core/ssd_engine/file.h @@ -43,11 +43,11 @@ namespace MxRec { vector GetKeys(); - uint64_t GetDataCnt(); + uint64_t GetDataCnt() const; - uint64_t GetFileID(); + uint64_t GetFileID() const; - uint64_t GetStaleDataCnt(); + uint64_t GetStaleDataCnt() const; private: uint64_t fileID; // init by constructor diff --git a/src/core/ssd_engine/ssd_engine.cpp b/src/core/ssd_engine/ssd_engine.cpp index ba4b9ba7..f9cedccc 100644 --- a/src/core/ssd_engine/ssd_engine.cpp +++ b/src/core/ssd_engine/ssd_engine.cpp @@ -11,7 +11,7 @@ bool SSDEngine::IsTableExist(const string &tableName) if (!isRunning) { throw invalid_argument("SSDEngine not running"); } - auto it = tableMap.find(tableName); + auto it = as_const(tableMap).find(tableName); return !(it == tableMap.end()); } @@ -20,7 +20,7 @@ bool SSDEngine::IsKeyExist(const string &tableName, emb_key_t key) if (!isRunning) { throw invalid_argument("SSDEngine not running"); } - auto it = tableMap.find(tableName); + auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { throw invalid_argument("table not found"); } @@ -35,7 +35,7 @@ void SSDEngine::CreateTable(const string &tableName, vector savePaths, u if (savePaths.empty()) { throw invalid_argument("SSDEngine input savePaths is empty"); } - auto it = tableMap.find(tableName); + auto it = as_const(tableMap).find(tableName); if (it != tableMap.end()) { throw invalid_argument("table already exist"); } @@ -47,7 +47,7 @@ void SSDEngine::InsertEmbeddings(const string &tableName, vector &key if (!isRunning) { throw invalid_argument("SSDEngine not running"); } - auto it = tableMap.find(tableName); + auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { throw invalid_argument("table not found"); } @@ -64,7 +64,7 @@ void SSDEngine::DeleteEmbeddings(const string &tableName, vector &key if (!isRunning) { throw invalid_argument("SSDEngine not running"); } - auto it = tableMap.find(tableName); + auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { throw invalid_argument("table not found"); } @@ -77,7 +77,7 @@ int64_t SSDEngine::GetTableAvailableSpace(const string &tableName) if (!isRunning) { throw invalid_argument("SSDEngine not running"); } - auto it = tableMap.find(tableName); + auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { throw invalid_argument("table not found"); } @@ -90,7 +90,7 @@ void SSDEngine::Save(int step) if (!isRunning) { throw invalid_argument("SSDEngine not running"); } - for (auto item: tableMap) { + for (auto item: as_const(tableMap)) { item.second->Save(step); } } @@ -100,7 +100,7 @@ void SSDEngine::Load(const string &tableName, vector savePaths, uint64_t if (!isRunning) { throw invalid_argument("SSDEngine not running"); } - auto it = tableMap.find(tableName); + auto it = as_const(tableMap).find(tableName); if (it != tableMap.end()) { throw invalid_argument("table already exist"); } @@ -130,7 +130,7 @@ void SSDEngine::CompactMonitor() duration = chrono::duration_cast(end - start); if (duration >= compactPeriod) { LOG_DEBUG("SSDEngine CompactMonitor start compact"); - for (const auto &item: tableMap) { + for (const auto &item: as_const(tableMap)) { item.second->Compact(false); } LOG_DEBUG("SSDEngine CompactMonitor end compact"); @@ -147,7 +147,7 @@ vector> SSDEngine::FetchEmbeddings(const string &tableName, vector if (!isRunning) { throw invalid_argument("SSDEngine not running"); } - auto it = tableMap.find(tableName); + auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { throw invalid_argument("table not found"); } diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index c0ab83f9..b2a09dae 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -54,44 +54,6 @@ namespace MxRec { : start(start), len(len), constantVal(constantVal), randomMin(randomMin), randomMax(randomMax) {} - ConstantInitializerInfo::ConstantInitializerInfo(float constantValue, float initK) - : constantValue(constantValue), initK(initK) - {} - - NormalInitializerInfo::NormalInitializerInfo(float mean, float stddev, int seed, float initK) - : mean(mean), stddev(stddev), seed(seed), initK(initK) - {} - - InitializeInfo::InitializeInfo(std::string& name, int start, int len, - ConstantInitializerInfo constantInitializerInfo) - : name(name), start(start), len(len), constantInitializerInfo(constantInitializerInfo) - { - if (name == "constant_initializer") { - initializerType = InitializerType::CONSTANT; - initializer = make_shared(start, len, constantInitializerInfo.constantValue, - constantInitializerInfo.initK); - } else { - throw std::invalid_argument("Invalid Initializer Type."); - } - } - - InitializeInfo::InitializeInfo(std::string& name, int start, int len, NormalInitializerInfo normalInitializerInfo) - : name(name), start(start), len(len), normalInitializerInfo(normalInitializerInfo) - { - std::tuple ret(normalInitializerInfo.mean, normalInitializerInfo.stddev, - normalInitializerInfo.seed, normalInitializerInfo.initK); - - if (name == "truncated_normal_initializer") { - initializerType = InitializerType::TRUNCATED_NORMAL; - initializer = make_shared(start, len, ret); - } else if (name == "random_normal_initializer") { - initializerType = InitializerType::RANDOM_NORMAL; - initializer = make_shared(start, len, ret); - } else { - throw std::invalid_argument("Invalid Initializer Type."); - } - } - void SetLog(int rank) { g_glogLevel = GlobalEnv::glogStderrthreshold; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 8a6bdc39..2adbcb5f 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -352,47 +352,6 @@ namespace MxRec { void ValidateReadFile(const string& dataDir, size_t datasetSize); - enum class InitializerType { - CONSTANT, - TRUNCATED_NORMAL, - RANDOM_NORMAL - }; - - struct ConstantInitializerInfo { - ConstantInitializerInfo() = default; - explicit ConstantInitializerInfo(float constantValue, float initK); - - float constantValue; - float initK = 1.0; - }; - - struct NormalInitializerInfo { - NormalInitializerInfo() = default; - NormalInitializerInfo(float mean, float stddev, int seed, float initK); - - float mean; - float stddev; - int seed; - float initK = 1.0; - }; - - struct InitializeInfo { - InitializeInfo() = default; - - InitializeInfo(std::string& name, int start, int len, ConstantInitializerInfo constantInitializerInfo); - InitializeInfo(std::string& name, int start, int len, NormalInitializerInfo normalInitializerInfo); - - std::string name; - int start; - int len; - InitializerType initializerType; - - ConstantInitializerInfo constantInitializerInfo; - NormalInitializerInfo normalInitializerInfo; - - shared_ptr initializer; - }; - template inline Tensor Vec2TensorI32(const std::vector& data) { @@ -487,7 +446,7 @@ namespace MxRec { std::vector ddr2HbmKeys; void SetStartCount(); - bool HasFree(size_t i); + bool HasFree(size_t i) const; }; struct All2AllInfo { diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index ad7fbf68..368d12c6 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -205,7 +205,7 @@ public: { auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < embNames.size(); ++i) { - if (!keyProcess->hasEmbName(embNames.at(i))) { + if (!keyProcess->HasEmbName(embNames.at(i))) { LOG_INFO("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i)); tableUsed.push_back(false); } else { @@ -410,7 +410,7 @@ public: { auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < splits.size(); ++i) { - if (!keyProcess->hasEmbName(embNames.at(i))) { + if (!keyProcess->HasEmbName(embNames.at(i))) { LOG_INFO("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i)); tableUsed.push_back(false); } else { diff --git a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp index 261a1890..ff0bc6ba 100644 --- a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp +++ b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp @@ -111,7 +111,7 @@ TEST_F(HybridMgmtBlockTest, ResetAll) ASSERT_EQ(hybridMgmtBlock->hybridBatchId[0], 0); } -TEST_F(HybridMgmtBlockTest, CheckSaveEmbdMapValid) +TEST_F(HybridMgmtBlockTest, CheckSaveEmbMapValid) { hybridMgmtBlock = std::make_unique(); hybridMgmtBlock->SetStepInterval(1, 1); @@ -119,18 +119,18 @@ TEST_F(HybridMgmtBlockTest, CheckSaveEmbdMapValid) hybridMgmtBlock->pythonBatchId[0] = 0; hybridMgmtBlock->hybridBatchId[0] = 0; - hybridMgmtBlock->CheckSaveEmbdMapValid(); - int status0 = hybridMgmtBlock->CheckSaveEmbdMapValid(); + hybridMgmtBlock->CheckSaveEmbMapValid(); + int status0 = hybridMgmtBlock->CheckSaveEmbMapValid(); hybridMgmtBlock->pythonBatchId[0] = 0; hybridMgmtBlock->hybridBatchId[0] = 1; - hybridMgmtBlock->CheckSaveEmbdMapValid(); - int status1 = hybridMgmtBlock->CheckSaveEmbdMapValid(); + hybridMgmtBlock->CheckSaveEmbMapValid(); + int status1 = hybridMgmtBlock->CheckSaveEmbMapValid(); int step2 = 2; hybridMgmtBlock->pythonBatchId[0] = 0; hybridMgmtBlock->hybridBatchId[0] = step2; - int status2 = hybridMgmtBlock->CheckSaveEmbdMapValid(); + int status2 = hybridMgmtBlock->CheckSaveEmbMapValid(); ASSERT_EQ(status0, 0); ASSERT_EQ(status1, 1); ASSERT_EQ(status2, -1); diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index 99ffce4d..b27a6a26 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -47,8 +47,8 @@ TEST(InitializerTest, TruncatedNormalInitializerTest) { TruncatedNormalInitializer truncatedNormalInitializer; - std::tuple ret(1.0, 0.3, 1, 0.1); - truncatedNormalInitializer = TruncatedNormalInitializer(1, 10, ret); + auto initInfo = NormalInitializerInfo(1.0, 0.3, 1, 0.1); + truncatedNormalInitializer = TruncatedNormalInitializer(1, 10, initInfo); vector> embData; int vocabSize = 5; @@ -76,8 +76,8 @@ TEST(InitializerTest, TruncatedNormalInitializerTest) TEST(InitializerTest, RandomNormalInitializerTest) { - std::tuple ret(2.0, 0.5, 1, 0.1); - RandomNormalInitializer randomNormalInitializer(1, 10, ret); + auto initInfo = NormalInitializerInfo(1.0, 0.3, 1, 0.1); + RandomNormalInitializer randomNormalInitializer(1, 10, initInfo); vector> embData; int vocabSize = 5; diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index be8f6c58..9b3f6886 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -15,6 +15,7 @@ #include "utils/common.h" #include "key_process/key_process.h" #include "ock_ctr_common/include/unique.h" +#include "ock_ctr_common/include/error_code.h" using namespace std; using namespace MxRec; -- Gitee From 56f075cb004730e7415c8ef126d114b09f03d8b9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 21 Sep 2023 20:52:48 +0800 Subject: [PATCH 360/551] Match-id-09870dd404b044baed421d868a02b79ba6fb8d76 --- src/core/key_process/key_process.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index db142930..27df4961 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -764,7 +764,7 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< EASY_BLOCK("split push back") for (size_t i = 0; i < miniBs; i++) { const emb_key_t& key = batchData[i]; - int devId = key % static_cast(rankInfo.rankSize); + emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); auto result = uKey.find(key); if (result == uKey.end()) { splitKeys[devId].push_back(key); @@ -803,7 +803,7 @@ auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const EASY_BLOCK("split push back") for (size_t i = 0; i < miniBs; i++) { const emb_key_t& key = batchData[i]; - int devId = key % static_cast(rankInfo.rankSize); + emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); auto result = uKey.find(key); if (result == uKey.end()) { splitKeys[devId].push_back(key); @@ -861,7 +861,7 @@ tuple, vector, vector> if (batch->batchId % hotEmbUpdateStep == 0) { keyCountMap[key]++; } - int devId = key % static_cast(rankInfo.rankSize); // 数据所在的设备devID = key % dev总数 support -1 + emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); auto result = uKey.find(key); if (result != uKey.end()) { // // already in splitKeys restore[i] = result->second; @@ -1104,7 +1104,7 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vec int hotNum = 0; for (size_t i = 0; i < batch->Size(); ++i) { const emb_key_t key = batch->sample[i]; - int devId = key % static_cast(rankInfo.rankSize); + emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); if (restoreVec[i] >= hotPosSize) { restoreVec[i] += blockOffset[devId]; } else if (Log::GetLevel() >= Log::DEBUG) { -- Gitee From b6c7b80498c16866486316ad59c0463b8f74d188 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 22 Sep 2023 14:01:25 +0800 Subject: [PATCH 361/551] Match-id-4e5cb2d76362fd8468151b440f2647a89d7c99c6 --- mx_rec/core/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 6267d09b..43a4a863 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -50,7 +50,7 @@ from mx_rec.util.log import logger ("value_dtype", OptionValidator, {"options": [tf.float32]}), ("shard_num", NumValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), ("fusion_optimizer_var", ClassValidator, {"classes": (bool, )}), - ("hashtable_threshold", NumValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]) + ("hashtable_threshold", IntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]) ]) def create_table(key_dtype, dim, name, emb_initializer, optimizer_list: Optional[list] = None, -- Gitee From 48e1a471c218681c71ae6dca391ec80cb5d149c2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 22 Sep 2023 14:21:29 +0800 Subject: [PATCH 362/551] Match-id-df1721d8a53f89422d57a46bc4885150e699721c --- src/core/checkpoint/checkpoint.cpp | 2 +- src/core/emb_table/emb_table.cpp | 13 +++++++------ src/core/emb_table/emb_table.h | 4 ++-- src/core/hd_transfer/acl_channel.h | 6 +++--- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 6 +++--- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- .../random_normal_initializer.h | 2 -- .../truncated_normal_initializer.h | 6 ++---- src/core/key_process/key_process.h | 2 +- src/core/ssd_cache/cache_manager.cpp | 6 +++--- src/core/ssd_cache/cache_manager.h | 6 +++--- 11 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 195be601..9f1415a5 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -289,7 +289,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, } for (size_t i = 0, j = 0; i < transArr.size(); i += keyAddrElem, ++j) { vector row(embeddingSize); - readFile.read(reinterpret_cast (row.data()), embeddingSize * sizeof(float)); + readFile.read(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); aclError ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index ce6b7eed..85fde413 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -18,11 +18,12 @@ using namespace std; using namespace MxRec; using namespace tensorflow; -void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int initSeed) +void EmbTable::Init(const EmbInfo& eInfo, const RankInfo& rInfo, int initSeed) { #ifndef GTEST this->rankInfo = rInfo; this->seed = initSeed; + this->embInfo = eInfo; LOG_INFO("EmbTable init, deviceID {}, embSize {} running", rInfo.deviceId, embInfo.extEmbeddingSize); // 计算embedding table需要分配的内存块数 auto ret = aclrtSetDevice(static_cast(rInfo.deviceId)); @@ -41,7 +42,7 @@ void EmbTable::Init(const EmbInfo& embInfo, const RankInfo& rInfo, int initSeed) throw AclError(); } // 申请内存初始化 - RandomInit(newBlock, embInfo.initializeInfos); + RandomInit(newBlock); // 将新的内存块加入内存链表 memoryList.push_back(newBlock); SplitMemoryBlock(newBlock); @@ -77,7 +78,7 @@ int64_t EmbTable::GetEmbAddress() LOG_ERROR("aclrtMalloc failed, ret={}", ret); throw AclError(); } - RandomInit(addBlock, embInfo.initializeInfos); + RandomInit(addBlock); // 将新的内存块加入内存list memoryList.push_back(addBlock); SplitMemoryBlock(addBlock); @@ -90,12 +91,12 @@ int64_t EmbTable::GetEmbAddress() #endif } -void EmbTable::RandomInit(void* newBlock, const vector& initializeInfos) +void EmbTable::RandomInit(void* newBlock) { #ifndef GTEST - LOG_INFO("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); + LOG_INFO("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed, embInfo.initializeInfos.size()); vector devEmb(blockSize); - for (auto initializeInfo: initializeInfos) { + for (const auto& initializeInfo: as_const(embInfo.initializeInfos)) { LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name.c_str()); for (int i = 0; i < BLOCK_EMB_COUNT; i++) { initializeInfo.initializer->GenerateData(&devEmb[i * embSize], embSize); diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h index 7f3eb326..42eca691 100644 --- a/src/core/emb_table/emb_table.h +++ b/src/core/emb_table/emb_table.h @@ -23,7 +23,7 @@ namespace MxRec { public: EmbTable() = default; - void Init(const EmbInfo& embInfo, const RankInfo& rInfo, int seed = 0); + void Init(const EmbInfo& eInfo, const RankInfo& rInfo, int initSeed = 0); ~EmbTable(); @@ -59,7 +59,7 @@ namespace MxRec { // 内存块列表 vector memoryList; - void RandomInit(void* newBlock, const vector &initializeInfos); + void RandomInit(void* newBlock); // embSize由embInfo得出 void SplitMemoryBlock(void* newBlock); diff --git a/src/core/hd_transfer/acl_channel.h b/src/core/hd_transfer/acl_channel.h index 08da510f..6cbd4e0c 100644 --- a/src/core/hd_transfer/acl_channel.h +++ b/src/core/hd_transfer/acl_channel.h @@ -5,8 +5,8 @@ * Date: 2022/11/15 */ -#ifndef ACL_CHANNEL_H_ -#define ACL_CHANNEL_H_ +#ifndef ACL_CHANNEL_H +#define ACL_CHANNEL_H #include #include @@ -30,5 +30,5 @@ namespace tensorflow { } // namespace tensorflow -#endif // ACL_CHANNEL_H_ +#endif // ACL_CHANNEL_H diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 75d65b3e..de679b5c 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -386,14 +386,14 @@ key_offset_map_t HybridMgmt::SendHostMap(const string tableName) /// 加载key对应的offset,python侧调用;启动数据处理线程 /// \param ReceiveKeyOffsetMap -void HybridMgmt::ReceiveHostMap(all_key_offset_map_t ReceiveKeyOffsetMap) +void HybridMgmt::ReceiveHostMap(all_key_offset_map_t receiveKeyOffsetMap) { #ifndef GTEST preprocess->LoadSaveLock(); key_offset_mem_t loadKeyOffsetMap; offset_mem_t loadMaxOffset; - if (!ReceiveKeyOffsetMap.empty()) { - for (const auto& keyOffsetMap : as_const(ReceiveKeyOffsetMap)) { + if (!receiveKeyOffsetMap.empty()) { + for (const auto& keyOffsetMap : as_const(receiveKeyOffsetMap)) { auto& singleHashMap = loadKeyOffsetMap[keyOffsetMap.first]; auto& maxOffset = loadMaxOffset[keyOffsetMap.first]; for (const auto& it : keyOffsetMap.second) { diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index fd902589..dd0eb46f 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -65,7 +65,7 @@ namespace MxRec { key_offset_map_t SendHostMap(const string tableName); - void ReceiveHostMap(all_key_offset_map_t keyOffsetMap); + void ReceiveHostMap(all_key_offset_map_t receiveKeyOffsetMap); void Start(); diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index fdedb190..fedb0b5f 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -8,9 +8,7 @@ #ifndef MX_REC_RANDOM_NORMAL_INITIALIZER_H #define MX_REC_RANDOM_NORMAL_INITIALIZER_H -#include #include -#include #include "initializer/initializer.h" diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index e68ad02a..8de6ad52 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -8,9 +8,7 @@ #ifndef MX_REC_TRUNCATED_NORMAL_INITIALIZER_H #define MX_REC_TRUNCATED_NORMAL_INITIALIZER_H -#include #include -#include #include "initializer/initializer.h" @@ -36,8 +34,8 @@ namespace MxRec { std::default_random_engine generator; std::normal_distribution distribution; - float minBound; - float maxBound; + float minBound = 0; + float maxBound = 0; }; } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index ce885d0c..3ba374fb 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -254,7 +254,7 @@ namespace MxRec { vector >& keyCount); template - inline vector Count2Start(const vector& count) + inline vector Count2Start(const vector& count) const { vector start = { 0 }; for (size_t i = 0; i < count.size() - 1; ++i) { diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index de25abf2..3d421dcf 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -16,7 +16,7 @@ using namespace MxRec; inline void CacheManager::GetExternalKeys(EmbHashMapInfo &embHashMap, vector &externalKeys, - vector &internalKeys, const vector &keys) + vector &internalKeys, const vector &keys) const { for (const emb_key_t key : keys) { if (embHashMap.hostHashMap.find(key) == embHashMap.hostHashMap.end()) { @@ -28,7 +28,7 @@ inline void CacheManager::GetExternalKeys(EmbHashMapInfo &embHashMap, vector &externalKeys, - vector &externalSSDKeys) + vector &externalSSDKeys) const { LOG_DEBUG("TransferDDREmbWithSSD: batchKeySize:{}, externalKeys size:{}, externalSSDKeys size:{}", batchKeySize, externalKeys.size(), externalSSDKeys.size()); @@ -39,7 +39,7 @@ void CacheManager::AddDebugAndTraceLog(size_t batchKeySize, vector &e /// 去重和过滤无效key /// \param originalKeys 原有keys /// \param keys 处理后的keys -void CacheManager::HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys) +void CacheManager::HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys) const { // 去重并保持原key的顺序 结果可测试 unordered_set keySet; diff --git a/src/core/ssd_cache/cache_manager.h b/src/core/ssd_cache/cache_manager.h index c240d67c..86352598 100644 --- a/src/core/ssd_cache/cache_manager.h +++ b/src/core/ssd_cache/cache_manager.h @@ -113,12 +113,12 @@ namespace MxRec { EmbHashMapInfo& embHashMap); inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& externalKeys, - vector& internalKeys, const vector& keys); + vector& internalKeys, const vector& keys) const; void AddDebugAndTraceLog(size_t batchKeySize, vector& externalKeys, - vector& externalSSDKeys); + vector& externalSSDKeys) const; - void HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys); + void HandleRepeatAndInvalidKey(const vector& originalKeys, vector& keys) const; unordered_map embBaseInfos; -- Gitee From 975a109929ff632488c9e244c648340b97615646 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 22 Sep 2023 18:21:29 +0800 Subject: [PATCH 363/551] Match-id-0e8f52fb6454f6034fc5c9490781a02dd7fd893d --- .../emb_hash_ckpt/emb_hash_ckpt.h | 4 +- .../feat_admit_n_evict_ckpt.h | 4 +- .../host_emb_ckpt/host_emb_ckpt.h | 4 +- .../key_freq_map_ckpt/key_freq_map_ckpt.h | 8 +- .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.h | 4 +- .../nddr_offset_ckpt/nddr_offset_ckpt.h | 4 +- src/core/emb_hashmap/emb_hashmap.cpp | 4 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 18 +- src/core/hybrid_mgmt/hybrid_mgmt.h | 6 +- .../key_process/feature_admit_and_evict.cpp | 6 +- .../key_process/feature_admit_and_evict.h | 10 +- src/core/key_process/key_process.cpp | 102 +- src/core/key_process/key_process.h | 80 +- src/core/utils/common.h | 76 +- src/core/utils/config.cpp | 5 + src/core/utils/config.h | 4 +- src/core/utils/log.cpp | 18 +- src/core/utils/log.h | 73 +- src/core/utils/safe_queue.h | 220 ++-- src/core/utils/singleton.h | 63 +- src/core/utils/time_cost.h | 47 +- src/ops_tf/hybrid_dataset_ops.cpp | 1041 +++++++++-------- src/pybind/module_main.cpp | 269 ++--- src/pybind/module_main.h | 11 - src/tests/checkpoint/checkpoint_test.cpp | 54 +- .../ckpt_data_handler_test.cpp | 6 +- src/tests/emb_hashmap/emb_hashmap_test.cpp | 2 +- .../feature_admit_and_evict_test.cpp | 48 +- src/tests/key_process/key_process_test.cpp | 74 +- src/tests/utils/log_test.cpp | 20 +- 30 files changed, 1145 insertions(+), 1140 deletions(-) delete mode 100644 src/pybind/module_main.h diff --git a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h index 4fd413af..c7cd8aec 100644 --- a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h +++ b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h @@ -46,8 +46,8 @@ namespace MxRec { const int embHashElmtNum { 2 }; const int embCurrStatNum { 4 }; - emb_hash_mem_t saveEmbHashMaps; - emb_hash_mem_t loadEmbHashMaps; + EmbHashMemT saveEmbHashMaps; + EmbHashMemT loadEmbHashMaps; void SetEmbHashMapTrans(string embName); void SetDevOffsetTrans(string embName); diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h index a3fe3be5..adc9a830 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h @@ -46,8 +46,8 @@ namespace MxRec { const int countIdxOffset { 1 }; const int lastTimeIdxOffset { 2 }; - table_2_thresh_mem_t saveTable2Thresh; - table_2_thresh_mem_t loadTable2Thresh; + Table2ThreshMemT saveTable2Thresh; + Table2ThreshMemT loadTable2Thresh; AdmitAndEvictData saveHistRec; AdmitAndEvictData loadHistRec; diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h index 1dc34691..f51a5120 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h @@ -44,8 +44,8 @@ namespace MxRec { const int attribEmbDataInnerIdx { 1 }; const int embSveElmtNum { 4 }; - emb_mem_t* saveHostEmbs; - emb_mem_t* loadHostEmbs; + EmbMemT* saveHostEmbs; + EmbMemT* loadHostEmbs; void SetEmbInfoTrans(string embName); void SetEmbDataTrans(string embName); diff --git a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h index 2db7091c..cde71e68 100644 --- a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h +++ b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h @@ -35,10 +35,10 @@ namespace MxRec { const int freqMapElmtNum { 2 }; // Number of element types in the keyFreqMap during saving - key_freq_mem_t saveDDRKeyFreqMaps; - key_freq_mem_t loadDDRKeyFreqMaps; - key_freq_mem_t saveExcludeDDRKeyFreqMaps; - key_freq_mem_t loadExcludeDDRKeyFreqMaps; + KeyFreqMemT saveDDRKeyFreqMaps; + KeyFreqMemT loadDDRKeyFreqMaps; + KeyFreqMemT saveExcludeDDRKeyFreqMaps; + KeyFreqMemT loadExcludeDDRKeyFreqMaps; void SetDDRFreqMapTrans(string embName); void SetExcludeDDRFreqMapTrans(string embName); diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h index f670c302..dd7f4e16 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h @@ -35,8 +35,8 @@ namespace MxRec { const int embHashElmtNum { 2 }; - key_offset_mem_t saveKeyOffsetMap; - key_offset_mem_t loadKeyOffsetMap; + KeyOffsetMemT saveKeyOffsetMap; + KeyOffsetMemT loadKeyOffsetMap; }; } diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h index c8664e1c..0e414462 100644 --- a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h @@ -33,8 +33,8 @@ namespace MxRec { const vector fileDirNames { "HashTable", "HBM" }; const vector saveDataTypes { CkptDataType::NDDR_OFFSET }; - offset_mem_t saveMaxOffset; - offset_mem_t loadMaxOffset; + OffsetMemT saveMaxOffset; + OffsetMemT loadMaxOffset; }; } diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 863e1f6a..ff24531c 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -249,7 +249,7 @@ auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map return embHashMapsOld; } -void EmbHashMap::LoadHashMap(emb_hash_mem_t& loadData) +void EmbHashMap::LoadHashMap(EmbHashMemT& loadData) { embHashMaps = std::move(loadData); } @@ -540,7 +540,7 @@ void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& /// 记录日志:HBM和DDR换入换出后,比较hostHashMap中DDR内key和表对应的lfuCache对象中的key内容 void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const { - if (Log::GetLevel() != Log::TRACE) { + if (Log::GetLevel() != Log::trace) { return; } auto& hostMap = embHashMap.hostHashMap; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index de679b5c..ade55872 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -142,7 +142,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, // 比较hostHashMap和cacheManager的数据是否一致 void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) { - if (Log::GetLevel() != Log::TRACE) { + if (Log::GetLevel() != Log::trace) { return; } auto& embHashMaps = saveData.embHashMaps; @@ -359,12 +359,12 @@ void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, /// 获取key对应的offset,python侧调用 /// \param tableName 表名 /// \return -key_offset_map_t HybridMgmt::SendHostMap(const string tableName) +KeyOffsetMapT HybridMgmt::SendHostMap(const string tableName) { #ifndef GTEST preprocess->LoadSaveLock(); - key_offset_mem_t keyOffsetMap; - key_offset_map_t sendKeyOffsetMap; + KeyOffsetMemT keyOffsetMap; + KeyOffsetMapT sendKeyOffsetMap; if (!mgmtRankInfo.noDDR) { LOG_DEBUG(MGMT + "Start send sparse data: ddr mode hashmap"); @@ -386,12 +386,12 @@ key_offset_map_t HybridMgmt::SendHostMap(const string tableName) /// 加载key对应的offset,python侧调用;启动数据处理线程 /// \param ReceiveKeyOffsetMap -void HybridMgmt::ReceiveHostMap(all_key_offset_map_t receiveKeyOffsetMap) +void HybridMgmt::ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap) { #ifndef GTEST preprocess->LoadSaveLock(); - key_offset_mem_t loadKeyOffsetMap; - offset_mem_t loadMaxOffset; + KeyOffsetMemT loadKeyOffsetMap; + OffsetMemT loadMaxOffset; if (!receiveKeyOffsetMap.empty()) { for (const auto& keyOffsetMap : as_const(receiveKeyOffsetMap)) { auto& singleHashMap = loadKeyOffsetMap[keyOffsetMap.first]; @@ -422,7 +422,9 @@ void HybridMgmt::ReceiveHostMap(all_key_offset_map_t receiveKeyOffsetMap) /// \param setupHostEmbs /// \param embTableCount /// \return -bool HybridMgmt::IsLoadDataMatches(emb_mem_t& loadHostEmbs, EmbInfo& setupHostEmbs, size_t& embTableCount) const +bool HybridMgmt::IsLoadDataMatches(const EmbMemT& loadHostEmbs, + const EmbInfo& setupHostEmbs, + size_t& embTableCount) const { bool loadDataMatches = { true }; const auto& loadEmbTable { loadHostEmbs.find(setupHostEmbs.name) }; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index dd0eb46f..8f45a745 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -63,9 +63,9 @@ namespace MxRec { void SetFeatureTypeForLoad(vector& loadFeatures, const FeatureAdmitAndEvict& featAdmitNEvict); - key_offset_map_t SendHostMap(const string tableName); + KeyOffsetMapT SendHostMap(const string tableName); - void ReceiveHostMap(all_key_offset_map_t receiveKeyOffsetMap); + void ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap); void Start(); @@ -115,7 +115,7 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); - bool IsLoadDataMatches(emb_mem_t& loadHostEmbs, EmbInfo& setupHostEmbs, size_t& embTableCount) const; + bool IsLoadDataMatches(const EmbMemT& loadHostEmbs, const EmbInfo& setupHostEmbs, size_t& embTableCount) const; void NotifyBySessionRun(int channelID) const; diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index c7a97daa..b3354b9c 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -38,7 +38,7 @@ bool FeatureAdmitAndEvict::Init(const std::vector& thresholdValu // 以下为类的公共接口 FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, - const std::unique_ptr& batch, keys_t& splitKey, std::vector& keyCount) + const std::unique_ptr& batch, KeysT& splitKey, std::vector& keyCount) { if (splitKey.size() != keyCount.size()) { LOG_ERROR("splitKey.size {} != keyCount.size {}", splitKey.size(), keyCount.size()); @@ -251,7 +251,7 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t return true; } -auto FeatureAdmitAndEvict::GetTableThresholds() -> table_2_thresh_mem_t +auto FeatureAdmitAndEvict::GetTableThresholds() -> Table2ThreshMemT { std::lock_guard lock(m_syncMutexs); return m_table2Threshold; @@ -263,7 +263,7 @@ auto FeatureAdmitAndEvict::GetHistoryRecords() -> AdmitAndEvictData& return m_recordsData; } -void FeatureAdmitAndEvict::LoadTableThresholds(table_2_thresh_mem_t& loadData) +void FeatureAdmitAndEvict::LoadTableThresholds(Table2ThreshMemT& loadData) { std::lock_guard lock(m_syncMutexs); m_table2Threshold = std::move(loadData); diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 4f56f5b5..cc4f76ca 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -52,13 +52,13 @@ namespace MxRec { // 以下为类的公共接口 // 特征准入接口 - FeatureAdmitReturnType FeatureAdmit(int channel, const std::unique_ptr& batch, - keys_t& splitKey, std::vector& keyCount); + FeatureAdmitReturnType FeatureAdmit(int channel, const std::unique_ptr& batch, + KeysT& splitKey, std::vector& keyCount); // 特征淘汰接口 void FeatureEvict(map>& evictKeyMap); void ExecuteFeatureAdmit( - const string& tableName, int channel, keys_t& splitKey, absl::flat_hash_map& mergeKeys); + const string& tableName, int channel, KeysT& splitKey, absl::flat_hash_map& mergeKeys); // 特征淘汰的使能接口 void SetFunctionSwitch(bool isEnableEvict); @@ -74,10 +74,10 @@ namespace MxRec { const std::vector& embNames, bool isTimestamp); // 与模型保存加载交互的接口 - auto GetTableThresholds() -> table_2_thresh_mem_t; + auto GetTableThresholds() -> Table2ThreshMemT; auto GetHistoryRecords() -> AdmitAndEvictData&; - void LoadTableThresholds(table_2_thresh_mem_t& loadData); + void LoadTableThresholds(Table2ThreshMemT& loadData); void LoadHistoryRecords(AdmitAndEvictData& loadData); static std::vector m_cfgThresholds; // 用于判断阈值配置的有效性 diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 27df4961..742b0217 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -32,7 +32,7 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos SetupHotEmbUpdateStep(); } - map scInfo; + map scInfo; for (const auto& info: eInfos) { embInfos[info.name] = info; scInfo[info.name] = info.sendCount; @@ -117,12 +117,12 @@ void KeyProcess::InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo) HOT_EMB_CACHE_PCT / static_cast(embeddingSize)); } -auto KeyProcess::GetMaxOffset() -> offset_mem_t +auto KeyProcess::GetMaxOffset() -> OffsetMemT { return maxOffset; } -auto KeyProcess::GetKeyOffsetMap() -> key_offset_mem_t +auto KeyProcess::GetKeyOffsetMap() -> KeyOffsetMemT { return keyOffsetMap; } @@ -132,14 +132,14 @@ auto KeyProcess::GetFeatAdmitAndEvict() -> FeatureAdmitAndEvict& return m_featureAdmitAndEvict; } -void KeyProcess::LoadMaxOffset(offset_mem_t& loadData) +void KeyProcess::LoadMaxOffset(OffsetMemT& loadData) { maxOffset = std::move(loadData); } /// 加载每张表key到offset的映射 /// \param loadData -void KeyProcess::LoadKeyOffsetMap(key_offset_mem_t& loadData) +void KeyProcess::LoadKeyOffsetMap(KeyOffsetMemT& loadData) { keyOffsetMap = std::move(loadData); } @@ -197,7 +197,7 @@ void KeyProcess::GetUniqueConfig(UniqueConf& uniqueConf) } void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, - const unique_ptr & batch, UniquePtr& unique) + const unique_ptr & batch, UniquePtr& unique) { uniqueConf.desiredSize = static_cast(batch->Size()); if (preBatchSize != batch->Size()) { @@ -223,7 +223,7 @@ void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) { - unique_ptr batch; + unique_ptr batch; UniquePtr unique = nullptr; UniqueConf uniqueConf; size_t preBatchSize = 0; @@ -239,7 +239,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) while (true) { TimeCost getAndProcessTC; TimeCost getBatchDataTC; - batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue + batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue LOG_DEBUG("getBatchDataTC(ms):{}", getBatchDataTC.ElapsedMS()); if (batch == nullptr) { break; @@ -255,7 +255,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) " get data time(ms):{}, batch name:{}, channel:{}, batchID:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, batch->name, batch->channel, batch->batchId); - auto batchQueue = SingletonQueue::getInstances(threadId + KEY_PROCESS_THREAD * batch->channel); + auto batchQueue = SingletonQueue::GetInstances(threadId + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } unique->UnInitialize(); @@ -269,12 +269,12 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) void KeyProcess::KeyProcessTask(int channel, int threadId) { - unique_ptr batch; + unique_ptr batch; try { while (true) { TimeCost getAndProcessTC; TimeCost getBatchDataTC; - batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue + batch = GetBatchData(channel, threadId); // get batch data from SingletonQueue LOG_DEBUG("getBatchDataTC(ms):{}", getBatchDataTC.ElapsedMS()); if (batch == nullptr) { break; @@ -289,7 +289,7 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) " get data time(ms):{}, batch name:{}, channel:{}, batchID:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, batch->name, batch->channel, batch->batchId); - auto batchQueue = SingletonQueue::getInstances(threadId + KEY_PROCESS_THREAD * batch->channel); + auto batchQueue = SingletonQueue::GetInstances(threadId + KEY_PROCESS_THREAD * batch->channel); batchQueue->PutDirty(move(batch)); } } catch (const EndRunExit &e) { @@ -298,7 +298,7 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) LOG_INFO(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, threadId, channel); } -void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, +void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, vector & restore, vector & hotPos, vector >& keyCount) { @@ -316,7 +316,7 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector < LOG_DEBUG("uniqueTc(ms):{}", uniqueTc.ElapsedMS()); } -bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, UniquePtr& unique, +bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, UniquePtr& unique, int channel, int threadId) { // tuple for keyRec restore hotPos scAll countRecv @@ -371,9 +371,9 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& bat return true; } -bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) +bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId) { - vector splitKeys; + vector splitKeys; vector restore; vector hotPos; vector> keyCount; @@ -434,11 +434,11 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channe return true; } -void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tensors, keys_t& lookupKeys, int channel) +void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tensors, KeysT& lookupKeys, int channel) { if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && channel == TRAIN_CHANNEL_ID) { - keys_t uniqueKeys; + KeysT uniqueKeys; vector restoreVecSec; TimeCost globalUniqueSyncTC; @@ -449,7 +449,7 @@ void KeyProcess::PushGlobalUniqueTensors(const unique_ptr>& tenso } } -vector KeyProcess::GetCountRecv(const unique_ptr& batch, int id, +vector KeyProcess::GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss) { TimeCost getCountRecvTC; @@ -479,8 +479,8 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, return countRecv; } -void KeyProcess::PushResult(unique_ptr& batch, unique_ptr> tensors, - keys_t& lookupKeys) +void KeyProcess::PushResult(unique_ptr& batch, unique_ptr> tensors, + KeysT& lookupKeys) { std::unique_lock lockGuard(mut); storage.push_front(move(tensors)); @@ -492,16 +492,16 @@ void KeyProcess::PushResult(unique_ptr& batch, unique_ptr中读取batch数据并返回。batch数据由 ReadEmbKeyV2 写入。 + * 从共享队列SingletonQueue中读取batch数据并返回。batch数据由 ReadEmbKeyV2 写入。 * commID为线程标识[0, KEY_PROCESS_THREAD-1],不同线程、训练或推理数据用不同的共享队列通信 */ -unique_ptr KeyProcess::GetBatchData(int channel, int commId) +unique_ptr KeyProcess::GetBatchData(int channel, int commId) { EASY_FUNCTION() - unique_ptr batch = nullptr; + unique_ptr batch = nullptr; // train data, queue id = thread id [0, KEY_PROCESS_THREAD-1] - auto batchQueue = SingletonQueue::getInstances(commId + KEY_PROCESS_THREAD * channel); + auto batchQueue = SingletonQueue::GetInstances(commId + KEY_PROCESS_THREAD * channel); EASY_BLOCK("get samples") EASY_VALUE("run on CPU", sched_getcpu()) TimeCost tc = TimeCost(); @@ -540,7 +540,7 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) return batch; } -size_t KeyProcess::GetKeySize(const unique_ptr &batch) +size_t KeyProcess::GetKeySize(const unique_ptr &batch) { size_t size = rankInfo.rankSize * embInfos[batch->name].sendCount; if (!rankInfo.useStatic) { @@ -549,7 +549,7 @@ size_t KeyProcess::GetKeySize(const unique_ptr &batch) return size; } -void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch, UniquePtr& unique, +void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch, UniquePtr& unique, int id, UniqueInfo& uniqueInfoOut) { EASY_FUNCTION(profiler::colors::Purple) @@ -608,7 +608,7 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch } } -void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, +void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, KeySendInfo& keySendInfo, vector& sc, vector& splitSize) { std::shared_lock lock(g_smut); @@ -637,7 +637,7 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uni } } -void KeyProcess::ComputeHotPos(const unique_ptr &batch, absl::flat_hash_map &hotMap, +void KeyProcess::ComputeHotPos(const unique_ptr &batch, absl::flat_hash_map &hotMap, vector &hotPos, vector &restore, const int hotOffset) const { auto* inputData = batch->sample.data(); @@ -690,8 +690,8 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS EASY_END_BLOCK } -auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, - vector& splitKeys) -> tuple, vector> +auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, + vector& splitKeys) -> tuple, vector> { TimeCost processSplitKeysTC; EASY_FUNCTION(profiler::colors::Purple) @@ -711,13 +711,13 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, i.resize(embInfos[batch->name].sendCount, -1); } } - keys_t keySend; + KeysT keySend; vector sc; // send count for (const auto& i: splitKeys) { sc.push_back(static_cast(i.size())); keySend.insert(keySend.cend(), i.cbegin(), i.cend()); } - keys_t keyRecv; + KeysT keyRecv; TimeCost getScAllTC; auto scAll = GetScAll(sc, id, batch->channel); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 @@ -752,12 +752,12 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, * splitKeys返回:将数据的key切分到其所在dev id对应的桶中,并去重。 * restore返回:去重后key在桶内偏移量(用于计算恢复向量) */ -auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple, vector> +auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple, vector> { EASY_FUNCTION(profiler::colors::Gold) auto* batchData = batch->sample.data(); size_t miniBs = batch->Size(); - vector splitKeys(rankInfo.rankSize); + vector splitKeys(rankInfo.rankSize); vector restore(batch->Size()); vector hashSplitLens(rankInfo.rankSize); // 初始化全0,记录每个桶的长度 absl::flat_hash_map uKey; // 用于去重查询 @@ -789,13 +789,13 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple< return { splitKeys, restore }; } -auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const - -> tuple, vector, vector>> +auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const + -> tuple, vector, vector>> { EASY_FUNCTION(profiler::colors::Gold) auto* batchData = batch->sample.data(); size_t miniBs = batch->Size(); - vector splitKeys(rankInfo.rankSize); + vector splitKeys(rankInfo.rankSize); vector> keyCount(rankInfo.rankSize); // splitKeys在原始batch中对应的频次 vector restore(batch->Size()); vector hashSplitLens(rankInfo.rankSize); // 初始化全0,记录每个桶的长度 @@ -839,13 +839,13 @@ auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const return { splitKeys, restore, keyCount }; } -auto KeyProcess::HotHashSplit(const unique_ptr& batch) -> -tuple, vector, vector> +auto KeyProcess::HotHashSplit(const unique_ptr& batch) -> +tuple, vector, vector> { EASY_FUNCTION(profiler::colors::Gold) auto* batchData = batch->sample.data(); size_t miniBs = batch->Size(); - vector splitKeys(rankInfo.rankSize); + vector splitKeys(rankInfo.rankSize); vector restore(batch->Size()); absl::flat_hash_map uKey; // 用于去重查询 absl::flat_hash_map keyCountMap; @@ -901,8 +901,8 @@ tuple, vector, vector> return { splitKeys, restore, hotPos }; } -void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, - const unique_ptr& batch) +void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, + const unique_ptr& batch) { vector splitKeysSize {}; if (rankInfo.useStatic) { @@ -920,7 +920,7 @@ void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& h } } -void KeyProcess::UpdateHotMapForUnique(const keys_t &keySend, const vector &keyCount, +void KeyProcess::UpdateHotMapForUnique(const KeysT &keySend, const vector &keyCount, uint32_t count, bool refresh, const string& embName) { auto& hotMap = hotKey[embName]; @@ -1013,7 +1013,7 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, in LOG_DEBUG("rank {} key scAllOut matrix:\n{}", rankInfo.rankId, VectorToString(scAllOut)); } -void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int channel) +void KeyProcess::Key2Offset(const EmbNameT& embName, KeysT& splitKey, int channel) { TimeCost key2OffsetTC; EASY_FUNCTION(profiler::colors::Blue600) @@ -1055,7 +1055,7 @@ void KeyProcess::Key2Offset(const emb_name_t& embName, keys_t& splitKey, int cha embName, maxOffsetTmp, embInfos[embName].devVocabSize, key2OffsetTC.ElapsedMS()); } -void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& splitKey, int channel) +void KeyProcess::Key2OffsetDynamicExpansion(const EmbNameT& embName, KeysT& splitKey, int channel) { TimeCost key2OffsetTC; EASY_FUNCTION(profiler::colors::Blue600) @@ -1096,7 +1096,7 @@ void KeyProcess::Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& s * 实现方案2:用map记录keySend中key和表内index/offset的映射,在恢复emb时直接根据batch的key查询该map即可找到receive * emb中的 位置,时间复杂度:O(map构建keySend.size + map查询),空间复杂度:O(map) */ -void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, +void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, vector& restoreVec, int hotPosSize) const { TimeCost buildRestoreVecTC; @@ -1107,7 +1107,7 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vec emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); if (restoreVec[i] >= hotPosSize) { restoreVec[i] += blockOffset[devId]; - } else if (Log::GetLevel() >= Log::DEBUG) { + } else if (Log::GetLevel() >= Log::debug) { hotNum += 1; } } @@ -1142,7 +1142,7 @@ T KeyProcess::GetInfo(info_list_t& list, int batch, const string& embName, in /// \param embName 表名 /// \param channel 通道索引(训练/推理) /// \return -keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) +KeysT KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) { TimeCost tc = TimeCost(); // 循环尝试获取list中的数据;如果key process线程退出或者处理数据超时,返回空vector @@ -1163,7 +1163,7 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) } try { auto ret = GetInfo(lookupKeysList, batch, embName, channel); - return get(ret); + return get(ret); } catch (EmptyList&) { LOG_TRACE("getting info failed {}[{}]:{}", embName, channel, batch); this_thread::sleep_for(1ms); @@ -1183,7 +1183,7 @@ keys_t KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type) { TimeCost tc = TimeCost(); - info_list_t* list; + info_list_t* list; // 根据数据类型,选择对应的list switch (type) { diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 3ba374fb..767fe387 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -50,7 +50,7 @@ namespace MxRec { using heap_t = priority_queue, Cmp>; template - using info_list_t = map, MAX_QUEUE_NUM>>; + using info_list_t = map, MAX_QUEUE_NUM>>; enum class ProcessedInfo { RESTORE, @@ -84,21 +84,21 @@ namespace MxRec { unique_ptr> GetInfoVec(int batch, const string& embName, int channel, ProcessedInfo type); - keys_t GetLookupKeys(int batch, const string& embName, int channel); + KeysT GetLookupKeys(int batch, const string& embName, int channel); int GetMaxStep(int channelId) const; int Start(); - auto GetMaxOffset() -> offset_mem_t; + auto GetMaxOffset() -> OffsetMemT; - auto GetKeyOffsetMap() -> key_offset_mem_t; + auto GetKeyOffsetMap() -> KeyOffsetMemT; auto GetFeatAdmitAndEvict() -> FeatureAdmitAndEvict&; - void LoadMaxOffset(offset_mem_t& loadData); + void LoadMaxOffset(OffsetMemT& loadData); - void LoadKeyOffsetMap(key_offset_mem_t& loadData); + void LoadKeyOffsetMap(KeyOffsetMemT& loadData); void Destroy(); @@ -153,22 +153,22 @@ namespace MxRec { T GetInfo(info_list_t& list, int batch, const string& embName, int channel); RankInfo rankInfo; - map embInfos; + map embInfos; MPI_Comm comm[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD]; std::mutex mut {}; vector> procThreads {}; std::mutex loadSaveMut[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD] {}; - info_list_t lookupKeysList; + info_list_t lookupKeysList; list>> storage; - info_list_t infoList; - info_list_t all2AllList; - map maxOffset {}; - map> keyOffsetMap {}; + info_list_t infoList; + info_list_t all2AllList; + map maxOffset {}; + map> keyOffsetMap {}; FeatureAdmitAndEvict m_featureAdmitAndEvict {}; - map> evictPosMap {}; - map> hotKey {}; - map hotEmbTotCount; - map embeddingTableMap {}; + map> evictPosMap {}; + map> hotKey {}; + map hotEmbTotCount; + map embeddingTableMap {}; FactoryPtr factory {}; int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; @@ -180,45 +180,45 @@ namespace MxRec { void KeyProcessTaskWithFastUnique(int channel, int threadId); - bool KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId); + bool KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId); - bool KeyProcessTaskHelperWithFastUnique(unique_ptr &batch, UniquePtr& unique, + bool KeyProcessTaskHelperWithFastUnique(unique_ptr &batch, UniquePtr& unique, int channel, int threadId); - auto ProcessSplitKeys(const unique_ptr& batch, int id, - vector& splitKeys) -> tuple, vector>; + auto ProcessSplitKeys(const unique_ptr& batch, int id, + vector& splitKeys) -> tuple, vector>; void GetUniqueConfig(UniqueConf& uniqueConf); void InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, - const unique_ptr & batch, UniquePtr& unique); + const unique_ptr & batch, UniquePtr& unique); - void ProcessBatchWithFastUnique(const unique_ptr &batch, UniquePtr& unique, + void ProcessBatchWithFastUnique(const unique_ptr &batch, UniquePtr& unique, int id, UniqueInfo& uniqueInfoOut); - size_t GetKeySize(const unique_ptr &batch); + size_t GetKeySize(const unique_ptr &batch); void All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, All2AllInfo& all2AllInfoOut); - auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; + auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; - auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector>; + auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector>; - auto HashSplitWithFAAE(const unique_ptr& batch) const - -> tuple, vector, vector>>; + auto HashSplitWithFAAE(const unique_ptr& batch) const + -> tuple, vector, vector>>; vector GetScAll(const vector& keyScLocal, int commId, int channel) const; void GetScAllForUnique(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const; - void Key2Offset(const emb_name_t& embName, keys_t& splitKey, int channel); + void Key2Offset(const EmbNameT& embName, KeysT& splitKey, int channel); - void Key2OffsetDynamicExpansion(const emb_name_t& embName, keys_t& splitKey, int channel); + void Key2OffsetDynamicExpansion(const EmbNameT& embName, KeysT& splitKey, int channel); - unique_ptr GetBatchData(int channel, int commId); + unique_ptr GetBatchData(int channel, int commId); - void BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, + void BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, vector& restoreVec, int hotPosSize = 0) const; void SendA2A(const vector& a2aInfo, const string& embName, int channel, int batch); @@ -230,26 +230,26 @@ namespace MxRec { void UpdateHotMap(absl::flat_hash_map& keyCountMap, uint32_t count, bool refresh, const string& embName); - void UpdateHotMapForUnique(const keys_t &keySend, const vector &keyCount, + void UpdateHotMapForUnique(const KeysT &keySend, const vector &keyCount, uint32_t count, bool refresh, const string& embName); - void HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, + void HandleHotAndSendCount(const unique_ptr &batch, UniqueInfo& uniqueInfoOut, KeySendInfo& keySendInfo, vector& sc, vector& splitSize); - void PushResult(unique_ptr& batch, unique_ptr> tensors, keys_t& lookupKeys); + void PushResult(unique_ptr& batch, unique_ptr> tensors, KeysT& lookupKeys); - void PushGlobalUniqueTensors(const unique_ptr>& tensors, keys_t& lookupKeys, int channel); + void PushGlobalUniqueTensors(const unique_ptr>& tensors, KeysT& lookupKeys, int channel); - void AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, - const unique_ptr& batch); + void AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, + const unique_ptr& batch); - void ComputeHotPos(const unique_ptr &batch, absl::flat_hash_map &hotMap, + void ComputeHotPos(const unique_ptr &batch, absl::flat_hash_map &hotMap, vector &hotPos, vector &restore, const int hotOffset) const; - vector GetCountRecv(const unique_ptr& batch, int id, + vector GetCountRecv(const unique_ptr& batch, int id, vector>& keyCount, vector scAll, vector ss); - void HashSplitHelper(const unique_ptr & batch, vector & splitKeys, + void HashSplitHelper(const unique_ptr & batch, vector & splitKeys, vector & restore, vector & hotPos, vector >& keyCount); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 2adbcb5f..80c51498 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -96,10 +96,10 @@ namespace MxRec { using emb_key_t = int64_t; using freq_num_t = int64_t; - using emb_name_t = std::string; - using keys_t = std::vector; - using lookup_key_t = std::tuple; // batch_id quarry_lable keys_vector - using tensor_info_t = std::tuple>>::iterator>; + using EmbNameT= std::string; + using KeysT = std::vector; + using LookupKeyT = std::tuple; // batch_id quarry_lable keys_vector + using TensorInfoT = std::tuple>>::iterator>; namespace HybridOption { const int USE_STATIC = 0x001; @@ -128,15 +128,15 @@ namespace MxRec { inline int GetUBSize(int devID) { - std::map ChipUbSizeList = {{"910A", UBSize::ASCEND910_A}, + std::map chipUbSizeList = {{"910A", UBSize::ASCEND910_A}, {"910B", UBSize::ASCEND910_B}, {"920A", UBSize::ASCEND920_A}, {"910B1", UBSize::ASCEND910_B1}, {"910B2", UBSize::ASCEND910_B2}, {"910B3", UBSize::ASCEND910_B3}, {"910B4", UBSize::ASCEND910_B4}}; - auto it = ChipUbSizeList.find(GetChipName(devID)); - if (it != ChipUbSizeList.end()) { + const auto it = chipUbSizeList.find(GetChipName(devID)); + if (it != chipUbSizeList.end()) { return it->second; } @@ -153,8 +153,8 @@ namespace MxRec { std::string UnParse() const { std::string s; - constexpr size_t MAX_DISP_LEN = 20; - int maxLen = static_cast(std::min(sample.size(), MAX_DISP_LEN)); + constexpr size_t maxDispLen = 20; + int maxLen = static_cast(std::min(sample.size(), maxDispLen)); for (int i = 0; i < maxLen; i++) { s += std::to_string(sample[i]) + " "; } @@ -183,8 +183,8 @@ namespace MxRec { const void *tensor; }; - using emb_batch_t = Batch; - using batch_task_t = BatchTask; + using EmbBatchT = Batch; + using BatchTaskT = BatchTask; struct DDRParam { vector tmpDataOut; @@ -254,7 +254,7 @@ namespace MxRec { struct ThresholdValue { ThresholdValue() = default; - ThresholdValue(emb_name_t name, int countThre, int timeThre, int faaeCoef) + ThresholdValue(EmbNameT name, int countThre, int timeThre, int faaeCoef) { tableName = name; countThreshold = countThre; @@ -262,7 +262,7 @@ namespace MxRec { faaeCoefficient = faaeCoef; } - emb_name_t tableName { "" }; // embName + EmbNameT tableName { "" }; // embName int countThreshold { -1 }; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 int timeThreshold { -1 }; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 int faaeCoefficient { 1 }; // 配置后,该表在准入时,count计数会乘以该系数 @@ -292,7 +292,7 @@ namespace MxRec { auto size = static_cast(GLOG_MAX_BUF_SIZE); unique_ptr buf(new char[size]); memset_s(buf.get(), size, 0, size); - int nChar = snprintf_s(buf.get(), size, size-1, format.c_str(), args ...); + int nChar = snprintf_s(buf.get(), size, size - 1, format.c_str(), args ...); if (nChar == -1) { throw invalid_argument("StringFormat failed"); } @@ -386,11 +386,13 @@ namespace MxRec { std::vector initializeInfos, std::vector ssdDataPath) : name(name), sendCount(sendCount), embeddingSize(embeddingSize), extEmbeddingSize(extEmbeddingSize), - isSave(isSave), initializeInfos(initializeInfos), ssdDataPath(std::move(ssdDataPath)) + isSave(isSave), + devVocabSize(vocabsize[0]), + hostVocabSize(vocabsize[1]), + ssdVocabSize(vocabsize[SSD_SIZE_INDEX]), + initializeInfos(initializeInfos), + ssdDataPath(std::move(ssdDataPath)) { - devVocabSize = vocabsize[0]; - hostVocabSize = vocabsize[1]; - ssdVocabSize = vocabsize[SSD_SIZE_INDEX]; } std::string name; @@ -450,11 +452,11 @@ namespace MxRec { }; struct All2AllInfo { - keys_t keyRecv; + KeysT keyRecv; vector scAll; vector countRecv; All2AllInfo() = default; - All2AllInfo(keys_t keyRecv, vector scAll, vector countRecv) + All2AllInfo(KeysT keyRecv, vector scAll, vector countRecv) : keyRecv(keyRecv), scAll(scAll), countRecv(countRecv) {} }; @@ -468,19 +470,19 @@ namespace MxRec { }; struct KeySendInfo { - keys_t keySend; + KeysT keySend; vector keyCount; }; - using emb_mem_t = absl::flat_hash_map; - using emb_hash_mem_t = absl::flat_hash_map; - using offset_mem_t = std::map; - using key_offset_mem_t = std::map>; - using table_2_thresh_mem_t = absl::flat_hash_map; + using EmbMemT = absl::flat_hash_map; + using EmbHashMemT = absl::flat_hash_map; + using OffsetMemT = std::map; + using KeyOffsetMemT = std::map>; + using Table2ThreshMemT = absl::flat_hash_map; using trans_serialize_t = uint8_t; - using key_offset_map_t = std::map; - using all_key_offset_map_t = std::map>; - using key_freq_mem_t = unordered_map>; + using KeyOffsetMapT = std::map; + using AllKeyOffsetMapT = std::map>; + using KeyFreqMemT = unordered_map>; enum class CkptFeatureType { HOST_EMB = 0, @@ -493,14 +495,14 @@ namespace MxRec { }; struct CkptData { - emb_mem_t* hostEmbs = nullptr; - emb_hash_mem_t embHashMaps; - offset_mem_t maxOffset; - key_offset_mem_t keyOffsetMap; - table_2_thresh_mem_t table2Thresh; + EmbMemT* hostEmbs = nullptr; + EmbHashMemT embHashMaps; + OffsetMemT maxOffset; + KeyOffsetMemT keyOffsetMap; + Table2ThreshMemT table2Thresh; AdmitAndEvictData histRec; - key_freq_mem_t ddrKeyFreqMaps; - key_freq_mem_t excludeDDRKeyFreqMaps; + KeyFreqMemT ddrKeyFreqMaps; + KeyFreqMemT excludeDDRKeyFreqMaps; }; struct CkptTransData { @@ -529,7 +531,7 @@ namespace MxRec { EVICT_POS = 12 }; - ostream& operator<<(ostream& s, MxRec::CkptDataType type); + ostream& operator<<(ostream& ss, MxRec::CkptDataType type); } // end namespace MxRec #define KEY_PROCESS "\033[45m[KeyProcess]\033[0m " diff --git a/src/core/utils/config.cpp b/src/core/utils/config.cpp index 0e482427..35f0e6ff 100644 --- a/src/core/utils/config.cpp +++ b/src/core/utils/config.cpp @@ -13,6 +13,11 @@ using namespace std; namespace MxRec { + namespace ApplyGradientsStrategyOptions { + const std::string DIRECT_APPLY = "direct_apply"; + const std::string SUM_SAME_ID_GRADIENTS_AND_APPLY = "sum_same_id_gradients_and_apply"; + }; + // 设置环境变量默认值 string GlobalEnv::applyGradientsStrategy = ApplyGradientsStrategyOptions::DIRECT_APPLY; int GlobalEnv::aclTimeout = -1; // 默认阻塞方式,一直等待直到数据接收完成。 diff --git a/src/core/utils/config.h b/src/core/utils/config.h index 58040cb3..f29c2346 100644 --- a/src/core/utils/config.h +++ b/src/core/utils/config.h @@ -29,8 +29,8 @@ namespace MxRec { }; namespace ApplyGradientsStrategyOptions { - const std::string DIRECT_APPLY = "direct_apply"; - const std::string SUM_SAME_ID_GRADIENTS_AND_APPLY = "sum_same_id_gradients_and_apply"; + extern const std::string DIRECT_APPLY; + extern const std::string SUM_SAME_ID_GRADIENTS_AND_APPLY; }; struct GlobalEnv { diff --git a/src/core/utils/log.cpp b/src/core/utils/log.cpp index 2314a581..f3e490b2 100644 --- a/src/core/utils/log.cpp +++ b/src/core/utils/log.cpp @@ -11,27 +11,27 @@ namespace MxRec { -int MxRec::Log::_level = MxRec::Log::INFO; -int MxRec::Log::_rank = 0; +int MxRec::Log::level = MxRec::Log::info; +int MxRec::Log::rank = 0; void Log::SetRank(int rank) { - Log::_rank = rank; + Log::rank = rank; } void Log::SetLevel(int level) { - Log::_level = level; + Log::level = level; } int Log::GetLevel() { - return Log::_level; + return Log::level; } const char* Log::LevelToStr(int level) { - if (level < TRACE || level > ERROR) { + if (level < trace || level > error) { return "INVALID LEVEL"; } static const char* msg[] = { @@ -41,11 +41,11 @@ const char* Log::LevelToStr(int level) "WARN", "ERROR", }; - constexpr int LEVEL_OFFSET = 2; - return msg[level + LEVEL_OFFSET]; + constexpr int levelOffset = 2; + return msg[level + levelOffset]; } -void Log::LogUnpack(queue& fmt, stringstream &ss) +void Log::LogUnpack(std::queue& fmt, std::stringstream &ss) { while (!fmt.empty()) { ss << fmt.front(); diff --git a/src/core/utils/log.h b/src/core/utils/log.h index 3f3bf7bf..ed740cbf 100644 --- a/src/core/utils/log.h +++ b/src/core/utils/log.h @@ -6,19 +6,18 @@ * History: NA */ -#ifndef MXREC_LOG_H_ -#define MXREC_LOG_H_ +#ifndef MXREC_LOG_H +#define MXREC_LOG_H -#include -#include -#include +#include +#include +#include #include #include #include #include #include -using namespace std; namespace MxRec { @@ -28,11 +27,11 @@ constexpr size_t DELIM_LEN = 2; class Log { public: - static constexpr int TRACE = -2; - static constexpr int DEBUG = -1; - static constexpr int INFO = 0; - static constexpr int WARN = 1; - static constexpr int ERROR = 2; + static constexpr int trace = -2; + static constexpr int debug = -1; + static constexpr int info = 0; + static constexpr int warn = 1; + static constexpr int error = 2; static void SetRank(int rank); @@ -41,12 +40,12 @@ public: static int GetLevel(); template - static void Format(stringstream& ss, const char* fmt, Args &&...args) + static void Format(std::stringstream& ss, const char* fmt, Args &&...args) { - queue formats; - string tmp(fmt); - for (size_t pos = tmp.find_first_of("{}"); pos != string::npos; pos = tmp.find_first_of("{}")) { - string x = tmp.substr(0, pos); + std::queue formats; + std::string tmp(fmt); + for (size_t pos = tmp.find_first_of("{}"); pos != std::string::npos; pos = tmp.find_first_of("{}")) { + std::string x = tmp.substr(0, pos); formats.push(x); tmp = tmp.substr(pos + DELIM_LEN); } @@ -55,9 +54,9 @@ public: } template - static string Format(const char* fmt, Args &&...args) + static std::string Format(const char* fmt, Args &&...args) { - stringstream ss; + std::stringstream ss; Log::Format(ss, fmt, args...); return ss.str(); } @@ -65,14 +64,14 @@ public: template static void log(const char* file, int line, int level, const char* fmt, Args &&...args) { - stringstream ss; + std::stringstream ss; struct tm t; struct timeval tv; - gettimeofday(&tv, NULL); + gettimeofday(&tv, nullptr); localtime_r(&tv.tv_sec, &t); ss << "[MxRec][" << YEAR_BASE + t.tm_year << "/" << t.tm_mon << "/" << t.tm_mday<< " " << t.tm_hour << ":" << t.tm_min << ":" << t.tm_sec << "." << tv.tv_usec << "] [" - << Log::_rank << "] ["<< Log::LevelToStr(level) << "] [" + << Log::rank << "] ["<< Log::LevelToStr(level) << "] [" << (strrchr(file, '/') ? strrchr(file, '/') + 1 : file) << ":" << line << "] "; Log::Format(ss, fmt, args...); ss << std::endl; @@ -80,7 +79,7 @@ public: } template - static void log(const char* file, int line, int level, const string& fmt, Args &&...args) + static void log(const char* file, int line, int level, const std::string& fmt, Args &&...args) { Log::log(file, line, level, fmt.c_str(), args...); } @@ -88,10 +87,10 @@ public: private: static const char* LevelToStr(int level); - static void LogUnpack(queue& fmt, stringstream &ss); + static void LogUnpack(std::queue& fmt, std::stringstream &ss); template - static void LogUnpack(queue& fmt, stringstream &ss, head &h, tail &&...tails) + static void LogUnpack(std::queue& fmt, std::stringstream &ss, head &h, tail &&...tails) { if (!fmt.empty()) { ss << fmt.front(); @@ -100,26 +99,26 @@ private: ss << h; LogUnpack(fmt, ss, tails...); }; - static int _level; - static int _rank; + static int level; + static int rank; }; -#define LOG_TRACE(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::TRACE) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::TRACE, args) +#define LOG_TRACE(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::trace) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::trace, args) -#define LOG_DEBUG(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::DEBUG) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::DEBUG, args) +#define LOG_DEBUG(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::debug) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::debug, args) -#define LOG_INFO(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::INFO) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::INFO, args) +#define LOG_INFO(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::info) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::info, args) -#define LOG_WARN(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::WARN) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::WARN, args) +#define LOG_WARN(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::warn) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::warn, args) -#define LOG_ERROR(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::ERROR) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::ERROR, args) +#define LOG_ERROR(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::error) \ +MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::error, args) } -#endif // MXREC_LOG_H_ \ No newline at end of file +#endif // MXREC_LOG_H \ No newline at end of file diff --git a/src/core/utils/safe_queue.h b/src/core/utils/safe_queue.h index 3c2e2be5..8f7f2efc 100644 --- a/src/core/utils/safe_queue.h +++ b/src/core/utils/safe_queue.h @@ -16,137 +16,137 @@ #include #include #include "common.h" +namespace MxRec { + template + class SafeQueue { + static constexpr uint64_t defaultCap = 10; -template -class SafeQueue { - static constexpr uint64_t DEFAULT_CAP = 10; + public: + SafeQueue() = default; -public: - SafeQueue() = default; + ~SafeQueue() = default; - ~SafeQueue() = default; - - SafeQueue(SafeQueue const& other) - { - std::lock_guard lk(other.mut); - dataQueue = other.dataQueue; - } + SafeQueue(SafeQueue const &other) + { + std::lock_guard lk(other.mut); + dataQueue = other.dataQueue; + } - SafeQueue& operator=(SafeQueue const& other) - { - if (this == &other) { + SafeQueue &operator=(SafeQueue const &other) + { + if (this == &other) { + return *this; + } + std::lock_guard lk(other.mut); + dataQueue = other.dataQueue; return *this; } - std::lock_guard lk(other.mut); - dataQueue = other.dataQueue; - return *this; - } - - std::unique_ptr GetOne() - { - std::lock_guard lk(mut); - if (emptyQueue.empty()) { - return std::make_unique(); - } else { + + std::unique_ptr GetOne() + { + std::lock_guard lk(mut); + if (emptyQueue.empty()) { + return std::make_unique(); + } else { + auto t = move(emptyQueue.back()); + emptyQueue.pop_back(); + return move(t); + } + } + + std::unique_ptr WaitAndGetOne() + { + { + std::lock_guard lk(mut); + if (creatNum < capacity) { + creatNum++; + return std::make_unique(); + } + } + std::unique_lock locker(mut); + dirtyCond.wait(locker, [this] { return !emptyQueue.empty(); }); auto t = move(emptyQueue.back()); emptyQueue.pop_back(); return move(t); } - } - std::unique_ptr WaitAndGetOne() - { + void PutDirty(std::unique_ptr &&t) { - std::lock_guard lk(mut); - if (creatNum < capacity) { - creatNum++; - return std::make_unique(); + std::lock_guard lk(mut); + emptyQueue.push_back(move(t)); + dirtyCond.notify_one(); + } + + void Pushv(std::unique_ptr &&t) // 入队操作 + { + std::lock_guard lk(mut); + dataQueue.push_back(move(t)); + dataCond.notify_one(); + } + + std::unique_ptr WaitAndPop() + { + std::unique_lock lk(mut); + dataCond.wait(lk, [this] { return !dataQueue.empty(); }); + std::unique_ptr res = std::move(dataQueue.front()); + dataQueue.pop_front(); + return move(res); + } + + std::unique_ptr TryPop() + { + std::lock_guard lk(mut); + if (dataQueue.empty()) { + return nullptr; } + std::unique_ptr res = std::move(dataQueue.front()); + dataQueue.pop_front(); + return move(res); } - std::unique_lock locker(mut); - dirtyCond.wait(locker, [this] { return !emptyQueue.empty(); }); - auto t = move(emptyQueue.back()); - emptyQueue.pop_back(); - return move(t); - } - - void PutDirty(std::unique_ptr&& t) - { - std::lock_guard lk(mut); - emptyQueue.push_back(move(t)); - dirtyCond.notify_one(); - } - - void Pushv(std::unique_ptr&& t) // 入队操作 - { - std::lock_guard lk(mut); - dataQueue.push_back(move(t)); - dataCond.notify_one(); - } - - std::unique_ptr WaitAndPop() - { - std::unique_lock lk(mut); - dataCond.wait(lk, [this] { return !dataQueue.empty(); }); - std::unique_ptr res = std::move(dataQueue.front()); - dataQueue.pop_front(); - return move(res); - } - - std::unique_ptr TryPop() - { - std::lock_guard lk(mut); - if (dataQueue.empty()) { - return nullptr; + + bool Empty() const + { + std::lock_guard lk(mut); + return dataQueue.empty(); } - std::unique_ptr res = std::move(dataQueue.front()); - dataQueue.pop_front(); - return move(res); - } - - bool Empty() const - { - std::lock_guard lk(mut); - return dataQueue.empty(); - } - - size_t Size() const - { - std::lock_guard lk(mut); - return dataQueue.size(); - } - -private: - mutable std::mutex mut; - uint64_t capacity = DEFAULT_CAP; - std::atomic creatNum {}; - std::list> dataQueue; - std::list> emptyQueue; - std::condition_variable dataCond; - std::condition_variable dirtyCond; -}; - -template -class SingletonQueue { -public: - static SafeQueue* getInstances(int i) - { - static SafeQueue instance[MxRec::MAX_QUEUE_NUM]; - if (i >= MxRec::MAX_QUEUE_NUM || i < 0) { - return nullptr; + + size_t Size() const + { + std::lock_guard lk(mut); + return dataQueue.size(); } - return &instance[i]; + + private: + mutable std::mutex mut; + uint64_t capacity = defaultCap; + std::atomic creatNum{}; + std::list > dataQueue; + std::list > emptyQueue; + std::condition_variable dataCond; + std::condition_variable dirtyCond; }; - SingletonQueue() = delete; + template + class SingletonQueue { + public: + static SafeQueue *GetInstances(int i) + { + static SafeQueue instance[MxRec::MAX_QUEUE_NUM]; + if (i >= MxRec::MAX_QUEUE_NUM || i < 0) { + return nullptr; + } + return &instance[i]; + }; - ~SingletonQueue() = delete; + SingletonQueue() = delete; - SingletonQueue(T&&) = delete; + ~SingletonQueue() = delete; - SingletonQueue(const T&) = delete; + SingletonQueue(T &&) = delete; - void operator=(const T&) = delete; -}; + SingletonQueue(const T &) = delete; + void operator=(const T &) = delete; + }; +} #endif \ No newline at end of file diff --git a/src/core/utils/singleton.h b/src/core/utils/singleton.h index 2ec50940..7a265a29 100644 --- a/src/core/utils/singleton.h +++ b/src/core/utils/singleton.h @@ -17,37 +17,38 @@ * T must be destructed * @tparam T */ -template -class Singleton { -public: - Singleton() = delete; - - Singleton(const Singleton& singleton) = delete; - - Singleton& operator=(const Singleton& singleton) = delete; - - static T* GetInstance() - { - try { - static T instance; - return &instance; - } catch (std::exception& e) { - std::cerr << " create singleton error" << std::endl; - return nullptr; +namespace MxRec { + template + class Singleton { + public: + Singleton() = delete; + + Singleton(const Singleton &singleton) = delete; + + Singleton &operator=(const Singleton &singleton) = delete; + + static T *GetInstance() + { + try { + static T instance; + return &instance; + } catch (std::exception &e) { + std::cerr << " create singleton error" << std::endl; + return nullptr; + } } - } - - template - static T* GetInstance(P&& ... args) - { - try { - static T instance(std::forward

(args)...); - return &instance; - } catch (std::exception& e) { - std::cerr << " create singleton error" << std::endl; - return nullptr; - } - } -}; + template + static T *GetInstance(P &&... args) + { + try { + static T instance(std::forward

(args)...); + return &instance; + } catch (std::exception &e) { + std::cerr << " create singleton error" << std::endl; + return nullptr; + } + } + }; +} #endif diff --git a/src/core/utils/time_cost.h b/src/core/utils/time_cost.h index 9852d30f..495282c1 100644 --- a/src/core/utils/time_cost.h +++ b/src/core/utils/time_cost.h @@ -10,30 +10,31 @@ #define TIMECOST_H #include +namespace MxRec { + class TimeCost { + public: + TimeCost() noexcept + { + start_ = std::chrono::high_resolution_clock::now(); + } -class TimeCost { -public: - TimeCost() - { - start_ = std::chrono::high_resolution_clock::now(); - } + double ElapsedSec() + { + std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); + std::chrono::duration d = + std::chrono::duration_cast < std::chrono::duration < double >> (end - start_); + return d.count(); + } - double ElapsedSec() - { - std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); - std::chrono::duration d = std::chrono::duration_cast>(end - start_); - return d.count(); - } - - size_t ElapsedMS() - { - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::milliseconds d = std::chrono::duration_cast(end - start_); - return d.count(); - } - -private: - std::chrono::high_resolution_clock::time_point start_; -}; + size_t ElapsedMS() + { + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::milliseconds d = std::chrono::duration_cast(end - start_); + return d.count(); + } + private: + std::chrono::high_resolution_clock::time_point start_; + }; +} #endif \ No newline at end of file diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 368d12c6..c8b29dcc 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -34,588 +34,591 @@ using OpKernelConstructionPtr = OpKernelConstruction*; using OpKernelContextPtr = OpKernelContext*; using InferenceContextPtr = ::tensorflow::shape_inference::InferenceContext*; -TimeCost staticSw {}; -TimeCost staticReadRaw {}; -array batchIdsInfo {}; - -REGISTER_OP("ClearChannel").Attr("channel_id : int"); - -class ClearChannel : public OpKernel { -public: - explicit ClearChannel(OpKernelConstructionPtr context) : OpKernel(context) - { - OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); +namespace { + static TimeCost g_staticSw{}; +} + +namespace MxRec { + class ClearChannel : public OpKernel { + public: + explicit ClearChannel(OpKernelConstructionPtr context) : OpKernel(context) + { + OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); + + if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { + throw runtime_error(StringFormat( + "channelId is invalid, It should be in range [0, %d)", MAX_CHANNEL_NUM)); + } - if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { - throw runtime_error(StringFormat( - "channelId is invalid, It should be in range [0, %d)", MAX_CHANNEL_NUM)); + if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "ClearChannel channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", + MAX_CHANNEL_NUM))); + return; + } } - if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( - "ClearChannel channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", - MAX_CHANNEL_NUM))); - return; - } - } + ~ClearChannel() = default; - ~ClearChannel() = default; + void Compute(OpKernelContextPtr context) override + { + LOG_INFO("clear channel {}, context {}", channelId, context->step_id()); + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + hybridMgmtBlock->ResetAll(channelId); + } - void Compute(OpKernelContextPtr context) override - { - LOG_INFO("clear channel {}, context {}", channelId, context->step_id()); - HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - hybridMgmtBlock->ResetAll(channelId); - } + private: + int channelId {}; + }; -private: - int channelId {}; -}; + class ReturnTimestamp : public OpKernel { + public: + explicit ReturnTimestamp(OpKernelConstructionPtr context) : OpKernel(context) + {} -REGISTER_KERNEL_BUILDER(Name("ClearChannel").Device(DEVICE_CPU), ClearChannel); + ~ReturnTimestamp() = default; + void Compute(OpKernelContextPtr context) override + { + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = time(nullptr); + } + }; + + class ReadEmbKeyV2Dynamic : public OpKernel { + public: + explicit ReadEmbKeyV2Dynamic(OpKernelConstructionPtr context) : OpKernel(context) + { + LOG_DEBUG("ReadEmbKeyV2Dynamic init"); + OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference + OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); + OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); + hybridMgmtBlock = Singleton::GetInstance(); + // 特征准入&淘汰功能 相关校验 + + // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 + if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() && + !FeatureAdmitAndEvict::IsThresholdCfgOK( + FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp) + ) { + context->SetStatus( + errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold config, or timestamp error ...")); + return; + } -// ##################### ReturnTimestamp ####################### -REGISTER_OP("ReturnTimestamp") - .Input("input: int64") - .Output("output: int64") - .SetShapeFn([](InferenceContextPtr c) { - c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); - return Status::OK(); - }); -class ReturnTimestamp : public OpKernel { -public: - explicit ReturnTimestamp(OpKernelConstructionPtr context) : OpKernel(context) - {} - - ~ReturnTimestamp() = default; - - void Compute(OpKernelContextPtr context) override - { - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); - auto out = output->flat(); - out(0) = time(nullptr); - } -}; - -REGISTER_KERNEL_BUILDER(Name("ReturnTimestamp").Device(DEVICE_CPU), ReturnTimestamp); + if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "ReadEmbKeyV2Dynamic channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", + MAX_CHANNEL_NUM))); + return; + } + LOG_INFO(HYBRID_BLOCKING + " reset channel {}", channelId); + hybridMgmtBlock->ResetAll(channelId); -// ##################### ReadEmbKeyV2Dynamic ####################### -REGISTER_OP("ReadEmbKeyV2Dynamic") - .Input("sample: T") - .Input("splits: int32") - .Output("output: int32") - .Attr("T: {int64, int32}") - .Attr("channel_id: int") - .Attr("emb_name: list(string)") // for which table to lookup - .Attr("timestamp: bool") // use for feature evict, (unix timestamp) - .SetShapeFn([](InferenceContextPtr c) { - c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); - return Status::OK(); - }); - -class ReadEmbKeyV2Dynamic : public OpKernel { -public: - explicit ReadEmbKeyV2Dynamic(OpKernelConstructionPtr context) : OpKernel(context) - { - LOG_DEBUG("ReadEmbKeyV2Dynamic init"); - OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference - OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); - OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); - hybridMgmtBlock = Singleton::GetInstance(); - // 特征准入&淘汰功能 相关校验 - - // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 - if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() && - !FeatureAdmitAndEvict::IsThresholdCfgOK( - FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp) - ) { - context->SetStatus( - errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold config, or timestamp error ...")); - return; + threadNum = GetThreadNumEnv(); + auto keyProcess = Singleton::GetInstance(); + if (!keyProcess->isRunning) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); + return; + } + maxStep = keyProcess->GetMaxStep(channelId); } + ~ReadEmbKeyV2Dynamic() = default; + + void Compute(OpKernelContextPtr context) override + { + EASY_FUNCTION(); + LOG_DEBUG("enter ReadEmbKeyV2Dynamic"); + TimeCost tc = TimeCost(); + int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; + if (channelId == 1) { + if (maxStep != -1 && batchId >= maxStep) { + LOG_WARN("skip excess batch after {}/{}", batchId, maxStep); + return; + } + } + const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); + const auto& splits = context->input(TENSOR_INDEX_1).flat(); + int fieldNum = 0; + for (int i = 0; i < splits.size(); ++i) { + fieldNum += splits(i); + } + size_t dataSize = inputTensor.NumElements(); - if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( - "ReadEmbKeyV2Dynamic channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", - MAX_CHANNEL_NUM))); - return; - } - LOG_INFO(HYBRID_BLOCKING + " reset channel {}", channelId); - hybridMgmtBlock->ResetAll(channelId); - - threadNum = GetThreadNumEnv(); - auto keyProcess = Singleton::GetInstance(); - if (!keyProcess->isRunning) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); - return; - } - maxStep = keyProcess->GetMaxStep(channelId); - } - ~ReadEmbKeyV2Dynamic() = default; - - void Compute(OpKernelContextPtr context) override - { - EASY_FUNCTION(); - LOG_DEBUG("enter ReadEmbKeyV2Dynamic"); - TimeCost tc = TimeCost(); - int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; - if (channelId == 1) { - if (maxStep != -1 && batchId >= maxStep) { - LOG_WARN("skip excess batch after {}/{}", batchId, maxStep); + time_t timestamp = -1; + // 如果传递了时间戳,解析和校验 + if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "timestamp[%d] error, skip excess batch after %d/%d", timestamp, batchId, maxStep))); return; } + // 保证所有embNames在m_embStatus中有状态记录 + SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); + + // [batchId % KEY_PROCESS_THREAD] which thread process this batch + // [KEY_PROCESS_THREAD * 0 or 1] train or inference + int batchQueueId = batchId % threadNum + KEY_PROCESS_THREAD * channelId; + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = batchId; + + TimeCost enqueueTC; + EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); + LOG_DEBUG(KEY_PROCESS "ReadEmbKeyV2Dynamic read batch cost(ms):{}, elapsed from last(ms):{}," + " enqueueTC(ms):{}, batch[{}]:{}", + tc.ElapsedMS(), g_staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId); + g_staticSw = TimeCost(); } - const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); - const auto& splits = context->input(TENSOR_INDEX_1).flat(); - int fieldNum = 0; - for (int i = 0; i < splits.size(); ++i) { - fieldNum += splits(i); - } - size_t dataSize = inputTensor.NumElements(); - - time_t timestamp = -1; - // 如果传递了时间戳,解析和校验 - if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( - "timestamp[%d] error, skip excess batch after %d/%d", timestamp, batchId, maxStep))); - return; - } - // 保证所有embNames在m_embStatus中有状态记录 - SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); - - // [batchId % KEY_PROCESS_THREAD] which thread process this batch - // [KEY_PROCESS_THREAD * 0 or 1] train or inference - int batchQueueId = batchId % threadNum + KEY_PROCESS_THREAD * channelId; - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); - auto out = output->flat(); - out(0) = batchId; - - TimeCost enqueueTC; - EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); - LOG_DEBUG(KEY_PROCESS "ReadEmbKeyV2Dynamic read batch cost(ms):{}, elapsed from last(ms):{}," - " enqueueTC(ms):{}, batch[{}]:{}", - tc.ElapsedMS(), staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId); - staticSw = TimeCost(); - } - - void CheckEmbTables() - { - auto keyProcess = Singleton::GetInstance(); - for (size_t i = 0; i < embNames.size(); ++i) { - if (!keyProcess->HasEmbName(embNames.at(i))) { - LOG_INFO("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i)); - tableUsed.push_back(false); - } else { - tableUsed.push_back(true); + + void CheckEmbTables() + { + auto keyProcess = Singleton::GetInstance(); + for (size_t i = 0; i < embNames.size(); ++i) { + if (!keyProcess->HasEmbName(embNames.at(i))) { + LOG_INFO("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i)); + tableUsed.push_back(false); + } else { + tableUsed.push_back(true); + } } } - } - void EnqueueBatchData(std::vector ids, time_t timestamp, - const Tensor& inputTensor, const TTypes::ConstFlat& splits) - { - if (tableUsed.empty()) { - CheckEmbTables(); - } - auto queue = SingletonQueue::getInstances(ids[1]); - size_t offset = 0; - if (isTimestamp) { - offset += 1; // 前面8个字节是unix时间戳 - } - for (int i = 0; i < splits.size(); ++i) { - if (!tableUsed.at(i)) { - offset += splits(i); - continue; + void EnqueueBatchData(std::vector ids, time_t timestamp, + const Tensor& inputTensor, const TTypes::ConstFlat& splits) + { + if (tableUsed.empty()) { + CheckEmbTables(); } - auto batchData = queue->GetOne(); // get dirty or empty data block - batchData->name = embNames.at(i); - size_t len = splits(i); - batchData->channel = channelId; - batchData->batchId = ids[0]; - batchData->sample.resize(len); + auto queue = SingletonQueue::GetInstances(ids[1]); + size_t offset = 0; if (isTimestamp) { - batchData->timestamp = timestamp; + offset += 1; // 前面8个字节是unix时间戳 } - - if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { - auto src = (const int32_t*)inputTensor.tensor_data().data(); - copy(src + offset, src + offset + len, batchData->sample.data()); - } else { - auto src = (const int64_t*)inputTensor.tensor_data().data(); - copy(src + offset, src + offset + len, batchData->sample.data()); + for (int i = 0; i < splits.size(); ++i) { + if (!tableUsed.at(i)) { + offset += splits(i); + continue; + } + auto batchData = queue->GetOne(); // get dirty or empty data block + batchData->name = embNames.at(i); + size_t len = splits(i); + batchData->channel = channelId; + batchData->batchId = ids[0]; + batchData->sample.resize(len); + if (isTimestamp) { + batchData->timestamp = timestamp; + } + + if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { + auto src = reinterpret_cast(inputTensor.tensor_data().data()); + copy(src + offset, src + offset + len, batchData->sample.data()); + } else { + auto src = reinterpret_cast(inputTensor.tensor_data().data()); + copy(src + offset, src + offset + len, batchData->sample.data()); + } + offset += len; + queue->Pushv(move(batchData)); } - offset += len; - queue->Pushv(move(batchData)); - } - } - - bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp, - size_t& dataSize) - { - if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 - LOG_ERROR("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); - return false; } - // 前面8个字节、即占一个featureId位,是unix时间戳 - auto src = (const time_t*)inputTensor.tensor_data().data(); - std::copy(src, src + 1, ×tamp); - LOG_INFO("current batchId[{}] timestamp[{}]", batchId, timestamp); - dataSize -= 1; + bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp, + size_t& dataSize) const + { + if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 + LOG_ERROR("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); + return false; + } - if (timestamp <= 0) { - LOG_ERROR("timestamp[{}] <= 0 ", timestamp); - return false; - } + // 前面8个字节、即占一个featureId位,是unix时间戳 + auto src = reinterpret_cast(inputTensor.tensor_data().data()); + std::copy(src, src + 1, ×tamp); + LOG_INFO("current batchId[{}] timestamp[{}]", batchId, timestamp); + dataSize -= 1; - return true; - } - - void SetCurrEmbNamesStatus(const vector& embeddingNames, - absl::flat_hash_map& embStatus) - { - for (size_t i = 0; i < embeddingNames.size(); ++i) { - auto it = embStatus.find(embeddingNames[i]); - // 对配置了的,进行校验 - if (it == embStatus.end()) { - // 没有配置的,则不需要“准入&淘汰”功能 - embStatus.insert(std::pair(embeddingNames[i], SingleEmbTableStatus::SETS_NONE)); + if (timestamp <= 0) { + LOG_ERROR("timestamp[{}] <= 0 ", timestamp); + return false; } + + return true; } - } - int channelId {}; - vector embNames {}; - vector tableUsed{}; - int maxStep = 0; - bool isTimestamp { false }; - int threadNum = 0; - HybridMgmtBlock* hybridMgmtBlock; -}; + void SetCurrEmbNamesStatus(const vector& embeddingNames, + absl::flat_hash_map& embStatus) const + { + for (size_t i = 0; i < embeddingNames.size(); ++i) { + auto it = embStatus.find(embeddingNames[i]); + // 对配置了的,进行校验 + if (it == embStatus.end()) { + // 没有配置的,则不需要“准入&淘汰”功能 + embStatus.insert(std::pair(embeddingNames[i], SingleEmbTableStatus::SETS_NONE)); + } + } + } -REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2Dynamic").Device(DEVICE_CPU), ReadEmbKeyV2Dynamic); + int channelId {}; + vector embNames {}; + vector tableUsed{}; + int maxStep = 0; + bool isTimestamp { false }; + int threadNum = 0; + HybridMgmtBlock* hybridMgmtBlock; + }; + + class ReadEmbKeyV2 : public OpKernel { + public: + explicit ReadEmbKeyV2(OpKernelConstructionPtr context) : OpKernel(context) + { + LOG_DEBUG("ReadEmbKeyV2 init"); + OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference + OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); + OP_REQUIRES_OK(context, context->GetAttr("splits", &splits)); // 每个表的field Number + OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); + fieldNum = accumulate(splits.begin(), splits.end(), 0); + + hybridMgmtBlock = Singleton::GetInstance(); + // 特征准入&淘汰功能 相关校验 + + // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 + if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() && + !FeatureAdmitAndEvict::IsThresholdCfgOK(FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp)) { + context->SetStatus( + errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold config, or timestamp error ...") + ); + return; + } -// ##################### ReadEmbKeyV2 ####################### -REGISTER_OP("ReadEmbKeyV2") - .Input("sample: T") - .Output("output: int32") - .Attr("T: {int64, int32}") - .Attr("channel_id: int") - .Attr("splits: list(int)") - .Attr("emb_name: list(string)") // for which table to lookup - .Attr("timestamp: bool") // use for feature evict, (unix timestamp) - .SetShapeFn([](InferenceContextPtr c) { - c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); - return Status::OK(); - }); - -class ReadEmbKeyV2 : public OpKernel { -public: - explicit ReadEmbKeyV2(OpKernelConstructionPtr context) : OpKernel(context) - { - LOG_DEBUG("ReadEmbKeyV2 init"); - OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); // 0 train or 1 inference - OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames)); - OP_REQUIRES_OK(context, context->GetAttr("splits", &splits)); // 每个表的field Number - OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp)); - fieldNum = accumulate(splits.begin(), splits.end(), 0); - - hybridMgmtBlock = Singleton::GetInstance(); - // 特征准入&淘汰功能 相关校验 - - // 配置了,也不能多配、配不相关的;同时支持“准入&淘汰”,则不能没有时间戳 - if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() && - !FeatureAdmitAndEvict::IsThresholdCfgOK(FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp)) { - context->SetStatus( - errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold config, or timestamp error ...") - ); - return; + if (splits.size() != embNames.size()) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "splits & embNames size error.%d %d", splits.size(), embNames.size()))); + return; + } + if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "ReadEmbKeyV2 channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", + MAX_CHANNEL_NUM))); + return; + } + LOG_INFO(HYBRID_BLOCKING + " reset channel {}", channelId); + // 重置此数据通道中所有的步数 + hybridMgmtBlock->ResetAll(channelId); + + threadNum = GetThreadNumEnv(); + auto keyProcess = Singleton::GetInstance(); + if (!keyProcess->isRunning) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); + return; + } + maxStep = keyProcess->GetMaxStep(channelId); } - if (splits.size() != embNames.size()) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( - "splits & embNames size error.%d %d", splits.size(), embNames.size()))); - return; - } - if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( - "ReadEmbKeyV2 channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", MAX_CHANNEL_NUM))); - return; - } - LOG_INFO(HYBRID_BLOCKING + " reset channel {}", channelId); - // 重置此数据通道中所有的步数 - hybridMgmtBlock->ResetAll(channelId); - - threadNum = GetThreadNumEnv(); - auto keyProcess = Singleton::GetInstance(); - if (!keyProcess->isRunning) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); - return; - } - maxStep = keyProcess->GetMaxStep(channelId); - } - - ~ReadEmbKeyV2() = default; - - void Compute(OpKernelContextPtr context) override - { - EASY_FUNCTION(); - LOG_DEBUG("enter ReadEmbKeyV2"); - TimeCost tc = TimeCost(); - int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; - Tensor* output = nullptr; - if (channelId == 1) { - if (maxStep != -1 && batchId >= maxStep) { - LOG_WARN(StringFormat("skip excess batch after {}/{}", batchId, maxStep)); - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); - auto out = output->flat(); - out(0) = batchId; + ~ReadEmbKeyV2() = default; + + void Compute(OpKernelContextPtr context) override + { + EASY_FUNCTION(); + LOG_DEBUG("enter ReadEmbKeyV2"); + TimeCost tc = TimeCost(); + int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; + Tensor* output = nullptr; + if (channelId == 1) { + if (maxStep != -1 && batchId >= maxStep) { + LOG_WARN(StringFormat("skip excess batch after {}/{}", batchId, maxStep)); + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = batchId; + return; + } + } + const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); + size_t dataSize = inputTensor.NumElements(); + + time_t timestamp = -1; + // 如果传递了时间戳,解析和校验 + if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( + "timestamp[%d] error, skip excess batch after %d/%d", timestamp, batchId, maxStep))); return; } + // 保证所有embNames在m_embStatus中有状态记录 + SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); + + // [batchId % KEY_PROCESS_THREAD] which thread process this batch + // [KEY_PROCESS_THREAD * 0 or 1] train or inference + int batchQueueId = batchId % threadNum + KEY_PROCESS_THREAD * channelId; + + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = batchId; + + TimeCost enqueueTC; + EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); + LOG_DEBUG(KEY_PROCESS "ReadEmbKeyV2Static read batch cost(ms):{}, elapsed from last(ms):{}," + " enqueueTC(ms):{}, batch[{}]:{}", + tc.ElapsedMS(), g_staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId); + g_staticSw = TimeCost(); } - const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); - size_t dataSize = inputTensor.NumElements(); - - time_t timestamp = -1; - // 如果传递了时间戳,解析和校验 - if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) { - context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( - "timestamp[%d] error, skip excess batch after %d/%d", timestamp, batchId, maxStep))); - return; - } - // 保证所有embNames在m_embStatus中有状态记录 - SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus); - - // [batchId % KEY_PROCESS_THREAD] which thread process this batch - // [KEY_PROCESS_THREAD * 0 or 1] train or inference - int batchQueueId = batchId % threadNum + KEY_PROCESS_THREAD * channelId; - - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); - auto out = output->flat(); - out(0) = batchId; - - TimeCost enqueueTC; - EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); - LOG_DEBUG(KEY_PROCESS "ReadEmbKeyV2Static read batch cost(ms):{}, elapsed from last(ms):{}," - " enqueueTC(ms):{}, batch[{}]:{}", - tc.ElapsedMS(), staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId); - staticSw = TimeCost(); - } - - void CheckEmbTables() - { - auto keyProcess = Singleton::GetInstance(); - for (size_t i = 0; i < splits.size(); ++i) { - if (!keyProcess->HasEmbName(embNames.at(i))) { - LOG_INFO("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i)); - tableUsed.push_back(false); - } else { - tableUsed.push_back(true); + + void CheckEmbTables() + { + auto keyProcess = Singleton::GetInstance(); + for (size_t i = 0; i < splits.size(); ++i) { + if (!keyProcess->HasEmbName(embNames.at(i))) { + LOG_INFO("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i)); + tableUsed.push_back(false); + } else { + tableUsed.push_back(true); + } } } - } - void EnqueueBatchData(int batchId, int batchQueueId, time_t timestamp, const Tensor& inputTensor) - { - if (tableUsed.empty()) { - CheckEmbTables(); - } - auto queue = SingletonQueue::getInstances(batchQueueId); + void EnqueueBatchData(int batchId, int batchQueueId, time_t timestamp, const Tensor& inputTensor) + { + if (tableUsed.empty()) { + CheckEmbTables(); + } + auto queue = SingletonQueue::GetInstances(batchQueueId); - size_t offset = 0; - if (isTimestamp) { - offset += 1; // 前面8个字节是unix时间戳 + size_t offset = 0; + if (isTimestamp) { + offset += 1; // 前面8个字节是unix时间戳 + } + for (size_t i = 0; i < splits.size(); ++i) { + if (!tableUsed.at(i)) { + offset += splits.at(i); + continue; + } + auto batchData = queue->GetOne(); // get dirty or empty data block + batchData->name = embNames.at(i); + size_t len = splits.at(i); + batchData->channel = channelId; + batchData->batchId = batchId; + batchData->sample.resize(len); + if (isTimestamp) { + batchData->timestamp = timestamp; + } + + if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { + auto src = reinterpret_cast(inputTensor.tensor_data().data()); + copy(src + offset, src + offset + len, batchData->sample.data()); + } else { + auto src = reinterpret_cast(inputTensor.tensor_data().data()); + copy(src + offset, src + offset + len, batchData->sample.data()); + } + offset += len; + queue->Pushv(move(batchData)); + } } - for (size_t i = 0; i < splits.size(); ++i) { - if (!tableUsed.at(i)) { - offset += splits.at(i); - continue; + + bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp, + size_t& dataSize) const + { + if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 + LOG_ERROR("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); + return false; } - auto batchData = queue->GetOne(); // get dirty or empty data block - batchData->name = embNames.at(i); - size_t len = splits.at(i); - batchData->channel = channelId; - batchData->batchId = batchId; - batchData->sample.resize(len); - if (isTimestamp) { - batchData->timestamp = timestamp; + + // 前面8个字节、即占一个featureId位,是unix时间戳 + auto src = reinterpret_cast(inputTensor.tensor_data().data()); + std::copy(src, src + 1, ×tamp); + LOG_INFO("current batchId[{}] timestamp[{}]", batchId, timestamp); + dataSize -= 1; + + if (timestamp <= 0) { + LOG_ERROR("timestamp[{}] <= 0 ", timestamp); + return false; } - if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) { - auto src = (const int32_t*)inputTensor.tensor_data().data(); - copy(src + offset, src + offset + len, batchData->sample.data()); - } else { - auto src = (const int64_t*)inputTensor.tensor_data().data(); - copy(src + offset, src + offset + len, batchData->sample.data()); + return true; + } + void SetCurrEmbNamesStatus(const vector& embeddingNames, + absl::flat_hash_map& embStatus) const + { + for (size_t i = 0; i < embeddingNames.size(); ++i) { + auto it = embStatus.find(embeddingNames[i]); + // 对配置了的,进行校验 + if (it == embStatus.end()) { + // 没有配置的,则不需要“准入&淘汰”功能 + embStatus.insert(std::pair(embeddingNames[i], SingleEmbTableStatus::SETS_NONE)); + } } - offset += len; - queue->Pushv(move(batchData)); } - } - - bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp, - size_t& dataSize) - { - if (dataSize - fieldNumTmp != 1) { // 说明没有传时间戳 - LOG_ERROR("dataSize[{}], fieldNum[{}] ...", dataSize, fieldNumTmp); - return false; + + int channelId {}; + vector splits {}; + vector tableUsed{}; + int fieldNum {}; + vector embNames {}; + int maxStep = 0; + bool isTimestamp { false }; + int threadNum = KEY_PROCESS_THREAD; + HybridMgmtBlock* hybridMgmtBlock; + }; + + class ReadEmbKeyDatasetDummy : public OpKernel { + public: + explicit ReadEmbKeyDatasetDummy(OpKernelConstructionPtr context) : OpKernel(context) + { + OP_REQUIRES_OK(context, context->GetAttr("max_lookup_len", &lookupLen)); } - // 前面8个字节、即占一个featureId位,是unix时间戳 - auto src = (const time_t*)inputTensor.tensor_data().data(); - std::copy(src, src + 1, ×tamp); - LOG_INFO("current batchId[{}] timestamp[{}]", batchId, timestamp); - dataSize -= 1; + ~ReadEmbKeyDatasetDummy() override = default; - if (timestamp <= 0) { - LOG_ERROR("timestamp[{}] <= 0 ", timestamp); - return false; - } + void Compute(OpKernelContextPtr context) override + { + EASY_FUNCTION(); + TimeCost tc = TimeCost(); + const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); + auto input = inputTensor.flat(); + const int restoreLen = static_cast(input.size()); - return true; - } - void SetCurrEmbNamesStatus(const vector& embeddingNames, - absl::flat_hash_map& embStatus) - { - for (size_t i = 0; i < embeddingNames.size(); ++i) { - auto it = embStatus.find(embeddingNames[i]); - // 对配置了的,进行校验 - if (it == embStatus.end()) { - // 没有配置的,则不需要“准入&淘汰”功能 - embStatus.insert(std::pair(embeddingNames[i], SingleEmbTableStatus::SETS_NONE)); - } - } - } + // write lookup & restore vec + Tensor* lookupVec = nullptr; + Tensor* restoreVecTensor = nullptr; - int channelId {}; - vector splits {}; - vector tableUsed{}; - int fieldNum {}; - vector embNames {}; - int maxStep = 0; - bool isTimestamp { false }; - int threadNum = KEY_PROCESS_THREAD; - HybridMgmtBlock* hybridMgmtBlock; -}; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape { lookupLen }, &lookupVec)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape { restoreLen }, &restoreVecTensor)); + auto l = lookupVec->flat(); + auto r = restoreVecTensor->flat(); -REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), ReadEmbKeyV2); + // check whether lookupLen is zero + if (lookupLen == 0) { + throw runtime_error("lookupLen is 0, it causes the denominator to be 0 during division"); + } -// ##################### ReadEmbKeyDatasetDummy ####################### -REGISTER_OP("ReadEmbKeyDatasetDummy") - .Input("sample: T") - .Output("lookup_vec: int32") - .Output("restore_vec: int32") - .Attr("T: {int64}") - .Attr("max_lookup_len: int") - .SetShapeFn([](InferenceContextPtr c) { - int temp; - TF_RETURN_IF_ERROR(c->GetAttr("max_lookup_len", &temp)); - c->set_output(TensorIndex::TENSOR_INDEX_0, c->Vector(temp)); - c->set_output(TensorIndex::TENSOR_INDEX_1, c->input(TensorIndex::TENSOR_INDEX_0)); - return Status::OK(); - }); - -class ReadEmbKeyDatasetDummy : public OpKernel { -public: - explicit ReadEmbKeyDatasetDummy(OpKernelConstructionPtr context) : OpKernel(context) - { - OP_REQUIRES_OK(context, context->GetAttr("max_lookup_len", &lookupLen)); - } - - ~ReadEmbKeyDatasetDummy() override = default; - - void Compute(OpKernelContextPtr context) override - { - EASY_FUNCTION(); - TimeCost tc = TimeCost(); - const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); - auto input = inputTensor.flat(); - const int restoreLen = static_cast(input.size()); - - // write lookup & restore vec - Tensor* lookupVec = nullptr; - Tensor* restoreVecTensor = nullptr; - - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape { lookupLen }, &lookupVec)); - OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape { restoreLen }, &restoreVecTensor)); - auto l = lookupVec->flat(); - auto r = restoreVecTensor->flat(); - - // check whether lookupLen is zero - if (lookupLen == 0) { - throw runtime_error("lookupLen is 0, it causes the denominator to be 0 during division"); + // dummy data + for (int i { 0 }; i < lookupLen; ++i) { + l(i) = i; + } + for (int i { 0 }; i < restoreLen; ++i) { + r(i) = i % lookupLen; + } + LOG_WARN("dummy read batch cost: {},elapsed from last {}", + tc.ElapsedMS(), g_staticSw.ElapsedMS()); + tc = TimeCost(); } - // dummy data - for (int i { 0 }; i < lookupLen; ++i) { - l(i) = i; + int lookupLen {}; + }; + + class CustOps : public OpKernel { + public: + explicit CustOps(OpKernelConstructionPtr context) : OpKernel(context) + { } - for (int i { 0 }; i < restoreLen; ++i) { - r(i) = i % lookupLen; + + void Compute(OpKernelContextPtr context) override + { + LOG_INFO("context {}", context->step_id()); + std::cout << " Cust opp not installed!!" << std::endl; } - LOG_WARN("dummy read batch cost: {},elapsed from last {}", - tc.ElapsedMS(), staticSw.ElapsedMS()); - tc = TimeCost(); - } - int lookupLen {}; -}; + ~CustOps() = default; + }; -REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyDatasetDummy").Device(DEVICE_CPU), ReadEmbKeyDatasetDummy); +} +REGISTER_OP("ClearChannel").Attr("channel_id : int"); +REGISTER_KERNEL_BUILDER(Name("ClearChannel").Device(DEVICE_CPU), MxRec::ClearChannel); -class CustOps : public OpKernel { -public: - explicit CustOps(OpKernelConstructionPtr context) : OpKernel(context) - { - } - void Compute(OpKernelContextPtr context) override - { - LOG_INFO("context {}", context->step_id()); - std::cout << " Cust opp not installed!!" << std::endl; - } +// ##################### ReturnTimestamp ####################### +REGISTER_OP("ReturnTimestamp") +.Input("input: int64") +.Output("output: int64") +.SetShapeFn([](InferenceContextPtr c) { +c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); +return Status::OK(); +}); +REGISTER_KERNEL_BUILDER(Name("ReturnTimestamp").Device(DEVICE_CPU), MxRec::ReturnTimestamp); + +// ##################### ReadEmbKeyV2Dynamic ####################### +REGISTER_OP("ReadEmbKeyV2Dynamic") +.Input("sample: T") +.Input("splits: int32") +.Output("output: int32") +.Attr("T: {int64, int32}") +.Attr("channel_id: int") +.Attr("emb_name: list(string)") // for which table to lookup +.Attr("timestamp: bool") // use for feature evict, (unix timestamp) +.SetShapeFn([](InferenceContextPtr c) { +c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); +return Status::OK(); +}); + +REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2Dynamic").Device(DEVICE_CPU), MxRec::ReadEmbKeyV2Dynamic); + +// ##################### ReadEmbKeyV2 ####################### +REGISTER_OP("ReadEmbKeyV2") +.Input("sample: T") +.Output("output: int32") +.Attr("T: {int64, int32}") +.Attr("channel_id: int") +.Attr("splits: list(int)") +.Attr("emb_name: list(string)") // for which table to lookup +.Attr("timestamp: bool") // use for feature evict, (unix timestamp) +.SetShapeFn([](InferenceContextPtr c) { +c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); +return Status::OK(); +}); + +REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), MxRec::ReadEmbKeyV2); + +// ##################### ReadEmbKeyDatasetDummy ####################### +REGISTER_OP("ReadEmbKeyDatasetDummy") +.Input("sample: T") +.Output("lookup_vec: int32") +.Output("restore_vec: int32") +.Attr("T: {int64}") +.Attr("max_lookup_len: int") +.SetShapeFn([](InferenceContextPtr c) { +int temp; +TF_RETURN_IF_ERROR(c->GetAttr("max_lookup_len", &temp)); +c->set_output(TensorIndex::TENSOR_INDEX_0, c->Vector(temp)); +c->set_output(TensorIndex::TENSOR_INDEX_1, c->input(TensorIndex::TENSOR_INDEX_0)); +return Status::OK(); +}); + +REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyDatasetDummy").Device(DEVICE_CPU), MxRec::ReadEmbKeyDatasetDummy); - ~CustOps() = default; -}; REGISTER_OP("EmbeddingLookupByAddress") - .Input("address: int64") - .Attr("embedding_dim: int") - .Attr("embedding_type: int") - .Output("y: float") - .SetIsStateful() - .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { - ShapeHandle addrShape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); - int embSize; - TF_RETURN_IF_ERROR(c->GetAttr("embedding_dim", &embSize)); - tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); - c->set_output(TENSOR_INDEX_0, c->Matrix(rows, embSize)); - return Status::OK(); - }); - -REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupByAddress").Device(DEVICE_CPU), CustOps); +.Input("address: int64") +.Attr("embedding_dim: int") +.Attr("embedding_type: int") +.Output("y: float") +.SetIsStateful() +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { +ShapeHandle addrShape; +TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); +int embSize; +TF_RETURN_IF_ERROR(c->GetAttr("embedding_dim", &embSize)); +tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); +c->set_output(TENSOR_INDEX_0, c->Matrix(rows, embSize)); +return Status::OK(); +}); + +REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupByAddress").Device(DEVICE_CPU), MxRec::CustOps); REGISTER_OP("EmbeddingUpdateByAddress") - .Input("address: int64") - .Input("embedding: float") - .Attr("update_type: int") - .Output("y: float") - .SetIsStateful() - .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { - ShapeHandle addrShape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); - ShapeHandle embeddingShape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &embeddingShape)); - tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); - tensorflow::shape_inference::DimensionHandle cols = c->Dim(embeddingShape, 1); - c->set_output(TENSOR_INDEX_0, c->Matrix(rows, cols)); - return Status::OK(); - }); - -REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), CustOps); \ No newline at end of file +.Input("address: int64") +.Input("embedding: float") +.Attr("update_type: int") +.Output("y: float") +.SetIsStateful() +.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { +ShapeHandle addrShape; +TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); +ShapeHandle embeddingShape; +TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &embeddingShape)); +tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); +tensorflow::shape_inference::DimensionHandle cols = c->Dim(embeddingShape, 1); +c->set_output(TENSOR_INDEX_0, c->Matrix(rows, cols)); +return Status::OK(); +}); + +REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), MxRec::CustOps); \ No newline at end of file diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 18277c3b..e64eb33f 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -9,171 +9,174 @@ #include #include "hybrid_mgmt/hybrid_mgmt.h" -#include "module_main.h" namespace py = pybind11; using namespace MxRec; +namespace { + void GetRankInfo(py::module_& m); -void GetRankInfo(py::module_& m); + void GetEmbInfo(py::module_& m); -void GetEmbInfo(py::module_& m); + void GetRandomInfo(py::module_& m); -void GetRandomInfo(py::module_& m); + void GetHybridMgmt(py::module_& m); -void GetHybridMgmt(py::module_& m); + void GetThresholdValue(pybind11::module_& m); -void GetThresholdValue(pybind11::module_& m); + void GetInitializeInfo(pybind11::module_& m); -void GetInitializeInfo(pybind11::module_& m); + void GetConstantInitializerInfo(pybind11::module_& m); -void GetConstantInitializerInfo(pybind11::module_& m); + void GetNormalInitializerInfo(pybind11::module_& m); -void GetNormalInitializerInfo(pybind11::module_& m); - -int GetUBHotSize(int devID) -{ - return static_cast(static_cast(MxRec::GetUBSize(devID)) / sizeof(float) * HOT_EMB_CACHE_PCT) ; -} + int GetUBHotSize(int devID) + { + return static_cast(static_cast(MxRec::GetUBSize(devID)) / sizeof(float) * HOT_EMB_CACHE_PCT) ; + } -int32_t GetLogicID(uint32_t phyid) -{ - uint32_t logicId; - int32_t ret = dsmi_get_logicid_from_phyid(phyid, &logicId); - if (ret != 0) { - return ret; + int32_t GetLogicID(uint32_t phyid) + { + uint32_t logicId; + int32_t ret = dsmi_get_logicid_from_phyid(phyid, &logicId); + if (ret != 0) { + return ret; + } + return logicId; } - return logicId; -} -PYBIND11_MODULE(mxrec_pybind, m) -{ - m.def("get_ub_hot_size", &GetUBHotSize, py::arg("device_id")); + PYBIND11_MODULE(mxrec_pybind, m) + { + m.def("get_ub_hot_size", &GetUBHotSize, py::arg("device_id")); - m.def("get_logic_id", &GetLogicID, py::arg("physic_id")); + m.def("get_logic_id", &GetLogicID, py::arg("physic_id")); - m.attr("USE_STATIC") = py::int_(HybridOption::USE_STATIC); + m.attr("USE_STATIC") = py::int_(HybridOption::USE_STATIC); - m.attr("USE_HOT") = py::int_(HybridOption::USE_HOT); + m.attr("USE_HOT") = py::int_(HybridOption::USE_HOT); - m.attr("USE_DYNAMIC_EXPANSION") = py::int_(HybridOption::USE_DYNAMIC_EXPANSION); + m.attr("USE_DYNAMIC_EXPANSION") = py::int_(HybridOption::USE_DYNAMIC_EXPANSION); - GetRankInfo(m); + GetRankInfo(m); - GetEmbInfo(m); + GetEmbInfo(m); - GetRandomInfo(m); + GetRandomInfo(m); - GetHybridMgmt(m); + GetHybridMgmt(m); - GetThresholdValue(m); + GetThresholdValue(m); - GetInitializeInfo(m); + GetInitializeInfo(m); - GetConstantInitializerInfo(m); + GetConstantInitializerInfo(m); - GetNormalInitializerInfo(m); -} + GetNormalInitializerInfo(m); + } -void GetRankInfo(pybind11::module_& m) -{ - pybind11::class_(m, "RankInfo") - .def(py::init>(), py::arg("rank_id"), py::arg("device_id"), - py::arg("local_rank_size"), py::arg("option"), - py::arg("max_step") = vector { -1, -1 }) - .def_readwrite("rank_id", &RankInfo::rankId) - .def_readwrite("device_id", &RankInfo::deviceId) - .def_readwrite("rank_size", &RankInfo::rankSize) - .def_readwrite("local_rank_size", &RankInfo::localRankSize) - .def_readwrite("option", &RankInfo::option) - .def_readwrite("max_step", &RankInfo::maxStep); -} + void GetRankInfo(pybind11::module_& m) + { + pybind11::class_(m, "RankInfo") + .def(py::init>(), py::arg("rank_id"), py::arg("device_id"), + py::arg("local_rank_size"), py::arg("option"), + py::arg("max_step") = vector { -1, -1 }) + .def_readwrite("rank_id", &RankInfo::rankId) + .def_readwrite("device_id", &RankInfo::deviceId) + .def_readwrite("rank_size", &RankInfo::rankSize) + .def_readwrite("local_rank_size", &RankInfo::localRankSize) + .def_readwrite("option", &RankInfo::option) + .def_readwrite("max_step", &RankInfo::maxStep); + } -void GetEmbInfo(pybind11::module_& m) -{ - pybind11::class_(m, "EmbInfo") - .def(pybind11::init, - std::vector&, std::vector&>(), - py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), py::arg("ext_embedding_size"), - py::arg("is_save"), py::arg("vocab_size"), py::arg("initialize_infos"), - py::arg("ssd_data_path")) - .def_readwrite("name", &EmbInfo::name) - .def_readwrite("send_count", &EmbInfo::sendCount) - .def_readwrite("embedding_size", &EmbInfo::embeddingSize) - .def_readwrite("ext_embedding_size", &EmbInfo::extEmbeddingSize) - .def_readwrite("is_save", &EmbInfo::isSave) - .def_readwrite("dev_vocab_size", &EmbInfo::devVocabSize) - .def_readwrite("host_vocab_size", &EmbInfo::hostVocabSize) - .def_readwrite("initialize_infos", &EmbInfo::initializeInfos) - .def_readwrite("ssd_data_path", &EmbInfo::ssdDataPath); -} + void GetEmbInfo(pybind11::module_& m) + { + pybind11::class_(m, "EmbInfo") + .def(pybind11::init, + std::vector&, std::vector&>(), + py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), py::arg("ext_embedding_size"), + py::arg("is_save"), py::arg("vocab_size"), py::arg("initialize_infos"), + py::arg("ssd_data_path")) + .def_readwrite("name", &EmbInfo::name) + .def_readwrite("send_count", &EmbInfo::sendCount) + .def_readwrite("embedding_size", &EmbInfo::embeddingSize) + .def_readwrite("ext_embedding_size", &EmbInfo::extEmbeddingSize) + .def_readwrite("is_save", &EmbInfo::isSave) + .def_readwrite("dev_vocab_size", &EmbInfo::devVocabSize) + .def_readwrite("host_vocab_size", &EmbInfo::hostVocabSize) + .def_readwrite("initialize_infos", &EmbInfo::initializeInfos) + .def_readwrite("ssd_data_path", &EmbInfo::ssdDataPath); + } -void GetRandomInfo(pybind11::module_& m) -{ - pybind11::class_(m, "RandomInfo") - .def(pybind11::init()) - .def_readwrite("start", &RandomInfo::start) - .def_readwrite("len", &RandomInfo::len) - .def_readwrite("constant_val", &RandomInfo::constantVal) - .def_readwrite("random_min", &RandomInfo::randomMin) - .def_readwrite("random_max", &RandomInfo::randomMax); -} + void GetRandomInfo(pybind11::module_& m) + { + pybind11::class_(m, "RandomInfo") + .def(pybind11::init()) + .def_readwrite("start", &RandomInfo::start) + .def_readwrite("len", &RandomInfo::len) + .def_readwrite("constant_val", &RandomInfo::constantVal) + .def_readwrite("random_min", &RandomInfo::randomMin) + .def_readwrite("random_max", &RandomInfo::randomMax); + } -void GetInitializeInfo(pybind11::module_ &m) -{ - pybind11::class_(m, "InitializeInfo") - .def(py::init(), py::arg("name"), py::arg("start"), - py::arg("len"), py::arg("constant_initializer_info")) - .def(py::init(), py::arg("name"), py::arg("start"), - py::arg("len"), py::arg("normal_initializer_info")) - .def_readwrite("name", &InitializeInfo::name) - .def_readwrite("start", &InitializeInfo::start) - .def_readwrite("len", &InitializeInfo::len) - .def_readwrite("ConstantInitializerInfo", &InitializeInfo::constantInitializerInfo) - .def_readwrite("NormalInitializerInfo", &InitializeInfo::normalInitializerInfo); -} + void GetInitializeInfo(pybind11::module_ &m) + { + pybind11::class_(m, "InitializeInfo") + .def(py::init(), py::arg("name"), py::arg("start"), + py::arg("len"), py::arg("constant_initializer_info")) + .def(py::init(), py::arg("name"), py::arg("start"), + py::arg("len"), py::arg("normal_initializer_info")) + .def_readwrite("name", &InitializeInfo::name) + .def_readwrite("start", &InitializeInfo::start) + .def_readwrite("len", &InitializeInfo::len) + .def_readwrite("ConstantInitializerInfo", &InitializeInfo::constantInitializerInfo) + .def_readwrite("NormalInitializerInfo", &InitializeInfo::normalInitializerInfo); + } -void GetConstantInitializerInfo(pybind11::module_ &m) -{ - pybind11::class_(m, "ConstantInitializerInfo") - .def(py::init(), py::arg("constant_val") = 0, py::arg("initK") = 1.0) - .def_readwrite("constant_val", &ConstantInitializerInfo::constantValue) - .def_readwrite("initK", &ConstantInitializerInfo::initK); -} + void GetConstantInitializerInfo(pybind11::module_ &m) + { + pybind11::class_(m, "ConstantInitializerInfo") + .def(py::init(), py::arg("constant_val") = 0, py::arg("initK") = 1.0) + .def_readwrite("constant_val", &ConstantInitializerInfo::constantValue) + .def_readwrite("initK", &ConstantInitializerInfo::initK); + } -void GetNormalInitializerInfo(pybind11::module_ &m) -{ - pybind11::class_(m, "NormalInitializerInfo") - .def(py::init(), py::arg("mean") = 0.0, py::arg("stddev") = 1.0, py::arg("seed") = 0, - py::arg("initK") = 1.0) - .def_readwrite("mean", &NormalInitializerInfo::mean) - .def_readwrite("stddev", &NormalInitializerInfo::stddev) - .def_readwrite("seed", &NormalInitializerInfo::seed) - .def_readwrite("initK", &NormalInitializerInfo::initK); -} + void GetNormalInitializerInfo(pybind11::module_ &m) + { + pybind11::class_(m, "NormalInitializerInfo") + .def(py::init(), py::arg("mean") = 0.0, + py::arg("stddev") = 1.0, py::arg("seed") = 0, + py::arg("initK") = 1.0) + .def_readwrite("mean", &NormalInitializerInfo::mean) + .def_readwrite("stddev", &NormalInitializerInfo::stddev) + .def_readwrite("seed", &NormalInitializerInfo::seed) + .def_readwrite("initK", &NormalInitializerInfo::initK); + } -void GetHybridMgmt(pybind11::module_& m) -{ - pybind11::class_(m, "HybridMgmt") - .def(py::init()) - .def("initialize", &MxRec::HybridMgmt::Initialize, py::arg("rank_info"), py::arg("emb_info"), - py::arg("seed") = DEFAULT_RANDOM_SEED, py::arg("threshold_values") = vector {}, - py::arg("if_load") = false) - .def("save", &MxRec::HybridMgmt::Save, py::arg("save_path") = "") - .def("load", &MxRec::HybridMgmt::Load, py::arg("load_path") = "") - .def("destroy", &MxRec::HybridMgmt::Destroy) - .def("evict", &MxRec::HybridMgmt::Evict) - .def("send", &MxRec::HybridMgmt::SendHostMap, py::arg("table_name") = "") - .def("receive", &MxRec::HybridMgmt::ReceiveHostMap, py::arg("key_offset_map")) - .def("block_notify_wake", &MxRec::HybridMgmt::NotifyBySessionRun, py::arg("channel_id")) - .def("block_count_steps", &MxRec::HybridMgmt::CountStepBySessionRun, py::arg("channel_id"), py::arg("steps")=1); -} + void GetHybridMgmt(pybind11::module_& m) + { + pybind11::class_(m, "HybridMgmt") + .def(py::init()) + .def("initialize", &MxRec::HybridMgmt::Initialize, py::arg("rank_info"), py::arg("emb_info"), + py::arg("seed") = DEFAULT_RANDOM_SEED, py::arg("threshold_values") = vector {}, + py::arg("if_load") = false) + .def("save", &MxRec::HybridMgmt::Save, py::arg("save_path") = "") + .def("load", &MxRec::HybridMgmt::Load, py::arg("load_path") = "") + .def("destroy", &MxRec::HybridMgmt::Destroy) + .def("evict", &MxRec::HybridMgmt::Evict) + .def("send", &MxRec::HybridMgmt::SendHostMap, py::arg("table_name") = "") + .def("receive", &MxRec::HybridMgmt::ReceiveHostMap, py::arg("key_offset_map")) + .def("block_notify_wake", &MxRec::HybridMgmt::NotifyBySessionRun, py::arg("channel_id")) + .def("block_count_steps", &MxRec::HybridMgmt::CountStepBySessionRun, + py::arg("channel_id"), py::arg("steps")=1); + } + + void GetThresholdValue(pybind11::module_& m) + { + pybind11::class_(m, "ThresholdValue") + .def(pybind11::init()) + .def_readwrite("table_name", &ThresholdValue::tableName) + .def_readwrite("count_threshold", &ThresholdValue::countThreshold) + .def_readwrite("time_threshold", &ThresholdValue::timeThreshold) + .def_readwrite("faae_coefficient", &ThresholdValue::faaeCoefficient); + } -void GetThresholdValue(pybind11::module_& m) -{ - pybind11::class_(m, "ThresholdValue") - .def(pybind11::init()) - .def_readwrite("table_name", &ThresholdValue::tableName) - .def_readwrite("count_threshold", &ThresholdValue::countThreshold) - .def_readwrite("time_threshold", &ThresholdValue::timeThreshold) - .def_readwrite("faae_coefficient", &ThresholdValue::faaeCoefficient); } diff --git a/src/pybind/module_main.h b/src/pybind/module_main.h deleted file mode 100644 index e0ffa3bd..00000000 --- a/src/pybind/module_main.h +++ /dev/null @@ -1,11 +0,0 @@ -/* -* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. -* Description: main head file -* Author: MindX SDK -* Date: 2022/11/15 -*/ - -#ifndef SPARSE_SSD_DEMO_MODULE_MAIN_H -#define SPARSE_SSD_DEMO_MODULE_MAIN_H - -#endif // SPARSE_SSD_DEMO_MODULE_MAIN_H diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 820fe57f..c71f42fc 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -85,7 +85,7 @@ protected: } } - void SetHostEmbs(std::shared_ptr testHostEmbs) + void SetHostEmbs(std::shared_ptr testHostEmbs) { vector> testEmbData; for (const auto& testEmbInfo : testEmbInfos) { @@ -95,7 +95,7 @@ protected: } } - void SetHostEmptyEmbs(std::shared_ptr loadHostEmbs) + void SetHostEmptyEmbs(std::shared_ptr loadHostEmbs) { vector> testEmbData; for (const auto& testEmbInfo : testEmbInfos) { @@ -128,7 +128,7 @@ protected: fill(testDev2B.begin(), testDev2B.end(), -1); } - void SetEmbHashMaps(emb_hash_mem_t& testEmbHashMaps) + void SetEmbHashMaps(EmbHashMemT& testEmbHashMaps) { EmbHashMapInfo embHashInfo; absl::flat_hash_map testHash; @@ -150,7 +150,7 @@ protected: } } - void SetMaxOffset(offset_mem_t& testMaxOffset) + void SetMaxOffset(OffsetMemT& testMaxOffset) { for (const auto& testEmbInfo : testEmbInfos) { testMaxOffset[testEmbInfo.name] = offsetMem; @@ -165,7 +165,7 @@ protected: } } - void SetKeyOffsetMaps(key_offset_mem_t& testKeyOffsetMaps) + void SetKeyOffsetMaps(KeyOffsetMemT& testKeyOffsetMaps) { absl::flat_hash_map testKeyOffsetMap; for (const auto& testEmbInfo : testEmbInfos) { @@ -190,7 +190,7 @@ protected: } } - void SetDDRKeyFreqMaps(key_freq_mem_t& testDDRKeyFreqMaps) + void SetDDRKeyFreqMaps(KeyFreqMemT& testDDRKeyFreqMaps) { unordered_map testDDRKeyFreqMap; for (const auto& testEmbInfo : testEmbInfos) { @@ -199,7 +199,7 @@ protected: } } - void SetExcludeDDRKeyFreqMaps(key_freq_mem_t& testExcludeDDRKeyFreqMaps) + void SetExcludeDDRKeyFreqMaps(KeyFreqMemT& testExcludeDDRKeyFreqMaps) { unordered_map testExcludeDDRKeyFreqMap; for (const auto& testEmbInfo : testEmbInfos) { @@ -208,7 +208,7 @@ protected: } } - void SetTable2Threshold(table_2_thresh_mem_t& testTable2Threshold) + void SetTable2Threshold(Table2ThreshMemT& testTable2Threshold) { for (const auto& testEmbInfo : testEmbInfos) { ThresholdValue val; @@ -276,12 +276,12 @@ protected: TEST_F(CheckpointTest, HostEmbs) { - std::shared_ptr testHostEmbs = std::make_shared(); + std::shared_ptr testHostEmbs = std::make_shared(); SetEmbInfo(); SetHostEmbs(testHostEmbs); - shared_ptr validHostEmbs = std::make_shared(); + shared_ptr validHostEmbs = std::make_shared(); SetHostEmbs(validHostEmbs); - shared_ptr loadHostEmbs = std::make_shared(); + shared_ptr loadHostEmbs = std::make_shared(); SetHostEmptyEmbs(loadHostEmbs); CkptData testSaveData; @@ -313,8 +313,8 @@ TEST_F(CheckpointTest, HostEmbs) TEST_F(CheckpointTest, EmbHashMaps) { - emb_hash_mem_t testEmbHashMaps; - emb_hash_mem_t validEmbHashMaps; + EmbHashMemT testEmbHashMaps; + EmbHashMemT validEmbHashMaps; SetEmbInfo(); SetEmbHashMaps(testEmbHashMaps); @@ -355,8 +355,8 @@ TEST_F(CheckpointTest, EmbHashMaps) TEST_F(CheckpointTest, MaxOffset) { - offset_mem_t testMaxOffset; - offset_mem_t validMaxOffset; + OffsetMemT testMaxOffset; + OffsetMemT validMaxOffset; SetEmbInfo(); SetMaxOffset(testMaxOffset); @@ -385,8 +385,8 @@ TEST_F(CheckpointTest, MaxOffset) TEST_F(CheckpointTest, KeyOffsetMaps) { - key_offset_mem_t testKeyOffsetMaps; - key_offset_mem_t validKeyOffsetMaps; + KeyOffsetMemT testKeyOffsetMaps; + KeyOffsetMemT validKeyOffsetMaps; SetEmbInfo(); SetKeyOffsetMaps(testKeyOffsetMaps); @@ -413,10 +413,10 @@ TEST_F(CheckpointTest, KeyOffsetMaps) TEST_F(CheckpointTest, AllMgmt) { - offset_mem_t testMaxOffset; - offset_mem_t validMaxOffset; - key_offset_mem_t testKeyOffsetMaps; - key_offset_mem_t validKeyOffsetMaps; + OffsetMemT testMaxOffset; + OffsetMemT validMaxOffset; + KeyOffsetMemT testKeyOffsetMaps; + KeyOffsetMemT validKeyOffsetMaps; SetEmbInfo(); SetMaxOffset(testMaxOffset); @@ -461,8 +461,8 @@ TEST_F(CheckpointTest, AllMgmt) TEST_F(CheckpointTest, FeatAdmitNEvict) { - table_2_thresh_mem_t testTrens2Thresh; - table_2_thresh_mem_t validTrens2Thresh; + Table2ThreshMemT testTrens2Thresh; + Table2ThreshMemT validTrens2Thresh; AdmitAndEvictData testHistRec; AdmitAndEvictData validHistRec; @@ -528,10 +528,10 @@ TEST_F(CheckpointTest, FeatAdmitNEvict) TEST_F(CheckpointTest, KeyFreqMaps) { - key_freq_mem_t testDDRKeyFreqMaps; - key_freq_mem_t validDDRKeyFreqMaps; - key_freq_mem_t testExcludeDDRKeyFreqMaps; - key_freq_mem_t validExcludeDDRKeyFreqMaps; + KeyFreqMemT testDDRKeyFreqMaps; + KeyFreqMemT validDDRKeyFreqMaps; + KeyFreqMemT testExcludeDDRKeyFreqMaps; + KeyFreqMemT validExcludeDDRKeyFreqMaps; SetEmbInfo(); SetDDRKeyFreqMaps(testDDRKeyFreqMaps); diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 39dcc36f..2ac01a45 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -79,7 +79,7 @@ protected: } } - void SetTable2Threshold(table_2_thresh_mem_t& testTable2Threshold, + void SetTable2Threshold(Table2ThreshMemT& testTable2Threshold, valid_int_t& validArr, valid_attrib_t& validAttrib) { @@ -247,8 +247,8 @@ protected: TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) { - table_2_thresh_mem_t testTrens2Thresh; - table_2_thresh_mem_t validTrens2Thresh; + Table2ThreshMemT testTrens2Thresh; + Table2ThreshMemT validTrens2Thresh; AdmitAndEvictData testHistRec; AdmitAndEvictData validHistRec; diff --git a/src/tests/emb_hashmap/emb_hashmap_test.cpp b/src/tests/emb_hashmap/emb_hashmap_test.cpp index 80814d6f..6c03c9c7 100644 --- a/src/tests/emb_hashmap/emb_hashmap_test.cpp +++ b/src/tests/emb_hashmap/emb_hashmap_test.cpp @@ -87,7 +87,7 @@ TEST(EmbHashMap, TestFindOffset) auto& ddrKeyMap = cacheManager.ddrKeyFreqMap[embTableName]; auto logLevelTemp = Log::GetLevel(); - Log::SetLevel(Log::TRACE); + Log::SetLevel(Log::trace); vector keys4 = {21, 21, 21, 21}; // 新key重复值, 且需要换入换出 hostHashMaps.FindOffset(embTableName, keys4, currentBatchId++, keepBatchId++, channelId); RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index 384a0c37..acbdbdf1 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -36,16 +36,16 @@ enum SleepTime : uint32_t { using HashMapInfo = absl::flat_hash_map; struct InputArgs { - keys_t keys; + KeysT keys; vector cnt; - keys_t expectKeys; + KeysT expectKeys; HashMapInfo lastHistory; HashMapInfo expectHistory; }; class FeatureAdmitAndEvictTest : public testing::Test { protected: - HashMapInfo GetHistoryRecords(keys_t& keys, vector& cnt, time_t ts, std::string embName, + HashMapInfo GetHistoryRecords(KeysT& keys, vector& cnt, time_t ts, std::string embName, HashMapInfo& oldInfos) { HashMapInfo newInfos; @@ -113,7 +113,7 @@ protected: return true; } - bool IsAllTheSameVector(keys_t& keys1, keys_t& keys2) + bool IsAllTheSameVector(KeysT& keys1, KeysT& keys2) { printf("\nrun ret: keys1 ===> \n\t"); for (auto &k1 : keys1) { @@ -142,8 +142,8 @@ protected: void FeatureAdmitCommon(FeatureAdmitAndEvict& faae, int channel, string embName, InputArgs& args) { time_t ts = time(nullptr); - keys_t tmpKeys = args.keys; - std::unique_ptr batch = make_unique(); + KeysT tmpKeys = args.keys; + std::unique_ptr batch = make_unique(); batch->name = embName; batch->timestamp = ts; printf("\n"); @@ -165,8 +165,8 @@ protected: void FeatureAdmitCommonMultiThr(FeatureAdmitAndEvict& faae, int channel, string embName, InputArgs& args) { time_t ts = time(nullptr); - keys_t tmpKeys = args.keys; - std::unique_ptr batch = make_unique(); + KeysT tmpKeys = args.keys; + std::unique_ptr batch = make_unique(); batch->name = embName; batch->timestamp = ts; printf("\n"); @@ -273,7 +273,7 @@ protected: keys1 = {11, 11, 33, 44, 11, 55, 88, 55} cnt1 = 1 2 1 3 1 1 4 1 */ - keys_t expectRet1 = {11, 11, -1, 44, 11, 55, 88, 55}; + KeysT expectRet1 = {11, 11, -1, 44, 11, 55, 88, 55}; InputArgs args1 = {keys1, cnt1, expectRet1, initHistory, {}}; // 每个表的第一次记录,要用initHistory追加 FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args1); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_1)); @@ -283,7 +283,7 @@ protected: keys2 = {11, 12, 33, 21, 11, 12} cnt2 = 1 2 1 1 2 3 */ - keys_t expectRet2 = {11, 12, 33, -1, 11, 12}; + KeysT expectRet2 = {11, 12, 33, -1, 11, 12}; InputArgs args2 = {keys2, cnt2, expectRet2, args1.expectHistory, {}}; FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args2); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); @@ -293,7 +293,7 @@ protected: keys3 = {123, 121, 121, 212, 211} cnt3 = 1 2 1 1 2 */ - keys_t expectRet3 = {-1, 121, 121, -1, -1}; + KeysT expectRet3 = {-1, 121, 121, -1, -1}; InputArgs args3 = {keys3, cnt3, expectRet3, initHistory, {}}; FeatureAdmitCommon(faae, 0, thresholds[1].tableName, args3); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_6)); @@ -303,7 +303,7 @@ protected: keys4 = {11, 11, 33, 44, 55, 88, 55} cnt4 = 1 2 3 2 1 2 1 */ - keys_t expectRet4 = {11, 11, 33, 44, 55, 88, 55}; + KeysT expectRet4 = {11, 11, 33, 44, 55, 88, 55}; InputArgs args4 = {keys4, cnt4, expectRet4, args2.expectHistory, {}}; FeatureAdmitCommon(faae, 0, thresholds[0].tableName, args4); std::this_thread::sleep_for(std::chrono::seconds(SleepTime::SLEEP_SECOND_2)); @@ -313,7 +313,7 @@ protected: keys5 = {125, 121, 122, 212, 211} cnt5 = 1 2 1 3 1 */ - keys_t expectRet5 = {-1, 121, -1, 212, 211}; + KeysT expectRet5 = {-1, 121, -1, 212, 211}; InputArgs args5 = {keys5, cnt5, expectRet5, args3.expectHistory, {}}; FeatureAdmitCommon(faae, 0, thresholds[1].tableName, args5); @@ -328,10 +328,10 @@ protected: faae.ParseThresholdCfg(thresholds); // 测试点:tmpCnt.size() != tmpKeys.size() - keys_t tmpKeys = {11, 11, 33, 44, 11, 55, 88, 55}; + KeysT tmpKeys = {11, 11, 33, 44, 11, 55, 88, 55}; vector tmpCnt = {1, 2, 1, 3, 1, 1, 4}; - std::unique_ptr batch = make_unique(); + std::unique_ptr batch = make_unique(); batch->name = thresholds[0].tableName; batch->timestamp = time(nullptr); @@ -351,7 +351,7 @@ protected: {} // 校验多线程跑的结果 - void CheckMultiThreadRet(keys_t& expectKeys, std::vector& expectCnt, const std::string& embName, + void CheckMultiThreadRet(KeysT& expectKeys, std::vector& expectCnt, const std::string& embName, int threadCnt) { // 校验历史记录表信息 @@ -404,9 +404,9 @@ protected: tableBBB数据将会是 {121, 122, 123, 125, 211, 212}; 5 1 1 1 3 4 */ - keys_t expectKeys1 = {11, 33, 44, 55, 88}; // 12,21被淘汰掉了 + KeysT expectKeys1 = {11, 33, 44, 55, 88}; // 12,21被淘汰掉了 vector expectCnt1 = {10, 5, 5, 4, 6}; - keys_t expectKeys2 = {121, 122, 125, 211, 212}; // 123被淘汰掉了 + KeysT expectKeys2 = {121, 122, 125, 211, 212}; // 123被淘汰掉了 vector expectCnt2 = {5, 1, 1, 3, 4}; std::lock_guard lock(faae.m_syncMutexs); // 与 evict-thread 竞争资源 CheckMultiThreadRet(expectKeys1, expectCnt1, thresholds[0].tableName, threadNum); @@ -423,7 +423,7 @@ protected: faae.ResetAllRecords(); faae.ParseThresholdCfg(thresholds); - std::unique_ptr batch = make_unique(); + std::unique_ptr batch = make_unique(); // 测试点:tableDDD表没有配置阈值,则不支持 batch->name = std::string("tableDDD"); batch->timestamp = time(nullptr); @@ -436,15 +436,15 @@ protected: HashMapInfo initHistory; FeatureAdmitAndEvict faae; std::thread evictThr; - keys_t keys1 = {11, 11, 33, 44, 11, 55, 88, 55}; + KeysT keys1 = {11, 11, 33, 44, 11, 55, 88, 55}; vector cnt1 = {1, 2, 1, 3, 1, 1, 4, 1}; - keys_t keys2 = {11, 12, 33, 21, 11, 12}; + KeysT keys2 = {11, 12, 33, 21, 11, 12}; vector cnt2 = {1, 2, 1, 1, 2, 3}; - keys_t keys3 = {123, 121, 121, 212, 211}; + KeysT keys3 = {123, 121, 121, 212, 211}; vector cnt3 = {1, 2, 1, 1, 2}; - keys_t keys4 = {11, 11, 33, 44, 55, 88, 55}; + KeysT keys4 = {11, 11, 33, 44, 55, 88, 55}; vector cnt4 = {1, 2, 3, 2, 1, 2, 1}; - keys_t keys5 = {125, 121, 122, 212, 211}; + KeysT keys5 = {125, 121, 122, 212, 211}; vector cnt5 = {1, 2, 1, 3, 1}; std::vector thresholds = {{"tableAAA", 2, 5, 1}, {"tableBBB", 3, 7, 1}, {"tableCCC", 5, 9, 1}}; }; diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 9b3f6886..3c4ebf73 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -72,14 +72,14 @@ protected: splits = fieldNums; } - vector> PrepareBatch() + vector> PrepareBatch() { - vector> result(KEY_PROCESS_THREAD * MAX_CHANNEL_NUM); + vector> result(KEY_PROCESS_THREAD * MAX_CHANNEL_NUM); // 向共享队列中写入本进程所有线程要处理的 KEY_PROCESS_THREAD * BATCH_NUM_EACH_THREAD 个batch数据 for (size_t threadId = 0; threadId < KEY_PROCESS_THREAD; ++threadId) { int batchQueueId = threadId + KEY_PROCESS_THREAD * channel; unsigned int seed = batchQueueId * 10; - auto queue = SingletonQueue::getInstances(batchQueueId); + auto queue = SingletonQueue::GetInstances(batchQueueId); for (size_t batchNum = 0; batchNum < BATCH_NUM_EACH_THREAD; ++batchNum) { size_t batchId = @@ -96,7 +96,7 @@ protected: worldRank, worldSize, batchQueueId, batch->name, batch->channel, batch->batchId, batch->sample.size() ); - emb_batch_t temp; + EmbBatchT temp; temp.sample = batch->sample; temp.name = batch->name; temp.batchId = batch->batchId; @@ -153,9 +153,9 @@ protected: return true; } - auto GetSplitAndRestore(keys_t& sample) -> tuple, vector> + auto GetSplitAndRestore(KeysT& sample) -> tuple, vector> { - vector expectSplitKeys(worldSize); + vector expectSplitKeys(worldSize); vector expectRestore(sample.size()); absl::flat_hash_map uKey; for (unsigned int i = 0; i < sample.size(); ++i) { @@ -172,7 +172,7 @@ protected: return { expectSplitKeys, expectRestore }; } - void PrintHotHashSplit(const vector& splitKeys, + void PrintHotHashSplit(const vector& splitKeys, const vector& restore, const vector& hotPos, int rankSize) { @@ -186,7 +186,7 @@ protected: LOG_INFO(VectorToString(hotPos)); } - void GetExpectRestore(keys_t& sample, vector& blockOffset, vector& restoreVec) + void GetExpectRestore(KeysT& sample, vector& blockOffset, vector& restoreVec) { for (unsigned int i = 0; i < sample.size(); ++i) { int devId = sample[i] % worldSize; @@ -221,8 +221,8 @@ protected: vector src; vector allRankInfo; vector embInfos; - unique_ptr batchData; - vector splitKeys; + unique_ptr batchData; + vector splitKeys; vector restore; KeyProcess process; @@ -264,9 +264,9 @@ TEST_F(KeyProcessTest, Start) TEST_F(KeyProcessTest, HashSplit) { int rankSize = 4; - auto queue = SingletonQueue::getInstances(0); + auto queue = SingletonQueue::GetInstances(0); auto batch = queue->GetOne(); - keys_t batchKeys = { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }; + KeysT batchKeys = { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }; vector expectRestore = { 0, 0, 0, 0, 1, 1, 1, 1, 1, 2 }; vector> expectSplitKeys = { { 4, 16 }, { 1, 21, 29 }, { 14, 2 }, { 23, 7 } }; batch->sample = std::move(batchKeys); @@ -315,9 +315,9 @@ TEST_F(KeyProcessTest, GetScAllForUnique) TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { - auto queue = SingletonQueue::getInstances(0); + auto queue = SingletonQueue::GetInstances(0); auto batch = queue->GetOne(); - vector allBatchKeys = { { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }, + vector allBatchKeys = { { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }, { 5, 17, 26, 9, 27, 22, 27, 28, 15, 3 }, { 10, 4, 22, 17, 24, 13, 24, 26, 29, 11 }, { 14, 21, 18, 25, 21, 4, 20, 24, 13, 19 } }; @@ -354,11 +354,11 @@ TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) auto fn = [this](int channel, int id) { auto embName = embInfos[0].name; process.hotEmbTotCount[embName] = 10; - vector splitKeys; + vector splitKeys; vector restore; vector hotPos; - unique_ptr batch; - batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue + unique_ptr batch; + batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); LOG_INFO("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, VectorToString(hotPos)); @@ -381,11 +381,11 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) auto fn = [this](int channel, int id) { auto embName = embInfos[0].name; - vector splitKeys; + vector splitKeys; vector restore; vector hotPos; - unique_ptr batch; - batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue + unique_ptr batch; + batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); auto[lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); @@ -406,26 +406,26 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) TEST_F(KeyProcessTest, Key2Offset) { - keys_t lookupKeys = { 4, 16, 28, 4, 24, 4, 20, 24 }; - keys_t expectOffset = { 0, 1, 2, 0, 3, 0, 4, 3 }; + KeysT lookupKeys = { 4, 16, 28, 4, 24, 4, 20, 24 }; + KeysT expectOffset = { 0, 1, 2, 0, 3, 0, 4, 3 }; ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.Key2Offset("emb0", lookupKeys, TRAIN_CHANNEL_ID); - map tmp; + map tmp; for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { - tmp.insert(pair(it->first, MapToString(it->second).c_str())); + tmp.insert(pair(it->first, MapToString(it->second).c_str())); } LOG_DEBUG(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", VectorToString(lookupKeys), MapToString(tmp)); ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); - keys_t lookupKeys2 = { 5, 17, 29, 5, 25, 5, 21, 25 }; - keys_t expectOffset2 = { -1, -1, -1, -1, -1, -1, -1, -1 }; + KeysT lookupKeys2 = { 5, 17, 29, 5, 25, 5, 21, 25 }; + KeysT expectOffset2 = { -1, -1, -1, -1, -1, -1, -1, -1 }; process.Key2Offset("emb0", lookupKeys2, EVAL_CHANNEL_ID); - map tmp2; + map tmp2; for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { - tmp.insert(pair(it->first, MapToString(it->second).c_str())); + tmp.insert(pair(it->first, MapToString(it->second).c_str())); } LOG_DEBUG(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", VectorToString(lookupKeys2), MapToString(tmp2).c_str()); @@ -434,16 +434,16 @@ TEST_F(KeyProcessTest, Key2Offset) TEST_F(KeyProcessTest, Key2OffsetDynamicExpansion) { - keys_t lookupKeys = { 4, 16, 28, -1, 24, -1, 20, 24 }; - keys_t expectOffset = { 0, 0, 0, 0, 0, 0, 0, 0 }; + KeysT lookupKeys = { 4, 16, 28, -1, 24, -1, 20, 24 }; + KeysT expectOffset = { 0, 0, 0, 0, 0, 0, 0, 0 }; ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); process.Key2OffsetDynamicExpansion("emb0", lookupKeys, EVAL_CHANNEL_ID); LOG_DEBUG(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", VectorToString(lookupKeys), [&] { - map tmp; + map tmp; for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { - tmp.insert(pair(it->first, MapToString(it->second).c_str())); + tmp.insert(pair(it->first, MapToString(it->second).c_str())); } return MapToString(tmp); }()); @@ -485,7 +485,7 @@ TEST_F(KeyProcessTest, InitializeUnique) ASSERT_EQ(factory->CreateUnique(unique), 0); PrepareBatch(); - unique_ptr batch; + unique_ptr batch; batch = process.GetBatchData(0, 0); UniqueConf uniqueConf; process.rankInfo.rankSize = worldSize; @@ -498,7 +498,7 @@ TEST_F(KeyProcessTest, InitializeUnique) TEST_F(KeyProcessTest, GetKeySize) { PrepareBatch(); - unique_ptr batch; + unique_ptr batch; batch = process.GetBatchData(0, 0); process.rankInfo.rankSize = worldSize; process.rankInfo.useStatic = true; @@ -517,12 +517,12 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) auto embName = embInfos[0].name; process.hotEmbTotCount[embName] = 10; - vector splitKeys; + vector splitKeys; vector restore; vector hotPos; - unique_ptr batch; + unique_ptr batch; UniqueInfo uniqueInfo; - batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue + batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue ASSERT_EQ(factory->CreateUnique(unique), ock::ctr::H_OK); UniqueConf uniqueConf; diff --git a/src/tests/utils/log_test.cpp b/src/tests/utils/log_test.cpp index 09ed911a..120eb303 100644 --- a/src/tests/utils/log_test.cpp +++ b/src/tests/utils/log_test.cpp @@ -21,7 +21,7 @@ TEST(Log, Format) TEST(Log, LogLevel) { - MxRec::Log::SetLevel(Log::DEBUG); + MxRec::Log::SetLevel(Log::debug); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -33,7 +33,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Log::SetLevel(Log::INFO); + MxRec::Log::SetLevel(Log::info); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -45,7 +45,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Log::SetLevel(Log::WARN); + MxRec::Log::SetLevel(Log::warn); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -57,7 +57,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Log::SetLevel(Log::ERROR); + MxRec::Log::SetLevel(Log::error); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -72,7 +72,7 @@ TEST(Log, LogLevel) TEST(Log, LayzEvalution) { - MxRec::Log::SetLevel(Log::WARN); + MxRec::Log::SetLevel(Log::warn); testing::internal::CaptureStdout(); int flag1 = 0; int flag2 = 0; @@ -97,7 +97,7 @@ TEST(Log, LayzEvalution) TEST(Log, Basic) { - MxRec::Log::SetLevel(Log::INFO); + MxRec::Log::SetLevel(Log::info); testing::internal::CaptureStdout(); LOG_INFO("basictest"); std::string output = testing::internal::GetCapturedStdout(); @@ -106,7 +106,7 @@ TEST(Log, Basic) TEST(Log, TooManyArgs1) { - MxRec::Log::SetLevel(Log::INFO); + MxRec::Log::SetLevel(Log::info); testing::internal::CaptureStdout(); LOG_INFO("{} {} {}", 0.1f, 'h', 'e', "llow"); std::string output = testing::internal::GetCapturedStdout(); @@ -116,7 +116,7 @@ TEST(Log, TooManyArgs1) TEST(Log, TooManyArgs2) { - MxRec::Log::SetLevel(Log::INFO); + MxRec::Log::SetLevel(Log::info); testing::internal::CaptureStdout(); LOG_INFO("{}", "h", "h", "h", "h", "h", "h", "h"); std::string output = testing::internal::GetCapturedStdout(); @@ -126,7 +126,7 @@ TEST(Log, TooManyArgs2) TEST(Log, FewArgs) { - MxRec::Log::SetLevel(Log::INFO); + MxRec::Log::SetLevel(Log::info); testing::internal::CaptureStdout(); LOG_INFO("{} {} {} {} {} {}", "hellow", "hellow"); std::string output = testing::internal::GetCapturedStdout(); @@ -136,7 +136,7 @@ TEST(Log, FewArgs) TEST(Log, CkptType) { - MxRec::Log::SetLevel(Log::INFO); + MxRec::Log::SetLevel(Log::info); testing::internal::CaptureStdout(); LOG_INFO("ckpt type={}", CkptDataType::EMB_DATA); std::string output = testing::internal::GetCapturedStdout(); -- Gitee From b7152ec9143b6ad503c2e257e0652680ea35b8fa Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 23 Sep 2023 16:19:53 +0800 Subject: [PATCH 364/551] Match-id-b5edb93b6f5c0d8f62490b05cf955533a742359e --- mx_rec/core/asc/manager.py | 7 +-- src/core/checkpoint/checkpoint.cpp | 6 +-- src/core/emb_hashmap/emb_hashmap.cpp | 4 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 6 +-- src/core/key_process/key_process.cpp | 20 ++++---- src/core/ssd_engine/table.cpp | 13 ++++-- src/core/utils/common.cpp | 26 +++++------ src/core/utils/common.h | 53 ++++++++++++++++------ src/core/utils/config.cpp | 2 +- src/core/utils/{log.cpp => logger.cpp} | 24 +++++----- src/core/utils/{log.h => logger.h} | 47 ++++++++++--------- src/pybind/module_main.cpp | 27 +++++++++-- src/tests/emb_hashmap/emb_hashmap_test.cpp | 6 +-- src/tests/emb_mgmt/emb_mgmt_test.cpp | 20 ++++---- src/tests/ssd_cache/cache_manager_test.cpp | 2 +- src/tests/ssd_engine/engine_test.cpp | 21 +++++---- src/tests/ssd_engine/file_test.cpp | 16 +++---- src/tests/ssd_engine/table_test.cpp | 12 ++--- src/tests/utils/log_test.cpp | 22 ++++----- 19 files changed, 194 insertions(+), 140 deletions(-) rename src/core/utils/{log.cpp => logger.cpp} (54%) rename src/core/utils/{log.h => logger.h} (64%) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index dcac0ec8..11514e7b 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -27,7 +27,7 @@ def check_dangling_table(): def generate_table_info_list(): - from mxrec_pybind import EmbInfo + from mxrec_pybind import EmbInfo, EmbInfoParams from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN # table_name is corresponding to channel_name which is in used in operator gen_npu_ops.get_next table_info_list = [] @@ -61,8 +61,9 @@ def generate_table_info_list(): table_instance.slice_device_vocabulary_size) logger.debug("table_instance.slice_host_vocabulary_size: %s", table_instance.slice_host_vocabulary_size) logger.debug("table_instance.slice_ssd_vocabulary_size: %s", table_instance.slice_ssd_vocabulary_size) - table_info = EmbInfo(table_instance.table_name, table_instance.send_count, table_instance.scalar_emb_size, - table_instance.ext_emb_size, table_instance.is_save, + params = EmbInfoParams(table_instance.table_name, table_instance.send_count, table_instance.scalar_emb_size, + table_instance.ext_emb_size, table_instance.is_save) + table_info = EmbInfo(params, [table_instance.slice_device_vocabulary_size, table_instance.slice_host_vocabulary_size, table_instance.slice_ssd_vocabulary_size], [matched_emb_initializer(table_instance)] + diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 9f1415a5..d30acb2d 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -217,7 +217,7 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { LOG_ERROR("Set device failed, device_id:{}", deviceId); - throw runtime_error(Log::Format("Set device failed, device_id:{}", deviceId).c_str()); + throw runtime_error(Logger::Format("Set device failed, device_id:{}", deviceId).c_str()); } auto &transArr = transData.int64Arr; @@ -231,7 +231,7 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da ACL_MEMCPY_DEVICE_TO_HOST); if (ret != ACL_SUCCESS) { LOG_ERROR("aclrtMemcpy failed, ret={}", ret); - throw runtime_error(Log::Format("aclrtMemcpy failed, ret={}", ret).c_str()); + throw runtime_error(Logger::Format("aclrtMemcpy failed, ret={}", ret).c_str()); } writeFile.write(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); @@ -480,7 +480,7 @@ void Checkpoint::ReadStream(CkptTransData& transData, ValidateReadFile(dataDir, datasetSize); } catch (const std::invalid_argument& e) { readFile.close(); - throw runtime_error(Log::Format("Invalid read file path: {}", e.what())); + throw runtime_error(Logger::Format("Invalid read file path: {}", e.what())); } if (datasetSize % dataElmtBytes > 0) { diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index ff24531c..203377b3 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -110,7 +110,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, DDRPara auto swapLen = ddrParam.tmpDataOut.back().flat(); swapLen(0) = swapSize; - if (g_statOn) { + if (GlogConfig::gStatOn) { LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} swap_key_size {} swap_time_cost {}", channelId, swapId, rankInfo.rankId, swapSize, swapTimeCost.ElapsedMS()); } @@ -540,7 +540,7 @@ void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& /// 记录日志:HBM和DDR换入换出后,比较hostHashMap中DDR内key和表对应的lfuCache对象中的key内容 void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const { - if (Log::GetLevel() != Log::trace) { + if (Logger::GetLevel() != Logger::trace) { return; } auto& hostMap = embHashMap.hostHashMap; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index ade55872..04fe1596 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -7,7 +7,7 @@ #include "hybrid_mgmt.h" #include "utils/time_cost.h" -#include "utils/log.h" +#include "utils/logger.h" #include "checkpoint/checkpoint.h" @@ -85,7 +85,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, } InitRankInfo(rankInfo, embInfos); - g_statOn = GlobalEnv::statOn; + GlogConfig::gStatOn = GlobalEnv::statOn; LOG_INFO(MGMT + "begin initialize, localRankSize:{}, localRankId:{}, rank:{}", rankInfo.localRankSize, rankInfo.localRankId, rankInfo.rankId); @@ -142,7 +142,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, // 比较hostHashMap和cacheManager的数据是否一致 void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) { - if (Log::GetLevel() != Log::trace) { + if (Logger::GetLevel() != Logger::trace) { return; } auto& embHashMaps = saveData.embHashMaps; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 742b0217..48f67583 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -215,7 +215,7 @@ void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, auto ret = unique->Initialize(uniqueConf); if (ret != ock::ctr::H_OK) { - throw runtime_error(Log::Format("fast unique init failed, code:{}", ret)); + throw runtime_error(Logger::Format("fast unique init failed, code:{}", ret)); } uniqueInitialize = true; } @@ -231,7 +231,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) auto ret = factory->CreateUnique(unique); if (ret != ock::ctr::H_OK) { - throw runtime_error(Log::Format("create fast unique failed, error code:{}", ret)); + throw runtime_error(Logger::Format("create fast unique failed, error code:{}", ret)); } GetUniqueConfig(uniqueConf); @@ -363,7 +363,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch TimeCost pushResultTC; PushResult(batch, move(tensors), uniqueInfo.all2AllInfo.keyRecv); - if (g_statOn) { + if (GlogConfig::gStatOn) { LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} key_process_time_cost_with_fast_unique {}", channel, batch->batchId, rankInfo.rankId, totalTimeCost.ElapsedMS()); } @@ -427,7 +427,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, PushResult(batch, move(tensors), lookupKeys); LOG_DEBUG("pushResultTC(ms):{}", pushResultTC.ElapsedMS()); - if (g_statOn) { + if (GlogConfig::gStatOn) { LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} key_process_time_cost {}", channel, batch->batchId, rankInfo.rankId, totalTimeCost.ElapsedMS()); } @@ -601,7 +601,7 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch, batch->batchId, batch->Size(), batch->channel, batch->name, uniqueInfoOut.restore.size(), keySendInfo.keyCount.size()); - if (g_statOn) { + if (GlogConfig::gStatOn) { LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} " "batch_key_num_with_fast_unique {} unique_key_num_with_fast_unique {}", batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), uniqueOut.uniqueIdCnt); @@ -778,7 +778,7 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple& batch) const EASY_END_BLOCK LOG_TRACE("dump splitKeys {}", DumpSplitKeys(splitKeys)); - if (g_statOn) { + if (GlogConfig::gStatOn) { size_t uniqueKeyNum = 0; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { uniqueKeyNum += splitKeys[devId].size(); @@ -887,7 +887,7 @@ tuple, vector, vector> uKey[key] = restore[i]; } - if (g_statOn) { + if (GlogConfig::gStatOn) { size_t uniqueKeyNum = 0; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { uniqueKeyNum += splitKeys[devId].size(); @@ -1107,7 +1107,7 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vecto emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); if (restoreVec[i] >= hotPosSize) { restoreVec[i] += blockOffset[devId]; - } else if (Log::GetLevel() >= Log::debug) { + } else if (Logger::GetLevel() >= Logger::debug) { hotNum += 1; } } @@ -1301,7 +1301,7 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset LOG_ERROR("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", embName, offset.size(), embInfos[embName].devVocabSize); throw runtime_error( - Log::Format("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", + Logger::Format("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", embName, offset.size(), embInfos[embName].devVocabSize ).c_str()); } diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index 7fe23f2f..e3627059 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -19,7 +19,8 @@ Table::Table(const string &name, vector &savePaths, uint64_t maxTableSiz maxTableSize(maxTableSize), compactThreshold(compactThreshold) { - curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + g_rankId + "/" + name).string(); + curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + + saveDirPrefix + GlogConfig::gRankId + "/" + name).string(); if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { throw runtime_error("fail to create table directory"); } @@ -39,7 +40,8 @@ Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize compactThreshold(compactThreshold) { // always use first path to save until it's full - curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + g_rankId + "/" + name).string(); + curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + + saveDirPrefix + GlogConfig::gRankId + "/" + name).string(); if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { throw runtime_error("fail to create table directory"); } @@ -47,7 +49,8 @@ Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize bool isMetaFileFound = false; for (const string &dirPath: saveDirs) { auto metaFilePath = fs::absolute( - dirPath + "/" + saveDirPrefix + g_rankId + "/" + name + "/" + name + ".meta." + to_string(step)).string(); + dirPath + "/" + saveDirPrefix + GlogConfig::gRankId + "/" + + name + "/" + name + ".meta." + to_string(step)).string(); if (!fs::exists(metaFilePath)) { continue; } @@ -157,7 +160,7 @@ void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) shared_ptr loadedFile = nullptr; for (const string &p: savePaths) { // try to find data file from each path - string loadPath = p + "/" + saveDirPrefix + g_rankId + "/" + name; + string loadPath = p + "/" + saveDirPrefix + GlogConfig::gRankId + "/" + name; SetTablePathToDiskWithSpace(); if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { throw runtime_error("fail to create table directory"); @@ -377,7 +380,7 @@ void Table::SetTablePathToDiskWithSpace() throw runtime_error("all disk's space not enough"); } curTablePath = fs::absolute( - savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + g_rankId + "/" + name).string(); + savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + GlogConfig::gRankId + "/" + name).string(); LOG_INFO("current data path's available space less than {}%, try next path:{}", diskAvailSpaceThreshold * convertToPercentage, curTablePath); diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index b2a09dae..9bf5297e 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -21,10 +21,10 @@ using namespace std; using std::chrono::system_clock; namespace MxRec { - string g_rankId; - int g_glogLevel; bool g_isGlogInit = false; - bool g_statOn = false; + bool GlogConfig::gStatOn = false; + int GlogConfig::gGlogLevel; + string GlogConfig::gRankId; RankInfo::RankInfo(int rankId, int deviceId, int localRankSize, int option, const vector& maxStep) : rankId(rankId), deviceId(deviceId), localRankSize(localRankSize), option(option), maxStep(maxStep) @@ -33,9 +33,9 @@ namespace MxRec { if (localRankSize != 0) { localRankId = rankId % localRankSize; } - useStatic = option bitand HybridOption::USE_STATIC; - useHot = option bitand HybridOption::USE_HOT; - useDynamicExpansion = option bitand HybridOption::USE_DYNAMIC_EXPANSION; + useStatic = static_cast(option) bitand HybridOption::USE_STATIC; + useHot = static_cast(option) bitand HybridOption::USE_HOT; + useDynamicExpansion = static_cast(option) bitand HybridOption::USE_DYNAMIC_EXPANSION; } RankInfo::RankInfo(int localRankSize, int option, const vector& maxStep) @@ -46,8 +46,8 @@ namespace MxRec { if (localRankSize != 0) { localRankId = rankId % localRankSize; } - useStatic = option & HybridOption::USE_STATIC; - useHot = option & HybridOption::USE_HOT; + useStatic = static_cast(option) & HybridOption::USE_STATIC; + useHot = static_cast(option) & HybridOption::USE_HOT; } RandomInfo::RandomInfo(int start, int len, float constantVal, float randomMin, float randomMax) @@ -56,13 +56,13 @@ namespace MxRec { void SetLog(int rank) { - g_glogLevel = GlobalEnv::glogStderrthreshold; - if (g_rankId.empty()) { - g_rankId = std::to_string(rank); + GlogConfig::gGlogLevel = GlobalEnv::glogStderrthreshold; + if (GlogConfig::gRankId.empty()) { + GlogConfig::gRankId = std::to_string(rank); } if (!g_isGlogInit) { - Log::SetLevel(g_glogLevel); - Log::SetRank(rank); + Logger::SetLevel(GlogConfig::gGlogLevel); + Logger::SetRank(rank); g_isGlogInit = true; } } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 80c51498..afdf7252 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -25,7 +25,7 @@ #include "tensorflow/core/framework/tensor.h" #include "absl/container/flat_hash_map.h" #include "securec.h" -#include "utils/log.h" +#include "utils/logger.h" #include "utils/config.h" #include "initializer/initializer.h" @@ -63,15 +63,17 @@ namespace MxRec { constexpr int SSD_SIZE_INDEX = 2; // for GLOG - extern bool g_statOn; - extern int g_glogLevel; - extern string g_rankId; + struct GlogConfig { + static bool gStatOn; + static int gGlogLevel; + static string gRankId; + }; + constexpr int GLOG_MAX_BUF_SIZE = 1024; constexpr int GLOG_TIME_WIDTH_2 = 2; constexpr int GLOG_TIME_WIDTH_6 = 6; constexpr char GLOG_STAT_FLAG[] = "statOn"; - // unique related config constexpr int UNIQUE_BUCKET = 6; constexpr int MIN_UNIQUE_THREAD_NUM = 1; @@ -102,9 +104,9 @@ namespace MxRec { using TensorInfoT = std::tuple>>::iterator>; namespace HybridOption { - const int USE_STATIC = 0x001; - const int USE_HOT = 0x001 << 1; - const int USE_DYNAMIC_EXPANSION = 0x001 << 2; + const unsigned int USE_STATIC = 0x001; + const unsigned int USE_HOT = 0x001 << 1; + const unsigned int USE_DYNAMIC_EXPANSION = 0x001 << 2; }; string GetChipName(int devID); @@ -135,7 +137,7 @@ namespace MxRec { {"910B2", UBSize::ASCEND910_B2}, {"910B3", UBSize::ASCEND910_B3}, {"910B4", UBSize::ASCEND910_B4}}; - const auto it = chipUbSizeList.find(GetChipName(devID)); + std::map::const_iterator it = chipUbSizeList.find(GetChipName(devID)); if (it != chipUbSizeList.end()) { return it->second; } @@ -374,19 +376,40 @@ namespace MxRec { return tmpTensor; } - struct EmbInfo { - EmbInfo() = default; + struct EmbInfoParams { + EmbInfoParams() = default; - EmbInfo(const std::string& name, + EmbInfoParams(const std::string& name, int sendCount, int embeddingSize, int extEmbeddingSize, - bool isSave, + bool isSave) + : name(name), + sendCount(sendCount), + embeddingSize(embeddingSize), + extEmbeddingSize(extEmbeddingSize), + isSave(isSave) + { + } + std::string name; + int sendCount; + int embeddingSize; + int extEmbeddingSize; + bool isSave; + }; + + struct EmbInfo { + EmbInfo() = default; + + EmbInfo(const EmbInfoParams& embInfoParams, std::vector vocabsize, std::vector initializeInfos, std::vector ssdDataPath) - : name(name), sendCount(sendCount), embeddingSize(embeddingSize), extEmbeddingSize(extEmbeddingSize), - isSave(isSave), + : name(embInfoParams.name), + sendCount(embInfoParams.sendCount), + embeddingSize(embInfoParams.embeddingSize), + extEmbeddingSize(embInfoParams.extEmbeddingSize), + isSave(embInfoParams.isSave), devVocabSize(vocabsize[0]), hostVocabSize(vocabsize[1]), ssdVocabSize(vocabsize[SSD_SIZE_INDEX]), diff --git a/src/core/utils/config.cpp b/src/core/utils/config.cpp index 35f0e6ff..255df2b9 100644 --- a/src/core/utils/config.cpp +++ b/src/core/utils/config.cpp @@ -8,7 +8,7 @@ #include "config.h" -#include "log.h" +#include "logger.h" using namespace std; diff --git a/src/core/utils/log.cpp b/src/core/utils/logger.cpp similarity index 54% rename from src/core/utils/log.cpp rename to src/core/utils/logger.cpp index f3e490b2..3e8c5339 100644 --- a/src/core/utils/log.cpp +++ b/src/core/utils/logger.cpp @@ -7,31 +7,31 @@ */ -#include "utils/log.h" +#include "utils/logger.h" namespace MxRec { -int MxRec::Log::level = MxRec::Log::info; -int MxRec::Log::rank = 0; +int MxRec::Logger::level = MxRec::Logger::info; +int MxRec::Logger::rank = 0; -void Log::SetRank(int rank) +void Logger::SetRank(int logRank) { - Log::rank = rank; + Logger::rank = logRank; } -void Log::SetLevel(int level) +void Logger::SetLevel(int logLevel) { - Log::level = level; + Logger::level = logLevel; } -int Log::GetLevel() +int Logger::GetLevel() { - return Log::level; + return Logger::level; } -const char* Log::LevelToStr(int level) +const char* Logger::LevelToStr(int logLevel) { - if (level < trace || level > error) { + if (logLevel < trace || logLevel > error) { return "INVALID LEVEL"; } static const char* msg[] = { @@ -45,7 +45,7 @@ const char* Log::LevelToStr(int level) return msg[level + levelOffset]; } -void Log::LogUnpack(std::queue& fmt, std::stringstream &ss) +void Logger::LogUnpack(std::queue& fmt, std::stringstream &ss) { while (!fmt.empty()) { ss << fmt.front(); diff --git a/src/core/utils/log.h b/src/core/utils/logger.h similarity index 64% rename from src/core/utils/log.h rename to src/core/utils/logger.h index ed740cbf..68028f11 100644 --- a/src/core/utils/log.h +++ b/src/core/utils/logger.h @@ -6,8 +6,8 @@ * History: NA */ -#ifndef MXREC_LOG_H -#define MXREC_LOG_H +#ifndef MXREC_LOGGER_H +#define MXREC_LOGGER_H #include #include @@ -24,7 +24,7 @@ namespace MxRec { constexpr int YEAR_BASE = 1900; constexpr size_t DELIM_LEN = 2; -class Log { +class Logger { public: static constexpr int trace = -2; @@ -33,9 +33,9 @@ public: static constexpr int warn = 1; static constexpr int error = 2; - static void SetRank(int rank); + static void SetRank(int logRank); - static void SetLevel(int level); + static void SetLevel(int logLevel); static int GetLevel(); @@ -57,12 +57,12 @@ public: static std::string Format(const char* fmt, Args &&...args) { std::stringstream ss; - Log::Format(ss, fmt, args...); + Logger::Format(ss, fmt, args...); return ss.str(); } template - static void log(const char* file, int line, int level, const char* fmt, Args &&...args) + static void Log(const char* file, int line, int level, const char* fmt, Args &&...args) { std::stringstream ss; struct tm t; @@ -71,21 +71,21 @@ public: localtime_r(&tv.tv_sec, &t); ss << "[MxRec][" << YEAR_BASE + t.tm_year << "/" << t.tm_mon << "/" << t.tm_mday<< " " << t.tm_hour << ":" << t.tm_min << ":" << t.tm_sec << "." << tv.tv_usec << "] [" - << Log::rank << "] ["<< Log::LevelToStr(level) << "] [" + << Logger::rank << "] ["<< Logger::LevelToStr(level) << "] [" << (strrchr(file, '/') ? strrchr(file, '/') + 1 : file) << ":" << line << "] "; - Log::Format(ss, fmt, args...); + Logger::Format(ss, fmt, args...); ss << std::endl; std::cout << ss.str(); } template - static void log(const char* file, int line, int level, const std::string& fmt, Args &&...args) + static void Log(const char* file, int line, int level, const std::string& fmt, Args &&...args) { - Log::log(file, line, level, fmt.c_str(), args...); + Logger::Log(file, line, level, fmt.c_str(), args...); } private: - static const char* LevelToStr(int level); + static const char* LevelToStr(int logLevel); static void LogUnpack(std::queue& fmt, std::stringstream &ss); @@ -103,22 +103,21 @@ private: static int rank; }; +#define LOG_TRACE(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::trace) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::trace, args) -#define LOG_TRACE(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::trace) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::trace, args) +#define LOG_DEBUG(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::debug) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::debug, args) -#define LOG_DEBUG(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::debug) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::debug, args) +#define LOG_INFO(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::info) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::info, args) -#define LOG_INFO(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::info) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::info, args) +#define LOG_WARN(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::warn) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::warn, args) -#define LOG_WARN(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::warn) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::warn, args) - -#define LOG_ERROR(args...) if (MxRec::Log::GetLevel() <= MxRec::Log::error) \ -MxRec::Log::log(__FILE__, __LINE__, MxRec::Log::error, args) +#define LOG_ERROR(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::error) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::error, args) } -#endif // MXREC_LOG_H \ No newline at end of file +#endif // MXREC_LOGGER_H \ No newline at end of file diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index e64eb33f..8a04b288 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -15,6 +15,8 @@ using namespace MxRec; namespace { void GetRankInfo(py::module_& m); + void GetEmbInfoParams(py::module_& m); + void GetEmbInfo(py::module_& m); void GetRandomInfo(py::module_& m); @@ -58,6 +60,8 @@ namespace { GetRankInfo(m); + GetEmbInfoParams(m); + GetEmbInfo(m); GetRandomInfo(m); @@ -87,13 +91,30 @@ namespace { .def_readwrite("max_step", &RankInfo::maxStep); } + void GetEmbInfoParams(pybind11::module_& m) + { + pybind11::class_(m, "EmbInfoParams") + .def(pybind11::init(), + py::arg("name"), + py::arg("send_count"), + py::arg("embedding_size"), + py::arg("ext_embedding_size"), + py::arg("is_save")) + .def_readwrite("name", &EmbInfoParams::name) + .def_readwrite("send_count", &EmbInfoParams::sendCount) + .def_readwrite("embedding_size", &EmbInfoParams::embeddingSize) + .def_readwrite("ext_embedding_size", &EmbInfoParams::extEmbeddingSize) + .def_readwrite("is_save", &EmbInfoParams::isSave); + } + void GetEmbInfo(pybind11::module_& m) { pybind11::class_(m, "EmbInfo") - .def(pybind11::init, + .def(pybind11::init, std::vector&, std::vector&>(), - py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), py::arg("ext_embedding_size"), - py::arg("is_save"), py::arg("vocab_size"), py::arg("initialize_infos"), + py::arg("embInfoParams"), + py::arg("vocab_size"), + py::arg("initialize_infos"), py::arg("ssd_data_path")) .def_readwrite("name", &EmbInfo::name) .def_readwrite("send_count", &EmbInfo::sendCount) diff --git a/src/tests/emb_hashmap/emb_hashmap_test.cpp b/src/tests/emb_hashmap/emb_hashmap_test.cpp index 6c03c9c7..fe46ac07 100644 --- a/src/tests/emb_hashmap/emb_hashmap_test.cpp +++ b/src/tests/emb_hashmap/emb_hashmap_test.cpp @@ -86,8 +86,8 @@ TEST(EmbHashMap, TestFindOffset) auto& excludeKeyMap = cacheManager.excludeDDRKeyCountMap[embTableName]; auto& ddrKeyMap = cacheManager.ddrKeyFreqMap[embTableName]; - auto logLevelTemp = Log::GetLevel(); - Log::SetLevel(Log::trace); + auto logLevelTemp = Logger::GetLevel(); + Logger::SetLevel(Logger::trace); vector keys4 = {21, 21, 21, 21}; // 新key重复值, 且需要换入换出 hostHashMaps.FindOffset(embTableName, keys4, currentBatchId++, keepBatchId++, channelId); RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); @@ -106,6 +106,6 @@ TEST(EmbHashMap, TestFindOffset) ASSERT_EQ(excludeKeyMap[INT_42], INT_2); ASSERT_EQ(ddrKeyMap.Get(INT_21), NEGATIVE_INT_1); ASSERT_EQ(ddrKeyMap.Get(1), NEGATIVE_INT_1); - Log::SetLevel(logLevelTemp); // 恢复日志级别 + Logger::SetLevel(logLevelTemp); // 恢复日志级别 LOG_INFO("test TestFindOffset end."); } \ No newline at end of file diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 8b08443f..39e53d6b 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -113,14 +113,15 @@ protected: TEST_F(EmbMgmtTest, Initialize) { vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); + aoto param = EmbInfoParams(name, sendCount, embeddingSize, extEmbeddingSize, isSave) + embInfo = EmbInfo(param, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues = {}; auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; - allRank = RankInfo(g_rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); + allRank = RankInfo(GlogConfig::gRankId, deviceId, localRankSize, useStatic, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); auto hostEmbs = make_unique(); hostEmbs->Initialize(embInfos, seed); @@ -172,14 +173,15 @@ TEST_F(EmbMgmtTest, Initialize_HBM) devVocabSize = HBM_DEVICE_SIZE; hostVocabSize = HBM_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); + aoto param = EmbInfoParams(name, sendCount, embeddingSize, extEmbeddingSize, isSave) + embInfo = EmbInfo(params, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; thresholdValues.emplace_back(name, 1, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; - allRank = RankInfo(g_rankId, deviceId, localRankSize, useStatic, nBatch, maxStep); + allRank = RankInfo(GlogConfig::gRankId, deviceId, localRankSize, useStatic, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); hybridMgmt->Destroy(); @@ -192,14 +194,15 @@ TEST_F(EmbMgmtTest, Evict) size_t devVocabSize = DDR_DEVICE_SIZE; size_t hostVocabSize = DDR_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); + aoto param = EmbInfoParams(name, sendCount, embeddingSize, extEmbeddingSize, isSave) + embInfo = EmbInfo(params, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; thresholdValues.emplace_back(name, 1, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; - allRank = RankInfo(g_rankId, deviceId, localRankSize, true, nBatch, maxStep); + allRank = RankInfo(GlogConfig::gRankId, deviceId, localRankSize, true, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); // evict test, ddr @@ -215,14 +218,15 @@ TEST_F(EmbMgmtTest, Evict_HBM) devVocabSize = HBM_DEVICE_SIZE; hostVocabSize = HBM_HOST_SIZE; vector vocabsize = { devVocabSize, hostVocabSize }; - embInfo = EmbInfo(name, sendCount, embeddingSize, extEmbeddingSize, isSave, vocabsize, initializeInfos); + aoto param = EmbInfoParams(name, sendCount, embeddingSize, extEmbeddingSize, isSave) + embInfo = EmbInfo(params, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; thresholdValues.emplace_back(name, 1, 1, 1); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; - allRank = RankInfo(g_rankId, deviceId, localRankSize, true, nBatch, maxStep); + allRank = RankInfo(GlogConfig::gRankId, deviceId, localRankSize, true, nBatch, maxStep); hybridMgmt->Initialize(allRank, embInfos, seed, thresholdValues, false); // evict test, hbm diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 651b13c7..45c589bf 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -71,7 +71,7 @@ protected: // 设置全局rankId,ssdEngine保存时会使用 int workRankId; MPI_Comm_rank(MPI_COMM_WORLD, &workRankId); - g_rankId = to_string(workRankId); + GlogConfig::gRankId = to_string(workRankId); cacheManager.ddrKeyFreqMap[embTableName] = cache; cacheManager.ddrKeyFreqMap[embTableName].PutKeys(input_keys); diff --git a/src/tests/ssd_engine/engine_test.cpp b/src/tests/ssd_engine/engine_test.cpp index d6805f46..a09b2078 100644 --- a/src/tests/ssd_engine/engine_test.cpp +++ b/src/tests/ssd_engine/engine_test.cpp @@ -16,7 +16,7 @@ TEST(SSDEngine, CreateAndWriteAndReadAndAutoCompactAndSave) { int rankId; MPI_Comm_rank(MPI_COMM_WORLD, &rankId); - g_rankId = to_string(rankId); + GlogConfig::gRankId = to_string(rankId); string tbName = "test"; vector savePath = {"."}; @@ -68,26 +68,29 @@ TEST(SSDEngine, CreateAndWriteAndReadAndAutoCompactAndSave) // after saving, full compact will perform, old file will be deleted string oldDataFilePath = - savePath.front() + "ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "0.data.latest"; + savePath.front() + "ssd_sparse_model_rank_" + GlogConfig::gRankId + "/" + tbName + "/" + "0.data.latest"; string oldMetaFilePath = - savePath.front() + "ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "0.meta.latest"; + savePath.front() + "ssd_sparse_model_rank_" + GlogConfig::gRankId + "/" + tbName + "/" + "0.meta.latest"; ASSERT_EQ(fs::exists(oldDataFilePath), false); ASSERT_EQ(fs::exists(oldMetaFilePath), false); // check saved data existence string newDataFilePath = - savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "1.data." + to_string(saveStep); + savePath.front() + "/ssd_sparse_model_rank_" + GlogConfig::gRankId + "/" + + tbName + "/" + "1.data." + to_string(saveStep); string newMetaFilePath = - savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "1.meta." + to_string(saveStep); + savePath.front() + "/ssd_sparse_model_rank_" + GlogConfig::gRankId + "/" + + tbName + "/" + "1.meta." + to_string(saveStep); string newTableMetaFilePath = - savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + tbName + ".meta." + + savePath.front() + "/ssd_sparse_model_rank_" + GlogConfig::gRankId + "/" + + tbName + "/" + tbName + ".meta." + to_string(saveStep); ASSERT_EQ(fs::exists(newDataFilePath), true); ASSERT_EQ(fs::exists(newMetaFilePath), true); ASSERT_EQ(fs::exists(newTableMetaFilePath), true); for (const string& p: savePath) { - fs::remove_all(p + "/ssd_sparse_model_rank_" + g_rankId); + fs::remove_all(p + "/ssd_sparse_model_rank_" + GlogConfig::gRankId); } } @@ -95,10 +98,10 @@ TEST(SSDEngine, LoadAndRead) { int rankId; MPI_Comm_rank(MPI_COMM_WORLD, &rankId); - g_rankId = to_string(rankId); + GlogConfig::gRankId = to_string(rankId); string tbName = "test"; - vector savePath = {g_rankId}; + vector savePath = {GlogConfig::gRankId}; uint64_t maxTableSize = 100; int saveStep = 0; diff --git a/src/tests/ssd_engine/file_test.cpp b/src/tests/ssd_engine/file_test.cpp index e4458514..1b60a801 100644 --- a/src/tests/ssd_engine/file_test.cpp +++ b/src/tests/ssd_engine/file_test.cpp @@ -17,9 +17,9 @@ TEST(File, CreateEmptyFile) { int rankId; MPI_Comm_rank(MPI_COMM_WORLD, &rankId); - g_rankId = to_string(rankId); + GlogConfig::gRankId = to_string(rankId); - string fileDir = g_rankId; + string fileDir = GlogConfig::gRankId; bool isExceptionThrown = false; try { auto f = make_shared(0, fileDir); @@ -36,9 +36,9 @@ TEST(File, LoadFromFile) // prepare int rankId; MPI_Comm_rank(MPI_COMM_WORLD, &rankId); - g_rankId = to_string(rankId); + GlogConfig::gRankId = to_string(rankId); - string fileDir = g_rankId; + string fileDir = GlogConfig::gRankId; if (!fs::exists(fs::absolute(fileDir)) && !fs::create_directories(fs::absolute(fileDir))) { throw runtime_error("fail to create Save directory"); } @@ -85,9 +85,9 @@ TEST(File, WriteAndRead) { int rankId; MPI_Comm_rank(MPI_COMM_WORLD, &rankId); - g_rankId = to_string(rankId); + GlogConfig::gRankId = to_string(rankId); - string savePath = g_rankId; + string savePath = GlogConfig::gRankId; auto f = make_shared(0, savePath); vector keys; @@ -113,10 +113,10 @@ TEST(File, SaveAndLoad) { int rankId; MPI_Comm_rank(MPI_COMM_WORLD, &rankId); - g_rankId = to_string(rankId); + GlogConfig::gRankId = to_string(rankId); int saveStep = 0; - string fileDir = g_rankId; + string fileDir = GlogConfig::gRankId; auto fTmp = make_shared(0, fileDir); vector key = {0}; diff --git a/src/tests/ssd_engine/table_test.cpp b/src/tests/ssd_engine/table_test.cpp index 44b1704c..f9765f58 100644 --- a/src/tests/ssd_engine/table_test.cpp +++ b/src/tests/ssd_engine/table_test.cpp @@ -16,10 +16,10 @@ TEST(Table, WriteAndReadAndDeleteAndCompact) { int rankId; MPI_Comm_rank(MPI_COMM_WORLD, &rankId); - g_rankId = to_string(rankId); + GlogConfig::gRankId = to_string(rankId); string tbName = "test"; - vector savePath = {g_rankId}; + vector savePath = {GlogConfig::gRankId}; uint64_t maxTableSize = 1000000; uint64_t embDim = 240; double compactThreshold = 0.5; @@ -83,9 +83,9 @@ TEST(Table, WriteAndReadAndDeleteAndCompact) // full compact, old file will delete, valid data will move to new file tb->Compact(true); string oldDataFilePath = - savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "0.data.latest"; + savePath.front() + "/ssd_sparse_model_rank_" + GlogConfig::gRankId + "/" + tbName + "/" + "0.data.latest"; string oldMetaFilePath = - savePath.front() + "/ssd_sparse_model_rank_" + g_rankId + "/" + tbName + "/" + "0.meta.latest"; + savePath.front() + "/ssd_sparse_model_rank_" + GlogConfig::gRankId + "/" + tbName + "/" + "0.meta.latest"; ASSERT_EQ(fs::exists(oldDataFilePath), false); ASSERT_EQ(fs::exists(oldMetaFilePath), false); @@ -98,10 +98,10 @@ TEST(Table, SaveAndLoad) { int rankId; MPI_Comm_rank(MPI_COMM_WORLD, &rankId); - g_rankId = to_string(rankId); + GlogConfig::gRankId = to_string(rankId); string tbName = "test"; - vector savePath = {g_rankId}; + vector savePath = {GlogConfig::gRankId}; uint64_t maxTableSize = 100; double compactThreshold = 0.5; int saveStep = 0; diff --git a/src/tests/utils/log_test.cpp b/src/tests/utils/log_test.cpp index 120eb303..9aa70bd8 100644 --- a/src/tests/utils/log_test.cpp +++ b/src/tests/utils/log_test.cpp @@ -15,13 +15,13 @@ using namespace testing; TEST(Log, Format) { - string test = Log::Format("{}{}{}", 1, 2, 3); + string test = Logger::Format("{}{}{}", 1, 2, 3); EXPECT_STREQ(test.c_str(), "123"); } TEST(Log, LogLevel) { - MxRec::Log::SetLevel(Log::debug); + MxRec::Logger::SetLevel(Logger::debug); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -33,7 +33,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Log::SetLevel(Log::info); + MxRec::Logger::SetLevel(Logger::info); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -45,7 +45,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Log::SetLevel(Log::warn); + MxRec::Logger::SetLevel(Logger::warn); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -57,7 +57,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Log::SetLevel(Log::error); + MxRec::Logger::SetLevel(Logger::error); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -72,7 +72,7 @@ TEST(Log, LogLevel) TEST(Log, LayzEvalution) { - MxRec::Log::SetLevel(Log::warn); + MxRec::Logger::SetLevel(Logger::warn); testing::internal::CaptureStdout(); int flag1 = 0; int flag2 = 0; @@ -97,7 +97,7 @@ TEST(Log, LayzEvalution) TEST(Log, Basic) { - MxRec::Log::SetLevel(Log::info); + MxRec::Logger::SetLevel(Logger::info); testing::internal::CaptureStdout(); LOG_INFO("basictest"); std::string output = testing::internal::GetCapturedStdout(); @@ -106,7 +106,7 @@ TEST(Log, Basic) TEST(Log, TooManyArgs1) { - MxRec::Log::SetLevel(Log::info); + MxRec::Logger::SetLevel(Logger::info); testing::internal::CaptureStdout(); LOG_INFO("{} {} {}", 0.1f, 'h', 'e', "llow"); std::string output = testing::internal::GetCapturedStdout(); @@ -116,7 +116,7 @@ TEST(Log, TooManyArgs1) TEST(Log, TooManyArgs2) { - MxRec::Log::SetLevel(Log::info); + MxRec::Logger::SetLevel(Logger::info); testing::internal::CaptureStdout(); LOG_INFO("{}", "h", "h", "h", "h", "h", "h", "h"); std::string output = testing::internal::GetCapturedStdout(); @@ -126,7 +126,7 @@ TEST(Log, TooManyArgs2) TEST(Log, FewArgs) { - MxRec::Log::SetLevel(Log::info); + MxRec::Logger::SetLevel(Logger::info); testing::internal::CaptureStdout(); LOG_INFO("{} {} {} {} {} {}", "hellow", "hellow"); std::string output = testing::internal::GetCapturedStdout(); @@ -136,7 +136,7 @@ TEST(Log, FewArgs) TEST(Log, CkptType) { - MxRec::Log::SetLevel(Log::info); + MxRec::Logger::SetLevel(Logger::info); testing::internal::CaptureStdout(); LOG_INFO("ckpt type={}", CkptDataType::EMB_DATA); std::string output = testing::internal::GetCapturedStdout(); -- Gitee From 60a8afce114e5f97c636ad4f2591dfd863f3ae69 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 25 Sep 2023 09:45:35 +0800 Subject: [PATCH 365/551] Match-id-b1ba931b466a30d7e2efe1baf4211ffd19f029a4 --- src/core/utils/common.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index afdf7252..47d1a8ff 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -130,14 +130,14 @@ namespace MxRec { inline int GetUBSize(int devID) { - std::map chipUbSizeList = {{"910A", UBSize::ASCEND910_A}, - {"910B", UBSize::ASCEND910_B}, - {"920A", UBSize::ASCEND920_A}, - {"910B1", UBSize::ASCEND910_B1}, - {"910B2", UBSize::ASCEND910_B2}, - {"910B3", UBSize::ASCEND910_B3}, - {"910B4", UBSize::ASCEND910_B4}}; - std::map::const_iterator it = chipUbSizeList.find(GetChipName(devID)); + const std::map chipUbSizeList = {{"910A", UBSize::ASCEND910_A}, + {"910B", UBSize::ASCEND910_B}, + {"920A", UBSize::ASCEND920_A}, + {"910B1", UBSize::ASCEND910_B1}, + {"910B2", UBSize::ASCEND910_B2}, + {"910B3", UBSize::ASCEND910_B3}, + {"910B4", UBSize::ASCEND910_B4}}; + auto it = chipUbSizeList.find(GetChipName(devID)); if (it != chipUbSizeList.end()) { return it->second; } -- Gitee From 81180e2ad81ef052da1eeac4e62b125426ff79c1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 25 Sep 2023 09:59:16 +0800 Subject: [PATCH 366/551] Match-id-5638ce0ac3fe1f64a44b985ce644f0357ac799cd --- mx_rec/saver/saver.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 5d7feb47..b3e67497 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -206,10 +206,6 @@ class Saver(object): @performance("_save") def _save(self, sess, root_dir): result = self.save_op_dict - if is_asc_manager_initialized() and not self.save_easy_mode: - save_host_data(root_dir) - logger.debug(f"host data was saved.") - threads = [] for table_name in result.keys(): thread = SaveModelThread(sess, result, root_dir, table_name) @@ -221,6 +217,10 @@ class Saver(object): for thread in threads: thread.join() + if is_asc_manager_initialized() and not self.save_easy_mode: + save_host_data(root_dir) + logger.debug(f"host data was saved.") + def _save_easy_mode_save_key_data(self, dump_data_dict, root_dir, table_name): host_data = get_host_data(table_name) key = np.array(list(host_data.keys())) -- Gitee From fc53b047ebf91e3ca41381ccbbc7ba571dcea649 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 25 Sep 2023 12:28:52 +0800 Subject: [PATCH 367/551] Match-id-2ae7e6d23278490961ce1adb40f40cf3ff5192ce --- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index 7f530a64..b11f8d44 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -16,6 +16,11 @@ using namespace MxRec; /// \param channelId train 0 eval 1 void HybridMgmtBlock::CheckAndSetBlock(int channelId) { + // 当hybrid为0时,只有两种情况,程序启动时和Reset的时候,这两种情况都不应该阻塞 + if (hybridBatchId[channelId] == 0) { + return; + } + // 判断save时候的阻塞情况 // 当在进行训练通道,且save interval不为0和-1(不需要阻塞),且运行到了需要阻塞的步骤 if (channelId == TRAIN_CHANNEL_ID && saveInterval != 0 && @@ -127,7 +132,7 @@ void HybridMgmtBlock::DoBlock(int channelId) LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt starts blocking channelId {} hybridBatchId {}", channelId, hybridBatchId[channelId]); - while (isBlock[channelId]) { + while (isBlock[channelId] and !rankInfo.noDDR) { std::this_thread::sleep_for(SLEEP_MS); if (!isRunning) { return; -- Gitee From 2abff3670e2e836e8ac4624db87f2741a5b8bccb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 22 Sep 2023 17:02:21 +0800 Subject: [PATCH 368/551] Match-id-deff982a7fef823b2a725a2a646e6508a9f4e67a --- mx_rec/saver/saver.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 5d7feb47..87cbb222 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -50,6 +50,7 @@ class Saver(object): self.placeholder_dict = defaultdict(dict) # save_easy_mode : only save the embedding and key data of sparse tables self.save_easy_mode = (global_env.save_easy == Flag.TRUE.value) + self._last_checkponts = [] self.build() @staticmethod @@ -121,6 +122,13 @@ class Saver(object): logger.info("rank id %s | Saving_path '%s' has been made.", self.rank_id, saving_path) self._save(sess, saving_path) + if self.max_to_keep: + self._last_checkponts.append(saving_path) + if len(self._last_checkponts) > self.max_to_keep: + logger.info("checkpoints num %d > max_to_keep %d delete %s", + len(self._last_checkponts), self.max_to_keep, + self._last_checkponts[0]) + tf.io.gfile.rmtree(self._last_checkponts.pop(0)) logger.info("sparse model was saved in dir '%s' .", saving_path) logger.info("======== Saving finished for rank id %s ========", self.rank_id) -- Gitee From e9b7eac41130cfbf2fd0d71dc03d2d32ea4660cb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 25 Sep 2023 17:08:07 +0800 Subject: [PATCH 369/551] Match-id-fd7334e7de59a99c24859cf8a528520e9cbf277b --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 22 +++++++++++++++++----- src/core/key_process/key_process.cpp | 15 +++++++++++++++ src/core/key_process/key_process.h | 2 ++ 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 04fe1596..2bbb7db8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -862,15 +862,27 @@ bool HybridMgmt::Evict() } if (mgmtRankInfo.noDDR) { - for (const auto& evict : as_const(evictKeyMap)) { - preprocess->EvictKeys(evict.first, evict.second); + if (GlobalEnv::useCombineFaae) { + preprocess->EvictKeysCombine(evictKeyMap[COMBINE_HISTORY_NAME]); + } else { + for (const auto& evict : as_const(evictKeyMap)) { + preprocess->EvictKeys(evict.first, evict.second); + } } } else { - for (const auto& evict : as_const(evictKeyMap)) { - EvictKeys(evict.first, evict.second); - EvictSSDKeys(evict.first, evict.second); + if (GlobalEnv::useCombineFaae) { + for (auto& map : hostHashMaps->embHashMaps) { + EvictKeys(map.first, evictKeyMap[COMBINE_HISTORY_NAME]); + EvictSSDKeys(map.first, evictKeyMap[COMBINE_HISTORY_NAME]); + } + } else { + for (const auto& evict : as_const(evictKeyMap)) { + EvictKeys(evict.first, evict.second); + EvictSSDKeys(evict.first, evict.second); + } } } + evictKeyMap.clear(); return true; #endif } diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 48f67583..34828e4e 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1267,6 +1267,21 @@ void KeyProcess::EvictKeys(const string& embName, const vector& keys) EvictInitDeviceEmb(embName, evictPosMap.at(embName)); } +void KeyProcess::EvictKeysCombine(const vector& keys) // hbm +{ + LOG_INFO(KEY_PROCESS "hbm combine funEvictCall, keySize:{}", keys.size()); + // 删除映射关系 + if (keys.size() != 0) { + for (auto& map : keyOffsetMap) { + EvictDeleteDeviceEmb(map.first, keys); + } + } + for (auto map : evictPosMap) { + // 初始化 dev + EvictInitDeviceEmb(map.first, map.second); + } +} + void KeyProcess::EvictDeleteDeviceEmb(const string& embName, const vector& keys) { EASY_FUNCTION(profiler::colors::Blue600) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 767fe387..a15bbc71 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -108,6 +108,8 @@ namespace MxRec { void EvictKeys(const string& embName, const vector& keys); + void EvictKeysCombine(const vector& keys); + void SetupHotEmbUpdateStep(); template -- Gitee From d3ebce6ddb7d05e324df9f646513dc20f86937a8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 25 Sep 2023 17:10:45 +0800 Subject: [PATCH 370/551] Match-id-8bb5a6e79220c32703b65e70321cd641c422acb9 --- mx_rec/constants/constants.py | 4 ++++ mx_rec/util/initialize.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index d5eb61b7..1dc9d3bb 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -20,8 +20,12 @@ ASCEND_SPARSE_LOOKUP_HOT_POS = "ASCEND_SPARSE_LOOKUP_HOT_POS" ASCEND_TIMESTAMP = "ASCEND_TIMESTAMP" CUSTOMIZED_OPS_LIB_PATH = "CUSTOMIZED_OPS_LIB_PATH" ASCEND_SPARSE_LOOKUP_LOCAL_EMB = "ASCEND_SPARSE_LOOKUP_LOCAL_EMB" + EMPTY_STR = "" +# 获取ConfigInitializer对象实例失败提示信息 +GET_CONFIG_INSTANCE_ERR_MSG = "Please init the environment for mx_rec at first." + # 自动改图模式下从计算图中寻找dataset的锚点名称 ANCHOR_DATASET_NAME = "PrefetchDataset" diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 6903f65d..7b11d836 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -11,7 +11,8 @@ import psutil import mx_rec.constants.constants from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, HASHTABLE_COLLECTION_NAME_LENGTH, \ - TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE, MAX_INT32, TFDevice, Flag + TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE, MAX_INT32, TFDevice, Flag, \ + GET_CONFIG_INSTANCE_ERR_MSG from mx_rec.util.communication.hccl_mgmt import parse_hccl_json, set_hccl_info_without_json from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.validator.validator import StringValidator, FileValidator, para_checker_decorator, ClassValidator, \ @@ -228,7 +229,7 @@ class ConfigInitializer: @staticmethod def get_instance(): if ConfigInitializer._single_instance is None: - raise EnvironmentError("Please init the environment for mx_rec at first.") + raise EnvironmentError(GET_CONFIG_INSTANCE_ERR_MSG) return ConfigInitializer._single_instance @@ -734,7 +735,12 @@ def get_use_dynamic_expansion(): def terminate_config_initializer(): - ConfigInitializer.get_instance().terminate() + try: + ConfigInitializer.get_instance().terminate() + except EnvironmentError as err: + if GET_CONFIG_INSTANCE_ERR_MSG not in str(err): + raise err + logger.warning(GET_CONFIG_INSTANCE_ERR_MSG) def get_name_to_var_dict(): -- Gitee From 97692dfe56cda43c51d858dce0f97b35109e713a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 25 Sep 2023 19:43:31 +0800 Subject: [PATCH 371/551] Match-id-5d8dd6b78e79934c8443d000b68baeae0b3a410e --- mx_rec/core/embedding.py | 16 ++++++++-------- mx_rec/graph/merge_lookup.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 43a4a863..6f4c7c59 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -133,7 +133,7 @@ class SparseEmbedding: self.lookup_info = set() self.lookup_result = dict() self.use_dynamic_expansion = get_use_dynamic_expansion() - self.lookup_name_list = [] + self.lookup_name_dict = {True: [], False: []} self.modify_graph = False self.init_param = config.get("init_param") self.all2all_gradients_op = All2allGradientsOp.mapping(config.get("all2all_gradients_op")) @@ -335,8 +335,8 @@ class SparseEmbedding: SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = kwargs.get("is_train") SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.FEATURE_SPEC] = feature_spec - def check_multi_lookup_times(self): - lookup_times = len(self.lookup_name_list) if self.modify_graph else len(self.lookup_result) + def check_multi_lookup_times(self, is_training): + lookup_times = len(self.lookup_name_dict.get(is_training)) if self.modify_graph else len(self.lookup_result) if not self.modify_graph and get_training_mode_channel_id(True) is not None and \ get_training_mode_channel_id(False) is not None: lookup_times = int(lookup_times / 2) @@ -447,11 +447,11 @@ class SparseEmbedding: # record multi lookup info ids_lookup_name = feature_spec.name + "_lookup_ids" - # set in train mode, train and eval mode, eval mode - if is_training or eval_mode: - self.lookup_name_list.append(ids_lookup_name) + if self.lookup_name_dict.get(is_training) is None: + self.lookup_name_dict[is_training] = [] + self.lookup_name_dict.get(is_training).append(ids_lookup_name) self.modify_graph = kwargs.get("modify_graph", True) - self.check_multi_lookup_times() + self.check_multi_lookup_times(is_training) # return the stub tensor of the lookup result if not get_use_static(): @@ -540,7 +540,7 @@ class SparseEmbedding: is_training) if not self.modify_graph: - self.check_multi_lookup_times() + self.check_multi_lookup_times(is_training) return self.lookup_result.get(spec_name).get(is_training) def split_lookup_result(self, same_table_feature_spec: list, tensor_split_list: list, tensor_list: list, diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py index 7a88e085..3e0df2c0 100644 --- a/mx_rec/graph/merge_lookup.py +++ b/mx_rec/graph/merge_lookup.py @@ -51,7 +51,7 @@ def do_merge_lookup(is_train: bool = True): continue table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) - if not get_use_static() and len(table_instance.lookup_name_list) > 1: + if not get_use_static() and len(table_instance.lookup_name_dict.get(is_train)) > 1: feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) feature_spec_name_ids_dict[feature_spec.name] = cutting_point if sub_cutting_points_dict.get(is_training) is None: @@ -66,7 +66,7 @@ def do_merge_lookup(is_train: bool = True): for cutting_point in sub_cutting_point_list: table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) - if len(table_instance.lookup_name_list) == 1: + if len(table_instance.lookup_name_dict.get(is_train)) == 1: logger.debug("The origin lookup result of %s for %s does not need to be replaced.", feature_spec.name, table_instance.table_name) continue -- Gitee From 74b4ba62de200114b83988513b1e431be88dd3af Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 26 Sep 2023 19:48:38 +0800 Subject: [PATCH 372/551] Match-id-dc31a6af873caabbecb87b54405e38bf3e29eaf7 --- mx_rec/core/embedding.py | 6 +++++- mx_rec/util/initialize.py | 8 ++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 6f4c7c59..dea3fc5d 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -485,6 +485,8 @@ class SparseEmbedding: raise RuntimeError("When the 'feature spec' mode and 'dynamic shape' are used, the 'batch' is required.") table_name = feature_spec.table_name same_table_feature_spec = ConfigInitializer.get_instance().table_name_to_feature_spec[table_name][is_training] + logger.debug("The feature spec of the same table is %s, table name is %s.", + [fs.name for fs in same_table_feature_spec], self.table_name) same_table_spec_count = len(same_table_feature_spec) if same_table_spec_count == 0: raise RuntimeError(f"spec_name {spec_name} not in table {table_name}.") @@ -535,9 +537,11 @@ class SparseEmbedding: kwargs["multi_lookup"] = True total_send_count = self.same_table_send_count if self.modify_graph else send_count * same_table_spec_count lookup_result = self.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, total_send_count, **kwargs) - logger.debug("lookup table %s via %s", table_name, tensor_split_list) + logger.debug("multi lookup table %s via %s.", table_name, tensor_split_list) self.split_lookup_result(same_table_feature_spec, tensor_split_list, tensor_list, lookup_result, is_training) + # 当一表多查完成后,将此表对应的feature specs列表清空,便于estimator模式下多轮eval时不会累加上轮eval的feature specs + ConfigInitializer.get_instance().clear_same_table_feature_spec(self.table_name, is_training) if not self.modify_graph: self.check_multi_lookup_times(is_training) diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 7b11d836..23ee3866 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -253,6 +253,14 @@ class ConfigInitializer: self._is_terminated = True ConfigInitializer._single_instance = None + def clear_same_table_feature_spec(self, table_name, is_training): + if self.table_name_to_feature_spec.get(table_name) is None or \ + self.table_name_to_feature_spec.get(table_name).get(is_training) is None: + raise KeyError("The table name `%s` does not exist in table_name_to_feature_spec, " + "please check whether the insert_feature_spec(...) is invoked.", table_name) + self.table_name_to_feature_spec.get(table_name)[is_training] = [] + logger.debug("The feature spec of the table name `%s` has been cleared.", table_name) + def insert_feature_spec(self, feature, is_training): self._feature_spec_dict[feature.name] = feature if feature.table_name not in self._table_name_to_feature_spec: -- Gitee From d96aa215442c0e0c1948f4a1233a2660e0a7ee10 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 27 Sep 2023 14:26:30 +0800 Subject: [PATCH 373/551] Match-id-20e0ff538f209a86594e0d53cf4f6c5b4a62d837 --- mx_rec/core/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index dea3fc5d..4109ca7a 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -48,7 +48,7 @@ from mx_rec.util.log import logger ("init_param", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), ("all2all_gradients_op", OptionValidator, {"options": [i.value for i in list(All2allGradientsOp)]}), ("value_dtype", OptionValidator, {"options": [tf.float32]}), - ("shard_num", NumValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), + ("shard_num", IntValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), ("fusion_optimizer_var", ClassValidator, {"classes": (bool, )}), ("hashtable_threshold", IntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]) ]) -- Gitee From 88a10e92a23542567adb970b209fc435ec94482c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 27 Sep 2023 22:49:46 +0800 Subject: [PATCH 374/551] Match-id-543512ef359715ab60da7222af43c6de0f1aef95 --- src/core/key_process/key_process.cpp | 56 ++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 34828e4e..e80d6782 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -473,8 +473,11 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, in auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 vector countRecv; countRecv.resize(rs.back() + rc.back()); - MPI_Alltoallv(countSend.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), - rc.data(), rs.data(), MPI_UINT32_T, comm[batch->channel][id]); + auto retCode = MPI_Alltoallv(countSend.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), + rc.data(), rs.data(), MPI_UINT32_T, comm[batch->channel][id]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {}, MPI_Alltoallv failed:{}", rankInfo.rankId, retCode); + } LOG_DEBUG("getCountRecvTC(ms)(with-all2all):{}", getCountRecvTC.ElapsedMS()); return countRecv; } @@ -523,7 +526,10 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) if (!isRunning) { // 通信终止信号,同步退出,防止线程卡住 int exitFlag = isRunning; - MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + auto retCode = MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); + } throw EndRunExit("GetBatchData end run."); } } @@ -678,13 +684,18 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 all2AllInfoOut.keyRecv.resize(rs.back() + rc.back()); EASY_BLOCK("all2all") - MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, all2AllInfoOut.keyRecv.data(), - rc.data(), rs.data(), MPI_INT64_T, comm[channel][id]); - + auto retCode = MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, + all2AllInfoOut.keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[channel][id]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {}, MPI_Alltoallv failed:{}", rankInfo.rankId, retCode); + } all2AllInfoOut.countRecv.resize(rs.back() + rc.back()); if (isWithFAAE) { - MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, all2AllInfoOut.countRecv.data(), - rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); + retCode = MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, + all2AllInfoOut.countRecv.data(), rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {}, MPI_Alltoallv failed:{}", rankInfo.rankId, retCode); + } } LOG_DEBUG("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); EASY_END_BLOCK @@ -736,8 +747,11 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, EASY_BLOCK("all2all") TimeCost uniqueAll2AllTC; - MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, + auto retCode = MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[batch->channel][id]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {}, MPI_Allgather failed:{}", rankInfo.rankId, retCode); + } LOG_DEBUG("uniqueAll2AllTC(ms):{}", uniqueAll2AllTC.ElapsedMS()); EASY_END_BLOCK @@ -980,15 +994,21 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int // 通信终止信号,同步退出,防止线程卡住 TimeCost tc = TimeCost(); int exitFlag = isRunning; - MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + auto retCode = MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {} commId {}, MPI_Allreduce failed:{}", rankInfo.rankId, commId, retCode); + } if (exitFlag < rankInfo.rankSize) { throw EndRunExit("GetScAll end run."); } EASY_END_BLOCK; LOG_DEBUG(KEY_PROCESS "barrier time:{}", tc.ElapsedMS()); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, - scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, + scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {} commId {}, MPI_Allgather failed:{}", rankInfo.rankId, commId, retCode); + } LOG_DEBUG("rank {} key scAll matrix:\n{}", rankInfo.rankId, VectorToString(scAll)); return scAll; } @@ -1001,15 +1021,21 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, in // 通信终止信号,同步退出,防止线程卡住 TimeCost tc = TimeCost(); int exitFlag = isRunning; - MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + auto retCode = MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); + } if (exitFlag < rankInfo.rankSize) { throw EndRunExit("GetScAll end run."); } EASY_END_BLOCK; LOG_DEBUG(KEY_PROCESS "barrier time:{}", tc.ElapsedMS()); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, - scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, + scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {}, MPI_Allgather failed:{}", rankInfo.rankId, retCode); + } LOG_DEBUG("rank {} key scAllOut matrix:\n{}", rankInfo.rankId, VectorToString(scAllOut)); } -- Gitee From 02030f83d3117d30c91fd0e4cdccd7dc74f4d4bc Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 28 Sep 2023 11:37:44 +0800 Subject: [PATCH 375/551] Match-id-889210316bffe32b8452f978c1e5237a98c2e702 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 2bbb7db8..d0253cb0 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -761,7 +761,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha channelId == TRAIN_CHANNEL_ID && remainBatchOut) { vector uniqueKeys; vector restoreVecSec; - preprocess->GlobalUnique(offsetsOut, uniqueKeys, restoreVecSec); + preprocess->GlobalUnique(ddrParam.offsetsOut, uniqueKeys, restoreVecSec); TimeCost sendUnikeysSyncTC; hdTransfer->Send(TransferChannel::UNIQKEYS, { mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : -- Gitee From 3a259c68dcfb072cd4a43a2e416e057d7016a23b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 28 Sep 2023 16:07:06 +0800 Subject: [PATCH 376/551] Match-id-1c078accde32e3d36280286f79bce0aae8b47286 --- src/core/checkpoint/buffer_queue.cpp | 30 +++++++++++ src/core/checkpoint/buffer_queue.h | 30 +++++++++++ src/core/checkpoint/checkpoint.cpp | 80 +++++++++++++++++++++++----- src/core/checkpoint/checkpoint.h | 15 ++++-- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 +- src/core/utils/common.h | 5 +- 6 files changed, 141 insertions(+), 21 deletions(-) create mode 100644 src/core/checkpoint/buffer_queue.cpp create mode 100644 src/core/checkpoint/buffer_queue.h diff --git a/src/core/checkpoint/buffer_queue.cpp b/src/core/checkpoint/buffer_queue.cpp new file mode 100644 index 00000000..ec8d9138 --- /dev/null +++ b/src/core/checkpoint/buffer_queue.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: checkpoint module + * Author: MindX SDK + * Date: 2023/9/28 + * History: NA + */ + +#include "buffer_queue.h" + +using namespace MxRec; +using namespace std; + +void BufferQueue::push(std::vector&& buffer) +{ + std::unique_lock lock(mtx); + bufferQueue.push(std::move(buffer)); + cv.notify_one(); +} + +std::vector BufferQueue::pop() +{ + std::unique_lock lock(mtx); + cv.wait(lock, [this] { + return !bufferQueue.empty(); + }); + auto buffer = std::move(bufferQueue.front()); + bufferQueue.pop(); + return buffer; +} \ No newline at end of file diff --git a/src/core/checkpoint/buffer_queue.h b/src/core/checkpoint/buffer_queue.h new file mode 100644 index 00000000..d635a452 --- /dev/null +++ b/src/core/checkpoint/buffer_queue.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: checkpoint module + * Author: MindX SDK + * Date: 2023/9/28 + * History: NA + */ + +#ifndef MXREC_BUFFER_QUEUE_H +#define MXREC_BUFFER_QUEUE_H + +#include +#include +#include +#include +#include + +namespace MxRec { + class BufferQueue { + public: + void push(std::vector&& buffer); + std::vector pop(); + private: + std::queue> bufferQueue; + std::mutex mtx; + std::condition_variable cv; + }; +} + +#endif diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index d30acb2d..36043227 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "ckpt_data_handler//emb_hash_ckpt/emb_hash_ckpt.h" #include "ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h" @@ -307,15 +308,18 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType) { - ofstream writeFile; - writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); - - if (!writeFile.is_open()) { - LOG_DEBUG("unable to open save file: {}", dataDir); - writeFile.close(); + int fd = open(dataDir.c_str(), O_RDWR | O_CREAT | O_TRUNC, (mode_t)0600); + if (fd == -1) { + LOG_ERROR("Error opening file for writing"); return; } + buffer.reserve(BUFFER_SIZE); + + BufferQueue queue; + + std::thread writer(&Checkpoint::WriterFn, this, std::ref(queue), fd); + int loops = 1; if (dataType == CkptDataType::EMB_DATA) { loops = static_cast(transData.floatArr.size()); @@ -331,9 +335,9 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si writeSize = dataCol; } if (floatTransSet.find(dataType) != floatTransSet.end()) { - writeFile.write(reinterpret_cast(transData.floatArr[i]) + idx, writeSize); + FillToBuffer(queue, reinterpret_cast(transData.floatArr[i]) + idx, writeSize); } else { - WriteDataset(transData, writeFile, writeSize, dataType, idx); + WriteDataset(transData, fd, writeSize, dataType, idx); } dataCol -= writeSize; @@ -341,21 +345,71 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si } } - writeFile.close(); + // After all data has been processed, check if there is any data left in the buffer + if (!buffer.empty()) { + queue.push(std::move(buffer)); + buffer.clear(); + } + + queue.push(std::vector()); + + writer.join(); + + close(fd); +} + +void Checkpoint::WriterFn(BufferQueue& queue, int fd) +{ + while (true) { + auto buffer = queue.pop(); + if (buffer.size() == 0) { + break; + } + ssize_t result = write(fd, buffer.data(), buffer.size()); + if (result != buffer.size()) { + LOG_ERROR("Error writing to file"); + } + buffer.clear(); + } } void Checkpoint::WriteDataset(CkptTransData& transData, - ofstream& writeFile, + int fd, size_t writeSize, CkptDataType dataType, size_t idx) { + ssize_t result; if (int32TransSet.find(dataType) != int32TransSet.end()) { - writeFile.write(reinterpret_cast(transData.int32Arr.data()) + idx, writeSize); + result = write(fd, (const char*)(transData.int32Arr.data()) + idx, writeSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - writeFile.write(reinterpret_cast(transData.int64Arr.data()) + idx, writeSize); + result = write(fd, (const char*)(transData.int64Arr.data()) + idx, writeSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - writeFile.write(reinterpret_cast(transData.attribute.data()) + idx, writeSize); + result = write(fd, (const char*)(transData.attribute.data()) + idx, writeSize); + } + + if (result != writeSize) { + LOG_ERROR("Error writing to file, please check the disk buffer or temporary folder space or file permissions!"); + return; + } +} + +void Checkpoint::FillToBuffer(BufferQueue& queue, const char* data, size_t dataSize) +{ + size_t dataIdx = 0; + while (dataIdx < dataSize) { + size_t remainingSpace = BUFFER_SIZE - buffer.size(); + if (dataSize - dataIdx <= remainingSpace) { + buffer.insert(buffer.end(), data + dataIdx, data + dataSize); + return; + } else { + buffer.insert(buffer.end(), data + dataIdx, data + dataIdx + remainingSpace); + queue.push(std::move(buffer)); + if (BUFFER_SIZE > buffer.capacity()) { + buffer.reserve(BUFFER_SIZE); + } + dataIdx += remainingSpace; + } } } diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index f7ad4ccc..47ba0b7a 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -11,8 +11,9 @@ #include #include #include - +#include "utils/common.h" #include "ckpt_data_handler/ckpt_data_handler.h" +#include "buffer_queue.h" namespace MxRec { using namespace std; @@ -27,6 +28,7 @@ namespace MxRec { const vector& featureTypes); private: + std::vector buffer; const string datasetName { "slice_" }; const string dataFileType { ".data" }; const string attribFileType { ".attribute" }; @@ -81,8 +83,15 @@ namespace MxRec { void SaveDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler); void WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType); - void WriteDataset(CkptTransData& transData, ofstream& writeFile, size_t writeSize, CkptDataType dataType, - size_t idx); + void FillToBuffer(BufferQueue& queue, const char* data, size_t dataSize); + void WriteDataset(CkptTransData& transData, + int fd, + size_t writeSize, + CkptDataType dataType, + size_t idx); + + void WriterFn(BufferQueue& queue, int fd); + void WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize); void ReadEmbedding(CkptTransData& transData, const string& dataDir, const string& embName); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 2bbb7db8..d0253cb0 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -761,7 +761,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha channelId == TRAIN_CHANNEL_ID && remainBatchOut) { vector uniqueKeys; vector restoreVecSec; - preprocess->GlobalUnique(offsetsOut, uniqueKeys, restoreVecSec); + preprocess->GlobalUnique(ddrParam.offsetsOut, uniqueKeys, restoreVecSec); TimeCost sendUnikeysSyncTC; hdTransfer->Send(TransferChannel::UNIQKEYS, { mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 47d1a8ff..0a434c3d 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -10,15 +10,11 @@ #define COMMON_H #include - #include - #include #include #include -#include #include -#include #include #include #include @@ -81,6 +77,7 @@ namespace MxRec { // validate file constexpr long long FILE_MAX_SIZE = 1LL << 40; constexpr int FILE_MIN_SIZE = 0; + constexpr size_t BUFFER_SIZE{1024 * 1024 * 64}; constexpr int KEY_PROCESS_TIMEOUT = 120; constexpr int GET_BATCH_TIMEOUT = 300; -- Gitee From 1898f820275ee723bda965f557241bff03b92010 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 28 Sep 2023 16:34:45 +0800 Subject: [PATCH 377/551] Match-id-ca35400726e940a5dd124b62b58f72f74db393dd --- src/core/host_emb/host_emb.cpp | 25 +++++++++------------- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 2 +- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index b189e8b3..de56ebdb 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -136,30 +136,25 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI EASY_FUNCTION(profiler::colors::Purple) auto updateThread = [this, missingKeysHostPos, channelId, embName] { + LOG_INFO(HOSTEMB + "UpdateEmbV2, channelId:{}, embName:{}", channelId, embName); + EASY_FUNCTION(profiler::colors::Purple); + TimeCost tc = TimeCost(); auto hdTransfer = Singleton::GetInstance(); TransferChannel transferName = TransferChannel::D2H; LOG_INFO(HOSTEMB + "wait D2H embs, channelId:{}", channelId); - auto size = hdTransfer->RecvAcl(transferName, channelId, embName); - if (size == 0) { + const auto tensors = hdTransfer->Recv(transferName, channelId, embName); + if (tensors.empty()) { LOG_WARN(HOSTEMB + "recv empty data"); return; } - TimeCost tc = TimeCost(); - + const Tensor& d2hEmb = tensors[0]; EASY_BLOCK("Update") - auto& embData = hostEmbs[embName].embData; + const float* ptr = d2hEmb.flat().data(); auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; - auto aclData = acltdtGetDataItem(hdTransfer->aclDatasets[embName], 0); - if (aclData == nullptr) { - throw runtime_error("Acl get tensor data from dataset failed."); - } - float* ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); + auto& embData = hostEmbs[embName].embData; - size_t elementSize = acltdtGetDataSizeFromItem(aclData); - size_t dimNum = acltdtGetDimNumFromItem(aclData); - LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}," - " embData.size = {}, RecvAcl = {}, elementSize = {}, dimNum = {}", - embName, missingKeysHostPos.size(), embeddingSize, embData.size(), size, elementSize, dimNum); + LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}, embData.size = {}", + embName, missingKeysHostPos.size(), embeddingSize, embData.size()); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ptr, embData, embeddingSize) for (size_t j = 0; j < missingKeysHostPos.size(); j++) { auto& dst = embData[missingKeysHostPos[j]]; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index b11f8d44..bde5ce80 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -132,7 +132,7 @@ void HybridMgmtBlock::DoBlock(int channelId) LOG_DEBUG(HYBRID_BLOCKING + "HybridMgmt starts blocking channelId {} hybridBatchId {}", channelId, hybridBatchId[channelId]); - while (isBlock[channelId] and !rankInfo.noDDR) { + while (isBlock[channelId]) { std::this_thread::sleep_for(SLEEP_MS); if (!isRunning) { return; -- Gitee From 5a5f6f5f983a83b2a57f2e3a7589c0c7f5297651 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 28 Sep 2023 17:01:12 +0800 Subject: [PATCH 378/551] Match-id-f067e3e4c7c40d9d9c13c04e585720b9eee85e29 --- .../op_kernel/embedding_lookup_by_address.cpp | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 235a9561..4296fada 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -111,7 +111,7 @@ private: for (int i = 0; i < sizes; i++) { - dataLocal = isFull ? inQueue.AllocTensor() : dataLocal; + dataLocal = inQueue.AllocTensor(); int64_t address = srcAddrLocal.GetValue(i); if (address != 0) @@ -128,18 +128,9 @@ private: } } - - nums++; - isFull = ( i == tmp_cache || i == sizes - 1); - if (isFull) - { inQueue.EnQue(dataLocal); - Compute(nums); - CopyOut(out_index, turns, nums); - nums = 0; - out_index = i + 1; - tmp_cache += cache; - } + Compute(1); + CopyOut(out_index, turns, 1); } } -- Gitee From 9baa6e0e0e41c657c709a718a8fc2a7921286077 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 6 Oct 2023 15:29:28 +0800 Subject: [PATCH 379/551] Match-id-1bb246dbab2a968beee5646c7fc21566b67b51a6 --- .../cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 4296fada..e54a8596 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -130,7 +130,7 @@ private: } inQueue.EnQue(dataLocal); Compute(1); - CopyOut(out_index, turns, 1); + CopyOut(i, turns, 1); } } -- Gitee From cdb3e2c939a9fa9818718b7dee89fd96047192a1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 6 Oct 2023 16:34:47 +0800 Subject: [PATCH 380/551] Match-id-2795ff2120491c3f892e4a1a5105937a3adcbad3 --- src/core/utils/logger.cpp | 2 +- src/core/utils/logger.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/utils/logger.cpp b/src/core/utils/logger.cpp index 3e8c5339..b12f9870 100644 --- a/src/core/utils/logger.cpp +++ b/src/core/utils/logger.cpp @@ -42,7 +42,7 @@ const char* Logger::LevelToStr(int logLevel) "ERROR", }; constexpr int levelOffset = 2; - return msg[level + levelOffset]; + return msg[logLevel + levelOffset]; } void Logger::LogUnpack(std::queue& fmt, std::stringstream &ss) diff --git a/src/core/utils/logger.h b/src/core/utils/logger.h index 68028f11..dc760e60 100644 --- a/src/core/utils/logger.h +++ b/src/core/utils/logger.h @@ -69,7 +69,7 @@ public: struct timeval tv; gettimeofday(&tv, nullptr); localtime_r(&tv.tv_sec, &t); - ss << "[MxRec][" << YEAR_BASE + t.tm_year << "/" << t.tm_mon << "/" << t.tm_mday<< " " + ss << "[MxRec][" << YEAR_BASE + t.tm_year << "/" << 1 + t.tm_mon << "/" << t.tm_mday<< " " << t.tm_hour << ":" << t.tm_min << ":" << t.tm_sec << "." << tv.tv_usec << "] [" << Logger::rank << "] ["<< Logger::LevelToStr(level) << "] [" << (strrchr(file, '/') ? strrchr(file, '/') + 1 : file) << ":" << line << "] "; -- Gitee From cde442a2936a05f37f20f1eb715c1afbf42738c3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 7 Oct 2023 14:06:15 +0800 Subject: [PATCH 381/551] Match-id-915a6b5785171b4796c63d99acbbfa8f209e8c16 --- mx_rec/graph/patch.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 5fab4a8c..2d5abf61 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -19,6 +19,7 @@ from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.training.optimizer import Optimizer from tensorflow.python.client.session import BaseSession +from mx_rec.constants import constants from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph, insert_bool_gauge, \ get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round, get_asc_manager from mx_rec.util.tf_version_adapter import NPUCheckpointSaverHook @@ -103,8 +104,6 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): name2channel_cache[name_list_str_key] = this_channel_id return this_channel_id - # patch的方式为session增加步数属性 - steps = self.get_mxrec_steps() # patch的方式为图增加缓存属性 name2channel_cache = self.get_mxrec_name2channel_cache() @@ -118,6 +117,13 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): if channel_id != -1: get_asc_manager().block_notify_wake(channel_id) + if channel_id == constants.EVAL_CHANNEL_ID: + # eval的时候不进行循环下沉 + steps = 1 + else: + # patch的方式为session增加步数属性 + steps = self.get_mxrec_steps() + # 调用tensorflow原生的方法 result = self.old_run_method(fetches, feed_dict, options, run_metadata) if channel_id != -1: -- Gitee From cc90170a007ebb0149657022bdff567bd232e3e1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 8 Oct 2023 11:11:09 +0800 Subject: [PATCH 382/551] Match-id-a80a0b162ae407be0c72b575ec94805da4beaa80 --- src/core/key_process/key_process.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index e80d6782..c8c7304a 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1298,11 +1298,11 @@ void KeyProcess::EvictKeysCombine(const vector& keys) // hbm LOG_INFO(KEY_PROCESS "hbm combine funEvictCall, keySize:{}", keys.size()); // 删除映射关系 if (keys.size() != 0) { - for (auto& map : keyOffsetMap) { + for (const auto& map : keyOffsetMap) { EvictDeleteDeviceEmb(map.first, keys); } } - for (auto map : evictPosMap) { + for (const auto map : evictPosMap) { // 初始化 dev EvictInitDeviceEmb(map.first, map.second); } -- Gitee From 92e44621ca0e193845fc4c4b896a4d226a1e3908 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 9 Oct 2023 15:21:18 +0800 Subject: [PATCH 383/551] Match-id-3d98c66cc6350274617ee8aa5f284ace3561309e --- src/core/checkpoint/buffer_queue.cpp | 7 +++--- src/core/checkpoint/buffer_queue.h | 5 ++--- src/core/checkpoint/checkpoint.cpp | 32 ++++++++++++++-------------- src/core/checkpoint/checkpoint.h | 1 + 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/core/checkpoint/buffer_queue.cpp b/src/core/checkpoint/buffer_queue.cpp index ec8d9138..0d289359 100644 --- a/src/core/checkpoint/buffer_queue.cpp +++ b/src/core/checkpoint/buffer_queue.cpp @@ -11,20 +11,19 @@ using namespace MxRec; using namespace std; -void BufferQueue::push(std::vector&& buffer) +void BufferQueue::Push(std::vector &&buffer) { std::unique_lock lock(mtx); bufferQueue.push(std::move(buffer)); cv.notify_one(); } -std::vector BufferQueue::pop() +void BufferQueue::Pop(std::vector& buffer) { std::unique_lock lock(mtx); cv.wait(lock, [this] { return !bufferQueue.empty(); }); - auto buffer = std::move(bufferQueue.front()); + buffer = std::move(bufferQueue.front()); bufferQueue.pop(); - return buffer; } \ No newline at end of file diff --git a/src/core/checkpoint/buffer_queue.h b/src/core/checkpoint/buffer_queue.h index d635a452..cf38dff8 100644 --- a/src/core/checkpoint/buffer_queue.h +++ b/src/core/checkpoint/buffer_queue.h @@ -12,14 +12,13 @@ #include #include #include -#include #include namespace MxRec { class BufferQueue { public: - void push(std::vector&& buffer); - std::vector pop(); + void Push(std::vector &&buffer); + void Pop(std::vector& buffer); private: std::queue> bufferQueue; std::mutex mtx; diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 36043227..3cab55dc 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -308,7 +308,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType) { - int fd = open(dataDir.c_str(), O_RDWR | O_CREAT | O_TRUNC, (mode_t)0600); + int fd = open(dataDir.c_str(), O_RDWR | O_CREAT | O_TRUNC, static_cast(0600)); if (fd == -1) { LOG_ERROR("Error opening file for writing"); return; @@ -347,11 +347,11 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si // After all data has been processed, check if there is any data left in the buffer if (!buffer.empty()) { - queue.push(std::move(buffer)); + queue.Push(std::move(buffer)); buffer.clear(); } - queue.push(std::vector()); + queue.Push(std::vector()); writer.join(); @@ -361,15 +361,15 @@ void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, si void Checkpoint::WriterFn(BufferQueue& queue, int fd) { while (true) { - auto buffer = queue.pop(); - if (buffer.size() == 0) { + queue.Pop(writeBuffer); + if (writeBuffer.size() == 0) { break; } - ssize_t result = write(fd, buffer.data(), buffer.size()); - if (result != buffer.size()) { + ssize_t result = write(fd, writeBuffer.data(), writeBuffer.size()); + if (result != writeBuffer.size()) { LOG_ERROR("Error writing to file"); } - buffer.clear(); + writeBuffer.clear(); } } @@ -379,13 +379,13 @@ void Checkpoint::WriteDataset(CkptTransData& transData, CkptDataType dataType, size_t idx) { - ssize_t result; + ssize_t result = 0; if (int32TransSet.find(dataType) != int32TransSet.end()) { - result = write(fd, (const char*)(transData.int32Arr.data()) + idx, writeSize); + result = write(fd, reinterpret_cast(transData.int32Arr.data()) + idx, writeSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - result = write(fd, (const char*)(transData.int64Arr.data()) + idx, writeSize); + result = write(fd, reinterpret_cast(transData.int64Arr.data()) + idx, writeSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - result = write(fd, (const char*)(transData.attribute.data()) + idx, writeSize); + result = write(fd, reinterpret_cast(transData.attribute.data()) + idx, writeSize); } if (result != writeSize) { @@ -400,12 +400,12 @@ void Checkpoint::FillToBuffer(BufferQueue& queue, const char* data, size_t dataS while (dataIdx < dataSize) { size_t remainingSpace = BUFFER_SIZE - buffer.size(); if (dataSize - dataIdx <= remainingSpace) { - buffer.insert(buffer.end(), data + dataIdx, data + dataSize); + buffer.insert(buffer.cend(), data + dataIdx, data + dataSize); return; } else { - buffer.insert(buffer.end(), data + dataIdx, data + dataIdx + remainingSpace); - queue.push(std::move(buffer)); - if (BUFFER_SIZE > buffer.capacity()) { + buffer.insert(buffer.cend(), data + dataIdx, data + dataIdx + remainingSpace); + queue.Push(std::move(buffer)); + if (buffer.capacity() < BUFFER_SIZE) { buffer.reserve(BUFFER_SIZE); } dataIdx += remainingSpace; diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 47ba0b7a..d4a4ac2c 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -29,6 +29,7 @@ namespace MxRec { private: std::vector buffer; + std::vector writeBuffer; const string datasetName { "slice_" }; const string dataFileType { ".data" }; const string attribFileType { ".attribute" }; -- Gitee From a738835dcad837af4179acbf6fae7d1b4cf8d0f4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 9 Oct 2023 15:34:53 +0800 Subject: [PATCH 384/551] Match-id-3606f3451d845eb841906f08b3854f4be264b278 --- src/core/emb_table/emb_table.cpp | 1 + src/core/ssd_engine/table.cpp | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 85fde413..cd8b112d 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -61,6 +61,7 @@ EmbTable::~EmbTable() if (ret != ACL_SUCCESS) { LOG_ERROR("aclrtFree failed, ret={}", ret); } + block = nullptr; } #endif } diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index e3627059..a5f1c546 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -4,7 +4,6 @@ #include "table.h" -#include "utils/common.h" using namespace MxRec; @@ -271,7 +270,10 @@ vector> Table::FetchEmbeddingsInner(vector &keys) size_t dLen = keys.size(); unordered_map, shared_ptr, vector>>> miniBatch; for (size_t i = 0; i < dLen; ++i) { - auto it = keyToFile.find(keys[i]); + auto it = as_const(keyToFile).find(keys[i]); + if (it == keyToFile.end()) { + throw invalid_argument(StringFormat("failed to find the key, {key=%d} not exist!", keys[i])); + } if (miniBatch[it->second] == nullptr) { miniBatch[it->second] = make_shared, vector>>(); } -- Gitee From 701ccbdbf6ee8b9434f342be3e57bd7871c896d8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 9 Oct 2023 20:30:48 +0800 Subject: [PATCH 385/551] Match-id-7dbde16e1b65299665cf62604aa95b1d19ff595e --- .../op_host/embedding_lookup_by_address.cpp | 60 +++++++++++++----- .../op_host/embedding_update_by_address.cpp | 62 +++++++++++++------ 2 files changed, 86 insertions(+), 36 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 18e3f295..52ca55f4 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -4,20 +4,30 @@ namespace optiling { + + template + static ge::graphStatus CheckNullPointer(T *pointer, const char *errorMessage) + { + if (pointer == nullptr) { + printf("%s nullptr\n", errorMessage); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; + } + static ge::graphStatus TilingFunc(gert::TilingContext *context) { TilingData1 tiling; size_t usrSize = 256; size_t sysWorkspaceSize = 16 * 1024 * 1024; - if (context == nullptr) { - printf("Tiling context nullptr\n"); + if (CheckNullPointer(context, "Tiling context") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } size_t *currentWorkspace = context->GetWorkspaceSizes(1); - if (currentWorkspace == nullptr) { - printf("currentWorkspace nullptr\n"); + if (CheckNullPointer(currentWorkspace, "currentWorkspace") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } @@ -27,8 +37,7 @@ namespace optiling int32_t ub_limit = 175 * 1024; auto *attrs = context->GetAttrs(); const auto *attr0_value = attrs->GetAttrPointer(0); - if (attr0_value == nullptr) { - printf(" Lookup embbeding_type attr0_value nullptr\n"); + if (CheckNullPointer(attr0_value, " Lookup embbeding_type attr0_value") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } @@ -39,14 +48,18 @@ namespace optiling } const auto *attr1_value = attrs->GetAttrPointer(1); - if (attr1_value == nullptr) { - printf(" Lookup embbeding_type attr1_value nullptr\n"); + if (CheckNullPointer(attr1_value, "Lookup embbeding_type attr1_value") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } int32_t embbeding_type = *attr1_value; - int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); + auto inputTensor = context->GetInputTensor(0); + if (CheckNullPointer(inputTensor, "inputTensor") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + + int32_t input_shape = inputTensor->GetShapeSize(); tiling.set_embbeding_type(embbeding_type); tiling.set_update_dim(embbeding_dim); @@ -67,10 +80,17 @@ namespace ge { gert::Shape *y_shape = context->GetOutputShape(0); + if (optiling::CheckNullPointer(y_shape, "y_shape") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + auto *attrs = context->GetAttrs(); + if (optiling::CheckNullPointer(attrs, "attrs") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + const auto *attr0_value = attrs->GetAttrPointer(0); - if (attr0_value == nullptr) { - printf(" Lookup embbeding_type attr0_value nullptr\n"); + if (optiling::CheckNullPointer(attr0_value, "Lookup embbeding_type attr0_value") != ge::GRAPH_SUCCESS) { return GRAPH_FAILED; } @@ -86,15 +106,21 @@ namespace ge { int64_t embbeding_type; + if (optiling::CheckNullPointer(context, "context") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + auto *attrs = context->GetAttrs(); - const auto *attr1_value = attrs->GetAttrPointer(1); - if (attr1_value == nullptr) { - printf(" Lookup embbeding_type nullptr\n"); + if (optiling::CheckNullPointer(attrs, "attrs") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; } - else - { - embbeding_type = *attr1_value; + + const auto *attr1_value = attrs->GetAttrPointer(1); + if (optiling::CheckNullPointer(attr1_value, "Lookup embbeding_type") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; } + + embbeding_type = *attr1_value; if (embbeding_type == 0) { context->SetOutputDataType(0, ge::DataType(DT_INT32)); diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index b75a5912..0d82535e 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -5,41 +5,67 @@ namespace optiling { + template + static ge::graphStatus CheckPointer(T *pointer, const char *errorMessage) + { + if (pointer == nullptr) { + printf("%s nullptr\n", errorMessage); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; + } + + static ge::graphStatus CheckPositiveInt(int32_t value, const char *errorMessage) + { + if (value <= 0) { + printf("%s must larger than 0\n", errorMessage); + return ge::GRAPH_FAILED; + } + + return ge::GRAPH_SUCCESS; + } + static ge::graphStatus TilingFunc(gert::TilingContext *context) { TilingData2 tiling; size_t usrSize = 256, sysWorkspaceSize = 16 * 1024 * 1024; - if (context == nullptr) { - printf("Update embbeding_type context nullptr\n"); + if (CheckPointer(context, "Update embbeding_type context") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; - } size_t *currentWorkspace = context->GetWorkspaceSizes(1); - if (currentWorkspace == nullptr) { - printf("currentWorkspace nullptr\n"); + if (CheckPointer(currentWorkspace, "currentWorkspace") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; - } currentWorkspace[0] = sysWorkspaceSize + usrSize; int32_t block_total_nums = 48; int32_t ub_limit = 175 * 1024; int32_t update_dim, embbeding_type; - int32_t input_shape = context->GetInputTensor(0)->GetShapeSize(); - if (input_shape <= 0) { - printf("input_shape must larger than 0\n"); + auto inputTensor = context->GetInputTensor(0); + if (CheckPointer(inputTensor, "GetInputTensor inputTensor") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; - } - int32_t input_dim = context->GetInputTensor(1)->GetShapeSize() / input_shape; - if (context->GetAttrs()->GetAttrPointer(0) == nullptr) { - printf("context GetAttrs GetAttrPointer nullptr\n"); + int32_t input_shape = inputTensor->GetShapeSize(); + if (CheckPositiveInt(input_shape, "input_shape") != ge::GRAPH_SUCCESS) + return ge::GRAPH_FAILED; + + auto inputTensor1 = context->GetInputTensor(1); + if (CheckPointer(inputTensor1, "GetInputTensor inputTensor1") != ge::GRAPH_SUCCESS) + return ge::GRAPH_FAILED; + + int32_t input_dim = inputTensor1->GetShapeSize() / input_shape; + auto attrs = context->GetAttrs(); + if (CheckPointer(attrs, "GetAttrs attrs") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; - } - int32_t update_type = *(context->GetAttrs()->GetAttrPointer(0)); - ge::DataType input_datatype = context->GetInputTensor(1)->GetDataType(); + auto attrPointer = attrs->GetAttrPointer(0); + if (CheckPointer(attrPointer, "attrPointer") != ge::GRAPH_SUCCESS) + return ge::GRAPH_FAILED; + + int32_t update_type = *(attrPointer); + ge::DataType input_datatype = inputTensor1->GetDataType(); if (input_datatype == ge::DT_FLOAT16) { embbeding_type = 2; } else if (input_datatype == ge::DT_INT32) { @@ -49,10 +75,8 @@ namespace optiling } update_dim = input_dim; - if (update_dim <= 0) { - printf("update_dim must larger than 0\n"); + if (CheckPositiveInt(update_dim, "update_dim") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; - } tiling.set_update_type(update_type); tiling.set_embbeding_type(embbeding_type); -- Gitee From 64a8f42f623003930dc423664d2b8ab5b8cd924e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 13 Oct 2023 10:53:20 +0800 Subject: [PATCH 386/551] Match-id-fe815b40c3fbb0683488ea142afac41fa4af9daf --- build/build_all.sh | 5 ++--- build/build_tf1.sh | 3 --- build/build_tf2.sh | 2 -- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/build/build_all.sh b/build/build_all.sh index 26719538..bd33cc5f 100644 --- a/build/build_all.sh +++ b/build/build_all.sh @@ -13,13 +13,12 @@ ROOT_DIR=$(dirname "${SCRIPT_DIR}") cd "$SCRIPT_DIR" if [ "$(uname -m)" = "aarch64" ] then - virtualenv -p "$(which python3.7)" tf2_env source tf2_env/bin/activate tf265="tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl" [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow deactivate tf2_env - virtualenv -p "$(which python3.7)" tf1_env + source tf1_env/bin/activate tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl" [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ @@ -34,7 +33,7 @@ then [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow deactivate tf2_env - virtualenv -p "$(which python3.7)" tf1_env + source tf1_env/bin/activate tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl" [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ diff --git a/build/build_tf1.sh b/build/build_tf1.sh index 6d7deabf..d2bb5b42 100644 --- a/build/build_tf1.sh +++ b/build/build_tf1.sh @@ -14,7 +14,6 @@ cd "$SCRIPT_DIR" if [ "$(uname -m)" = "x86_64" ] then - virtualenv -p "$(which python3.7)" tf1_env source /opt/buildtools/tf1_env/bin/activate tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core deactivate tf1_env @@ -144,8 +143,6 @@ then compile_acc_ctr_so_file echo "-----Build Start tf1 -----" - virtualenv -p "$(which python3.7)" tf1_env - echo "--tf1 env ${env}---" source /opt/buildtools/tf1_env/bin/activate compile_so_file "${tf1_path}" collect_so_file diff --git a/build/build_tf2.sh b/build/build_tf2.sh index e3d01417..deebfe76 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -14,7 +14,6 @@ cd "$SCRIPT_DIR" if [ "$(uname -m)" = "x86_64" ] then - virtualenv -p "$(which python3.7)" tf2_env source /opt/buildtools/tf2_env/bin/activate tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow deactivate tf2_env @@ -143,7 +142,6 @@ then compile_acc_ctr_so_file echo "-----Build Start tf2 -----" - virtualenv -p "$(which python3.7)" tf2_env source /opt/buildtools/tf2_env/bin/activate compile_so_file "${tf2_path}" collect_so_file -- Gitee From b666f19c2fe3aaeb9c7cf0310535822d232f4dfc Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 11 Oct 2023 11:47:55 +0800 Subject: [PATCH 387/551] Match-id-a2fdbe341b3858d62eead1f5b200eb49a834b33d --- src/CMakeLists.txt | 3 ++- src/core/ssd_engine/file.cpp | 2 ++ src/core/utils/common.cpp | 8 ++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6d1aa6de..fa4e1fa5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,6 +15,7 @@ endif () if (DEFINED PYTHON_PATH) set(PYTHON_INCLUDE_DIR ${PYTHON_PATH}/include/python3.7m) set(PYTHON_LIBRARY ${PYTHON_PATH}/lib/libpython3.7m.so) + set(PYTHON_LIB_PATH ${PYTHON_PATH}/lib) else () message("ERROR no PYTHON_PATH") endif () @@ -43,7 +44,7 @@ else () endif () set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -ffunction-sections -O0 -Wall -g2 -ggdb") set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -ffunction-sections -O3 -Wfatal-errors -DNDEBUG -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -s") -set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack") +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -Wl,--disable-new-dtags,--rpath=${PYTHON_LIB_PATH}") option(ENABLE_DEBUG "use debug mode" OFF) if (ENABLE_DEBUG) diff --git a/src/core/ssd_engine/file.cpp b/src/core/ssd_engine/file.cpp index fe6d9724..7e5592f0 100644 --- a/src/core/ssd_engine/file.cpp +++ b/src/core/ssd_engine/file.cpp @@ -20,6 +20,7 @@ File::File(uint64_t fileID, string &fileDir) : fileID(fileID), fileDir(fileDir) throw runtime_error("fail to create Save directory"); } + // latest file is temporary, unnecessary to check file existence and privilege metaFilePath = fs::absolute(fileDir + "/" + to_string(fileID) + ".meta.latest"); dataFilePath = fs::absolute(fileDir + "/" + to_string(fileID) + ".data.latest"); localFileMeta.open(metaFilePath, ios::out | ios::trunc | ios::binary); @@ -57,6 +58,7 @@ File::File(uint64_t fileID, string &fileDir, string &loadDir, int step) : fileID ValidateReadFile(metaFileToLoad, fs::file_size(metaFileToLoad)); ValidateReadFile(dataFileToLoad, fs::file_size(dataFileToLoad)); + // latest file is temporary, unnecessary to check file existence and privilege metaFilePath = fs::absolute(fileDir + "/" + to_string(fileID) + ".meta.latest"); dataFilePath = fs::absolute(fileDir + "/" + to_string(fileID) + ".data.latest"); fs::remove(metaFilePath); diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 9bf5297e..48c3a8f4 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include @@ -21,6 +22,8 @@ using namespace std; using std::chrono::system_clock; namespace MxRec { + namespace fs = std::experimental::filesystem; + bool g_isGlogInit = false; bool GlogConfig::gStatOn = false; int GlogConfig::gGlogLevel; @@ -104,6 +107,11 @@ namespace MxRec { LOG_ERROR("the reading file size is invalid, not in range [{},{}]", FILE_MIN_SIZE, FILE_MAX_SIZE); throw invalid_argument(StringFormat("file size invalid")); } + // validate file privilege + fs::perms permissions = fs::status(dataDir).permissions(); + if ((permissions & fs::perms::owner_read) == fs::perms::none) { + throw invalid_argument(StringFormat("no read permission for file:%s", dataDir.c_str())); + } } ostream& operator<<(ostream& ss, MxRec::CkptDataType type) -- Gitee From 5b3a84b6ba23e816b0211be3061594d31b8cda00 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 24 Oct 2023 20:36:32 +0800 Subject: [PATCH 388/551] Match-id-4be60ff587cf4d1f80f05010d4857b3f3db94b86 --- .../op_kernel/embedding_lookup_by_address.cpp | 89 ++++++++++--------- .../op_kernel/embedding_update_by_address.cpp | 4 +- src/core/key_process/key_process.h | 4 +- 3 files changed, 54 insertions(+), 43 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index e54a8596..0adf5d2e 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -14,7 +14,7 @@ public: NeedComputeAddrLen = SingleCoreAddrLen; if (block_idx == block_num - 1) { - NeedComputeAddrLen = addr_nums * sizeof(int64_t) - SingleCoreAddrLen * (block_num - 1); + NeedComputeAddrLen = addrNums * sizeof(int64_t) - SingleCoreAddrLen * (block_num - 1); } round = NeedComputeAddrLen / (roundSize * sizeof(int64_t)); // pipe alloc memory to queue, the unit is Bytes @@ -32,12 +32,12 @@ public: { GET_TILING_DATA(constData, tiling); // 数据的维度数 - int32_t update_dim = constData.update_dim; - int32_t embbeding_type = constData.embbeding_type; - int32_t block_total_nums = block_num; - int32_t ub_limit = constData.ub_limit; - addr_nums = constData.addr_nums; - if (embbeding_type == 2) + int32_t updateDim = constData.update_dim; + int32_t embeddingType = constData.embbeding_type; + int32_t blockTotalNums = block_num; + int32_t ubLimit = constData.ub_limit; + addrNums = constData.addr_nums; + if (embeddingType == 2) { singleDataSize = 2; } @@ -47,24 +47,24 @@ public: } // 缓冲区数量 PingpongNum = 1; - int min_move_num = 32 / singleDataSize; - // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + min_move_num) / min_move_num表示除以min_move_num向下取整 - onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); - int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums; + int minMoveNum = 32 / singleDataSize; + // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + minMoveNum) / minMoveNum表示除以minMoveNum向下取整 + onceMoveNums = minMoveNum * ((int)(updateDim - 1 + minMoveNum) / minMoveNum); + int numToMove = (int32_t)(updateDim - 1 + onceMoveNums) / onceMoveNums; // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 - int occupyAddressBytesNum = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2; + int occupyAddressBytesNum = sizeof(int64_t) + singleDataSize * onceMoveNums * numToMove * PingpongNum * 2; // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 - int addrMaxNum = ((int)((int)(ub_limit / occupyAddressBytesNum) / 4)) * 4; - int singlenum = (int)(addr_nums / block_total_nums); - if (singlenum % 4) + int addrMaxNum = ((int)((int)(ubLimit / occupyAddressBytesNum) / 4)) * 4; + int singleNum = (int)(addrNums / blockTotalNums); + if (singleNum % 4) { - singlenum -= singlenum % 4; + singleNum -= singleNum % 4; } roundSize = addrMaxNum; Veclen = roundSize * singleDataSize * onceMoveNums; - SingleCoreAddrLen = singlenum * sizeof(int64_t); + SingleCoreAddrLen = singleNum * sizeof(int64_t); cache = roundSize; - dim = update_dim; + dim = updateDim; } __aicore__ inline void Process() @@ -81,18 +81,18 @@ public: } } - int unprocess = (NeedComputeAddrLen / sizeof(int64_t)) % roundSize; - if (unprocess) + int unProcess = (NeedComputeAddrLen / sizeof(int64_t)) % roundSize; + if (unProcess) { // 处理 addresslist 不对齐32b - int unprocess_once_copyaddr = unprocess; - if (unprocess_once_copyaddr % 4 != 0) + int unProcessOnceCopyAddr = unProcess; + if (unProcessOnceCopyAddr % 4 != 0) { - unprocess_once_copyaddr += (4 - unprocess % 4); + unProcessOnceCopyAddr += (4 - unProcess % 4); } - DataCopy(srcAddrLocal, srcAddrGlobal[round * roundSize], unprocess_once_copyaddr); - MoveProcess(srcAddrLocal, round, unprocess); + DataCopy(srcAddrLocal, srcAddrGlobal[round * roundSize], unProcessOnceCopyAddr); + MoveProcess(srcAddrLocal, round, unProcess); } } @@ -101,17 +101,17 @@ private: { set_flag(PIPE_MTE2, PIPE_S, 0); wait_flag(PIPE_MTE2, PIPE_S, 0); - LocalTensor dataLocal; - bool isFull = true; + LocalTensor dataLocal = inQueue.AllocTensor(); + bool isFull = false; int nums = 0; - int out_index = 0; + int outIndex = 0; int times = onceMoveNums / 8; - int tmp_cache = cache - 1; + int tmpCache = cache - 1; for (int i = 0; i < sizes; i++) { - dataLocal = inQueue.AllocTensor(); + dataLocal = isFull ? inQueue.AllocTensor() : dataLocal; int64_t address = srcAddrLocal.GetValue(i); if (address != 0) @@ -128,9 +128,18 @@ private: } } - inQueue.EnQue(dataLocal); - Compute(1); - CopyOut(i, turns, 1); + + nums++; + isFull = (i == tmpCache || i == sizes - 1); + if (isFull) + { + inQueue.EnQue(dataLocal); + Compute(nums); + CopyOut(outIndex, turns, nums); + nums = 0; + outIndex = i + 1; + tmpCache += cache; + } } } @@ -140,10 +149,10 @@ private: LocalTensor srcLocal = inQueue.DeQue(); LocalTensor dstLocal = outQueue.AllocTensor(); - DataCopyParams copyparams; - copyparams.blockCount = 1; - copyparams.blockLen = onceMoveNums * sizeof(T) * nums / 32; - DataCopy(dstLocal, srcLocal, copyparams); + DataCopyParams copyParams; + copyParams.blockCount = 1; + copyParams.blockLen = onceMoveNums * sizeof(T) * nums / 32; + DataCopy(dstLocal, srcLocal, copyParams); outQueue.EnQue(dstLocal); inQueue.FreeTensor(srcLocal); @@ -174,7 +183,7 @@ private: public: int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, cache, Veclen, dim, PingpongNum; - int32_t addr_nums; + int32_t addrNums; int32_t onceMoveNums, singleDataSize, update_type; private: @@ -191,9 +200,9 @@ extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR addres { GET_TILING_DATA(constData, tiling); - int32_t embbeding_type = constData.embbeding_type; + int32_t embeddingType = constData.embbeding_type; - switch (embbeding_type) + switch (embeddingType) { case 0: { diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index 6fc875a6..d13075e0 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -106,7 +106,7 @@ private: if (dim == onceMoveNums) { dataLocal = inQueue.AllocTensor(); - DataCopy(dataLocal, srcDataBufferGm[turns * roundSize], sizes * onceMoveNums); + DataCopy(dataLocal, srcDataBufferGm[turns * roundSize * dim], sizes * onceMoveNums); inQueue.EnQue(dataLocal); Compute(sizes); LocalTensor dstLocal = outQueue.DeQue(); @@ -134,7 +134,7 @@ private: for (int i = 0; i < sizes; i++) { dataLocal = inQueue.AllocTensor(); - DataCopy(dataLocal, srcDataBufferGm[i * dim + turns * roundSize], onceMoveNums); + DataCopy(dataLocal, srcDataBufferGm[i * dim + turns * roundSize * dim], onceMoveNums); inQueue.EnQue(dataLocal); Compute(1); address = srcAddrLocal.GetValue(i); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index a15bbc71..fba24d0e 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -121,9 +121,11 @@ namespace MxRec { for (size_t i = 0; i < lookupKeys.size(); ++i) { int64_t key = lookupKeys[i]; - if (rankInfo.useStatic && key == -1) { + if (rankInfo.useStatic && ( + (!rankInfo.useDynamicExpansion && key == -1) || (rankInfo.useDynamicExpansion && key == 0))) { continue; } + auto result = umap.find(key); if (result == umap.end()) { uniqueKeys.push_back(lookupKeys[i]); -- Gitee From d2d61b1b01fdd07010957ea79c3302ab53ba0864 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 26 Oct 2023 09:35:29 +0800 Subject: [PATCH 389/551] Match-id-948469bc2dd545598ba8d3c235087b671f8ec730 --- .../key_process/feature_admit_and_evict.cpp | 46 +++++++++-- .../key_process/feature_admit_and_evict.h | 2 + src/ops_tf/hybrid_dataset_ops.cpp | 80 +++++++++++++++++++ 3 files changed, 121 insertions(+), 7 deletions(-) diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index b3354b9c..e6bdfd61 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -106,8 +106,13 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con absl::flat_hash_map& historyRecordInfos = m_recordsData.historyRecords[tableName]; auto innerIt = historyRecordInfos.find(featureId); - - if (channel == TRAIN_CHANNEL_ID) { + uint32_t countThreshold = static_cast(m_table2Threshold[tableNameOrigin].countThreshold); + // countThreshold = 0或者eval,只查询count,不做累加,若是新key,则count使用初始值0 + if (channel == EVAL_CHANNEL_ID || countThreshold == 0) { + if (innerIt != historyRecordInfos.end()) { + currKeyCount = historyRecordInfos[featureId].count; + } + } else if (channel == TRAIN_CHANNEL_ID) { // train 且 countThreshold > 0 if (innerIt == historyRecordInfos.end()) { // 维护 m_historyRecords FeatureItemInfo info(featureCnt, m_recordsData.timestamps[tableName]); @@ -121,14 +126,10 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con info.lastTime = m_recordsData.timestamps[tableName]; currKeyCount = info.count; } - } else if (channel == EVAL_CHANNEL_ID) { // eval - if (innerIt != historyRecordInfos.end()) { - currKeyCount = historyRecordInfos[featureId].count; - } } // 准入条件判断 - if (currKeyCount >= static_cast(m_table2Threshold[tableNameOrigin].countThreshold)) { + if (currKeyCount >= countThreshold) { return FeatureAdmitType::FEATURE_ADMIT_OK; } @@ -251,6 +252,37 @@ bool FeatureAdmitAndEvict::IsThresholdCfgOK(const std::vector& t return true; } +bool FeatureAdmitAndEvict::SetTableThresholds(int threshold, string embName) +{ + std::lock_guard lock(m_syncMutexs); + if (!embName.empty()) { + return SetTableThreshold(threshold, embName); + } + + bool result = true; + for (const auto& m : m_table2Threshold) { + if (!SetTableThreshold(threshold, m.second.tableName)) { + result = false; + } + } + return result; +} + +bool FeatureAdmitAndEvict::SetTableThreshold(int threshold, string embName) +{ + auto it = m_table2Threshold.find(embName); + if (it == m_table2Threshold.end()) { + LOG_WARN("SetTableThreshold failed, cause embName [{}] is not in m_table2Threshold...", embName); + return false; + } + LOG_INFO("SetTableThreshold success, embName[{}], count before [{}], count after [{}], time[{}], " + "coefficient[{}] ...", embName, m_table2Threshold[embName].countThreshold, threshold, + m_table2Threshold[embName].timeThreshold, m_table2Threshold[embName].faaeCoefficient); + + m_table2Threshold[embName].countThreshold = threshold; + return true; +} + auto FeatureAdmitAndEvict::GetTableThresholds() -> Table2ThreshMemT { std::lock_guard lock(m_syncMutexs); diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index cc4f76ca..6b8ff4fe 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -73,6 +73,8 @@ namespace MxRec { static bool IsThresholdCfgOK(const std::vector& thresholds, const std::vector& embNames, bool isTimestamp); + bool SetTableThresholds(int threshold, string embName); + bool SetTableThreshold(int threshold, string embName); // 与模型保存加载交互的接口 auto GetTableThresholds() -> Table2ThreshMemT; auto GetHistoryRecords() -> AdmitAndEvictData&; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index c8b29dcc..791dfbdf 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -43,6 +43,7 @@ namespace MxRec { public: explicit ClearChannel(OpKernelConstructionPtr context) : OpKernel(context) { + LOG_INFO("clear channel init"); OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { @@ -71,6 +72,74 @@ namespace MxRec { int channelId {}; }; + class SetThreshold : public OpKernel { + public: + explicit SetThreshold(OpKernelConstructionPtr context) : OpKernel(context) + { + LOG_INFO("SetThreshold init"); + OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embName)); + OP_REQUIRES_OK(context, context->GetAttr("ids_name", &idsName)); // sparse_lookup查询 + } + + ~SetThreshold() = default; + + void Compute(OpKernelContextPtr context) override + { + LOG_DEBUG("enter SetThreshold"); + int threshold = 1; + const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); + + int available = ParseThresholdAndCheck(inputTensor, threshold); + if (available == 0) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", + StringFormat("threshold[%d] error", threshold))); + return; + } + + // 开了准入才能调用修改阈值算子 + if (!FeatureAdmitAndEvict::m_cfgThresholds.empty()) { + auto keyProcess = Singleton::GetInstance(); + if (!keyProcess->isRunning) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running.")); + return; + } + + if (!keyProcess->GetFeatAdmitAndEvict().SetTableThresholds(threshold, embName)) { + context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold set error ...") + ); + return; + } + } else { + LOG_WARN("SetThreshold failed, because feature admit-and-evict switch is closed"); + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = available; + } + + int ParseThresholdAndCheck(const Tensor& inputTensor, int& threshold) const + { + // 前面8个字节、即占一个featureId位,是unix时间戳 + auto src = reinterpret_cast(inputTensor.tensor_data().data()); + std::copy(src, src + 1, &threshold); + + if (threshold <= 0) { + LOG_ERROR("set threshold[{}] <= 0 ", threshold); + return 0; + } + LOG_INFO("ParseThresholdAndCheck, emb_name:[{}], ids_name: [{}], threshold: [{}]", + embName, idsName, threshold); + + return 1; + } + + private: + string embName {}; + string idsName {}; + }; + class ReturnTimestamp : public OpKernel { public: explicit ReturnTimestamp(OpKernelConstructionPtr context) : OpKernel(context) @@ -525,6 +594,17 @@ namespace MxRec { REGISTER_OP("ClearChannel").Attr("channel_id : int"); REGISTER_KERNEL_BUILDER(Name("ClearChannel").Device(DEVICE_CPU), MxRec::ClearChannel); +// ##################### SetThreshold ####################### +REGISTER_OP("SetThreshold") +.Input("input: int32") +.Attr("emb_name: string = ''") +.Attr("ids_name: string = ''") +.Output("output: int32") +.SetShapeFn([](InferenceContextPtr c) { +c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); +return Status::OK(); +}); +REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(DEVICE_CPU), MxRec::SetThreshold); // ##################### ReturnTimestamp ####################### REGISTER_OP("ReturnTimestamp") -- Gitee From cf56b569c18783199493206cd19a87e6cbb77baa Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 26 Oct 2023 09:38:52 +0800 Subject: [PATCH 390/551] Match-id-907e42ec0f225706fb9cdf1109103b227a99c2be --- src/core/checkpoint/checkpoint.cpp | 15 +++++++++++++-- src/core/utils/common.h | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 3cab55dc..4c9901e2 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -218,6 +218,7 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { LOG_ERROR("Set device failed, device_id:{}", deviceId); + writeFile.close(); throw runtime_error(Logger::Format("Set device failed, device_id:{}", deviceId).c_str()); } @@ -232,6 +233,7 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da ACL_MEMCPY_DEVICE_TO_HOST); if (ret != ACL_SUCCESS) { LOG_ERROR("aclrtMemcpy failed, ret={}", ret); + writeFile.close(); throw runtime_error(Logger::Format("aclrtMemcpy failed, ret={}", ret).c_str()); } @@ -265,6 +267,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, auto &attributeArr = transData.attribute; auto embHashMapSize = attributeArr.at(0); if (embHashMapSize <= 0) { + readFile.close(); throw runtime_error(StringFormat("Invalid EmbHashMapSize:%d, must be greater than 0", embHashMapSize).c_str()); } auto embeddingSize = static_cast(datasetSize / sizeof(float) / embHashMapSize); @@ -282,10 +285,12 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, auto &transArr = transData.int64Arr; EmbSizeInfo embSizeInfo = GetEmbeddingSize(embName); if (embSizeInfo.embSize == 0) { + readFile.close(); throw runtime_error(StringFormat("embsize is 0").c_str()); } auto keyAddrElem = embSizeInfo.extEmbSize / embSizeInfo.embSize - 1; if (keyAddrElem < 0) { + readFile.close(); throw runtime_error(StringFormat("keyAddrElem: %d is less than 0", keyAddrElem).c_str()); } for (size_t i = 0, j = 0; i < transArr.size(); i += keyAddrElem, ++j) { @@ -457,11 +462,17 @@ vector Checkpoint::GetTableLayerLoadDir() auto dir { opendir(innerDirPath.c_str()) }; struct dirent* en; if (dir != nullptr) { + int fileNum = 0; while ((en = readdir(dir)) != nullptr) { - if (strcmp(en->d_name, currDir.c_str()) != 0 && - strcmp(en->d_name, prevDir.c_str()) != 0) { + if (fileNum > MAX_FILE_NUM) { + closedir(dir); + throw std::runtime_error("The number of files has exceeded the limit " + std::to_string(MAX_FILE_NUM)); + } + if (strncmp(en->d_name, currDir.c_str(), strlen(currDir.c_str())) != 0 && + strncmp(en->d_name, prevDir.c_str(), strlen(prevDir.c_str())) != 0) { loadTableDir.emplace_back(en->d_name); } + fileNum++; } closedir(dir); } else { diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 0a434c3d..b5cc3e25 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -57,7 +57,7 @@ namespace MxRec { constexpr char SUM_SAME_ID[] = "sum_same_id_gradients_and_apply"; constexpr int MAX_VOCABULARY_SIZE = 1e9; constexpr int SSD_SIZE_INDEX = 2; - + constexpr int MAX_FILE_NUM = 1000; // for GLOG struct GlogConfig { static bool gStatOn; -- Gitee From 4a3f8a79e7e6c76eb2dfccf552584ebf22506eb1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 26 Oct 2023 10:00:36 +0800 Subject: [PATCH 391/551] Match-id-0341c91cd0ce75ef59c117baadd1bc6089bf8784 --- src/core/checkpoint/checkpoint.cpp | 128 ++++++++++++++++++++++------- src/core/checkpoint/checkpoint.h | 5 ++ src/core/utils/common.h | 1 + 3 files changed, 104 insertions(+), 30 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 4c9901e2..e92fa0a6 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "ckpt_data_handler//emb_hash_ckpt/emb_hash_ckpt.h" #include "ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h" @@ -19,7 +20,6 @@ #include "ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h" #include "utils/time_cost.h" #include "utils/common.h" - #include "checkpoint.h" using namespace std; @@ -574,6 +574,62 @@ void Checkpoint::ReadStream(CkptTransData& transData, readFile.close(); } +void Checkpoint::ValidateFile(int fd, const string& dataDir, size_t datasetSize) const +{ + try { + ValidateReadFile(dataDir, datasetSize); + } catch (const std::invalid_argument& e) { + close(fd); + throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + } +} + +void Checkpoint::HandleMappedData(char* mappedData, size_t mapRowNum, size_t onceReadByteSize, + vector>& dst, size_t cnt) const +{ +#pragma omp parallel for + for (size_t j = 0; j < mapRowNum; ++j) { + size_t idx = 0; + size_t readSize = 0; + size_t dataCol = onceReadByteSize; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = dataCol; + } + + errno_t err = memcpy_s(dst[cnt + j].data() + idx, readSize, + mappedData + j * onceReadByteSize + idx, readSize); + if (err != 0) { + throw std::runtime_error("Error execution memcpy_s: " + std::to_string(err)); + } + dataCol -= readSize; + idx += readSize; + } + } +} + +void Checkpoint::CalculateMapSize(off_t fileSize, size_t& mapByteSize, size_t& mapRowNum, size_t onceReadByteSize) const +{ + // 每次映射的字节数 + mapByteSize = MAP_BYTE_SIZE; + // 确保mapByteSize是onceReadByteSize和pageSize的整数倍,确保每次映射的offset是页大小的整数倍 + size_t pageSize = sysconf(_SC_PAGESIZE); + if (pageSize == -1) { + throw std::runtime_error("Failed to get page size: " + std::string(strerror(errno))); + } + size_t lcmVal = std::lcm(onceReadByteSize, pageSize); + mapByteSize = (mapByteSize / lcmVal) * lcmVal; + + // 如果文件大小小于每次映射的字节数,则一次性映射,映射大小不是页大小整数倍的时候,mmap会自动向上取整,额外的字节会初始化成零 + if (fileSize <= mapByteSize) { + mapByteSize = fileSize; + } + + mapRowNum = mapByteSize / onceReadByteSize; +} + void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, uint32_t dataElmtBytes, @@ -588,48 +644,60 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, if (embDataOuterSize <= 0 || embDataOuterSize > MAX_VOCABULARY_SIZE) { throw runtime_error(StringFormat("Invalid embDataOuterSize :%d", embDataOuterSize).c_str()); } - std::ifstream readFile; - readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); - size_t datasetSize = static_cast(readFile.tellg()); - readFile.seekg(0, std::ios::beg); - try { - ValidateReadFile(dataDir, datasetSize); - } catch (const std::invalid_argument& e) { - readFile.close(); - throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + + int fd = open(dataDir.c_str(), O_RDONLY); + if (fd == -1) { + throw runtime_error(StringFormat("Failed to open file: %s", dataDir).c_str()); } + off_t fileSize = lseek(fd, 0, SEEK_END); + + size_t datasetSize = fileSize; + ValidateFile(fd, dataDir, datasetSize); + if (datasetSize % embDataOuterSize > 0 || datasetSize % dataElmtBytes > 0) { LOG_ERROR("data is missing or incomplete in load file: {}", dataDir); - readFile.close(); + close(fd); throw runtime_error("unable to load EMB_DATA cause wrong-format saved emb data"); } + auto loadHostEmbs = ckptData.hostEmbs; auto& dst = (*loadHostEmbs)[embName].embData; dst.reserve(embDataOuterSize); + auto onceReadByteSize { datasetSize / embDataOuterSize }; - if (!readFile.is_open()) { - LOG_DEBUG("unable to open load file: {}", dataDir); - readFile.close(); - return; - } - for (size_t i = 0; i < embDataOuterSize; ++i) { - size_t idx = 0; - size_t readSize = 0; - size_t dataCol = onceReadByteSize; - while (dataCol != 0) { - if (dataCol > oneTimeReadWriteLen) { - readSize = oneTimeReadWriteLen; - } else { - readSize = dataCol; - } - readFile.read(reinterpret_cast(dst[i].data()) + idx, readSize); - dataCol -= readSize; - idx += readSize; + size_t mapByteSize; + size_t mapRowNum; + CalculateMapSize(fileSize, mapByteSize, mapRowNum, onceReadByteSize); + + off_t offset = 0; + size_t remainBytes = fileSize; + + for (size_t i = 0; i < embDataOuterSize; i += mapRowNum) { + // 如果剩余字节数小于每次映射的字节数,则更新每次映射的字节数和行数 + if (remainBytes < mapByteSize) { + mapByteSize = remainBytes; + mapRowNum = mapByteSize / onceReadByteSize; + } + + void* tempMappedData = mmap(NULL, mapByteSize, PROT_READ, MAP_PRIVATE, fd, offset); + if (tempMappedData == MAP_FAILED) { + close(fd); + throw std::runtime_error("Failed to map file: " + dataDir + ", errno: " + std::to_string(errno)); } + char* mappedData = static_cast(tempMappedData); + + // 处理映射的数据 + HandleMappedData(mappedData, mapRowNum, onceReadByteSize, dst, i); + + munmap(mappedData, mapByteSize); + + offset += mapByteSize; + remainBytes -= mapByteSize; } - readFile.close(); + + close(fd); } void Checkpoint::SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType) diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index d4a4ac2c..cadd86d1 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -110,6 +110,11 @@ namespace MxRec { void LoadDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler, CkptData& ckptData); void ReadStream(CkptTransData& transData, const string& dataDir, CkptDataType dataType, uint32_t dataElmtBytes); + void ValidateFile(int fd, const string& dataDir, size_t datasetSize) const; + void HandleMappedData(char* mappedData, size_t mapRowNum, size_t onceReadByteSize, + vector>& dst, size_t cnt) const; + void CalculateMapSize(off_t fileSize, size_t& mapByteSize, size_t& mapRowNum, size_t onceReadByteSize) const; + void ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, uint32_t dataElmtBytes, CkptData& ckptData, string embName) const; void SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index b5cc3e25..9fb11da0 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -78,6 +78,7 @@ namespace MxRec { constexpr long long FILE_MAX_SIZE = 1LL << 40; constexpr int FILE_MIN_SIZE = 0; constexpr size_t BUFFER_SIZE{1024 * 1024 * 64}; + constexpr size_t MAP_BYTE_SIZE{static_cast(10) * 1024 * 1024 * 1024}; constexpr int KEY_PROCESS_TIMEOUT = 120; constexpr int GET_BATCH_TIMEOUT = 300; -- Gitee From 630053f6a6dff288eea0e3643a7bb01b6d37403c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 26 Oct 2023 10:33:10 +0800 Subject: [PATCH 392/551] Match-id-c50942c12a5c38eaefc6db558ab5a3839b87cd11 --- src/core/checkpoint/checkpoint.cpp | 15 +++++---------- src/core/checkpoint/checkpoint.h | 1 + 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index e92fa0a6..c536d6d2 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -271,6 +271,11 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, throw runtime_error(StringFormat("Invalid EmbHashMapSize:%d, must be greater than 0", embHashMapSize).c_str()); } auto embeddingSize = static_cast(datasetSize / sizeof(float) / embHashMapSize); + EmbSizeInfo embSizeInfo = GetEmbeddingSize(embName); + if (embeddingSize != embSizeInfo.extEmbSize) { + readFile.close(); + throw runtime_error(StringFormat("Invalid embedding size to be read, may read file has been changed").c_str()); + } aclError ret; void *newBlock = nullptr; @@ -283,16 +288,6 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, float *floatPtr = static_cast(newBlock); auto &transArr = transData.int64Arr; - EmbSizeInfo embSizeInfo = GetEmbeddingSize(embName); - if (embSizeInfo.embSize == 0) { - readFile.close(); - throw runtime_error(StringFormat("embsize is 0").c_str()); - } - auto keyAddrElem = embSizeInfo.extEmbSize / embSizeInfo.embSize - 1; - if (keyAddrElem < 0) { - readFile.close(); - throw runtime_error(StringFormat("keyAddrElem: %d is less than 0", keyAddrElem).c_str()); - } for (size_t i = 0, j = 0; i < transArr.size(); i += keyAddrElem, ++j) { vector row(embeddingSize); readFile.read(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index cadd86d1..3403a576 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -72,6 +72,7 @@ namespace MxRec { const int embHashNum { 2 }; const int attribEmbDataOuterIdx { 0 }; const int attribEmbDataInnerIdx { 1 }; + const int keyAddrElem { 2 }; void SetDataHandler(CkptData& ckptData); void SetDataHandler(const vector& featureTypes); -- Gitee From b8143e67a1b81afeaaa843c38a1150ad266ee1d3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 26 Oct 2023 16:37:10 +0800 Subject: [PATCH 393/551] Match-id-084241dc07816c12e5e6bb5f29fa85d60db2531e --- mx_rec/core/asc/helper.py | 14 ++++---------- mx_rec/core/embedding.py | 3 ++- mx_rec/optimizers/lazy_adam.py | 4 ++-- mx_rec/validator/validator.py | 26 ++++++++++++++++++++++++++ 4 files changed, 34 insertions(+), 13 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 6c88344d..7925fbd4 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -22,14 +22,11 @@ from mx_rec.constants.constants import MAX_INT32 ["check_at_least_one_equal_to_target"]), ("tgt_key_specs", ClassValidator, {"classes": (FeatureSpec, list, tuple, type(None))}), ("args_index_list", ClassValidator, {"classes": (list, type(None))}), - ("feature_numbers", ClassValidator, {"classes": (int, list, type(None))}), - ("feature_numbers", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), ("table_names", ClassValidator, {"classes": (list, type(None))}), ("is_training", ClassValidator, {"classes": (bool, type(None))}), ("dump_graph", ClassValidator, {"classes": (bool, type(None))}), ]) -def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, feature_numbers=None, - table_names=None, **kwargs): +def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, table_names=None, **kwargs): ''' desperated. use create_asc_insert_func_with_specs or create_asc_insert_func_with_agc @@ -38,7 +35,6 @@ def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, feature_number return create_asc_insert_func_with_specs(tgt_key_specs=tgt_key_specs, **kwargs) if args_index_list is not None: return create_asc_insert_func_with_acg(args_index_list=args_index_list, - feature_counts=feature_numbers, table_names=table_names, **kwargs) raise RuntimeError("call get_asc_insert_func in-correctly.") @@ -52,21 +48,19 @@ def create_asc_insert_func_with_specs(tgt_key_specs, **kwargs): @para_checker_decorator(check_option_list=[ - (["args_index_list", "feature_counts", "table_names"], ValueCompareValidator, {"target": None}, + (["args_index_list", "table_names"], ValueCompareValidator, {"target": None}, ["check_all_not_equal_to_target"]), ]) -def create_asc_insert_func_with_acg(args_index_list, feature_counts, table_names, **kwargs): +def create_asc_insert_func_with_acg(args_index_list, table_names, **kwargs): ''' 自动改图模式 auto change graph ''' return get_asc_insert_func_inner(args_index_list=args_index_list, - feature_counts=feature_counts, table_names=table_names, **kwargs) -def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, feature_counts=None, - table_names=None, **kwargs): +def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, table_names=None, **kwargs): is_training = kwargs.get("is_training", True) dump_graph = kwargs.get("dump_graph", False) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 4109ca7a..e7d5032a 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -25,7 +25,7 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, ge get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set, \ get_table_instance_by_name from mx_rec.validator.validator import ClassValidator, StringValidator, SSDFeatureValidator, \ - para_checker_decorator, IntValidator, NumValidator, OptionValidator, OptionalIntValidator + para_checker_decorator, IntValidator, NumValidator, OptionValidator, OptionalIntValidator, OptionalStringValidator from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.util.normalization import fix_invalid_table_name from mx_rec.util.global_env_conf import global_env @@ -786,6 +786,7 @@ class SparseEmbedding: ("send_count", ClassValidator, {"classes": (int, type(None))}), ("send_count", OptionalIntValidator, {"min_value": 1, "max_value": MAX_INT32}, ["check_value"]), ("name", ClassValidator, {"classes": (str, type(None))}), + ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), ("modify_graph", ClassValidator, {"classes": (bool, type(None))}), ("batch", ClassValidator, {"classes": (dict, list, tuple, type(None))}), ("access_and_evict_config", ClassValidator, {"classes": (dict, type(None))}), diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 8b1670f3..ec67fab0 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -26,9 +26,9 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, @para_checker_decorator(check_option_list=[ ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), - ("beta1", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("beta1", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value_for_open_interval"]), ("beta2", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), - ("epsilon", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("epsilon", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value_for_left_open_interval"]), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdam"): diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index da1f7d04..63fa11b2 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -364,6 +364,32 @@ class NumValidator(Validator): return self + def check_value_for_open_interval(self): + if self.min_value is not None: + self.register_checker(lambda: self.value > self.min_value, + f"'{self.name}' is less than or equal {self.min_value}") + if self.max_value is not None: + self.register_checker(lambda: self.value < self.max_value, + f"'{self.name}' is bigger than or equal {self.max_value}") + return self + + def check_value_for_left_open_interval(self): + if self.min_value is not None: + self.register_checker(lambda: self.value > self.min_value, + f"'{self.name}' is less than or equal {self.min_value}") + if self.max_value is not None: + self.register_checker(lambda: self.value <= self.max_value, + f"'{self.name}' is bigger than {self.max_value}") + return self + + def check_value_for_right_open_interval(self): + if self.min_value is not None: + self.register_checker(lambda: self.value >= self.min_value, f"'{self.name}' is less than {self.min_value}") + if self.max_value is not None: + self.register_checker(lambda: self.value < self.max_value, + f"'{self.name}' is bigger than or equal {self.max_value}") + return self + class IntValidator(NumValidator): """ -- Gitee From 155297133e12ffaaf4fa2f181c5525e91da93299 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 26 Oct 2023 17:09:17 +0800 Subject: [PATCH 394/551] Match-id-00b415fad9347dff7d5b723133922b39bfd1dfdd --- src/core/hd_transfer/acl_transfer.cpp | 27 +++++++++++++++++++++++++++ src/core/hd_transfer/acl_transfer.h | 19 +++++++++++++++++++ src/core/hd_transfer/hd_transfer.cpp | 20 +++++++++----------- src/core/hd_transfer/hd_transfer.h | 2 +- src/core/host_emb/host_emb.cpp | 20 +++++++------------- 5 files changed, 63 insertions(+), 25 deletions(-) create mode 100644 src/core/hd_transfer/acl_transfer.cpp create mode 100644 src/core/hd_transfer/acl_transfer.h diff --git a/src/core/hd_transfer/acl_transfer.cpp b/src/core/hd_transfer/acl_transfer.cpp new file mode 100644 index 00000000..375a2b43 --- /dev/null +++ b/src/core/hd_transfer/acl_transfer.cpp @@ -0,0 +1,27 @@ +// +// Created by w00842226 on 2023/10/26. +// +#include "acl_transfer.h" + +AclTransferStatus RecvByAcl(const acltdtChannelHandle *handle, acltdtDataset *dataset, float* resultPtr){ +#ifndef GTEST + if (dataset==nullptr || handle==nullptr) { + throw runtime_error(StringFormat("handle or dataset is nullptr:%s."); + } + auto aclStatus = acltdtReceiveTensor(handle, dataset, GlobalEnv::aclTimeout); + if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { + return AclTransferStatus::F001; + } + auto size = acltdtGetDatasetSize(dataset); + if (size == 0) { + LOG_WARN(HOSTEMB + "recv empty data"); + return AclTransferStatus::OK; + } + auto aclData = acltdtGetDataItem(dataset, 0); + if (aclData == nullptr) { + return AclTransferStatus::F001; + } + resultPtr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); +#endif + return AclTransferStatus::OK;; +} \ No newline at end of file diff --git a/src/core/hd_transfer/acl_transfer.h b/src/core/hd_transfer/acl_transfer.h new file mode 100644 index 00000000..5bdba433 --- /dev/null +++ b/src/core/hd_transfer/acl_transfer.h @@ -0,0 +1,19 @@ +// +// Created by w00842226 on 2023/10/26. +// +#include "hd_transfer.h" +#include +#include "utils/common.h" +#include "utils/time_cost.h" + +using namespace MxRec; +using namespace std; +#ifndef MXREC_ACL_TRANSFER_H +#define MXREC_ACL_TRANSFER_H +enum class AclTransferStatus { + OK, + F001 +}; +AclTransferStatus RecvByAcl(const acltdtChannelHandle *handle, const acltdtDataset *dataset, float* resultPtr); + +#endif //MXREC_ACL_TRANSFER_H diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 5f32b457..f1b259b9 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -8,6 +8,7 @@ #include #include "utils/common.h" #include "utils/time_cost.h" +#include "acl_transfer.h" using namespace MxRec; using namespace std; @@ -57,7 +58,7 @@ void HDTransfer::Destroy() LOG_INFO(HD + "destroy channel start"); for (auto& c: transferChannels) { LOG_INFO(HD + "start destroy channel:{}", c.first); - tensorflow::StopRecvTensorByAcl(&c.second, c.first); + acltdtDestroyChannel(c.second); LOG_INFO(HD + "destroy channel:{}", c.first); } for (auto& d: aclDatasets) { @@ -183,25 +184,22 @@ vector HDTransfer::Recv(TransferChannel channel, int channel /// \param channelId 通道索引(训练/推理) /// \param embName 表名 /// \return -size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& embName) +void HDTransfer::RecvD2H(int channelId, const string& embName, float*& resultPtr) { EASY_FUNCTION() #ifndef GTEST std::vector tensors; - string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); + TransferChannel channleName = TransferChannel::D2H; + string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channleName).c_str(), channelId); LOG_DEBUG("hd transfer try recv:{}", recvName); TimeCost tc = TimeCost(); - if (aclDatasets[embName] == nullptr) { - throw runtime_error(StringFormat("Failed recv:%s.", recvName.c_str()).c_str()); + auto aclStatus = RecvByAcl(transferChannels[recvName], aclDatasets[embName], resultPtr); + if (aclStatus != AclTransferStatus::OK) { + throw runtime_error(StringFormat("Failed receive data from acl channel, acl status:%d", aclStatus).c_str()); } - auto aclStatus = acltdtReceiveTensor(transferChannels[recvName], aclDatasets[embName], GlobalEnv::aclTimeout); if (!running) { - return 0; - } - if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { - throw runtime_error(StringFormat("Failed receive data from acl channel, acl status:%d", aclStatus).c_str()); + return; } LOG_INFO("hd transfer recv:{} cost:{}ms", recvName, tc.ElapsedMS()); - return acltdtGetDatasetSize(aclDatasets[embName]); #endif } diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index b7f9ea38..40104105 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -79,7 +79,7 @@ namespace MxRec { vector Recv(TransferChannel channel, int channelId, const string& embName); - size_t RecvAcl(TransferChannel channel, int channelId, const string& embName); + void RecvD2H(int channelId, const string& embName, float*& resultPtr); void Destroy(); diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index de56ebdb..06c1e75d 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -136,25 +136,19 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI EASY_FUNCTION(profiler::colors::Purple) auto updateThread = [this, missingKeysHostPos, channelId, embName] { - LOG_INFO(HOSTEMB + "UpdateEmbV2, channelId:{}, embName:{}", channelId, embName); - EASY_FUNCTION(profiler::colors::Purple); - TimeCost tc = TimeCost(); auto hdTransfer = Singleton::GetInstance(); - TransferChannel transferName = TransferChannel::D2H; LOG_INFO(HOSTEMB + "wait D2H embs, channelId:{}", channelId); - const auto tensors = hdTransfer->Recv(transferName, channelId, embName); - if (tensors.empty()) { - LOG_WARN(HOSTEMB + "recv empty data"); + float* ptr = nullptr; + hdTransfer->RecvD2H(channelId, embName, ptr); + if (!ptr) { return; } - const Tensor& d2hEmb = tensors[0]; + + TimeCost tc = TimeCost(); EASY_BLOCK("Update") - const float* ptr = d2hEmb.flat().data(); - auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; auto& embData = hostEmbs[embName].embData; - - LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}, embData.size = {}", - embName, missingKeysHostPos.size(), embeddingSize, embData.size()); + LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}," + " embData.size = {}, RecvAcl = {}, elementSize = {}, dimNum = {}", #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ptr, embData, embeddingSize) for (size_t j = 0; j < missingKeysHostPos.size(); j++) { auto& dst = embData[missingKeysHostPos[j]]; -- Gitee From 08eb18277ec36d407c270020398ff574370dd600 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 26 Oct 2023 22:14:48 +0800 Subject: [PATCH 395/551] Match-id-28182757e3d73bcc70f802ed17607cee18b7fdff --- mx_rec/constants/constants.py | 1 + mx_rec/core/asc/manager.py | 4 +++- mx_rec/core/embedding.py | 18 +++++++++++++++++- mx_rec/graph/merge_lookup.py | 4 +++- mx_rec/graph/patch.py | 12 +++++++++++- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 9 +++++++-- src/core/utils/common.h | 9 +++++++-- src/pybind/module_main.cpp | 9 ++++++--- 8 files changed, 55 insertions(+), 11 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 1dc9d3bb..137532bc 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -150,6 +150,7 @@ class ASCAnchorAttr(Enum): RESTORE_VECTOR_SECOND = "restore_vector_second" UNIQUE_KEYS = "unique_keys" GRADIENTS_STRATEGY = "gradients_strategy" + IS_GRAD = "is_grad" class OptimizerType(Enum): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 11514e7b..982bf924 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -61,8 +61,10 @@ def generate_table_info_list(): table_instance.slice_device_vocabulary_size) logger.debug("table_instance.slice_host_vocabulary_size: %s", table_instance.slice_host_vocabulary_size) logger.debug("table_instance.slice_ssd_vocabulary_size: %s", table_instance.slice_ssd_vocabulary_size) + logger.debug("EmbInfoParams: The table name is %s, and the value of `is_grad` in this table is %s.", + table_instance.table_name, table_instance.is_grad) params = EmbInfoParams(table_instance.table_name, table_instance.send_count, table_instance.scalar_emb_size, - table_instance.ext_emb_size, table_instance.is_save) + table_instance.ext_emb_size, table_instance.is_save, table_instance.is_grad) table_info = EmbInfo(params, [table_instance.slice_device_vocabulary_size, table_instance.slice_host_vocabulary_size, table_instance.slice_ssd_vocabulary_size], diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 4109ca7a..77dc363f 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -137,6 +137,7 @@ class SparseEmbedding: self.modify_graph = False self.init_param = config.get("init_param") self.all2all_gradients_op = All2allGradientsOp.mapping(config.get("all2all_gradients_op")) + self.is_grad = False self.set_slice_vocab_size() self.set_emb_size() @@ -334,6 +335,7 @@ class SparseEmbedding: SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.TABLE_INSTANCE] = self SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = kwargs.get("is_train") SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.FEATURE_SPEC] = feature_spec + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_GRAD] = kwargs.get("is_grad") def check_multi_lookup_times(self, is_training): lookup_times = len(self.lookup_name_dict.get(is_training)) if self.modify_graph else len(self.lookup_result) @@ -458,6 +460,8 @@ class SparseEmbedding: kwargs["ids"] = ids mock_lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) mock_lookup_result = tf.identity(mock_lookup_result, name=ASCAnchorAttr.MOCK_LOOKUP_RESULT.value) + if not kwargs.get("is_grad"): + mock_lookup_result = tf.stop_gradient(mock_lookup_result, name="mock_stop_grad_lookup_res") SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.MOCK_LOOKUP_RESULT] = mock_lookup_result logger.debug("Return the stub tensor `%s` of the `%s` table.", mock_lookup_result, self.table_name) return mock_lookup_result @@ -479,6 +483,8 @@ class SparseEmbedding: spec_name = feature_spec.name is_training = kwargs.get("is_train") if spec_name in self.lookup_result and is_training in self.lookup_result.get(spec_name): + if not kwargs.get("is_grad"): + return tf.stop_gradient(self.lookup_result.get(spec_name).get(is_training), name="stop_grad_lookup_res") return self.lookup_result.get(spec_name).get(is_training) if not get_use_static() and not self.modify_graph and kwargs.get("batch") is None: @@ -545,6 +551,8 @@ class SparseEmbedding: if not self.modify_graph: self.check_multi_lookup_times(is_training) + if not kwargs.get("is_grad"): + return tf.stop_gradient(self.lookup_result.get(spec_name).get(is_training), name="stop_grad_lookup_res") return self.lookup_result.get(spec_name).get(is_training) def split_lookup_result(self, same_table_feature_spec: list, tensor_split_list: list, tensor_list: list, @@ -609,7 +617,6 @@ class SparseEmbedding: emb_size=self.emb_size, use_hot=use_hot, device_id=device_id, use_dynamic_expansion=use_dynamic_expansion) - if self.skip_emb_transfer: result = get_preprocessed_tensor_for_asc(self.variable, config) else: @@ -738,6 +745,7 @@ class SparseEmbedding: logger.debug("feature spec mode, table_name: %s, ASCEND_TABLE_NAME_MUST_CONTAIN: %s", self.table_name, ASCEND_TABLE_NAME_MUST_CONTAIN) + if is_training and is_table_name_valid: add_to_collection() @@ -789,6 +797,7 @@ class SparseEmbedding: ("modify_graph", ClassValidator, {"classes": (bool, type(None))}), ("batch", ClassValidator, {"classes": (dict, list, tuple, type(None))}), ("access_and_evict_config", ClassValidator, {"classes": (dict, type(None))}), + ("is_grad", ClassValidator, {"classes": (bool, )}), ]) def sparse_lookup(hashtable: SparseEmbedding, ids: Union[FeatureSpec, tf.Tensor], @@ -798,6 +807,7 @@ def sparse_lookup(hashtable: SparseEmbedding, modify_graph: bool = False, batch: Optional[dict] = None, access_and_evict_config: Optional[dict] = None, + is_grad: bool = True, **kwargs): """ Args: @@ -809,16 +819,22 @@ def sparse_lookup(hashtable: SparseEmbedding, modify_graph: if True, the original graph will be modified before building a Session instance batch: the value returned by the get_next() method of TF Dataset access_and_evict_config: the configuration for the feature of feature filtering and eviction + is_grad: indicate whether this lookup requires update gradients Returns: Tensor for lookup result """ + kwargs["is_grad"] = is_grad + # 一表多查时,只要有一次查询需要grad,那么这张表也需要grad;否则整张表都不需要gard,同时在全局unique情况下,C++也不需要send数据 + hashtable.is_grad |= is_grad kwargs["is_train"] = is_train kwargs["name"] = name if name is not None else hashtable.get_default_lookup_name() kwargs["modify_graph"] = modify_graph kwargs["batch"] = batch kwargs["access_and_evict_config"] = access_and_evict_config scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) + logger.info("Lookup: The table name is %s, and the value of `is_grad` in this lookup (lookup name is %s) is %s.", + hashtable.table_name, name, is_grad) with tf.compat.v1.variable_scope(scope_name): if isinstance(ids, FeatureSpec): diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py index 3e0df2c0..79f2649a 100644 --- a/mx_rec/graph/merge_lookup.py +++ b/mx_rec/graph/merge_lookup.py @@ -63,16 +63,18 @@ def do_merge_lookup(is_train: bool = True): if not sub_cutting_point_list: raise RuntimeError(f"The current mode(train: True, eval: False) is {is_train}, and the sparse table does not " f"have anchor ids.") + for cutting_point in sub_cutting_point_list: table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) + is_grad = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_GRAD) if len(table_instance.lookup_name_dict.get(is_train)) == 1: logger.debug("The origin lookup result of %s for %s does not need to be replaced.", feature_spec.name, table_instance.table_name) continue send_count = table_instance.send_count - kwargs = dict(is_train=is_train, ids=cutting_point, multi_lookup=True) + kwargs = dict(is_train=is_train, ids=cutting_point, multi_lookup=True, is_grad=is_grad) if not get_use_static(): kwargs["feature_spec_name_ids_dict"] = feature_spec_name_ids_dict lookup_result = table_instance.lookup_for_asc_with_feature_spec(feature_spec, send_count, **kwargs) diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 2d5abf61..d3f920f7 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -21,7 +21,8 @@ from tensorflow.python.client.session import BaseSession from mx_rec.constants import constants from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph, insert_bool_gauge, \ - get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round, get_asc_manager + get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round, get_asc_manager, \ + export_table_instances from mx_rec.util.tf_version_adapter import NPUCheckpointSaverHook from mx_rec.graph.merge_lookup import do_merge_lookup from mx_rec.util.log import logger @@ -307,6 +308,15 @@ def scale_loss(self: Optimizer, loss_value: tf.Tensor) -> tf.Tensor: # Ensure that the backward of graph is constructed and the gradient calculation is correct. do_merge_lookup(is_train=True) + # 在训练情况下,至少要有一个variable参与反向,否则报错 + is_grad = False + table_var_list = [] + for _, table_instance in export_table_instances().items(): + is_grad |= table_instance.is_grad + table_var_list.append(table_instance.variable) + if not is_grad: + raise RuntimeError("No gradients provided for any variable: %s." % (table_var_list,)) + # origin code ops.get_default_graph()._is_loss_scaled_by_optimizer = False if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN: diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index d0253cb0..89a9fd74 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -635,12 +635,17 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && channelId == TRAIN_CHANNEL_ID) { TimeCost sendUnikeysSyncTC; - hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embInfo.name); + LOG_DEBUG("global unique, table name: {}, is grad: {}", embInfo.name, embInfo.isGrad); + if (embInfo.isGrad) { + hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embInfo.name); + } infoVecs->pop_back(); LOG_DEBUG("sendUnikeysSyncTC(ms):{}", sendUnikeysSyncTC.ElapsedMS()); TimeCost sendRestoreVecSecSyncTC; - hdTransfer->Send(TransferChannel::RESTORE_SECOND, { infoVecs->back() }, channelId, embInfo.name); + if (embInfo.isGrad) { + hdTransfer->Send(TransferChannel::RESTORE_SECOND, { infoVecs->back() }, channelId, embInfo.name); + } infoVecs->pop_back(); LOG_DEBUG("sendRestoreVecSecSyncTC(ms):{}", sendRestoreVecSecSyncTC.ElapsedMS()); } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 9fb11da0..45a275eb 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -381,12 +381,14 @@ namespace MxRec { int sendCount, int embeddingSize, int extEmbeddingSize, - bool isSave) + bool isSave, + bool isGrad) : name(name), sendCount(sendCount), embeddingSize(embeddingSize), extEmbeddingSize(extEmbeddingSize), - isSave(isSave) + isSave(isSave), + isGrad(isGrad) { } std::string name; @@ -394,6 +396,7 @@ namespace MxRec { int embeddingSize; int extEmbeddingSize; bool isSave; + bool isGrad; }; struct EmbInfo { @@ -408,6 +411,7 @@ namespace MxRec { embeddingSize(embInfoParams.embeddingSize), extEmbeddingSize(embInfoParams.extEmbeddingSize), isSave(embInfoParams.isSave), + isGrad(embInfoParams.isGrad), devVocabSize(vocabsize[0]), hostVocabSize(vocabsize[1]), ssdVocabSize(vocabsize[SSD_SIZE_INDEX]), @@ -421,6 +425,7 @@ namespace MxRec { int embeddingSize; int extEmbeddingSize; bool isSave; + bool isGrad; size_t devVocabSize; size_t hostVocabSize; size_t ssdVocabSize; diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 8a04b288..387c6307 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -94,17 +94,19 @@ namespace { void GetEmbInfoParams(pybind11::module_& m) { pybind11::class_(m, "EmbInfoParams") - .def(pybind11::init(), + .def(pybind11::init(), py::arg("name"), py::arg("send_count"), py::arg("embedding_size"), py::arg("ext_embedding_size"), - py::arg("is_save")) + py::arg("is_save"), + py::arg("is_grad")) .def_readwrite("name", &EmbInfoParams::name) .def_readwrite("send_count", &EmbInfoParams::sendCount) .def_readwrite("embedding_size", &EmbInfoParams::embeddingSize) .def_readwrite("ext_embedding_size", &EmbInfoParams::extEmbeddingSize) - .def_readwrite("is_save", &EmbInfoParams::isSave); + .def_readwrite("is_save", &EmbInfoParams::isSave) + .def_readwrite("is_grad", &EmbInfoParams::isGrad); } void GetEmbInfo(pybind11::module_& m) @@ -121,6 +123,7 @@ namespace { .def_readwrite("embedding_size", &EmbInfo::embeddingSize) .def_readwrite("ext_embedding_size", &EmbInfo::extEmbeddingSize) .def_readwrite("is_save", &EmbInfo::isSave) + .def_readwrite("is_grad", &EmbInfo::isGrad) .def_readwrite("dev_vocab_size", &EmbInfo::devVocabSize) .def_readwrite("host_vocab_size", &EmbInfo::hostVocabSize) .def_readwrite("initialize_infos", &EmbInfo::initializeInfos) -- Gitee From 63e3e07c5f4755c446103cc1413fbcb2816b2d13 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 26 Oct 2023 10:26:06 +0800 Subject: [PATCH 396/551] Match-id-97832c1985bfb03f7e913b038339700c9ed1c4e9 --- src/core/ssd_cache/cache_manager.cpp | 9 +++++++ src/core/ssd_cache/cache_manager.h | 2 ++ src/core/ssd_engine/ssd_engine.cpp | 34 +++++++++++++++++--------- src/core/ssd_engine/ssd_engine.h | 2 ++ src/core/ssd_engine/table.cpp | 6 +++++ src/core/ssd_engine/table.h | 2 ++ src/tests/ssd_engine/table_test.cpp | 36 ++++++++++++++++++++++++++++ 7 files changed, 80 insertions(+), 11 deletions(-) diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index 3d421dcf..fdf3eac3 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -492,3 +492,12 @@ void CacheManager::SaveSSDEngine(int step) ssdEngine->Save(step); #endif } + +int64_t CacheManager::GetTableEmbeddingSize(const string& tableName) +{ + if (ssdEngine == nullptr) { + throw runtime_error("SSDEngine not init"); + } + return ssdEngine->GetTableEmbeddingSize(tableName); +} + diff --git a/src/core/ssd_cache/cache_manager.h b/src/core/ssd_cache/cache_manager.h index 86352598..1995556a 100644 --- a/src/core/ssd_cache/cache_manager.h +++ b/src/core/ssd_cache/cache_manager.h @@ -72,6 +72,8 @@ namespace MxRec { // 每张表中非DDR内key的出现次数 unordered_map> excludeDDRKeyCountMap; + int64_t GetTableEmbeddingSize(const string& tableName); + private: struct EmbBaseInfo { uint64_t maxTableSize; diff --git a/src/core/ssd_engine/ssd_engine.cpp b/src/core/ssd_engine/ssd_engine.cpp index f9cedccc..25d2da38 100644 --- a/src/core/ssd_engine/ssd_engine.cpp +++ b/src/core/ssd_engine/ssd_engine.cpp @@ -9,7 +9,7 @@ using namespace std; bool SSDEngine::IsTableExist(const string &tableName) { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } auto it = as_const(tableMap).find(tableName); return !(it == tableMap.end()); @@ -18,7 +18,7 @@ bool SSDEngine::IsTableExist(const string &tableName) bool SSDEngine::IsKeyExist(const string &tableName, emb_key_t key) { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { @@ -30,7 +30,7 @@ bool SSDEngine::IsKeyExist(const string &tableName, emb_key_t key) void SSDEngine::CreateTable(const string &tableName, vector savePaths, uint64_t maxTableSize) { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } if (savePaths.empty()) { throw invalid_argument("SSDEngine input savePaths is empty"); @@ -45,7 +45,7 @@ void SSDEngine::CreateTable(const string &tableName, vector savePaths, u void SSDEngine::InsertEmbeddings(const string &tableName, vector &keys, vector> &embeddings) { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { @@ -62,7 +62,7 @@ void SSDEngine::InsertEmbeddings(const string &tableName, vector &key void SSDEngine::DeleteEmbeddings(const string &tableName, vector &keys) { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { @@ -75,7 +75,7 @@ void SSDEngine::DeleteEmbeddings(const string &tableName, vector &key int64_t SSDEngine::GetTableAvailableSpace(const string &tableName) { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { @@ -88,7 +88,7 @@ int64_t SSDEngine::GetTableAvailableSpace(const string &tableName) void SSDEngine::Save(int step) { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } for (auto item: as_const(tableMap)) { item.second->Save(step); @@ -98,7 +98,7 @@ void SSDEngine::Save(int step) void SSDEngine::Load(const string &tableName, vector savePaths, uint64_t maxTableSize, int step) { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } auto it = as_const(tableMap).find(tableName); if (it != tableMap.end()) { @@ -145,7 +145,7 @@ void SSDEngine::CompactMonitor() vector> SSDEngine::FetchEmbeddings(const string &tableName, vector &keys) { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } auto it = as_const(tableMap).find(tableName); if (it == tableMap.end()) { @@ -158,7 +158,7 @@ vector> SSDEngine::FetchEmbeddings(const string &tableName, vector void SSDEngine::Stop() { if (!isRunning) { - throw invalid_argument("SSDEngine not running"); + throw runtime_error("SSDEngine not running"); } isRunning = false; compactThread->join(); @@ -184,4 +184,16 @@ void SSDEngine::SetCompactThreshold(double threshold) return; } throw invalid_argument("compact threshold should in range [0, 1]"); -} \ No newline at end of file +} + +int64_t SSDEngine::GetTableEmbeddingSize(const string &tableName) +{ + if (!isRunning) { + throw runtime_error("SSDEngine not running"); + } + auto it = as_const(tableMap).find(tableName); + if (it == tableMap.end()) { + return -1; + } + return static_cast(it->second->GetTableUsage()); +} diff --git a/src/core/ssd_engine/ssd_engine.h b/src/core/ssd_engine/ssd_engine.h index d0d4ee59..b6ad644d 100644 --- a/src/core/ssd_engine/ssd_engine.h +++ b/src/core/ssd_engine/ssd_engine.h @@ -43,6 +43,8 @@ namespace MxRec { void SetCompactThreshold(double threshold); + int64_t GetTableEmbeddingSize(const string& tableName); + private: bool isRunning = false; diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index a5f1c546..e6a324ef 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -390,3 +390,9 @@ void Table::SetTablePathToDiskWithSpace() } } +uint64_t Table::GetTableUsage() +{ + lock_guard guard(rwLock); + return totalKeyCnt; +} + diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h index 3218b871..cb743f15 100644 --- a/src/core/ssd_engine/table.h +++ b/src/core/ssd_engine/table.h @@ -39,6 +39,8 @@ namespace MxRec { void Compact(bool fullCompact); + uint64_t GetTableUsage(); + private: void Load(const string& metaFilePath, int step); diff --git a/src/tests/ssd_engine/table_test.cpp b/src/tests/ssd_engine/table_test.cpp index f9765f58..6fdb06b8 100644 --- a/src/tests/ssd_engine/table_test.cpp +++ b/src/tests/ssd_engine/table_test.cpp @@ -130,4 +130,40 @@ TEST(Table, SaveAndLoad) for (const string &p: savePath) { fs::remove_all(p); } +} + +TEST(Table, GetTableUsage) +{ + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + GlogConfig::gRankId = to_string(rankId); + + string tbName = "test"; + vector savePath = {GlogConfig::gRankId}; + uint64_t maxTableSize = 100; + double compactThreshold = 0.5; + int saveStep = 0; + + // create + auto tbSave = make_shared

(tbName, savePath, maxTableSize, compactThreshold); + + // write + uint64_t expectKeyCnt = 2; + vector keys = {1, 2}; + vector> embs = {{0.1}, {0.2}}; + tbSave->InsertEmbeddings(keys, embs); + + // check before saving + uint64_t keyCntSave = tbSave->GetTableUsage(); + ASSERT_EQ(keyCntSave, expectKeyCnt); + + // check after saving + tbSave->Save(saveStep); + uint64_t keyCntSave2 = tbSave->GetTableUsage(); + ASSERT_EQ(keyCntSave2, expectKeyCnt); + + // check after load + auto tbLoad = make_shared
(tbName, savePath, maxTableSize, compactThreshold, saveStep); + uint64_t keyCntLoad = tbLoad->GetTableUsage(); + ASSERT_EQ(keyCntLoad, expectKeyCnt); } \ No newline at end of file -- Gitee From 26cc646ad271ce5466cb57336e6e2178cb122850 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 27 Oct 2023 10:58:23 +0800 Subject: [PATCH 397/551] Match-id-de400826ba2b0aa204cba7fc6ab22454c7f1a4ca --- mx_rec/core/embedding.py | 42 +++++++++---------- src/core/emb_table/emb_table.cpp | 21 ++++++++-- src/core/emb_table/emb_table.h | 4 ++ src/core/hybrid_mgmt/hybrid_mgmt.cpp | 58 +++++++++++++++++++++++++- src/core/hybrid_mgmt/hybrid_mgmt.h | 4 ++ src/core/key_process/key_process.cpp | 26 ++++++++++++ src/core/key_process/key_process.h | 4 ++ src/pybind/module_main.cpp | 4 +- src/tests/emb_table/emb_table_test.cpp | 2 +- 9 files changed, 137 insertions(+), 28 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 74c804ea..80f14f2f 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -23,7 +23,7 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, ge insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set, \ - get_table_instance_by_name + get_table_instance_by_name, get_asc_manager from mx_rec.validator.validator import ClassValidator, StringValidator, SSDFeatureValidator, \ para_checker_decorator, IntValidator, NumValidator, OptionValidator, OptionalIntValidator, OptionalStringValidator from mx_rec.util.tf_version_adapter import npu_ops @@ -207,27 +207,6 @@ class SparseEmbedding: optimizer.insert_slot(slot, named_slot_key, slot_name) - @staticmethod - def get_emb_table_size(table_name: str) -> int: - """ - For HBM or DDR mode, return the size of sparse embedding table - :param table_name: the name of sparse embedding table - :return: the size of the sparse embedding table - """ - table_instance = get_table_instance_by_name(table_name) - host_vocabulary_size = table_instance.host_vocabulary_size() - device_vocabulary_size = table_instance.device_vocabulary_size - if not host_vocabulary_size and not get_use_dynamic_expansion(): - embed_dim = table_instance.emb_size - size = embed_dim * device_vocabulary_size - elif not host_vocabulary_size and get_use_dynamic_expansion(): - embed_dim = table_instance.ext_emb_size - size = embed_dim * device_vocabulary_size - else: - embed_dim = table_instance.ext_emb_size - size = (device_vocabulary_size + host_vocabulary_size) * embed_dim - return size - @staticmethod def _get_own_emb(emb, all2all_args, emb_size, use_static): """ @@ -264,6 +243,25 @@ class SparseEmbedding: return tf.reshape(src_emb, reshape_info) + def size(self) -> int: + """ + For HBM or DDR or SSD mode, return the size of sparse table + """ + return get_asc_manager().get_table_size(self.table_name) + + def capacity(self) -> int: + """ + For HBM or DDR or SSD mode, return the capacity of sparse table + """ + if get_use_dynamic_expansion(): + return get_asc_manager().get_table_capacity(self.table_name) + + if not self.host_vocabulary_size and not self.ssd_vocabulary_size: + return self.device_vocabulary_size + if not self.ssd_vocabulary_size: + return self.device_vocabulary_size + self.host_vocabulary_size + return self.device_vocabulary_size + self.host_vocabulary_size + self.ssd_vocabulary_size + def check_optimizer_instance(self): for optimizer_instance in self._optimizer_instance_list: if tf.__version__.startswith("1"): diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index cd8b112d..148796fc 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -14,6 +14,7 @@ #include "initializer/initializer.h" #include "emb_table/emb_table.h" + using namespace std; using namespace MxRec; using namespace tensorflow; @@ -47,7 +48,7 @@ void EmbTable::Init(const EmbInfo& eInfo, const RankInfo& rInfo, int initSeed) memoryList.push_back(newBlock); SplitMemoryBlock(newBlock); } - totalCapacity = static_cast(memoryList.size()); + totalCapacity = static_cast(memoryList.size()) * BLOCK_EMB_COUNT; LOG_INFO("aclrtMalloc success, emb name:{}, total capacity:{}", embInfo.name, totalCapacity); #endif } @@ -83,7 +84,7 @@ int64_t EmbTable::GetEmbAddress() // 将新的内存块加入内存list memoryList.push_back(addBlock); SplitMemoryBlock(addBlock); - totalCapacity++; + totalCapacity += BLOCK_EMB_COUNT; } float *embAddr = embeddingList.front(); embeddingList.pop_front(); @@ -138,5 +139,19 @@ void EmbTable::PrintStatus() const { // 输出embedding table的总容量和未使用的使用容量 LOG_INFO("Total capacity:{}, Unused capacity:{}", - totalCapacity * blockSize, totalCapacity * blockSize - usedCapacity * embSize); + totalCapacity * embSize, totalCapacity * embSize - usedCapacity * embSize); +} + +int64_t EmbTable::GetTableSize() const +{ +#ifndef GTEST + return static_cast(usedCapacity); +#endif +} + +int64_t EmbTable::GetTableCapacity() const +{ +#ifndef GTEST + return static_cast(totalCapacity); +#endif } diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h index 42eca691..8136165e 100644 --- a/src/core/emb_table/emb_table.h +++ b/src/core/emb_table/emb_table.h @@ -33,6 +33,10 @@ namespace MxRec { // 打印emb表使用情况 void PrintStatus() const; + int64_t GetTableSize() const; + + int64_t GetTableCapacity() const; + EmbTable(const EmbTable&) = delete; EmbTable(EmbTable&&) = delete; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 89a9fd74..2dcac774 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -898,7 +898,7 @@ bool HybridMgmt::Evict() void HybridMgmt::EvictKeys(const string& embName, const vector& keys) { #ifndef GTEST - LOG_DEBUG(MGMT + "ddr mode, delete emb: [{}]! evict keySize:{}", embName.c_str(), keys.size()); + LOG_DEBUG(MGMT + "ddr mode, delete emb: [{}]! evict keySize:{}", embName, keys.size()); // 删除映射关系 if (keys.size() != 0) { hostHashMaps->EvictDeleteEmb(embName, keys); @@ -1000,4 +1000,60 @@ void HybridMgmt::NotifyBySessionRun(int channelID) const void HybridMgmt::CountStepBySessionRun(int channelID, int steps) const { hybridMgmtBlock->CountPythonStep(channelID, steps); +} + +/// 获取table表使用大小 +/// \param embName 表名 +/// \return 表使用大小 +int64_t HybridMgmt::GetTableSize(const string& embName) const +{ +#ifndef GTEST + if (mgmtRankInfo.useDynamicExpansion) { + int64_t size = preprocess->GetExpansionTableSize(embName); + LOG_INFO(MGMT + "dynamic expansion mode, get emb:[{}] size:{}", embName, size); + return size; + } + if (mgmtRankInfo.noDDR) { + auto maxOffset = preprocess->GetMaxOffset(); + const auto& iter = maxOffset.find(embName); + if (iter == maxOffset.end()) { + LOG_ERROR(MGMT + "get maxOffset, wrong embName:{} ", embName); + return -1; + } + int64_t size = static_cast(maxOffset[embName]); + LOG_INFO(MGMT + "HBM mode, get emb:[{}] size:{}", embName, size); + return size; + } + int64_t ssdSize = 0; + if (mgmtRankInfo.isSSDEnabled) { + ssdSize= cacheManager->GetTableEmbeddingSize(embName); + } + + const auto& iter = hostHashMaps->embHashMaps.find(embName); + if (iter == hostHashMaps->embHashMaps.end()) { + LOG_ERROR(MGMT + "get maxOffset, wrong embName:{} ", embName); + return -1; + } + auto maxOffset = hostHashMaps->embHashMaps.at(embName).maxOffset; + int64_t size = static_cast(maxOffset) + ssdSize; + + LOG_INFO(MGMT + "DDR/SSD mode, get emb:[{}] size:{}", embName, size); + return size; +#endif +} + +/// 获取table表容量大小 +/// \param embName 表名 +/// \return 表容量大小 +int64_t HybridMgmt::GetTableCapacity(const string& embName) const +{ +#ifndef GTEST + if (mgmtRankInfo.useDynamicExpansion) { + int64_t capacity = preprocess->GetExpansionTableCapacity(embName); + LOG_INFO(MGMT + "dynamic expansion mode, get emb:[{}] capacity:{}", embName, capacity); + return capacity; + } + LOG_WARN(MGMT + "no dynamic expansion mode, get emb:[{}] capacity failed", embName); + return -1; +#endif } \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 8f45a745..7f533aa8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -121,6 +121,10 @@ namespace MxRec { void CountStepBySessionRun(int channelID, int steps) const; + int64_t GetTableSize(const string& embName) const; + + int64_t GetTableCapacity(const string& embName) const; + private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index c8c7304a..57ea47d0 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1375,3 +1375,29 @@ string KeyProcess::DumpSplitKeys(vector> &splitKeys) const } return ssTrace.str(); } + +int64_t KeyProcess::GetExpansionTableSize(const string& embName) +{ +#ifndef GTEST + const auto& iter = embeddingTableMap.find(embName); + if (iter == embeddingTableMap.end()) { + LOG_ERROR(KEY_PROCESS "GetExpansionEmbSize, wrong embName:{} ", embName); + return -1; + } + std::lock_guard lk(mut); // lock for PROCESS_THREAD + return embeddingTableMap[embName].GetTableSize(); +#endif +} + +int64_t KeyProcess::GetExpansionTableCapacity(const string& embName) +{ +#ifndef GTEST + const auto& iter = embeddingTableMap.find(embName); + if (iter == embeddingTableMap.end()) { + LOG_ERROR(KEY_PROCESS "GetExpansionEmbSize, wrong embName:{} ", embName); + return -1; + } + std::lock_guard lk(mut); // lock for PROCESS_THREAD + return embeddingTableMap[embName].GetTableCapacity(); +#endif +} \ No newline at end of file diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index fba24d0e..83b71a03 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -112,6 +112,10 @@ namespace MxRec { void SetupHotEmbUpdateStep(); + int64_t GetExpansionTableSize(const string& embName); + + int64_t GetExpansionTableCapacity(const string& embName); + template void GlobalUnique(T& lookupKeys, T& uniqueKeys, vector& restoreVecSec) { diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index 387c6307..a8954c7d 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -190,7 +190,9 @@ namespace { .def("receive", &MxRec::HybridMgmt::ReceiveHostMap, py::arg("key_offset_map")) .def("block_notify_wake", &MxRec::HybridMgmt::NotifyBySessionRun, py::arg("channel_id")) .def("block_count_steps", &MxRec::HybridMgmt::CountStepBySessionRun, - py::arg("channel_id"), py::arg("steps")=1); + py::arg("channel_id"), py::arg("steps")=1) + .def("get_table_size", &MxRec::HybridMgmt::GetTableSize, py::arg("table_name")) + .def("get_table_capacity", &MxRec::HybridMgmt::GetTableCapacity, py::arg("table_name")); } void GetThresholdValue(pybind11::module_& m) diff --git a/src/tests/emb_table/emb_table_test.cpp b/src/tests/emb_table/emb_table_test.cpp index ecc8711d..d669c7c6 100644 --- a/src/tests/emb_table/emb_table_test.cpp +++ b/src/tests/emb_table/emb_table_test.cpp @@ -66,7 +66,7 @@ TEST_F(EmbTableTest, Init) ASSERT_EQ(embTable.rankInfo.localRankId, rankInfo.localRankId); // 测试容量是否正常 LOG_INFO("totalCapacity {}, INIT_BLOCK_COUNT {}", embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); - EXPECT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT); + EXPECT_EQ(embTable.totalCapacity, embTable.INIT_BLOCK_COUNT * embTable.BLOCK_EMB_COUNT); #endif } -- Gitee From ead318735464055b05abb9569aec3eca3affe2f9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 27 Oct 2023 14:44:50 +0800 Subject: [PATCH 398/551] Match-id-45e56402f8ffb18a42e9d5f05842d75d44a6fd17 --- src/core/hd_transfer/acl_transfer.cpp | 27 --------------------------- src/core/hd_transfer/acl_transfer.h | 19 ------------------- src/core/hd_transfer/hd_transfer.cpp | 17 ++++++++++------- src/core/hd_transfer/hd_transfer.h | 2 +- src/core/host_emb/host_emb.cpp | 19 +++++++++++++++---- 5 files changed, 26 insertions(+), 58 deletions(-) delete mode 100644 src/core/hd_transfer/acl_transfer.cpp delete mode 100644 src/core/hd_transfer/acl_transfer.h diff --git a/src/core/hd_transfer/acl_transfer.cpp b/src/core/hd_transfer/acl_transfer.cpp deleted file mode 100644 index 375a2b43..00000000 --- a/src/core/hd_transfer/acl_transfer.cpp +++ /dev/null @@ -1,27 +0,0 @@ -// -// Created by w00842226 on 2023/10/26. -// -#include "acl_transfer.h" - -AclTransferStatus RecvByAcl(const acltdtChannelHandle *handle, acltdtDataset *dataset, float* resultPtr){ -#ifndef GTEST - if (dataset==nullptr || handle==nullptr) { - throw runtime_error(StringFormat("handle or dataset is nullptr:%s."); - } - auto aclStatus = acltdtReceiveTensor(handle, dataset, GlobalEnv::aclTimeout); - if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { - return AclTransferStatus::F001; - } - auto size = acltdtGetDatasetSize(dataset); - if (size == 0) { - LOG_WARN(HOSTEMB + "recv empty data"); - return AclTransferStatus::OK; - } - auto aclData = acltdtGetDataItem(dataset, 0); - if (aclData == nullptr) { - return AclTransferStatus::F001; - } - resultPtr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); -#endif - return AclTransferStatus::OK;; -} \ No newline at end of file diff --git a/src/core/hd_transfer/acl_transfer.h b/src/core/hd_transfer/acl_transfer.h deleted file mode 100644 index 5bdba433..00000000 --- a/src/core/hd_transfer/acl_transfer.h +++ /dev/null @@ -1,19 +0,0 @@ -// -// Created by w00842226 on 2023/10/26. -// -#include "hd_transfer.h" -#include -#include "utils/common.h" -#include "utils/time_cost.h" - -using namespace MxRec; -using namespace std; -#ifndef MXREC_ACL_TRANSFER_H -#define MXREC_ACL_TRANSFER_H -enum class AclTransferStatus { - OK, - F001 -}; -AclTransferStatus RecvByAcl(const acltdtChannelHandle *handle, const acltdtDataset *dataset, float* resultPtr); - -#endif //MXREC_ACL_TRANSFER_H diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index f1b259b9..a20fd164 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -184,22 +184,25 @@ vector HDTransfer::Recv(TransferChannel channel, int channel /// \param channelId 通道索引(训练/推理) /// \param embName 表名 /// \return -void HDTransfer::RecvD2H(int channelId, const string& embName, float*& resultPtr) +size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& embName) { EASY_FUNCTION() #ifndef GTEST std::vector tensors; - TransferChannel channleName = TransferChannel::D2H; - string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channleName).c_str(), channelId); + string recvName = StringFormat("%s_%s_%d", embName.c_str(), TransferChannel2Str(channel).c_str(), channelId); LOG_DEBUG("hd transfer try recv:{}", recvName); TimeCost tc = TimeCost(); - auto aclStatus = RecvByAcl(transferChannels[recvName], aclDatasets[embName], resultPtr); - if (aclStatus != AclTransferStatus::OK) { - throw runtime_error(StringFormat("Failed receive data from acl channel, acl status:%d", aclStatus).c_str()); + if (aclDatasets[embName] == nullptr) { + throw runtime_error(StringFormat("Failed recv:%s.", recvName.c_str()).c_str()); } + auto aclStatus = acltdtReceiveTensor(transferChannels[recvName], aclDatasets[embName], GlobalEnv::aclTimeout); if (!running) { - return; + return 0; + } + if (aclStatus != ACL_ERROR_NONE && aclStatus != ACL_ERROR_RT_QUEUE_EMPTY) { + throw runtime_error(StringFormat("Failed receive data from acl channel, acl status:%d", aclStatus).c_str()); } LOG_INFO("hd transfer recv:{} cost:{}ms", recvName, tc.ElapsedMS()); + return acltdtGetDatasetSize(aclDatasets[embName]); #endif } diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index 40104105..b7f9ea38 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -79,7 +79,7 @@ namespace MxRec { vector Recv(TransferChannel channel, int channelId, const string& embName); - void RecvD2H(int channelId, const string& embName, float*& resultPtr); + size_t RecvAcl(TransferChannel channel, int channelId, const string& embName); void Destroy(); diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 06c1e75d..eba0b971 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -137,18 +137,29 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI auto updateThread = [this, missingKeysHostPos, channelId, embName] { auto hdTransfer = Singleton::GetInstance(); + TransferChannel transferName = TransferChannel::D2H; LOG_INFO(HOSTEMB + "wait D2H embs, channelId:{}", channelId); - float* ptr = nullptr; - hdTransfer->RecvD2H(channelId, embName, ptr); - if (!ptr) { + auto size = hdTransfer->RecvAcl(transferName, channelId, embName); + if (size == 0) { + LOG_WARN(HOSTEMB + "recv empty data"); return; } - TimeCost tc = TimeCost(); + EASY_BLOCK("Update") auto& embData = hostEmbs[embName].embData; + auto embeddingSize = hostEmbs[embName].hostEmbInfo.extEmbeddingSize; + auto aclData = acltdtGetDataItem(hdTransfer->aclDatasets[embName], 0); + if (aclData == nullptr) { + throw runtime_error("Acl get tensor data from dataset failed."); + } + float* ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); + + size_t elementSize = acltdtGetDataSizeFromItem(aclData); + size_t dimNum = acltdtGetDimNumFromItem(aclData); LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}," " embData.size = {}, RecvAcl = {}, elementSize = {}, dimNum = {}", + embName, missingKeysHostPos.size(), embeddingSize, embData.size(), size, elementSize, dimNum); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ptr, embData, embeddingSize) for (size_t j = 0; j < missingKeysHostPos.size(); j++) { auto& dst = embData[missingKeysHostPos[j]]; -- Gitee From b6c3d8c29c334fec0a329384ea6a61e2e82bf77d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 23 Oct 2023 16:27:31 +0800 Subject: [PATCH 399/551] Match-id-14497ab49c10a2955defeda1828dc7254f1232a1 --- src/core/key_process/key_process.cpp | 13 ++++++++----- src/core/key_process/key_process.h | 4 ++-- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index c8c7304a..e2d617c4 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -160,7 +160,7 @@ void KeyProcess::Destroy() void KeyProcess::LoadSaveLock() { for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { - for (int threadId { 0 }; threadId < KEY_PROCESS_THREAD; ++threadId) { + for (int threadId { 0 }; threadId < MAX_KEY_PROCESS_THREAD; ++threadId) { loadSaveMut[channelId][threadId].lock(); } } @@ -170,7 +170,7 @@ void KeyProcess::LoadSaveLock() void KeyProcess::LoadSaveUnlock() { for (int channelId { 0 }; channelId < MAX_CHANNEL_NUM; ++channelId) { - for (int threadId { 0 }; threadId < KEY_PROCESS_THREAD; ++threadId) { + for (int threadId { 0 }; threadId < MAX_KEY_PROCESS_THREAD; ++threadId) { loadSaveMut[channelId][threadId].unlock(); } } @@ -255,7 +255,8 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) " get data time(ms):{}, batch name:{}, channel:{}, batchID:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, batch->name, batch->channel, batch->batchId); - auto batchQueue = SingletonQueue::GetInstances(threadId + KEY_PROCESS_THREAD * batch->channel); + int queueIndex = threadId + (MAX_KEY_PROCESS_THREAD * batch->channel); + auto batchQueue = SingletonQueue::GetInstances(queueIndex); batchQueue->PutDirty(move(batch)); } unique->UnInitialize(); @@ -289,7 +290,8 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) " get data time(ms):{}, batch name:{}, channel:{}, batchID:{}", getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, batch->name, batch->channel, batch->batchId); - auto batchQueue = SingletonQueue::GetInstances(threadId + KEY_PROCESS_THREAD * batch->channel); + int queueIndex = threadId + (MAX_KEY_PROCESS_THREAD * batch->channel); + auto batchQueue = SingletonQueue::GetInstances(queueIndex); batchQueue->PutDirty(move(batch)); } } catch (const EndRunExit &e) { @@ -504,7 +506,8 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) unique_ptr batch = nullptr; // train data, queue id = thread id [0, KEY_PROCESS_THREAD-1] - auto batchQueue = SingletonQueue::GetInstances(commId + KEY_PROCESS_THREAD * channel); + int queueIndex = commId + (MAX_KEY_PROCESS_THREAD * channel); + auto batchQueue = SingletonQueue::GetInstances(queueIndex); EASY_BLOCK("get samples") EASY_VALUE("run on CPU", sched_getcpu()) TimeCost tc = TimeCost(); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index fba24d0e..9ecf3831 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -158,10 +158,10 @@ namespace MxRec { RankInfo rankInfo; map embInfos; - MPI_Comm comm[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD]; + MPI_Comm comm[MAX_CHANNEL_NUM][MAX_KEY_PROCESS_THREAD]; std::mutex mut {}; vector> procThreads {}; - std::mutex loadSaveMut[MAX_CHANNEL_NUM][KEY_PROCESS_THREAD] {}; + std::mutex loadSaveMut[MAX_CHANNEL_NUM][MAX_KEY_PROCESS_THREAD] {}; info_list_t lookupKeysList; list>> storage; info_list_t infoList; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index c8b29dcc..22b702a4 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -159,7 +159,7 @@ namespace MxRec { // [batchId % KEY_PROCESS_THREAD] which thread process this batch // [KEY_PROCESS_THREAD * 0 or 1] train or inference - int batchQueueId = batchId % threadNum + KEY_PROCESS_THREAD * channelId; + int batchQueueId = (batchId % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); -- Gitee From 14bc49c075b0da7062f688d733ef50754853daf0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 27 Oct 2023 17:16:41 +0800 Subject: [PATCH 400/551] Match-id-4fc0db38f8635aa04fdd16a677d80b287a8fa969 --- src/core/ssd_engine/table.cpp | 5 +---- src/tests/key_process/key_process_test.cpp | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index a5f1c546..a51d6388 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -295,10 +295,7 @@ vector> Table::FetchEmbeddingsInner(vector &keys) #pragma omp parallel for num_threads(readThreadNum) default(none) shared(ret, queryLen, queryList) for (size_t i = 0; i < queryLen; ++i) { tuple item = queryList[i]; - shared_ptr f; - vector batchKeys; - vector batchIdx; - tie(f, batchKeys, batchIdx) = item; + auto [f, batchKeys, batchIdx] = item; vector> batchRet = f->FetchEmbeddings(batchKeys); size_t batchLen = batchRet.size(); for (size_t j = 0; j < batchLen; ++j) { diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 3c4ebf73..35688816 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -493,6 +493,7 @@ TEST_F(KeyProcessTest, InitializeUnique) bool uniqueInitialize = false; size_t preBatchSize = 0; process.InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); + unique->UnInitialize(); } TEST_F(KeyProcessTest, GetKeySize) -- Gitee From a88c83334e5106ca4bce32b0792c1c67909e18da Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 28 Oct 2023 09:05:09 +0800 Subject: [PATCH 401/551] Match-id-84893a68c8b22db458b96911058bbcb33caae807 --- src/core/checkpoint/checkpoint.cpp | 15 +++++---------- src/core/checkpoint/checkpoint.h | 1 + 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index e92fa0a6..c536d6d2 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -271,6 +271,11 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, throw runtime_error(StringFormat("Invalid EmbHashMapSize:%d, must be greater than 0", embHashMapSize).c_str()); } auto embeddingSize = static_cast(datasetSize / sizeof(float) / embHashMapSize); + EmbSizeInfo embSizeInfo = GetEmbeddingSize(embName); + if (embeddingSize != embSizeInfo.extEmbSize) { + readFile.close(); + throw runtime_error(StringFormat("Invalid embedding size to be read, may read file has been changed").c_str()); + } aclError ret; void *newBlock = nullptr; @@ -283,16 +288,6 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, float *floatPtr = static_cast(newBlock); auto &transArr = transData.int64Arr; - EmbSizeInfo embSizeInfo = GetEmbeddingSize(embName); - if (embSizeInfo.embSize == 0) { - readFile.close(); - throw runtime_error(StringFormat("embsize is 0").c_str()); - } - auto keyAddrElem = embSizeInfo.extEmbSize / embSizeInfo.embSize - 1; - if (keyAddrElem < 0) { - readFile.close(); - throw runtime_error(StringFormat("keyAddrElem: %d is less than 0", keyAddrElem).c_str()); - } for (size_t i = 0, j = 0; i < transArr.size(); i += keyAddrElem, ++j) { vector row(embeddingSize); readFile.read(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index cadd86d1..3403a576 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -72,6 +72,7 @@ namespace MxRec { const int embHashNum { 2 }; const int attribEmbDataOuterIdx { 0 }; const int attribEmbDataInnerIdx { 1 }; + const int keyAddrElem { 2 }; void SetDataHandler(CkptData& ckptData); void SetDataHandler(const vector& featureTypes); -- Gitee From a954154599a0327231daaa7c91697e0c8cb75b95 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 30 Oct 2023 09:23:52 +0800 Subject: [PATCH 402/551] Match-id-1bdd73b0132ce0d224f6e3658bd1956b12152c66 --- src/core/hd_transfer/hd_transfer.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index a20fd164..44736319 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -8,7 +8,6 @@ #include #include "utils/common.h" #include "utils/time_cost.h" -#include "acl_transfer.h" using namespace MxRec; using namespace std; -- Gitee From a28ddf76e5485a8e56750cc3259e57e871554286 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 30 Oct 2023 11:07:18 +0800 Subject: [PATCH 403/551] Match-id-8c55c839aa3d1f5c32e4929bd019b57237141a36 --- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 5056ac2a..fb380f04 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -418,7 +418,7 @@ namespace MxRec { // [batchId % KEY_PROCESS_THREAD] which thread process this batch // [KEY_PROCESS_THREAD * 0 or 1] train or inference - int batchQueueId = batchId % threadNum + KEY_PROCESS_THREAD * channelId; + int batchQueueId = (batchId % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId); OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); auto out = output->flat(); -- Gitee From f20f08a45eebcce729ca4e455a0e8c2dfe70a8f7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 30 Oct 2023 20:32:55 +0800 Subject: [PATCH 404/551] Match-id-2cc5cec05a819ec1aa2969dbddf395235dbb5edd --- mx_rec/saver/saver.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index b4915352..d2f6a962 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -110,10 +110,13 @@ class Saver(object): try: if save_path.find("://") == -1: - DirectoryValidator("saving_path", saving_path).with_blacklist(exact_compare=False).check() + directory_validator = DirectoryValidator("saving_path", saving_path) + directory_validator.check_not_soft_link() + directory_validator.with_blacklist(exact_compare=False) + directory_validator.check() except ValueError as err: raise ValueError(f"The saving path {saving_path} cannot be a system directory " - f"or a subdirectory of the system directory.") from err + f"and cannot be soft link.") from err if tf.io.gfile.exists(saving_path): tf.io.gfile.rmtree(saving_path) -- Gitee From 688537fdbe4e2cf92f2ac6be310912fb41e0ba1b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 31 Oct 2023 16:08:05 +0800 Subject: [PATCH 405/551] Match-id-b92a408c6ff8f8f33da932ed3a43b20fa7192841 --- mx_rec/core/asc/manager.py | 6 ++++-- .../feat_admit_n_evict_ckpt.cpp | 3 +++ .../feat_admit_n_evict_ckpt.h | 3 ++- src/core/key_process/feature_admit_and_evict.cpp | 16 ++++++++++------ src/core/utils/common.h | 4 +++- src/ops_tf/hybrid_dataset_ops.cpp | 4 ++-- src/pybind/module_main.cpp | 5 +++-- src/tests/checkpoint/checkpoint_test.cpp | 1 + .../ckpt_data_handler/ckpt_data_handler_test.cpp | 6 ++++-- src/tests/emb_mgmt/emb_mgmt_test.cpp | 6 +++--- .../key_process/feature_admit_and_evict_test.cpp | 3 ++- 11 files changed, 37 insertions(+), 20 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 982bf924..8c39338f 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -170,14 +170,16 @@ def generate_threshold_list(): threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, feature_spec.eviction_threshold, - coef) + coef, + True) threshold_list.append(threshold) continue if feature_spec.access_threshold: threshold = ThresholdValue(feature_spec.table_name, feature_spec.access_threshold, -1, - coef) + coef, + True) threshold_list.append(threshold) return threshold_list diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index 4d3c7444..fbbce658 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -90,6 +90,8 @@ void FeatAdmitNEvictCkpt::SetTable2ThreshTrans(string embName) transArr.reserve(table2ThreshSize); transArr.push_back(table2Thresh.countThreshold); transArr.push_back(table2Thresh.timeThreshold); + int32_t isSum = table2Thresh.isEnableSum ? 1 : 0; + transArr.push_back(isSum); } // save @@ -121,6 +123,7 @@ void FeatAdmitNEvictCkpt::SetTable2Thresh(string embName) tens2Thresh.tableName = embName; tens2Thresh.countThreshold = transArr[countThresholdIdx]; tens2Thresh.timeThreshold = transArr[timeThresholdIdx]; + tens2Thresh.isEnableSum = (transArr[isSumThresholdIdx] == 1); } // load diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h index adc9a830..ee716abf 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h @@ -35,10 +35,11 @@ namespace MxRec { const vector saveDataTypes { CkptDataType::TABLE_2_THRESH, CkptDataType::HIST_REC }; const int featItemInfoSaveNum { 3 }; - const int threshValSaveNum { 2 }; + const int threshValSaveNum { 3 }; const int countThresholdIdx { 0 }; const int timeThresholdIdx { 1 }; + const int isSumThresholdIdx { 2 }; const int featItemInfoOffset { 1 }; diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index e6bdfd61..3476b8e0 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -106,13 +106,13 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con absl::flat_hash_map& historyRecordInfos = m_recordsData.historyRecords[tableName]; auto innerIt = historyRecordInfos.find(featureId); - uint32_t countThreshold = static_cast(m_table2Threshold[tableNameOrigin].countThreshold); - // countThreshold = 0或者eval,只查询count,不做累加,若是新key,则count使用初始值0 - if (channel == EVAL_CHANNEL_ID || countThreshold == 0) { + + // isEnableSum = false或者eval,只查询count,不做累加,若是新key,则count使用初始值0 + if (channel == EVAL_CHANNEL_ID || m_table2Threshold[tableNameOrigin].isEnableSum == false) { if (innerIt != historyRecordInfos.end()) { currKeyCount = historyRecordInfos[featureId].count; } - } else if (channel == TRAIN_CHANNEL_ID) { // train 且 countThreshold > 0 + } else if (channel == TRAIN_CHANNEL_ID) { // train 且 isEnableSum = true if (innerIt == historyRecordInfos.end()) { // 维护 m_historyRecords FeatureItemInfo info(featureCnt, m_recordsData.timestamps[tableName]); @@ -127,9 +127,8 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con currKeyCount = info.count; } } - // 准入条件判断 - if (currKeyCount >= countThreshold) { + if (currKeyCount >= static_cast(m_table2Threshold[tableNameOrigin].countThreshold)) { return FeatureAdmitType::FEATURE_ADMIT_OK; } @@ -275,6 +274,11 @@ bool FeatureAdmitAndEvict::SetTableThreshold(int threshold, string embName) LOG_WARN("SetTableThreshold failed, cause embName [{}] is not in m_table2Threshold...", embName); return false; } + if (threshold == 0) { + LOG_INFO("SetTableThreshold success, embName[{}], isEnableSum = false ...", embName); + m_table2Threshold[embName].isEnableSum = false; + return true; + } LOG_INFO("SetTableThreshold success, embName[{}], count before [{}], count after [{}], time[{}], " "coefficient[{}] ...", embName, m_table2Threshold[embName].countThreshold, threshold, m_table2Threshold[embName].timeThreshold, m_table2Threshold[embName].faaeCoefficient); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 45a275eb..d1e1ca7f 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -254,18 +254,20 @@ namespace MxRec { struct ThresholdValue { ThresholdValue() = default; - ThresholdValue(EmbNameT name, int countThre, int timeThre, int faaeCoef) + ThresholdValue(EmbNameT name, int countThre, int timeThre, int faaeCoef, bool isSum) { tableName = name; countThreshold = countThre; timeThreshold = timeThre; faaeCoefficient = faaeCoef; + isEnableSum = isSum; } EmbNameT tableName { "" }; // embName int countThreshold { -1 }; // 只配置count,即“只有准入、而没有淘汰”功能,对应SingleHostEmbTableStatus::SETS_ONLY_ADMIT状态 int timeThreshold { -1 }; // 只配置time,配置错误;即准入是淘汰的前提,对应SingleHostEmbTableStatus::SETS_BOTH状态 int faaeCoefficient { 1 }; // 配置后,该表在准入时,count计数会乘以该系数 + bool isEnableSum {true}; // 配置false,该表在准入时,count计数不会累加 }; struct FeatureItemInfo { diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index fb380f04..a919637b 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -125,8 +125,8 @@ namespace MxRec { auto src = reinterpret_cast(inputTensor.tensor_data().data()); std::copy(src, src + 1, &threshold); - if (threshold <= 0) { - LOG_ERROR("set threshold[{}] <= 0 ", threshold); + if (threshold < 0) { + LOG_ERROR("set threshold[{}] < 0 ", threshold); return 0; } LOG_INFO("ParseThresholdAndCheck, emb_name:[{}], ids_name: [{}], threshold: [{}]", diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index a8954c7d..d3793f1d 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -198,11 +198,12 @@ namespace { void GetThresholdValue(pybind11::module_& m) { pybind11::class_(m, "ThresholdValue") - .def(pybind11::init()) + .def(pybind11::init()) .def_readwrite("table_name", &ThresholdValue::tableName) .def_readwrite("count_threshold", &ThresholdValue::countThreshold) .def_readwrite("time_threshold", &ThresholdValue::timeThreshold) - .def_readwrite("faae_coefficient", &ThresholdValue::faaeCoefficient); + .def_readwrite("faae_coefficient", &ThresholdValue::faaeCoefficient) + .def_readwrite("is_enable_sum", &ThresholdValue::isEnableSum); } } diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index c71f42fc..8bd950df 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -216,6 +216,7 @@ protected: val.countThreshold = offsetMem; val.timeThreshold = offsetMem; val.faaeCoefficient = 1; + val.isEnableSum = true; offsetMem++; diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 2ac01a45..0e81ea56 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -85,21 +85,23 @@ protected: { int countThreshold { 20 }; int timeThreshold { 100 }; + int isSum {1}; for (const auto& testEmbInfo : testEmbInfos) { ThresholdValue val; val.tableName = testEmbInfo.name; val.countThreshold = countThreshold; val.timeThreshold = timeThreshold; + val.isEnableSum = true; - vector valid { countThreshold, timeThreshold }; + vector valid { countThreshold, timeThreshold, isSum}; countThreshold++; timeThreshold++; testTable2Threshold[testEmbInfo.name] = move(val); validArr[testEmbInfo.name] = move(valid); - validAttrib[testEmbInfo.name].push_back(2); // 2 is element num in one vector + validAttrib[testEmbInfo.name].push_back(3); // 3 is element num in one vector validAttrib[testEmbInfo.name].push_back(int32Bytes); } } diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 39e53d6b..20735b94 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -177,7 +177,7 @@ TEST_F(EmbMgmtTest, Initialize_HBM) embInfo = EmbInfo(params, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1, 1); + thresholdValues.emplace_back(name, 1, 1, 1, true); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; @@ -198,7 +198,7 @@ TEST_F(EmbMgmtTest, Evict) embInfo = EmbInfo(params, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1, 1); + thresholdValues.emplace_back(name, 1, 1, 1, true); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; @@ -222,7 +222,7 @@ TEST_F(EmbMgmtTest, Evict_HBM) embInfo = EmbInfo(params, vocabsize, initializeInfos); embInfos.emplace_back(embInfo); vector thresholdValues; - thresholdValues.emplace_back(name, 1, 1, 1); + thresholdValues.emplace_back(name, 1, 1, 1, true); auto hybridMgmt = Singleton::GetInstance(); cout << "setup..." << endl; diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index acbdbdf1..4ef990f2 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -446,7 +446,8 @@ protected: vector cnt4 = {1, 2, 3, 2, 1, 2, 1}; KeysT keys5 = {125, 121, 122, 212, 211}; vector cnt5 = {1, 2, 1, 3, 1}; - std::vector thresholds = {{"tableAAA", 2, 5, 1}, {"tableBBB", 3, 7, 1}, {"tableCCC", 5, 9, 1}}; + std::vector thresholds = {{"tableAAA", 2, 5, 1, true}, {"tableBBB", 3, 7, 1, true}, + {"tableCCC", 5, 9, 1, true}}; }; void SetEnv() -- Gitee From 0a020e38cd73e6733f497b7460a9e390c9ceabbd Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 1 Nov 2023 09:17:11 +0800 Subject: [PATCH 406/551] Match-id-a6b86622c61a0108d6f11796681162a9793c29f9 --- mx_rec/validator/validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index 63fa11b2..4bd2012b 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -496,8 +496,8 @@ class DirectoryValidator(StringValidator): return self def check_not_soft_link(self): - self.register_checker(lambda: os.path.realpath(self.value) == os.path.normpath(self.value), - "soft link or relative path should not be in the path parameter") + self.register_checker(lambda: not os.path.islink(self.value), + f"soft link or relative path: {self.value} should not be in the path parameter") return self def path_should_exist(self, is_file=True, msg=None): -- Gitee From a0f80231a35569bce76f5348c1cd526f1c7b388c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 2 Nov 2023 19:47:45 +0800 Subject: [PATCH 407/551] Match-id-f031f974ade36435cc700092ceea34273e2bf606 --- src/CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d48ae2c9..6c770bca 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,7 +15,6 @@ endif () if (DEFINED PYTHON_PATH) set(PYTHON_INCLUDE_DIR ${PYTHON_PATH}/include/python3.7m) set(PYTHON_LIBRARY ${PYTHON_PATH}/lib/libpython3.7m.so) - set(PYTHON_LIB_PATH ${PYTHON_PATH}/lib) else () message("ERROR no PYTHON_PATH") endif () @@ -44,7 +43,7 @@ else () endif () set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -ffunction-sections -O0 -Wall -g2 -ggdb") set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -ffunction-sections -O3 -Wfatal-errors -DNDEBUG -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -s") -set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -Wl,--disable-new-dtags,--rpath=${PYTHON_LIB_PATH}") +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack") option(ENABLE_DEBUG "use debug mode" OFF) if (ENABLE_DEBUG) -- Gitee From b858800e729621eded926488f2bcd8f216cda853 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 2 Nov 2023 15:10:06 +0800 Subject: [PATCH 408/551] Match-id-9fa5196aa32ec6149a547b2e4717470b5c1bef27 --- mx_rec/core/asc/helper.py | 12 +++++++++++- mx_rec/util/initialize.py | 6 +++--- src/core/checkpoint/checkpoint.cpp | 4 ++++ src/tests/key_process/key_process_test.cpp | 1 + 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 7925fbd4..fd9d4214 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -12,6 +12,7 @@ from mx_rec.core.asc.merge_table import find_dangling_table, should_skip from mx_rec.validator.validator import para_checker_decorator, ValueCompareValidator, ClassValidator, \ OptionalIntValidator from mx_rec.util.log import logger +from mx_rec.util.normalization import fix_invalid_table_name from mx_rec.constants.constants import MAX_INT32 @@ -31,11 +32,20 @@ def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, table_names=No desperated. use create_asc_insert_func_with_specs or create_asc_insert_func_with_agc ''' + # condition 1: only tgt_key_specs if tgt_key_specs is not None: + if args_index_list is not None or table_names is not None: + raise RuntimeError("call get_asc_insert_func in-correctly, when tgt_key_specs is not None, " + "please set args_index_list and table_names None.") return create_asc_insert_func_with_specs(tgt_key_specs=tgt_key_specs, **kwargs) + # condition 2: only args_index_list and table_names if args_index_list is not None: + if table_names is None: + raise RuntimeError("call get_asc_insert_func in-correctly, when args_index_list is not None, " + "please set tgt_key_specs None and set table_names correctly.") + fixed_table_names = [fix_invalid_table_name(table_name) for table_name in table_names] return create_asc_insert_func_with_acg(args_index_list=args_index_list, - table_names=table_names, + table_names=fixed_table_names, **kwargs) raise RuntimeError("call get_asc_insert_func in-correctly.") diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 23ee3866..eb624b24 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -29,9 +29,9 @@ class ConfigInitializer: @para_checker_decorator(check_option_list=[ ("use_mpi", ClassValidator, {"classes": (bool,)}), - ("train_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), - ("eval_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), - ("save_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}), + ("train_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), + ("eval_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), + ("save_steps", IntValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), (["train_steps", "eval_steps"], ValueCompareValidator, {"target": 0}, ["check_at_least_one_not_equal_to_target"]), ("if_load", ClassValidator, {"classes": (bool,)}), diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index c536d6d2..f2e041a0 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -646,6 +646,10 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, } off_t fileSize = lseek(fd, 0, SEEK_END); + if (fileSize == 0) { + close(fd); + throw runtime_error(StringFormat("emb data file's size is 0").c_str()); + } size_t datasetSize = fileSize; ValidateFile(fd, dataDir, datasetSize); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 35688816..d406f36b 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -535,6 +535,7 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); process.KeyProcessTaskHelperWithFastUnique(batch, unique, channel, id); LOG_INFO("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, VectorToString(hotPos)); + unique->UnInitialize(); }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < 1; ++id) { -- Gitee From ebe96229380c54df05ff1d22dec71a0639827b62 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 4 Nov 2023 11:56:58 +0800 Subject: [PATCH 409/551] Match-id-48b1b808025e3a5d130cb594c7610f83baa26f58 --- mx_rec/logger/log.py | 11 ++++------- mx_rec/util/communication/hccl_mgmt.py | 7 ++++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/mx_rec/logger/log.py b/mx_rec/logger/log.py index fd630a8a..2855e329 100644 --- a/mx_rec/logger/log.py +++ b/mx_rec/logger/log.py @@ -14,19 +14,16 @@ from mx_rec.util.global_env_conf import global_env def init_sys_log(): work_dir = os.path.dirname(os.path.dirname(__file__)) log_cfg_file = os.path.join(work_dir, "logger.yaml") - real_config_path = os.path.realpath(log_cfg_file) - - if not FileValidator("log_cfg_file", log_cfg_file).check_file_size(real_config_path).check().is_valid(): - raise ValueError("Config file size is not valid.") - - with open(real_config_path, 'r', encoding='utf-8') as open_file: - if not FileValidator("log_cfg_file", real_config_path). \ + with open(log_cfg_file, 'r', encoding='utf-8'): + if not FileValidator("log_cfg_file", log_cfg_file). \ check_file_size(LOG_MAX_SIZE). \ check_not_soft_link(). \ check_user_group(). \ is_valid(): raise ValueError("Log config file is not valid.") + real_config_path = os.path.realpath(log_cfg_file) + with open(real_config_path, 'r', encoding='utf-8') as open_file: data = open_file.read(LOG_MAX_SIZE) log_cfg = yaml.safe_load(data) diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index 3a2db57c..bf17c309 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -13,9 +13,8 @@ from mx_rec.util.global_env_conf import global_env def parse_hccl_json(): - rank_table_path = os.path.realpath(global_env.rank_table_file) - - with open(rank_table_path, "r", encoding="utf-8") as file: + rank_table_path = global_env.rank_table_file + with open(rank_table_path, "r", encoding="utf-8"): # check whether json file is valid file_validator = FileValidator("RANK_TABLE_FILE", rank_table_path) # 1.check whether rank_table_path is soft link @@ -24,6 +23,8 @@ def parse_hccl_json(): file_validator.check_file_size(MAX_CONFIG_SIZE, MIN_SIZE) file_validator.check() + rank_table_path = os.path.realpath(global_env.rank_table_file) + with open(rank_table_path, "r", encoding="utf-8") as file: table_hccl = json.load(file) if "server_list" not in table_hccl: raise AttributeError(f"Lack of attribute server_list.") -- Gitee From 7be983fad621644d98ccaf673ce09eb082077b37 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 4 Nov 2023 12:51:39 +0800 Subject: [PATCH 410/551] Match-id-c808a6a84452f3d990d036ad045be1bcedecbd26 --- mx_rec/__init__.py | 3 +-- mx_rec/constants/constants.py | 3 +-- mx_rec/core/asc/manager.py | 29 +++++++++++++------------ mx_rec/optimizers/__init__.py | 3 +-- mx_rec/saver/saver.py | 36 +++++++++++++++---------------- src/core/utils/common.h | 1 + src/core/utils/safe_queue.h | 4 ++-- src/ops_tf/hybrid_dataset_ops.cpp | 12 +++++------ 8 files changed, 45 insertions(+), 46 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 9c1157fd..770c8e03 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["constants", "core", "graph", "util", - "version", "__version__"] +__all__ = ["constants", "core", "graph", "util", "version", "__version__"] from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 137532bc..caa5ed7d 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -165,8 +165,7 @@ class OptimizerType(Enum): raise ValueError(f"Invalid mode value, please choose one from {list(map(lambda c: c.value, OptimizerType))}") -OPTIMIZER_STATE_META = {OptimizerType.LAZY_ADAM: ["momentum", "velocity"], - OptimizerType.SGD: []} +OPTIMIZER_STATE_META = {OptimizerType.LAZY_ADAM: ["momentum", "velocity"], OptimizerType.SGD: []} class All2allGradientsOp(BaseEnum): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 8c39338f..42d94476 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -110,20 +110,21 @@ def matched_truncated_normal_initializer(tabel_info): def matched_emb_initializer(tabel_info): - initializer_case_map = {"tf1/tf2_constant_initializer": - isinstance(tabel_info.emb_initializer, tf.keras.initializers.Constant) or - isinstance(tabel_info.emb_initializer, tf.constant_initializer), - "tf1/tf2_random_normal_initializer": - isinstance(tabel_info.emb_initializer, tf.keras.initializers.RandomNormal) or - isinstance(tabel_info.emb_initializer, tf.random_normal_initializer), - "tf1_truncated_normal_initializer": - tf.__version__.startswith("1") and - (isinstance(tabel_info.emb_initializer, tf.truncated_normal_initializer) or - isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal)), - "tf2_truncated_normal_initializer": - tf.__version__.startswith("2") and - isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal), - } + initializer_case_map = { + "tf1/tf2_constant_initializer": + isinstance(tabel_info.emb_initializer, tf.keras.initializers.Constant) or + isinstance(tabel_info.emb_initializer, tf.constant_initializer), + "tf1/tf2_random_normal_initializer": + isinstance(tabel_info.emb_initializer, tf.keras.initializers.RandomNormal) or + isinstance(tabel_info.emb_initializer, tf.random_normal_initializer), + "tf1_truncated_normal_initializer": + tf.__version__.startswith("1") and + (isinstance(tabel_info.emb_initializer, tf.truncated_normal_initializer) or + isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal)), + "tf2_truncated_normal_initializer": + tf.__version__.startswith("2") and + isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal), + } if initializer_case_map.get("tf1/tf2_constant_initializer"): initializer = matched_constant_initializer(tabel_info) elif initializer_case_map.get("tf1/tf2_random_normal_initializer"): diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index 22b891b6..733d27b2 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -2,8 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["create_hash_optimizer", "create_ftrl_dense_optimizer", - "create_hash_optimizer_by_addr", "create_hash_optimizer_by_address"] +__all__ = ["create_hash_optimizer", "create_ftrl_dense_optimizer", "create_hash_optimizer_by_addr", "create_hash_optimizer_by_address"] from mx_rec.optimizers.adagrad import create_hash_optimizer from mx_rec.optimizers.ftrl import create_hash_optimizer diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index d2f6a962..f4fe8904 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -170,6 +170,24 @@ class Saver(object): save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, self.rank_id) + @performance("_save") + def _save(self, sess, root_dir): + result = self.save_op_dict + threads = [] + for table_name in result.keys(): + thread = SaveModelThread(sess, result, root_dir, table_name) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + if is_asc_manager_initialized() and not self.save_easy_mode: + save_host_data(root_dir) + logger.debug(f"host data was saved.") + def _build_save(self): for var in self.var_list: if os.getenv("TF_DEVICE", " ") == "NPU" and "merged" not in var.name: @@ -214,24 +232,6 @@ class Saver(object): assign_op = state.assign(sub_optimizer_placeholder_dict.get(key_state)) self.restore_fetch_list.append(assign_op) - @performance("_save") - def _save(self, sess, root_dir): - result = self.save_op_dict - threads = [] - for table_name in result.keys(): - thread = SaveModelThread(sess, result, root_dir, table_name) - threads.append(thread) - - for thread in threads: - thread.start() - - for thread in threads: - thread.join() - - if is_asc_manager_initialized() and not self.save_easy_mode: - save_host_data(root_dir) - logger.debug(f"host data was saved.") - def _save_easy_mode_save_key_data(self, dump_data_dict, root_dir, table_name): host_data = get_host_data(table_name) key = np.array(list(host_data.keys())) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index d1e1ca7f..4d3d8084 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -293,6 +293,7 @@ namespace MxRec { { auto size = static_cast(GLOG_MAX_BUF_SIZE); unique_ptr buf(new char[size]); + auto buf = std::make_unique(size); memset_s(buf.get(), size, 0, size); int nChar = snprintf_s(buf.get(), size, size - 1, format.c_str(), args ...); if (nChar == -1) { diff --git a/src/core/utils/safe_queue.h b/src/core/utils/safe_queue.h index 8f7f2efc..4b493d6c 100644 --- a/src/core/utils/safe_queue.h +++ b/src/core/utils/safe_queue.h @@ -19,7 +19,7 @@ namespace MxRec { template class SafeQueue { - static constexpr uint64_t defaultCap = 10; + static constexpr uint64_t DEFAULT_CAP = 10; public: SafeQueue() = default; @@ -118,7 +118,7 @@ namespace MxRec { private: mutable std::mutex mut; - uint64_t capacity = defaultCap; + uint64_t capacity = DEFAULT_CAP; std::atomic creatNum{}; std::list > dataQueue; std::list > emptyQueue; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index a919637b..1658f43d 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -59,7 +59,7 @@ namespace MxRec { } } - ~ClearChannel() = default; + ~ClearChannel() override = default; void Compute(OpKernelContextPtr context) override { @@ -81,7 +81,7 @@ namespace MxRec { OP_REQUIRES_OK(context, context->GetAttr("ids_name", &idsName)); // sparse_lookup查询 } - ~SetThreshold() = default; + ~SetThreshold() override = default; void Compute(OpKernelContextPtr context) override { @@ -145,7 +145,7 @@ namespace MxRec { explicit ReturnTimestamp(OpKernelConstructionPtr context) : OpKernel(context) {} - ~ReturnTimestamp() = default; + ~ReturnTimestamp() override = default; void Compute(OpKernelContextPtr context) override { @@ -194,7 +194,7 @@ namespace MxRec { } maxStep = keyProcess->GetMaxStep(channelId); } - ~ReadEmbKeyV2Dynamic() = default; + ~ReadEmbKeyV2Dynamic() override = default; void Compute(OpKernelContextPtr context) override { @@ -385,7 +385,7 @@ namespace MxRec { maxStep = keyProcess->GetMaxStep(channelId); } - ~ReadEmbKeyV2() = default; + ~ReadEmbKeyV2() override = default; void Compute(OpKernelContextPtr context) override { @@ -587,7 +587,7 @@ namespace MxRec { std::cout << " Cust opp not installed!!" << std::endl; } - ~CustOps() = default; + ~CustOps() override = default; }; } -- Gitee From 2cff7570ec98face97dd90f13c2b78c4b31f88d3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 4 Nov 2023 14:35:30 +0800 Subject: [PATCH 411/551] Match-id-3cb2cd07f6ad932d4db8c9bf87c11bc742a74c9f --- src/core/ssd_engine/table.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index bc8c0e0b..3335b69d 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -170,6 +170,7 @@ void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) break; } catch (invalid_argument &e) { // do nothing because file may in other path + LOG_INFO("insert exception, do nothing because file may in other path"); } } if (loadedFile == nullptr) { -- Gitee From 9e4d4d59ed1da7b347bef4d6f88f517988721dab Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 4 Nov 2023 14:31:26 +0800 Subject: [PATCH 412/551] Match-id-6a54f5492c6fd896cc65477b264538722e03e22f --- src/core/checkpoint/checkpoint.cpp | 1 + src/core/ssd_engine/table.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index f2e041a0..574f30ca 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -213,6 +213,7 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da { ofstream writeFile; writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); + fs::permissions(dataDir.c_str(), fs::perms::owner_read | fs::perms::owner_write); #ifndef GTEST auto res = aclrtSetDevice(static_cast(deviceId)); diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index bc8c0e0b..a1088e86 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -106,6 +106,7 @@ void Table::Save(int step) if (!metaFile.is_open()) { throw runtime_error("fail to create table meta file"); } + fs::permissions(metaFilePath, fs::perms::owner_read | fs::perms::owner_write); // dump table name uint32_t nameSize = static_cast(name.size()); -- Gitee From ab141ae99b1f90d2bd7b9365829b1f1ae468a0ae Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 4 Nov 2023 15:10:44 +0800 Subject: [PATCH 413/551] Match-id-edbb457d5b04582c7c1bf4c0adef93fb6f24cd54 --- mx_rec/saver/patch.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 47141c7f..7ce8f5d8 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -28,7 +28,7 @@ import numpy as np from mx_rec.saver.saver import Saver as SparseSaver from mx_rec.util.initialize import get_ascend_global_hashtable_collection, export_removing_var_list from mx_rec.validator.validator import para_checker_decorator, ClassValidator, StringValidator, OptionalIntValidator, \ - OptionalStringValidator + OptionalStringValidator, DirectoryValidator from mx_rec.util.log import logger from mx_rec.constants.constants import MAX_INT32 @@ -236,6 +236,12 @@ def restore(self, sess, save_path): if tf.__version__.startswith("2") and save_path.find("://") != -1: import tensorflow_io as tfio + if save_path.find("://") == -1: + directory_validator = DirectoryValidator("reading_path", save_path) + directory_validator.check_not_soft_link() + directory_validator.with_blacklist(exact_compare=False) + directory_validator.check() + checkpoint_prefix = compat.as_text(save_path) if self._is_empty: return -- Gitee From 8765279f731c95c85e9f484bdb4c37bb479eff2b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 4 Nov 2023 18:56:00 +0800 Subject: [PATCH 414/551] Match-id-f3b50f7d146f00256744bd719ac07c2c074e0e73 --- .../feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index fbbce658..918e026e 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -136,7 +136,9 @@ void FeatAdmitNEvictCkpt::SetHistRec(string embName) const auto& attribute = transferData.attribute; auto& timestamp = loadHistRec.timestamps[embName]; auto& histRecs = loadHistRec.historyRecords[embName]; - + if (transArr.empty() || attribute.empty()) { + throw std::runtime_error("transArr or attribute is empty"); + } timestamp = transArr.front(); size_t featItemInfoTotalSize = attribute.front() * static_cast(featItemInfoSaveNum); -- Gitee From a757c830f22ecd5959591e745712c9e8e4ff1620 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 4 Nov 2023 18:56:47 +0800 Subject: [PATCH 415/551] Match-id-1f3eef75a20eb19e62fc869d3eb6a380cd386b36 --- src/ops_tf/hybrid_dataset_ops.cpp | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index a919637b..1b93b31b 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -202,6 +202,10 @@ namespace MxRec { LOG_DEBUG("enter ReadEmbKeyV2Dynamic"); TimeCost tc = TimeCost(); int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = batchId; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { LOG_WARN("skip excess batch after {}/{}", batchId, maxStep); @@ -229,10 +233,6 @@ namespace MxRec { // [batchId % KEY_PROCESS_THREAD] which thread process this batch // [KEY_PROCESS_THREAD * 0 or 1] train or inference int batchQueueId = (batchId % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId); - Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); - auto out = output->flat(); - out(0) = batchId; TimeCost enqueueTC; EnqueueBatchData(std::vector{batchId, batchQueueId}, timestamp, inputTensor, splits); @@ -394,12 +394,12 @@ namespace MxRec { TimeCost tc = TimeCost(); int batchId = hybridMgmtBlock->readEmbedBatchId[channelId]++; Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); + auto out = output->flat(); + out(0) = batchId; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { LOG_WARN(StringFormat("skip excess batch after {}/{}", batchId, maxStep)); - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); - auto out = output->flat(); - out(0) = batchId; return; } } @@ -420,10 +420,6 @@ namespace MxRec { // [KEY_PROCESS_THREAD * 0 or 1] train or inference int batchQueueId = (batchId % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId); - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output)); - auto out = output->flat(); - out(0) = batchId; - TimeCost enqueueTC; EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor); LOG_DEBUG(KEY_PROCESS "ReadEmbKeyV2Static read batch cost(ms):{}, elapsed from last(ms):{}," -- Gitee From 11a3ed08e889879ac44a814189af82dcb53469ef Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 6 Nov 2023 22:52:56 +0800 Subject: [PATCH 416/551] Match-id-eb6ffaad287eb78be9a9a927224108269a8c2213 --- src/core/checkpoint/checkpoint.cpp | 9 +++++++-- src/core/ssd_engine/table.cpp | 20 ++++++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 574f30ca..e5c8b157 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -689,8 +689,13 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, char* mappedData = static_cast(tempMappedData); // 处理映射的数据 - HandleMappedData(mappedData, mapRowNum, onceReadByteSize, dst, i); - + try { + HandleMappedData(mappedData, mapRowNum, onceReadByteSize, dst, i); + } catch (const std::runtime_error& e) { + close(fd); + munmap(mappedData, mapByteSize); + throw runtime_error(StringFormat("handle mapped data error: %s", e.what())); + } munmap(mappedData, mapByteSize); offset += mapByteSize; diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index bd7bc3ff..fd36975b 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -119,8 +119,14 @@ void Table::Save(int step) for (const auto &f: fileSet) { uint64_t fid = f->GetFileID(); metaFile.write(reinterpret_cast(&fid), sizeof(fid)); - SetTablePathToDiskWithSpace(); + try { + SetTablePathToDiskWithSpace(); + } catch (runtime_error &e) { + metaFile.close(); + throw runtime_error(StringFormat("set table path to disk with space error:{}", e.what())); + } if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { + metaFile.close(); throw runtime_error("fail to create table directory"); } f->Save(curTablePath, step); @@ -128,6 +134,7 @@ void Table::Save(int step) metaFile.flush(); if (metaFile.fail()) { + metaFile.close(); throw runtime_error("fail to Save table meta file"); } @@ -211,25 +218,34 @@ void Table::Load(const string &metaFilePath, int step) uint32_t nameSize; metaFile->read(reinterpret_cast(&nameSize), sizeof(nameSize)); if (metaFile->fail()) { + metaFile->close(); throw invalid_argument("fail to read table name size"); } if (nameSize > maxNameSize) { + metaFile->close(); throw invalid_argument("table name too large, file may broken"); } char tmpArr[nameSize + 1]; auto ec = memset_s(tmpArr, nameSize + 1, '\0', nameSize + 1); if (ec != EOK) { + metaFile->close(); throw runtime_error("fail to init table name array"); } metaFile->read(tmpArr, static_cast(nameSize)); tmpArr[nameSize] = '\0'; string tbNameInFile = tmpArr; if (name != tbNameInFile) { + metaFile->close(); throw invalid_argument("table name not match"); } // construct file set - LoadDataFileSet(metaFile, step); + try { + LoadDataFileSet(metaFile, step); + } catch (exception &e) { + metaFile->close(); + throw runtime_error(StringFormat("load data file set error:{}", e.what())); + } metaFile->close(); if (metaFile->fail()) { throw runtime_error("fail to load table"); -- Gitee From 2a77030f2bbebaea76b95ece0c8c3f2759b502e4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 6 Nov 2023 23:16:58 +0800 Subject: [PATCH 417/551] Match-id-2f9bbe3a994df593d5aae03384dc6cbc07657e72 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 32 ++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 2dcac774..c74a1f21 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -469,6 +469,10 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) { size_t embTableCount { 0 }; auto loadHostEmbs { loadData.hostEmbs }; + if (loadHostEmbs == nullptr) { + LOG_ERROR(MGMT + "Host Embedding of load checkpoint data is nullptr!"); + return false; + } for (EmbInfo setupHostEmbs : mgmtEmbInfo) { if (!IsLoadDataMatches(*loadHostEmbs, setupHostEmbs, embTableCount)) { return false; @@ -611,15 +615,16 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) } // 动态shape场景下,获取all2all向量(通信量矩阵) + TimeCost sendTensorsSyncTC; unique_ptr> all2all = nullptr; if (!mgmtRankInfo.useStatic) { all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); - } - LOG_DEBUG("getTensorsSyncTC(ms):{}", getTensorsSyncTC.ElapsedMS()); - - // 动态shape场景下,发送all2all向量(通信量矩阵) - TimeCost sendTensorsSyncTC; - if (!mgmtRankInfo.useStatic) { + LOG_DEBUG("getTensorsSyncTC(ms):{}", getTensorsSyncTC.ElapsedMS()); + if (all2all == nullptr) { + LOG_ERROR("Information vector is nullptr!"); + return false; + } + sendTensorsSyncTC = TimeCost(); TimeCost sendAll2AllScSyncTC; hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embInfo.name); LOG_DEBUG("sendAll2AllScSyncTC(ms):{}", sendAll2AllScSyncTC.ElapsedMS()); @@ -730,6 +735,11 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha { TimeCost getAndSendTensorsTC; TimeCost getTensorsTC; + + if (hostHashMaps->embHashMaps.find(embName) == hostHashMaps->embHashMaps.end()) { + LOG_ERROR("Failed to get embedding hash map with given name: {}", embName); + return false; + } auto& embHashMap = hostHashMaps->embHashMaps.at(embName); // 计数初始化 @@ -784,6 +794,10 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha hdTransfer->Send(TransferChannel::SWAP, ddrParam.tmpDataOut, channelId, embName); if (!mgmtRankInfo.useStatic) { auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); + if (all2all == nullptr) { + LOG_ERROR("Information vector is nullptr!"); + return false; + } hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } LOG_DEBUG("sendTensorsTC(ms):{} getAndSendTensorsTC(ms):{}, channelId:{}", @@ -907,6 +921,10 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) // 初始化host侧的emb auto& evictOffset = hostHashMaps->GetEvictPos(embName); vector evictOffset4Ddr; + if (hostHashMaps->embHashMaps.find(embName) == hostHashMaps->embHashMaps.end()) { + LOG_ERROR("Failed to get embedding hash map with given name: {}", embName); + return; + } auto devVocabSize = hostHashMaps->embHashMaps.at(embName).devVocabSize; for (auto& offsetInHostHashMap : evictOffset) { evictOffset4Ddr.emplace_back(offsetInHostHashMap - devVocabSize); @@ -1056,4 +1074,4 @@ int64_t HybridMgmt::GetTableCapacity(const string& embName) const LOG_WARN(MGMT + "no dynamic expansion mode, get emb:[{}] capacity failed", embName); return -1; #endif -} \ No newline at end of file +} -- Gitee From 89f60c2c660e0095997d6db8bc026d274b745f90 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 6 Nov 2023 23:48:11 +0800 Subject: [PATCH 418/551] Match-id-88bb41513edaff6d25b05f61c7c7bebb4519491f --- mx_rec/constants/constants.py | 2 - mx_rec/util/global_env_conf.py | 6 - src/core/emb_hashmap/emb_hashmap.cpp | 170 ++++----------------------- src/core/utils/config.cpp | 16 --- src/core/utils/config.h | 4 - 5 files changed, 26 insertions(+), 172 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 137532bc..60a24901 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -110,8 +110,6 @@ class EnvOption(Enum): APPLY_GRADIENTS_STRATEGY = "APPLY_GRADIENTS_STRATEGY" ACL_TIMEOUT = "AclTimeout" HD_CHANNEL_SIZE = "HD_CHANNEL_SIZE" - FIND_OFFSET_V2 = "FIND_OFFSET_V2" - FIND_OFFSET_V3 = "FIND_OFFSET_V3" KEY_PROCESS_THREAD_NUM = "KEY_PROCESS_THREAD_NUM" MAX_UNIQUE_THREAD_NUM = "MAX_UNIQUE_THREAD_NUM" FAST_UNIQUE = "FAST_UNIQUE" diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index b617fd62..355f1640 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -24,8 +24,6 @@ class RecEnv: apply_gradients_strategy: str acl_timeout: str hd_channel_size: str - find_offset_v2: str - find_offset_v3: str key_process_thread_num: str max_unique_thread_num: str fast_unique: str @@ -53,8 +51,6 @@ def get_global_env_conf() -> RecEnv: ApplyGradientsStrategy.DIRECT_APPLY.value), acl_timeout=os.getenv(EnvOption.ACL_TIMEOUT.value, "-1"), hd_channel_size=os.getenv(EnvOption.HD_CHANNEL_SIZE.value, DEFAULT_HD_CHANNEL_SIZE), - find_offset_v2=os.getenv(EnvOption.FIND_OFFSET_V2.value, Flag.FALSE.value), - find_offset_v3=os.getenv(EnvOption.FIND_OFFSET_V3.value, Flag.FALSE.value), key_process_thread_num=os.getenv(EnvOption.KEY_PROCESS_THREAD_NUM.value, DEFAULT_KP_THREAD_NUM), max_unique_thread_num=os.getenv(EnvOption.MAX_UNIQUE_THREAD_NUM.value, DEFAULT_FAST_UNIQUE_THREAD_NUM), fast_unique=os.getenv(EnvOption.FAST_UNIQUE.value, Flag.FALSE.value), @@ -76,8 +72,6 @@ def get_global_env_conf() -> RecEnv: ("acl_timeout", Convert2intValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), ("hd_channel_size", Convert2intValidator, {"min_value": MIN_HD_CHANNEL_SIZE, "max_value": MAX_HD_CHANNEL_SIZE}, ["check_value"]), - ("find_offset_v2", OptionValidator, {"options": [i.value for i in list(Flag)]}), - ("find_offset_v3", OptionValidator, {"options": [i.value for i in list(Flag)]}), ("key_process_thread_num", Convert2intValidator, {"min_value": MIN_KP_THREAD_NUM, "max_value": MAX_KP_THREAD_NUM}, ["check_value"]), ("max_unique_thread_num", Convert2intValidator, diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 203377b3..ff686b49 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -7,11 +7,8 @@ #include "emb_hashmap.h" #include -#include #include -#include "checkpoint/checkpoint.h" -#include "hd_transfer/hd_transfer.h" #include "hybrid_mgmt/hybrid_mgmt_block.h" #include "utils/common.h" @@ -56,21 +53,19 @@ void EmbHashMap::Process(const string& embName, vector& keys, DDRPara #ifndef GTEST EASY_FUNCTION(profiler::colors::Pink) TimeCost swapTimeCost; - auto& embHashMap = embHashMaps.at(embName); + auto it = embHashMaps.find(embName); + if (it == embHashMaps.end()) { + throw runtime_error("table not exist in embHashMaps"); + } + auto &embHashMap = it->second; embHashMap.devOffset2KeyOld.clear(); embHashMap.oldSwap.clear(); embHashMap.maxOffsetOld = embHashMap.maxOffset; auto keepBatch = swapId; // 处理batch的次数,多个预取一起处理算一次 - LOG_DEBUG("FindOffset version:{}", GlobalEnv::findOffsetV2); - // 找到所有key的偏移;dev和host需要交换的位置 - if (GlobalEnv::findOffsetV2) { - FindAndUpdateOffset(embName, keys, swapId, keepBatch, channelId); - } else { - FindOffset(embName, keys, swapId, keepBatch, channelId); - } + FindOffset(embName, keys, swapId, keepBatch, channelId); LOG_DEBUG("FindOffset end"); // 调用刷新频次数据方法 @@ -120,100 +115,6 @@ void EmbHashMap::Process(const string& embName, vector& keys, DDRPara #endif } -/* - * 从embHashMaps获取key对应的位置,并更新devOffset2Batch - */ -#ifndef GTEST -void EmbHashMap::FindAndUpdateOffset(const string& embName, vector& keys, - size_t currentBatchId, size_t keepBatchId, int channelId) -{ - EASY_FUNCTION() - size_t keySize = keys.size(); - auto& embHashMap = embHashMaps.at(embName); - FindAndUpdateBatchId(keys, currentBatchId, keySize, embHashMap); - const int devVocabSize = static_cast(embHashMap.devVocabSize); - for (size_t i = 0; i < keySize; i++) { - auto key = keys[i]; - if (key == -1) { - continue; - } - auto& offset = embHashMap.lookUpVec[i]; - if (offset == INVALID_KEY_VALUE && channelId == TRAIN_CHANNEL_ID) { - offset = FindNewOffset(key, embHashMap); - if (offset < devVocabSize) { - embHashMap.devOffset2KeyOld.emplace_back(offset, embHashMap.devOffset2Key[offset]); - embHashMap.devOffset2Key[offset] = key; - embHashMap.devOffset2Batch[offset] = static_cast(currentBatchId); - } - } - if (offset >= devVocabSize) { - embHashMap.missingKeysHostPos.emplace_back(offset - embHashMap.devVocabSize); - offset = FindSwapPosV2(embName, key, offset, currentBatchId, keepBatchId); - } - } -} - -int32_t EmbHashMap::FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap) const -{ - int32_t offset; - const auto& iter = embHashMap.hostHashMap.find(key); - if (iter != embHashMap.hostHashMap.end()) { // 由于未全局去重,需要再次查询确保是新key - offset = static_cast(iter->second); - } else if (embHashMap.evictDevPos.size() != 0) { // 优先复用hbm表 - offset = static_cast(embHashMap.evictDevPos.back()); - embHashMap.hostHashMap[key] = offset; - LOG_TRACE("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, embHashMap.evictDevPos.size()); - embHashMap.evictDevPos.pop_back(); - } else if (embHashMap.evictPos.size() != 0) { // hbm不足,再复用ddr表 - offset = static_cast(embHashMap.evictPos.back()); - embHashMap.hostHashMap[key] = offset; - LOG_TRACE("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, embHashMap.evictPos.size()); - embHashMap.evictPos.pop_back(); - } else { - embHashMap.hostHashMap[key] = embHashMap.maxOffset; - offset = static_cast(embHashMap.maxOffset); - embHashMap.maxOffset++; - if (embHashMap.maxOffset == embHashMap.devVocabSize) { - LOG_INFO("start using host vocab!"); - } - if (embHashMap.maxOffset > embHashMap.hostVocabSize + embHashMap.devVocabSize) { - LOG_ERROR("hostVocabSize too small! dev:{} host:{}", embHashMap.devVocabSize, - embHashMap.hostVocabSize); - throw runtime_error("hostVocabSize too small"); - } - } - return offset; -} - -void EmbHashMap::FindAndUpdateBatchId(vector& keys, size_t currentBatchId, size_t keySize, - EmbHashMapInfo& embHashMap) const -{ - EASY_FUNCTION() - for (size_t i = 0; i < keySize; i++) { - int offset; - auto& key = keys[i]; - if (key == -1) { - continue; - } - const auto& iter = embHashMap.hostHashMap.find(key); - if (iter != embHashMap.hostHashMap.end()) { // found - if (GlobalEnv::findOffsetV3) { - key = -1; - } - offset = static_cast(iter->second); - embHashMap.lookUpVec.emplace_back(offset); // convert to offset(current) - - if (offset < static_cast(embHashMap.devVocabSize)) { - embHashMap.devOffset2Batch[offset] = static_cast(currentBatchId); - } - } else { - embHashMap.lookUpVec.emplace_back(INVALID_KEY_VALUE); - } - } -} - auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map { LOG_DEBUG(HYBRID_BLOCKING + " start GetHashMaps"); @@ -317,42 +218,6 @@ void EmbHashMap::EvictDeleteEmb(const string& embName, const vector& LOG_TRACE("hostHashMap, {}", MapToString(embHashMaps[embName].hostHashMap)); } -/* - * 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys - */ -int EmbHashMap::FindSwapPosV2(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, - size_t keepBatchId) -{ - bool notFind = true; - auto& embHashMap = embHashMaps.at(embName); - int newDevOffset; - while (notFind) { - if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { - embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] = static_cast(currentBatchId); - embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); - newDevOffset = static_cast(embHashMap.currentUpdatePos); - embHashMap.hostHashMap[key] = embHashMap.currentUpdatePos; - embHashMap.devOffset2KeyOld.emplace_back(embHashMap.currentUpdatePos, - embHashMap.devOffset2Key[embHashMap.currentUpdatePos]); - auto& oldKey = embHashMap.devOffset2Key[embHashMap.currentUpdatePos]; - embHashMap.oldSwap.emplace_back(oldKey, key); - embHashMap.hostHashMap[oldKey] = hostOffset; - oldKey = key; - notFind = false; - } - embHashMap.currentUpdatePos++; - embHashMap.freeSize--; - if (embHashMap.currentUpdatePos == embHashMap.devVocabSize) { - embHashMap.currentUpdatePos = 0; - } - if (embHashMap.currentUpdatePos == embHashMap.currentUpdatePosStart) { - LOG_ERROR("devVocabSize is too small"); - throw runtime_error("devVocabSize is too small"); - } - } - return newDevOffset; -} -#endif /// 从embHashMaps获取key对应的位置,构造查询向量;更新devOffset2Batch;记录dev与host需要交换的偏移 /// \param embName 表名 @@ -365,7 +230,11 @@ void EmbHashMap::FindOffset(const string& embName, const vector& keys { EASY_FUNCTION() size_t keySize = keys.size(); - auto& embHashMap = embHashMaps.at(embName); + auto it = embHashMaps.find(embName); + if (it == embHashMaps.end()) { + throw runtime_error("table not exist in embHashMaps"); + } + auto &embHashMap = it->second; UpdateBatchId(keys, currentBatchId, keySize, embHashMap); for (size_t i = 0; i < keySize; i++) { auto key = keys[i]; @@ -465,6 +334,7 @@ void EmbHashMap::UpdateBatchId(const vector& keys, size_t currentBatc LOG_TRACE("key will be used, {} , offset , {}", key, offset); if (offset < embHashMap.devVocabSize) { + // devOffset2Batch size equal to devVocabSize, unnecessary to check index boundary embHashMap.devOffset2Batch[offset] = static_cast(currentBatchId); } } @@ -482,9 +352,17 @@ bool EmbHashMap::FindSwapPosOld(const string& embName, emb_key_t key, size_t hos size_t keepBatchId) { bool notFind = true; - auto& embHashMap = embHashMaps.at(embName); + auto it = embHashMaps.find(embName); + if (it == embHashMaps.end()) { + throw runtime_error("table not exist in embHashMaps"); + } + auto &embHashMap = it->second; while (notFind) { // 找到本次预取之前的偏移(保证所有预取batch的key都在HBM中) + if (embHashMap.currentUpdatePos >= embHashMap.devOffset2Batch.size()) { + throw runtime_error("currentUpdatePos out of range"); + } + if (embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] < static_cast(keepBatchId)) { embHashMap.devOffset2Batch[embHashMap.currentUpdatePos] = static_cast(currentBatchId); embHashMap.swapPos.emplace_back(embHashMap.currentUpdatePos); // 记录需要被换出的HBM偏移 @@ -545,7 +423,11 @@ void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHa } auto& hostMap = embHashMap.hostHashMap; auto& devSize = embHashMap.devVocabSize; - auto& lfu = cacheManager->ddrKeyFreqMap[embTableName]; + auto iter = cacheManager->ddrKeyFreqMap.find(embTableName); + if (iter == cacheManager->ddrKeyFreqMap.end()) { + throw runtime_error("table not in ddrKeyFreqMap"); + } + auto &lfu = iter->second; const auto& lfuTab = lfu.GetFreqTable(); if (lfuTab.empty()) { return; diff --git a/src/core/utils/config.cpp b/src/core/utils/config.cpp index 255df2b9..46fb3c8a 100644 --- a/src/core/utils/config.cpp +++ b/src/core/utils/config.cpp @@ -22,8 +22,6 @@ namespace MxRec { string GlobalEnv::applyGradientsStrategy = ApplyGradientsStrategyOptions::DIRECT_APPLY; int GlobalEnv::aclTimeout = -1; // 默认阻塞方式,一直等待直到数据接收完成。 int GlobalEnv::hdChannelSize = 40; // 默认通道深度40 - bool GlobalEnv::findOffsetV2 = false; - bool GlobalEnv::findOffsetV3 = false; int GlobalEnv::keyProcessThreadNum = 6; // 默认6个线程 int GlobalEnv::maxUniqueThreadNum = 8; // 默认最大8个线程 bool GlobalEnv::fastUnique = false; @@ -54,18 +52,6 @@ namespace MxRec { GlobalEnv::hdChannelSize = std::stoi(envHDChannelSize); } - // 设置偏移查找策略V2 - const char *envFindOffsetV2 = getenv(RecEnvNames::FIND_OFFSET_V2); - if (envFindOffsetV2 != nullptr) { - GlobalEnv::findOffsetV2 = (std::stoi(envFindOffsetV2) == 1); - } - - // 设置偏移查找策略V3 - const char *envFindOffsetV3 = getenv(RecEnvNames::FIND_OFFSET_V3); - if (envFindOffsetV3 != nullptr) { - GlobalEnv::findOffsetV3 = (std::stoi(envFindOffsetV3) == 1); - } - // 设置数据处理线程数 const char *envKPNum = getenv(RecEnvNames::KEY_PROCESS_THREAD_NUM); if (envKPNum != nullptr) { @@ -122,8 +108,6 @@ namespace MxRec { RecEnvNames::APPLY_GRADIENTS_STRATEGY, GlobalEnv::applyGradientsStrategy, RecEnvNames::ACL_TIMEOUT, GlobalEnv::aclTimeout, RecEnvNames::HD_CHANNEL_SIZE, GlobalEnv::hdChannelSize, - RecEnvNames::FIND_OFFSET_V2, GlobalEnv::findOffsetV2, - RecEnvNames::FIND_OFFSET_V3, GlobalEnv::findOffsetV3, RecEnvNames::KEY_PROCESS_THREAD_NUM, GlobalEnv::keyProcessThreadNum, RecEnvNames::MAX_UNIQUE_THREAD_NUM, GlobalEnv::maxUniqueThreadNum, RecEnvNames::FAST_UNIQUE, GlobalEnv::fastUnique, diff --git a/src/core/utils/config.h b/src/core/utils/config.h index f29c2346..6b498df9 100644 --- a/src/core/utils/config.h +++ b/src/core/utils/config.h @@ -16,8 +16,6 @@ namespace MxRec { const char *const APPLY_GRADIENTS_STRATEGY = "APPLY_GRADIENTS_STRATEGY"; const char *const ACL_TIMEOUT = "AclTimeout"; const char *const HD_CHANNEL_SIZE = "HD_CHANNEL_SIZE"; - const char *const FIND_OFFSET_V2 = "FIND_OFFSET_V2"; - const char *const FIND_OFFSET_V3 = "FIND_OFFSET_V3"; const char *const KEY_PROCESS_THREAD_NUM = "KEY_PROCESS_THREAD_NUM"; const char *const MAX_UNIQUE_THREAD_NUM = "MAX_UNIQUE_THREAD_NUM"; const char *const FAST_UNIQUE = "FAST_UNIQUE"; @@ -37,8 +35,6 @@ namespace MxRec { static std::string applyGradientsStrategy; static int aclTimeout; static int hdChannelSize; - static bool findOffsetV2; - static bool findOffsetV3; static int keyProcessThreadNum; static int maxUniqueThreadNum; static bool fastUnique; -- Gitee From a15574d4ca305f0d1efa3efd3f145e596f77ba21 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 7 Nov 2023 00:03:38 +0800 Subject: [PATCH 419/551] Match-id-e7ad04600354a3e9c3955c912fd1f808192d3300 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 42 ++++++++++++++++++++++++++++ src/core/hybrid_mgmt/hybrid_mgmt.h | 6 ++++ src/core/utils/common.h | 3 -- src/core/utils/safe_queue.h | 32 --------------------- 4 files changed, 48 insertions(+), 35 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 2dcac774..ff9a1e43 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -136,6 +136,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, LOG_INFO(MGMT + "end initialize, noDDR:{}, maxStep:[{}, {}], rank:{}", rankInfo.noDDR, rankInfo.maxStep.at(TRAIN_CHANNEL_ID), rankInfo.maxStep.at(EVAL_CHANNEL_ID), rankInfo.rankId); #endif + isInitialized = true; return true; } @@ -223,6 +224,11 @@ void HybridMgmt::RestoreFreq4Save(CkptData& saveData) const bool HybridMgmt::Save(const string savePath) { #ifndef GTEST + if (!isInitialized) { + LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); + return false; + } + // 数据处理线程上锁 preprocess->LoadSaveLock(); @@ -275,6 +281,11 @@ bool HybridMgmt::Save(const string savePath) bool HybridMgmt::Load(const string& loadPath) { #ifndef GTEST + if (!isInitialized) { + LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); + return false; + } + // 数据处理线程上锁 preprocess->LoadSaveLock(); @@ -362,6 +373,10 @@ void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, KeyOffsetMapT HybridMgmt::SendHostMap(const string tableName) { #ifndef GTEST + if (!isInitialized) { + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); + } + preprocess->LoadSaveLock(); KeyOffsetMemT keyOffsetMap; KeyOffsetMapT sendKeyOffsetMap; @@ -389,6 +404,10 @@ KeyOffsetMapT HybridMgmt::SendHostMap(const string tableName) void HybridMgmt::ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap) { #ifndef GTEST + if (!isInitialized) { + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); + } + preprocess->LoadSaveLock(); KeyOffsetMemT loadKeyOffsetMap; OffsetMemT loadMaxOffset; @@ -850,6 +869,11 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) bool HybridMgmt::Evict() { #ifndef GTEST + if (!isInitialized) { + LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); + return false; + } + // 配置了淘汰选项,则触发 auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { @@ -991,6 +1015,10 @@ int HybridMgmt::GetStepFromPath(const string& loadPath) const /// \param steps 运行的步数,由于可能存在循环下沉,所以1个session run 对应N步 void HybridMgmt::NotifyBySessionRun(int channelID) const { + if (!isInitialized) { + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); + } + hybridMgmtBlock->CheckAndNotifyWake(channelID); } @@ -999,6 +1027,10 @@ void HybridMgmt::NotifyBySessionRun(int channelID) const /// \param steps 运行的步数,由于可能存在循环下沉,所以1个session run 对应N步 void HybridMgmt::CountStepBySessionRun(int channelID, int steps) const { + if (!isInitialized) { + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); + } + hybridMgmtBlock->CountPythonStep(channelID, steps); } @@ -1008,6 +1040,11 @@ void HybridMgmt::CountStepBySessionRun(int channelID, int steps) const int64_t HybridMgmt::GetTableSize(const string& embName) const { #ifndef GTEST + if (!isInitialized) { + LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); + return -1; + } + if (mgmtRankInfo.useDynamicExpansion) { int64_t size = preprocess->GetExpansionTableSize(embName); LOG_INFO(MGMT + "dynamic expansion mode, get emb:[{}] size:{}", embName, size); @@ -1048,6 +1085,11 @@ int64_t HybridMgmt::GetTableSize(const string& embName) const int64_t HybridMgmt::GetTableCapacity(const string& embName) const { #ifndef GTEST + if (!isInitialized) { + LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); + return -1; + } + if (mgmtRankInfo.useDynamicExpansion) { int64_t capacity = preprocess->GetExpansionTableCapacity(embName); LOG_INFO(MGMT + "dynamic expansion mode, get emb:[{}] capacity:{}", embName, capacity); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 7f533aa8..0defe7b8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -73,6 +73,11 @@ namespace MxRec { void Destroy() { + if (!isInitialized) { + LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); + return; + } + if (!isRunning) { return; } @@ -160,6 +165,7 @@ namespace MxRec { bool isSSDEnabled { false }; bool isRunning; bool isLoad { false }; + bool isInitialized { false }; void TrainTask(TaskType type); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index d1e1ca7f..ad21a7f7 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -162,12 +162,10 @@ namespace MxRec { } std::vector sample; - void *tensorAddr = nullptr; std::string name; size_t batchSize; int batchId; int channel = 0; - bool isInt64; // true int64 false int32 time_t timestamp { -1 }; }; @@ -179,7 +177,6 @@ namespace MxRec { int batchId; int channelId; time_t timestamp { -1 }; - bool flag; // true int64 false int32 const void *tensor; }; diff --git a/src/core/utils/safe_queue.h b/src/core/utils/safe_queue.h index 8f7f2efc..5c77a797 100644 --- a/src/core/utils/safe_queue.h +++ b/src/core/utils/safe_queue.h @@ -54,22 +54,6 @@ namespace MxRec { } } - std::unique_ptr WaitAndGetOne() - { - { - std::lock_guard lk(mut); - if (creatNum < capacity) { - creatNum++; - return std::make_unique(); - } - } - std::unique_lock locker(mut); - dirtyCond.wait(locker, [this] { return !emptyQueue.empty(); }); - auto t = move(emptyQueue.back()); - emptyQueue.pop_back(); - return move(t); - } - void PutDirty(std::unique_ptr &&t) { std::lock_guard lk(mut); @@ -84,15 +68,6 @@ namespace MxRec { dataCond.notify_one(); } - std::unique_ptr WaitAndPop() - { - std::unique_lock lk(mut); - dataCond.wait(lk, [this] { return !dataQueue.empty(); }); - std::unique_ptr res = std::move(dataQueue.front()); - dataQueue.pop_front(); - return move(res); - } - std::unique_ptr TryPop() { std::lock_guard lk(mut); @@ -104,12 +79,6 @@ namespace MxRec { return move(res); } - bool Empty() const - { - std::lock_guard lk(mut); - return dataQueue.empty(); - } - size_t Size() const { std::lock_guard lk(mut); @@ -119,7 +88,6 @@ namespace MxRec { private: mutable std::mutex mut; uint64_t capacity = defaultCap; - std::atomic creatNum{}; std::list > dataQueue; std::list > emptyQueue; std::condition_variable dataCond; -- Gitee From 515e94b8c2c63a785a6154918b6db8441ddb4ce2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 7 Nov 2023 00:00:12 +0800 Subject: [PATCH 420/551] Match-id-08889f4e5ce7f865f45230f036f60de6978fb7f1 --- build/build.sh | 12 ++++++++++++ build/build_tf1.sh | 2 ++ build/build_tf2.sh | 2 ++ mx_rec/constants/constants.py | 2 +- mx_rec/util/communication/hccl_mgmt.py | 7 ++++++- mx_rec/util/global_env_conf.py | 4 ++-- mx_rec/util/initialize.py | 13 ++++++------- mx_rec/util/variable.py | 15 --------------- src/core/key_process/key_process.cpp | 5 ++++- 9 files changed, 35 insertions(+), 27 deletions(-) diff --git a/build/build.sh b/build/build.sh index 8a2dd5d4..65657ee4 100644 --- a/build/build.sh +++ b/build/build.sh @@ -59,6 +59,18 @@ gen_tar_file() mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" + # change dirs and files 's permission + chmod 550 ../build/"${pkg_dir}"/tf1_whl + chmod 550 ../build/"${pkg_dir}"/tf1_whl/mx_rec*.whl + chmod 550 ../build/"${pkg_dir}"/tf2_whl + chmod 550 ../build/"${pkg_dir}"/tf2_whl/mx_rec*.whl + chmod 550 ../build/"${pkg_dir}"/cust_op/ + chmod 550 ../build/"${pkg_dir}"/cust_op/cust_op_by_addr + cd ../build/"${pkg_dir}"/cust_op/cust_op_by_addr + chmod 550 *.sh + chmod 640 *.json + chmod 550 op_host op_kernel op_host/* op_kernel/* + cd - cd ../build tar -zvcf "${release_tar}" "${pkg_dir}" || { warn "compression failed, packages might be broken" diff --git a/build/build_tf1.sh b/build/build_tf1.sh index 653c95c4..90b1c493 100644 --- a/build/build_tf1.sh +++ b/build/build_tf1.sh @@ -53,10 +53,12 @@ remove "${SCRIPT_DIR}/lib" get_version export VERSION echo "MindX SDK mxrec: ${VERSION}" >> ./version.info +chmod 640 ./version.info pkg_dir=mindxsdk-mxrec remove "${pkg_dir}" mkdir "${pkg_dir}" +chmod 750 "$pkg_dir" mv version.info "${pkg_dir}" opensource_path="${ROOT_DIR}"/../opensource/opensource diff --git a/build/build_tf2.sh b/build/build_tf2.sh index 69ce537d..dd586a31 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -53,10 +53,12 @@ remove "${SCRIPT_DIR}/lib" get_version export VERSION echo "MindX SDK mxrec: ${VERSION}" >> ./version.info +chmod 640 ./version.info pkg_dir=mindxsdk-mxrec remove "${pkg_dir}" mkdir "${pkg_dir}" +chmod 750 "$pkg_dir" mv version.info "${pkg_dir}" opensource_path="${ROOT_DIR}"/../opensource/opensource diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 60a24901..a856d9fe 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -44,7 +44,7 @@ MIN_HD_CHANNEL_SIZE = 2 # key process线程数 DEFAULT_KP_THREAD_NUM = 6 MIN_KP_THREAD_NUM = 1 -MAX_KP_THREAD_NUM = 32 +MAX_KP_THREAD_NUM = 10 # Fast unique去重最大线程数 DEFAULT_FAST_UNIQUE_THREAD_NUM = 8 diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index bf17c309..84b618b0 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -25,7 +25,12 @@ def parse_hccl_json(): rank_table_path = os.path.realpath(global_env.rank_table_file) with open(rank_table_path, "r", encoding="utf-8") as file: - table_hccl = json.load(file) + try: + table_hccl = json.load(file) + except FileNotFoundError as e: + raise ValueError("rank table file not found") from e + except json.JSONDecodeError as e: + raise ValueError("rank table file is unable to parse as json") from e if "server_list" not in table_hccl: raise AttributeError(f"Lack of attribute server_list.") if not table_hccl.get("server_list"): diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index 355f1640..a9886e06 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -8,7 +8,7 @@ from mx_rec.constants.constants import EnvOption, RecPyLogLevel, Flag, EMPTY_STR DEFAULT_HD_CHANNEL_SIZE, DEFAULT_KP_THREAD_NUM, DEFAULT_FAST_UNIQUE_THREAD_NUM, RecCPPLogLevel, MAX_INT32, \ MIN_HD_CHANNEL_SIZE, MAX_HD_CHANNEL_SIZE, MIN_KP_THREAD_NUM, MAX_KP_THREAD_NUM, \ MIN_FAST_UNIQUE_THREAD_NUM, MAX_FAST_UNIQUE_THREAD_NUM, DEFAULT_HOT_EMB_UPDATE_STEP, MIN_HOT_EMB_UPDATE_STEP, \ - MAX_HOT_EMB_UPDATE_STEP + MAX_HOT_EMB_UPDATE_STEP, TFDevice from mx_rec.validator.validator import para_checker_decorator, OptionValidator, DirectoryValidator, Convert2intValidator @@ -46,7 +46,7 @@ def get_global_env_conf() -> RecEnv: ascend_visible_devices=os.getenv(EnvOption.ASCEND_VISIBLE_DEVICES.value), cm_chief_device=os.getenv(EnvOption.CM_CHIEF_DEVICE.value), cm_worker_size=os.getenv(EnvOption.CM_WORKER_SIZE.value), - tf_device=os.getenv(EnvOption.TF_DEVICE.value), + tf_device=os.getenv(EnvOption.TF_DEVICE.value, TFDevice.NPU.value), apply_gradients_strategy=os.getenv(EnvOption.APPLY_GRADIENTS_STRATEGY.value, ApplyGradientsStrategy.DIRECT_APPLY.value), acl_timeout=os.getenv(EnvOption.ACL_TIMEOUT.value, "-1"), diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index eb624b24..45d95b80 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -42,8 +42,6 @@ class ConfigInitializer: ]) def __init__(self, use_mpi=True, **kwargs): self._use_mpi = use_mpi - self._rank_id = kwargs.get("rank_id", 0) - self._rank_size = kwargs.get("rank_size", 1) self._ascend_global_hashtable_collection = ASCEND_GLOBAL_HASHTABLE_COLLECTION self._comm = None self._asc_manager = None @@ -409,10 +407,6 @@ class ConfigInitializer: @ascend_global_hashtable_collection.setter def ascend_global_hashtable_collection(self, name): - string_validator = StringValidator(name="hashtable_collection", value=name, - max_len=HASHTABLE_COLLECTION_NAME_LENGTH, min_len=1) - if not string_validator.check_string_length().check_whitelist().is_valid(): - raise ValueError(string_validator.msg) self._ascend_global_hashtable_collection = name def get_initializer(self, is_training): @@ -448,7 +442,7 @@ class ConfigInitializer: @para_checker_decorator(check_option_list=[ ("name", ClassValidator, {"classes": (str, type(None))}), - ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length", "check_whitelist"]), ]) def set_ascend_global_hashtable_collection(name=ASCEND_GLOBAL_HASHTABLE_COLLECTION): ConfigInitializer.get_instance().ascend_global_hashtable_collection = name @@ -936,6 +930,11 @@ def bind_cpu(rank_id: int, local_rank_size: int): import math total_cpu, cpu_range_list = get_available_cpu_num_and_range() + + if local_rank_size <= 0: + logger.error(f"local rank size 's value less than or equal 0.") + return + avg_count = math.ceil(total_cpu / local_rank_size) while True: if avg_count == 0: diff --git a/mx_rec/util/variable.py b/mx_rec/util/variable.py index 900d1dad..c74b8718 100644 --- a/mx_rec/util/variable.py +++ b/mx_rec/util/variable.py @@ -22,18 +22,3 @@ def check_and_get_config_via_var(variable, optimizer_type: str): f" init method of SparseEmbedding.") return table_instance - - -def check_param_range(name, value, min_border, max_border): - if value > max_border or value < min_border: - raise ValueError(f"Please offer a {name} between [{min_border}, {max_border}].") - - return - - -def check_param_type(name, value, legal_type): - if not isinstance(value, legal_type): - raise TypeError(f"Please offer a {name} within types: {legal_type}.") - - return - diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 63ca11cb..6809bd95 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -66,7 +66,10 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos } if (GlobalEnv::fastUnique) { - Factory::Create(factory); + int result = Factory::Create(factory); + if (result != 0) { + throw runtime_error(Logger::Format("create fast factory failed, error code:{}", result)); + } } LOG_INFO(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", -- Gitee From f4b714692820be59bdc62a212ab1f4b4f70c2a2b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 6 Nov 2023 23:53:01 +0800 Subject: [PATCH 421/551] Match-id-d5856b2634876a792a66697789a5031a01430c36 --- mx_rec/core/embedding.py | 2 +- src/core/emb_table/emb_table.h | 6 +++--- .../truncated_normal_initializer.cpp | 21 ++++++++++++++++++- .../truncated_normal_initializer.h | 5 +++++ src/core/ssd_cache/lfu_cache.cpp | 7 ------- src/core/ssd_cache/lfu_cache.h | 2 -- src/tests/ssd_cache/cache_manager_test.cpp | 11 ++++++++-- src/tests/ssd_cache/lfu_cache_test.cpp | 15 +++++++++---- 8 files changed, 49 insertions(+), 20 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 80f14f2f..ec9d7604 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -45,7 +45,7 @@ from mx_rec.util.log import logger ("ssd_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), ("ssd_data_path", ClassValidator, {"classes": (list, tuple)}), ("is_save", ClassValidator, {"classes": (bool, )}), - ("init_param", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("init_param", NumValidator, {"min_value": -10, "max_value": 10}, ["check_value"]), ("all2all_gradients_op", OptionValidator, {"options": [i.value for i in list(All2allGradientsOp)]}), ("value_dtype", OptionValidator, {"options": [tf.float32]}), ("shard_num", IntValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h index 8136165e..62200ada 100644 --- a/src/core/emb_table/emb_table.h +++ b/src/core/emb_table/emb_table.h @@ -53,10 +53,10 @@ namespace MxRec { constexpr static int TEST_EMB_SIZE = 12; EmbInfo embInfo; RankInfo rankInfo; - int blockSize = 1; + size_t blockSize = 1; int embSize = 1; - int totalCapacity = 1; - int usedCapacity = 0; + size_t totalCapacity = 1; + size_t usedCapacity = 0; int seed = 0; // embedding地址的列表 list embeddingList; diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index d3ea48fb..0ee9c336 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -12,9 +12,28 @@ using namespace MxRec; TruncatedNormalInitializer::TruncatedNormalInitializer(int start, int len, NormalInitializerInfo& initInfo) - : start(start), len(len), mean(initInfo.mean), stddev(initInfo.stddev), seed(initInfo.seed) + : start(start), len(len), seed(initInfo.seed) { initParam = initInfo.initK; + // 校验stddev mean值范围 + if (initInfo.mean > TRUNCATED_NORMAL_MEAN_MAX) { + LOG_WARN("truncated normal mean param is greater than 1e9, and will use 10e9."); + mean = TRUNCATED_NORMAL_MEAN_MAX; + } else if (initInfo.mean < TRUNCATED_NORMAL_MEAN_MIN) { + LOG_WARN("truncated normal mean param is less than -1e9, and will use -10e9."); + mean = TRUNCATED_NORMAL_MEAN_MIN; + } else { + mean = initInfo.mean; + } + if (initInfo.stddev > TRUNCATED_NORMAL_STDDEV_MAX) { + LOG_WARN("truncated normal stddev param is greater than 100, and will use 100."); + stddev = TRUNCATED_NORMAL_STDDEV_MAX; + } else if (initInfo.stddev < TRUNCATED_NORMAL_STDDEV_MIN) { + LOG_WARN("truncated normal stddev param is less than -100, and will use -100."); + stddev = TRUNCATED_NORMAL_STDDEV_MIN; + } else { + stddev = initInfo.stddev; + } generator = std::default_random_engine(seed); distribution = std::normal_distribution(mean, stddev); diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index 8de6ad52..e7d9ea5f 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -15,6 +15,11 @@ namespace MxRec { using namespace std; + constexpr float TRUNCATED_NORMAL_STDDEV_MAX = 100; + constexpr float TRUNCATED_NORMAL_STDDEV_MIN = -100; + constexpr float TRUNCATED_NORMAL_MEAN_MAX = 1e9; + constexpr float TRUNCATED_NORMAL_MEAN_MIN = -1e9; + class TruncatedNormalInitializer : public Initializer { public: TruncatedNormalInitializer() = default; diff --git a/src/core/ssd_cache/lfu_cache.cpp b/src/core/ssd_cache/lfu_cache.cpp index 9e312c7a..2ceb5607 100644 --- a/src/core/ssd_cache/lfu_cache.cpp +++ b/src/core/ssd_cache/lfu_cache.cpp @@ -91,13 +91,6 @@ void LFUCache::Put(emb_key_t key) keyTable[key] = freqTable[freq + 1].begin(); } -void LFUCache::PutKeys(vector& keys) -{ - for (auto key : keys) { - Put(key); - } -} - /// 直接放入指定次数;用于初始化场景 /// \param key key /// \param freq 频次 diff --git a/src/core/ssd_cache/lfu_cache.h b/src/core/ssd_cache/lfu_cache.h index 170e0fc7..46584474 100644 --- a/src/core/ssd_cache/lfu_cache.h +++ b/src/core/ssd_cache/lfu_cache.h @@ -41,8 +41,6 @@ namespace MxRec { void Put(emb_key_t key); - void PutKeys(vector& keys); - bool Pop(emb_key_t key); void PutWithInit(emb_key_t key, freq_num_t freq); diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 45c589bf..767dee98 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -64,6 +64,13 @@ void InitDDREmbData(absl::flat_hash_map& loadData, string& loadData[embTableName] = hEmbTable; } +void PutKeyInfo(LFUCache& lfu, vector& embKeys) +{ + for (auto& key : embKeys) { + lfu.Put(key); + } +} + class CacheManagerTest : public testing::Test { protected: void SetUp() @@ -74,10 +81,10 @@ protected: GlogConfig::gRankId = to_string(workRankId); cacheManager.ddrKeyFreqMap[embTableName] = cache; - cacheManager.ddrKeyFreqMap[embTableName].PutKeys(input_keys); + PutKeyInfo(cacheManager.ddrKeyFreqMap[embTableName], input_keys); LFUCache cache2; cacheManager.ddrKeyFreqMap[embTableName2] = cache2; - cacheManager.ddrKeyFreqMap[embTableName2].PutKeys(input_keys); + PutKeyInfo(cacheManager.ddrKeyFreqMap[embTableName2], input_keys); unordered_map excludeDDRKeyFreq; excludeDDRKeyFreq[27] = 10; excludeDDRKeyFreq[30] = 10; diff --git a/src/tests/ssd_cache/lfu_cache_test.cpp b/src/tests/ssd_cache/lfu_cache_test.cpp index 6e4bd487..76d5f7fa 100644 --- a/src/tests/ssd_cache/lfu_cache_test.cpp +++ b/src/tests/ssd_cache/lfu_cache_test.cpp @@ -35,10 +35,17 @@ inline void CompareHandleRet(vector& leastFreqKeys, vector& embKeys) +{ + for (auto& key : embKeys) { + lfu.Put(key); + } +} + TEST(LFUCache, TestGetFreqTable) { LFUCache cache; - cache.PutKeys(INPUT_KEYS); + PutKeys(cache, INPUT_KEYS); auto ret = cache.GetFreqTable(); ASSERT_EQ(ret[9], 1); ASSERT_EQ(ret[6], 2); @@ -48,7 +55,7 @@ TEST(LFUCache, TestGetFreqTable) TEST(LFUCache, PopTest) { LFUCache cache; - cache.PutKeys(INPUT_KEYS); + PutKeys(cache, INPUT_KEYS); cache.Pop(8); cache.Pop(9); ASSERT_EQ(cache.minFreq, 2); @@ -79,7 +86,7 @@ TEST(LFUCache, PutInitTest) TEST(LFUCache, LFUDeleteTotalFreqListTest) { LFUCache cache; - cache.PutKeys(INPUT_KEYS); + PutKeys(cache, INPUT_KEYS); vector retainedKeys = {4, 6, 8, 9}; vector leastFreqKeys; vector leastFreq; @@ -92,7 +99,7 @@ TEST(LFUCache, LFUDeleteTotalFreqListTest) TEST(LFUCache, BaseCacheTest) { LFUCache cache; - cache.PutKeys(INPUT_KEYS); + PutKeys(cache, INPUT_KEYS); vector retainedKeys = {8, 4, 6, 2}; vector leastFreqKeys; vector leastFreq; -- Gitee From 5271ec4985cc318d28732f0c3df16392a7eb8166 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 7 Nov 2023 15:15:29 +0800 Subject: [PATCH 422/551] Match-id-05899f6a5ec8d39ca4c3756d41323870bb2c8177 --- mx_rec/graph/patch.py | 35 ++++++++++++++++++---------- mx_rec/saver/sparse.py | 2 +- src/core/hd_transfer/hd_transfer.cpp | 4 +++- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 +++- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index d3f920f7..a28713c1 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -26,6 +26,9 @@ from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_ from mx_rec.util.tf_version_adapter import NPUCheckpointSaverHook from mx_rec.graph.merge_lookup import do_merge_lookup from mx_rec.util.log import logger +from mx_rec.validator.validator import para_checker_decorator, ClassValidator + +MAX_DEEP_RECUR = 500 def init_dataset(self, input_data): @@ -41,6 +44,12 @@ def init_dataset(self, input_data): self._graph_attr = ops.get_default_graph() +@para_checker_decorator(check_option_list=[ + ("fetches", ClassValidator, {"classes": (str, tf.Operation, tf.Tensor, tf.sparse.SparseTensor, list, tuple, dict)}), + ("feed_dict", ClassValidator, {"classes": (tf.Tensor, tf.sparse.SparseTensor, list, tuple, dict, type(None))}), + ("options", ClassValidator, {"classes": (tf.compat.v1.RunOptions, type(None))}), + ("run_metadata", ClassValidator, {"classes": (tf.compat.v1.RunMetadata, type(None))}), +]) def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """ @@ -74,15 +83,17 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): all_op = [] - def get_all_tensor(tensor_or_tensorlist): + def get_all_tensor(tensor_or_tensorlist, deep=0): + if deep >= MAX_DEEP_RECUR: + raise RuntimeError("Maximum recursion depth reached, fetches is too long to parse") # 把所有的tensor和Operation取出来 - if isinstance(tensor_or_tensorlist, (list, tuple)) : + if isinstance(tensor_or_tensorlist, (list, tuple)): for i in tensor_or_tensorlist: - get_all_tensor(i) + get_all_tensor(i, deep+1) elif isinstance(tensor_or_tensorlist, dict): for k in tensor_or_tensorlist.keys(): - get_all_tensor(tensor_or_tensorlist.get(k)) - elif isinstance(tensor_or_tensorlist, (tf.Tensor, tf.Operation)): + get_all_tensor(tensor_or_tensorlist.get(k), deep+1) + elif isinstance(tensor_or_tensorlist, (tf.Tensor, tf.Operation, tf.sparse.SparseTensor)): name = tensor_or_tensorlist.name if ":" in name: name = name[:name.find(":")] @@ -109,13 +120,13 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): name2channel_cache = self.get_mxrec_name2channel_cache() # 查找相应的channel_id - get_all_tensor(fetches) + get_all_tensor(fetches, deep=0) try: channel_id = get_channel_id_by_sub_graph(all_op, name2channel_cache) except AssertionError: channel_id = -1 - if channel_id != -1: + if channel_id != -1 and get_asc_manager(): get_asc_manager().block_notify_wake(channel_id) if channel_id == constants.EVAL_CHANNEL_ID: @@ -127,7 +138,7 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): # 调用tensorflow原生的方法 result = self.old_run_method(fetches, feed_dict, options, run_metadata) - if channel_id != -1: + if channel_id != -1 and get_asc_manager(): get_asc_manager().block_count_steps(channel_id, steps) return result @@ -141,15 +152,15 @@ def patch_for_session(): def get_mxrec_steps(self): try: # 不能在未调用非__init__函数之前调用非__init__中定义的实例化属性 - return self.steps + return self.mxrec_steps except AttributeError: - self.steps = 1 + self.mxrec_steps = 1 for custom_optimizer in self.get_config().graph_options.rewrite_options.custom_optimizers: if custom_optimizer.name == "NpuOptimizer" \ and custom_optimizer.parameter_map["iterations_per_loop"].i != 0: - self.steps = custom_optimizer.parameter_map["iterations_per_loop"].i + self.mxrec_steps = custom_optimizer.parameter_map["iterations_per_loop"].i break - return self.steps + return self.mxrec_steps def get_mxrec_name2channel_cache(self): try: diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index 581acb93..a8d3321e 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -166,7 +166,7 @@ class SparseProcessor: @para_checker_decorator(check_option_list=[ - ("table_list", ClassValidator, {"classes": (list, )}) + ("table_list", ClassValidator, {"classes": (list, type(None))}) ]) def export(**kwargs): empty_value = 0 diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 44736319..e92537dc 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -57,7 +57,9 @@ void HDTransfer::Destroy() LOG_INFO(HD + "destroy channel start"); for (auto& c: transferChannels) { LOG_INFO(HD + "start destroy channel:{}", c.first); - acltdtDestroyChannel(c.second); + if (acltdtStopChannel(c.second)!=ACL_ERROR_NONE || acltdtDestroyChannel(c.second)!=ACL_ERROR_NONE) { + throw runtime_error("Acl destroy channel failed."); + } LOG_INFO(HD + "destroy channel:{}", c.first); } for (auto& d: aclDatasets) { diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index ff9a1e43..47718fb1 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -825,7 +825,9 @@ void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start) TimeCost hostEmbsTC; hostEmbs->Join(channelId); LOG_DEBUG("hostEmbsTC(ms):{}", hostEmbsTC.ElapsedMS()); - + if (!isRunning) { + return; + } EmbHDTrans(channelId, batchId); } -- Gitee From c24624329192bcb6c82a773233107fea26d8bbac Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 7 Nov 2023 14:51:16 +0800 Subject: [PATCH 423/551] Match-id-d9776973b4f93ea20e25ab6c6c4a802e17671a7d --- .../op_host/embedding_lookup_by_address.cpp | 91 ++++++++++++------- .../embedding_lookup_by_address_tiling.h | 5 + .../op_host/embedding_update_by_address.cpp | 71 ++++++++++----- .../embedding_update_by_address_tiling.h | 5 + .../op_kernel/embedding_lookup_by_address.cpp | 32 ++----- .../op_kernel/embedding_update_by_address.cpp | 62 +++++-------- 6 files changed, 150 insertions(+), 116 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 52ca55f4..763ecc8f 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -33,40 +33,67 @@ namespace optiling currentWorkspace[0] = sysWorkspaceSize + usrSize; - int32_t block_total_nums = 48; - int32_t ub_limit = 175 * 1024; + int32_t blockTotalNums = 48; + int32_t ubLimit = 175 * 1024; auto *attrs = context->GetAttrs(); - const auto *attr0_value = attrs->GetAttrPointer(0); - if (CheckNullPointer(attr0_value, " Lookup embbeding_type attr0_value") != ge::GRAPH_SUCCESS) { + if (CheckNullPointer(attrs, "GetAttrs attrs") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + const auto *attr0Value = attrs->GetAttrPointer(0); + if (CheckNullPointer(attr0Value, " Lookup embbedingType attr0Value") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } - int32_t embbeding_dim = *attr0_value; - if (embbeding_dim <= 0) { - printf("embbeding_dim must larger than 0\n"); + int32_t embbedingDim = *attr0Value; + if (embbedingDim <= 0) { + printf("embbedingDim must larger than 0\n"); return ge::GRAPH_FAILED; } - const auto *attr1_value = attrs->GetAttrPointer(1); - if (CheckNullPointer(attr1_value, "Lookup embbeding_type attr1_value") != ge::GRAPH_SUCCESS) { + const auto *attr1Value = attrs->GetAttrPointer(1); + if (CheckNullPointer(attr1Value, "Lookup embbedingType attr1Value") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } - int32_t embbeding_type = *attr1_value; + int32_t embbedingType = *attr1Value; auto inputTensor = context->GetInputTensor(0); if (CheckNullPointer(inputTensor, "inputTensor") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } - int32_t input_shape = inputTensor->GetShapeSize(); + int32_t inputShape = inputTensor->GetShapeSize(); + int32_t singleDataSize = 4; + if (embbedingType == 2) { + singleDataSize = 2; + } + int32_t minMoveNum = 32 / singleDataSize; + + // onceMoveNums,(embbedingDim - 1 + minMoveNum) / min_move_num表示除以min_move_num向下取整 + int32_t onceMoveNums = minMoveNum * ((embbedingDim - 1 + minMoveNum) / minMoveNum); + + int32_t numToMove = (embbedingDim - 1 + onceMoveNums) / onceMoveNums; + // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 + int32_t pingPongNum = 1; + int32_t occupyAddressBytesNum = + sizeof(int64_t) + singleDataSize * onceMoveNums * numToMove * pingPongNum * 2; + // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 + int32_t addrMaxNum = (((ubLimit / occupyAddressBytesNum) / 4)) * 4; + if (addrMaxNum <= 0) { + return ge::GRAPH_FAILED; + } + + tiling.set_embbeding_type(embbedingType); + tiling.set_update_dim(embbedingDim); + tiling.set_addr_nums(inputShape); + tiling.set_ub_limit(ubLimit); - tiling.set_embbeding_type(embbeding_type); - tiling.set_update_dim(embbeding_dim); - tiling.set_addr_nums(input_shape); - tiling.set_ub_limit(ub_limit); + tiling.set_addr_max_num(addrMaxNum); + tiling.set_ping_pong_num(pingPongNum); + tiling.set_single_data_size(singleDataSize); + tiling.set_once_move_nums(onceMoveNums); - context->SetBlockDim(block_total_nums); + context->SetBlockDim(blockTotalNums); tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); @@ -79,8 +106,8 @@ namespace ge static ge::graphStatus InferShape1(gert::InferShapeContext *context) { - gert::Shape *y_shape = context->GetOutputShape(0); - if (optiling::CheckNullPointer(y_shape, "y_shape") != ge::GRAPH_SUCCESS) { + gert::Shape *yShape = context->GetOutputShape(0); + if (optiling::CheckNullPointer(yShape, "yShape") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } @@ -89,23 +116,23 @@ namespace ge return ge::GRAPH_FAILED; } - const auto *attr0_value = attrs->GetAttrPointer(0); - if (optiling::CheckNullPointer(attr0_value, "Lookup embbeding_type attr0_value") != ge::GRAPH_SUCCESS) { + const auto *attr0Value = attrs->GetAttrPointer(0); + if (optiling::CheckNullPointer(attr0Value, "Lookup embbedingType attr0Value") != ge::GRAPH_SUCCESS) { return GRAPH_FAILED; } - int64_t update_dim = *attr0_value; + int64_t updateDim = *attr0Value; - int64_t input_shape = context->GetInputTensor(0)->GetShapeSize(); - y_shape->SetDimNum(2); - y_shape->SetDim(0, input_shape); - y_shape->SetDim(1, update_dim); + int64_t inputShape = context->GetInputTensor(0)->GetShapeSize(); + yShape->SetDimNum(2); + yShape->SetDim(0, inputShape); + yShape->SetDim(1, updateDim); return GRAPH_SUCCESS; } static ge::graphStatus InferDataType1(gert::InferDataTypeContext *context) { - int64_t embbeding_type; + int64_t embbedingType; if (optiling::CheckNullPointer(context, "context") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } @@ -115,21 +142,21 @@ namespace ge return ge::GRAPH_FAILED; } - const auto *attr1_value = attrs->GetAttrPointer(1); - if (optiling::CheckNullPointer(attr1_value, "Lookup embbeding_type") != ge::GRAPH_SUCCESS) { + const auto *attr1Value = attrs->GetAttrPointer(1); + if (optiling::CheckNullPointer(attr1Value, "Lookup embbedingType") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } - embbeding_type = *attr1_value; - if (embbeding_type == 0) + embbedingType = *attr1Value; + if (embbedingType == 0) { context->SetOutputDataType(0, ge::DataType(DT_INT32)); } - else if (embbeding_type == 1) + else if (embbedingType == 1) { context->SetOutputDataType(0, ge::DataType(DT_FLOAT)); } - else if (embbeding_type == 2) + else if (embbedingType == 2) { context->SetOutputDataType(0, ge::DataType(DT_FLOAT16)); diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h index b91f759b..596bd715 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h @@ -8,6 +8,11 @@ BEGIN_TILING_DATA_DEF(TilingData1) TILING_DATA_FIELD_DEF(int32_t, ub_limit); TILING_DATA_FIELD_DEF(int32_t, embbeding_type); TILING_DATA_FIELD_DEF(int32_t, update_type); + TILING_DATA_FIELD_DEF(int32_t, addr_max_num); + TILING_DATA_FIELD_DEF(int32_t, ping_pong_num); + TILING_DATA_FIELD_DEF(int32_t, single_data_size); + TILING_DATA_FIELD_DEF(int32_t, once_move_nums); + END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(EmbeddingLookupByAddress, TilingData1) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index 0d82535e..ee4dfea2 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -31,7 +31,7 @@ namespace optiling TilingData2 tiling; size_t usrSize = 256, sysWorkspaceSize = 16 * 1024 * 1024; - if (CheckPointer(context, "Update embbeding_type context") != ge::GRAPH_SUCCESS) + if (CheckPointer(context, "Update embbedingType context") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; size_t *currentWorkspace = context->GetWorkspaceSizes(1); @@ -40,22 +40,26 @@ namespace optiling currentWorkspace[0] = sysWorkspaceSize + usrSize; - int32_t block_total_nums = 48; - int32_t ub_limit = 175 * 1024; - int32_t update_dim, embbeding_type; + int32_t blockTotalNums = 48; + int32_t ubLimit = 175 * 1024; + auto inputTensor = context->GetInputTensor(0); if (CheckPointer(inputTensor, "GetInputTensor inputTensor") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; - int32_t input_shape = inputTensor->GetShapeSize(); - if (CheckPositiveInt(input_shape, "input_shape") != ge::GRAPH_SUCCESS) + int32_t inputShape = inputTensor->GetShapeSize(); + if (CheckPositiveInt(inputShape, "inputShape") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; auto inputTensor1 = context->GetInputTensor(1); if (CheckPointer(inputTensor1, "GetInputTensor inputTensor1") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; - int32_t input_dim = inputTensor1->GetShapeSize() / input_shape; + int32_t inputDim = inputTensor1->GetShapeSize() / inputShape; + if (CheckPositiveInt(inputDim, "inputDim") != ge::GRAPH_SUCCESS) { + return ge::GRAPH_FAILED; + } + auto attrs = context->GetAttrs(); if (CheckPointer(attrs, "GetAttrs attrs") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; @@ -64,26 +68,49 @@ namespace optiling if (CheckPointer(attrPointer, "attrPointer") != ge::GRAPH_SUCCESS) return ge::GRAPH_FAILED; - int32_t update_type = *(attrPointer); - ge::DataType input_datatype = inputTensor1->GetDataType(); - if (input_datatype == ge::DT_FLOAT16) { - embbeding_type = 2; - } else if (input_datatype == ge::DT_INT32) { - embbeding_type = 0; + int32_t updateType = *(attrPointer); + ge::DataType inputDatatype = inputTensor1->GetDataType(); + int32_t embbedingType; + if (inputDatatype == ge::DT_FLOAT16) { + embbedingType = 2; + } else if (inputDatatype == ge::DT_INT32) { + embbedingType = 0; } else { - embbeding_type = 1; + embbedingType = 1; } - update_dim = input_dim; - if (CheckPositiveInt(update_dim, "update_dim") != ge::GRAPH_SUCCESS) + int32_t singleDataSize = 4; + if (embbedingType == 2) { + singleDataSize = 2; + } + int32_t minMoveNum = 32 / singleDataSize; + + // onceMoveNums,(updateDim - 1 + minMoveNum) / min_move_num表示除以min_move_num向下取整 + int32_t onceMoveNums = minMoveNum * ((inputDim - 1 + minMoveNum) / minMoveNum); + + int32_t numToMove = (inputDim - 1 + onceMoveNums) / onceMoveNums; + // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 + int32_t pingPongNum = 1; + int32_t occupyAddressBytesNum = + sizeof(int64_t) + singleDataSize * onceMoveNums * numToMove * pingPongNum * 2; + // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 + int32_t addrMaxNum = ((int)((int)(ubLimit / occupyAddressBytesNum) / 4)) * 4; + if (CheckPositiveInt(addrMaxNum, "addrMaxNum") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; + } + + tiling.set_update_type(updateType); + tiling.set_embbeding_type(embbedingType); + tiling.set_update_dim(inputDim); + tiling.set_addr_nums(inputShape); + tiling.set_ub_limit(ubLimit); + + tiling.set_addr_max_num(addrMaxNum); + tiling.set_ping_pong_num(pingPongNum); + tiling.set_single_data_size(singleDataSize); + tiling.set_once_move_nums(onceMoveNums); - tiling.set_update_type(update_type); - tiling.set_embbeding_type(embbeding_type); - tiling.set_update_dim(update_dim); - tiling.set_addr_nums(input_shape); - tiling.set_ub_limit(ub_limit); - context->SetBlockDim(block_total_nums); + context->SetBlockDim(blockTotalNums); tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h index 323014d3..9cd630a1 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h @@ -9,6 +9,11 @@ BEGIN_TILING_DATA_DEF(TilingData2) TILING_DATA_FIELD_DEF(int32_t, ub_limit); TILING_DATA_FIELD_DEF(int32_t, embbeding_type); TILING_DATA_FIELD_DEF(int32_t, update_type); + TILING_DATA_FIELD_DEF(int32_t, addr_max_num); + TILING_DATA_FIELD_DEF(int32_t, ping_pong_num); + TILING_DATA_FIELD_DEF(int32_t, single_data_size); + TILING_DATA_FIELD_DEF(int32_t, once_move_nums); + END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(EmbeddingUpdateByAddress, TilingData2) diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 0adf5d2e..285519b0 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -1,4 +1,3 @@ - #include "kernel_operator.h" using namespace AscendC; template @@ -32,39 +31,24 @@ public: { GET_TILING_DATA(constData, tiling); // 数据的维度数 - int32_t updateDim = constData.update_dim; - int32_t embeddingType = constData.embbeding_type; + dim = constData.update_dim; int32_t blockTotalNums = block_num; - int32_t ubLimit = constData.ub_limit; addrNums = constData.addr_nums; - if (embeddingType == 2) - { - singleDataSize = 2; - } - else - { - singleDataSize = 4; - } // 缓冲区数量 - PingpongNum = 1; - int minMoveNum = 32 / singleDataSize; - // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + minMoveNum) / minMoveNum表示除以minMoveNum向下取整 - onceMoveNums = minMoveNum * ((int)(updateDim - 1 + minMoveNum) / minMoveNum); - int numToMove = (int32_t)(updateDim - 1 + onceMoveNums) / onceMoveNums; - // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 - int occupyAddressBytesNum = sizeof(int64_t) + singleDataSize * onceMoveNums * numToMove * PingpongNum * 2; - // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 - int addrMaxNum = ((int)((int)(ubLimit / occupyAddressBytesNum) / 4)) * 4; + PingpongNum = constData.ping_pong_num; + singleDataSize = constData.single_data_size; + onceMoveNums = constData.once_move_nums; + roundSize = constData.addr_max_num; + int singleNum = (int)(addrNums / blockTotalNums); if (singleNum % 4) { singleNum -= singleNum % 4; } - roundSize = addrMaxNum; + Veclen = roundSize * singleDataSize * onceMoveNums; SingleCoreAddrLen = singleNum * sizeof(int64_t); cache = roundSize; - dim = updateDim; } __aicore__ inline void Process() @@ -184,7 +168,7 @@ private: public: int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, cache, Veclen, dim, PingpongNum; int32_t addrNums; - int32_t onceMoveNums, singleDataSize, update_type; + int32_t onceMoveNums, singleDataSize, updateType; private: TPipe pipe; diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index d13075e0..2acd79c0 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -12,7 +12,7 @@ public: NeedComputeAddrLen = SingleCoreAddrLen; if (block_idx == block_num - 1) { - NeedComputeAddrLen = addr_nums * sizeof(int64_t) - SingleCoreAddrLen * (block_num - 1); + NeedComputeAddrLen = addrNums * sizeof(int64_t) - SingleCoreAddrLen * (block_num - 1); } round = NeedComputeAddrLen / (roundSize * sizeof(int64_t)); @@ -29,40 +29,26 @@ public: { GET_TILING_DATA(constData, tiling); // 数据的维度数 - int32_t update_dim = constData.update_dim; - int32_t embbeding_type = constData.embbeding_type; + dim = constData.update_dim; int32_t block_total_nums = block_num; - int32_t ub_limit = constData.ub_limit; - update_type = constData.update_type; - addr_nums = constData.addr_nums; - if (embbeding_type == 2) - { - singleDataSize = 2; - } - else - { - singleDataSize = 4; - } + updateType = constData.update_type; + addrNums = constData.addr_nums; + // 缓冲区数量 - PingpongNum = 1; - int min_move_num = 32 / singleDataSize; - // onceMoveNums表示每个数据维度需要移动的次数,(update_dim - 1 + min_move_num) / min_move_num表示除以min_move_num向下取整 - onceMoveNums = min_move_num * ((int)(update_dim - 1 + min_move_num) / min_move_num); - int num_to_move = (int32_t)(update_dim - 1 + onceMoveNums) / onceMoveNums; - // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 - int occupyAddressBytesNum = sizeof(int64_t) + singleDataSize * onceMoveNums * num_to_move * PingpongNum * 2; - // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 - int addrMaxNum = ((int)((int)(ub_limit / occupyAddressBytesNum) / 4)) * 4; - int singlenum = (int)(addr_nums / block_total_nums); + PingpongNum = constData.ping_pong_num; + singleDataSize = constData.single_data_size; + onceMoveNums = constData.once_move_nums; + roundSize = constData.addr_max_num; + + int singlenum = (int)(addrNums / block_total_nums); if (singlenum % 4) { singlenum -= singlenum % 4; } - roundSize = addrMaxNum; + Veclen = roundSize * singleDataSize * onceMoveNums; SingleCoreAddrLen = singlenum * sizeof(int64_t); cache = roundSize; - dim = update_dim; } __aicore__ inline void Process() @@ -83,13 +69,13 @@ public: if (unprocess) { - int unprocess_once_copyaddr = unprocess; - if (unprocess_once_copyaddr % 4 != 0) + int unprocessOnceCopyaddr = unprocess; + if (unprocessOnceCopyaddr % 4 != 0) { - unprocess_once_copyaddr += (4 - unprocess % 4); + unprocessOnceCopyaddr += (4 - unprocess % 4); } - DataCopy(srcAddrLocal, srcAddrGlobal[round * roundSize], unprocess_once_copyaddr); + DataCopy(srcAddrLocal, srcAddrGlobal[round * roundSize], unprocessOnceCopyaddr); MoveProcess(srcAddrLocal, round, unprocess); } } @@ -110,7 +96,7 @@ private: inQueue.EnQue(dataLocal); Compute(sizes); LocalTensor dstLocal = outQueue.DeQue(); - if (update_type == 0) + if (updateType == 0) { SetAtomicAdd(); } @@ -123,7 +109,7 @@ private: DataCopy(dstDataGm, dstLocal[i*onceMoveNums], onceMoveNums); } } - if (update_type == 0) + if (updateType == 0) { SetAtomicNone(); } @@ -166,7 +152,7 @@ private: { dstDataGm.SetGlobalBuffer((__gm__ T *)(address)); - if (update_type == 0) + if (updateType == 0) { SetAtomicAdd(); } @@ -187,7 +173,7 @@ private: DataCopy(dstDataGm, dstLocal, onceMoveNums); #endif } - if (update_type == 0) + if (updateType == 0) { SetAtomicNone(); } @@ -195,8 +181,8 @@ private: } public: - int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, addr_nums, cache, Veclen, dim, PingpongNum; - int32_t onceMoveNums, singleDataSize, update_type; + int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, addrNums, cache, Veclen, dim, PingpongNum; + int32_t onceMoveNums, singleDataSize, updateType; private: TPipe pipe; @@ -212,9 +198,9 @@ extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR addres { GET_TILING_DATA(constData, tiling); - int32_t embbeding_type = constData.embbeding_type; + int32_t embbedingType = constData.embbeding_type; - switch (embbeding_type) + switch (embbedingType) { case 0: { -- Gitee From 501b95154f7791d1f39dfbc441815bd274826a31 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 09:24:55 +0800 Subject: [PATCH 424/551] Match-id-349d424189439813bf3b46e16a6c8e81e25fc4b8 --- src/core/hd_transfer/acl_channel.h | 1 - src/ops_tf/hybrid_dataset_ops.cpp | 69 ------------------------------ 2 files changed, 70 deletions(-) diff --git a/src/core/hd_transfer/acl_channel.h b/src/core/hd_transfer/acl_channel.h index 6cbd4e0c..9d0b5b49 100644 --- a/src/core/hd_transfer/acl_channel.h +++ b/src/core/hd_transfer/acl_channel.h @@ -21,7 +21,6 @@ namespace tensorflow { Status RecvTensorByAcl(const acltdtChannelHandle* aclHandle, std::vector& tensors); - Status StopRecvTensorByAcl(acltdtChannelHandle **handle, const std::string &channelName); #endif diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 1b93b31b..7d2e5ad4 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -46,11 +46,6 @@ namespace MxRec { LOG_INFO("clear channel init"); OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); - if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { - throw runtime_error(StringFormat( - "channelId is invalid, It should be in range [0, %d)", MAX_CHANNEL_NUM)); - } - if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat( "ClearChannel channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)", @@ -525,52 +520,6 @@ namespace MxRec { HybridMgmtBlock* hybridMgmtBlock; }; - class ReadEmbKeyDatasetDummy : public OpKernel { - public: - explicit ReadEmbKeyDatasetDummy(OpKernelConstructionPtr context) : OpKernel(context) - { - OP_REQUIRES_OK(context, context->GetAttr("max_lookup_len", &lookupLen)); - } - - ~ReadEmbKeyDatasetDummy() override = default; - - void Compute(OpKernelContextPtr context) override - { - EASY_FUNCTION(); - TimeCost tc = TimeCost(); - const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0); - auto input = inputTensor.flat(); - const int restoreLen = static_cast(input.size()); - - // write lookup & restore vec - Tensor* lookupVec = nullptr; - Tensor* restoreVecTensor = nullptr; - - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape { lookupLen }, &lookupVec)); - OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape { restoreLen }, &restoreVecTensor)); - auto l = lookupVec->flat(); - auto r = restoreVecTensor->flat(); - - // check whether lookupLen is zero - if (lookupLen == 0) { - throw runtime_error("lookupLen is 0, it causes the denominator to be 0 during division"); - } - - // dummy data - for (int i { 0 }; i < lookupLen; ++i) { - l(i) = i; - } - for (int i { 0 }; i < restoreLen; ++i) { - r(i) = i % lookupLen; - } - LOG_WARN("dummy read batch cost: {},elapsed from last {}", - tc.ElapsedMS(), g_staticSw.ElapsedMS()); - tc = TimeCost(); - } - - int lookupLen {}; - }; - class CustOps : public OpKernel { public: explicit CustOps(OpKernelConstructionPtr context) : OpKernel(context) @@ -644,24 +593,6 @@ return Status::OK(); REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), MxRec::ReadEmbKeyV2); -// ##################### ReadEmbKeyDatasetDummy ####################### -REGISTER_OP("ReadEmbKeyDatasetDummy") -.Input("sample: T") -.Output("lookup_vec: int32") -.Output("restore_vec: int32") -.Attr("T: {int64}") -.Attr("max_lookup_len: int") -.SetShapeFn([](InferenceContextPtr c) { -int temp; -TF_RETURN_IF_ERROR(c->GetAttr("max_lookup_len", &temp)); -c->set_output(TensorIndex::TENSOR_INDEX_0, c->Vector(temp)); -c->set_output(TensorIndex::TENSOR_INDEX_1, c->input(TensorIndex::TENSOR_INDEX_0)); -return Status::OK(); -}); - -REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyDatasetDummy").Device(DEVICE_CPU), MxRec::ReadEmbKeyDatasetDummy); - - REGISTER_OP("EmbeddingLookupByAddress") .Input("address: int64") .Attr("embedding_dim: int") -- Gitee From c68b85d563cc388c611eb42201e656dc234bb29d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 09:26:30 +0800 Subject: [PATCH 425/551] Match-id-114a1cfa9b8b128c6f4b2d48e12713929b3756b2 --- mx_rec/constants/constants.py | 3 ++ mx_rec/core/__init__.py | 2 +- mx_rec/core/asc/helper.py | 96 +--------------------------------- mx_rec/core/asc/merge_table.py | 5 ++ mx_rec/core/embedding.py | 11 ++-- mx_rec/core/feature_process.py | 2 +- mx_rec/graph/merge_lookup.py | 2 +- mx_rec/graph/modifier.py | 14 ++++- mx_rec/graph/utils.py | 12 ++++- mx_rec/util/initialize.py | 23 +++++++- 10 files changed, 65 insertions(+), 105 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 60a24901..f04ae201 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -36,6 +36,9 @@ ASCEND_TABLE_NAME_MUST_CONTAIN = None # to avoid op "scatter_nd_update" may get a None tensor for input AVOID_TENSOR_POS = 439999 +# while循环最大深度 +MAX_WHILE_SIZE = 800 + # acl通道数据深度 DEFAULT_HD_CHANNEL_SIZE = 40 MAX_HD_CHANNEL_SIZE = 8192 diff --git a/mx_rec/core/__init__.py b/mx_rec/core/__init__.py index 43336ebe..d711ee8e 100644 --- a/mx_rec/core/__init__.py +++ b/mx_rec/core/__init__.py @@ -2,4 +2,4 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["asc", "embedding"] \ No newline at end of file +__all__ = ["asc", "embedding", "feature_process"] diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index fd9d4214..1c1ee48b 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -24,8 +24,6 @@ from mx_rec.constants.constants import MAX_INT32 ("tgt_key_specs", ClassValidator, {"classes": (FeatureSpec, list, tuple, type(None))}), ("args_index_list", ClassValidator, {"classes": (list, type(None))}), ("table_names", ClassValidator, {"classes": (list, type(None))}), - ("is_training", ClassValidator, {"classes": (bool, type(None))}), - ("dump_graph", ClassValidator, {"classes": (bool, type(None))}), ]) def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, table_names=None, **kwargs): ''' @@ -265,6 +263,8 @@ def do_insert(args, insert_tensors, splits, table_names, input_dict): def export_read_emb_key_v2_op(args, pipeline_op): origin_batch = list(args) + if len(origin_batch) < 1: + raise ValueError("The length of args is less than 1.") if isinstance(origin_batch[0], dict): output_batch = origin_batch[0] valid_key = get_valid_op_key(output_batch) @@ -401,95 +401,3 @@ def is_feature_spec_list(specs): return False return True - - -def get_asc_read_raw_func(cfg_list): - batch = {} - int_name_order = [] - int_len_list = [] - float_name_order = [] - float_len_list = [] - line_per_sample_list = [] - host_pipeline_ops = get_host_pipeline_ops() - for cfg in cfg_list: - if cfg.data_type == "int64": - int_name_order.append(cfg.feature_name) - int_len_list.append(cfg.feature_len) - line_per_sample_list.append(cfg.line_per_sample) - - if cfg.data_type == "float": - float_name_order.append(cfg.feature_name) - float_len_list.append(cfg.feature_len) - line_per_sample_list.append(cfg.line_per_sample) - if len(set(line_per_sample_list)) != 1: - raise ValueError(f"Please check that each line_per_sample value should be equal.") - line_per_sample = line_per_sample_list[0] - - def read_raw_fn(data_src): - raw_int_sample, raw_float_sample = host_pipeline_ops.read_raw( - sample=data_src, - int_len=sum(int_len_list) * line_per_sample, - float_len=sum(float_len_list) * line_per_sample, - feat_order=int_name_order + float_name_order - ) - - int_split_res = tf.split(raw_int_sample, [i * line_per_sample_list[0] for i in int_len_list]) - - float_split_res = tf.split(raw_float_sample, [i * line_per_sample_list[0] for i in float_len_list]) - - - for name_id, name in enumerate(int_name_order): - batch[name] = int_split_res[name_id] - - for name_id, name in enumerate(float_name_order): - batch[name] = float_split_res[name_id] - return batch - - return read_raw_fn - - -class ParseConfig: - - def __init__(self, **kwargs): - self.input_keys = set(kwargs.keys()) - self._feature_name = kwargs.get("feature_name") - self._feature_len = int(kwargs.get("feature_len")) - self._data_type = kwargs.get("data_type") - self._line_per_sample = int(kwargs.get("line_per_sample")) - self.check_params() - - @property - def feature_name(self): - return self._feature_name - - @property - def feature_len(self): - return self._feature_len - - @property - def data_type(self): - return self._data_type - - @property - def line_per_sample(self): - return self._line_per_sample - - def check_params(self): - supported_keys = {"feature_name", "feature_len", "line_per_sample", "data_type"} - if self.input_keys != supported_keys: - raise KeyError("Please offer an expected keyword argument") - - if not isinstance(self._feature_name, str): - raise TypeError(f"Please offer a feature_name with string type.") - - if not isinstance(self._data_type, str): - raise TypeError(f"Please offer a data_type with string type.") - - if self._data_type not in ("int64", "float"): - raise TypeError(f"Please offer a data_type with int64 or float type") - - if self._feature_len <= 0: - raise ValueError(f"Please offer a feature_len greater than zero.") - - if self._line_per_sample <= 0: - raise ValueError(f"Please offer a line_per_sample greater than zero.") diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index 38ba8c71..e06a7c7d 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -7,6 +7,7 @@ from typing import Dict, List import tensorflow as tf from tensorflow import Operation, Tensor +from mx_rec.constants.constants import MAX_WHILE_SIZE from mx_rec.util.initialize import get_enable_table_merge, export_table_instances, insert_dangling_table, \ get_bool_gauge_set from mx_rec.util.log import logger @@ -103,7 +104,11 @@ def find_dangling_table(table_names: List[str]) -> List[str]: """ tensors_visited = set() op_visited = set() + while_num = 0 while next_to_visit: + while_num += 1 + if while_num > MAX_WHILE_SIZE: + raise RuntimeError(f"In bfs_lookup function, the maximum cycle depth is greater than {MAX_WHILE_SIZE}.") spread_tensors = [] for tensor in next_to_visit: if tensor in tensors_visited: diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 80f14f2f..e7834b53 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -455,7 +455,7 @@ class SparseEmbedding: # return the stub tensor of the lookup result if not get_use_static(): - kwargs["ids"] = ids + kwargs["lookup_ids"] = ids mock_lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) mock_lookup_result = tf.identity(mock_lookup_result, name=ASCAnchorAttr.MOCK_LOOKUP_RESULT.value) if not kwargs.get("is_grad"): @@ -673,7 +673,7 @@ class SparseEmbedding: feature_spec_tensor = None if not self.modify_graph: feature_spec_tensor = kwargs.get("batch").get(feature_spec.index_key) - modify_graph_tensor = kwargs.get("ids") + modify_graph_tensor = kwargs.get("lookup_ids") tensor = feature_spec_tensor if not self.modify_graph else modify_graph_tensor if tensor is None: raise KeyError(f"key or ids does not exist in batch, now modify graph is {self.modify_graph}.") @@ -794,7 +794,7 @@ class SparseEmbedding: ("name", ClassValidator, {"classes": (str, type(None))}), ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), ("modify_graph", ClassValidator, {"classes": (bool, type(None))}), - ("batch", ClassValidator, {"classes": (dict, list, tuple, type(None))}), + ("batch", ClassValidator, {"classes": (dict, type(None))}), ("access_and_evict_config", ClassValidator, {"classes": (dict, type(None))}), ("is_grad", ClassValidator, {"classes": (bool, )}), ]) @@ -831,6 +831,11 @@ def sparse_lookup(hashtable: SparseEmbedding, kwargs["modify_graph"] = modify_graph kwargs["batch"] = batch kwargs["access_and_evict_config"] = access_and_evict_config + # 参数由内部创建,不使用外部入参,覆盖外部入参 + kwargs["feature_spec_name_ids_dict"] = None + kwargs["multi_lookup"] = False + kwargs["lookup_ids"] = None + scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) logger.info("Lookup: The table name is %s, and the value of `is_grad` in this lookup (lookup name is %s) is %s.", hashtable.table_name, name, is_grad) diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py index c2c16f11..6dd6fc3d 100644 --- a/mx_rec/core/feature_process.py +++ b/mx_rec/core/feature_process.py @@ -12,7 +12,7 @@ from mx_rec.validator.validator import para_checker_decorator, ClassValidator, I from mx_rec.util.log import logger -class _EvictHook(tf.compat.v1.train.SessionRunHook): +class EvictHook(tf.compat.v1.train.SessionRunHook): """Sets evict based on global step or time.""" @para_checker_decorator( check_option_list=[ diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py index 79f2649a..380879a4 100644 --- a/mx_rec/graph/merge_lookup.py +++ b/mx_rec/graph/merge_lookup.py @@ -74,7 +74,7 @@ def do_merge_lookup(is_train: bool = True): continue send_count = table_instance.send_count - kwargs = dict(is_train=is_train, ids=cutting_point, multi_lookup=True, is_grad=is_grad) + kwargs = dict(is_train=is_train, lookup_ids=cutting_point, multi_lookup=True, is_grad=is_grad) if not get_use_static(): kwargs["feature_spec_name_ids_dict"] = feature_spec_name_ids_dict lookup_result = table_instance.lookup_for_asc_with_feature_spec(feature_spec, send_count, **kwargs) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index d9c2855a..9226af56 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -15,7 +15,7 @@ from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME + ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, increase_run_times, \ get_is_last_round, insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ @@ -155,7 +155,12 @@ def find_target_dataset_op(base_ops, op_type): base_ops = check_input_list(base_ops, tf.Operation) parent_ops = base_ops + while_num = 0 while True: + while_num += 1 + if while_num > MAX_WHILE_SIZE: + raise RuntimeError(f"In find_target_dataset_op function, the maximum cycle depth is greater " + f"than {MAX_WHILE_SIZE}.") for parent_op in parent_ops: if parent_op.type == op_type: return parent_op @@ -204,7 +209,12 @@ def get_passing_tensor_list(src_tensors, target_op): def get_passing_tensors(src_tensor): passing_tensors = [] tensor_list = [src_tensor] + while_num = 0 while tensor_list: + while_num += 1 + if while_num > MAX_WHILE_SIZE: + raise RuntimeError(f"In get_passing_tensors function, the maximum cycle depth is greater " + f"than {MAX_WHILE_SIZE}.") last_tensor = tensor_list.pop() if last_tensor.op is target_op: passing_tensors.append(last_tensor) @@ -557,7 +567,7 @@ class GraphModifierHook(tf.estimator.SessionRunHook): ("modify_graph", ClassValidator, {"classes": (bool,)}) ] ) - def __init__(self, dump_graph=True, modify_graph=True): + def __init__(self, dump_graph=False, modify_graph=True): self._dump_graph = dump_graph self._modify_graph = modify_graph self._iterator_type = "" diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index e0e81b70..8c063c26 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -6,9 +6,11 @@ from collections import defaultdict import os import tensorflow as tf +from tensorflow.python.framework.errors_impl import InvalidArgumentError from mx_rec.constants.constants import ASCAnchorAttr, DUMP_MIDIFY_GRAPH_FILE_MODE from mx_rec.core.embedding import SparseEmbedding +from mx_rec.util.log import logger def check_input_list(objs, obj_type): @@ -59,9 +61,15 @@ def replace_anchor(replacement_specs: defaultdict, new_tensor_list: list): raise ValueError(f"Given replacement_specs and new_tensor_list must have the same length. " f"replacement_specs: {replacement_specs}, new_tensor_list: {new_tensor_list}") - for tensor_idx, (_, items) in enumerate(replacement_specs.items()): + for tensor_idx, (old_tensor, items) in enumerate(replacement_specs.items()): for input_idx, operator in items: - operator._update_input(input_idx, new_tensor_list[tensor_idx]) + try: + operator._update_input(input_idx, new_tensor_list[tensor_idx]) + except InvalidArgumentError as err: + logger.info("The replacement specs keys (old batch) is: %s. \n\t\t The new_tensor_list is: %s.", + replacement_specs.keys(), new_tensor_list) + raise RuntimeError(f"Cannot update edge, old tensor: {old_tensor}, " + f"new tensor: {new_tensor_list[tensor_idx]}.") from err def export_pb_graph(file_name, dump_graph, graph_def=None, export_path="./export_graph", as_text=False): diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index eb624b24..8e91ff2c 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -6,7 +6,6 @@ import os from collections import defaultdict import dataclasses import json - import psutil import mx_rec.constants.constants @@ -82,6 +81,7 @@ class ConfigInitializer: self._comm = MPI.COMM_WORLD self._rank_id = self._comm.Get_rank() self._rank_size = self._comm.Get_size() + self.check_mpi_params() else: raise ValueError("only mpi is supported for launching task.") @@ -240,6 +240,12 @@ class ConfigInitializer: ConfigInitializer._single_instance = ConfigInitializer(use_mpi, **kwargs) + def check_mpi_params(self): + if self._rank_size < 1: + raise ValueError("The length of the mpi rank_size is less than 1.") + if self._rank_id < 0: + raise ValueError("The length of the mpi rank_id is less than 0.") + def terminate(self): logger.info("python process run into terminate") if self._is_terminated: @@ -766,7 +772,19 @@ def set_initializer(is_training, initializer): ConfigInitializer.get_instance().set_initializer(is_training, initializer) +@para_checker_decorator(check_option_list=[ + ("name", ClassValidator, {"classes": (str, list)}), + ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) +]) def set_ascend_table_name_must_contain(name="merged"): + """ + 设置表名中必须包含的关键字 + Args: + name: 表名中必须包含的关键字 + + Returns: None + + """ mx_rec.constants.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name @@ -802,6 +820,9 @@ def set_target_batch(is_training: bool, batch: dict): ConfigInitializer.get_instance().set_target_batch(is_training, batch) +@para_checker_decorator(check_option_list=[ + ("is_training", ClassValidator, {"classes": (bool, )}) +]) def get_target_batch(is_training: bool) -> dict: """ 返回自动改图模式下生成新数据集中batch的记录. -- Gitee From 0b369f9e8e20e56f387f07ce6e7bbf0789a20583 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 09:29:17 +0800 Subject: [PATCH 426/551] Match-id-388215b8f5d428691f909e1cfa3e7234efd228e7 --- mx_rec/core/asc/helper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 1c1ee48b..cac318cf 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -24,6 +24,8 @@ from mx_rec.constants.constants import MAX_INT32 ("tgt_key_specs", ClassValidator, {"classes": (FeatureSpec, list, tuple, type(None))}), ("args_index_list", ClassValidator, {"classes": (list, type(None))}), ("table_names", ClassValidator, {"classes": (list, type(None))}), + ("is_training", ClassValidator, {"classes": (bool, type(None))}), + ("dump_graph", ClassValidator, {"classes": (bool, type(None))}), ]) def get_asc_insert_func(tgt_key_specs=None, args_index_list=None, table_names=None, **kwargs): ''' -- Gitee From 8e12146c9e933094d08623a3af0c39d70cb33d0d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 10:05:01 +0800 Subject: [PATCH 427/551] Match-id-52fad344486dee47c135ed99cc29f5f52b1a0a8f --- mx_rec/core/asc/helper.py | 7 +++++-- mx_rec/core/embedding.py | 6 +++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index cac318cf..783871b1 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -6,7 +6,7 @@ from functools import reduce import tensorflow as tf -from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static +from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static, get_modify_graph from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.merge_table import find_dangling_table, should_skip from mx_rec.validator.validator import para_checker_decorator, ValueCompareValidator, ClassValidator, \ @@ -160,7 +160,10 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list): f"len(split_list): {len(split_list)}" f"len(table_name_list): {len(table_name_list)}") feature_id_requests = zip(feature_id_list, split_list, table_name_list) - feature_id_requests = sorted(feature_id_requests, key=lambda x: (x[2], x[0].name)) + if get_modify_graph(): + feature_id_requests = sorted(feature_id_requests, key=lambda x: (x[2])) + else: + feature_id_requests = sorted(feature_id_requests, key=lambda x: (x[2], x[0].name)) logger.debug("features to merge: %s", feature_id_requests) last_table_name = None diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 9be54609..fcd191ae 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -523,9 +523,9 @@ class SparseEmbedding: same_table_tensor_list.append(tensor) return same_table_tensor_list - # Ensure that tensors in the same table are sorted according to the lookup sequence (modify graph mode) or - # the sequence in which feature specs are created (feature spec mode). - same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) + # 改图模式下FeatureSpec是按照lookup顺序创建的,无需对ids进行排序;fs模式下手动创建FeatureSpec,不一定有序 + if not self.modify_graph: + same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", table_name=table_name) if get_use_static(): -- Gitee From 527864156a73f775916d69fb438ed51e1bbc3e20 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 10:10:12 +0800 Subject: [PATCH 428/551] Match-id-4d979969bc61b100e3b5cd133f3bf6d59bc8e56f --- mx_rec/saver/patch.py | 8 ++--- mx_rec/saver/saver.py | 56 +++++++++++++++++------------- mx_rec/saver/sparse.py | 14 ++++---- src/core/checkpoint/checkpoint.cpp | 49 ++++++++++++++++++++------ src/core/utils/common.h | 2 +- 5 files changed, 82 insertions(+), 47 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 7ce8f5d8..fe9e7ffb 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -175,13 +175,13 @@ def build(self): @para_checker_decorator(check_option_list=[ ("sess", ClassValidator, {"classes": (tf.compat.v1.Session, tf.compat.v1.train.MonitoredSession)}), - ("save_path", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("save_path", StringValidator, {"min_len": 1, "max_len": 150}, ["check_string_length"]), ("global_step", ClassValidator, {"classes": (int, np.int64, type(None))}), ("global_step", OptionalIntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), ("latest_filename", ClassValidator, {"classes": (str, type(None))}), - ("latest_filename", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("latest_filename", OptionalStringValidator, {"min_len": 1, "max_len": 50}, ["check_string_length"]), ("meta_graph_suffix", ClassValidator, {"classes": (str, type(None))}), - ("meta_graph_suffix", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("meta_graph_suffix", OptionalStringValidator, {"min_len": 1, "max_len": 50}, ["check_string_length"]), ("write_meta_graph", ClassValidator, {"classes": (bool, type(None))}), ("write_state", ClassValidator, {"classes": (bool, type(None))}), ("strip_default_attrs", ClassValidator, {"classes": (bool, type(None))}), @@ -227,7 +227,7 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra @para_checker_decorator(check_option_list=[ ("sess", ClassValidator, {"classes": (tf.compat.v1.Session, tf.compat.v1.train.MonitoredSession)}), - ("save_path", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("save_path", StringValidator, {"min_len": 1, "max_len": 150}, ["check_string_length"]), ]) def restore(self, sess, save_path): if save_path is None: diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index d2f6a962..52d465b6 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -11,12 +11,13 @@ import numpy as np import tensorflow as tf from tensorflow.python.util import compat -from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE, Flag, TFDevice +from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE, Flag, TFDevice, MAX_INT32 from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ send_host_data, get_ascend_global_hashtable_collection, set_sparse_dir, get_local_rank_size from mx_rec.util.perf import performance -from mx_rec.validator.validator import DirectoryValidator, FileValidator +from mx_rec.validator.validator import DirectoryValidator, FileValidator, para_checker_decorator, ClassValidator, \ + IntValidator, OptionalStringValidator from mx_rec.util.global_env_conf import global_env from mx_rec.util.log import logger @@ -37,6 +38,12 @@ class SaveModelThread(threading.Thread): class Saver(object): customized_ops = get_customized_ops() + @para_checker_decorator(check_option_list=[ + ("var_list", ClassValidator, {"classes": (list, type(None))}), + ("max_to_keep", IntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), + ("prefix_name", ClassValidator, {"classes": (str, type(None))}), + ("prefix_name", OptionalStringValidator, {"min_len": 1, "max_len": 50}, ["check_string_length"]), + ]) def __init__(self, var_list=None, max_to_keep=3, prefix_name="checkpoint"): self.max_to_keep = max_to_keep self._prefix_name = prefix_name @@ -95,6 +102,10 @@ class Saver(object): :return: None """ logger.debug("======== Start saving for rank id %s ========", self.rank_id) + if not check_file_system_is_valid(save_path): + raise ValueError(f"the path to save sparse embedding table data belong to invalid file system, " + f"only local file system supported. ") + save_path = save_path if save_path else self._prefix_name directory, base_name = os.path.split(save_path) @@ -109,11 +120,10 @@ class Saver(object): set_sparse_dir(saving_path) try: - if save_path.find("://") == -1: - directory_validator = DirectoryValidator("saving_path", saving_path) - directory_validator.check_not_soft_link() - directory_validator.with_blacklist(exact_compare=False) - directory_validator.check() + directory_validator = DirectoryValidator("saving_path", saving_path) + directory_validator.check_not_soft_link() + directory_validator.with_blacklist(exact_compare=False) + directory_validator.check() except ValueError as err: raise ValueError(f"The saving path {saving_path} cannot be a system directory " f"and cannot be soft link.") from err @@ -138,6 +148,10 @@ class Saver(object): @performance("Restore") def restore(self, sess, reading_path): logger.debug("======== Start restoring ========") + if not check_file_system_is_valid(reading_path): + raise ValueError(f"the path to save sparse embedding table data belong to invalid file system, " + f"only local file system supported. ") + directory, base_name = os.path.split(reading_path) ckpt_name = f"sparse-{base_name}" @@ -409,14 +423,7 @@ def write_binary_data(writing_path, suffix, data, attributes=None): if tf.io.gfile.exists(target_attribute_dir): raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} exists before writing.") - if target_data_dir.find("://") != -1: - logger.debug("use hdfs path %s to save sparse data.", target_data_dir) - with tf.io.gfile.GFile(target_data_dir, "wb") as file: - data = data.tostring() - file.write(data) - else: - logger.debug("use local file path %s to save sparse data.", target_data_dir) - data.tofile(target_data_dir) + data.tofile(target_data_dir) if attributes is not None: if not isinstance(attributes, dict): @@ -452,13 +459,7 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: with tf.io.gfile.GFile(target_data_dir, "rb") as file: validate_read_file(target_data_dir) - if target_data_dir.find("://") != -1: - logger.debug("use hdfs path %s to restore sparse data.", target_data_dir) - data_to_restore = file.read() - data_to_restore = np.fromstring(data_to_restore, dtype=attributes.pop(DataAttr.DATATYPE.value)) - else: - logger.debug("use local file path %s to restore sparse data.", target_data_dir) - data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) + data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) if DataAttr.SHAPE.value in attributes and data_name != DataName.KEY.value: data_shape = attributes.pop(DataAttr.SHAPE.value) @@ -482,9 +483,8 @@ def validate_read_file(read_file_path): """ file_validator = FileValidator("read_file_path", read_file_path) file_validator.check_file_size(MAX_FILE_SIZE, MIN_SIZE) - # local file need to check soft link - if read_file_path.find("://") == -1: - file_validator.check_not_soft_link() + file_validator.check_user_group() + file_validator.check_not_soft_link() file_validator.check() @@ -512,3 +512,9 @@ def process_embedding_data(data_to_restore: np.ndarray, current_data_shape: list f"saved vocabulary size {vocab_size},which would loss the mapping between keys and embeddings ") return data_to_restore + + +def check_file_system_is_valid(file_path): + if file_path.find("://") == -1: + return True + return False diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index 581acb93..a614cf9d 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -16,7 +16,7 @@ from mx_rec.util.log import logger class SparseProcessor: single_instance = None - def __init__(self, **kwargs): + def __init__(self, table_list): self.export_name = "key-emb" self.device_dir_list = ["HashTable", "HBM"] self.host_dir_list = ["HashTable", "DDR"] @@ -28,7 +28,7 @@ class SparseProcessor: self.attrib_suffix = ".attribute" self.json_attrib_dtype = "data_type" self.json_attrib_shape = "shape" - self.table_list = kwargs.get("table_list") + self.table_list = table_list self.default_table_list = list(export_table_name_set()) if not self.table_list: @@ -38,8 +38,8 @@ class SparseProcessor: self.table_list = check_table_param(self.table_list, self.default_table_list) @staticmethod - def set_instance(**kwargs): - SparseProcessor.single_instance = SparseProcessor(**kwargs) + def set_instance(table_list): + SparseProcessor.single_instance = SparseProcessor(table_list) @staticmethod def _get_data(data_dir, dtype, data_shape): @@ -166,11 +166,11 @@ class SparseProcessor: @para_checker_decorator(check_option_list=[ - ("table_list", ClassValidator, {"classes": (list, )}) + ("table_list", ClassValidator, {"classes": (list, type(None))}) ]) -def export(**kwargs): +def export(table_list=None): empty_value = 0 - SparseProcessor.set_instance(**kwargs) + SparseProcessor.set_instance(table_list) if SparseProcessor.single_instance.table_list: return SparseProcessor.single_instance.export_sparse_data() else: diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index e5c8b157..9fe750dc 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -229,16 +229,27 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da int64_t address = transArr.at(i + 1); float *floatPtr = reinterpret_cast(address); - aclError ret = aclrtMemcpy(row.data(), embeddingSize * sizeof(float), - floatPtr, embeddingSize * sizeof(float), - ACL_MEMCPY_DEVICE_TO_HOST); + aclError ret; + try { + ret = aclrtMemcpy(row.data(), embeddingSize * sizeof(float), + floatPtr, embeddingSize * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); + } catch (std::exception& e) { + writeFile.close(); + throw runtime_error(StringFormat("error happen when acl memory copy from device to host: %s", e.what())); + } + if (ret != ACL_SUCCESS) { LOG_ERROR("aclrtMemcpy failed, ret={}", ret); writeFile.close(); throw runtime_error(Logger::Format("aclrtMemcpy failed, ret={}", ret).c_str()); } - writeFile.write(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); + try { + writeFile.write(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); + } catch (std::exception& e) { + writeFile.close(); + throw runtime_error(StringFormat("error happen when write embedding to file: %s", e.what())); + } } #endif writeFile.close(); @@ -291,10 +302,20 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, auto &transArr = transData.int64Arr; for (size_t i = 0, j = 0; i < transArr.size(); i += keyAddrElem, ++j) { vector row(embeddingSize); - readFile.read(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); - - aclError ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), - row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + try { + readFile.read(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); + } catch (std::exception& e) { + readFile.close(); + throw runtime_error(StringFormat("error happen when reading embedding from file: %s", e.what())); + } + aclError ec; + try { + ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), + row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + } catch (std::exception& e) { + readFile.close(); + throw runtime_error(StringFormat("error happen when acl memory copy from host to device: %s", e.what())); + } if (ec != ACL_SUCCESS) { LOG_ERROR("aclrtMemcpy failed, ret={}", ec); readFile.close(); @@ -369,6 +390,8 @@ void Checkpoint::WriterFn(BufferQueue& queue, int fd) ssize_t result = write(fd, writeBuffer.data(), writeBuffer.size()); if (result != writeBuffer.size()) { LOG_ERROR("Error writing to file"); + close(fd); + throw runtime_error(StringFormat("error happen when writing file. ")); } writeBuffer.clear(); } @@ -390,8 +413,9 @@ void Checkpoint::WriteDataset(CkptTransData& transData, } if (result != writeSize) { + close(fd); LOG_ERROR("Error writing to file, please check the disk buffer or temporary folder space or file permissions!"); - return; + throw runtime_error(StringFormat("error happen when write file. ")); } } @@ -559,7 +583,12 @@ void Checkpoint::ReadStream(CkptTransData& transData, } else { readSize = datasetSize; } - ReadDataset(transData, readFile, readSize, dataType, idx); + try { + ReadDataset(transData, readFile, readSize, dataType, idx); + } catch (std::exception& e) { + readFile.close(); + throw runtime_error(StringFormat("error happen when reading data from file: %s", e.what())); + } datasetSize -= readSize; idx += readSize; } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index ad21a7f7..1fdb5333 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -55,7 +55,7 @@ namespace MxRec { constexpr int DEFAULT_KEY_PROCESS_THREAD = 6; constexpr int KEY_PROCESS_THREAD = 6; constexpr char SUM_SAME_ID[] = "sum_same_id_gradients_and_apply"; - constexpr int MAX_VOCABULARY_SIZE = 1e9; + constexpr size_t MAX_VOCABULARY_SIZE = 1e10; constexpr int SSD_SIZE_INDEX = 2; constexpr int MAX_FILE_NUM = 1000; // for GLOG -- Gitee From 70b92bcd5b850e36f159299fcdbfa6d74f999621 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 15:31:45 +0800 Subject: [PATCH 429/551] Match-id-cdff4e81b9b631d1676d88dea0c5973abf5c9329 --- mx_rec/graph/patch.py | 8 +++++--- mx_rec/validator/validator.py | 6 ++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index a28713c1..b843db85 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -45,11 +45,13 @@ def init_dataset(self, input_data): @para_checker_decorator(check_option_list=[ - ("fetches", ClassValidator, {"classes": (str, tf.Operation, tf.Tensor, tf.sparse.SparseTensor, list, tuple, dict)}), - ("feed_dict", ClassValidator, {"classes": (tf.Tensor, tf.sparse.SparseTensor, list, tuple, dict, type(None))}), + ("fetches", ClassValidator, {"classes": (str, tf.Operation, tf.Variable, tf.Tensor, + tf.sparse.SparseTensor, list, tuple, dict)}), + ("feed_dict", ClassValidator, {"classes": (tf.Variable, tf.Tensor, tf.sparse.SparseTensor, + list, tuple, dict, type(None))}), ("options", ClassValidator, {"classes": (tf.compat.v1.RunOptions, type(None))}), ("run_metadata", ClassValidator, {"classes": (tf.compat.v1.RunMetadata, type(None))}), -]) +], output_log=False) def run(self, fetches, feed_dict=None, options=None, run_metadata=None): """ diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index 4bd2012b..80cc8b28 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -54,9 +54,10 @@ class Validator: def para_checker_decorator(check_option_list: List[Tuple[Union[List[str], str], Type[Validator], Optional[Dict], - Optional[List[str]]]]): + Optional[List[str]]]], output_log=True): """ 函数参数校验装饰器 + :param output_log: 是否打印日志 :param check_option_list: 需要校验的参数及其相关校验器[“需要检验的参数或参数组合”, "使用的校验器", "校验器的参数", "校验器需要执行的方法(添加指定校验)"] :return: @@ -87,7 +88,8 @@ def para_checker_decorator(check_option_list: List[Tuple[Union[List[str], str], continue args_with_default.add(arg) kwargs.update({arg: default}) - logger.debug("[checker wrapper]func %s kwargs: %s", func.__name__, actual_args) + if output_log: + logger.debug("[checker wrapper]func %s kwargs: %s", func.__name__, actual_args) # 执行每一个检查项 for option in check_option_list: optional_check_list = None -- Gitee From 5fa1965009a73a508661daa88fd4f0b9cad79389 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 17:18:40 +0800 Subject: [PATCH 430/551] Match-id-b277a520b4da7ea9fad5f58a07b978c6435642b2 --- mx_rec/constants/constants.py | 15 --------------- mx_rec/logger/log.py | 36 ----------------------------------- mx_rec/logger/logger.yaml | 18 ------------------ 3 files changed, 69 deletions(-) delete mode 100644 mx_rec/logger/log.py delete mode 100644 mx_rec/logger/logger.yaml diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 4d6dc729..bda77f07 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -7,20 +7,11 @@ import numpy as np ASCEND_GLOBAL_HASHTABLE_COLLECTION = "ASCEND_GLOBAL_HASHTABLE_COLLECTION" ASCEND_CUTTING_POINT_INITIALIZER = "ASCEND_CUTTING_POINT_INITIALIZER" -ASCEND_CUTTING_POINT = "ASCEND_CUTTING_POINT" ASCEND_SPARSE_LOOKUP_ENTRANCE = "ASCEND_SPARSE_LOOKUP_ENTRANCE" ASCEND_SPARSE_LOOKUP_ID_OFFSET = "ASCEND_SPARSE_LOOKUP_ID_OFFSET" -ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR = "ASCEND_SPARSE_LOOKUP_RESTORE_VECTOR" ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS = "ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS" -ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT = "ASCEND_SPARSE_LOOKUP_LOOKUP_RESULT" -# dynamic shape identity -ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX = "ASCEND_SPARSE_LOOKUP_ALL2ALL_MATRIX" -# hot embed function identity -ASCEND_SPARSE_LOOKUP_HOT_POS = "ASCEND_SPARSE_LOOKUP_HOT_POS" ASCEND_TIMESTAMP = "ASCEND_TIMESTAMP" -CUSTOMIZED_OPS_LIB_PATH = "CUSTOMIZED_OPS_LIB_PATH" ASCEND_SPARSE_LOOKUP_LOCAL_EMB = "ASCEND_SPARSE_LOOKUP_LOCAL_EMB" - EMPTY_STR = "" # 获取ConfigInitializer对象实例失败提示信息 @@ -32,10 +23,6 @@ ANCHOR_DATASET_NAME = "PrefetchDataset" # the name of the embedding table merged by third party ASCEND_TABLE_NAME_MUST_CONTAIN = None -# this number is a temp plan to solve a problem -# to avoid op "scatter_nd_update" may get a None tensor for input -AVOID_TENSOR_POS = 439999 - # while循环最大深度 MAX_WHILE_SIZE = 800 @@ -72,9 +59,7 @@ MIN_SIZE = 1 MAX_CONFIG_SIZE = 10 * 1024 * 1024 MAX_SIZE = 1024 * 1024 * 1024 * 1024 MAX_FILE_SIZE = 500 * 1024 * 1024 * 1024 -MAX_DEVICE_NUM = 16 MAX_RANK_SIZE = 4095 -MIN_DEVICE_NUM = 1 MIN_RANK_SIZE = 1 LOG_MAX_SIZE = 1024 * 1024 diff --git a/mx_rec/logger/log.py b/mx_rec/logger/log.py deleted file mode 100644 index 2855e329..00000000 --- a/mx_rec/logger/log.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2023 Huawei Technologies Co., Ltd - -import logging.config -import os -import yaml - -from mx_rec.constants.constants import LOG_MAX_SIZE -from mx_rec.validator.validator import FileValidator -from mx_rec.util.global_env_conf import global_env - - -def init_sys_log(): - work_dir = os.path.dirname(os.path.dirname(__file__)) - log_cfg_file = os.path.join(work_dir, "logger.yaml") - with open(log_cfg_file, 'r', encoding='utf-8'): - if not FileValidator("log_cfg_file", log_cfg_file). \ - check_file_size(LOG_MAX_SIZE). \ - check_not_soft_link(). \ - check_user_group(). \ - is_valid(): - raise ValueError("Log config file is not valid.") - - real_config_path = os.path.realpath(log_cfg_file) - with open(real_config_path, 'r', encoding='utf-8') as open_file: - data = open_file.read(LOG_MAX_SIZE) - log_cfg = yaml.safe_load(data) - - logging.config.dictConfig(log_cfg) - - -init_sys_log() -srv_stream_log = logging.getLogger("logStream") -srv_log = srv_stream_log -srv_log.setLevel(global_env.mxrec_log_level) diff --git a/mx_rec/logger/logger.yaml b/mx_rec/logger/logger.yaml deleted file mode 100644 index 13eb3158..00000000 --- a/mx_rec/logger/logger.yaml +++ /dev/null @@ -1,18 +0,0 @@ -version: 1 -formatters: - simpleFmt: - format: '[%(asctime)s][%(levelname)s][%(message)s]' - wholeFmt: - format: '[%(asctime)s][%(levelname)s][%(message)s][%(filename)s, %(funcName)s:%(lineno)d][%(process)d, %(thread)d]' -handlers: - runStreamHandler: - class: logging.handlers.RotatingFileHandler - level: INFO - formatter: wholeFmt - stream: ext://sys.stdout - -loggers: - logStream: - level: INFO - handlers: [runStreamHandler] - propagate: no \ No newline at end of file -- Gitee From 2bea3d7b61fcd960619cf28e3eeba6e29d6dad35 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 17:22:23 +0800 Subject: [PATCH 431/551] Match-id-259ff22e78c7ad37e6dd0dba7a3ac9c747a66299 --- mx_rec/__init__.py | 4 ++-- mx_rec/constants/__init__.py | 4 +++- mx_rec/core/__init__.py | 5 ++++- mx_rec/core/asc/__init__.py | 6 +++++- mx_rec/graph/__init__.py | 5 ++++- mx_rec/saver/__init__.py | 5 +++++ mx_rec/util/__init__.py | 13 ++++++++++++- 7 files changed, 35 insertions(+), 7 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 9c1157fd..a2de58ea 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["constants", "core", "graph", "util", - "version", "__version__"] + +__all__ = ["version", "__version__"] from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook diff --git a/mx_rec/constants/__init__.py b/mx_rec/constants/__init__.py index 0270daf3..eab15375 100644 --- a/mx_rec/constants/__init__.py +++ b/mx_rec/constants/__init__.py @@ -2,4 +2,6 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["constants"] \ No newline at end of file +__all__ = ["ASCEND_TIMESTAMP", "ApplyGradientsStrategy"] + +from mx_rec.constants.constants import ASCEND_TIMESTAMP, ApplyGradientsStrategy diff --git a/mx_rec/core/__init__.py b/mx_rec/core/__init__.py index d711ee8e..f904b242 100644 --- a/mx_rec/core/__init__.py +++ b/mx_rec/core/__init__.py @@ -2,4 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["asc", "embedding", "feature_process"] +__all__ = ["create_table", "sparse_lookup", "EvictHook"] + +from mx_rec.core.embedding import create_table, sparse_lookup +from mx_rec.core.feature_process import EvictHook diff --git a/mx_rec/core/asc/__init__.py b/mx_rec/core/asc/__init__.py index b9575dd5..4c0f202a 100644 --- a/mx_rec/core/asc/__init__.py +++ b/mx_rec/core/asc/__init__.py @@ -2,4 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["feature_spec", "helper", "manager"] \ No newline at end of file +__all__ = ["get_asc_insert_func", "start_asc_pipeline", "FeatureSpec"] + +from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.asc.manager import start_asc_pipeline +from mx_rec.core.asc.helper import get_asc_insert_func diff --git a/mx_rec/graph/__init__.py b/mx_rec/graph/__init__.py index ee63132c..22f0e96f 100644 --- a/mx_rec/graph/__init__.py +++ b/mx_rec/graph/__init__.py @@ -2,4 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["modifier"] \ No newline at end of file +__all__ = ["modify_graph_and_start_emb_cache", "GraphModifierHook", "run"] + +from mx_rec.graph.modifier import GraphModifierHook, modify_graph_and_start_emb_cache +from mx_rec.graph.patch import run diff --git a/mx_rec/saver/__init__.py b/mx_rec/saver/__init__.py index 6924f767..bfd098f3 100644 --- a/mx_rec/saver/__init__.py +++ b/mx_rec/saver/__init__.py @@ -1,3 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +__all__ = ["export", "save", "restore"] + +from mx_rec.saver.patch import save, restore +from mx_rec.saver.sparse import export diff --git a/mx_rec/util/__init__.py b/mx_rec/util/__init__.py index 4c2bd953..2d244a10 100644 --- a/mx_rec/util/__init__.py +++ b/mx_rec/util/__init__.py @@ -2,4 +2,15 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["initialize", "variable"] +__all__ = [ + "init", "get_rank_id", "get_initializer", "terminate_config_initializer", "clear_channel", + "get_dense_and_sparse_variable", "set_if_load", "set_ascend_global_hashtable_collection", + "get_ascend_global_hashtable_collection", "get_rank_size", "get_host_pipeline_ops", + "get_use_dynamic_expansion", "set_ascend_table_name_must_contain", "get_target_batch" +] + +from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook +from mx_rec.util.initialize import init, get_rank_id, get_initializer, terminate_config_initializer, clear_channel, \ + set_if_load, set_ascend_global_hashtable_collection, get_ascend_global_hashtable_collection, get_rank_size, \ + get_host_pipeline_ops, get_use_dynamic_expansion, set_ascend_table_name_must_contain, get_target_batch +from mx_rec.util.variable import get_dense_and_sparse_variable -- Gitee From 82dfde513ad1744cd8cac7a3f8798d481ad52101 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 17:24:45 +0800 Subject: [PATCH 432/551] Match-id-96df9ce7804901b4881eb194c55fe46fa78afe1a --- src/core/ssd_engine/file.cpp | 45 +++++++++++++++++++++---- src/core/ssd_engine/table.cpp | 62 +++++++++++++++++++++++------------ src/core/ssd_engine/table.h | 2 ++ 3 files changed, 82 insertions(+), 27 deletions(-) diff --git a/src/core/ssd_engine/file.cpp b/src/core/ssd_engine/file.cpp index 7e5592f0..6b8805ba 100644 --- a/src/core/ssd_engine/file.cpp +++ b/src/core/ssd_engine/file.cpp @@ -16,8 +16,17 @@ File::File(uint64_t fileID, string &fileDir) : fileID(fileID), fileDir(fileDir) { LOG_DEBUG("start init file, fileID:{}", fileID); - if (!fs::exists(fs::absolute(fileDir)) && (!fs::create_directories(fs::absolute(fileDir)))) { - throw runtime_error("fail to create Save directory"); + if (!fs::exists(fs::absolute(fileDir))) { + if (!fs::create_directories(fs::absolute(fileDir))) { + throw runtime_error("fail to create Save directory"); + } + try { + fs::permissions(fileDir, fs::perms::owner_all | fs::perms::group_read | fs::perms::group_exec); + } catch (runtime_error &e) { + LOG_ERROR("fail to change permission of {}", fileDir.c_str()); + fs::remove_all(fileDir); + throw; + } } // latest file is temporary, unnecessary to check file existence and privilege @@ -27,12 +36,24 @@ File::File(uint64_t fileID, string &fileDir) : fileID(fileID), fileDir(fileDir) if (!localFileMeta.is_open()) { throw runtime_error("fail to create meta file"); } - fs::permissions(metaFilePath, fs::perms::owner_read | fs::perms::owner_write); + try { + fs::permissions(metaFilePath, fs::perms::owner_read | fs::perms::owner_write | fs::perms::group_read); + } catch (runtime_error &e) { + LOG_ERROR("fail to change permission of {}", metaFilePath.c_str()); + fs::remove_all(metaFilePath); + throw; + } localFileData.open(dataFilePath, ios::out | ios::in | ios::trunc | ios::binary); if (!localFileData.is_open()) { throw runtime_error("fail to create data file"); } - fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write); + try { + fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write | fs::perms::group_read); + } catch (runtime_error &e) { + LOG_ERROR("fail to change permission of {}", dataFilePath.c_str()); + fs::remove_all(dataFilePath); + throw; + } LOG_DEBUG("end init file, fileID:{}", fileID); } @@ -75,12 +96,24 @@ File::File(uint64_t fileID, string &fileDir, string &loadDir, int step) : fileID if (!localFileMeta.is_open()) { throw runtime_error("fail to Load latest meta file"); } - fs::permissions(metaFilePath, fs::perms::owner_read | fs::perms::owner_write); + try { + fs::permissions(metaFilePath, fs::perms::owner_read | fs::perms::owner_write | fs::perms::group_read); + } catch (runtime_error &e) { + LOG_ERROR("fail to change permission of {}", metaFilePath.c_str()); + fs::remove_all(metaFilePath); + throw; + } localFileData.open(dataFilePath, ios::out | ios::in | ios::binary); if (!localFileData.is_open()) { throw runtime_error("fail to Load latest data file"); } - fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write); + try { + fs::permissions(dataFilePath, fs::perms::owner_read | fs::perms::owner_write | fs::perms::group_read); + } catch (runtime_error &e) { + LOG_ERROR("fail to change permission of {}", dataFilePath.c_str()); + fs::remove_all(dataFilePath); + throw; + } Load(); LOG_DEBUG("end init file with load, fileID:{}", fileID); diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index fd36975b..d294bb36 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -18,11 +18,10 @@ Table::Table(const string &name, vector &savePaths, uint64_t maxTableSiz maxTableSize(maxTableSize), compactThreshold(compactThreshold) { - curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + - saveDirPrefix + GlogConfig::gRankId + "/" + name).string(); - if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { - throw runtime_error("fail to create table directory"); - } + auto rankPath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + GlogConfig::gRankId); + CreateTableDir(rankPath); + curTablePath = fs::absolute(rankPath.string() + "/" + name).string(); + CreateTableDir(curTablePath); LOG_INFO("create table:{} at path:{}", name, curTablePath); } @@ -39,11 +38,10 @@ Table::Table(const string &name, vector &saveDirs, uint64_t maxTableSize compactThreshold(compactThreshold) { // always use first path to save until it's full - curTablePath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + - saveDirPrefix + GlogConfig::gRankId + "/" + name).string(); - if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { - throw runtime_error("fail to create table directory"); - } + auto rankPath = fs::absolute(savePaths.at(curSavePathIdx) + "/" + saveDirPrefix + GlogConfig::gRankId); + CreateTableDir(rankPath); + curTablePath = fs::absolute(rankPath.string() + "/" + name).string(); + CreateTableDir(curTablePath); bool isMetaFileFound = false; for (const string &dirPath: saveDirs) { @@ -106,7 +104,13 @@ void Table::Save(int step) if (!metaFile.is_open()) { throw runtime_error("fail to create table meta file"); } - fs::permissions(metaFilePath, fs::perms::owner_read | fs::perms::owner_write); + try { + fs::permissions(metaFilePath, fs::perms::owner_read | fs::perms::owner_write | fs::perms::group_read); + } catch (runtime_error &e) { + LOG_ERROR("fail to change permission of {}", metaFilePath.c_str()); + fs::remove_all(metaFilePath); + throw; + } // dump table name uint32_t nameSize = static_cast(name.size()); @@ -125,9 +129,11 @@ void Table::Save(int step) metaFile.close(); throw runtime_error(StringFormat("set table path to disk with space error:{}", e.what())); } - if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { + try { + CreateTableDir(curTablePath); + } catch (runtime_error &e) { metaFile.close(); - throw runtime_error("fail to create table directory"); + throw; } f->Save(curTablePath, step); } @@ -169,20 +175,18 @@ void Table::LoadDataFileSet(const shared_ptr &metaFile, int step) // try to find data file from each path string loadPath = p + "/" + saveDirPrefix + GlogConfig::gRankId + "/" + name; SetTablePathToDiskWithSpace(); - if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { - throw runtime_error("fail to create table directory"); - } + CreateTableDir(curTablePath); try { loadedFile = make_shared(fileID, curTablePath, loadPath, step); fileSet.insert(loadedFile); break; } catch (invalid_argument &e) { // do nothing because file may in other path - LOG_INFO("insert exception, do nothing because file may in other path"); + LOG_DEBUG("data file not found, id:{}, try other path", fileID); } } if (loadedFile == nullptr) { - throw invalid_argument("data file not found"); + throw invalid_argument(StringFormat("data file not found, id:%d", fileID)); } auto keys = loadedFile->GetKeys(); @@ -261,9 +265,7 @@ void Table::InsertEmbeddingsInner(vector &keys, vector> if (curFile == nullptr || (curFile != nullptr && curFile->GetDataCnt() >= maxDataNumInFile)) { SetTablePathToDiskWithSpace(); - if (!fs::exists(curTablePath) && !fs::create_directories(curTablePath)) { - throw runtime_error("fail to create table directory"); - } + CreateTableDir(curTablePath); curFile = make_shared(curMaxFileID, curTablePath); fileSet.insert(curFile); curMaxFileID++; @@ -411,3 +413,21 @@ uint64_t Table::GetTableUsage() return totalKeyCnt; } +void Table::CreateTableDir(const string &path) +{ + if (fs::exists(path)) { + return; + } + if (!fs::create_directories(path)) { + throw runtime_error(StringFormat("fail to create table directory:%s", path.c_str())); + } + try { + fs::permissions(path, fs::perms::owner_all | fs::perms::group_read | fs::perms::group_exec); + } catch (runtime_error &e) { + LOG_ERROR("fail to change permission of {}", path.c_str()); + fs::remove_all(path); + throw; + } + LOG_DEBUG("create table dir:{}", path); +} + diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h index cb743f15..4c7e1ad9 100644 --- a/src/core/ssd_engine/table.h +++ b/src/core/ssd_engine/table.h @@ -42,6 +42,8 @@ namespace MxRec { uint64_t GetTableUsage(); private: + static void CreateTableDir(const string& path); + void Load(const string& metaFilePath, int step); void InsertEmbeddingsInner(vector &keys, vector> &embeddings); -- Gitee From 78c554a50262be71e5638f2e338fe2ffd10bbc2f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 15:22:40 +0800 Subject: [PATCH 433/551] Match-id-d643069f5d5eb201a24419818f6e6ea92322438f --- mx_rec/saver/patch.py | 12 +++++------- mx_rec/saver/saver.py | 3 +-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index fe9e7ffb..d64d5d9d 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -25,7 +25,7 @@ from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.saving import saveable_object_util import numpy as np -from mx_rec.saver.saver import Saver as SparseSaver +from mx_rec.saver.saver import Saver as SparseSaver, check_file_system_is_valid from mx_rec.util.initialize import get_ascend_global_hashtable_collection, export_removing_var_list from mx_rec.validator.validator import para_checker_decorator, ClassValidator, StringValidator, OptionalIntValidator, \ OptionalStringValidator, DirectoryValidator @@ -189,9 +189,8 @@ def build(self): ]) def save(self, sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix="meta", write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False): - # since tf 2.6.0, tf needs tensorflow_io to support hdfs path - if tf.__version__.startswith("2") and save_path.find("://") != -1: - import tensorflow_io as tfio + if not check_file_system_is_valid(save_path): + raise ValueError(f"the path to save belong to invalid file system, only local file system supported. ") if not self._is_built and not context.executing_eagerly(): raise RuntimeError("`build()` should be called before save if defer_build==True") @@ -232,9 +231,8 @@ def save(self, sess, save_path, global_step=None, latest_filename=None, meta_gra def restore(self, sess, save_path): if save_path is None: raise ValueError("Can't load save_path when it is None.") - # since tf 2.6.0, tf needs tensorflow_io to support hdfs path - if tf.__version__.startswith("2") and save_path.find("://") != -1: - import tensorflow_io as tfio + if not check_file_system_is_valid(save_path): + raise ValueError(f"the path to restore belong to invalid file system, only local file system supported. ") if save_path.find("://") == -1: directory_validator = DirectoryValidator("reading_path", save_path) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 52d465b6..e07b5691 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -88,8 +88,7 @@ class Saver(object): @performance("Save") def save(self, sess, save_path="model", global_step=None): """ - Save sparse tables. For local save, both save_easy mode and normal mode is supported. For HDFS save, - only save_easy mode is supported. + Save sparse tables. For local save, both save_easy mode and normal mode is supported. For easy_save mode, checkpoint is saved in under format: ./rank_id/HashTable/HBM/embed_table_name/key/xxx.data ./rank_id/HashTable/HBM/embed_table_name/key/xxx.attribute -- Gitee From 415ed391319796accdd3932a1c5f52bc35f490f2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 8 Nov 2023 17:53:13 +0800 Subject: [PATCH 434/551] Match-id-9cb07dca224cb89d41f93db1450ac595368d7ec7 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 15 +++++---------- src/core/hybrid_mgmt/hybrid_mgmt.h | 3 +-- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 42d5f0be..91b9b7bb 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -225,8 +225,7 @@ bool HybridMgmt::Save(const string savePath) { #ifndef GTEST if (!isInitialized) { - LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); - return false; + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } // 数据处理线程上锁 @@ -282,8 +281,7 @@ bool HybridMgmt::Load(const string& loadPath) { #ifndef GTEST if (!isInitialized) { - LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); - return false; + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } // 数据处理线程上锁 @@ -886,8 +884,7 @@ bool HybridMgmt::Evict() { #ifndef GTEST if (!isInitialized) { - LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); - return false; + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } // 配置了淘汰选项,则触发 @@ -1061,8 +1058,7 @@ int64_t HybridMgmt::GetTableSize(const string& embName) const { #ifndef GTEST if (!isInitialized) { - LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); - return -1; + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } if (mgmtRankInfo.useDynamicExpansion) { @@ -1106,8 +1102,7 @@ int64_t HybridMgmt::GetTableCapacity(const string& embName) const { #ifndef GTEST if (!isInitialized) { - LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); - return -1; + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } if (mgmtRankInfo.useDynamicExpansion) { diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 0defe7b8..5ebfed4a 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -74,8 +74,7 @@ namespace MxRec { void Destroy() { if (!isInitialized) { - LOG_ERROR("HybridMgmt not initialized. Call Initialize first."); - return; + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } if (!isRunning) { -- Gitee From e49d2686adeeb1b0c09225409674d5a092977df4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 9 Nov 2023 11:07:23 +0800 Subject: [PATCH 435/551] Match-id-f196e81c271c73e7e1bfc15f4f05b106da872680 --- src/core/utils/common.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 1fdb5333..f7f27a34 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -124,6 +124,7 @@ namespace MxRec { const int ASCEND910_PRO_A = 262144; const int ASCEND910_B = 262144; const int ASCEND910_A = 262144; + const int ASCEND910_B2C = 196608; }; inline int GetUBSize(int devID) @@ -134,7 +135,8 @@ namespace MxRec { {"910B1", UBSize::ASCEND910_B1}, {"910B2", UBSize::ASCEND910_B2}, {"910B3", UBSize::ASCEND910_B3}, - {"910B4", UBSize::ASCEND910_B4}}; + {"910B4", UBSize::ASCEND910_B4}, + {"910B2C", UBSize::ASCEND910_B2C}}; auto it = chipUbSizeList.find(GetChipName(devID)); if (it != chipUbSizeList.end()) { return it->second; -- Gitee From f9e364301caf8e0916e28531ac30efd50f1ea917 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 9 Nov 2023 15:04:07 +0800 Subject: [PATCH 436/551] Match-id-ab0694836d442c76abd6a658b79ff1679044880b --- cust_op/cust_op_by_addr/run.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cust_op/cust_op_by_addr/run.sh b/cust_op/cust_op_by_addr/run.sh index 86ca88cf..f6b9a94e 100644 --- a/cust_op/cust_op_by_addr/run.sh +++ b/cust_op/cust_op_by_addr/run.sh @@ -28,6 +28,9 @@ if [ ! -f "CMakePresets.json" ]; then exit 1 fi +# 禁止生成CRC校验和 +sed -i 's/--nomd5/--nomd5 --nocrc/g' ./cmake/makeself.cmake + # 修改cann安装路径 sed -i 's:"/usr/local/Ascend/latest":"/usr/local/Ascend/ascend-toolkit/latest":g' CMakePresets.json -- Gitee From 8074a6e5db60beae4c21a10c35a2fca04da65848 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 31 Oct 2023 14:32:36 +0800 Subject: [PATCH 437/551] Match-id-c6c6974675939b28c04378836c659bd1f55b0205 --- mx_rec/__init__.py | 3 +-- mx_rec/graph/modifier.py | 20 ++------------------ mx_rec/graph/patch.py | 33 +-------------------------------- mx_rec/util/initialize.py | 22 ++-------------------- 4 files changed, 6 insertions(+), 72 deletions(-) diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index a2de58ea..c7261449 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -8,7 +8,7 @@ from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator, patch_for_bool_gauge, \ - patch_for_end, patch_for_assert_eval_spec, patch_for_scale_loss, patch_for_session + patch_for_assert_eval_spec, patch_for_scale_loss, patch_for_session from mx_rec.optimizers.base import patch_for_optimizer patch_for_saver() @@ -17,7 +17,6 @@ patch_for_scale_loss() patch_for_chief_session_creator() patch_for_assert_eval_spec() patch_for_bool_gauge() -patch_for_end() patch_for_optimizer() patch_for_session() __version__ = "5.0.RC2" diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 9226af56..ff2391f9 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -17,8 +17,8 @@ from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ - terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, increase_run_times, \ - get_is_last_round, insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ + terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, \ + insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ set_iterator_type from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ @@ -588,19 +588,3 @@ class GraphModifierHook(tf.estimator.SessionRunHook): if self._modify_graph and self._iterator_type == "MakeIterator": session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) - def end(self, session): - bool_gauge_set = get_bool_gauge_set() - logger.debug("GraphModifierHook, bool_gauge_set: %s", bool_gauge_set) - - # In eval or predict mode, the initializer can be directly terminated. - if 'train' not in bool_gauge_set: - logger.debug("In evaluate or predict case, GraphModifierHook call 'terminate_config_initializer'...") - terminate_config_initializer() - return - - if 'train_and_evaluate' in bool_gauge_set: - increase_run_times() - # In 'train_and_evaluate' mode, the terminate function should be executed last. - if get_is_last_round(): - logger.debug("In train_and_evaluate case, GraphModifierHook call 'terminate_config_initializer'...") - terminate_config_initializer() diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index b843db85..ad9f2c42 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -21,8 +21,7 @@ from tensorflow.python.client.session import BaseSession from mx_rec.constants import constants from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph, insert_bool_gauge, \ - get_bool_gauge_set, terminate_config_initializer, get_run_times, set_is_last_round, get_asc_manager, \ - export_table_instances + get_bool_gauge_set, terminate_config_initializer, get_asc_manager, export_table_instances from mx_rec.util.tf_version_adapter import NPUCheckpointSaverHook from mx_rec.graph.merge_lookup import do_merge_lookup from mx_rec.util.log import logger @@ -246,36 +245,6 @@ def patch_for_bool_gauge(): logger.debug("Function 'get_cell' in Class 'BoolGauge' has been patched.") -def end(self: NPUCheckpointSaverHook, session: tf.compat.v1.Session): - """ - Call at the end of session hook. - - Args: - self: An `NPUCheckpointSaverHook` instance. - session: A TensorFlow Session that will be soon closed. - - Returns: None - - """ - - logger.debug("Enter patch 'NPUCheckpointSaverHook.end'.") - logger.info("NPUCheckpointSaverHook end...") - basic_session_run_hooks.CheckpointSaverHook.end(self, session) - - if 'train_and_evaluate' in get_bool_gauge_set() and get_run_times() == 1: - set_is_last_round(True) - return - logger.debug("NPUCheckpointSaverHook call 'terminate_config_initializer'...") - terminate_config_initializer() - - -def patch_for_end(): - """Patch for 'NPUCheckpointSaverHook.end'.""" - - NPUCheckpointSaverHook.end = end - logger.debug("Function 'end' in Class 'NPUCheckpointSaverHook' has been patched.") - - def assert_eval_spec(eval_spec: EvalSpec): """ Raise error if `eval_spec` is not of the right type. diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index e0a0d5bc..5005f647 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +import atexit import os from collections import defaultdict import dataclasses @@ -107,9 +108,6 @@ class ConfigInitializer: self.notify_hybrid_channel_sparse_id = [0, 0] self.stat_on = (global_env.stat_on == Flag.TRUE.value) - def __del__(self): - self.terminate() - @property def iterator_type(self): return self._iterator_type @@ -471,6 +469,7 @@ def init(use_mpi, **kwargs): json.dumps(dataclasses.asdict(global_env), ensure_ascii=False)) ConfigInitializer.set_instance(use_mpi, **kwargs) set_ascend_env() + atexit.register(terminate_config_initializer) def get_is_graph_modify_hook_running(): @@ -481,22 +480,6 @@ def set_is_graph_modify_hook_running(is_running): ConfigInitializer.get_instance().is_graph_modify_hook_running = is_running -def get_run_times(): - return ConfigInitializer.get_instance().run_times - - -def increase_run_times(): - ConfigInitializer.get_instance().run_times.increase() - - -def get_is_last_round(): - return ConfigInitializer.get_instance().is_last_round - - -def set_is_last_round(last_round): - ConfigInitializer.get_instance().is_last_round = last_round - - def get_bool_gauge_set(): return ConfigInitializer.get_instance().bool_gauge_set @@ -592,7 +575,6 @@ def restore_host_data(root_dir): raise RuntimeError("ASC manager does not exist.") if not ConfigInitializer.get_instance().get_asc_manager().load(root_dir): - terminate_config_initializer() raise TypeError("Asc load data does not match usr setups, \ please re-consider if you want to restore from this dir") logger.debug("Data from host pipeline has been restored.") -- Gitee From f3f9f56ce4bb0d9dbaaa235c5e949cc6c159e0f1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 9 Nov 2023 16:06:04 +0800 Subject: [PATCH 438/551] Match-id-b518309d4a6aa02897814a33769c6feaa3fe57bb --- cust_op/cust_op_by_addr/run.sh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cust_op/cust_op_by_addr/run.sh b/cust_op/cust_op_by_addr/run.sh index f6b9a94e..8f539865 100644 --- a/cust_op/cust_op_by_addr/run.sh +++ b/cust_op/cust_op_by_addr/run.sh @@ -49,3 +49,10 @@ cd .. bash build.sh +# 安装编译成功的算子包 +bash ./build_out/custom_opp_centos*.run + +cd .. + +rm -rf ./custom_op + \ No newline at end of file -- Gitee From 9aadb5a7ec8b82b623e26b2ced26ecba36f21668 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 9 Nov 2023 19:22:03 +0800 Subject: [PATCH 439/551] Match-id-9082463538c83f57f22c89527302a815f852a689 --- cust_op/cust_op_by_addr/run.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cust_op/cust_op_by_addr/run.sh b/cust_op/cust_op_by_addr/run.sh index 8f539865..2e7c0d67 100644 --- a/cust_op/cust_op_by_addr/run.sh +++ b/cust_op/cust_op_by_addr/run.sh @@ -50,9 +50,8 @@ cd .. bash build.sh # 安装编译成功的算子包 -bash ./build_out/custom_opp_centos*.run +bash ./build_out/custom_opp*.run cd .. rm -rf ./custom_op - \ No newline at end of file -- Gitee From 201322e45765ab36cd3a8eda96e817a2f565217c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 10 Nov 2023 11:11:47 +0800 Subject: [PATCH 440/551] Match-id-1d34c7708fc4e1bb807981f13f84c4e856df6185 --- src/core/checkpoint/checkpoint.cpp | 4 ++-- src/core/checkpoint/checkpoint.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 9fe750dc..8db8fbba 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -213,7 +213,7 @@ void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& da { ofstream writeFile; writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); - fs::permissions(dataDir.c_str(), fs::perms::owner_read | fs::perms::owner_write); + fs::permissions(dataDir.c_str(), fs::perms::owner_read | fs::perms::owner_write | fs::perms::group_read); #ifndef GTEST auto res = aclrtSetDevice(static_cast(deviceId)); @@ -330,7 +330,7 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType) { - int fd = open(dataDir.c_str(), O_RDWR | O_CREAT | O_TRUNC, static_cast(0600)); + int fd = open(dataDir.c_str(), O_RDWR | O_CREAT | O_TRUNC, static_cast(0640)); if (fd == -1) { LOG_ERROR("Error opening file for writing"); return; diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 3403a576..cb402182 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -35,7 +35,7 @@ namespace MxRec { const string attribFileType { ".attribute" }; const string dirSeparator { "/" }; const string ssdSymbol {"SSD"}; - const mode_t dirMode { 0500 }; + const mode_t dirMode { 0750 }; const string currDir { "." }; const string prevDir { ".." }; -- Gitee From 622483070032eab29190d181ec4dc9cb6052e68c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 10 Nov 2023 15:09:40 +0800 Subject: [PATCH 441/551] Match-id-31fcd328dd20107da6b4d4c272b1518d502d5c85 --- mx_rec/util/communication/hccl_mgmt.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index 84b618b0..1a50cec3 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -10,6 +10,7 @@ from mx_rec.constants.constants import VALID_DEVICE_ID_LIST, MIN_SIZE, MAX_CONFI from mx_rec.validator.validator import FileValidator, para_checker_decorator, StringValidator, \ Convert2intValidator from mx_rec.util.global_env_conf import global_env +from mx_rec.util.log import logger def parse_hccl_json(): @@ -89,19 +90,17 @@ def set_hccl_info_without_json(visible_devices: str, rank_size: str, chief_devic sorted_device_list = sorted(device_list) local_rank_size = len(sorted_device_list) - if rank_size < local_rank_size: - raise ValueError(f"Rank size {rank_size} is less than devices: {local_rank_size}.") + if rank_size > local_rank_size: + raise ValueError(f"Rank size {rank_size} is larger than local available devices: {local_rank_size}.") - rank_to_device_dict = {0: chief_device} + if chief_device not in sorted_device_list: + raise ValueError(f"The environment variable CM_CHIEF_DEVICE {chief_device} is not in the local device list. ") - try: - sorted_device_list.pop(chief_device % local_rank_size) - except IndexError as err: - raise IndexError( - f"Config CM_CHIEF_DEVICE {chief_device} not in training container device list {sorted_device_list}.") \ - from err - except ZeroDivisionError as err: - raise ZeroDivisionError("sorted_device_list length can not equal to 0.") from err + + rank_to_device_dict = {} + chief_index = sorted_device_list.index(chief_device) + sorted_device_list = sorted_device_list[chief_index:] + sorted_device_list[0: chief_index] + sorted_device_list = sorted_device_list[:rank_size] for device_idx in sorted_device_list: import mxrec_pybind @@ -113,7 +112,7 @@ def set_hccl_info_without_json(visible_devices: str, rank_size: str, chief_devic if res > MAX_DEVICE_ID: raise ValueError(f"get logic id from physic id fail.") index = sorted_device_list.index(device_idx) - rank_to_device_dict[index + 1] = res + rank_to_device_dict[index] = res return rank_to_device_dict, local_rank_size @@ -137,4 +136,6 @@ def get_device_list(ascend_visible_devices): raise IndexError( f"Index of ascend_visible_devices {ascend_visible_devices.strip().split('-')[-1]} is out of range") \ from error + if not device_list: + raise ValueError("No device is available in the environment.") return device_list -- Gitee From 3e91b95eb0364a9d8999c81a2a84005fb2a2ed17 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 11:42:29 +0800 Subject: [PATCH 442/551] Match-id-3fa0079220103ced36150092a69cfe3b29d0ceab --- mx_rec/saver/saver.py | 16 +++++----- src/core/emb_hashmap/emb_hashmap.cpp | 2 +- src/core/hd_transfer/acl_channel.h | 1 - src/core/hd_transfer/hd_transfer.cpp | 2 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 +-- .../key_process/feature_admit_and_evict.cpp | 2 +- src/core/key_process/key_process.cpp | 2 +- src/core/utils/common.h | 1 - src/core/utils/logger.cpp | 4 +-- src/core/utils/logger.h | 30 +++++++++---------- src/core/utils/safe_queue.h | 4 +-- src/ops_tf/hybrid_dataset_ops.cpp | 10 ++++--- src/tests/emb_hashmap/emb_hashmap_test.cpp | 2 +- src/tests/key_process/key_process_test.cpp | 5 +++- src/tests/utils/log_test.cpp | 20 ++++++------- 15 files changed, 54 insertions(+), 51 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 15a34fbe..8004426d 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -38,6 +38,14 @@ class SaveModelThread(threading.Thread): class Saver(object): customized_ops = get_customized_ops() + @staticmethod + def _make_table_name_dir(root_dir, table_instance, table_name): + if table_instance.host_vocabulary_size > 0: + table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) + else: + table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) + tf.io.gfile.makedirs(table_dir) + @para_checker_decorator(check_option_list=[ ("var_list", ClassValidator, {"classes": (list, type(None))}), ("max_to_keep", IntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), @@ -60,14 +68,6 @@ class Saver(object): self._last_checkponts = [] self.build() - @staticmethod - def _make_table_name_dir(root_dir, table_instance, table_name): - if table_instance.host_vocabulary_size > 0: - table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) - else: - table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) - tf.io.gfile.makedirs(table_dir) - def build(self): if self.var_list is None: self.var_list = [] diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index ff686b49..496c47a8 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -418,7 +418,7 @@ void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& /// 记录日志:HBM和DDR换入换出后,比较hostHashMap中DDR内key和表对应的lfuCache对象中的key内容 void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const { - if (Logger::GetLevel() != Logger::trace) { + if (Logger::GetLevel() != Logger::TRACE) { return; } auto& hostMap = embHashMap.hostHashMap; diff --git a/src/core/hd_transfer/acl_channel.h b/src/core/hd_transfer/acl_channel.h index 9d0b5b49..efbadd77 100644 --- a/src/core/hd_transfer/acl_channel.h +++ b/src/core/hd_transfer/acl_channel.h @@ -9,7 +9,6 @@ #define ACL_CHANNEL_H #include -#include #include "acl/acl_tdt.h" #include "tensorflow/core/framework/tensor.h" diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index e92537dc..ea520d13 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -57,7 +57,7 @@ void HDTransfer::Destroy() LOG_INFO(HD + "destroy channel start"); for (auto& c: transferChannels) { LOG_INFO(HD + "start destroy channel:{}", c.first); - if (acltdtStopChannel(c.second)!=ACL_ERROR_NONE || acltdtDestroyChannel(c.second)!=ACL_ERROR_NONE) { + if (acltdtStopChannel(c.second) != ACL_ERROR_NONE || acltdtDestroyChannel(c.second) != ACL_ERROR_NONE) { throw runtime_error("Acl destroy channel failed."); } LOG_INFO(HD + "destroy channel:{}", c.first); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 91b9b7bb..3ed87246 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -143,7 +143,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, // 比较hostHashMap和cacheManager的数据是否一致 void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) { - if (Logger::GetLevel() != Logger::trace) { + if (Logger::GetLevel() != Logger::TRACE) { return; } auto& embHashMaps = saveData.embHashMaps; @@ -1079,7 +1079,7 @@ int64_t HybridMgmt::GetTableSize(const string& embName) const } int64_t ssdSize = 0; if (mgmtRankInfo.isSSDEnabled) { - ssdSize= cacheManager->GetTableEmbeddingSize(embName); + ssdSize = cacheManager->GetTableEmbeddingSize(embName); } const auto& iter = hostHashMaps->embHashMaps.find(embName); diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 3476b8e0..09bc4249 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -108,7 +108,7 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con auto innerIt = historyRecordInfos.find(featureId); // isEnableSum = false或者eval,只查询count,不做累加,若是新key,则count使用初始值0 - if (channel == EVAL_CHANNEL_ID || m_table2Threshold[tableNameOrigin].isEnableSum == false) { + if (channel == EVAL_CHANNEL_ID || !m_table2Threshold[tableNameOrigin].isEnableSum) { if (innerIt != historyRecordInfos.end()) { currKeyCount = historyRecordInfos[featureId].count; } diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 6809bd95..3f5619da 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1139,7 +1139,7 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vecto emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); if (restoreVec[i] >= hotPosSize) { restoreVec[i] += blockOffset[devId]; - } else if (Logger::GetLevel() >= Logger::debug) { + } else if (Logger::GetLevel() >= Logger::DEBUG) { hotNum += 1; } } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 4bf79240..a54c14de 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -291,7 +291,6 @@ namespace MxRec { string StringFormat(const string& format, Args ... args) { auto size = static_cast(GLOG_MAX_BUF_SIZE); - unique_ptr buf(new char[size]); auto buf = std::make_unique(size); memset_s(buf.get(), size, 0, size); int nChar = snprintf_s(buf.get(), size, size - 1, format.c_str(), args ...); diff --git a/src/core/utils/logger.cpp b/src/core/utils/logger.cpp index b12f9870..59134dda 100644 --- a/src/core/utils/logger.cpp +++ b/src/core/utils/logger.cpp @@ -11,7 +11,7 @@ namespace MxRec { -int MxRec::Logger::level = MxRec::Logger::info; +int MxRec::Logger::level = MxRec::Logger::INFO; int MxRec::Logger::rank = 0; void Logger::SetRank(int logRank) @@ -31,7 +31,7 @@ int Logger::GetLevel() const char* Logger::LevelToStr(int logLevel) { - if (logLevel < trace || logLevel > error) { + if (logLevel < TRACE || logLevel > ERROR) { return "INVALID LEVEL"; } static const char* msg[] = { diff --git a/src/core/utils/logger.h b/src/core/utils/logger.h index dc760e60..f095599b 100644 --- a/src/core/utils/logger.h +++ b/src/core/utils/logger.h @@ -27,11 +27,11 @@ constexpr size_t DELIM_LEN = 2; class Logger { public: - static constexpr int trace = -2; - static constexpr int debug = -1; - static constexpr int info = 0; - static constexpr int warn = 1; - static constexpr int error = 2; + static constexpr int TRACE = -2; + static constexpr int DEBUG = -1; + static constexpr int INFO = 0; + static constexpr int WARN = 1; + static constexpr int ERROR = 2; static void SetRank(int logRank); @@ -103,20 +103,20 @@ private: static int rank; }; -#define LOG_TRACE(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::trace) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::trace, args) +#define LOG_TRACE(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::TRACE) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::TRACE, args) -#define LOG_DEBUG(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::debug) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::debug, args) +#define LOG_DEBUG(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::DEBUG) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::DEBUG, args) -#define LOG_INFO(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::info) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::info, args) +#define LOG_INFO(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::INFO) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::INFO, args) -#define LOG_WARN(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::warn) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::warn, args) +#define LOG_WARN(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::WARN) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::WARN, args) -#define LOG_ERROR(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::error) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::error, args) +#define LOG_ERROR(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::ERROR) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::ERROR, args) } diff --git a/src/core/utils/safe_queue.h b/src/core/utils/safe_queue.h index 5c77a797..79122038 100644 --- a/src/core/utils/safe_queue.h +++ b/src/core/utils/safe_queue.h @@ -19,7 +19,7 @@ namespace MxRec { template class SafeQueue { - static constexpr uint64_t defaultCap = 10; + static constexpr uint64_t DEFAULT_CAP = 10; public: SafeQueue() = default; @@ -87,7 +87,7 @@ namespace MxRec { private: mutable std::mutex mut; - uint64_t capacity = defaultCap; + uint64_t capacity = DEFAULT_CAP; std::list > dataQueue; std::list > emptyQueue; std::condition_variable dataCond; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 86dcf99a..4cd5a85f 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -69,7 +69,7 @@ namespace MxRec { class SetThreshold : public OpKernel { public: - explicit SetThreshold(OpKernelConstructionPtr context) : OpKernel(context) + explicit SetThreshold(OpKernelConstruction& context) : OpKernel(context) { LOG_INFO("SetThreshold init"); OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embName)); @@ -117,7 +117,7 @@ namespace MxRec { int ParseThresholdAndCheck(const Tensor& inputTensor, int& threshold) const { // 前面8个字节、即占一个featureId位,是unix时间戳 - auto src = reinterpret_cast(inputTensor.tensor_data().data()); + const int* src = static_cast(inputTensor.tensor_data().data()); std::copy(src, src + 1, &threshold); if (threshold < 0) { @@ -540,7 +540,7 @@ REGISTER_OP("ClearChannel").Attr("channel_id : int"); REGISTER_KERNEL_BUILDER(Name("ClearChannel").Device(DEVICE_CPU), MxRec::ClearChannel); // ##################### SetThreshold ####################### -REGISTER_OP("SetThreshold") +REGISTER_OP("SetThreshold") noexcept .Input("input: int32") .Attr("emb_name: string = ''") .Attr("ids_name: string = ''") @@ -549,7 +549,9 @@ REGISTER_OP("SetThreshold") c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); return Status::OK(); }); -REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(DEVICE_CPU), MxRec::SetThreshold); +namespace MxRec { + REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(DEVICE_CPU), SetThreshold) noexcept; +} // ##################### ReturnTimestamp ####################### REGISTER_OP("ReturnTimestamp") diff --git a/src/tests/emb_hashmap/emb_hashmap_test.cpp b/src/tests/emb_hashmap/emb_hashmap_test.cpp index fe46ac07..f55b3b21 100644 --- a/src/tests/emb_hashmap/emb_hashmap_test.cpp +++ b/src/tests/emb_hashmap/emb_hashmap_test.cpp @@ -87,7 +87,7 @@ TEST(EmbHashMap, TestFindOffset) auto& ddrKeyMap = cacheManager.ddrKeyFreqMap[embTableName]; auto logLevelTemp = Logger::GetLevel(); - Logger::SetLevel(Logger::trace); + Logger::SetLevel(Logger::TRACE); vector keys4 = {21, 21, 21, 21}; // 新key重复值, 且需要换入换出 hostHashMaps.FindOffset(embTableName, keys4, currentBatchId++, keepBatchId++, channelId); RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index d406f36b..eae9f463 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -135,7 +135,10 @@ protected: { default_random_engine generator; uniform_int_distribution distribution(randMin, randMax); - int embSizeMin = 5, embSizeMax = 8, base = 2, vocabSize = 100; + int embSizeMin = 5; + int embSizeMax = 8; + int base = 2; + int vocabSize = 100; uniform_int_distribution embSizeDistribution(embSizeMin, embSizeMax); stringstream ss; for (unsigned int i = 0; i < embNums; ++i) { diff --git a/src/tests/utils/log_test.cpp b/src/tests/utils/log_test.cpp index 9aa70bd8..1eb3aa10 100644 --- a/src/tests/utils/log_test.cpp +++ b/src/tests/utils/log_test.cpp @@ -21,7 +21,7 @@ TEST(Log, Format) TEST(Log, LogLevel) { - MxRec::Logger::SetLevel(Logger::debug); + MxRec::Logger::SetLevel(Logger::DEBUG); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -33,7 +33,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -45,7 +45,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Logger::SetLevel(Logger::warn); + MxRec::Logger::SetLevel(Logger::WARN); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -57,7 +57,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Logger::SetLevel(Logger::error); + MxRec::Logger::SetLevel(Logger::ERROR); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -72,7 +72,7 @@ TEST(Log, LogLevel) TEST(Log, LayzEvalution) { - MxRec::Logger::SetLevel(Logger::warn); + MxRec::Logger::SetLevel(Logger::WARN); testing::internal::CaptureStdout(); int flag1 = 0; int flag2 = 0; @@ -97,7 +97,7 @@ TEST(Log, LayzEvalution) TEST(Log, Basic) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("basictest"); std::string output = testing::internal::GetCapturedStdout(); @@ -106,7 +106,7 @@ TEST(Log, Basic) TEST(Log, TooManyArgs1) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("{} {} {}", 0.1f, 'h', 'e', "llow"); std::string output = testing::internal::GetCapturedStdout(); @@ -116,7 +116,7 @@ TEST(Log, TooManyArgs1) TEST(Log, TooManyArgs2) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("{}", "h", "h", "h", "h", "h", "h", "h"); std::string output = testing::internal::GetCapturedStdout(); @@ -126,7 +126,7 @@ TEST(Log, TooManyArgs2) TEST(Log, FewArgs) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("{} {} {} {} {} {}", "hellow", "hellow"); std::string output = testing::internal::GetCapturedStdout(); @@ -136,7 +136,7 @@ TEST(Log, FewArgs) TEST(Log, CkptType) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("ckpt type={}", CkptDataType::EMB_DATA); std::string output = testing::internal::GetCapturedStdout(); -- Gitee From 09e9dcbc90574bbad614b8a8bd4eaa658b501f7f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 11:50:46 +0800 Subject: [PATCH 443/551] Match-id-75edef0cfc8e7073b19fd8d763e975a776edabaa --- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 4cd5a85f..383e929c 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -69,7 +69,7 @@ namespace MxRec { class SetThreshold : public OpKernel { public: - explicit SetThreshold(OpKernelConstruction& context) : OpKernel(context) + explicit SetThreshold(OpKernelConstruction* context) : OpKernel(context) { LOG_INFO("SetThreshold init"); OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embName)); -- Gitee From f0f06f7914e4656a0df205ad2676fdf033fd09ce Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 12:53:56 +0800 Subject: [PATCH 444/551] Match-id-0cd64a05d1d4596eb0de2740cc9daed69da57a88 --- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 383e929c..cba00fb8 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -117,7 +117,7 @@ namespace MxRec { int ParseThresholdAndCheck(const Tensor& inputTensor, int& threshold) const { // 前面8个字节、即占一个featureId位,是unix时间戳 - const int* src = static_cast(inputTensor.tensor_data().data()); + auto src = reinterpret_cast(inputTensor.tensor_data().data()); std::copy(src, src + 1, &threshold); if (threshold < 0) { -- Gitee From 73203aacb44163b156241d4c6a99b30b2c57f177 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 14:15:33 +0800 Subject: [PATCH 445/551] Match-id-e25a5b38096a4e3d45c28715bf8cb5b432eb2163 --- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index cba00fb8..d97a9ef3 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -540,7 +540,7 @@ REGISTER_OP("ClearChannel").Attr("channel_id : int"); REGISTER_KERNEL_BUILDER(Name("ClearChannel").Device(DEVICE_CPU), MxRec::ClearChannel); // ##################### SetThreshold ####################### -REGISTER_OP("SetThreshold") noexcept +REGISTER_OP("SetThreshold") .Input("input: int32") .Attr("emb_name: string = ''") .Attr("ids_name: string = ''") -- Gitee From 78633b29e9224bd4b2a3d0007524fb918bada8c6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 14:26:54 +0800 Subject: [PATCH 446/551] Match-id-c3c23012633808e263797f1fb75c724143d9376f --- mx_rec/optimizers/__init__.py | 3 ++- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index 733d27b2..1cec4f26 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["create_hash_optimizer", "create_ftrl_dense_optimizer", "create_hash_optimizer_by_addr", "create_hash_optimizer_by_address"] +__all__ = ["create_hash_optimizer", "create_ftrl_dense_optimizer", "create_hash_optimizer_by_addr", + "create_hash_optimizer_by_address"] from mx_rec.optimizers.adagrad import create_hash_optimizer from mx_rec.optimizers.ftrl import create_hash_optimizer diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index d97a9ef3..ae07892a 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -550,7 +550,7 @@ c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); return Status::OK(); }); namespace MxRec { - REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(DEVICE_CPU), SetThreshold) noexcept; + REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(tensorflow::DEVICE_CPU), SetThreshold); } // ##################### ReturnTimestamp ####################### -- Gitee From 238b2644e6e60769fd98eca2d9beb2ee0a23b92f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 20:32:56 +0800 Subject: [PATCH 447/551] Match-id-be4fc0bde55e2fa0463c9109915c2dc16ea7fed7 --- src/ops_tf/hybrid_dataset_ops.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index ae07892a..fca54eb4 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -549,9 +549,7 @@ REGISTER_OP("SetThreshold") c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); return Status::OK(); }); -namespace MxRec { - REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(tensorflow::DEVICE_CPU), SetThreshold); -} +REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(tensorflow::DEVICE_CPU), SetThreshold); // ##################### ReturnTimestamp ####################### REGISTER_OP("ReturnTimestamp") -- Gitee From 72e6f3b539d77b3d4c7695855f8624e5c049df1d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 20:49:25 +0800 Subject: [PATCH 448/551] Match-id-3de31253ed1f25746918adc0a05b4364f7fd6cd4 --- mx_rec/constants/constants.py | 3 +- mx_rec/core/asc/manager.py | 29 ++++++----- mx_rec/optimizers/__init__.py | 4 +- mx_rec/saver/saver.py | 52 +++++++++---------- src/core/emb_hashmap/emb_hashmap.cpp | 2 +- src/core/hd_transfer/acl_channel.h | 1 - src/core/hd_transfer/hd_transfer.cpp | 2 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 +- .../key_process/feature_admit_and_evict.cpp | 2 +- src/core/key_process/key_process.cpp | 2 +- src/core/utils/common.h | 2 +- src/core/utils/logger.cpp | 4 +- src/core/utils/logger.h | 30 +++++------ src/core/utils/safe_queue.h | 4 +- src/ops_tf/hybrid_dataset_ops.cpp | 16 +++--- src/tests/emb_hashmap/emb_hashmap_test.cpp | 2 +- src/tests/key_process/key_process_test.cpp | 5 +- src/tests/utils/log_test.cpp | 20 +++---- 18 files changed, 93 insertions(+), 91 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index bda77f07..64a98dc4 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -151,8 +151,7 @@ class OptimizerType(Enum): raise ValueError(f"Invalid mode value, please choose one from {list(map(lambda c: c.value, OptimizerType))}") -OPTIMIZER_STATE_META = {OptimizerType.LAZY_ADAM: ["momentum", "velocity"], - OptimizerType.SGD: []} +OPTIMIZER_STATE_META = {OptimizerType.LAZY_ADAM: ["momentum", "velocity"], OptimizerType.SGD: []} class All2allGradientsOp(BaseEnum): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 8c39338f..42d94476 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -110,20 +110,21 @@ def matched_truncated_normal_initializer(tabel_info): def matched_emb_initializer(tabel_info): - initializer_case_map = {"tf1/tf2_constant_initializer": - isinstance(tabel_info.emb_initializer, tf.keras.initializers.Constant) or - isinstance(tabel_info.emb_initializer, tf.constant_initializer), - "tf1/tf2_random_normal_initializer": - isinstance(tabel_info.emb_initializer, tf.keras.initializers.RandomNormal) or - isinstance(tabel_info.emb_initializer, tf.random_normal_initializer), - "tf1_truncated_normal_initializer": - tf.__version__.startswith("1") and - (isinstance(tabel_info.emb_initializer, tf.truncated_normal_initializer) or - isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal)), - "tf2_truncated_normal_initializer": - tf.__version__.startswith("2") and - isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal), - } + initializer_case_map = { + "tf1/tf2_constant_initializer": + isinstance(tabel_info.emb_initializer, tf.keras.initializers.Constant) or + isinstance(tabel_info.emb_initializer, tf.constant_initializer), + "tf1/tf2_random_normal_initializer": + isinstance(tabel_info.emb_initializer, tf.keras.initializers.RandomNormal) or + isinstance(tabel_info.emb_initializer, tf.random_normal_initializer), + "tf1_truncated_normal_initializer": + tf.__version__.startswith("1") and + (isinstance(tabel_info.emb_initializer, tf.truncated_normal_initializer) or + isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal)), + "tf2_truncated_normal_initializer": + tf.__version__.startswith("2") and + isinstance(tabel_info.emb_initializer, tf.keras.initializers.TruncatedNormal), + } if initializer_case_map.get("tf1/tf2_constant_initializer"): initializer = matched_constant_initializer(tabel_info) elif initializer_case_map.get("tf1/tf2_random_normal_initializer"): diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index 22b891b6..1cec4f26 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -2,8 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["create_hash_optimizer", "create_ftrl_dense_optimizer", - "create_hash_optimizer_by_addr", "create_hash_optimizer_by_address"] +__all__ = ["create_hash_optimizer", "create_ftrl_dense_optimizer", "create_hash_optimizer_by_addr", + "create_hash_optimizer_by_address"] from mx_rec.optimizers.adagrad import create_hash_optimizer from mx_rec.optimizers.ftrl import create_hash_optimizer diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index e07b5691..8004426d 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -38,6 +38,14 @@ class SaveModelThread(threading.Thread): class Saver(object): customized_ops = get_customized_ops() + @staticmethod + def _make_table_name_dir(root_dir, table_instance, table_name): + if table_instance.host_vocabulary_size > 0: + table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) + else: + table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) + tf.io.gfile.makedirs(table_dir) + @para_checker_decorator(check_option_list=[ ("var_list", ClassValidator, {"classes": (list, type(None))}), ("max_to_keep", IntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), @@ -60,14 +68,6 @@ class Saver(object): self._last_checkponts = [] self.build() - @staticmethod - def _make_table_name_dir(root_dir, table_instance, table_name): - if table_instance.host_vocabulary_size > 0: - table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) - else: - table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) - tf.io.gfile.makedirs(table_dir) - def build(self): if self.var_list is None: self.var_list = [] @@ -183,6 +183,24 @@ class Saver(object): save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, self.rank_id) + @performance("_save") + def _save(self, sess, root_dir): + result = self.save_op_dict + threads = [] + for table_name in result.keys(): + thread = SaveModelThread(sess, result, root_dir, table_name) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + if is_asc_manager_initialized() and not self.save_easy_mode: + save_host_data(root_dir) + logger.debug(f"host data was saved.") + def _build_save(self): for var in self.var_list: if os.getenv("TF_DEVICE", " ") == "NPU" and "merged" not in var.name: @@ -227,24 +245,6 @@ class Saver(object): assign_op = state.assign(sub_optimizer_placeholder_dict.get(key_state)) self.restore_fetch_list.append(assign_op) - @performance("_save") - def _save(self, sess, root_dir): - result = self.save_op_dict - threads = [] - for table_name in result.keys(): - thread = SaveModelThread(sess, result, root_dir, table_name) - threads.append(thread) - - for thread in threads: - thread.start() - - for thread in threads: - thread.join() - - if is_asc_manager_initialized() and not self.save_easy_mode: - save_host_data(root_dir) - logger.debug(f"host data was saved.") - def _save_easy_mode_save_key_data(self, dump_data_dict, root_dir, table_name): host_data = get_host_data(table_name) key = np.array(list(host_data.keys())) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index ff686b49..496c47a8 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -418,7 +418,7 @@ void EmbHashMap::RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& /// 记录日志:HBM和DDR换入换出后,比较hostHashMap中DDR内key和表对应的lfuCache对象中的key内容 void EmbHashMap::AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const { - if (Logger::GetLevel() != Logger::trace) { + if (Logger::GetLevel() != Logger::TRACE) { return; } auto& hostMap = embHashMap.hostHashMap; diff --git a/src/core/hd_transfer/acl_channel.h b/src/core/hd_transfer/acl_channel.h index 9d0b5b49..efbadd77 100644 --- a/src/core/hd_transfer/acl_channel.h +++ b/src/core/hd_transfer/acl_channel.h @@ -9,7 +9,6 @@ #define ACL_CHANNEL_H #include -#include #include "acl/acl_tdt.h" #include "tensorflow/core/framework/tensor.h" diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index e92537dc..ea520d13 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -57,7 +57,7 @@ void HDTransfer::Destroy() LOG_INFO(HD + "destroy channel start"); for (auto& c: transferChannels) { LOG_INFO(HD + "start destroy channel:{}", c.first); - if (acltdtStopChannel(c.second)!=ACL_ERROR_NONE || acltdtDestroyChannel(c.second)!=ACL_ERROR_NONE) { + if (acltdtStopChannel(c.second) != ACL_ERROR_NONE || acltdtDestroyChannel(c.second) != ACL_ERROR_NONE) { throw runtime_error("Acl destroy channel failed."); } LOG_INFO(HD + "destroy channel:{}", c.first); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 91b9b7bb..3ed87246 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -143,7 +143,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, // 比较hostHashMap和cacheManager的数据是否一致 void HybridMgmt::AddCacheManagerTraceLog(CkptData& saveData) { - if (Logger::GetLevel() != Logger::trace) { + if (Logger::GetLevel() != Logger::TRACE) { return; } auto& embHashMaps = saveData.embHashMaps; @@ -1079,7 +1079,7 @@ int64_t HybridMgmt::GetTableSize(const string& embName) const } int64_t ssdSize = 0; if (mgmtRankInfo.isSSDEnabled) { - ssdSize= cacheManager->GetTableEmbeddingSize(embName); + ssdSize = cacheManager->GetTableEmbeddingSize(embName); } const auto& iter = hostHashMaps->embHashMaps.find(embName); diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 3476b8e0..09bc4249 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -108,7 +108,7 @@ FeatureAdmitType FeatureAdmitAndEvict::FeatureAdmitHelper(const int channel, con auto innerIt = historyRecordInfos.find(featureId); // isEnableSum = false或者eval,只查询count,不做累加,若是新key,则count使用初始值0 - if (channel == EVAL_CHANNEL_ID || m_table2Threshold[tableNameOrigin].isEnableSum == false) { + if (channel == EVAL_CHANNEL_ID || !m_table2Threshold[tableNameOrigin].isEnableSum) { if (innerIt != historyRecordInfos.end()) { currKeyCount = historyRecordInfos[featureId].count; } diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 6809bd95..3f5619da 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1139,7 +1139,7 @@ void KeyProcess::BuildRestoreVec(const unique_ptr& batch, const vecto emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); if (restoreVec[i] >= hotPosSize) { restoreVec[i] += blockOffset[devId]; - } else if (Logger::GetLevel() >= Logger::debug) { + } else if (Logger::GetLevel() >= Logger::DEBUG) { hotNum += 1; } } diff --git a/src/core/utils/common.h b/src/core/utils/common.h index f7f27a34..a54c14de 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -291,7 +291,7 @@ namespace MxRec { string StringFormat(const string& format, Args ... args) { auto size = static_cast(GLOG_MAX_BUF_SIZE); - unique_ptr buf(new char[size]); + auto buf = std::make_unique(size); memset_s(buf.get(), size, 0, size); int nChar = snprintf_s(buf.get(), size, size - 1, format.c_str(), args ...); if (nChar == -1) { diff --git a/src/core/utils/logger.cpp b/src/core/utils/logger.cpp index b12f9870..59134dda 100644 --- a/src/core/utils/logger.cpp +++ b/src/core/utils/logger.cpp @@ -11,7 +11,7 @@ namespace MxRec { -int MxRec::Logger::level = MxRec::Logger::info; +int MxRec::Logger::level = MxRec::Logger::INFO; int MxRec::Logger::rank = 0; void Logger::SetRank(int logRank) @@ -31,7 +31,7 @@ int Logger::GetLevel() const char* Logger::LevelToStr(int logLevel) { - if (logLevel < trace || logLevel > error) { + if (logLevel < TRACE || logLevel > ERROR) { return "INVALID LEVEL"; } static const char* msg[] = { diff --git a/src/core/utils/logger.h b/src/core/utils/logger.h index dc760e60..f095599b 100644 --- a/src/core/utils/logger.h +++ b/src/core/utils/logger.h @@ -27,11 +27,11 @@ constexpr size_t DELIM_LEN = 2; class Logger { public: - static constexpr int trace = -2; - static constexpr int debug = -1; - static constexpr int info = 0; - static constexpr int warn = 1; - static constexpr int error = 2; + static constexpr int TRACE = -2; + static constexpr int DEBUG = -1; + static constexpr int INFO = 0; + static constexpr int WARN = 1; + static constexpr int ERROR = 2; static void SetRank(int logRank); @@ -103,20 +103,20 @@ private: static int rank; }; -#define LOG_TRACE(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::trace) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::trace, args) +#define LOG_TRACE(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::TRACE) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::TRACE, args) -#define LOG_DEBUG(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::debug) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::debug, args) +#define LOG_DEBUG(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::DEBUG) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::DEBUG, args) -#define LOG_INFO(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::info) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::info, args) +#define LOG_INFO(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::INFO) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::INFO, args) -#define LOG_WARN(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::warn) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::warn, args) +#define LOG_WARN(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::WARN) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::WARN, args) -#define LOG_ERROR(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::error) \ -MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::error, args) +#define LOG_ERROR(args...) if (MxRec::Logger::GetLevel() <= MxRec::Logger::ERROR) \ +MxRec::Logger::Log(__FILE__, __LINE__, MxRec::Logger::ERROR, args) } diff --git a/src/core/utils/safe_queue.h b/src/core/utils/safe_queue.h index 5c77a797..79122038 100644 --- a/src/core/utils/safe_queue.h +++ b/src/core/utils/safe_queue.h @@ -19,7 +19,7 @@ namespace MxRec { template class SafeQueue { - static constexpr uint64_t defaultCap = 10; + static constexpr uint64_t DEFAULT_CAP = 10; public: SafeQueue() = default; @@ -87,7 +87,7 @@ namespace MxRec { private: mutable std::mutex mut; - uint64_t capacity = defaultCap; + uint64_t capacity = DEFAULT_CAP; std::list > dataQueue; std::list > emptyQueue; std::condition_variable dataCond; diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 7d2e5ad4..fca54eb4 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -54,7 +54,7 @@ namespace MxRec { } } - ~ClearChannel() = default; + ~ClearChannel() override = default; void Compute(OpKernelContextPtr context) override { @@ -69,14 +69,14 @@ namespace MxRec { class SetThreshold : public OpKernel { public: - explicit SetThreshold(OpKernelConstructionPtr context) : OpKernel(context) + explicit SetThreshold(OpKernelConstruction* context) : OpKernel(context) { LOG_INFO("SetThreshold init"); OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embName)); OP_REQUIRES_OK(context, context->GetAttr("ids_name", &idsName)); // sparse_lookup查询 } - ~SetThreshold() = default; + ~SetThreshold() override = default; void Compute(OpKernelContextPtr context) override { @@ -140,7 +140,7 @@ namespace MxRec { explicit ReturnTimestamp(OpKernelConstructionPtr context) : OpKernel(context) {} - ~ReturnTimestamp() = default; + ~ReturnTimestamp() override = default; void Compute(OpKernelContextPtr context) override { @@ -189,7 +189,7 @@ namespace MxRec { } maxStep = keyProcess->GetMaxStep(channelId); } - ~ReadEmbKeyV2Dynamic() = default; + ~ReadEmbKeyV2Dynamic() override = default; void Compute(OpKernelContextPtr context) override { @@ -380,7 +380,7 @@ namespace MxRec { maxStep = keyProcess->GetMaxStep(channelId); } - ~ReadEmbKeyV2() = default; + ~ReadEmbKeyV2() override = default; void Compute(OpKernelContextPtr context) override { @@ -532,7 +532,7 @@ namespace MxRec { std::cout << " Cust opp not installed!!" << std::endl; } - ~CustOps() = default; + ~CustOps() override = default; }; } @@ -549,7 +549,7 @@ REGISTER_OP("SetThreshold") c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); return Status::OK(); }); -REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(DEVICE_CPU), MxRec::SetThreshold); +REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(tensorflow::DEVICE_CPU), SetThreshold); // ##################### ReturnTimestamp ####################### REGISTER_OP("ReturnTimestamp") diff --git a/src/tests/emb_hashmap/emb_hashmap_test.cpp b/src/tests/emb_hashmap/emb_hashmap_test.cpp index fe46ac07..f55b3b21 100644 --- a/src/tests/emb_hashmap/emb_hashmap_test.cpp +++ b/src/tests/emb_hashmap/emb_hashmap_test.cpp @@ -87,7 +87,7 @@ TEST(EmbHashMap, TestFindOffset) auto& ddrKeyMap = cacheManager.ddrKeyFreqMap[embTableName]; auto logLevelTemp = Logger::GetLevel(); - Logger::SetLevel(Logger::trace); + Logger::SetLevel(Logger::TRACE); vector keys4 = {21, 21, 21, 21}; // 新key重复值, 且需要换入换出 hostHashMaps.FindOffset(embTableName, keys4, currentBatchId++, keepBatchId++, channelId); RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index d406f36b..eae9f463 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -135,7 +135,10 @@ protected: { default_random_engine generator; uniform_int_distribution distribution(randMin, randMax); - int embSizeMin = 5, embSizeMax = 8, base = 2, vocabSize = 100; + int embSizeMin = 5; + int embSizeMax = 8; + int base = 2; + int vocabSize = 100; uniform_int_distribution embSizeDistribution(embSizeMin, embSizeMax); stringstream ss; for (unsigned int i = 0; i < embNums; ++i) { diff --git a/src/tests/utils/log_test.cpp b/src/tests/utils/log_test.cpp index 9aa70bd8..1eb3aa10 100644 --- a/src/tests/utils/log_test.cpp +++ b/src/tests/utils/log_test.cpp @@ -21,7 +21,7 @@ TEST(Log, Format) TEST(Log, LogLevel) { - MxRec::Logger::SetLevel(Logger::debug); + MxRec::Logger::SetLevel(Logger::DEBUG); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -33,7 +33,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -45,7 +45,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Logger::SetLevel(Logger::warn); + MxRec::Logger::SetLevel(Logger::WARN); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -57,7 +57,7 @@ TEST(Log, LogLevel) EXPECT_NE(output.find("warn log hellow"), string::npos); EXPECT_NE(output.find("error log hellow"), string::npos); - MxRec::Logger::SetLevel(Logger::error); + MxRec::Logger::SetLevel(Logger::ERROR); testing::internal::CaptureStdout(); LOG_DEBUG("debug log {}", "hellow"); LOG_INFO("info log {}", "hellow"); @@ -72,7 +72,7 @@ TEST(Log, LogLevel) TEST(Log, LayzEvalution) { - MxRec::Logger::SetLevel(Logger::warn); + MxRec::Logger::SetLevel(Logger::WARN); testing::internal::CaptureStdout(); int flag1 = 0; int flag2 = 0; @@ -97,7 +97,7 @@ TEST(Log, LayzEvalution) TEST(Log, Basic) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("basictest"); std::string output = testing::internal::GetCapturedStdout(); @@ -106,7 +106,7 @@ TEST(Log, Basic) TEST(Log, TooManyArgs1) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("{} {} {}", 0.1f, 'h', 'e', "llow"); std::string output = testing::internal::GetCapturedStdout(); @@ -116,7 +116,7 @@ TEST(Log, TooManyArgs1) TEST(Log, TooManyArgs2) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("{}", "h", "h", "h", "h", "h", "h", "h"); std::string output = testing::internal::GetCapturedStdout(); @@ -126,7 +126,7 @@ TEST(Log, TooManyArgs2) TEST(Log, FewArgs) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("{} {} {} {} {} {}", "hellow", "hellow"); std::string output = testing::internal::GetCapturedStdout(); @@ -136,7 +136,7 @@ TEST(Log, FewArgs) TEST(Log, CkptType) { - MxRec::Logger::SetLevel(Logger::info); + MxRec::Logger::SetLevel(Logger::INFO); testing::internal::CaptureStdout(); LOG_INFO("ckpt type={}", CkptDataType::EMB_DATA); std::string output = testing::internal::GetCapturedStdout(); -- Gitee From f3eb42c9bf4dd0d638375739141c102f414993ab Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 22:08:53 +0800 Subject: [PATCH 449/551] Match-id-d850ac0d403f6baadcf4a275b56b2a5e821e110d --- mx_rec/optimizers/__init__.py | 6 +- mx_rec/saver/saver.py | 16 +-- src/ops_tf/hybrid_dataset_ops.cpp | 190 +++++++++++++++--------------- 3 files changed, 108 insertions(+), 104 deletions(-) diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index 1cec4f26..8006213a 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -2,8 +2,10 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["create_hash_optimizer", "create_ftrl_dense_optimizer", "create_hash_optimizer_by_addr", - "create_hash_optimizer_by_address"] +__all__ = [ + "create_hash_optimizer", "create_ftrl_dense_optimizer", "create_hash_optimizer_by_addr", + "create_hash_optimizer_by_address" +] from mx_rec.optimizers.adagrad import create_hash_optimizer from mx_rec.optimizers.ftrl import create_hash_optimizer diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 8004426d..15a34fbe 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -38,14 +38,6 @@ class SaveModelThread(threading.Thread): class Saver(object): customized_ops = get_customized_ops() - @staticmethod - def _make_table_name_dir(root_dir, table_instance, table_name): - if table_instance.host_vocabulary_size > 0: - table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) - else: - table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) - tf.io.gfile.makedirs(table_dir) - @para_checker_decorator(check_option_list=[ ("var_list", ClassValidator, {"classes": (list, type(None))}), ("max_to_keep", IntValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), @@ -68,6 +60,14 @@ class Saver(object): self._last_checkponts = [] self.build() + @staticmethod + def _make_table_name_dir(root_dir, table_instance, table_name): + if table_instance.host_vocabulary_size > 0: + table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) + else: + table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) + tf.io.gfile.makedirs(table_dir) + def build(self): if self.var_list is None: self.var_list = [] diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index fca54eb4..c5a9d76c 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -69,7 +69,7 @@ namespace MxRec { class SetThreshold : public OpKernel { public: - explicit SetThreshold(OpKernelConstruction* context) : OpKernel(context) + explicit SetThreshold(OpKernelConstruction& context) : OpKernel(context) { LOG_INFO("SetThreshold init"); OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embName)); @@ -536,96 +536,98 @@ namespace MxRec { }; } -REGISTER_OP("ClearChannel").Attr("channel_id : int"); -REGISTER_KERNEL_BUILDER(Name("ClearChannel").Device(DEVICE_CPU), MxRec::ClearChannel); - -// ##################### SetThreshold ####################### -REGISTER_OP("SetThreshold") -.Input("input: int32") -.Attr("emb_name: string = ''") -.Attr("ids_name: string = ''") -.Output("output: int32") -.SetShapeFn([](InferenceContextPtr c) { -c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); -return Status::OK(); -}); -REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(tensorflow::DEVICE_CPU), SetThreshold); - -// ##################### ReturnTimestamp ####################### -REGISTER_OP("ReturnTimestamp") -.Input("input: int64") -.Output("output: int64") -.SetShapeFn([](InferenceContextPtr c) { -c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); -return Status::OK(); -}); -REGISTER_KERNEL_BUILDER(Name("ReturnTimestamp").Device(DEVICE_CPU), MxRec::ReturnTimestamp); - -// ##################### ReadEmbKeyV2Dynamic ####################### -REGISTER_OP("ReadEmbKeyV2Dynamic") -.Input("sample: T") -.Input("splits: int32") -.Output("output: int32") -.Attr("T: {int64, int32}") -.Attr("channel_id: int") -.Attr("emb_name: list(string)") // for which table to lookup -.Attr("timestamp: bool") // use for feature evict, (unix timestamp) -.SetShapeFn([](InferenceContextPtr c) { -c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); -return Status::OK(); -}); - -REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2Dynamic").Device(DEVICE_CPU), MxRec::ReadEmbKeyV2Dynamic); - -// ##################### ReadEmbKeyV2 ####################### -REGISTER_OP("ReadEmbKeyV2") -.Input("sample: T") -.Output("output: int32") -.Attr("T: {int64, int32}") -.Attr("channel_id: int") -.Attr("splits: list(int)") -.Attr("emb_name: list(string)") // for which table to lookup -.Attr("timestamp: bool") // use for feature evict, (unix timestamp) -.SetShapeFn([](InferenceContextPtr c) { -c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); -return Status::OK(); -}); - -REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), MxRec::ReadEmbKeyV2); - -REGISTER_OP("EmbeddingLookupByAddress") -.Input("address: int64") -.Attr("embedding_dim: int") -.Attr("embedding_type: int") -.Output("y: float") -.SetIsStateful() -.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { -ShapeHandle addrShape; -TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); -int embSize; -TF_RETURN_IF_ERROR(c->GetAttr("embedding_dim", &embSize)); -tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); -c->set_output(TENSOR_INDEX_0, c->Matrix(rows, embSize)); -return Status::OK(); -}); - -REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupByAddress").Device(DEVICE_CPU), MxRec::CustOps); - -REGISTER_OP("EmbeddingUpdateByAddress") -.Input("address: int64") -.Input("embedding: float") -.Attr("update_type: int") -.Output("y: float") -.SetIsStateful() -.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { -ShapeHandle addrShape; -TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); -ShapeHandle embeddingShape; -TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &embeddingShape)); -tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); -tensorflow::shape_inference::DimensionHandle cols = c->Dim(embeddingShape, 1); -c->set_output(TENSOR_INDEX_0, c->Matrix(rows, cols)); -return Status::OK(); -}); - -REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), MxRec::CustOps); \ No newline at end of file +namespace tensorflow { + REGISTER_OP("ClearChannel").Attr("channel_id : int"); + REGISTER_KERNEL_BUILDER(Name("ClearChannel").Device(DEVICE_CPU), MxRec::ClearChannel); + + // ##################### SetThreshold ####################### + REGISTER_OP("SetThreshold") + .Input("input: int32") + .Attr("emb_name: string = ''") + .Attr("ids_name: string = ''") + .Output("output: int32") + .SetShapeFn([](InferenceContextPtr c) { + c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); + return Status::OK(); + }); + REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(DEVICE_CPU), MxRec::SetThreshold); + + // ##################### ReturnTimestamp ####################### + REGISTER_OP("ReturnTimestamp") + .Input("input: int64") + .Output("output: int64") + .SetShapeFn([](InferenceContextPtr c) { + c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); + return Status::OK(); + }); + REGISTER_KERNEL_BUILDER(Name("ReturnTimestamp").Device(DEVICE_CPU), MxRec::ReturnTimestamp); + + // ##################### ReadEmbKeyV2Dynamic ####################### + REGISTER_OP("ReadEmbKeyV2Dynamic") + .Input("sample: T") + .Input("splits: int32") + .Output("output: int32") + .Attr("T: {int64, int32}") + .Attr("channel_id: int") + .Attr("emb_name: list(string)") // for which table to lookup + .Attr("timestamp: bool") // use for feature evict, (unix timestamp) + .SetShapeFn([](InferenceContextPtr c) { + c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); + return Status::OK(); + }); + + REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2Dynamic").Device(DEVICE_CPU), MxRec::ReadEmbKeyV2Dynamic); + + // ##################### ReadEmbKeyV2 ####################### + REGISTER_OP("ReadEmbKeyV2") + .Input("sample: T") + .Output("output: int32") + .Attr("T: {int64, int32}") + .Attr("channel_id: int") + .Attr("splits: list(int)") + .Attr("emb_name: list(string)") // for which table to lookup + .Attr("timestamp: bool") // use for feature evict, (unix timestamp) + .SetShapeFn([](InferenceContextPtr c) { + c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar()); + return Status::OK(); + }); + + REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), MxRec::ReadEmbKeyV2); + + REGISTER_OP("EmbeddingLookupByAddress") + .Input("address: int64") + .Attr("embedding_dim: int") + .Attr("embedding_type: int") + .Output("y: float") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ShapeHandle addrShape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); + int embSize; + TF_RETURN_IF_ERROR(c->GetAttr("embedding_dim", &embSize)); + tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); + c->set_output(TENSOR_INDEX_0, c->Matrix(rows, embSize)); + return Status::OK(); + }); + + REGISTER_KERNEL_BUILDER(Name("EmbeddingLookupByAddress").Device(DEVICE_CPU), MxRec::CustOps); + + REGISTER_OP("EmbeddingUpdateByAddress") + .Input("address: int64") + .Input("embedding: float") + .Attr("update_type: int") + .Output("y: float") + .SetIsStateful() + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ShapeHandle addrShape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &addrShape)); + ShapeHandle embeddingShape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &embeddingShape)); + tensorflow::shape_inference::DimensionHandle rows = c->Dim(addrShape, 0); + tensorflow::shape_inference::DimensionHandle cols = c->Dim(embeddingShape, 1); + c->set_output(TENSOR_INDEX_0, c->Matrix(rows, cols)); + return Status::OK(); + }); + + REGISTER_KERNEL_BUILDER(Name("EmbeddingUpdateByAddress").Device(DEVICE_CPU), MxRec::CustOps); +} \ No newline at end of file -- Gitee From 93368e15ee230435b52aa7d82e252694c495dea1 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 22:25:34 +0800 Subject: [PATCH 450/551] Match-id-137b1710cd4ebf96b3821a00ea3e10367001664d --- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 62f939ac..79c1a32d 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -69,7 +69,7 @@ namespace MxRec { class SetThreshold : public OpKernel { public: - explicit SetThreshold(OpKernelConstruction& context) : OpKernel(context) + explicit SetThreshold(OpKernelConstructionPtr context) : OpKernel(context) { LOG_INFO("SetThreshold init"); OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embName)); -- Gitee From 803ecbe75e73412bed22c0be4c87aa5792dd8948 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 22:32:11 +0800 Subject: [PATCH 451/551] Match-id-993bc9fb5c68bf15fe47c98e9afb2ce04c9fdedd --- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 79c1a32d..dc535f20 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -54,7 +54,7 @@ namespace MxRec { } } - ~ClearChannel() = default; + ~ClearChannel() override = default; void Compute(OpKernelContextPtr context) override { -- Gitee From 3aba6b68eb21011095ae0554ca4748ce86252c2b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 13 Nov 2023 22:33:03 +0800 Subject: [PATCH 452/551] Match-id-5fad0463261fd922ec77c0f063b7aa78cd920830 --- src/ops_tf/hybrid_dataset_ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index dc535f20..c605f06b 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -76,7 +76,7 @@ namespace MxRec { OP_REQUIRES_OK(context, context->GetAttr("ids_name", &idsName)); // sparse_lookup查询 } - ~SetThreshold() = default; + ~SetThreshold() override = default; void Compute(OpKernelContextPtr context) override { -- Gitee From a284305d42adc1cd7d8e414406f577afa4597ffc Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 16 Nov 2023 17:31:17 +0800 Subject: [PATCH 453/551] Match-id-9e3e18bd134574e404febcab2fdfd9b8b1d7b95b --- mx_rec/core/embedding.py | 47 +++++++++--- tests/mx_rec/core/test_embedding.py | 112 ++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 10 deletions(-) create mode 100644 tests/mx_rec/core/test_embedding.py diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index fcd191ae..e3b648bb 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -641,8 +641,11 @@ class SparseEmbedding: if not use_dynamic_expansion: id_offsets_abs = tf.abs(id_offsets) local_embeddings = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") - local_embeddings = set_zero_for_non_valid_key(id_offsets, local_embeddings, - feature_spec.access_threshold) + local_embeddings = set_specific_value_for_non_valid_key(id_offsets, + local_embeddings, + feature_spec.access_threshold, + kwargs.get("serving_default_value"), + is_training=is_training) else: local_embeddings = tf.identity(table, name="identity_local_emb") @@ -797,6 +800,7 @@ class SparseEmbedding: ("batch", ClassValidator, {"classes": (dict, type(None))}), ("access_and_evict_config", ClassValidator, {"classes": (dict, type(None))}), ("is_grad", ClassValidator, {"classes": (bool, )}), + ("serving_default_value", ClassValidator, {"classes": (tf.Tensor, type(None))}) ]) def sparse_lookup(hashtable: SparseEmbedding, ids: Union[FeatureSpec, tf.Tensor], @@ -807,6 +811,7 @@ def sparse_lookup(hashtable: SparseEmbedding, batch: Optional[dict] = None, access_and_evict_config: Optional[dict] = None, is_grad: bool = True, + serving_default_value: Optional[tf.Tensor] = None, **kwargs): """ Args: @@ -819,7 +824,10 @@ def sparse_lookup(hashtable: SparseEmbedding, batch: the value returned by the get_next() method of TF Dataset access_and_evict_config: the configuration for the feature of feature filtering and eviction is_grad: indicate whether this lookup requires update gradients - + serving_default_value: The hashtable misses the id, that is, the id that is lower than the threshold during + training, and the newly appeared id during prediction, and the lookup return value, which can ensure that + the return value of the new id is consistent during training and prediction. The default is None, and the + return value of the hashtable corresponding to the missing id is based on the initializer of hashtable. Returns: Tensor for lookup result """ @@ -835,7 +843,7 @@ def sparse_lookup(hashtable: SparseEmbedding, kwargs["feature_spec_name_ids_dict"] = None kwargs["multi_lookup"] = False kwargs["lookup_ids"] = None - + kwargs["serving_default_value"] = serving_default_value scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) logger.info("Lookup: The table name is %s, and the value of `is_grad` in this lookup (lookup name is %s) is %s.", hashtable.table_name, name, is_grad) @@ -856,23 +864,42 @@ def sparse_lookup(hashtable: SparseEmbedding, return hashtable.lookup_for_asc(ids, send_count, **kwargs) -def set_zero_for_non_valid_key(id_offsets: Optional[tf.Tensor], embeddings: Optional[tf.Tensor], - access_threshold: bool): +def set_specific_value_for_non_valid_key(id_offsets: Optional[tf.Tensor], + embeddings: Optional[tf.Tensor], + access_threshold: Optional[int], + serving_default_value: Optional[tf.Tensor] = None, + is_training: bool = True): """ - 将key为-1的特征对应的emb置为0 + 将key为-1(无效值)的特征对应的emb置为0或者指定值 :param id_offsets: 特征索引 :param embeddings: 稀疏表 :param access_threshold: 准入阈值 + :param serving_default_value: 参考create_table接口描述 + :param is_training: 当前流程是训练还是推理 :return: """ - if access_threshold is None or access_threshold <= 0: + # 在训练时,仅当开启准入功能才会出现无效值;推理时,是否开启准入都可能存在无效值 + if is_training and (access_threshold is None or access_threshold < 0): return embeddings + if serving_default_value is None: + # 未设置时,默认无效值的emb为全0 + default_value = tf.zeros_like(embeddings) + else: + try: + default_value = tf.broadcast_to(serving_default_value, tf.shape(embeddings)) + except ValueError as e: + logger.error("failed to broadcast serving_default_value to target embedding , please check its shape.") + raise e + except Exception as e: + logger.error("failed to process serving_default_value.") + raise e + if tf.__version__.startswith("1"): id_offsets_expand = tf.math.greater_equal(id_offsets, 0) - embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like(embeddings)) + embeddings = tf.where(id_offsets_expand, embeddings, default_value) return embeddings id_offsets_expand = tf.compat.v1.expand_dims(id_offsets >= 0, axis=-1) - embeddings = tf.where(id_offsets_expand, embeddings, tf.zeros_like(embeddings)) + embeddings = tf.where(id_offsets_expand, embeddings, default_value) return embeddings diff --git a/tests/mx_rec/core/test_embedding.py b/tests/mx_rec/core/test_embedding.py new file mode 100644 index 00000000..cea7c35e --- /dev/null +++ b/tests/mx_rec/core/test_embedding.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + + +import unittest +import tensorflow as tf +from mx_rec.core.embedding import set_specific_value_for_non_valid_key + +if tf.__version__.startswith("1"): + tf.enable_eager_execution() +else: + tf.compat.v1.enable_eager_execution() + + +class TestSetSpecificValueForNonValidKey(unittest.TestCase): + """ + Test Suite for set_specific_value_for_non_valid_key. + """ + + def setUp(self): + """ + 准备步骤 + :return:无 + """ + super().setUp() + + def tearDown(self): + """ + 销毁步骤 + :return: 无 + """ + super().tearDown() + + def test_given_admit_if_turned_off_then_return_raw_embedding(self): + # given + id_offsets = tf.constant([0, 1, 2, 3]) + embeddings = tf.ones(shape=(4, 10), dtype=tf.float32) + access_threshold = None + serving_default_value = tf.ones(shape=(4, 10), dtype=tf.float32) * 2 + + # when + modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, + access_threshold, serving_default_value) + + # then + result = bool(tf.reduce_all(tf.equal(embeddings, modified_emb))) + self.assertTrue(result) + + def test_given_no_default_value_and_all_valid_key_then_return_raw_embedding(self): + # given + id_offsets = tf.constant([0, 1, 2, 3]) + embeddings = tf.ones(shape=(4, 10), dtype=tf.float32) + access_threshold = 1 + serving_default_value = None + + # when + modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, + access_threshold, serving_default_value) + + # then + result = bool(tf.reduce_all(tf.equal(embeddings, modified_emb))) + self.assertTrue(result) + + def test_given_no_default_value_and_invalid_key_then_emb_of_invalid_key_set_to_zero(self): + # given + id_offsets = tf.constant([-1, 1, 2, 3]) + embeddings = tf.ones(shape=(4, 2), dtype=tf.float32) + access_threshold = 1 + serving_default_value = None + + # when + modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, + access_threshold, serving_default_value) + + # then + result = modified_emb.numpy().tolist()[0] + self.assertEqual([0, 0], result) + + def test_given_default_value_and_invalid_key_then_emb_of_invalid_key_set_to_default_value(self): + # given + id_offsets = tf.constant([-1, 1, 2, 3]) + embeddings = tf.ones(shape=(4, 2), dtype=tf.float32) + access_threshold = 1 + serving_default_value = tf.ones(shape=(4, 2), dtype=tf.float32) * 2 + + # when + modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, + access_threshold, serving_default_value) + + # then + result = modified_emb.numpy().tolist()[0] + self.assertEqual([2, 2], result) + + def test_given_default_value_and_with_all_invalid_key_then_emb_of_invalid_key_set_to_default_value(self): + # given + id_offsets = tf.constant([-1, -1, -1, -1]) + embeddings = tf.ones(shape=(4, 2), dtype=tf.float32) + access_threshold = 1 + serving_default_value = tf.ones(shape=(4, 2), dtype=tf.float32) * 2 + + # when + modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, + access_threshold, serving_default_value) + + # then + result = bool(tf.reduce_all(tf.equal(serving_default_value, modified_emb))) + self.assertTrue(result) + + +if __name__ == '__main__': + unittest.main() -- Gitee From 3a5d1940b5e37b273e1db0b64b5dd144cac84ce8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 16 Nov 2023 18:58:03 +0800 Subject: [PATCH 454/551] Match-id-9f2a872f1848bcfdd565534d5d388d339725714c --- mx_rec/util/global_env_conf.py | 1 + src/core/file_system/file_system.h | 45 +++++ src/core/file_system/file_system_handler.cpp | 24 +++ src/core/file_system/file_system_handler.h | 25 +++ .../hdfs_file_system/hdfs_file_system.h | 45 +++++ .../hdfs_file_system/hdfs_wrapper.h | 184 ++++++++++++++++++ .../local_file_system/local_file_system.h | 49 +++++ 7 files changed, 373 insertions(+) create mode 100644 src/core/file_system/file_system.h create mode 100644 src/core/file_system/file_system_handler.cpp create mode 100644 src/core/file_system/file_system_handler.h create mode 100644 src/core/file_system/hdfs_file_system/hdfs_file_system.h create mode 100644 src/core/file_system/hdfs_file_system/hdfs_wrapper.h create mode 100644 src/core/file_system/local_file_system/local_file_system.h diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index a9886e06..a3c38b93 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. import os diff --git a/src/core/file_system/file_system.h b/src/core/file_system/file_system.h new file mode 100644 index 00000000..559d8442 --- /dev/null +++ b/src/core/file_system/file_system.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#ifndef MX_REC_FILE_SYSTEM_H +#define MX_REC_FILE_SYSTEM_H + +#include "checkpoint/buffer_queue.h" + +namespace MxRec { + using namespace std; + + class FileSystem { + public: + FileSystem() = default; + virtual ~FileSystem() = default; + + virtual void CreateDir(const string& dirName) = 0; + virtual vector ListDir(const string& dirName) = 0; + virtual size_t GetFileSize(const string& filePath) = 0; + + virtual ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) = 0; + virtual ssize_t Write(const string& filePath, vector fileContent, size_t dataSize) = 0; + virtual void WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) = 0; + + virtual ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) = 0; + virtual ssize_t Read(const string& filePath, vector>& fileContent, size_t datasetSize) = 0; + virtual void ReadEmbedding(const string& filePath, const int& embeddingSize, + vector& addressArr, int deviceId) = 0; + + // The parameter oneTimeReadWriteLen specifies the maximum length of a file read or write at a time. + // The parameter can be adjusted based on the service requirements. + const size_t oneTimeReadWriteLen = 32768; + const int embHashNum = 1; + const int keyAddrElem = 1; + std::vector buffer; + std::vector writeBuffer; + }; +} + +#endif //MX_REC_FILE_SYSTEM_H diff --git a/src/core/file_system/file_system_handler.cpp b/src/core/file_system/file_system_handler.cpp new file mode 100644 index 00000000..351cff2d --- /dev/null +++ b/src/core/file_system/file_system_handler.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-11-16 + */ + +#include "file_system_handler.h" + +using namespace std; +using namespace MxRec; + +inline unique_ptr FileSystemHandler::Create(const string &dataDir) +{ + if (dataDir.empty()) { + throw runtime_error("dataDir is Null. The pointer of the file system cannot be created."); + } + for (const auto &prefix: hdfsPrefixes) { + if (dataDir.substr(0, prefix.length()) == prefix) { + return make_unique(); + } + } + return make_unique(); +} diff --git a/src/core/file_system/file_system_handler.h b/src/core/file_system/file_system_handler.h new file mode 100644 index 00000000..59ae8738 --- /dev/null +++ b/src/core/file_system/file_system_handler.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#ifndef MX_REC_FILE_SYSTEM_HANDLER_H +#define MX_REC_FILE_SYSTEM_HANDLER_H + +#include "hdfs_file_system/hdfs_file_system.h" +#include "local_file_system/local_file_system.h" + +namespace MxRec { + using namespace std; + + class FileSystemHandler { + public: + inline unique_ptr Create(const string &dataDir); + private: + const vector hdfsPrefixes = {"hdfs://", "viewfs://"}; + }; +} + +#endif // MX_REC_FILE_SYSTEM_HANDLER_H \ No newline at end of file diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.h b/src/core/file_system/hdfs_file_system/hdfs_file_system.h new file mode 100644 index 00000000..9de39dce --- /dev/null +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#ifndef MX_REC_HDFS_FILE_SYSTEM_H +#define MX_REC_HDFS_FILE_SYSTEM_H + +#include "file_system/file_system.h" +#include "hdfs_wrapper.h" + +namespace MxRec { + using namespace std; + + class HdfsFileSystem : public FileSystem { + public: + HdfsFileSystem() + { + hdfs = make_unique(); + }; + ~HdfsFileSystem() override {} + + void CreateDir(const string& dirName) override; + vector ListDir(const string& dirName) override; + size_t GetFileSize(const string& filePath) override; + + ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) override; + ssize_t Write(const string& filePath, vector fileVector, size_t dataSize) override; + void WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) override; + + ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) override; + ssize_t Read(const string& filePath, vector>& fileVector, size_t datasetSize) override; + void ReadEmbedding(const string &filePath, const int& embeddingSize, + vector& addressArr, int deviceId) override; + + hdfsFS ConnectHdfs(); + + unique_ptr hdfs; + }; +} + +#endif // MX_REC_HDFS_FILE_SYSTEM_H \ No newline at end of file diff --git a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h new file mode 100644 index 00000000..e1a0ab5d --- /dev/null +++ b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h @@ -0,0 +1,184 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-27 + */ + +#ifndef MX_REC_HDFS_LOADER_H +#define MX_REC_HDFS_LOADER_H + +#include +#include + +#include "utils/common.h" + +namespace MxRec { + + // The following parameters are not named in large camel case to adapt to native HDFS interfaces. + // Including tObjectKind, tPort, tSize, tTime, tOffset, hdfs_internal, hdfsFS, hdfsFile_internal, hdfsFile, hdfsFileInfo + enum tObjectKind { + kObjectKindFile = 'F', + kObjectKindDirectory = 'D', + }; + + using tPort = uint16_t; + using tSize = int32_t; + using tTime = time_t; + using tOffset = int64_t; + + struct hdfs_internal; + using hdfsFS = struct hdfs_internal*; + struct hdfsFile_internal; + using hdfsFile = struct hdfsFile_internal*; + + struct hdfsFileInfo { + tObjectKind mKind{}; /* file or directory */ + char *mName{}; /* the name of the file */ + tTime mLastMod{}; /* the last modification time for the file in seconds */ + tOffset mSize{}; /* the size of the file in bytes */ + short mReplication{}; /* the count of replicas */ + tOffset mBlockSize{}; /* the block size for the file */ + char *mOwner{}; /* the owner of the file */ + char *mGroup{}; /* the group associated with the file */ + short mPermissions{}; /* the permissions associated with the file */ + tTime mLastAccess{}; /* the last access time for the file in seconds */ + }; + + class HdfsWrapper { + public: + HdfsWrapper() + { + // 动态加载hdfs库 + libhdfs = dlopen("libhdfs.so", RTLD_LAZY); + if (!libhdfs) { + LOG_ERROR("Init hdfs wrapper failed when loading libhdfs.so in environment."); + throw runtime_error("Init hdfs wrapper failed when loading libhdfs.so in environment. "); + } + + // 获取hdfs库中的函数指针 + hdfsConnect = reinterpret_cast(dlsym(libhdfs, "hdfsConnect")); + hdfsDisconnect = reinterpret_cast(dlsym(libhdfs, "hdfsDisconnect")); + hdfsCreateDirectory = reinterpret_cast(dlsym(libhdfs, "hdfsCreateDirectory")); + hdfsListDirectory = reinterpret_cast(dlsym(libhdfs, "hdfsListDirectory")); + hdfsFreeFileInfo = reinterpret_cast(dlsym(libhdfs, "hdfsFreeFileInfo")); + hdfsGetPathInfo = reinterpret_cast(dlsym(libhdfs, "hdfsGetPathInfo")); + hdfsOpenFile = reinterpret_cast(dlsym(libhdfs, "hdfsOpenFile")); + hdfsCloseFile = reinterpret_cast(dlsym(libhdfs, "hdfsCloseFile")); + hdfsRead = reinterpret_cast(dlsym(libhdfs, "hdfsRead")); + hdfsWrite = reinterpret_cast(dlsym(libhdfs, "hdfsWrite")); + } + + ~HdfsWrapper() + { + dlclose(libhdfs); + } + + hdfsFS Connect(const char* host, tPort port) + { + if (hdfsConnect == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsConnect from the libhdfs."); + } + return hdfsConnect(host, port); + } + + int Disconnect(hdfsFS fs) + { + if (hdfsDisconnect == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsDisconnect from the libhdfs."); + } + return hdfsDisconnect(fs); + } + + int CreateDirectory(hdfsFS fs, const char* path) + { + if (hdfsCreateDirectory == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsCreateDirectory from libhdfs."); + } + return hdfsCreateDirectory(fs, path); + } + + hdfsFileInfo* ListDirectory(hdfsFS fs, const char* path, int *numEntries) + { + if (hdfsListDirectory == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsListDirectory from the libhdfs."); + } + return hdfsListDirectory(fs, path, numEntries); + } + + hdfsFileInfo* GetPathInfo(hdfsFS fs, const char* path) + { + if (hdfsGetPathInfo == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsGetPathInfo from the libhdfs."); + } + return hdfsGetPathInfo(fs, path); + } + + void FreeFileInfo(hdfsFileInfo *hdfsFileInfo, int numEntries) + { + if (hdfsFreeFileInfo == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsFreeFileInfo from the libhdfs."); + } + return hdfsFreeFileInfo(hdfsFileInfo, numEntries); + } + + hdfsFile OpenFile(hdfsFS fs, const char* path, int flags, int bufferSize, short replication, tSize blocksize) + { + if (hdfsOpenFile == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsOpenFile from the libhdfs."); + } + return hdfsOpenFile(fs, path, flags, bufferSize, replication, blocksize); + } + + int CloseFile(hdfsFS fs, hdfsFile file) + { + if (hdfsCloseFile == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsCloseFile from the libhdfs."); + } + return hdfsCloseFile(fs, file); + } + + tSize Read(hdfsFS fs, hdfsFile file, void* buffer, tSize length) + { + if (hdfsRead == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsRead from the libhdfs."); + } + return hdfsRead(fs, file, buffer, length); + } + + tSize Write(hdfsFS fs, hdfsFile file, const void* buffer, tSize length) + { + if (hdfsWrite == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsWrite from the libhdfs."); + } + return hdfsWrite(fs, file, buffer, length); + } + + private: + void* libhdfs; + + using HdfsConnectFunc = hdfsFS (*)(const char*, tPort); + using HdfsDisconnectFunc = int (*)(hdfsFS); + using HdfsCreateDirectoryFunc = int (*)(hdfsFS fs, const char* path); + using HdfsListDirectoryFunc = hdfsFileInfo* (*)(hdfsFS fs, const char* path, int *numEntries); + using HdfsFreeFileInfoFunc = void (*)(hdfsFileInfo *hdfsFileInfo, int numEntries); + using HdfsGetPathInfoFunc = hdfsFileInfo* (*)(hdfsFS fs, const char* path); + using HdfsOpenFileFunc = hdfsFile (*)(hdfsFS, const char*, int, int, short, tSize); + using HdfsCloseFileFunc = int (*)(hdfsFS, hdfsFile); + using HdfsReadFunc = tSize (*)(hdfsFS, hdfsFile, void*, tSize); + using HdfsWriteFunc = tSize (*)(hdfsFS, hdfsFile, const void*, tSize); + + HdfsConnectFunc hdfsConnect; + HdfsDisconnectFunc hdfsDisconnect; + HdfsCreateDirectoryFunc hdfsCreateDirectory; + HdfsListDirectoryFunc hdfsListDirectory; + HdfsFreeFileInfoFunc hdfsFreeFileInfo; + HdfsGetPathInfoFunc hdfsGetPathInfo; + HdfsOpenFileFunc hdfsOpenFile; + HdfsCloseFileFunc hdfsCloseFile; + HdfsReadFunc hdfsRead; + HdfsWriteFunc hdfsWrite; + }; +} + +#endif // MX_REC_HDFS_LOADER_H \ No newline at end of file diff --git a/src/core/file_system/local_file_system/local_file_system.h b/src/core/file_system/local_file_system/local_file_system.h new file mode 100644 index 00000000..d6346a1d --- /dev/null +++ b/src/core/file_system/local_file_system/local_file_system.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#ifndef MX_REC_LOCAL_FILE_SYSTEM_H +#define MX_REC_LOCAL_FILE_SYSTEM_H + +#include "file_system/file_system.h" + +namespace MxRec { + using namespace std; + + class LocalFileSystem : public FileSystem { + public: + LocalFileSystem() : dirMode(0750), fileMode(0640), currDir("."), prevDir("..") {} + ~LocalFileSystem() override {} + + void CreateDir(const string& dirName) override; + vector ListDir(const string& dirName) override; + size_t GetFileSize(const string& filePath) override; + + ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) override; + ssize_t Write(const string& filePath, vector fileVector, size_t dataSize) override; + void WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) override; + + ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) override; + ssize_t Read(const string& filePath, vector>& fileVector, size_t datasetSize) override; + void ReadEmbedding(const string& filePath, const int& embeddingSize, + vector& addressArr, int deviceId) override; + + void WriterFn(BufferQueue& queue, int fd, ssize_t& writerBytesNum); + void FillToBuffer(BufferQueue& queue, const char* data, size_t dataSize); + void CalculateMapSize(off_t fileSize, size_t& mapByteSize, size_t& mapRowNum, size_t onceReadByteSize) const; + void HandleMappedData(char* mappedData, size_t mapRowNum, size_t onceReadByteSize, + vector>& dst, size_t cnt) const; + + private: + const mode_t dirMode; + const mode_t fileMode; + const string currDir; + const string prevDir; + }; +} + +#endif // MX_REC_LOCAL_FILE_SYSTEM_H \ No newline at end of file -- Gitee From 27470bcc10aa79e963cfa3742330be69c185895d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 16 Nov 2023 19:49:22 +0800 Subject: [PATCH 455/551] Match-id-f7c673dd851df961acbd92971fe4a26171f24804 --- src/core/checkpoint/checkpoint.cpp | 7 +- .../key_count_map_ckpt/key_count_map_ckpt.cpp | 87 +++++++++++++++++++ .../key_count_map_ckpt/key_count_map_ckpt.h | 43 +++++++++ src/core/hybrid_mgmt/hybrid_mgmt.cpp | 3 + src/core/key_process/key_process.cpp | 27 ++++++ src/core/key_process/key_process.h | 7 ++ src/core/utils/common.h | 8 +- 7 files changed, 179 insertions(+), 3 deletions(-) create mode 100644 src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.cpp create mode 100644 src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 8db8fbba..e9e055f5 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -18,6 +18,7 @@ #include "ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h" #include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" #include "ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h" +#include "ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h" #include "utils/time_cost.h" #include "utils/common.h" #include "checkpoint.h" @@ -61,6 +62,9 @@ void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRa void Checkpoint::SetDataHandler(CkptData& ckptData) { dataHandlers.clear(); + if (!ckptData.keyCountMap.empty()) { + dataHandlers.push_back(make_unique()); + } if (ckptData.hostEmbs != nullptr) { dataHandlers.push_back(make_unique()); } @@ -90,7 +94,8 @@ void Checkpoint::SetDataHandler(const vector& featureTypes) {CkptFeatureType::MAX_OFFSET, [this] { dataHandlers.push_back(make_unique()); }}, {CkptFeatureType::KEY_OFFSET_MAP, [this] { dataHandlers.push_back(make_unique()); }}, {CkptFeatureType::FEAT_ADMIT_N_EVICT, [this] { dataHandlers.push_back(make_unique()); }}, - {CkptFeatureType::DDR_KEY_FREQ_MAP, [this] { dataHandlers.push_back(make_unique()); }} + {CkptFeatureType::DDR_KEY_FREQ_MAP, [this] { dataHandlers.push_back(make_unique()); }}, + {CkptFeatureType::KEY_COUNT_MAP, [this] { dataHandlers.push_back(make_unique()); }} }; for (const auto& featureType : featureTypes) { diff --git a/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.cpp b/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.cpp new file mode 100644 index 00000000..d3ae9963 --- /dev/null +++ b/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-11-01 + */ + +#include "key_count_map_ckpt.h" + +using namespace std; +using namespace MxRec; + +void KeyCountMapCkpt::SetProcessData(CkptData& processData) +{ + saveKeyCountMap.clear(); + loadKeyCountMap.clear(); + saveKeyCountMap = std::move(processData.keyCountMap); +} + +void KeyCountMapCkpt::GetProcessData(CkptData& processData) +{ + processData.keyCountMap = std::move(loadKeyCountMap); + saveKeyCountMap.clear(); + loadKeyCountMap.clear(); +} + +vector KeyCountMapCkpt::GetDataTypes() +{ + return saveDataTypes; +} + +vector KeyCountMapCkpt::GetDirNames() +{ + return fileDirNames; +} + +vector KeyCountMapCkpt::GetEmbNames() +{ + vector embNames; + for (const auto& item : as_const(saveKeyCountMap)) { + embNames.push_back(item.first); + } + return embNames; +} + +CkptTransData KeyCountMapCkpt::GetDataset(CkptDataType dataType, string embName) +{ + CleanTransfer(); + + auto& transArr = transferData.int64Arr; + auto& attribute = transferData.attribute; + auto embHashMapSize = saveKeyCountMap.at(embName).size(); + + attribute.push_back(embHashMapSize); + embHashMapSize = embHashMapSize * embHashElmtNum; + + attribute.push_back(embHashElmtNum); + attribute.push_back(eightBytes); + transferData.datasetSize = embHashMapSize * eightBytes; + transferData.attributeSize = attribute.size() * eightBytes; + + transArr.reserve(embHashMapSize); + for (const auto& it : saveKeyCountMap.at(embName)) { + transArr.push_back(it.first); + transArr.push_back(it.second); + } + LOG_INFO("CkptDataType::EMB_INFO:{}, dataType:{}", CkptDataType::EMB_INFO, dataType); + return move(transferData); +} + +void KeyCountMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) +{ + CleanTransfer(); + transferData = move(loadedData); + + auto& singleKeyCountMap = loadKeyCountMap[embName]; + const auto& transArr = transferData.int64Arr; + + for (size_t i = 0; i < transArr.size(); i += embHashElmtNum) { + if (i + embHashElmtNum > transArr.size()) { + // this is an error, need to log this + } + int64_t key { transArr.at(i) }; + singleKeyCountMap[key] = transArr.at(i + 1); + } + LOG_INFO("dataType:{}", dataType); +} diff --git a/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h b/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h new file mode 100644 index 00000000..16869a9d --- /dev/null +++ b/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-11-01 + */ + +#ifndef MXREC_KEY_COUNT_MAP_CKPT_H +#define MXREC_KEY_COUNT_MAP_CKPT_H + +#include "ckpt_data_handler/ckpt_data_handler.h" + +namespace MxRec { + using namespace std; + + class KeyCountMapCkpt : public CkptDataHandler { + public: + KeyCountMapCkpt() = default; + ~KeyCountMapCkpt() override {} + + void SetProcessData(CkptData& processData) override; + void GetProcessData(CkptData& processData) override; + + vector GetDataTypes() override; + + vector GetDirNames() override; + vector GetEmbNames() override; + CkptTransData GetDataset(CkptDataType dataType, string embName) override; + + void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) override; + + private: + const vector fileDirNames { "HashTable", "FEAT_INFO" }; + const vector saveDataTypes { CkptDataType::KEY_COUNT_MAP }; + + const int embHashElmtNum = 2; + + KeyCountMemT saveKeyCountMap; + KeyCountMemT loadKeyCountMap; + }; +} + +#endif // MXREC_KEY_COUNT_MAP_CKPT_H diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 3ed87246..c0c108c8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -233,6 +233,7 @@ bool HybridMgmt::Save(const string savePath) CkptData saveData; Checkpoint saveCkpt; + saveData.keyCountMap = preprocess->GetKeyCountMap(); if (!mgmtRankInfo.noDDR) { // DDR模式保存host的emb表以及hashmap LOG_DEBUG(MGMT + "Start host side save: ddr mode hashmap"); @@ -306,6 +307,7 @@ bool HybridMgmt::Load(const string& loadPath) return false; } + preprocess->LoadKeyCountMap(loadData.keyCountMap); if (!mgmtRankInfo.noDDR) { // DDR模式 将加载的hash map进行赋值 LOG_DEBUG(MGMT + "Start host side load: ddr mode hashmap"); @@ -345,6 +347,7 @@ bool HybridMgmt::Load(const string& loadPath) void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, const FeatureAdmitAndEvict& featAdmitNEvict) { + loadFeatures.push_back(CkptFeatureType::KEY_COUNT_MAP); if (!mgmtRankInfo.noDDR) { // DDR模式加载的类型为host的emb表以及hashmap loadFeatures.push_back(CkptFeatureType::HOST_EMB); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 3f5619da..c75e2c36 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -130,6 +130,11 @@ auto KeyProcess::GetKeyOffsetMap() -> KeyOffsetMemT return keyOffsetMap; } +auto KeyProcess::GetKeyCountMap() -> KeyCountMemT +{ + return keyCountMap; +} + auto KeyProcess::GetFeatAdmitAndEvict() -> FeatureAdmitAndEvict& { return m_featureAdmitAndEvict; @@ -147,6 +152,11 @@ void KeyProcess::LoadKeyOffsetMap(KeyOffsetMemT& loadData) keyOffsetMap = std::move(loadData); } +void KeyProcess::LoadKeyCountMap(KeyCountMemT& loadData) +{ + keyCountMap = std::move(loadData); +} + // 只在python侧当训练结束时调用,如果出现死锁直接结束程序即可,测试时让进程等待足够长的时间再调用 void KeyProcess::Destroy() { @@ -250,6 +260,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) auto getBatchTime = getBatchDataTC.ElapsedMS(); TimeCost processDataTime = TimeCost(); + RecordKeyCountMap(batch); InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); if (!KeyProcessTaskHelperWithFastUnique(batch, unique, channel, threadId)) { break; @@ -286,6 +297,8 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) auto getBatchTime = getBatchDataTC.ElapsedMS(); TimeCost processDataTime = TimeCost(); + RecordKeyCountMap(batch); + if (!KeyProcessTaskHelper(batch, channel, threadId)) { break; } @@ -1406,4 +1419,18 @@ int64_t KeyProcess::GetExpansionTableCapacity(const string& embName) std::lock_guard lk(mut); // lock for PROCESS_THREAD return embeddingTableMap[embName].GetTableCapacity(); #endif +} + +void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) +{ + size_t miniBs = batch->Size(); + auto* batchData = batch->sample.data(); + auto& singleKeyCountMap = keyCountMap[batch->name]; + for (size_t i = 0; i < miniBs; i++) { + const emb_key_t& key = batchData[i]; + if (singleKeyCountMap.find(key) == singleKeyCountMap.end()) { + singleKeyCountMap[key] = 1; + } + singleKeyCountMap[key]++; + } } \ No newline at end of file diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index cea6978d..43fa8c7a 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -94,12 +94,16 @@ namespace MxRec { auto GetKeyOffsetMap() -> KeyOffsetMemT; + auto GetKeyCountMap() -> KeyCountMemT; + auto GetFeatAdmitAndEvict() -> FeatureAdmitAndEvict&; void LoadMaxOffset(OffsetMemT& loadData); void LoadKeyOffsetMap(KeyOffsetMemT& loadData); + void LoadKeyCountMap(KeyCountMemT& loadData); + void Destroy(); void LoadSaveLock(); @@ -116,6 +120,8 @@ namespace MxRec { int64_t GetExpansionTableCapacity(const string& embName); + void RecordKeyCountMap(const unique_ptr& batch); + template void GlobalUnique(T& lookupKeys, T& uniqueKeys, vector& restoreVecSec) { @@ -172,6 +178,7 @@ namespace MxRec { info_list_t all2AllList; map maxOffset {}; map> keyOffsetMap {}; + map> keyCountMap {}; FeatureAdmitAndEvict m_featureAdmitAndEvict {}; map> evictPosMap {}; map> hotKey {}; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index a54c14de..9910eb43 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -505,6 +505,7 @@ namespace MxRec { using EmbHashMemT = absl::flat_hash_map; using OffsetMemT = std::map; using KeyOffsetMemT = std::map>; + using KeyCountMemT = std::map>; using Table2ThreshMemT = absl::flat_hash_map; using trans_serialize_t = uint8_t; using KeyOffsetMapT = std::map; @@ -518,7 +519,8 @@ namespace MxRec { KEY_OFFSET_MAP = 3, FEAT_ADMIT_N_EVICT = 4, DDR_KEY_FREQ_MAP = 5, - EXCLUDE_DDR_KEY_FREQ_MAP = 6 + EXCLUDE_DDR_KEY_FREQ_MAP = 6, + KEY_COUNT_MAP = 7 }; struct CkptData { @@ -526,6 +528,7 @@ namespace MxRec { EmbHashMemT embHashMaps; OffsetMemT maxOffset; KeyOffsetMemT keyOffsetMap; + KeyCountMemT keyCountMap; Table2ThreshMemT table2Thresh; AdmitAndEvictData histRec; KeyFreqMemT ddrKeyFreqMaps; @@ -555,7 +558,8 @@ namespace MxRec { ATTRIBUTE = 9, DDR_FREQ_MAP = 10, EXCLUDE_FREQ_MAP = 11, - EVICT_POS = 12 + EVICT_POS = 12, + KEY_COUNT_MAP = 13 }; ostream& operator<<(ostream& ss, MxRec::CkptDataType type); -- Gitee From b275e05e4fbb4627232b163cbdb8bb75b0487f6d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 16 Nov 2023 21:55:39 +0800 Subject: [PATCH 456/551] Match-id-a41e1218dcb64c4ba760c71c5ca686b25067b3bc --- src/core/checkpoint/checkpoint.cpp | 4 ---- src/core/checkpoint/checkpoint.h | 3 ++- src/core/ckpt_data_handler/ckpt_data_handler.h | 5 +++-- src/core/key_process/key_process.cpp | 6 +++--- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index e9e055f5..a2a659fa 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -107,14 +107,11 @@ void Checkpoint::SaveProcess(CkptData& ckptData) { for (const auto& dataHandler : dataHandlers) { dataHandler->SetProcessData(ckptData); - vector embNames { dataHandler->GetEmbNames() }; vector dirNames { dataHandler->GetDirNames() }; vector saveDataTypes { dataHandler->GetDataTypes() }; - MakeUpperLayerSaveDir(dirNames); MakeDataLayerSaveDir(embNames, saveDataTypes, dataHandler); - SaveDataset(embNames, saveDataTypes, dataHandler); } } @@ -186,7 +183,6 @@ void Checkpoint::SaveDataset(const vector& embNames, if (!CheckEmbNames(embName)) { continue; } - auto dataDir{innerDirPath + dirSeparator + embName}; for (const auto& saveDataType: saveDataTypes) { auto datasetPath { dataDir + dirSeparator + dataHandler->GetDataDirName(saveDataType) }; diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index cb402182..5ada133f 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -54,7 +54,8 @@ namespace MxRec { CkptDataType::HIST_REC, CkptDataType::NDDR_FEATMAP, CkptDataType::DDR_FREQ_MAP, - CkptDataType::EXCLUDE_FREQ_MAP + CkptDataType::EXCLUDE_FREQ_MAP, + CkptDataType::KEY_COUNT_MAP }; const set floatTransSet{ CkptDataType::EMB_DATA diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index 82ee2f3a..aa2314d9 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -52,9 +52,10 @@ namespace MxRec { "attribute", "ddr_key_freq_map", "exclude_ddr_key_freq_map", - "evict_pos" + "evict_pos", + "key_count_map" }; - const vector dataElmtBytes { 4, 4, 8, 8, 4, 4, 8, 4, 8, 8, 8, 8, 8}; + const vector dataElmtBytes { 4, 4, 8, 8, 4, 4, 8, 4, 8, 8, 8, 8, 8, 8}; const uint32_t eightBytes { 8 }; const uint32_t fourBytes { 4 }; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index c75e2c36..55356e94 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -260,7 +260,6 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) auto getBatchTime = getBatchDataTC.ElapsedMS(); TimeCost processDataTime = TimeCost(); - RecordKeyCountMap(batch); InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); if (!KeyProcessTaskHelperWithFastUnique(batch, unique, channel, threadId)) { break; @@ -297,8 +296,6 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) auto getBatchTime = getBatchDataTC.ElapsedMS(); TimeCost processDataTime = TimeCost(); - RecordKeyCountMap(batch); - if (!KeyProcessTaskHelper(batch, channel, threadId)) { break; } @@ -358,6 +355,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch std::lock_guard lock(loadSaveMut[channel][threadId]); // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) + RecordKeyCountMap(batch); if (rankInfo.noDDR) { TimeCost key2OffsetTC; Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv, channel); @@ -405,6 +403,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, countRecv = GetCountRecv(batch, threadId, keyCount, scAll, ss); } std::lock_guard lock(loadSaveMut[channel][threadId]); + RecordKeyCountMap(batch); BuildRestoreVec(batch, ss, restore, static_cast(hotPos.size())); // 特征准入&淘汰 @@ -1423,6 +1422,7 @@ int64_t KeyProcess::GetExpansionTableCapacity(const string& embName) void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) { + std::lock_guard lk(mut); size_t miniBs = batch->Size(); auto* batchData = batch->sample.data(); auto& singleKeyCountMap = keyCountMap[batch->name]; -- Gitee From a59fc6d2b09e756c3bb81deda1fc81c389caae62 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 17 Nov 2023 12:01:58 +0800 Subject: [PATCH 457/551] Match-id-7635e7cfb411ae89e798b85dd6c5507ff8f823df --- mx_rec/constants/constants.py | 2 + mx_rec/saver/saver.py | 134 +++--- src/core/checkpoint/checkpoint.cpp | 408 ++--------------- src/core/checkpoint/checkpoint.h | 19 +- .../ckpt_data_handler/ckpt_data_handler.h | 2 +- .../feat_admit_n_evict_ckpt.h | 2 +- .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp | 25 +- .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.h | 5 +- .../buffer_queue.cpp | 0 .../buffer_queue.h | 0 src/core/file_system/file_system.h | 45 ++ src/core/file_system/file_system_handler.cpp | 24 + src/core/file_system/file_system_handler.h | 25 ++ .../hdfs_file_system/hdfs_file_system.cpp | 333 ++++++++++++++ .../hdfs_file_system/hdfs_file_system.h | 45 ++ .../hdfs_file_system/hdfs_wrapper.h | 184 ++++++++ .../local_file_system/local_file_system.cpp | 424 ++++++++++++++++++ .../local_file_system/local_file_system.h | 49 ++ src/core/hybrid_mgmt/hybrid_mgmt.cpp | 36 +- src/core/hybrid_mgmt/hybrid_mgmt.h | 3 +- src/core/utils/common.h | 6 +- 21 files changed, 1291 insertions(+), 480 deletions(-) rename src/core/{checkpoint => file_system}/buffer_queue.cpp (100%) rename src/core/{checkpoint => file_system}/buffer_queue.h (100%) create mode 100644 src/core/file_system/file_system.h create mode 100644 src/core/file_system/file_system_handler.cpp create mode 100644 src/core/file_system/file_system_handler.h create mode 100644 src/core/file_system/hdfs_file_system/hdfs_file_system.cpp create mode 100644 src/core/file_system/hdfs_file_system/hdfs_file_system.h create mode 100644 src/core/file_system/hdfs_file_system/hdfs_wrapper.h create mode 100644 src/core/file_system/local_file_system/local_file_system.cpp create mode 100644 src/core/file_system/local_file_system/local_file_system.h diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 64a98dc4..59b1ef27 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -69,6 +69,8 @@ MAX_INT32 = np.iinfo(np.int32).max DUMP_MIDIFY_GRAPH_FILE_MODE = 0o550 MAX_DEVICE_ID = 15 +# HDFS file system's file prefix +HDFS_FILE_PREFIX = ["viewfs://", "hdfs://"] class BaseEnum(Enum): @classmethod diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 8004426d..80fb2a52 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -11,10 +11,11 @@ import numpy as np import tensorflow as tf from tensorflow.python.util import compat -from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE, Flag, TFDevice, MAX_INT32 +from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE, Flag, TFDevice, \ + MAX_INT32, HDFS_FILE_PREFIX from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ - send_host_data, get_ascend_global_hashtable_collection, set_sparse_dir, get_local_rank_size + send_host_data, get_ascend_global_hashtable_collection, set_sparse_dir, get_local_rank_size, get_use_dynamic_expansion from mx_rec.util.perf import performance from mx_rec.validator.validator import DirectoryValidator, FileValidator, para_checker_decorator, ClassValidator, \ IntValidator, OptionalStringValidator @@ -103,7 +104,7 @@ class Saver(object): logger.debug("======== Start saving for rank id %s ========", self.rank_id) if not check_file_system_is_valid(save_path): raise ValueError(f"the path to save sparse embedding table data belong to invalid file system, " - f"only local file system supported. ") + f"only local file system and hdfs file system supported. ") save_path = save_path if save_path else self._prefix_name directory, base_name = os.path.split(save_path) @@ -119,10 +120,11 @@ class Saver(object): set_sparse_dir(saving_path) try: - directory_validator = DirectoryValidator("saving_path", saving_path) - directory_validator.check_not_soft_link() - directory_validator.with_blacklist(exact_compare=False) - directory_validator.check() + if not check_file_system_is_hdfs(saving_path): + directory_validator = DirectoryValidator("saving_path", saving_path) + directory_validator.check_not_soft_link() + directory_validator.with_blacklist(exact_compare=False) + directory_validator.check() except ValueError as err: raise ValueError(f"The saving path {saving_path} cannot be a system directory " f"and cannot be soft link.") from err @@ -149,7 +151,7 @@ class Saver(object): logger.debug("======== Start restoring ========") if not check_file_system_is_valid(reading_path): raise ValueError(f"the path to save sparse embedding table data belong to invalid file system, " - f"only local file system supported. ") + f"only local file system and hdfs file system supported. ") directory, base_name = os.path.split(reading_path) ckpt_name = f"sparse-{base_name}" @@ -165,26 +167,33 @@ class Saver(object): @performance("save_table_name_data") def save_table_name_data(self, sess, result, root_dir, table_name): - dump_data_dict = sess.run(result.get(table_name)) - table_instance = get_table_instance_by_name(table_name) self._make_table_name_dir(root_dir, table_instance, table_name) - # save key - if is_asc_manager_initialized() and self.save_easy_mode: - self._save_easy_mode_save_key_data(dump_data_dict, root_dir, table_name) + + dump_data_dict = sess.run(result.get(table_name)) + # when HBM mode is on, need to get host offset data, to process dump data dict for saving valid embedding. + if is_asc_manager_initialized() and table_instance.host_vocabulary_size == 0: + self._get_valid_dict_data(dump_data_dict, table_name) + # save embedding save_embedding_data(root_dir, table_name, dump_data_dict, self.rank_id) - if table_instance.use_feature_mapping: - save_feature_mapping_data(root_dir, table_name, dump_data_dict, self.rank_id) - save_offset_data(root_dir, table_name, dump_data_dict, self.rank_id) + + # save optimizer data if "optimizer" in dump_data_dict: dump_optimizer_data_dict = dump_data_dict.get("optimizer") for optimizer_name, dump_optimizer_data in dump_optimizer_data_dict.items(): - save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, - self.rank_id) + save_optimizer_state_data(root_dir, table_name, optimizer_name, dump_optimizer_data, self.rank_id) @performance("_save") def _save(self, sess, root_dir): + if is_asc_manager_initialized(): + save_host_data(root_dir) + logger.debug(f"host data was saved.") + + if get_use_dynamic_expansion(): + logger.error(f"use dynamic expansion.") + return + result = self.save_op_dict threads = [] for table_name in result.keys(): @@ -197,9 +206,11 @@ class Saver(object): for thread in threads: thread.join() - if is_asc_manager_initialized() and not self.save_easy_mode: - save_host_data(root_dir) - logger.debug(f"host data was saved.") + def _get_valid_dict_data(self, dump_data_dict, table_name): + host_data = get_host_data(table_name) + offset = list(host_data) + + get_valid_dict_data(dump_data_dict, offset) def _build_save(self): for var in self.var_list: @@ -245,21 +256,21 @@ class Saver(object): assign_op = state.assign(sub_optimizer_placeholder_dict.get(key_state)) self.restore_fetch_list.append(assign_op) - def _save_easy_mode_save_key_data(self, dump_data_dict, root_dir, table_name): - host_data = get_host_data(table_name) - key = np.array(list(host_data.keys())) - offset = list(host_data.values()) - get_valid_dict_data(dump_data_dict, offset) - save_key_data(root_dir, table_name, key, self.rank_id) def _restore(self, sess, reading_path): + if is_asc_manager_initialized(): + restore_host_data(reading_path) + logger.info("host data was restored.") + + if get_use_dynamic_expansion: + return + restore_feed_dict = defaultdict(dict) - key_offset_dict = defaultdict(dict) + for table_name, sub_placeholder_dict in self.placeholder_dict.items(): fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, NameDescriptor(table_name, DataName.EMBEDDING.value)) - if self.save_easy_mode: - fill_key_offset_dict(reading_path, self.rank_id, table_name, key_offset_dict) + table_instance = get_table_instance_by_name(table_name) if table_instance.use_feature_mapping: @@ -279,15 +290,10 @@ class Saver(object): name_descriptor=NameDescriptor(table_name, state_key, optimizer_name=optimizer_name)) - if is_asc_manager_initialized() and self.save_easy_mode: - send_host_data(key_offset_dict) - logger.info("host data was sent to the host pipeline.") - if is_asc_manager_initialized() and not self.save_easy_mode: - restore_host_data(reading_path) - logger.info("host data was restored.") sess.run(self.restore_fetch_list, feed_dict=restore_feed_dict) + class NameDescriptor: def __init__(self, table_name, data_name, optimizer_name=None): self.table_name = table_name @@ -313,23 +319,6 @@ def get_valid_dict_data(dump_data_dict: dict, offset: list): dump_data_dict["optimizer"] = dump_optimizer_data_dict -def fill_key_offset_dict(reading_path: str, rank_id: int, table_name: str, key_offset_dict: dict): - """ - Filling data in the key-offset dictionary , which is sent to the host pipeline. - :param reading_path: the path restoring the model - :param rank_id: rank id - :param table_name: the sparse table name - :param key_offset_dict: key-offset dictionary saving mapping relationship - """ - target_path = generate_path(reading_path, "HashTable", "HBM", table_name, - DataName.KEY.value) - key = read_binary_data(target_path, rank_id, DataName.KEY.value, table_name) - key = key.get(DataName.KEY.value) - offsets = list(range(key.shape[0])) - key_offset_map = dict(zip(key, offsets)) - key_offset_dict[table_name] = key_offset_map - - def fill_placeholder(reading_path, placeholder_dict, feed_dict, suffix, name_descriptor): if name_descriptor.optimizer_name: target_path = generate_path(reading_path, "Optimizer", name_descriptor.optimizer_name, "HBM", @@ -355,21 +344,6 @@ def save_embedding_data(root_dir, table_name, dump_data_dict, suffix): write_binary_data(target_path, suffix, data_to_write, attributes=attribute) -def save_key_data(root_dir: str, table_name: str, data_to_write: np.ndarray, suffix: int): - """ - Save the keys of the sparse table - :param root_dir: the root path saving the model - :param table_name: the sparse table name - :param data_to_write: the key array to be written - :param suffix: suffix of sparse data - """ - target_path = generate_path(root_dir, "HashTable", "HBM", table_name, DataName.KEY.value) - attribute = dict() - attribute[DataAttr.DATATYPE.value] = data_to_write.dtype.name - attribute[DataAttr.SHAPE.value] = data_to_write.shape - write_binary_data(target_path, suffix, data_to_write, attributes=attribute) - - def save_feature_mapping_data(root_dir, table_name, dump_data_dict, suffix): target_path = generate_path(root_dir, "HashTable", "HBM", table_name, DataName.FEATURE_MAPPING.value) data_to_write = dump_data_dict.get(DataName.FEATURE_MAPPING.value) @@ -422,7 +396,12 @@ def write_binary_data(writing_path, suffix, data, attributes=None): if tf.io.gfile.exists(target_attribute_dir): raise FileExistsError(f"Target_attribute_dir {target_attribute_dir} exists before writing.") - data.tofile(target_data_dir) + if check_file_system_is_hdfs(target_data_dir): + with tf.io.gfile.GFile(target_data_dir, "wb") as file: + data = data.tostring() + file.write(data) + else: + data.tofile(target_data_dir) if attributes is not None: if not isinstance(attributes, dict): @@ -458,7 +437,11 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: with tf.io.gfile.GFile(target_data_dir, "rb") as file: validate_read_file(target_data_dir) - data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) + if check_file_system_is_hdfs(target_data_dir): + data_to_restore = file.read() + data_to_restore = np.fromstring(data_to_restore, dtype=attributes.pop(DataAttr.DATATYPE.value)) + else: + data_to_restore = np.fromfile(target_data_dir, dtype=attributes.pop(DataAttr.DATATYPE.value)) if DataAttr.SHAPE.value in attributes and data_name != DataName.KEY.value: data_shape = attributes.pop(DataAttr.SHAPE.value) @@ -483,7 +466,8 @@ def validate_read_file(read_file_path): file_validator = FileValidator("read_file_path", read_file_path) file_validator.check_file_size(MAX_FILE_SIZE, MIN_SIZE) file_validator.check_user_group() - file_validator.check_not_soft_link() + if not check_file_system_is_hdfs(read_file_path): + file_validator.check_not_soft_link() file_validator.check() @@ -514,6 +498,12 @@ def process_embedding_data(data_to_restore: np.ndarray, current_data_shape: list def check_file_system_is_valid(file_path): - if file_path.find("://") == -1: + if file_path.find("://") == -1 or check_file_system_is_hdfs(file_path): return True return False + +def check_file_system_is_hdfs(file_path): + for prefix in HDFS_FILE_PREFIX: + if file_path.startwith(prefix): + return True + return False diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index a2a659fa..57d17f24 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -21,6 +21,8 @@ #include "ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h" #include "utils/time_cost.h" #include "utils/common.h" +#include "file_system/file_system_handler.h" + #include "checkpoint.h" using namespace std; @@ -34,6 +36,9 @@ void Checkpoint::SaveModel(string savePath, CkptData& ckptData, RankInfo& mgmtRa useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; mgmtEmbInfo = embInfo; + auto fileSystemHandler = make_unique(); + fileSystemPtr = fileSystemHandler->Create(savePath); + LOG_INFO("Start host side saving data."); LOG_DEBUG("==Start to create save data handler."); SetDataHandler(ckptData); @@ -51,6 +56,9 @@ void Checkpoint::LoadModel(string loadPath, CkptData& ckptData, RankInfo& mgmtRa useDynamicExpansion = mgmtRankInfo.useDynamicExpansion; mgmtEmbInfo = embInfo; + auto fileSystemHandler = make_unique(); + fileSystemPtr = fileSystemHandler->Create(loadPath); + LOG_INFO("Start host side loading data."); LOG_DEBUG("==Start to create load data handler."); SetDataHandler(featureTypes); @@ -71,9 +79,6 @@ void Checkpoint::SetDataHandler(CkptData& ckptData) if (!ckptData.embHashMaps.empty()) { dataHandlers.push_back(make_unique()); } - if (!ckptData.maxOffset.empty()) { - dataHandlers.push_back(make_unique()); - } if (!ckptData.keyOffsetMap.empty()) { dataHandlers.push_back(make_unique()); } @@ -145,11 +150,7 @@ void Checkpoint::MakeDataLayerSaveDir(const vector& embNames, void Checkpoint::MakeSaveDir(const string& dirName) const { - if (access(dirName.c_str(), F_OK) == -1) { - if (mkdir(dirName.c_str(), dirMode) == -1) { - LOG_DEBUG("Unable to create directory: {}", dirName); - } - } + fileSystemPtr->CreateDir(dirName); } Checkpoint::EmbSizeInfo Checkpoint::GetEmbeddingSize(const string& embName) @@ -194,7 +195,7 @@ void Checkpoint::SaveDataset(const vector& embNames, // save embedding when dynamic expansion is open if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { - auto embedPath { dataDir + dirSeparator + "key_embedding" }; + auto embedPath { dataDir + dirSeparator + "embedding" }; auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; auto embeddingSizeInfo = GetEmbeddingSize(embName); MakeSaveDir(embedPath); @@ -212,232 +213,48 @@ void Checkpoint::SaveDataset(const vector& embNames, void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize) { - ofstream writeFile; - writeFile.open(dataDir.c_str(), std::ios::out | std::ios::trunc | std::ios::binary); - fs::permissions(dataDir.c_str(), fs::perms::owner_read | fs::perms::owner_write | fs::perms::group_read); - -#ifndef GTEST - auto res = aclrtSetDevice(static_cast(deviceId)); - if (res != ACL_ERROR_NONE) { - LOG_ERROR("Set device failed, device_id:{}", deviceId); - writeFile.close(); - throw runtime_error(Logger::Format("Set device failed, device_id:{}", deviceId).c_str()); - } - - auto &transArr = transData.int64Arr; - for (size_t i{0}; i < transArr.size(); i += embHashNum) { - vector row(embeddingSize); - int64_t address = transArr.at(i + 1); - float *floatPtr = reinterpret_cast(address); - - aclError ret; - try { - ret = aclrtMemcpy(row.data(), embeddingSize * sizeof(float), - floatPtr, embeddingSize * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); - } catch (std::exception& e) { - writeFile.close(); - throw runtime_error(StringFormat("error happen when acl memory copy from device to host: %s", e.what())); - } - - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtMemcpy failed, ret={}", ret); - writeFile.close(); - throw runtime_error(Logger::Format("aclrtMemcpy failed, ret={}", ret).c_str()); - } - - try { - writeFile.write(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); - } catch (std::exception& e) { - writeFile.close(); - throw runtime_error(StringFormat("error happen when write embedding to file: %s", e.what())); - } - } -#endif - writeFile.close(); + auto &transArr = transData.addressArr; + fileSystemPtr->WriteEmbedding(dataDir, embeddingSize, transArr, deviceId); } void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, const string& embName) { - std::ifstream readFile; - readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); - size_t datasetSize = static_cast(readFile.tellg()); - readFile.seekg(0, std::ios::beg); - try { - ValidateReadFile(dataDir, datasetSize); - } catch (const std::invalid_argument& e) { - readFile.close(); - throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); - } - -#ifndef GTEST - auto res = aclrtSetDevice(static_cast(deviceId)); - if (res != ACL_ERROR_NONE) { - LOG_ERROR("Set device failed, device_id:{}", deviceId); - readFile.close(); - throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); - } - + auto datasetSize = fileSystemPtr->GetFileSize(dataDir); auto &attributeArr = transData.attribute; auto embHashMapSize = attributeArr.at(0); if (embHashMapSize <= 0) { - readFile.close(); + throw runtime_error(StringFormat("Invalid EmbHashMapSize:%d, must be greater than 0", embHashMapSize).c_str()); } + auto embeddingSize = static_cast(datasetSize / sizeof(float) / embHashMapSize); + auto &transArr = transData.addressArr; + EmbSizeInfo embSizeInfo = GetEmbeddingSize(embName); if (embeddingSize != embSizeInfo.extEmbSize) { - readFile.close(); - throw runtime_error(StringFormat("Invalid embedding size to be read, may read file has been changed").c_str()); - } + throw runtime_error(StringFormat("Invalid embedding size to be read, may read file has been changed").c_str()); - aclError ret; - void *newBlock = nullptr; - ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); - if (ret != ACL_SUCCESS) { - LOG_ERROR("aclrtMalloc failed, ret={}", ret); - readFile.close(); - throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); } - float *floatPtr = static_cast(newBlock); - auto &transArr = transData.int64Arr; - for (size_t i = 0, j = 0; i < transArr.size(); i += keyAddrElem, ++j) { - vector row(embeddingSize); - try { - readFile.read(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); - } catch (std::exception& e) { - readFile.close(); - throw runtime_error(StringFormat("error happen when reading embedding from file: %s", e.what())); - } - aclError ec; - try { - ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), - row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); - } catch (std::exception& e) { - readFile.close(); - throw runtime_error(StringFormat("error happen when acl memory copy from host to device: %s", e.what())); - } - if (ec != ACL_SUCCESS) { - LOG_ERROR("aclrtMemcpy failed, ret={}", ec); - readFile.close(); - throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ec).c_str()); - } - int64_t address = reinterpret_cast(floatPtr + j * embeddingSize); - transArr.at(i + 1) = address; - } -#endif - readFile.close(); -} + fileSystemPtr->ReadEmbedding(dataDir, embeddingSize, transArr, deviceId); -void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType) -{ - int fd = open(dataDir.c_str(), O_RDWR | O_CREAT | O_TRUNC, static_cast(0640)); - if (fd == -1) { - LOG_ERROR("Error opening file for writing"); - return; - } - - buffer.reserve(BUFFER_SIZE); - - BufferQueue queue; - - std::thread writer(&Checkpoint::WriterFn, this, std::ref(queue), fd); - - int loops = 1; - if (dataType == CkptDataType::EMB_DATA) { - loops = static_cast(transData.floatArr.size()); - } - for (int i = 0; i < loops; i++) { - size_t idx = 0; - size_t writeSize = 0; - size_t dataCol = dataSize; - while (dataCol != 0) { - if (dataCol > oneTimeReadWriteLen) { - writeSize = oneTimeReadWriteLen; - } else { - writeSize = dataCol; - } - if (floatTransSet.find(dataType) != floatTransSet.end()) { - FillToBuffer(queue, reinterpret_cast(transData.floatArr[i]) + idx, writeSize); - } else { - WriteDataset(transData, fd, writeSize, dataType, idx); - } - - dataCol -= writeSize; - idx += writeSize; - } - } - - // After all data has been processed, check if there is any data left in the buffer - if (!buffer.empty()) { - queue.Push(std::move(buffer)); - buffer.clear(); - } - - queue.Push(std::vector()); - - writer.join(); - - close(fd); } -void Checkpoint::WriterFn(BufferQueue& queue, int fd) -{ - while (true) { - queue.Pop(writeBuffer); - if (writeBuffer.size() == 0) { - break; - } - ssize_t result = write(fd, writeBuffer.data(), writeBuffer.size()); - if (result != writeBuffer.size()) { - LOG_ERROR("Error writing to file"); - close(fd); - throw runtime_error(StringFormat("error happen when writing file. ")); - } - writeBuffer.clear(); - } -} - -void Checkpoint::WriteDataset(CkptTransData& transData, - int fd, - size_t writeSize, - CkptDataType dataType, - size_t idx) +void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType) { - ssize_t result = 0; - if (int32TransSet.find(dataType) != int32TransSet.end()) { - result = write(fd, reinterpret_cast(transData.int32Arr.data()) + idx, writeSize); + LOG_ERROR("lff debug write stream {} dataset_size {}", dataType, dataSize); + if (floatTransSet.find(dataType) != floatTransSet.end()) { + fileSystemPtr->Write(dataDir, transData.floatArr, dataSize); + } else if (int32TransSet.find(dataType) != int32TransSet.end()) { + fileSystemPtr->Write(dataDir, reinterpret_cast(transData.int32Arr.data()), dataSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - result = write(fd, reinterpret_cast(transData.int64Arr.data()) + idx, writeSize); + fileSystemPtr->Write(dataDir, reinterpret_cast(transData.int64Arr.data()), dataSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - result = write(fd, reinterpret_cast(transData.attribute.data()) + idx, writeSize); - } - - if (result != writeSize) { - close(fd); - LOG_ERROR("Error writing to file, please check the disk buffer or temporary folder space or file permissions!"); - throw runtime_error(StringFormat("error happen when write file. ")); + fileSystemPtr->Write(dataDir, reinterpret_cast(transData.attribute.data()), dataSize); } + LOG_ERROR("lff debug write stream {} dataset_dir {} over!", dataType, dataDir); } -void Checkpoint::FillToBuffer(BufferQueue& queue, const char* data, size_t dataSize) -{ - size_t dataIdx = 0; - while (dataIdx < dataSize) { - size_t remainingSpace = BUFFER_SIZE - buffer.size(); - if (dataSize - dataIdx <= remainingSpace) { - buffer.insert(buffer.cend(), data + dataIdx, data + dataSize); - return; - } else { - buffer.insert(buffer.cend(), data + dataIdx, data + dataIdx + remainingSpace); - queue.Push(std::move(buffer)); - if (buffer.capacity() < BUFFER_SIZE) { - buffer.reserve(BUFFER_SIZE); - } - dataIdx += remainingSpace; - } - } -} void Checkpoint::LoadProcess(CkptData& ckptData) { @@ -480,25 +297,7 @@ vector Checkpoint::GetEmbedTableNames() vector Checkpoint::GetTableLayerLoadDir() { vector loadTableDir; - auto dir { opendir(innerDirPath.c_str()) }; - struct dirent* en; - if (dir != nullptr) { - int fileNum = 0; - while ((en = readdir(dir)) != nullptr) { - if (fileNum > MAX_FILE_NUM) { - closedir(dir); - throw std::runtime_error("The number of files has exceeded the limit " + std::to_string(MAX_FILE_NUM)); - } - if (strncmp(en->d_name, currDir.c_str(), strlen(currDir.c_str())) != 0 && - strncmp(en->d_name, prevDir.c_str(), strlen(prevDir.c_str())) != 0) { - loadTableDir.emplace_back(en->d_name); - } - fileNum++; - } - closedir(dir); - } else { - LOG_WARN("when loading data in ssd, there are no table files."); - } + loadTableDir = fileSystemPtr->ListDir(innerDirPath); return loadTableDir; } @@ -532,7 +331,7 @@ void Checkpoint::LoadDataset(const vector& embNames, // load embedding when use dynamic expansion is open if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { - auto embedPath { dataDir + dirSeparator + "key_embedding" }; + auto embedPath { dataDir + dirSeparator + "embedding" }; auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; LOG_DEBUG("====Start loading embedding data from: {}", datasetDir); ReadEmbedding(transData, embedDatasetDir, embName); @@ -557,47 +356,26 @@ void Checkpoint::ReadStream(CkptTransData& transData, LOG_WARN("dataElmtBytes is 0, don't handle [/ %] operation"); return ; } - - std::ifstream readFile; - readFile.open(dataDir.c_str(), std::ios::in | std::ios::binary | std::ios::ate); - size_t datasetSize = static_cast(readFile.tellg()); - readFile.seekg(0, std::ios::beg); - try { - ValidateReadFile(dataDir, datasetSize); - } catch (const std::invalid_argument& e) { - readFile.close(); - throw runtime_error(Logger::Format("Invalid read file path: {}", e.what())); - } + size_t datasetSize = fileSystemPtr->GetFileSize(dataDir); + auto resizeSize { datasetSize / dataElmtBytes }; + SetTransDataSize(transData, resizeSize, dataType); if (datasetSize % dataElmtBytes > 0) { LOG_DEBUG("data is missing or incomplete in load file: {}", dataDir); } + LOG_ERROR("lff debug start to enter read type:{}, size:{}", dataType, datasetSize); - auto resizeSize { datasetSize / dataElmtBytes }; - SetTransDataSize(transData, resizeSize, dataType); - if (readFile.is_open()) { - size_t idx = 0; - size_t readSize = 0; - while (datasetSize != 0) { - if (datasetSize > oneTimeReadWriteLen) { - readSize = oneTimeReadWriteLen; - } else { - readSize = datasetSize; - } - try { - ReadDataset(transData, readFile, readSize, dataType, idx); - } catch (std::exception& e) { - readFile.close(); - throw runtime_error(StringFormat("error happen when reading data from file: %s", e.what())); - } - datasetSize -= readSize; - idx += readSize; - } - } else { - LOG_DEBUG("unable to open load file: {}", dataDir); + + if (int32TransSet.find(dataType) != int32TransSet.end()) { + fileSystemPtr->Read(dataDir, reinterpret_cast(transData.int32Arr.data()), datasetSize); + } else if (int64TransSet.find(dataType) != int64TransSet.end()) { + fileSystemPtr->Read(dataDir, reinterpret_cast(transData.int64Arr.data()), datasetSize); + } else if (floatTransSet.find(dataType) != floatTransSet.end()) { + fileSystemPtr->Read(dataDir, reinterpret_cast(transData.floatArr.data()), datasetSize); + } else if (dataType == CkptDataType::ATTRIBUTE) { + fileSystemPtr->Read(dataDir, reinterpret_cast(transData.attribute.data()), datasetSize); } - readFile.close(); } void Checkpoint::ValidateFile(int fd, const string& dataDir, size_t datasetSize) const @@ -610,51 +388,6 @@ void Checkpoint::ValidateFile(int fd, const string& dataDir, size_t datasetSize) } } -void Checkpoint::HandleMappedData(char* mappedData, size_t mapRowNum, size_t onceReadByteSize, - vector>& dst, size_t cnt) const -{ -#pragma omp parallel for - for (size_t j = 0; j < mapRowNum; ++j) { - size_t idx = 0; - size_t readSize = 0; - size_t dataCol = onceReadByteSize; - while (dataCol != 0) { - if (dataCol > oneTimeReadWriteLen) { - readSize = oneTimeReadWriteLen; - } else { - readSize = dataCol; - } - - errno_t err = memcpy_s(dst[cnt + j].data() + idx, readSize, - mappedData + j * onceReadByteSize + idx, readSize); - if (err != 0) { - throw std::runtime_error("Error execution memcpy_s: " + std::to_string(err)); - } - dataCol -= readSize; - idx += readSize; - } - } -} - -void Checkpoint::CalculateMapSize(off_t fileSize, size_t& mapByteSize, size_t& mapRowNum, size_t onceReadByteSize) const -{ - // 每次映射的字节数 - mapByteSize = MAP_BYTE_SIZE; - // 确保mapByteSize是onceReadByteSize和pageSize的整数倍,确保每次映射的offset是页大小的整数倍 - size_t pageSize = sysconf(_SC_PAGESIZE); - if (pageSize == -1) { - throw std::runtime_error("Failed to get page size: " + std::string(strerror(errno))); - } - size_t lcmVal = std::lcm(onceReadByteSize, pageSize); - mapByteSize = (mapByteSize / lcmVal) * lcmVal; - - // 如果文件大小小于每次映射的字节数,则一次性映射,映射大小不是页大小整数倍的时候,mmap会自动向上取整,额外的字节会初始化成零 - if (fileSize <= mapByteSize) { - mapByteSize = fileSize; - } - - mapRowNum = mapByteSize / onceReadByteSize; -} void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, @@ -667,27 +400,16 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, return ; } auto embDataOuterSize = transData.attribute.at(attribEmbDataOuterIdx); + LOG_ERROR("1102debug read emb data :{}", embDataOuterSize); if (embDataOuterSize <= 0 || embDataOuterSize > MAX_VOCABULARY_SIZE) { throw runtime_error(StringFormat("Invalid embDataOuterSize :%d", embDataOuterSize).c_str()); } - int fd = open(dataDir.c_str(), O_RDONLY); - if (fd == -1) { - throw runtime_error(StringFormat("Failed to open file: %s", dataDir).c_str()); - } - - off_t fileSize = lseek(fd, 0, SEEK_END); - if (fileSize == 0) { - close(fd); - throw runtime_error(StringFormat("emb data file's size is 0").c_str()); - } - - size_t datasetSize = fileSize; - ValidateFile(fd, dataDir, datasetSize); + size_t datasetSize = fileSystemPtr->GetFileSize(dataDir); if (datasetSize % embDataOuterSize > 0 || datasetSize % dataElmtBytes > 0) { LOG_ERROR("data is missing or incomplete in load file: {}", dataDir); - close(fd); + throw runtime_error("unable to load EMB_DATA cause wrong-format saved emb data"); } @@ -695,44 +417,8 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, auto& dst = (*loadHostEmbs)[embName].embData; dst.reserve(embDataOuterSize); - auto onceReadByteSize { datasetSize / embDataOuterSize }; - - size_t mapByteSize; - size_t mapRowNum; - CalculateMapSize(fileSize, mapByteSize, mapRowNum, onceReadByteSize); - - off_t offset = 0; - size_t remainBytes = fileSize; - - for (size_t i = 0; i < embDataOuterSize; i += mapRowNum) { - // 如果剩余字节数小于每次映射的字节数,则更新每次映射的字节数和行数 - if (remainBytes < mapByteSize) { - mapByteSize = remainBytes; - mapRowNum = mapByteSize / onceReadByteSize; - } - - void* tempMappedData = mmap(NULL, mapByteSize, PROT_READ, MAP_PRIVATE, fd, offset); - if (tempMappedData == MAP_FAILED) { - close(fd); - throw std::runtime_error("Failed to map file: " + dataDir + ", errno: " + std::to_string(errno)); - } - char* mappedData = static_cast(tempMappedData); - - // 处理映射的数据 - try { - HandleMappedData(mappedData, mapRowNum, onceReadByteSize, dst, i); - } catch (const std::runtime_error& e) { - close(fd); - munmap(mappedData, mapByteSize); - throw runtime_error(StringFormat("handle mapped data error: %s", e.what())); - } - munmap(mappedData, mapByteSize); - - offset += mapByteSize; - remainBytes -= mapByteSize; - } + fileSystemPtr->Read(dataDir, dst, datasetSize); - close(fd); } void Checkpoint::SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType) diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 5ada133f..1be72faa 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -13,7 +13,8 @@ #include #include "utils/common.h" #include "ckpt_data_handler/ckpt_data_handler.h" -#include "buffer_queue.h" +#include "file_system/buffer_queue.h" +#include "file_system/file_system_handler.h" namespace MxRec { using namespace std; @@ -37,9 +38,6 @@ namespace MxRec { const string ssdSymbol {"SSD"}; const mode_t dirMode { 0750 }; - const string currDir { "." }; - const string prevDir { ".." }; - const size_t oneTimeReadWriteLen { 32768 }; // 4096 * 8 const set int32TransSet { @@ -55,7 +53,8 @@ namespace MxRec { CkptDataType::NDDR_FEATMAP, CkptDataType::DDR_FREQ_MAP, CkptDataType::EXCLUDE_FREQ_MAP, - CkptDataType::KEY_COUNT_MAP + CkptDataType::KEY_COUNT_MAP, + CkptDataType::EVICT_POS }; const set floatTransSet{ CkptDataType::EMB_DATA @@ -70,6 +69,8 @@ namespace MxRec { bool useDynamicExpansion {false}; vector mgmtEmbInfo; + unique_ptr fileSystemPtr; + const int embHashNum { 2 }; const int attribEmbDataOuterIdx { 0 }; const int attribEmbDataInnerIdx { 1 }; @@ -86,14 +87,6 @@ namespace MxRec { void SaveDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler); void WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType); - void FillToBuffer(BufferQueue& queue, const char* data, size_t dataSize); - void WriteDataset(CkptTransData& transData, - int fd, - size_t writeSize, - CkptDataType dataType, - size_t idx); - - void WriterFn(BufferQueue& queue, int fd); void WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize); void ReadEmbedding(CkptTransData& transData, const string& dataDir, const string& embName); diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index aa2314d9..f65dae5e 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -46,7 +46,7 @@ namespace MxRec { "dev_offset_2_Batch_n_Key", "embedding_current_status", "max_offset", - "key_offset_map", + "key", "table_2_threshold", "history_record", "attribute", diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h index ee716abf..37d8623e 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h @@ -31,7 +31,7 @@ namespace MxRec { void SetDataset(CkptDataType dataType, string embName, CkptTransData& loadedData) override; private: - const vector fileDirNames { "HashTable", "DDR" }; + const vector fileDirNames { "HashTable", "FEAT_INFO" }; const vector saveDataTypes { CkptDataType::TABLE_2_THRESH, CkptDataType::HIST_REC }; const int featItemInfoSaveNum { 3 }; diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp index 46b8e8b2..e6698746 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp @@ -15,13 +15,17 @@ void NddrFeatMapCkpt::SetProcessData(CkptData& processData) saveKeyOffsetMap.clear(); loadKeyOffsetMap.clear(); saveKeyOffsetMap = std::move(processData.keyOffsetMap); + // 传递 offsetMap的引用 + offsetMapPtr = processData.offsetMapPtr; } void NddrFeatMapCkpt::GetProcessData(CkptData& processData) { processData.keyOffsetMap = std::move(loadKeyOffsetMap); + processData.maxOffset = std::move(loadMaxOffset); saveKeyOffsetMap.clear(); loadKeyOffsetMap.clear(); + loadMaxOffset.clear(); } vector NddrFeatMapCkpt::GetDataTypes() @@ -48,6 +52,7 @@ CkptTransData NddrFeatMapCkpt::GetDataset(CkptDataType dataType, string embName) CleanTransfer(); auto& transArr = transferData.int64Arr; + auto& addressArr = transferData.addressArr; auto& attribute = transferData.attribute; auto embHashMapSize = saveKeyOffsetMap.at(embName).size(); @@ -60,11 +65,14 @@ CkptTransData NddrFeatMapCkpt::GetDataset(CkptDataType dataType, string embName) transferData.attributeSize = attribute.size() * eightBytes; transArr.reserve(embHashMapSize); + (*offsetMapPtr)[embName].clear(); + LOG_ERROR("build offset map : first key offset {}", saveKeyOffsetMap[embName][0]); for (const auto& it : saveKeyOffsetMap.at(embName)) { transArr.push_back(it.first); transArr.push_back(it.second); + (*offsetMapPtr)[embName].push_back(it.second); } - LOG_INFO("CkptDataType::EMB_INFO:{}, dataType:{}", CkptDataType::EMB_INFO, dataType); + LOG_INFO("CkptDataType::EMB_INFO:{}, dataType:{}", CkptDataType::NDDR_FEATMAP, dataType); return move(transferData); } @@ -72,16 +80,25 @@ void NddrFeatMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTran { CleanTransfer(); transferData = move(loadedData); - + auto& maxOffset = loadMaxOffset[embName]; auto& hostHashMap = loadKeyOffsetMap[embName]; const auto& transArr = transferData.int64Arr; - + const auto& addressArr = transferData.addressArr; + int64_t offset { 0 }; for (size_t i { 0 }; i < transArr.size(); i += embHashElmtNum) { if (i + embHashElmtNum > transArr.size()) { // this is an error, need to log this } int64_t key { transArr.at(i) }; - hostHashMap[key] = transArr.at(i + 1); + if (addressArr.size() == 0) { + // no dynamic expansion + hostHashMap[key] = offset; + } else{ + // dynamic expansion + hostHashMap[key] = addressArr.at(i); + } + offset++; } + maxOffset = offset; LOG_INFO("dataType:{}", dataType); } diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h index dd7f4e16..57fb7d21 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h @@ -33,10 +33,13 @@ namespace MxRec { const vector fileDirNames { "HashTable", "HBM" }; const vector saveDataTypes { CkptDataType::NDDR_FEATMAP }; - const int embHashElmtNum { 2 }; + const int embHashElmtNum { 1 }; KeyOffsetMemT saveKeyOffsetMap; KeyOffsetMemT loadKeyOffsetMap; + + OffsetMemT loadMaxOffset; + OffsetMapT* offsetMapPtr; }; } diff --git a/src/core/checkpoint/buffer_queue.cpp b/src/core/file_system/buffer_queue.cpp similarity index 100% rename from src/core/checkpoint/buffer_queue.cpp rename to src/core/file_system/buffer_queue.cpp diff --git a/src/core/checkpoint/buffer_queue.h b/src/core/file_system/buffer_queue.h similarity index 100% rename from src/core/checkpoint/buffer_queue.h rename to src/core/file_system/buffer_queue.h diff --git a/src/core/file_system/file_system.h b/src/core/file_system/file_system.h new file mode 100644 index 00000000..372f85a9 --- /dev/null +++ b/src/core/file_system/file_system.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#ifndef MX_REC_FILE_SYSTEM_H +#define MX_REC_FILE_SYSTEM_H + +#include "buffer_queue.h" + +namespace MxRec { + using namespace std; + + class FileSystem { + public: + FileSystem() = default; + virtual ~FileSystem() = default; + + virtual void CreateDir(const string& dirName) = 0; + virtual vector ListDir(const string& dirName) = 0; + virtual size_t GetFileSize(const string& filePath) = 0; + + virtual ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) = 0; + virtual ssize_t Write(const string& filePath, vector fileContent, size_t dataSize) = 0; + virtual void WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) = 0; + + virtual ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) = 0; + virtual ssize_t Read(const string& filePath, vector>& fileContent, size_t datasetSize) = 0; + virtual void ReadEmbedding(const string& filePath, const int& embeddingSize, + vector& addressArr, int deviceId) = 0; + + // The parameter oneTimeReadWriteLen specifies the maximum length of a file read or write at a time. + // The parameter can be adjusted based on the service requirements. + const size_t oneTimeReadWriteLen = 32768; + const int embHashNum = 1; + const int keyAddrElem = 1; + std::vector buffer; + std::vector writeBuffer; + }; +} + +#endif //MX_REC_FILE_SYSTEM_H diff --git a/src/core/file_system/file_system_handler.cpp b/src/core/file_system/file_system_handler.cpp new file mode 100644 index 00000000..faa8147f --- /dev/null +++ b/src/core/file_system/file_system_handler.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-11-16 + */ + +#include "file_system_handler.h" + +using namespace std; +using namespace MxRec; + +unique_ptr FileSystemHandler::Create(const string& filePath) +{ + if (filePath.empty()) { + throw runtime_error("dataDir is Null. The pointer of the file system cannot be created."); + } + for (const auto &prefix: hdfsPrefixes) { + if (filePath.substr(0, prefix.length()) == prefix) { + return make_unique(); + } + } + return make_unique(); +} diff --git a/src/core/file_system/file_system_handler.h b/src/core/file_system/file_system_handler.h new file mode 100644 index 00000000..b2a92999 --- /dev/null +++ b/src/core/file_system/file_system_handler.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#ifndef MX_REC_FILE_SYSTEM_HANDLER_H +#define MX_REC_FILE_SYSTEM_HANDLER_H + +#include "hdfs_file_system/hdfs_file_system.h" +#include "local_file_system/local_file_system.h" + +namespace MxRec { + using namespace std; + + class FileSystemHandler { + public: + unique_ptr Create(const string& filePath); + private: + const vector hdfsPrefixes = {"hdfs://", "viewfs://"}; + }; +} + +#endif // MX_REC_FILE_SYSTEM_HANDLER_H \ No newline at end of file diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp new file mode 100644 index 00000000..e0fb9827 --- /dev/null +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#include "hdfs_file_system.h" + +#include +#include +#include +#include + +#include "hdfs_wrapper.h" +#include "utils/logger.h" + +using namespace std; +using namespace MxRec; + +void HdfsFileSystem::CreateDir(const string& dirName) +{ + hdfsFS fs = ConnectHdfs(); + int ret = hdfs->CreateDirectory(fs, dirName.c_str()); + if (ret == -1) { + LOG_DEBUG("Unable to create hdfs directory: {}", dirName); + } + hdfs->Disconnect(fs); +} + +vector HdfsFileSystem::ListDir(const string& dirName) +{ + vector dirs; + hdfsFS fs = ConnectHdfs(); + + int numEntries = 0; + hdfsFileInfo* subDirs = hdfs->ListDirectory(fs, dirName.c_str(), &numEntries); + for (int i = 0; i < numEntries; ++i) { + if (subDirs[i].mKind == kObjectKindDirectory) { + dirs.emplace_back(subDirs[i].mName); + } + } + + hdfs->FreeFileInfo(subDirs, numEntries); + hdfs->Disconnect(fs); + return dirs; +} + +size_t HdfsFileSystem::GetFileSize(const string& filePath) +{ + hdfsFS fs = ConnectHdfs(); + hdfsFileInfo* fileInfo = hdfs->GetPathInfo(fs, filePath.c_str()); + hdfs->Disconnect(fs); + if (fileInfo == nullptr) { + return 0; + } + auto fileSize = static_cast(fileInfo->mSize); + return fileSize; +} + +ssize_t HdfsFileSystem::Write(const string& filePath, const char* fileContent, size_t dataSize) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_WRONLY | O_CREAT, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("Error writing to hdfs file."); + } + + size_t dataCol = dataSize; + size_t writeSize = 0; + size_t idx = 0; + tSize writeBytesNum = 0; + + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = dataCol; + } + + tSize res = hdfs->Write(fs, file, fileContent + idx, writeSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(res); + } + dataCol -= writeSize; + idx += writeSize; + writeBytesNum += res; + } + + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(writeBytesNum); +} + +ssize_t HdfsFileSystem::Write(const string& filePath, vector fileContent, size_t dataSize) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_WRONLY | O_CREAT, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("Error writing to hdfs file."); + } + + tSize writeBytesNum = 0; + size_t loops = fileContent.size(); + for (size_t i = 0; i < loops; i++) { + size_t dataCol = dataSize; + size_t writeSize = 0; + size_t idx = 0; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = dataCol; + } + tSize res = hdfs->Write(fs, file, fileContent[i] + idx, writeSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(res); + } + dataCol -= writeSize; + idx += writeSize; + writeBytesNum += res; + } + } + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(writeBytesNum); +} + +/// 用于动态扩容模式下,往hdfs文件中写embedding +/// \param filePath 文件路径 +/// \param embeddingSize embedding的长度 +/// \param addressArr 存放embedding的地址vector +/// \param deviceId 运行的卡的id +/// \return +void HdfsFileSystem::WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_WRONLY | O_CREAT, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("Error writing to hdfs file."); + } + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } + + for (size_t i = 0; i < addressArr.size(); i += embHashNum) { + vector row(embeddingSize); + int64_t address = addressArr.at(i); + float *floatPtr = reinterpret_cast(address); + + aclError ret = aclrtMemcpy(row.data(), embeddingSize * sizeof(float), + floatPtr, embeddingSize * sizeof(float), + ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_SUCCESS) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error("aclrtMemcpy failed"); + } + + auto numBytesWritten = hdfs->Write(fs, file, row.data(), embeddingSize * sizeof(float)); + if (numBytesWritten != embeddingSize * sizeof(float)) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error("Error writing to hdfs file."); + } + } +#endif + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); +} + +ssize_t HdfsFileSystem::Read(const string& filePath, char* fileContent, size_t datasetSize) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("open hdfs file failed."); + } + + size_t dataCol = datasetSize; + size_t idx = 0; + size_t readSize = 0; + tSize readBytesNum = 0; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = dataCol; + } + tSize res = hdfs->Read(fs, file, fileContent + idx, readSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(res); + } + dataCol -= readSize; + idx += readSize; + readBytesNum += res; + } + + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(readBytesNum); +} + +ssize_t HdfsFileSystem::Read(const string& filePath, vector>& fileVector, size_t datasetSize) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("open hdfs file failed."); + } + + size_t embDataOuterSize = fileVector.capacity(); + auto onceReadByteSize { datasetSize / embDataOuterSize }; + tSize readBytesNum = 0; + + for (size_t i = 0; i < embDataOuterSize; ++i) { + size_t idx = 0; + size_t readSize = 0; + size_t dataCol = onceReadByteSize; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = dataCol; + } + tSize res = hdfs->Read(fs, file, reinterpret_cast(fileVector[i].data()) + idx, readSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(res); + } + dataCol -= readSize; + idx += readSize; + readBytesNum += res; + } + } + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(readBytesNum); +} + +/// 用于动态扩容模式下,从hdfs文件中读embedding +/// \param filePath 文件路径 +/// \param embeddingSize embedding的长度 +/// \param addressArr 存放embedding的地址vector +/// \param deviceId 运行的卡的id +/// \return +void HdfsFileSystem::ReadEmbedding(const string& filePath, const int& embeddingSize, + vector& addressArr, int deviceId) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("open hdfs file failed."); + } + + size_t datasetSize = GetFileSize(filePath); + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } + + aclError ret; + void *newBlock = nullptr; + ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); + } + + float *floatPtr = static_cast(newBlock); + + for (size_t i = 0, j = 0; i < addressArr.size(); i += keyAddrElem, ++j) { + vector row(embeddingSize); + auto bytesRead = hdfs->Read(fs, file, row.data(), embeddingSize * sizeof(float)); + if (bytesRead != embeddingSize * sizeof(float)) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error("Error read hdfs file."); + } + + aclError ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), + row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + if (ec != ACL_SUCCESS) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ec).c_str()); + } + int64_t address = reinterpret_cast(floatPtr + j * embeddingSize); + addressArr.at(i) = address; + } + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); +#endif +} + +hdfsFS HdfsFileSystem::ConnectHdfs() +{ + hdfsFS fs = hdfs->Connect("default", 0); + if (!fs) { + throw runtime_error("Connect hdfs file system failed."); + } + return fs; +} \ No newline at end of file diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.h b/src/core/file_system/hdfs_file_system/hdfs_file_system.h new file mode 100644 index 00000000..9de39dce --- /dev/null +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#ifndef MX_REC_HDFS_FILE_SYSTEM_H +#define MX_REC_HDFS_FILE_SYSTEM_H + +#include "file_system/file_system.h" +#include "hdfs_wrapper.h" + +namespace MxRec { + using namespace std; + + class HdfsFileSystem : public FileSystem { + public: + HdfsFileSystem() + { + hdfs = make_unique(); + }; + ~HdfsFileSystem() override {} + + void CreateDir(const string& dirName) override; + vector ListDir(const string& dirName) override; + size_t GetFileSize(const string& filePath) override; + + ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) override; + ssize_t Write(const string& filePath, vector fileVector, size_t dataSize) override; + void WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) override; + + ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) override; + ssize_t Read(const string& filePath, vector>& fileVector, size_t datasetSize) override; + void ReadEmbedding(const string &filePath, const int& embeddingSize, + vector& addressArr, int deviceId) override; + + hdfsFS ConnectHdfs(); + + unique_ptr hdfs; + }; +} + +#endif // MX_REC_HDFS_FILE_SYSTEM_H \ No newline at end of file diff --git a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h new file mode 100644 index 00000000..e1a0ab5d --- /dev/null +++ b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h @@ -0,0 +1,184 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-27 + */ + +#ifndef MX_REC_HDFS_LOADER_H +#define MX_REC_HDFS_LOADER_H + +#include +#include + +#include "utils/common.h" + +namespace MxRec { + + // The following parameters are not named in large camel case to adapt to native HDFS interfaces. + // Including tObjectKind, tPort, tSize, tTime, tOffset, hdfs_internal, hdfsFS, hdfsFile_internal, hdfsFile, hdfsFileInfo + enum tObjectKind { + kObjectKindFile = 'F', + kObjectKindDirectory = 'D', + }; + + using tPort = uint16_t; + using tSize = int32_t; + using tTime = time_t; + using tOffset = int64_t; + + struct hdfs_internal; + using hdfsFS = struct hdfs_internal*; + struct hdfsFile_internal; + using hdfsFile = struct hdfsFile_internal*; + + struct hdfsFileInfo { + tObjectKind mKind{}; /* file or directory */ + char *mName{}; /* the name of the file */ + tTime mLastMod{}; /* the last modification time for the file in seconds */ + tOffset mSize{}; /* the size of the file in bytes */ + short mReplication{}; /* the count of replicas */ + tOffset mBlockSize{}; /* the block size for the file */ + char *mOwner{}; /* the owner of the file */ + char *mGroup{}; /* the group associated with the file */ + short mPermissions{}; /* the permissions associated with the file */ + tTime mLastAccess{}; /* the last access time for the file in seconds */ + }; + + class HdfsWrapper { + public: + HdfsWrapper() + { + // 动态加载hdfs库 + libhdfs = dlopen("libhdfs.so", RTLD_LAZY); + if (!libhdfs) { + LOG_ERROR("Init hdfs wrapper failed when loading libhdfs.so in environment."); + throw runtime_error("Init hdfs wrapper failed when loading libhdfs.so in environment. "); + } + + // 获取hdfs库中的函数指针 + hdfsConnect = reinterpret_cast(dlsym(libhdfs, "hdfsConnect")); + hdfsDisconnect = reinterpret_cast(dlsym(libhdfs, "hdfsDisconnect")); + hdfsCreateDirectory = reinterpret_cast(dlsym(libhdfs, "hdfsCreateDirectory")); + hdfsListDirectory = reinterpret_cast(dlsym(libhdfs, "hdfsListDirectory")); + hdfsFreeFileInfo = reinterpret_cast(dlsym(libhdfs, "hdfsFreeFileInfo")); + hdfsGetPathInfo = reinterpret_cast(dlsym(libhdfs, "hdfsGetPathInfo")); + hdfsOpenFile = reinterpret_cast(dlsym(libhdfs, "hdfsOpenFile")); + hdfsCloseFile = reinterpret_cast(dlsym(libhdfs, "hdfsCloseFile")); + hdfsRead = reinterpret_cast(dlsym(libhdfs, "hdfsRead")); + hdfsWrite = reinterpret_cast(dlsym(libhdfs, "hdfsWrite")); + } + + ~HdfsWrapper() + { + dlclose(libhdfs); + } + + hdfsFS Connect(const char* host, tPort port) + { + if (hdfsConnect == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsConnect from the libhdfs."); + } + return hdfsConnect(host, port); + } + + int Disconnect(hdfsFS fs) + { + if (hdfsDisconnect == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsDisconnect from the libhdfs."); + } + return hdfsDisconnect(fs); + } + + int CreateDirectory(hdfsFS fs, const char* path) + { + if (hdfsCreateDirectory == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsCreateDirectory from libhdfs."); + } + return hdfsCreateDirectory(fs, path); + } + + hdfsFileInfo* ListDirectory(hdfsFS fs, const char* path, int *numEntries) + { + if (hdfsListDirectory == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsListDirectory from the libhdfs."); + } + return hdfsListDirectory(fs, path, numEntries); + } + + hdfsFileInfo* GetPathInfo(hdfsFS fs, const char* path) + { + if (hdfsGetPathInfo == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsGetPathInfo from the libhdfs."); + } + return hdfsGetPathInfo(fs, path); + } + + void FreeFileInfo(hdfsFileInfo *hdfsFileInfo, int numEntries) + { + if (hdfsFreeFileInfo == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsFreeFileInfo from the libhdfs."); + } + return hdfsFreeFileInfo(hdfsFileInfo, numEntries); + } + + hdfsFile OpenFile(hdfsFS fs, const char* path, int flags, int bufferSize, short replication, tSize blocksize) + { + if (hdfsOpenFile == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsOpenFile from the libhdfs."); + } + return hdfsOpenFile(fs, path, flags, bufferSize, replication, blocksize); + } + + int CloseFile(hdfsFS fs, hdfsFile file) + { + if (hdfsCloseFile == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsCloseFile from the libhdfs."); + } + return hdfsCloseFile(fs, file); + } + + tSize Read(hdfsFS fs, hdfsFile file, void* buffer, tSize length) + { + if (hdfsRead == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsRead from the libhdfs."); + } + return hdfsRead(fs, file, buffer, length); + } + + tSize Write(hdfsFS fs, hdfsFile file, const void* buffer, tSize length) + { + if (hdfsWrite == nullptr) { + throw runtime_error("Failed to obtain the pointer of the function hdfsWrite from the libhdfs."); + } + return hdfsWrite(fs, file, buffer, length); + } + + private: + void* libhdfs; + + using HdfsConnectFunc = hdfsFS (*)(const char*, tPort); + using HdfsDisconnectFunc = int (*)(hdfsFS); + using HdfsCreateDirectoryFunc = int (*)(hdfsFS fs, const char* path); + using HdfsListDirectoryFunc = hdfsFileInfo* (*)(hdfsFS fs, const char* path, int *numEntries); + using HdfsFreeFileInfoFunc = void (*)(hdfsFileInfo *hdfsFileInfo, int numEntries); + using HdfsGetPathInfoFunc = hdfsFileInfo* (*)(hdfsFS fs, const char* path); + using HdfsOpenFileFunc = hdfsFile (*)(hdfsFS, const char*, int, int, short, tSize); + using HdfsCloseFileFunc = int (*)(hdfsFS, hdfsFile); + using HdfsReadFunc = tSize (*)(hdfsFS, hdfsFile, void*, tSize); + using HdfsWriteFunc = tSize (*)(hdfsFS, hdfsFile, const void*, tSize); + + HdfsConnectFunc hdfsConnect; + HdfsDisconnectFunc hdfsDisconnect; + HdfsCreateDirectoryFunc hdfsCreateDirectory; + HdfsListDirectoryFunc hdfsListDirectory; + HdfsFreeFileInfoFunc hdfsFreeFileInfo; + HdfsGetPathInfoFunc hdfsGetPathInfo; + HdfsOpenFileFunc hdfsOpenFile; + HdfsCloseFileFunc hdfsCloseFile; + HdfsReadFunc hdfsRead; + HdfsWriteFunc hdfsWrite; + }; +} + +#endif // MX_REC_HDFS_LOADER_H \ No newline at end of file diff --git a/src/core/file_system/local_file_system/local_file_system.cpp b/src/core/file_system/local_file_system/local_file_system.cpp new file mode 100644 index 00000000..7c6c1d1f --- /dev/null +++ b/src/core/file_system/local_file_system/local_file_system.cpp @@ -0,0 +1,424 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#include "local_file_system.h" + +#include +#include +#include +#include +#include +#include + +#include "file_system/buffer_queue.h" +#include "utils/common.h" + +using namespace std; +using namespace MxRec; + +void LocalFileSystem::CreateDir(const string& dirName) +{ + if (access(dirName.c_str(), F_OK) == -1) { + if (mkdir(dirName.c_str(), dirMode) == -1) { + LOG_DEBUG("Unable to create directory: {}", dirName); + } + } +} + +vector LocalFileSystem::ListDir(const string& dirName) +{ + vector dirs; + DIR *dir = opendir(dirName.c_str()); + struct dirent* en; + if (dir == nullptr) { + LOG_WARN("Open directory {} failed while trying to traverse the directory.", dirName); + closedir(dir); + return dirs; + } + + for (en = readdir(dir); dir != nullptr; en = readdir(dir)) { + if (strncmp(en->d_name, currDir.c_str(), strlen(currDir.c_str())) != 0 && + strncmp(en->d_name, prevDir.c_str(), strlen(prevDir.c_str())) != 0) { + dirs.emplace_back(en->d_name); + } + closedir(dir); + } + return dirs; +} + +size_t LocalFileSystem::GetFileSize(const string& filePath) +{ + std::ifstream readFile; + readFile.open(filePath.c_str(), std::ios::in | std::ios::binary | std::ios::ate); + if (!readFile.is_open()) { + throw runtime_error(StringFormat("open file %s to get file size failed.", filePath.c_str())); + } + size_t datasetSize = static_cast(readFile.tellg()); + readFile.close(); + return datasetSize; +} + +ssize_t LocalFileSystem::Write(const string& filePath, const char* fileContent, size_t dataSize) +{ + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, fileMode); + if (fd == -1) { + throw runtime_error(StringFormat("open file %s to write failed.", filePath.c_str())); + } + + size_t dataCol = dataSize; + size_t writeSize = 0; + size_t idx = 0; + ssize_t writeBytesNum = 0; + + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = dataCol; + } + ssize_t res = write(fd, fileContent + idx, writeSize); + if (res == -1) { + close(fd); + return res; + } + dataCol -= writeSize; + idx += writeSize; + writeBytesNum += res; + } + close(fd); + return writeBytesNum; +} + +ssize_t LocalFileSystem::Write(const string& filePath, vector fileContent, size_t dataSize) +{ + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, fileMode); + if (fd == -1) { + throw runtime_error(StringFormat("open file %s to write failed.", filePath.c_str())); + } + + buffer.reserve(BUFFER_SIZE); + BufferQueue queue; + ssize_t writeBytesNum = 0; + std::thread writer(&LocalFileSystem::WriterFn, this, std::ref(queue), fd, std::ref(writeBytesNum)); + + size_t loops = fileContent.size(); + for (size_t i = 0; i < loops; i++) { + size_t idx = 0; + size_t writeSize = 0; + size_t dataCol = dataSize; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = dataCol; + } + FillToBuffer(queue, reinterpret_cast(fileContent[i]) + idx, writeSize); + dataCol -= writeSize; + idx += writeSize; + } + } + + // After all data has been processed, check if there is any data left in the buffer + if (!buffer.empty()) { + queue.Push(std::move(buffer)); + buffer.clear(); + } + + queue.Push(std::vector()); + writer.join(); + close(fd); + return writeBytesNum; +} + +/// 用于动态扩容模式下,往本地文件中写embedding +/// \param filePath 文件路径 +/// \param embeddingSize embedding的长度 +/// \param addressArr 存放embedding的地址vector +/// \param deviceId 运行的卡的id +/// \return +void LocalFileSystem::WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) +{ + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, fileMode); + if (fd == -1) { + throw runtime_error(StringFormat("open file %s to write failed.", filePath.c_str())); + } + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + close(fd); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } + + for (size_t i = 0; i < addressArr.size(); i += keyAddrElem) { + vector row(embeddingSize); + int64_t address = addressArr.at(i); + float *floatPtr = reinterpret_cast(address); + + aclError ret; + try { + ret = aclrtMemcpy(row.data(), embeddingSize * sizeof(float), + floatPtr, embeddingSize * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); + } catch (std::exception& e) { + close(fd); + throw runtime_error(StringFormat("error happen when acl memory copy from device to host: %s", e.what())); + } + + if (ret != ACL_SUCCESS) { + close(fd); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); + } + + ssize_t result = write(fd, row.data(), embeddingSize * sizeof(float)); + if (result != embeddingSize * sizeof(float)) { + close(fd); + throw runtime_error("Error writing to local file, " + "please check the disk buffer or temporary folder space or file permissions!"); + } + } +#endif + close(fd); +} + +ssize_t LocalFileSystem::Read(const string& filePath, char* fileContent, size_t datasetSize) +{ + int fd = open(filePath.c_str(), O_RDONLY); + if (fd == -1) { + throw runtime_error(StringFormat("Failed to open read file: %s", filePath.c_str())); + } + + try { + ValidateReadFile(filePath, datasetSize); + } catch (const std::invalid_argument& e) { + close(fd); + throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + } + + size_t idx = 0; + size_t readSize = 0; + ssize_t readBytesNum = 0; + while (datasetSize != 0) { + if (datasetSize > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = datasetSize; + } + ssize_t res = read(fd, fileContent + idx, readSize); + if (res == -1) { + close(fd); + return res; + } + datasetSize -= readSize; + idx += readSize; + readBytesNum += readSize; + } + close(fd); + return readBytesNum; +} + +ssize_t LocalFileSystem::Read(const string& filePath, vector>& fileContent, size_t datasetSize) +{ + size_t embDataOuterSize = fileContent.capacity(); + auto onceReadByteSize { datasetSize / embDataOuterSize }; + + size_t mapByteSize; + size_t mapRowNum; + CalculateMapSize(datasetSize, mapByteSize, mapRowNum, onceReadByteSize); + + off_t offset = 0; + size_t remainBytes = datasetSize; + ssize_t readBytesNum = 0; + + int fd = open(filePath.c_str(), O_RDONLY); + if (fd == -1) { + throw runtime_error(StringFormat("Failed to open read file: %s", filePath.c_str())); + } + + for (size_t i = 0; i < embDataOuterSize; i += mapRowNum) { + // 如果剩余字节数小于每次映射的字节数,则更新每次映射的字节数和行数 + if (remainBytes < mapByteSize) { + mapByteSize = remainBytes; + mapRowNum = mapByteSize / onceReadByteSize; + } + + void* tempMappedData = mmap(nullptr, mapByteSize, PROT_READ, MAP_PRIVATE, fd, offset); + if (tempMappedData == MAP_FAILED) { + close(fd); + return -1; + } + readBytesNum += mapByteSize; + + char* mappedData = static_cast(tempMappedData); + + // 处理映射的数据 + try { + HandleMappedData(mappedData, mapRowNum, onceReadByteSize, fileContent, i); + } catch (const std::runtime_error& e) { + close(fd); + munmap(mappedData, mapByteSize); + throw runtime_error(StringFormat("handle mapped data error: %s", e.what())); + } + munmap(mappedData, mapByteSize); + + offset += mapByteSize; + remainBytes -= mapByteSize; + } + close(fd); + return readBytesNum; +} + +/// 用于动态扩容模式下,从本地文件中读取embedding +/// \param filePath 文件路径 +/// \param embeddingSize embedding的长度 +/// \param addressArr 存放embedding的地址vector +/// \param deviceId 运行的卡的id +/// \return +void LocalFileSystem::ReadEmbedding(const string& filePath, const int& embeddingSize, + vector& addressArr, int deviceId) +{ + std::ifstream readFile; + readFile.open(filePath.c_str(), std::ios::in | std::ios::binary | std::ios::ate); + if (!readFile.is_open()) { + throw runtime_error(StringFormat("open file %s to read failed.", filePath.c_str())); + } + + size_t datasetSize = static_cast(readFile.tellg()); + auto embHashMapSize = static_cast(datasetSize / sizeof(float) / embeddingSize); + readFile.seekg(0, std::ios::beg); + + try { + ValidateReadFile(filePath, datasetSize); + } catch (const std::invalid_argument& e) { + readFile.close(); + throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + } + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + readFile.close(); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } + + void *newBlock = nullptr; + aclError ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + readFile.close(); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); + } + + float *floatPtr = static_cast(newBlock); + addressArr.reserve(embHashMapSize); + + for (size_t i = 0, j = 0; i < embHashMapSize; i += keyAddrElem, ++j) { + vector row(embeddingSize); + readFile.read(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); + aclError ec; + try { + ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), + row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + } catch (std::exception& e) { + readFile.close(); + throw runtime_error(StringFormat("error happen when acl memory copy from host to device: %s", e.what())); + } + + if (ec != ACL_SUCCESS) { + readFile.close(); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ec).c_str()); + } + int64_t address = reinterpret_cast(floatPtr + j * embeddingSize); + addressArr.push_back(address); + } +#endif + readFile.close(); +} + +void LocalFileSystem::WriterFn(BufferQueue& queue, int fd, ssize_t& writerBytesNum) +{ + while (true) { + queue.Pop(writeBuffer); + if (writeBuffer.size() == 0) { + break; + } + ssize_t res = write(fd, writeBuffer.data(), writeBuffer.size()); + if (res == -1) { + close(fd); + writerBytesNum = -1; + break; + } + writerBytesNum += res; + writeBuffer.clear(); + } +} + +void LocalFileSystem::FillToBuffer(BufferQueue& queue, const char* data, size_t dataSize) +{ + size_t dataIdx = 0; + while (dataIdx < dataSize) { + size_t remainingSpace = BUFFER_SIZE - buffer.size(); + if (dataSize - dataIdx <= remainingSpace) { + buffer.insert(buffer.cend(), data + dataIdx, data + dataSize); + return; + } else { + buffer.insert(buffer.cend(), data + dataIdx, data + dataIdx + remainingSpace); + queue.Push(std::move(buffer)); + if (buffer.capacity() < BUFFER_SIZE) { + buffer.reserve(BUFFER_SIZE); + } + dataIdx += remainingSpace; + } + } +} + +void LocalFileSystem::CalculateMapSize(off_t fileSize, size_t& mapByteSize, + size_t& mapRowNum, size_t onceReadByteSize) const +{ + // 每次映射的字节数 + mapByteSize = MAP_BYTE_SIZE; + // 确保mapByteSize是onceReadByteSize和pageSize的整数倍,确保每次映射的offset是页大小的整数倍 + size_t pageSize = sysconf(_SC_PAGESIZE); + if (pageSize == -1) { + throw std::runtime_error("Failed to get page size: " + std::string(strerror(errno))); + } + size_t lcmVal = std::lcm(onceReadByteSize, pageSize); + mapByteSize = (mapByteSize / lcmVal) * lcmVal; + + // 如果文件大小小于每次映射的字节数,则一次性映射,映射大小不是页大小整数倍的时候,mmap会自动向上取整,额外的字节会初始化成零 + if (fileSize <= mapByteSize) { + mapByteSize = fileSize; + } + + mapRowNum = mapByteSize / onceReadByteSize; +} + + +void LocalFileSystem::HandleMappedData(char* mappedData, size_t mapRowNum, size_t onceReadByteSize, + vector>& dst, size_t cnt) const +{ +#pragma omp parallel for + for (size_t j = 0; j < mapRowNum; ++j) { + size_t idx = 0; + size_t readSize = 0; + size_t dataCol = onceReadByteSize; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = dataCol; + } + + errno_t err = memcpy_s(dst[cnt + j].data() + idx, readSize, + mappedData + j * onceReadByteSize + idx, readSize); + if (err != 0) { + throw std::runtime_error("Error execution memcpy_s: " + std::to_string(err)); + } + dataCol -= readSize; + idx += readSize; + } + } +} \ No newline at end of file diff --git a/src/core/file_system/local_file_system/local_file_system.h b/src/core/file_system/local_file_system/local_file_system.h new file mode 100644 index 00000000..d6346a1d --- /dev/null +++ b/src/core/file_system/local_file_system/local_file_system.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#ifndef MX_REC_LOCAL_FILE_SYSTEM_H +#define MX_REC_LOCAL_FILE_SYSTEM_H + +#include "file_system/file_system.h" + +namespace MxRec { + using namespace std; + + class LocalFileSystem : public FileSystem { + public: + LocalFileSystem() : dirMode(0750), fileMode(0640), currDir("."), prevDir("..") {} + ~LocalFileSystem() override {} + + void CreateDir(const string& dirName) override; + vector ListDir(const string& dirName) override; + size_t GetFileSize(const string& filePath) override; + + ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) override; + ssize_t Write(const string& filePath, vector fileVector, size_t dataSize) override; + void WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) override; + + ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) override; + ssize_t Read(const string& filePath, vector>& fileVector, size_t datasetSize) override; + void ReadEmbedding(const string& filePath, const int& embeddingSize, + vector& addressArr, int deviceId) override; + + void WriterFn(BufferQueue& queue, int fd, ssize_t& writerBytesNum); + void FillToBuffer(BufferQueue& queue, const char* data, size_t dataSize); + void CalculateMapSize(off_t fileSize, size_t& mapByteSize, size_t& mapRowNum, size_t onceReadByteSize) const; + void HandleMappedData(char* mappedData, size_t mapRowNum, size_t onceReadByteSize, + vector>& dst, size_t cnt) const; + + private: + const mode_t dirMode; + const mode_t fileMode; + const string currDir; + const string prevDir; + }; +} + +#endif // MX_REC_LOCAL_FILE_SYSTEM_H \ No newline at end of file diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index c0c108c8..48f274a2 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -8,6 +8,7 @@ #include "utils/time_cost.h" #include "utils/logger.h" +#include "utils/common.h" #include "checkpoint/checkpoint.h" @@ -268,7 +269,7 @@ bool HybridMgmt::Save(const string savePath) // 执行保存操作 saveCkpt.SaveModel(savePath, saveData, mgmtRankInfo, mgmtEmbInfo); - + offsetMapToSend = std::move(saveData.offsetMap); // 数据处理线程释放锁 preprocess->LoadSaveUnlock(); #endif @@ -315,7 +316,6 @@ bool HybridMgmt::Load(const string& loadPath) } else { // HBM模式 将加载的最大偏移(真正使用了多少vocab容量)、特征到偏移的映射,进行赋值 LOG_DEBUG(MGMT + "Start host side load: no ddr mode hashmap"); - preprocess->LoadMaxOffset(loadData.maxOffset); preprocess->LoadKeyOffsetMap(loadData.keyOffsetMap); } @@ -354,7 +354,6 @@ void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, loadFeatures.push_back(CkptFeatureType::EMB_HASHMAP); } else { // HBM模式加载的类型为最大偏移(真正使用了多少vocab容量),特征到偏移的映射 - loadFeatures.push_back(CkptFeatureType::MAX_OFFSET); loadFeatures.push_back(CkptFeatureType::KEY_OFFSET_MAP); } @@ -371,32 +370,19 @@ void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, /// 获取key对应的offset,python侧调用 /// \param tableName 表名 /// \return -KeyOffsetMapT HybridMgmt::SendHostMap(const string tableName) +OffsetT HybridMgmt::SendHostMap(const string tableName) { #ifndef GTEST - if (!isInitialized) { - throw runtime_error("HybridMgmt not initialized. Call Initialize first."); - } - - preprocess->LoadSaveLock(); - KeyOffsetMemT keyOffsetMap; - KeyOffsetMapT sendKeyOffsetMap; - - if (!mgmtRankInfo.noDDR) { - LOG_DEBUG(MGMT + "Start send sparse data: ddr mode hashmap"); - } else { - LOG_DEBUG(MGMT + "Start send sparse data: no ddr mode hashmap"); - keyOffsetMap = preprocess->GetKeyOffsetMap(); - } - - if ((!keyOffsetMap.empty()) && keyOffsetMap.count(tableName) > 0) { - for (const auto& it : keyOffsetMap.at(tableName)) { - sendKeyOffsetMap[it.first] = it.second; + OffsetT OffsetMap; + // 先校验这个map是不是空的 + if ((!offsetMapToSend.empty()) && offsetMapToSend.count(tableName) > 0) { + LOG_ERROR("send offset map : table name =={} offset count {}", tableName.c_str(), offsetMapToSend.count(tableName)); + LOG_ERROR("send offset map : first key offset {}", offsetMapToSend[tableName][0]); + for (auto& it : offsetMapToSend.at(tableName)) { + OffsetMap.push_back(it); } } - - preprocess->LoadSaveUnlock(); - return sendKeyOffsetMap; + return OffsetMap; #endif } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 5ebfed4a..2828abd4 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -63,7 +63,7 @@ namespace MxRec { void SetFeatureTypeForLoad(vector& loadFeatures, const FeatureAdmitAndEvict& featAdmitNEvict); - KeyOffsetMapT SendHostMap(const string tableName); + OffsetT SendHostMap(const string tableName); void ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap); @@ -161,6 +161,7 @@ namespace MxRec { map> evictKeyMap {}; KeyProcess *preprocess; HDTransfer *hdTransfer; + OffsetMapT offsetMapToSend; bool isSSDEnabled { false }; bool isRunning; bool isLoad { false }; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 9910eb43..e6ea6934 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -508,7 +508,8 @@ namespace MxRec { using KeyCountMemT = std::map>; using Table2ThreshMemT = absl::flat_hash_map; using trans_serialize_t = uint8_t; - using KeyOffsetMapT = std::map; + using OffsetMapT = std::map>; + using OffsetT = std::vector; using AllKeyOffsetMapT = std::map>; using KeyFreqMemT = unordered_map>; @@ -528,6 +529,8 @@ namespace MxRec { EmbHashMemT embHashMaps; OffsetMemT maxOffset; KeyOffsetMemT keyOffsetMap; + OffsetMapT offsetMap; + OffsetMapT* offsetMapPtr = &offsetMap; KeyCountMemT keyCountMap; Table2ThreshMemT table2Thresh; AdmitAndEvictData histRec; @@ -537,6 +540,7 @@ namespace MxRec { struct CkptTransData { std::vector int64Arr; + std::vector addressArr; std::vector floatArr; std::vector int32Arr; std::vector transDataset; // may all use this to transfer data -- Gitee From f2cd38832b13c3aec0877e83e2adb84cb3a2ffcb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 17 Nov 2023 14:26:24 +0800 Subject: [PATCH 458/551] Match-id-6ce12462bec7e5364639857277133f9cb104313a --- src/core/file_system/file_system_handler.cpp | 6 +- src/core/file_system/file_system_handler.h | 2 +- .../hdfs_file_system/hdfs_file_system.cpp | 333 ++++++++++++++ .../local_file_system/local_file_system.cpp | 424 ++++++++++++++++++ .../file_system/local_file_system_test.cpp | 30 ++ 5 files changed, 791 insertions(+), 4 deletions(-) create mode 100644 src/core/file_system/hdfs_file_system/hdfs_file_system.cpp create mode 100644 src/core/file_system/local_file_system/local_file_system.cpp create mode 100644 src/tests/file_system/local_file_system_test.cpp diff --git a/src/core/file_system/file_system_handler.cpp b/src/core/file_system/file_system_handler.cpp index 351cff2d..faa8147f 100644 --- a/src/core/file_system/file_system_handler.cpp +++ b/src/core/file_system/file_system_handler.cpp @@ -10,13 +10,13 @@ using namespace std; using namespace MxRec; -inline unique_ptr FileSystemHandler::Create(const string &dataDir) +unique_ptr FileSystemHandler::Create(const string& filePath) { - if (dataDir.empty()) { + if (filePath.empty()) { throw runtime_error("dataDir is Null. The pointer of the file system cannot be created."); } for (const auto &prefix: hdfsPrefixes) { - if (dataDir.substr(0, prefix.length()) == prefix) { + if (filePath.substr(0, prefix.length()) == prefix) { return make_unique(); } } diff --git a/src/core/file_system/file_system_handler.h b/src/core/file_system/file_system_handler.h index 59ae8738..b2a92999 100644 --- a/src/core/file_system/file_system_handler.h +++ b/src/core/file_system/file_system_handler.h @@ -16,7 +16,7 @@ namespace MxRec { class FileSystemHandler { public: - inline unique_ptr Create(const string &dataDir); + unique_ptr Create(const string& filePath); private: const vector hdfsPrefixes = {"hdfs://", "viewfs://"}; }; diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp new file mode 100644 index 00000000..e0fb9827 --- /dev/null +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp @@ -0,0 +1,333 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#include "hdfs_file_system.h" + +#include +#include +#include +#include + +#include "hdfs_wrapper.h" +#include "utils/logger.h" + +using namespace std; +using namespace MxRec; + +void HdfsFileSystem::CreateDir(const string& dirName) +{ + hdfsFS fs = ConnectHdfs(); + int ret = hdfs->CreateDirectory(fs, dirName.c_str()); + if (ret == -1) { + LOG_DEBUG("Unable to create hdfs directory: {}", dirName); + } + hdfs->Disconnect(fs); +} + +vector HdfsFileSystem::ListDir(const string& dirName) +{ + vector dirs; + hdfsFS fs = ConnectHdfs(); + + int numEntries = 0; + hdfsFileInfo* subDirs = hdfs->ListDirectory(fs, dirName.c_str(), &numEntries); + for (int i = 0; i < numEntries; ++i) { + if (subDirs[i].mKind == kObjectKindDirectory) { + dirs.emplace_back(subDirs[i].mName); + } + } + + hdfs->FreeFileInfo(subDirs, numEntries); + hdfs->Disconnect(fs); + return dirs; +} + +size_t HdfsFileSystem::GetFileSize(const string& filePath) +{ + hdfsFS fs = ConnectHdfs(); + hdfsFileInfo* fileInfo = hdfs->GetPathInfo(fs, filePath.c_str()); + hdfs->Disconnect(fs); + if (fileInfo == nullptr) { + return 0; + } + auto fileSize = static_cast(fileInfo->mSize); + return fileSize; +} + +ssize_t HdfsFileSystem::Write(const string& filePath, const char* fileContent, size_t dataSize) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_WRONLY | O_CREAT, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("Error writing to hdfs file."); + } + + size_t dataCol = dataSize; + size_t writeSize = 0; + size_t idx = 0; + tSize writeBytesNum = 0; + + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = dataCol; + } + + tSize res = hdfs->Write(fs, file, fileContent + idx, writeSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(res); + } + dataCol -= writeSize; + idx += writeSize; + writeBytesNum += res; + } + + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(writeBytesNum); +} + +ssize_t HdfsFileSystem::Write(const string& filePath, vector fileContent, size_t dataSize) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_WRONLY | O_CREAT, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("Error writing to hdfs file."); + } + + tSize writeBytesNum = 0; + size_t loops = fileContent.size(); + for (size_t i = 0; i < loops; i++) { + size_t dataCol = dataSize; + size_t writeSize = 0; + size_t idx = 0; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = dataCol; + } + tSize res = hdfs->Write(fs, file, fileContent[i] + idx, writeSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(res); + } + dataCol -= writeSize; + idx += writeSize; + writeBytesNum += res; + } + } + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(writeBytesNum); +} + +/// 用于动态扩容模式下,往hdfs文件中写embedding +/// \param filePath 文件路径 +/// \param embeddingSize embedding的长度 +/// \param addressArr 存放embedding的地址vector +/// \param deviceId 运行的卡的id +/// \return +void HdfsFileSystem::WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_WRONLY | O_CREAT, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("Error writing to hdfs file."); + } + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } + + for (size_t i = 0; i < addressArr.size(); i += embHashNum) { + vector row(embeddingSize); + int64_t address = addressArr.at(i); + float *floatPtr = reinterpret_cast(address); + + aclError ret = aclrtMemcpy(row.data(), embeddingSize * sizeof(float), + floatPtr, embeddingSize * sizeof(float), + ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_SUCCESS) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error("aclrtMemcpy failed"); + } + + auto numBytesWritten = hdfs->Write(fs, file, row.data(), embeddingSize * sizeof(float)); + if (numBytesWritten != embeddingSize * sizeof(float)) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error("Error writing to hdfs file."); + } + } +#endif + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); +} + +ssize_t HdfsFileSystem::Read(const string& filePath, char* fileContent, size_t datasetSize) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("open hdfs file failed."); + } + + size_t dataCol = datasetSize; + size_t idx = 0; + size_t readSize = 0; + tSize readBytesNum = 0; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = dataCol; + } + tSize res = hdfs->Read(fs, file, fileContent + idx, readSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(res); + } + dataCol -= readSize; + idx += readSize; + readBytesNum += res; + } + + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(readBytesNum); +} + +ssize_t HdfsFileSystem::Read(const string& filePath, vector>& fileVector, size_t datasetSize) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("open hdfs file failed."); + } + + size_t embDataOuterSize = fileVector.capacity(); + auto onceReadByteSize { datasetSize / embDataOuterSize }; + tSize readBytesNum = 0; + + for (size_t i = 0; i < embDataOuterSize; ++i) { + size_t idx = 0; + size_t readSize = 0; + size_t dataCol = onceReadByteSize; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = dataCol; + } + tSize res = hdfs->Read(fs, file, reinterpret_cast(fileVector[i].data()) + idx, readSize); + if (res == -1) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(res); + } + dataCol -= readSize; + idx += readSize; + readBytesNum += res; + } + } + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + return static_cast(readBytesNum); +} + +/// 用于动态扩容模式下,从hdfs文件中读embedding +/// \param filePath 文件路径 +/// \param embeddingSize embedding的长度 +/// \param addressArr 存放embedding的地址vector +/// \param deviceId 运行的卡的id +/// \return +void HdfsFileSystem::ReadEmbedding(const string& filePath, const int& embeddingSize, + vector& addressArr, int deviceId) +{ + hdfsFS fs = ConnectHdfs(); + + hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); + if (!file) { + hdfs->Disconnect(fs); + throw runtime_error("open hdfs file failed."); + } + + size_t datasetSize = GetFileSize(filePath); + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } + + aclError ret; + void *newBlock = nullptr; + ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); + } + + float *floatPtr = static_cast(newBlock); + + for (size_t i = 0, j = 0; i < addressArr.size(); i += keyAddrElem, ++j) { + vector row(embeddingSize); + auto bytesRead = hdfs->Read(fs, file, row.data(), embeddingSize * sizeof(float)); + if (bytesRead != embeddingSize * sizeof(float)) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error("Error read hdfs file."); + } + + aclError ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), + row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + if (ec != ACL_SUCCESS) { + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ec).c_str()); + } + int64_t address = reinterpret_cast(floatPtr + j * embeddingSize); + addressArr.at(i) = address; + } + hdfs->CloseFile(fs, file); + hdfs->Disconnect(fs); +#endif +} + +hdfsFS HdfsFileSystem::ConnectHdfs() +{ + hdfsFS fs = hdfs->Connect("default", 0); + if (!fs) { + throw runtime_error("Connect hdfs file system failed."); + } + return fs; +} \ No newline at end of file diff --git a/src/core/file_system/local_file_system/local_file_system.cpp b/src/core/file_system/local_file_system/local_file_system.cpp new file mode 100644 index 00000000..51648fa4 --- /dev/null +++ b/src/core/file_system/local_file_system/local_file_system.cpp @@ -0,0 +1,424 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: + * Author: MindX SDK + * Create: 2023-10-19 + */ + +#include "local_file_system.h" + +#include +#include +#include +#include +#include +#include + +#include "checkpoint/buffer_queue.h" +#include "utils/common.h" + +using namespace std; +using namespace MxRec; + +void LocalFileSystem::CreateDir(const string& dirName) +{ + if (access(dirName.c_str(), F_OK) == -1) { + if (mkdir(dirName.c_str(), dirMode) == -1) { + LOG_DEBUG("Unable to create directory: {}", dirName); + } + } +} + +vector LocalFileSystem::ListDir(const string& dirName) +{ + vector dirs; + DIR *dir = opendir(dirName.c_str()); + struct dirent* en; + if (dir == nullptr) { + LOG_WARN("Open directory {} failed while trying to traverse the directory.", dirName); + closedir(dir); + return dirs; + } + + for (en = readdir(dir); dir != nullptr; en = readdir(dir)) { + if (strncmp(en->d_name, currDir.c_str(), strlen(currDir.c_str())) != 0 && + strncmp(en->d_name, prevDir.c_str(), strlen(prevDir.c_str())) != 0) { + dirs.emplace_back(en->d_name); + } + closedir(dir); + } + return dirs; +} + +size_t LocalFileSystem::GetFileSize(const string& filePath) +{ + std::ifstream readFile; + readFile.open(filePath.c_str(), std::ios::in | std::ios::binary | std::ios::ate); + if (!readFile.is_open()) { + throw runtime_error(StringFormat("open file %s to get file size failed.", filePath.c_str())); + } + size_t datasetSize = static_cast(readFile.tellg()); + readFile.close(); + return datasetSize; +} + +ssize_t LocalFileSystem::Write(const string& filePath, const char* fileContent, size_t dataSize) +{ + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, fileMode); + if (fd == -1) { + throw runtime_error(StringFormat("open file %s to write failed.", filePath.c_str())); + } + + size_t dataCol = dataSize; + size_t writeSize = 0; + size_t idx = 0; + ssize_t writeBytesNum = 0; + + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = dataCol; + } + ssize_t res = write(fd, fileContent + idx, writeSize); + if (res == -1) { + close(fd); + return res; + } + dataCol -= writeSize; + idx += writeSize; + writeBytesNum += res; + } + close(fd); + return writeBytesNum; +} + +ssize_t LocalFileSystem::Write(const string& filePath, vector fileContent, size_t dataSize) +{ + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, fileMode); + if (fd == -1) { + throw runtime_error(StringFormat("open file %s to write failed.", filePath.c_str())); + } + + buffer.reserve(BUFFER_SIZE); + BufferQueue queue; + ssize_t writeBytesNum = 0; + std::thread writer(&LocalFileSystem::WriterFn, this, std::ref(queue), fd, std::ref(writeBytesNum)); + + size_t loops = fileContent.size(); + for (size_t i = 0; i < loops; i++) { + size_t idx = 0; + size_t writeSize = 0; + size_t dataCol = dataSize; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + writeSize = oneTimeReadWriteLen; + } else { + writeSize = dataCol; + } + FillToBuffer(queue, reinterpret_cast(fileContent[i]) + idx, writeSize); + dataCol -= writeSize; + idx += writeSize; + } + } + + // After all data has been processed, check if there is any data left in the buffer + if (!buffer.empty()) { + queue.Push(std::move(buffer)); + buffer.clear(); + } + + queue.Push(std::vector()); + writer.join(); + close(fd); + return writeBytesNum; +} + +/// 用于动态扩容模式下,往本地文件中写embedding +/// \param filePath 文件路径 +/// \param embeddingSize embedding的长度 +/// \param addressArr 存放embedding的地址vector +/// \param deviceId 运行的卡的id +/// \return +void LocalFileSystem::WriteEmbedding(const string& filePath, const int& embeddingSize, + const vector& addressArr, int deviceId) +{ + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, fileMode); + if (fd == -1) { + throw runtime_error(StringFormat("open file %s to write failed.", filePath.c_str())); + } + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + close(fd); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } + + for (size_t i = 0; i < addressArr.size(); i += keyAddrElem) { + vector row(embeddingSize); + int64_t address = addressArr.at(i); + float *floatPtr = reinterpret_cast(address); + + aclError ret; + try { + ret = aclrtMemcpy(row.data(), embeddingSize * sizeof(float), + floatPtr, embeddingSize * sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST); + } catch (std::exception& e) { + close(fd); + throw runtime_error(StringFormat("error happen when acl memory copy from device to host: %s", e.what())); + } + + if (ret != ACL_SUCCESS) { + close(fd); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); + } + + ssize_t result = write(fd, row.data(), embeddingSize * sizeof(float)); + if (result != embeddingSize * sizeof(float)) { + close(fd); + throw runtime_error("Error writing to local file, " + "please check the disk buffer or temporary folder space or file permissions!"); + } + } +#endif + close(fd); +} + +ssize_t LocalFileSystem::Read(const string& filePath, char* fileContent, size_t datasetSize) +{ + int fd = open(filePath.c_str(), O_RDONLY); + if (fd == -1) { + throw runtime_error(StringFormat("Failed to open read file: %s", filePath.c_str())); + } + + try { + ValidateReadFile(filePath, datasetSize); + } catch (const std::invalid_argument& e) { + close(fd); + throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + } + + size_t idx = 0; + size_t readSize = 0; + ssize_t readBytesNum = 0; + while (datasetSize != 0) { + if (datasetSize > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = datasetSize; + } + ssize_t res = read(fd, fileContent + idx, readSize); + if (res == -1) { + close(fd); + return res; + } + datasetSize -= readSize; + idx += readSize; + readBytesNum += readSize; + } + close(fd); + return readBytesNum; +} + +ssize_t LocalFileSystem::Read(const string& filePath, vector>& fileContent, size_t datasetSize) +{ + size_t embDataOuterSize = fileContent.capacity(); + auto onceReadByteSize { datasetSize / embDataOuterSize }; + + size_t mapByteSize; + size_t mapRowNum; + CalculateMapSize(datasetSize, mapByteSize, mapRowNum, onceReadByteSize); + + off_t offset = 0; + size_t remainBytes = datasetSize; + ssize_t readBytesNum = 0; + + int fd = open(filePath.c_str(), O_RDONLY); + if (fd == -1) { + throw runtime_error(StringFormat("Failed to open read file: %s", filePath.c_str())); + } + + for (size_t i = 0; i < embDataOuterSize; i += mapRowNum) { + // 如果剩余字节数小于每次映射的字节数,则更新每次映射的字节数和行数 + if (remainBytes < mapByteSize) { + mapByteSize = remainBytes; + mapRowNum = mapByteSize / onceReadByteSize; + } + + void* tempMappedData = mmap(nullptr, mapByteSize, PROT_READ, MAP_PRIVATE, fd, offset); + if (tempMappedData == MAP_FAILED) { + close(fd); + return -1; + } + readBytesNum += mapByteSize; + + char* mappedData = static_cast(tempMappedData); + + // 处理映射的数据 + try { + HandleMappedData(mappedData, mapRowNum, onceReadByteSize, fileContent, i); + } catch (const std::runtime_error& e) { + close(fd); + munmap(mappedData, mapByteSize); + throw runtime_error(StringFormat("handle mapped data error: %s", e.what())); + } + munmap(mappedData, mapByteSize); + + offset += mapByteSize; + remainBytes -= mapByteSize; + } + close(fd); + return readBytesNum; +} + +/// 用于动态扩容模式下,从本地文件中读取embedding +/// \param filePath 文件路径 +/// \param embeddingSize embedding的长度 +/// \param addressArr 存放embedding的地址vector +/// \param deviceId 运行的卡的id +/// \return +void LocalFileSystem::ReadEmbedding(const string& filePath, const int& embeddingSize, + vector& addressArr, int deviceId) +{ + std::ifstream readFile; + readFile.open(filePath.c_str(), std::ios::in | std::ios::binary | std::ios::ate); + if (!readFile.is_open()) { + throw runtime_error(StringFormat("open file %s to read failed.", filePath.c_str())); + } + + size_t datasetSize = static_cast(readFile.tellg()); + auto embHashMapSize = static_cast(datasetSize / sizeof(float) / embeddingSize); + readFile.seekg(0, std::ios::beg); + + try { + ValidateReadFile(filePath, datasetSize); + } catch (const std::invalid_argument& e) { + readFile.close(); + throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + } + +#ifndef GTEST + auto res = aclrtSetDevice(static_cast(deviceId)); + if (res != ACL_ERROR_NONE) { + readFile.close(); + throw runtime_error(StringFormat("Set device failed, device_id:%d", deviceId).c_str()); + } + + void *newBlock = nullptr; + aclError ret = aclrtMalloc(&newBlock, static_cast(datasetSize), ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + readFile.close(); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ret).c_str()); + } + + float *floatPtr = static_cast(newBlock); + addressArr.reserve(embHashMapSize); + + for (size_t i = 0, j = 0; i < embHashMapSize; i += keyAddrElem, ++j) { + vector row(embeddingSize); + readFile.read(reinterpret_cast(row.data()), embeddingSize * sizeof(float)); + aclError ec; + try { + ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), + row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + } catch (std::exception& e) { + readFile.close(); + throw runtime_error(StringFormat("error happen when acl memory copy from host to device: %s", e.what())); + } + + if (ec != ACL_SUCCESS) { + readFile.close(); + throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ec).c_str()); + } + int64_t address = reinterpret_cast(floatPtr + j * embeddingSize); + addressArr.push_back(address); + } +#endif + readFile.close(); +} + +void LocalFileSystem::WriterFn(BufferQueue& queue, int fd, ssize_t& writerBytesNum) +{ + while (true) { + queue.Pop(writeBuffer); + if (writeBuffer.size() == 0) { + break; + } + ssize_t res = write(fd, writeBuffer.data(), writeBuffer.size()); + if (res == -1) { + close(fd); + writerBytesNum = -1; + break; + } + writerBytesNum += res; + writeBuffer.clear(); + } +} + +void LocalFileSystem::FillToBuffer(BufferQueue& queue, const char* data, size_t dataSize) +{ + size_t dataIdx = 0; + while (dataIdx < dataSize) { + size_t remainingSpace = BUFFER_SIZE - buffer.size(); + if (dataSize - dataIdx <= remainingSpace) { + buffer.insert(buffer.cend(), data + dataIdx, data + dataSize); + return; + } else { + buffer.insert(buffer.cend(), data + dataIdx, data + dataIdx + remainingSpace); + queue.Push(std::move(buffer)); + if (buffer.capacity() < BUFFER_SIZE) { + buffer.reserve(BUFFER_SIZE); + } + dataIdx += remainingSpace; + } + } +} + +void LocalFileSystem::CalculateMapSize(off_t fileSize, size_t& mapByteSize, + size_t& mapRowNum, size_t onceReadByteSize) const +{ + // 每次映射的字节数 + mapByteSize = MAP_BYTE_SIZE; + // 确保mapByteSize是onceReadByteSize和pageSize的整数倍,确保每次映射的offset是页大小的整数倍 + long pageSize = sysconf(_SC_PAGESIZE); + if (pageSize == -1) { + throw std::runtime_error("Failed to get page size: " + std::string(strerror(errno))); + } + size_t lcmVal = std::lcm(onceReadByteSize, pageSize); + mapByteSize = (mapByteSize / lcmVal) * lcmVal; + + // 如果文件大小小于每次映射的字节数,则一次性映射,映射大小不是页大小整数倍的时候,mmap会自动向上取整,额外的字节会初始化成零 + if (fileSize <= mapByteSize) { + mapByteSize = fileSize; + } + + mapRowNum = mapByteSize / onceReadByteSize; +} + + +void LocalFileSystem::HandleMappedData(char* mappedData, size_t mapRowNum, size_t onceReadByteSize, + vector>& dst, size_t cnt) const +{ +#pragma omp parallel for + for (size_t j = 0; j < mapRowNum; ++j) { + size_t idx = 0; + size_t readSize = 0; + size_t dataCol = onceReadByteSize; + while (dataCol != 0) { + if (dataCol > oneTimeReadWriteLen) { + readSize = oneTimeReadWriteLen; + } else { + readSize = dataCol; + } + + errno_t err = memcpy_s(dst[cnt + j].data() + idx, readSize, + mappedData + j * onceReadByteSize + idx, readSize); + if (err != 0) { + throw std::runtime_error("Error execution memcpy_s: " + std::to_string(err)); + } + dataCol -= readSize; + idx += readSize; + } + } +} \ No newline at end of file diff --git a/src/tests/file_system/local_file_system_test.cpp b/src/tests/file_system/local_file_system_test.cpp new file mode 100644 index 00000000..410bb63a --- /dev/null +++ b/src/tests/file_system/local_file_system_test.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + */ + +#include + +#include "file_system/file_system_handler.h" +#include "file_system/local_file_system/local_file_system.h" + +using namespace std; +using namespace MxRec; +using namespace testing; + +TEST(LocalFileSystem, WriteAndReadFile) +{ + string filePath = "./write.data"; + vector writeData = {0, 1, 2, 3, 4, 5}; + auto fileSystemHandler = make_unique(); + auto fileSystemPtr = fileSystemHandler->Create(filePath); + ssize_t res = fileSystemPtr->Write(filePath, reinterpret_cast(writeData.data()), + writeData.size() * sizeof(int64_t)); + + ASSERT_EQ(writeData.size() * sizeof(int64_t), res); + vector readData = {}; + readData.reserve(6); + res = fileSystemPtr->Read(filePath, reinterpret_cast(readData.data()), + writeData.size() * sizeof(int64_t)); + ASSERT_EQ(writeData.size() * sizeof(int64_t), res); +} + -- Gitee From 812d439ab9fad441c31b68efbdfe74d42a3664df Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 17 Nov 2023 17:46:08 +0800 Subject: [PATCH 459/551] Match-id-e79712cc082983c3854b2ee23bf0bb0432012ad0 --- mx_rec/core/embedding.py | 18 + mx_rec/graph/__init__.py | 3 +- mx_rec/graph/acg_push_ops.py | 638 +++++++++++++++++++++++++++++++++++ 3 files changed, 658 insertions(+), 1 deletion(-) create mode 100644 mx_rec/graph/acg_push_ops.py diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index e3b648bb..e92c9458 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -848,6 +848,10 @@ def sparse_lookup(hashtable: SparseEmbedding, logger.info("Lookup: The table name is %s, and the value of `is_grad` in this lookup (lookup name is %s) is %s.", hashtable.table_name, name, is_grad) + # 对于向上找没有IteratorGetNext的孤儿ids需要标记,以便于后续ACGPushOpsToDataset工作 + if isinstance(ids, tf.Tensor): + ids = _tag_orphan_ids(ids) + with tf.compat.v1.variable_scope(scope_name): if isinstance(ids, FeatureSpec): # check whether the name of the table exists with FeatureSpec. @@ -903,3 +907,17 @@ def set_specific_value_for_non_valid_key(id_offsets: Optional[tf.Tensor], id_offsets_expand = tf.compat.v1.expand_dims(id_offsets >= 0, axis=-1) embeddings = tf.where(id_offsets_expand, embeddings, default_value) return embeddings + + +def _tag_orphan_ids(ids: tf.Tensor) -> tf.Tensor: + """ + 将孤儿ids使用identity操作创建ACG_PUSH_NODE前缀命名的标记节点,以便在PushOps时能找到。 + """ + graph_def = tf.compat.v1.get_default_graph().as_graph_def() + subgraph = tf.compat.v1.graph_util.extract_sub_graph(graph_def, [ids.op.name]) + for node in subgraph.node: + if node.name == 'IteratorGetNext': + return ids + new_ids = tf.identity(ids, name=f"ACG_PUSH_NODE_{ids.op.name}") + logger.info('Tag orphan op node: %s with %s.', ids, new_ids) + return new_ids diff --git a/mx_rec/graph/__init__.py b/mx_rec/graph/__init__.py index 22f0e96f..143dd645 100644 --- a/mx_rec/graph/__init__.py +++ b/mx_rec/graph/__init__.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -__all__ = ["modify_graph_and_start_emb_cache", "GraphModifierHook", "run"] +__all__ = ["modify_graph_and_start_emb_cache", "GraphModifierHook", "run", "ACGPushOpsToDatasetHook"] from mx_rec.graph.modifier import GraphModifierHook, modify_graph_and_start_emb_cache from mx_rec.graph.patch import run +from mx_rec.graph.acg_push_ops import ACGPushOpsToDatasetHook diff --git a/mx_rec/graph/acg_push_ops.py b/mx_rec/graph/acg_push_ops.py new file mode 100644 index 00000000..aa60a5d6 --- /dev/null +++ b/mx_rec/graph/acg_push_ops.py @@ -0,0 +1,638 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import os +import weakref +from dataclasses import dataclass +from typing import Dict, Tuple, FrozenSet, List, Set + +import tensorflow as tf +from tensorflow.python.framework import ops +from tensorflow.python.data.ops.dataset_ops import _VariantTracker +from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter +from tensorflow.python.framework.ops import Operation +from tensorflow.python.data.ops.dataset_ops import DatasetV2 +from tensorflow.python.util import nest as tf_nest +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import tensor_util + +from mx_rec.graph import modifier +from mx_rec.util.log import logger +from mx_rec.graph.utils import export_pb_graph +from mx_rec.constants.constants import ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE +from mx_rec.validator.validator import para_checker_decorator, ClassValidator + +tf.compat.v1.disable_eager_execution() + +_ACG_NEW_NODE_PREFIX = "ACG_" +_ACG_NEW_ITERATOR = "ACG_NEW_ITERATOR" +_ACG_NEW_INITIALIZER = "ACG_NEW_INITIALIZER" + +_OP_TYPE_TO_PUSH = frozenset(["StringSplit", "StringToNumber"]) +_OP_TYPE_TO_IGNORE = frozenset(["IteratorGetNext"]) +_OP_TYPE_CONTAIN_STRING_TO_IGNORE = frozenset(["Dataset", "Summary"]) +_OP_NAME_CONTAIN_STRING_TO_IGNORE = frozenset(["save", "report_", "loss"]) +_OP_NAME_CONTAIN_STRING_TO_PUSH = frozenset(["ACG_PUSH_NODE"]) + +_TENSOR_TYPE_TO_IGNORE = frozenset([tf.variant, tf.resource]) + +_VARIABLE_TYPES = frozenset(["Variable", "VariableV2", "VarHandleOp"]) +_IGNORE_REPLACE_NODE = frozenset(["Assign", "SaveV2"]) + + +@dataclass +class SubgraphInfo: + subgraph_in: Dict[tf.Operation, Set[tf.Operation]] + subgraph_out: Dict[tf.Operation, Set[tf.Operation]] + subgraph_to_push: Set[tf.Operation] + + +class ACGPushOpsToDatasetHook(tf.estimator.SessionRunHook): + @para_checker_decorator( + check_option_list=[ + ("dump_graph", ClassValidator, {"classes": (bool,)}), + ] + ) + def __init__(self, dump_graph: bool = False) -> None: + super().__init__() + self._dump_graph = dump_graph + + modifier.get_src_dataset = _patched_get_src_dataset + logger.info("[ACGPushOpsToDatasetHook] The function `get_src_dataset` of modifier has been replaced!") + + def begin(self): + logger.info("[ACGPushOpsToDataset] Trigger at beginning!") + graph = tf.compat.v1.get_default_graph() + _find_ops_to_be_pushed(graph=graph, dump_graph=self._dump_graph) + + def after_create_session(self, session, coord): + logger.info("[ACGPushOpsToDatasetHook] Trigger after create session!") + initializers = tf.compat.v1.get_collection(_ACG_NEW_INITIALIZER) + logger.info(f"[ACGPushOpsToDatasetHook] Got new initialzers: %s.", initializers) + session.run(initializers) + + def end(self, session): + logger.info("[ACGPushOpsToDatasetHook] Trigger in the end!") + + +def _find_ops_to_be_pushed(graph: tf.Graph, dump_graph: bool = False): + export_pb_graph("before_push_graph.pbtxt", dump_graph, graph_def=graph.as_graph_def()) + op_nodes = graph.get_operations() + nodes_to_push = set() + + for op_node in op_nodes: + if op_node.type in _OP_TYPE_TO_IGNORE: + continue + + pushable = False + if op_node.type in _OP_TYPE_TO_PUSH: + pushable = True + + for ignore_type in _OP_TYPE_CONTAIN_STRING_TO_IGNORE: + if ignore_type in op_node.type: + pushable = False + if not pushable: + continue + for ignore_name in _OP_NAME_CONTAIN_STRING_TO_IGNORE: + if ignore_name in op_node.name: + pushable = False + if not pushable: + continue + for each_tensor in list(op_node.outputs) + list(op_node.inputs): + if each_tensor.dtype in _TENSOR_TYPE_TO_IGNORE: + pushable = False + if not pushable: + continue + + for push_name in _OP_NAME_CONTAIN_STRING_TO_PUSH: + if push_name in op_node.name: + pushable = True + break + + if pushable: + nodes_to_push.add(op_node) + + if not nodes_to_push: + logger.info("No target op has to be pushed to dataset map func!") + return + + logger.info("Found operations should be pushed: %s.", nodes_to_push) + subgraph_nodes = _find_subgraph_nodes(graph, nodes_to_push, tgt_op_type="IteratorGetNext", exclude_tgt_op=True) + _push_subgraph_to_dataset(graph, subgraph_nodes, dump_graph) + export_pb_graph("after_push_graph.pbtxt", dump_graph, graph_def=graph.as_graph_def()) + + +def _find_subgraph_nodes( + graph: tf.Graph, + base_nodes: Set[tf.Operation], + tgt_op_type: str, + exclude_tgt_op: bool = True, +) -> Set[tf.Operation]: + subgraph_nodes = set() + visited_nodes = base_nodes + found_nodes = base_nodes + all_nodes = graph.get_operations() + logger.info("Got base_nodes: %s.", base_nodes) + + loop_cnt = 0 + while len(found_nodes) > 0: + loop_cnt += 1 + if loop_cnt > MAX_WHILE_SIZE: + raise RuntimeError(f"In bfs_lookup function, the maximum cycle depth is greater than {MAX_WHILE_SIZE}.") + + base_nodes = set() + for parent_node in found_nodes: + if (not exclude_tgt_op) and parent_node.type == tgt_op_type: + continue + base_nodes.add(parent_node) + found_nodes = set() + for base_node in base_nodes: + tmp_nodes = [x.op for x in base_node.inputs] + base_node.control_inputs + _warn_for_var_scope_nodes(all_nodes, base_node) + + tmp_nodes = set(tmp_nodes) - visited_nodes + if exclude_tgt_op: + tmp_nodes = set(filter(lambda node: node.type != tgt_op_type, tmp_nodes)) + found_nodes.update(tmp_nodes) + visited_nodes.update(tmp_nodes) + + subgraph_nodes.update(visited_nodes) + logger.info("Found subgraph from nodes_to_push: %s.", subgraph_nodes) + return subgraph_nodes + + +def _warn_for_var_scope_nodes(all_nodes: List[tf.Operation], base_node: tf.Operation): + if base_node.type in _VARIABLE_TYPES: + for x in base_node.outputs: + varable_scope_node = [x for x in all_nodes if x.name.startswith(f"{base_node.name}/")] + logger.warning("Got base_node: %s and varable_scope_node: %s.", base_node, varable_scope_node) + + +def _find_op_from_base_op(base_ops: tf.Operation, target_op_type: str) -> tf.Operation: + base_ops = modifier.check_input_list(base_ops, tf.Operation) + parent_ops = base_ops + while True: + for parent_op in parent_ops: + if parent_op.type == target_op_type: + return parent_op + base_ops = parent_ops + parent_ops = [] + for base_op in base_ops: + parent_ops.extend(modifier.find_parent_op(base_op)) + if not parent_ops: + raise ValueError(f"Op {target_op_type} was not found.") + + +def _get_dataset_op(graph: tf.Graph, get_next_op: Operation) -> Operation: + if get_next_op.type != "IteratorGetNext": + raise TypeError("Op '{get_next_op}' must be one instance of IteratorGetNext.") + # looking for the MakeIterator operator which corresponds to given batch_tensor + base_op = modifier.find_make_iterator_op(get_next_op.outputs[0]) + # looking for the op which is the one before OptimizeDataset operator + if tf.__version__.startswith("1"): + optimize_dataset_op = _find_op_from_base_op(base_op, "ModelDataset") + target_op = modifier.find_parent_op(optimize_dataset_op) + if not target_op: + raise RuntimeError(f"The parent op for 'ModelDataset' op was not found.") + if target_op[0].type != "OptimizeDataset": + raise TypeError(f"Op OptimizeDataset was not found.") + target_op = target_op[0] + else: + # 'OptimizeDataset' is not available in TensorFlow2.X + raise RuntimeError("Not supoprt tf2") + return target_op + + +def _ordered_output_from_subgraph(subgraph_out: Dict[tf.Operation, Set[tf.Operation]]) -> List[tf.Tensor]: + addition_funcgraph_output_tensor = [] + for k, v in sorted(subgraph_out.items(), key=lambda x: x[0].name): + k_inputs = set(k.inputs) + for node in v: + _add_sorted_additional_tensors(addition_funcgraph_output_tensor, k_inputs, node) + return addition_funcgraph_output_tensor + + +def _add_sorted_additional_tensors(addition_funcgraph_output_tensor, k_inputs, node): + for each_tensor in sorted(node.outputs, key=lambda x: x.name): + if each_tensor in k_inputs: + addition_funcgraph_output_tensor.append(each_tensor) + + +def _get_tensor_consumers_unsafe(tensor: tf.Tensor) -> List[tf.Operation]: + if isinstance(tensor, tf.Operation): + raise RuntimeError("not support type: {node}") + + from tensorflow.python import pywrap_tensorflow as c_api + + consumer_names = c_api.TF_OperationOutputConsumers_wrapper(tensor._as_tf_output()) + graph = tensor.graph + result = [] + for name in consumer_names: + with graph._lock: + if name in graph._nodes_by_name: # ignore deleted node + result.append(graph._nodes_by_name[name]) + + return result + + +def _push_subgraph_to_dataset(graph: tf.Graph, subgraph_to_push: Set[tf.Operation], dump_graph: bool = False): + subgraph_in, subgraph_out = _find_subgraph_in_out(subgraph_to_push) + logger.info("Got input tensor of extracted subgraph: %s", subgraph_in) + logger.info("Got output tensor of extracted subgraph: %s", subgraph_out) + + get_next_node = graph.get_operation_by_name("IteratorGetNext") + src_dataset = _get_src_dataset(graph, get_next_node) + + def acg_func(*x): + old_x = x + logger.debug("Got old batch layout: %s", x) + + x = tf_nest.flatten(x) + for each_tensor in x: + if not isinstance(each_tensor, tf.Tensor): + raise RuntimeError(f"Expected tensor as input of mapfunc. but got: {x}!") + + funcgraph = tf.compat.v1.get_default_graph() + subgraph_info = SubgraphInfo(subgraph_in, subgraph_out, subgraph_to_push) + new_batch = _clone_subgraph_into_funcgraph( + funcgraph, + graph, + subgraph_info, + x, + old_x, + ) + + logger.debug("Got new batch layout: %s.", new_batch) + export_pb_graph("map_func_graph.pbtxt", dump_graph, graph_def=funcgraph.as_graph_def()) + return new_batch + + tgt_dataset = src_dataset.map(acg_func) + tgt_dataset = tgt_dataset.prefetch(0) + _update_iterator_getnext( + graph=graph, + get_next_op=get_next_node, + tgt_dataset=tgt_dataset, + subgraph_out=subgraph_out, + subgraph_to_push=subgraph_to_push, + ) + + +def _find_subgraph_in_out( + sub_graph_nodes: Set[tf.Operation], +) -> Tuple[Dict[tf.Operation, Set[tf.Operation]], Dict[tf.Operation, Set[tf.Operation]]]: + relay_input_nodes = set() + relay_output_nodes = set() + input_to_subnodes = dict() + output_to_subnodes = dict() + + for base_node in sub_graph_nodes: + _update_subgraph_in(base_node, input_to_subnodes, relay_input_nodes, sub_graph_nodes) + _update_subgraph_out(base_node, output_to_subnodes, relay_output_nodes, sub_graph_nodes) + + return input_to_subnodes, output_to_subnodes + + +def _update_subgraph_in( + base_node: tf.Operation, + input_to_subnodes: Dict[tf.Operation, Set[tf.Operation]], + relay_input_nodes: Set[tf.Operation], + sub_graph_nodes: Set[tf.Operation], +): + for input_tensor in base_node.inputs: + input_node = input_tensor.op + if input_node not in sub_graph_nodes: + relay_input_nodes.add(input_node) + res = input_to_subnodes.get(input_node, set()) + res.add(base_node) + input_to_subnodes[input_node] = res + + +def _update_subgraph_out( + base_node: tf.Operation, + output_to_subnodes: Dict[tf.Operation, Set[tf.Operation]], + relay_output_nodes: Set[tf.Operation], + sub_graph_nodes: Set[tf.Operation], +): + for output_tensor in base_node.outputs: + for output_consumer in output_tensor.consumers(): + if output_consumer not in sub_graph_nodes: + relay_output_nodes.add(output_consumer) + res = output_to_subnodes.get(output_consumer, set()) + res.add(base_node) + output_to_subnodes[output_consumer] = res + + +def _get_src_dataset(graph: tf.Graph, get_next_op: Operation) -> DatasetV1Adapter: + try: + target_op = _get_dataset_op(graph, get_next_op) + except (ValueError, TypeError, RuntimeError) as err: + logger.warning("The dataset op was not found, the error is %s. Start to traverse the operations.", err) + dataset_op_list = [op for op in graph.get_operations() if ANCHOR_DATASET_NAME in op.name] + if len(dataset_op_list) != 1: + raise RuntimeError( + f"The `{ANCHOR_DATASET_NAME}` was not found from the operations, dataset_op_list: " + f"{dataset_op_list}." + ) from err + target_op = dataset_op_list[0] + except Exception as err: + raise RuntimeError(f"The dataset was not found, the error is `{err}`.") from err + if not target_op.outputs: + raise ValueError(f"The length of the outputs of target op `{target_op}` is 0.") + logger.info("Find target op `%s`, and output is `%s`.", target_op.name, target_op.outputs) + src_dataset = modifier.find_target_instance_dataset(target_op.outputs[0]) + return src_dataset + + +def _clone_subgraph_into_funcgraph( + funcgraph: tf.Graph, + defaultgraph: tf.Graph, + subgraph_info: SubgraphInfo, + x: List[tf.Tensor], + old_x: Tuple[Dict[str, tf.Tensor]], +) -> Dict[str, tf.Tensor]: + topo_subgraph_list = _topo_subgraph(subgraph_info.subgraph_to_push) # node + tensor_mapping = {} # subgraph-tensor -> funcgraph-tensor + node_mapping = {} # subgraph-node -> funcgraph-node + for k, v in subgraph_info.subgraph_in.items(): + _get_mapping_for_subgraph_in(k, v, x, tensor_mapping) + for old_node in topo_subgraph_list: + _get_mapping_for_subgraph(funcgraph, defaultgraph, node_mapping, old_node, tensor_mapping) + + logger.info("Got node_mapping: %s", node_mapping) + logger.info("Got tensor_mapping: %s", tensor_mapping) + + ordered_output_subgraph_tensors = _ordered_output_from_subgraph(subgraph_info.subgraph_out) + addition_funcgraph_output_tensor = _get_mapping_tensor(tensor_mapping, ordered_output_subgraph_tensors) + new_funcgraph_output_tensor = list(x) + addition_funcgraph_output_tensor + logger.info("Got new_funcgraph_output_tensor: %s", new_funcgraph_output_tensor) + + new_x = old_x[0] + for tensor in addition_funcgraph_output_tensor: + last_key = f"{sorted(new_x)[-1]}_last_key" + new_x[last_key] = tensor + + return new_x + + +def _get_mapping_for_subgraph_in( + from_node: tf.Operation, to_nodes: Set[tf.Operation], x: List[tf.Tensor], tensor_mapping +): + if from_node.type != "IteratorGetNext": + raise RuntimeError(f"Expect IteratorGetNext for input tensor of subgraph, but got {from_node}") + for node in to_nodes: + for each_tensor in node.inputs: + if each_tensor.op.type != "IteratorGetNext": + continue + old_tensor_name = each_tensor.name + x_index = int(old_tensor_name.split(":")[-1]) + tensor_mapping[each_tensor] = x[x_index] + + +def _get_mapping_for_subgraph( + funcgraph: tf.Graph, + defaultgraph: tf.Graph, + node_mapping: Dict[tf.Operation, tf.Operation], + old_node: tf.Operation, + tensor_mapping: Dict[tf.Tensor, tf.Tensor], +): + logger.debug("old_node: %s \n old_node_inputs: %s", old_node, [x for x in old_node.inputs]) + node_def = old_node.node_def + for each_tensor in old_node.inputs: + if each_tensor not in tensor_mapping: + raise RuntimeError( + f"each_tensor(input) {each_tensor} need by {old_node.name} not in tensor_mapping.{tensor_mapping}" + ) + new_inputs = _get_mapping_tensor(tensor_mapping, old_node.inputs) + if old_node.type in _VARIABLE_TYPES: + node_def = _frozen_variable_node_to_func_const_node_def( + variable_node=old_node, funcgraph=funcgraph, defaultgraph=defaultgraph + ) + node_def.name = _ACG_NEW_NODE_PREFIX + node_def.name + new_node = tf.Operation(node_def=node_def, g=funcgraph, inputs=new_inputs) + node_mapping[old_node] = new_node + for old_out_tensor, new_out_tensor in zip(old_node.outputs, new_node.outputs): + tensor_mapping[old_out_tensor] = new_out_tensor + + +def _frozen_variable_node_to_func_const_node_def( + variable_node: tf.Operation, funcgraph: tf.Graph, defaultgraph: tf.Graph +) -> node_def_pb2.NodeDef: + def create_const_node_def(node_name, dtype, data, data_shape=None): + """Creates a Const op.""" + output_node = node_def_pb2.NodeDef() + output_node.op = "Const" + output_node.name = node_name + output_node.attr["dtype"].CopyFrom(dtype) + output_node.attr["value"].CopyFrom( + attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data, dtype=dtype.type, shape=data_shape)) + ) + return output_node + + # NOTE: Variable node type is readonly in funcgraph, all nodes of this type have to be fronzen. + variable_name = variable_node.name + if variable_node.type == "VarHandleOp": + variable_name = f"{variable_name}/Read/ReadVariableOp:0" + else: + variable_name = f"{variable_name}:0" + initializer = defaultgraph.get_operation_by_name(f"{variable_node.name}/Assign") + logger.info(f"VariableV2: {variable_node.name}, initializer: {initializer.name} ") + defaultsession = tf.compat.v1.Session(graph=defaultgraph) + _ = defaultsession.run([initializer]) + logger.info(f"Start run variables data: {variable_name}") + returned_variable_data = defaultsession.run(variable_name) + logger.info(f"Start froze variables: {variable_name} {returned_variable_data}") + new_const_node = create_const_node_def( + variable_node.name, variable_node.node_def.attr["dtype"], returned_variable_data, returned_variable_data.shape + ) + return new_const_node + + +def _get_mapping_tensor(tsr2tsr: Dict[tf.Tensor, tf.Tensor], keys: List[tf.Tensor]) -> List[tf.Tensor]: + tensors = [] + for k in keys: + if k not in tsr2tsr: + raise KeyError(f"Failed to find key tensor: {k} from tensor map: {tsr2tsr}.") + tensors.append(tsr2tsr[k]) + return tensors + + +def _topo_subgraph(subgraph: Set[tf.Operation]) -> List[tf.Operation]: + topo_subgraph_list = [] + topo_subgraph_set = set() + start_nodes = set() + [start_nodes.add(x) for x in subgraph] + logger.info("Got topo_subgraph start nodes: %s", start_nodes) + + def topo_subgraph_dfs(curr_node, output_list, output_set): + if not isinstance(curr_node, tf.Operation): + raise RuntimeError(f"topo_subgraph_dfs input should be node(aka. tf.Operator). {curr_node}") + curr_inputs = curr_node.inputs + logger.debug("Got topo_dfs: %s <- %s", curr_node.name, [x.name for x in curr_inputs]) + current_control_inputs = curr_node.control_inputs + if len(current_control_inputs) > 0: + raise RuntimeError( + f"Control input are not supported: {curr_node.name}, control_inputs: {current_control_inputs}" + ) + if curr_node in output_set: + return + output_set.add(curr_node) + for tensor in curr_inputs: + node = tensor.op + if node.type != "IteratorGetNext" and node not in output_set: + topo_subgraph_dfs(node, output_list, output_set) + output_list.append(curr_node) + + [topo_subgraph_dfs(x, topo_subgraph_list, topo_subgraph_set) for x in start_nodes] + if len(topo_subgraph_list) != len(topo_subgraph_set): + raise RuntimeError(f"Got duplicated topo node: {sorted(topo_subgraph_list, key=lambda x: x.name)}.") + logger.info("Got topo_subgraph: %s", topo_subgraph_list) + return topo_subgraph_list + + +def _update_iterator_getnext( + graph: tf.Graph, + get_next_op: Operation, + tgt_dataset: DatasetV1Adapter, + subgraph_out: Dict[tf.Operation, Set[tf.Operation]], + subgraph_to_push: Set[tf.Operation], +): + if not get_next_op.outputs: + raise RuntimeError("There is no tensor in the dataset. Please check the dataset and data processing.") + iterator_type = "" + if get_next_op.inputs: + iterator_type = get_next_op.inputs[0].op.type + if iterator_type == "IteratorV2": + iterator_type = modifier.find_make_iterator_op(get_next_op.outputs[0]).type + if iterator_type not in ("MakeIterator", "OneShotIterator"): + raise RuntimeError( + f"Only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, " + f"but the current iterator is `{iterator_type}`." + ) + logger.info("The iterator type of dataset is %s.", iterator_type) + if iterator_type == "MakeIterator": + new_iterator = tgt_dataset.make_initializable_iterator() + logger.info("Got new_iterator: %s, new_iterator.initializer: %s.", new_iterator, new_iterator.initializer) + graph.add_to_collection(_ACG_NEW_INITIALIZER, new_iterator.initializer) + else: + new_iterator = tgt_dataset.make_one_shot_iterator() + new_batch = new_iterator.get_next(_ACG_NEW_ITERATOR) + if "timestamp" in new_batch.keys(): + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, new_batch["timestamp"]) + try: + new_batch_tensor = new_batch + while not isinstance(new_batch_tensor, tf.Tensor): + if isinstance(new_batch_tensor, tuple): + new_batch_tensor = new_batch_tensor[0] + elif isinstance(new_batch_tensor, dict): + new_batch_tensor = list(new_batch_tensor.values()) + elif isinstance(new_batch_tensor, list): + new_batch_tensor = new_batch_tensor[0] + elif isinstance(new_batch_tensor, tf.Tensor): + break + else: + raise RuntimeError( + f"Need to support new_batch_tensor{new_batch_tensor}, type: {type(new_batch_tensor)}" + ) + except IndexError as err: + raise IndexError("Cannot find a tensor from given batch.") from err + new_get_next_op = _find_op_from_base_op(new_batch_tensor.op, "IteratorGetNext") + logger.info("Got new_get_next_op: %s.", new_get_next_op) + _replace_get_next_op(graph, get_next_op, new_get_next_op, subgraph_out, subgraph_to_push) + + +def _replace_get_next_op( + graph: tf.Graph, + old_get_next_op: tf.Operation, + new_get_next_op: tf.Operation, + subgraph_out: Dict[tf.Operation, Set[tf.Operation]], + subgraph_to_push: Set[tf.Operation], +): + for output_tensor in old_get_next_op.outputs: + _update_old_consumer(graph, new_get_next_op, output_tensor, subgraph_to_push) + + old_get_next_op_output_size = len(old_get_next_op.outputs) + ordered_output_tensor = _ordered_output_from_subgraph(subgraph_out) + + for i, output_tensor in enumerate(ordered_output_tensor): + offset = old_get_next_op_output_size + i + _update_subgraph_out_consumer(graph, new_get_next_op, offset, output_tensor) + + +def _update_old_consumer( + graph: tf.Graph, new_get_next_op: tf.Operation, output_tensor: tf.Tensor, subgraph_to_push: List[tf.Operation] +): + old_tensor_name = output_tensor.name + output_index = old_tensor_name.split(":")[-1] + new_tensor_name = f"{new_get_next_op.name}:{output_index}" + logger.info("Replace old_tensor_name: %s to new_tensor_name: %s", old_tensor_name, new_tensor_name) + new_tensor = graph.get_tensor_by_name(new_tensor_name) + for output_consumer in _get_tensor_consumers_unsafe(output_tensor): + if output_consumer in subgraph_to_push: + logger.info( + "Ignore consumer in old subgraph %s, not let it connect to new IteratorGetNext.", output_consumer + ) + continue + for i, consumer_input in enumerate(output_consumer.inputs): + if consumer_input != output_tensor: + logger.debug("Not replace output_consumer: %s consumer_input: %s.", output_consumer, consumer_input) + continue + logger.info( + "Success replace output_consumer: %s type: %s from consumer_input: %s to new_tensor: %s", + output_consumer.name, + output_consumer.type, + consumer_input, + new_tensor, + ) + output_consumer._update_input(i, new_tensor) + + +def _update_subgraph_out_consumer( + graph: tf.Graph, new_get_next_op: tf.Operation, offset: int, output_tensor: tf.Tensor +): + new_tensor_name = f"{new_get_next_op.name}:{offset}" + logger.info("Replace old_tensor_name: %s to new_tensor_name: %s.", output_tensor.name, new_tensor_name) + new_tensor = graph.get_tensor_by_name(new_tensor_name) + for output_consumer in _get_tensor_consumers_unsafe(output_tensor): + if output_consumer.type in _IGNORE_REPLACE_NODE: + logger.info("Ignore replace output_consumer: %s, it's of type: %s.", output_consumer, output_consumer.type) + continue + for j, consumer_input in enumerate(output_consumer.inputs): + if consumer_input != output_tensor: + logger.debug("Not replace output_consumer: %s consumer_input: %s.", output_consumer, consumer_input) + continue + logger.info( + "Success replace output_consumer: %s type: %s from consumer_input: %s to new_tensor: %s", + output_consumer.name, + output_consumer.type, + consumer_input, + new_tensor, + ) + output_consumer._update_input(j, new_tensor) + + +def _patched_get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapter: + try: + target_op = modifier.get_dataset_op(get_next_op) + except (ValueError, TypeError, RuntimeError) as err: + logger.debug("In `OneShotIterator` mode, find `PrefetchDataset` from all ops in graph.") + graph = tf.compat.v1.get_default_graph() + dataset_op_list = [op for op in graph.get_operations() if ANCHOR_DATASET_NAME in op.name] + dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) + logger.debug("Got sorted dataset_op_list: %s.", dataset_op_list) + if len(dataset_op_list) != 2: + raise RuntimeError( + f"Expect two `PrefetchDataset` ops in dataset_op_list, but got: {dataset_op_list}." + ) from err + target_op = dataset_op_list[1] + except Exception as err: + raise RuntimeError(f"The source dataset can't be found, got error: {err}.") from err + + if not target_op.outputs: + raise ValueError(f"The length of the outputs of target op `{target_op}` is 0.") + + logger.debug("Find target dataset op: %s, and output is %s.", target_op, target_op.outputs) + src_dataset = modifier.find_target_instance_dataset(target_op.outputs[0]) + + return src_dataset -- Gitee From eb680b603e5e51583ed3a28e4015248e67f0495e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 17 Nov 2023 14:40:24 +0800 Subject: [PATCH 460/551] Match-id-1ac5ccf7408c18e1d42f3a5830281b23f51cb991 --- mx_rec/constants/constants.py | 1 + mx_rec/saver/saver.py | 10 +- src/core/checkpoint/checkpoint.cpp | 109 ++++++++++-------- src/core/checkpoint/checkpoint.h | 6 - .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp | 5 +- src/core/file_system/file_system.h | 2 +- .../local_file_system/local_file_system.cpp | 1 - .../local_file_system/local_file_system.h | 1 + src/core/hybrid_mgmt/hybrid_mgmt.cpp | 2 +- src/tests/checkpoint/checkpoint_test.cpp | 19 ++- 10 files changed, 87 insertions(+), 69 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 59b1ef27..5b555b2e 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -72,6 +72,7 @@ MAX_DEVICE_ID = 15 # HDFS file system's file prefix HDFS_FILE_PREFIX = ["viewfs://", "hdfs://"] + class BaseEnum(Enum): @classmethod def mapping(cls, key): diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 80fb2a52..a8a14037 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -15,7 +15,8 @@ from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SI MAX_INT32, HDFS_FILE_PREFIX from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ - send_host_data, get_ascend_global_hashtable_collection, set_sparse_dir, get_local_rank_size, get_use_dynamic_expansion + send_host_data, get_ascend_global_hashtable_collection, set_sparse_dir, get_local_rank_size, \ + get_use_dynamic_expansion from mx_rec.util.perf import performance from mx_rec.validator.validator import DirectoryValidator, FileValidator, para_checker_decorator, ClassValidator, \ IntValidator, OptionalStringValidator @@ -210,7 +211,7 @@ class Saver(object): host_data = get_host_data(table_name) offset = list(host_data) - get_valid_dict_data(dump_data_dict, offset) + get_valid_dict_data_from_host_offset(dump_data_dict, offset) def _build_save(self): for var in self.var_list: @@ -301,7 +302,7 @@ class NameDescriptor: self.optimizer_name = optimizer_name -def get_valid_dict_data(dump_data_dict: dict, offset: list): +def get_valid_dict_data_from_host_offset(dump_data_dict: dict, offset: list): """ Extract embedding and optimizer data from the dict based on offset. :param dump_data_dict: sparse data dict to be saved @@ -502,8 +503,9 @@ def check_file_system_is_valid(file_path): return True return False + def check_file_system_is_hdfs(file_path): for prefix in HDFS_FILE_PREFIX: - if file_path.startwith(prefix): + if file_path.startswith(prefix): return True return False diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 57d17f24..0f6b039e 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -150,6 +150,10 @@ void Checkpoint::MakeDataLayerSaveDir(const vector& embNames, void Checkpoint::MakeSaveDir(const string& dirName) const { + if (fileSystemPtr == nullptr) { + LOG_WARN("please init file system pointer before using. "); + throw runtime_error("Nullptr. file system pointer is not initialized. "); + } fileSystemPtr->CreateDir(dirName); } @@ -199,7 +203,7 @@ void Checkpoint::SaveDataset(const vector& embNames, auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; auto embeddingSizeInfo = GetEmbeddingSize(embName); MakeSaveDir(embedPath); - LOG_DEBUG("====Start saving embedding data to: {}", datasetDir); + LOG_DEBUG("====Start saving embedding data to: {}", embedPath); WriteEmbedding(transData, embedDatasetDir, embeddingSizeInfo.extEmbSize); } @@ -214,16 +218,24 @@ void Checkpoint::SaveDataset(const vector& embNames, void Checkpoint::WriteEmbedding(const CkptTransData& transData, const string& dataDir, const int& embeddingSize) { auto &transArr = transData.addressArr; + if (fileSystemPtr == nullptr) { + LOG_WARN("please init file system pointer before using. "); + throw runtime_error("Nullptr. file system pointer is not initialized. "); + } fileSystemPtr->WriteEmbedding(dataDir, embeddingSize, transArr, deviceId); } void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, const string& embName) { + if (fileSystemPtr == nullptr) { + LOG_WARN("please init file system pointer before using. "); + throw runtime_error("Nullptr. file system pointer is not initialized. "); + } + auto datasetSize = fileSystemPtr->GetFileSize(dataDir); auto &attributeArr = transData.attribute; auto embHashMapSize = attributeArr.at(0); if (embHashMapSize <= 0) { - throw runtime_error(StringFormat("Invalid EmbHashMapSize:%d, must be greater than 0", embHashMapSize).c_str()); } @@ -233,26 +245,36 @@ void Checkpoint::ReadEmbedding(CkptTransData& transData, const string& dataDir, EmbSizeInfo embSizeInfo = GetEmbeddingSize(embName); if (embeddingSize != embSizeInfo.extEmbSize) { throw runtime_error(StringFormat("Invalid embedding size to be read, may read file has been changed").c_str()); - } fileSystemPtr->ReadEmbedding(dataDir, embeddingSize, transArr, deviceId); - } void Checkpoint::WriteStream(CkptTransData& transData, const string& dataDir, size_t dataSize, CkptDataType dataType) { - LOG_ERROR("lff debug write stream {} dataset_size {}", dataType, dataSize); + if (fileSystemPtr == nullptr) { + LOG_WARN("please init file system pointer before using. "); + throw runtime_error("Nullptr. file system pointer is not initialized. "); + } + + ssize_t writeBytesNum; if (floatTransSet.find(dataType) != floatTransSet.end()) { - fileSystemPtr->Write(dataDir, transData.floatArr, dataSize); + writeBytesNum = fileSystemPtr->Write(dataDir, transData.floatArr, dataSize); } else if (int32TransSet.find(dataType) != int32TransSet.end()) { - fileSystemPtr->Write(dataDir, reinterpret_cast(transData.int32Arr.data()), dataSize); + writeBytesNum = fileSystemPtr->Write(dataDir, + reinterpret_cast(transData.int32Arr.data()), dataSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - fileSystemPtr->Write(dataDir, reinterpret_cast(transData.int64Arr.data()), dataSize); + writeBytesNum = fileSystemPtr->Write(dataDir, + reinterpret_cast(transData.int64Arr.data()), dataSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - fileSystemPtr->Write(dataDir, reinterpret_cast(transData.attribute.data()), dataSize); + writeBytesNum = fileSystemPtr->Write(dataDir, + reinterpret_cast(transData.attribute.data()), dataSize); + } + + if (writeBytesNum == -1) { + LOG_ERROR("error happened when writing data to file."); + throw runtime_error("error happened when writing data to file."); } - LOG_ERROR("lff debug write stream {} dataset_dir {} over!", dataType, dataDir); } @@ -297,6 +319,10 @@ vector Checkpoint::GetEmbedTableNames() vector Checkpoint::GetTableLayerLoadDir() { vector loadTableDir; + if (fileSystemPtr == nullptr) { + LOG_WARN("please init file system pointer before using. "); + throw runtime_error("Nullptr. file system pointer is not initialized. "); + } loadTableDir = fileSystemPtr->ListDir(innerDirPath); return loadTableDir; } @@ -356,6 +382,12 @@ void Checkpoint::ReadStream(CkptTransData& transData, LOG_WARN("dataElmtBytes is 0, don't handle [/ %] operation"); return ; } + + if (fileSystemPtr == nullptr) { + LOG_WARN("please init file system pointer before using. "); + throw runtime_error("Nullptr. file system pointer is not initialized. "); + } + size_t datasetSize = fileSystemPtr->GetFileSize(dataDir); auto resizeSize { datasetSize / dataElmtBytes }; SetTransDataSize(transData, resizeSize, dataType); @@ -363,32 +395,24 @@ void Checkpoint::ReadStream(CkptTransData& transData, if (datasetSize % dataElmtBytes > 0) { LOG_DEBUG("data is missing or incomplete in load file: {}", dataDir); } - LOG_ERROR("lff debug start to enter read type:{}, size:{}", dataType, datasetSize); - + ssize_t readBytesNum; if (int32TransSet.find(dataType) != int32TransSet.end()) { - fileSystemPtr->Read(dataDir, reinterpret_cast(transData.int32Arr.data()), datasetSize); + readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.int32Arr.data()), datasetSize); } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - fileSystemPtr->Read(dataDir, reinterpret_cast(transData.int64Arr.data()), datasetSize); + readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.int64Arr.data()), datasetSize); } else if (floatTransSet.find(dataType) != floatTransSet.end()) { - fileSystemPtr->Read(dataDir, reinterpret_cast(transData.floatArr.data()), datasetSize); + readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.floatArr.data()), datasetSize); } else if (dataType == CkptDataType::ATTRIBUTE) { - fileSystemPtr->Read(dataDir, reinterpret_cast(transData.attribute.data()), datasetSize); + readBytesNum = fileSystemPtr->Read(dataDir, reinterpret_cast(transData.attribute.data()), datasetSize); } -} - -void Checkpoint::ValidateFile(int fd, const string& dataDir, size_t datasetSize) const -{ - try { - ValidateReadFile(dataDir, datasetSize); - } catch (const std::invalid_argument& e) { - close(fd); - throw runtime_error(StringFormat("Invalid read file path: %s", e.what())); + if (readBytesNum == -1) { + LOG_ERROR("error happened when reading data from file."); + throw runtime_error("error happened when reading data from file."); } } - void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, uint32_t dataElmtBytes, @@ -399,17 +423,21 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, LOG_ERROR("dataElmtBytes is 0, don't handle [/ %] operation"); return ; } + + if (fileSystemPtr == nullptr) { + LOG_WARN("please init file system pointer before using. "); + throw runtime_error("Nullptr. file system pointer is not initialized. "); + } + auto embDataOuterSize = transData.attribute.at(attribEmbDataOuterIdx); - LOG_ERROR("1102debug read emb data :{}", embDataOuterSize); + if (embDataOuterSize <= 0 || embDataOuterSize > MAX_VOCABULARY_SIZE) { throw runtime_error(StringFormat("Invalid embDataOuterSize :%d", embDataOuterSize).c_str()); } size_t datasetSize = fileSystemPtr->GetFileSize(dataDir); - if (datasetSize % embDataOuterSize > 0 || datasetSize % dataElmtBytes > 0) { LOG_ERROR("data is missing or incomplete in load file: {}", dataDir); - throw runtime_error("unable to load EMB_DATA cause wrong-format saved emb data"); } @@ -417,8 +445,12 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, auto& dst = (*loadHostEmbs)[embName].embData; dst.reserve(embDataOuterSize); + ssize_t readBytesNum; fileSystemPtr->Read(dataDir, dst, datasetSize); - + if (readBytesNum == -1) { + LOG_ERROR("error happened when reading data from file."); + throw runtime_error("error happened when reading data from file."); + } } void Checkpoint::SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType) @@ -433,20 +465,3 @@ void Checkpoint::SetTransDataSize(CkptTransData& transData, size_t datasetSize, transData.attribute.resize(datasetSize); } } - -void Checkpoint::ReadDataset(CkptTransData& transData, - ifstream& readFile, - size_t readSize, - CkptDataType dataType, - size_t idx) -{ - if (int32TransSet.find(dataType) != int32TransSet.end()) { - readFile.read(reinterpret_cast(transData.int32Arr.data()) + idx, readSize); - } else if (int64TransSet.find(dataType) != int64TransSet.end()) { - readFile.read(reinterpret_cast(transData.int64Arr.data()) + idx, readSize); - } else if (floatTransSet.find(dataType) != floatTransSet.end()) { - readFile.read(reinterpret_cast(transData.floatArr.data()) + idx, readSize); - } else if (dataType == CkptDataType::ATTRIBUTE) { - readFile.read(reinterpret_cast(transData.attribute.data()) + idx, readSize); - } -} diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 1be72faa..94ca7e68 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -105,16 +105,10 @@ namespace MxRec { void LoadDataset(const vector& embNames, const vector& saveDataTypes, const unique_ptr& dataHandler, CkptData& ckptData); void ReadStream(CkptTransData& transData, const string& dataDir, CkptDataType dataType, uint32_t dataElmtBytes); - void ValidateFile(int fd, const string& dataDir, size_t datasetSize) const; - void HandleMappedData(char* mappedData, size_t mapRowNum, size_t onceReadByteSize, - vector>& dst, size_t cnt) const; - void CalculateMapSize(off_t fileSize, size_t& mapByteSize, size_t& mapRowNum, size_t onceReadByteSize) const; void ReadStreamForEmbData(CkptTransData& transData, const string& dataDir, uint32_t dataElmtBytes, CkptData& ckptData, string embName) const; void SetTransDataSize(CkptTransData& transData, size_t datasetSize, CkptDataType dataType); - void ReadDataset(CkptTransData& transData, ifstream& readFile, size_t readSize, CkptDataType dataType, - size_t idx); }; } diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp index e6698746..66eb4ecc 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp @@ -66,11 +66,10 @@ CkptTransData NddrFeatMapCkpt::GetDataset(CkptDataType dataType, string embName) transArr.reserve(embHashMapSize); (*offsetMapPtr)[embName].clear(); - LOG_ERROR("build offset map : first key offset {}", saveKeyOffsetMap[embName][0]); for (const auto& it : saveKeyOffsetMap.at(embName)) { transArr.push_back(it.first); - transArr.push_back(it.second); (*offsetMapPtr)[embName].push_back(it.second); + addressArr.push_back(it.second); } LOG_INFO("CkptDataType::EMB_INFO:{}, dataType:{}", CkptDataType::NDDR_FEATMAP, dataType); return move(transferData); @@ -93,7 +92,7 @@ void NddrFeatMapCkpt::SetDataset(CkptDataType dataType, string embName, CkptTran if (addressArr.size() == 0) { // no dynamic expansion hostHashMap[key] = offset; - } else{ + } else { // dynamic expansion hostHashMap[key] = addressArr.at(i); } diff --git a/src/core/file_system/file_system.h b/src/core/file_system/file_system.h index 559d8442..ab6a2204 100644 --- a/src/core/file_system/file_system.h +++ b/src/core/file_system/file_system.h @@ -8,7 +8,7 @@ #ifndef MX_REC_FILE_SYSTEM_H #define MX_REC_FILE_SYSTEM_H -#include "checkpoint/buffer_queue.h" +#include "utils/common.h" namespace MxRec { using namespace std; diff --git a/src/core/file_system/local_file_system/local_file_system.cpp b/src/core/file_system/local_file_system/local_file_system.cpp index 51648fa4..493440a4 100644 --- a/src/core/file_system/local_file_system/local_file_system.cpp +++ b/src/core/file_system/local_file_system/local_file_system.cpp @@ -14,7 +14,6 @@ #include #include -#include "checkpoint/buffer_queue.h" #include "utils/common.h" using namespace std; diff --git a/src/core/file_system/local_file_system/local_file_system.h b/src/core/file_system/local_file_system/local_file_system.h index d6346a1d..e951dd8f 100644 --- a/src/core/file_system/local_file_system/local_file_system.h +++ b/src/core/file_system/local_file_system/local_file_system.h @@ -9,6 +9,7 @@ #define MX_REC_LOCAL_FILE_SYSTEM_H #include "file_system/file_system.h" +#include "file_system/buffer_queue.h" namespace MxRec { using namespace std; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 48f274a2..e678ce59 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -317,6 +317,7 @@ bool HybridMgmt::Load(const string& loadPath) // HBM模式 将加载的最大偏移(真正使用了多少vocab容量)、特征到偏移的映射,进行赋值 LOG_DEBUG(MGMT + "Start host side load: no ddr mode hashmap"); preprocess->LoadKeyOffsetMap(loadData.keyOffsetMap); + preprocess->LoadMaxOffset(loadData.maxOffset); } // 将加载的特征准入淘汰记录进行赋值 @@ -376,7 +377,6 @@ OffsetT HybridMgmt::SendHostMap(const string tableName) OffsetT OffsetMap; // 先校验这个map是不是空的 if ((!offsetMapToSend.empty()) && offsetMapToSend.count(tableName) > 0) { - LOG_ERROR("send offset map : table name =={} offset count {}", tableName.c_str(), offsetMapToSend.count(tableName)); LOG_ERROR("send offset map : first key offset {}", offsetMapToSend[tableName][0]); for (auto& it : offsetMapToSend.at(tableName)) { OffsetMap.push_back(it); diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 8bd950df..6681f055 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -36,6 +36,7 @@ protected: float floatMem { MEM_INIT_VALUE }; int64_t featMem { static_cast(UINT32_MAX) }; int32_t offsetMem { 0 }; + int32_t maxOffsetMem { 16 }; string name { "table" }; int sendCount { 8 }; @@ -153,7 +154,7 @@ protected: void SetMaxOffset(OffsetMemT& testMaxOffset) { for (const auto& testEmbInfo : testEmbInfos) { - testMaxOffset[testEmbInfo.name] = offsetMem; + testMaxOffset[testEmbInfo.name] = maxOffsetMem; } } @@ -407,8 +408,11 @@ TEST_F(CheckpointTest, KeyOffsetMaps) EXPECT_EQ(validLoadData.keyOffsetMap.size(), testLoadData.keyOffsetMap.size()); for (const auto& it : validLoadData.keyOffsetMap) { EXPECT_EQ(1, testLoadData.keyOffsetMap.count(it.first)); - const auto& maxOffset = testLoadData.keyOffsetMap.at(it.first); - EXPECT_EQ(it.second, maxOffset); + const auto& keyOffsetMap = testLoadData.keyOffsetMap.at(it.first); + const auto& validKeyOffsetMap = validLoadData.keyOffsetMap.at(it.first); + for (const auto& key: keyOffsetMap) { + EXPECT_EQ(validKeyOffsetMap.count(key.first), 1); + } } } @@ -440,7 +444,7 @@ TEST_F(CheckpointTest, AllMgmt) testLoadData, rankInfo, testEmbInfos, - { CkptFeatureType::MAX_OFFSET, CkptFeatureType::KEY_OFFSET_MAP }); + {CkptFeatureType::KEY_OFFSET_MAP }); EXPECT_EQ(validLoadData.maxOffset.size(), testLoadData.maxOffset.size()); for (const auto& it : validLoadData.maxOffset) { @@ -454,9 +458,12 @@ TEST_F(CheckpointTest, AllMgmt) EXPECT_EQ(validLoadData.keyOffsetMap.size(), testLoadData.keyOffsetMap.size()); for (const auto& it : validLoadData.keyOffsetMap) { EXPECT_EQ(1, testLoadData.keyOffsetMap.count(it.first)); + const auto& keyOffsetMap = testLoadData.keyOffsetMap.at(it.first); + const auto& validKeyOffsetMap = validLoadData.keyOffsetMap.at(it.first); + for (const auto& key: keyOffsetMap) { + EXPECT_EQ(validKeyOffsetMap.count(key.first), 1); + } - const auto& maxOffset = testLoadData.keyOffsetMap.at(it.first); - EXPECT_EQ(it.second, maxOffset); } } -- Gitee From 33804150f8ef706276cc30b9e6fccaba80785e27 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 17 Nov 2023 10:26:20 +0800 Subject: [PATCH 461/551] Match-id-606839df98c87a4fdc38e165cc52f05eaf5db62b --- src/core/hd_transfer/hd_transfer.cpp | 5 + src/core/hd_transfer/hd_transfer.h | 4 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 147 +++++++++++------- src/core/hybrid_mgmt/hybrid_mgmt.h | 25 ++- .../key_process/feature_admit_and_evict.cpp | 4 +- .../key_process/feature_admit_and_evict.h | 1 + src/core/key_process/key_process.cpp | 137 +++++++++++----- src/core/key_process/key_process.h | 14 +- src/tests/key_process/key_process_test.cpp | 20 ++- 9 files changed, 242 insertions(+), 115 deletions(-) diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index e92537dc..9c569df2 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -207,3 +207,8 @@ size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& return acltdtGetDatasetSize(aclDatasets[embName]); #endif } + +std::unordered_map HDTransfer::GetTransChannel() +{ + return transferChannels; +} \ No newline at end of file diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index b7f9ea38..6c046da0 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -83,10 +83,10 @@ namespace MxRec { void Destroy(); + std::unordered_map GetTransChannel(); + private: -#ifndef GTEST std::unordered_map transferChannels; -#endif bool running; void CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum); }; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 91b9b7bb..b9f567b3 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -246,6 +246,7 @@ bool HybridMgmt::Save(const string savePath) } if (isSSDEnabled) { + LOG_DEBUG(MGMT + "Start host side save: ssd mode hashmap"); for (auto& it : cacheManager->ddrKeyFreqMap) { saveData.ddrKeyFreqMaps[it.first] = it.second.GetFreqTable(); } @@ -616,60 +617,48 @@ void HybridMgmt::EvalTask(TaskType type) /// \return bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) { - LOG_INFO(MGMT + "start parse keys HBM, nBatch:{} , [{}]:{}", mgmtRankInfo.nBatch, channelId, batchId); + LOG_INFO(MGMT + "nBatch:{} channelId:{} batchId:{}, ParseKeys with HBM mode start.", + mgmtRankInfo.nBatch, channelId, batchId); // 循环处理每个表的数据 for (const auto& embInfo: mgmtEmbInfo) { TimeCost parseKeysTc; - // get - TimeCost getTensorsSyncTC; - // 获取各类向量,如果为空指针,退出当前函数 auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); if (infoVecs == nullptr) { - LOG_INFO(MGMT + "ParseKeys infoVecs empty ! batchId:{}, channelId:{}", batchId, channelId); + LOG_INFO(MGMT + "channelId:{} batchId:{}, ParseKeys infoVecs empty !", channelId, batchId); return false; } - + LOG_DEBUG("channelId:{} batchId:{}, ParseKeysHBM GetInfoVec end.", channelId, batchId); // 动态shape场景下,获取all2all向量(通信量矩阵) TimeCost sendTensorsSyncTC; unique_ptr> all2all = nullptr; if (!mgmtRankInfo.useStatic) { + TimeCost getTensorsSyncTC; all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); - LOG_DEBUG("getTensorsSyncTC(ms):{}", getTensorsSyncTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, getTensorsSyncTC(ms):{}", + channelId, batchId, getTensorsSyncTC.ElapsedMS()); if (all2all == nullptr) { LOG_ERROR("Information vector is nullptr!"); return false; } - sendTensorsSyncTC = TimeCost(); + sendTensorsSyncTC = TimeCost(); // 重新初始化,不计算getTensors耗时 TimeCost sendAll2AllScSyncTC; hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embInfo.name); - LOG_DEBUG("sendAll2AllScSyncTC(ms):{}", sendAll2AllScSyncTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, sendAll2AllScSyncTC(ms):{}", + channelId, batchId, sendAll2AllScSyncTC.ElapsedMS()); } // 发送查询向量 TimeCost sendLookupSyncTC; hdTransfer->Send(TransferChannel::LOOKUP, { infoVecs->back() }, channelId, embInfo.name); infoVecs->pop_back(); - LOG_DEBUG("sendLookupSyncTC(ms):{}", sendLookupSyncTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, sendLookupSyncTC(ms):{}", channelId, batchId, sendLookupSyncTC.ElapsedMS()); // 训练时,使用全局去重聚合梯度,发送全局去重的key和对应的恢复向量 if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && channelId == TRAIN_CHANNEL_ID) { - TimeCost sendUnikeysSyncTC; - LOG_DEBUG("global unique, table name: {}, is grad: {}", embInfo.name, embInfo.isGrad); - if (embInfo.isGrad) { - hdTransfer->Send(TransferChannel::UNIQKEYS, { infoVecs->back() }, channelId, embInfo.name); - } - infoVecs->pop_back(); - LOG_DEBUG("sendUnikeysSyncTC(ms):{}", sendUnikeysSyncTC.ElapsedMS()); - - TimeCost sendRestoreVecSecSyncTC; - if (embInfo.isGrad) { - hdTransfer->Send(TransferChannel::RESTORE_SECOND, { infoVecs->back() }, channelId, embInfo.name); - } - infoVecs->pop_back(); - LOG_DEBUG("sendRestoreVecSecSyncTC(ms):{}", sendRestoreVecSecSyncTC.ElapsedMS()); + SendUniqKeysAndRestoreVecHBM(channelId, batchId, embInfo, infoVecs); } // 发送恢复向量 @@ -677,10 +666,35 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embInfo.name); LOG_DEBUG("sendRestoreSyncTC(ms):{}, sendTensorsSyncTC(ms):{}, parseKeysTc HBM mode (ms):{}", sendRestoreSyncTC.ElapsedMS(), sendTensorsSyncTC.ElapsedMS(), parseKeysTc.ElapsedMS()); + LOG_INFO(MGMT + "channelId:{} batchId:{}, embName:{}, ParseKeys with HBM mode end.", + channelId, batchId, embInfo.name); } batchId++; return true; } + +void HybridMgmt::SendUniqKeysAndRestoreVecHBM(int channelId, int &batchId, const EmbInfo &embInfo, + const unique_ptr> &infoVecs) +{ + TimeCost sendUniqueKeysSyncTC; + LOG_DEBUG("channelId:{} batchId:{}, global unique, table name: {}, is grad: {}", + channelId, batchId, embInfo.name, embInfo.isGrad); + if (embInfo.isGrad) { + hdTransfer->Send(TransferChannel::UNIQKEYS, {infoVecs->back()}, channelId, embInfo.name); + } + infoVecs->pop_back(); + LOG_DEBUG("channelId:{} batchId:{}, sendUniqueKeysSyncTC(ms):{}", + channelId, batchId, sendUniqueKeysSyncTC.ElapsedMS()); + + TimeCost sendUniqueRestoreVecSyncTC; + if (embInfo.isGrad) { + hdTransfer->Send(TransferChannel::RESTORE_SECOND, {infoVecs->back()}, channelId, embInfo.name); + } + infoVecs->pop_back(); + LOG_DEBUG("channelId:{} batchId:{}, sendUniqueRestoreVecSyncTC(ms):{}", + channelId, batchId, sendUniqueRestoreVecSyncTC.ElapsedMS()); +} + #endif /// 当前处理的batch是否是最后一个batch @@ -699,12 +713,11 @@ bool HybridMgmt::EndBatch(int batchId, int channelId) const bool HybridMgmt::ParseKeys(int channelId, int& batchId) { #ifndef GTEST - LOG_INFO(MGMT + "DDR mode, start parse keys, [{}]:{}", channelId, batchId); + LOG_INFO(MGMT + "channelId:{} batchId:{}, DDR mode, ParseKeys start.", channelId, batchId); TimeCost parseKeyTC; int start = batchId; bool remainBatch = true; // 是否从通道获取了数据 - LOG_INFO(MGMT + "parse keys, [{}]:{}", channelId, batchId); for (const auto& embInfo : mgmtEmbInfo) { ProcessEmbInfo(embInfo.name, batchId, channelId, remainBatch); // 通道数据已空 @@ -718,10 +731,9 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) if (!isRunning) { return false; } - TimeCost embHdTrans2TC; EmbHDTransWrap(channelId, batchId - 1, start); - LOG_DEBUG("embHdTrans2TC TimeCost(ms):{}", embHdTrans2TC.ElapsedMS()); - LOG_DEBUG("[{}]-{}, parseKeyTC TimeCost(ms):{}", channelId, batchId, parseKeyTC.ElapsedMS()); + LOG_DEBUG(MGMT + "channelId:{} batchId:{}, ParseKeys end, parseKeyTC(ms):{}", + channelId, batchId, parseKeyTC.ElapsedMS()); #endif return true; } @@ -766,20 +778,24 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha auto lookupKeys = preprocess->GetLookupKeys(batchId, embName, channelId); if (lookupKeys.empty()) { remainBatchOut = false; + LOG_ERROR("channelId:{} batchId:{}, embName:{}, GetLookupKeys result is empty.", + channelId, batchId, embName); return false; } - + LOG_DEBUG("channelId:{} batchId:{}, embName:{}, GetLookupKeys end.", channelId, batchId, embName); // 获取各类向量,如果为空指针,退出当前函数 auto infoVecs = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); if (infoVecs == nullptr) { return false; } - LOG_DEBUG("getTensorsTC(ms):{}", getTensorsTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, GetInfoVec end, getTensorsTC(ms):{}", + channelId, batchId, getTensorsTC.ElapsedMS()); TimeCost sendRestoreSyncTC; hdTransfer->Send(TransferChannel::RESTORE, *infoVecs, channelId, embName); - LOG_DEBUG("sendRestoreSyncTC(ms):{}", sendRestoreSyncTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, send restore end, sendRestoreSyncTC(ms):{}", + channelId, batchId, sendRestoreSyncTC.ElapsedMS()); // 调用SSD cache缓存处理流程 - PrepareDDRData(embName, embHashMap, lookupKeys, channelId); + PrepareDDRData(embName, embHashMap, lookupKeys, channelId, batchId); // 计算查询向量;记录需要被换出的HBM偏移 vector tmpData; @@ -787,22 +803,12 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha DDRParam ddrParam(tmpData, offsetsOut); TimeCost hostHashMapProcessTC; hostHashMaps->Process(embName, lookupKeys, ddrParam, channelId); - LOG_DEBUG("hostHashMapProcessTC(ms):{}", hostHashMapProcessTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, hostHashMapProcessTC(ms):{}", + channelId, batchId, hostHashMapProcessTC.ElapsedMS()); if (GlobalEnv::applyGradientsStrategy == ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY && channelId == TRAIN_CHANNEL_ID && remainBatchOut) { - vector uniqueKeys; - vector restoreVecSec; - preprocess->GlobalUnique(ddrParam.offsetsOut, uniqueKeys, restoreVecSec); - - TimeCost sendUnikeysSyncTC; - hdTransfer->Send(TransferChannel::UNIQKEYS, { mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : - Vec2TensorI32(uniqueKeys) }, channelId, embName); - - TimeCost sendRestoreVecSecSyncTC; - hdTransfer->Send(TransferChannel::RESTORE_SECOND, { Vec2TensorI32(restoreVecSec) }, channelId, embName); - LOG_DEBUG("sendUnikeysSyncTC(ms):{}sendRestoreVecSecSyncTC(ms):{}", - sendUnikeysSyncTC.ElapsedMS(), sendRestoreVecSecSyncTC.ElapsedMS()); + SendUniqKeysAndRestoreVecDDR(embName, batchId, channelId, ddrParam); } TimeCost sendTensorsTC; @@ -817,30 +823,54 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha } hdTransfer->Send(TransferChannel::ALL2ALL, *all2all, channelId, embName); } - LOG_DEBUG("sendTensorsTC(ms):{} getAndSendTensorsTC(ms):{}, channelId:{}", - sendTensorsTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS(), channelId); + LOG_DEBUG("channelId:{} batchId:{}, ProcessEmbInfo end, sendTensorsTC(ms):{}, getAndSendTensorsTC(ms):{}", + channelId, batchId, sendTensorsTC.ElapsedMS(), getAndSendTensorsTC.ElapsedMS()); if (!isSSDEnabled && embHashMap.HasFree(lookupKeys.size())) { // check free > next one batch - LOG_WARN(MGMT + "embName {}[{}]{}, freeSize not enough, {}", embName, channelId, batchId, lookupKeys.size()); + LOG_WARN(MGMT + "channelId:{} batchId:{}, embName:{}, freeSize not enough:{}", + channelId, batchId, embName, lookupKeys.size()); return false; } return true; } +void HybridMgmt::SendUniqKeysAndRestoreVecDDR(const string &embName, int &batchId, int &channelId, DDRParam &ddrParam) +{ + LOG_DEBUG("channelId:{} batchId:{}, embName:{}, SendUniqKeysAndRestoreVecDDR start.", channelId, batchId, embName); + vector uniqueKeys; + vector restoreVecSec; + preprocess->GlobalUnique(ddrParam.offsetsOut, uniqueKeys, restoreVecSec); + + TimeCost sendUniqueKeysSyncTC; + hdTransfer->Send(TransferChannel::UNIQKEYS, {mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : + Vec2TensorI32(uniqueKeys) }, channelId, embName); + LOG_DEBUG("channelId:{} batchId:{}, sendUniqueKeysSyncTC(ms):{}", + channelId, batchId, sendUniqueKeysSyncTC.ElapsedMS()); + + TimeCost sendRestoreVecSecSyncTC; + hdTransfer->Send(TransferChannel::RESTORE_SECOND, {Vec2TensorI32(restoreVecSec) }, channelId, embName); + LOG_DEBUG("channelId:{} batchId:{}, sendRestoreVecSecSyncTC(ms):{}", + channelId, batchId, sendRestoreVecSecSyncTC.ElapsedMS()); +} + /// 发送H2D和接收D2H向量 /// \param channelId 通道索引(训练/推理) /// \param batchId 已处理的batch数 /// \param start void HybridMgmt::EmbHDTransWrap(int channelId, const int& batchId, int start) { - LOG_INFO(MGMT + "trans emb, batchId:[{}-{}], channelId:{}", start, batchId, channelId); + LOG_INFO(MGMT + "start:{} channelId:{} batchId:{}, EmbHDTransWrap start.", start, channelId, batchId); + TimeCost embHDTransWrapTC; TimeCost hostEmbsTC; hostEmbs->Join(channelId); - LOG_DEBUG("hostEmbsTC(ms):{}", hostEmbsTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, hostEmbs Join end, hostEmbsTC(ms):{}", + channelId, batchId, hostEmbsTC.ElapsedMS()); if (!isRunning) { return; } EmbHDTrans(channelId, batchId); + LOG_DEBUG("channelId:{} batchId:{}, EmbHDTransWrap end, embHDTransWrapTC(ms):{}", + channelId, batchId, embHDTransWrapTC.ElapsedMS()); } /// 发送H2D和接收D2H向量,并更新host emb @@ -850,8 +880,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) { EASY_FUNCTION(profiler::colors::Blue) EASY_VALUE("mgmtProcess", batchId) - LOG_DEBUG(MGMT + "trans emb, batchId:{}, channelId:{}", batchId, channelId); - TimeCost tr; + LOG_DEBUG(MGMT + "channelId:{} batchId:{}, EmbHDTrans start.", channelId, batchId); TimeCost h2dTC; // 发送host需要换出的emb for (const auto& embInfo: mgmtEmbInfo) { @@ -860,7 +889,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) hostEmbs->GetH2DEmb(missingKeys, embInfo.name, h2dEmb); // order! hdTransfer->Send(TransferChannel::H2D, h2dEmb, channelId, embInfo.name, batchId); } - LOG_DEBUG("h2dTC(ms):{}", h2dTC.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, EmbHDTrans h2d end, h2dTC(ms):{}", channelId, batchId, h2dTC.ElapsedMS()); TimeCost d2hTC; // 接收device换出的emb,并更新到host上 @@ -873,8 +902,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) } hostHashMaps->ClearMissingKeys(embInfo.name); } - LOG_DEBUG("D2HTC(ms):{} EmbHDTrans TimeCost(ms):{} batchId: {} channelId:{}", - d2hTC.ElapsedMS(), tr.ElapsedMS(), batchId, channelId); + LOG_DEBUG("channelId:{} batchId:{}, EmbHDTrans d2h end, d2hTC(ms):{}", channelId, batchId, d2hTC.ElapsedMS()); } #endif @@ -977,18 +1005,19 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) } inline void HybridMgmt::PrepareDDRData(const string& embTableName, EmbHashMapInfo& embHashMap, - const vector& keys, int channelId) const + const vector& keys, int channelId, int batchId) const { if (!isSSDEnabled) { return; } - LOG_DEBUG("PrepareDDRData start."); + LOG_DEBUG("channelId:{} batchId:{}, embTableName:{}, PrepareDDRData start.", channelId, batchId, embTableName); TimeCost prepareDDRDataTc; TransferRet ret = cacheManager->TransferDDREmbWithSSD(embTableName, embHashMap, keys, channelId); if (ret != TransferRet::TRANSFER_OK) { HandlePrepareDDRDataRet(ret); } - LOG_DEBUG("PrepareDDRData end, TimeCost(ms):{}", prepareDDRDataTc.ElapsedMS()); + LOG_DEBUG("channelId:{} batchId:{}, embTableName:{}, PrepareDDRData end, prepareDDRDataTc(ms):{}", + channelId, batchId, embTableName, prepareDDRDataTc.ElapsedMS()); } void HybridMgmt::EvictSSDKeys(const string& embName, const vector& keys) const diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 5ebfed4a..9382b9b1 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -21,6 +21,7 @@ #include "utils/common.h" #include "utils/config.h" #include "utils/singleton.h" +#include "utils/logger.h" #include "host_emb/host_emb.h" #include "emb_hashmap/emb_hashmap.h" @@ -73,6 +74,7 @@ namespace MxRec { void Destroy() { + LOG_DEBUG(MGMT + "start Destroy hybrid_mgmt module"); if (!isInitialized) { throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } @@ -81,12 +83,16 @@ namespace MxRec { return; } // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 - isRunning = false; - // 先发送停止信号给preprocess,用于停止查询中lookup卡住状态 - preprocess->isRunning = false; - // 停止hdTransfer,用于停止mgmt的recv中卡住状态 - hdTransfer->Destroy(); + if (preprocess != nullptr) { + // 获取锁 避免KeyProcess中手动发送结束信息时通道关闭 + std::unique_lock lockGuard(preprocess->destroyMutex); + // 先发送停止信号给preprocess,用于停止查询中lookup卡住状态 + preprocess->isRunning = false; + // 停止hdTransfer,用于停止mgmt的recv中卡住状态 + hdTransfer->Destroy(); + LOG_DEBUG(MGMT + "destroy hdTransfer end."); + } hybridMgmtBlock->Destroy(); for (auto& t : procThreads) { t->join(); @@ -104,7 +110,9 @@ namespace MxRec { if (preprocess != nullptr) { preprocess->Destroy(); preprocess = nullptr; + LOG_DEBUG(MGMT + "invoke KeyProcess destroy end."); } + LOG_DEBUG(MGMT + "Destroy hybrid_mgmt module end."); }; bool ParseKeys(int channelId, int& batchId); @@ -138,7 +146,7 @@ namespace MxRec { void EvictSSDKeys(const string& embName, const vector& keys) const; void PrepareDDRData(const std::string& embTableName, EmbHashMapInfo& embHashMap, - const vector &keys, int channelId) const; + const vector &keys, int channelId, int batchId) const; int GetStepFromPath(const string& loadPath) const; @@ -177,6 +185,11 @@ namespace MxRec { bool LoadMatchesDDRSetup(const CkptData& loadData); void HandlePrepareDDRDataRet(TransferRet prepareSSDRet) const; + + void SendUniqKeysAndRestoreVecHBM(int channelId, int& batchId, const EmbInfo &embInfo, + const unique_ptr> &infoVecs); + + void SendUniqKeysAndRestoreVecDDR(const string &embName, int &batchId, int &channelId, DDRParam &ddrParam); }; } #endif // MX_REC_EMB_MGMT_H diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index 3476b8e0..b78f1d83 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -44,7 +44,7 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, LOG_ERROR("splitKey.size {} != keyCount.size {}", splitKey.size(), keyCount.size()); return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_ERROR; } - + TimeCost featureAdmitAndEvictTC; std::string tableName = batch->name; if (m_isCombine) { tableName = COMBINE_HISTORY_NAME; @@ -90,7 +90,7 @@ FeatureAdmitReturnType FeatureAdmitAndEvict::FeatureAdmit(int channel, } LOG_TRACE("FeatureAdmit, name:[{}], channel:[{}], after admit, splitKey:[{}] ...", tableName, channel, VectorToString(splitKey)); - + LOG_DEBUG("featureAdmitAndEvictTC(ms):{}", featureAdmitAndEvictTC.ElapsedMS()); return FeatureAdmitReturnType::FEATURE_ADMIT_RETURN_OK; } diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 6b8ff4fe..4103c82a 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -21,6 +21,7 @@ #include "utils/common.h" #include "utils/safe_queue.h" #include "utils/singleton.h" +#include "utils/time_cost.h" namespace MxRec { enum class FeatureAdmitType { diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 6809bd95..b50acfd5 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -151,12 +151,12 @@ void KeyProcess::LoadKeyOffsetMap(KeyOffsetMemT& loadData) void KeyProcess::Destroy() { isRunning = false; - LOG_INFO(KEY_PROCESS "rank {} begin destroy.", rankInfo.rankId); + LOG_INFO(KEY_PROCESS "rankId:{} KeyProcess begin destroy.", rankInfo.rankId); for (auto& i: procThreads) { i->join(); } procThreads.clear(); - LOG_INFO(KEY_PROCESS "rank {} destroy success.", rankInfo.rankId); + LOG_INFO(KEY_PROCESS "rankId:{} KeyProcess destroy success.", rankInfo.rankId); } /// 每个数据通道的所有数据处理线程上锁 @@ -530,18 +530,21 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) tc = TimeCost(); } if (!isRunning) { + LOG_WARN("channelId:{} threadId:{}, isRunning is false when GetBatchData", channel, commId); // 通信终止信号,同步退出,防止线程卡住 int exitFlag = isRunning; auto retCode = MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); } + LOG_DEBUG("channelId:{} threadId:{}, GetBatchData Allreduce end, receiveFlag:{}", + channel, commId, exitFlag); throw EndRunExit("GetBatchData end run."); } } EASY_END_BLOCK - LOG_DEBUG(KEY_PROCESS "rank {} thread {} get batch {}[{}]:{} done. bs:{} sample:[{}]", - rankInfo.rankId, commId, batch->name, batch->channel, batch->batchId, batch->Size(), batch->UnParse()); + LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, get batch data done, batchName:{}. bs:{} sample:[{}]", + batch->channel, commId, batch->batchId, batch->name, batch->Size(), batch->UnParse()); #if defined(PROFILING) && defined(BUILD_WITH_EASY_PROFILER) if (batch->batchId == PROFILING_START_BATCH_ID) { EASY_PROFILER_ENABLE @@ -606,7 +609,7 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch, vector sc; HandleHotAndSendCount(batch, uniqueInfoOut, keySendInfo, sc, splitSize); - All2All(sc, id, batch->channel, keySendInfo, uniqueInfoOut.all2AllInfo); + All2All(sc, id, batch, keySendInfo, uniqueInfoOut.all2AllInfo); LOG_DEBUG(KEY_PROCESS "ProcessBatchWithFastUnique get batchId:{}, batchSize:{}," " channel:{}, name:{}, restore:{}, keyCount:{}", @@ -673,11 +676,12 @@ void KeyProcess::ComputeHotPos(const unique_ptr &batch, absl::flat_ha } } -void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, +void KeyProcess::All2All(vector& sc, int id, const unique_ptr &batch, KeySendInfo& keySendInfo, All2AllInfo& all2AllInfoOut) { TimeCost getScAllTC; - GetScAllForUnique(sc, id, channel, all2AllInfoOut.scAll); // Allgather通信获取所有(不同rank相同thread id的) + int channel = batch->channel; + GetScAllForUnique(sc, id, batch, all2AllInfoOut.scAll); // Allgather通信获取所有(不同rank相同thread id的) LOG_DEBUG("GetScAll TimeCost(ms):{}", getScAllTC.ElapsedMS()); TimeCost all2allTC; @@ -691,19 +695,24 @@ void KeyProcess::All2All(vector& sc, int id, int channel, KeySendInfo& keyS all2AllInfoOut.keyRecv.resize(rs.back() + rc.back()); EASY_BLOCK("all2all") auto retCode = MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, - all2AllInfoOut.keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[channel][id]); + all2AllInfoOut.keyRecv.data(), rc.data(), rs.data(), + MPI_INT64_T, comm[channel][id]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Alltoallv failed:{}", rankInfo.rankId, retCode); } + LOG_DEBUG("channelId:{} threadId:{} batchId:{}, All2All MPI_Alltoallv end.", channel, id, batch->batchId); all2AllInfoOut.countRecv.resize(rs.back() + rc.back()); if (isWithFAAE) { retCode = MPI_Alltoallv(keySendInfo.keyCount.data(), sc.data(), ss.data(), MPI_UINT32_T, - all2AllInfoOut.countRecv.data(), rc.data(), rs.data(), MPI_UINT32_T, comm[channel][id]); + all2AllInfoOut.countRecv.data(), rc.data(), + rs.data(), MPI_UINT32_T, comm[channel][id]); if (retCode != MPI_SUCCESS) { - LOG_ERROR("rank {}, MPI_Alltoallv failed:{}", rankInfo.rankId, retCode); + LOG_ERROR("channelId:{} threadId:{} batchId:{}, MPI_Alltoallv failed:{}", + channel, id, batch->batchId, retCode); } } - LOG_DEBUG("all2allTC TimeCost(ms):{}", all2allTC.ElapsedMS()); + LOG_DEBUG("channelId:{} threadId:{} batchId:{}, All2All end, all2allTC TimeCost(ms):{}", + channel, id, batch->batchId, all2allTC.ElapsedMS()); EASY_END_BLOCK } @@ -713,7 +722,8 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, TimeCost processSplitKeysTC; EASY_FUNCTION(profiler::colors::Purple) EASY_VALUE("batchId", batch->batchId) - LOG_INFO(KEY_PROCESS "ProcessSplitKeys start batchId:{}, channel:{}", batch->batchId, batch->channel); + LOG_INFO(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, ProcessSplitKeys start.", + batch->channel, id, batch->batchId); // 使用静态all2all通信:发送或接受量为预置固定值 scInfo[batch->name] = 65536 / rankSize 经验值 if (rankInfo.useStatic) { // maybe move after all2all @@ -737,7 +747,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, KeysT keyRecv; TimeCost getScAllTC; - auto scAll = GetScAll(sc, id, batch->channel); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 + auto scAll = GetScAll(sc, id, batch); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 LOG_DEBUG("getScAllTC(ms)(AllReduce-AllGather):{}", getScAllTC.ElapsedMS()); auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 @@ -748,22 +758,20 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, } auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 keyRecv.resize(rs.back() + rc.back()); - LOG_TRACE(KEY_PROCESS "MPI_Alltoallv begin. rank {} thread {} batch {} {}", - rankInfo.rankId, id, batch->batchId, batch->name); EASY_BLOCK("all2all") TimeCost uniqueAll2AllTC; auto retCode = MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, - keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[batch->channel][id]); + keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[batch->channel][id]); if (retCode != MPI_SUCCESS) { - LOG_ERROR("rank {}, MPI_Allgather failed:{}", rankInfo.rankId, retCode); + LOG_ERROR("rank {}, MPI_Alltoallv failed:{}", rankInfo.rankId, retCode); } LOG_DEBUG("uniqueAll2AllTC(ms):{}", uniqueAll2AllTC.ElapsedMS()); EASY_END_BLOCK - LOG_TRACE(KEY_PROCESS "MPI_Alltoallv finish. rank {} thread {} batch {} {}", - rankInfo.rankId, id, batch->batchId, batch->name); - LOG_DEBUG("processSplitKeysTC(ms):{}", processSplitKeysTC.ElapsedMS()); + LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, batchName:{}, MPI_Alltoallv finish." + " processSplitKeysTC(ms):{}", + batch->channel, id, batch->batchId, batch->name, processSplitKeysTC.ElapsedMS()); return { keyRecv, scAll, ss }; } @@ -991,37 +999,93 @@ void KeyProcess::UpdateHotMap(absl::flat_hash_map& keyCountMap, * 将本地(rank)batch要发送的key数据量进行Allgather通信,获取所有(不同rank相同thread id的)线程间的通信量矩阵 * scAll返回:所有线程间的通信量矩阵(按行平铺的一维向量) */ -vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, int channel) const +vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, const unique_ptr& batch) { EASY_FUNCTION() vector scAll; scAll.resize(rankInfo.rankSize * rankInfo.rankSize); EASY_BLOCK("barrier"); + LOG_DEBUG("channelId:{} threadId:{} batchId:{}, GetScAll start.", batch->channel, commId, batch->batchId); + // 通信终止信号,同步退出,防止线程卡住 TimeCost tc = TimeCost(); int exitFlag = isRunning; - auto retCode = MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + int receiveFlag = exitFlag; + auto retCode = MPI_Allreduce(&exitFlag, &receiveFlag, 1, MPI_INT, MPI_SUM, comm[batch->channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {} commId {}, MPI_Allreduce failed:{}", rankInfo.rankId, commId, retCode); } - if (exitFlag < rankInfo.rankSize) { - throw EndRunExit("GetScAll end run."); - } + LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, GetScAll MPI_Allreduce end, receiveFlag:{}" + " barrier time:{}", + batch->channel, commId, batch->batchId, receiveFlag, tc.ElapsedMS()); + + // 处理其他rank线程退出的情况 + HandleRankExitScene(commId, batch, receiveFlag); + EASY_END_BLOCK; - LOG_DEBUG(KEY_PROCESS "barrier time:{}", tc.ElapsedMS()); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, - scAll.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, MPI_INT, + comm[batch->channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {} commId {}, MPI_Allgather failed:{}", rankInfo.rankId, commId, retCode); } - LOG_DEBUG("rank {} key scAll matrix:\n{}", rankInfo.rankId, VectorToString(scAll)); + LOG_DEBUG("channelId:{} threadId:{} batchId:{}, GetScAll MPI_Allgather end, key scAll matrix:\n{}", + batch->channel, commId, batch->batchId, VectorToString(scAll)); return scAll; } -void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const +void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &batch, int receiveFlag) +{ + if (!isRunning) { + throw EndRunExit("GetScAll end run, isRunning is false."); + } + if (receiveFlag < rankInfo.rankSize) { + unique_lock lockGuard(destroyMutex); + if (!isRunning) { + LOG_INFO("channelId:{} threadId:{} batchId:{}, isRunning is false after lock destroyMutex.", + batch->channel, commId, batch->batchId); + throw EndRunExit("GetScAll end run, isRunning is false after lock destroyMutex."); + } + SendEosInfo(commId, batch); + isRunning = false; + throw EndRunExit("has SendEosInfo, GetScAll end run."); + } +} + +void KeyProcess::SendEosInfo(int commId, const unique_ptr& batch) +{ + // 注: SendTensorsByAcl方法UT无法链接 需屏蔽 +#ifndef GTEST + auto trans = Singleton::GetInstance(); + auto transChannel = trans->GetTransChannel(); + LOG_INFO("channelId:{} threadId:{} batchId:{}, start send acl eos info.", + batch->channel, commId, batch->batchId); + vector tensors; + bool isNeedResend = true; + string all2all_sendName = StringFormat("%s_%s_%d", batch->name.c_str(), + TransferChannel2Str(TransferChannel::ALL2ALL).c_str(), + batch->channel); + SendTensorsByAcl(transChannel[all2all_sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, + isNeedResend); + string restore_sendName = StringFormat("%s_%s_%d", batch->name.c_str(), + TransferChannel2Str(TransferChannel::RESTORE).c_str(), + batch->channel); + SendTensorsByAcl(transChannel[restore_sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, + isNeedResend); + string lookup_sendName = StringFormat("%s_%s_%d", batch->name.c_str(), + TransferChannel2Str(TransferChannel::LOOKUP).c_str(), batch->channel); + SendTensorsByAcl(transChannel[lookup_sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, + isNeedResend); + LOG_INFO("channelId:{} threadId:{} batchId:{}, send acl eos info end.", + batch->channel, commId, batch->batchId); +#endif +} + +void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, const unique_ptr &batch, + vector &scAllOut) { EASY_FUNCTION() + int channel = batch->channel; scAllOut.resize(rankInfo.rankSize * rankInfo.rankSize); EASY_BLOCK("barrier"); // 通信终止信号,同步退出,防止线程卡住 @@ -1031,18 +1095,21 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, in if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); } - if (exitFlag < rankInfo.rankSize) { - throw EndRunExit("GetScAll end run."); - } + LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, GetScAllForUnique MPI AllReduce end, " + "receiveFlag:{}, barrier time:{}", + channel, commId, batch->batchId, exitFlag, tc.ElapsedMS()); + // 处理其他rank线程退出的情况 + HandleRankExitScene(commId, batch, exitFlag); + EASY_END_BLOCK; - LOG_DEBUG(KEY_PROCESS "barrier time:{}", tc.ElapsedMS()); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Allgather failed:{}", rankInfo.rankId, retCode); } - LOG_DEBUG("rank {} key scAllOut matrix:\n{}", rankInfo.rankId, VectorToString(scAllOut)); + LOG_DEBUG("channelId:{} threadId:{} batchId:{}, GetScAllForUnique end, key scAllOut matrix:\n{}", + channel, commId, batch->batchId, VectorToString(scAllOut)); } void KeyProcess::Key2Offset(const EmbNameT& embName, KeysT& splitKey, int channel) diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index cea6978d..f8261ce2 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -151,7 +151,7 @@ namespace MxRec { } bool isRunning { false }; - + std::mutex destroyMutex; inline bool HasEmbName(const string& embName) { return embInfos.find(embName) != embInfos.end(); @@ -177,7 +177,6 @@ namespace MxRec { map> hotKey {}; map hotEmbTotCount; map embeddingTableMap {}; - FactoryPtr factory {}; int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; @@ -206,7 +205,7 @@ namespace MxRec { size_t GetKeySize(const unique_ptr &batch); - void All2All(vector& sc, int id, int channel, KeySendInfo& keySendInfo, + void All2All(vector& sc, int id, const unique_ptr &batch, KeySendInfo& keySendInfo, All2AllInfo& all2AllInfoOut); auto HashSplit(const unique_ptr& batch) const -> tuple, vector>; @@ -216,9 +215,10 @@ namespace MxRec { auto HashSplitWithFAAE(const unique_ptr& batch) const -> tuple, vector, vector>>; - vector GetScAll(const vector& keyScLocal, int commId, int channel) const; + vector GetScAll(const vector& keyScLocal, int commId, const unique_ptr& batch); - void GetScAllForUnique(const vector& keyScLocal, int commId, int channel, vector &scAllOut) const; + void GetScAllForUnique(const vector& keyScLocal, int commId, const unique_ptr &batch, + vector &scAllOut); void Key2Offset(const EmbNameT& embName, KeysT& splitKey, int channel); @@ -272,6 +272,10 @@ namespace MxRec { } string DumpSplitKeys(vector>& splitKeys) const; + + void SendEosInfo(int commId, const unique_ptr& batch); + + void HandleRankExitScene(int commId, const unique_ptr &batch, int receiveFlag); }; } // end namespace MxRec #endif // MX_REC_KEY_PROCESS_H diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index d406f36b..05682560 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -14,6 +14,7 @@ #include "utils/common.h" #include "key_process/key_process.h" +#include "hd_transfer/hd_transfer.h" #include "ock_ctr_common/include/unique.h" #include "ock_ctr_common/include/error_code.h" @@ -72,6 +73,7 @@ protected: splits = fieldNums; } + // 使用该方法构造的数据需要使用掉,否则会影响其他用例 vector> PrepareBatch() { vector> result(KEY_PROCESS_THREAD * MAX_CHANNEL_NUM); @@ -281,7 +283,6 @@ TEST_F(KeyProcessTest, HashSplit) ASSERT_THAT(restore, ElementsAreArray(expectRestore)); } -#ifndef GTEST TEST_F(KeyProcessTest, GetScAll) { vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 @@ -292,16 +293,18 @@ TEST_F(KeyProcessTest, GetScAll) } ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); - vector scAll; - process.GetScAll(keyScLocal, 0, 0, scAll); + // 仅用于集合通信获取sendCount信息,构造EmbBatchT对象即可,通道传0,不用构造batch数据 + EmbBatchT tempBatch; + tempBatch.channel = 0; + unique_ptr batch = std::make_unique(tempBatch); + vector scAll = process.GetScAll(keyScLocal, 0, batch); ASSERT_THAT(scAll, ElementsAreArray(expectScAll)); } -#endif TEST_F(KeyProcessTest, GetScAllForUnique) { vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 - LOG_DEBUG(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, VectorToString(keyScLocal)); + LOG_INFO(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, VectorToString(keyScLocal)); vector expectScAll(worldSize * worldSize); for (unsigned int i = 0; i < expectScAll.size(); ++i) { expectScAll[i] = floor(i / worldSize) + 1; @@ -309,7 +312,12 @@ TEST_F(KeyProcessTest, GetScAllForUnique) ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); vector scAll; - process.GetScAllForUnique(keyScLocal, 0, 0, scAll); + // 仅用于集合通信获取sendCount信息,构造EmbBatchT对象即可,通道传0,不用构造batch数据 + EmbBatchT tempBatch; + tempBatch.channel = 0; + unique_ptr batch = std::make_unique(tempBatch); + process.GetScAllForUnique(keyScLocal, 0, batch, scAll); + LOG_INFO("scAll:{}", VectorToString(scAll)); ASSERT_THAT(scAll, ElementsAreArray(expectScAll)); } -- Gitee From 104bbb273e64630b1dbe75c003faf24139f75eb8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 17 Nov 2023 21:23:45 +0800 Subject: [PATCH 462/551] Match-id-53e31b9ff70a2a01ffa851daf8d132e8630ecf52 --- mx_rec/saver/saver.py | 3 +- mx_rec/saver/sparse.py | 15 +++++---- src/core/checkpoint/checkpoint.cpp | 3 +- src/core/checkpoint/checkpoint.h | 1 - .../hdfs_file_system/hdfs_file_system.cpp | 18 ++++------ .../local_file_system/local_file_system.cpp | 4 +-- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 1 - src/tests/checkpoint/checkpoint_test.cpp | 33 ------------------- 8 files changed, 20 insertions(+), 58 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index a8a14037..d24d7a43 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -192,7 +192,7 @@ class Saver(object): logger.debug(f"host data was saved.") if get_use_dynamic_expansion(): - logger.error(f"use dynamic expansion.") + # Data related to dynamic expansion needs to be saved only on the host side. return result = self.save_op_dict @@ -264,6 +264,7 @@ class Saver(object): logger.info("host data was restored.") if get_use_dynamic_expansion: + # Data related to dynamic expansion needs to be restored only on the host side. return restore_feed_dict = defaultdict(dict) diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index a614cf9d..4f750a54 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -22,7 +22,7 @@ class SparseProcessor: self.host_dir_list = ["HashTable", "DDR"] self.device_emb_dir = "embedding" self.host_emb_dir = "embedding_data" - self.device_hashmap_dir = "key_offset_map" + self.device_hashmap_dir = "key" self.host_hashmap_dir = "embedding_hashmap" self.data_suffix = ".data" self.attrib_suffix = ".attribute" @@ -93,13 +93,14 @@ class SparseProcessor: device_table_dir = os.path.join(dev_dir, table) host_table_dir = os.path.join(host_dir, table) if table_instance.host_vocabulary_size != 0: - ddr = True out_dir = host_table_dir + key, offset = self._get_hashmap(host_table_dir, True) + emb_data = self.get_embedding(device_table_dir, host_table_dir, True) + emb_data = emb_data[offset] else: out_dir = device_table_dir - key, offset = self._get_hashmap(out_dir, ddr) - emb_data = self.get_embedding(device_table_dir, host_table_dir, ddr) - emb_data = emb_data[offset] + key, _ = self._get_hashmap(device_table_dir, False) + emb_data = self.get_embedding(device_table_dir, host_table_dir, False) transformed_data = dict(zip(key[:], emb_data[:])) save_path = os.path.join(out_dir, self.export_name + ".npy") np.save(save_path, transformed_data) @@ -143,7 +144,9 @@ class SparseProcessor: raise ValueError(f"the attribute data from file {attribute_file} is invalid") data_shape = shape_data[:2] raw_hashmap = self._get_data(data_file, np.uint64, data_shape) - offset = raw_hashmap[:, -1] + offset = [] + if ddr: + offset = raw_hashmap[:, -1] key = raw_hashmap[:, 0] return key, offset diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 0f6b039e..4c86f0fb 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -359,7 +359,7 @@ void Checkpoint::LoadDataset(const vector& embNames, if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { auto embedPath { dataDir + dirSeparator + "embedding" }; auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; - LOG_DEBUG("====Start loading embedding data from: {}", datasetDir); + LOG_DEBUG("====Start loading embedding data from: {}", embedPath); ReadEmbedding(transData, embedDatasetDir, embName); } @@ -430,7 +430,6 @@ void Checkpoint::ReadStreamForEmbData(CkptTransData& transData, } auto embDataOuterSize = transData.attribute.at(attribEmbDataOuterIdx); - if (embDataOuterSize <= 0 || embDataOuterSize > MAX_VOCABULARY_SIZE) { throw runtime_error(StringFormat("Invalid embDataOuterSize :%d", embDataOuterSize).c_str()); } diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 94ca7e68..b5aafbe1 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -13,7 +13,6 @@ #include #include "utils/common.h" #include "ckpt_data_handler/ckpt_data_handler.h" -#include "file_system/buffer_queue.h" #include "file_system/file_system_handler.h" namespace MxRec { diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp index e0fb9827..01a6507f 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp @@ -270,6 +270,8 @@ ssize_t HdfsFileSystem::Read(const string& filePath, vector>& file void HdfsFileSystem::ReadEmbedding(const string& filePath, const int& embeddingSize, vector& addressArr, int deviceId) { + size_t datasetSize = GetFileSize(filePath); + auto embHashMapSize = static_cast(datasetSize / sizeof(float) / embeddingSize); hdfsFS fs = ConnectHdfs(); hdfsFile file = hdfs->OpenFile(fs, filePath.c_str(), O_RDONLY, 0, 0, 0); @@ -278,8 +280,6 @@ void HdfsFileSystem::ReadEmbedding(const string& filePath, const int& embeddingS throw runtime_error("open hdfs file failed."); } - size_t datasetSize = GetFileSize(filePath); - #ifndef GTEST auto res = aclrtSetDevice(static_cast(deviceId)); if (res != ACL_ERROR_NONE) { @@ -298,15 +298,9 @@ void HdfsFileSystem::ReadEmbedding(const string& filePath, const int& embeddingS } float *floatPtr = static_cast(newBlock); - - for (size_t i = 0, j = 0; i < addressArr.size(); i += keyAddrElem, ++j) { + for (size_t i = 0, j = 0; i < embHashMapSize; i += keyAddrElem, ++j) { vector row(embeddingSize); - auto bytesRead = hdfs->Read(fs, file, row.data(), embeddingSize * sizeof(float)); - if (bytesRead != embeddingSize * sizeof(float)) { - hdfs->CloseFile(fs, file); - hdfs->Disconnect(fs); - throw runtime_error("Error read hdfs file."); - } + hdfs->Read(fs, file, row.data(), embeddingSize * sizeof(float)); aclError ec = aclrtMemcpy(floatPtr + j * embeddingSize, embeddingSize * sizeof(float), row.data(), embeddingSize * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); @@ -316,11 +310,11 @@ void HdfsFileSystem::ReadEmbedding(const string& filePath, const int& embeddingS throw runtime_error(StringFormat("aclrtMemcpy failed, ret=%d", ec).c_str()); } int64_t address = reinterpret_cast(floatPtr + j * embeddingSize); - addressArr.at(i) = address; + addressArr.push_back(address); } +#endif hdfs->CloseFile(fs, file); hdfs->Disconnect(fs); -#endif } hdfsFS HdfsFileSystem::ConnectHdfs() diff --git a/src/core/file_system/local_file_system/local_file_system.cpp b/src/core/file_system/local_file_system/local_file_system.cpp index 493440a4..2335f94b 100644 --- a/src/core/file_system/local_file_system/local_file_system.cpp +++ b/src/core/file_system/local_file_system/local_file_system.cpp @@ -39,13 +39,13 @@ vector LocalFileSystem::ListDir(const string& dirName) return dirs; } - for (en = readdir(dir); dir != nullptr; en = readdir(dir)) { + for (en = readdir(dir); en != nullptr ; en = readdir(dir)) { if (strncmp(en->d_name, currDir.c_str(), strlen(currDir.c_str())) != 0 && strncmp(en->d_name, prevDir.c_str(), strlen(prevDir.c_str())) != 0) { dirs.emplace_back(en->d_name); } - closedir(dir); } + closedir(dir); return dirs; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index e678ce59..fb695254 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -377,7 +377,6 @@ OffsetT HybridMgmt::SendHostMap(const string tableName) OffsetT OffsetMap; // 先校验这个map是不是空的 if ((!offsetMapToSend.empty()) && offsetMapToSend.count(tableName) > 0) { - LOG_ERROR("send offset map : first key offset {}", offsetMapToSend[tableName][0]); for (auto& it : offsetMapToSend.at(tableName)) { OffsetMap.push_back(it); } diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 6681f055..607a7c96 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -355,36 +355,6 @@ TEST_F(CheckpointTest, EmbHashMaps) } } -TEST_F(CheckpointTest, MaxOffset) -{ - OffsetMemT testMaxOffset; - OffsetMemT validMaxOffset; - - SetEmbInfo(); - SetMaxOffset(testMaxOffset); - validMaxOffset = testMaxOffset; - - CkptData testSaveData; - CkptData validLoadData; - CkptData testLoadData; - - testSaveData.maxOffset = std::move(testMaxOffset); - validLoadData.maxOffset = std::move(validMaxOffset); - - Checkpoint testCkpt; - testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); - testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::MAX_OFFSET }); - - EXPECT_EQ(validLoadData.maxOffset.size(), testLoadData.maxOffset.size()); - for (const auto& it : validLoadData.maxOffset) { - EXPECT_EQ(1, testLoadData.maxOffset.count(it.first)); - - const auto& maxOffset = testLoadData.maxOffset.at(it.first); - - EXPECT_EQ(it.second, maxOffset); - } -} - TEST_F(CheckpointTest, KeyOffsetMaps) { KeyOffsetMemT testKeyOffsetMaps; @@ -449,9 +419,7 @@ TEST_F(CheckpointTest, AllMgmt) EXPECT_EQ(validLoadData.maxOffset.size(), testLoadData.maxOffset.size()); for (const auto& it : validLoadData.maxOffset) { EXPECT_EQ(1, testLoadData.maxOffset.count(it.first)); - const auto& maxOffset = testLoadData.maxOffset.at(it.first); - EXPECT_EQ(it.second, maxOffset); } @@ -463,7 +431,6 @@ TEST_F(CheckpointTest, AllMgmt) for (const auto& key: keyOffsetMap) { EXPECT_EQ(validKeyOffsetMap.count(key.first), 1); } - } } -- Gitee From 2bd4cfc6ae46f3dd87e3c552eec00cb4ddb631e2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 20 Nov 2023 11:48:04 +0800 Subject: [PATCH 463/551] Match-id-82d88150d3932cac6d415a806019c4eaa9c18757 --- mx_rec/saver/saver.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index d24d7a43..889d804e 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -104,8 +104,8 @@ class Saver(object): """ logger.debug("======== Start saving for rank id %s ========", self.rank_id) if not check_file_system_is_valid(save_path): - raise ValueError(f"the path to save sparse embedding table data belong to invalid file system, " - f"only local file system and hdfs file system supported. ") + raise ValueError("the path to save sparse embedding table data belong to invalid file system, " + "only local file system and hdfs file system supported. ") save_path = save_path if save_path else self._prefix_name directory, base_name = os.path.split(save_path) @@ -151,8 +151,8 @@ class Saver(object): def restore(self, sess, reading_path): logger.debug("======== Start restoring ========") if not check_file_system_is_valid(reading_path): - raise ValueError(f"the path to save sparse embedding table data belong to invalid file system, " - f"only local file system and hdfs file system supported. ") + raise ValueError("the path to save sparse embedding table data belong to invalid file system, " + "only local file system and hdfs file system supported. ") directory, base_name = os.path.split(reading_path) ckpt_name = f"sparse-{base_name}" -- Gitee From c38e841c667c2c4eeb9203fd1cc8a54be3adfb6f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 20 Nov 2023 15:09:04 +0800 Subject: [PATCH 464/551] Match-id-ac28746327a33568a626e50b7b69bb6d02c896ed --- .../ckpt_data_handler/ckpt_data_handler.cpp | 4 +- .../ckpt_data_handler/ckpt_data_handler.h | 37 ++++++++++--------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.cpp b/src/core/ckpt_data_handler/ckpt_data_handler.cpp index 8faec17b..e9c93476 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.cpp +++ b/src/core/ckpt_data_handler/ckpt_data_handler.cpp @@ -12,12 +12,12 @@ using namespace MxRec; uint32_t CkptDataHandler::GetDataElmtBytes(CkptDataType dataType) { - return dataElmtBytes.at(static_cast(dataType)); + return dataTypeInfoMap.at(dataType).second; } string CkptDataHandler::GetDataDirName(CkptDataType dataType) { - return dataDirNames.at(static_cast(dataType)); + return dataTypeInfoMap.at(dataType).first; } void CkptDataHandler::CleanTransfer() diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index f65dae5e..460438f8 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -39,26 +39,27 @@ namespace MxRec { CkptDataType dataType, string embName, CkptTransData& loadedData, CkptData& ckptData); protected: - const vector dataDirNames { - "embedding_info", - "embedding_data", - "embedding_hashmap", - "dev_offset_2_Batch_n_Key", - "embedding_current_status", - "max_offset", - "key", - "table_2_threshold", - "history_record", - "attribute", - "ddr_key_freq_map", - "exclude_ddr_key_freq_map", - "evict_pos", - "key_count_map" + // The dataTypeInfoMap stores the directory name and number of bytes of storage elements + // corresponding to each data type. + const map> dataTypeInfoMap { + {CkptDataType::EMB_INFO, make_pair("embedding_info", 4) }, + {CkptDataType::EMB_DATA, make_pair("embedding_data", 4) }, + {CkptDataType::EMB_HASHMAP, make_pair("embedding_hashmap", 8) }, + {CkptDataType::DEV_OFFSET, make_pair("dev_offset_2_Batch_n_Key", 8) }, + {CkptDataType::EMB_CURR_STAT, make_pair("embedding_current_status", 4)}, + {CkptDataType::NDDR_OFFSET, make_pair("max_offset", 4)}, + {CkptDataType::NDDR_FEATMAP, make_pair("key", 8)}, + {CkptDataType::TABLE_2_THRESH, make_pair("table_2_threshold", 4)}, + {CkptDataType::HIST_REC, make_pair("history_record", 8)}, + {CkptDataType::ATTRIBUTE, make_pair("attribute", 8)}, + {CkptDataType::DDR_FREQ_MAP, make_pair("ddr_key_freq_map", 8)}, + {CkptDataType::EXCLUDE_FREQ_MAP, make_pair("exclude_ddr_key_freq_map", 8)}, + {CkptDataType::EVICT_POS, make_pair("evict_pos", 8)}, + {CkptDataType::KEY_COUNT_MAP, make_pair("key_count_map", 8)} }; - const vector dataElmtBytes { 4, 4, 8, 8, 4, 4, 8, 4, 8, 8, 8, 8, 8, 8}; - const uint32_t eightBytes { 8 }; - const uint32_t fourBytes { 4 }; + const uint32_t eightBytes = 8; + const uint32_t fourBytes = 4; CkptTransData transferData; -- Gitee From 12ea0452337681ef1cb90434dd143a65db3074fa Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 20 Nov 2023 15:56:53 +0800 Subject: [PATCH 465/551] Match-id-d745037a9e2b282646b735755654ecfcb676433e --- mx_rec/core/embedding.py | 5 +++-- mx_rec/optimizers/adagrad.py | 6 +++--- mx_rec/optimizers/ftrl.py | 20 +++++++++---------- mx_rec/optimizers/gradient_descent.py | 4 ++-- mx_rec/optimizers/gradient_descent_by_addr.py | 6 +++--- mx_rec/optimizers/lazy_adam.py | 10 +++++----- mx_rec/optimizers/lazy_adam_by_addr.py | 10 +++++----- mx_rec/optimizers/momentum.py | 6 +++--- mx_rec/util/global_env_conf.py | 1 + mx_rec/validator/validator.py | 19 ++++++++++++++---- 10 files changed, 50 insertions(+), 37 deletions(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index e92c9458..91209285 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -25,7 +25,8 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, ge get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set, \ get_table_instance_by_name, get_asc_manager from mx_rec.validator.validator import ClassValidator, StringValidator, SSDFeatureValidator, \ - para_checker_decorator, IntValidator, NumValidator, OptionValidator, OptionalIntValidator, OptionalStringValidator + para_checker_decorator, IntValidator, NumValidator, OptionValidator, OptionalIntValidator, \ + OptionalStringValidator, FloatValidator from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.util.normalization import fix_invalid_table_name from mx_rec.util.global_env_conf import global_env @@ -45,7 +46,7 @@ from mx_rec.util.log import logger ("ssd_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), ("ssd_data_path", ClassValidator, {"classes": (list, tuple)}), ("is_save", ClassValidator, {"classes": (bool, )}), - ("init_param", NumValidator, {"min_value": -10, "max_value": 10}, ["check_value"]), + ("init_param", FloatValidator, {"min_value": -10, "max_value": 10}, ["check_value"]), ("all2all_gradients_op", OptionValidator, {"options": [i.value for i in list(All2allGradientsOp)]}), ("value_dtype", OptionValidator, {"options": [tf.float32]}), ("shard_num", IntValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index c4b23cdd..5aaf1220 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -16,12 +16,12 @@ from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_table_instance, insert_removing_var_list from mx_rec.constants.constants import MAX_INT32 -from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, NumValidator +from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @para_checker_decorator(check_option_list=[ - ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), - ("initial_accumulator_value", NumValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), + ("learning_rate", FloatValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("initial_accumulator_value", FloatValidator, {"min_value": 0, "max_value": MAX_INT32}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool, )}), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 5cbc83bd..827b04d1 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -22,21 +22,21 @@ from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_table_instance, insert_removing_var_list from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.constants.constants import MAX_INT32 -from mx_rec.validator.validator import para_checker_decorator, OptionalStringValidator, ClassValidator, NumValidator, \ - StringValidator +from mx_rec.validator.validator import para_checker_decorator, ClassValidator, NumValidator, StringValidator, \ + FloatValidator @para_checker_decorator(check_option_list=[ - ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), - ("initial_accumulator_value", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), - ("learning_rate_power", NumValidator, {"min_value": -MAX_INT32, "max_value": 0}, ["check_value"]), - ("l1_regularization_strength", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), - ("l2_regularization_strength", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), - ("l2_shrinkage_regularization_strength", NumValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), + ("learning_rate", FloatValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("initial_accumulator_value", FloatValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), + ("learning_rate_power", FloatValidator, {"min_value": -MAX_INT32, "max_value": 0}, ["check_value"]), + ("l1_regularization_strength", FloatValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), + ("l2_regularization_strength", FloatValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), + ("l2_shrinkage_regularization_strength", FloatValidator, {"min_value": 0, "max_value": 1e4}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool,)}), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), - ("accum_name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), - ("linear_name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) + ("accum_name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), + ("linear_name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl", **kwargs): return CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 39f5725b..5456d1bb 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -15,11 +15,11 @@ from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import MAX_INT32 -from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, NumValidator +from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @para_checker_decorator(check_option_list=[ - ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("learning_rate", FloatValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool,)}), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index 7542121d..4e0bb7da 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -14,12 +14,12 @@ from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer from mx_rec.constants.constants import MAX_INT32 -from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, NumValidator +from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @para_checker_decorator(check_option_list=[ - ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), - ("weight_decay", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("learning_rate", FloatValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("weight_decay", FloatValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool,)}), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index ec67fab0..578708fa 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -21,14 +21,14 @@ from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_table_instance, insert_removing_var_list from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.constants.constants import MAX_INT32 -from mx_rec.validator.validator import para_checker_decorator, StringValidator, NumValidator +from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator @para_checker_decorator(check_option_list=[ - ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), - ("beta1", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value_for_open_interval"]), - ("beta2", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), - ("epsilon", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value_for_left_open_interval"]), + ("learning_rate", FloatValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("beta1", FloatValidator, {"min_value": 0, "max_value": 1}, ["check_value_for_open_interval"]), + ("beta2", FloatValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("epsilon", FloatValidator, {"min_value": 0, "max_value": 1}, ["check_value_for_left_open_interval"]), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, name="LazyAdam"): diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index 1e601221..d42f8a7f 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -15,14 +15,14 @@ from tensorflow.python.training import adam from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import MAX_INT32 -from mx_rec.validator.validator import para_checker_decorator, StringValidator, NumValidator +from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator @para_checker_decorator(check_option_list=[ - ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), - ("beta1", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), - ("beta2", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), - ("epsilon", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("learning_rate", FloatValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("beta1", FloatValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("beta2", FloatValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("epsilon", FloatValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer_by_address(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, diff --git a/mx_rec/optimizers/momentum.py b/mx_rec/optimizers/momentum.py index dff1a14f..8c8737fc 100644 --- a/mx_rec/optimizers/momentum.py +++ b/mx_rec/optimizers/momentum.py @@ -17,12 +17,12 @@ from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import get_table_instance, insert_removing_var_list from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.constants.constants import MAX_INT32 -from mx_rec.validator.validator import para_checker_decorator, StringValidator, NumValidator, ClassValidator +from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator, ClassValidator @para_checker_decorator(check_option_list=[ - ("learning_rate", NumValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), - ("mom", NumValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), + ("learning_rate", FloatValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), + ("mom", FloatValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), ("use_locking", ClassValidator, {"classes": (bool,)}), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), ("enable_nesterov", ClassValidator, {"classes": (bool,)}), diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index a3c38b93..29c58bc9 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -69,6 +69,7 @@ def get_global_env_conf() -> RecEnv: ("mxrec_log_level", OptionValidator, {"options": [i.value for i in list(RecPyLogLevel)]}), ("save_easy", OptionValidator, {"options": [i.value for i in list(Flag)]}), ("rank_table_file", DirectoryValidator, {}, ["check_exists_if_not_empty"]), + ("tf_device", OptionValidator, {"options": [i.value for i in list(TFDevice)]}), ("apply_gradients_strategy", OptionValidator, {"options": [i.value for i in list(ApplyGradientsStrategy)]}), ("acl_timeout", Convert2intValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), ("hd_channel_size", Convert2intValidator, diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index 80cc8b28..d0ae61a3 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -334,8 +334,9 @@ class NumValidator(Validator): number validator float or int """ - def __init__(self, name: str, value: int, min_value: int = None, max_value: int = None, - invalid_options: List = None, constrained_options: List = None, msg: str = ""): + def __init__(self, name: str, value: Union[int, float], min_value: Union[int, float] = None, + max_value: Union[int, float] = None, invalid_options: List = None, + constrained_options: List = None, msg: str = ""): if isinstance(value, tf.TensorShape) and value.ndims == 1: value = value.as_list()[0] if isinstance(value, tf.Tensor): @@ -347,8 +348,6 @@ class NumValidator(Validator): self.max_value = max_value self.invalid_options = invalid_options self.constrained_options = constrained_options - self.register_checker(lambda: isinstance(self.value, (int, float)), - msg if msg else f"type of '{name}' is not int or float") def check_value(self): if self.min_value is not None: @@ -393,6 +392,18 @@ class NumValidator(Validator): return self +class FloatValidator(NumValidator): + """ + float type data validator + """ + + def __init__(self, name: str, value: float, min_value: float = None, max_value: float = None, + invalid_options: List = None, constrained_options: List = None, msg: str = ""): + super(FloatValidator, self).__init__(name, value, min_value, max_value, invalid_options, constrained_options, + msg) + self.register_checker(lambda: isinstance(self.value, float), msg if msg else f"type of '{name}' is not float") + + class IntValidator(NumValidator): """ Int type validator -- Gitee From 5f7e9600be8d8fdadc899667ab1157add037b115 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 21 Nov 2023 11:57:27 +0800 Subject: [PATCH 466/551] Match-id-25931d63eaac05b2e46ac57b2ab1046c7a28dea4 --- build/build_tf1.sh | 1 + tests/__init__.py | 0 tests/mx_rec/__init__.py | 0 tests/mx_rec/core/__init__.py | 0 tests/mx_rec/core/test_build_graph.py | 61 ----------------------- tests/mx_rec/validator/__init__.py | 0 tests/mx_rec/validator/test_validators.py | 39 +-------------- tests/run_python_dt.sh | 49 ++++++++++++++++++ 8 files changed, 51 insertions(+), 99 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/mx_rec/__init__.py create mode 100644 tests/mx_rec/core/__init__.py create mode 100644 tests/mx_rec/validator/__init__.py create mode 100644 tests/run_python_dt.sh diff --git a/build/build_tf1.sh b/build/build_tf1.sh index 90b1c493..14c07a9c 100644 --- a/build/build_tf1.sh +++ b/build/build_tf1.sh @@ -15,6 +15,7 @@ cd "$SCRIPT_DIR" if [ "$(uname -m)" = "x86_64" ] then source /opt/buildtools/tf1_env/bin/activate + pip3 install setuptools==65.6.3 tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core deactivate tf1_env fi diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mx_rec/__init__.py b/tests/mx_rec/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mx_rec/core/__init__.py b/tests/mx_rec/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mx_rec/core/test_build_graph.py b/tests/mx_rec/core/test_build_graph.py index 291d297e..91fb0b29 100644 --- a/tests/mx_rec/core/test_build_graph.py +++ b/tests/mx_rec/core/test_build_graph.py @@ -11,7 +11,6 @@ import tensorflow as tf from tests.mx_rec.core.mxrec_pybind_mock import MxRecPybindMock from tests.mx_rec.core.initializer_mock import InitializerMock -from mx_rec.core.asc import build_graph from mx_rec.util.tf_version_adapter import npu_ops sys.modules['mxrec_pybind'] = MxRecPybindMock @@ -109,47 +108,6 @@ class TestBuildGraph(unittest.TestCase): dtype=tf.float32) return input_table - @mock.patch("npu_bridge.hccl.hccl_ops") - @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") - def test_get_restore_vector(self, tf1_hccl_ops_mock, - tf1_save_mock): - with mock.patch.object(npu_ops, "gen_npu_ops") as mock_npu_ops: - from mx_rec.core.asc.build_graph import get_restore_vector - restore_vector_mock = tf.constant(value=1, shape=[2908800], dtype=tf.int32, - name="aicpu_getnext_restore_vector/GetNext") - hot_pos_mock = tf.constant( - value=1, - name="restore_vector/one_ascend_hash_embedding/GetNext", - shape=[2730, ], - dtype=tf.int32) - mock_npu_ops.get_next.return_value = [restore_vector_mock, hot_pos_mock] - tf1_hccl_ops_mock.return_value = None - tf1_save_mock.return_value = None - input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, True, 6, - False) - input_config = self.get_input_config(input_config_instance) - res_restore_vector, res_hot_emb = get_restore_vector(input_config) - self.assertEqual(res_restore_vector, restore_vector_mock) - self.assertEqual(res_hot_emb, hot_pos_mock) - - @mock.patch("npu_bridge.hccl.hccl_ops") - @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") - def test_get_restore_vector_no_hot_embed(self, tf1_hccl_ops_mock, - tf1_save_mock): - with mock.patch.object(npu_ops, "gen_npu_ops") as mock_npu_ops: - from mx_rec.core.asc.build_graph import get_restore_vector - restore_vector_mock = tf.constant(value=1, shape=[2908800], dtype=tf.int32, - name="aicpu_getnext_restore_vector/GetNext") - mock_npu_ops.get_next.return_value = [restore_vector_mock] - tf1_hccl_ops_mock.return_value = None - tf1_save_mock.return_value = None - input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, - False, 6, False) - input_config = self.get_input_config(input_config_instance) - res_restore_vector, hot_pos_vector = get_restore_vector(input_config) - self.assertEqual(res_restore_vector, restore_vector_mock) - self.assertIsNone(hot_pos_vector) - @mock.patch('npu_bridge.estimator.npu_ops') @mock.patch("npu_bridge.hccl.hccl_ops") @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") @@ -267,25 +225,6 @@ class TestBuildGraph(unittest.TestCase): self.assertEqual(res_all2all_args.dtype, tf.constant(value=1, shape=[8, 8], dtype=tf.int64, name="mul").dtype) - @mock.patch("npu_bridge.hccl.hccl_ops") - @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") - def test_get_preprocessed_tensor_for_asc(self, tf1_hccl_ops_mock, tf1_save_mock): - with mock.patch.object(npu_ops, "gen_npu_ops", return_value=self.get_next_mock()), \ - mock.patch.object(build_graph, "get_id_offsets", return_value=self.get_id_offsets_mock()), \ - mock.patch.object(build_graph, "get_all2all_args", return_value=self.get_all2all_mock()), \ - mock.patch.object(build_graph, "get_restore_vector", return_value=self.get_restore_vector_mock()): - from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc - os.environ["use_static"] = "False" - tf1_hccl_ops_mock.return_value = None - tf1_save_mock.return_value = None - - input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, True, 6, - False) - input_config = self.get_input_config(input_config_instance) - sparse_table = tf.Variable(tf.zeros([875000, 8]), name='inference/one_ascend_hash_embedding', - dtype=tf.float32) - res = get_preprocessed_tensor_for_asc(sparse_table, input_config) - self.assertEqual(res.get("hot_pos"), self.get_restore_vector_mock()[1]) if __name__ == '__main__': diff --git a/tests/mx_rec/validator/__init__.py b/tests/mx_rec/validator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mx_rec/validator/test_validators.py b/tests/mx_rec/validator/test_validators.py index 72f180f1..bd4ee631 100644 --- a/tests/mx_rec/validator/test_validators.py +++ b/tests/mx_rec/validator/test_validators.py @@ -31,19 +31,6 @@ class ParameterCheckerTest(unittest.TestCase): """ super().tearDown() - def test_validator_should_return_default_if_invalid(self): - validation = Validator("v1", 'aa') - validation.register_checker(lambda x: len(x) < 5, 'length of string should be less than 5') - self.assertTrue(validation.is_valid()) - try: - validation = Validator("v2", '123456') - validation.register_checker(lambda x: len(x) < 5, 'length of string should be less than 5') - validation.is_valid() - except ValueError as exp: - self.assertEqual(type(exp), ValueError) - else: - self.fail("ValueError not raised.") - def test_string_validator_max_len_parameter(self): try: StringValidator("val", 'aa.1245', max_len=3).check_string_length().check().is_valid() @@ -135,33 +122,9 @@ class ParameterCheckerTest(unittest.TestCase): os.remove(path) os.removedirs(temp_dir) - def test_directory_check(self): - - try: - DirectoryValidator("val", 'a/b/.././c/a.txt').check_not_soft_link().check().is_valid() - except ValueError as exp: - self.assertEqual(type(exp), ValueError) - else: - self.fail("ValueError not raised.") - - try: - DirectoryValidator("val", "").check_is_not_none().check().is_valid() - except ValueError as exp: - self.assertEqual(type(exp), ValueError) - else: - self.fail("ValueError not raised.") - - try: - DirectoryValidator("val", None).check_is_not_none().check().is_valid() - except ValueError as exp: - self.assertEqual(type(exp), ValueError) - else: - self.fail("ValueError not raised.") - self.assertTrue(DirectoryValidator("val", "a/bc/d").check().is_valid()) - def test_decorator(self): @para_checker_decorator(check_option_list=[ - ("class_arg", ClassValidator, {"classed": (bool,)}), + ("class_arg", ClassValidator, {"classes": (bool,)}), ("options_arg", OptionValidator, {"options": (1, 2, 3)}), (["options_arg", "int_arg"], ValueCompareValidator, {"target": -1}, ["check_all_not_equal_to_target"]), ("string_arg", OptionalStringValidator, {"max_len": 255}, ["check_string_length"]), diff --git a/tests/run_python_dt.sh b/tests/run_python_dt.sh new file mode 100644 index 00000000..a0fcd861 --- /dev/null +++ b/tests/run_python_dt.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Description: start script. +# Author: MindX SDK +# Create: 2023 +# History: NA + +set -e + +CUR_PATH=$(cd "$(dirname "$0")" || { warn "Failed to check path/to/run_python_dt.sh" ; exit ; } ; pwd) +TOP_PATH="${CUR_PATH}"/../ + +# build mxRec and get output directory +bash "$TOP_PATH"/build/build_tf1.sh + +# create libasc directory and copy so files into it +cd "$TOP_PATH"/mx_rec +mkdir -p libasc +cp -f "$TOP_PATH"/output/*.so ./libasc +cd - + +# set environment variable +export PYTHONPATH="${TOP_PATH}"/output:$PYTHONPATH +export LD_LIBRARY_PATH="${TOP_PATH}"/output:/usr/local/lib:$LD_LIBRARY_PATH + +rm -rf result +mkdir -p result + +function run_test_cases() { + echo "Get testcases final result." + pytest --cov="${CUR_PATH}"/../mx_rec --cov-report=html --cov-report=xml --junit-xml=./final.xml --html=./final.html --self-contained-html --durations=5 -vv + coverage xml -i --omit="build/*,cust_op/*,src/*" + cp coverage.xml final.xml final.html ./result + cp -r htmlcov ./result + rm -rf coverage.xml final.xml final.html htmlcov +} + +echo "************************************* Start MxRec LLT Test *************************************" +start=$(date +%s) +run_test_cases +ret=$? +end=$(date +%s) +echo "************************************* End MxRec LLT Test *************************************" +echo "LLT running take: $(expr "${end}" - "${start}") seconds" + +rm -rf "$TOP_PATH"/mx_rec/libasc + +exit "${ret}" -- Gitee From bf9e8324c99036efc010c5754e26f91b86aa5313 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 21 Nov 2023 14:32:29 +0800 Subject: [PATCH 467/551] Match-id-f64f76ce7f0e3b9de9b5a4b9489e8a46f7eaa619 --- mx_rec/graph/modifier.py | 76 ++++++++++++++++++++++++++++------------ mx_rec/graph/utils.py | 34 +++++++++++++----- 2 files changed, 79 insertions(+), 31 deletions(-) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index ff2391f9..38f02bad 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -3,12 +3,15 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. from collections import defaultdict -from typing import Any +from typing import Any, List, Dict, DefaultDict, Tuple, Union +from collections.abc import Callable import tensorflow as tf +from tensorflow import Tensor from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter from tensorflow.python.framework.ops import Operation from tensorflow.python.framework.errors_impl import InvalidArgumentError +from tensorflow.core.framework.graph_pb2 import GraphDef from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.feature_spec import FeatureSpec @@ -22,14 +25,19 @@ from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_in set_iterator_type from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ - export_pb_graph, make_sorted_key_to_tensor_list + export_pb_graph, make_sorted_key_to_tensor_list, ReplacementSpec, AnchorRecord from mx_rec.graph.merge_lookup import do_merge_lookup from mx_rec.validator.validator import para_checker_decorator, ClassValidator from mx_rec.util.log import logger -def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tensor_names=None, - pipeline_input_indexes=None): +def get_preprocessing_map_func( + graph_def: GraphDef, + input_names: List[str], + output_names: List[str], + batch_tensor_names: List[str] = None, + pipeline_input_indexes: List[int] = None +) -> Callable: input_names = check_input_list(input_names, str) output_names = check_input_list(output_names, str) batch_tensor_names = check_input_list(batch_tensor_names, str) @@ -53,7 +61,7 @@ def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tenso """ - def parse_tensor(data_tensor: tf.Tensor, data_batch: dict, key: str = None): + def parse_tensor(data_tensor: Tensor, data_batch: dict, key: str = None): """ 将待解析batch中的tensor写入解析后的batch中,如果key存在则使用原key,不存在则生成batch中字典序最小的key. Args: @@ -81,7 +89,7 @@ def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tenso for data_arg in data_args: parse_batch(data_arg, data_batch, key) return - elif isinstance(data_args, tf.Tensor): + elif isinstance(data_args, Tensor): # 将old batch中的tensor加入到dict中 parse_tensor(data_args, data_batch, key) return @@ -119,7 +127,13 @@ def get_preprocessing_map_func(graph_def, input_names, output_names, batch_tenso return map_func -def get_input_index_list(cutting_point_list, replacement_specs, mapping_name_list, base_count, timestamp_index=None): +def get_input_index_list( + cutting_point_list: List[Tensor], + replacement_specs: ReplacementSpec, + mapping_name_list: List[str], + base_count: int, + timestamp_index: int = None +) -> List[int]: input_index_list = [] for cutting_point in cutting_point_list: if cutting_point in replacement_specs: @@ -137,7 +151,7 @@ def get_input_index_list(cutting_point_list, replacement_specs, mapping_name_lis return input_index_list -def find_make_iterator_op(batch_tensor): +def find_make_iterator_op(batch_tensor: Tensor) -> Operation: graph = tf.compat.v1.get_default_graph() operations = graph.get_operations() for each_op in operations: @@ -151,7 +165,7 @@ def find_make_iterator_op(batch_tensor): @performance("find_target_dataset_op") -def find_target_dataset_op(base_ops, op_type): +def find_target_dataset_op(base_ops: Operation, op_type: str) -> Operation: base_ops = check_input_list(base_ops, tf.Operation) parent_ops = base_ops @@ -205,7 +219,10 @@ def get_dataset_op(get_next_op: Operation) -> Operation: return target_op -def get_passing_tensor_list(src_tensors, target_op): +def get_passing_tensor_list( + src_tensors: List[Tensor], + target_op: Operation +) -> Tuple[List[Tensor], List[int], List[Tensor]]: def get_passing_tensors(src_tensor): passing_tensors = [] tensor_list = [src_tensor] @@ -223,7 +240,7 @@ def get_passing_tensor_list(src_tensors, target_op): return passing_tensors - src_tensors = check_input_list(src_tensors, tf.Tensor) + src_tensors = check_input_list(src_tensors, Tensor) passing_tensor_list = [] sub_src_tensors = [] for tensor in src_tensors: @@ -242,7 +259,7 @@ def get_passing_tensor_list(src_tensors, target_op): return passing_tensor_list, output_index_list, sub_src_tensors -def find_target_instance_dataset(variant_tensor): +def find_target_instance_dataset(variant_tensor: Tensor) -> DatasetV1Adapter: dataset_instance_list = tf.compat.v1.get_collection("dataset_group") for ins in dataset_instance_list: if ins._variant_tensor == variant_tensor: @@ -259,7 +276,10 @@ def find_target_instance_dataset(variant_tensor): raise LookupError(f"Can not find target instance, whose variant_tensor is '{variant_tensor}' respectively.") -def get_sub_graph(input_tensors, output_tensors): +def get_sub_graph( + input_tensors: List[Tensor], + output_tensors: List[Tensor] +) -> Tuple[GraphDef, List[str], List[str]]: input_tensors = check_input_list(input_tensors, tf.Tensor) output_tensors = check_input_list(output_tensors, tf.Tensor) input_op_name_list = [tensor.op.name for tensor in input_tensors] @@ -285,7 +305,9 @@ def get_sub_graph(input_tensors, output_tensors): return sub_graph_def, input_name_list, output_name_list -def update_input_tensor_with_new_batch(replacement_specs: dict, new_get_next_op_name: str, new_batch: dict): +def update_input_tensor_with_new_batch(replacement_specs: ReplacementSpec, + new_get_next_op_name: str, + new_batch: Dict[str, Tensor]): """ 用新batch中的IteratorGetNext替换计算图中老batch的IteratorGetNext. @@ -335,12 +357,15 @@ def get_dataset_tensor_count(dataset: DatasetV1Adapter) -> int: @para_checker_decorator( check_option_list=[("dump_graph", ClassValidator, {"classes": (bool,)})] ) -def modify_graph_and_start_emb_cache(dump_graph=False): +def modify_graph_and_start_emb_cache(dump_graph: bool = False): modify_graph_for_asc(dump_graph=dump_graph) start_asc_pipeline() -def generate_get_next_op_specs(cutting_point_list, dump_graph): +def generate_get_next_op_specs( + cutting_point_list: List[Tensor], + dump_graph: bool = False +) -> Dict[Tensor, ReplacementSpec]: get_next_op_map = defaultdict(dict) for input_tensor in cutting_point_list: get_next_op = find_target_dataset_op(input_tensor.op, "IteratorGetNext") @@ -409,8 +434,13 @@ def get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapt return src_dataset -def get_tgt_dataset(src_dataset: DatasetV1Adapter, sub_cutting_point_list: list, records: dict, - dump_graph: bool = False, prefetch: int = 10) -> DatasetV1Adapter: +def get_tgt_dataset( + src_dataset: DatasetV1Adapter, + sub_cutting_point_list: List[Tensor], + records: AnchorRecord, + dump_graph: bool = False, + prefetch: int = 10 +) -> DatasetV1Adapter: """ 根据原始数据集生成新的数据集实例. @@ -445,7 +475,10 @@ def get_tgt_dataset(src_dataset: DatasetV1Adapter, sub_cutting_point_list: list, return tgt_dataset -def update_iterator_getnext(get_next_op: Operation, tgt_dataset: DatasetV1Adapter, is_training: bool, records: dict): +def update_iterator_getnext(get_next_op: Operation, + tgt_dataset: DatasetV1Adapter, + is_training: bool, + records: AnchorRecord): """ 用新数据集中的`IteratorGetNext`算子替换计算图中原始数据集的`IteratorGetNext`算子,即用新数据集的batch替换原始数据集的batch. @@ -458,7 +491,6 @@ def update_iterator_getnext(get_next_op: Operation, tgt_dataset: DatasetV1Adapte Returns: None """ - if not get_next_op.outputs: raise RuntimeError("There is no tensor in the dataset. Please check the dataset and data processing.") iterator_type = "" @@ -490,7 +522,7 @@ def update_iterator_getnext(get_next_op: Operation, tgt_dataset: DatasetV1Adapte @performance("graph_modifier") -def modify_graph_for_asc(dump_graph=False, prefetch=10): +def modify_graph_for_asc(dump_graph: bool = False, prefetch: int = 10): cutting_point_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE) check_cutting_points(cutting_point_list) if not cutting_point_list: @@ -539,7 +571,7 @@ def modify_graph_for_asc(dump_graph=False, prefetch=10): export_pb_graph("new_graph.pb", dump_graph) -def get_timestamp_index(get_next_op, is_training): +def get_timestamp_index(get_next_op: Operation, is_training: bool) -> int: timestamp_tensor_list = tf.compat.v1.get_collection(ASCEND_TIMESTAMP) timestamp_index = None for timestamp in timestamp_tensor_list: diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index 8c063c26..e78803a8 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -2,10 +2,14 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -from collections import defaultdict import os +from collections import defaultdict +from typing import Any, List, Dict, DefaultDict, Tuple, Union import tensorflow as tf +from tensorflow import Tensor +from tensorflow import Operation +from tensorflow.core.framework.graph_pb2 import GraphDef from tensorflow.python.framework.errors_impl import InvalidArgumentError from mx_rec.constants.constants import ASCAnchorAttr, DUMP_MIDIFY_GRAPH_FILE_MODE @@ -13,7 +17,11 @@ from mx_rec.core.embedding import SparseEmbedding from mx_rec.util.log import logger -def check_input_list(objs, obj_type): +ReplacementSpec = DefaultDict[Tensor, List[Tuple[int, Operation]]] +AnchorRecord = Dict[str, Union[ReplacementSpec, GraphDef, bool, List[Tensor], List[int], List[str]]] + + +def check_input_list(objs: Union[object, List[object]], obj_type: type) -> Union[object, List[object]]: if isinstance(objs, obj_type): objs = [objs] @@ -25,7 +33,7 @@ def check_input_list(objs, obj_type): return objs -def find_parent_op(operator): +def find_parent_op(operator: Operation) -> List[Operation]: parent_ops = [] for input_tensor in operator.inputs: parent_op = input_tensor.op @@ -34,7 +42,7 @@ def find_parent_op(operator): return parent_ops -def check_cutting_points(cutting_point_list): +def check_cutting_points(cutting_point_list: List[Tensor]): for tensor in cutting_point_list: if not isinstance(tensor, tf.Tensor): raise TypeError(f"Collection ASCEND_CUTTING_POINT can only contain Tensors, but '{tensor}' was found.") @@ -43,7 +51,7 @@ def check_cutting_points(cutting_point_list): raise ValueError(f"Cutting point can only be the output of an Operator 'Identity'.") -def record_ops_to_replace(src_op): +def record_ops_to_replace(src_op: Operation) -> ReplacementSpec: replacement_specs = defaultdict(list) output_list = src_op.outputs op_list = tf.compat.v1.get_default_graph().get_operations() @@ -56,7 +64,7 @@ def record_ops_to_replace(src_op): return replacement_specs -def replace_anchor(replacement_specs: defaultdict, new_tensor_list: list): +def replace_anchor(replacement_specs: ReplacementSpec, new_tensor_list: List[Tensor]): if len(replacement_specs) != len(new_tensor_list): raise ValueError(f"Given replacement_specs and new_tensor_list must have the same length. " f"replacement_specs: {replacement_specs}, new_tensor_list: {new_tensor_list}") @@ -72,7 +80,11 @@ def replace_anchor(replacement_specs: defaultdict, new_tensor_list: list): f"new tensor: {new_tensor_list[tensor_idx]}.") from err -def export_pb_graph(file_name, dump_graph, graph_def=None, export_path="./export_graph", as_text=False): +def export_pb_graph(file_name: str, + dump_graph: bool = False, + graph_def: GraphDef = None, + export_path: str = "./export_graph", + as_text: bool = False): """ Save tensorflow graph before and after modifier graph :param file_name: FileName of the graph @@ -90,7 +102,11 @@ def export_pb_graph(file_name, dump_graph, graph_def=None, export_path="./export tf.io.write_graph(graph_def, export_path, file_name, as_text) -def make_sorted_key_to_tensor_list(element_spec, sorted_keys, prefix=""): +def make_sorted_key_to_tensor_list( + element_spec: List[Dict[str, Tensor]], + sorted_keys: List[str], + prefix: str = "" +) -> List[str]: if isinstance(element_spec, tf.TensorSpec): sorted_keys.append(prefix) return sorted_keys @@ -115,7 +131,7 @@ def make_sorted_key_to_tensor_list(element_spec, sorted_keys, prefix=""): raise TypeError(f"Given element_spec, whose type is {type(element_spec)}, is invalid.") -def replace_anchor_vec(cutting_point: tf.Tensor, attribute: ASCAnchorAttr, anchor: tf.Tensor): +def replace_anchor_vec(cutting_point: Tensor, attribute: ASCAnchorAttr, anchor: Tensor): """ 根据打桩节点的名字找到以此为输入的op,并将该op的输入替换为入参anchor. -- Gitee From 7a939e6e306bb83bf472d20d21071dfa89728480 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 21 Nov 2023 16:39:53 +0800 Subject: [PATCH 468/551] Match-id-98781a4a6c7e58a29126f746aa15a74b779802b3 --- mx_rec/constants/constants.py | 1 + mx_rec/util/global_env_conf.py | 7 +++++-- src/core/key_process/key_process.cpp | 4 ++++ src/core/utils/config.cpp | 10 +++++++++- src/core/utils/config.h | 2 ++ 5 files changed, 21 insertions(+), 3 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 5b555b2e..4eb18316 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -109,6 +109,7 @@ class EnvOption(Enum): GLOG_STDERRTHREAHOLD = "GLOG_stderrthreshold" USE_COMBINE_FAAE = "USE_COMBINE_FAAE" STAT_ON = "STAT_ON" + RECORD_KEY_COUNT = "RECORD_KEY_COUNT" class DataName(Enum): diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index 29c58bc9..f2b802e8 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -33,6 +33,7 @@ class RecEnv: glog_stderrthreahold: str use_combine_faae: str stat_on: str + record_key_count: str def get_global_env_conf() -> RecEnv: @@ -59,7 +60,8 @@ def get_global_env_conf() -> RecEnv: hot_emb_update_step=os.getenv(EnvOption.HOT_EMB_UPDATE_STEP.value, DEFAULT_HOT_EMB_UPDATE_STEP), glog_stderrthreahold=os.getenv(EnvOption.GLOG_STDERRTHREAHOLD.value, RecCPPLogLevel.INFO.value), use_combine_faae=os.getenv(EnvOption.USE_COMBINE_FAAE.value, Flag.FALSE.value), - stat_on=os.getenv(EnvOption.STAT_ON.value, Flag.FALSE.value) + stat_on=os.getenv(EnvOption.STAT_ON.value, Flag.FALSE.value), + record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value) ) return rec_env @@ -84,7 +86,8 @@ def get_global_env_conf() -> RecEnv: {"min_value": MIN_HOT_EMB_UPDATE_STEP, "max_value": MAX_HOT_EMB_UPDATE_STEP}, ["check_value"]), ("glog_stderrthreahold", OptionValidator, {"options": [i.value for i in list(RecCPPLogLevel)]}), ("use_combine_faae", OptionValidator, {"options": [i.value for i in list(Flag)]}), - ("stat_on", OptionValidator, {"options": [i.value for i in list(Flag)]}) + ("stat_on", OptionValidator, {"options": [i.value for i in list(Flag)]}), + ("record_key_count", OptionValidator, {"options": [i.value for i in list(Flag)]}) ]) def check_env(**kwargs): pass diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 92f641a0..efa2abf2 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -355,6 +355,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch std::lock_guard lock(loadSaveMut[channel][threadId]); // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) + RecordKeyCountMap(batch); if (rankInfo.noDDR) { TimeCost key2OffsetTC; @@ -1489,6 +1490,9 @@ int64_t KeyProcess::GetExpansionTableCapacity(const string& embName) void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) { + if (!GlobalEnv::recordKeyCount) { + return; + } std::lock_guard lk(mut); size_t miniBs = batch->Size(); auto* batchData = batch->sample.data(); diff --git a/src/core/utils/config.cpp b/src/core/utils/config.cpp index 46fb3c8a..cabbfee4 100644 --- a/src/core/utils/config.cpp +++ b/src/core/utils/config.cpp @@ -30,6 +30,7 @@ namespace MxRec { int GlobalEnv::glogStderrthreshold = 0; // 默认info级别 bool GlobalEnv::useCombineFaae = false; bool GlobalEnv::statOn = false; + bool GlobalEnv::recordKeyCount = false; // 默认不打开记录key count的开关 /// 配置环境变量,Python侧已经做了变量值校验,CPP侧直接使用即可;bool类型,1代表true,0代表false void ConfigGlobalEnv() @@ -99,6 +100,12 @@ namespace MxRec { if (envStat != nullptr) { GlobalEnv::statOn = (std::stoi(envStat) == 1); } + + // 设置打开记录开关,记录batch中key与出现的count的数目 + const char *envRecordKeyCount = getenv(RecEnvNames::RECORD_KEY_COUNT); + if (envRecordKeyCount != nullptr) { + GlobalEnv::recordKeyCount = (std::stoi(envRecordKeyCount) == 1); + } } void LogGlobalEnv() @@ -115,6 +122,7 @@ namespace MxRec { RecEnvNames::HOT_EMB_UPDATE_STEP, GlobalEnv::hotEmbUpdateStep, RecEnvNames::GLOG_STDERR_THRESHOLD, GlobalEnv::glogStderrthreshold, RecEnvNames::USE_COMBINE_FAAE, GlobalEnv::useCombineFaae, - RecEnvNames::STAT_ON, GlobalEnv::statOn); + RecEnvNames::STAT_ON, GlobalEnv::statOn, + RecEnvNames::RECORD_KEY_COUNT, GlobalEnv::recordKeyCount); } } \ No newline at end of file diff --git a/src/core/utils/config.h b/src/core/utils/config.h index 6b498df9..49b4b501 100644 --- a/src/core/utils/config.h +++ b/src/core/utils/config.h @@ -24,6 +24,7 @@ namespace MxRec { const char *const GLOG_STDERR_THRESHOLD = "GLOG_stderrthreshold"; const char *const USE_COMBINE_FAAE = "USE_COMBINE_FAAE"; const char *const STAT_ON = "STAT_ON"; + const char *const RECORD_KEY_COUNT = "RECORD_KEY_COUNT"; }; namespace ApplyGradientsStrategyOptions { @@ -43,6 +44,7 @@ namespace MxRec { static int glogStderrthreshold; static bool useCombineFaae; static bool statOn; + static bool recordKeyCount; }; void ConfigGlobalEnv(); -- Gitee From ee83f05347f1fd2931ddb2b6f36a1ae21332ad30 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 22 Nov 2023 10:20:04 +0800 Subject: [PATCH 469/551] Match-id-b3dc21617135ffff0ca03e0cac399b074d92a51f --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 5c3f5db9..12ea1b58 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -349,7 +349,9 @@ bool HybridMgmt::Load(const string& loadPath) void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, const FeatureAdmitAndEvict& featAdmitNEvict) { - loadFeatures.push_back(CkptFeatureType::KEY_COUNT_MAP); + if (GlobalEnv::recordKeyCount) { + loadFeatures.push_back(CkptFeatureType::KEY_COUNT_MAP); + } if (!mgmtRankInfo.noDDR) { // DDR模式加载的类型为host的emb表以及hashmap loadFeatures.push_back(CkptFeatureType::HOST_EMB); -- Gitee From 1a685653ca0e04dff7b35a2f4efb7da90990418a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 22 Nov 2023 14:16:50 +0800 Subject: [PATCH 470/551] Match-id-a933301b412048aae5794d72f2994d74218b178c --- mx_rec/core/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 91209285..bdb8ba5c 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -36,7 +36,7 @@ from mx_rec.util.log import logger @para_checker_decorator(check_option_list=[ ("key_dtype", OptionValidator, {"options": (tf.int64, tf.int32, tf.string)}), ("dim", ClassValidator, {"classes": (int, tf.TensorShape)}), - ("dim", NumValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), + ("dim", NumValidator, {"min_value": 1, "max_value": 512}, ["check_value"]), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length", "check_whitelist"]), ("emb_initializer", ClassValidator, {"classes": (InitializerV1, InitializerV2)}), ("optimizer_list", ClassValidator, {"classes": (list, type(None))}), -- Gitee From e270597b5c79059f28a5e8347c7e8a18805c4990 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 22 Nov 2023 17:56:05 +0800 Subject: [PATCH 471/551] Match-id-3aae365cc29b1e5b4f39a8af15a7f86d6482c8e7 --- mx_rec/constants/constants.py | 1 + mx_rec/core/embedding.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 4eb18316..acf756db 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -52,6 +52,7 @@ TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 HASHTABLE_COLLECTION_NAME_LENGTH = 30 MAX_VOCABULARY_SIZE = 10**10 +MAX_DEVICE_VOCABULARY_SIZE = 256 * (10 ** 5) # RANK INFO VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"] diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index bdb8ba5c..025c1d22 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -18,7 +18,8 @@ from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temp from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, \ - ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_VOCABULARY_SIZE + ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_VOCABULARY_SIZE, \ + MAX_DEVICE_VOCABULARY_SIZE from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ @@ -36,12 +37,13 @@ from mx_rec.util.log import logger @para_checker_decorator(check_option_list=[ ("key_dtype", OptionValidator, {"options": (tf.int64, tf.int32, tf.string)}), ("dim", ClassValidator, {"classes": (int, tf.TensorShape)}), - ("dim", NumValidator, {"min_value": 1, "max_value": 512}, ["check_value"]), + ("dim", NumValidator, {"min_value": 1, "max_value": 8192}, ["check_value"]), ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length", "check_whitelist"]), ("emb_initializer", ClassValidator, {"classes": (InitializerV1, InitializerV2)}), ("optimizer_list", ClassValidator, {"classes": (list, type(None))}), (["ssd_vocabulary_size", "ssd_data_path", "host_vocabulary_size"], SSDFeatureValidator), - ("device_vocabulary_size", IntValidator, {"min_value": 1, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), + ("device_vocabulary_size", IntValidator, {"min_value": 1, "max_value": MAX_DEVICE_VOCABULARY_SIZE}, + ["check_value"]), ("host_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), ("ssd_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), ("ssd_data_path", ClassValidator, {"classes": (list, tuple)}), @@ -383,7 +385,7 @@ class SparseEmbedding: raise ValueError("Send count must be a integer which is larger than 0.") check_params() - if self.slice_host_vocabulary_size + self.slice_device_vocabulary_size > MAX_INT32: + if self.slice_host_vocabulary_size + self.slice_device_vocabulary_size > MAX_VOCABULARY_SIZE: raise ValueError(f"Given device_vocabulary_size and host_vocabulary_size was too big for table " f"'{self.table_name}', in which slice_device_vocabulary_size was " f"{self.slice_device_vocabulary_size} and slice_host_vocabulary_size was " -- Gitee From 7c249c2f66107b9a09cf38efbd403b4e47c77236 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 23 Nov 2023 14:27:28 +0800 Subject: [PATCH 472/551] Match-id-4441953a753e35d9507c7968f01808ccd248010b --- src/core/key_process/key_process.cpp | 14 ++++++++++++++ src/core/utils/common.h | 1 + 2 files changed, 15 insertions(+) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index efa2abf2..26b13b9f 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -830,6 +830,17 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple& splitKeys) +{ + for (auto& keys : splitKeys){ + if (keys.size() % ALLTOALLVC_ALIGN == 0){ + continue; + } + padding_size = ALLTOALLVC_ALIGN - (key.size() % ALLTOALLVC_ALIGN); + std::fill_n(keys.back(), padding_size, INVALID_KEY_VALUE); + } +} + auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const -> tuple, vector, vector>> { @@ -857,6 +868,9 @@ auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const } } + if (!rankInfo.useStatic) { + PaddingAlltoallVC(splitKeys); + } // 处理splitKeys对应的count for (int j = 0; j < rankInfo.rankSize; ++j) { vector count; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index e6ea6934..33486448 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -85,6 +85,7 @@ namespace MxRec { constexpr size_t DEFAULT_RANDOM_SEED = 10086; constexpr int INVALID_KEY_VALUE = -1; + constexpr int ALLTOALLVC_ALIGN = 128; constexpr int PROFILING_START_BATCH_ID = 100; constexpr int PROFILING_END_BATCH_ID = 200; constexpr int MGMT_THREAD_BIND = 48; -- Gitee From 6a9ec74461366c69094a28571f6a6fcd993dad97 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 23 Nov 2023 14:53:35 +0800 Subject: [PATCH 473/551] Match-id-eddc7d19aeee14340b9ca8654b8a8405afcd2dca --- src/core/key_process/key_process.cpp | 6 +++--- src/core/key_process/key_process.h | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 26b13b9f..a46eb73f 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -830,10 +830,10 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple& splitKeys) +void KeyProcess::PaddingAlltoallVC(vector& splitKeys) { - for (auto& keys : splitKeys){ - if (keys.size() % ALLTOALLVC_ALIGN == 0){ + for (auto& keys : splitKeys) { + if (keys.size() % ALLTOALLVC_ALIGN == 0) { continue; } padding_size = ALLTOALLVC_ALIGN - (key.size() % ALLTOALLVC_ALIGN); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 79dd844e..dd2c1b34 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -219,6 +219,8 @@ namespace MxRec { auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector>; + void PaddingAlltoallVC(vector& splitKeys); + auto HashSplitWithFAAE(const unique_ptr& batch) const -> tuple, vector, vector>>; -- Gitee From 9abdee26b598cbfde435019d63cc0da687e38d3e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 23 Nov 2023 15:01:32 +0800 Subject: [PATCH 474/551] Match-id-9fc685d930f6f70a4527a1fe88e3440a6f94ff98 --- src/core/key_process/key_process.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a46eb73f..93f43479 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -836,7 +836,7 @@ void KeyProcess::PaddingAlltoallVC(vector& splitKeys) if (keys.size() % ALLTOALLVC_ALIGN == 0) { continue; } - padding_size = ALLTOALLVC_ALIGN - (key.size() % ALLTOALLVC_ALIGN); + int padding_size = ALLTOALLVC_ALIGN - (key.size() % ALLTOALLVC_ALIGN); std::fill_n(keys.back(), padding_size, INVALID_KEY_VALUE); } } -- Gitee From 9f455f4fbe6e5c99780ab119cf71a13231bcdb70 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 23 Nov 2023 15:07:55 +0800 Subject: [PATCH 475/551] Match-id-6a84de2934c82b665e5fedb684f8f2da900ae7fd --- src/core/key_process/key_process.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 93f43479..8123b7e7 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -836,7 +836,7 @@ void KeyProcess::PaddingAlltoallVC(vector& splitKeys) if (keys.size() % ALLTOALLVC_ALIGN == 0) { continue; } - int padding_size = ALLTOALLVC_ALIGN - (key.size() % ALLTOALLVC_ALIGN); + int padding_size = ALLTOALLVC_ALIGN - (keys.size() % ALLTOALLVC_ALIGN); std::fill_n(keys.back(), padding_size, INVALID_KEY_VALUE); } } -- Gitee From 1dea04c489c529bd31e6b653a0329f1b22d69680 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 23 Nov 2023 15:17:42 +0800 Subject: [PATCH 476/551] Match-id-35428995c4420515616e408aedde01de85d98aaf --- src/core/key_process/key_process.cpp | 3 ++- src/core/key_process/key_process.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 8123b7e7..87c29bba 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -830,7 +830,7 @@ auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple& splitKeys) +void KeyProcess::PaddingAlltoallVC(vector& splitKeys) const { for (auto& keys : splitKeys) { if (keys.size() % ALLTOALLVC_ALIGN == 0) { @@ -839,6 +839,7 @@ void KeyProcess::PaddingAlltoallVC(vector& splitKeys) int padding_size = ALLTOALLVC_ALIGN - (keys.size() % ALLTOALLVC_ALIGN); std::fill_n(keys.back(), padding_size, INVALID_KEY_VALUE); } + return; } auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index dd2c1b34..99d2d853 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -219,7 +219,7 @@ namespace MxRec { auto HotHashSplit(const unique_ptr& batch) -> tuple, vector, vector>; - void PaddingAlltoallVC(vector& splitKeys); + void PaddingAlltoallVC(vector& splitKeys) const; auto HashSplitWithFAAE(const unique_ptr& batch) const -> tuple, vector, vector>>; -- Gitee From b96a64b7ad3ea0a3b7ea393f9a25c8a75c5cdd40 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 23 Nov 2023 15:34:34 +0800 Subject: [PATCH 477/551] Match-id-d0787a9117999c67408dea0c15839e068cff954b --- src/core/key_process/key_process.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 87c29bba..b3847f4d 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -837,7 +837,7 @@ void KeyProcess::PaddingAlltoallVC(vector& splitKeys) const continue; } int padding_size = ALLTOALLVC_ALIGN - (keys.size() % ALLTOALLVC_ALIGN); - std::fill_n(keys.back(), padding_size, INVALID_KEY_VALUE); + std::fill_n(std::back_inserter(keys), padding_size, INVALID_KEY_VALUE); } return; } -- Gitee From 9f2f97db76cab2ae916668eee542962f296049d4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 23 Nov 2023 16:25:37 +0800 Subject: [PATCH 478/551] Match-id-71682af5b6d6ace27d70968fcd62e94425567799 --- src/core/checkpoint/checkpoint.cpp | 13 ++++++++----- src/core/checkpoint/checkpoint.h | 2 ++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 4c86f0fb..b3d24467 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -197,20 +197,23 @@ void Checkpoint::SaveDataset(const vector& embNames, LOG_DEBUG("====Start getting data from handler to: {}", datasetDir); auto transData { dataHandler->GetDataset(saveDataType, embName) }; + LOG_DEBUG("====Start saving data to: {}", datasetDir); + WriteStream(transData, datasetDir, transData.datasetSize, saveDataType); + LOG_DEBUG("====Start saving data to: {}", attributeDir); + WriteStream(transData, attributeDir, transData.attributeSize, CkptDataType::ATTRIBUTE); + // save embedding when dynamic expansion is open if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { auto embedPath { dataDir + dirSeparator + "embedding" }; auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; + auto embedAttributeDir { embedPath + dirSeparator + datasetName + to_string(rankId) + attribFileType}; auto embeddingSizeInfo = GetEmbeddingSize(embName); + transData.attribute = {transData.int64Arr.size(), static_cast(embeddingSizeInfo.extEmbSize), fourBytes}; MakeSaveDir(embedPath); LOG_DEBUG("====Start saving embedding data to: {}", embedPath); WriteEmbedding(transData, embedDatasetDir, embeddingSizeInfo.extEmbSize); + WriteStream(transData, embedAttributeDir, transData.attributeSize, CkptDataType::ATTRIBUTE); } - - LOG_DEBUG("====Start saving data to: {}", datasetDir); - WriteStream(transData, datasetDir, transData.datasetSize, saveDataType); - LOG_DEBUG("====Start saving data to: {}", attributeDir); - WriteStream(transData, attributeDir, transData.attributeSize, CkptDataType::ATTRIBUTE); } } } diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index b5aafbe1..2872b6eb 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -75,6 +75,8 @@ namespace MxRec { const int attribEmbDataInnerIdx { 1 }; const int keyAddrElem { 2 }; + const uint32_t fourBytes = 4; + void SetDataHandler(CkptData& ckptData); void SetDataHandler(const vector& featureTypes); -- Gitee From a834b42f2e7eece9954ff330762ae1f89e82cb68 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 22 Nov 2023 15:17:56 +0800 Subject: [PATCH 479/551] Match-id-7e8e4b5217588f7fcc8afb93ff070e4ca2270b03 --- mx_rec/constants/constants.py | 6 ++++++ mx_rec/core/asc/manager.py | 27 ++++++++++++++++++++++++++- mx_rec/graph/modifier.py | 4 ++-- mx_rec/util/global_env_conf.py | 9 ++++++--- 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 4eb18316..c261c7f5 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -72,6 +72,10 @@ MAX_DEVICE_ID = 15 # HDFS file system's file prefix HDFS_FILE_PREFIX = ["viewfs://", "hdfs://"] +# get next名称 +ITERATOR_GET_NEXT = "IteratorGetNext" +NPU_GET_NEXT = "npuGetNext" + class BaseEnum(Enum): @classmethod @@ -110,6 +114,7 @@ class EnvOption(Enum): USE_COMBINE_FAAE = "USE_COMBINE_FAAE" STAT_ON = "STAT_ON" RECORD_KEY_COUNT = "RECORD_KEY_COUNT" + ADD_CONTROL_EDGE = "ADD_CONTROL_EDGE" class DataName(Enum): @@ -186,6 +191,7 @@ class TFDevice(Enum): CPU = "CPU" NPU = "NPU" GPU = "GPU" + NONE = "NONE" class Flag(Enum): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 42d94476..11265169 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -9,8 +9,11 @@ from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitiali from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_steps, get_eval_steps, get_save_steps, \ export_table_instances, export_feature_spec, get_if_load, get_use_static, \ - get_use_hot, get_stat_on, get_use_dynamic_expansion, export_optimizer, export_dangling_table, export_table_num + get_use_hot, get_stat_on, get_use_dynamic_expansion, export_optimizer, export_dangling_table, export_table_num, \ + get_modify_graph +from mx_rec.constants.constants import ITERATOR_GET_NEXT, NPU_GET_NEXT, TFDevice, EnvOption, Flag from mx_rec.core.asc.merge_table import find_dangling_table, should_skip +from mx_rec.util.global_env_conf import global_env from mx_rec.util.log import logger @@ -225,7 +228,29 @@ def initialize_emb_cache(table_info_list, threshold_list): logger.debug("threshold_values are %s.", threshold_list) +def add_control_edge(): + iterator_get_next_op = None + get_next_name = ITERATOR_GET_NEXT + if get_modify_graph(): + get_next_name = NPU_GET_NEXT + for op in tf.compat.v1.get_default_graph().get_operations(): + if get_next_name == op.name and ITERATOR_GET_NEXT == op.type: + iterator_get_next_op = op + break + logger.info("iterator_get_next_op: %s", iterator_get_next_op) + if not iterator_get_next_op: + return + for op in tf.compat.v1.get_default_graph().get_operations(): + if "GetNext" == op.type: + if global_env.tf_device == TFDevice.NPU.value and "merged" not in op.name: + continue + op._add_control_input(iterator_get_next_op) + logger.info("_add_control_input: %s", op) + + def start_asc_pipeline(): + if global_env.add_control_edge == Flag.TRUE.value: + add_control_edge() table_info_list = generate_table_info_list() threshold_list = generate_threshold_list() diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index ff2391f9..b05634e5 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -15,7 +15,7 @@ from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE + ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE, NPU_GET_NEXT from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, \ insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ @@ -478,7 +478,7 @@ def update_iterator_getnext(get_next_op: Operation, tgt_dataset: DatasetV1Adapte set_initializer(is_training, new_iterator.initializer) else: new_iterator = tgt_dataset.make_one_shot_iterator() - new_batch = new_iterator.get_next() + new_batch = new_iterator.get_next(NPU_GET_NEXT) set_target_batch(is_training, new_batch) try: diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index f2b802e8..45e4cda6 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -34,6 +34,7 @@ class RecEnv: use_combine_faae: str stat_on: str record_key_count: str + add_control_edge: str def get_global_env_conf() -> RecEnv: @@ -48,7 +49,7 @@ def get_global_env_conf() -> RecEnv: ascend_visible_devices=os.getenv(EnvOption.ASCEND_VISIBLE_DEVICES.value), cm_chief_device=os.getenv(EnvOption.CM_CHIEF_DEVICE.value), cm_worker_size=os.getenv(EnvOption.CM_WORKER_SIZE.value), - tf_device=os.getenv(EnvOption.TF_DEVICE.value, TFDevice.NPU.value), + tf_device=os.getenv(EnvOption.TF_DEVICE.value, TFDevice.NONE.value), apply_gradients_strategy=os.getenv(EnvOption.APPLY_GRADIENTS_STRATEGY.value, ApplyGradientsStrategy.DIRECT_APPLY.value), acl_timeout=os.getenv(EnvOption.ACL_TIMEOUT.value, "-1"), @@ -61,7 +62,8 @@ def get_global_env_conf() -> RecEnv: glog_stderrthreahold=os.getenv(EnvOption.GLOG_STDERRTHREAHOLD.value, RecCPPLogLevel.INFO.value), use_combine_faae=os.getenv(EnvOption.USE_COMBINE_FAAE.value, Flag.FALSE.value), stat_on=os.getenv(EnvOption.STAT_ON.value, Flag.FALSE.value), - record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value) + record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value), + add_control_edge=os.getenv(EnvOption.ADD_CONTROL_EDGE.value, Flag.TRUE.value) ) return rec_env @@ -87,7 +89,8 @@ def get_global_env_conf() -> RecEnv: ("glog_stderrthreahold", OptionValidator, {"options": [i.value for i in list(RecCPPLogLevel)]}), ("use_combine_faae", OptionValidator, {"options": [i.value for i in list(Flag)]}), ("stat_on", OptionValidator, {"options": [i.value for i in list(Flag)]}), - ("record_key_count", OptionValidator, {"options": [i.value for i in list(Flag)]}) + ("record_key_count", OptionValidator, {"options": [i.value for i in list(Flag)]}), + ("add_control_edge", OptionValidator, {"options": [i.value for i in list(Flag)]}) ]) def check_env(**kwargs): pass -- Gitee From eb29c3af66e0d2b0d1e4c4f60bc7e2a382192de3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 23 Nov 2023 16:31:20 +0800 Subject: [PATCH 480/551] Match-id-9dd575956f2c36cedb85872f5e72a12500fb2563 --- mx_rec/saver/sparse.py | 21 +++++++++++++-------- src/core/checkpoint/checkpoint.cpp | 9 +++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index 4f750a54..085b9f5c 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -95,17 +95,19 @@ class SparseProcessor: if table_instance.host_vocabulary_size != 0: out_dir = host_table_dir key, offset = self._get_hashmap(host_table_dir, True) - emb_data = self.get_embedding(device_table_dir, host_table_dir, True) + emb_data = self.get_embedding(device_table_dir, host_table_dir, True, + table_instance.use_dynamic_expansion) emb_data = emb_data[offset] else: out_dir = device_table_dir key, _ = self._get_hashmap(device_table_dir, False) - emb_data = self.get_embedding(device_table_dir, host_table_dir, False) + emb_data = self.get_embedding(device_table_dir, host_table_dir, False, + table_instance.use_dynamic_expansion) transformed_data = dict(zip(key[:], emb_data[:])) save_path = os.path.join(out_dir, self.export_name + ".npy") np.save(save_path, transformed_data) - def get_embedding(self, device_table_dir, host_table_dir, ddr): + def get_embedding(self, device_table_dir, host_table_dir, ddr, use_dynamic_expansion): emb_dir = os.path.join(device_table_dir, self.device_emb_dir) data_file, attribute_file = self._get_file_names(emb_dir) if not os.path.exists(data_file): @@ -113,17 +115,20 @@ class SparseProcessor: if not os.path.exists(attribute_file): raise FileExistsError(f"attribute file {attribute_file} does not exist when reading.") - temp = self._get_shape_form_attrib(attribute_file, is_json=True) - data_shape = temp.pop(self.json_attrib_shape) - data_dtype = temp.pop(self.json_attrib_dtype) - emb_data = self._get_data(data_file, data_dtype, data_shape) + if use_dynamic_expansion: + device_attribute = self._get_shape_form_attrib(attribute_file, is_json=False) + data_shape = [device_attribute[0], device_attribute[1]] + else: + device_attribute = self._get_shape_form_attrib(attribute_file, is_json=True) + data_shape = device_attribute.pop(self.json_attrib_shape) + emb_data = self._get_data(data_file, np.float32, data_shape) if ddr: emb_dir = os.path.join(host_table_dir, self.host_emb_dir) data_file, attribute_file = self._get_file_names(emb_dir) host_attribute = self._get_shape_form_attrib(attribute_file, is_json=False) host_data_shape = [host_attribute[0], host_attribute[1]] - host_data = self._get_data(data_file, data_dtype, host_data_shape) + host_data = self._get_data(data_file, np.float32, host_data_shape) host_data = host_data[:, :data_shape[1]] emb_data = np.append(emb_data, host_data, axis=0) return emb_data diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index b3d24467..79c11c97 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -204,11 +204,12 @@ void Checkpoint::SaveDataset(const vector& embNames, // save embedding when dynamic expansion is open if ((saveDataType == CkptDataType::NDDR_FEATMAP) && useDynamicExpansion) { - auto embedPath { dataDir + dirSeparator + "embedding" }; - auto embedDatasetDir { embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType }; - auto embedAttributeDir { embedPath + dirSeparator + datasetName + to_string(rankId) + attribFileType}; + string embedPath = dataDir + dirSeparator + "embedding"; + string embedDatasetDir = embedPath + dirSeparator + datasetName + to_string(rankId) + dataFileType; + string embedAttributeDir = embedPath + dirSeparator + datasetName + to_string(rankId) + attribFileType; auto embeddingSizeInfo = GetEmbeddingSize(embName); - transData.attribute = {transData.int64Arr.size(), static_cast(embeddingSizeInfo.extEmbSize), fourBytes}; + transData.attribute = {transData.int64Arr.size(), + static_cast(embeddingSizeInfo.extEmbSize), fourBytes}; MakeSaveDir(embedPath); LOG_DEBUG("====Start saving embedding data to: {}", embedPath); WriteEmbedding(transData, embedDatasetDir, embeddingSizeInfo.extEmbSize); -- Gitee From a6d29df6def660a3aa6a9292ebf72c2e3e8e1831 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 24 Nov 2023 10:20:48 +0800 Subject: [PATCH 481/551] Match-id-d82b14763df48b8358da287e387696fbb0702b00 --- mx_rec/saver/sparse.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index 085b9f5c..f4f9cdaa 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -60,7 +60,7 @@ class SparseProcessor: return data @staticmethod - def _get_shape_form_attrib(attribute_dir, is_json): + def _get_shape_from_attrib(attribute_dir, is_json): if is_json: try: with open(attribute_dir, "r") as fin: @@ -116,17 +116,17 @@ class SparseProcessor: raise FileExistsError(f"attribute file {attribute_file} does not exist when reading.") if use_dynamic_expansion: - device_attribute = self._get_shape_form_attrib(attribute_file, is_json=False) + device_attribute = self._get_shape_from_attrib(attribute_file, is_json=False) data_shape = [device_attribute[0], device_attribute[1]] else: - device_attribute = self._get_shape_form_attrib(attribute_file, is_json=True) + device_attribute = self._get_shape_from_attrib(attribute_file, is_json=True) data_shape = device_attribute.pop(self.json_attrib_shape) emb_data = self._get_data(data_file, np.float32, data_shape) if ddr: emb_dir = os.path.join(host_table_dir, self.host_emb_dir) data_file, attribute_file = self._get_file_names(emb_dir) - host_attribute = self._get_shape_form_attrib(attribute_file, is_json=False) + host_attribute = self._get_shape_from_attrib(attribute_file, is_json=False) host_data_shape = [host_attribute[0], host_attribute[1]] host_data = self._get_data(data_file, np.float32, host_data_shape) host_data = host_data[:, :data_shape[1]] @@ -144,7 +144,7 @@ class SparseProcessor: if not os.path.exists(attribute_file): raise FileExistsError(f"hashmap attribute file {attribute_file} does not exist when reading.") - shape_data = self._get_shape_form_attrib(attribute_file, is_json=False) + shape_data = self._get_shape_from_attrib(attribute_file, is_json=False) if len(shape_data) < 2: raise ValueError(f"the attribute data from file {attribute_file} is invalid") data_shape = shape_data[:2] -- Gitee From ee3701e4be01a3ccfeb60635ae017738f617d50c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 24 Nov 2023 10:38:10 +0800 Subject: [PATCH 482/551] Match-id-931b0366ff441eb783c5c1c31395cb18b36c5a45 --- mx_rec/constants/constants.py | 1 - mx_rec/core/embedding.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 089f0a50..c261c7f5 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -52,7 +52,6 @@ TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 HASHTABLE_COLLECTION_NAME_LENGTH = 30 MAX_VOCABULARY_SIZE = 10**10 -MAX_DEVICE_VOCABULARY_SIZE = 256 * (10 ** 5) # RANK INFO VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"] diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 025c1d22..2d4fde9f 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -18,8 +18,7 @@ from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temp from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, \ - ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_VOCABULARY_SIZE, \ - MAX_DEVICE_VOCABULARY_SIZE + ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_VOCABULARY_SIZE from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ @@ -42,7 +41,7 @@ from mx_rec.util.log import logger ("emb_initializer", ClassValidator, {"classes": (InitializerV1, InitializerV2)}), ("optimizer_list", ClassValidator, {"classes": (list, type(None))}), (["ssd_vocabulary_size", "ssd_data_path", "host_vocabulary_size"], SSDFeatureValidator), - ("device_vocabulary_size", IntValidator, {"min_value": 1, "max_value": MAX_DEVICE_VOCABULARY_SIZE}, + ("device_vocabulary_size", IntValidator, {"min_value": 1, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), ("host_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), ("ssd_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), -- Gitee From 9caac9d4bc482a948c4397bde46a62ba972405ab Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 25 Nov 2023 09:48:08 +0800 Subject: [PATCH 483/551] Match-id-14cdf46a010784c2cf032357a46506f33baba07a --- mx_rec/saver/saver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 889d804e..0f2f0096 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -215,7 +215,7 @@ class Saver(object): def _build_save(self): for var in self.var_list: - if os.getenv("TF_DEVICE", " ") == "NPU" and "merged" not in var.name: + if global_env.tf_device == TFDevice.NPU.value and "merged" not in var.name: continue table_instance = get_table_instance(var) @@ -263,7 +263,7 @@ class Saver(object): restore_host_data(reading_path) logger.info("host data was restored.") - if get_use_dynamic_expansion: + if get_use_dynamic_expansion(): # Data related to dynamic expansion needs to be restored only on the host side. return -- Gitee From c1625e07aff3a998d8911d8dec4b8f5927b79ea0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 25 Nov 2023 12:05:14 +0800 Subject: [PATCH 484/551] Match-id-8f9908e1cb78fe70c506d695ff1db4b99b77f030 --- mx_rec/optimizers/__init__.py | 2 - mx_rec/optimizers/ftrl_t.py | 245 ------------------------------ mx_rec/optimizers/ftrl_t_dense.py | 179 ---------------------- 3 files changed, 426 deletions(-) delete mode 100644 mx_rec/optimizers/ftrl_t.py delete mode 100644 mx_rec/optimizers/ftrl_t_dense.py diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index 8006213a..8589e5f8 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -9,8 +9,6 @@ __all__ = [ from mx_rec.optimizers.adagrad import create_hash_optimizer from mx_rec.optimizers.ftrl import create_hash_optimizer -from mx_rec.optimizers.ftrl_t import create_hash_optimizer -from mx_rec.optimizers.ftrl_t_dense import create_ftrl_dense_optimizer from mx_rec.optimizers.gradient_descent import create_hash_optimizer from mx_rec.optimizers.gradient_descent_by_addr import create_hash_optimizer_by_addr from mx_rec.optimizers.lazy_adam import create_hash_optimizer diff --git a/mx_rec/optimizers/ftrl_t.py b/mx_rec/optimizers/ftrl_t.py deleted file mode 100644 index ec5c38ce..00000000 --- a/mx_rec/optimizers/ftrl_t.py +++ /dev/null @@ -1,245 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import defaultdict - -import tensorflow as tf - -from tensorflow.python.framework import ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import gen_state_ops -from tensorflow.python.training import optimizer -from tensorflow.python.training import slot_creator - -from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance, insert_removing_var_list -from mx_rec.util.variable import check_and_get_config_via_var - - -def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl_t", **kwargs): - - return CustomizedFtrlT(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) - - -class CustomizedFtrlT(optimizer.Optimizer, CustomizedOptimizer): - name_counter = defaultdict(int) - - def __init__(self, learning_rate, use_locking=False, name="Ftrl_t", **kwargs): - self.optimizer_type = "ftrl" - super(CustomizedFtrlT, self)._get_name(name=name) - - self._learning_rate = learning_rate - self._alpha = kwargs.get("alpha", 0.06) - self._beta = kwargs.get("beta", 1.0) - self._lambda1 = kwargs.get("lambda1", 0.0) - self._lambda2 = kwargs.get("lambda2", 0.0) - self._epsilon = kwargs.get("epsilon", 0.0) - self._grad_factor = kwargs.get("grad_factor", 0.0) - self._z_name = kwargs.get("z_name", None) - self._n_name = kwargs.get("n_name", None) - self._g_name = kwargs.get("g_name", None) - self._learning_rate_tensor = None - self._alpha_tensor = None - self._beta_tensor = None - self._lambda1_tensor = None - self._lambda2_tensor = None - self._epsilon_tensor = None - self._grad_factor_tensor = None - super(CustomizedFtrlT, self).__init__(use_locking, self.unique_name) - - def initialize_slots(self, var, table_instance): - z = slot_creator.create_zeros_slot(var, self._name + "/" + "z") - n = slot_creator.create_zeros_slot(var, self._name + "/" + "n") - g = slot_creator.create_zeros_slot(var, self._name + "/" + "g") - w = slot_creator.create_zeros_slot(var, self._name + "/" + "w") - insert_removing_var_list(z.name) - insert_removing_var_list(n.name) - insert_removing_var_list(g.name) - insert_removing_var_list(w.name) - named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") - - table_instance.set_optimizer(self._name, {"z": z, "n": n, "g": g, "w": w}) - return [{"slot": z, "named_slot_key": named_slot_key, "slot_name": "z", "optimizer": self}, - {"slot": n, "named_slot_key": named_slot_key, "slot_name": "n", "optimizer": self}, - {"slot": g, "named_slot_key": named_slot_key, "slot_name": "g", "optimizer": self}, - {"slot": w, "named_slot_key": named_slot_key, "slot_name": "w", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot - - def get_slot_init_values(self): - initial_z_value = 0.0 - initial_n_value = 0.0 - initial_g_value = 0.0 - initial_w_value = 0.0 - return [initial_z_value, initial_n_value, initial_g_value, initial_w_value] - - def _prepare(self): - self._learning_rate_tensor = ops.convert_to_tensor( - self._learning_rate, name="learning_rate") - self._alpha_tensor = ops.convert_to_tensor(self._alpha, name="alpha") - self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta") - self._lambda1_tensor = ops.convert_to_tensor(self._lambda1, name="lambda1") - self._lambda2_tensor = ops.convert_to_tensor(self._lambda2, name="lambda2") - self._epsilon_tensor = ops.convert_to_tensor(self._epsilon, name="epsilon") - self._grad_factor_tensor = ops.convert_to_tensor(self._grad_factor, name="grad_factor") - - def _apply_sparse_duplicate_indices(self, grad, var): - return self._apply_sparse(grad, var) - - def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): - return self._resource_apply_sparse(grad, handle, indices) - - def _resource_apply_sparse(self, grad, handle, indices): - if self._lambda1 > 1e-10: - return self._apply_sparse_shared( - grad, - handle, - indices, - self._resource_scatter_nd_update) - else: - return self._apply_sparse_shared_v2( - grad, - handle, - indices, - self._resource_scatter_nd_update) - - def _apply_sparse(self, grad, var): - if self._lambda1 > 1e-10: - return self._apply_sparse_shared( - grad.values, - var, - grad.indices, - lambda x, i, v: tf.compat.v1.scatter_nd_update(x, i, v)) - else: - return self._apply_sparse_shared_v2( - grad.values, - var, - grad.indices, - lambda x, i, v: tf.compat.v1.scatter_nd_update(x, i, v)) - - def _apply_sparse_shared(self, grad, var, indices, scatter_nd_update): - z = self.get_slot(var, "z") - n = self.get_slot(var, "n") - g = self.get_slot(var, "g") - w = self.get_slot(var, "w") - alpha = math_ops.cast(self._alpha_tensor, var.dtype.base_dtype) - beta = math_ops.cast(self._beta_tensor, var.dtype.base_dtype) - lambda1 = math_ops.cast(self._lambda1_tensor, var.dtype.base_dtype) - lambda2 = math_ops.cast(self._lambda2_tensor, var.dtype.base_dtype) - epsilon = math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype) - grad_factor = math_ops.cast(self._grad_factor_tensor, var.dtype.base_dtype) - - abs_indices = tf.math.maximum(indices, 0) - nd_indices = tf.expand_dims(indices, 1) - with tf.control_dependencies([grad]): - z_old = tf.gather(z, abs_indices) - n_old = tf.gather(n, abs_indices) - g_old = tf.gather(g, abs_indices) - var_old = tf.gather(w, abs_indices) - - g_new = grad_factor * g_old + (1.0 - grad_factor) * grad - with tf.control_dependencies([g_new]): - g_update = scatter_nd_update(g, nd_indices, g_new) - - rho = tf.divide(tf.sqrt(n_old + tf.square(g_new)) - tf.sqrt(n_old), alpha) - z_new = (1.0 - epsilon) * z_old + g_new - tf.multiply(rho, var_old) - with tf.control_dependencies([z_new]): - z_update = scatter_nd_update(z, nd_indices, z_new) - - n_new = (1.0 - epsilon) * n_old + tf.square(g_new) - with tf.control_dependencies([n_new]): - n_update = scatter_nd_update(n, nd_indices, n_new) - - denominator = tf.divide((beta + tf.sqrt(n_new)), alpha) + lambda2 - numerator = lambda1 * tf.sign(z_new) - z_new - mask = math_ops.cast(tf.math.greater(tf.abs(z_new), lambda1), var.dtype.base_dtype) - var_new = tf.multiply(mask, tf.divide(numerator, denominator)) - with tf.control_dependencies([var_new]): - w_update = scatter_nd_update(w, nd_indices, var_new) - var_update = scatter_nd_update(var, nd_indices, var_new) - - return control_flow_ops.group(g_update, z_update, n_update, w_update, var_update) - - def _apply_sparse_shared_v2(self, grad, var, indices, scatter_nd_update): - z = self.get_slot(var, "z") - n = self.get_slot(var, "n") - g = self.get_slot(var, "g") - w = self.get_slot(var, "w") - alpha = math_ops.cast(self._alpha_tensor, var.dtype.base_dtype) - beta = math_ops.cast(self._beta_tensor, var.dtype.base_dtype) - lambda2 = math_ops.cast(self._lambda2_tensor, var.dtype.base_dtype) - epsilon = math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype) - grad_factor = math_ops.cast(self._grad_factor_tensor, var.dtype.base_dtype) - - abs_indices = tf.math.maximum(indices, 0) - nd_indices = tf.expand_dims(indices, 1) - with tf.control_dependencies([grad]): - z_old = tf.gather(z, abs_indices) - n_old = tf.gather(n, abs_indices) - g_old = tf.gather(g, abs_indices) - var_old = tf.gather(w, abs_indices) - - g_new = grad_factor * g_old + (1.0 - grad_factor) * grad - with tf.control_dependencies([g_new]): - g_update = scatter_nd_update(g, nd_indices, g_new) - - rho = tf.divide(tf.sqrt(n_old + tf.square(g_new)) - tf.sqrt(n_old), alpha) - z_new = (1.0 - epsilon) * z_old + g_new - tf.multiply(rho, var_old) - with tf.control_dependencies([z_new]): - z_update = scatter_nd_update(z, nd_indices, z_new) - - n_new = (1.0 - epsilon) * n_old + tf.square(g_new) - with tf.control_dependencies([n_new]): - n_update = scatter_nd_update(n, nd_indices, n_new) - - denominator = tf.divide((beta + tf.sqrt(n_new)), alpha) + lambda2 - var_new = tf.divide(-1.0 * z_new, denominator) - with tf.control_dependencies([var_new]): - w_update = scatter_nd_update(w, nd_indices, var_new) - var_update = scatter_nd_update(var, nd_indices, var_new) - - return control_flow_ops.group(g_update, z_update, n_update, w_update, var_update) - - def _resource_scatter_nd_update(self, x, i, v): - with ops.control_dependencies([ - gen_state_ops.resource_scatter_nd_update(x.handle, i, v)]): - return x.value() - - def _create_slots(self, var_list): - - # Create slots for the first and second moments. - z_state_name = self._name + "/" + "z" - n_state_name = self._name + "/" + "n" - g_state_name = self._name + "/" + "g" - w_state_name = self._name + "/" + "w" - for each_var in var_list: - with ops.colocate_with(each_var): - table_instance = check_and_get_config_via_var(each_var, self.optimizer_type) - - z = self._zeros_slot(each_var, "z", z_state_name) - n = self._zeros_slot(each_var, "n", n_state_name) - g = self._zeros_slot(each_var, "g", g_state_name) - w = self._zeros_slot(each_var, "w", w_state_name) - # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - insert_removing_var_list(z.name) - insert_removing_var_list(n.name) - insert_removing_var_list(g.name) - insert_removing_var_list(w.name) - - if self._name not in table_instance.optimizer: - table_instance.set_optimizer(self._name, {"z": z, "n": n, "g": g, "w": w}) diff --git a/mx_rec/optimizers/ftrl_t_dense.py b/mx_rec/optimizers/ftrl_t_dense.py deleted file mode 100644 index a59e793c..00000000 --- a/mx_rec/optimizers/ftrl_t_dense.py +++ /dev/null @@ -1,179 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import defaultdict - -import tensorflow as tf - -from tensorflow.python.framework import ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import gen_state_ops -from tensorflow.python.training import optimizer - -from mx_rec.util.initialize import insert_removing_var_list - - -def create_ftrl_dense_optimizer(learning_rate, use_locking=False, name="Ftrl_t_dense", **kwargs): - return CustomizedFtrlTZ(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) - - -class CustomizedFtrlTZ(optimizer.Optimizer): - name_counter = defaultdict(int) - - def __init__(self, learning_rate, use_locking=False, name="Ftrl_t_dense", **kwargs): - self.optimizer_type = "ftrl" - self._learning_rate = learning_rate - self._alpha = kwargs.get("alpha", 0.06) - self._beta = kwargs.get("beta", 1.0) - self._lambda1 = kwargs.get("lambda1", 0.0) - self._lambda2 = kwargs.get("lambda2", 0.0) - self._epsilon = kwargs.get("epsilon", 0.0) - self._grad_factor = kwargs.get("grad_factor", 0.0) - self._z_name = kwargs.get("z_name", None) - self._n_name = kwargs.get("n_name", None) - self._g_name = kwargs.get("g_name", None) - self._learning_rate_tensor = None - self._alpha_tensor = None - self._beta_tensor = None - self._lambda1_tensor = None - self._lambda2_tensor = None - self._epsilon_tensor = None - self._grad_factor_tensor = None - super(CustomizedFtrlTZ, self).__init__(use_locking, name) - - def _prepare(self): - self._learning_rate_tensor = ops.convert_to_tensor( - self._learning_rate, name="learning_rate") - self._alpha_tensor = ops.convert_to_tensor(self._alpha, name="alpha") - self._beta_tensor = ops.convert_to_tensor(self._beta, name="beta") - self._lambda1_tensor = ops.convert_to_tensor(self._lambda1, name="lambda1") - self._lambda2_tensor = ops.convert_to_tensor(self._lambda2, name="lambda2") - self._epsilon_tensor = ops.convert_to_tensor(self._epsilon, name="epsilon") - self._grad_factor_tensor = ops.convert_to_tensor(self._grad_factor, name="grad_factor") - - def _resource_apply_dense(self, grad, handle): - if self._lambda1 > 1e-10: - return self._apply_dense_shared( - grad, - handle) - else: - return self._apply_dense_shared_v2( - grad, - handle) - - def _apply_dense(self, grad, var): - if self._lambda1 > 1e-10: - return self._apply_dense_shared( - grad, - var) - else: - return self._apply_dense_shared_v2( - grad, - var) - - def _apply_dense_shared(self, grad, var): - z_var = self.get_slot(var, "z") - n_var = self.get_slot(var, "n") - g_var = self.get_slot(var, "g") - w_var = self.get_slot(var, "w") - alpha = math_ops.cast(self._alpha_tensor, var.dtype.base_dtype) - beta = math_ops.cast(self._beta_tensor, var.dtype.base_dtype) - lambda1 = math_ops.cast(self._lambda1_tensor, var.dtype.base_dtype) - lambda2 = math_ops.cast(self._lambda2_tensor, var.dtype.base_dtype) - epsilon = math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype) - grad_factor = math_ops.cast(self._grad_factor_tensor, var.dtype.base_dtype) - - z_old = tf.identity(z_var) - n_old = tf.identity(n_var) - g_old = tf.identity(g_var) - var_old = tf.identity(w_var) - - g_new = grad_factor * g_old + (1.0 - grad_factor) * grad - with tf.control_dependencies([g_new]): - g_update = tf.compat.v1.assign(g_var, g_new) - - rho = tf.divide(tf.sqrt(n_old + tf.square(g_new)) - tf.sqrt(n_old), alpha) - z_new = (1.0 - epsilon) * z_old + g_new - tf.multiply(rho, var_old) - with tf.control_dependencies([z_new]): - z_update = tf.compat.v1.assign(z_var, z_new) - - n_new = (1.0 - epsilon) * n_old + tf.square(g_new) - with tf.control_dependencies([n_new]): - n_update = tf.compat.v1.assign(n_var, n_new) - - denominator = tf.divide((beta + tf.sqrt(n_new)), alpha) + lambda2 - numerator = lambda1 * tf.sign(z_new) - z_new - mask = math_ops.cast(tf.math.greater(tf.abs(z_new), lambda1), var.dtype.base_dtype) - var_new = tf.multiply(mask, tf.divide(numerator, denominator)) - with tf.control_dependencies([var_new]): - w_update = tf.compat.v1.assign(w_var, var_new) - var_updata = tf.compat.v1.assign(var, var_new) - - return control_flow_ops.group(g_update, z_update, n_update, w_update, var_updata) - - def _apply_dense_shared_v2(self, grad, var): - z_var = self.get_slot(var, "z") - n_var = self.get_slot(var, "n") - g_var = self.get_slot(var, "g") - w_var = self.get_slot(var, "w") - alpha = math_ops.cast(self._alpha_tensor, var.dtype.base_dtype) - beta = math_ops.cast(self._beta_tensor, var.dtype.base_dtype) - lambda2 = math_ops.cast(self._lambda2_tensor, var.dtype.base_dtype) - epsilon = math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype) - grad_factor = math_ops.cast(self._grad_factor_tensor, var.dtype.base_dtype) - - z_old = tf.identity(z_var) - n_old = tf.identity(n_var) - g_old = tf.identity(g_var) - var_old = tf.identity(w_var) - - g_new = grad_factor * g_old + (1.0 - grad_factor) * grad - with tf.control_dependencies([g_new]): - g_updata = tf.compat.v1.assign(g_var, g_new) - - rho = tf.divide(tf.sqrt(n_old + tf.square(g_new)) - tf.sqrt(n_old), alpha) - z_new = (1.0 - epsilon) * z_old + g_new - tf.multiply(rho, var_old) - with tf.control_dependencies([z_new]): - z_updata = tf.compat.v1.assign(z_var, z_new) - - n_new = (1.0 - epsilon) * n_old + tf.square(g_new) - with tf.control_dependencies([n_new]): - n_updata = tf.compat.v1.assign(n_var, n_new) - - denominator = tf.divide((beta + tf.sqrt(n_new)), alpha) + lambda2 - var_new = tf.divide(-1.0 * z_new, denominator) - with tf.control_dependencies([var_new]): - w_updata = tf.compat.v1.assign(w_var, var_new) - var_updata = tf.compat.v1.assign(var, var_new) - - return control_flow_ops.group(g_updata, z_updata, n_updata, w_updata, var_updata) - - def _resource_scatter_nd_update(self, x_input, i_input, v_input): - with ops.control_dependencies([ - gen_state_ops.resource_scatter_nd_update(x_input.handle, i_input, v_input)]): - return x_input.value() - - def _create_slots(self, var_list): - - # Create slots for the first and second moments. - z_state_name = f"{self._name}/z" - n_state_name = f"{self._name}/n" - g_state_name = f"{self._name}/g" - w_state_name = f"{self._name}/w" - for each_var in var_list: - with ops.colocate_with(each_var): - z_zero = self._zeros_slot(each_var, "z", z_state_name) - n_zero = self._zeros_slot(each_var, "n", n_state_name) - g_zero = self._zeros_slot(each_var, "g", g_state_name) - w_zero = self._zeros_slot(each_var, "w", w_state_name) - # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - insert_removing_var_list(z_zero.name) - insert_removing_var_list(n_zero.name) - insert_removing_var_list(g_zero.name) - insert_removing_var_list(w_zero.name) -- Gitee From b0428b564f407cc463f085ec36268258d746fa21 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 25 Nov 2023 16:03:40 +0800 Subject: [PATCH 485/551] Match-id-bb918178a0c68444339d78c30343ecc40327b3cd --- mx_rec/optimizers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index 8589e5f8..de88c7e2 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -3,7 +3,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. __all__ = [ - "create_hash_optimizer", "create_ftrl_dense_optimizer", "create_hash_optimizer_by_addr", + "create_hash_optimizer", "create_hash_optimizer_by_addr", "create_hash_optimizer_by_address" ] -- Gitee From 224b2a4b1ef38a7c76f6b53224c6092e48f366ac Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 27 Nov 2023 17:47:27 +0800 Subject: [PATCH 486/551] Match-id-9ecbeacf165153d82b6b3e793e6f63eb07c7c30c --- mx_rec/optimizers/ftrl.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 827b04d1..8cda3643 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -130,7 +130,11 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): linear = self.get_slot(var, "linear") lr = math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) l1 = math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype) - l2 = math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype) + + if tf.__version__.startswith("1"): + l2 = math_ops.cast(self._l2_regularization_strength_tensor, var.dtype.base_dtype) + else: + l2 = math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype) lr_power = math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype) abs_indices = tf.math.maximum(indices, 0) @@ -167,7 +171,10 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): linear = self.get_slot(var, "linear") lr = math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) l1 = math_ops.cast(self._l1_regularization_strength_tensor, var.dtype.base_dtype) - l2 = math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype) + if tf.__version__.startswith("1"): + l2 = math_ops.cast(self._l2_regularization_strength_tensor, var.dtype.base_dtype) + else: + l2 = math_ops.cast(self._adjusted_l2_regularization_strength_tensor, var.dtype.base_dtype) lr_power = math_ops.cast(self._learning_rate_power_tensor, var.dtype.base_dtype) l2_shrinkage = math_ops.cast(self._l2_shrinkage_regularization_strength_tensor, var.dtype.base_dtype) -- Gitee From 5ea44a49ae8d2480d8276aa90f45f485864598f4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 27 Nov 2023 19:35:27 +0800 Subject: [PATCH 487/551] Match-id-79d71a725755c5074827f95c847a8a3b1e06c68a --- .../op_host/embedding_lookup_by_address.cpp | 79 ++++++------ .../embedding_lookup_by_address_tiling.h | 19 +-- .../op_host/embedding_update_by_address.cpp | 82 ++++++------ .../embedding_update_by_address_tiling.h | 15 ++- .../op_kernel/embedding_lookup_by_address.cpp | 122 +++++++++--------- .../op_kernel/embedding_update_by_address.cpp | 113 ++++++++-------- 6 files changed, 211 insertions(+), 219 deletions(-) diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 763ecc8f..8e8379e0 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -4,6 +4,14 @@ namespace optiling { + constexpr int32_t BLOCK_DIM = 48; // 910b一张卡48个vector核 + constexpr int32_t SIZE_OF_HALF = 2; + constexpr int32_t SIZE_OF_FLOAT_OR_INT = 4; + constexpr int32_t MIN_BLOCK_SIZE = 32; // ub空间的数据都要按照32对齐 + constexpr int32_t UB_LIMIT = 175 * 1024; + constexpr int32_t USR_SIZE = 256; + constexpr int32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024; + constexpr int32_t PING_PONG_NUM = 1; template static ge::graphStatus CheckNullPointer(T *pointer, const char *errorMessage) @@ -18,23 +26,16 @@ namespace optiling static ge::graphStatus TilingFunc(gert::TilingContext *context) { - TilingData1 tiling; - - size_t usrSize = 256; - size_t sysWorkspaceSize = 16 * 1024 * 1024; if (CheckNullPointer(context, "Tiling context") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } - size_t *currentWorkspace = context->GetWorkspaceSizes(1); + size_t *currentWorkspace = context->GetWorkspaceSizes(1); // 设备侧Global Memory上的一块内存 if (CheckNullPointer(currentWorkspace, "currentWorkspace") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } + currentWorkspace[0] = SYS_WORKSPACE_SIZE + USR_SIZE; - currentWorkspace[0] = sysWorkspaceSize + usrSize; - - int32_t blockTotalNums = 48; - int32_t ubLimit = 175 * 1024; auto *attrs = context->GetAttrs(); if (CheckNullPointer(attrs, "GetAttrs attrs") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; @@ -43,10 +44,9 @@ namespace optiling if (CheckNullPointer(attr0Value, " Lookup embbedingType attr0Value") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } - - int32_t embbedingDim = *attr0Value; - if (embbedingDim <= 0) { - printf("embbedingDim must larger than 0\n"); + int32_t embeddingDim = *attr0Value; + if (embeddingDim <= 0) { + printf("embeddingDim must larger than 0\n"); return ge::GRAPH_FAILED; } @@ -54,8 +54,10 @@ namespace optiling if (CheckNullPointer(attr1Value, "Lookup embbedingType attr1Value") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } - - int32_t embbedingType = *attr1Value; + int32_t embeddingType = *attr1Value; // 0:int32; 1:float; 2:half + if (embeddingType > 2 || embeddingType < 0) { + return ge::GRAPH_FAILED; + } auto inputTensor = context->GetInputTensor(0); if (CheckNullPointer(inputTensor, "inputTensor") != ge::GRAPH_SUCCESS) { @@ -63,40 +65,37 @@ namespace optiling } int32_t inputShape = inputTensor->GetShapeSize(); - int32_t singleDataSize = 4; - if (embbedingType == 2) { - singleDataSize = 2; - } - int32_t minMoveNum = 32 / singleDataSize; - - // onceMoveNums,(embbedingDim - 1 + minMoveNum) / min_move_num表示除以min_move_num向下取整 - int32_t onceMoveNums = minMoveNum * ((embbedingDim - 1 + minMoveNum) / minMoveNum); - int32_t numToMove = (embbedingDim - 1 + onceMoveNums) / onceMoveNums; - // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 - int32_t pingPongNum = 1; + int32_t typeSize = SIZE_OF_FLOAT_OR_INT; + if (embeddingType == 2) { + typeSize = SIZE_OF_HALF; + } + // shape需要对齐到的最小单位, MIN_BLOCK_SIZE=32 + int32_t alignNum = MIN_BLOCK_SIZE / typeSize; + // embeddingDimAligned,表示需要向上对齐到最小单位 + int32_t embeddingDimAligned = ((embeddingDim - 1 + alignNum) / alignNum) * alignNum; + // 每个地址需要占用sizeof(int64_t)个字节,typeSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 int32_t occupyAddressBytesNum = - sizeof(int64_t) + singleDataSize * onceMoveNums * numToMove * pingPongNum * 2; - // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 - int32_t addrMaxNum = (((ubLimit / occupyAddressBytesNum) / 4)) * 4; - if (addrMaxNum <= 0) { + sizeof(int64_t) + typeSize * embeddingDimAligned * PING_PONG_NUM * 2; + // 一轮计算中最多计算多少个addr,由于地址也要搬到ub,所以需要对齐32, + int32_t addrPerLoop = (UB_LIMIT / occupyAddressBytesNum) & (~3); // & (~3),保证地址数是4的倍数 + if (addrPerLoop <= 0) { return ge::GRAPH_FAILED; } - - tiling.set_embbeding_type(embbedingType); - tiling.set_update_dim(embbedingDim); + TilingData1 tiling; + tiling.set_ping_pong_num(PING_PONG_NUM); tiling.set_addr_nums(inputShape); - tiling.set_ub_limit(ubLimit); + tiling.set_embedding_type(embeddingType); + tiling.set_embedding_dim(embeddingDim); - tiling.set_addr_max_num(addrMaxNum); - tiling.set_ping_pong_num(pingPongNum); - tiling.set_single_data_size(singleDataSize); - tiling.set_once_move_nums(onceMoveNums); + tiling.set_addr_per_loop(addrPerLoop); + tiling.set_type_size(typeSize); + tiling.set_emb_dim_aligned(embeddingDimAligned); - context->SetBlockDim(blockTotalNums); + // 和tiling set 区别开, BlockDim就是BlockNum,可以理解为卡的核数 + context->SetBlockDim(BLOCK_DIM); tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); - return ge::GRAPH_SUCCESS; } } diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h index 596bd715..2a9d8951 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h @@ -1,19 +1,20 @@ +#ifndef EMBEDDING_LOOKUP_BY_ADDRESS_TILING_H +#define EMBEDDING_LOOKUP_BY_ADDRESS_TILING_H #include "register/tilingdata_base.h" namespace optiling { BEGIN_TILING_DATA_DEF(TilingData1) - TILING_DATA_FIELD_DEF(int32_t, update_dim); - TILING_DATA_FIELD_DEF(int32_t, addr_nums); - TILING_DATA_FIELD_DEF(int32_t, ub_limit); - TILING_DATA_FIELD_DEF(int32_t, embbeding_type); - TILING_DATA_FIELD_DEF(int32_t, update_type); - TILING_DATA_FIELD_DEF(int32_t, addr_max_num); TILING_DATA_FIELD_DEF(int32_t, ping_pong_num); - TILING_DATA_FIELD_DEF(int32_t, single_data_size); - TILING_DATA_FIELD_DEF(int32_t, once_move_nums); + TILING_DATA_FIELD_DEF(int32_t, addr_nums); + TILING_DATA_FIELD_DEF(int32_t, embedding_type); + TILING_DATA_FIELD_DEF(int32_t, embedding_dim); + TILING_DATA_FIELD_DEF(int32_t, addr_per_loop); + TILING_DATA_FIELD_DEF(int32_t, type_size); + TILING_DATA_FIELD_DEF(int32_t, emb_dim_aligned); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(EmbeddingLookupByAddress, TilingData1) -} \ No newline at end of file +} +#endif // EMBEDDING_LOOKUP_BY_ADDRESS_TILING_H \ No newline at end of file diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index ee4dfea2..5a5cc953 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -4,6 +4,14 @@ namespace optiling { + constexpr int32_t BLOCK_DIM = 48; // 910b一张卡48个vector核 + constexpr int32_t SIZE_OF_HALF = 2; + constexpr int32_t SIZE_OF_FLOAT_OR_INT = 4; + constexpr int32_t MIN_BLOCK_SIZE = 32; // ub空间的数据都要按照32对齐 + constexpr int32_t UB_LIMIT = 175 * 1024; + constexpr int32_t USR_SIZE = 256; + constexpr int32_t SYS_WORKSPACE_SIZE = 16 * 1024 * 1024; + constexpr int32_t PING_PONG_NUM = 1; template static ge::graphStatus CheckPointer(T *pointer, const char *errorMessage) @@ -28,32 +36,30 @@ namespace optiling static ge::graphStatus TilingFunc(gert::TilingContext *context) { - TilingData2 tiling; - - size_t usrSize = 256, sysWorkspaceSize = 16 * 1024 * 1024; - if (CheckPointer(context, "Update embbedingType context") != ge::GRAPH_SUCCESS) + if (CheckPointer(context, "Update embeddingType context") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; + } size_t *currentWorkspace = context->GetWorkspaceSizes(1); - if (CheckPointer(currentWorkspace, "currentWorkspace") != ge::GRAPH_SUCCESS) + if (CheckPointer(currentWorkspace, "currentWorkspace") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; - - currentWorkspace[0] = sysWorkspaceSize + usrSize; - - int32_t blockTotalNums = 48; - int32_t ubLimit = 175 * 1024; + } + currentWorkspace[0] = SYS_WORKSPACE_SIZE + USR_SIZE; auto inputTensor = context->GetInputTensor(0); - if (CheckPointer(inputTensor, "GetInputTensor inputTensor") != ge::GRAPH_SUCCESS) + if (CheckPointer(inputTensor, "GetInputTensor inputTensor") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; + } int32_t inputShape = inputTensor->GetShapeSize(); - if (CheckPositiveInt(inputShape, "inputShape") != ge::GRAPH_SUCCESS) + if (CheckPositiveInt(inputShape, "inputShape") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; + } auto inputTensor1 = context->GetInputTensor(1); - if (CheckPointer(inputTensor1, "GetInputTensor inputTensor1") != ge::GRAPH_SUCCESS) + if (CheckPointer(inputTensor1, "GetInputTensor inputTensor1") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; + } int32_t inputDim = inputTensor1->GetShapeSize() / inputShape; if (CheckPositiveInt(inputDim, "inputDim") != ge::GRAPH_SUCCESS) { @@ -61,56 +67,54 @@ namespace optiling } auto attrs = context->GetAttrs(); - if (CheckPointer(attrs, "GetAttrs attrs") != ge::GRAPH_SUCCESS) + if (CheckPointer(attrs, "GetAttrs attrs") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; + } auto attrPointer = attrs->GetAttrPointer(0); - if (CheckPointer(attrPointer, "attrPointer") != ge::GRAPH_SUCCESS) + if (CheckPointer(attrPointer, "attrPointer") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; + } int32_t updateType = *(attrPointer); ge::DataType inputDatatype = inputTensor1->GetDataType(); - int32_t embbedingType; + int32_t embeddingType; if (inputDatatype == ge::DT_FLOAT16) { - embbedingType = 2; + embeddingType = 2; } else if (inputDatatype == ge::DT_INT32) { - embbedingType = 0; + embeddingType = 0; } else { - embbedingType = 1; + embeddingType = 1; } - int32_t singleDataSize = 4; - if (embbedingType == 2) { - singleDataSize = 2; + int32_t typeSize = SIZE_OF_FLOAT_OR_INT; + if (embeddingType == 2) { + typeSize = SIZE_OF_HALF; } - int32_t minMoveNum = 32 / singleDataSize; + int32_t alignNum = MIN_BLOCK_SIZE / typeSize; - // onceMoveNums,(updateDim - 1 + minMoveNum) / min_move_num表示除以min_move_num向下取整 - int32_t onceMoveNums = minMoveNum * ((inputDim - 1 + minMoveNum) / minMoveNum); + int32_t inputDimAligned = alignNum * ((inputDim - 1 + alignNum) / alignNum); - int32_t numToMove = (inputDim - 1 + onceMoveNums) / onceMoveNums; // 每个地址需要占用sizeof(int64_t)个字节,singleDataSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 - int32_t pingPongNum = 1; int32_t occupyAddressBytesNum = - sizeof(int64_t) + singleDataSize * onceMoveNums * numToMove * pingPongNum * 2; - // 计算一轮计算中最多计算多少个addr,最后的 /4 再*4 是为了与32对齐,因为sizeof(int64_t) = 8 - int32_t addrMaxNum = ((int)((int)(ubLimit / occupyAddressBytesNum) / 4)) * 4; - if (CheckPositiveInt(addrMaxNum, "addrMaxNum") != ge::GRAPH_SUCCESS) { + sizeof(int64_t) + typeSize * inputDimAligned * PING_PONG_NUM * 2; + // 一轮计算中最多计算多少个addr,由于地址也要搬到ub,所以需要对齐32 + int32_t addrPerLoop = (UB_LIMIT / occupyAddressBytesNum) & (~3); // & (~3),保证地址数是4的倍数 + if (CheckPositiveInt(addrPerLoop, "addrPerLoop") != ge::GRAPH_SUCCESS) { return ge::GRAPH_FAILED; } + TilingData2 tiling; + tiling.set_ping_pong_num(PING_PONG_NUM); tiling.set_update_type(updateType); - tiling.set_embbeding_type(embbedingType); + tiling.set_embedding_type(embeddingType); tiling.set_update_dim(inputDim); tiling.set_addr_nums(inputShape); - tiling.set_ub_limit(ubLimit); - - tiling.set_addr_max_num(addrMaxNum); - tiling.set_ping_pong_num(pingPongNum); - tiling.set_single_data_size(singleDataSize); - tiling.set_once_move_nums(onceMoveNums); + tiling.set_addr_per_loop(addrPerLoop); + tiling.set_type_size(typeSize); + tiling.set_input_dim_aligned(inputDimAligned); - context->SetBlockDim(blockTotalNums); + context->SetBlockDim(BLOCK_DIM); tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h index 9cd630a1..7bbeb16c 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h @@ -1,4 +1,5 @@ - +#ifndef EMBEDDING_UPDATE_BY_ADDRESS_TILING_H +#define EMBEDDING_UPDATE_BY_ADDRESS_TILING_H #include "register/tilingdata_base.h" namespace optiling @@ -6,15 +7,15 @@ namespace optiling BEGIN_TILING_DATA_DEF(TilingData2) TILING_DATA_FIELD_DEF(int32_t, update_dim); TILING_DATA_FIELD_DEF(int32_t, addr_nums); - TILING_DATA_FIELD_DEF(int32_t, ub_limit); - TILING_DATA_FIELD_DEF(int32_t, embbeding_type); + TILING_DATA_FIELD_DEF(int32_t, embedding_type); TILING_DATA_FIELD_DEF(int32_t, update_type); - TILING_DATA_FIELD_DEF(int32_t, addr_max_num); + TILING_DATA_FIELD_DEF(int32_t, addr_per_loop); TILING_DATA_FIELD_DEF(int32_t, ping_pong_num); - TILING_DATA_FIELD_DEF(int32_t, single_data_size); - TILING_DATA_FIELD_DEF(int32_t, once_move_nums); + TILING_DATA_FIELD_DEF(int32_t, type_size); + TILING_DATA_FIELD_DEF(int32_t, input_dim_aligned); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(EmbeddingUpdateByAddress, TilingData2) -} \ No newline at end of file +} +#endif // EMBEDDING_UPDATE_BY_ADDRESS_TILING_H \ No newline at end of file diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 285519b0..19a384c0 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -1,5 +1,10 @@ #include "kernel_operator.h" using namespace AscendC; + +constexpr int32_t SIZE_OF_HALF = 2; +constexpr int32_t SIZE_OF_FLOAT_OR_INT = 4; +constexpr int32_t PADDING_ZERO_NUM_PER_TIME = 8; + template class KernelEimtable { @@ -9,112 +14,101 @@ public: } __aicore__ inline void Init(GM_ADDR address, GM_ADDR y) { - - NeedComputeAddrLen = SingleCoreAddrLen; - if (block_idx == block_num - 1) + needComputeAddrLen = singleCoreAddrLen; + if (block_idx == block_num - 1) // 最后一个core,需要多计算的addr长度 { - NeedComputeAddrLen = addrNums * sizeof(int64_t) - SingleCoreAddrLen * (block_num - 1); + needComputeAddrLen = addrNums * sizeof(int64_t) - singleCoreAddrLen * (block_num - 1); } - round = NeedComputeAddrLen / (roundSize * sizeof(int64_t)); + loopCount = needComputeAddrLen / (addrNumPerLoop * sizeof(int64_t)); // 可能为0 + // pipe alloc memory to queue, the unit is Bytes - pipe.InitBuffer(tbuf, roundSize * sizeof(int64_t)); + pipe.InitBuffer(tbuf, addrNumPerLoop * sizeof(int64_t)); - pipe.InitBuffer(inQueue, PingpongNum, Veclen); - pipe.InitBuffer(outQueue, PingpongNum, Veclen); // + pipe.InitBuffer(inQueue, pingpongNum, veclen); + pipe.InitBuffer(outQueue, pingpongNum, veclen); // - // get start index for current core, core parallel block_indx block_dim - srcAddrGlobal.SetGlobalBuffer((__gm__ int64_t *)(address + block_idx * SingleCoreAddrLen)); + // get start index for current core, core parallel block_indx block_dim,即使是最后一个核也应该多初始化一些,并对齐4的倍数 + srcAddrGlobal.SetGlobalBuffer((__gm__ int64_t *)(address + block_idx * singleCoreAddrLen), needComputeAddrLen); dstDataGm.SetGlobalBuffer((__gm__ T *)(y)); } __aicore__ inline void Init_param(GM_ADDR tiling) { GET_TILING_DATA(constData, tiling); - // 数据的维度数 - dim = constData.update_dim; - int32_t blockTotalNums = block_num; - addrNums = constData.addr_nums; - // 缓冲区数量 - PingpongNum = constData.ping_pong_num; - singleDataSize = constData.single_data_size; - onceMoveNums = constData.once_move_nums; - roundSize = constData.addr_max_num; - - int singleNum = (int)(addrNums / blockTotalNums); - if (singleNum % 4) - { - singleNum -= singleNum % 4; - } - Veclen = roundSize * singleDataSize * onceMoveNums; - SingleCoreAddrLen = singleNum * sizeof(int64_t); - cache = roundSize; + pingpongNum = constData.ping_pong_num; + addrNums = constData.addr_nums; + dim = constData.embedding_dim; + addrNumPerLoop = constData.addr_per_loop; + typeSize = constData.type_size; + embDimAligned = constData.emb_dim_aligned; + + int singleCoreAddrNum = (int)(addrNums / block_num); // 有可能没有整除,最后的核会处理更多的数据 + singleCoreAddrNum = singleCoreAddrNum & (~3); // & (~3) 代表取4的倍数向下取整,处理的地址占8字节,对齐32B的话,数量需要是4倍数 + ASSERT(singleCoreAddrNum != 0 && "single num can not be zero!"); + + singleCoreAddrLen = singleCoreAddrNum * sizeof(int64_t); + veclen = addrNumPerLoop * typeSize * embDimAligned; // 向上对齐32B + cache = constData.addr_per_loop; } __aicore__ inline void Process() { - LocalTensor srcAddrLocal = tbuf.Get(roundSize); + LocalTensor srcAddrLocal = tbuf.Get(addrNumPerLoop); - if (round > 0) + if (loopCount > 0) { - for (int32_t i = 0; i < round; i++) + for (int32_t i = 0; i < loopCount; i++) { - DataCopy(srcAddrLocal, srcAddrGlobal[i * roundSize], roundSize); - MoveProcess(srcAddrLocal, i, roundSize); + DataCopy(srcAddrLocal, srcAddrGlobal[i * addrNumPerLoop], addrNumPerLoop); + MoveProcess(srcAddrLocal, i, addrNumPerLoop); } } - - int unProcess = (NeedComputeAddrLen / sizeof(int64_t)) % roundSize; + // 处理最后一张卡剩下的addr + int unProcess = (needComputeAddrLen / sizeof(int64_t)) % addrNumPerLoop; if (unProcess) { - // 处理 addresslist 不对齐32b - int unProcessOnceCopyAddr = unProcess; - if (unProcessOnceCopyAddr % 4 != 0) - { - unProcessOnceCopyAddr += (4 - unProcess % 4); - } - - DataCopy(srcAddrLocal, srcAddrGlobal[round * roundSize], unProcessOnceCopyAddr); - MoveProcess(srcAddrLocal, round, unProcess); + int unProcessAligned = (unProcess + 3) & (~3); // 处理 addressList 不对齐32b的情况 + // 地址列表访问越界,对齐考虑无问题,会自动多申请一部分,兼容 + DataCopy(srcAddrLocal, srcAddrGlobal[loopCount * addrNumPerLoop], unProcessAligned); + MoveProcess(srcAddrLocal, loopCount, unProcess); } } private: - __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int sizes) + __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int addrNum) { set_flag(PIPE_MTE2, PIPE_S, 0); wait_flag(PIPE_MTE2, PIPE_S, 0); - LocalTensor dataLocal = inQueue.AllocTensor(); + LocalTensor dataLocal = inQueue.AllocTensor(); // Queue的大小可以容下一个循环的所有emb bool isFull = false; int nums = 0; int outIndex = 0; - int times = onceMoveNums / 8; - int tmpCache = cache - 1; + int times = embDimAligned >> 3; // >>3位运算:除以8。 embDimAligned一定是8的倍数,若地址无效时,每次填充8个0 + int tmpCache = cache - 1; // 设计初是一次cache执行多次copyin、一次compute和一次copyout,现状是一次loop就只对应一次cache - for (int i = 0; i < sizes; i++) + for (int i = 0; i < addrNum; i++) { - + // 多次copyIn, 对应一次compute和copyOut,由cache决定 dataLocal = isFull ? inQueue.AllocTensor() : dataLocal; int64_t address = srcAddrLocal.GetValue(i); if (address != 0) { - srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(address)); - DataCopy(dataLocal[onceMoveNums * nums], srcDataBufferGm, onceMoveNums); + srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(address), embDimAligned); + DataCopy(dataLocal[embDimAligned * nums], srcDataBufferGm, embDimAligned); } else { - for (int j = 0; j < times; j++) { - Duplicate(dataLocal[onceMoveNums * nums + j * 8], (T)0, 8); + Duplicate(dataLocal[embDimAligned * nums + j * PADDING_ZERO_NUM_PER_TIME], (T)0, PADDING_ZERO_NUM_PER_TIME); } - } nums++; - isFull = (i == tmpCache || i == sizes - 1); + isFull = (i == tmpCache || i == addrNum - 1); // cache满了,或者最后一个地址 if (isFull) { inQueue.EnQue(dataLocal); @@ -135,7 +129,7 @@ private: DataCopyParams copyParams; copyParams.blockCount = 1; - copyParams.blockLen = onceMoveNums * sizeof(T) * nums / 32; + copyParams.blockLen = (embDimAligned * sizeof(T) * nums) >> 5; // >> 5, 除以32,ub空间对齐 DataCopy(dstLocal, srcLocal, copyParams); outQueue.EnQue(dstLocal); @@ -146,29 +140,29 @@ private: { LocalTensor dstLocal = outQueue.DeQue(); - int offset = block_idx * dim * SingleCoreAddrLen / sizeof(int64_t) + (turns * roundSize * dim) + dim * index; + int offset = block_idx * dim * singleCoreAddrLen / sizeof(int64_t) + (turns * addrNumPerLoop * dim) + dim * index; #if defined(__DAV_C220_VEC__) - if (singleDataSize == 4) + if (typeSize == SIZE_OF_FLOAT_OR_INT) { copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm[offset].GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, nums, dim * sizeof(T), 0, 0, 0, 0); } - else if (singleDataSize == 2) + else if (typeSize == SIZE_OF_HALF) { copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm[offset].GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, nums, dim * sizeof(T), 0, 0, 0, 0); } #else - DataCopy(dstDataGm[offset], dstLocal, onceMoveNums * nums); + DataCopy(dstDataGm[offset], dstLocal, embDimAligned * nums); #endif outQueue.FreeTensor(dstLocal); } public: - int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, cache, Veclen, dim, PingpongNum; + int32_t addrNumPerLoop, loopCount, singleCoreAddrLen, needComputeAddrLen, veclen, dim, pingpongNum, cache; int32_t addrNums; - int32_t onceMoveNums, singleDataSize, updateType; + int32_t embDimAligned, typeSize, updateType; private: TPipe pipe; @@ -184,7 +178,7 @@ extern "C" __global__ __aicore__ void embedding_lookup_by_address(GM_ADDR addres { GET_TILING_DATA(constData, tiling); - int32_t embeddingType = constData.embbeding_type; + int32_t embeddingType = constData.embedding_type; switch (embeddingType) { diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index 2acd79c0..ca00a5fe 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -1,5 +1,9 @@ #include "kernel_operator.h" using namespace AscendC; + +constexpr int32_t SIZE_OF_HALF = 2; +constexpr int32_t SIZE_OF_FLOAT_OR_INT = 4; + template class KernelEimtable_update { @@ -9,104 +13,95 @@ public: } __aicore__ inline void Init(GM_ADDR address, GM_ADDR embedding, GM_ADDR y) { - NeedComputeAddrLen = SingleCoreAddrLen; + needComputeAddrLen = singleCoreAddrLen; if (block_idx == block_num - 1) { - NeedComputeAddrLen = addrNums * sizeof(int64_t) - SingleCoreAddrLen * (block_num - 1); + needComputeAddrLen = addrNums * sizeof(int64_t) - singleCoreAddrLen * (block_num - 1); } - round = NeedComputeAddrLen / (roundSize * sizeof(int64_t)); + loopCount = needComputeAddrLen / (addrNumPerLoop * sizeof(int64_t)); + + pipe.InitBuffer(tbuf, addrNumPerLoop * sizeof(int64_t)); + pipe.InitBuffer(inQueue, pingpongNum, veclen); + pipe.InitBuffer(outQueue, pingpongNum, veclen); - pipe.InitBuffer(tbuf, roundSize * sizeof(int64_t)); - pipe.InitBuffer(inQueue, PingpongNum, Veclen); - pipe.InitBuffer(outQueue, PingpongNum, Veclen); // get start index for current core, core parallel block_indx block_dim - srcAddrGlobal.SetGlobalBuffer((__gm__ int64_t *)(address + block_idx * SingleCoreAddrLen)); - srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(embedding + block_idx * SingleCoreAddrLen / sizeof(int64_t) * sizeof(T) * dim)); + srcAddrGlobal.SetGlobalBuffer((__gm__ int64_t *)(address + block_idx * singleCoreAddrLen)); + srcDataBufferGm.SetGlobalBuffer((__gm__ T *)(embedding + block_idx * singleCoreAddrLen / sizeof(int64_t) * sizeof(T) * dim)); outDataGm.SetGlobalBuffer((__gm__ T *)(y)); } __aicore__ inline void Init_param(GM_ADDR tiling) { GET_TILING_DATA(constData, tiling); - // 数据的维度数 + + pingpongNum = constData.ping_pong_num; dim = constData.update_dim; - int32_t block_total_nums = block_num; updateType = constData.update_type; addrNums = constData.addr_nums; + typeSize = constData.type_size; + inputDimAligned = constData.input_dim_aligned; + addrNumPerLoop = constData.addr_per_loop; - // 缓冲区数量 - PingpongNum = constData.ping_pong_num; - singleDataSize = constData.single_data_size; - onceMoveNums = constData.once_move_nums; - roundSize = constData.addr_max_num; + int singleCoreAddrNum = (int)(addrNums / block_num); + singleCoreAddrNum = singleCoreAddrNum & (~3); // & (~3) 代表取4的倍数向下取整,处理的地址占8字节,对齐32B的话,数量需要是4倍数 + ASSERT(singleCoreAddrNum != 0 && "single num can not be zero!"); - int singlenum = (int)(addrNums / block_total_nums); - if (singlenum % 4) - { - singlenum -= singlenum % 4; - } - - Veclen = roundSize * singleDataSize * onceMoveNums; - SingleCoreAddrLen = singlenum * sizeof(int64_t); - cache = roundSize; + veclen = addrNumPerLoop * typeSize * inputDimAligned; + singleCoreAddrLen = singleCoreAddrNum * sizeof(int64_t); + cache = constData.addr_per_loop; } __aicore__ inline void Process() { - LocalTensor srcAddrLocal = tbuf.Get(roundSize); - - int unprocess = (NeedComputeAddrLen / sizeof(int64_t)) % roundSize; + LocalTensor srcAddrLocal = tbuf.Get(addrNumPerLoop); - if (round > 0) + if (loopCount > 0) { - for (int32_t i = 0; i < round; i++) + for (int32_t i = 0; i < loopCount; i++) { - DataCopy(srcAddrLocal, srcAddrGlobal[i * roundSize], roundSize); - MoveProcess(srcAddrLocal, i, roundSize); + DataCopy(srcAddrLocal, srcAddrGlobal[i * addrNumPerLoop], addrNumPerLoop); + MoveProcess(srcAddrLocal, i, addrNumPerLoop); } } - if (unprocess) + int unProcess = (needComputeAddrLen / sizeof(int64_t)) % addrNumPerLoop; + if (unProcess) { - int unprocessOnceCopyaddr = unprocess; - if (unprocessOnceCopyaddr % 4 != 0) - { - unprocessOnceCopyaddr += (4 - unprocess % 4); - } - - DataCopy(srcAddrLocal, srcAddrGlobal[round * roundSize], unprocessOnceCopyaddr); - MoveProcess(srcAddrLocal, round, unprocess); + int unProcessAligned = (unProcess + 3) & (~3); // 处理 addressList 不对齐32b的情况 + DataCopy(srcAddrLocal, srcAddrGlobal[loopCount * addrNumPerLoop], unProcessAligned); + MoveProcess(srcAddrLocal, loopCount, unProcess); } } private: - __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int sizes) + __aicore__ inline void MoveProcess(const LocalTensor srcAddrLocal, const int turns, int addrNum) { set_flag(PIPE_MTE2, PIPE_S, 0); wait_flag(PIPE_MTE2, PIPE_S, 0); LocalTensor dataLocal; - int out_index = 0; - int offset = 0; + int64_t address = 0; - if (dim == onceMoveNums) + if (dim == inputDimAligned) // copyIn 和 compute一次,copyOut多次 { dataLocal = inQueue.AllocTensor(); - DataCopy(dataLocal, srcDataBufferGm[turns * roundSize * dim], sizes * onceMoveNums); + DataCopy(dataLocal, srcDataBufferGm[turns * addrNumPerLoop * dim], addrNum * inputDimAligned); inQueue.EnQue(dataLocal); - Compute(sizes); + + Compute(addrNum); // 只有copyOut的管道支持拷贝到gm上 + LocalTensor dstLocal = outQueue.DeQue(); if (updateType == 0) { SetAtomicAdd(); } - for (int i = 0; i < sizes; i++) + for (int i = 0; i < addrNum; i++) { address = srcAddrLocal.GetValue(i); if (address != 0) { dstDataGm.SetGlobalBuffer((__gm__ T*)(address)); - DataCopy(dstDataGm, dstLocal[i*onceMoveNums], onceMoveNums); + DataCopy(dstDataGm, dstLocal[i * inputDimAligned], inputDimAligned); } } if (updateType == 0) @@ -117,10 +112,10 @@ private: } else { - for (int i = 0; i < sizes; i++) + for (int i = 0; i < addrNum; i++) { dataLocal = inQueue.AllocTensor(); - DataCopy(dataLocal, srcDataBufferGm[i * dim + turns * roundSize * dim], onceMoveNums); + DataCopy(dataLocal, srcDataBufferGm[i * dim + turns * addrNumPerLoop * dim], inputDimAligned); inQueue.EnQue(dataLocal); Compute(1); address = srcAddrLocal.GetValue(i); @@ -136,7 +131,7 @@ private: LocalTensor dstLocal = outQueue.AllocTensor(); DataCopyParams copyparams; copyparams.blockCount = 1; - copyparams.blockLen = onceMoveNums * sizeof(T) * nums / 32; + copyparams.blockLen = (inputDimAligned * sizeof(T) * nums) >> 5; // >> 5, 除以32,ub空间对齐 DataCopy(dstLocal, srcLocal, copyparams); outQueue.EnQue(dstLocal); inQueue.FreeTensor(srcLocal); @@ -146,8 +141,6 @@ private: { LocalTensor dstLocal = outQueue.DeQue(); - int offset = block_idx * dim * SingleCoreAddrLen / sizeof(int64_t) + (turns * roundSize * dim) + dim * index; - if (address != 0) { dstDataGm.SetGlobalBuffer((__gm__ T *)(address)); @@ -158,19 +151,19 @@ private: } #if defined(__DAV_C220_VEC__) - if (singleDataSize == 4) + if (typeSize == SIZE_OF_FLOAT_OR_INT) { copy_ubuf_to_gm_align_b32((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, 1, dim * sizeof(T), 0, 0, 0, 0); } - else if (singleDataSize == 2) + else if (typeSize == SIZE_OF_HALF) { copy_ubuf_to_gm_align_b16((__gm__ T *)dstDataGm.GetPhyAddr(), (__ubuf__ T *)dstLocal.GetPhyAddr(), 0, 1, dim * sizeof(T), 0, 0, 0, 0); } #else - DataCopy(dstDataGm, dstLocal, onceMoveNums); + DataCopy(dstDataGm, dstLocal, inputDimAligned); #endif } if (updateType == 0) @@ -181,8 +174,8 @@ private: } public: - int32_t roundSize, round, SingleCoreAddrLen, NeedComputeAddrLen, addrNums, cache, Veclen, dim, PingpongNum; - int32_t onceMoveNums, singleDataSize, updateType; + int32_t addrNumPerLoop, loopCount, singleCoreAddrLen, needComputeAddrLen, addrNums, cache, veclen, dim, pingpongNum; + int32_t inputDimAligned, typeSize, updateType; private: TPipe pipe; @@ -198,9 +191,9 @@ extern "C" __global__ __aicore__ void embedding_update_by_address(GM_ADDR addres { GET_TILING_DATA(constData, tiling); - int32_t embbedingType = constData.embbeding_type; + int32_t embeddingType = constData.embedding_type; - switch (embbedingType) + switch (embeddingType) { case 0: { -- Gitee From a7841d9df542fde679a46c8de24c1cc590b4fa91 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 27 Nov 2023 19:46:34 +0800 Subject: [PATCH 488/551] Match-id-572a88c9497090ea807c8ee3e42f67cf1c69b881 --- tests/run_python_dt.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/run_python_dt.sh b/tests/run_python_dt.sh index a0fcd861..773cece7 100644 --- a/tests/run_python_dt.sh +++ b/tests/run_python_dt.sh @@ -29,7 +29,7 @@ mkdir -p result function run_test_cases() { echo "Get testcases final result." - pytest --cov="${CUR_PATH}"/../mx_rec --cov-report=html --cov-report=xml --junit-xml=./final.xml --html=./final.html --self-contained-html --durations=5 -vv + pytest --cov="${CUR_PATH}"/../mx_rec --cov-report=html --cov-report=xml --junit-xml=./final.xml --html=./final.html --self-contained-html --durations=5 -vv --cov-branch coverage xml -i --omit="build/*,cust_op/*,src/*" cp coverage.xml final.xml final.html ./result cp -r htmlcov ./result -- Gitee From 18eda94d7b4c749301df8728e61712a9d1b103fc Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 28 Nov 2023 09:42:13 +0800 Subject: [PATCH 489/551] Match-id-338c7e02fab162dfd4528be14b11a62dacef8ae9 --- mx_rec/graph/modifier.py | 4 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 13 ++++++ src/core/hybrid_mgmt/hybrid_mgmt.h | 4 +- src/core/key_process/key_process.cpp | 68 +++++++++++++++------------- src/core/key_process/key_process.h | 21 +++++++++ src/pybind/module_main.cpp | 3 +- 6 files changed, 78 insertions(+), 35 deletions(-) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 8231fe53..7df2a230 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -22,7 +22,7 @@ from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_ from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, \ insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ - set_iterator_type + set_iterator_type, get_asc_manager from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ export_pb_graph, make_sorted_key_to_tensor_list, ReplacementSpec, AnchorRecord @@ -620,3 +620,5 @@ class GraphModifierHook(tf.estimator.SessionRunHook): if self._modify_graph and self._iterator_type == "MakeIterator": session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) + def end(self, session): + get_asc_manager().set_mpi_send_abnormal_status() diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 12ea1b58..a563d037 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1133,3 +1133,16 @@ int64_t HybridMgmt::GetTableCapacity(const string& embName) const return -1; #endif } + +void HybridMgmt::SetMpiSendAbnormalStatus() +{ + int sendValue = MPI_ABNORMAL_SEND_VALUE; + int channel = hybridMgmtBlock->lastRunChannelId; + if (channel < 0 || channel >= MAX_CHANNEL_NUM) { + LOG_WARN("channel is abnormal:{} when set mpi all reduce", channel); + return; + } + LOG_INFO(MGMT + "set mpi all reduce value:{}, channelId:{}", sendValue, channel); + preprocess->mpiAllReduceSend[channel] = sendValue; + preprocess->isNeedExit[channel] = true; +} diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 0755ad55..571a10ab 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -39,6 +39,8 @@ namespace MxRec { DDR }; + constexpr int MGMT_THREAD_ID = -1; + class HybridMgmt { public: HybridMgmt() = default; @@ -137,6 +139,7 @@ namespace MxRec { int64_t GetTableCapacity(const string& embName) const; + void SetMpiSendAbnormalStatus(); private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed); @@ -153,7 +156,6 @@ namespace MxRec { static void AddCacheManagerTraceLog(CkptData& saveData); void RestoreFreq4Save(CkptData& saveData) const; - private: int currentBatchId; int trainBatchId = 0; // 0-199, 200- diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index b3847f4d..fffc5b48 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -160,6 +160,8 @@ void KeyProcess::LoadKeyCountMap(KeyCountMemT& loadData) // 只在python侧当训练结束时调用,如果出现死锁直接结束程序即可,测试时让进程等待足够长的时间再调用 void KeyProcess::Destroy() { + mpiAllReduceSend[0] = MPI_ABNORMAL_SEND_VALUE; + mpiAllReduceSend[1] = MPI_ABNORMAL_SEND_VALUE; isRunning = false; LOG_INFO(KEY_PROCESS "rankId:{} KeyProcess begin destroy.", rankInfo.rankId); for (auto& i: procThreads) { @@ -247,7 +249,6 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) throw runtime_error(Logger::Format("create fast unique failed, error code:{}", ret)); } GetUniqueConfig(uniqueConf); - try { while (true) { TimeCost getAndProcessTC; @@ -265,9 +266,9 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) break; } LOG_INFO(KEY_PROCESS "getAndProcessTC(ms):{}, key process with fast unique cost:{}," - " get data time(ms):{}, batch name:{}, channel:{}, batchID:{}", - getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, - batch->name, batch->channel, batch->batchId); + " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}", + getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, + batch->name, batch->channel, threadId, batch->batchId); int queueIndex = threadId + (MAX_KEY_PROCESS_THREAD * batch->channel); auto batchQueue = SingletonQueue::GetInstances(queueIndex); batchQueue->PutDirty(move(batch)); @@ -276,8 +277,8 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) } catch (const EndRunExit &e) { LOG_INFO(KEY_PROCESS "abort run: {}", e.what()); } - LOG_INFO(KEY_PROCESS "KeyProcessTaskWithFastUnique exit. rank:{} thread:{}, channel:{}", - rankInfo.rankId, threadId, channel); + LOG_INFO(KEY_PROCESS "KeyProcessTaskWithFastUnique exit. rank:{} channelId:{}, threadId:{}", + rankInfo.rankId, channel, threadId); } @@ -300,9 +301,9 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) break; } LOG_INFO(KEY_PROCESS "getAndProcessTC(ms):{}, key process cost:{}," - " get data time(ms):{}, batch name:{}, channel:{}, batchID:{}", - getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, - batch->name, batch->channel, batch->batchId); + " get data time(ms):{}, batch name:{}, channelId:{}, threadId:{}, batchId:{}", + getAndProcessTC.ElapsedMS(), processDataTime.ElapsedMS(), getBatchTime, + batch->name, batch->channel, threadId, batch->batchId); int queueIndex = threadId + (MAX_KEY_PROCESS_THREAD * batch->channel); auto batchQueue = SingletonQueue::GetInstances(queueIndex); batchQueue->PutDirty(move(batch)); @@ -310,7 +311,7 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) } catch (const EndRunExit &e) { LOG_INFO(KEY_PROCESS "abort run: {}", e.what()); } - LOG_INFO(KEY_PROCESS "KeyProcessTask exit. rank:{} thread:{}, channel:{}", rankInfo.rankId, threadId, channel); + LOG_INFO(KEY_PROCESS "KeyProcessTask exit. rank:{} channelId:{}, threadId:{}", rankInfo.rankId, channel, threadId); } void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & splitKeys, @@ -531,9 +532,8 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) batch = batchQueue->TryPop(); if (batch != nullptr) { break; - } else { - this_thread::sleep_for(100us); } + this_thread::sleep_for(100us); if (tc.ElapsedSec() > GET_BATCH_TIMEOUT) { if (commId == 0) { LOG_WARN(KEY_PROCESS "getting batch timeout! 1. check last 'read batch cost' print. " @@ -542,17 +542,20 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) this_thread::sleep_for(seconds(1)); tc = TimeCost(); } - if (!isRunning) { - LOG_WARN("channelId:{} threadId:{}, isRunning is false when GetBatchData", channel, commId); + + if (!isRunning || isNeedExit[channel]) { + LOG_WARN("channelId:{} threadId:{}, enter GetBatchData abnormal scene, isRunning:{}, isNeedExit:{}", + channel, commId, isRunning, isNeedExit[channel]); // 通信终止信号,同步退出,防止线程卡住 - int exitFlag = isRunning; - auto retCode = MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + int receiveFlag = 0; + int sendValue = 0; // 此处直接发送0,不使用mpiAllReduceSend值,防止多线程数据可见性问题 + auto retCode = MPI_Allreduce(&sendValue, &receiveFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); } LOG_DEBUG("channelId:{} threadId:{}, GetBatchData Allreduce end, receiveFlag:{}", - channel, commId, exitFlag); - throw EndRunExit("GetBatchData end run."); + channel, commId, receiveFlag); + throw EndRunExit("GetBatchData end run, thread will exit."); } } EASY_END_BLOCK @@ -942,7 +945,7 @@ tuple, vector, vector> } uKey[key] = restore[i]; } - + if (GlogConfig::gStatOn) { size_t uniqueKeyNum = 0; for (int devId = 0; devId < rankInfo.rankSize; ++devId) { @@ -1037,9 +1040,9 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, cons // 通信终止信号,同步退出,防止线程卡住 TimeCost tc = TimeCost(); - int exitFlag = isRunning; - int receiveFlag = exitFlag; - auto retCode = MPI_Allreduce(&exitFlag, &receiveFlag, 1, MPI_INT, MPI_SUM, comm[batch->channel][commId]); + int receiveFlag = 0; + auto retCode = MPI_Allreduce(&mpiAllReduceSend[batch->channel], &receiveFlag, 1, MPI_INT, MPI_SUM, + comm[batch->channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {} commId {}, MPI_Allreduce failed:{}", rankInfo.rankId, commId, retCode); } @@ -1069,14 +1072,14 @@ void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &ba } if (receiveFlag < rankInfo.rankSize) { unique_lock lockGuard(destroyMutex); - if (!isRunning) { - LOG_INFO("channelId:{} threadId:{} batchId:{}, isRunning is false after lock destroyMutex.", + if (isNeedExit[batch->channel]) { + LOG_INFO("channelId:{} threadId:{} batchId:{}, has send acl eos info, thread will exit.", batch->channel, commId, batch->batchId); - throw EndRunExit("GetScAll end run, isRunning is false after lock destroyMutex."); + throw EndRunExit("has send acl eos info, thread will exit."); } SendEosInfo(commId, batch); - isRunning = false; - throw EndRunExit("has SendEosInfo, GetScAll end run."); + isNeedExit[batch->channel] = true; + throw EndRunExit("has SendEosInfo, GetScAll end, thread will exit."); } } @@ -1118,16 +1121,17 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, co EASY_BLOCK("barrier"); // 通信终止信号,同步退出,防止线程卡住 TimeCost tc = TimeCost(); - int exitFlag = isRunning; - auto retCode = MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + int receiveFlag = 0; + auto retCode = MPI_Allreduce(&mpiAllReduceSend[channel], &receiveFlag, 1, MPI_INT, MPI_SUM, + comm[channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); } LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, GetScAllForUnique MPI AllReduce end, " "receiveFlag:{}, barrier time:{}", - channel, commId, batch->batchId, exitFlag, tc.ElapsedMS()); + channel, commId, batch->batchId, receiveFlag, tc.ElapsedMS()); // 处理其他rank线程退出的情况 - HandleRankExitScene(commId, batch, exitFlag); + HandleRankExitScene(commId, batch, receiveFlag); EASY_END_BLOCK; // allgather keyScLocal(key all2all keyScLocal = device all2all rc) @@ -1519,4 +1523,4 @@ void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) } singleKeyCountMap[key]++; } -} \ No newline at end of file +} diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 99d2d853..b9cecf44 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -58,6 +58,9 @@ namespace MxRec { INVALID }; + constexpr int MPI_ABNORMAL_SEND_VALUE = 0; // MPI异常通信时发送0 + constexpr int MPI_NORMAL_SEND_VALUE = 1; // MPI正常通信时发送1 + class EndRunExit : public std::exception { public: explicit EndRunExit(const char* message) : errorMessage(message) {} @@ -71,6 +74,20 @@ namespace MxRec { const char* errorMessage; }; + // 结束运行并阻塞异常 + class EndRunBlock : public std::exception { + public: + explicit EndRunBlock(const char *message) : errorMessage(message) {} + + const char *what() const noexcept override + { + return errorMessage; + } + + private: + const char *errorMessage; + }; + class EmptyList : public std::exception { }; @@ -157,6 +174,10 @@ namespace MxRec { } bool isRunning { false }; + // 是否需要退出当前通道对应预处理线程,区分channel;已发送eos信息时需退出 + bool isNeedExit[2] = {false, false}; + // MPI all reduce通信时发送数据 + int mpiAllReduceSend[2] = {MPI_NORMAL_SEND_VALUE, MPI_NORMAL_SEND_VALUE}; std::mutex destroyMutex; inline bool HasEmbName(const string& embName) { diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index d3793f1d..ed455647 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -192,7 +192,8 @@ namespace { .def("block_count_steps", &MxRec::HybridMgmt::CountStepBySessionRun, py::arg("channel_id"), py::arg("steps")=1) .def("get_table_size", &MxRec::HybridMgmt::GetTableSize, py::arg("table_name")) - .def("get_table_capacity", &MxRec::HybridMgmt::GetTableCapacity, py::arg("table_name")); + .def("get_table_capacity", &MxRec::HybridMgmt::GetTableCapacity, py::arg("table_name")) + .def("set_mpi_send_abnormal_status", &MxRec::HybridMgmt::SetMpiSendAbnormalStatus); } void GetThresholdValue(pybind11::module_& m) -- Gitee From 4e3f82b2bb2c224c064ba69c3d5f2fa4b50cd52a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 28 Nov 2023 11:27:04 +0800 Subject: [PATCH 490/551] Match-id-7663c09e05907aa51b2a5f80a867efa1cecda12c --- mx_rec/saver/saver.py | 2 +- mx_rec/saver/sparse.py | 60 +++++++++++++++--------------------------- 2 files changed, 22 insertions(+), 40 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 0f2f0096..361320f6 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -467,9 +467,9 @@ def validate_read_file(read_file_path): """ file_validator = FileValidator("read_file_path", read_file_path) file_validator.check_file_size(MAX_FILE_SIZE, MIN_SIZE) - file_validator.check_user_group() if not check_file_system_is_hdfs(read_file_path): file_validator.check_not_soft_link() + file_validator.check_user_group() file_validator.check() diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index f4f9cdaa..a844e17d 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -6,11 +6,13 @@ import os import json import numpy as np +import tensorflow as tf from mx_rec.util.initialize import get_table_instance_by_name, export_table_name_set, get_sparse_dir from mx_rec.validator.validator import FileValidator from mx_rec.validator.validator import para_checker_decorator, ClassValidator from mx_rec.util.log import logger +from mx_rec.saver.saver import validate_read_file class SparseProcessor: @@ -43,43 +45,24 @@ class SparseProcessor: @staticmethod def _get_data(data_dir, dtype, data_shape): - with open(data_dir, "rb") as file: - # check whether data file is valid - file_validator = FileValidator("data_dir", data_dir) - # 1.check whether data_dir is soft link - file_validator.check_not_soft_link() - # 2.check data file size - file_validator.check_file_size() - file_validator.check() - - try: - data = np.fromfile(data_dir, dtype=dtype) - except FileNotFoundError as err: - raise FileNotFoundError(f"data dir not found.") from err + with tf.io.gfile.GFile(data_dir, "rb") as file: + validate_read_file(data_dir) + data = file.read() + data = np.fromstring(data, dtype=dtype) + data = data.reshape(data_shape) return data @staticmethod def _get_shape_from_attrib(attribute_dir, is_json): if is_json: - try: - with open(attribute_dir, "r") as fin: - # check whether attribute file is valid - file_validator = FileValidator("attribute_dir", attribute_dir) - # 1.check whether attribute_dir is soft link - file_validator.check_not_soft_link() - # 2.check attribute file size - file_validator.check_file_size() - file_validator.check() - attributes = json.load(fin) - except FileNotFoundError as err: - raise FileNotFoundError("attribute dir not found.") from err + with tf.io.gfile.GFile(attribute_dir, "r") as file: + validate_read_file(attribute_dir) + attributes = json.load(file) else: - try: - attributes = np.fromfile(attribute_dir, np.uint64) - except FileNotFoundError as err: - raise FileNotFoundError("attribute dir not found.") from err - + with tf.io.gfile.GFile(attribute_dir, "rb") as file: + attributes = file.read() + attributes = np.fromstring(attributes, dtype=np.uint64) return attributes def export_sparse_data(self): @@ -105,14 +88,15 @@ class SparseProcessor: table_instance.use_dynamic_expansion) transformed_data = dict(zip(key[:], emb_data[:])) save_path = os.path.join(out_dir, self.export_name + ".npy") - np.save(save_path, transformed_data) + with tf.io.gfile.GFile(save_path, "wb") as file: + np.save(file, transformed_data) def get_embedding(self, device_table_dir, host_table_dir, ddr, use_dynamic_expansion): emb_dir = os.path.join(device_table_dir, self.device_emb_dir) data_file, attribute_file = self._get_file_names(emb_dir) - if not os.path.exists(data_file): + if not tf.io.gfile.exists(data_file): raise FileExistsError(f"embedding data file {data_file} does not exist when reading.") - if not os.path.exists(attribute_file): + if not tf.io.gfile.exists(attribute_file): raise FileExistsError(f"attribute file {attribute_file} does not exist when reading.") if use_dynamic_expansion: @@ -139,9 +123,9 @@ class SparseProcessor: else: hashmap_dir = os.path.join(table_dir, self.host_hashmap_dir) data_file, attribute_file = self._get_file_names(hashmap_dir) - if not os.path.exists(data_file): + if not tf.io.gfile.exists(data_file): raise FileExistsError(f"hashmap data file {data_file} does not exist when reading.") - if not os.path.exists(attribute_file): + if not tf.io.gfile.exists(attribute_file): raise FileExistsError(f"hashmap attribute file {attribute_file} does not exist when reading.") shape_data = self._get_shape_from_attrib(attribute_file, is_json=False) @@ -156,14 +140,12 @@ class SparseProcessor: return key, offset def _get_file_names(self, directory): - files = [] data_file = None attribute_file = None - for _, _, file in os.walk(directory): - files.append(file) + files = tf.io.gfile.listdir(directory) if not files: raise FileExistsError(f"There is no files under the {directory} ") - for file in files[0]: + for file in files: if file.find(self.data_suffix) != -1: data_file = file elif file.find(self.attrib_suffix) != -1: -- Gitee From 6fa3cb9a66384aac9a6970dda5f289bfea405646 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 28 Nov 2023 19:29:09 +0800 Subject: [PATCH 491/551] Match-id-1b48287569fa1ef2d922d6f0579af2eb6b355bdd --- mx_rec/graph/modifier.py | 4 +--- mx_rec/util/global_env_conf.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 7df2a230..8231fe53 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -22,7 +22,7 @@ from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_ from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, \ insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ - set_iterator_type, get_asc_manager + set_iterator_type from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ export_pb_graph, make_sorted_key_to_tensor_list, ReplacementSpec, AnchorRecord @@ -620,5 +620,3 @@ class GraphModifierHook(tf.estimator.SessionRunHook): if self._modify_graph and self._iterator_type == "MakeIterator": session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) - def end(self, session): - get_asc_manager().set_mpi_send_abnormal_status() diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index 45e4cda6..2a6bf186 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -63,7 +63,7 @@ def get_global_env_conf() -> RecEnv: use_combine_faae=os.getenv(EnvOption.USE_COMBINE_FAAE.value, Flag.FALSE.value), stat_on=os.getenv(EnvOption.STAT_ON.value, Flag.FALSE.value), record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value), - add_control_edge=os.getenv(EnvOption.ADD_CONTROL_EDGE.value, Flag.TRUE.value) + add_control_edge=os.getenv(EnvOption.ADD_CONTROL_EDGE.value, Flag.FALSE.value) ) return rec_env -- Gitee From 3610034d5b4dae7dde85fb97ff01bf539d4c25c2 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 28 Nov 2023 16:30:27 +0800 Subject: [PATCH 492/551] Match-id-4cc05638169e4035e671182400405ebbd4a59809 --- mx_rec/saver/sparse.py | 55 ++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index a844e17d..6f15c8f2 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -45,24 +45,32 @@ class SparseProcessor: @staticmethod def _get_data(data_dir, dtype, data_shape): - with tf.io.gfile.GFile(data_dir, "rb") as file: - validate_read_file(data_dir) - data = file.read() - data = np.fromstring(data, dtype=dtype) - - data = data.reshape(data_shape) + try: + with tf.io.gfile.GFile(data_dir, "rb") as file: + validate_read_file(data_dir) + data = file.read() + data = np.fromstring(data, dtype=dtype) + data = data.reshape(data_shape) + except Exception as err: + raise RuntimeError(f"error happened when get data from data file {data_dir}, " + f"the error is `{err}`.") from err return data @staticmethod def _get_shape_from_attrib(attribute_dir, is_json): - if is_json: - with tf.io.gfile.GFile(attribute_dir, "r") as file: - validate_read_file(attribute_dir) - attributes = json.load(file) - else: - with tf.io.gfile.GFile(attribute_dir, "rb") as file: - attributes = file.read() - attributes = np.fromstring(attributes, dtype=np.uint64) + try: + if is_json: + with tf.io.gfile.GFile(attribute_dir, "r") as file: + validate_read_file(attribute_dir) + attributes = json.load(file) + else: + with tf.io.gfile.GFile(attribute_dir, "rb") as file: + validate_read_file(attribute_dir) + attributes = file.read() + attributes = np.fromstring(attributes, dtype=np.uint64) + except Exception as err: + raise RuntimeError(f"error happened when get shape from attribute file {attribute_dir}, " + f"the error is `{err}`.") from err return attributes def export_sparse_data(self): @@ -94,10 +102,6 @@ class SparseProcessor: def get_embedding(self, device_table_dir, host_table_dir, ddr, use_dynamic_expansion): emb_dir = os.path.join(device_table_dir, self.device_emb_dir) data_file, attribute_file = self._get_file_names(emb_dir) - if not tf.io.gfile.exists(data_file): - raise FileExistsError(f"embedding data file {data_file} does not exist when reading.") - if not tf.io.gfile.exists(attribute_file): - raise FileExistsError(f"attribute file {attribute_file} does not exist when reading.") if use_dynamic_expansion: device_attribute = self._get_shape_from_attrib(attribute_file, is_json=False) @@ -123,10 +127,6 @@ class SparseProcessor: else: hashmap_dir = os.path.join(table_dir, self.host_hashmap_dir) data_file, attribute_file = self._get_file_names(hashmap_dir) - if not tf.io.gfile.exists(data_file): - raise FileExistsError(f"hashmap data file {data_file} does not exist when reading.") - if not tf.io.gfile.exists(attribute_file): - raise FileExistsError(f"hashmap attribute file {attribute_file} does not exist when reading.") shape_data = self._get_shape_from_attrib(attribute_file, is_json=False) if len(shape_data) < 2: @@ -144,14 +144,23 @@ class SparseProcessor: attribute_file = None files = tf.io.gfile.listdir(directory) if not files: - raise FileExistsError(f"There is no files under the {directory} ") + raise FileExistsError(f"There is no files under the {directory}.") for file in files: if file.find(self.data_suffix) != -1: data_file = file elif file.find(self.attrib_suffix) != -1: attribute_file = file + if not data_file: + raise FileNotFoundError(f"There is no data file under the {directory}.") + if not attribute_file: + raise FileNotFoundError(f"There is no attribute file under the {directory}.") + data_file = os.path.join(directory, data_file) attribute_file = os.path.join(directory, attribute_file) + if not tf.io.gfile.exists(data_file): + raise FileExistsError(f"embedding data file {data_file} does not exist when reading.") + if not tf.io.gfile.exists(attribute_file): + raise FileExistsError(f"attribute file {attribute_file} does not exist when reading.") return data_file, attribute_file -- Gitee From affbe2bd54ab84685e95aa049b1ac8a9a5a8c426 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 30 Nov 2023 11:15:30 +0800 Subject: [PATCH 493/551] Match-id-35605da705d44cbbff571b0b1b5af95176f315ed --- mx_rec/optimizers/__init__.py | 1 - mx_rec/optimizers/momentum.py | 124 ---------------------------------- 2 files changed, 125 deletions(-) delete mode 100644 mx_rec/optimizers/momentum.py diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index de88c7e2..660bdc14 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -13,4 +13,3 @@ from mx_rec.optimizers.gradient_descent import create_hash_optimizer from mx_rec.optimizers.gradient_descent_by_addr import create_hash_optimizer_by_addr from mx_rec.optimizers.lazy_adam import create_hash_optimizer from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address -from mx_rec.optimizers.momentum import create_hash_optimizer \ No newline at end of file diff --git a/mx_rec/optimizers/momentum.py b/mx_rec/optimizers/momentum.py deleted file mode 100644 index 8c8737fc..00000000 --- a/mx_rec/optimizers/momentum.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import defaultdict - -from tensorflow.python.ops import math_ops -from tensorflow.python.training import training_ops -from tensorflow.python.training import momentum -from tensorflow.python.training import slot_creator - -from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance, insert_removing_var_list -from mx_rec.util.variable import check_and_get_config_via_var -from mx_rec.constants.constants import MAX_INT32 -from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator, ClassValidator - - -@para_checker_decorator(check_option_list=[ - ("learning_rate", FloatValidator, {"min_value": -MAX_INT32, "max_value": MAX_INT32}, ["check_value"]), - ("mom", FloatValidator, {"min_value": 0, "max_value": 1}, ["check_value"]), - ("use_locking", ClassValidator, {"classes": (bool,)}), - ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]), - ("enable_nesterov", ClassValidator, {"classes": (bool,)}), -]) -def create_hash_optimizer(learning_rate=0.001, mom=0.9, use_locking=False, name="momentum", - enable_nesterov=False): - """ - Create an instance of hash optimizer - :param learning_rate: A `Tensor` or a floating point value. The learning rate. - :param mom: A `Tensor` or a floating point value. The momentum. - :param use_locking: If `True` use locks for update operations. - :param name: Optional name prefix for the operations created when applying gradients. - Defaults to "Momentum". - :param enable_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al., 2013). This implementation always - computes gradients at the value of the variable(s) passed to the optimizer. Using Nesterov Momentum makes the - variable(s) track the values called `theta_t + mu*v_t` in the paper. This implementation is an approximation of - the original formula, valid for high values of momentum. It will compute the "adjusted gradient" in NAG by - assuming that the new gradient will be estimated by the current average gradient plus the product of momentum and - the change in the average gradient. - :return: momentum hash optimizer instance - """ - return CustomizedMomentum(learning_rate=learning_rate, - momentum_var=mom, - use_locking=use_locking, - name=name, - use_nesterov=enable_nesterov) - - -class CustomizedMomentum(momentum.MomentumOptimizer, CustomizedOptimizer): - name_counter = defaultdict(int) - - def __init__(self, - learning_rate, - momentum_var, - use_locking=False, - name="Momentum", - use_nesterov=False): - self.optimizer_type = "Momentum" - super(CustomizedMomentum, self)._get_name(name=name) - super(CustomizedMomentum, self).__init__(learning_rate=learning_rate, - momentum=momentum_var, - use_locking=use_locking, - name=self.unique_name, - use_nesterov=use_nesterov) - - def initialize_slots(self, var, table_instance): - # Create slots for the first and second moments. - def creat_one_single_slot(var, op_name): - new_slot_variable = slot_creator.create_zeros_slot(var, op_name) - # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - return new_slot_variable - - momentum_slot = creat_one_single_slot(var, self._name + "/" + "momentum") - insert_removing_var_list(momentum_slot.name) - named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") - - table_instance.set_optimizer(self._name, {"momentum": momentum_slot}) - return [{"slot": momentum_slot, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}] - - def insert_slot(self, slot, named_slots_key, slot_name): - named_slots = self._slot_dict(slot_name) - if named_slots_key in named_slots: - raise EnvironmentError(f"named_slots_key should be global unique, but it has been in use now, " - f"please double check.") - - named_slots[named_slots_key] = slot - - def get_slot_init_values(self): - # return state value list of momentum that needs to initialize in ASC DDR. - initial_momentum_value = 0.0 - return [initial_momentum_value] - - def _create_slots(self, var_list): - m_state_name = self._name + "/" + "momentum" - for var in var_list: - table_instance = check_and_get_config_via_var(var, self.optimizer_type) - momentum_slot = self._zeros_slot(var, "m", m_state_name) - insert_removing_var_list(momentum_slot.name) - if self._name not in table_instance.optimizer: - table_instance.set_optimizer(self._name, {"momentum": momentum_slot}) - - def _apply_sparse(self, grad, var): - mom = self.get_slot(var, "m") - return training_ops.sparse_apply_momentum( - var, mom, math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), - grad.values, grad.indices, math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), - use_locking=self._use_locking, - use_nesterov=self._use_nesterov).op - - def _resource_apply_sparse(self, grad, var, indices): - mom = self.get_slot(var, "m") - return training_ops.resource_sparse_apply_momentum( - var.handle, mom.handle, math_ops.cast(self._learning_rate_tensor, grad.dtype), - grad, indices, math_ops.cast(self._momentum_tensor, grad.dtype), - use_locking=self._use_locking, - use_nesterov=self._use_nesterov) -- Gitee From aee4f667d93454aecac084fbf743f7fe182a1c84 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 30 Nov 2023 10:03:33 +0800 Subject: [PATCH 494/551] Match-id-88b6569cc40ea1920ab044b40cc0f80b1f010904 --- src/core/file_system/file_system.h | 2 +- .../hdfs_file_system/hdfs_file_system.cpp | 6 +++--- .../hdfs_file_system/hdfs_file_system.h | 2 +- .../hdfs_file_system/hdfs_wrapper.h | 19 ++++++++++--------- .../local_file_system/local_file_system.h | 7 ++++--- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 8 ++++---- src/core/hybrid_mgmt/hybrid_mgmt.h | 2 +- src/core/key_process/key_process.cpp | 15 ++++++++------- src/core/key_process/key_process.h | 2 +- 9 files changed, 33 insertions(+), 30 deletions(-) diff --git a/src/core/file_system/file_system.h b/src/core/file_system/file_system.h index ab6a2204..2e8a788f 100644 --- a/src/core/file_system/file_system.h +++ b/src/core/file_system/file_system.h @@ -42,4 +42,4 @@ namespace MxRec { }; } -#endif //MX_REC_FILE_SYSTEM_H +#endif // MX_REC_FILE_SYSTEM_H diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp index 01a6507f..5ac3291d 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp @@ -221,7 +221,7 @@ ssize_t HdfsFileSystem::Read(const string& filePath, char* fileContent, size_t d return static_cast(readBytesNum); } -ssize_t HdfsFileSystem::Read(const string& filePath, vector>& fileVector, size_t datasetSize) +ssize_t HdfsFileSystem::Read(const string& filePath, vector>& fileContent, size_t datasetSize) { hdfsFS fs = ConnectHdfs(); @@ -231,7 +231,7 @@ ssize_t HdfsFileSystem::Read(const string& filePath, vector>& file throw runtime_error("open hdfs file failed."); } - size_t embDataOuterSize = fileVector.capacity(); + size_t embDataOuterSize = fileContent.capacity(); auto onceReadByteSize { datasetSize / embDataOuterSize }; tSize readBytesNum = 0; @@ -245,7 +245,7 @@ ssize_t HdfsFileSystem::Read(const string& filePath, vector>& file } else { readSize = dataCol; } - tSize res = hdfs->Read(fs, file, reinterpret_cast(fileVector[i].data()) + idx, readSize); + tSize res = hdfs->Read(fs, file, reinterpret_cast(fileContent[i].data()) + idx, readSize); if (res == -1) { hdfs->CloseFile(fs, file); hdfs->Disconnect(fs); diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.h b/src/core/file_system/hdfs_file_system/hdfs_file_system.h index 9de39dce..c941ba44 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_file_system.h +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.h @@ -32,7 +32,7 @@ namespace MxRec { const vector& addressArr, int deviceId) override; ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) override; - ssize_t Read(const string& filePath, vector>& fileVector, size_t datasetSize) override; + ssize_t Read(const string& filePath, vector>& fileContent, size_t datasetSize) override; void ReadEmbedding(const string &filePath, const int& embeddingSize, vector& addressArr, int deviceId) override; diff --git a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h index e1a0ab5d..1b225ae7 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h +++ b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h @@ -16,7 +16,8 @@ namespace MxRec { // The following parameters are not named in large camel case to adapt to native HDFS interfaces. - // Including tObjectKind, tPort, tSize, tTime, tOffset, hdfs_internal, hdfsFS, hdfsFile_internal, hdfsFile, hdfsFileInfo + // Including: tObjectKind, tPort, tSize, tTime, tOffset, hdfs_internal, hdfsFS, hdfsFile_internal, + // hdfsFile, hdfsFileInfo enum tObjectKind { kObjectKindFile = 'F', kObjectKindDirectory = 'D', @@ -74,7 +75,7 @@ namespace MxRec { dlclose(libhdfs); } - hdfsFS Connect(const char* host, tPort port) + hdfsFS Connect(const char* host, tPort port) const { if (hdfsConnect == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsConnect from the libhdfs."); @@ -82,7 +83,7 @@ namespace MxRec { return hdfsConnect(host, port); } - int Disconnect(hdfsFS fs) + int Disconnect(hdfsFS fs) const { if (hdfsDisconnect == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsDisconnect from the libhdfs."); @@ -90,7 +91,7 @@ namespace MxRec { return hdfsDisconnect(fs); } - int CreateDirectory(hdfsFS fs, const char* path) + int CreateDirectory(hdfsFS fs, const char* path) const { if (hdfsCreateDirectory == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsCreateDirectory from libhdfs."); @@ -98,7 +99,7 @@ namespace MxRec { return hdfsCreateDirectory(fs, path); } - hdfsFileInfo* ListDirectory(hdfsFS fs, const char* path, int *numEntries) + hdfsFileInfo* ListDirectory(hdfsFS fs, const char* path, int *numEntries) const { if (hdfsListDirectory == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsListDirectory from the libhdfs."); @@ -106,7 +107,7 @@ namespace MxRec { return hdfsListDirectory(fs, path, numEntries); } - hdfsFileInfo* GetPathInfo(hdfsFS fs, const char* path) + hdfsFileInfo* GetPathInfo(hdfsFS fs, const char* path) const { if (hdfsGetPathInfo == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsGetPathInfo from the libhdfs."); @@ -114,7 +115,7 @@ namespace MxRec { return hdfsGetPathInfo(fs, path); } - void FreeFileInfo(hdfsFileInfo *hdfsFileInfo, int numEntries) + void FreeFileInfo(hdfsFileInfo *hdfsFileInfo, int numEntries) const { if (hdfsFreeFileInfo == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsFreeFileInfo from the libhdfs."); @@ -138,7 +139,7 @@ namespace MxRec { return hdfsCloseFile(fs, file); } - tSize Read(hdfsFS fs, hdfsFile file, void* buffer, tSize length) + tSize Read(hdfsFS fs, hdfsFile file, void* buffer, tSize length) const { if (hdfsRead == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsRead from the libhdfs."); @@ -146,7 +147,7 @@ namespace MxRec { return hdfsRead(fs, file, buffer, length); } - tSize Write(hdfsFS fs, hdfsFile file, const void* buffer, tSize length) + tSize Write(hdfsFS fs, hdfsFile file, const void* buffer, tSize length) const { if (hdfsWrite == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsWrite from the libhdfs."); diff --git a/src/core/file_system/local_file_system/local_file_system.h b/src/core/file_system/local_file_system/local_file_system.h index e951dd8f..2e743165 100644 --- a/src/core/file_system/local_file_system/local_file_system.h +++ b/src/core/file_system/local_file_system/local_file_system.h @@ -13,10 +13,11 @@ namespace MxRec { using namespace std; - + const int DIR_RIGHT_MODE = 0750; + const int FILE_RIGHT_MODE = 0640; class LocalFileSystem : public FileSystem { public: - LocalFileSystem() : dirMode(0750), fileMode(0640), currDir("."), prevDir("..") {} + LocalFileSystem() : dirMode(DIR_RIGHT_MODE), fileMode(FILE_RIGHT_MODE), currDir("."), prevDir("..") {} ~LocalFileSystem() override {} void CreateDir(const string& dirName) override; @@ -29,7 +30,7 @@ namespace MxRec { const vector& addressArr, int deviceId) override; ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) override; - ssize_t Read(const string& filePath, vector>& fileVector, size_t datasetSize) override; + ssize_t Read(const string& filePath, vector>& fileContent, size_t datasetSize) override; void ReadEmbedding(const string& filePath, const int& embeddingSize, vector& addressArr, int deviceId) override; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index a563d037..f16292ad 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -377,14 +377,14 @@ void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, OffsetT HybridMgmt::SendHostMap(const string tableName) { #ifndef GTEST - OffsetT OffsetMap; + OffsetT offsetMap; // 先校验这个map是不是空的 if ((!offsetMapToSend.empty()) && offsetMapToSend.count(tableName) > 0) { for (auto& it : offsetMapToSend.at(tableName)) { - OffsetMap.push_back(it); + offsetMap.push_back(it); } } - return OffsetMap; + return offsetMap; #endif } @@ -664,7 +664,7 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) } void HybridMgmt::SendUniqKeysAndRestoreVecHBM(int channelId, int &batchId, const EmbInfo &embInfo, - const unique_ptr> &infoVecs) + const unique_ptr> &infoVecs) const { TimeCost sendUniqueKeysSyncTC; LOG_DEBUG("channelId:{} batchId:{}, global unique, table name: {}, is grad: {}", diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 571a10ab..0fa94cb9 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -190,7 +190,7 @@ namespace MxRec { void HandlePrepareDDRDataRet(TransferRet prepareSSDRet) const; void SendUniqKeysAndRestoreVecHBM(int channelId, int& batchId, const EmbInfo &embInfo, - const unique_ptr> &infoVecs); + const unique_ptr> &infoVecs) const; void SendUniqKeysAndRestoreVecDDR(const string &embName, int &batchId, int &channelId, DDRParam &ddrParam); }; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index fffc5b48..0af90acb 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -839,8 +839,8 @@ void KeyProcess::PaddingAlltoallVC(vector& splitKeys) const if (keys.size() % ALLTOALLVC_ALIGN == 0) { continue; } - int padding_size = ALLTOALLVC_ALIGN - (keys.size() % ALLTOALLVC_ALIGN); - std::fill_n(std::back_inserter(keys), padding_size, INVALID_KEY_VALUE); + int paddingSize = ALLTOALLVC_ALIGN - (keys.size() % ALLTOALLVC_ALIGN); + std::fill_n(std::back_inserter(keys), paddingSize, INVALID_KEY_VALUE); } return; } @@ -907,7 +907,7 @@ tuple, vector, vector> vector splitKeys(rankInfo.rankSize); vector restore(batch->Size()); absl::flat_hash_map uKey; // 用于去重查询 - absl::flat_hash_map keyCountMap; + absl::flat_hash_map keyCountMapByEmbName; std::shared_lock lock(g_smut); auto hotMap = hotKey[batch->name]; lock.unlock(); @@ -918,7 +918,7 @@ tuple, vector, vector> for (size_t i = 0; i < miniBs; i++) { // for mini batch const emb_key_t& key = batchData[i]; if (batch->batchId % hotEmbUpdateStep == 0) { - keyCountMap[key]++; + keyCountMapByEmbName[key]++; } emb_key_t devId = abs(key % static_cast(rankInfo.rankSize)); auto result = uKey.find(key); @@ -955,7 +955,8 @@ tuple, vector, vector> batch->channel, batch->batchId, rankInfo.rankId, batch->Size(), uniqueKeyNum); } - UpdateHotMap(keyCountMap, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, batch->name); + UpdateHotMap(keyCountMapByEmbName, hotEmbTotCount[batch->name], batch->batchId % hotEmbUpdateStep == 0, + batch->name); AddCountStartToHotPos(splitKeys, hotPos, hotPosDev, batch); return { splitKeys, restore, hotPos }; } @@ -1004,13 +1005,13 @@ void KeyProcess::UpdateHotMapForUnique(const KeysT &keySend, const vector& keyCountMap, uint32_t count, bool refresh, +void KeyProcess::UpdateHotMap(absl::flat_hash_map& keyCountMapByEmbName, uint32_t count, bool refresh, const string& embName) { auto& hotMap = hotKey[embName]; if (refresh) { priority_queue> pq; // top k key - for (auto& p: keyCountMap) { + for (auto& p: keyCountMapByEmbName) { pq.push(pair(-p.second, p.first)); if (pq.size() > count) { pq.pop(); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index b9cecf44..af8c12af 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -265,7 +265,7 @@ namespace MxRec { void EvictInitDeviceEmb(const string& embName, vector offset); - void UpdateHotMap(absl::flat_hash_map& keyCountMap, uint32_t count, bool refresh, + void UpdateHotMap(absl::flat_hash_map& keyCountMapByEmbName, uint32_t count, bool refresh, const string& embName); void UpdateHotMapForUnique(const KeysT &keySend, const vector &keyCount, -- Gitee From 051b9cc2d5295a4be763410cc19ec8ee52073cf9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 30 Nov 2023 17:44:44 +0800 Subject: [PATCH 495/551] Match-id-3af8c3f10815844868f26ca4d6e2749d17aeb0b7 --- src/CMakeLists.txt | 1 + src/core/hd_transfer/hd_transfer.cpp | 6 + src/core/hd_transfer/hd_transfer.h | 2 + src/core/hybrid_mgmt/hybrid_mgmt.cpp | 13 -- src/core/hybrid_mgmt/hybrid_mgmt.h | 3 - src/core/key_process/key_process.cpp | 163 ++++++++----------- src/core/key_process/key_process.h | 31 +--- src/dataset_tf/CMakeLists.txt | 15 ++ src/dataset_tf/eos_dataset_op.cc | 230 +++++++++++++++++++++++++++ src/dataset_tf/eos_dataset_op.h | 43 +++++ src/pybind/module_main.cpp | 3 +- 11 files changed, 368 insertions(+), 142 deletions(-) create mode 100644 src/dataset_tf/CMakeLists.txt create mode 100644 src/dataset_tf/eos_dataset_op.cc create mode 100644 src/dataset_tf/eos_dataset_op.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6c770bca..e5ab5996 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -112,6 +112,7 @@ endif() add_subdirectory(core) add_subdirectory(ops_tf) add_subdirectory(pybind) +add_subdirectory(dataset_tf) if (CMAKE_BUILD_TYPE MATCHES "Release") message(STATUS "CMAKE_BUILD_TYPE is Release") else() diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index 429e1c1a..fd0ce522 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -146,6 +146,7 @@ void HDTransfer::Send(TransferChannel channel, const vector &tensors, in } resendTime++; } while (isNeedResend); + usedChannelsNames[channelId].insert(TransferChannel2Str(channel)); #endif } @@ -211,4 +212,9 @@ size_t HDTransfer::RecvAcl(TransferChannel channel, int channelId, const string& std::unordered_map HDTransfer::GetTransChannel() { return transferChannels; +} + +std::unordered_map> HDTransfer::GetUsedTransChannel() +{ + return usedChannelsNames; } \ No newline at end of file diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index 6c046da0..eceaa617 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -84,9 +84,11 @@ namespace MxRec { void Destroy(); std::unordered_map GetTransChannel(); + unordered_map> GetUsedTransChannel(); private: std::unordered_map transferChannels; + std::unordered_map> usedChannelsNames; // key是通道0、1 bool running; void CreateChannel(const uint32_t localRankId, const string& embName, const int channelNum); }; diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index a563d037..12ea1b58 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1133,16 +1133,3 @@ int64_t HybridMgmt::GetTableCapacity(const string& embName) const return -1; #endif } - -void HybridMgmt::SetMpiSendAbnormalStatus() -{ - int sendValue = MPI_ABNORMAL_SEND_VALUE; - int channel = hybridMgmtBlock->lastRunChannelId; - if (channel < 0 || channel >= MAX_CHANNEL_NUM) { - LOG_WARN("channel is abnormal:{} when set mpi all reduce", channel); - return; - } - LOG_INFO(MGMT + "set mpi all reduce value:{}, channelId:{}", sendValue, channel); - preprocess->mpiAllReduceSend[channel] = sendValue; - preprocess->isNeedExit[channel] = true; -} diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 571a10ab..84537bf0 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -39,8 +39,6 @@ namespace MxRec { DDR }; - constexpr int MGMT_THREAD_ID = -1; - class HybridMgmt { public: HybridMgmt() = default; @@ -139,7 +137,6 @@ namespace MxRec { int64_t GetTableCapacity(const string& embName) const; - void SetMpiSendAbnormalStatus(); private: bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, const vector& thresholdValues, int seed); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index fffc5b48..41663f57 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -160,8 +160,6 @@ void KeyProcess::LoadKeyCountMap(KeyCountMemT& loadData) // 只在python侧当训练结束时调用,如果出现死锁直接结束程序即可,测试时让进程等待足够长的时间再调用 void KeyProcess::Destroy() { - mpiAllReduceSend[0] = MPI_ABNORMAL_SEND_VALUE; - mpiAllReduceSend[1] = MPI_ABNORMAL_SEND_VALUE; isRunning = false; LOG_INFO(KEY_PROCESS "rankId:{} KeyProcess begin destroy.", rankInfo.rankId); for (auto& i: procThreads) { @@ -249,6 +247,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) throw runtime_error(Logger::Format("create fast unique failed, error code:{}", ret)); } GetUniqueConfig(uniqueConf); + try { while (true) { TimeCost getAndProcessTC; @@ -542,20 +541,9 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) this_thread::sleep_for(seconds(1)); tc = TimeCost(); } - - if (!isRunning || isNeedExit[channel]) { - LOG_WARN("channelId:{} threadId:{}, enter GetBatchData abnormal scene, isRunning:{}, isNeedExit:{}", - channel, commId, isRunning, isNeedExit[channel]); - // 通信终止信号,同步退出,防止线程卡住 - int receiveFlag = 0; - int sendValue = 0; // 此处直接发送0,不使用mpiAllReduceSend值,防止多线程数据可见性问题 - auto retCode = MPI_Allreduce(&sendValue, &receiveFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); - if (retCode != MPI_SUCCESS) { - LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); - } - LOG_DEBUG("channelId:{} threadId:{}, GetBatchData Allreduce end, receiveFlag:{}", - channel, commId, receiveFlag); - throw EndRunExit("GetBatchData end run, thread will exit."); + if (!isRunning) { + LOG_WARN("channelId:{} threadId:{}, isRunning is false when GetBatchData", channel, commId); + throw EndRunExit("GetBatchData end run."); } } EASY_END_BLOCK @@ -1035,28 +1023,11 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, cons EASY_FUNCTION() vector scAll; scAll.resize(rankInfo.rankSize * rankInfo.rankSize); - EASY_BLOCK("barrier"); LOG_DEBUG("channelId:{} threadId:{} batchId:{}, GetScAll start.", batch->channel, commId, batch->batchId); - // 通信终止信号,同步退出,防止线程卡住 - TimeCost tc = TimeCost(); - int receiveFlag = 0; - auto retCode = MPI_Allreduce(&mpiAllReduceSend[batch->channel], &receiveFlag, 1, MPI_INT, MPI_SUM, - comm[batch->channel][commId]); - if (retCode != MPI_SUCCESS) { - LOG_ERROR("rank {} commId {}, MPI_Allreduce failed:{}", rankInfo.rankId, commId, retCode); - } - LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, GetScAll MPI_Allreduce end, receiveFlag:{}" - " barrier time:{}", - batch->channel, commId, batch->batchId, receiveFlag, tc.ElapsedMS()); - - // 处理其他rank线程退出的情况 - HandleRankExitScene(commId, batch, receiveFlag); - - EASY_END_BLOCK; // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, MPI_INT, - comm[batch->channel][commId]); + auto retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, + MPI_INT, comm[batch->channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {} commId {}, MPI_Allgather failed:{}", rankInfo.rankId, commId, retCode); } @@ -1065,78 +1036,16 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, cons return scAll; } -void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &batch, int receiveFlag) -{ - if (!isRunning) { - throw EndRunExit("GetScAll end run, isRunning is false."); - } - if (receiveFlag < rankInfo.rankSize) { - unique_lock lockGuard(destroyMutex); - if (isNeedExit[batch->channel]) { - LOG_INFO("channelId:{} threadId:{} batchId:{}, has send acl eos info, thread will exit.", - batch->channel, commId, batch->batchId); - throw EndRunExit("has send acl eos info, thread will exit."); - } - SendEosInfo(commId, batch); - isNeedExit[batch->channel] = true; - throw EndRunExit("has SendEosInfo, GetScAll end, thread will exit."); - } -} - -void KeyProcess::SendEosInfo(int commId, const unique_ptr& batch) -{ - // 注: SendTensorsByAcl方法UT无法链接 需屏蔽 -#ifndef GTEST - auto trans = Singleton::GetInstance(); - auto transChannel = trans->GetTransChannel(); - LOG_INFO("channelId:{} threadId:{} batchId:{}, start send acl eos info.", - batch->channel, commId, batch->batchId); - vector tensors; - bool isNeedResend = true; - string all2all_sendName = StringFormat("%s_%s_%d", batch->name.c_str(), - TransferChannel2Str(TransferChannel::ALL2ALL).c_str(), - batch->channel); - SendTensorsByAcl(transChannel[all2all_sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, - isNeedResend); - string restore_sendName = StringFormat("%s_%s_%d", batch->name.c_str(), - TransferChannel2Str(TransferChannel::RESTORE).c_str(), - batch->channel); - SendTensorsByAcl(transChannel[restore_sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, - isNeedResend); - string lookup_sendName = StringFormat("%s_%s_%d", batch->name.c_str(), - TransferChannel2Str(TransferChannel::LOOKUP).c_str(), batch->channel); - SendTensorsByAcl(transChannel[lookup_sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, - isNeedResend); - LOG_INFO("channelId:{} threadId:{} batchId:{}, send acl eos info end.", - batch->channel, commId, batch->batchId); -#endif -} - void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, const unique_ptr &batch, vector &scAllOut) { EASY_FUNCTION() int channel = batch->channel; scAllOut.resize(rankInfo.rankSize * rankInfo.rankSize); - EASY_BLOCK("barrier"); - // 通信终止信号,同步退出,防止线程卡住 - TimeCost tc = TimeCost(); - int receiveFlag = 0; - auto retCode = MPI_Allreduce(&mpiAllReduceSend[channel], &receiveFlag, 1, MPI_INT, MPI_SUM, - comm[channel][commId]); - if (retCode != MPI_SUCCESS) { - LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); - } - LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, GetScAllForUnique MPI AllReduce end, " - "receiveFlag:{}, barrier time:{}", - channel, commId, batch->batchId, receiveFlag, tc.ElapsedMS()); - // 处理其他rank线程退出的情况 - HandleRankExitScene(commId, batch, receiveFlag); - EASY_END_BLOCK; // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, - scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + auto retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, + scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Allgather failed:{}", rankInfo.rankId, retCode); } @@ -1305,6 +1214,49 @@ KeysT KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) } } +/// 当数据列表为空,且eos标志位为true时,主动发送eos +/// \param batchId 已处理的batch数 +/// \param channel 通道索引(训练/推理) +void KeyProcess::SendEos(int batchId, int channel) +{ +#ifndef GTEST + LOG_INFO("channelId:{} batchId:{}, SendEos start.", channel, batchId); + + auto trans = Singleton::GetInstance(); + unordered_map transChannels = trans->GetTransChannel(); + std::set usedChannelNames = trans->GetUsedTransChannel()[channel]; + + vector tensors; + bool isNeedResend = true; + for (const auto& emb:embInfos) { // 一个表触发以后,其余表都发送eos,最后外层接收null退出此次循环 + LOG_INFO("channelId:{} batchId:{}, the embName:{} related channel SendEos start.", channel, batchId, emb.first); + if (!isRunning) { + throw EndRunExit("SendEos end run, isRunning is false after lock destroyMutex."); + } + + string randomSendName = StringFormat("%s_%s_%d", emb.first.c_str(), (*usedChannelNames.begin()).c_str(), + channel); + + size_t channel_size; // 避免eos在keyProcess还未处理完数据时插队到通道前面 + do { + acltdtQueryChannelSize(transChannels[randomSendName], &channel_size); + LOG_TRACE("Before SendEos, channelName:{}, unsolved channel_size {}", randomSendName, channel_size); + this_thread::sleep_for(1ms); + } while (channel_size != 0); + + for (const string& transName : usedChannelNames) { + string sendName = StringFormat("%s_%s_%d", emb.first.c_str(), transName.c_str(), channel); + SendTensorsByAcl(transChannels[sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, isNeedResend); + LOG_DEBUG("SendTensorsByAcl eos channelName:{}, batchId:{}", sendName, batchId); + } + LOG_INFO("channelId:{} batchId:{}, the embName:{} related channel SendEos end.", channel, batchId, emb.first); + } + + LOG_INFO("channelId:{} batchId:{}, SendEos end.", channel, batchId); + isNeedSendEos[channel] = false; +#endif +} + /// HBM模式下,从list中获取指定类型的tensor向量 /// \param batch 已处理的batch数 /// \param embName 表名 @@ -1352,6 +1304,11 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa storage.erase(it); return uTensor; } catch (EmptyList&) { + unique_lock lockGuard(destroyMutex); + if (isNeedSendEos[channel]) { + SendEos(batch, channel); + return nullptr; + } LOG_TRACE("getting info failed {}[{}]:{}", embName, channel, batch); this_thread::sleep_for(1ms); } catch (WrongListTop&) { @@ -1524,3 +1481,11 @@ void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) singleKeyCountMap[key]++; } } + +void KeyProcess::SetEos(int status, int channelId) +{ + unique_lock lockGuard(destroyMutex); + LOG_INFO("isNeedSendEos status is changed, before status:[{}], input status:{}, channel:[{}], ", + isNeedSendEos[channelId], status, channelId); + isNeedSendEos[channelId] = (status == 1); +} diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index b9cecf44..4e388956 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -58,9 +58,6 @@ namespace MxRec { INVALID }; - constexpr int MPI_ABNORMAL_SEND_VALUE = 0; // MPI异常通信时发送0 - constexpr int MPI_NORMAL_SEND_VALUE = 1; // MPI正常通信时发送1 - class EndRunExit : public std::exception { public: explicit EndRunExit(const char* message) : errorMessage(message) {} @@ -74,20 +71,6 @@ namespace MxRec { const char* errorMessage; }; - // 结束运行并阻塞异常 - class EndRunBlock : public std::exception { - public: - explicit EndRunBlock(const char *message) : errorMessage(message) {} - - const char *what() const noexcept override - { - return errorMessage; - } - - private: - const char *errorMessage; - }; - class EmptyList : public std::exception { }; @@ -173,11 +156,12 @@ namespace MxRec { } } + void SetEos(int status, int channelId); + + void SendEos(int batchId, int channel); + bool isRunning { false }; - // 是否需要退出当前通道对应预处理线程,区分channel;已发送eos信息时需退出 - bool isNeedExit[2] = {false, false}; - // MPI all reduce通信时发送数据 - int mpiAllReduceSend[2] = {MPI_NORMAL_SEND_VALUE, MPI_NORMAL_SEND_VALUE}; + std::mutex destroyMutex; inline bool HasEmbName(const string& embName) { @@ -208,6 +192,7 @@ namespace MxRec { FactoryPtr factory {}; int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; + vector isNeedSendEos { false, false }; // 分别代表通道0、1的eos状态 void InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo); @@ -302,10 +287,6 @@ namespace MxRec { } string DumpSplitKeys(vector>& splitKeys) const; - - void SendEosInfo(int commId, const unique_ptr& batch); - - void HandleRankExitScene(int commId, const unique_ptr &batch, int receiveFlag); }; } // end namespace MxRec #endif // MX_REC_KEY_PROCESS_H diff --git a/src/dataset_tf/CMakeLists.txt b/src/dataset_tf/CMakeLists.txt new file mode 100644 index 00000000..70d42d88 --- /dev/null +++ b/src/dataset_tf/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.12) +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_BUILD_TYPE "Release") + +file(GLOB_RECURSE MXREC_OP_EOS_DATA ./*.cc) +add_library(rec_eos_ops SHARED ${MXREC_OP_EOS_DATA}) + +set(TF_INTERNAL_LIB ${TF_PATH}/python/_pywrap_tensorflow_internal.so) + +target_link_libraries(rec_eos_ops PUBLIC ASC + ${TF_INTERNAL_LIB} + ) + + +install(TARGETS rec_eos_ops LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}) \ No newline at end of file diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc new file mode 100644 index 00000000..aa5b4b62 --- /dev/null +++ b/src/dataset_tf/eos_dataset_op.cc @@ -0,0 +1,230 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: dataset eos ops. + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + +#include "eos_dataset_op.h" + +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/mutex.h" +#if defined(TF_VERSION_TF2) +#include "tensorflow/core/data/name_utils.h" +#endif + +#include "key_process/key_process.h" +#include "utils/logger.h" + +using namespace std; +using namespace MxRec; + +namespace tensorflow { +namespace data { + +MPI_Comm comm; +MPI_Group worldGroup; + +constexpr const char *const EosDatasetOp::kDatasetType; +constexpr const char *const EosDatasetOp::kInputDataset; +constexpr const char *const EosDatasetOp::kChannelId; +constexpr const char *const EosDatasetOp::kOutputTypes; +constexpr const char *const EosDatasetOp::kOutputShapes; + +// 表示数据集的不可变性定义,这个类的 MakeIterator() 方法告诉 TensorFlow 怎样在数据集上生成迭代器对象。 +class EosDatasetOp::Dataset : public DatasetBase { +public: + explicit Dataset(OpKernelContext *ctx, const DatasetBase *input, int32_t channelId) + : DatasetBase(DatasetContext(ctx)), + input_(input), + channelId_(channelId) + { + input_->Ref(); + auto os_input = input->output_shapes(); + output_shapes_ = os_input; + keyProcess = Singleton::GetInstance(); + MPI_Comm_group(MPI_COMM_WORLD, &worldGroup); + MPI_Comm_create(MPI_COMM_WORLD, worldGroup, &comm); + } + + ~Dataset() override + { + input_->Unref(); + } + + std::unique_ptr MakeIteratorInternal(const string &prefix) const override + { +#if defined(TF_VERSION_TF2) + string prefix_para = name_utils::IteratorPrefix(kDatasetType, prefix); +#else + string prefix_para = prefix + "::" + kDatasetType; +#endif + return absl::make_unique(Iterator::Params{ + this, prefix_para}); + } + + const DataTypeVector &output_dtypes() const override + { + return input_->output_dtypes(); + } + + const std::vector &output_shapes() const override + { + return output_shapes_; + } + + string DebugString() const override + { +#if defined(TF_VERSION_TF2) + return name_utils::DatasetDebugString(kDatasetType); +#else + return "NpuMapDatasetOp::DataSet"; +#endif + } + + int64 Cardinality() const override + { + return input_->Cardinality(); + } + + Status CheckExternalState() const override + { + return input_->CheckExternalState(); + } + +protected: + Status AsGraphDefInternal(SerializationContext *ctx, DatasetGraphDefBuilder *b, Node **output) const override + { + Node *input_graph = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); + Node *channel_id_x = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(channelId_, &channel_id_x)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, channel_id_x}, output)); + return Status::OK(); + } + +private: + // 表示特定数据集上的迭代器的可变性,这个类的 GetNextInternal() 方法告诉 TensorFlow 怎样获取迭代器的下一个元素。 + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params ¶ms) : DatasetIterator(params), i_(0) {} +#if defined(TF_VERSION_TF2) + Status Initialize(IteratorContext* ctx) override + { + return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); + } +#else + Status Initialize(IteratorContext *ctx) override + { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } +#endif + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override + { + mutex_lock l(mu_); + int exitFlag = 0; + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + + // 正常数据流程 + if (!*end_of_sequence) { + LOG_TRACE("GetNext, step in MPI_Allreduce, exitFlag:[{}]", exitFlag); + MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm); + LOG_TRACE("GetNext, step out MPI_Allreduce, exitFlag:[{}]", exitFlag); + // 数据不均衡场景, 别的卡eos + if (exitFlag != 0) { + i_ = 1; + *end_of_sequence = true; + LOG_INFO("GetNext, some rank eos, channelID:[{}]", dataset()->channelId_); + dataset()->keyProcess->SetEos(1, dataset()->channelId_); + } + return Status::OK(); + } + // 数据eos场景 + i_ = 1; + exitFlag = 1; + *end_of_sequence = true; + input_impl_.reset(); + LOG_TRACE("GetNext eos, step in MPI_Allreduce, exitFlag:[{}]", exitFlag); + MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm); + LOG_TRACE("GetNext eos, step out MPI_Allreduce, exitFlag:[{}]", exitFlag); + + LOG_INFO("GetNext eos, channelID:[{}]", dataset()->channelId_); + dataset()->keyProcess->SetEos(1, dataset()->channelId_); + + return Status::OK(); + } + protected: + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override + { + return model::MakeKnownRatioNode(std::move(args), /* ratio= */ 1); + } +#if defined(TF_VERSION_TF2) + Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override + { + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); + return Status::OK(); + } +#else + Status SaveInternal(IteratorStateWriter* writer) override + { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + return Status::OK(); + } +#endif + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override + { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + tensorflow::mutex mu_; + int64 i_ GUARDED_BY(mu_); + std::unique_ptr input_impl_ GUARDED_BY(mu_); + }; + const DatasetBase *input_; + int32_t channelId_; + KeyProcess* keyProcess; + std::vector output_shapes_; +}; + +EosDatasetOp::EosDatasetOp(OpKernelConstruction *ctx) : UnaryDatasetOpKernel(ctx) {} + +void EosDatasetOp::MakeDataset(OpKernelContext *ctx, DatasetBase *input, DatasetBase **output) +{ + int32_t channel; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kChannelId, &channel)); + *output = new Dataset(ctx, input, channel); +} + +REGISTER_OP("EosDataset") +.Input("input_dataset: variant") +.Input("channel_id: int32") +.Output("handle: variant") +.Attr("output_types: list(type) >= 1") +.Attr("output_shapes: list(shape) >= 1") +.SetShapeFn(shape_inference::ScalarShape); +REGISTER_KERNEL_BUILDER(Name("EosDataset").Device(DEVICE_CPU), + EosDatasetOp); + +} // namespace data +} // namespace tensorflow \ No newline at end of file diff --git a/src/dataset_tf/eos_dataset_op.h b/src/dataset_tf/eos_dataset_op.h new file mode 100644 index 00000000..4ece6723 --- /dev/null +++ b/src/dataset_tf/eos_dataset_op.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: dataset eos ops. + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EOS_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EOS_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/public/version.h" + +#if TF_MAJOR_VERSION == 2 +#define TF_VERSION_TF2 +#endif + +using namespace std; +namespace tensorflow { +namespace data { + // 这个类的 MakeDataset() 方法告诉 TensorFlow 怎样根据一个操作的输入和属性生成一个数据集的对象。 + class EosDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char *const kDatasetType = "Eos"; + static constexpr const char *const kInputDataset = "input_dataset"; + static constexpr const char *const kChannelId = "channel_id"; + static constexpr const char *const kOutputTypes = "output_types"; + static constexpr const char *const kOutputShapes = "output_shapes"; + + explicit EosDatasetOp(OpKernelConstruction *ctx); + + protected: + void MakeDataset(OpKernelContext *ctx, DatasetBase *input, + DatasetBase **output) override; + + private: + class Dataset; + }; // class EosDatasetOp +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EOS_DATASET_OP_H_ diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index ed455647..d3793f1d 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -192,8 +192,7 @@ namespace { .def("block_count_steps", &MxRec::HybridMgmt::CountStepBySessionRun, py::arg("channel_id"), py::arg("steps")=1) .def("get_table_size", &MxRec::HybridMgmt::GetTableSize, py::arg("table_name")) - .def("get_table_capacity", &MxRec::HybridMgmt::GetTableCapacity, py::arg("table_name")) - .def("set_mpi_send_abnormal_status", &MxRec::HybridMgmt::SetMpiSendAbnormalStatus); + .def("get_table_capacity", &MxRec::HybridMgmt::GetTableCapacity, py::arg("table_name")); } void GetThresholdValue(pybind11::module_& m) -- Gitee From f3d042e19dc2d9063fe9cf459fcc9c9817f963d0 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 30 Nov 2023 17:45:07 +0800 Subject: [PATCH 496/551] Match-id-7a2c4b1883e13d0b82246c6152564b35c6cb3200 --- mx_rec/__init__.py | 2 ++ mx_rec/constants/constants.py | 7 +++---- mx_rec/core/asc/manager.py | 27 +------------------------ mx_rec/data/__init__.py | 3 +++ mx_rec/data/dataset.py | 36 ++++++++++++++++++++++++++++++++++ mx_rec/data/patch.py | 19 ++++++++++++++++++ mx_rec/graph/modifier.py | 13 ++++++++---- mx_rec/util/global_env_conf.py | 7 ++----- mx_rec/util/ops.py | 19 ++++++++++++++---- 9 files changed, 90 insertions(+), 43 deletions(-) create mode 100644 mx_rec/data/__init__.py create mode 100644 mx_rec/data/dataset.py create mode 100644 mx_rec/data/patch.py diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index c7261449..53c31414 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -9,10 +9,12 @@ from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaver from mx_rec.saver.patch import patch_for_saver from mx_rec.graph.patch import patch_for_dataset, patch_for_chief_session_creator, patch_for_bool_gauge, \ patch_for_assert_eval_spec, patch_for_scale_loss, patch_for_session +from mx_rec.data.patch import patch_for_dataset_eos_map from mx_rec.optimizers.base import patch_for_optimizer patch_for_saver() patch_for_dataset() +patch_for_dataset_eos_map() patch_for_scale_loss() patch_for_chief_session_creator() patch_for_assert_eval_spec() diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index c261c7f5..6a1d18c8 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -72,9 +72,9 @@ MAX_DEVICE_ID = 15 # HDFS file system's file prefix HDFS_FILE_PREFIX = ["viewfs://", "hdfs://"] -# get next名称 -ITERATOR_GET_NEXT = "IteratorGetNext" -NPU_GET_NEXT = "npuGetNext" +# so包名称 +LIBASC_OPS_SO = "libasc_ops.so" +LIBREC_EOS_OPS_SO = "librec_eos_ops.so" class BaseEnum(Enum): @@ -114,7 +114,6 @@ class EnvOption(Enum): USE_COMBINE_FAAE = "USE_COMBINE_FAAE" STAT_ON = "STAT_ON" RECORD_KEY_COUNT = "RECORD_KEY_COUNT" - ADD_CONTROL_EDGE = "ADD_CONTROL_EDGE" class DataName(Enum): diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 11265169..42d94476 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -9,11 +9,8 @@ from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitiali from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_steps, get_eval_steps, get_save_steps, \ export_table_instances, export_feature_spec, get_if_load, get_use_static, \ - get_use_hot, get_stat_on, get_use_dynamic_expansion, export_optimizer, export_dangling_table, export_table_num, \ - get_modify_graph -from mx_rec.constants.constants import ITERATOR_GET_NEXT, NPU_GET_NEXT, TFDevice, EnvOption, Flag + get_use_hot, get_stat_on, get_use_dynamic_expansion, export_optimizer, export_dangling_table, export_table_num from mx_rec.core.asc.merge_table import find_dangling_table, should_skip -from mx_rec.util.global_env_conf import global_env from mx_rec.util.log import logger @@ -228,29 +225,7 @@ def initialize_emb_cache(table_info_list, threshold_list): logger.debug("threshold_values are %s.", threshold_list) -def add_control_edge(): - iterator_get_next_op = None - get_next_name = ITERATOR_GET_NEXT - if get_modify_graph(): - get_next_name = NPU_GET_NEXT - for op in tf.compat.v1.get_default_graph().get_operations(): - if get_next_name == op.name and ITERATOR_GET_NEXT == op.type: - iterator_get_next_op = op - break - logger.info("iterator_get_next_op: %s", iterator_get_next_op) - if not iterator_get_next_op: - return - for op in tf.compat.v1.get_default_graph().get_operations(): - if "GetNext" == op.type: - if global_env.tf_device == TFDevice.NPU.value and "merged" not in op.name: - continue - op._add_control_input(iterator_get_next_op) - logger.info("_add_control_input: %s", op) - - def start_asc_pipeline(): - if global_env.add_control_edge == Flag.TRUE.value: - add_control_edge() table_info_list = generate_table_info_list() threshold_list = generate_threshold_list() diff --git a/mx_rec/data/__init__.py b/mx_rec/data/__init__.py new file mode 100644 index 00000000..6924f767 --- /dev/null +++ b/mx_rec/data/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/mx_rec/data/dataset.py b/mx_rec/data/dataset.py new file mode 100644 index 00000000..844fc967 --- /dev/null +++ b/mx_rec/data/dataset.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +from tensorflow.python.data.ops.dataset_ops import get_legacy_output_types, get_legacy_output_classes, \ + get_legacy_output_shapes, UnaryDataset +from tensorflow.python.data.util import structure +from tensorflow.python.framework import ops +from tensorflow.python.framework import dtypes + + +class EosDataset(UnaryDataset): + """用于发送end_of_sequence的dataset.""" + + def __init__(self, input_dataset, librec, channel_id): + self._input_dataset = input_dataset + output_types = get_legacy_output_types(input_dataset) + output_classes = get_legacy_output_classes(input_dataset) + input_shapes = get_legacy_output_shapes(self._input_dataset) + output_shapes = input_shapes + + self._structure = structure.convert_legacy_structure( + output_types, output_shapes, output_classes) + channel_id = ops.convert_to_tensor(channel_id, dtype=dtypes.int32, name="channel_id") + self._input_datasets = [input_dataset] + variant_tensor = librec.eos_dataset( + input_dataset=input_dataset._variant_tensor, channel_id=channel_id, + output_shapes=self._flat_shapes, output_types=self._flat_types) + super(EosDataset, self).__init__(input_dataset, variant_tensor) + + @property + def element_spec(self): + return self._structure + + def _inputs(self): + return self._input_datasets diff --git a/mx_rec/data/patch.py b/mx_rec/data/patch.py new file mode 100644 index 00000000..71655512 --- /dev/null +++ b/mx_rec/data/patch.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +from tensorflow.python.data.ops.dataset_ops import DatasetV2, DatasetV1Adapter + +from mx_rec.data.dataset import EosDataset + + +def patch_for_dataset_eos_map(): + """ + 给DatasetV2类增加eos_map方法. + Returns: None + """ + + def eos_map_fn(self, librec, channel_id): + return DatasetV1Adapter(EosDataset(self, librec, channel_id)) + + DatasetV2.eos_map = eos_map_fn diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index 8231fe53..d936fa7c 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -18,11 +18,12 @@ from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import SparseEmbedding from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE, NPU_GET_NEXT + ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE, LIBREC_EOS_OPS_SO from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ - terminate_config_initializer, set_is_graph_modify_hook_running, get_bool_gauge_set, \ + get_training_mode_channel_id, set_is_graph_modify_hook_running, get_bool_gauge_set, \ insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ set_iterator_type +from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.util.perf import performance from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ export_pb_graph, make_sorted_key_to_tensor_list, ReplacementSpec, AnchorRecord @@ -455,6 +456,11 @@ def get_tgt_dataset( """ + librec = import_host_pipeline_ops(LIBREC_EOS_OPS_SO) + channel_id = get_training_mode_channel_id(records.get("is_training")) + # 在数据读取完时,通过EosDataset向acl数据通道发送end_of_sequence + src_dataset = src_dataset.eos_map(librec, channel_id) + tgt_dataset = src_dataset.map(get_preprocessing_map_func(records.get("sub_graph_def"), records.get("input_name_list"), records.get("output_name_list"), @@ -510,7 +516,7 @@ def update_iterator_getnext(get_next_op: Operation, set_initializer(is_training, new_iterator.initializer) else: new_iterator = tgt_dataset.make_one_shot_iterator() - new_batch = new_iterator.get_next(NPU_GET_NEXT) + new_batch = new_iterator.get_next() set_target_batch(is_training, new_batch) try: @@ -619,4 +625,3 @@ class GraphModifierHook(tf.estimator.SessionRunHook): def after_create_session(self, session, coord): if self._modify_graph and self._iterator_type == "MakeIterator": session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) - diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index 2a6bf186..0ba5bdde 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -34,7 +34,6 @@ class RecEnv: use_combine_faae: str stat_on: str record_key_count: str - add_control_edge: str def get_global_env_conf() -> RecEnv: @@ -62,8 +61,7 @@ def get_global_env_conf() -> RecEnv: glog_stderrthreahold=os.getenv(EnvOption.GLOG_STDERRTHREAHOLD.value, RecCPPLogLevel.INFO.value), use_combine_faae=os.getenv(EnvOption.USE_COMBINE_FAAE.value, Flag.FALSE.value), stat_on=os.getenv(EnvOption.STAT_ON.value, Flag.FALSE.value), - record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value), - add_control_edge=os.getenv(EnvOption.ADD_CONTROL_EDGE.value, Flag.FALSE.value) + record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value) ) return rec_env @@ -89,8 +87,7 @@ def get_global_env_conf() -> RecEnv: ("glog_stderrthreahold", OptionValidator, {"options": [i.value for i in list(RecCPPLogLevel)]}), ("use_combine_faae", OptionValidator, {"options": [i.value for i in list(Flag)]}), ("stat_on", OptionValidator, {"options": [i.value for i in list(Flag)]}), - ("record_key_count", OptionValidator, {"options": [i.value for i in list(Flag)]}), - ("add_control_edge", OptionValidator, {"options": [i.value for i in list(Flag)]}) + ("record_key_count", OptionValidator, {"options": [i.value for i in list(Flag)]}) ]) def check_env(**kwargs): pass diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py index 121e65f3..7869ce56 100644 --- a/mx_rec/util/ops.py +++ b/mx_rec/util/ops.py @@ -3,20 +3,31 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import os +from types import ModuleType import tensorflow as tf from mx_rec.util.log import logger +from mx_rec.constants.constants import LIBASC_OPS_SO -def import_host_pipeline_ops(): +def import_host_pipeline_ops(so_pkg_name: str = LIBASC_OPS_SO) -> ModuleType: + """ + 导入so包. + + Args: + so_pkg_name: so包的名称 + Returns: 返回用于调用op的module + """ + + so_pkg_path = 'mx_rec/libasc/' + so_pkg_name if os.path.exists( os.path.join(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), - 'mx_rec/libasc/libasc_ops.so')): + so_pkg_path)): default_so_path = os.path.join( os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")), - 'mx_rec/libasc/libasc_ops.so') + so_pkg_path) logger.debug("Using the DEFAULT PATH '%s' to get ops lib.", default_so_path) return tf.load_op_library(default_so_path) else: - raise ValueError("Please check if libasc_ops.so exists (mxRec correctly installed)") + raise ValueError(f"Please check if `{so_pkg_name}` exists (mxRec correctly installed).") -- Gitee From ae4ef1aa25736889262b9cc185c13438922fa4e5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 30 Nov 2023 21:20:05 +0800 Subject: [PATCH 497/551] Match-id-fcad33dfcc8c249193a3eb527de68de92bf1dbc5 --- src/core/file_system/hdfs_file_system/hdfs_file_system.h | 2 +- src/core/file_system/hdfs_file_system/hdfs_wrapper.h | 5 +++-- src/core/file_system/local_file_system/local_file_system.h | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.h b/src/core/file_system/hdfs_file_system/hdfs_file_system.h index c941ba44..e6c0f6f1 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_file_system.h +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.h @@ -27,7 +27,7 @@ namespace MxRec { size_t GetFileSize(const string& filePath) override; ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) override; - ssize_t Write(const string& filePath, vector fileVector, size_t dataSize) override; + ssize_t Write(const string& filePath, vector fileContent, size_t dataSize) override; void WriteEmbedding(const string& filePath, const int& embeddingSize, const vector& addressArr, int deviceId) override; diff --git a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h index 1b225ae7..2accf487 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h +++ b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h @@ -123,7 +123,8 @@ namespace MxRec { return hdfsFreeFileInfo(hdfsFileInfo, numEntries); } - hdfsFile OpenFile(hdfsFS fs, const char* path, int flags, int bufferSize, short replication, tSize blocksize) + hdfsFile OpenFile(hdfsFS fs, const char* path, int flags, int bufferSize, short replication, + tSize blocksize) const { if (hdfsOpenFile == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsOpenFile from the libhdfs."); @@ -131,7 +132,7 @@ namespace MxRec { return hdfsOpenFile(fs, path, flags, bufferSize, replication, blocksize); } - int CloseFile(hdfsFS fs, hdfsFile file) + int CloseFile(hdfsFS fs, hdfsFile file) const { if (hdfsCloseFile == nullptr) { throw runtime_error("Failed to obtain the pointer of the function hdfsCloseFile from the libhdfs."); diff --git a/src/core/file_system/local_file_system/local_file_system.h b/src/core/file_system/local_file_system/local_file_system.h index 2e743165..78ea4167 100644 --- a/src/core/file_system/local_file_system/local_file_system.h +++ b/src/core/file_system/local_file_system/local_file_system.h @@ -25,7 +25,7 @@ namespace MxRec { size_t GetFileSize(const string& filePath) override; ssize_t Write(const string& filePath, const char* fileContent, size_t dataSize) override; - ssize_t Write(const string& filePath, vector fileVector, size_t dataSize) override; + ssize_t Write(const string& filePath, vector fileContent, size_t dataSize) override; void WriteEmbedding(const string& filePath, const int& embeddingSize, const vector& addressArr, int deviceId) override; -- Gitee From 52d9a7d93b4543b740a65d4d3f104fe861408368 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 1 Dec 2023 11:05:11 +0800 Subject: [PATCH 498/551] Match-id-42e63ee33f8555af752b6b731b050f3612c81a72 --- src/core/key_process/key_process.cpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 16fa109b..eaa042b4 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1235,16 +1235,6 @@ void KeyProcess::SendEos(int batchId, int channel) throw EndRunExit("SendEos end run, isRunning is false after lock destroyMutex."); } - string randomSendName = StringFormat("%s_%s_%d", emb.first.c_str(), (*usedChannelNames.begin()).c_str(), - channel); - - size_t channel_size; // 避免eos在keyProcess还未处理完数据时插队到通道前面 - do { - acltdtQueryChannelSize(transChannels[randomSendName], &channel_size); - LOG_TRACE("Before SendEos, channelName:{}, unsolved channel_size {}", randomSendName, channel_size); - this_thread::sleep_for(1ms); - } while (channel_size != 0); - for (const string& transName : usedChannelNames) { string sendName = StringFormat("%s_%s_%d", emb.first.c_str(), transName.c_str(), channel); SendTensorsByAcl(transChannels[sendName], ACL_TENSOR_DATA_END_OF_SEQUENCE, tensors, isNeedResend); @@ -1306,11 +1296,15 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa return uTensor; } catch (EmptyList&) { unique_lock lockGuard(destroyMutex); - if (isNeedSendEos[channel]) { + // readEmbKey真实的次数是readEmbedBatchId减1 + int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[channel] - 1; + // 避免eos在keyProcess还未处理完数据时插队到通道前面 + if (isNeedSendEos[channel] && readEmbKeyBatchId < batch) { SendEos(batch, channel); return nullptr; } - LOG_TRACE("getting info failed {}[{}]:{}", embName, channel, batch); + LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batch id: {}, readEmbKey batch id: {}.", + embName, channel, batch, readEmbKeyBatchId); this_thread::sleep_for(1ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed {}[{}]:{} wrong top", embName, channel, batch); -- Gitee From 2255e0c8aed6ff1ae5c88bce66052f2433ab978e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 29 Nov 2023 01:39:29 +0800 Subject: [PATCH 499/551] Match-id-9370247cb449510b5ae910ac2d2bc5a950e913a0 --- src/core/emb_hashmap/emb_hashmap.h | 3 - .../local_file_system/local_file_system.cpp | 3 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 37 +++-- src/core/hybrid_mgmt/hybrid_mgmt.h | 3 +- src/core/key_process/key_process.cpp | 153 +++++++++--------- src/core/key_process/key_process.h | 21 ++- src/tests/key_process/key_process_test.cpp | 20 +-- 7 files changed, 118 insertions(+), 122 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index df44e74c..b8b6c883 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -28,9 +28,6 @@ namespace MxRec { void Process(const string& embName, std::vector& keys, DDRParam& ddrParam, int channelId); - void FindAndUpdateOffset(const string& embName, vector& keys, size_t currentBatchId, - size_t keepBatchId, int channelId); - auto GetHashMaps() -> absl::flat_hash_map; void LoadHashMap(absl::flat_hash_map& loadData); diff --git a/src/core/file_system/local_file_system/local_file_system.cpp b/src/core/file_system/local_file_system/local_file_system.cpp index 2335f94b..6e55072f 100644 --- a/src/core/file_system/local_file_system/local_file_system.cpp +++ b/src/core/file_system/local_file_system/local_file_system.cpp @@ -35,7 +35,6 @@ vector LocalFileSystem::ListDir(const string& dirName) struct dirent* en; if (dir == nullptr) { LOG_WARN("Open directory {} failed while trying to traverse the directory.", dirName); - closedir(dir); return dirs; } @@ -420,4 +419,4 @@ void LocalFileSystem::HandleMappedData(char* mappedData, size_t mapRowNum, size_ idx += readSize; } } -} \ No newline at end of file +} diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 4606cf22..e287ef1d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -500,27 +500,15 @@ void HybridMgmt::Start() { #ifndef GTEST if (mgmtRankInfo.noDDR) { - InsertThreadForHBM(); - } - - if (!mgmtRankInfo.noDDR) { - auto parseKeysTaskForTrain = [this]() { - TrainTask(TaskType::DDR); - LOG_INFO("parseKeysTaskForTrain done"); - }; - procThreads.emplace_back(std::make_unique(parseKeysTaskForTrain)); - - auto parseKeysTaskForEval = [this]() { - EvalTask(TaskType::DDR); - LOG_INFO("parseKeysTaskForEval done"); - }; - procThreads.emplace_back(std::make_unique(parseKeysTaskForEval)); + StartThreadForHBM(); + } else { + StartThreadForDDR(); } #endif } /// 启动HBM模式数据处理线程 -void HybridMgmt::InsertThreadForHBM() +void HybridMgmt::StartThreadForHBM() { #ifndef GTEST auto parseKeysTaskForHBMTrain = [this]() { @@ -537,6 +525,23 @@ void HybridMgmt::InsertThreadForHBM() #endif } +void HybridMgmt::StartThreadForDDR() +{ +#ifndef GTEST + auto parseKeysTaskForTrain = [this]() { + TrainTask(TaskType::DDR); + LOG_INFO("parseKeysTaskForTrain done"); + }; + procThreads.emplace_back(std::make_unique(parseKeysTaskForTrain)); + + auto parseKeysTaskForEval = [this]() { + EvalTask(TaskType::DDR); + LOG_INFO("parseKeysTaskForEval done"); + }; + procThreads.emplace_back(std::make_unique(parseKeysTaskForEval)); +#endif +} + #ifndef GTEST /// 启动hybrid处理任务 /// \param type diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 3c045901..3e24cda1 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -70,7 +70,8 @@ namespace MxRec { void Start(); - void InsertThreadForHBM(); + void StartThreadForHBM(); + void StartThreadForDDR(); void Destroy() { diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index eaa042b4..b8b64108 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -14,7 +14,6 @@ using namespace std; using namespace chrono; using namespace MxRec; -using namespace ock::ctr; static shared_mutex g_smut; @@ -66,7 +65,7 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos } if (GlobalEnv::fastUnique) { - int result = Factory::Create(factory); + int result = ock::ctr::Factory::Create(factory); if (result != 0) { throw runtime_error(Logger::Format("create fast factory failed, error code:{}", result)); } @@ -112,7 +111,7 @@ int KeyProcess::Start() void KeyProcess::InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo) { - auto embeddingSize = info.extEmbeddingSize; + int embeddingSize = info.extEmbeddingSize; if (rankInfo.useDynamicExpansion) { embeddingSize = info.embeddingSize; } @@ -120,22 +119,22 @@ void KeyProcess::InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo) HOT_EMB_CACHE_PCT / static_cast(embeddingSize)); } -auto KeyProcess::GetMaxOffset() -> OffsetMemT +OffsetMemT KeyProcess::GetMaxOffset() { return maxOffset; } -auto KeyProcess::GetKeyOffsetMap() -> KeyOffsetMemT +KeyOffsetMemT KeyProcess::GetKeyOffsetMap() { return keyOffsetMap; } -auto KeyProcess::GetKeyCountMap() -> KeyCountMemT +KeyCountMemT KeyProcess::GetKeyCountMap() { return keyCountMap; } -auto KeyProcess::GetFeatAdmitAndEvict() -> FeatureAdmitAndEvict& +FeatureAdmitAndEvict& KeyProcess::GetFeatAdmitAndEvict() { return m_featureAdmitAndEvict; } @@ -189,7 +188,7 @@ void KeyProcess::LoadSaveUnlock() } } -void KeyProcess::GetUniqueConfig(UniqueConf& uniqueConf) +void KeyProcess::GetUniqueConfig(ock::ctr::UniqueConf& uniqueConf) { if (rankInfo.rankSize > 0) { uniqueConf.useSharding = true; @@ -204,13 +203,13 @@ void KeyProcess::GetUniqueConfig(UniqueConf& uniqueConf) } uniqueConf.useIdCount = true; - uniqueConf.outputType = OutputType::ENHANCED; + uniqueConf.outputType = ock::ctr::OutputType::ENHANCED; uniqueConf.minThreadNum = MIN_UNIQUE_THREAD_NUM; uniqueConf.maxThreadNum = GlobalEnv::maxUniqueThreadNum; } -void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, - const unique_ptr & batch, UniquePtr& unique) +void KeyProcess::InitializeUnique(ock::ctr::UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, + const unique_ptr & batch, ock::ctr::UniquePtr& unique) { uniqueConf.desiredSize = static_cast(batch->Size()); if (preBatchSize != batch->Size()) { @@ -226,7 +225,7 @@ void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, uniqueConf.maxIdVal = INT64_MAX; uniqueConf.dataType = ock::ctr::DataType::INT64; - auto ret = unique->Initialize(uniqueConf); + int ret = unique->Initialize(uniqueConf); if (ret != ock::ctr::H_OK) { throw runtime_error(Logger::Format("fast unique init failed, code:{}", ret)); } @@ -237,12 +236,12 @@ void KeyProcess::InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) { unique_ptr batch; - UniquePtr unique = nullptr; - UniqueConf uniqueConf; + ock::ctr::UniquePtr unique = nullptr; + ock::ctr::UniqueConf uniqueConf; size_t preBatchSize = 0; bool uniqueInitialize = false; - auto ret = factory->CreateUnique(unique); + int ret = factory->CreateUnique(unique); if (ret != ock::ctr::H_OK) { throw runtime_error(Logger::Format("create fast unique failed, error code:{}", ret)); } @@ -257,7 +256,7 @@ void KeyProcess::KeyProcessTaskWithFastUnique(int channel, int threadId) if (batch == nullptr) { break; } - auto getBatchTime = getBatchDataTC.ElapsedMS(); + size_t getBatchTime = getBatchDataTC.ElapsedMS(); TimeCost processDataTime = TimeCost(); InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); @@ -293,7 +292,7 @@ void KeyProcess::KeyProcessTask(int channel, int threadId) if (batch == nullptr) { break; } - auto getBatchTime = getBatchDataTC.ElapsedMS(); + size_t getBatchTime = getBatchDataTC.ElapsedMS(); TimeCost processDataTime = TimeCost(); if (!KeyProcessTaskHelper(batch, channel, threadId)) { @@ -331,7 +330,7 @@ void KeyProcess::HashSplitHelper(const unique_ptr & batch, vector & batch, UniquePtr& unique, +bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch, ock::ctr::UniquePtr& unique, int channel, int threadId) { // tuple for keyRec restore hotPos scAll countRecv @@ -488,11 +487,11 @@ vector KeyProcess::GetCountRecv(const unique_ptr& batch, in for (int i = 0; i < rankInfo.rankSize; ++i) { rc.push_back(scAll.at(i * rankInfo.rankSize + rankInfo.rankId)); } - auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 + vector rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 vector countRecv; countRecv.resize(rs.back() + rc.back()); - auto retCode = MPI_Alltoallv(countSend.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), - rc.data(), rs.data(), MPI_UINT32_T, comm[batch->channel][id]); + int retCode = MPI_Alltoallv(countSend.data(), sc.data(), ss.data(), MPI_UINT32_T, countRecv.data(), + rc.data(), rs.data(), MPI_UINT32_T, comm[batch->channel][id]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Alltoallv failed:{}", rankInfo.rankId, retCode); } @@ -568,7 +567,7 @@ size_t KeyProcess::GetKeySize(const unique_ptr &batch) return size; } -void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch, UniquePtr& unique, +void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch, ock::ctr::UniquePtr& unique, int id, UniqueInfo& uniqueInfoOut) { EASY_FUNCTION(profiler::colors::Purple) @@ -586,11 +585,11 @@ void KeyProcess::ProcessBatchWithFastUnique(const unique_ptr &batch, vector idCount(batch->Size()); keySendInfo.keyCount.resize(size); - UniqueIn uniqueIn; + ock::ctr::UniqueIn uniqueIn; uniqueIn.inputIdCnt = static_cast(batch->Size()); uniqueIn.inputId = reinterpret_cast(batch->sample.data()); - EnhancedUniqueOut uniqueOut; + ock::ctr::EnhancedUniqueOut uniqueOut; uniqueOut.uniqueId = reinterpret_cast(keySendInfo.keySend.data()); uniqueOut.index = reinterpret_cast(uniqueInfoOut.restore.data()); if (rankInfo.useStatic) { @@ -631,7 +630,7 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uniqu KeySendInfo& keySendInfo, vector& sc, vector& splitSize) { std::shared_lock lock(g_smut); - auto hotMap = hotKey[batch->name]; + absl::flat_hash_map hotMap = hotKey[batch->name]; lock.unlock(); if (rankInfo.useHot) { @@ -650,7 +649,7 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uniqu sc.resize(rankInfo.rankSize, embInfos[batch->name].sendCount); } else { sc.resize(rankInfo.rankSize); - for (int i = 0;i < rankInfo.rankSize; i++) { + for (int i = 0; i < rankInfo.rankSize; i++) { sc[i] = splitSize[i]; } } @@ -659,11 +658,11 @@ void KeyProcess::HandleHotAndSendCount(const unique_ptr &batch, Uniqu void KeyProcess::ComputeHotPos(const unique_ptr &batch, absl::flat_hash_map &hotMap, vector &hotPos, vector &restore, const int hotOffset) const { - auto* inputData = batch->sample.data(); + emb_key_t* inputData = batch->sample.data(); size_t miniBs = batch->Size(); int hotCount = 0; - for (size_t i = 0;i < miniBs; i++) { + for (size_t i = 0; i < miniBs; i++) { const emb_key_t& key = inputData[i]; auto hot = hotMap.find(key); if (hot != hotMap.end()) { @@ -689,18 +688,18 @@ void KeyProcess::All2All(vector& sc, int id, const unique_ptr &b LOG_DEBUG("GetScAll TimeCost(ms):{}", getScAllTC.ElapsedMS()); TimeCost all2allTC; - auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 + vector ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 vector rc(rankInfo.rankSize); // receive count for (int i = 0; i < rankInfo.rankSize; ++i) { // 通信量矩阵某一列的和即为本地要从其他设备接受的key数据量 rc[i] = all2AllInfoOut.scAll.at(i * rankInfo.rankSize + rankInfo.rankId); } - auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 + vector rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 all2AllInfoOut.keyRecv.resize(rs.back() + rc.back()); EASY_BLOCK("all2all") - auto retCode = MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, - all2AllInfoOut.keyRecv.data(), rc.data(), rs.data(), - MPI_INT64_T, comm[channel][id]); + int retCode = MPI_Alltoallv(keySendInfo.keySend.data(), sc.data(), ss.data(), MPI_INT64_T, + all2AllInfoOut.keyRecv.data(), rc.data(), rs.data(), + MPI_INT64_T, comm[channel][id]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Alltoallv failed:{}", rankInfo.rankId, retCode); } @@ -731,7 +730,7 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, // 使用静态all2all通信:发送或接受量为预置固定值 scInfo[batch->name] = 65536 / rankSize 经验值 if (rankInfo.useStatic) { // maybe move after all2all - for (auto& i: splitKeys) { + for (KeysT& i: splitKeys) { if (static_cast(i.size()) > embInfos[batch->name].sendCount) { LOG_ERROR("{}[{}]:{} overflow! set send count bigger than {}", batch->name, batch->channel, batch->batchId, i.size()); @@ -751,22 +750,22 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, KeysT keyRecv; TimeCost getScAllTC; - auto scAll = GetScAll(sc, id, batch); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 + vector scAll = GetScAll(sc, id, batch); // Allgather通信获取所有(不同rank相同thread id的)线程间通信量矩阵 LOG_DEBUG("getScAllTC(ms)(AllReduce-AllGather):{}", getScAllTC.ElapsedMS()); - auto ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 + vector ss = Count2Start(sc); // send displays/offset 发送数据的起始偏移量 vector rc; // receive count for (int i = 0; i < rankInfo.rankSize; ++i) { // 通信量矩阵某一列的和即为本地要从其他设备接受的key数据量 rc.push_back(scAll.at(i * rankInfo.rankSize + rankInfo.rankId)); } - auto rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 + vector rs = Count2Start(rc); // receive displays/offset 接受数据的起始偏移量 keyRecv.resize(rs.back() + rc.back()); EASY_BLOCK("all2all") TimeCost uniqueAll2AllTC; - auto retCode = MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, - keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[batch->channel][id]); + int retCode = MPI_Alltoallv(keySend.data(), sc.data(), ss.data(), MPI_INT64_T, + keyRecv.data(), rc.data(), rs.data(), MPI_INT64_T, comm[batch->channel][id]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Alltoallv failed:{}", rankInfo.rankId, retCode); } @@ -784,10 +783,10 @@ auto KeyProcess::ProcessSplitKeys(const unique_ptr& batch, int id, * splitKeys返回:将数据的key切分到其所在dev id对应的桶中,并去重。 * restore返回:去重后key在桶内偏移量(用于计算恢复向量) */ -auto KeyProcess::HashSplit(const unique_ptr& batch) const -> tuple, vector> +tuple, vector> KeyProcess::HashSplit(const unique_ptr& batch) const { EASY_FUNCTION(profiler::colors::Gold) - auto* batchData = batch->sample.data(); + emb_key_t* batchData = batch->sample.data(); size_t miniBs = batch->Size(); vector splitKeys(rankInfo.rankSize); vector restore(batch->Size()); @@ -833,11 +832,11 @@ void KeyProcess::PaddingAlltoallVC(vector& splitKeys) const return; } -auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const - -> tuple, vector, vector>> +tuple, vector, vector>> KeyProcess::HashSplitWithFAAE( + const unique_ptr& batch) const { EASY_FUNCTION(profiler::colors::Gold) - auto* batchData = batch->sample.data(); + emb_key_t* batchData = batch->sample.data(); size_t miniBs = batch->Size(); vector splitKeys(rankInfo.rankSize); vector> keyCount(rankInfo.rankSize); // splitKeys在原始batch中对应的频次 @@ -886,11 +885,10 @@ auto KeyProcess::HashSplitWithFAAE(const unique_ptr& batch) const return { splitKeys, restore, keyCount }; } -auto KeyProcess::HotHashSplit(const unique_ptr& batch) -> -tuple, vector, vector> +tuple, vector, vector> KeyProcess::HotHashSplit(const unique_ptr& batch) { EASY_FUNCTION(profiler::colors::Gold) - auto* batchData = batch->sample.data(); + emb_key_t* batchData = batch->sample.data(); size_t miniBs = batch->Size(); vector splitKeys(rankInfo.rankSize); vector restore(batch->Size()); @@ -952,17 +950,13 @@ tuple, vector, vector> void KeyProcess::AddCountStartToHotPos(vector& splitKeys, vector& hotPos, const vector& hotPosDev, const unique_ptr& batch) { - vector splitKeysSize {}; - if (rankInfo.useStatic) { - for (size_t i = 0; i < splitKeys.size(); i++) { - splitKeysSize.push_back(embInfos[batch->name].sendCount); - } - } else { - for (auto& splitKey: splitKeys) { - splitKeysSize.push_back(static_cast(splitKey.size())); - } + vector splitKeysSize; + for (auto& splitKey: splitKeys) { + int tmp = rankInfo.useStatic ? embInfos[batch->name].sendCount : static_cast(splitKey.size()); + splitKeysSize.push_back(tmp); } - auto cs = Count2Start(splitKeysSize); + + vector cs = Count2Start(splitKeysSize); for (size_t i = 0; i < hotPos.size(); ++i) { hotPos[i] += cs[hotPosDev[i]]; } @@ -996,23 +990,24 @@ void KeyProcess::UpdateHotMapForUnique(const KeysT &keySend, const vector& keyCountMapByEmbName, uint32_t count, bool refresh, const string& embName) { + if (!refresh) { + return; + } auto& hotMap = hotKey[embName]; - if (refresh) { - priority_queue> pq; // top k key - for (auto& p: keyCountMapByEmbName) { - pq.push(pair(-p.second, p.first)); - if (pq.size() > count) { - pq.pop(); - } - } - // gen new hot map - std::unique_lock lock(g_smut); - hotMap.clear(); - while (!pq.empty()) { - hotMap.insert(make_pair(pq.top().second, -1)); + priority_queue> pq; // top k key + for (auto& p: keyCountMapByEmbName) { + pq.push(pair(-p.second, p.first)); + if (pq.size() > count) { pq.pop(); } } + // gen new hot map + std::unique_lock lock(g_smut); + hotMap.clear(); + while (!pq.empty()) { + hotMap.insert(make_pair(pq.top().second, -1)); + pq.pop(); + } } /* @@ -1027,8 +1022,8 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, cons LOG_DEBUG("channelId:{} threadId:{} batchId:{}, GetScAll start.", batch->channel, commId, batch->batchId); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - auto retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, - MPI_INT, comm[batch->channel][commId]); + int retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, + MPI_INT, comm[batch->channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {} commId {}, MPI_Allgather failed:{}", rankInfo.rankId, commId, retCode); } @@ -1045,8 +1040,8 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, co scAllOut.resize(rankInfo.rankSize * rankInfo.rankSize); // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - auto retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, - scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + int retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, + scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Allgather failed:{}", rankInfo.rankId, retCode); } @@ -1116,7 +1111,7 @@ void KeyProcess::Key2OffsetDynamicExpansion(const EmbNameT& embName, KeysT& spli // 新值 if (channel == TRAIN_CHANNEL_ID) { #ifndef GTEST - auto addr = curEmbTable.GetEmbAddress(); + int64_t addr = curEmbTable.GetEmbAddress(); key2Offset[key] = addr; key = addr; #endif @@ -1376,7 +1371,7 @@ void KeyProcess::EvictDeleteDeviceEmb(const string& embName, const vector offset tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto evictLen = tmpDataOut.back().flat(); - auto evictSize = static_cast(offset.size()); + int evictSize = static_cast(offset.size()); evictLen(0) = evictSize; // evict key发送给dev侧,dev侧初始化emb @@ -1442,7 +1437,7 @@ int64_t KeyProcess::GetExpansionTableSize(const string& embName) return -1; } std::lock_guard lk(mut); // lock for PROCESS_THREAD - return embeddingTableMap[embName].GetTableSize(); + return iter->second.GetTableSize(); #endif } @@ -1455,7 +1450,7 @@ int64_t KeyProcess::GetExpansionTableCapacity(const string& embName) return -1; } std::lock_guard lk(mut); // lock for PROCESS_THREAD - return embeddingTableMap[embName].GetTableCapacity(); + return iter->second.GetTableCapacity(); #endif } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 185be563..9c8f8eff 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -36,7 +36,6 @@ namespace MxRec { using namespace std; - using namespace ock::ctr; template struct Cmp { @@ -189,7 +188,7 @@ namespace MxRec { map> hotKey {}; map hotEmbTotCount; map embeddingTableMap {}; - FactoryPtr factory {}; + ock::ctr::FactoryPtr factory {}; int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; vector isNeedSendEos { false, false }; // 分别代表通道0、1的eos状态 @@ -202,18 +201,18 @@ namespace MxRec { bool KeyProcessTaskHelper(unique_ptr& batch, int channel, int threadId); - bool KeyProcessTaskHelperWithFastUnique(unique_ptr &batch, UniquePtr& unique, + bool KeyProcessTaskHelperWithFastUnique(unique_ptr &batch, ock::ctr::UniquePtr& unique, int channel, int threadId); - auto ProcessSplitKeys(const unique_ptr& batch, int id, - vector& splitKeys) -> tuple, vector>; + tuple, vector> ProcessSplitKeys(const unique_ptr& batch, + int id, vector& splitKeys); - void GetUniqueConfig(UniqueConf& uniqueConf); + void GetUniqueConfig(ock::ctr::UniqueConf& uniqueConf); - void InitializeUnique(UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, - const unique_ptr & batch, UniquePtr& unique); + void InitializeUnique(ock::ctr::UniqueConf& uniqueConf, size_t& preBatchSize, bool& uniqueInitialize, + const unique_ptr & batch, ock::ctr::UniquePtr& unique); - void ProcessBatchWithFastUnique(const unique_ptr &batch, UniquePtr& unique, + void ProcessBatchWithFastUnique(const unique_ptr &batch, ock::ctr::UniquePtr& unique, int id, UniqueInfo& uniqueInfoOut); size_t GetKeySize(const unique_ptr &batch); @@ -227,8 +226,8 @@ namespace MxRec { void PaddingAlltoallVC(vector& splitKeys) const; - auto HashSplitWithFAAE(const unique_ptr& batch) const - -> tuple, vector, vector>>; + tuple, vector, vector>> + HashSplitWithFAAE(const unique_ptr& batch) const; vector GetScAll(const vector& keyScLocal, int commId, const unique_ptr& batch); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 749c3606..dbed8831 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -23,7 +23,7 @@ using namespace MxRec; using namespace testing; static constexpr size_t BATCH_NUM_EACH_THREAD = 3; -FactoryPtr factory; +ock::ctr::FactoryPtr factory; class SimpleThreadPool { public: @@ -251,7 +251,7 @@ TEST_F(KeyProcessTest, Initialize) ASSERT_NE(process.embInfos.find(info.name), process.embInfos.end()); } - Factory::Create(factory); + ock::ctr::Factory::Create(factory); } TEST_F(KeyProcessTest, Start) @@ -464,7 +464,7 @@ TEST_F(KeyProcessTest, Key2OffsetDynamicExpansion) TEST_F(KeyProcessTest, GetUniqueConfig) { - UniqueConf uniqueConf; + ock::ctr::UniqueConf uniqueConf; process.rankInfo.rankSize = worldSize; process.rankInfo.useStatic = true; process.GetUniqueConfig(uniqueConf); @@ -491,14 +491,14 @@ TEST_F(KeyProcessTest, ProcessPrefetchTask) TEST_F(KeyProcessTest, InitializeUnique) { - ASSERT_EQ(Factory::Create(factory), -1); - UniquePtr unique; + ASSERT_EQ(ock::ctr::Factory::Create(factory), -1); + ock::ctr::UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); PrepareBatch(); unique_ptr batch; batch = process.GetBatchData(0, 0); - UniqueConf uniqueConf; + ock::ctr::UniqueConf uniqueConf; process.rankInfo.rankSize = worldSize; process.rankInfo.useStatic = true; bool uniqueInitialize = false; @@ -520,13 +520,13 @@ TEST_F(KeyProcessTest, GetKeySize) TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) { PrepareBatch(); - + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); LOG_INFO("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { - UniquePtr unique; - + ock::ctr::UniquePtr unique; + auto embName = embInfos[0].name; process.hotEmbTotCount[embName] = 10; vector splitKeys; @@ -537,7 +537,7 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue ASSERT_EQ(factory->CreateUnique(unique), ock::ctr::H_OK); - UniqueConf uniqueConf; + ock::ctr::UniqueConf uniqueConf; size_t preBatchSize = 0; bool uniqueInitialize = false; process.GetUniqueConfig(uniqueConf); -- Gitee From 9e0ccd28f952eea6dc4b6745e7ff538d4d9a1d89 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 1 Dec 2023 16:06:21 +0800 Subject: [PATCH 500/551] Match-id-f494b13736a3c9bc93710a1797a1cba7dbeb3ca5 --- src/core/key_process/key_process.cpp | 4 ++-- src/core/key_process/key_process.h | 2 +- src/dataset_tf/eos_dataset_op.cc | 12 ++++++------ src/dataset_tf/eos_dataset_op.h | 1 - 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index b8b64108..6acedec6 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -515,7 +515,7 @@ void KeyProcess::PushResult(unique_ptr& batch, unique_ptr中读取batch数据并返回。batch数据由 ReadEmbKeyV2 写入。 * commID为线程标识[0, KEY_PROCESS_THREAD-1],不同线程、训练或推理数据用不同的共享队列通信 */ -unique_ptr KeyProcess::GetBatchData(int channel, int commId) +unique_ptr KeyProcess::GetBatchData(int channel, int commId) const { EASY_FUNCTION() unique_ptr batch = nullptr; @@ -1224,7 +1224,7 @@ void KeyProcess::SendEos(int batchId, int channel) vector tensors; bool isNeedResend = true; - for (const auto& emb:embInfos) { // 一个表触发以后,其余表都发送eos,最后外层接收null退出此次循环 + for (const auto& emb: as_const(embInfos)) { // 一个表触发以后,其余表都发送eos,最后外层接收null退出此次循环 LOG_INFO("channelId:{} batchId:{}, the embName:{} related channel SendEos start.", channel, batchId, emb.first); if (!isRunning) { throw EndRunExit("SendEos end run, isRunning is false after lock destroyMutex."); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 9c8f8eff..1a1c1332 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -238,7 +238,7 @@ namespace MxRec { void Key2OffsetDynamicExpansion(const EmbNameT& embName, KeysT& splitKey, int channel); - unique_ptr GetBatchData(int channel, int commId); + unique_ptr GetBatchData(int channel, int commId) const; void BuildRestoreVec(const unique_ptr& batch, const vector& blockOffset, vector& restoreVec, int hotPosSize = 0) const; diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index aa5b4b62..39999baf 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -31,8 +31,8 @@ using namespace MxRec; namespace tensorflow { namespace data { -MPI_Comm comm; -MPI_Group worldGroup; +MPI_Comm g_comm; +MPI_Group g_worldGroup; constexpr const char *const EosDatasetOp::kDatasetType; constexpr const char *const EosDatasetOp::kInputDataset; @@ -52,8 +52,8 @@ public: auto os_input = input->output_shapes(); output_shapes_ = os_input; keyProcess = Singleton::GetInstance(); - MPI_Comm_group(MPI_COMM_WORLD, &worldGroup); - MPI_Comm_create(MPI_COMM_WORLD, worldGroup, &comm); + MPI_Comm_group(MPI_COMM_WORLD, &g_worldGroup); + MPI_Comm_create(MPI_COMM_WORLD, g_worldGroup, &g_comm); } ~Dataset() override @@ -144,7 +144,7 @@ private: // 正常数据流程 if (!*end_of_sequence) { LOG_TRACE("GetNext, step in MPI_Allreduce, exitFlag:[{}]", exitFlag); - MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm); + MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, g_comm); LOG_TRACE("GetNext, step out MPI_Allreduce, exitFlag:[{}]", exitFlag); // 数据不均衡场景, 别的卡eos if (exitFlag != 0) { @@ -161,7 +161,7 @@ private: *end_of_sequence = true; input_impl_.reset(); LOG_TRACE("GetNext eos, step in MPI_Allreduce, exitFlag:[{}]", exitFlag); - MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, comm); + MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, g_comm); LOG_TRACE("GetNext eos, step out MPI_Allreduce, exitFlag:[{}]", exitFlag); LOG_INFO("GetNext eos, channelID:[{}]", dataset()->channelId_); diff --git a/src/dataset_tf/eos_dataset_op.h b/src/dataset_tf/eos_dataset_op.h index 4ece6723..5f5383f5 100644 --- a/src/dataset_tf/eos_dataset_op.h +++ b/src/dataset_tf/eos_dataset_op.h @@ -16,7 +16,6 @@ #define TF_VERSION_TF2 #endif -using namespace std; namespace tensorflow { namespace data { // 这个类的 MakeDataset() 方法告诉 TensorFlow 怎样根据一个操作的输入和属性生成一个数据集的对象。 -- Gitee From 67a8e21e4dcf9943a9ebf1c4d01727a5f405dcc9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 1 Dec 2023 19:19:19 +0800 Subject: [PATCH 501/551] Match-id-9d2158308547781e316e627d5ad1fd17b8fbb966 --- src/dataset_tf/eos_dataset_op.cc | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 39999baf..cf1b7be7 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -34,6 +34,9 @@ namespace data { MPI_Comm g_comm; MPI_Group g_worldGroup; +const int TIME_OUT_TIMES = 80000; +const int SLEEP_TIME = 1000; + constexpr const char *const EosDatasetOp::kDatasetType; constexpr const char *const EosDatasetOp::kInputDataset; constexpr const char *const EosDatasetOp::kChannelId; @@ -143,11 +146,25 @@ private: // 正常数据流程 if (!*end_of_sequence) { + MPI_Request request = MPI_REQUEST_NULL; + MPI_Status status; + LOG_TRACE("GetNext, step in MPI_Allreduce, exitFlag:[{}]", exitFlag); - MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, g_comm); + MPI_Iallreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, g_comm, &request); LOG_TRACE("GetNext, step out MPI_Allreduce, exitFlag:[{}]", exitFlag); - // 数据不均衡场景, 别的卡eos - if (exitFlag != 0) { + + int j = 0; + for (j = 0; j < TIME_OUT_TIMES; ++j) { + int flag = 0; + MPI_Test(&request, &flag, &status); + if (flag) { + MPI_Wait(&request, &status); + break; + } + usleep(SLEEP_TIME); + } + + if (j >= TIME_OUT_TIMES) { i_ = 1; *end_of_sequence = true; LOG_INFO("GetNext, some rank eos, channelID:[{}]", dataset()->channelId_); @@ -155,6 +172,7 @@ private: } return Status::OK(); } + // 数据eos场景 i_ = 1; exitFlag = 1; -- Gitee From 46caacaf7f6815683ff4a4731a515378af5a3736 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 2 Dec 2023 09:59:35 +0800 Subject: [PATCH 502/551] Match-id-5aa22078b80e78960ab6a603e8ca25aab9e37685 --- src/core/key_process/key_process.cpp | 91 +++++++++++++++++++++++++--- src/core/key_process/key_process.h | 10 +++ src/dataset_tf/eos_dataset_op.cc | 38 ------------ 3 files changed, 92 insertions(+), 47 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 6acedec6..6698abd4 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -540,9 +540,19 @@ unique_ptr KeyProcess::GetBatchData(int channel, int commId) const this_thread::sleep_for(seconds(1)); tc = TimeCost(); } - if (!isRunning) { - LOG_WARN("channelId:{} threadId:{}, isRunning is false when GetBatchData", channel, commId); - throw EndRunExit("GetBatchData end run."); + if (!isRunning || isNeedExit[channel]) { + LOG_WARN("channelId:{} threadId:{}, enter GetBatchData abnormal scene, isRunning:{}, isNeedExit:{}", + channel, commId, isRunning, isNeedExit[channel]); + // 通信终止信号,同步退出,防止线程卡住 + int receiveFlag = 0; + int sendValue = 0; // 此处直接发送0,不使用mpiAllReduceSend值,防止多线程数据可见性问题 + auto retCode = MPI_Allreduce(&sendValue, &receiveFlag, 1, MPI_INT, MPI_SUM, comm[channel][commId]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); + } + LOG_DEBUG("channelId:{} threadId:{}, GetBatchData Allreduce end, receiveFlag:{}", + channel, commId, receiveFlag); + throw EndRunExit("GetBatchData end run, thread will exit."); } } EASY_END_BLOCK @@ -1021,9 +1031,25 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, cons scAll.resize(rankInfo.rankSize * rankInfo.rankSize); LOG_DEBUG("channelId:{} threadId:{} batchId:{}, GetScAll start.", batch->channel, commId, batch->batchId); + // 通信终止信号,同步退出,防止线程卡住 + TimeCost tc = TimeCost(); + int receiveFlag = 0; + auto retCode = MPI_Allreduce(&mpiAllReduceSend[batch->channel], &receiveFlag, 1, MPI_INT, MPI_SUM, + comm[batch->channel][commId]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {} commId {}, MPI_Allreduce failed:{}", rankInfo.rankId, commId, retCode); + } + LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, GetScAll MPI_Allreduce end, receiveFlag:{}" + " barrier time:{}", + batch->channel, commId, batch->batchId, receiveFlag, tc.ElapsedMS()); + + // 处理其他rank线程退出的情况 + HandleRankExitScene(commId, batch, receiveFlag); + + EASY_END_BLOCK; // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - int retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, - MPI_INT, comm[batch->channel][commId]); + retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, scAll.data(), rankInfo.rankSize, MPI_INT, + comm[batch->channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {} commId {}, MPI_Allgather failed:{}", rankInfo.rankId, commId, retCode); } @@ -1032,6 +1058,26 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, cons return scAll; } + +void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &batch, int receiveFlag) +{ + if (!isRunning) { + throw EndRunExit("GetScAll end run, isRunning is false."); + } + if (receiveFlag < rankInfo.rankSize) { + unique_lock lockGuard(destroyMutex); + if (isNeedExit[batch->channel]) { + LOG_INFO("channelId:{} threadId:{} batchId:{}, has send acl eos info, thread will exit.", + batch->channel, commId, batch->batchId); + throw EndRunExit("has send acl eos info, thread will exit."); + } + SendEos(batch->batchId, batch->channel); + isNeedExit[batch->channel] = true; + throw EndRunExit("has SendEosInfo, GetScAll end, thread will exit."); + } +} + + void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, const unique_ptr &batch, vector &scAllOut) { @@ -1039,9 +1085,24 @@ void KeyProcess::GetScAllForUnique(const vector& keyScLocal, int commId, co int channel = batch->channel; scAllOut.resize(rankInfo.rankSize * rankInfo.rankSize); + // 通信终止信号,同步退出,防止线程卡住 + TimeCost tc = TimeCost(); + int receiveFlag = 0; + auto retCode = MPI_Allreduce(&mpiAllReduceSend[channel], &receiveFlag, 1, MPI_INT, MPI_SUM, + comm[channel][commId]); + if (retCode != MPI_SUCCESS) { + LOG_ERROR("rank {}, MPI_Allreduce failed:{}", rankInfo.rankId, retCode); + } + LOG_DEBUG(KEY_PROCESS "channelId:{} threadId:{} batchId:{}, GetScAllForUnique MPI AllReduce end, " + "receiveFlag:{}, barrier time:{}", + channel, commId, batch->batchId, receiveFlag, tc.ElapsedMS()); + // 处理其他rank线程退出的情况 + HandleRankExitScene(commId, batch, receiveFlag); + + EASY_END_BLOCK; // allgather keyScLocal(key all2all keyScLocal = device all2all rc) - int retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, - scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); + retCode = MPI_Allgather(keyScLocal.data(), rankInfo.rankSize, MPI_INT, + scAllOut.data(), rankInfo.rankSize, MPI_INT, comm[channel][commId]); if (retCode != MPI_SUCCESS) { LOG_ERROR("rank {}, MPI_Allgather failed:{}", rankInfo.rankId, retCode); } @@ -1201,7 +1262,16 @@ KeysT KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) auto ret = GetInfo(lookupKeysList, batch, embName, channel); return get(ret); } catch (EmptyList&) { - LOG_TRACE("getting info failed {}[{}]:{}", embName, channel, batch); + unique_lock lockGuard(destroyMutex); + // readEmbKey真实的次数是readEmbedBatchId减1 + int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[channel] - 1; + // 避免eos在keyProcess还未处理完数据时插队到通道前面 + if (isNeedSendEos[channel] && readEmbKeyBatchId < batch) { + SendEos(batch, channel); + return {}; + } + LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", + embName, channel, batch, readEmbKeyBatchId); this_thread::sleep_for(1ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed {}[{}]:{} wrong top", embName, channel, batch); @@ -1240,6 +1310,8 @@ void KeyProcess::SendEos(int batchId, int channel) LOG_INFO("channelId:{} batchId:{}, SendEos end.", channel, batchId); isNeedSendEos[channel] = false; + mpiAllReduceSend[channel] = 0; + isNeedExit[channel] = true; #endif } @@ -1282,6 +1354,7 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa LOG_WARN(KEY_PROCESS "getting lookup keys timeout! {}[{}]:{}", embName, channel, batch); return nullptr; } + try { auto ret = GetInfo(*list, batch, embName, channel); auto it = get>>::iterator>(ret); @@ -1298,7 +1371,7 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa SendEos(batch, channel); return nullptr; } - LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batch id: {}, readEmbKey batch id: {}.", + LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", embName, channel, batch, readEmbKeyBatchId); this_thread::sleep_for(1ms); } catch (WrongListTop&) { diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 1a1c1332..cd8c1ea6 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -70,6 +70,9 @@ namespace MxRec { const char* errorMessage; }; + constexpr int MPI_ABNORMAL_SEND_VALUE = 0; // MPI异常通信时发送0 + constexpr int MPI_NORMAL_SEND_VALUE = 1; // MPI正常通信时发送1 + class EmptyList : public std::exception { }; @@ -161,6 +164,11 @@ namespace MxRec { bool isRunning { false }; + // 是否需要退出当前通道对应预处理线程,区分channel;已发送eos信息时需退出 + bool isNeedExit[2] = {false, false}; + // MPI all reduce通信时发送数据 + int mpiAllReduceSend[2] = {MPI_NORMAL_SEND_VALUE, MPI_NORMAL_SEND_VALUE}; + std::mutex destroyMutex; inline bool HasEmbName(const string& embName) { @@ -275,6 +283,8 @@ namespace MxRec { vector & restore, vector & hotPos, vector >& keyCount); + void HandleRankExitScene(int commId, const unique_ptr &batch, int receiveFlag); + template inline vector Count2Start(const vector& count) const { diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index cf1b7be7..112c64b3 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -8,8 +8,6 @@ #include "eos_dataset_op.h" -#include - #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" @@ -31,12 +29,6 @@ using namespace MxRec; namespace tensorflow { namespace data { -MPI_Comm g_comm; -MPI_Group g_worldGroup; - -const int TIME_OUT_TIMES = 80000; -const int SLEEP_TIME = 1000; - constexpr const char *const EosDatasetOp::kDatasetType; constexpr const char *const EosDatasetOp::kInputDataset; constexpr const char *const EosDatasetOp::kChannelId; @@ -55,8 +47,6 @@ public: auto os_input = input->output_shapes(); output_shapes_ = os_input; keyProcess = Singleton::GetInstance(); - MPI_Comm_group(MPI_COMM_WORLD, &g_worldGroup); - MPI_Comm_create(MPI_COMM_WORLD, g_worldGroup, &g_comm); } ~Dataset() override @@ -146,30 +136,6 @@ private: // 正常数据流程 if (!*end_of_sequence) { - MPI_Request request = MPI_REQUEST_NULL; - MPI_Status status; - - LOG_TRACE("GetNext, step in MPI_Allreduce, exitFlag:[{}]", exitFlag); - MPI_Iallreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, g_comm, &request); - LOG_TRACE("GetNext, step out MPI_Allreduce, exitFlag:[{}]", exitFlag); - - int j = 0; - for (j = 0; j < TIME_OUT_TIMES; ++j) { - int flag = 0; - MPI_Test(&request, &flag, &status); - if (flag) { - MPI_Wait(&request, &status); - break; - } - usleep(SLEEP_TIME); - } - - if (j >= TIME_OUT_TIMES) { - i_ = 1; - *end_of_sequence = true; - LOG_INFO("GetNext, some rank eos, channelID:[{}]", dataset()->channelId_); - dataset()->keyProcess->SetEos(1, dataset()->channelId_); - } return Status::OK(); } @@ -178,10 +144,6 @@ private: exitFlag = 1; *end_of_sequence = true; input_impl_.reset(); - LOG_TRACE("GetNext eos, step in MPI_Allreduce, exitFlag:[{}]", exitFlag); - MPI_Allreduce(&exitFlag, &exitFlag, 1, MPI_INT, MPI_SUM, g_comm); - LOG_TRACE("GetNext eos, step out MPI_Allreduce, exitFlag:[{}]", exitFlag); - LOG_INFO("GetNext eos, channelID:[{}]", dataset()->channelId_); dataset()->keyProcess->SetEos(1, dataset()->channelId_); -- Gitee From 164b19d538ba5a91469a4fa55f0a5478d9c80f89 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 4 Dec 2023 11:38:42 +0800 Subject: [PATCH 503/551] Match-id-cf5280f9f2ed264eb16c148f955b7413968bad36 --- mx_rec/optimizers/adagrad.py | 6 +++++- mx_rec/optimizers/ftrl.py | 5 ++++- mx_rec/optimizers/gradient_descent.py | 4 ++++ mx_rec/optimizers/gradient_descent_by_addr.py | 5 ++++- mx_rec/optimizers/lazy_adam.py | 5 ++++- mx_rec/optimizers/lazy_adam_by_addr.py | 6 +++++- 6 files changed, 26 insertions(+), 5 deletions(-) diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index 5aaf1220..596b6ef3 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -14,7 +14,7 @@ from tensorflow.python.training import adagrad, training_ops from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance, insert_removing_var_list +from mx_rec.util.initialize import get_table_instance, insert_removing_var_list, get_use_dynamic_expansion from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @@ -37,6 +37,10 @@ def create_hash_optimizer(learning_rate=0.001, :param name: Optional name prefix for the operations created when applying gradients. Defaults to "Adagrad". :return: adagrad hash optimizer instance """ + + if get_use_dynamic_expansion(): + raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " + "expansion mode and optimizer correctly") return CustomizedAdagrad(learning_rate=learning_rate, initial_accumulator_value=initial_accumulator_value, use_locking=use_locking, diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 8cda3643..a7424f7d 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -19,7 +19,7 @@ from tensorflow.python.training import ftrl from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance, insert_removing_var_list +from mx_rec.util.initialize import get_table_instance, insert_removing_var_list, get_use_dynamic_expansion from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, ClassValidator, NumValidator, StringValidator, \ @@ -39,6 +39,9 @@ from mx_rec.validator.validator import para_checker_decorator, ClassValidator, N ("linear_name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl", **kwargs): + if get_use_dynamic_expansion(): + raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " + "expansion mode and optimizer correctly") return CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 5456d1bb..3487a8d9 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -15,6 +15,7 @@ from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import MAX_INT32 +from mx_rec.util.initialize import get_use_dynamic_expansion from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @@ -24,6 +25,9 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate, use_locking=False, name="GradientDescent"): + if get_use_dynamic_expansion(): + raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " + "expansion mode and optimizer correctly") return CustomizedGradientDescent(learning_rate=learning_rate, use_locking=use_locking, name=name) diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index 4e0bb7da..33457205 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -12,7 +12,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer +from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer, get_use_dynamic_expansion from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @@ -24,6 +24,9 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer_by_addr(learning_rate, weight_decay=0.0001, use_locking=False, name="GradientDescentByAddr"): + if not get_use_dynamic_expansion(): + raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " + "expansion mode and optimizer correctly") optimizer_by_addr = CustomizedGradientDescentByAddr(learning_rate=learning_rate, weight_decay=weight_decay, use_locking=use_locking, diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 578708fa..75ee91b8 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -18,7 +18,7 @@ from tensorflow.python.training import adam from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance, insert_removing_var_list +from mx_rec.util.initialize import get_table_instance, insert_removing_var_list, get_use_dynamic_expansion from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator @@ -42,6 +42,9 @@ def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1 Returns: a customized optimizer instance """ + if get_use_dynamic_expansion(): + raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " + "expansion mode and optimizer correctly") return CustomizedLazyAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index d42f8a7f..b977ac5e 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -12,7 +12,7 @@ import tensorflow as tf from tensorflow.python.ops import math_ops from tensorflow.python.training import adam -from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer +from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer, get_use_dynamic_expansion from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator @@ -38,6 +38,10 @@ def create_hash_optimizer_by_address(learning_rate=0.001, beta1=0.9, beta2=0.999 Returns: a customized optimizer instance """ + if not get_use_dynamic_expansion(): + raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " + "expansion mode and optimizer correctly") + optimizer_by_addr = CustomizedLazyAdamByAddress(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) insert_optimizer(optimizer_by_addr) -- Gitee From 5ce3bbd056e9f0841479f44049d7594085c2efa6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 4 Dec 2023 19:32:40 +0800 Subject: [PATCH 504/551] Match-id-9818e1b55a431604f38b527df99ea4709381e441 --- src/core/key_process/key_process.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 6698abd4..0796b2c1 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1061,9 +1061,6 @@ vector KeyProcess::GetScAll(const vector& keyScLocal, int commId, cons void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &batch, int receiveFlag) { - if (!isRunning) { - throw EndRunExit("GetScAll end run, isRunning is false."); - } if (receiveFlag < rankInfo.rankSize) { unique_lock lockGuard(destroyMutex); if (isNeedExit[batch->channel]) { -- Gitee From c1d31b8d40be44bc29e92107de81291688736819 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 2 Dec 2023 00:45:28 +0800 Subject: [PATCH 505/551] Match-id-1e360701bba31631cb876ee0ef0aa26b7c2aaae2 --- src/core/emb_hashmap/emb_hashmap.cpp | 2 +- src/core/emb_hashmap/emb_hashmap.h | 18 +-- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 135 ++++++++++-------- src/core/hybrid_mgmt/hybrid_mgmt.h | 69 ++------- src/core/key_process/key_process.cpp | 8 ++ src/core/key_process/key_process.h | 29 ++-- .../ckpt_data_handler_test.cpp | 24 +++- 7 files changed, 136 insertions(+), 149 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 496c47a8..8905f6db 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -109,7 +109,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, DDRPara LOG_INFO(STAT_INFO "channel_id {} batch_id {} rank_id {} swap_key_size {} swap_time_cost {}", channelId, swapId, rankInfo.rankId, swapSize, swapTimeCost.ElapsedMS()); } - + swapId++; EASY_END_BLOCK #endif diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index b8b6c883..f76b18b7 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -46,17 +46,11 @@ namespace MxRec { absl::flat_hash_map embHashMaps; - void FindOffset(const string& embName, const vector& keys, - size_t currentBatchId, size_t keepBatchId, int channelId); - bool FindOffsetHelper(const emb_key_t& key, EmbHashMapInfo& embHashMap, int channelId, size_t& offset) const; void UpdateBatchId(const vector& keys, size_t currentBatchId, size_t keySize, EmbHashMapInfo& embHashMap) const; - int FindSwapPosV2(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, - size_t keepBatchId); - bool FindSwapPosOld(const string& embName, emb_key_t key, size_t hostOffset, size_t currentBatchId, size_t keepBatchId); @@ -68,11 +62,10 @@ namespace MxRec { bool isSSDEnabled { false }; CacheManager* cacheManager; - private: - void FindAndUpdateBatchId(vector& keys, size_t currentBatchId, size_t keySize, - EmbHashMapInfo& embHashMap) const; + GTEST_PRIVATE: - int32_t FindNewOffset(const emb_key_t& key, EmbHashMapInfo& embHashMap) const; + void FindOffset(const string& embName, const vector& keys, + size_t currentBatchId, size_t keepBatchId, int channelId); void AddCacheManagerTraceLog(const string& embTableName, const EmbHashMapInfo& embHashMap) const; @@ -80,11 +73,10 @@ namespace MxRec { void ClearLookupAndSwapOffset(EmbHashMapInfo& embHashMap) const; + void RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap) const; + RankInfo rankInfo; int swapId { 0 }; - - GTEST_PRIVATE: - void RefreshFreqInfoWithSwap(const string& embName, EmbHashMapInfo& embHashMap) const; }; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index e287ef1d..e45e074f 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -10,29 +10,13 @@ #include "utils/logger.h" #include "utils/common.h" #include "checkpoint/checkpoint.h" +#include "key_process/key_process.h" +#include "key_process/feature_admit_and_evict.h" using namespace MxRec; using namespace std; -/// 启动数据处理线程 -/// \param rankInfo 当前rank基本配置信息 -/// \param embInfos 表信息list -/// \param thresholdValues 准入淘汰相关配置 -/// \param seed 随机种子 -/// \return bool类型 启动成功/失败 -bool HybridMgmt::InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, - const vector& thresholdValues, int seed) -{ -#ifndef GTEST - // 初始化数据处理类,配置相关信息,启动处理线程 - preprocess = Singleton::GetInstance(); - preprocess->Initialize(rankInfo, embInfos, thresholdValues, seed); - preprocess->Start(); -#endif - return true; -} - /// Openmpi通信域进程数设置、计算所有表host特征数量总数、设置训练模式(HBM/DDR) /// \param rankInfo /// \param embInfos @@ -102,10 +86,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, hybridMgmtBlock->SetRankInfo(rankInfo); // 启动数据处理线程 - bool rc = InitKeyProcess(rankInfo, embInfos, thresholdValues, seed); - if (!rc) { - return false; - } + KEY_PROCESS_INSTANCE->Initialize(rankInfo, embInfos, thresholdValues, seed); isRunning = true; @@ -230,11 +211,11 @@ bool HybridMgmt::Save(const string savePath) } // 数据处理线程上锁 - preprocess->LoadSaveLock(); + KEY_PROCESS_INSTANCE->LoadSaveLock(); CkptData saveData; Checkpoint saveCkpt; - saveData.keyCountMap = preprocess->GetKeyCountMap(); + saveData.keyCountMap = KEY_PROCESS_INSTANCE->GetKeyCountMap(); if (!mgmtRankInfo.noDDR) { // DDR模式保存host的emb表以及hashmap LOG_DEBUG(MGMT + "Start host side save: ddr mode hashmap"); @@ -243,8 +224,8 @@ bool HybridMgmt::Save(const string savePath) } else { // HBM模式保存最大偏移(真正使用了多少vocab容量),特征到偏移的映射 LOG_DEBUG(MGMT + "Start host side save: no ddr mode hashmap"); - saveData.maxOffset = preprocess->GetMaxOffset(); - saveData.keyOffsetMap = preprocess->GetKeyOffsetMap(); + saveData.maxOffset = KEY_PROCESS_INSTANCE->GetMaxOffset(); + saveData.keyOffsetMap = KEY_PROCESS_INSTANCE->GetKeyOffsetMap(); } if (isSSDEnabled) { @@ -260,7 +241,7 @@ bool HybridMgmt::Save(const string savePath) } // 保存特征准入淘汰相关的数据 - auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); + FeatureAdmitAndEvict& featAdmitNEvict = KEY_PROCESS_INSTANCE->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { LOG_DEBUG(MGMT + "Start host side save: feature admit and evict"); saveData.table2Thresh = featAdmitNEvict.GetTableThresholds(); @@ -272,7 +253,7 @@ bool HybridMgmt::Save(const string savePath) saveCkpt.SaveModel(savePath, saveData, mgmtRankInfo, mgmtEmbInfo); offsetMapToSend = std::move(saveData.offsetMap); // 数据处理线程释放锁 - preprocess->LoadSaveUnlock(); + KEY_PROCESS_INSTANCE->LoadSaveUnlock(); #endif return true; } @@ -288,16 +269,14 @@ bool HybridMgmt::Load(const string& loadPath) } // 数据处理线程上锁 - preprocess->LoadSaveLock(); + KEY_PROCESS_INSTANCE->LoadSaveLock(); LOG_DEBUG(MGMT + "Start host side load process"); CkptData loadData; Checkpoint loadCkpt; vector loadFeatures; - - auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); - SetFeatureTypeForLoad(loadFeatures, featAdmitNEvict); + SetFeatureTypeForLoad(loadFeatures); loadData.hostEmbs = hostEmbs->GetHostEmbs(); // 获取已经初始化好的host emb // 执行加载操作 @@ -305,11 +284,11 @@ bool HybridMgmt::Load(const string& loadPath) // 检查DDR模式保存的模型和当前训练配置是否一致,不一致则退出 if (!mgmtRankInfo.noDDR && !LoadMatchesDDRSetup(loadData)) { - preprocess->LoadSaveUnlock(); + KEY_PROCESS_INSTANCE->LoadSaveUnlock(); return false; } - preprocess->LoadKeyCountMap(loadData.keyCountMap); + KEY_PROCESS_INSTANCE->LoadKeyCountMap(loadData.keyCountMap); if (!mgmtRankInfo.noDDR) { // DDR模式 将加载的hash map进行赋值 LOG_DEBUG(MGMT + "Start host side load: ddr mode hashmap"); @@ -317,11 +296,12 @@ bool HybridMgmt::Load(const string& loadPath) } else { // HBM模式 将加载的最大偏移(真正使用了多少vocab容量)、特征到偏移的映射,进行赋值 LOG_DEBUG(MGMT + "Start host side load: no ddr mode hashmap"); - preprocess->LoadKeyOffsetMap(loadData.keyOffsetMap); - preprocess->LoadMaxOffset(loadData.maxOffset); + KEY_PROCESS_INSTANCE->LoadKeyOffsetMap(loadData.keyOffsetMap); + KEY_PROCESS_INSTANCE->LoadMaxOffset(loadData.maxOffset); } // 将加载的特征准入淘汰记录进行赋值 + FeatureAdmitAndEvict& featAdmitNEvict = KEY_PROCESS_INSTANCE->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { LOG_DEBUG(MGMT + "Start host side load: feature admit and evict"); featAdmitNEvict.LoadTableThresholds(loadData.table2Thresh); @@ -336,7 +316,7 @@ bool HybridMgmt::Load(const string& loadPath) LOG_DEBUG(MGMT + "Finish host side load process"); - preprocess->LoadSaveUnlock(); + KEY_PROCESS_INSTANCE->LoadSaveUnlock(); // 执行训练 if (isLoad) { @@ -346,8 +326,7 @@ bool HybridMgmt::Load(const string& loadPath) return true; } -void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, - const FeatureAdmitAndEvict& featAdmitNEvict) +void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures) { if (GlobalEnv::recordKeyCount) { loadFeatures.push_back(CkptFeatureType::KEY_COUNT_MAP); @@ -362,6 +341,7 @@ void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures, } // 添加特征准入淘汰相关的数据类型的加载 + FeatureAdmitAndEvict& featAdmitNEvict = KEY_PROCESS_INSTANCE->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { loadFeatures.push_back(CkptFeatureType::FEAT_ADMIT_N_EVICT); } @@ -397,7 +377,7 @@ void HybridMgmt::ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap) throw runtime_error("HybridMgmt not initialized. Call Initialize first."); } - preprocess->LoadSaveLock(); + KEY_PROCESS_INSTANCE->LoadSaveLock(); KeyOffsetMemT loadKeyOffsetMap; OffsetMemT loadMaxOffset; if (!receiveKeyOffsetMap.empty()) { @@ -414,11 +394,11 @@ void HybridMgmt::ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap) LOG_DEBUG(MGMT + "Start receive sparse data: ddr mode hashmap"); } else { LOG_DEBUG(MGMT + "Start receive sparse data: no ddr mode hashmap"); - preprocess->LoadKeyOffsetMap(loadKeyOffsetMap); - preprocess->LoadMaxOffset(loadMaxOffset); + KEY_PROCESS_INSTANCE->LoadKeyOffsetMap(loadKeyOffsetMap); + KEY_PROCESS_INSTANCE->LoadMaxOffset(loadMaxOffset); } - preprocess->LoadSaveUnlock(); + KEY_PROCESS_INSTANCE->LoadSaveUnlock(); if (isLoad) { Start(); } @@ -542,6 +522,44 @@ void HybridMgmt::StartThreadForDDR() #endif } +void HybridMgmt::Destroy() +{ + LOG_DEBUG(MGMT + "start Destroy hybrid_mgmt module"); + if (!isInitialized) { + throw runtime_error("HybridMgmt not initialized. Call Initialize first."); + } + + if (!isRunning) { + return; + } + // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 + isRunning = false; + // 获取锁 避免KeyProcess中手动发送结束信息时通道关闭 + std::unique_lock lockGuard(KEY_PROCESS_INSTANCE->destroyMutex); + // 先发送停止信号给KEY_PROCESS_INSTANCE,用于停止查询中lookup卡住状态 + KEY_PROCESS_INSTANCE->isRunning = false; + // 停止hdTransfer,用于停止mgmt的recv中卡住状态 + hdTransfer->Destroy(); + LOG_DEBUG(MGMT + "destroy hdTransfer end."); + + hybridMgmtBlock->Destroy(); + for (auto& t : procThreads) { + t->join(); + } + if (cacheManager != nullptr) { + cacheManager = nullptr; + } + if (hostEmbs != nullptr) { + hostEmbs->Join(TRAIN_CHANNEL_ID); + hostEmbs->Join(EVAL_CHANNEL_ID); + hostEmbs = nullptr; + } + procThreads.clear(); + // 停止预处理 + KEY_PROCESS_INSTANCE->Destroy(); + LOG_DEBUG(MGMT + "Destroy hybrid_mgmt module end."); +}; + #ifndef GTEST /// 启动hybrid处理任务 /// \param type @@ -619,7 +637,7 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) for (const auto& embInfo: mgmtEmbInfo) { TimeCost parseKeysTc; // 获取各类向量,如果为空指针,退出当前函数 - auto infoVecs = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); + auto infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::RESTORE); if (infoVecs == nullptr) { LOG_INFO(MGMT + "channelId:{} batchId:{}, ParseKeys infoVecs empty !", channelId, batchId); return false; @@ -630,7 +648,7 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) unique_ptr> all2all = nullptr; if (!mgmtRankInfo.useStatic) { TimeCost getTensorsSyncTC; - all2all = preprocess->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); + all2all = KEY_PROCESS_INSTANCE->GetInfoVec(batchId, embInfo.name, channelId, ProcessedInfo::ALL2ALL); LOG_DEBUG("channelId:{} batchId:{}, getTensorsSyncTC(ms):{}", channelId, batchId, getTensorsSyncTC.ElapsedMS()); if (all2all == nullptr) { @@ -770,7 +788,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha embHashMap.SetStartCount(); // 获取查询向量 - auto lookupKeys = preprocess->GetLookupKeys(batchId, embName, channelId); + auto lookupKeys = KEY_PROCESS_INSTANCE->GetLookupKeys(batchId, embName, channelId); if (lookupKeys.empty()) { remainBatchOut = false; LOG_ERROR("channelId:{} batchId:{}, embName:{}, GetLookupKeys result is empty.", @@ -779,8 +797,12 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha } LOG_DEBUG("channelId:{} batchId:{}, embName:{}, GetLookupKeys end.", channelId, batchId, embName); // 获取各类向量,如果为空指针,退出当前函数 - auto infoVecs = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::RESTORE); - if (infoVecs == nullptr) { return false; } + unique_ptr> infoVecs = KEY_PROCESS_INSTANCE->GetInfoVec(batchId, embName, channelId, + ProcessedInfo::RESTORE); + if (infoVecs == nullptr) { + LOG_ERROR("Information vector is nullptr!"); + return false; + } LOG_DEBUG("channelId:{} batchId:{}, GetInfoVec end, getTensorsTC(ms):{}", channelId, batchId, getTensorsTC.ElapsedMS()); @@ -811,7 +833,8 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha ddrParam.tmpDataOut.erase(ddrParam.tmpDataOut.cbegin()); hdTransfer->Send(TransferChannel::SWAP, ddrParam.tmpDataOut, channelId, embName); if (!mgmtRankInfo.useStatic) { - auto all2all = preprocess->GetInfoVec(batchId, embName, channelId, ProcessedInfo::ALL2ALL); + unique_ptr> all2all = KEY_PROCESS_INSTANCE->GetInfoVec(batchId, embName, + channelId, ProcessedInfo::ALL2ALL); if (all2all == nullptr) { LOG_ERROR("Information vector is nullptr!"); return false; @@ -834,7 +857,7 @@ void HybridMgmt::SendUniqKeysAndRestoreVecDDR(const string &embName, int &batchI LOG_DEBUG("channelId:{} batchId:{}, embName:{}, SendUniqKeysAndRestoreVecDDR start.", channelId, batchId, embName); vector uniqueKeys; vector restoreVecSec; - preprocess->GlobalUnique(ddrParam.offsetsOut, uniqueKeys, restoreVecSec); + KEY_PROCESS_INSTANCE->GlobalUnique(ddrParam.offsetsOut, uniqueKeys, restoreVecSec); TimeCost sendUniqueKeysSyncTC; hdTransfer->Send(TransferChannel::UNIQKEYS, {mgmtRankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueKeys) : @@ -911,7 +934,7 @@ bool HybridMgmt::Evict() } // 配置了淘汰选项,则触发 - auto& featAdmitNEvict = preprocess->GetFeatAdmitAndEvict(); + FeatureAdmitAndEvict& featAdmitNEvict = KEY_PROCESS_INSTANCE->GetFeatAdmitAndEvict(); if (featAdmitNEvict.GetFunctionSwitch()) { featAdmitNEvict.FeatureEvict(evictKeyMap); } else { @@ -928,10 +951,10 @@ bool HybridMgmt::Evict() if (mgmtRankInfo.noDDR) { if (GlobalEnv::useCombineFaae) { - preprocess->EvictKeysCombine(evictKeyMap[COMBINE_HISTORY_NAME]); + KEY_PROCESS_INSTANCE->EvictKeysCombine(evictKeyMap[COMBINE_HISTORY_NAME]); } else { for (const auto& evict : as_const(evictKeyMap)) { - preprocess->EvictKeys(evict.first, evict.second); + KEY_PROCESS_INSTANCE->EvictKeys(evict.first, evict.second); } } } else { @@ -1086,12 +1109,12 @@ int64_t HybridMgmt::GetTableSize(const string& embName) const } if (mgmtRankInfo.useDynamicExpansion) { - int64_t size = preprocess->GetExpansionTableSize(embName); + int64_t size = KEY_PROCESS_INSTANCE->GetExpansionTableSize(embName); LOG_INFO(MGMT + "dynamic expansion mode, get emb:[{}] size:{}", embName, size); return size; } if (mgmtRankInfo.noDDR) { - auto maxOffset = preprocess->GetMaxOffset(); + auto maxOffset = KEY_PROCESS_INSTANCE->GetMaxOffset(); const auto& iter = maxOffset.find(embName); if (iter == maxOffset.end()) { LOG_ERROR(MGMT + "get maxOffset, wrong embName:{} ", embName); @@ -1130,7 +1153,7 @@ int64_t HybridMgmt::GetTableCapacity(const string& embName) const } if (mgmtRankInfo.useDynamicExpansion) { - int64_t capacity = preprocess->GetExpansionTableCapacity(embName); + int64_t capacity = KEY_PROCESS_INSTANCE->GetExpansionTableCapacity(embName); LOG_INFO(MGMT + "dynamic expansion mode, get emb:[{}] capacity:{}", embName, capacity); return capacity; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 3e24cda1..7a9fa84d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -8,10 +8,6 @@ #ifndef MX_REC_EMB_MGMT_H #define MX_REC_EMB_MGMT_H -#include - -#include - #include #include #include @@ -20,13 +16,10 @@ #include "utils/common.h" #include "utils/config.h" -#include "utils/singleton.h" -#include "utils/logger.h" #include "host_emb/host_emb.h" #include "emb_hashmap/emb_hashmap.h" #include "hd_transfer/hd_transfer.h" -#include "key_process/key_process.h" #include "ssd_cache/cache_manager.h" #include "hybrid_mgmt_block.h" @@ -61,9 +54,6 @@ namespace MxRec { bool Load(const string& loadPath); - void SetFeatureTypeForLoad(vector& loadFeatures, - const FeatureAdmitAndEvict& featAdmitNEvict); - OffsetT SendHostMap(const string tableName); void ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap); @@ -71,50 +61,10 @@ namespace MxRec { void Start(); void StartThreadForHBM(); + void StartThreadForDDR(); - void Destroy() - { - LOG_DEBUG(MGMT + "start Destroy hybrid_mgmt module"); - if (!isInitialized) { - throw runtime_error("HybridMgmt not initialized. Call Initialize first."); - } - - if (!isRunning) { - return; - } - // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 - isRunning = false; - if (preprocess != nullptr) { - // 获取锁 避免KeyProcess中手动发送结束信息时通道关闭 - std::unique_lock lockGuard(preprocess->destroyMutex); - // 先发送停止信号给preprocess,用于停止查询中lookup卡住状态 - preprocess->isRunning = false; - // 停止hdTransfer,用于停止mgmt的recv中卡住状态 - hdTransfer->Destroy(); - LOG_DEBUG(MGMT + "destroy hdTransfer end."); - } - hybridMgmtBlock->Destroy(); - for (auto& t : procThreads) { - t->join(); - } - if (cacheManager != nullptr) { - cacheManager = nullptr; - } - if (hostEmbs != nullptr) { - hostEmbs->Join(TRAIN_CHANNEL_ID); - hostEmbs->Join(EVAL_CHANNEL_ID); - hostEmbs = nullptr; - } - procThreads.clear(); - // 停止预处理 - if (preprocess != nullptr) { - preprocess->Destroy(); - preprocess = nullptr; - LOG_DEBUG(MGMT + "invoke KeyProcess destroy end."); - } - LOG_DEBUG(MGMT + "Destroy hybrid_mgmt module end."); - }; + void Destroy(); bool ParseKeys(int channelId, int& batchId); @@ -126,10 +76,6 @@ namespace MxRec { bool Evict(); - void EvictKeys(const string& embName, const vector& keys); - - bool IsLoadDataMatches(const EmbMemT& loadHostEmbs, const EmbInfo& setupHostEmbs, size_t& embTableCount) const; - void NotifyBySessionRun(int channelID) const; void CountStepBySessionRun(int channelID, int steps) const; @@ -138,9 +84,13 @@ namespace MxRec { int64_t GetTableCapacity(const string& embName) const; - private: - bool InitKeyProcess(const RankInfo& rankInfo, const vector& embInfos, - const vector& thresholdValues, int seed); + GTEST_PRIVATE: + + void SetFeatureTypeForLoad(vector& loadFeatures); + + bool IsLoadDataMatches(const EmbMemT& loadHostEmbs, const EmbInfo& setupHostEmbs, size_t& embTableCount) const; + + void EvictKeys(const string& embName, const vector& keys); void InitRankInfo(RankInfo& rankInfo, const vector& embInfos) const; @@ -167,7 +117,6 @@ namespace MxRec { unique_ptr hostHashMaps {}; vector> procThreads {}; map> evictKeyMap {}; - KeyProcess *preprocess; HDTransfer *hdTransfer; OffsetMapT offsetMapToSend; bool isSSDEnabled { false }; diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 0796b2c1..a1237634 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -7,6 +7,11 @@ #include "key_process.h" +#include +#include "utils/safe_queue.h" +#include "utils/time_cost.h" +#include "utils/config.h" +#include "host_emb/host_emb.h" #include "checkpoint/checkpoint.h" #include "hd_transfer/hd_transfer.h" #include "ock_ctr_common/include/error_code.h" @@ -73,6 +78,9 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos LOG_INFO(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", MapToString(scInfo), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot); +#ifndef GTEST + Start(); +#endif return true; } diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index cd8c1ea6..5ed771ba 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -9,30 +9,21 @@ #define MX_REC_KEY_PROCESS_H #include -#include -#include #include #include #include #include -#include #include #include #include - #include "ock_ctr_common/include/factory.h" #include "utils/common.h" -#include "utils/config.h" -#include "utils/time_cost.h" -#include "utils/safe_queue.h" - -#include "host_emb/host_emb.h" #include "emb_table/emb_table.h" - #include "feature_admit_and_evict.h" #include "hybrid_mgmt/hybrid_mgmt_block.h" +#include "utils/singleton.h" namespace MxRec { using namespace std; @@ -90,15 +81,13 @@ namespace MxRec { int GetMaxStep(int channelId) const; - int Start(); - - auto GetMaxOffset() -> OffsetMemT; + OffsetMemT GetMaxOffset(); - auto GetKeyOffsetMap() -> KeyOffsetMemT; + KeyOffsetMemT GetKeyOffsetMap(); - auto GetKeyCountMap() -> KeyCountMemT; + KeyCountMemT GetKeyCountMap(); - auto GetFeatAdmitAndEvict() -> FeatureAdmitAndEvict&; + FeatureAdmitAndEvict& GetFeatAdmitAndEvict(); void LoadMaxOffset(OffsetMemT& loadData); @@ -175,6 +164,9 @@ namespace MxRec { return embInfos.find(embName) != embInfos.end(); }; GTEST_PRIVATE: + + int Start(); + template T GetInfo(info_list_t& list, int batch, const string& embName, int channel); @@ -199,7 +191,7 @@ namespace MxRec { ock::ctr::FactoryPtr factory {}; int hotEmbUpdateStep = HOT_EMB_UPDATE_STEP_DEFAULT; bool isWithFAAE; - vector isNeedSendEos { false, false }; // 分别代表通道0、1的eos状态 + bool isNeedSendEos[2] = { 0, 0 }; // 分别代表通道0、1的eos状态 void InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo); @@ -297,5 +289,8 @@ namespace MxRec { string DumpSplitKeys(vector>& splitKeys) const; }; + +#define KEY_PROCESS_INSTANCE Singleton::GetInstance() } // end namespace MxRec + #endif // MX_REC_KEY_PROCESS_H diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 0e81ea56..8de60233 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -11,7 +11,9 @@ #include "ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h" #include "ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h" #include "ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h" -#include "ckpt_data_handler//feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" +#include "ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h" +#include "ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h" +#include "utils/common.h" using namespace std; using namespace MxRec; @@ -294,4 +296,22 @@ TEST_F(CkptDataHandlerTest, FeatAdmitNEvict) // 测试load TestForLoad(args); -} \ No newline at end of file +} + +TEST_F(CkptDataHandlerTest, KeyCountMapCkpt) +{ + CkptData data; + KeyCountMapCkpt ckpt; + CkptTransData tranData; + const int testCount = 10; + const int exceptCount = 5; + for (int i = 0; i < testCount; ++i) { + tranData.int64Arr.push_back(i); + } + + ckpt.SetProcessData(data); + ckpt.SetDataset(CkptDataType::KEY_COUNT_MAP, std::string("test"), tranData); + ckpt.GetProcessData(data); + absl::flat_hash_map &testmap = data.keyCountMap["test"]; + EXPECT_EQ(testmap.size(), exceptCount); +} -- Gitee From c1f829bf1b9a8972e66728faecc451ec038de499 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 6 Dec 2023 17:18:44 +0800 Subject: [PATCH 506/551] Match-id-6b54c0b84a56db46d1619b566919aaf383464b93 --- mx_rec/constants/constants.py | 1 + mx_rec/core/embedding.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 6a1d18c8..07d2aa7d 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -52,6 +52,7 @@ TRAIN_CHANNEL_ID = 0 EVAL_CHANNEL_ID = 1 HASHTABLE_COLLECTION_NAME_LENGTH = 30 MAX_VOCABULARY_SIZE = 10**10 +MAX_DEVICE_VOCABULARY_SIZE = 10 ** 9 # RANK INFO VALID_DEVICE_ID_LIST = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15"] diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 2d4fde9f..025c1d22 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -18,7 +18,8 @@ from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temp from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, \ - ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_VOCABULARY_SIZE + ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_VOCABULARY_SIZE, \ + MAX_DEVICE_VOCABULARY_SIZE from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, get_customized_ops, \ insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ @@ -41,7 +42,7 @@ from mx_rec.util.log import logger ("emb_initializer", ClassValidator, {"classes": (InitializerV1, InitializerV2)}), ("optimizer_list", ClassValidator, {"classes": (list, type(None))}), (["ssd_vocabulary_size", "ssd_data_path", "host_vocabulary_size"], SSDFeatureValidator), - ("device_vocabulary_size", IntValidator, {"min_value": 1, "max_value": MAX_VOCABULARY_SIZE}, + ("device_vocabulary_size", IntValidator, {"min_value": 1, "max_value": MAX_DEVICE_VOCABULARY_SIZE}, ["check_value"]), ("host_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), ("ssd_vocabulary_size", IntValidator, {"min_value": 0, "max_value": MAX_VOCABULARY_SIZE}, ["check_value"]), -- Gitee From 1541456b50c586d77cfb658753c17bbbe7c7f13d Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 7 Dec 2023 20:40:51 +0800 Subject: [PATCH 507/551] Match-id-fad408a6e4963c34eb4102fbe4f333081ad62b2d --- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 1 + src/core/hybrid_mgmt/hybrid_mgmt_block.h | 3 ++ src/core/key_process/key_process.cpp | 34 ++++++++++++---- src/core/key_process/key_process.h | 1 + src/core/utils/common.h | 1 + src/tests/checkpoint/checkpoint_test.cpp | 39 +++++++++++++++++++ .../ckpt_data_handler_test.cpp | 12 ++++++ .../file_system/local_file_system_test.cpp | 12 ++++++ .../hybrid_mgmt/hybrid_mgmt_block_test.cpp | 14 +++++++ src/tests/key_process/key_process_test.cpp | 21 ++++++++++ 10 files changed, 130 insertions(+), 8 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index bde5ce80..daeb07c8 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -84,6 +84,7 @@ void HybridMgmtBlock::CountPythonStep(int channelId, int steps) { // 相应的通知计数 pythonBatchId[channelId] += steps; + loop[channelId] = steps; } /// 检查是否进行了通道切换,检查当前的step是否合理 diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h index 02d4a070..930df3ea 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -30,6 +30,9 @@ namespace MxRec { int pythonBatchId[2] = {0, 0}; // readEmbed算子侧将要处理的batch id int readEmbedBatchId[2] = {0, 0}; + + int loop[2] = {1, 1}; + bool isRunning = true; ~HybridMgmtBlock(); diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index a1237634..1eb2de21 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1076,6 +1076,22 @@ void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &ba batch->channel, commId, batch->batchId); throw EndRunExit("has send acl eos info, thread will exit."); } + LOG_INFO("channelId:{} batchId:{}, GetScAll HandleRankExitScene eos.", batch->channel, batch->batchId); + + int timeout = 0; + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + bool isExit = hybridMgmtBlock->pythonBatchId[batch->channel] < + (hybridMgmtBlock->hybridBatchId[batch->channel] - hybridMgmtBlock->loop[batch->channel] + 1); + while (isExit && timeout < EOS_TIMEOUT) { + LOG_DEBUG("wait until hybridBatchId equal pythonBatchId before SendEos, channelId:{}, pyBatchId:{}, " + "mgmtBatchId:{}", batch->channel, hybridMgmtBlock->pythonBatchId[batch->channel], + hybridMgmtBlock->hybridBatchId[batch->channel]); + this_thread::sleep_for(seconds(1)); + isExit = hybridMgmtBlock->pythonBatchId[batch->channel] < + (hybridMgmtBlock->hybridBatchId[batch->channel] - hybridMgmtBlock->loop[batch->channel] + 1); + timeout ++; + } + SendEos(batch->batchId, batch->channel); isNeedExit[batch->channel] = true; throw EndRunExit("has SendEosInfo, GetScAll end, thread will exit."); @@ -1267,11 +1283,13 @@ KeysT KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) auto ret = GetInfo(lookupKeysList, batch, embName, channel); return get(ret); } catch (EmptyList&) { - unique_lock lockGuard(destroyMutex); + unique_lock lockEosGuard(eosMutex); // readEmbKey真实的次数是readEmbedBatchId减1 int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[channel] - 1; // 避免eos在keyProcess还未处理完数据时插队到通道前面 if (isNeedSendEos[channel] && readEmbKeyBatchId < batch) { + LOG_INFO("channelId:{} batchId:{}, GetLookupKeys eos.", channel, batch); + unique_lock lockDestroyGuard(destroyMutex); SendEos(batch, channel); return {}; } @@ -1368,16 +1386,16 @@ unique_ptr> KeyProcess::GetInfoVec(int batch, const string& embNa storage.erase(it); return uTensor; } catch (EmptyList&) { - unique_lock lockGuard(destroyMutex); - // readEmbKey真实的次数是readEmbedBatchId减1 - int readEmbKeyBatchId = hybridMgmtBlock->readEmbedBatchId[channel] - 1; - // 避免eos在keyProcess还未处理完数据时插队到通道前面 - if (isNeedSendEos[channel] && readEmbKeyBatchId < batch) { + unique_lock lockEosGuard(eosMutex); + // 避免eos在keyProcess还未处理完数据时插队到通道前面, readEmbKey真实的次数是readEmbedBatchId减1 + if (isNeedSendEos[channel] && (hybridMgmtBlock->readEmbedBatchId[channel] - 1) < batch) { + LOG_INFO("channelId:{} batchId:{}, GetInfoVec eos.", channel, batch); + unique_lock lockDestroyGuard(destroyMutex); SendEos(batch, channel); return nullptr; } LOG_TRACE("getting info failed {}[{}], list is empty, and mgmt batchId: {}, readEmbKey batchId: {}.", - embName, channel, batch, readEmbKeyBatchId); + embName, channel, batch, (hybridMgmtBlock->readEmbedBatchId[channel] - 1)); this_thread::sleep_for(1ms); } catch (WrongListTop&) { LOG_TRACE("getting info failed {}[{}]:{} wrong top", embName, channel, batch); @@ -1552,7 +1570,7 @@ void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) void KeyProcess::SetEos(int status, int channelId) { - unique_lock lockGuard(destroyMutex); + unique_lock lockGuard(eosMutex); LOG_INFO("isNeedSendEos status is changed, before status:[{}], input status:{}, channel:[{}], ", isNeedSendEos[channelId], status, channelId); isNeedSendEos[channelId] = (status == 1); diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index 5ed771ba..b28038ca 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -159,6 +159,7 @@ namespace MxRec { int mpiAllReduceSend[2] = {MPI_NORMAL_SEND_VALUE, MPI_NORMAL_SEND_VALUE}; std::mutex destroyMutex; + std::mutex eosMutex; inline bool HasEmbName(const string& embName) { return embInfos.find(embName) != embInfos.end(); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 33486448..30f1b17c 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -82,6 +82,7 @@ namespace MxRec { constexpr int KEY_PROCESS_TIMEOUT = 120; constexpr int GET_BATCH_TIMEOUT = 300; + constexpr int EOS_TIMEOUT = 60; constexpr size_t DEFAULT_RANDOM_SEED = 10086; constexpr int INVALID_KEY_VALUE = -1; diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 607a7c96..3754cdf2 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -183,6 +183,14 @@ protected: } } + void SetKeyCountMap(absl::flat_hash_map& testKeyCountMap) + { + for (int64_t i { 0 }; i < hostVocabSize; ++i) { + testKeyCountMap[featMem] = i; + featMem++; + } + } + void SetExcludeDDRKeyFreqMap(unordered_map& testExcludeDDRKeyFreqMap) { for (int64_t i { 0 }; i < hostVocabSize; ++i) { @@ -200,6 +208,15 @@ protected: } } + void SetKeyCountMaps(KeyCountMemT & testKeyCountMaps) + { + absl::flat_hash_map testKeyCountMap; + for (const auto& testEmbInfo : testEmbInfos) { + SetKeyCountMap(testKeyCountMap); + testKeyCountMaps[testEmbInfo.name] = std::move(testKeyCountMap); + } + } + void SetExcludeDDRKeyFreqMaps(KeyFreqMemT& testExcludeDDRKeyFreqMaps) { unordered_map testExcludeDDRKeyFreqMap; @@ -534,3 +551,25 @@ TEST_F(CheckpointTest, KeyFreqMaps) } } +TEST_F(CheckpointTest, KeyCountMapCkpt) +{ + KeyCountMemT testKeyCountMaps; + KeyCountMemT validKeyCountMaps; + + SetEmbInfo(); + SetKeyCountMaps(testKeyCountMaps); + + validKeyCountMaps = testKeyCountMaps; + + CkptData testSaveData; + CkptData validLoadData; + CkptData testLoadData; + + testSaveData.keyCountMap = std::move(testKeyCountMaps); + validLoadData.keyCountMap = std::move(validKeyCountMaps); + + Checkpoint testCkpt; + testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); + testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::KEY_COUNT_MAP }); + EXPECT_EQ(validLoadData.keyCountMap.size(), testLoadData.keyCountMap.size()); +} \ No newline at end of file diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 8de60233..2ac145d4 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -315,3 +315,15 @@ TEST_F(CkptDataHandlerTest, KeyCountMapCkpt) absl::flat_hash_map &testmap = data.keyCountMap["test"]; EXPECT_EQ(testmap.size(), exceptCount); } + +TEST_F(CkptDataHandlerTest, SetDatasetForLoadEmb) +{ + KeyCountMapCkpt ckpt; + CkptTransData tranData; + CkptData data; + try { + ckpt.SetDatasetForLoadEmb(CkptDataType::KEY_COUNT_MAP, std::string("test"), tranData, data); + } catch (runtime_error& e) { + LOG_INFO(KEY_PROCESS "success"); + } +} \ No newline at end of file diff --git a/src/tests/file_system/local_file_system_test.cpp b/src/tests/file_system/local_file_system_test.cpp index 410bb63a..359b86de 100644 --- a/src/tests/file_system/local_file_system_test.cpp +++ b/src/tests/file_system/local_file_system_test.cpp @@ -28,3 +28,15 @@ TEST(LocalFileSystem, WriteAndReadFile) ASSERT_EQ(writeData.size() * sizeof(int64_t), res); } +TEST(LocalFileSystem, WriteEmbedding) +{ + string filePath = "./write.data"; + float p[5] = {1.1, 2.2, 3.3, 4.4, 5.5}; + vector writeData = {p, p+1, p+2, p+3, p+4}; + + auto fileSystemHandler = make_unique(); + auto fileSystemPtr = fileSystemHandler->Create(filePath); + ssize_t res = fileSystemPtr->Write(filePath, writeData, sizeof(float)); + + ASSERT_EQ(writeData.size() * sizeof(float), res); +} diff --git a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp index ff0bc6ba..f278875a 100644 --- a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp +++ b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp @@ -27,6 +27,7 @@ protected: { LOG_DEBUG("start initialize") ; } + int loop = 10; }; TEST_F(HybridMgmtBlockTest, CheckAndDoBlock) @@ -134,4 +135,17 @@ TEST_F(HybridMgmtBlockTest, CheckSaveEmbMapValid) ASSERT_EQ(status0, 0); ASSERT_EQ(status1, 1); ASSERT_EQ(status2, -1); +} + +TEST_F(HybridMgmtBlockTest, CountPythonStep) +{ + hybridMgmtBlock = std::make_unique(); + + hybridMgmtBlock->pythonBatchId[0] = 1; + hybridMgmtBlock->loop[0] = 1; + + hybridMgmtBlock->CountPythonStep(0, loop); + + ASSERT_EQ(hybridMgmtBlock->pythonBatchId[0], loop + 1); + ASSERT_EQ(hybridMgmtBlock->loop[0], loop); } \ No newline at end of file diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index dbed8831..c5708488 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -304,6 +304,27 @@ TEST_F(KeyProcessTest, GetScAll) ASSERT_THAT(scAll, ElementsAreArray(expectScAll)); } +TEST_F(KeyProcessTest, HandleRankExitScene) +{ + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); + ASSERT_EQ(process.isRunning, true); + // 仅用于集合通信获取sendCount信息,构造EmbBatchT对象即可,通道传0,不用构造batch数据 + EmbBatchT tempBatch; + tempBatch.channel = 0; + unique_ptr batch = std::make_unique(tempBatch); + + std::unique_ptr hybridMgmtBlock = std::make_unique(); + hybridMgmtBlock->pythonBatchId[0] = 1; + hybridMgmtBlock->hybridBatchId[0] = 1; + hybridMgmtBlock->loop[0] = 1; + + try { + process.HandleRankExitScene(0, batch, 0); + } catch (EndRunExit e) { + LOG_INFO(KEY_PROCESS "success"); + } +} + TEST_F(KeyProcessTest, GetScAllForUnique) { vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 -- Gitee From ad216dafe4335a0c22b9687ae1d0ce031567524f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 8 Dec 2023 17:02:51 +0800 Subject: [PATCH 508/551] Match-id-8770800d65fddb61f438a4faa5f4b6cc6091d1e6 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 17 ++++++++++------- src/core/key_process/key_process.cpp | 7 +++---- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index e45e074f..36d54e7b 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -534,13 +534,16 @@ void HybridMgmt::Destroy() } // 先发送停止信号mgmt,先停止新lookup查询, 解除queue的限制防止卡住 isRunning = false; - // 获取锁 避免KeyProcess中手动发送结束信息时通道关闭 - std::unique_lock lockGuard(KEY_PROCESS_INSTANCE->destroyMutex); - // 先发送停止信号给KEY_PROCESS_INSTANCE,用于停止查询中lookup卡住状态 - KEY_PROCESS_INSTANCE->isRunning = false; - // 停止hdTransfer,用于停止mgmt的recv中卡住状态 - hdTransfer->Destroy(); - LOG_DEBUG(MGMT + "destroy hdTransfer end."); + + { + // 获取锁 避免KeyProcess中手动发送结束信息时通道关闭 + std::unique_lock lockGuard(KEY_PROCESS_INSTANCE->destroyMutex); + // 先发送停止信号给KEY_PROCESS_INSTANCE,用于停止查询中lookup卡住状态 + KEY_PROCESS_INSTANCE->isRunning = false; + // 停止hdTransfer,用于停止mgmt的recv中卡住状态 + hdTransfer->Destroy(); + LOG_DEBUG(MGMT + "destroy hdTransfer end."); + } hybridMgmtBlock->Destroy(); for (auto& t : procThreads) { diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 1eb2de21..95b53a0a 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1071,10 +1071,10 @@ void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &ba { if (receiveFlag < rankInfo.rankSize) { unique_lock lockGuard(destroyMutex); - if (isNeedExit[batch->channel]) { - LOG_INFO("channelId:{} threadId:{} batchId:{}, has send acl eos info, thread will exit.", + if (isNeedExit[batch->channel] || !isRunning) { + LOG_INFO("channelId:{} threadId:{} batchId:{}, no need send acl eos, thread will exit.", batch->channel, commId, batch->batchId); - throw EndRunExit("has send acl eos info, thread will exit."); + throw EndRunExit("no need send acl eos, thread will exit."); } LOG_INFO("channelId:{} batchId:{}, GetScAll HandleRankExitScene eos.", batch->channel, batch->batchId); @@ -1093,7 +1093,6 @@ void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &ba } SendEos(batch->batchId, batch->channel); - isNeedExit[batch->channel] = true; throw EndRunExit("has SendEosInfo, GetScAll end, thread will exit."); } } -- Gitee From fcad725e306c19597c5f10096fe0560578543e8c Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sat, 9 Dec 2023 11:31:28 +0800 Subject: [PATCH 509/551] Match-id-1378ce20a09288b3cfff9bab6431828727fa8097 --- src/core/key_process/key_process.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 95b53a0a..3befd2f8 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1089,7 +1089,7 @@ void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &ba this_thread::sleep_for(seconds(1)); isExit = hybridMgmtBlock->pythonBatchId[batch->channel] < (hybridMgmtBlock->hybridBatchId[batch->channel] - hybridMgmtBlock->loop[batch->channel] + 1); - timeout ++; + timeout++; } SendEos(batch->batchId, batch->channel); -- Gitee From 7975f6f0fc0b38ca62f7be083a201f23b50ec025 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 11 Dec 2023 15:33:28 +0800 Subject: [PATCH 510/551] Match-id-d0286578f4a9bfca8a59b820cfa528f740a21267 --- src/tests/utils/common_h_test.cpp | 137 ++++++++++++++++++++++++++++ src/tests/utils/common_test.cpp | 47 +++++++++- src/tests/utils/config_test.cpp | 100 ++++++++++++++++++++ src/tests/utils/log_test.cpp | 4 +- src/tests/utils/safe_queue_test.cpp | 104 +++++++++++++++++++++ 5 files changed, 386 insertions(+), 6 deletions(-) create mode 100644 src/tests/utils/common_h_test.cpp create mode 100644 src/tests/utils/config_test.cpp create mode 100644 src/tests/utils/safe_queue_test.cpp diff --git a/src/tests/utils/common_h_test.cpp b/src/tests/utils/common_h_test.cpp new file mode 100644 index 00000000..2fcd4083 --- /dev/null +++ b/src/tests/utils/common_h_test.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: common.h test + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + +#include +#include "utils/common.h" + +TEST(TestStringFormat, Basic) +{ + EXPECT_EQ(MxRec::StringFormat("%s %d", "test", 123), "test 123"); +} + +TEST(TestVectorToString, Basic) +{ + std::vector vec = {1, 2, 3}; + EXPECT_EQ(MxRec::VectorToString(vec), "[1, 2, 3]"); +} + +TEST(TestMapToString, Basic) +{ + std::map map = {{1, 2}, {3, 4}}; + EXPECT_EQ(MxRec::MapToString(map), "{1: 2, 3: 4}"); +} + +// 测试 MapToString 函数可以处理 absl::flat_hash_map +TEST(TestMapToString, AbseilFlatHashMap) +{ + absl::flat_hash_map map = {{1, 2}, {3, 4}}; + std::string result = MxRec::MapToString(map); + EXPECT_TRUE(result.find("1: 2") != std::string::npos); + EXPECT_TRUE(result.find("3: 4") != std::string::npos); +} + +TEST(TestVec2TensorI32, Basic) +{ + std::vector vec = {1, 2, 3}; + tensorflow::Tensor tensor = MxRec::Vec2TensorI32(vec); + auto tensor_data = tensor.flat(); + for (int i = 0; i < vec.size(); ++i) { + EXPECT_EQ(tensor_data(i), vec[i]); + } +} + +TEST(TestVec2TensorI64, Basic) { + std::vector vec = {1, 2, 3}; + tensorflow::Tensor tensor = MxRec::Vec2TensorI64(vec); + auto tensor_data = tensor.flat(); + for (int i = 0; i < vec.size(); ++i) { + EXPECT_EQ(tensor_data(i), vec[i]); + } +} + +TEST(TestGetUBSize, InvalidDeviceID) +{ + EXPECT_THROW(MxRec::GetUBSize(999), std::runtime_error); +} + +// 测试 Batch 结构的 Size 和 UnParse 方法 +TEST(TestBatch, SizeAndUnParse) +{ + MxRec::Batch batch; + batch.sample = {1, 2, 3}; + EXPECT_EQ(batch.Size(), 3); + EXPECT_EQ(batch.UnParse(), "1 2 3 "); +} + +// 测试 RankInfo 结构的默认构造函数 +TEST(TestRankInfo, DefaultConstructor) +{ + MxRec::RankInfo rankInfo; +} + +// 测试 ThresholdValue 结构的默认构造函数 +TEST(TestThresholdValue, DefaultConstructor) +{ + MxRec::ThresholdValue thresholdValue; +} + +// 测试 FeatureItemInfo 结构的默认构造函数和带参数的构造函数 +TEST(TestFeatureItemInfo, Constructors) +{ + MxRec::FeatureItemInfo featureItemInfo1; + + MxRec::FeatureItemInfo featureItemInfo2(123, 456); +} + +// 测试 AdmitAndEvictData 结构的默认构造函数 +TEST(TestAdmitAndEvictData, DefaultConstructor) +{ + MxRec::AdmitAndEvictData admitAndEvictData; +} + +// 测试 EmbInfo 结构的默认构造函数 +TEST(TestEmbInfo, DefaultConstructor) +{ + MxRec::EmbInfo embInfo; +} + +// 测试 HostEmbTable 结构的默认构造函数 +TEST(TestHostEmbTable, DefaultConstructor) +{ + MxRec::HostEmbTable hostEmbTable; +} + +// 测试 EmbHashMapInfo 结构的默认构造函数 +TEST(TestEmbHashMapInfo, DefaultConstructor) +{ + MxRec::EmbHashMapInfo embHashMapInfo; +} + +// 测试 All2AllInfo 结构的默认构造函数 +TEST(TestAll2AllInfo, DefaultConstructor) +{ + MxRec::All2AllInfo all2AllInfo; +} + +// 测试 UniqueInfo 结构的默认构造函数 +TEST(TestUniqueInfo, DefaultConstructor) +{ + MxRec::UniqueInfo uniqueInfo; +} + +// 测试 KeySendInfo 结构的默认构造函数 +TEST(TestKeySendInfo, DefaultConstructor) +{ + MxRec::KeySendInfo keySendInfo; +} + +// 测试 CkptTransData 结构的默认构造函数 +TEST(TestCkptTransData, DefaultConstructor) +{ + MxRec::CkptTransData ckptTransData; +} diff --git a/src/tests/utils/common_test.cpp b/src/tests/utils/common_test.cpp index 20813619..d918360b 100644 --- a/src/tests/utils/common_test.cpp +++ b/src/tests/utils/common_test.cpp @@ -1,8 +1,8 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: key process test + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: common.cpp test * Author: MindX SDK - * Create: 2022 + * Create: 2023 * History: NA */ @@ -37,4 +37,43 @@ TEST(common, InitializeInfo) isExceptionThrow = true; } ASSERT_EQ(isExceptionThrow, true); -} \ No newline at end of file +} + +// 测试 RandomInfo 构造函数 +TEST(TestRandomInfo, Basic) +{ + MxRec::RandomInfo info(0, 10, 1.0f, 0.0f, 1.0f); + EXPECT_EQ(info.start, 0); + EXPECT_EQ(info.len, 10); + EXPECT_EQ(info.constantVal, 1.0f); + EXPECT_EQ(info.randomMin, 0.0f); + EXPECT_EQ(info.randomMax, 1.0f); +} + +TEST(TestSetLog, Basic) +{ + // 在每次测试之前重置 gRankId + MxRec::GlogConfig::gRankId = ""; + + // 假设 GlobalEnv::glogStderrthreshold 已经被设置为一个有效的值 + MxRec::SetLog(0); + + // 检查 gGlogLevel 是否被正确设置 + EXPECT_EQ(MxRec::GlogConfig::gGlogLevel, GlobalEnv::glogStderrthreshold); + + // 检查 gRankId 是否被正确设置 + EXPECT_EQ(MxRec::GlogConfig::gRankId, "0"); +} + +TEST(TestGetThreadNumEnv, Basic) +{ + // 假设 GlobalEnv::keyProcessThreadNum 已经被设置为一个有效的值 + int num = MxRec::GetThreadNumEnv(); + // 检查返回的线程数是否正确 + EXPECT_EQ(num, GlobalEnv::keyProcessThreadNum); +} + +TEST(TestValidateReadFile, Basic) +{ + EXPECT_NO_THROW(MxRec::ValidateReadFile("/home/slice_0.data", 28000000)); +} diff --git a/src/tests/utils/config_test.cpp b/src/tests/utils/config_test.cpp new file mode 100644 index 00000000..24dd1c2e --- /dev/null +++ b/src/tests/utils/config_test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: config test + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + +#include +#include + +#include "utils/config.h" +#include "utils/logger.h" + +using namespace std; +using namespace MxRec; + +void SetEnvironmentVariables() +{ + setenv(RecEnvNames::APPLY_GRADIENTS_STRATEGY, "sum_same_id_gradients_and_apply", 1); + setenv(RecEnvNames::ACL_TIMEOUT, "100", 1); + setenv(RecEnvNames::HD_CHANNEL_SIZE, "50", 1); + setenv(RecEnvNames::KEY_PROCESS_THREAD_NUM, "8", 1); + setenv(RecEnvNames::MAX_UNIQUE_THREAD_NUM, "10", 1); + setenv(RecEnvNames::FAST_UNIQUE, "1", 1); + setenv(RecEnvNames::UPDATE_EMB_V2, "1", 1); + setenv(RecEnvNames::HOT_EMB_UPDATE_STEP, "2000", 1); + setenv(RecEnvNames::GLOG_STDERR_THRESHOLD, "1", 1); + setenv(RecEnvNames::USE_COMBINE_FAAE, "1", 1); + setenv(RecEnvNames::STAT_ON, "1", 1); + setenv(RecEnvNames::RECORD_KEY_COUNT, "1", 1); +} + +void UnsetEnvironmentVariables() +{ + unsetenv(RecEnvNames::APPLY_GRADIENTS_STRATEGY); + unsetenv(RecEnvNames::ACL_TIMEOUT); + unsetenv(RecEnvNames::HD_CHANNEL_SIZE); + unsetenv(RecEnvNames::KEY_PROCESS_THREAD_NUM); + unsetenv(RecEnvNames::MAX_UNIQUE_THREAD_NUM); + unsetenv(RecEnvNames::FAST_UNIQUE); + unsetenv(RecEnvNames::UPDATE_EMB_V2); + unsetenv(RecEnvNames::HOT_EMB_UPDATE_STEP); + unsetenv(RecEnvNames::GLOG_STDERR_THRESHOLD); + unsetenv(RecEnvNames::USE_COMBINE_FAAE); + unsetenv(RecEnvNames::STAT_ON); + unsetenv(RecEnvNames::RECORD_KEY_COUNT); +} + +TEST(GlobalEnv, DefaultValues) +{ + ASSERT_EQ(GlobalEnv::applyGradientsStrategy, ApplyGradientsStrategyOptions::DIRECT_APPLY); + ASSERT_EQ(GlobalEnv::aclTimeout, -1); + ASSERT_EQ(GlobalEnv::hdChannelSize, 40); + ASSERT_EQ(GlobalEnv::keyProcessThreadNum, 6); + ASSERT_EQ(GlobalEnv::maxUniqueThreadNum, 8); + ASSERT_EQ(GlobalEnv::fastUnique, false); + ASSERT_EQ(GlobalEnv::updateEmbV2, false); + ASSERT_EQ(GlobalEnv::hotEmbUpdateStep, 1000); + ASSERT_EQ(GlobalEnv::glogStderrthreshold, 0); + ASSERT_EQ(GlobalEnv::useCombineFaae, false); + ASSERT_EQ(GlobalEnv::statOn, false); + ASSERT_EQ(GlobalEnv::recordKeyCount, false); +} + +TEST(GlobalEnv, ConfigGlobalEnv) +{ + SetEnvironmentVariables(); + + ConfigGlobalEnv(); + + // 验证环境变量是否已经被正确配置 + ASSERT_EQ(GlobalEnv::applyGradientsStrategy, "sum_same_id_gradients_and_apply"); + ASSERT_EQ(GlobalEnv::aclTimeout, 100); + ASSERT_EQ(GlobalEnv::hdChannelSize, 50); + ASSERT_EQ(GlobalEnv::keyProcessThreadNum, 8); + ASSERT_EQ(GlobalEnv::maxUniqueThreadNum, 10); + ASSERT_EQ(GlobalEnv::fastUnique, true); + ASSERT_EQ(GlobalEnv::updateEmbV2, true); + ASSERT_EQ(GlobalEnv::hotEmbUpdateStep, 2000); + ASSERT_EQ(GlobalEnv::glogStderrthreshold, 1); + ASSERT_EQ(GlobalEnv::useCombineFaae, true); + ASSERT_EQ(GlobalEnv::statOn, true); + ASSERT_EQ(GlobalEnv::recordKeyCount, true); + + // 清除环境变量 + UnsetEnvironmentVariables(); +} + +TEST(LogTest, LogGlobalEnv) +{ + SetEnvironmentVariables(); + + ConfigGlobalEnv(); + + // 使用ASSERT_NO_THROW宏,断言LogGlobalEnv函数是否没有抛出异常 + ASSERT_NO_THROW(LogGlobalEnv()); + + UnsetEnvironmentVariables(); +} diff --git a/src/tests/utils/log_test.cpp b/src/tests/utils/log_test.cpp index 1eb3aa10..ebdb8487 100644 --- a/src/tests/utils/log_test.cpp +++ b/src/tests/utils/log_test.cpp @@ -1,6 +1,6 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: log test * Author: MindX SDK * Create: 2023 * History: NA diff --git a/src/tests/utils/safe_queue_test.cpp b/src/tests/utils/safe_queue_test.cpp new file mode 100644 index 00000000..9b2e78f7 --- /dev/null +++ b/src/tests/utils/safe_queue_test.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: safe_queue test + * Author: MindX SDK + * Create: 2023 + * History: NA + */ + +#include +#include + +#include +#include + +#include "utils/safe_queue.h" + +void TestSafeQueue() +{ + MxRec::SafeQueue sq; + + // 测试入队操作 + sq.Pushv(std::make_unique(1)); + assert(sq.Size() == 1); + + // 测试出队操作 + auto item = sq.TryPop(); + assert(*item == 1); + assert(sq.Size() == 0); + + // 测试空队列出队 + item = sq.TryPop(); + assert(item == nullptr); + + // 测试多线程环境 + std::thread t1([&sq]() { + const int maxCount = 100; + for (int i = 0; i < maxCount; ++i) { + sq.Pushv(std::make_unique(i)); + } + }); + + std::thread t2([&sq]() { + const int maxCount = 100; + for (int i = 0; i < maxCount; ++i) { + auto item = sq.TryPop(); + if (item != nullptr) { + std::cout << *item << std::endl; + } + } + }); + + t1.join(); + t2.join(); +} + +void TestSingletonQueue() +{ + auto sq1 = MxRec::SingletonQueue::GetInstances(0); + auto sq2 = MxRec::SingletonQueue::GetInstances(0); + + // 测试是否为单例 + assert(sq1 == sq2); + + // 测试超出范围 + auto sq3 = MxRec::SingletonQueue::GetInstances(MxRec::MAX_QUEUE_NUM); + assert(sq3 == nullptr); +} + +void TestPutDirty() +{ + MxRec::SafeQueue sq; + + // 创建一个智能指针对象ptr + std::unique_ptr ptr = std::make_unique(1); + + // 使用 PutDirty 方法将 unique_ptr 放入 SafeQueue + sq.PutDirty(std::move(ptr)); + + // 使用 GetOne 方法检查 unique_ptr 是否已经被正确地放入 SafeQueue + auto item = sq.GetOne(); + assert(*item == 1); +} + +void TestGetOneFromEmptyQueue() +{ + MxRec::SafeQueue sq; + + // 直接从新创建的 SafeQueue 中调用 GetOne + auto item = sq.GetOne(); + + // 检查返回的 unique_ptr 是否非空 + assert(item != nullptr); + + // 检查返回的 unique_ptr 指向的值是否为默认构造的 int + assert(*item == int()); +} + +TEST(LogTest, TestHybridMgmt) +{ + TestSafeQueue(); + TestSingletonQueue(); + TestPutDirty(); + TestGetOneFromEmptyQueue(); +} \ No newline at end of file -- Gitee From bbf22b8954fdeb68c6177688d649580ccdab2c24 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 11 Dec 2023 16:16:48 +0800 Subject: [PATCH 511/551] Match-id-0478a5c7d51dbcace0ece44792bf8c3a4e58e674 --- mx_rec/constants/constants.py | 1 - mx_rec/saver/saver.py | 5 +---- mx_rec/util/global_env_conf.py | 3 --- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 07d2aa7d..92dab5c3 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -97,7 +97,6 @@ class BaseEnum(Enum): class EnvOption(Enum): MXREC_LOG_LEVEL = "MXREC_LOG_LEVEL" - SAVE_EASY = "SAVE_EASY" RANK_TABLE_FILE = "RANK_TABLE_FILE" ASCEND_VISIBLE_DEVICES = "ASCEND_VISIBLE_DEVICES" CM_CHIEF_DEVICE = "CM_CHIEF_DEVICE" diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 361320f6..f53f198b 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -65,8 +65,6 @@ class Saver(object): self.save_op_dict = defaultdict(dict) self.restore_fetch_list = [] self.placeholder_dict = defaultdict(dict) - # save_easy_mode : only save the embedding and key data of sparse tables - self.save_easy_mode = (global_env.save_easy == Flag.TRUE.value) self._last_checkponts = [] self.build() @@ -90,8 +88,7 @@ class Saver(object): @performance("Save") def save(self, sess, save_path="model", global_step=None): """ - Save sparse tables. For local save, both save_easy mode and normal mode is supported. - For easy_save mode, checkpoint is saved in under format: + Save sparse tables. checkpoint is saved in under format: ./rank_id/HashTable/HBM/embed_table_name/key/xxx.data ./rank_id/HashTable/HBM/embed_table_name/key/xxx.attribute ./rank_id/HashTable/HBM/embed_table_name/embedding/xxx.data diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index 0ba5bdde..5f4f2a48 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -16,7 +16,6 @@ from mx_rec.validator.validator import para_checker_decorator, OptionValidator, @dataclass class RecEnv: mxrec_log_level: str - save_easy: str rank_table_file: str ascend_visible_devices: str cm_chief_device: str @@ -43,7 +42,6 @@ def get_global_env_conf() -> RecEnv: """ rec_env = RecEnv( mxrec_log_level=os.getenv(EnvOption.MXREC_LOG_LEVEL.value, RecPyLogLevel.INFO.value), - save_easy=os.getenv(EnvOption.SAVE_EASY.value, Flag.FALSE.value), rank_table_file=os.getenv(EnvOption.RANK_TABLE_FILE.value, EMPTY_STR), ascend_visible_devices=os.getenv(EnvOption.ASCEND_VISIBLE_DEVICES.value), cm_chief_device=os.getenv(EnvOption.CM_CHIEF_DEVICE.value), @@ -69,7 +67,6 @@ def get_global_env_conf() -> RecEnv: @para_checker_decorator(check_option_list=[ ("mxrec_log_level", OptionValidator, {"options": [i.value for i in list(RecPyLogLevel)]}), - ("save_easy", OptionValidator, {"options": [i.value for i in list(Flag)]}), ("rank_table_file", DirectoryValidator, {}, ["check_exists_if_not_empty"]), ("tf_device", OptionValidator, {"options": [i.value for i in list(TFDevice)]}), ("apply_gradients_strategy", OptionValidator, {"options": [i.value for i in list(ApplyGradientsStrategy)]}), -- Gitee From 76eb29ed2af18c47185aa5be4636055908d86909 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 13 Dec 2023 15:46:20 +0800 Subject: [PATCH 512/551] Match-id-c31029373dfa740fff05aed73d444eeedd5135b1 --- mx_rec/core/asc/feature_spec.py | 10 +- tests/mx_rec/core/generator_dataset.py | 68 +++++ tests/mx_rec/core/mock_class.py | 86 ++++++ tests/mx_rec/core/test_embedding.py | 112 -------- tests/mx_rec/core/test_feature_spec.py | 348 +++++++++++++++++++++++++ 5 files changed, 509 insertions(+), 115 deletions(-) create mode 100644 tests/mx_rec/core/generator_dataset.py create mode 100644 tests/mx_rec/core/mock_class.py delete mode 100644 tests/mx_rec/core/test_embedding.py create mode 100644 tests/mx_rec/core/test_feature_spec.py diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 1eb0c35a..52f2be23 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -42,14 +42,18 @@ class FeatureSpec: ]) def __init__(self, name: str, table_name: Optional[str] = None, - index_key: Optional[str] = None, + index_key: Union[None, int, str] = None, access_threshold: Optional[int] = None, eviction_threshold: Optional[int] = None, is_timestamp: Optional[bool] = None, - batch_size: Optional[int] = None, faae_coefficient: int = 1): + batch_size: Optional[int] = None, faae_coefficient: Optional[int] = 1): feature_spec_global_id.increase() spec_name = name + f"_{feature_spec_global_id}" self.name = spec_name - self._index_key = index_key if index_key else name + # 防止当index_key=0时,判断条件被误判为False + if isinstance(index_key, int): + self._index_key = index_key + else: + self._index_key = index_key if index_key else name self._table_name = fix_invalid_table_name(table_name if table_name else name) self._feat_cnt = None self._access_threshold = access_threshold diff --git a/tests/mx_rec/core/generator_dataset.py b/tests/mx_rec/core/generator_dataset.py new file mode 100644 index 00000000..cf53f471 --- /dev/null +++ b/tests/mx_rec/core/generator_dataset.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +from typing import Callable + +import numpy as np +import tensorflow as tf +from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter + + +class Config: + """ + 配置类 + """ + + def __init__(self, batch_size=32, batch_number=100): + self.batch_size = batch_size + self.batch_number = batch_number + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.float32 + self.item_range = 16 + self.item_feat_cnt = 8 + + +def __get_data_generator(cfg: Config) -> Callable: + """ + 生成数据迭代器 + + Args: + cfg: 配置类实例 + + Returns: 数据迭代器fn + + """ + + def data_generator(): + i = 0 + while i < cfg.batch_number: + item_ids = np.random.randint(0, cfg.item_range, (cfg.batch_size, cfg.item_feat_cnt)) + label_0 = np.random.randint(0, 2, (cfg.batch_size,)) + + yield {"item_ids": item_ids, + "label_0": label_0} + i += 1 + + return data_generator + + +def generate_dataset(cfg: Config) -> DatasetV1Adapter: + """ + 生成dataset + + Args: + cfg: 配置类实例 + + Returns: dataset + + """ + + dataset = tf.compat.v1.data.Dataset.from_generator( + generator=__get_data_generator(cfg), + output_types={"item_ids": cfg.key_type, + "label_0": cfg.label_type}, + output_shapes={"item_ids": tf.TensorShape([cfg.batch_size, cfg.item_feat_cnt]), + "label_0": tf.TensorShape([cfg.batch_size])}) + return dataset diff --git a/tests/mx_rec/core/mock_class.py b/tests/mx_rec/core/mock_class.py new file mode 100644 index 00000000..5fe33363 --- /dev/null +++ b/tests/mx_rec/core/mock_class.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import tensorflow as tf + + +class MockSparseEmbedding: + """ + 原始SparseEmbedding会调用很多接口,用MockSparseEmbedding防止mock过多接口 + """ + + def __init__(self, table_name="test_table", slice_device_vocabulary_size=10, embedding_size=5, init_param=1., + emb_initializer=tf.zeros_initializer()): + self.table_name = table_name + self.slice_device_vocabulary_size = slice_device_vocabulary_size + self.embedding_size = tf.TensorShape([embedding_size]) + self.init_param = init_param + self.emb_initializer = emb_initializer + self.variable = tf.compat.v1.get_variable(table_name, + shape=[slice_device_vocabulary_size, embedding_size], + trainable=False, initializer=tf.ones_initializer()) + + +class MockHostPipeLineOps: + """ + 用于mock host_pipeline_ops,返回静态/动态 readEmbKey + """ + + def __init__(self): + def _mock_read_emb_key_v2_fn(concat_tensor, **kwargs): + return 0 + + def _mock_read_emb_key_v2_dynamic_fn(concat_tensor, tensorshape_split_list, **kwargs): + return 1 + + self.read_emb_key_v2 = _mock_read_emb_key_v2_fn + self.read_emb_key_v2_dynamic = _mock_read_emb_key_v2_dynamic_fn + + +class MockHcclOps: + """ + 用于mock hccl_ops + """ + + def __init__(self, shape=None): + def _mock_all_to_all_v(send_data, send_counts, send_displacements, recv_counts, recv_displacements): + if shape is None: + return tf.constant(1, dtype=tf.int64, name="all_to_all_v") + return tf.ones(shape, dtype=tf.int64, name="all_to_all_v") + + def _mock_all_to_all_v_c(send_data, send_count_matrix, rank): + if shape is None: + return tf.constant(1, dtype=tf.int64, name="all_to_all_v_c") + return tf.ones(shape, dtype=tf.int64, name="all_to_all_v_c") + + self.all_to_all_v = _mock_all_to_all_v + self.all_to_all_v_c = _mock_all_to_all_v_c + + +class MockOptimizer: + """ + 用于mock optimizer + """ + + def __init__(self): + def _mock_insert_slot(slot, named_slot_key, slot_name): + return "mock_insert_slot" + + self.insert_slot = _mock_insert_slot + + +class MockAscManager: + """ + 用于mock get_asc_manager() + """ + + def __init__(self): + def _mock_get_table_size(self): + return 0 + + def _mock_get_table_capacity(self): + return 1 + + self.get_table_size = _mock_get_table_size + self.get_table_capacity = _mock_get_table_capacity diff --git a/tests/mx_rec/core/test_embedding.py b/tests/mx_rec/core/test_embedding.py deleted file mode 100644 index cea7c35e..00000000 --- a/tests/mx_rec/core/test_embedding.py +++ /dev/null @@ -1,112 +0,0 @@ -#!/usr/bin/env python3 -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. - - -import unittest -import tensorflow as tf -from mx_rec.core.embedding import set_specific_value_for_non_valid_key - -if tf.__version__.startswith("1"): - tf.enable_eager_execution() -else: - tf.compat.v1.enable_eager_execution() - - -class TestSetSpecificValueForNonValidKey(unittest.TestCase): - """ - Test Suite for set_specific_value_for_non_valid_key. - """ - - def setUp(self): - """ - 准备步骤 - :return:无 - """ - super().setUp() - - def tearDown(self): - """ - 销毁步骤 - :return: 无 - """ - super().tearDown() - - def test_given_admit_if_turned_off_then_return_raw_embedding(self): - # given - id_offsets = tf.constant([0, 1, 2, 3]) - embeddings = tf.ones(shape=(4, 10), dtype=tf.float32) - access_threshold = None - serving_default_value = tf.ones(shape=(4, 10), dtype=tf.float32) * 2 - - # when - modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, - access_threshold, serving_default_value) - - # then - result = bool(tf.reduce_all(tf.equal(embeddings, modified_emb))) - self.assertTrue(result) - - def test_given_no_default_value_and_all_valid_key_then_return_raw_embedding(self): - # given - id_offsets = tf.constant([0, 1, 2, 3]) - embeddings = tf.ones(shape=(4, 10), dtype=tf.float32) - access_threshold = 1 - serving_default_value = None - - # when - modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, - access_threshold, serving_default_value) - - # then - result = bool(tf.reduce_all(tf.equal(embeddings, modified_emb))) - self.assertTrue(result) - - def test_given_no_default_value_and_invalid_key_then_emb_of_invalid_key_set_to_zero(self): - # given - id_offsets = tf.constant([-1, 1, 2, 3]) - embeddings = tf.ones(shape=(4, 2), dtype=tf.float32) - access_threshold = 1 - serving_default_value = None - - # when - modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, - access_threshold, serving_default_value) - - # then - result = modified_emb.numpy().tolist()[0] - self.assertEqual([0, 0], result) - - def test_given_default_value_and_invalid_key_then_emb_of_invalid_key_set_to_default_value(self): - # given - id_offsets = tf.constant([-1, 1, 2, 3]) - embeddings = tf.ones(shape=(4, 2), dtype=tf.float32) - access_threshold = 1 - serving_default_value = tf.ones(shape=(4, 2), dtype=tf.float32) * 2 - - # when - modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, - access_threshold, serving_default_value) - - # then - result = modified_emb.numpy().tolist()[0] - self.assertEqual([2, 2], result) - - def test_given_default_value_and_with_all_invalid_key_then_emb_of_invalid_key_set_to_default_value(self): - # given - id_offsets = tf.constant([-1, -1, -1, -1]) - embeddings = tf.ones(shape=(4, 2), dtype=tf.float32) - access_threshold = 1 - serving_default_value = tf.ones(shape=(4, 2), dtype=tf.float32) * 2 - - # when - modified_emb = set_specific_value_for_non_valid_key(id_offsets, embeddings, - access_threshold, serving_default_value) - - # then - result = bool(tf.reduce_all(tf.equal(serving_default_value, modified_emb))) - self.assertTrue(result) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/mx_rec/core/test_feature_spec.py b/tests/mx_rec/core/test_feature_spec.py new file mode 100644 index 00000000..7adad830 --- /dev/null +++ b/tests/mx_rec/core/test_feature_spec.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import unittest +from unittest import mock +from functools import reduce + +import tensorflow as tf + +from mx_rec.core.asc.feature_spec import FeatureSpec + + +class TestFeatureSpecClass(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.feature_spec.FeatureSpec'. + """ + + def setUp(self): + # 每个测试方法执行前,将FeatureSpec的静态成员设为默认值 + FeatureSpec.instance_count_train = 0 + FeatureSpec.instance_count_eval = 0 + FeatureSpec.use_timestamp_train = False + FeatureSpec.use_timestamp_eval = False + + def test_init_case1(self): + """ + case1: 初始化实例成功 + """ + self.assertIsInstance(FeatureSpec("case1"), FeatureSpec) + + def test_init_case2(self): + """ + case2: 当access_threshold为None,eviction_threshold不为None时,初始化失败并抛出异常 + """ + with self.assertRaises(ValueError): + FeatureSpec("case2", eviction_threshold=20) + + +class TestIncludeTimestampFuncOfFeatureSpecClass(TestFeatureSpecClass): + """ + Test for 'mx_rec.core.asc.feature_spec.FeatureSpec.include_timestamp'. + """ + + def setUp(self): + # 每个测试方法执行前,将FeatureSpec的静态成员设为默认值 + super().setUp() + + def test_include_timestamp_case1(self): + """ + case1: is_training为False + """ + case1_feature_spec = FeatureSpec("case1") + is_training = False + case1_feature_spec.include_timestamp(is_training) + self.assertTrue(case1_feature_spec.use_timestamp_eval) + + def test_include_timestamp_case2(self): + """ + case2: is_training为True + """ + case2_feature_spec = FeatureSpec("case2") + is_training = True + case2_feature_spec.include_timestamp(is_training) + self.assertTrue(case2_feature_spec.use_timestamp_train) + + def test_include_timestamp_case3(self): + """ + case3: is_training为True并调用2次include_timestamp,抛出EnvironmentError异常 + """ + case3_feature_spec = FeatureSpec("case3") + is_training = True + case3_feature_spec.include_timestamp(is_training) + with self.assertRaises(EnvironmentError): + case3_feature_spec.include_timestamp(is_training) + + +class TestUseTimestampFuncOfFeatureSpecClass(TestFeatureSpecClass): + """ + Test for 'mx_rec.core.asc.feature_spec.FeatureSpec.use_timestamp'. + """ + + def setUp(self): + # 每个测试方法执行前,将FeatureSpec的静态成员设为默认值 + super().setUp() + + def test_use_timestamp_case1(self): + """ + case1: is_training为False + """ + + case1_feature_spec = FeatureSpec("case1") + is_training = False + self.assertEqual(case1_feature_spec.use_timestamp(is_training), FeatureSpec.use_timestamp_eval) + + def test_use_timestamp_case2(self): + """ + case2: is_training为True + """ + + case2_feature_spec = FeatureSpec("case2") + is_training = True + self.assertEqual(case2_feature_spec.use_timestamp(is_training), FeatureSpec.use_timestamp_train) + + +class TestSetFeatPosFuncOfFeatureSpecClass(TestFeatureSpecClass): + """ + Test for 'mx_rec.core.asc.feature_spec.FeatureSpec.set_feat_pos'. + """ + + def setUp(self): + # 每个测试方法执行前,将FeatureSpec的静态成员设为默认值 + super().setUp() + + def test_set_feat_pos_case1(self): + """ + case1: is_training为False + """ + + case1_feature_spec = FeatureSpec("case1") + is_training = False + set_times = 5 + for _ in range(set_times): + case1_feature_spec.set_feat_pos(is_training) + # 因为feat_pos_eval/train = instance_count_eval/train ++,因此feat_pos_eval会少1 + self.assertEqual(case1_feature_spec.feat_pos_eval, set_times - 1) + self.assertEqual(case1_feature_spec.instance_count_eval, set_times) + + def test_set_feat_pos_case2(self): + """ + case2: is_training为True + """ + + case2_feature_spec = FeatureSpec("case2") + is_training = True + set_times = 5 + for _ in range(set_times): + case2_feature_spec.set_feat_pos(is_training) + # 因为feat_pos_eval/train = instance_count_eval/train ++,因此feat_pos_eval会少1 + self.assertEqual(case2_feature_spec.feat_pos_train, set_times - 1) + self.assertEqual(case2_feature_spec.instance_count_train, set_times) + + +class TestInsertPipelineModeFuncOfFeatureSpecClass(TestFeatureSpecClass): + """ + Test for 'mx_rec.core.asc.feature_spec.FeatureSpec.insert_pipeline_mode'. + """ + + def setUp(self): + # 每个测试方法执行前,将FeatureSpec的静态成员设为默认值 + super().setUp() + + @mock.patch.multiple("mx_rec.core.asc.feature_spec", + insert_training_mode_channel_id=mock.MagicMock(return_value=None)) + def test_insert_pipeline_mode_case1(self): + """ + case1: mode为非bool类型,抛出异常 + """ + + case1_feature_spec = FeatureSpec("case1") + mode = "xxx" + with self.assertRaises(TypeError): + case1_feature_spec.insert_pipeline_mode(mode) + + @mock.patch.multiple("mx_rec.core.asc.feature_spec", + insert_training_mode_channel_id=mock.MagicMock(return_value=None)) + def test_insert_pipeline_mode_case2(self): + """ + case2: mode为False + """ + + case2_feature_spec = FeatureSpec("case2") + mode = False + case2_feature_spec.insert_pipeline_mode(mode) + self.assertSetEqual(case2_feature_spec.pipeline_mode, {False}) + + @mock.patch.multiple("mx_rec.core.asc.feature_spec", + insert_training_mode_channel_id=mock.MagicMock(return_value=None)) + def test_insert_pipeline_mode_case3(self): + """ + case3: mode为True + """ + + case3_feature_spec = FeatureSpec("case3") + mode = True + case3_feature_spec.insert_pipeline_mode(mode) + self.assertSetEqual(case3_feature_spec.pipeline_mode, {True}) + + @mock.patch.multiple("mx_rec.core.asc.feature_spec", + insert_training_mode_channel_id=mock.MagicMock(return_value=None)) + def test_insert_pipeline_mode_case4(self): + """ + case4: mode为True,在已经设置过一次True的情况下,设置第二次后无报错 + """ + + case4_feature_spec = FeatureSpec("case4") + mode = True + case4_feature_spec.insert_pipeline_mode(mode) + case4_feature_spec.insert_pipeline_mode(mode) + self.assertSetEqual(case4_feature_spec.pipeline_mode, {True}) + + +class TestSetFeatAttributeFuncOfFeatureSpecClass(TestFeatureSpecClass): + """ + Test for 'mx_rec.core.asc.feature_spec.FeatureSpec.set_feat_attribute'. + """ + + def setUp(self): + self.is_training = True + # 每个测试方法执行前,将FeatureSpec的静态成员设为默认值 + super().setUp() + + @mock.patch.multiple("mx_rec.core.asc.feature_spec", + insert_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True), + insert_feature_spec=mock.MagicMock(return_value=None)) + def test_set_feat_attribute_case1(self): + """ + case1: 未初始化initialized成员,静态shape,tensor的rank为0,抛出异常 + """ + case1_tensor = tf.ones([], tf.int32) + case1_feature_spec = FeatureSpec("case1") + with self.assertRaises(ValueError): + case1_feature_spec.set_feat_attribute(case1_tensor, self.is_training) + + @mock.patch.multiple("mx_rec.core.asc.feature_spec", + insert_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True), + insert_feature_spec=mock.MagicMock(return_value=None)) + def test_set_feat_attribute_case2(self): + """ + case2: 未初始化initialized成员,静态shape,tensor的rank等于1 + """ + case2_tensor = tf.ones([32], tf.int32) + case2_feature_spec = FeatureSpec("case2") + case2_feature_spec.set_feat_attribute(case2_tensor, self.is_training) + self.assertEqual(case2_feature_spec.feat_cnt, 1) + + @mock.patch.multiple("mx_rec.core.asc.feature_spec", + insert_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True), + insert_feature_spec=mock.MagicMock(return_value=None)) + def test_set_feat_attribute_case3(self): + """ + case3: 未初始化initialized成员,静态shape,tensor的rank大于1 + """ + test_tensor_shape = [32, 16, 2] + test_tensor = tf.ones(test_tensor_shape, tf.int32) + case3_feature_spec = FeatureSpec("case3") + case3_feature_spec.set_feat_attribute(test_tensor, self.is_training) + reduce_shape = reduce(lambda x, y: x * y, test_tensor_shape[1:]) + self.assertTrue(case3_feature_spec.initialized) + self.assertListEqual(case3_feature_spec.dims, test_tensor_shape) + self.assertEqual(case3_feature_spec.rank, len(test_tensor_shape)) + self.assertEqual(case3_feature_spec.batch_size, test_tensor_shape[0]) + self.assertEqual(case3_feature_spec.feat_cnt, reduce_shape) + self.assertEqual(case3_feature_spec.split, test_tensor_shape[0] * reduce_shape) + + @mock.patch.multiple("mx_rec.core.asc.feature_spec", + insert_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=False), + insert_feature_spec=mock.MagicMock(return_value=None)) + def test_set_feat_attribute_case4(self): + """ + case4: 未初始化initialized成员,动态shape,tensor的rank大于1 + """ + test_tensor_shape = [32, 3] + test_tensor = tf.ones(test_tensor_shape, tf.int32) + case4_feature_spec = FeatureSpec("case4") + case4_feature_spec.set_feat_attribute(test_tensor, self.is_training) + reduce_shape = reduce(lambda x, y: x * y, test_tensor_shape) + with tf.Session() as sess: + self.assertTrue(case4_feature_spec.initialized) + self.assertEqual(sess.run(case4_feature_spec.dims), reduce_shape) + self.assertEqual(case4_feature_spec.rank, 1) + self.assertEqual(sess.run(case4_feature_spec.split), reduce_shape) + self.assertEqual(sess.run(case4_feature_spec.batch_size), reduce_shape) + self.assertEqual(case4_feature_spec.feat_cnt, 1) + + @mock.patch.multiple("mx_rec.core.asc.feature_spec", + insert_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True), + insert_feature_spec=mock.MagicMock(return_value=None)) + def test_set_feat_attribute_case5(self): + """ + case5: 静态shape,tensor的rank大于1,再次用不同tensor初始化initialized成员,抛出异常 + """ + case5_feature_spec = FeatureSpec("case5") + case5_feature_spec.set_feat_attribute(tf.ones([32, 3], tf.int32), self.is_training) + # 再次初始化 + with self.assertRaises(ValueError): + case5_feature_spec.set_feat_attribute(tf.ones([64, 3], tf.int32), self.is_training) + + +class TestGetFeatureSpecFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.feature_spec.get_feature_spec'. + """ + + def test_get_feature_spec_case1(self): + """ + case1: access_and_evict_config为None + """ + + from mx_rec.core.asc.feature_spec import get_feature_spec + + access_and_evict_config = None + case1_results = get_feature_spec("fake_name", access_and_evict_config) + self.assertIsInstance(case1_results, FeatureSpec) + self.assertEqual(case1_results.access_threshold, None) + self.assertEqual(case1_results.eviction_threshold, None) + self.assertEqual(case1_results.faae_coefficient, None) + + def test_get_feature_spec_case2(self): + """ + case2: access_and_evict_config不为None + """ + + from mx_rec.core.asc.feature_spec import get_feature_spec + + access_and_evict_config = dict(access_threshold=10, eviction_threshold=20, faae_coefficient=2) + case2_results = get_feature_spec("fake_name", access_and_evict_config) + self.assertIsInstance(case2_results, FeatureSpec) + self.assertEqual(case2_results.access_threshold, access_and_evict_config.get("access_threshold")) + self.assertEqual(case2_results.eviction_threshold, access_and_evict_config.get("eviction_threshold")) + self.assertEqual(case2_results.faae_coefficient, access_and_evict_config.get("faae_coefficient")) + + +class TestSetTemporaryFeatureSpecAttributeFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.feature_spec.set_temporary_feature_spec_attribute'. + """ + + def test_set_temporary_feature_spec_attribute(self): + from mx_rec.core.asc.feature_spec import set_temporary_feature_spec_attribute + + user_ids_feature_spec = FeatureSpec("user_ids") + total_feature_count = 200 + set_temporary_feature_spec_attribute(user_ids_feature_spec, total_feature_count) + self.assertEqual(user_ids_feature_spec.batch_size, total_feature_count) + self.assertEqual(user_ids_feature_spec.feat_cnt, 1) + self.assertListEqual(user_ids_feature_spec.dims, [total_feature_count, 1]) + self.assertTrue(user_ids_feature_spec.initialized) + self.assertSetEqual(user_ids_feature_spec.pipeline_mode, {True, False}) + + +if __name__ == '__main__': + unittest.main() -- Gitee From 23c050dfed1b32be684792e938c795579dba80c7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 13 Dec 2023 16:20:31 +0800 Subject: [PATCH 513/551] Match-id-34595b8ef77b0c925d3b70cfa21a07e9784181ff --- src/test_ut.sh | 2 +- src/tests/emb_hashmap/emb_hashmap_test.cpp | 66 +++++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/test_ut.sh b/src/test_ut.sh index 98cdcf39..93c60f25 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -81,7 +81,7 @@ cd "$(dirname "${PWD}")" COVERAGE_FILE=coverage.info REPORT_FOLDER=coverage_report lcov --rc lcov_branch_coverage=1 -c -d build -o "${COVERAGE_FILE}"_tmp -lcov -r "${COVERAGE_FILE}"_tmp 'ut/*' '/usr1/mxRec/src/core/key_process*' '/usr1/mxRec/src/core/hybrid_mgmt*' '/usr1/mxRec/src/core/host_emb*' '/usr1/mxRec/src/core/emb_table*' '7/ext*' '*7/bits*' 'platform/*' '/usr/local/*' '/usr/include/*' '/opt/buildtools/python-3.7.5/lib/python3.7/site-packages/tensorflow*' 'tests/*' '/usr1/mxRec/src/core/ock_ctr_common/include*' --rc lcov_branch_coverage=1 -o "${COVERAGE_FILE}" +lcov -r "${COVERAGE_FILE}"_tmp 'ut/*' '/usr1/mxRec/src/core/key_process*' '/usr1/mxRec/src/core/hybrid_mgmt*' '/usr1/mxRec/src/core/host_emb*' '7/ext*' '*7/bits*' 'platform/*' '/usr/local/*' '/usr/include/*' '/opt/buildtools/python-3.7.5/lib/python3.7/site-packages/tensorflow*' 'tests/*' '/usr1/mxRec/src/core/ock_ctr_common/include*' --rc lcov_branch_coverage=1 -o "${COVERAGE_FILE}" genhtml --rc genhtml_branch_coverage=1 "${COVERAGE_FILE}" -o "${REPORT_FOLDER}" [ -d "${COVERAGE_FILE}"_tmp ] && rm -rf "${COVERAGE_FILE}"_tmp [ -d "${COVERAGE_FILE}" ] && rm -rf "${COVERAGE_FILE}" diff --git a/src/tests/emb_hashmap/emb_hashmap_test.cpp b/src/tests/emb_hashmap/emb_hashmap_test.cpp index f55b3b21..4c08e315 100644 --- a/src/tests/emb_hashmap/emb_hashmap_test.cpp +++ b/src/tests/emb_hashmap/emb_hashmap_test.cpp @@ -9,7 +9,7 @@ #include #include "emb_hashmap/emb_hashmap.h" - +#include "hybrid_mgmt/hybrid_mgmt_block.h" #include "ssd_cache/cache_manager.h" #include "utils/common.h" @@ -108,4 +108,68 @@ TEST(EmbHashMap, TestFindOffset) ASSERT_EQ(ddrKeyMap.Get(1), NEGATIVE_INT_1); Logger::SetLevel(logLevelTemp); // 恢复日志级别 LOG_INFO("test TestFindOffset end."); +} + +TEST(EmbHashMap, TESTGetHashMaps) +{ + string embTableName = "table1"; + EmbHashMap hostHashMaps; + RankInfo rankInfo; + auto embInfo = GetEmbInfoList(); + hostHashMaps.Init(rankInfo, embInfo, false); + CacheManager cacheManager; + cacheManager.Init(nullptr, embInfo); + hostHashMaps.isSSDEnabled = true; + hostHashMaps.cacheManager = &cacheManager; + int channelId = 0; + size_t currentBatchId = 0; + size_t keepBatchId = 0; + int opTimes = 0; + + vector keys = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + hostHashMaps.FindOffset(embTableName, keys, currentBatchId++, keepBatchId++, channelId); + RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); + auto testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); + hostHashMaps.embHashMaps.at(embTableName).maxOffsetOld = testEmbHashMap.maxOffset; + // 增加10个key, offset长度变为10 + ASSERT_EQ(testEmbHashMap.maxOffset, 10); + + keys = {11, 12, 13, 14, 15, 16, 17, 18, 19, 20}; + hostHashMaps.FindOffset(embTableName, keys, currentBatchId++, keepBatchId++, channelId); + RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); + testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); + // 再增加10个key,offset变为20 + ASSERT_EQ(testEmbHashMap.maxOffset, 20); + + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + hybridMgmtBlock->lastRunChannelId = channelId; + hybridMgmtBlock->hybridBatchId[0] = 1; + testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); + // 回退一步,offset变回10 + ASSERT_EQ(testEmbHashMap.maxOffset, 10); + + hybridMgmtBlock->hybridBatchId[0] = 2; + // 回退2步,抛出异常 + ASSERT_THROW(hostHashMaps.GetHashMaps(), HybridMgmtBlockingException); + hybridMgmtBlock->hybridBatchId[0] = 0; + + keys = {10, 11}; + hostHashMaps.EvictDeleteEmb(embTableName, keys); + testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); + // 淘汰1个hbm key和1个ddr key,表中无法查找到该key + ASSERT_EQ(testEmbHashMap.hostHashMap.find(10), testEmbHashMap.hostHashMap.end()); + ASSERT_EQ(testEmbHashMap.hostHashMap.find(11), testEmbHashMap.hostHashMap.end()); + ASSERT_EQ(cacheManager.excludeDDRKeyCountMap[embTableName][11], 0); + ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].Get(10), -1); + + keys = {1, 2}; + hostHashMaps.FindOffset(embTableName, keys, currentBatchId++, keepBatchId++, channelId); + RefreshSwapFreqInfoAndPrint(hostHashMaps, embTableName, opTimes++); + testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); + // 从ddr中换回2个key到hbm,交换变量长度为2 + ASSERT_EQ(testEmbHashMap.ddr2HbmKeys.size(), 2); + hostHashMaps.ClearLookupAndSwapOffset(hostHashMaps.embHashMaps.at(embTableName)); + testEmbHashMap = hostHashMaps.GetHashMaps().at(embTableName); + // 清理后,交换变量长度为0 + ASSERT_EQ(testEmbHashMap.ddr2HbmKeys.size(), 0); } \ No newline at end of file -- Gitee From ee7f09398c21b57aeeff76cb216756a0745ef926 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 13 Dec 2023 20:16:52 +0800 Subject: [PATCH 514/551] Match-id-2ddc77afc761488c4ec483dbb346983ec91695c3 --- tests/mx_rec/graph/__init__.py | 1 + tests/mx_rec/graph/mock_dataset.py | 30 ++ tests/mx_rec/graph/test_merge_lookup.py | 79 +++++ tests/mx_rec/graph/test_modifier.py | 451 ++++++++++++++++++++++++ tests/mx_rec/graph/test_utils.py | 157 +++++++++ 5 files changed, 718 insertions(+) create mode 100644 tests/mx_rec/graph/__init__.py create mode 100644 tests/mx_rec/graph/mock_dataset.py create mode 100644 tests/mx_rec/graph/test_merge_lookup.py create mode 100644 tests/mx_rec/graph/test_modifier.py create mode 100644 tests/mx_rec/graph/test_utils.py diff --git a/tests/mx_rec/graph/__init__.py b/tests/mx_rec/graph/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/mx_rec/graph/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/mx_rec/graph/mock_dataset.py b/tests/mx_rec/graph/mock_dataset.py new file mode 100644 index 00000000..8148c397 --- /dev/null +++ b/tests/mx_rec/graph/mock_dataset.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +from typing import Dict, Iterable + +import numpy as np +import tensorflow as tf + + +def gen_mock_dataset(batch_num: int = 100, batch_size: int = 4096) -> tf.compat.v1.data.Dataset: + def data_generator() -> Iterable[Dict[str, np.ndarray]]: + i = 0 + while i < batch_num: + mock_ids = np.random.randint(low=0, high=100, size=(batch_size, 8)) + mock_labels = np.random.randint(low=0, high=100, size=(batch_size, 1)) + mock_timestamp = np.random.randint(low=0, high=100, size=(batch_size, 1)) + yield {"mock_ids": mock_ids, "mock_labels": mock_labels, "mock_timestamp": mock_timestamp} + i += 1 + + dataset = tf.compat.v1.data.Dataset.from_generator( + generator=data_generator, + output_types={"mock_ids": tf.int64, "mock_labels": tf.int32, "mock_timestamp": tf.int32}, + output_shapes={ + "mock_ids": tf.TensorShape([batch_size, 8]), + "mock_labels": tf.TensorShape([batch_size, 1]), + "mock_timestamp": tf.TensorShape([batch_size, 1]), + }, + ) + return dataset diff --git a/tests/mx_rec/graph/test_merge_lookup.py b/tests/mx_rec/graph/test_merge_lookup.py new file mode 100644 index 00000000..1bf4311f --- /dev/null +++ b/tests/mx_rec/graph/test_merge_lookup.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +import unittest +from typing import Union +from unittest import TestCase +from unittest.mock import Mock, patch + +import tensorflow as tf +from tensorflow import Tensor +import mx_rec.graph.merge_lookup as merge_lookup +from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCAnchorAttr + + +def mock_get_anchor_attribute(anchor: Tensor, attr: ASCAnchorAttr) -> Union[bool, Mock]: + if attr == ASCAnchorAttr.IS_TRAINING: + return True + if attr == ASCAnchorAttr.IS_GRAD: + return True + if attr == ASCAnchorAttr.TABLE_INSTANCE: + mock_table_instance = Mock() + mock_table_instance.table_name = "mock_table_name" + mock_table_instance.lookup_name_dict = {True: ["lookup_1", "lookup_2"]} + mock_table_instance.send_count = 4096 * 8 + return mock_table_instance + if attr == ASCAnchorAttr.FEATURE_SPEC: + mock_feature_spec = Mock() + mock_feature_spec.name = "mock_feature_spec_name" + return mock_feature_spec + + raise ValueError(f"Unsupported param 'attr' for enum class 'ASCAnchorAttr': attr={attr}.") + + +class DoMergeLookupTest(TestCase): + def tearDown(self): + tf.compat.v1.reset_default_graph() + + @patch.multiple( + "mx_rec.graph.merge_lookup", + get_modify_graph=Mock(return_value=True), + get_merged_multi_lookup=Mock(return_value=False), + get_use_static=Mock(return_value=False), + replace_anchor_vec=Mock(), + insert_merged_multi_lookup=Mock(), + ) + @patch.multiple("mx_rec.graph.merge_lookup.SparseEmbedding", get_anchor_attribute=mock_get_anchor_attribute) + def test_ok(self): + mock_cutting_point = tf.identity(tf.zeros(shape=(4096, 8))) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, mock_cutting_point) + merge_lookup.do_merge_lookup() + + @patch.multiple( + "mx_rec.graph.merge_lookup", + get_modify_graph=Mock(return_value=False), + ) + def test_ok_disable_modify_graph(self): + merge_lookup.do_merge_lookup() + + @patch.multiple( + "mx_rec.graph.merge_lookup", + get_modify_graph=Mock(return_value=True), + get_merged_multi_lookup=Mock(return_value=True), + ) + def test_ok_already_exec_merged_lookup(self): + merge_lookup.do_merge_lookup() + + @patch.multiple( + "mx_rec.graph.merge_lookup", + get_modify_graph=Mock(return_value=True), + get_merged_multi_lookup=Mock(return_value=False), + ) + def test_err_empty_cutting_point_list(self): + with self.assertRaises(RuntimeError): + merge_lookup.do_merge_lookup() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mx_rec/graph/test_modifier.py b/tests/mx_rec/graph/test_modifier.py new file mode 100644 index 00000000..14d1633e --- /dev/null +++ b/tests/mx_rec/graph/test_modifier.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +import os +import unittest +from collections import defaultdict +from unittest import TestCase +from unittest.mock import patch, Mock +from typing import Union, Callable + +import tensorflow as tf +from tensorflow import Tensor +from mx_rec.constants.constants import ( + ASCEND_CUTTING_POINT_INITIALIZER, + ASCEND_SPARSE_LOOKUP_ENTRANCE, + ASCEND_TIMESTAMP, + ASCAnchorAttr, +) +from mx_rec.graph.modifier import ( + GraphModifierHook, + find_make_iterator_op, + find_target_dataset_op, + find_target_instance_dataset, + generate_get_next_op_specs, + get_dataset_op, + get_input_index_list, + get_passing_tensor_list, + get_preprocessing_map_func, + get_src_dataset, + get_tgt_dataset, + get_timestamp_index, + modify_graph_for_asc, +) + +from tests.mx_rec.graph.mock_dataset import gen_mock_dataset + + +def _gen_mock_get_anchor_attribute(is_training: bool = True) -> Callable: + def mock_get_anchor_attribute(anchor: Tensor, attr: ASCAnchorAttr) -> Union[bool, Mock]: + if attr == ASCAnchorAttr.IS_TRAINING: + return is_training + if attr == ASCAnchorAttr.TABLE_INSTANCE: + mock_table_instance = Mock() + return mock_table_instance + if attr == ASCAnchorAttr.FEATURE_SPEC: + mock_feature_spec = Mock() + mock_feature_spec.name = "mock_feature_spec_name" + mock_feature_spec.table_name = "mock_table_name" + return mock_feature_spec + + raise ValueError(f"Unsupported param 'attr' for enum class 'ASCAnchorAttr': attr={attr}.") + + return mock_get_anchor_attribute + + +class GetPreprocessingMapFuncTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_err_none_names_and_indexes(self): + mock_graph_def = tf.compat.v1.GraphDef() + mock_input_names = [] + mock_output_names = [] + + with self.assertRaises(ValueError): + get_preprocessing_map_func(mock_graph_def, mock_input_names, mock_output_names) + + +class GetInputIndexListTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_err_no_matched_cutting_point(self): + mock_cutting_point_list = [tf.ones(shape=(4096, 8))] + mock_replace_ment_specs = {} + mock_mapping_name_list = [] + mock_base_count = 0 + + with self.assertRaises(ValueError): + get_input_index_list( + mock_cutting_point_list, mock_replace_ment_specs, mock_mapping_name_list, mock_base_count + ) + + +class FindMakeIteratorOpTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + + found_iter_op = find_make_iterator_op(mock_ids) + self.assertEqual(found_iter_op.type, "MakeIterator") + + def test_err_no_tgt_dataset_op(self): + mock_ids = tf.zeros(shape=(4096, 8)) + with self.assertRaises(ValueError): + find_make_iterator_op(mock_ids) + + +class FindTargetDatasetOpTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_base_op = tf.identity(mock_ids).op + + found_tgt_dataset_op = find_target_dataset_op(base_ops=mock_base_op, op_type="IteratorGetNext") + self.assertEqual(found_tgt_dataset_op, mock_ids.op) + + def test_err_no_tgt_op_type(self): + mock_ids = tf.zeros(shape=(4096, 8)) + mock_base_op = mock_ids.op + with self.assertRaises(ValueError): + find_target_dataset_op(mock_base_op, "IteratorGetNext") + + +class GetDatasetOpTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + found_dataset_op = get_dataset_op(mock_get_next_op) + self.assertEqual(found_dataset_op.type, "OptimizeDataset") + + def test_err_invalid_op_type(self): + mock_get_next_op = tf.zeros(shape=(4096, 8)).op + with self.assertRaises(TypeError): + get_dataset_op(mock_get_next_op) + + +class GetPassingTensorList(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_tgt_op = mock_ids.op + mock_cutting_point = tf.identity(mock_ids) + mock_cutting_point_list = [mock_cutting_point] + + expected = { + "passing_tensor_list": [mock_ids], + "output_index_list": [0], + "sub_src_tensors": mock_cutting_point_list, + } + passing_tensor_list, output_index_list, sub_src_tensors = get_passing_tensor_list( + mock_cutting_point_list, mock_tgt_op + ) + self.assertEqual(passing_tensor_list, expected["passing_tensor_list"]) + self.assertEqual(output_index_list, expected["output_index_list"]) + self.assertEqual(sub_src_tensors, expected["sub_src_tensors"]) + + +class FindTargetInstanceDatasetTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_err_no_target_dataset_instance(self): + with self.assertRaises(LookupError): + find_target_instance_dataset(None) + + +class GenerateGetNextOpSpecsTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + @patch.multiple("mx_rec.graph.merge_lookup.SparseEmbedding", get_anchor_attribute=Mock(return_value=True)) + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_labels = mock_batch.get("mock_labels") + mock_cutting_point_list = [mock_ids, mock_labels] + + get_next_op = mock_ids.op + replacement_specs = defaultdict(dict) + passing_tensor_list = [mock_ids, mock_labels] + batch_tensor_index_list = [0, 1] + sub_cutting_point_list = [mock_ids, mock_labels] + sub_graph_def = tf.compat.v1.GraphDef() + input_name_list = [mock_ids.name, mock_labels.name] + output_name_list = [mock_ids.name, mock_labels.name] + is_training = True + + get_next_op_map = generate_get_next_op_specs(mock_cutting_point_list) + expected = defaultdict(dict) + expected[get_next_op] = { + "replacement_specs": replacement_specs, + "passing_tensor_list": passing_tensor_list, + "batch_tensor_index_list": batch_tensor_index_list, + "sub_cutting_point_list": sub_cutting_point_list, + "sub_graph_def": sub_graph_def, + "input_name_list": input_name_list, + "output_name_list": output_name_list, + "is_training": is_training, + } + self.assertEqual(get_next_op_map, expected) + + +class GetSrcDatasetTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok_one_shot(self): + mock_dataset = gen_mock_dataset() + mock_prefetch_dataset = mock_dataset.prefetch(10) + mock_double_prefetch_dataset = mock_prefetch_dataset.prefetch(10) + mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + src_dataset = get_src_dataset(mock_get_next_op, is_training=True) + self.assertEqual(src_dataset, mock_dataset) + + +class GetTgtDatasetTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + @patch.multiple( + "mx_rec.graph.modifier", + get_training_mode_channel_id=Mock(return_value=0), + get_asc_insert_func=Mock(return_value=lambda x, y: x), + ) + @patch.multiple("mx_rec.graph.modifier.SparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_sub_cutting_point_list = [mock_ids] + mock_records = { + "sub_graph_def": tf.compat.v1.GraphDef(), + "input_name_list": [], + "output_name_list": [], + "batch_tensor_index_list": [], + } + + tgt_dataset = get_tgt_dataset(mock_dataset, mock_sub_cutting_point_list, mock_records) + new_iter = tgt_dataset.make_initializable_iterator() + new_batch = new_iter.get_next() + new_ids = new_batch.get("mock_ids") + with tf.compat.v1.Session() as sess: + sess.run(new_iter.initializer) + sess.run(new_ids) + + +class ModifyGraphForAscTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + @patch.multiple( + "mx_rec.graph.modifier", + get_training_mode_channel_id=Mock(return_value=True), + get_asc_insert_func=Mock(return_value=lambda x, y: x), + set_iterator_type=Mock(), + set_initializer=Mock(), + set_target_batch=Mock(), + get_merged_multi_lookup=Mock(return_value=True), + ) + @patch.multiple("mx_rec.graph.modifier.SparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) + def test_ok_train_mode(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_cutting_point = tf.identity(mock_ids) + + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, mock_cutting_point) + + modify_graph_for_asc() + + @patch.multiple( + "mx_rec.graph.modifier", + get_training_mode_channel_id=Mock(return_value=True), + get_asc_insert_func=Mock(return_value=lambda x, y: x), + set_iterator_type=Mock(), + set_initializer=Mock(), + set_target_batch=Mock(), + get_merged_multi_lookup=Mock(return_value=True), + do_merge_lookup=Mock(), + get_bool_gauge_set=Mock(return_value={"evaluate"}), + insert_merged_multi_lookup=Mock(), + ) + @patch.multiple( + "mx_rec.graph.modifier.SparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute(is_training=False) + ) + def test_ok_eval_mode(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_cutting_point = tf.identity(mock_ids) + + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, mock_cutting_point) + + modify_graph_for_asc() + + @patch.multiple( + "mx_rec.graph.modifier", + get_training_mode_channel_id=Mock(return_value=True), + get_asc_insert_func=Mock(return_value=lambda x, y: x), + set_iterator_type=Mock(), + set_initializer=Mock(), + set_target_batch=Mock(), + get_merged_multi_lookup=Mock(return_value=False), + insert_merged_multi_lookup=Mock(), + ) + @patch.multiple("mx_rec.graph.modifier.SparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) + def test_err_not_clear_flag(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_cutting_point = tf.identity(mock_ids) + + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, mock_cutting_point) + + with self.assertRaises(RuntimeError): + modify_graph_for_asc() + + +class GetTimestampIndexTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + @patch.multiple( + "mx_rec.graph.modifier", + insert_feature_spec=Mock(), + get_feature_spec=Mock(return_value=None), + ) + @patch.multiple( + "mx_rec.graph.modifier.FeatureSpec", + include_timestamp=Mock(), + index_key=Mock(return_value=2), + ) + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_timestamp = mock_batch.get("mock_timestamp") + mock_get_next_op = mock_timestamp.op + + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, mock_timestamp) + + timestamp_index = get_timestamp_index(mock_get_next_op, is_training=True) + self.assertEqual(timestamp_index, 2) + + @patch.multiple( + "mx_rec.graph.modifier", + insert_feature_spec=Mock(), + get_feature_spec=Mock(), + ) + @patch.multiple( + "mx_rec.graph.modifier.FeatureSpec", + include_timestamp=Mock(), + index_key=Mock(return_value=0), + ) + def test_err_unmatched_timestamp_index(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_timestamp = mock_batch.get("mock_timestamp") + mock_get_next_op = mock_timestamp.op + + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, mock_timestamp) + + with self.assertRaises(ValueError): + get_timestamp_index(mock_get_next_op, is_training=True) + + +@patch.multiple( + "mx_rec.graph.patch", + get_modify_graph=Mock(return_value=True), + get_is_graph_modify_hook_running=Mock(return_value=True), +) +@patch.multiple( + "tensorflow.compat.v1.train.Saver", + __init__=Mock(return_value=None), + build=Mock(), +) +class GraphModifierHookTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + @patch.multiple( + "mx_rec.graph.modifier", + set_is_graph_modify_hook_running=Mock(), + modify_graph_and_start_emb_cache=Mock(), + start_asc_pipeline=Mock(), + get_iterator_type=Mock(return_value="MakeIterator"), + ) + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_cutting_point = tf.identity(mock_ids) + + mock_new_iterator = mock_dataset.make_initializable_iterator() + tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, mock_new_iterator.initializer) + + with tf.compat.v1.train.MonitoredSession(hooks=[GraphModifierHook(modify_graph=True)]) as sess: + sess.run(mock_iterator.initializer) + sess.run(mock_cutting_point) + + @patch.multiple( + "mx_rec.graph.modifier", + set_is_graph_modify_hook_running=Mock(), + modify_graph_and_start_emb_cache=Mock(), + start_asc_pipeline=Mock(), + get_iterator_type=Mock(return_value="InvalidIterator"), + ) + def test_err_invalid_iterator_type(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_cutting_point = tf.identity(mock_ids) + + mock_new_iterator = mock_dataset.make_initializable_iterator() + tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, mock_new_iterator.initializer) + + with self.assertRaises(ValueError): + with tf.compat.v1.train.MonitoredSession(hooks=[GraphModifierHook(modify_graph=True)]) as sess: + sess.run(mock_iterator.initializer) + sess.run(mock_cutting_point) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mx_rec/graph/test_utils.py b/tests/mx_rec/graph/test_utils.py new file mode 100644 index 00000000..562679e5 --- /dev/null +++ b/tests/mx_rec/graph/test_utils.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +import sys +import os +import pathlib +import shutil +import unittest +from unittest import TestCase + +import tensorflow as tf +from tensorflow import Tensor, TensorSpec +from mx_rec.constants.constants import ASCAnchorAttr +from mx_rec.core.embedding import SparseEmbedding +from mx_rec.graph.utils import ( + check_input_list, + find_parent_op, + check_cutting_points, + export_pb_graph, + make_sorted_key_to_tensor_list, + replace_anchor_vec, +) + + +class CheckInputListTest(TestCase): + def tearDown(self): + tf.compat.v1.reset_default_graph() + + def test_ok_single_object(self): + mock_obj = "obj" + obj_type = str + + checked_objs = check_input_list(mock_obj, obj_type) + self.assertEqual([mock_obj], checked_objs) + + def test_ok_object_list(self): + mock_objs = ["obj1", "obj2", "ojb3"] + obj_type = str + + checked_cutting_points = check_input_list(mock_objs, obj_type) + self.assertEqual(mock_objs, checked_cutting_points) + + def test_err_inconsistent_object_and_type(self): + mock_objs = ["obj1", "obj2", "ojb3"] + obj_type = Tensor + + with self.assertRaises(ValueError): + check_input_list(mock_objs, obj_type) + + +class FindParentOpTest(TestCase): + def tearDown(self): + tf.compat.v1.reset_default_graph() + + def test_ok(self): + tsr1 = tf.constant([1, 2, 3], dtype=tf.int64) + mock_parent_op = tsr1.op + tsr2 = tf.identity(tsr1) + mock_child_op = tsr2.op + + parent_op = find_parent_op(mock_child_op) + self.assertEqual([mock_parent_op], parent_op) + + +class CheckCuttingPointsTest(TestCase): + def setUp(self): + self._generator_iter_times = 3 + + def tearDown(self): + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_cutting_point_list = [tf.identity(tf.zeros(shape=(1,))) for _ in range(self._generator_iter_times)] + check_cutting_points(mock_cutting_point_list) + + def test_err_invalid_cutting_point_list(self): + mock_cutting_point_list = ["point" for _ in range(self._generator_iter_times)] + with self.assertRaises(TypeError): + check_cutting_points(mock_cutting_point_list) + + def test_err_invalid_cutting_point_operation(self): + mock_cutting_point_list = [tf.zeros(shape=(1,)) for _ in range(self._generator_iter_times)] + with self.assertRaises(ValueError): + check_cutting_points(mock_cutting_point_list) + + +class ExportPBGraphTest(TestCase): + def setUp(self) -> None: + self._dir_name = "./export_graph" + + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + if os.path.isdir(self._dir_name): + shutil.rmtree(self._dir_name) + + def test_ok(self): + mock_file_name = "test_graph.pbtxt" + dump_graph = True + mock_graph_def = tf.Graph().as_graph_def() + as_text = True + + export_pb_graph(mock_file_name, dump_graph, mock_graph_def, self._dir_name, as_text) + path = pathlib.Path(self._dir_name + "/" + mock_file_name) + self.assertTrue(path.is_file()) + + +class MakeSortedKeyToTensorListTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_batch = { + "item_ids": TensorSpec(shape=(4096, 16), dtype=tf.int64), + "user_ids": TensorSpec(shape=(4096, 8), dtype=tf.int64), + "category_ids": TensorSpec(shape=(4096, 3), dtype=tf.int64), + "label_0": TensorSpec(shape=(4096,), dtype=tf.int64), + "label_1": TensorSpec(shape=(4096,), dtype=tf.int64), + "user_ids_last_key": TensorSpec(shape=(4096, 16), dtype=tf.int64), + "user_ids_last_key_last_key": TensorSpec(shape=(4096, 8), dtype=tf.int64), + } + mock_element_spec = [mock_batch] + mock_sorted_keys = [] + mock_prefix = "mock_prefix" + + expected = [ + "mock_prefix_0_item_ids", + "mock_prefix_0_item_ids_user_ids", + "mock_prefix_0_item_ids_user_ids_category_ids", + "mock_prefix_0_item_ids_user_ids_category_ids_label_0", + "mock_prefix_0_item_ids_user_ids_category_ids_label_0_label_1", + "mock_prefix_0_item_ids_user_ids_category_ids_label_0_label_1_user_ids_last_key", + "mock_prefix_0_item_ids_user_ids_category_ids_label_0_label_1_user_ids_last_key_user_ids_last_key_last_key", + ] + sorted_batch_keys = make_sorted_key_to_tensor_list(mock_element_spec, mock_sorted_keys, mock_prefix) + self.assertEqual(sorted_batch_keys, expected) + + +class ReplaceAnchorVecTest(TestCase): + def tearDown(self): + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_cutting_point = tf.zeros(shape=(4096, 8), dtype=tf.int64, name="ids") + mock_attribute = ASCAnchorAttr.MOCK_LOOKUP_RESULT + mock_anchor = tf.zeros(shape=(4096, 8), dtype=tf.float32, name="anchor") + + anchor_vec = tf.identity(mock_cutting_point, name="anchor_vec") + anchor_vec_output = tf.identity(anchor_vec, name="anchor_vec_output") + SparseEmbedding.anchor_tensor_specs[mock_cutting_point][mock_attribute] = anchor_vec + + replace_anchor_vec(mock_cutting_point, mock_attribute, mock_anchor) + self.assertEqual(anchor_vec_output.op.inputs[0], mock_anchor) + + +if __name__ == "__main__": + unittest.main() -- Gitee From 30c8c9719e348a7588cd16bb2cbcdfdd21f4cc7e Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Dec 2023 15:17:45 +0800 Subject: [PATCH 515/551] Match-id-27877c6d7e26ae8eaeb69bdaff9650da1d9dee34 --- mx_rec/core/asc/build_graph.py | 18 +- tests/mx_rec/core/initializer_mock.py | 13 - tests/mx_rec/core/mxrec_pybind_mock.py | 11 - tests/mx_rec/core/test_build_graph.py | 667 +++++++++++++++++-------- 4 files changed, 474 insertions(+), 235 deletions(-) delete mode 100644 tests/mx_rec/core/initializer_mock.py delete mode 100644 tests/mx_rec/core/mxrec_pybind_mock.py diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index e9d54115..db4f037e 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +from typing import Optional + import tensorflow as tf import mxrec_pybind @@ -15,16 +17,16 @@ from mx_rec.util.log import logger def get_restore_vector(config): logger.debug('Channel %s_restore_%s was built for getnext', config.get("table_name"), config.get("channel_id")) if config.get("skip_emb_transfer"): - if not isinstance(config.get("emb_size"), int) or config.get("emb_size") < 1: - raise TypeError(f"emb_size must be a int") + if not isinstance(config.get("emb_size"), int): + raise TypeError("emb_size must be a int") if config.get("emb_size") < 1: - raise ValueError(f"emb_size is less than 1") + raise ValueError("emb_size is less than 1") emb_size = config.get("emb_size") else: - if not isinstance(config.get("ext_emb_size"), int) or config.get("ext_emb_size") < 1: - raise TypeError(f"ext_emb_size must be a int") + if not isinstance(config.get("ext_emb_size"), int): + raise TypeError("ext_emb_size must be a int") if config.get("ext_emb_size") < 1: - raise ValueError(f"ext_emb_size is less than 1") + raise ValueError("ext_emb_size is less than 1") emb_size = config.get("ext_emb_size") use_hot = config.get("use_hot") @@ -117,7 +119,7 @@ def get_unique_keys(max_lookup_vec_size: int, config: dict) -> tf.Tensor: return unique_keys -def get_all2all_args(use_static: bool, config: dict) -> list: +def get_all2all_args(use_static: bool, config: dict) -> Optional[list]: """ Get all2all parameters for dynamic condition :param use_static: dynamic or static @@ -219,4 +221,4 @@ def get_preprocessed_tensor_for_asc(table, config): with tf.compat.v1.variable_scope("unique_keys"): unique_keys = get_unique_keys(max_lookup_vec_size, config) result.update({'restore_vector_second': restore_vector_second, 'unique_keys': unique_keys}) - return result \ No newline at end of file + return result diff --git a/tests/mx_rec/core/initializer_mock.py b/tests/mx_rec/core/initializer_mock.py deleted file mode 100644 index a5c840da..00000000 --- a/tests/mx_rec/core/initializer_mock.py +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env python3 -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. -import os - - -class InitializerMock: - """ - initializer mock module - """ - @staticmethod - def get_use_static(): - return os.getenv("use_static", True) diff --git a/tests/mx_rec/core/mxrec_pybind_mock.py b/tests/mx_rec/core/mxrec_pybind_mock.py deleted file mode 100644 index f65356c2..00000000 --- a/tests/mx_rec/core/mxrec_pybind_mock.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python3 -# coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. - - -class MxRecPybindMock: - """ - mxrec_pybind mock module - """ - def get_ub_hot_size(self): - return 21845 diff --git a/tests/mx_rec/core/test_build_graph.py b/tests/mx_rec/core/test_build_graph.py index 91fb0b29..f6ac6e6d 100644 --- a/tests/mx_rec/core/test_build_graph.py +++ b/tests/mx_rec/core/test_build_graph.py @@ -1,230 +1,491 @@ #!/usr/bin/env python3 # coding: UTF-8 # Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. -import os -import sys + import unittest -from dataclasses import dataclass from unittest import mock import tensorflow as tf -from tests.mx_rec.core.mxrec_pybind_mock import MxRecPybindMock -from tests.mx_rec.core.initializer_mock import InitializerMock -from mx_rec.util.tf_version_adapter import npu_ops +from mx_rec.util.global_env_conf import global_env -sys.modules['mxrec_pybind'] = MxRecPybindMock -sys.modules['mx_rec.util.initialize'] = InitializerMock -os.environ[ - "HOST_PIPELINE_OPS_LIB_PATH"] = f"{os.getenv('so_path')}/libasc/libasc_ops.so" +class TestGetRestoreVectorFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.build_graph.get_restore_vector'. + """ + def setUp(self): + # 默认动态扩容、hot emb、HBM + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) -@dataclass -class InputConfig: - batch_size: int - feat_cnt: int - send_count: int - rank_size: int - channel_id: int - table_name: str - skip_emb_transfer: bool - ext_emb_size: int - emb_size: int - use_hot: bool - device_id: int - use_dynamic_expansion: bool + def tearDown(self): + # 恢复config + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + def test_get_restore_vector_case1(self): + """ + case1: HBM,emb_size不为int,抛出异常 -class TestBuildGraph(unittest.TestCase): - """ - Test Suite for Exception Checkpoint. - """ + """ - def setUp(self): + from mx_rec.core.asc.build_graph import get_restore_vector + + self.config["emb_size"] = "xxx" + with self.assertRaises(TypeError): + get_restore_vector(self.config) + + def test_get_restore_vector_case2(self): """ - 准备步骤 - :return:无 + case2: HBM,emb_size小于1,抛出异常 """ - super().setUp() - def tearDown(self): + from mx_rec.core.asc.build_graph import get_restore_vector + + self.config["emb_size"] = 0 + with self.assertRaises(ValueError): + get_restore_vector(self.config) + + def test_get_restore_vector_case3(self): + """ + case3: 非HBM,ext_emb_size不为int,抛出异常 + """ + + from mx_rec.core.asc.build_graph import get_restore_vector + + self.config["skip_emb_transfer"] = False + self.config["ext_emb_size"] = "xxx" + with self.assertRaises(TypeError): + get_restore_vector(self.config) + + def test_get_restore_vector_case4(self): + """ + case4: 非HBM,ext_emb_size小于1,抛出异常 """ - 销毁步骤 - :return: 无 - """ - super().tearDown() - - @staticmethod - def get_next_mock(): - return tf.constant(value=1, name="inference/asecnd_lookup_one_big_embedding/all2all/mul", shape=[8, 8], - dtype=tf.int64) - - @staticmethod - def get_id_offsets_mock(): - return tf.constant(value=1, shape=[270412, 8], dtype=tf.float32, name="inference/gather_for_id_offsets"), [], 0 - - @staticmethod - def get_all2all_mock(): - return tf.constant(value=1, shape=[8, ], dtype=tf.int64, name="mul") - - @staticmethod - def get_restore_vector_mock(): - return [tf.constant(value=1, shape=[2908800], dtype=tf.int32, name="aicpu_getnext_restore_vector/GetNext"), - None] - - @staticmethod - def get_input_config(input_config_init: InputConfig): - batch_size = input_config_init.batch_size - feat_cnt = input_config_init.feat_cnt - send_count = input_config_init.send_count - rank_size = input_config_init.rank_size - channel_id = input_config_init.channel_id - table_name = input_config_init.table_name - skip_emb_transfer = input_config_init.skip_emb_transfer - ext_emb_size = input_config_init.ext_emb_size - emb_size = input_config_init.emb_size - use_hot = input_config_init.use_hot - device_id = input_config_init.device_id - use_dynamic_expansion = input_config_init.use_dynamic_expansion - - input_config = {'batch_size': batch_size, - 'feat_cnt': feat_cnt, - 'send_count': send_count, - 'rank_size': rank_size, - 'channel_id': channel_id, - 'table_name': table_name, - 'skip_emb_transfer': skip_emb_transfer, - 'ext_emb_size': ext_emb_size, - 'emb_size': emb_size, - 'use_hot': use_hot, - 'device_id': device_id, - 'use_dynamic_expansion': use_dynamic_expansion} - return input_config - - @staticmethod - def get_input_table(): - input_table = tf.Variable(tf.zeros([875000, 8]), name="inference/one_ascend_hash_embedding:0", - dtype=tf.float32) - return input_table - - @mock.patch('npu_bridge.estimator.npu_ops') - @mock.patch("npu_bridge.hccl.hccl_ops") - @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") - def test_get_restore_vector_use_hot(self, tf1_npu_ops_mock, tf1_hccl_ops_mock, - tf1_save_mock): + from mx_rec.core.asc.build_graph import get_restore_vector - tf1_npu_ops_mock.return_value = None - tf1_hccl_ops_mock.return_value = None - tf1_save_mock.return_value = None - input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, "8", True, 6, - False) - input_config = self.get_input_config(input_config_instance) - try: - get_restore_vector(input_config) - except TypeError as exp: - self.assertEqual(type(exp), TypeError) - else: - self.fail("TypeError not raised.") - - @mock.patch("npu_bridge.hccl.hccl_ops") - @mock.patch('npu_bridge.estimator.npu_ops') - @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") - def test_get_restore_vector_emb_size_value_error(self, tf1_hccl_ops_mock, tf1_npu_ops_mock, - tf1_save_mock): + + self.config["skip_emb_transfer"] = False + self.config["ext_emb_size"] = 0 + with self.assertRaises(ValueError): + get_restore_vector(self.config) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.asc.build_graph.mxrec_pybind.get_ub_hot_size") + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_restore_vector_case5(self, mock_get_next, mock_get_ub_hot_size): + """ + case5: HBM,静态shape,hot emb + """ + from mx_rec.core.asc.build_graph import get_restore_vector - tf1_save_mock.return_value = None - tf1_npu_ops_mock.return_value = None - tf1_hccl_ops_mock.return_value = None - input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, -1, True, 6, - False) - input_config = self.get_input_config(input_config_instance) - try: - get_restore_vector(input_config) - except TypeError as exp: - self.assertEqual(type(exp), TypeError) - else: - self.fail("ValueError not raised.") - - @mock.patch('npu_bridge.estimator.npu_ops') - @mock.patch("npu_bridge.hccl.hccl_ops") - @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") - def test_get_restore_vector_ext_emb_size_type_error(self, tf1_npu_ops_mock, tf1_hccl_ops_mock, - tf1_save_mock): + + with tf.Graph().as_default(): + mock_get_next.return_value = [0, 1] + mock_get_ub_hot_size.return_value = 8 + restore_vector, hot_pos = get_restore_vector(self.config) + self.assertEqual(restore_vector, 0) + self.assertEqual(hot_pos, 1) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.asc.build_graph.mxrec_pybind.get_ub_hot_size") + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_restore_vector_case6(self, mock_get_next, mock_get_ub_hot_size): + """ + case6: HBM,动态shape,hot emb + """ + from mx_rec.core.asc.build_graph import get_restore_vector - tf1_npu_ops_mock.return_value = None - tf1_hccl_ops_mock.return_value = None - tf1_save_mock.return_value = None - input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", False, "8", 8, True, 6, - False) - input_config = self.get_input_config(input_config_instance) - try: - get_restore_vector(input_config) - except TypeError as exp: - self.assertEqual(type(exp), TypeError) - else: - self.fail("TypeError not raised.") - - @mock.patch('npu_bridge.estimator.npu_ops') - @mock.patch("npu_bridge.hccl.hccl_ops") - @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") - def test_get_restore_vector_ext_emb_size_value_error(self, tf1_npu_ops_mock, tf1_hccl_ops_mock, - tf1_save_mock): + + with tf.Graph().as_default(): + mock_get_next.return_value = [0, 1] + mock_get_ub_hot_size.return_value = 8 + restore_vector, hot_pos = get_restore_vector(self.config) + self.assertEqual(restore_vector, 0) + self.assertEqual(hot_pos, 1) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_restore_vector_case7(self, mock_get_next): + """ + case7: HBM,静态shape + """ + from mx_rec.core.asc.build_graph import get_restore_vector - tf1_npu_ops_mock.return_value = None - tf1_hccl_ops_mock.return_value = None - tf1_save_mock.return_value = None - input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", False, -1, 8, True, 6, - False) - input_config = self.get_input_config(input_config_instance) - try: - get_restore_vector(input_config) - except TypeError as exp: - self.assertEqual(type(exp), TypeError) - else: - self.fail("ValueError not raised.") - - @mock.patch("npu_bridge.hccl.hccl_ops") - @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") - def test_get_id_offsets(self, tf1_hccl_ops_mock, - tf1_save_mock): - with mock.patch.object(npu_ops, "gen_npu_ops") as mock_npu_ops: - from mx_rec.core.asc.build_graph import get_id_offsets - id_offset_mock = tf.constant(value=1, shape=[270412, 8], dtype=tf.float32, - name="inference/gather_for_id_offsets") - mock_npu_ops.get_next.return_value = [id_offset_mock] - tf1_hccl_ops_mock.return_value = None - tf1_save_mock.return_value = None - input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, True, 6, - False) - input_config = self.get_input_config(input_config_instance) - max_lookup_vec_size = None - res_id_offsets = get_id_offsets(max_lookup_vec_size, input_config) - self.assertEqual(res_id_offsets[0], id_offset_mock) - - @mock.patch("npu_bridge.hccl.hccl_ops") - @mock.patch("npu_bridge.estimator.npu.npu_hook.NPUCheckpointSaverHook") - def test_get_all2all_args(self, tf1_hccl_ops_mock, - tf1_save_mock): - with mock.patch.object(npu_ops, "gen_npu_ops") as mock_npu_ops: - from mx_rec.core.asc.build_graph import get_all2all_args - all2all_mock = tf.constant( - value=1, - name='mul', - shape=[8, 8], dtype=tf.int64) - mock_npu_ops.get_next.return_value = all2all_mock - tf1_hccl_ops_mock.return_value = None - tf1_save_mock.return_value = None - input_config_instance = InputConfig(9600, 1, None, 8, 0, "one_ascend_hash_embedding", True, 8, 8, True, 6, - False) - input_config = self.get_input_config(input_config_instance) - use_static = False - res_all2all_args = get_all2all_args(use_static, input_config) - self.assertEqual(res_all2all_args.shape, tf.constant(value=1, shape=[8, ], dtype=tf.int64, - name="mul").shape) - self.assertEqual(res_all2all_args.dtype, tf.constant(value=1, shape=[8, 8], dtype=tf.int64, - name="mul").dtype) + with tf.Graph().as_default(): + mock_get_next.return_value = [0] + self.config["use_hot"] = False + restore_vector, hot_pos = get_restore_vector(self.config) + self.assertEqual(restore_vector, 0) + self.assertIsNone(hot_pos) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=False)) + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_restore_vector_case8(self, mock_get_next): + """ + case8: HBM,动态shape + """ + + from mx_rec.core.asc.build_graph import get_restore_vector + + with tf.Graph().as_default(): + mock_get_next.return_value = [0] + self.config["use_hot"] = False + restore_vector, hot_pos = get_restore_vector(self.config) + self.assertEqual(restore_vector, 0) + self.assertIsNone(hot_pos) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=False)) + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_restore_vector_case9(self, mock_get_next): + """ + case9: 非HBM,动态shape + """ + + from mx_rec.core.asc.build_graph import get_restore_vector + + with tf.Graph().as_default(): + mock_get_next.return_value = [0] + self.config["skip_emb_transfer"] = False + self.config["use_hot"] = False + restore_vector, hot_pos = get_restore_vector(self.config) + self.assertEqual(restore_vector, 0) + self.assertIsNone(hot_pos) + + +class TestGetIdOffsetsFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.build_graph.get_id_offsets'. + """ + + def setUp(self): + # 默认动态扩容、hot emb、HBM + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + self.max_lookup_vec_size = self.config.get("send_count") * self.config.get("rank_size") + + def tearDown(self): + # 恢复config + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_id_offsets_case1(self, mock_get_next): + """ + case1: 动态扩容 + """ + + from mx_rec.core.asc.build_graph import get_id_offsets + + with tf.Graph().as_default(): + mock_get_next.return_value = [0] + id_offsets, swap_pos, swap_len = get_id_offsets(self.max_lookup_vec_size, self.config) + self.assertEqual(id_offsets, 0) + self.assertListEqual(swap_pos, []) + self.assertEqual(swap_len, 0) + + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_id_offsets_case2(self, mock_get_next): + """ + case2: 非动态扩容,HBM + """ + + from mx_rec.core.asc.build_graph import get_id_offsets + + with tf.Graph().as_default(): + self.config["use_dynamic_expansion"] = False + mock_get_next.return_value = [0] + id_offsets, swap_pos, swap_len = get_id_offsets(self.max_lookup_vec_size, self.config) + self.assertEqual(id_offsets, 0) + self.assertListEqual(swap_pos, []) + self.assertEqual(swap_len, 0) + + +class TestGetRestoreVectorSecondFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.build_graph.get_restore_vector_second'. + """ + + def setUp(self): + # 默认动态扩容、hot emb、HBM + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + self.max_lookup_vec_size = self.config.get("send_count") * self.config.get("rank_size") + + def tearDown(self): + # 恢复config + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_restore_vector_second(self, mock_get_next): + """ + case: test get_restore_vector_second + """ + + from mx_rec.core.asc.build_graph import get_restore_vector_second + + with tf.Graph().as_default(): + mock_get_next.return_value = [0] + restore_vector_second = get_restore_vector_second(self.max_lookup_vec_size, self.config) + self.assertEqual(restore_vector_second, 0) + + +class TestGetUniqueKeysFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.build_graph.get_unique_keys'. + """ + + def setUp(self): + # 默认动态扩容、hot emb、HBM + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + self.max_lookup_vec_size = self.config.get("send_count") * self.config.get("rank_size") + + def tearDown(self): + # 恢复config + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_unique_keys_case1(self, mock_get_next): + """ + case1: 动态扩容 + """ + + from mx_rec.core.asc.build_graph import get_unique_keys + + with tf.Graph().as_default(): + mock_get_next.return_value = [0] + unique_keys = get_unique_keys(self.max_lookup_vec_size, self.config) + self.assertEqual(unique_keys, 0) + + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_unique_keys_case2(self, mock_get_next): + """ + case2: 非动态扩容 + """ + + from mx_rec.core.asc.build_graph import get_unique_keys + + with tf.Graph().as_default(): + self.config["use_dynamic_expansion"] = False + mock_get_next.return_value = [1] + unique_keys = get_unique_keys(self.max_lookup_vec_size, self.config) + self.assertEqual(unique_keys, 1) + + +class TestGetAll2allArgsFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.build_graph.get_all2all_args'. + """ + + def setUp(self): + # 默认动态扩容、hot emb、HBM + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + + def tearDown(self): + # 恢复config + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + + def test_get_all2all_args_case1(self): + """ + case1: 静态shape + """ + + from mx_rec.core.asc.build_graph import get_all2all_args + + with tf.Graph().as_default(): + all2all_args = get_all2all_args(True, self.config) + self.assertIsNone(all2all_args) + + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_all2all_args_case2(self, mock_get_next): + """ + case2: 动态shape + """ + + from mx_rec.core.asc.build_graph import get_all2all_args + + with tf.Graph().as_default(): + mock_get_next.return_value = [0] + all2all_args = get_all2all_args(False, self.config) + self.assertEqual(all2all_args, 0) + + +class TestGetSwapInfoFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.build_graph.get_swap_info'. + """ + + def setUp(self): + # 默认动态扩容、hot emb、HBM + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + + def tearDown(self): + # 恢复config + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=True)) + def test_get_swap_info_case1(self): + """ + case1: 静态shape,HBM + """ + + from mx_rec.core.asc.build_graph import get_swap_info + + with tf.Graph().as_default(): + swap_in = get_swap_info(self.config, None, None, None) + self.assertIsInstance(swap_in[0], type(tf.no_op())) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") + def test_get_swap_info_case2(self, mock_get_next): + """ + case2: 静态shape,非HBM,table传入非list,抛出异常 + """ + + from mx_rec.core.asc.build_graph import get_swap_info + + with tf.Graph().as_default(): + mock_get_next.return_value = tf.ones(shape=[8, 8], dtype=tf.float32) + swap_pos = tf.constant([8, 9], dtype=tf.int32) + swap_len = tf.constant(2, dtype=tf.int32) + table = tf.compat.v1.get_variable("test_table", shape=[10, 8], initializer=tf.ones_initializer()) + self.config["skip_emb_transfer"] = False + with self.assertRaises(RuntimeError): + get_swap_info(self.config, swap_len, swap_pos, table) + + +class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.build_graph.get_preprocessed_tensor_for_asc'. + """ + + def setUp(self): + # 默认动态扩容、hot emb、HBM + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + + def tearDown(self): + # 恢复config + self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, + use_hot=True, use_dynamic_expansion=True) + global_env.apply_gradients_strategy = "direct_apply" + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=True), + get_restore_vector=mock.MagicMock(return_value=[0, 0]), + get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), + get_all2all_args=mock.MagicMock(return_value=0), + get_swap_info=mock.MagicMock(return_value=0), + get_restore_vector_second=mock.MagicMock(return_value=0), + get_unique_keys=mock.MagicMock(return_value=0)) + def test_get_preprocessed_tensor_for_asc_case1(self): + """ + case1: 静态shape,全局unique + """ + + from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc + + global_env.apply_gradients_strategy = "sum_same_id_gradients_and_apply" + with tf.Graph().as_default(): + result = get_preprocessed_tensor_for_asc(None, self.config) + self.assertIsNotNone(result.get("restore_vector")) + self.assertIsNotNone(result.get("restore_vector_second")) + self.assertIsNotNone(result.get("unique_keys")) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=False), + get_restore_vector=mock.MagicMock(return_value=[0, 0]), + get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), + get_all2all_args=mock.MagicMock(return_value=0), + get_swap_info=mock.MagicMock(return_value=0), + get_restore_vector_second=mock.MagicMock(return_value=0), + get_unique_keys=mock.MagicMock(return_value=0)) + def test_get_preprocessed_tensor_for_asc_case2(self): + """ + case2: 动态shape,全局unique + """ + + from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc + + global_env.apply_gradients_strategy = "sum_same_id_gradients_and_apply" + with tf.Graph().as_default(): + result = get_preprocessed_tensor_for_asc(None, self.config) + self.assertIsNotNone(result.get("restore_vector")) + self.assertIsNotNone(result.get("restore_vector_second")) + self.assertIsNotNone(result.get("unique_keys")) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=False), + get_restore_vector=mock.MagicMock(return_value=[0, 0]), + get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), + get_all2all_args=mock.MagicMock(return_value=0), + get_swap_info=mock.MagicMock(return_value=0), + get_restore_vector_second=mock.MagicMock(return_value=0), + get_unique_keys=mock.MagicMock(return_value=0)) + def test_get_preprocessed_tensor_for_asc_case3(self): + """ + case3: 动态shape,全局unique,channel_id=1 + """ + + from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc + + global_env.apply_gradients_strategy = "sum_same_id_gradients_and_apply" + with tf.Graph().as_default(): + self.config["channel_id"] = 1 + result = get_preprocessed_tensor_for_asc(None, self.config) + self.assertIsNotNone(result.get("restore_vector")) + self.assertIsNone(result.get("restore_vector_second")) + + @mock.patch.multiple("mx_rec.core.asc.build_graph", + get_use_static=mock.MagicMock(return_value=False), + get_restore_vector=mock.MagicMock(return_value=[0, 0]), + get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), + get_all2all_args=mock.MagicMock(return_value=0), + get_swap_info=mock.MagicMock(return_value=0), + get_restore_vector_second=mock.MagicMock(return_value=0), + get_unique_keys=mock.MagicMock(return_value=0)) + def test_get_preprocessed_tensor_for_asc_case4(self): + """ + case4: 动态shape + """ + + from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc + + with tf.Graph().as_default(): + result = get_preprocessed_tensor_for_asc(None, self.config) + self.assertIsNotNone(result.get("restore_vector")) + self.assertIsNone(result.get("restore_vector_second")) if __name__ == '__main__': -- Gitee From 68b8a14c4e45763a38ab347bd00f6ca58f5723c6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Dec 2023 19:25:27 +0800 Subject: [PATCH 516/551] Match-id-5e3ccdbe62d6560235ba6cbe849f833658935596 --- mx_rec/core/asc/helper.py | 54 +- tests/mx_rec/core/test_helper.py | 925 +++++++++++++++++++++++++++++++ 2 files changed, 947 insertions(+), 32 deletions(-) create mode 100644 tests/mx_rec/core/test_helper.py diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 783871b1..4b1e1fbf 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -9,11 +9,9 @@ import tensorflow as tf from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static, get_modify_graph from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.merge_table import find_dangling_table, should_skip -from mx_rec.validator.validator import para_checker_decorator, ValueCompareValidator, ClassValidator, \ - OptionalIntValidator +from mx_rec.validator.validator import para_checker_decorator, ValueCompareValidator, ClassValidator from mx_rec.util.log import logger from mx_rec.util.normalization import fix_invalid_table_name -from mx_rec.constants.constants import MAX_INT32 @para_checker_decorator(check_option_list=[ @@ -71,7 +69,6 @@ def create_asc_insert_func_with_acg(args_index_list, table_names, **kwargs): def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, table_names=None, **kwargs): - is_training = kwargs.get("is_training", True) dump_graph = kwargs.get("dump_graph", False) @@ -79,7 +76,7 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, table_na if not isinstance(tgt_key_specs, (list, tuple)): tgt_key_specs = [tgt_key_specs] - def insert_fn_for_feature_specs(*args): + def insert_fn_for_feature_specs(*args): # pragma: no cover data_src = args if len(args) == 1: data_src = args[0] @@ -109,7 +106,7 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, table_na logger.info("In insert found dangling table(s): %s which does not need to be provided to the EmbInfo.", dangling_tables) - def insert_fn_for_arg_indexes(*args): + def insert_fn_for_arg_indexes(*args): # pragma: no cover insert_tensors = get_target_tensors_with_args_indexes(args_index_list) logger.debug("do_insert without spec for %s", table_names) @@ -193,18 +190,18 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list): logger.debug("merge request from %s %s to %s %s", table_name_list, split_list, output_table_name_list, output_split_list) - list_set = { + output_dict = { 'output_feature_id_list': output_feature_id_list, 'output_split_list': output_split_list, 'output_table_name_list': output_table_name_list, 'output_tensorshape_split_list': output_tensorshape_split_list, } - return list_set + return output_dict def send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict): - is_training = input_dict["is_training"] - timestamp = input_dict["timestamp"] + is_training = input_dict.get("is_training") + timestamp = input_dict.get("timestamp") host_pipeline_ops = get_host_pipeline_ops() use_static = get_use_static() timestamp_feature_id = [] @@ -213,11 +210,11 @@ def send_feature_id_request_async(feature_id_list, split_list, table_name_list, timestamp_feature_id = feature_id_list[:1] feature_id_list = feature_id_list[1:] - list_set = merge_feature_id_request(feature_id_list, split_list, table_name_list) - feature_id_list = list_set.get("output_feature_id_list") - split_list = list_set.get("output_split_list") - table_name_list = list_set.get("output_table_name_list") - tensorshape_split_list = list_set.get("output_tensorshape_split_list") + merged_dict = merge_feature_id_request(feature_id_list, split_list, table_name_list) + feature_id_list = merged_dict.get("output_feature_id_list") + split_list = merged_dict.get("output_split_list") + table_name_list = merged_dict.get("output_table_name_list") + tensorshape_split_list = merged_dict.get("output_tensorshape_split_list") # check training mode order and ensure channel id channel_id = get_training_mode_channel_id(is_training=is_training) @@ -242,11 +239,11 @@ def send_feature_id_request_async(feature_id_list, split_list, table_name_list, def do_insert(args, insert_tensors, splits, table_names, input_dict): - is_training = input_dict["is_training"] - dump_graph = input_dict["dump_graph"] - timestamp = input_dict["timestamp"] - feature_spec_names = input_dict["feature_spec_names"] - auto_change_graph = input_dict["auto_change_graph"] + is_training = input_dict.get("is_training") + dump_graph = input_dict.get("dump_graph") + timestamp = input_dict.get("timestamp") + feature_spec_names = input_dict.get("feature_spec_names") + auto_change_graph = input_dict.get("auto_change_graph") pipeline_op = \ send_feature_id_request_async(feature_id_list=insert_tensors, @@ -272,7 +269,10 @@ def export_read_emb_key_v2_op(args, pipeline_op): raise ValueError("The length of args is less than 1.") if isinstance(origin_batch[0], dict): output_batch = origin_batch[0] - valid_key = get_valid_op_key(output_batch) + # 找到output_batch中字典序最大的key + sorted_keys = sorted(output_batch) + valid_key = f"{sorted_keys[-1]}_read_emb_key" + # 将readEmbKey算子的输出插入到batch中,当dataset每次getnext时,就会执行readEmbKey算子获取输出 output_batch[valid_key] = pipeline_op elif len(origin_batch) == 1 and isinstance(origin_batch[0], tf.Tensor): @@ -301,17 +301,7 @@ def export_read_emb_key_v2_op(args, pipeline_op): return output_batch -def get_valid_op_key(batch_dict: dict) -> str: - if not isinstance(batch_dict, dict): - raise TypeError(f"batch_dict must be a dict") - - sorted_keys = sorted(batch_dict) - valid_key = f"{sorted_keys[-1]}_read_emb_key" - - return valid_key - - -def get_target_tensors_with_args_indexes(args_index_list): +def get_target_tensors_with_args_indexes(args_index_list): # pragma: no cover insert_tensors = [] graph = tf.compat.v1.get_default_graph() for index in args_index_list: diff --git a/tests/mx_rec/core/test_helper.py b/tests/mx_rec/core/test_helper.py new file mode 100644 index 00000000..3af6dc1a --- /dev/null +++ b/tests/mx_rec/core/test_helper.py @@ -0,0 +1,925 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import os +import shutil +import unittest +from unittest import mock +from typing import Callable + +import tensorflow as tf + +from mx_rec.core.asc.feature_spec import FeatureSpec +from tests.mx_rec.core.generator_dataset import generate_dataset, Config +from tests.mx_rec.core.mock_class import MockHostPipeLineOps + + +class TestGetAscInsertFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.helper.get_asc_insert_func'. + """ + + def test_get_asc_insert_func_case1(self): + """ + case1: tgt_key_specs和args_index_list都为None,抛出异常 + """ + + from mx_rec.core.asc.helper import get_asc_insert_func + + with self.assertRaises(ValueError): + get_asc_insert_func(tgt_key_specs=None, args_index_list=None) + + def test_get_asc_insert_func_case2(self): + """ + case2: tgt_key_specs和args_index_list都不为None,抛出异常 + """ + + from mx_rec.core.asc.helper import get_asc_insert_func + + with self.assertRaises(ValueError): + get_asc_insert_func(tgt_key_specs=[], args_index_list=[]) + + def test_get_asc_insert_func_case3(self): + """ + case3: tgt_key_specs不为None时,table_names应为None,否则抛出异常 + """ + + from mx_rec.core.asc.helper import get_asc_insert_func + + with self.assertRaises(RuntimeError): + get_asc_insert_func(tgt_key_specs=[], table_names=[]) + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_asc_insert_func_inner=mock.MagicMock(return_value=Callable)) + def test_get_asc_insert_func_case4(self): + """ + case4: tgt_key_specs不为None + """ + + from mx_rec.core.asc.helper import get_asc_insert_func + + self.assertTrue(callable(get_asc_insert_func(tgt_key_specs=[]))) + + def test_get_asc_insert_func_case5(self): + """ + case5: args_index_list不为None时,table_names应不为None,否则抛出异常 + """ + + from mx_rec.core.asc.helper import get_asc_insert_func + + with self.assertRaises(RuntimeError): + get_asc_insert_func(args_index_list=[]) + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_asc_insert_func_inner=mock.MagicMock(return_value=Callable)) + def test_get_asc_insert_func_case6(self): + """ + case6: args_index_list和table_names都不为None + """ + + from mx_rec.core.asc.helper import get_asc_insert_func + + self.assertTrue(callable(get_asc_insert_func(args_index_list=[], table_names=["xxx"]))) + + +class TestGetAscInsertFuncInnerFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.helper.get_asc_insert_func_inner'. + """ + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_target_tensors_with_feature_specs=mock.MagicMock(return_value=None)) + @mock.patch("mx_rec.core.asc.helper.do_insert") + def test_get_asc_insert_func_inner_case1(self, mock_do_insert): + """ + case1: tgt_key_specs不为None,args_index_list为None + """ + + from mx_rec.core.asc.helper import get_asc_insert_func_inner + + with tf.Graph().as_default(): + mock_do_insert.return_value = {"xxx": tf.constant(1, dtype=tf.int64)} + map_fn = get_asc_insert_func_inner(tgt_key_specs="xxx", dump_graph=False) + self.assertTrue(callable(map_fn)) + + dataset = generate_dataset(Config(batch_size=2, batch_number=2)) + dataset = dataset.map(map_fn) + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + with tf.Session() as sess: + sess.run(iterator.initializer) + sess.run(tf.compat.v1.global_variables_initializer()) + self.assertEqual(sess.run(batch.get("xxx")), 1) + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_target_tensors_with_feature_specs=mock.MagicMock(return_value=None), + find_dangling_table=mock.MagicMock(return_value=["table1"])) + @mock.patch("mx_rec.core.asc.helper.get_target_tensors_with_args_indexes") + @mock.patch("mx_rec.core.asc.helper.do_insert") + def test_get_asc_insert_func_inner_case2(self, mock_do_insert, mock_get_target_tensors_with_args_indexes): + """ + case2: args_index_list不为None,tgt_key_specs为None + """ + + from mx_rec.core.asc.helper import get_asc_insert_func_inner + + with tf.Graph().as_default(): + mock_do_insert.return_value = {"xxx": tf.constant(1, dtype=tf.int64)} + mock_get_target_tensors_with_args_indexes.return_value = [ + tf.constant([2, 1], dtype=tf.int64), tf.constant([2, 2], dtype=tf.int64) + ] + FeatureSpec.use_timestamp_train = True + map_fn = get_asc_insert_func_inner(args_index_list=[0], table_names=["table1", "table2"]) + self.assertTrue(callable(map_fn)) + + dataset = generate_dataset(Config(batch_size=2, batch_number=2)) + dataset = dataset.map(map_fn) + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + with tf.Session() as sess: + sess.run(iterator.initializer) + sess.run(tf.compat.v1.global_variables_initializer()) + self.assertEqual(sess.run(batch.get("xxx")), 1) + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_target_tensors_with_feature_specs=mock.MagicMock(return_value=None), + find_dangling_table=mock.MagicMock(return_value=["table1"])) + @mock.patch("mx_rec.core.asc.helper.get_target_tensors_with_args_indexes") + @mock.patch("mx_rec.core.asc.helper.do_insert") + def test_get_asc_insert_func_inner_case3(self, mock_do_insert, mock_get_target_tensors_with_args_indexes): + """ + case3: args_index_list不为None,tgt_key_specs为None,splits小于1抛出异常 + """ + + from mx_rec.core.asc.helper import get_asc_insert_func_inner + + with tf.Graph().as_default(): + mock_do_insert.return_value = {"xxx": tf.constant(1, dtype=tf.int64)} + mock_get_target_tensors_with_args_indexes.return_value = [] + FeatureSpec.use_timestamp_train = True + map_fn = get_asc_insert_func_inner(args_index_list=[0], table_names=[]) + self.assertTrue(callable(map_fn)) + + dataset = generate_dataset(Config(batch_size=2, batch_number=2)) + with self.assertRaises(ValueError): + dataset.map(map_fn) + + +class TestDoInsertFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.helper.do_insert'. + """ + + def setUp(self) -> None: + self._dir_name = "./export_graph" + + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + if os.path.isdir(self._dir_name): + shutil.rmtree(self._dir_name) + + @mock.patch.multiple("mx_rec.core.asc.helper", + send_feature_id_request_async=mock.MagicMock(return_value=None), + export_read_emb_key_v2_op=mock.MagicMock(return_value=dict())) + def test_do_insert_case(self): + """ + case: test do_insert + """ + + from mx_rec.core.asc.helper import do_insert + + args = dict() + insert_tensor = [] + splits = [] + table_names = [] + input_dict = dict(dump_graph=True) + + out_batch = do_insert(args, insert_tensor, splits, table_names, input_dict) + self.assertIsInstance(out_batch, dict) + + +class TestSendFeatureIdRequestAsyncFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.helper.send_feature_id_request_async'. + """ + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_use_static=mock.MagicMock(return_value=True), + get_training_mode_channel_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.asc.helper.merge_feature_id_request") + @mock.patch("mx_rec.core.asc.helper.get_host_pipeline_ops") + def test_send_feature_id_request_async_case1(self, mock_get_host_pipeline_ops, mock_merge_feature_id_request): + """ + case1: 静态shape + """ + + from mx_rec.core.asc.helper import send_feature_id_request_async + + with tf.Graph().as_default(): + mock_get_host_pipeline_ops.return_value = MockHostPipeLineOps() + feature_id_list = [tf.constant([2, ], dtype=tf.int64), tf.constant([3, ], dtype=tf.int64)] + mock_merge_feature_id_request.return_value = dict(output_feature_id_list=[feature_id_list[1]], + output_split_list=[1], output_tensorshape_split_list=[1]) + split_list = [2, 3] + table_name_list = ["table1", "table2"] + input_dict = dict(is_training=True, timestamp=True) + + mock_res = send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict) + self.assertEqual(mock_res, 0) + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_use_static=mock.MagicMock(return_value=False), + get_training_mode_channel_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.asc.helper.merge_feature_id_request") + @mock.patch("mx_rec.core.asc.helper.get_host_pipeline_ops") + def test_send_feature_id_request_async_case2(self, mock_get_host_pipeline_ops, mock_merge_feature_id_request): + """ + case2: 动态shape + """ + + from mx_rec.core.asc.helper import send_feature_id_request_async + + with tf.Graph().as_default(): + mock_get_host_pipeline_ops.return_value = MockHostPipeLineOps() + feature_id_list = [tf.constant([2, ], dtype=tf.int64), tf.constant([3, ], dtype=tf.int64)] + mock_merge_feature_id_request.return_value = dict(output_feature_id_list=[feature_id_list[1]], + output_split_list=[1], output_tensorshape_split_list=[1]) + split_list = [2, 3] + table_name_list = ["table1", "table2"] + input_dict = dict(is_training=True, timestamp=True) + + mock_res = send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict) + self.assertEqual(mock_res, 1) + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_use_static=mock.MagicMock(return_value=False), + get_training_mode_channel_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.asc.helper.merge_feature_id_request") + @mock.patch("mx_rec.core.asc.helper.get_host_pipeline_ops") + def test_send_feature_id_request_async_case3(self, mock_get_host_pipeline_ops, mock_merge_feature_id_request): + """ + case3: split_list或tensorshape_split_list为空,抛出异常 + """ + + from mx_rec.core.asc.helper import send_feature_id_request_async + + with tf.Graph().as_default(): + mock_get_host_pipeline_ops.return_value = MockHostPipeLineOps() + feature_id_list = [tf.constant([2, ], dtype=tf.int64), tf.constant([3, ], dtype=tf.int64)] + mock_merge_feature_id_request.return_value = dict(output_feature_id_list=[feature_id_list[1]], + output_split_list=[], output_tensorshape_split_list=[1]) + split_list = [2, 3] + table_name_list = ["table1", "table2"] + input_dict = dict(is_training=True, timestamp=True) + + with self.assertRaises(RuntimeError): + send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict) + + +class TestMergeFeatureIdRequestFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.helper.merge_feature_id_request'. + """ + + def test_merge_feature_id_request_case1(self): + """ + case1: 入参长度不等,抛出异常 + """ + + from mx_rec.core.asc.helper import merge_feature_id_request + + feature_id_list = [1, 2] + split_list = [1, 2] + table_name_list = ["table1"] + + with self.assertRaises(RuntimeError): + merge_feature_id_request(feature_id_list, split_list, table_name_list) + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_modify_graph=mock.MagicMock(return_value=False)) + def test_merge_feature_id_request_case2(self): + """ + case2: 非自动改图 + """ + + from mx_rec.core.asc.helper import merge_feature_id_request + + with tf.Graph().as_default(): + feature_id_list = [ + tf.constant([2, ], dtype=tf.int64), + tf.constant([3, ], dtype=tf.int64), + tf.constant([4, ], dtype=tf.int64) + ] + split_list = [2, 3, 4] + table_name_list = ["table1", "table2", "table1"] + + output_dict = merge_feature_id_request(feature_id_list, split_list, table_name_list) + ''' + 经过排序后:table_name_list = ["table1", "table1", "table2"] + split_list = [2, 4, 3] + feature_id_list = [tf.constant([2, ], dtype=tf.int64), + tf.constant([4, ], dtype=tf.int64), + tf.constant([3, ], dtype=tf.int64)] + 经过merge后:output_feature_id_list = [tf.constant([2, ], dtype=tf.int64), + tf.constant([4, ], dtype=tf.int64), + tf.constant([3, ], dtype=tf.int64)] + output_split_list = [6, 3] + output_table_name_list = ["table1", "table2"] + output_tensorshape_split_list = [Tensor(2), Tensor(1)] # shape + ''' + output_feature_id_list = [ + tf.constant([2, ], dtype=tf.int64), + tf.constant([4, ], dtype=tf.int64), + tf.constant([3, ], dtype=tf.int64) + ] + output_split_list = [6, 3] + output_table_name_list = ["table1", "table2"] + output_tensorshape_split_list = [ + tf.math.reduce_prod(tf.shape(output_feature_id_list[0])) + tf.math.reduce_prod( + tf.shape(output_feature_id_list[1])), + tf.math.reduce_prod(tf.shape(output_feature_id_list[2])) + ] + self.assertListEqual(output_dict.get("output_split_list"), output_split_list) + self.assertListEqual(output_dict.get("output_table_name_list"), output_table_name_list) + + with tf.Session() as sess: + for real_output_feature_id, except_output_feature_id in zip(output_dict.get("output_feature_id_list"), + output_feature_id_list): + self.assertListEqual(sess.run(real_output_feature_id).tolist(), + sess.run(except_output_feature_id).tolist()) + + for real_output_tensorshape_split, except_output_tensorshape_split in zip( + output_dict.get("output_tensorshape_split_list"), output_tensorshape_split_list): + self.assertEqual(sess.run(real_output_tensorshape_split), + sess.run(except_output_tensorshape_split)) + + @mock.patch.multiple("mx_rec.core.asc.helper", + get_modify_graph=mock.MagicMock(return_value=True)) + def test_merge_feature_id_request_case3(self): + """ + case3: 自动改图 + """ + + from mx_rec.core.asc.helper import merge_feature_id_request + + with tf.Graph().as_default(): + feature_id_list = [ + tf.constant([2, ], dtype=tf.int64), + tf.constant([3, ], dtype=tf.int64), + tf.constant([4, ], dtype=tf.int64) + ] + split_list = [2, 3, 4] + table_name_list = ["table1", "table2", "table1"] + + # 自动改图流程和非自动改图类似,只是排序的时候只根据table_name_list排序 + output_dict = merge_feature_id_request(feature_id_list, split_list, table_name_list) + output_feature_id_list = [ + tf.constant([2, ], dtype=tf.int64), + tf.constant([4, ], dtype=tf.int64), + tf.constant([3, ], dtype=tf.int64) + ] + output_split_list = [6, 3] + output_table_name_list = ["table1", "table2"] + output_tensorshape_split_list = [ + tf.math.reduce_prod(tf.shape(output_feature_id_list[0])) + tf.math.reduce_prod( + tf.shape(output_feature_id_list[1])), + tf.math.reduce_prod(tf.shape(output_feature_id_list[2])) + ] + self.assertListEqual(output_dict.get("output_split_list"), output_split_list) + self.assertListEqual(output_dict.get("output_table_name_list"), output_table_name_list) + + with tf.Session() as sess: + for real_output_feature_id, except_output_feature_id in zip(output_dict.get("output_feature_id_list"), + output_feature_id_list): + self.assertListEqual(sess.run(real_output_feature_id).tolist(), + sess.run(except_output_feature_id).tolist()) + + for real_output_tensorshape_split, except_output_tensorshape_split in zip( + output_dict.get("output_tensorshape_split_list"), output_tensorshape_split_list): + self.assertEqual(sess.run(real_output_tensorshape_split), + sess.run(except_output_tensorshape_split)) + + +class TestExportReadEmbKeyV2OpFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.helper.export_read_emb_key_v2_op'. + """ + + def test_export_read_emb_key_v2_op_case1(self): + """ + case1: 入参args长度小于1 + """ + + from mx_rec.core.asc.helper import export_read_emb_key_v2_op + + with self.assertRaises(ValueError): + export_read_emb_key_v2_op([], "") + + def test_export_read_emb_key_v2_op_case2(self): + """ + case2: args = ({"user_ids": tensor1, "item_ids": tensor2}, (tensor3, tensor4)),改图常见场景 + 走 `if isinstance(origin_batch[0], dict)` 分支 + """ + + from mx_rec.core.asc.helper import export_read_emb_key_v2_op + + with tf.Graph().as_default(): + args = ( + dict(user_ids=tf.constant(1, dtype=tf.int64), item_ids=tf.constant(1, dtype=tf.int64)), + (tf.constant(3, dtype=tf.int64), tf.constant(3, dtype=tf.int64)) + ) + pipeline_op = tf.constant(0, dtype=tf.int64) + + output_batch = export_read_emb_key_v2_op(args, pipeline_op) + self.assertIsInstance(output_batch, dict) + read_emb_key = f"{sorted(output_batch)[-1]}" + self.assertEqual(read_emb_key, "user_ids_read_emb_key") + with tf.Session() as sess: + self.assertEqual(sess.run(output_batch.get(read_emb_key)), 0) + + def test_export_read_emb_key_v2_op_case3(self): + """ + case3: args = [tensor1] + 走 `elif len(origin_batch) == 1 and isinstance(origin_batch[0], tf.Tensor)` 分支 + """ + + from mx_rec.core.asc.helper import export_read_emb_key_v2_op + + with tf.Graph().as_default(): + args = [tf.constant(1, dtype=tf.int64)] + pipeline_op = tf.constant(0, dtype=tf.int64) + + # pipeline_op插入到batch的最后一个 + output_batch = export_read_emb_key_v2_op(args, pipeline_op) + self.assertIsInstance(output_batch, tuple) + with tf.Session() as sess: + self.assertEqual(sess.run(output_batch[-1]), 0) + + def test_export_read_emb_key_v2_op_case4(self): + """ + case4: args = [[tensor1], tensor2] + 走 `elif len(origin_batch) == 2` 分支 + 走 `if isinstance(origin_batch[0], (list, tuple))` 分支 + """ + + from mx_rec.core.asc.helper import export_read_emb_key_v2_op + + with tf.Graph().as_default(): + args = [[tf.constant(1, dtype=tf.int64)], tf.constant(2, dtype=tf.int64)] + pipeline_op = tf.constant(0, dtype=tf.int64) + + # pipeline_op插入到batch[0]的最后一个 + output_batch = export_read_emb_key_v2_op(args, pipeline_op) + self.assertIsInstance(output_batch, tuple) + with tf.Session() as sess: + self.assertEqual(sess.run(output_batch[0][-1]), 0) + + def test_export_read_emb_key_v2_op_case5(self): + """ + case5: args = [tensor1, tensor2] + 走 `elif len(origin_batch) == 2` 分支 + 走 `elif isinstance(origin_batch[0], tf.Tensor)` 分支 + """ + + from mx_rec.core.asc.helper import export_read_emb_key_v2_op + + with tf.Graph().as_default(): + args = [tf.constant(1, dtype=tf.int64), tf.constant(2, dtype=tf.int64)] + pipeline_op = tf.constant(0, dtype=tf.int64) + + # pipeline_op插入到batch[0]的最后一个 + output_batch = export_read_emb_key_v2_op(args, pipeline_op) + self.assertIsInstance(output_batch, tuple) + with tf.Session() as sess: + self.assertEqual(sess.run(output_batch[0][-1]), 0) + + def test_export_read_emb_key_v2_op_case6(self): + """ + case6: args = [set(tensor1), tensor2],args[0]不是list、tuple、tensor,抛出异常 + """ + + from mx_rec.core.asc.helper import export_read_emb_key_v2_op + + with tf.Graph().as_default(): + args = [{tf.constant(1, dtype=tf.int64)}, tf.constant(2, dtype=tf.int64)] + pipeline_op = tf.constant(0, dtype=tf.int64) + + with self.assertRaises(EnvironmentError): + export_read_emb_key_v2_op(args, pipeline_op) + + def test_export_read_emb_key_v2_op_case7(self): + """ + case7: args = [tensor1, tensor2, tensor3],pipeline_op是列表[tensor] + """ + + from mx_rec.core.asc.helper import export_read_emb_key_v2_op + + with tf.Graph().as_default(): + args = [tf.constant(1, dtype=tf.int64), tf.constant(2, dtype=tf.int64), tf.constant(3, dtype=tf.int64)] + pipeline_op = [tf.constant(0, dtype=tf.int64)] + + # (pipeline_op,)插入到batch的最后一个 + output_batch = export_read_emb_key_v2_op(args, pipeline_op) + self.assertIsInstance(output_batch, tuple) + with tf.Session() as sess: + self.assertEqual(sess.run(output_batch[-1][0]), 0) + + +class TestGetTargetTensorsWithArgsIndexesFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.helper.get_target_tensors_with_args_indexes'. + """ + + def test_get_target_tensors_with_args_indexes(self): + """ + case: 此函数只能通过map_fn进行测试,否则`graph.get_tensor_by_name("args_%d:0" % index)`会直接报错 + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_args_indexes + + with tf.Graph().as_default(): + args_index_list = [0, 1] + batch_size = 2 + dataset_config = Config(batch_size=batch_size, batch_number=5) + + def _map_fn(args): + insert_tensors = get_target_tensors_with_args_indexes(args_index_list) + args["new_item_ids"] = insert_tensors[0] + args["new_label_0"] = insert_tensors[1] + return args + + dataset = generate_dataset(dataset_config) + dataset = dataset.map(_map_fn) + ''' + 原始batch: + batch = {"item_ids": Tensor([2, 8]), "label_0": Tensor([2, ])} + map后batch: + batch = {"item_ids": Tensor([2, 8]), "label_0": Tensor([2, ]), + "new_item_ids": Tensor([16, ]), "new_label_0": Tensor([2, ])} + ''' + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + with tf.Session() as sess: + sess.run(iterator.initializer) + sess.run(tf.compat.v1.global_variables_initializer()) + self.assertEqual(sess.run(tf.shape(batch.get("new_item_ids"))), + batch_size * dataset_config.item_feat_cnt) + self.assertEqual(sess.run(tf.shape(batch.get("new_label_0"))), batch_size) + + +class TestGetTargetTensorsWithFeatureSpecFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.helper.get_target_tensors_with_feature_specs'. + """ + + def setUp(self): + # 每个测试方法执行前,将FeatureSpec的静态成员设为默认值 + FeatureSpec.instance_count_train = 0 + FeatureSpec.instance_count_eval = 0 + FeatureSpec.use_timestamp_train = False + FeatureSpec.use_timestamp_eval = False + + def test_get_target_tensors_with_feature_specs_case1(self): + """ + case1: tgt_key_specs不属于dict、list、tuple和FeatureSpec,抛出异常 + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with self.assertRaises(ValueError): + get_target_tensors_with_feature_specs(tgt_key_specs=None, batch=dict(), is_training=True, + read_emb_key_inputs_dict=dict()) + + @mock.patch.multiple("mx_rec.core.asc.helper", + is_feature_spec_list=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.asc.helper.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value={ + "tensor": tf.constant(1, dtype=tf.int32), "table_name": "table1", "split": 1})) + def test_get_target_tensors_with_feature_specs_case2(self): + """ + case2: tgt_key_specs为list或tuple,并由FeatureSpec组成,batch为dict + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with tf.Graph().as_default(): + batch = {"ids1": tf.constant(1, dtype=tf.int64), "timestamp": tf.constant(2, dtype=tf.int64)} + tgt_key_specs = [ + FeatureSpec("ids1", table_name="table1"), + FeatureSpec("timestamp", table_name="table2", is_timestamp=True) + ] + is_training = True + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } + + get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + self.assertListEqual(read_emb_key_inputs_dict.get("table_names"), ["table1"]) + self.assertListEqual(read_emb_key_inputs_dict.get("feature_spec_names"), [tgt_key_specs[0].name]) + self.assertListEqual(read_emb_key_inputs_dict.get("splits"), [1]) + # insert_tensors中timestamp会在第一个,如:[timestamp, ids1] + insert_tensors = read_emb_key_inputs_dict.get("insert_tensors") + self.assertEqual(len(insert_tensors), len(tgt_key_specs)) + with tf.Session() as sess: + self.assertEqual(sess.run(insert_tensors[0]), 2) # timestamp + + @mock.patch.multiple("mx_rec.core.asc.helper", + is_feature_spec_list=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.asc.helper.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value={ + "tensor": tf.constant(1, dtype=tf.int64), "table_name": "table1", "split": 1})) + def test_get_target_tensors_with_feature_specs_case3(self): + """ + case3: tgt_key_specs为list或tuple,并由FeatureSpec组成,batch为dict + batch中某个value不为tensor,抛出异常 + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with tf.Graph().as_default(): + batch = {"ids1": 1, "timestamp": tf.constant(2, dtype=tf.int64)} + tgt_key_specs = [ + FeatureSpec("ids1", table_name="table1"), + FeatureSpec("timestamp", table_name="table2", is_timestamp=True) + ] + is_training = True + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } + + with self.assertRaises(TypeError): + get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + + @mock.patch.multiple("mx_rec.core.asc.helper", + is_feature_spec_list=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.asc.helper.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value={ + "tensor": tf.constant(1, dtype=tf.int64), "table_name": "table1", "split": 1})) + def test_get_target_tensors_with_feature_specs_case4(self): + """ + case4: tgt_key_specs为list或tuple,并由FeatureSpec组成,batch为dict + tgt_key_specs中有timestamp,但batch中没有timestamp,抛出异常 + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with tf.Graph().as_default(): + batch = {"ids1": tf.constant(1, dtype=tf.int64)} + tgt_key_specs = [ + FeatureSpec("ids1", table_name="table1"), + FeatureSpec("timestamp", table_name="table2", is_timestamp=True) + ] + is_training = True + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } + + with self.assertRaises(KeyError): + get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + + @mock.patch.multiple("mx_rec.core.asc.helper", + is_feature_spec_list=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.asc.helper.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value={ + "tensor": tf.constant(1, dtype=tf.int64), "table_name": "table1", "split": 1})) + def test_get_target_tensors_with_feature_specs_case5(self): + """ + case5: tgt_key_specs为list或tuple,并由FeatureSpec组成,batch为set,抛出异常 + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with tf.Graph().as_default(): + batch = {tf.constant(1, dtype=tf.int64)} + tgt_key_specs = [ + FeatureSpec("ids1", table_name="table1"), + FeatureSpec("timestamp", table_name="table2", is_timestamp=True) + ] + is_training = True + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } + + with self.assertRaises(ValueError): + get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + + @mock.patch.multiple("mx_rec.core.asc.helper", + is_feature_spec_list=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.asc.helper.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value={ + "tensor": tf.constant(1, dtype=tf.int64), "table_name": "table1", "split": 1})) + def test_get_target_tensors_with_feature_specs_case6(self): + """ + case6: tgt_key_specs为list或tuple,并由FeatureSpec组成,batch为list + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with tf.Graph().as_default(): + batch = [tf.constant(1, dtype=tf.int64), tf.constant(2, dtype=tf.int64)] + tgt_key_specs = [ + FeatureSpec("ids1", index_key=0, table_name="table1"), + FeatureSpec("timestamp", index_key=1, table_name="table2", is_timestamp=True) + ] + is_training = True + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } + + get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + self.assertListEqual(read_emb_key_inputs_dict.get("table_names"), ["table1"]) + self.assertListEqual(read_emb_key_inputs_dict.get("feature_spec_names"), [tgt_key_specs[0].name]) + self.assertListEqual(read_emb_key_inputs_dict.get("splits"), [1]) + # insert_tensors中timestamp会在第一个,如:[timestamp, ids1] + insert_tensors = read_emb_key_inputs_dict.get("insert_tensors") + self.assertEqual(len(insert_tensors), len(tgt_key_specs)) + with tf.Session() as sess: + self.assertEqual(sess.run(insert_tensors[0]), 2) # timestamp + + @mock.patch.multiple("mx_rec.core.asc.helper", + is_feature_spec_list=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.asc.helper.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value={ + "tensor": tf.constant(1, dtype=tf.int64), "table_name": "table1", "split": 1})) + def test_get_target_tensors_with_feature_specs_case7(self): + """ + case7: tgt_key_specs为list或tuple,并由FeatureSpec组成,batch为list + index_key的长度大于len(batch),抛出异常 + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with tf.Graph().as_default(): + batch = [tf.constant(1, dtype=tf.int64), tf.constant(2, dtype=tf.int64)] + tgt_key_specs = [ + FeatureSpec("ids1", index_key=3, table_name="table1"), + FeatureSpec("timestamp", index_key=1, table_name="table2", is_timestamp=True) + ] + is_training = True + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } + + with self.assertRaises(ValueError): + get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + + @mock.patch.multiple("mx_rec.core.asc.helper", + is_feature_spec_list=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.asc.helper.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value={ + "tensor": tf.constant(1, dtype=tf.int64), "table_name": "table1", "split": 1})) + def test_get_target_tensors_with_feature_specs_case8(self): + """ + case8: tgt_key_specs为list或tuple,并由FeatureSpec组成,batch为list + batch中不为tensor,抛出异常 + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with tf.Graph().as_default(): + batch = [1, tf.constant(2, dtype=tf.int64)] + tgt_key_specs = [ + FeatureSpec("ids1", index_key=0, table_name="table1"), + FeatureSpec("timestamp", index_key=1, table_name="table2", is_timestamp=True) + ] + is_training = True + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } + + with self.assertRaises(TypeError): + get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + + @mock.patch.multiple("mx_rec.core.asc.helper", + is_feature_spec_list=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.asc.helper.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value={ + "tensor": tf.constant(1, dtype=tf.int64), "table_name": "table1", "split": 1})) + def test_get_target_tensors_with_feature_specs_case9(self): + """ + case9: tgt_key_specs为dict,并由FeatureSpec组成,batch为dict + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with tf.Graph().as_default(): + batch = { + "ids1": {"ids1": tf.constant(1, dtype=tf.int64), "timestamp": tf.constant(2, dtype=tf.int64)}, + "timestamp": {"ids1": tf.constant(1, dtype=tf.int64), "timestamp": tf.constant(2, dtype=tf.int64)} + } + tgt_key_specs = { + "ids1": FeatureSpec("ids1", table_name="table1"), + "timestamp": FeatureSpec("timestamp", table_name="table1", is_timestamp=True), + } + is_training = True + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } + + get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + self.assertListEqual(read_emb_key_inputs_dict.get("table_names"), ["table1"]) + self.assertListEqual(read_emb_key_inputs_dict.get("feature_spec_names"), [tgt_key_specs.get("ids1").name]) + self.assertListEqual(read_emb_key_inputs_dict.get("splits"), [1]) + # insert_tensors中timestamp会在第一个,如:[timestamp, ids1] + insert_tensors = read_emb_key_inputs_dict.get("insert_tensors") + self.assertEqual(len(insert_tensors), len(tgt_key_specs)) + with tf.Session() as sess: + self.assertEqual(sess.run(insert_tensors[0]), 2) # timestamp + + @mock.patch.multiple("mx_rec.core.asc.helper", + is_feature_spec_list=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.asc.helper.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value={ + "tensor": tf.constant(1, dtype=tf.int64), "table_name": "table1", "split": 1})) + def test_get_target_tensors_with_feature_specs_case10(self): + """ + case10: tgt_key_specs为list或tuple,并由FeatureSpec组成,batch也为list或tuple + """ + + from mx_rec.core.asc.helper import get_target_tensors_with_feature_specs + + with tf.Graph().as_default(): + batch = [tf.constant(1, dtype=tf.int64), tf.constant(2, dtype=tf.int64)] + tgt_key_specs = [ + FeatureSpec("ids1", index_key=0, table_name="table1"), + FeatureSpec("timestamp", index_key=1, table_name="table2", is_timestamp=True) + ] + is_training = True + read_emb_key_inputs_dict = { + "insert_tensors": [], + "table_names": [], + "feature_spec_names": [], + "splits": [] + } + + get_target_tensors_with_feature_specs(tgt_key_specs, batch, is_training, read_emb_key_inputs_dict) + self.assertListEqual(read_emb_key_inputs_dict.get("table_names"), ["table1"]) + self.assertListEqual(read_emb_key_inputs_dict.get("feature_spec_names"), [tgt_key_specs[0].name]) + self.assertListEqual(read_emb_key_inputs_dict.get("splits"), [1]) + # insert_tensors中timestamp会在第一个,如:[timestamp, ids1] + insert_tensors = read_emb_key_inputs_dict.get("insert_tensors") + self.assertEqual(len(insert_tensors), len(tgt_key_specs)) + with tf.Session() as sess: + self.assertEqual(sess.run(insert_tensors[0]), 2) # timestamp + + +class TestIsFeatureSpecListFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.helper.is_feature_spec_list'. + """ + + def test_is_feature_spec_list_case1(self): + """ + case1: specs不为list或tuple,返回False + """ + + from mx_rec.core.asc.helper import is_feature_spec_list + + self.assertFalse(is_feature_spec_list(dict())) + + def test_is_feature_spec_list_case2(self): + """ + case2: specs为list或tuple,但不为FeatureSpec,返回False + """ + + from mx_rec.core.asc.helper import is_feature_spec_list + + self.assertFalse(is_feature_spec_list(["xxx"])) + + def test_is_feature_spec_list_case3(self): + """ + case3: specs为list或tuple,且为FeatureSpec,返回True + """ + + from mx_rec.core.asc.helper import is_feature_spec_list + + self.assertTrue(is_feature_spec_list([FeatureSpec("xxx")])) + + +if __name__ == '__main__': + unittest.main() -- Gitee From f85b94c15c5a1accbe5c5f42e229d6dcf43544e3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Dec 2023 21:02:11 +0800 Subject: [PATCH 517/551] Match-id-99e244561816f95ee1261e4fe9535a8d1579d308 --- mx_rec/core/embedding.py | 26 +- mx_rec/util/initialize.py | 25 + tests/mx_rec/core/mock_class.py | 66 ++- tests/mx_rec/core/test_embedding.py | 831 ++++++++++++++++++++++++++++ 4 files changed, 927 insertions(+), 21 deletions(-) create mode 100644 tests/mx_rec/core/test_embedding.py diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 025c1d22..7d96eb6c 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -24,11 +24,11 @@ from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, ge insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set, \ - get_table_instance_by_name, get_asc_manager + get_asc_manager, get_table_name_to_feature_spec, clear_same_table_feature_spec from mx_rec.validator.validator import ClassValidator, StringValidator, SSDFeatureValidator, \ para_checker_decorator, IntValidator, NumValidator, OptionValidator, OptionalIntValidator, \ OptionalStringValidator, FloatValidator -from mx_rec.util.tf_version_adapter import npu_ops +from mx_rec.util.tf_version_adapter import hccl_ops from mx_rec.util.normalization import fix_invalid_table_name from mx_rec.util.global_env_conf import global_env from mx_rec.util.log import logger @@ -122,7 +122,6 @@ class SparseEmbedding: self._slot_num = dict() self._send_count = 0 self.same_table_send_count = 0 - self._use_feature_mapping = False self.skip_emb_transfer = True if self.host_vocabulary_size <= 0 else False self._default_name_count = -1 self.emb_size = None @@ -155,10 +154,6 @@ class SparseEmbedding: self.set_ext_emb_size() tf.compat.v1.add_to_collection(get_ascend_global_hashtable_collection(), self.variable) - @property - def use_feature_mapping(self): - return self._use_feature_mapping - @property def scalar_emb_size(self): return self.emb_size @@ -181,7 +176,7 @@ class SparseEmbedding: """ Args: channel_id: channel id 0 for train,1 for eval - Returns: npu_ops.outfeed_enqueue_op notify preprocess step + Returns: tf.no_op notify preprocess step """ channel_name = "d2h_notify_hybridmgmt_{}".format(channel_id) notify_hybridmgmt_op = tf.no_op(channel_name) @@ -190,14 +185,14 @@ class SparseEmbedding: @staticmethod def get_anchor_attribute(anchor, attr): if not isinstance(anchor, tf.Tensor): - raise ValueError("Anchor must be a Tensor.") + raise TypeError("Anchor must be a Tensor.") if attr not in ASCAnchorAttr: raise ValueError("Given attr must be limited in Enum 'ASCAnchorAttr'.") specs = SparseEmbedding.anchor_tensor_specs.get(anchor) if specs is None: - raise ValueError(f"Given anchor '{anchor}' was not registered.") + raise KeyError(f"Given anchor '{anchor}' was not registered.") return specs.get(attr) @@ -220,7 +215,7 @@ class SparseEmbedding: :param use_static: enable static shape training or not :return: local embedding after all2all """ - from mx_rec.util.tf_version_adapter import hccl_ops + rank_size = get_rank_size() rank_id = get_rank_id() @@ -309,9 +304,6 @@ class SparseEmbedding: logger.debug("getting one default lookup name %s", default_name) return default_name - def set_using_feature_mapping(self): - self._use_feature_mapping = True - def set_emb_size(self): self.emb_size = self.embedding_size.as_list()[0] @@ -491,7 +483,7 @@ class SparseEmbedding: if not get_use_static() and not self.modify_graph and kwargs.get("batch") is None: raise RuntimeError("When the 'feature spec' mode and 'dynamic shape' are used, the 'batch' is required.") table_name = feature_spec.table_name - same_table_feature_spec = ConfigInitializer.get_instance().table_name_to_feature_spec[table_name][is_training] + same_table_feature_spec = get_table_name_to_feature_spec(table_name, is_training) logger.debug("The feature spec of the same table is %s, table name is %s.", [fs.name for fs in same_table_feature_spec], self.table_name) same_table_spec_count = len(same_table_feature_spec) @@ -548,7 +540,7 @@ class SparseEmbedding: self.split_lookup_result(same_table_feature_spec, tensor_split_list, tensor_list, lookup_result, is_training) # 当一表多查完成后,将此表对应的feature specs列表清空,便于estimator模式下多轮eval时不会累加上轮eval的feature specs - ConfigInitializer.get_instance().clear_same_table_feature_spec(self.table_name, is_training) + clear_same_table_feature_spec(self.table_name, is_training) if not self.modify_graph: self.check_multi_lookup_times(is_training) @@ -686,7 +678,7 @@ class SparseEmbedding: dest_shape = array_ops.concat([array_ops.shape(tensor), [self.scalar_emb_size]], 0) lookup_result = array_ops.reshape(embeddings, dest_shape) - def grad(lookup_diff): + def grad(lookup_diff): # pragma: no cover logger.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index 5005f647..e33a12c3 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -835,6 +835,31 @@ def set_iterator_type(iterator_type: str): ConfigInitializer.get_instance().iterator_type = iterator_type +def get_table_name_to_feature_spec(table_name: str, is_training: bool): + """ + 获取同一张表的所有FeatureSpec + Args: + table_name: 表名 + is_training: 是否为训练模式 + Returns: FeatureSpec列表 + """ + + same_table_feature_spec_dict = ConfigInitializer.get_instance().table_name_to_feature_spec.get(table_name) + return same_table_feature_spec_dict.get(is_training) + + +def clear_same_table_feature_spec(table_name: str, is_training: bool): + """ + 将表对应的feature specs列表清空 + Args: + table_name: 表名 + is_training: 是否为训练模式 + Returns: None + """ + + ConfigInitializer.get_instance().clear_same_table_feature_spec(table_name, is_training) + + def set_ascend_env(): """ 配置昇腾相关的参数和环境变量,生成hccl配置 diff --git a/tests/mx_rec/core/mock_class.py b/tests/mx_rec/core/mock_class.py index 5fe33363..58737058 100644 --- a/tests/mx_rec/core/mock_class.py +++ b/tests/mx_rec/core/mock_class.py @@ -3,6 +3,9 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import tensorflow as tf +from tensorflow_core.python.training import slot_creator + +from mx_rec.optimizers.lazy_adam import CustomizedLazyAdam class MockSparseEmbedding: @@ -58,16 +61,59 @@ class MockHcclOps: self.all_to_all_v_c = _mock_all_to_all_v_c -class MockOptimizer: +class MockOptimizer(CustomizedLazyAdam): """ 用于mock optimizer """ def __init__(self): - def _mock_insert_slot(slot, named_slot_key, slot_name): - return "mock_insert_slot" + super(MockOptimizer, self)._get_name(name="MockLazyAdam") + super(MockOptimizer, self).__init__(learning_rate=0.001, beta1=0.9, beta2=0.999, + epsilon=1e-8, use_locking=False, name="MockLazyAdam") + self.slot_num = 2 + + def initialize_slots(self, var, table_instance): + # Create slots for the first and second moments. + def creat_one_single_slot(var, op_name): + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + return new_slot_variable + + momentum = creat_one_single_slot(var, self._name + "/" + "momentum") + velocity = creat_one_single_slot(var, self._name + "/" + "velocity") + named_slot_key = (var.op.graph, var.op.name) + + table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) + return [{"slot": momentum, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}, + {"slot": velocity, "named_slot_key": named_slot_key, "slot_name": "v", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + pass + + def get_slot_init_values(self): + initial_momentum_value = 0.0 + initial_velocity_value = 0.0 + return [initial_momentum_value, initial_velocity_value] + + def update_op(self, optimizer, g): + return super().update_op(optimizer, g) + + def _apply_spare_duplicate_indices(self, grad, var): + return self._apply_sparse(grad, var) + + def _apply_sparse(self, grad, var): + return super()._apply_sparse(grad, var) - self.insert_slot = _mock_insert_slot + def _resource_apply_sparse(self, grad, handle, indices): + return super()._resource_apply_sparse(grad, handle, indices) + + def _apply_dense(self, grad, var): + return super()._apply_dense(grad, var) + + def _apply_sparse_duplicate_indices(self, grad, var): + return self._apply_sparse(grad, var) + + def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): + return self._resource_apply_sparse(grad, handle, indices) class MockAscManager: @@ -84,3 +130,15 @@ class MockAscManager: self.get_table_size = _mock_get_table_size self.get_table_capacity = _mock_get_table_capacity + + +class MockHybridMgmt: + """ + 用于mock HybridMgmt() + """ + + def __init__(self, is_initialized=True): + def _mock_initialize(rank_info=0, emb_info=1, if_load=False, threshold_values=3): + return is_initialized + + self.initialize = _mock_initialize diff --git a/tests/mx_rec/core/test_embedding.py b/tests/mx_rec/core/test_embedding.py new file mode 100644 index 00000000..0ce32500 --- /dev/null +++ b/tests/mx_rec/core/test_embedding.py @@ -0,0 +1,831 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +import os +import unittest +from unittest import mock + +import tensorflow as tf + +from mx_rec.core.asc import FeatureSpec +from mx_rec.core.asc.feature_spec import set_temporary_feature_spec_attribute +from mx_rec.core.embedding import SparseEmbedding +from mx_rec.constants.constants import All2allGradientsOp, ASCAnchorAttr +from tests.mx_rec.core.mock_class import MockSparseEmbedding, MockOptimizer, MockHcclOps, MockAscManager + + +class TestCreateTableFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.embedding.create_table'. + """ + + @mock.patch.multiple("mx_rec.core.embedding", + fix_invalid_table_name=mock.MagicMock(return_value="table1"), + SparseEmbedding=mock.MagicMock(return_value=MockSparseEmbedding())) + def test_create_table(self): + """ + case: test create_table + """ + + from mx_rec.core.embedding import create_table + + test_table = create_table(key_dtype=tf.int64, + dim=tf.TensorShape([8]), + name='test_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer()) + self.assertIsInstance(test_table, MockSparseEmbedding) + + +class TestSparseEmbeddingClass(unittest.TestCase): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding'. + """ + + def setUp(self): + key_dtype = tf.int64 + dim = 8 + name = 'test_table' + emb_initializer = tf.compat.v1.truncated_normal_initializer() + optimizer_list = [MockOptimizer()] + device_vocabulary_size = 1 + host_vocabulary_size = 2 + ssd_vocabulary_size = 3 + ssd_data_path = (os.getcwd(),) + is_save = True + init_param = 1. + all2all_gradients_op = All2allGradientsOp.SUM_GRADIENTS.value + + self.config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, + device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, + ssd_vocabulary_size=ssd_vocabulary_size, ssd_data_path=ssd_data_path, + optimizer_list=optimizer_list, init_param=init_param, is_save=is_save, + all2all_gradients_op=all2all_gradients_op) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None)) + def test_init(self): + """ + case: test create SparseEmbedding + + """ + + with tf.Graph().as_default(): + test_sparse_emb = SparseEmbedding(self.config) + self.assertIsInstance(test_sparse_emb, SparseEmbedding) + + +class TestGenerateLookupIdNotifyHybridFuncOfSparseEmbeddingClass(unittest.TestCase): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding.generate_lookup_id_notify_hybrid'. + """ + + def test_generate_lookup_id_notify_hybrid(self): + """ + case: test generate_lookup_id_notify_hybrid + """ + + with tf.Graph().as_default(): + self.assertEqual(SparseEmbedding.generate_lookup_id_notify_hybrid(0).name, "d2h_notify_hybridmgmt_0") + + +class TestGetAnchorAttributeFuncOfSparseEmbeddingClass(unittest.TestCase): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding.get_anchor_attribute'. + """ + + def test_get_anchor_attribute_case1(self): + """ + case1: 功能正常 + """ + + with tf.Graph().as_default(): + anchor_ids = tf.constant(1, dtype=tf.int64) + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = True + self.assertTrue(SparseEmbedding.get_anchor_attribute(anchor_ids, ASCAnchorAttr.IS_TRAINING)) + + def test_get_anchor_attribute_case2(self): + """ + case2: anchor_ids不是tensor,抛出异常 + """ + + with tf.Graph().as_default(): + anchor_ids = 1 + SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = True + with self.assertRaises(TypeError): + SparseEmbedding.get_anchor_attribute(anchor_ids, ASCAnchorAttr.IS_TRAINING) + + def test_get_anchor_attribute_case3(self): + """ + case3: attr不是ASCAnchorAttr,抛出异常 + """ + + with tf.Graph().as_default(): + anchor_ids = tf.constant(1, dtype=tf.int64) + SparseEmbedding.anchor_tensor_specs[anchor_ids]["xxx"] = True + with self.assertRaises(ValueError): + SparseEmbedding.get_anchor_attribute(anchor_ids, "xxx") + + def test_get_anchor_attribute_case4(self): + """ + case4: 没有set直接get,抛出异常 + """ + + with tf.Graph().as_default(): + anchor_ids = tf.constant(1, dtype=tf.int64) + with self.assertRaises(KeyError): + SparseEmbedding.get_anchor_attribute(anchor_ids, ASCAnchorAttr.IS_TRAINING) + + +class TestGetOwnEmbFuncOfSparseEmbeddingClass(unittest.TestCase): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding._get_own_emb'. + """ + + @mock.patch.multiple("mx_rec.core.embedding", + get_rank_size=mock.MagicMock(return_value=1), + get_rank_id=mock.MagicMock(return_value=0), + hccl_ops=MockHcclOps()) + def test_get_own_emb_case1(self): + """ + case1: rank=1,静态shape + """ + + with tf.Graph().as_default(): + src_emb = tf.constant([2, 1], dtype=tf.float32, name="src_emb") + all2all_args = 2 + emb_size = 1 + use_static = True + + # reshape_info为[2, 1] + own_emb = SparseEmbedding._get_own_emb(src_emb, all2all_args, emb_size, use_static) + self.assertListEqual(own_emb.shape.as_list(), [2, 1]) + + @mock.patch.multiple("mx_rec.core.embedding", + get_rank_size=mock.MagicMock(return_value=8), + get_rank_id=mock.MagicMock(return_value=0), + hccl_ops=MockHcclOps(shape=[2 * 8, 1])) + def test_get_own_emb_case2(self): + """ + case2: rank=8,静态shape + """ + + with tf.Graph().as_default(): + src_emb = tf.constant([2, 1], dtype=tf.float32, name="src_emb") + all2all_args = 2 + emb_size = 1 + use_static = True + mock_shape = [all2all_args * 8, emb_size] + + own_emb = SparseEmbedding._get_own_emb(src_emb, all2all_args, emb_size, use_static) + self.assertListEqual(own_emb.shape.as_list(), mock_shape) + + @mock.patch.multiple("mx_rec.core.embedding", + get_rank_size=mock.MagicMock(return_value=8), + get_rank_id=mock.MagicMock(return_value=0), + hccl_ops=MockHcclOps(shape=[2 * 8, 1])) + def test_get_own_emb_case3(self): + """ + case3: rank=8,动态shape + """ + + with tf.Graph().as_default(): + src_emb = tf.constant([2, 1], dtype=tf.float32, name="src_emb") + all2all_args = 2 + emb_size = 1 + use_static = False + mock_shape = [all2all_args * 8, emb_size] + + own_emb = SparseEmbedding._get_own_emb(src_emb, all2all_args, emb_size, use_static) + self.assertListEqual(own_emb.shape.as_list(), mock_shape) + + +class TestSizeFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding.size'. + """ + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_asc_manager=mock.MagicMock(return_value=MockAscManager())) + def test_size(self): + """ + case: test size + """ + + with tf.Graph().as_default(): + test_sparse_emb = SparseEmbedding(self.config) + self.assertEqual(test_sparse_emb.size(), 0) + + +class TestCapacityFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding.capacity'. + """ + + def tearDown(self): + self.config["device_vocabulary_size"] = 1 + self.config["host_vocabulary_size"] = 2 + self.config["ssd_vocabulary_size"] = 3 + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=True), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_asc_manager=mock.MagicMock(return_value=MockAscManager())) + def test_capacity_case1(self): + """ + case1: 开启动态扩容,HBM + + """ + + with tf.Graph().as_default(): + self.config["host_vocabulary_size"] = 0 + self.config["ssd_vocabulary_size"] = 0 + test_sparse_emb = SparseEmbedding(self.config) + self.assertEqual(test_sparse_emb.capacity(), 1) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None)) + def test_capacity_case2(self): + """ + case2: 关闭动态扩容,HBM + """ + + with tf.Graph().as_default(): + self.config["host_vocabulary_size"] = 0 + self.config["ssd_vocabulary_size"] = 0 + test_sparse_emb = SparseEmbedding(self.config) + self.assertEqual(test_sparse_emb.capacity(), 1) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None)) + def test_capacity_case3(self): + """ + case3: 关闭动态扩容,DDR + """ + + with tf.Graph().as_default(): + self.config["ssd_vocabulary_size"] = 0 + test_sparse_emb = SparseEmbedding(self.config) + self.assertEqual(test_sparse_emb.capacity(), 3) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None)) + def test_capacity_case4(self): + """ + case4: 关闭动态扩容,SSD + """ + + with tf.Graph().as_default(): + test_sparse_emb = SparseEmbedding(self.config) + self.assertEqual(test_sparse_emb.capacity(), 6) + + +class TestGetDefaultLookupNameFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding.get_default_lookup_name'. + """ + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None)) + def test_get_default_lookup_name(self): + """ + case: test get_default_lookup_name + """ + + with tf.Graph().as_default(): + test_sparse_emb = SparseEmbedding(self.config) + self.assertEqual(test_sparse_emb.get_default_lookup_name(), "sparse_lookup_0") + + +class TestLookupForAscFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding.lookup_for_asc'. + """ + + def tearDown(self): + self.config["device_vocabulary_size"] = 1 + self.config["host_vocabulary_size"] = 2 + self.config["ssd_vocabulary_size"] = 3 + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch.multiple("mx_rec.core.embedding.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value=None)) + def test_lookup_for_asc_case1(self): + """ + case1: test lookup_for_asc,静态shape + """ + + with tf.Graph().as_default(): + def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): + return tf.constant(1, dtype=tf.int64) + + self.config["device_vocabulary_size"] = 100 * 8 + self.config["host_vocabulary_size"] = 100 * 8 + test_sparse_emb = SparseEmbedding(self.config) + test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner + ids = tf.ones(shape=[2, 1], dtype=tf.int64, name="ids") + send_count = 1 + kwargs = {"is_train": True} + + lookup_res = test_sparse_emb.lookup_for_asc(ids, send_count, **kwargs) + with tf.Session() as sess: + self.assertEqual(sess.run(lookup_res), 1) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=True), + get_name_to_var_dict=mock.MagicMock(return_value={"test_table": 1}), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_training_mode_channel_id=mock.MagicMock(return_value=None), + clear_channel=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=False)) + @mock.patch.multiple("mx_rec.core.embedding.FeatureSpec", + set_feat_attribute=mock.MagicMock(return_value=None)) + def test_lookup_for_asc_case2(self): + """ + case2: test lookup_for_asc,动态shape,is_training=False + """ + + with tf.Graph().as_default(): + def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): + return tf.constant(1, dtype=tf.int64) + + self.config["device_vocabulary_size"] = 100 * 8 + self.config["host_vocabulary_size"] = 100 * 8 + test_sparse_emb = SparseEmbedding(self.config) + test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner + ids = tf.ones(shape=[2, 1], dtype=tf.int64, name="ids") + send_count = 1 + kwargs = {"is_train": False} + + lookup_res = test_sparse_emb.lookup_for_asc(ids, send_count, **kwargs) + with tf.Session() as sess: + self.assertEqual(sess.run(lookup_res), 1) + + +class TestLookupForAscWithFeatureSpecFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding.lookup_for_asc_with_feature_spec'. + """ + + def tearDown(self): + self.config["device_vocabulary_size"] = 1 + self.config["host_vocabulary_size"] = 2 + self.config["ssd_vocabulary_size"] = 3 + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.embedding.get_table_name_to_feature_spec") + def test_lookup_for_asc_with_feature_spec_case1(self, mock_get_table_name_to_feature_spec): + """ + case1: test lookup_for_asc_with_feature_spec,静态shape,len(same_table_feature_spec)=1 + """ + + with tf.Graph().as_default(): + def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): + return tf.constant(1, dtype=tf.int64) + + self.config["device_vocabulary_size"] = 100 * 8 + self.config["host_vocabulary_size"] = 100 * 8 + test_sparse_emb = SparseEmbedding(self.config) + test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner + case1_feat = FeatureSpec("case1_feat", table_name="test_table") + set_temporary_feature_spec_attribute(case1_feat, 1) + mock_get_table_name_to_feature_spec.return_value = [case1_feat] + send_count = 1 + kwargs = {"is_train": True} + + lookup_res = test_sparse_emb.lookup_for_asc_with_feature_spec(case1_feat, send_count, **kwargs) + with tf.Session() as sess: + self.assertEqual(sess.run(lookup_res), 1) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.embedding.get_table_name_to_feature_spec") + def test_lookup_for_asc_with_feature_spec_case2(self, mock_get_table_name_to_feature_spec): + """ + case2: test lookup_for_asc_with_feature_spec,静态shape,len(same_table_feature_spec)=0,抛出异常 + """ + + with tf.Graph().as_default(): + def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): + return tf.constant(1, dtype=tf.int64) + + self.config["device_vocabulary_size"] = 100 * 8 + self.config["host_vocabulary_size"] = 100 * 8 + test_sparse_emb = SparseEmbedding(self.config) + test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner + case2_feat = FeatureSpec("case2_feat", table_name="test_table") + set_temporary_feature_spec_attribute(case2_feat, 1) + mock_get_table_name_to_feature_spec.return_value = [] + send_count = 1 + kwargs = {"is_train": True} + + with self.assertRaises(RuntimeError): + test_sparse_emb.lookup_for_asc_with_feature_spec(case2_feat, send_count, **kwargs) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_training_mode_channel_id=mock.MagicMock(return_value=None), + clear_same_table_feature_spec=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.embedding.get_table_name_to_feature_spec") + def test_lookup_for_asc_with_feature_spec_case3(self, mock_get_table_name_to_feature_spec): + """ + case3: test lookup_for_asc_with_feature_spec,静态shape,len(same_table_feature_spec)>1 + """ + + with tf.Graph().as_default(): + def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): + return tf.ones(shape=[16, ], dtype=tf.int64) + + self.config["device_vocabulary_size"] = 100 * 8 + self.config["host_vocabulary_size"] = 100 * 8 + test_sparse_emb = SparseEmbedding(self.config) + test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner + case3_feat = FeatureSpec("case3_feat", table_name="test_table") + case3_feat_multi = FeatureSpec("case3_feat_multi", table_name="test_table") + set_temporary_feature_spec_attribute(case3_feat, 1) + set_temporary_feature_spec_attribute(case3_feat_multi, 1) + case3_feat.split = 8 + case3_feat_multi.split = 8 + mock_get_table_name_to_feature_spec.return_value = [case3_feat, case3_feat_multi] + send_count = 1 + kwargs = {"is_train": True} + + test_sparse_emb.lookup_for_asc_with_feature_spec(case3_feat, send_count, **kwargs) + self.assertGreater(len(test_sparse_emb.lookup_result), 0) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_training_mode_channel_id=mock.MagicMock(return_value=None), + clear_same_table_feature_spec=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=False)) + @mock.patch("mx_rec.core.embedding.get_table_name_to_feature_spec") + def test_lookup_for_asc_with_feature_spec_case4(self, mock_get_table_name_to_feature_spec): + """ + case4: test lookup_for_asc_with_feature_spec,动态shape,len(same_table_feature_spec)>1 + """ + + with tf.Graph().as_default(): + def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): + return tf.ones(shape=[16, ], dtype=tf.int64) + + self.config["device_vocabulary_size"] = 100 * 8 + self.config["host_vocabulary_size"] = 100 * 8 + test_sparse_emb = SparseEmbedding(self.config) + test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner + case4_feat = FeatureSpec("case4_feat", table_name="test_table") + case4_feat_multi = FeatureSpec("case4_feat_multi", table_name="test_table") + set_temporary_feature_spec_attribute(case4_feat, 1) + set_temporary_feature_spec_attribute(case4_feat_multi, 1) + case4_feat.split = 8 + case4_feat_multi.split = 8 + mock_get_table_name_to_feature_spec.return_value = [case4_feat, case4_feat_multi] + send_count = 1 + kwargs = { + "is_train": True, + "batch": { + "case4_feat": tf.ones(shape=[8, ], dtype=tf.int64), + "case4_feat_multi": tf.ones(shape=[8, ], dtype=tf.int64) + } + } + + test_sparse_emb.emb_size = 1 + test_sparse_emb.lookup_for_asc_with_feature_spec(case4_feat, send_count, **kwargs) + self.assertGreater(len(test_sparse_emb.lookup_result), 0) + + +class TestLookupForAscWithFeatureSpecInnerFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): + """ + Test for 'mx_rec.core.embedding.SparseEmbedding.lookup_for_asc_with_feature_spec_inner'. + """ + + def tearDown(self): + self.config["device_vocabulary_size"] = 1 + self.config["host_vocabulary_size"] = 2 + self.config["ssd_vocabulary_size"] = 3 + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + get_device_id=mock.MagicMock(return_value=0), + get_use_hot=mock.MagicMock(return_value=1), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.embedding.get_preprocessed_tensor_for_asc") + def test_lookup_for_asc_with_feature_spec_inner_case1(self, mock_get_preprocessed_tensor_for_asc): + """ + case1: test lookup_for_asc_with_feature_spec_inner,静态shape,关闭动态扩容 + """ + + with tf.Graph().as_default(): + mock_get_preprocessed_tensor_for_asc.return_value = { + "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), + "restore_vector_second": tf.ones(shape=[8, ], dtype=tf.int64), + "unique_keys": tf.ones(shape=[8, ], dtype=tf.int64), + "hot_pos": tf.ones(shape=[8, ], dtype=tf.int64), + "id_offsets": tf.ones(shape=[8, ], dtype=tf.int64), + "all2all_args": tf.ones(shape=[8, 8], dtype=tf.int64), + "swap_in": [tf.no_op()] + } + + self.config["device_vocabulary_size"] = 100 * 8 + self.config["host_vocabulary_size"] = 0 + test_sparse_emb = SparseEmbedding(self.config) + case1_feat = FeatureSpec("case1_feat", table_name="test_table") + set_temporary_feature_spec_attribute(case1_feat, 1) + case1_feat.dims = [8, 8] + send_count = 1 + kwargs = {"is_train": True} + + def _mock_get_own_emb(emb, all2all_args, emb_size, use_static): + return test_sparse_emb.variable + + test_sparse_emb._get_own_emb = _mock_get_own_emb + + lookup_res = test_sparse_emb.lookup_for_asc_with_feature_spec_inner(case1_feat, send_count, **kwargs) + self.assertIsInstance(lookup_res, tf.Tensor) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + get_device_id=mock.MagicMock(return_value=0), + get_use_hot=mock.MagicMock(return_value=1), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=False)) + @mock.patch("mx_rec.core.embedding.get_preprocessed_tensor_for_asc") + def test_lookup_for_asc_with_feature_spec_inner_case2(self, mock_get_preprocessed_tensor_for_asc): + """ + case2: test lookup_for_asc_with_feature_spec_inner,动态shape,关闭动态扩容 + """ + + with tf.Graph().as_default(): + mock_get_preprocessed_tensor_for_asc.return_value = { + "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), + "restore_vector_second": tf.ones(shape=[8, ], dtype=tf.int64), + "unique_keys": tf.ones(shape=[8, ], dtype=tf.int64), + "hot_pos": tf.ones(shape=[8, ], dtype=tf.int64), + "id_offsets": tf.ones(shape=[8, ], dtype=tf.int64), + "all2all_args": tf.ones(shape=[8, 8], dtype=tf.int64), + "swap_in": [tf.no_op()] + } + + self.config["device_vocabulary_size"] = 100 * 8 + self.config["host_vocabulary_size"] = 0 + test_sparse_emb = SparseEmbedding(self.config) + case2_feat = FeatureSpec("case2_feat", table_name="test_table") + set_temporary_feature_spec_attribute(case2_feat, 1) + case2_feat.dims = [8, 8] + send_count = 1 + kwargs = {"is_train": True, "batch": {"case2_feat": tf.ones(shape=[8, 8], dtype=tf.int64)}} + + def _mock_get_own_emb(emb, all2all_args, emb_size, use_static): + return test_sparse_emb.variable + + test_sparse_emb._get_own_emb = _mock_get_own_emb + + lookup_res = test_sparse_emb.lookup_for_asc_with_feature_spec_inner(case2_feat, send_count, **kwargs) + self.assertIsInstance(lookup_res, tf.Tensor) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + get_device_id=mock.MagicMock(return_value=0), + get_use_hot=mock.MagicMock(return_value=1), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None), + get_training_mode_channel_id=mock.MagicMock(return_value=None), + get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.embedding.get_preprocessed_tensor_for_asc") + def test_lookup_for_asc_with_feature_spec_inner_case3(self, mock_get_preprocessed_tensor_for_asc): + """ + case3: test lookup_for_asc_with_feature_spec_inner,静态shape,关闭动态扩容 + access_threshold > 0,覆盖 set_specific_value_for_non_valid_key() + """ + + with tf.Graph().as_default(): + mock_get_preprocessed_tensor_for_asc.return_value = { + "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), + "restore_vector_second": tf.ones(shape=[8, ], dtype=tf.int64), + "unique_keys": tf.ones(shape=[8, ], dtype=tf.int64), + "hot_pos": tf.ones(shape=[8, ], dtype=tf.int64), + "id_offsets": tf.ones(shape=[8, ], dtype=tf.int64), + "all2all_args": tf.ones(shape=[8, 8], dtype=tf.int64), + "swap_in": [tf.no_op()] + } + + self.config["device_vocabulary_size"] = 100 * 8 + self.config["host_vocabulary_size"] = 0 + test_sparse_emb = SparseEmbedding(self.config) + case3_feat = FeatureSpec("case3_feat", table_name="test_table", access_threshold=10) + set_temporary_feature_spec_attribute(case3_feat, 1) + case3_feat.dims = [8, 8] + send_count = 1 + kwargs = {"is_train": True} + + def _mock_get_own_emb(emb, all2all_args, emb_size, use_static): + return test_sparse_emb.variable + + test_sparse_emb._get_own_emb = _mock_get_own_emb + + lookup_res = test_sparse_emb.lookup_for_asc_with_feature_spec_inner(case3_feat, send_count, **kwargs) + self.assertIsInstance(lookup_res, tf.Tensor) + + +class TestSparseLookupFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.embedding.sparse_lookup'. + """ + + def setUp(self): + key_dtype = tf.int64 + dim = 8 + name = 'test_table' + emb_initializer = tf.compat.v1.truncated_normal_initializer() + optimizer_list = [MockOptimizer()] + device_vocabulary_size = 1 + host_vocabulary_size = 2 + ssd_vocabulary_size = 3 + ssd_data_path = (os.getcwd(),) + is_save = True + init_param = 1. + all2all_gradients_op = All2allGradientsOp.SUM_GRADIENTS.value + + self.config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, + device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, + ssd_vocabulary_size=ssd_vocabulary_size, ssd_data_path=ssd_data_path, + optimizer_list=optimizer_list, init_param=init_param, is_save=is_save, + all2all_gradients_op=all2all_gradients_op) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None)) + def test_sparse_lookup_case1(self): + """ + case1: test sparse_lookup,FeatureSpec模式 + """ + + from mx_rec.core.embedding import sparse_lookup + + def _mock_lookup_for_asc_with_feature_spec(ids, send_count, **kwargs): + return 0 + + with tf.Graph().as_default(): + case1_feat = FeatureSpec("case1_feat", table_name="test_table") + test_sparse_emb = SparseEmbedding(self.config) + test_sparse_emb.lookup_for_asc_with_feature_spec = _mock_lookup_for_asc_with_feature_spec + + self.assertEqual(sparse_lookup(test_sparse_emb, case1_feat), 0) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + set_modify_graph=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None)) + def test_sparse_lookup_case2(self): + """ + case2: test sparse_lookup,自动改图模式 + """ + + from mx_rec.core.embedding import sparse_lookup + + def _mock_lookup_for_asc(ids, send_count, **kwargs): + return 1 + + with tf.Graph().as_default(): + ids = tf.constant(1, tf.int64) + test_sparse_emb = SparseEmbedding(self.config) + test_sparse_emb.lookup_for_asc = _mock_lookup_for_asc + + self.assertEqual(sparse_lookup(test_sparse_emb, ids, modify_graph=True), 1) + + @mock.patch.multiple("mx_rec.core.embedding", + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + is_asc_frozen=mock.MagicMock(return_value=False), + get_name_to_var_dict=mock.MagicMock(return_value=None), + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + get_rank_size=mock.MagicMock(return_value=8), + insert_removing_var_list=mock.MagicMock(return_value=None), + set_modify_graph=mock.MagicMock(return_value=None), + insert_table_instance=mock.MagicMock(return_value=None)) + def test_sparse_lookup_case3(self): + """ + case3: test sparse_lookup,自动改图模式,没传入modify_graph参数,抛出异常 + """ + + from mx_rec.core.embedding import sparse_lookup + + with tf.Graph().as_default(): + ids = tf.constant(1, tf.int64) + test_sparse_emb = SparseEmbedding(self.config) + + with self.assertRaises(ValueError): + sparse_lookup(test_sparse_emb, ids) + + +if __name__ == '__main__': + unittest.main() -- Gitee From 218b845014cdab05ad266a58b6c81f13c7e624f6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 15 Dec 2023 11:13:53 +0800 Subject: [PATCH 518/551] Match-id-aa5a103eccd84a864271577a679c7fb38cdf5758 --- tests/mx_rec/data/__init__.py | 3 +++ tests/mx_rec/data/mock_class.py | 15 +++++++++++++ tests/mx_rec/data/test_dataset.py | 36 +++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+) create mode 100644 tests/mx_rec/data/__init__.py create mode 100644 tests/mx_rec/data/mock_class.py create mode 100644 tests/mx_rec/data/test_dataset.py diff --git a/tests/mx_rec/data/__init__.py b/tests/mx_rec/data/__init__.py new file mode 100644 index 00000000..6924f767 --- /dev/null +++ b/tests/mx_rec/data/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/tests/mx_rec/data/mock_class.py b/tests/mx_rec/data/mock_class.py new file mode 100644 index 00000000..a3f5c416 --- /dev/null +++ b/tests/mx_rec/data/mock_class.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + + +class MockEosOpsLib: + """ + mock librec.eos_dataset + """ + + def __init__(self, variant_tensor): + def _mock_eos_dataset_fn(**kwargs): + return variant_tensor + + self.eos_dataset = _mock_eos_dataset_fn diff --git a/tests/mx_rec/data/test_dataset.py b/tests/mx_rec/data/test_dataset.py new file mode 100644 index 00000000..4db7f63f --- /dev/null +++ b/tests/mx_rec/data/test_dataset.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import unittest + +import tensorflow as tf + +from tests.mx_rec.core.generator_dataset import generate_dataset, Config +from tests.mx_rec.data.mock_class import MockEosOpsLib + + +class TestEosDatasetClass(unittest.TestCase): + """ + Test for 'mx_rec.data.dataset.EosDataset'. + """ + + def test_init(self): + """ + case: 实例化EosDataset,使用eos_map + """ + + with tf.Graph().as_default(): + dataset_ori = generate_dataset(Config(batch_size=2, batch_number=2)) + dataset = dataset_ori.eos_map(MockEosOpsLib(dataset_ori._variant_tensor), 0) + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + with tf.Session() as sess: + sess.run(iterator.initializer) + sess.run(tf.compat.v1.global_variables_initializer()) + sess.run(batch) + self.assertIsNotNone(batch.get("item_ids")) + + +if __name__ == '__main__': + unittest.main() -- Gitee From 9f469e89a290a0145a5a6d598616b51a24c3c911 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Dec 2023 20:53:05 +0800 Subject: [PATCH 519/551] Match-id-09f447fb2940f0f028a2e0b285f51b7861b699d2 --- mx_rec/core/asc/manager.py | 8 +- mx_rec/core/asc/merge_table.py | 9 +- mx_rec/core/feature_process.py | 2 +- tests/mx_rec/core/mock_class.py | 66 +++- tests/mx_rec/core/test_feature_process.py | 161 ++++++++ tests/mx_rec/core/test_manager.py | 455 ++++++++++++++++++++++ tests/mx_rec/core/test_merge_table.py | 219 +++++++++++ 7 files changed, 904 insertions(+), 16 deletions(-) create mode 100644 tests/mx_rec/core/test_feature_process.py create mode 100644 tests/mx_rec/core/test_manager.py create mode 100644 tests/mx_rec/core/test_merge_table.py diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 42d94476..7b999ac8 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -4,7 +4,8 @@ import tensorflow as tf -from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo +from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo, EmbInfo, EmbInfoParams, \ + ThresholdValue, HybridMgmt, RankInfo, USE_STATIC, USE_HOT, USE_DYNAMIC_EXPANSION from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ is_asc_manager_initialized, get_train_steps, get_eval_steps, get_save_steps, \ @@ -27,8 +28,6 @@ def check_dangling_table(): def generate_table_info_list(): - from mxrec_pybind import EmbInfo, EmbInfoParams - from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN # table_name is corresponding to channel_name which is in used in operator gen_npu_ops.get_next table_info_list = [] @@ -162,7 +161,6 @@ def matched_opt_slot_initializers(table_instance): def generate_threshold_list(): - from mxrec_pybind import ThresholdValue threshold_list = [] for _, feature_spec in export_feature_spec().items(): @@ -187,8 +185,6 @@ def generate_threshold_list(): def initialize_emb_cache(table_info_list, threshold_list): - from mxrec_pybind import HybridMgmt, RankInfo, USE_STATIC, USE_HOT, USE_DYNAMIC_EXPANSION - rank_id = get_rank_id() device_id = get_device_id() rank_size = get_rank_size() diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index e06a7c7d..9cb59bcc 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -7,7 +7,7 @@ from typing import Dict, List import tensorflow as tf from tensorflow import Operation, Tensor -from mx_rec.constants.constants import MAX_WHILE_SIZE +from mx_rec.constants.constants import MAX_WHILE_SIZE, ASCEND_TABLE_NAME_MUST_CONTAIN from mx_rec.util.initialize import get_enable_table_merge, export_table_instances, insert_dangling_table, \ get_bool_gauge_set from mx_rec.util.log import logger @@ -63,7 +63,7 @@ def find_dangling_table(table_names: List[str]) -> List[str]: def find_table_op(table_name: str, the_op: Operation, table_lookup_op: Dict[str, List[Operation]], - table_reachable_tensor: Dict[str, List[Tensor]]) -> None: + table_reachable_tensor: Dict[str, List[Tensor]]) -> None: # pragma: no cover """ find all the table lookup op. :param table_name: tables' names :param the_op: the op to be @@ -84,7 +84,7 @@ def find_dangling_table(table_names: List[str]) -> List[str]: def extend(op_list: List[Operation], tensor: Tensor, - spread_tensors: List[Tensor]) -> None: + spread_tensors: List[Tensor]) -> None: # pragma: no cover """extend the tensors which table lookup op can reach :param op_list: all op in the graph @@ -96,7 +96,7 @@ def find_dangling_table(table_names: List[str]) -> List[str]: if tensor in the_op.inputs: spread_tensors.extend(the_op.outputs) - def bfs_lookup(next_to_visit: List[Tensor]) -> (set, bool): + def bfs_lookup(next_to_visit: List[Tensor]) -> (set, bool): # pragma: no cover """find all the tensors which table lookup op can reach :param next_to_visit: the tensor list to be visited by bfs @@ -158,7 +158,6 @@ def find_dangling_table(table_names: List[str]) -> List[str]: def should_skip(table_name) -> bool: - from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN if ASCEND_TABLE_NAME_MUST_CONTAIN is not None \ and isinstance(ASCEND_TABLE_NAME_MUST_CONTAIN, str) \ and ASCEND_TABLE_NAME_MUST_CONTAIN not in table_name: diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py index 6dd6fc3d..0fb95682 100644 --- a/mx_rec/core/feature_process.py +++ b/mx_rec/core/feature_process.py @@ -6,6 +6,7 @@ import time import tensorflow as tf +from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.constants.constants import DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, MAX_INT32 from mx_rec.util.initialize import trigger_evict, get_table_instance_by_name, export_feature_spec from mx_rec.validator.validator import para_checker_decorator, ClassValidator, IntValidator, OptionalIntValidator @@ -51,7 +52,6 @@ class EvictHook(tf.compat.v1.train.SessionRunHook): with tf.compat.v1.variable_scope(scope_name): logger.debug('Channel %s_evict_%d was built for op getnext', instance.table_name, TRAIN_CHANNEL_ID) - from mx_rec.util.tf_version_adapter import npu_ops evict_pos, evict_len = npu_ops.gen_npu_ops.get_next( output_types=[tf.int32, tf.int32], output_shapes=[[None], []], diff --git a/tests/mx_rec/core/mock_class.py b/tests/mx_rec/core/mock_class.py index 5fe33363..58737058 100644 --- a/tests/mx_rec/core/mock_class.py +++ b/tests/mx_rec/core/mock_class.py @@ -3,6 +3,9 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. import tensorflow as tf +from tensorflow_core.python.training import slot_creator + +from mx_rec.optimizers.lazy_adam import CustomizedLazyAdam class MockSparseEmbedding: @@ -58,16 +61,59 @@ class MockHcclOps: self.all_to_all_v_c = _mock_all_to_all_v_c -class MockOptimizer: +class MockOptimizer(CustomizedLazyAdam): """ 用于mock optimizer """ def __init__(self): - def _mock_insert_slot(slot, named_slot_key, slot_name): - return "mock_insert_slot" + super(MockOptimizer, self)._get_name(name="MockLazyAdam") + super(MockOptimizer, self).__init__(learning_rate=0.001, beta1=0.9, beta2=0.999, + epsilon=1e-8, use_locking=False, name="MockLazyAdam") + self.slot_num = 2 + + def initialize_slots(self, var, table_instance): + # Create slots for the first and second moments. + def creat_one_single_slot(var, op_name): + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + return new_slot_variable + + momentum = creat_one_single_slot(var, self._name + "/" + "momentum") + velocity = creat_one_single_slot(var, self._name + "/" + "velocity") + named_slot_key = (var.op.graph, var.op.name) + + table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) + return [{"slot": momentum, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}, + {"slot": velocity, "named_slot_key": named_slot_key, "slot_name": "v", "optimizer": self}] + + def insert_slot(self, slot, named_slots_key, slot_name): + pass + + def get_slot_init_values(self): + initial_momentum_value = 0.0 + initial_velocity_value = 0.0 + return [initial_momentum_value, initial_velocity_value] + + def update_op(self, optimizer, g): + return super().update_op(optimizer, g) + + def _apply_spare_duplicate_indices(self, grad, var): + return self._apply_sparse(grad, var) + + def _apply_sparse(self, grad, var): + return super()._apply_sparse(grad, var) - self.insert_slot = _mock_insert_slot + def _resource_apply_sparse(self, grad, handle, indices): + return super()._resource_apply_sparse(grad, handle, indices) + + def _apply_dense(self, grad, var): + return super()._apply_dense(grad, var) + + def _apply_sparse_duplicate_indices(self, grad, var): + return self._apply_sparse(grad, var) + + def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices): + return self._resource_apply_sparse(grad, handle, indices) class MockAscManager: @@ -84,3 +130,15 @@ class MockAscManager: self.get_table_size = _mock_get_table_size self.get_table_capacity = _mock_get_table_capacity + + +class MockHybridMgmt: + """ + 用于mock HybridMgmt() + """ + + def __init__(self, is_initialized=True): + def _mock_initialize(rank_info=0, emb_info=1, if_load=False, threshold_values=3): + return is_initialized + + self.initialize = _mock_initialize diff --git a/tests/mx_rec/core/test_feature_process.py b/tests/mx_rec/core/test_feature_process.py new file mode 100644 index 00000000..2acaf56b --- /dev/null +++ b/tests/mx_rec/core/test_feature_process.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import time +import unittest +from unittest import mock + +import tensorflow as tf + +from mx_rec.core.feature_process import EvictHook +from mx_rec.core.asc.feature_spec import FeatureSpec +from tests.mx_rec.core.mock_class import MockSparseEmbedding + + +class TestEvictHookClass(unittest.TestCase): + """ + Test for 'mx_rec.core.feature_process.EvictHook'. + """ + + def test_init_case1(self): + """ + case1: evict_step_interval为None + """ + self.assertIsInstance(EvictHook(), EvictHook) + + def test_init_case2(self): + """ + case2: evict_step_interval不为None + """ + self.assertIsInstance(EvictHook(evict_step_interval=5), EvictHook) + + +@mock.patch.multiple( + "mx_rec.graph.patch", + get_modify_graph=mock.Mock(return_value=True), + get_is_graph_modify_hook_running=mock.Mock(return_value=True), +) +@mock.patch.multiple( + "tensorflow.compat.v1.train.Saver", + __init__=mock.Mock(return_value=None), + build=mock.Mock(), +) +class TestAfterRunFuncOfEvictHookClass(TestEvictHookClass): + """ + Test for 'mx_rec.core.feature_process.EvictHook.after_run'. + """ + + def setUp(self): + self.ori_var_assert = [[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]] + self.evict_var_assert = [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]] + + @mock.patch.multiple("mx_rec.core.feature_process", + trigger_evict=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.feature_process.npu_ops.gen_npu_ops.get_next") + @mock.patch("mx_rec.core.feature_process.get_table_instance_by_name") + @mock.patch("mx_rec.core.feature_process.export_feature_spec") + def test_after_run_case1(self, mock_export_feature_spec, mock_get_table_instance_by_name, mock_get_next): + """ + case1: evict_enable为True,python和C++侧正常触发淘汰 + """ + + with tf.Graph().as_default(): + test_table = MockSparseEmbedding() + mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] + mock_get_table_instance_by_name.return_value = test_table + mock_export_feature_spec.return_value = dict( + test_spec=FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10)) + + evict_hook = EvictHook(evict_enable=True, evict_time_interval=1) + with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: + sess.graph._unsafe_unfinalize() + sess.run(tf.compat.v1.global_variables_initializer()) + + # sleep 1s 等待淘汰时间evict_time_interval + time.sleep(1) + + # 获取原variable,淘汰会发生在此session run之后 + ori_variable = sess.run(test_table.variable) + # ori_variable的9、10行值都为1 + self.assertListEqual(ori_variable[8:].tolist(), self.ori_var_assert) + + # 获取淘汰后variable + evict_variable = sess.run(test_table.variable) + # evict_variable的9、10行触发淘汰后都为0,其余行数的值不变 + self.assertListEqual(evict_variable[8:].tolist(), self.evict_var_assert) + self.assertListEqual(evict_variable[:2].tolist(), self.ori_var_assert) + + @mock.patch.multiple("mx_rec.core.feature_process", + trigger_evict=mock.MagicMock(return_value=False)) + @mock.patch("mx_rec.core.feature_process.npu_ops.gen_npu_ops.get_next") + @mock.patch("mx_rec.core.feature_process.get_table_instance_by_name") + @mock.patch("mx_rec.core.feature_process.export_feature_spec") + def test_after_run_case2(self, mock_export_feature_spec, mock_get_table_instance_by_name, mock_get_next): + """ + case2: evict_enable为True,C++侧异常 + """ + + with tf.Graph().as_default(): + test_table = MockSparseEmbedding() + mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] + mock_get_table_instance_by_name.return_value = test_table + mock_export_feature_spec.return_value = dict( + test_spec=FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10)) + + evict_hook = EvictHook(evict_enable=True, evict_time_interval=1) + with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: + sess.graph._unsafe_unfinalize() + sess.run(tf.compat.v1.global_variables_initializer()) + + # sleep 1s 等待淘汰时间evict_time_interval + time.sleep(1) + + # 获取原variable,淘汰会发生在此session run之后 + ori_variable = sess.run(test_table.variable) + # ori_variable的9、10行值都为1 + self.assertListEqual(ori_variable[8:].tolist(), self.ori_var_assert) + + # 获取淘汰后variable + evict_variable = sess.run(test_table.variable) + # 此时C++侧异常,不执行淘汰,因此evict_variable后两行还是1 + self.assertListEqual(evict_variable[8:].tolist(), self.ori_var_assert) + + @mock.patch.multiple("mx_rec.core.feature_process", + trigger_evict=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.feature_process.npu_ops.gen_npu_ops.get_next") + @mock.patch("mx_rec.core.feature_process.get_table_instance_by_name") + @mock.patch("mx_rec.core.feature_process.export_feature_spec") + def test_after_run_case3(self, mock_export_feature_spec, mock_get_table_instance_by_name, mock_get_next): + """ + case3: evict_enable为False + """ + + with tf.Graph().as_default(): + test_table = MockSparseEmbedding() + mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] + mock_get_table_instance_by_name.return_value = test_table + mock_export_feature_spec.return_value = dict( + test_spec=FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10)) + + evict_hook = EvictHook(evict_enable=False, evict_time_interval=1) + with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: + sess.graph._unsafe_unfinalize() + sess.run(tf.compat.v1.global_variables_initializer()) + + # sleep 1s 等待淘汰时间evict_time_interval + time.sleep(1) + + # 获取原variable,淘汰会发生在此session run之后 + ori_variable = sess.run(test_table.variable) + # ori_variable的9、10行值都为1 + self.assertListEqual(ori_variable[8:].tolist(), self.ori_var_assert) + + # 获取淘汰后variable + evict_variable = sess.run(test_table.variable) + # 此时evict_enable为False,不执行淘汰,因此evict_variable后两行还是1 + self.assertListEqual(evict_variable[8:].tolist(), self.ori_var_assert) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/core/test_manager.py b/tests/mx_rec/core/test_manager.py new file mode 100644 index 00000000..72920bdd --- /dev/null +++ b/tests/mx_rec/core/test_manager.py @@ -0,0 +1,455 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import unittest +from unittest import mock + +import tensorflow as tf + +from mx_rec.core.asc.feature_spec import FeatureSpec +from tests.mx_rec.core.mock_class import MockSparseEmbedding, MockOptimizer, MockHybridMgmt + + +class TestCheckDanglingTableFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.check_dangling_table'. + """ + + @mock.patch.multiple("mx_rec.core.asc.manager", + export_dangling_table=mock.MagicMock(return_value=[]), + export_table_instances=mock.MagicMock(return_value={}), + find_dangling_table=mock.MagicMock(return_value=["test_table"])) + def test_check_dangling_table(self): + """ + case: test check_dangling_table + """ + + from mx_rec.core.asc.manager import check_dangling_table + + self.assertListEqual(check_dangling_table(), ["test_table"]) + + +class TestGenerateTableInfoListFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.generate_table_info_list'. + """ + + @mock.patch("mx_rec.core.asc.manager.export_table_instances") + def test_generate_table_info_list_case1(self, mock_export_table_instances): + """ + case1: 一张表开DDR,一张表没开DDR,抛出异常 + """ + + from mx_rec.core.asc.manager import generate_table_info_list + + with tf.Graph().as_default(): + test_table1 = MockSparseEmbedding("test_table1") + test_table1.host_vocabulary_size = 1 + test_table2 = MockSparseEmbedding("test_table2") + test_table2.host_vocabulary_size = 0 + mock_export_table_instances.return_value = { + "test_table1": test_table1, + "test_table2": test_table2 + } + + with self.assertRaises(ValueError): + generate_table_info_list() + + @mock.patch.multiple("mx_rec.core.asc.manager", + export_optimizer=mock.MagicMock(return_value=None), + should_skip=mock.MagicMock(return_value=True), + check_dangling_table=mock.MagicMock(return_value=["test_table"])) + @mock.patch("mx_rec.core.asc.manager.export_table_instances") + def test_generate_table_info_list_case2(self, mock_export_table_instances): + """ + case2: test_table是dangling_table,skip为True + """ + + from mx_rec.core.asc.manager import generate_table_info_list + + with tf.Graph().as_default(): + test_table = MockSparseEmbedding("test_table") + test_table.host_vocabulary_size = 1 + mock_export_table_instances.return_value = dict(test_table=test_table) + + table_info_list = generate_table_info_list() + self.assertListEqual(table_info_list, []) + + @mock.patch.multiple("mx_rec.core.asc.manager", + export_optimizer=mock.MagicMock(return_value=None), + should_skip=mock.MagicMock(return_value=True), + check_dangling_table=mock.MagicMock(return_value=[])) + @mock.patch("mx_rec.core.asc.manager.export_table_instances") + def test_generate_table_info_list_case3(self, mock_export_table_instances): + """ + case3: test_table不是dangling_table,skip为True + """ + + from mx_rec.core.asc.manager import generate_table_info_list + + with tf.Graph().as_default(): + test_table = MockSparseEmbedding("test_table") + test_table.host_vocabulary_size = 1 + mock_export_table_instances.return_value = dict(test_table=test_table) + + table_info_list = generate_table_info_list() + self.assertListEqual(table_info_list, []) + + @mock.patch.multiple("mx_rec.core.asc.manager", + export_optimizer=mock.MagicMock(return_value=None), + EmbInfoParams=mock.MagicMock(return_value=None), + EmbInfo=mock.MagicMock(return_value="test_table_info"), + matched_emb_initializer=mock.MagicMock(return_value=[]), + matched_opt_slot_initializers=mock.MagicMock(return_value=[]), + should_skip=mock.MagicMock(return_value=False), + get_use_static=mock.MagicMock(return_value=True), + check_dangling_table=mock.MagicMock(return_value=[])) + @mock.patch("mx_rec.core.asc.manager.export_table_instances") + def test_generate_table_info_list_case3(self, mock_export_table_instances): + """ + case4: 静态shape,test_table不是dangling_table,skip为False + """ + + from mx_rec.core.asc.manager import generate_table_info_list + + with tf.Graph().as_default(): + test_table = MockSparseEmbedding("test_table") + test_table.host_vocabulary_size = 8 + test_table.send_count = 1 + test_table.slice_device_vocabulary_size = 1 + test_table.slice_host_vocabulary_size = 1 + test_table.slice_ssd_vocabulary_size = 0 + test_table.is_grad = True + test_table.is_save = True + test_table.scalar_emb_size = 8 + test_table.ext_emb_size = 8 + test_table.ssd_data_path = "" + mock_export_table_instances.return_value = dict(test_table=test_table) + + table_info_list = generate_table_info_list() + self.assertListEqual(table_info_list, ["test_table_info"]) + + +class TestMatchedConstantInitializerFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.matched_constant_initializer'. + """ + + @mock.patch.multiple("mx_rec.core.asc.manager", + InitializeInfo=mock.MagicMock(return_value=[]), + ConstantInitializerInfo=mock.MagicMock(return_value=[])) + def test_matched_constant_initializer(self): + """ + case: test matched_constant_initializer + """ + + from mx_rec.core.asc.manager import matched_constant_initializer + + with tf.Graph().as_default(): + table_info = MockSparseEmbedding("test_table") + table_info.init_param = 1. + table_info.scalar_emb_size = 8 + table_info.emb_initializer.value = 0 + + self.assertListEqual(matched_constant_initializer(table_info), []) + + +class TestMatchedRandomNormalInitializerFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.matched_random_normal_initializer'. + """ + + @mock.patch.multiple("mx_rec.core.asc.manager", + InitializeInfo=mock.MagicMock(return_value=[]), + NormalInitializerInfo=mock.MagicMock(return_value=[])) + def test_matched_random_normal_initializer_case1(self): + """ + case1: emb_initializer.seed为None + """ + + from mx_rec.core.asc.manager import matched_random_normal_initializer + + with tf.Graph().as_default(): + table_info = MockSparseEmbedding("test_table") + table_info.init_param = 1. + table_info.scalar_emb_size = 8 + table_info.emb_initializer.seed = None + table_info.emb_initializer.mean = 1 + table_info.emb_initializer.stddev = 1 + + self.assertListEqual(matched_random_normal_initializer(table_info), []) + + @mock.patch.multiple("mx_rec.core.asc.manager", + InitializeInfo=mock.MagicMock(return_value=[]), + NormalInitializerInfo=mock.MagicMock(return_value=[])) + def test_matched_random_normal_initializer_case2(self): + """ + case2: emb_initializer.seed非None + """ + + from mx_rec.core.asc.manager import matched_random_normal_initializer + + with tf.Graph().as_default(): + table_info = MockSparseEmbedding("test_table") + table_info.init_param = 1. + table_info.scalar_emb_size = 8 + table_info.emb_initializer.seed = 1 + table_info.emb_initializer.mean = 1 + table_info.emb_initializer.stddev = 1 + + self.assertListEqual(matched_random_normal_initializer(table_info), []) + + +class TestMatchedTruncatedNormalInitializerFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.matched_truncated_normal_initializer'. + """ + + @mock.patch.multiple("mx_rec.core.asc.manager", + InitializeInfo=mock.MagicMock(return_value=[]), + NormalInitializerInfo=mock.MagicMock(return_value=[])) + def test_matched_truncated_normal_initializer_case1(self): + """ + case1: emb_initializer.seed为None + """ + + from mx_rec.core.asc.manager import matched_truncated_normal_initializer + + with tf.Graph().as_default(): + table_info = MockSparseEmbedding("test_table") + table_info.init_param = 1. + table_info.scalar_emb_size = 8 + table_info.emb_initializer.seed = None + table_info.emb_initializer.mean = 1 + table_info.emb_initializer.stddev = 1 + + self.assertListEqual(matched_truncated_normal_initializer(table_info), []) + + @mock.patch.multiple("mx_rec.core.asc.manager", + InitializeInfo=mock.MagicMock(return_value=[]), + NormalInitializerInfo=mock.MagicMock(return_value=[])) + def test_matched_random_normal_initializer_case2(self): + """ + case2: emb_initializer.seed非None + """ + + from mx_rec.core.asc.manager import matched_truncated_normal_initializer + + with tf.Graph().as_default(): + table_info = MockSparseEmbedding("test_table") + table_info.init_param = 1. + table_info.scalar_emb_size = 8 + table_info.emb_initializer.seed = 1 + table_info.emb_initializer.mean = 1 + table_info.emb_initializer.stddev = 1 + + self.assertListEqual(matched_truncated_normal_initializer(table_info), []) + + +class TestMatchedEmbInitializerFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.matched_emb_initializer'. + """ + + @mock.patch.multiple("mx_rec.core.asc.manager", + matched_constant_initializer=mock.MagicMock(return_value=1)) + def test_matched_emb_initializer_case1(self): + """ + case1: 初始化器为 tf.constant_initializer + """ + + from mx_rec.core.asc.manager import matched_emb_initializer + + with tf.Graph().as_default(): + table_info = MockSparseEmbedding("test_table") + table_info.scalar_emb_size = 8 + table_info.emb_initializer = tf.constant_initializer() + + self.assertEqual(matched_emb_initializer(table_info), 1) + + @mock.patch.multiple("mx_rec.core.asc.manager", + matched_random_normal_initializer=mock.MagicMock(return_value=2)) + def test_matched_emb_initializer_case2(self): + """ + case2: 初始化器为 tf.random_normal_initializer + """ + + from mx_rec.core.asc.manager import matched_emb_initializer + + with tf.Graph().as_default(): + table_info = MockSparseEmbedding("test_table") + table_info.scalar_emb_size = 8 + table_info.emb_initializer = tf.random_normal_initializer() + + self.assertEqual(matched_emb_initializer(table_info), 2) + + @mock.patch.multiple("mx_rec.core.asc.manager", + matched_truncated_normal_initializer=mock.MagicMock(return_value=3)) + def test_matched_emb_initializer_case3(self): + """ + case3: 初始化器为 tf.truncated_normal_initializer + """ + + from mx_rec.core.asc.manager import matched_emb_initializer + + with tf.Graph().as_default(): + table_info = MockSparseEmbedding("test_table") + table_info.scalar_emb_size = 8 + table_info.emb_initializer = tf.truncated_normal_initializer() + + self.assertEqual(matched_emb_initializer(table_info), 3) + + +class TestMatchedOptSlotInitializersFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.matched_opt_slot_initializers'. + """ + + @mock.patch.multiple("mx_rec.core.asc.manager", + InitializeInfo=mock.MagicMock(return_value="slot_initializer")) + def test_matched_opt_slot_initializers(self): + """ + case: test matched_opt_slot_initializers + """ + + from mx_rec.core.asc.manager import matched_opt_slot_initializers + + with tf.Graph().as_default(): + table_instance = MockSparseEmbedding("test_table") + table_instance.scalar_emb_size = 8 + table_instance.ext_emb_size = 8 + table_instance.optimizer_instance_list = [MockOptimizer()] + + slot_initializers = matched_opt_slot_initializers(table_instance) + self.assertListEqual(slot_initializers, ["slot_initializer", "slot_initializer"]) + + +class TestGenerateThresholdListFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.generate_threshold_list'. + """ + + @mock.patch.multiple("mx_rec.core.asc.manager", + ThresholdValue=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.asc.manager.export_feature_spec") + def test_generate_threshold_list(self, mock_export_feature_spec): + """ + case: 有淘汰、准入 + """ + + from mx_rec.core.asc.manager import generate_threshold_list + + with tf.Graph().as_default(): + test_feature_spec1 = FeatureSpec("test_feature_spec1", + access_threshold=5, eviction_threshold=10, faae_coefficient=None) + test_feature_spec2 = FeatureSpec("test_feature_spec2", + access_threshold=5, faae_coefficient=None) + mock_export_feature_spec.return_value = { + "test_feature_spec1": test_feature_spec1, + "test_feature_spec2": test_feature_spec2 + } + + self.assertListEqual(generate_threshold_list(), [0, 0]) + + +class TestInitializeEmbCacheFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.initialize_emb_cache'. + """ + + @mock.patch.multiple("mx_rec.core.asc.manager", + get_rank_id=mock.MagicMock(return_value=0), + get_device_id=mock.MagicMock(return_value=0), + get_rank_size=mock.MagicMock(return_value=0), + get_train_steps=mock.MagicMock(return_value=0), + get_eval_steps=mock.MagicMock(return_value=0), + get_save_steps=mock.MagicMock(return_value=0), + get_if_load=mock.MagicMock(return_value=False), + get_use_static=mock.MagicMock(return_value=True), + get_use_hot=mock.MagicMock(return_value=True), + get_use_dynamic_expansion=mock.MagicMock(return_value=True), + USE_STATIC=mock.MagicMock(return_value=0), + USE_HOT=mock.MagicMock(return_value=1), + USE_DYNAMIC_EXPANSION=mock.MagicMock(return_value=2), + RankInfo=mock.MagicMock(return_value="mock_info"), + HybridMgmt=mock.MagicMock(return_value=MockHybridMgmt(is_initialized=False))) + def test_initialize_emb_cache_case1(self): + """ + case1: 初始化失败 + """ + + from mx_rec.core.asc.manager import initialize_emb_cache + + with self.assertRaises(RuntimeError): + initialize_emb_cache([], []) + + @mock.patch.multiple("mx_rec.core.asc.manager", + get_rank_id=mock.MagicMock(return_value=0), + get_device_id=mock.MagicMock(return_value=0), + get_rank_size=mock.MagicMock(return_value=0), + get_train_steps=mock.MagicMock(return_value=0), + get_eval_steps=mock.MagicMock(return_value=0), + get_save_steps=mock.MagicMock(return_value=0), + get_if_load=mock.MagicMock(return_value=False), + get_use_static=mock.MagicMock(return_value=True), + get_use_hot=mock.MagicMock(return_value=True), + get_use_dynamic_expansion=mock.MagicMock(return_value=True), + USE_STATIC=mock.MagicMock(return_value=0), + USE_HOT=mock.MagicMock(return_value=1), + USE_DYNAMIC_EXPANSION=mock.MagicMock(return_value=2), + RankInfo=mock.MagicMock(return_value="mock_info"), + set_asc_manager=mock.MagicMock(return_value=None)) + @mock.patch("mx_rec.core.asc.manager.HybridMgmt") + def test_initialize_emb_cache_case2(self, mock_hybrid_mgmt): + """ + case2: 初始化成功 + """ + + from mx_rec.core.asc.manager import initialize_emb_cache + + mock_mgmt = MockHybridMgmt(is_initialized=True) + mock_hybrid_mgmt.return_value = mock_mgmt + initialize_emb_cache([], []) + self.assertTrue(mock_mgmt.initialize()) + + +class TestStartAscPipeLineFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.manager.start_asc_pipeline'. + """ + + @mock.patch.multiple("mx_rec.core.asc.manager", + generate_table_info_list=mock.MagicMock(return_value=[]), + generate_threshold_list=mock.MagicMock(return_value=[])) + def test_start_asc_pipeline_case1(self): + """ + case1: table_info_list为[] + """ + + from mx_rec.core.asc.manager import start_asc_pipeline + + with self.assertRaises(RuntimeError): + start_asc_pipeline() + + @mock.patch.multiple("mx_rec.core.asc.manager", + generate_table_info_list=mock.MagicMock(return_value=["test_table"]), + generate_threshold_list=mock.MagicMock(return_value=[]), + get_stat_on=mock.MagicMock(return_value=True), + is_asc_manager_initialized=mock.MagicMock(return_value=False), + export_table_num=mock.MagicMock(return_value=None), + initialize_emb_cache=mock.MagicMock(return_value=None)) + def test_start_asc_pipeline_case2(self): + """ + case2: table_info_list为["test_table"],stat_on为True + """ + + from mx_rec.core.asc.manager import start_asc_pipeline + + # 该函数无返回值,且内部调用的函数已经被mock了,无异常、无返回值 + start_asc_pipeline() + self.assertTrue(callable(start_asc_pipeline)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/core/test_merge_table.py b/tests/mx_rec/core/test_merge_table.py new file mode 100644 index 00000000..56bbc357 --- /dev/null +++ b/tests/mx_rec/core/test_merge_table.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import unittest +from unittest import mock + +import tensorflow as tf + +import mx_rec.core.asc.merge_table +from tests.mx_rec.core.mock_class import MockSparseEmbedding + + +class TestAffirmFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.merge_table.affirm'. + """ + + def test_affirm_case1(self): + """ + case1: reach_op为[] + """ + + from mx_rec.core.asc.merge_table import affirm + + self.assertTrue(affirm([])) + + def test_affirm_case2(self): + """ + case2: reach_op中的算子没有("IdentityN", "Reshape", "Identity")中的类型 + """ + + from mx_rec.core.asc.merge_table import affirm + + with tf.Graph().as_default(): + self.assertFalse(affirm([tf.no_op()])) + + +class TestCheckOpFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.merge_table.check_op'. + """ + + def test_check_op_case1(self): + """ + case1: table_reachable_op的type为ApplyAdam + """ + + from mx_rec.core.asc.merge_table import check_op + + with tf.Graph().as_default(): + x = tf.Variable(0.0) + optimizer = tf.train.AdamOptimizer(learning_rate=0.01) + apply_op = optimizer.apply_gradients([(tf.constant(1.0), x)]) + self.assertTrue(check_op(apply_op.control_inputs[0])) + + def test_check_op_case2(self): + """ + case2: table_reachable_op是tf.no_op() + """ + + from mx_rec.core.asc.merge_table import check_op + + with tf.Graph().as_default(): + self.assertFalse(check_op(tf.no_op())) + + +class TestIsTrainTaskFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.merge_table.is_train_task'. + """ + + def setUp(self): + # 在tensorflow默认图中添加op + a = tf.constant(1) + b = tf.constant(2) + c = tf.add(a, b) + with tf.Session() as sess: + sess.run(c) + + def tearDown(self): + # 删除tensorflow默认图中添加的op + tf.reset_default_graph() + + @mock.patch.multiple("mx_rec.core.asc.merge_table", + get_bool_gauge_set=mock.MagicMock(return_value=[])) + def test_is_train_task_case1(self): + """ + case1: bool_gauge_set为[] + """ + + from mx_rec.core.asc.merge_table import is_train_task + + self.assertFalse(is_train_task()) + + @mock.patch.multiple("mx_rec.core.asc.merge_table", + get_bool_gauge_set=mock.MagicMock(return_value=["train"]), + check_op=mock.MagicMock(return_value=True)) + def test_is_train_task_case2(self): + """ + case2: bool_gauge_set为["train"],且check_op为True + """ + + from mx_rec.core.asc.merge_table import is_train_task + + self.assertTrue(is_train_task()) + + @mock.patch.multiple("mx_rec.core.asc.merge_table", + get_bool_gauge_set=mock.MagicMock(return_value=["train"]), + check_op=mock.MagicMock(return_value=False)) + def test_is_train_task_case3(self): + """ + case3: bool_gauge_set为["train"],且check_op为False + """ + + from mx_rec.core.asc.merge_table import is_train_task + + self.assertTrue(is_train_task()) + + +class TestFindDanglingTableFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.merge_table.find_dangling_table'. + """ + + def setUp(self): + # 在tensorflow默认图中添加op + a = tf.constant(1) + b = tf.constant(2) + c = tf.add(a, b) + with tf.Session() as sess: + sess.run(c) + + def tearDown(self): + # 删除tensorflow默认图中添加的op + tf.reset_default_graph() + + @mock.patch.multiple("mx_rec.core.asc.merge_table", + is_train_task=mock.MagicMock(return_value=False)) + def test_find_dangling_table_case1(self): + """ + case1: is_train_task为False + """ + + from mx_rec.core.asc.merge_table import find_dangling_table + + self.assertListEqual(find_dangling_table([]), []) + + @mock.patch.multiple("mx_rec.core.asc.merge_table", + is_train_task=mock.MagicMock(return_value=True), + get_enable_table_merge=mock.MagicMock(return_value=False)) + def test_find_dangling_table_case2(self): + """ + case2: is_train_task为True,merge为False + """ + + from mx_rec.core.asc.merge_table import find_dangling_table + + self.assertListEqual(find_dangling_table([]), []) + + @mock.patch.multiple("mx_rec.core.asc.merge_table", + is_train_task=mock.MagicMock(return_value=True), + affirm=mock.MagicMock(return_value=True), + get_enable_table_merge=mock.MagicMock(return_value=True), + insert_dangling_table=mock.MagicMock(return_value=None)) + @mock.patch("mx_rec.core.asc.merge_table.export_table_instances") + def test_find_dangling_table_case3(self, mock_export_table_instances): + """ + case3: is_train_task为True,merge为True + """ + + from mx_rec.core.asc.merge_table import find_dangling_table + + mock_export_table_instances.return_value = {"table1": MockSparseEmbedding("table1")} + dangling_table = find_dangling_table(["table2"]) + self.assertListEqual(dangling_table, ["table2", "table1"]) + + +class TestShouldSkipFunc(unittest.TestCase): + """ + Test for 'mx_rec.core.asc.merge_table.should_skip'. + """ + + def tearDown(self): + mx_rec.core.asc.merge_table.ASCEND_TABLE_NAME_MUST_CONTAIN = None + + def test_should_skip_case1(self): + """ + case1: table_name不包含"merged",ASCEND_TABLE_NAME_MUST_CONTAIN为str + """ + + from mx_rec.core.asc.merge_table import should_skip + + mx_rec.core.asc.merge_table.ASCEND_TABLE_NAME_MUST_CONTAIN = "merged" + self.assertTrue(should_skip("test_table")) + + def test_should_skip_case2(self): + """ + case2: table_name包含"merged",ASCEND_TABLE_NAME_MUST_CONTAIN为str + """ + + from mx_rec.core.asc.merge_table import should_skip + + mx_rec.core.asc.merge_table.ASCEND_TABLE_NAME_MUST_CONTAIN = "merged" + self.assertFalse(should_skip("merged_table")) + + def test_should_skip_case3(self): + """ + case3: table_name包含"merged",ASCEND_TABLE_NAME_MUST_CONTAIN为list + """ + + from mx_rec.core.asc.merge_table import should_skip + + mx_rec.core.asc.merge_table.ASCEND_TABLE_NAME_MUST_CONTAIN = ["merged"] + self.assertFalse(should_skip("merged_table")) + + +if __name__ == '__main__': + unittest.main() -- Gitee From 1e3b0dc9f55455c5fb14bb7f1b944e44f419d767 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 14 Dec 2023 10:41:45 +0800 Subject: [PATCH 520/551] Match-id-7537dc9480187befe2e4a168f9a909d8703ccded --- mx_rec/saver/saver.py | 9 +- mx_rec/saver/sparse.py | 6 +- tests/mx_rec/saver/sparse_embedding_mock.py | 28 +++++ tests/mx_rec/saver/test_saver.py | 82 +++++++++++++++ tests/mx_rec/saver/test_sparse.py | 109 ++++++++++++++++++++ 5 files changed, 227 insertions(+), 7 deletions(-) create mode 100644 tests/mx_rec/saver/sparse_embedding_mock.py create mode 100644 tests/mx_rec/saver/test_saver.py create mode 100644 tests/mx_rec/saver/test_sparse.py diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index f53f198b..e5071252 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -26,15 +26,16 @@ from mx_rec.util.log import logger # define save model thread class SaveModelThread(threading.Thread): - def __init__(self, sess, result, root_dir, table_name): + def __init__(self, saver, sess, result, root_dir, table_name): super().__init__() self.result = result self.root_dir = root_dir self.table_name = table_name self.sess = sess + self.saver = saver def run(self): - Saver().save_table_name_data(self.sess, self.result, self.root_dir, self.table_name) + self.saver.save_table_name_data(self.sess, self.result, self.root_dir, self.table_name) class Saver(object): @@ -61,7 +62,6 @@ class Saver(object): self.rank_id = get_rank_id() self.local_rank_size = get_local_rank_size() self.local_rank_id = self.rank_id % self.local_rank_size - self.rank_size = get_rank_size() self.save_op_dict = defaultdict(dict) self.restore_fetch_list = [] self.placeholder_dict = defaultdict(dict) @@ -195,7 +195,7 @@ class Saver(object): result = self.save_op_dict threads = [] for table_name in result.keys(): - thread = SaveModelThread(sess, result, root_dir, table_name) + thread = SaveModelThread(self, sess, result, root_dir, table_name) threads.append(thread) for thread in threads: @@ -214,7 +214,6 @@ class Saver(object): for var in self.var_list: if global_env.tf_device == TFDevice.NPU.value and "merged" not in var.name: continue - table_instance = get_table_instance(var) table_name = table_instance.table_name with tf.compat.v1.variable_scope(table_name): diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index 6f15c8f2..a75fec13 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -180,9 +180,11 @@ def export(table_list=None): def check_table_param(table_list, default_table_list): out_list = [] for table in table_list: - if table not in default_table_list: + if table in default_table_list: + out_list.append(table) + else: logger.warning("%s not be created , please check your table name.", table) - out_list.append(table) + return out_list diff --git a/tests/mx_rec/saver/sparse_embedding_mock.py b/tests/mx_rec/saver/sparse_embedding_mock.py new file mode 100644 index 00000000..a32b5e2d --- /dev/null +++ b/tests/mx_rec/saver/sparse_embedding_mock.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + + +import os + + +class SparseEmbeddingMock: + """ + sparse embedding mock module + """ + + def __init__(self, host_vocab_size=0): + self.is_save = True + self.table_name = "test_table" + self.slice_device_vocabulary_size = 10 + self.scalar_emb_size = 4 + self.host_vocabulary_size = host_vocab_size + self.use_feature_mapping = None + self.optimizer = dict() + self.use_dynamic_expansion = False + + def set_optimizer(self, key, state_dict): + if key in self.optimizer: + raise ValueError(f"Optimizer {key} has been set for hash table {self.table_name}") + + self.optimizer[key] = state_dict diff --git a/tests/mx_rec/saver/test_saver.py b/tests/mx_rec/saver/test_saver.py new file mode 100644 index 00000000..27bed04d --- /dev/null +++ b/tests/mx_rec/saver/test_saver.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +import os +import unittest +from unittest import mock + +import tensorflow as tf + +from mx_rec.saver.saver import Saver +from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION +from tests.mx_rec.saver.sparse_embedding_mock import SparseEmbeddingMock + +table_instance = SparseEmbeddingMock() + + +class TestSaver(unittest.TestCase): + """ + Test the function of saving and loading sparse tables. + """ + + @mock.patch.multiple("mx_rec.saver.saver", + get_rank_id=mock.MagicMock(return_value=0), + get_local_rank_size=mock.MagicMock(return_value=1), + get_ascend_global_hashtable_collection=mock.MagicMock( + return_value=ASCEND_GLOBAL_HASHTABLE_COLLECTION), + get_table_instance=mock.MagicMock(return_value=table_instance)) + def setUp(self): + self.table_name = "test_table" + self.optim_m_name = "test_table/LazyAdam/m" + self.optim_v_name = "test_table/LazyAdam/v" + self.graph = self.build_graph() + + with self.graph.as_default(): + self.saver = Saver() + + @mock.patch.multiple("mx_rec.saver.saver", + set_sparse_dir=mock.MagicMock(), + is_asc_manager_initialized=mock.MagicMock(return_value=True), + save_host_data=mock.MagicMock(), + get_use_dynamic_expansion=mock.MagicMock(return_value=False), + get_table_instance_by_name=mock.MagicMock(return_value=table_instance), + get_host_data=mock.MagicMock(return_value=[0, 1, 4, 6, 8]), + restore_host_data=mock.MagicMock()) + def test_save_and_load_is_consistent(self): + with tf.compat.v1.Session(graph=self.graph) as sess: + embedding_directory = "./sparse-model/HashTable/HBM/test_table/embedding" + data_file = os.path.join(embedding_directory, "slice_0.data") + attribute_file = os.path.join(embedding_directory, "slice_0.attribute") + sess.run(tf.global_variables_initializer()) + origin_embedding = sess.run(self.var)[[0, 1, 4, 6, 8], :] + + self.saver.save(sess) + self.assertTrue(os.path.exists(embedding_directory), "embedding目录已创建") + self.assertTrue(os.path.exists(data_file), "embedding的data文件存储成功") + self.assertTrue(os.path.exists(attribute_file), "embedding的attribute文件存储成功") + + self.saver.restore(sess, "./model") + load_embedding = sess.run(self.var)[:5, :] + self.assertEqual(load_embedding.all(), origin_embedding.all()) + + def build_graph(self): + self.graph = tf.compat.v1.Graph() + with self.graph.as_default(): + self.shape = tf.TensorShape([10, 4]) + emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=128) + initialized_tensor = emb_initializer(self.shape) + self.var = tf.compat.v1.get_variable(self.table_name, trainable=False, initializer=initialized_tensor) + + optim_m_tensor = emb_initializer(self.shape) + self.optimizer_m = tf.compat.v1.get_variable(self.optim_m_name, trainable=False, initializer=optim_m_tensor) + optim_v_tensor = emb_initializer(self.shape) + self.optimizer_v = tf.compat.v1.get_variable(self.optim_v_name, trainable=False, initializer=optim_v_tensor) + + table_instance.set_optimizer("LazyAdam", {"momentum": self.optimizer_m, "velocity": self.optimizer_v}) + tf.compat.v1.add_to_collection(ASCEND_GLOBAL_HASHTABLE_COLLECTION, self.var) + return self.graph + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/saver/test_sparse.py b/tests/mx_rec/saver/test_sparse.py new file mode 100644 index 00000000..f827aa9c --- /dev/null +++ b/tests/mx_rec/saver/test_sparse.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +import os +from unittest import mock +import unittest + +import tensorflow as tf +import numpy as np + +from mx_rec.saver.saver import write_binary_data, generate_file_name +from tests.mx_rec.saver.sparse_embedding_mock import SparseEmbeddingMock +from mx_rec.saver.sparse import export, set_upper_dir, check_table_param +from mx_rec.constants.constants import DataAttr + + +class TestSparseProcessor(unittest.TestCase): + """ + Test the function of exporting sparse tables. + """ + + def setUp(self): + self.table_name = "test_table" + self.device_dir_list = ["HashTable", "HBM"] + self.host_dir_list = ["HashTable", "DDR"] + self.fake_hbm_sparse_dir = "./test_export_hbm/sparse-model" + self.fake_ddr_sparse_dir = "./test_export_ddr/sparse-model" + self.hbm_npy_path = None + self.ddr_npy_path = None + + @mock.patch.multiple("mx_rec.saver.sparse", + export_table_name_set=mock.MagicMock(return_value={"test_table"}), + get_sparse_dir=mock.MagicMock(return_value="./test_export_hbm/sparse-model"), + get_table_instance_by_name=mock.MagicMock(return_value=SparseEmbeddingMock())) + def test_export_interface_on_hbm_mode(self): + self.build_fake_hbm_save() + export() + self.assertTrue(os.path.exists(self.hbm_npy_path)) + tf.io.gfile.rmtree("./test_export_hbm") + + + @mock.patch.multiple("mx_rec.saver.sparse", + export_table_name_set=mock.MagicMock(return_value={"test_table"}), + get_sparse_dir=mock.MagicMock(return_value="./test_export_ddr/sparse-model"), + get_table_instance_by_name=mock.MagicMock( + return_value=SparseEmbeddingMock(host_vocab_size=10))) + def test_export_interface_on_ddr_mode(self): + self.build_fake_ddr_save() + export() + self.assertTrue(os.path.exists(self.ddr_npy_path)) + tf.io.gfile.rmtree("./test_export_ddr") + + + def test_check_table_param(self): + table_list = ["test_table_1", "test_table_0"] + default_table_list = ["test_table_1", "test_table_2", "test_table_3"] + expect_table_list = ["test_table_1"] + result_table_list = check_table_param(table_list, default_table_list) + self.assertEqual(result_table_list, expect_table_list) + + def build_fake_hbm_save(self): + table_dir = os.path.join(set_upper_dir(self.fake_hbm_sparse_dir, self.device_dir_list), self.table_name) + fake_key = np.array([1, 2, 3, 4, 5]) + fake_emb = np.random.rand(5, 4).astype(np.float32) + # build HBM fake file + self.write_device_data(fake_emb, table_dir) + attribute = np.array([5, 1, 4]) + self.write_host_data(fake_key, attribute, "key", table_dir) + + self.hbm_npy_path = os.path.join(table_dir, "key-emb.npy") + + def build_fake_ddr_save(self): + table_dir = os.path.join(set_upper_dir(self.fake_ddr_sparse_dir, self.host_dir_list), self.table_name) + fake_key_offset_map = np.array([1, 0, 2, 6, 3, 2, 4, 9, 5, 4]) + key_offset_attribute = np.array([5, 2, 4]) + fake_embedding = np.random.rand(10, 4).astype(np.float32) + embedding_attribute = np.array([10, 4, 4]) + self.write_host_data(fake_key_offset_map, key_offset_attribute, "embedding_hashmap", table_dir) + self.write_host_data(fake_embedding, embedding_attribute, "embedding_data", table_dir) + + device_table_dir = os.path.join(set_upper_dir(self.fake_ddr_sparse_dir, self.device_dir_list), self.table_name) + fake_device_emb = np.random.rand(5, 4).astype(np.float32) + self.write_device_data(fake_device_emb, device_table_dir) + + self.ddr_npy_path = os.path.join(table_dir, "key-emb.npy") + + def write_device_data(self, embedding, table_dir): + attribute = dict() + attribute[DataAttr.DATATYPE.value] = embedding.dtype.name + attribute[DataAttr.SHAPE.value] = embedding.shape + + embedding_dir = os.path.join(table_dir, "embedding") + write_binary_data(embedding_dir, 0, embedding, attributes=attribute) + + def write_host_data(self, data, attribute, data_type, table_dir): + data_dir = os.path.join(table_dir, data_type) + tf.io.gfile.makedirs(data_dir) + data_file, attribute_file = generate_file_name(0) + target_data_dir = os.path.join(data_dir, data_file) + target_attribute_dir = os.path.join(data_dir, attribute_file) + + with tf.io.gfile.GFile(target_data_dir, "wb") as file: + data = data.tostring() + file.write(data) + + with tf.io.gfile.GFile(target_attribute_dir, "wb") as file: + attribute = attribute.tostring() + file.write(attribute) -- Gitee From ac546c623ca82040f57bcbf233f62f9f616eee9f Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 18 Dec 2023 14:57:24 +0800 Subject: [PATCH 521/551] Match-id-e29b39230493626d2dc417b7dbfa90b5b64ae6cf --- mx_rec/graph/acg_push_ops.py | 2 +- mx_rec/graph/modifier.py | 2 +- tests/mx_rec/graph/test_acg_push_ops.py | 499 ++++++++++++++++++++++++ 3 files changed, 501 insertions(+), 2 deletions(-) create mode 100644 tests/mx_rec/graph/test_acg_push_ops.py diff --git a/mx_rec/graph/acg_push_ops.py b/mx_rec/graph/acg_push_ops.py index aa60a5d6..1622c483 100644 --- a/mx_rec/graph/acg_push_ops.py +++ b/mx_rec/graph/acg_push_ops.py @@ -245,7 +245,7 @@ def _push_subgraph_to_dataset(graph: tf.Graph, subgraph_to_push: Set[tf.Operatio get_next_node = graph.get_operation_by_name("IteratorGetNext") src_dataset = _get_src_dataset(graph, get_next_node) - def acg_func(*x): + def acg_func(*x): # pragma: no cover old_x = x logger.debug("Got old batch layout: %s", x) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index d936fa7c..c883ce4d 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -49,7 +49,7 @@ def get_preprocessing_map_func( raise ValueError("It is legal when and only when one of the parameters 'batch_tensor_names' and " "'pipeline_input_indexes' was given.") - def map_func(*args): + def map_func(*args): # pragma: no cover def parse_batch(data_args: Any, data_batch: dict, key: str = None): """ 解析原始数据集中的batch,并将非dict格式的batch转为dict格式. diff --git a/tests/mx_rec/graph/test_acg_push_ops.py b/tests/mx_rec/graph/test_acg_push_ops.py new file mode 100644 index 00000000..f24ab78e --- /dev/null +++ b/tests/mx_rec/graph/test_acg_push_ops.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +from unittest import TestCase +from unittest.mock import patch, Mock + +import numpy as np +import tensorflow as tf +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.data.ops.dataset_ops import DatasetV1 +from mx_rec.graph.acg_push_ops import ( + ACGPushOpsToDatasetHook, + SubgraphInfo, + _OP_NAME_CONTAIN_STRING_TO_PUSH, + _ACG_NEW_INITIALIZER, + _find_ops_to_be_pushed, + _find_op_from_base_op, + _find_subgraph_nodes, + _get_mapping_tensor, + _topo_subgraph, + _get_dataset_op, + _clone_subgraph_into_funcgraph, + _update_subgraph_out_consumer, + _get_src_dataset, + _update_iterator_getnext, + _find_subgraph_in_out, + _push_subgraph_to_dataset, + _warn_for_var_scope_nodes, + _frozen_variable_node_to_func_const_node_def, + _update_old_consumer, + _get_mapping_for_subgraph, + _get_mapping_for_subgraph_in, + _ordered_output_from_subgraph, + _replace_get_next_op, + _patched_get_src_dataset, +) + +from tests.mx_rec.graph.mock_dataset import gen_mock_dataset + + +@patch.multiple( + "mx_rec.graph.patch", + get_modify_graph=Mock(return_value=True), + get_is_graph_modify_hook_running=Mock(return_value=True), +) +@patch.multiple( + "tensorflow.compat.v1.train.Saver", + __init__=Mock(return_value=None), + build=Mock(), +) +@patch.multiple("mx_rec.graph.acg_push_ops", _find_ops_to_be_pushed=Mock()) +class ACGPushOpsToDatasetHookTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_cutting_point = tf.identity(mock_ids) + + mock_new_iterator = mock_dataset.make_initializable_iterator() + tf.compat.v1.add_to_collection(_ACG_NEW_INITIALIZER, mock_new_iterator.initializer) + + with tf.compat.v1.train.MonitoredSession(hooks=[ACGPushOpsToDatasetHook()]) as sess: + sess.run(mock_iterator.initializer) + sess.run(mock_cutting_point) + + +@patch.multiple( + "mx_rec.graph.acg_push_ops", + _find_subgraph_nodes=Mock(return_value=set()), + _push_subgraph_to_dataset=Mock(), +) +class FindOpsToBePushedTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok_op_contain_str_to_push(self): + tensor = tf.constant(value=[1, 2, 3], name="MOCK" + list(_OP_NAME_CONTAIN_STRING_TO_PUSH)[0]) + mock_graph = tf.compat.v1.get_default_graph() + _find_ops_to_be_pushed(mock_graph) + + def test_ok_op_type_to_push(self): + const_tensor = tf.constant(value=[1, 2, 3], dtype=tf.int32) + str_tensor = tf.compat.v1.as_string(const_tensor) + num_tensor = tf.compat.v1.string_to_number(str_tensor) + mock_graph = tf.compat.v1.get_default_graph() + _find_ops_to_be_pushed(mock_graph) + + def test_ok_no_node_to_push(self): + mock_graph = tf.compat.v1.get_default_graph() + _find_ops_to_be_pushed(mock_graph) + + +class FindSubgraphNodesTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + + tensor_in_subgraph = tf.identity(mock_ids) + tensor_out_subgraph = tf.identity(tensor_in_subgraph) + mock_base_nodes = {tensor_out_subgraph.op} + + subgraph_nodes = _find_subgraph_nodes( + tf.compat.v1.get_default_graph(), mock_base_nodes, tgt_op_type="IteratorGetNext" + ) + self.assertEqual(subgraph_nodes, {tensor_in_subgraph.op, tensor_out_subgraph.op}) + + +class WarnForVarScopeNodesTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + with tf.compat.v1.variable_scope("mock_var_scope"): + var1 = tf.compat.v1.get_variable("var", shape=(3, 3), initializer=tf.random_normal_initializer()) + + mock_all_nodes = tf.compat.v1.get_default_graph().get_operations() + mock_base_node = var1.op + _warn_for_var_scope_nodes(mock_all_nodes, mock_base_node) + + +class FindOpFromBaseOpTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_err_no_tgt_op_type(self): + parent_tensor = tf.ones(shape=(3, 3)) + child_tensor = tf.identity(parent_tensor) + with self.assertRaises(ValueError): + _find_op_from_base_op(child_tensor.op, "IteratorGetNext") + + +class GetDatasetOpTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_prefetch_dataset = mock_dataset.prefetch(buffer_size=10) + mock_iterator = mock_prefetch_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + mock_graph = tf.compat.v1.get_default_graph() + expected = mock_graph.get_operation_by_name("OptimizeDataset") + + tgt_dataset_op = _get_dataset_op(mock_graph, mock_get_next_op) + self.assertEqual(tgt_dataset_op, expected) + + def test_err_invalid_get_next_op_type(self): + mock_get_next_op = tf.zeros(shape=(3,)).op + mock_graph = tf.compat.v1.get_default_graph() + + with self.assertRaises(TypeError): + _get_dataset_op(mock_graph, mock_get_next_op) + + @patch.multiple("mx_rec.graph.acg_push_ops", _find_op_from_base_op=Mock(return_value=None)) + @patch.multiple("mx_rec.graph.acg_push_ops.modifier", find_parent_op=Mock(return_value=None)) + def test_err_no_tgt_op_found(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + mock_graph = tf.compat.v1.get_default_graph() + + with self.assertRaises(RuntimeError): + _get_dataset_op(mock_graph, mock_get_next_op) + + +class OrderedOutputFromSubgraphTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next(name="IteratorGetNext") + mock_ids = mock_batch.get("mock_ids") + + mock_subgraph_out = {tf.identity(mock_ids).op: {mock_ids.op}} + + addition_funcgraph_output_tensor = _ordered_output_from_subgraph(mock_subgraph_out) + self.assertEqual(addition_funcgraph_output_tensor, [mock_ids]) + + +class PushSubgraphToDatasetTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + + tensor_in_subgraph = tf.identity(mock_ids) + tensor_out_subgraph = tf.identity(tensor_in_subgraph) + mock_subgraph_to_push = {tensor_in_subgraph.op} + _push_subgraph_to_dataset(tf.compat.v1.get_default_graph(), mock_subgraph_to_push) + + +class FindSubgraphInOutTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + + tensor_in_subgraph = tf.identity(mock_ids) + tensor_out_subgraph = tf.identity(tensor_in_subgraph) + mock_subgraph_nodes = {tensor_in_subgraph.op} + + ( + subgraph_in, + subgraph_out, + ) = _find_subgraph_in_out(mock_subgraph_nodes) + self.assertEqual(subgraph_in, {mock_ids.op: {tensor_in_subgraph.op}}) + self.assertEqual(subgraph_out, {tensor_out_subgraph.op: {tensor_in_subgraph.op}}) + + +class GetSrcDatasetTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok_make_iterator(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + src_dataset = _get_src_dataset(tf.compat.v1.get_default_graph(), mock_get_next_op) + self.assertEqual(src_dataset, mock_dataset) + + def test_ok_one_shot_iterator(self): + mock_dataset = gen_mock_dataset() + mock_prefetch_dataset = mock_dataset.prefetch(10) + mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + src_dataset = _get_src_dataset(tf.compat.v1.get_default_graph(), mock_get_next_op) + self.assertEqual(src_dataset, mock_dataset) + + def test_err_no_anchor_dataset(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_one_shot_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + with self.assertRaises(RuntimeError): + _get_src_dataset(tf.compat.v1.get_default_graph(), mock_get_next_op) + + +class CloneSubgraphIntoFuncgraphTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + + mock_subgraph_in = {mock_ids.op: {tf.identity(mock_ids).op}} + mock_subgraph_out = {tf.identity(mock_ids).op: {mock_ids.op}} + mock_subgraph_to_push = set() + mock_subgraph_info = SubgraphInfo(mock_subgraph_in, mock_subgraph_out, mock_subgraph_to_push) + + mock_new_ids = tf.ones_like(mock_ids) + mock_x = [mock_new_ids] + mock_old_x = ({"mock_new_ids": mock_new_ids},) + + mock_defaultgraph = tf.compat.v1.get_default_graph() + with tf.Graph().as_default(): + mock_funcgraph = tf.compat.v1.get_default_graph() + _clone_subgraph_into_funcgraph(mock_funcgraph, mock_defaultgraph, mock_subgraph_info, mock_x, mock_old_x) + + +class GetMappingForSubgraphInTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_prefetch_dataset = mock_dataset.prefetch(10) + mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + + mock_from_node = mock_ids.op + mock_to_nodes = {tf.identity(mock_ids).op} + mock_new_ids = tf.zeros_like(mock_ids) + mock_x = [mock_new_ids] + tensor_mapping = dict() + + _get_mapping_for_subgraph_in(mock_from_node, mock_to_nodes, mock_x, tensor_mapping) + self.assertEqual(tensor_mapping, {mock_ids: mock_new_ids}) + + +class GetMappingForSubgraphTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_defaultgraph = tf.compat.v1.get_default_graph() + + # NOTE: Simulate independent graph environment while executing `dataset.map()` method. + with tf.Graph().as_default(): + key_tensor = tf.zeros(shape=(1)) + val_tensor = tf.zeros(shape=(1)) + mock_tensor_mapping = {key_tensor: val_tensor} + + mock_node_mapping = dict() + mock_old_node = tf.identity(key_tensor).op + mock_funcgraph = tf.compat.v1.get_default_graph() + + _get_mapping_for_subgraph( + mock_funcgraph, mock_defaultgraph, mock_node_mapping, mock_old_node, mock_tensor_mapping + ) + + self.assertEqual(len(mock_node_mapping), 1) + self.assertEqual(len(mock_tensor_mapping), 2) + + +class FrozenVariableNodeToFuncConstNodeDefTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + var_tensor = tf.Variable(initial_value=[1], shape=(1,)) + tf.compat.v1.assign(ref=var_tensor, value=[1]) + + mock_funcgraph = tf.Graph() + mock_defaultgraph = tf.compat.v1.get_default_graph() + new_const_node: node_def_pb2.NodeDef = _frozen_variable_node_to_func_const_node_def( + var_tensor.op, mock_funcgraph, mock_defaultgraph + ) + self.assertEqual(new_const_node.op, "Const") + + +class GetMappingTensorTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + key_tensor = tf.zeros(shape=(3, 3)) + val_tensor = tf.ones(shape=(3, 3)) + tsr2tsr = {key_tensor: val_tensor} + keys = [key_tensor] + + mapped_tensors = _get_mapping_tensor(tsr2tsr, keys) + self.assertEqual(mapped_tensors, [val_tensor]) + + def test_err_key_tensor_not_exist(self): + tsr2tsr = {tf.zeros(shape=(3, 3)): tf.ones(shape=(3, 3))} + keys = [tf.ones(shape=(3, 3))] + + with self.assertRaises(KeyError): + _get_mapping_tensor(tsr2tsr, keys) + + +class TopoSubgraphTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_prefetch_dataset = mock_dataset.prefetch(10) + mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + tensor1 = tf.identity(mock_ids) + tensor2 = tf.add(tensor1, 1) + mock_subgraph = {tensor1.op, tensor2.op} + + const_op_for_add = None + for tensor in tensor2.op.inputs: + if tensor.op.name != "Add/y": + continue + const_op_for_add = tensor.op + + if not const_op_for_add: + self.fail( + f"Failed to find input of add operation, input tensor of add op: {[x.op for x in tensor2.op.inputs]}" + ) + + topo_subgraph_list = _topo_subgraph(mock_subgraph) + self.assertEqual(topo_subgraph_list, [tensor1.op, const_op_for_add, tensor2.op]) + + +class UpdateIteratorGetNextTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_old_dataset = gen_mock_dataset() + mock_old_iterator = mock_old_dataset.make_initializable_iterator() + mock_old_batch = mock_old_iterator.get_next(name="OldIteratorGetNext") + mock_old_ids = mock_old_batch.get("mock_ids") + mock_old_get_next_op = mock_old_ids.op + + mock_new_dataset: DatasetV1 = mock_old_dataset.map(lambda x: x) + mock_subgraph_out = {tf.identity(mock_old_ids).op: {mock_old_ids.op}} + + _update_iterator_getnext( + graph=tf.compat.v1.get_default_graph(), + get_next_op=mock_old_get_next_op, + tgt_dataset=mock_new_dataset, + subgraph_out=mock_subgraph_out, + subgraph_to_push=set(), + ) + + +class UpdateOldConsumerTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next(name="NewIteratorGetNext") + mock_ids = mock_batch.get("mock_ids") + mock_new_get_next_op = mock_ids.op + mock_output_tensor = tf.identity(mock_ids) + + _update_old_consumer( + graph=tf.compat.v1.get_default_graph(), + new_get_next_op=mock_new_get_next_op, + output_tensor=mock_ids, + subgraph_to_push=set(), + ) + + +class UpdateSubgraphOutConsumerTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_iterator = mock_dataset.make_initializable_iterator() + mock_batch = mock_iterator.get_next(name="NewIteratorGetNext") + mock_ids = mock_batch.get("mock_ids") + mock_new_get_next_op = mock_ids.op + mock_output_tensor = tf.identity(mock_ids) + + _update_subgraph_out_consumer( + graph=tf.compat.v1.get_default_graph(), + new_get_next_op=mock_new_get_next_op, + offset=0, + output_tensor=mock_ids, + ) + + +class PatchedGetSrcDatasetTest(TestCase): + def tearDown(self) -> None: + tf.compat.v1.reset_default_graph() + + def test_ok(self): + mock_dataset = gen_mock_dataset() + mock_prefetch_dataset = mock_dataset.prefetch(10) + mock_double_prefetch_dataset = mock_prefetch_dataset.prefetch(10) + mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + src_dataset = _patched_get_src_dataset(mock_get_next_op, is_training=True) + self.assertEqual(src_dataset, mock_prefetch_dataset) + + def test_err_single_prefetch_dataset(self): + mock_dataset = gen_mock_dataset() + mock_prefetch_dataset = mock_dataset.prefetch(10) + mock_iterator = mock_prefetch_dataset.make_one_shot_iterator() + mock_batch = mock_iterator.get_next() + mock_ids = mock_batch.get("mock_ids") + mock_get_next_op = mock_ids.op + + with self.assertRaises(RuntimeError): + _patched_get_src_dataset(mock_get_next_op, is_training=True) -- Gitee From 9aefdabbabf4631a0cd6ddd6f25e783bb4b359c6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 18 Dec 2023 20:12:01 +0800 Subject: [PATCH 522/551] Match-id-afba0856e34a52c9e2247db6b7c58578d343e1f8 --- mx_rec/util/communication/hccl_mgmt.py | 15 +-- tests/mx_rec/util/__init__.py | 0 tests/mx_rec/util/communication/__init__.py | 0 .../util/communication/test_hccl_mgmt.py | 127 ++++++++++++++++++ tests/mx_rec/util/test_atomic.py | 57 ++++++++ tests/mx_rec/util/test_normalization.py | 18 +++ tests/mx_rec/util/test_perf.py | 19 +++ tests/mx_rec/util/test_variable.py | 73 ++++++++++ tests/mx_rec/validator/test_validators.py | 59 +++++++- 9 files changed, 356 insertions(+), 12 deletions(-) create mode 100644 tests/mx_rec/util/__init__.py create mode 100644 tests/mx_rec/util/communication/__init__.py create mode 100644 tests/mx_rec/util/communication/test_hccl_mgmt.py create mode 100644 tests/mx_rec/util/test_atomic.py create mode 100644 tests/mx_rec/util/test_normalization.py create mode 100644 tests/mx_rec/util/test_perf.py create mode 100644 tests/mx_rec/util/test_variable.py diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index 1a50cec3..89a50400 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -5,12 +5,11 @@ import json import os -from mx_rec.constants.constants import VALID_DEVICE_ID_LIST, MIN_SIZE, MAX_CONFIG_SIZE, MAX_DEVICE_ID, \ - MIN_RANK_SIZE, MAX_RANK_SIZE -from mx_rec.validator.validator import FileValidator, para_checker_decorator, StringValidator, \ - Convert2intValidator +from mxrec_pybind import get_logic_id +from mx_rec.constants.constants import MAX_CONFIG_SIZE, MAX_DEVICE_ID, MAX_RANK_SIZE, MIN_RANK_SIZE, MIN_SIZE, \ + VALID_DEVICE_ID_LIST from mx_rec.util.global_env_conf import global_env -from mx_rec.util.log import logger +from mx_rec.validator.validator import Convert2intValidator, FileValidator, para_checker_decorator, StringValidator def parse_hccl_json(): @@ -54,8 +53,7 @@ def parse_hccl_json(): if "device_id" not in device or not device.get("device_id").isdigit(): raise ValueError(f"hccl_json device_id wrong.") - import mxrec_pybind - res = mxrec_pybind.get_logic_id(int(device.get("device_id"))) + res = get_logic_id(int(device.get("device_id"))) if res < 0: raise RuntimeError( f"get logic id from physic id fail, error code is {res}, please check if dsmi api is functional.") @@ -103,8 +101,7 @@ def set_hccl_info_without_json(visible_devices: str, rank_size: str, chief_devic sorted_device_list = sorted_device_list[:rank_size] for device_idx in sorted_device_list: - import mxrec_pybind - res = mxrec_pybind.get_logic_id(int(device_idx)) + res = get_logic_id(int(device_idx)) if res < 0: raise RuntimeError( f"get logic id from physic id fail, error code is {res}, please check if dsmi api is functional.") diff --git a/tests/mx_rec/util/__init__.py b/tests/mx_rec/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mx_rec/util/communication/__init__.py b/tests/mx_rec/util/communication/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mx_rec/util/communication/test_hccl_mgmt.py b/tests/mx_rec/util/communication/test_hccl_mgmt.py new file mode 100644 index 00000000..b9afb1c3 --- /dev/null +++ b/tests/mx_rec/util/communication/test_hccl_mgmt.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +import sys +import unittest +from unittest import mock +from unittest.mock import mock_open, patch + +from mx_rec.util.communication.hccl_mgmt import parse_hccl_json +from mx_rec.util.communication.hccl_mgmt import set_hccl_info_without_json +from mx_rec.util.global_env_conf import global_env + + +class HCCLMGMTTest(unittest.TestCase): + def setUp(self): + """ + 准备步骤 + :return:无 + """ + self.rank_table_file = global_env.rank_table_file + global_env.rank_table_file = __file__ + + def tearDown(self): + """ + 销毁步骤 + :return: 无 + """ + global_env.rank_table_file = self.rank_table_file + + @patch.multiple("mx_rec.util.communication.hccl_mgmt", + get_logic_id=mock.MagicMock(return_value=1)) + def test_parse_hccl_json_when_success(self): + with patch("builtins.open", mock_open(read_data="""{ + "server_count":"1", + "server_list":[ + { + "device":[ + { "device_id":"0", "device_ip":"xxx.xxx.xx.xxx", "rank_id":"0" } + ], + "server_id":"xxx.xxx.xx.xxx" + } + ], + "status":"completed", + "version":"1.0" + }""")) as mock_file: + rank_to_device_dict, local_rank_size = parse_hccl_json() + self.assertEqual(1, local_rank_size) + + def test_parse_hccl_json_when_attribute_error(self): + with patch("builtins.open", mock_open(read_data="""{ + "server_count":"1", + "status":"completed", + "version":"1.0" + }""")) as mock_file: + with self.assertRaises(AttributeError): + rank_to_device_dict, local_rank_size = parse_hccl_json() + + with patch("builtins.open", mock_open(read_data="""{ + "server_count":"1", + "server_list":[ + { + "server_id":"xxx.xxx.xx.xxx" + } + ], + "status":"completed", + "version":"1.0" + }""")) as mock_file: + with self.assertRaises(AttributeError): + rank_to_device_dict, local_rank_size = parse_hccl_json() + + def test_parse_hccl_json_when_value_error(self): + with patch("builtins.open", mock_open(read_data="""{ + "server_count":"1", + "server_list":[], + "status":"completed", + "version":"1.0" + }""")) as mock_file: + with self.assertRaises(ValueError): + rank_to_device_dict, local_rank_size = parse_hccl_json() + with patch("builtins.open", mock_open(read_data="""{ + "server_count":"1", + "server_list":[ + { + "device":[ + { "device_id":"0", "device_ip":"xxx.xxx.xx.xxx"} + ], + "server_id":"xxx.xxx.xx.xxx" + } + ], + "status":"completed", + "version":"1.0" + }""")) as mock_file: + with self.assertRaises(ValueError): + rank_to_device_dict, local_rank_size = parse_hccl_json() + with patch("builtins.open", mock_open(read_data="""{ + "server_count":"1", + "server_list":[ + { + "device":[ + {"device_ip":"xxx.xxx.xx.xxx", "rank_id":"0" } + ], + "server_id":"xxx.xxx.xx.xxx" + } + ], + "status":"completed", + "version":"1.0" + }""")) as mock_file: + with self.assertRaises(ValueError): + rank_to_device_dict, local_rank_size = parse_hccl_json() + + def test_set_hccl_info_without_json(self): + rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0-7", "8", "0") + self.assertEqual(8, local_rank_size) + rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0,1", "2", "0") + self.assertEqual(2, local_rank_size) + rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0", "1", "0") + self.assertEqual(1, local_rank_size) + with self.assertRaises(ValueError): + rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0", "8", "0") + with self.assertRaises(ValueError): + rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0-2", "8", "3") + with self.assertRaises(ValueError): + rank_to_device_dict, local_rank_size = set_hccl_info_without_json("17", "1", "1") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/util/test_atomic.py b/tests/mx_rec/util/test_atomic.py new file mode 100644 index 00000000..2a6b37c2 --- /dev/null +++ b/tests/mx_rec/util/test_atomic.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import unittest +from threading import Thread + +from mx_rec.util.atomic import AtomicInteger + + +class AtomicTest(unittest.TestCase): + def setUp(self): + """ + 准备步骤 + :return:无 + """ + self.iter_times = 100_000 + self.num_thread = 2 + + def test_str(self): + num = AtomicInteger(1) + self.assertEqual("1", str(num)) + + def test_increase(self): + num = AtomicInteger(0) + thread_pool = [] + + def runnable(): + for _ in range(self.iter_times): + num.increase() + + for _ in range(self.num_thread): + thread = Thread(target=runnable) + thread.start() + thread_pool.append(thread) + for thread in thread_pool: + thread.join() + self.assertEqual(200_000, num.value()) + + def test_decrease(self): + num = AtomicInteger(200_000) + thread_pool = [] + + def runnable(): + for _ in range(self.iter_times): + num.decrease() + + for _ in range(self.num_thread): + thread = Thread(target=runnable) + thread.start() + thread_pool.append(thread) + for thread in thread_pool: + thread.join() + self.assertEqual(0, num.value()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/util/test_normalization.py b/tests/mx_rec/util/test_normalization.py new file mode 100644 index 00000000..7bb2c967 --- /dev/null +++ b/tests/mx_rec/util/test_normalization.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import unittest + +from mx_rec.util.normalization import fix_invalid_table_name + + +class NormalizationTest(unittest.TestCase): + + def test_fix_invalid_table_name(self): + self.assertEqual("user112", fix_invalid_table_name("user1#12")) + with self.assertRaises(ValueError): + fix_invalid_table_name("####@@@") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/util/test_perf.py b/tests/mx_rec/util/test_perf.py new file mode 100644 index 00000000..2b50a2ac --- /dev/null +++ b/tests/mx_rec/util/test_perf.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import unittest + +from mx_rec.util.perf import performance + + +class PerfTest(unittest.TestCase): + + def test_performance(self): + def func(): + return 9 + derec = performance("func") + self.assertEqual(9, derec(func)()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/util/test_variable.py b/tests/mx_rec/util/test_variable.py new file mode 100644 index 00000000..b6ac5bd8 --- /dev/null +++ b/tests/mx_rec/util/test_variable.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import unittest +from unittest import mock +from unittest.mock import patch + +import tensorflow as tf +from mx_rec.util.global_env_conf import global_env +from mx_rec.util.variable import check_and_get_config_via_var +from mx_rec.util.variable import get_dense_and_sparse_variable + + +class MockTableInstance: + def __init__(self): + self.skip_emb_transfer = False + self.optimizer = False + + +class VariableTest(unittest.TestCase): + def setUp(self): + """ + 准备步骤 + :return:无 + """ + self.cm_worker_size = global_env.cm_worker_size + self.cm_chief_device = global_env.cm_chief_device + self.ascend_visible_devices = global_env.ascend_visible_devices + global_env.cm_worker_size = "8" + global_env.cm_chief_device = "0" + global_env.ascend_visible_devices = "0-7" + + def tearDown(self): + """ + 销毁步骤 + :return: 无 + """ + global_env.cm_worker_size = self.cm_worker_size + global_env.cm_chief_device = self.cm_chief_device + global_env.ascend_visible_devices = self.ascend_visible_devices + + @patch.multiple("mx_rec.util.variable", + get_ascend_global_hashtable_collection=mock.MagicMock(return_value="sparse_hastable")) + def test_get_dense_and_sparse_variable(self): + dense_layer = tf.Variable([1, 2], trainable=True) + sparse_emb = tf.Variable([4, 5], trainable=False) + tf.compat.v1.add_to_collection("sparse_hastable", sparse_emb) + tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, dense_layer) + dense_variables, sparse_variables = get_dense_and_sparse_variable() + with tf.Session() as sess: + result = tf.reduce_all(tf.equal(dense_layer, dense_variables)) + sess.run(tf.compat.v1.global_variables_initializer()) + result_run = sess.run([result]) + + self.assertTrue(result_run) + tf.reset_default_graph() + + @patch.multiple("mx_rec.util.variable", + get_table_instance=mock.MagicMock(return_value=MockTableInstance())) + def test_check_and_get_config_via_var_when_environment_error(self): + with self.assertRaises(EnvironmentError): + self.assertEqual(MockTableInstance(), check_and_get_config_via_var("1", "optimize")) + + def test_check_and_get_config_via_var_when_success(self): + table_instance = MockTableInstance() + table_instance.skip_emb_transfer = True + table_instance.optimizer = True + with patch("mx_rec.util.variable.get_table_instance") as mock_get_table_instance: + mock_get_table_instance.return_value = mock.MagicMock(table_instance) + self.assertEqual(mock_get_table_instance.return_value, check_and_get_config_via_var("1", "optimize")) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/validator/test_validators.py b/tests/mx_rec/validator/test_validators.py index bd4ee631..aa97c3b6 100644 --- a/tests/mx_rec/validator/test_validators.py +++ b/tests/mx_rec/validator/test_validators.py @@ -6,9 +6,9 @@ import sys import tempfile import unittest -from mx_rec.validator.validator import Validator, StringValidator, DirectoryValidator, para_checker_decorator, \ - ClassValidator, OptionValidator, ValueCompareValidator, OptionalStringValidator, \ - OptionalIntValidator, NumValidator, IntValidator, Convert2intValidator +from mx_rec.validator.validator import ClassValidator, Convert2intValidator, DirectoryValidator, IntValidator, \ + NumValidator, OptionalIntValidator, OptionalStringValidator, OptionValidator, para_checker_decorator, \ + StringValidator, ValueCompareValidator, FloatValidator, SSDFeatureValidator sys.modules['mxrec_pybind'] = __import__('os') @@ -142,6 +142,59 @@ class ParameterCheckerTest(unittest.TestCase): result = False self.assertTrue(result) + def test_ssd_feature_validator_when_size_0(self): + @para_checker_decorator(check_option_list=[ + ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, + ["check_string_length", "check_whitelist"]), + (["ssd_vocabulary_size", "ssd_data_path", "host_vocabulary_size"], SSDFeatureValidator)]) + def demo_func(name, host_vocabulary_size=0, + ssd_vocabulary_size=0, + ssd_data_path="./"): + return True + + try: + result = demo_func(name="host", host_vocabulary_size=0, + ssd_vocabulary_size=0, + ssd_data_path="./") + except ValueError: + result = False + self.assertTrue(result) + + def test_ssd_feature_validator_when_size_not_0(self): + @para_checker_decorator(check_option_list=[ + ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, + ["check_string_length", "check_whitelist"]), + (["ssd_vocabulary_size", "ssd_data_path", "host_vocabulary_size"], SSDFeatureValidator)]) + def demo_func(name, host_vocabulary_size=0, + ssd_vocabulary_size=0, + ssd_data_path="./"): + return True + + try: + result = demo_func(name="host", host_vocabulary_size=0, + ssd_vocabulary_size=1, + ssd_data_path="./") + except ValueError: + result = False + self.assertFalse(result) + + def test_check_value_for_open_interval(self): + @para_checker_decorator(check_option_list=[ + ("beta1", FloatValidator, {"min_value": 0, "max_value": 1}, + ["check_value_for_open_interval", "check_value_for_right_open_interval", + "check_value_for_left_open_interval"])]) + def demo_func(beta1): + return True + + try: + result = demo_func(beta1=0.5) + except ValueError: + result = False + self.assertTrue(result) + + def test_is_valid(self): + self.assertTrue(StringValidator("val", 'aa.1245', max_len=30).is_valid()) + if __name__ == '__main__': unittest.main() -- Gitee From c7114733d5f473192e10186901ad0532344fabdb Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 19 Dec 2023 10:12:15 +0800 Subject: [PATCH 523/551] Match-id-54395ae3a775e8c6e2577c917327342b7e6ae65e --- mx_rec/saver/saver.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index e5071252..daba4542 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -253,7 +253,6 @@ class Saver(object): assign_op = state.assign(sub_optimizer_placeholder_dict.get(key_state)) self.restore_fetch_list.append(assign_op) - def _restore(self, sess, reading_path): if is_asc_manager_initialized(): restore_host_data(reading_path) @@ -269,29 +268,14 @@ class Saver(object): fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, NameDescriptor(table_name, DataName.EMBEDDING.value)) - table_instance = get_table_instance_by_name(table_name) - - if table_instance.use_feature_mapping: - fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, - NameDescriptor(table_name, DataName.FEATURE_MAPPING.value)) - fill_placeholder(reading_path, sub_placeholder_dict, restore_feed_dict, self.rank_id, - NameDescriptor(table_name, DataName.OFFSET.value)) - if "optimizer" in sub_placeholder_dict: optimizer_state_placeholder_dict_group = sub_placeholder_dict.get("optimizer") - for optimizer_name, optimizer_state_placeholder_dict in optimizer_state_placeholder_dict_group.items(): - for state_key in optimizer_state_placeholder_dict: - fill_placeholder(reading_path=reading_path, - placeholder_dict=optimizer_state_placeholder_dict, - feed_dict=restore_feed_dict, - suffix=self.rank_id, - name_descriptor=NameDescriptor(table_name, state_key, - optimizer_name=optimizer_name)) + fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group, reading_path, + restore_feed_dict, self.rank_id, table_name) sess.run(self.restore_fetch_list, feed_dict=restore_feed_dict) - class NameDescriptor: def __init__(self, table_name, data_name, optimizer_name=None): self.table_name = table_name @@ -299,6 +283,18 @@ class NameDescriptor: self.optimizer_name = optimizer_name +def fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group, reading_path, restore_feed_dict, suffix, + table_name): + for optimizer_name, optimizer_state_placeholder_dict in optimizer_state_placeholder_dict_group.items(): + for state_key in optimizer_state_placeholder_dict: + fill_placeholder(reading_path=reading_path, + placeholder_dict=optimizer_state_placeholder_dict, + feed_dict=restore_feed_dict, + suffix=suffix, + name_descriptor=NameDescriptor(table_name, state_key, + optimizer_name=optimizer_name)) + + def get_valid_dict_data_from_host_offset(dump_data_dict: dict, offset: list): """ Extract embedding and optimizer data from the dict based on offset. -- Gitee From e544382a80a9e6114e96dd0ed6173a7362de141a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 19 Dec 2023 14:37:01 +0800 Subject: [PATCH 524/551] Match-id-494ca2fbc1a23d58bc90219843535b498731f150 --- mx_rec/saver/saver.py | 23 ++++++++++----------- tests/mx_rec/saver/sparse_embedding_mock.py | 1 - 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index daba4542..9c08d0a0 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -283,18 +283,6 @@ class NameDescriptor: self.optimizer_name = optimizer_name -def fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group, reading_path, restore_feed_dict, suffix, - table_name): - for optimizer_name, optimizer_state_placeholder_dict in optimizer_state_placeholder_dict_group.items(): - for state_key in optimizer_state_placeholder_dict: - fill_placeholder(reading_path=reading_path, - placeholder_dict=optimizer_state_placeholder_dict, - feed_dict=restore_feed_dict, - suffix=suffix, - name_descriptor=NameDescriptor(table_name, state_key, - optimizer_name=optimizer_name)) - - def get_valid_dict_data_from_host_offset(dump_data_dict: dict, offset: list): """ Extract embedding and optimizer data from the dict based on offset. @@ -313,6 +301,17 @@ def get_valid_dict_data_from_host_offset(dump_data_dict: dict, offset: list): dump_data_dict["optimizer"] = dump_optimizer_data_dict +def fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group, reading_path, restore_feed_dict, suffix, + table_name): + for optimizer_name, optimizer_state_placeholder_dict in optimizer_state_placeholder_dict_group.items(): + for state_key in optimizer_state_placeholder_dict: + fill_placeholder(reading_path=reading_path, + placeholder_dict=optimizer_state_placeholder_dict, + feed_dict=restore_feed_dict, + suffix=suffix, + name_descriptor=NameDescriptor(table_name, state_key, optimizer_name=optimizer_name)) + + def fill_placeholder(reading_path, placeholder_dict, feed_dict, suffix, name_descriptor): if name_descriptor.optimizer_name: target_path = generate_path(reading_path, "Optimizer", name_descriptor.optimizer_name, "HBM", diff --git a/tests/mx_rec/saver/sparse_embedding_mock.py b/tests/mx_rec/saver/sparse_embedding_mock.py index a32b5e2d..6a3d12df 100644 --- a/tests/mx_rec/saver/sparse_embedding_mock.py +++ b/tests/mx_rec/saver/sparse_embedding_mock.py @@ -17,7 +17,6 @@ class SparseEmbeddingMock: self.slice_device_vocabulary_size = 10 self.scalar_emb_size = 4 self.host_vocabulary_size = host_vocab_size - self.use_feature_mapping = None self.optimizer = dict() self.use_dynamic_expansion = False -- Gitee From 524bbaeddb65348237b693f05710a06df65cb048 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 19 Dec 2023 19:14:13 +0800 Subject: [PATCH 525/551] Match-id-b41ff655d12f4a0c7b274d5ef38fcc0b4c8a6e8b --- src/core/emb_table/emb_table.cpp | 4 - src/core/key_process/key_process.cpp | 16 +- src/core/utils/common.h | 5 +- src/test_ut.sh | 2 +- src/tests/CMakeLists.txt | 9 +- src/tests/key_process/key_process_test.cpp | 554 +++++++++++++++++++-- 6 files changed, 518 insertions(+), 72 deletions(-) diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index 148796fc..af617e3c 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -144,14 +144,10 @@ void EmbTable::PrintStatus() const int64_t EmbTable::GetTableSize() const { -#ifndef GTEST return static_cast(usedCapacity); -#endif } int64_t EmbTable::GetTableCapacity() const { -#ifndef GTEST return static_cast(totalCapacity); -#endif } diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 3befd2f8..f1e8295e 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -78,9 +78,7 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos LOG_INFO(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", MapToString(scInfo), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot); -#ifndef GTEST Start(); -#endif return true; } @@ -1307,9 +1305,8 @@ KeysT KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) /// \param channel 通道索引(训练/推理) void KeyProcess::SendEos(int batchId, int channel) { -#ifndef GTEST LOG_INFO("channelId:{} batchId:{}, SendEos start.", channel, batchId); - +#ifndef GTEST auto trans = Singleton::GetInstance(); unordered_map transChannels = trans->GetTransChannel(); std::set usedChannelNames = trans->GetUsedTransChannel()[channel]; @@ -1329,12 +1326,11 @@ void KeyProcess::SendEos(int batchId, int channel) } LOG_INFO("channelId:{} batchId:{}, the embName:{} related channel SendEos end.", channel, batchId, emb.first); } - +#endif LOG_INFO("channelId:{} batchId:{}, SendEos end.", channel, batchId); isNeedSendEos[channel] = false; mpiAllReduceSend[channel] = 0; isNeedExit[channel] = true; -#endif } /// HBM模式下,从list中获取指定类型的tensor向量 @@ -1493,7 +1489,7 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset embName, offset.size(), embInfos[embName].devVocabSize ).c_str()); } - +#ifndef GTEST vector tmpDataOut; Tensor tmpData = Vec2TensorI32(offset); tmpDataOut.emplace_back(tmpData); @@ -1506,7 +1502,7 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset // evict key发送给dev侧,dev侧初始化emb auto trans = Singleton::GetInstance(); trans->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); - +#endif LOG_INFO(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); } @@ -1525,7 +1521,6 @@ string KeyProcess::DumpSplitKeys(vector> &splitKeys) const int64_t KeyProcess::GetExpansionTableSize(const string& embName) { -#ifndef GTEST const auto& iter = embeddingTableMap.find(embName); if (iter == embeddingTableMap.end()) { LOG_ERROR(KEY_PROCESS "GetExpansionEmbSize, wrong embName:{} ", embName); @@ -1533,12 +1528,10 @@ int64_t KeyProcess::GetExpansionTableSize(const string& embName) } std::lock_guard lk(mut); // lock for PROCESS_THREAD return iter->second.GetTableSize(); -#endif } int64_t KeyProcess::GetExpansionTableCapacity(const string& embName) { -#ifndef GTEST const auto& iter = embeddingTableMap.find(embName); if (iter == embeddingTableMap.end()) { LOG_ERROR(KEY_PROCESS "GetExpansionEmbSize, wrong embName:{} ", embName); @@ -1546,7 +1539,6 @@ int64_t KeyProcess::GetExpansionTableCapacity(const string& embName) } std::lock_guard lk(mut); // lock for PROCESS_THREAD return iter->second.GetTableCapacity(); -#endif } void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 30f1b17c..d353ea0d 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -79,8 +79,11 @@ namespace MxRec { constexpr int FILE_MIN_SIZE = 0; constexpr size_t BUFFER_SIZE{1024 * 1024 * 64}; constexpr size_t MAP_BYTE_SIZE{static_cast(10) * 1024 * 1024 * 1024}; - +#ifdef GTEST + constexpr int KEY_PROCESS_TIMEOUT = 3; +#else constexpr int KEY_PROCESS_TIMEOUT = 120; +#endif constexpr int GET_BATCH_TIMEOUT = 300; constexpr int EOS_TIMEOUT = 60; diff --git a/src/test_ut.sh b/src/test_ut.sh index 93c60f25..3a2987f0 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -81,7 +81,7 @@ cd "$(dirname "${PWD}")" COVERAGE_FILE=coverage.info REPORT_FOLDER=coverage_report lcov --rc lcov_branch_coverage=1 -c -d build -o "${COVERAGE_FILE}"_tmp -lcov -r "${COVERAGE_FILE}"_tmp 'ut/*' '/usr1/mxRec/src/core/key_process*' '/usr1/mxRec/src/core/hybrid_mgmt*' '/usr1/mxRec/src/core/host_emb*' '7/ext*' '*7/bits*' 'platform/*' '/usr/local/*' '/usr/include/*' '/opt/buildtools/python-3.7.5/lib/python3.7/site-packages/tensorflow*' 'tests/*' '/usr1/mxRec/src/core/ock_ctr_common/include*' --rc lcov_branch_coverage=1 -o "${COVERAGE_FILE}" +lcov -r "${COVERAGE_FILE}"_tmp 'ut/*' '/usr1/mxRec/src/core/hybrid_mgmt*' '/usr1/mxRec/src/core/host_emb*' '7/ext*' '*7/bits*' 'platform/*' '/usr/local/*' '/usr/include/*' '/opt/buildtools/python-3.7.5/lib/python3.7/site-packages/tensorflow*' 'tests/*' '/usr1/mxRec/src/core/ock_ctr_common/include*' --rc lcov_branch_coverage=1 -o "${COVERAGE_FILE}" genhtml --rc genhtml_branch_coverage=1 "${COVERAGE_FILE}" -o "${REPORT_FOLDER}" [ -d "${COVERAGE_FILE}"_tmp ] && rm -rf "${COVERAGE_FILE}"_tmp [ -d "${COVERAGE_FILE}" ] && rm -rf "${COVERAGE_FILE}" diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 9854324a..cbf85dd9 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -5,6 +5,13 @@ find_package(GTest REQUIRED) include_directories(${GTEST_INCLUDE_DIRS}) add_definitions(-DGTEST) +set(EMOCK_INCLUDE_DIRS /usr/local/include/emock) +include_directories(${EMOCK_INCLUDE_DIRS}) +find_library(EMOCK_LIBRARY emock) +if(NOT EMOCK_LIBRARY) + set(EMOCK_LIBRARY /usr/local/lib/libemock.a) +endif() +message("EMOCK_LIBRARY: " ${EMOCK_LIBRARY}) # src file(GLOB_RECURSE MXREC_SRC ${PROJECT_SOURCE_DIR}/core/*.cpp) message("MXREC_SRC: " ${MXREC_SRC}) @@ -62,4 +69,4 @@ target_link_libraries(test_main PUBLIC MPI::MPI_CXX) target_link_libraries(test_main PUBLIC ascendcl msprofiler ge_executor gert runtime ge_common register graph ascend_protobuf - profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host) + profapi opt_feature error_manager exe_graph acl_tdt_channel acl_tdt_queue securec drvdsmi_host ${EMOCK_LIBRARY}) diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index c5708488..0e14a8f7 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -8,13 +8,13 @@ #include +#include +#include #include -#include #include #include "utils/common.h" #include "key_process/key_process.h" -#include "hd_transfer/hd_transfer.h" #include "ock_ctr_common/include/unique.h" #include "ock_ctr_common/include/error_code.h" @@ -66,11 +66,15 @@ protected: rankInfo.localRankSize = worldSize; rankInfo.useStatic = useStatic; rankInfo.localRankId = rankInfo.rankId % rankInfo.localRankSize; + rankInfo.deviceId = rankInfo.localRankId; rankInfo.noDDR = false; rankInfo.maxStep = { 1, -1 }; + rankInfo.useHot = false; // 初始化emb信息 GenEmbInfos(embNum, embInfos, fieldNums); splits = fieldNums; + BuildExpect(); + GlobalMockObject::verify(); } // 使用该方法构造的数据需要使用掉,否则会影响其他用例 @@ -122,21 +126,10 @@ protected: } } - template - inline vector Count2Start(const vector& count) - { - vector start = { 0 }; - for (size_t i = 0; i < count.size() - 1; ++i) { - start.push_back(count[i] + start.back()); - } - return start; - } - // 生成emb表信息 bool GenEmbInfos(size_t embNums, vector& allEmbInfos, vector& geFieldNums) { default_random_engine generator; - uniform_int_distribution distribution(randMin, randMax); int embSizeMin = 5; int embSizeMax = 8; int base = 2; @@ -149,11 +142,13 @@ protected: temp.name = "emb" + ss.str(); ss.str(""); ss.clear(); - temp.sendCount = distribution(generator); - temp.extEmbeddingSize = pow(base, embSizeDistribution(generator)); + temp.sendCount = sendCount; // 10~25 + temp.extEmbeddingSize = pow(base, embSizeDistribution(generator)); // 2^5~2^8 temp.devVocabSize = vocabSize; geFieldNums.push_back(sampleSize); allEmbInfos.push_back(move(temp)); + LOG_INFO("GenEmbInfos, emb Name: {}, sendCount:{}, extEmbeddingSize: {}, devVocabSize: {}", + temp.name, temp.sendCount, temp.extEmbeddingSize, temp.devVocabSize = vocabSize); } return true; } @@ -199,46 +194,168 @@ protected: } } - enum class A2A { - SC, SS, RC, RS, INVALID - }; + unique_ptr GenBatch(string batchName, int batchId, int channelId) // 用于端到端test + { + unique_ptr batch = std::make_unique(); + vector allBatchKeys = { { 11, 11, 6, 16, 14, 8, 6, 5, 8, 6, 14, 11, 4, 12, 1, 13 }, + { 8, 6, 2, 4, 3, 8, 13, 2, 1, 4, 2, 2, 11, 8, 14, 5 }, + { 16, 3, 2, 12, 4, 12, 12, 2, 6, 4, 1, 5, 9, 3, 5, 14 }, + { 2, 8, 2, 12, 1, 14, 9, 8, 14, 16, 11, 15, 1, 7, 5, 2 } }; + batch->sample = std::move(allBatchKeys[worldRank]); + batch->name = batchName; + batch->batchId = batchId; + batch->channel = channelId; + LOG_INFO(KEY_PROCESS "test GenExpect: rank {}, batchKeys {}", + worldRank, VectorToString(batch->sample)); + + return batch; + } + void BuildExpect() + { + allExpectSs = { { 0, 4, 7, 9 }, { 0, 2, 5, 8 }, { 0, 3, 6, 9 }, { 0, 3, 6, 8 } }; + allExpectRestore = { { 9, 9, 7, 0, 8, 1, 7, 4, 1, 7, 8, 9, 2, 3, 5, 6 }, + { 0, 5, 6, 1, 8, 0, 2, 6, 3, 1, 6, 6, 9, 0, 7, 4 }, + { 0, 9, 6, 1, 2, 1, 1, 6, 7, 2, 3, 4, 5, 9, 4, 8 }, + { 6, 0, 6, 1, 3, 7, 4, 0, 7, 2, 8, 9, 3, 10, 5, 6 } }; + // sendCount = 10, 按照10padding的结果 + allExpectRestoreStatic = { { 30, 30, 20, 0, 21, 1, 20, 10, 1, 20, 21, 30, 2, 3, 11, 12 }, + { 0, 20, 21, 1, 30, 0, 10, 21, 11, 1, 21, 21, 31, 0, 22, 12 }, + { 0, 30, 20, 1, 2, 1, 1, 20, 21, 2, 10, 11, 12, 30, 11, 22 }, + { 20, 0, 20, 1, 10, 21, 11, 0, 21, 2, 30, 31, 10, 32, 12, 20 } }; + allExpectLookupKeys = { { 16, 8, 4, 12, 8, 4, 16, 12, 4, 8, 12, 16 }, + { 5, 1, 13, 13, 1, 5, 1, 5, 9, 1, 9, 5 }, + { 6, 14, 6, 2, 14, 2, 6, 14, 2, 14 }, + { 11, 3, 11, 3, 11, 15, 7 } }; + allExpectOffset = { { 0, 1, 2, 3, 1, 2, 0, 3, 2, 1, 3, 0 }, + { 0, 1, 2, 2, 1, 0, 1, 0, 3, 1, 3, 0 }, + { 0, 1, 0, 2, 1, 2, 0, 1, 2, 1 }, + { 0, 1, 0, 1, 0, 2 } }; + allExpectCount = { { 1, 2, 1, 1, 3, 2, 1, 3, 2, 2, 1, 1 }, + { 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1 }, + { 3, 2, 1, 4, 1, 2, 1, 1, 3, 2 }, + { 3, 1, 1, 2, 1, 1, 1 } }; + allExpectAll2all = { { 4, 2, 3, 3 }, { 3, 3, 3, 3 }, { 2, 3, 3, 2 }, { 1, 2, 1, 3 } }; + } + + bool CheckMatrixTensor(const vector& actual, const vector>& expect) // 主要用于all2all的校验 + { + int row = expect.size(); + int col = expect[0].size(); + auto tmpTensor = actual.at(0); + auto tmpData = tmpTensor.matrix(); + for (int i = 0; i < row; ++i) { + for (int j = 0; j < col; ++j) { + if (!(tmpData(i, j) == expect[i][j])) { + return false; + } + } + } + return true; + } + + bool CheckFlatTensor(const vector& actual, const vector& expect) // 主要用于lookup和restore的校验 + { + int num = expect.size(); + auto tmpTensor = actual.at(0); + auto tmpData = tmpTensor.flat(); + for (int i = 0; i < num; ++i) { + if (!(tmpData(i) == expect[i])) { + return false; + } + } + return true; + } + + bool CheckPaddingVec(const vector& actual, const vector& expect) // 主要用于lookup静态下padding校验 + { + for (int i = 0, j = 0; i < actual.size(); i++) { + if (actual[i] == -1) { + continue; + } + if (!(actual[i] == expect[j])) { + return false; + } + j++; + } + return true; + } RankInfo rankInfo; + vector embInfos; int worldRank {}; int worldSize {}; vector splits; - int sampleSize = 20; + int sampleSize = 20; // dim维度 int channel = 0; - int randMin = 10; - int randMax = 25; // 最大随机数范围 - // RankInfo rankInfo + int sendCount = 10; + int randMax = 25; // GenData生成数据0~max范围 + int batchSize = 5; - int localRankSize = 2; bool useStatic = true; - int staticSendCount = 65536; - - int maxRankSize = 8; // vector embInfos int embNum = 1; - vector fieldNums; + vector fieldNums; // 多个表的dim维度 - vector src; - vector allRankInfo; - vector embInfos; - unique_ptr batchData; vector splitKeys; vector restore; + vector> keyCount; KeyProcess process; + vector> allExpectSs; + vector> allExpectAll2all; + vector> allExpectRestore; + vector> allExpectRestoreStatic; + vector> allExpectLookupKeys; + vector> allExpectOffset; + vector> allExpectCount; void TearDown() { + GlobalMockObject::verify(); // delete } }; +int EMOCK_API mockStart(KeyProcess* obj) { + return 1; +} + +void EMOCK_API mockDestroy(KeyProcess* obj) { + // 等待线程主动处理结束,再isRunning = false + for (auto& i: obj->procThreads) { + i->join(); + } + obj->isRunning = false; + obj->procThreads.clear(); + return; +} + +void EMOCK_API mockEmptyDestroy(KeyProcess* obj) { + auto batchQueue = SingletonQueue::GetInstances(0); + do { + LOG_INFO("wait for thread running"); + this_thread::sleep_for(2s); + } while (batchQueue->TryPop() != nullptr); + // 通过Queue中数据是否取完了来判断是否需要退出;可能会出现取完却未处理完的情况、出数据不均衡 + obj->isRunning = false; + for (auto& i: obj->procThreads) { + i->join(); + } + return; +} + +void EMOCK_API mockInitExpansionEmb(EmbTable* obj, const EmbInfo& embInfos, const RankInfo&, int) +{ + obj->totalCapacity = embInfos.devVocabSize; + obj->embSize = embInfos.extEmbeddingSize; + obj->usedCapacity = 1; +} + TEST_F(KeyProcessTest, Initialize) { + EMOCK(&KeyProcess::Start) + .expects(exactly(1)) + .will(invoke(mockStart)); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); ASSERT_EQ(process.rankInfo.rankId, rankInfo.rankId); @@ -254,6 +371,42 @@ TEST_F(KeyProcessTest, Initialize) ock::ctr::Factory::Create(factory); } +TEST_F(KeyProcessTest, InitializeHot) +{ + EMOCK(&KeyProcess::Start) + .expects(exactly(1)) + .will(invoke(mockStart)); + EMOCK(GetChipName) + .stubs() + .with(emock::any()) + .will(returnValue(string("910B"))); // 调用GetChipName时返回910B + EMOCK(&EmbTable::Init) + .stubs() + .will(invoke(mockInitExpansionEmb)); + rankInfo.useHot = true; + rankInfo.useDynamicExpansion = true; + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); + ASSERT_EQ(process.isRunning, true); + ASSERT_EQ(process.hotEmbUpdateStep, GlobalEnv::hotEmbUpdateStep); + + for (const EmbInfo& info: embInfos) { + ASSERT_NE(process.hotEmbTotCount.find(info.name), process.hotEmbTotCount.end()); + } +} + +TEST_F(KeyProcessTest, GetExpansionTableSizeOrCapacity) +{ + EMOCK(&EmbTable::Init) + .stubs() + .will(invoke(mockInitExpansionEmb)); + + for (const EmbInfo& info: embInfos) { + process.embeddingTableMap[info.name].Init(info, rankInfo, 0); + ASSERT_EQ(process.GetExpansionTableSize(info.name), 1); + ASSERT_EQ(process.GetExpansionTableCapacity(info.name), info.devVocabSize); + } +} + TEST_F(KeyProcessTest, Start) { ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); @@ -268,6 +421,9 @@ TEST_F(KeyProcessTest, Start) TEST_F(KeyProcessTest, HashSplit) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); int rankSize = 4; auto queue = SingletonQueue::GetInstances(0); auto batch = queue->GetOne(); @@ -286,8 +442,102 @@ TEST_F(KeyProcessTest, HashSplit) ASSERT_THAT(restore, ElementsAreArray(expectRestore)); } +TEST_F(KeyProcessTest, HashSplitWithFAAE) +{ + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); + int rankSize = 4; + auto queue = SingletonQueue::GetInstances(0); + auto batch = queue->GetOne(); + KeysT batchKeys = { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }; + vector expectRestore = { 0, 0, 0, 0, 1, 1, 1, 1, 1, 2 }; + vector> expectSplitKeys = { { 4, 16 }, { 1, 21, 29 }, { 14, 2 }, { 23, 7 } }; + vector > expectCount = {{1, 1}, {1, 2, 1}, {1, 1}, {1, 1}}; + batch->sample = std::move(batchKeys); + LOG_DEBUG(KEY_PROCESS "batch sample: {}", VectorToString(batch->sample)); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); + ASSERT_EQ(process.isRunning, true); + process.rankInfo.rankSize = rankSize; + auto [splitKeys, restore, keyCount] = process.HashSplitWithFAAE(batch); + LOG_INFO(KEY_PROCESS "HashSplitWithFAAE, batch splitKeys: {}, keyCount: {}", VectorToString(splitKeys[0]), + VectorToString(keyCount[0])); + + for (unsigned int i = 0; i < splitKeys.size(); ++i) { + ASSERT_THAT(splitKeys[i], ElementsAreArray(expectSplitKeys[i])); + ASSERT_THAT(keyCount[i], ElementsAreArray(expectCount[i])); + } + ASSERT_THAT(restore, ElementsAreArray(expectRestore)); +} + +// 准入+动态shape下,有padding +TEST_F(KeyProcessTest, PaddingHashSplitWithFAAE) +{ + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); + int rankSize = 4; + auto queue = SingletonQueue::GetInstances(0); + auto batch = queue->GetOne(); + KeysT batchKeys = { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }; + vector expectRestore = { 0, 0, 0, 0, 1, 1, 1, 1, 1, 2 }; + vector> expectSplitKeys = { { 4, 16 }, { 1, 21, 29 }, { 14, 2 }, { 23, 7 } }; + vector > expectCount = {{1, 1}, {1, 2, 1}, {1, 1}, {1, 1}}; + batch->sample = std::move(batchKeys); + LOG_DEBUG(KEY_PROCESS "batch sample: {}", VectorToString(batch->sample)); + + rankInfo.useStatic = false; + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); + ASSERT_EQ(process.isRunning, true); + process.rankInfo.rankSize = rankSize; + auto [splitKeys, restore, keyCount] = process.HashSplitWithFAAE(batch); + LOG_INFO(KEY_PROCESS "HashSplitWithFAAE Padding, batch splitKeys: {}, keyCount: {}", VectorToString(splitKeys[0]), + VectorToString(keyCount[0])); + + for (unsigned int i = 0; i < splitKeys.size(); ++i) { + ASSERT_EQ(splitKeys[i].size(), ALLTOALLVC_ALIGN); + ASSERT_EQ(keyCount[i].size(), ALLTOALLVC_ALIGN); + } +} + +TEST_F(KeyProcessTest, HotHashSplit) +{ + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); + EMOCK(&KeyProcess::Destroy) + .expects(exactly(1)) + .will(invoke(mockDestroy)); + PrepareBatch(); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); + LOG_INFO("CPU Core Num: %{}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + + auto fn = [this](int channel, int id) { + auto embName = embInfos[0].name; + process.hotEmbTotCount[embName] = 10; + vector splitKeys; + vector restore; + vector hotPos; + unique_ptr batch; + batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue + LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); + tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); + LOG_INFO("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, VectorToString(hotPos)); + }; // for clean code + for (int channel = 0; channel < 1; ++channel) { + for (int id = 0; id < 1; ++id) { + // use lambda expression initialize thread + process.procThreads.emplace_back(std::make_unique(fn, channel, id)); + } + } + process.Destroy(); +} + TEST_F(KeyProcessTest, GetScAll) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 LOG_DEBUG(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, VectorToString(keyScLocal)); vector expectScAll(worldSize * worldSize); @@ -306,6 +556,9 @@ TEST_F(KeyProcessTest, GetScAll) TEST_F(KeyProcessTest, HandleRankExitScene) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); // 仅用于集合通信获取sendCount信息,构造EmbBatchT对象即可,通道传0,不用构造batch数据 @@ -323,10 +576,20 @@ TEST_F(KeyProcessTest, HandleRankExitScene) } catch (EndRunExit e) { LOG_INFO(KEY_PROCESS "success"); } + + // 测试第二个线程进入,由于上一个sendEos,这个线程则不应该发送 + try { + process.HandleRankExitScene(1, batch, 0); + } catch (EndRunExit e) { + LOG_INFO(KEY_PROCESS "success"); + } } TEST_F(KeyProcessTest, GetScAllForUnique) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 LOG_INFO(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, VectorToString(keyScLocal)); vector expectScAll(worldSize * worldSize); @@ -345,8 +608,12 @@ TEST_F(KeyProcessTest, GetScAllForUnique) ASSERT_THAT(scAll, ElementsAreArray(expectScAll)); } +// 非hot、非准入模式,固定batch输入,校验restore TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); auto queue = SingletonQueue::GetInstances(0); auto batch = queue->GetOne(); vector allBatchKeys = { { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }, @@ -377,37 +644,59 @@ TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) ASSERT_THAT(restore, ElementsAreArray(allExpectRestore[worldRank])); } -TEST_F(KeyProcessTest, ProcessKeySplit_rebuilt) +// hot模式,batch随机数,ProcessSplitKeys后人为校验lookupKeys、scAll、restore +TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); + EMOCK(&KeyProcess::Destroy) + .expects(exactly(1)) + .will(invoke(mockDestroy)); PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - LOG_INFO("CPU Core Num: %{}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 + LOG_INFO("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 auto fn = [this](int channel, int id) { auto embName = embInfos[0].name; - process.hotEmbTotCount[embName] = 10; vector splitKeys; vector restore; vector hotPos; unique_ptr batch; batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue - LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); + LOG_INFO("rankid :{}, batchid: {}", rankInfo.rankId, batch->batchId); tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); - LOG_INFO("rankid :{},batchid: {}, hotPos {}", rankInfo.rankId, batch->batchId, VectorToString(hotPos)); + auto [lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); + process.BuildRestoreVec(batch, ss, restore, hotPos.size()); + LOG_INFO("rankid :{}, batchid: {}, lookupKeys: {}, scAll: {}, restore after build {}", + rankInfo.rankId, batch->batchId, VectorToString(lookupKeys), + VectorToString(scAll), VectorToString(restore)); }; // for clean code for (int channel = 0; channel < 1; ++channel) { - for (int id = 0; id < 1; ++id) { + for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { // use lambda expression initialize thread process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } } - this_thread::sleep_for(20s); + process.Destroy(); } -TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) +// 准入模式,batch随机数,ProcessSplitKeys后人为校验lookupKeys、scAll、count +TEST_F(KeyProcessTest, GetCountRecv) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); + EMOCK(&KeyProcess::Destroy) + .expects(exactly(1)) + .will(invoke(mockDestroy)); PrepareBatch(); + process.m_featureAdmitAndEvict.m_isEnableFunction = true; + for (size_t i = 0; i < embInfos.size(); i++) { + FeatureAdmitAndEvict::m_embStatus[embInfos[i].name] = SingleEmbTableStatus::SETS_BOTH; + } + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); LOG_INFO("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 @@ -415,16 +704,17 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) auto embName = embInfos[0].name; vector splitKeys; vector restore; - vector hotPos; + vector> count; unique_ptr batch; batch = process.GetBatchData(channel, id); // get batch data from SingletonQueue - LOG_INFO("rankid :{},batchid: {}", rankInfo.rankId, batch->batchId); - tie(splitKeys, restore, hotPos) = process.HotHashSplit(batch); - auto[lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); - process.BuildRestoreVec(batch, ss, restore, hotPos.size()); - LOG_INFO("rankid :{},batchid: {}, lookupKeys: {}, scAll: {}, restore after build {}", - rankInfo.rankId, batch->batchId, VectorToString(lookupKeys), - VectorToString(scAll), VectorToString(restore)); + LOG_INFO("rankid :{}, batchid: {}", rankInfo.rankId, batch->batchId); + tie(splitKeys, restore, count) = process.HashSplitWithFAAE(batch); + auto [lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); + vector countRecv = process.GetCountRecv(batch, id, count, scAll, ss); + + LOG_INFO("rankid :{}, batchid: {}, lookupKeys: {}, scAll: {}, count after build {}", + rankInfo.rankId, batch->batchId, VectorToString(lookupKeys), + VectorToString(scAll), VectorToString(countRecv)); }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { @@ -432,12 +722,14 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } } - this_thread::sleep_for(20s); process.Destroy(); } TEST_F(KeyProcessTest, Key2Offset) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); KeysT lookupKeys = { 4, 16, 28, 4, 24, 4, 20, 24 }; KeysT expectOffset = { 0, 1, 2, 0, 3, 0, 4, 3 }; ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); @@ -466,6 +758,9 @@ TEST_F(KeyProcessTest, Key2Offset) TEST_F(KeyProcessTest, Key2OffsetDynamicExpansion) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); KeysT lookupKeys = { 4, 16, 28, -1, 24, -1, 20, 24 }; KeysT expectOffset = { 0, 0, 0, 0, 0, 0, 0, 0 }; ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); @@ -497,21 +792,136 @@ TEST_F(KeyProcessTest, GetUniqueConfig) // 边界值、重复度测试 TEST_F(KeyProcessTest, ProcessPrefetchTask) { + EMOCK(&KeyProcess::Destroy) + .expects(exactly(1)) + .will(invoke(mockEmptyDestroy)); PrepareBatch(); + rankInfo.noDDR = true; + GlobalEnv::applyGradientsStrategy = ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY; ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); process.rankInfo.rankSize = worldSize; process.rankInfo.localRankId = process.rankInfo.rankId % process.rankInfo.localRankSize; ASSERT_EQ(process.isRunning, true); - ASSERT_EQ(process.Start(), 0); // 所有线程处理完(训练结束)后调用 - this_thread::sleep_for(5s); - LOG_INFO("wait 20s for thread running"); - this_thread::sleep_for(20s); + process.Destroy(); + GlobalEnv::applyGradientsStrategy = ApplyGradientsStrategyOptions::DIRECT_APPLY; +} + +// HBM端到端测试,动态shape,固定batch输入 +TEST_F(KeyProcessTest, KeyProcessTaskHelper) +{ + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); + rankInfo.noDDR = true; + rankInfo.useStatic = false; + rankInfo.useHot = false; + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); + ASSERT_EQ(process.isRunning, true); + int batchId = 0; + int channelId = 0; + auto batch = GenBatch(embInfos[0].name, batchId, channelId); // 测试一个表 + + LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}", rankInfo.rankId, batch->batchId); + + ASSERT_EQ(process.KeyProcessTaskHelper(batch, channelId, 0), true); // threadId = 0 + auto infoVecs = process.GetInfoVec(batchId, embInfos[0].name, channelId, ProcessedInfo::RESTORE); + ASSERT_NE(infoVecs, nullptr); + auto all2all = process.GetInfoVec(batchId, embInfos[0].name, channelId, ProcessedInfo::ALL2ALL); + ASSERT_NE(all2all, nullptr); + + ASSERT_EQ(CheckMatrixTensor(*all2all, allExpectAll2all), true); + ASSERT_EQ(CheckFlatTensor({infoVecs->back()}, allExpectOffset[worldRank]), true); + infoVecs->pop_back(); + ASSERT_EQ(CheckFlatTensor(*infoVecs, allExpectRestore[worldRank]), true); + LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, normal status success", rankInfo.rankId, batch->batchId); + // 测试batchId错误 + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + hybridMgmtBlock->hybridBatchId[0] = 1; + ASSERT_EQ(process.GetInfoVec(batchId, embInfos[0].name, channelId, ProcessedInfo::RESTORE), nullptr); + LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batchId exception success", + rankInfo.rankId, batch->batchId); + // 测试empty场景 + hybridMgmtBlock->pythonBatchId[1] = 1; + hybridMgmtBlock->hybridBatchId[1] = 1; + hybridMgmtBlock->readEmbedBatchId[1] = 1; + hybridMgmtBlock->loop[1] = 1; + ASSERT_EQ(process.GetInfoVec(batchId + 1, embInfos[0].name, channelId + 1, ProcessedInfo::RESTORE), nullptr); + LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batch empty success", rankInfo.rankId, batch->batchId); + // eos + process.SetEos(1, 1); + ASSERT_EQ(process.GetInfoVec(batchId + 1, embInfos[0].name, channelId + 1, ProcessedInfo::RESTORE), nullptr); + LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, eos status success", rankInfo.rankId, batch->batchId); + + process.Destroy(); +} + +// DDR端到端测试,静态shape,固定batch输入 +TEST_F(KeyProcessTest, KeyProcessTaskHelperDDR) +{ + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); + rankInfo.noDDR = false; + rankInfo.useStatic = true; + rankInfo.useHot = false; + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); + ASSERT_EQ(process.isRunning, true); + int batchId = 0; + int channelId = 0; + auto batch = GenBatch(embInfos[0].name, batchId, channelId); // 测试第一个表 + HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); + hybridMgmtBlock->hybridBatchId[0] = 0; + LOG_INFO("KeyProcessTaskHelperDDR, rankid: {}, batchid: {}", rankInfo.rankId, batch->batchId); + + ASSERT_EQ(process.KeyProcessTaskHelper(batch, channelId, 0), true); // threadId = 0 + + auto lookupKeys = process.GetLookupKeys(batchId, embInfos[0].name, channelId); // lookup list返回的不是tensor + ASSERT_EQ(lookupKeys.size(), sendCount * worldSize); + LOG_INFO("KeyProcessTaskHelperDDR, rankid: {}, batchid: {}, lookupKeys: {}", + rankInfo.rankId, batch->batchId, VectorToString(lookupKeys)); + ASSERT_EQ(CheckPaddingVec(lookupKeys, allExpectLookupKeys[worldRank]), true); + + auto infoVecs = process.GetInfoVec(batchId, embInfos[0].name, channelId, ProcessedInfo::RESTORE); + ASSERT_NE(infoVecs, nullptr); + int col = allExpectRestore[worldRank].size(); + auto tmpTensor = (*infoVecs).at(0); + auto tmpData = tmpTensor.flat(); + + vector actualGetRestore(col); + for (int j = 0; j < col; j++) { + actualGetRestore[j] = tmpData(j); + } + LOG_INFO("KeyProcessTaskHelperDDR, rankid: {}, batchid: {}, Restore: {}", + rankInfo.rankId, batch->batchId, VectorToString(actualGetRestore)); + ASSERT_THAT(actualGetRestore, ElementsAreArray(allExpectRestoreStatic[worldRank])); + LOG_INFO("KeyProcessTaskHelperDDR, rankid: {}, batchid: {}, normal status success", + rankInfo.rankId, batch->batchId); + + // 测试batchId错误 + hybridMgmtBlock->hybridBatchId[0] = 1; + ASSERT_EQ(process.GetLookupKeys(batchId, embInfos[0].name, channelId).empty(), true); + LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batchId exception success", + rankInfo.rankId, batch->batchId); + // 测试empty场景 + hybridMgmtBlock->pythonBatchId[1] = 1; + hybridMgmtBlock->hybridBatchId[1] = 1; + hybridMgmtBlock->readEmbedBatchId[1] = 1; + hybridMgmtBlock->loop[1] = 1; + ASSERT_EQ(process.GetLookupKeys(batchId + 1, embInfos[0].name, channelId + 1).empty(), true); + LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batch empty success", rankInfo.rankId, batch->batchId); + // eos + process.SetEos(1, 1); + ASSERT_EQ(process.GetLookupKeys(batchId + 1, embInfos[0].name, channelId + 1).empty(), true); + LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, eos status success", rankInfo.rankId, batch->batchId); process.Destroy(); } TEST_F(KeyProcessTest, InitializeUnique) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); ASSERT_EQ(ock::ctr::Factory::Create(factory), -1); ock::ctr::UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); @@ -540,6 +950,12 @@ TEST_F(KeyProcessTest, GetKeySize) TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) { + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); + EMOCK(&KeyProcess::Destroy) + .expects(exactly(1)) + .will(invoke(mockDestroy)); PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); @@ -575,6 +991,38 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } } - this_thread::sleep_for(20s); process.Destroy(); } + +TEST_F(KeyProcessTest, LoadSaveLock) +{ + process.LoadSaveLock(); + process.LoadSaveUnlock(); +} + +TEST_F(KeyProcessTest, EvictKeys) +{ + EMOCK(&KeyProcess::Start) + .stubs() + .will(invoke(mockStart)); + ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); + ASSERT_EQ(process.isRunning, true); + absl::flat_hash_map flatTmp0 { {0, 0}, {4, 1}, {8, 2}, {12, 3} }; + absl::flat_hash_map flatTmp1 { {1, 0}, {5, 1}, {9, 2}, {13, 3} }; + absl::flat_hash_map flatTmp2 { {2, 0}, {6, 1}, {10, 2}, {14, 3} }; + absl::flat_hash_map flatTmp3 { {3, 0}, {7, 1}, {11, 2}, {15, 3} }; + vector> allHashMap {flatTmp0, flatTmp1, flatTmp2, flatTmp3}; + process.keyOffsetMap.emplace(embInfos[0].name, allHashMap[worldRank]); + process.evictPosMap.emplace(embInfos[0].name, vector{}); + + vector> allEvictKeys {{4, 8}, {1, 5}, {10, 14}, {3, 11}}; + vector> allEvictPos {{1, 2}, {0, 1}, {2, 3}, {0, 2}}; + process.EvictKeys(embInfos[0].name, allEvictKeys[worldRank]); + ASSERT_THAT(process.evictPosMap.at(embInfos[0].name), ElementsAreArray(allEvictPos[worldRank])); + + // 测试并表统计情况下的淘汰 + vector> allEvictKeysCom {{0}, {}, {18}, {}}; + vector> allEvictPosCom {{1, 2, 0}, {0, 1}, {2, 3}, {0, 2}}; + process.EvictKeysCombine(allEvictKeysCom[worldRank]); + ASSERT_THAT(process.evictPosMap.at(embInfos[0].name), ElementsAreArray(allEvictPosCom[worldRank])); +} -- Gitee From f09d7c1295d3eb3adb5473f594139ce053026c8b Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 20 Dec 2023 15:17:18 +0800 Subject: [PATCH 526/551] Match-id-28545454d5d6daec590c072719b099391b49e039 --- src/core/host_emb/host_emb.cpp | 7 -- src/tests/host_emb/host_emb_test.cpp | 113 +++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 7 deletions(-) diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index eba0b971..e75ba892 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -41,7 +41,6 @@ void HostEmb::Initialize(const vector& embInfos, int seed) void HostEmb::EmbDataGenerator(const vector &initializeInfos, int seed, int vocabSize, int embeddingSize, vector> &embData) const { -#ifndef GTEST LOG_INFO(HOSTEMB + "GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); embData.clear(); embData.resize(vocabSize, vector(embeddingSize)); @@ -53,7 +52,6 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in } } LOG_INFO(HOSTEMB + "GenerateEmbData End, seed:{}", seed); -#endif } /// 停止用于异步更新D2H emb的线程 @@ -85,7 +83,6 @@ void HostEmb::Join(int channelId) } } -#ifndef GTEST /// 从hdTransfer获取device侧返回的emb信息,并在host侧表的对应位置插入。 /// missingKeysHostPos为host侧需要发送的emb的位置,也就是淘汰的emb的插入位置 /// \param missingKeysHostPos 当前batch在host上需要换出的偏移 @@ -154,7 +151,6 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI throw runtime_error("Acl get tensor data from dataset failed."); } float* ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); - size_t elementSize = acltdtGetDataSizeFromItem(aclData); size_t dimNum = acltdtGetDimNumFromItem(aclData); LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}," @@ -232,16 +228,13 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve } } } -#endif /// 利用initializer初始化emb淘汰的位置 /// \param embName 表名 /// \param offset 淘汰的偏移列表 void HostEmb::EvictInitEmb(const string& embName, const vector& offset) { -#ifndef GTEST auto& hostEmb = GetEmb(embName); EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); LOG_INFO(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); -#endif } \ No newline at end of file diff --git a/src/tests/host_emb/host_emb_test.cpp b/src/tests/host_emb/host_emb_test.cpp index dc7baf29..2632e0df 100644 --- a/src/tests/host_emb/host_emb_test.cpp +++ b/src/tests/host_emb/host_emb_test.cpp @@ -7,14 +7,75 @@ */ #include +#include #include "host_emb/host_emb.h" #include "tensorflow/core/framework/tensor.h" +#include "hd_transfer/hd_transfer.h" +#include "utils/singleton.h" using namespace std; using namespace tensorflow; using namespace MxRec; +namespace { +bool operator==(const Tensor& tensor1, const Tensor& tensor2) +{ + if (tensor1.shape() != tensor2.shape()) { + return false; + } + auto tensor1_data = tensor1.flat(); + auto tensor2_data = tensor2.flat(); + for (int j = 0; j < tensor1_data.size(); j++) { + if (tensor1_data(j) != tensor2_data(j)) { + return false; + } + } + return true; +} + +bool operator==(const vector& p1, const vector& p2) +{ + if (p1.size() != p2.size()) { + return false; + } + for (int i = 0; i tensors; + Tensor tmpTensor(tensorflow::DT_FLOAT, { 32 }); + auto tmpData = tmpTensor.flat(); + for (int j = 0; j < 32; j++) { + tmpData(j) = 0.1f*j; + } + tensors.emplace_back(tmpTensor); + + EMOCK(&HDTransfer::Recv).expects(exactly(1)).will(returnValue(tensors)); + HostEmb h; + EmbInfo embInfo; + embInfo.name = "TestEmb"; + embInfo.devVocabSize = 100; + embInfo.hostVocabSize = 200; + embInfo.extEmbeddingSize = 32; + std::string name = "random_normal_initializer"; + InitializeInfo info(name, 0, embInfo.extEmbeddingSize, NormalInitializerInfo(0, 1, 7, 1.0)); + embInfo.initializeInfos.emplace_back(info); + vector embInfos {embInfo}; + h.Initialize(embInfos, 7); + vector missingKeysHostPos{199}; + h.UpdateEmb(missingKeysHostPos, TRAIN_CHANNEL_ID, embInfo.name); + ASSERT_EQ(h.hostEmbs[embInfo.name].embData[199][0], 0); + ASSERT_EQ(h.hostEmbs[embInfo.name].embData[199][31], 0.1f*31); +} + TEST(HostEmb, Tensor2Float) { shared_ptr>> lookups; @@ -61,4 +122,56 @@ TEST(HostEmb, DefaultConstructor) h.procThreadsForEval.emplace_back(make_unique([] {})); h.Join(EVAL_CHANNEL_ID); ASSERT_EQ(h.procThreadsForEval.size(), 0); +} + +TEST(HostEmb, InitializerAndEvict) +{ + HostEmb h; + EmbInfo embInfo; + embInfo.name = "TestEmb"; + embInfo.devVocabSize = 100; + embInfo.hostVocabSize = 200; + embInfo.extEmbeddingSize = 32; + std::string name = "constant_initializer"; + float initVal = 0.05f; + InitializeInfo info(name, 0, embInfo.extEmbeddingSize, ConstantInitializerInfo(initVal, 1.0)); + embInfo.initializeInfos.emplace_back(info); + vector embInfos {embInfo}; + h.Initialize(embInfos, 7); + + ASSERT_EQ(h.hostEmbs[embInfo.name].embData.size(), embInfo.hostVocabSize); + ASSERT_EQ(h.hostEmbs[embInfo.name].embData[0].size(), embInfo.extEmbeddingSize); + ASSERT_EQ(h.hostEmbs[embInfo.name].embData[0][0], initVal); + ASSERT_EQ(h.hostEmbs[embInfo.name].embData[0][embInfo.extEmbeddingSize-1], initVal); + + float initVal1 = 100.89f; + InitializeInfo info1(name, 0, embInfo.extEmbeddingSize, ConstantInitializerInfo(initVal1, 1.0)); + embInfo.initializeInfos.clear(); + embInfo.initializeInfos.emplace_back(info1); + vector offset{1, 199}; + h.hostEmbs[embInfo.name].hostEmbInfo = embInfo; + h.EvictInitEmb(embInfo.name, offset); + ASSERT_EQ(h.hostEmbs[embInfo.name].embData[1][0], initVal1); + ASSERT_EQ(h.hostEmbs[embInfo.name].embData[199][embInfo.extEmbeddingSize-1], initVal1); +} + +TEST(HostEmb, GetH2DEmb) +{ + HostEmb h; + EmbInfo embInfo; + embInfo.name = "TestEmb"; + embInfo.devVocabSize = 100; + embInfo.hostVocabSize = 200; + embInfo.extEmbeddingSize = 32; + std::string name = "random_normal_initializer"; + InitializeInfo info(name, 0, embInfo.extEmbeddingSize, NormalInitializerInfo(0, 1, 7, 1.0)); + embInfo.initializeInfos.emplace_back(info); + vector embInfos {embInfo}; + h.Initialize(embInfos, 7); + vector missingKeysHostPos{1, 199}; + vector h2dEmbOut; + h.GetH2DEmb(missingKeysHostPos, embInfo.name, h2dEmbOut); + ASSERT_EQ(h.hostEmbs[embInfo.name].embData[1][0], h2dEmbOut[0].flat()(0)); + ASSERT_EQ(h.hostEmbs[embInfo.name].embData[199][0], h2dEmbOut[0].flat()(32)); +} } \ No newline at end of file -- Gitee From 0499ee4975f401d24c741e604f43a67b2ee62725 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 20 Dec 2023 15:42:56 +0800 Subject: [PATCH 527/551] Match-id-962c05d68cf02df7d5df2e8ed42000e3ed54a8fb --- .../aclnn_lookup_test/inc/common.h | 44 +++++ .../aclnn_lookup_test/inc/op_runner.h | 181 ++++++++++++++++++ .../aclnn_lookup_test/inc/operator_desc.h | 58 ++++++ .../aclnn_lookup_test/input/.keep | 0 .../aclnn_lookup_test/output/.keep | 0 .../cust_op_by_addr/aclnn_lookup_test/run.sh | 106 ++++++++++ .../aclnn_lookup_test/scripts/gen_data.py | 0 .../scripts/verify_result.py | 0 .../aclnn_lookup_test/src/common.cpp | 78 ++++++++ .../aclnn_lookup_test/src/main.cpp | 0 .../aclnn_lookup_test/src/op_runner.cpp | 0 .../aclnn_lookup_test/src/operator_desc.cpp | 0 .../aclnn_update_test/inc/common.h | 44 +++++ .../aclnn_update_test/inc/op_runner.h | 181 ++++++++++++++++++ .../aclnn_update_test/inc/operator_desc.h | 57 ++++++ .../aclnn_update_test/input/.keep | 0 .../aclnn_update_test/output/.keep | 0 .../cust_op_by_addr/aclnn_update_test/run.sh | 106 ++++++++++ .../aclnn_update_test/scripts/gen_data.py | 0 .../scripts/verify_result.py | 0 .../aclnn_update_test/src/common.cpp | 78 ++++++++ .../aclnn_update_test/src/main.cpp | 0 .../aclnn_update_test/src/op_runner.cpp | 0 .../aclnn_update_test/src/operator_desc.cpp | 0 .../op_host/embedding_lookup_by_address.cpp | 2 +- .../op_kernel/embedding_lookup_by_address.cpp | 4 +- 26 files changed, 936 insertions(+), 3 deletions(-) create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/inc/common.h create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/inc/op_runner.h create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/inc/operator_desc.h create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/input/.keep create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/output/.keep create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/run.sh create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/scripts/gen_data.py create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/scripts/verify_result.py create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/src/common.cpp create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/src/main.cpp create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/src/op_runner.cpp create mode 100644 cust_op/cust_op_by_addr/aclnn_lookup_test/src/operator_desc.cpp create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/inc/common.h create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/inc/op_runner.h create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/inc/operator_desc.h create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/input/.keep create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/output/.keep create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/run.sh create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/scripts/gen_data.py create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/scripts/verify_result.py create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/src/common.cpp create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/src/main.cpp create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/src/op_runner.cpp create mode 100644 cust_op/cust_op_by_addr/aclnn_update_test/src/operator_desc.cpp diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/common.h b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/common.h new file mode 100644 index 00000000..5b22736d --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/common.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * Author: MindX SDK + * Date: 2023/12/20 + */ +#ifndef COMMON_H +#define COMMON_H + +#include +#include +#include +#include +#include + +#include "acl/acl.h" + +#define SUCCESS 0 +#define FAILED 1 + +#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO] " fmt "\n", ##args) +#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN] " fmt "\n", ##args) +#define ERROR_LOG(fmt, args...) fprintf(stderr, "[ERROR] " fmt "\n", ##args) + +/** + * @brief Read data from file + * @param [in] filePath: file path + * @param [out] fileSize: file size + * @return read result + */ +bool ReadFile(const std::string &filePath, size_t fileSize, void *buffer, size_t bufferSize); + +/** + * @brief Write data to file + * @param [in] filePath: file path + * @param [in] buffer: data to write to file + * @param [in] size: size to write + * @return write result + */ +bool WriteFile(const std::string &filePath, const void *buffer, size_t size); + +#endif // COMMON_H diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/op_runner.h b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/op_runner.h new file mode 100644 index 00000000..d6f415a0 --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/op_runner.h @@ -0,0 +1,181 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * Author: MindX SDK + * Date: 2023/12/20 + */ +#ifndef OP_RUNNER_H +#define OP_RUNNER_H + +#include "aclnn/acl_meta.h" +#include "acl/acl.h" +#include "common.h" +#include "operator_desc.h" + +/** + * Op Runner + */ +class OpRunner { +public: + /** + * @brief Constructor + * @param [in] opDesc: op description + */ + explicit OpRunner(OperatorDesc *opDesc); + + /** + * @brief Destructor + */ + virtual ~OpRunner(); + + /** + * @brief Init op runner + */ + bool Init(); + + /** + * @brief Get number of inputs + * @return number of inputs + */ + const size_t NumInputs(); + + /** + * @brief Get number of outputs + * @return number of outputs + */ + const size_t NumOutputs(); + + /** + * @brief Get input size by index + * @param [in] index: input index + * @return size of the input + */ + const size_t GetInputSize(size_t index) const; + const size_t GetInputNumDims(size_t index) const; + aclDataType GetInputDataType(size_t index) const; + aclFormat GetInputFormat(size_t index) const; + + /** + * @brief Get output size by index + * @param [in] index: output index + * @return size of the output + */ + size_t GetOutputSize(size_t index) const; + const size_t GetOutputNumDims(size_t index) const; + aclDataType GetOutputDataType(size_t index) const; + aclFormat GetOutputFormat(size_t index) const; + + /** + * @brief Get input element count by index + * @param i[in] ndex: input index + * @return element count of the input + */ + size_t GetInputElementCount(size_t index) const; + + /** + * @brief Get output element count by index + * @param [in] index: output index + * @return element count of the output + */ + size_t GetOutputElementCount(size_t index) const; + + /** + * @brief Get input shape by index + * @param [in] index: input index + * @return shape of the output + */ + std::vector GetInputShape(size_t index) const; + + /** + * @brief Get output shape by index + * @param [in] index: output index + * @return shape of the output + */ + std::vector GetOutputShape(size_t index) const; + + /** + * @brief Get input buffer(host memory) by index + * @tparam T: data type + * @param [in] index: input index + * @return host address of the input + */ + template + T *GetInputBuffer(size_t index) + { + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return nullptr; + } + return reinterpret_cast(hostInputs_[index]); + } + + /** + * @brief Get output buffer(host memory) by index + * @tparam T: data type + * @param [in] index: output index + * @return host address of the output + */ + template + const T *GetOutputBuffer(size_t index) + { + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return nullptr; + } + + return reinterpret_cast(hostOutputs_[index]); + } + + /** + * @brief Print readable input by index + * @param [in] index: input index + * @param [in] elementsPerRow: number of elements per row + */ + void PrintInput(size_t index, size_t elementsPerRow = 16); + + /** + * @brief Print readable output by index + * @param [in] index: output index + * @param [in] elementsPerRow: number of elements per row + */ + void PrintOutput(size_t index, size_t elementsPerRow = 16); + + /** + * @brief Compile static op + * @return compile result + */ + bool CompileStaticOp(); + + /** + * @brief Compile dynamic op + * @return compile result + */ + bool CompileDynamicOp(); + + /** + * @brief Run op + * @return run result + */ + bool RunOp(); + +private: + size_t numInputs_; + size_t numOutputs_; + + std::vector inputBuffers_; + std::vector outputBuffers_; + + std::vector devInputs_; + std::vector devOutputs_; + + std::vector hostInputs_; + std::vector hostOutputs_; + + std::vector inputTensor_; + std::vector outputTensor_; + OperatorDesc *opDesc_; +}; + +#endif // OP_RUNNER_H diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/operator_desc.h b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/operator_desc.h new file mode 100644 index 00000000..434353d4 --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/operator_desc.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * Author: MindX SDK + * Date: 2023/12/20 + */ +#ifndef OPERATOR_DESC_H +#define OPERATOR_DESC_H + +#include +#include + +#include "acl/acl.h" + +/** + * Op description + */ +struct OperatorDesc { + /** + * Constructor + */ + explicit OperatorDesc(int64_t embDim, int64_t embType); + + /** + * Destructor + */ + virtual ~OperatorDesc(); + + /** + * Add an input tensor description + * @param [in] dataType: data type + * @param [in] numDims: number of dims + * @param [in] dims: dims + * @param [in] format: format + * @return OperatorDesc + */ + OperatorDesc &AddInputTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, aclFormat format); + + /** + * Add an output tensor description + * @param [in] dataType: data type + * @param [in] numDims: number of dims + * @param [in] dims: dims + * @param [in] format: format + * @return OperatorDesc + */ + OperatorDesc &AddOutputTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, aclFormat format); + + std::string opType; + std::vector inputDesc; + std::vector outputDesc; + int64_t embeddingDim; + int64_t embeddingType; +}; + +#endif // OPERATOR_DESC_H diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/input/.keep b/cust_op/cust_op_by_addr/aclnn_lookup_test/input/.keep new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/output/.keep b/cust_op/cust_op_by_addr/aclnn_lookup_test/output/.keep new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/run.sh b/cust_op/cust_op_by_addr/aclnn_lookup_test/run.sh new file mode 100644 index 00000000..4f1f92c5 --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/run.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: build test script. +# Author: MindX SDK +# Create: 2023 +# History: NA + +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +export ASCEND_GLOBAL_LOG_LEVEL=0 + +CURRENT_DIR=$( + cd $(dirname ${BASH_SOURCE:-$0}) + pwd +) +cd $CURRENT_DIR + +# 导出环境变量 +SHORT=v:, +LONG=dtype:, +OPTS=$(getopt -a --options $SHORT --longoptions $LONG -- "$@") +eval set -- "$OPTS" +while : +do + case "$1" in + # float16, float, int32 + (-v | --dtype) + DTYPE="$2" + shift 2;; + (--) + shift; + break;; + (*) + echo "[ERROR] Unexpected option: $1"; + break;; + esac +done + +if [ ! "$ASCEND_HOME_DIR" ]; then + if [ -d "$HOME/Ascend/ascend-toolkit/latest" ]; then + export ASCEND_HOME_DIR=$HOME/Ascend/ascend-toolkit/latest + else + export ASCEND_HOME_DIR=/usr/local/Ascend/ascend-toolkit/latest + fi +fi +source $ASCEND_HOME_DIR/bin/setenv.bash + +export DDK_PATH=$ASCEND_HOME_DIR +arch=$(uname -m) +export NPU_HOST_LIB=$ASCEND_HOME_DIR/${arch}-linux/lib64 + +main() +{ + # 1. 清除遗留生成文件和日志文件 + rm -rf $HOME/ascend/log/* + rm ./input/*.bin + rm ./output/*.bin + + # 2. 生成输入数据和真值数据 + cd $CURRENT_DIR + python3 scripts/gen_data.py + if [ $? -ne 0 ]; then + echo "ERROR: generate input data failed!" + return 1 + fi + echo "INFO: generate input data success!" + + # 3. 编译acl可执行文件 + cd $CURRENT_DIR; rm -rf build; mkdir -p build; cd build + cmake ../src + if [ $? -ne 0 ]; then + echo "ERROR: cmake failed!" + return 1 + fi + echo "INFO: cmake success!" + make + if [ $? -ne 0 ]; then + echo "ERROR: make failed!" + return 1 + fi + echo "INFO: make success!" + + # 4. 运行可执行文件 + cd $CURRENT_DIR/output + echo "INFO: execute op!" + ./execute_lookup_op + + if [ $? -ne 0 ]; then + echo "ERROR: acl executable run failed! please check your project!" + return 1 + fi + echo "INFO: acl executable run success!" + + # 5. 比较真值文件 + cd $CURRENT_DIR + ret=$(python3 scripts/verify_result.py output/output_z.bin output/golden.bin) + echo $ret + if [ "x$ret" == "xtest pass" ]; then + echo "" + echo "#####################################" + echo "INFO: you have passed the Precision!" + echo "#####################################" + echo "" + fi +} + +main diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/scripts/gen_data.py b/cust_op/cust_op_by_addr/aclnn_lookup_test/scripts/gen_data.py new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/scripts/verify_result.py b/cust_op/cust_op_by_addr/aclnn_lookup_test/scripts/verify_result.py new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/src/common.cpp b/cust_op/cust_op_by_addr/aclnn_lookup_test/src/common.cpp new file mode 100644 index 00000000..5388c370 --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/src/common.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * Author: MindX SDK + * Date: 2023/12/20 + */ +#include "common.h" + +#include +#include +#include +#include + +extern bool g_isDevice; + +bool ReadFile(const std::string &filePath, size_t fileSize, void *buffer, size_t bufferSize) +{ + struct stat sBuf; + int fileStatus = stat(filePath.data(), &sBuf); + if (fileStatus == -1) { + ERROR_LOG("failed to get file %s", filePath.c_str()); + return false; + } + if (S_ISREG(sBuf.st_mode) == 0) { + ERROR_LOG("%s is not a file, please enter a file", filePath.c_str()); + return false; + } + + std::ifstream file; + file.open(filePath, std::ios::binary); + if (!file.is_open()) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + + std::filebuf *buf = file.rdbuf(); + size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (size == 0) { + ERROR_LOG("file size is 0"); + file.close(); + return false; + } + if (size > bufferSize) { + ERROR_LOG("file size is larger than buffer size"); + file.close(); + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), size); + fileSize = size; + file.close(); + return true; +} + +bool WriteFile(const std::string &filePath, const void *buffer, size_t size) +{ + if (buffer == nullptr) { + ERROR_LOG("Write file failed. buffer is nullptr"); + return false; + } + + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + + auto writeSize = write(fd, buffer, size); + (void) close(fd); + if (writeSize != size) { + ERROR_LOG("Write file Failed."); + return false; + } + + return true; +} diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/src/main.cpp b/cust_op/cust_op_by_addr/aclnn_lookup_test/src/main.cpp new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/src/op_runner.cpp b/cust_op/cust_op_by_addr/aclnn_lookup_test/src/op_runner.cpp new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/src/operator_desc.cpp b/cust_op/cust_op_by_addr/aclnn_lookup_test/src/operator_desc.cpp new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/inc/common.h b/cust_op/cust_op_by_addr/aclnn_update_test/inc/common.h new file mode 100644 index 00000000..5b22736d --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_update_test/inc/common.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * Author: MindX SDK + * Date: 2023/12/20 + */ +#ifndef COMMON_H +#define COMMON_H + +#include +#include +#include +#include +#include + +#include "acl/acl.h" + +#define SUCCESS 0 +#define FAILED 1 + +#define INFO_LOG(fmt, args...) fprintf(stdout, "[INFO] " fmt "\n", ##args) +#define WARN_LOG(fmt, args...) fprintf(stdout, "[WARN] " fmt "\n", ##args) +#define ERROR_LOG(fmt, args...) fprintf(stderr, "[ERROR] " fmt "\n", ##args) + +/** + * @brief Read data from file + * @param [in] filePath: file path + * @param [out] fileSize: file size + * @return read result + */ +bool ReadFile(const std::string &filePath, size_t fileSize, void *buffer, size_t bufferSize); + +/** + * @brief Write data to file + * @param [in] filePath: file path + * @param [in] buffer: data to write to file + * @param [in] size: size to write + * @return write result + */ +bool WriteFile(const std::string &filePath, const void *buffer, size_t size); + +#endif // COMMON_H diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/inc/op_runner.h b/cust_op/cust_op_by_addr/aclnn_update_test/inc/op_runner.h new file mode 100644 index 00000000..d6f415a0 --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_update_test/inc/op_runner.h @@ -0,0 +1,181 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * Author: MindX SDK + * Date: 2023/12/20 + */ +#ifndef OP_RUNNER_H +#define OP_RUNNER_H + +#include "aclnn/acl_meta.h" +#include "acl/acl.h" +#include "common.h" +#include "operator_desc.h" + +/** + * Op Runner + */ +class OpRunner { +public: + /** + * @brief Constructor + * @param [in] opDesc: op description + */ + explicit OpRunner(OperatorDesc *opDesc); + + /** + * @brief Destructor + */ + virtual ~OpRunner(); + + /** + * @brief Init op runner + */ + bool Init(); + + /** + * @brief Get number of inputs + * @return number of inputs + */ + const size_t NumInputs(); + + /** + * @brief Get number of outputs + * @return number of outputs + */ + const size_t NumOutputs(); + + /** + * @brief Get input size by index + * @param [in] index: input index + * @return size of the input + */ + const size_t GetInputSize(size_t index) const; + const size_t GetInputNumDims(size_t index) const; + aclDataType GetInputDataType(size_t index) const; + aclFormat GetInputFormat(size_t index) const; + + /** + * @brief Get output size by index + * @param [in] index: output index + * @return size of the output + */ + size_t GetOutputSize(size_t index) const; + const size_t GetOutputNumDims(size_t index) const; + aclDataType GetOutputDataType(size_t index) const; + aclFormat GetOutputFormat(size_t index) const; + + /** + * @brief Get input element count by index + * @param i[in] ndex: input index + * @return element count of the input + */ + size_t GetInputElementCount(size_t index) const; + + /** + * @brief Get output element count by index + * @param [in] index: output index + * @return element count of the output + */ + size_t GetOutputElementCount(size_t index) const; + + /** + * @brief Get input shape by index + * @param [in] index: input index + * @return shape of the output + */ + std::vector GetInputShape(size_t index) const; + + /** + * @brief Get output shape by index + * @param [in] index: output index + * @return shape of the output + */ + std::vector GetOutputShape(size_t index) const; + + /** + * @brief Get input buffer(host memory) by index + * @tparam T: data type + * @param [in] index: input index + * @return host address of the input + */ + template + T *GetInputBuffer(size_t index) + { + if (index >= numInputs_) { + ERROR_LOG("index out of range. index = %zu, numInputs = %zu", index, numInputs_); + return nullptr; + } + return reinterpret_cast(hostInputs_[index]); + } + + /** + * @brief Get output buffer(host memory) by index + * @tparam T: data type + * @param [in] index: output index + * @return host address of the output + */ + template + const T *GetOutputBuffer(size_t index) + { + if (index >= numOutputs_) { + ERROR_LOG("index out of range. index = %zu, numOutputs = %zu", index, numOutputs_); + return nullptr; + } + + return reinterpret_cast(hostOutputs_[index]); + } + + /** + * @brief Print readable input by index + * @param [in] index: input index + * @param [in] elementsPerRow: number of elements per row + */ + void PrintInput(size_t index, size_t elementsPerRow = 16); + + /** + * @brief Print readable output by index + * @param [in] index: output index + * @param [in] elementsPerRow: number of elements per row + */ + void PrintOutput(size_t index, size_t elementsPerRow = 16); + + /** + * @brief Compile static op + * @return compile result + */ + bool CompileStaticOp(); + + /** + * @brief Compile dynamic op + * @return compile result + */ + bool CompileDynamicOp(); + + /** + * @brief Run op + * @return run result + */ + bool RunOp(); + +private: + size_t numInputs_; + size_t numOutputs_; + + std::vector inputBuffers_; + std::vector outputBuffers_; + + std::vector devInputs_; + std::vector devOutputs_; + + std::vector hostInputs_; + std::vector hostOutputs_; + + std::vector inputTensor_; + std::vector outputTensor_; + OperatorDesc *opDesc_; +}; + +#endif // OP_RUNNER_H diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/inc/operator_desc.h b/cust_op/cust_op_by_addr/aclnn_update_test/inc/operator_desc.h new file mode 100644 index 00000000..097b432d --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_update_test/inc/operator_desc.h @@ -0,0 +1,57 @@ +/* +* Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +* Description: This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +* Author: MindX SDK +* Date: 2023/12/20 +*/ +#ifndef OPERATOR_DESC_H +#define OPERATOR_DESC_H + +#include +#include + +#include "acl/acl.h" + +/** + * Op description + */ +struct OperatorDesc { + /** + * Constructor + */ + explicit OperatorDesc(int64_t update); + + /** + * Destructor + */ + virtual ~OperatorDesc(); + + /** + * Add an input tensor description + * @param [in] dataType: data type + * @param [in] numDims: number of dims + * @param [in] dims: dims + * @param [in] format: format + * @return OperatorDesc + */ + OperatorDesc &AddInputTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, aclFormat format); + + /** + * Add an output tensor description + * @param [in] dataType: data type + * @param [in] numDims: number of dims + * @param [in] dims: dims + * @param [in] format: format + * @return OperatorDesc + */ + OperatorDesc &AddOutputTensorDesc(aclDataType dataType, int numDims, const int64_t *dims, aclFormat format); + + std::string opType; + std::vector inputDesc; + std::vector outputDesc; + int64_t updateType; +}; + +#endif // OPERATOR_DESC_H diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/input/.keep b/cust_op/cust_op_by_addr/aclnn_update_test/input/.keep new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/output/.keep b/cust_op/cust_op_by_addr/aclnn_update_test/output/.keep new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/run.sh b/cust_op/cust_op_by_addr/aclnn_update_test/run.sh new file mode 100644 index 00000000..6367be33 --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_update_test/run.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Description: build test script. +# Author: MindX SDK +# Create: 2023 +# History: NA + +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +export ASCEND_GLOBAL_LOG_LEVEL=0 + +CURRENT_DIR=$( + cd $(dirname ${BASH_SOURCE:-$0}) + pwd +) +cd $CURRENT_DIR + +# 导出环境变量 +SHORT=v:, +LONG=dtype:, +OPTS=$(getopt -a --options $SHORT --longoptions $LONG -- "$@") +eval set -- "$OPTS" +while : +do + case "$1" in + # float16, float, int32 + (-v | --dtype) + DTYPE="$2" + shift 2;; + (--) + shift; + break;; + (*) + echo "[ERROR] Unexpected option: $1"; + break;; + esac +done + +if [ ! "$ASCEND_HOME_DIR" ]; then + if [ -d "$HOME/Ascend/ascend-toolkit/latest" ]; then + export ASCEND_HOME_DIR=$HOME/Ascend/ascend-toolkit/latest + else + export ASCEND_HOME_DIR=/usr/local/Ascend/ascend-toolkit/latest + fi +fi +source $ASCEND_HOME_DIR/bin/setenv.bash + +export DDK_PATH=$ASCEND_HOME_DIR +arch=$(uname -m) +export NPU_HOST_LIB=$ASCEND_HOME_DIR/${arch}-linux/lib64 + +main() +{ + # 1. 清除遗留生成文件和日志文件 + rm -rf $HOME/ascend/log/* + rm ./input/*.bin + rm ./output/*.bin + + # 2. 生成输入数据和真值数据 + cd $CURRENT_DIR + python3 scripts/gen_data.py + if [ $? -ne 0 ]; then + echo "ERROR: generate input data failed!" + return 1 + fi + echo "INFO: generate input data success!" + + # 3. 编译acl可执行文件 + cd $CURRENT_DIR; rm -rf build; mkdir -p build; cd build + cmake ../src + if [ $? -ne 0 ]; then + echo "ERROR: cmake failed!" + return 1 + fi + echo "INFO: cmake success!" + make + if [ $? -ne 0 ]; then + echo "ERROR: make failed!" + return 1 + fi + echo "INFO: make success!" + + # 4. 运行可执行文件 + cd $CURRENT_DIR/output + echo "INFO: execute op!" + ./execute_update_op + + if [ $? -ne 0 ]; then + echo "ERROR: acl executable run failed! please check your project!" + return 1 + fi + echo "INFO: acl executable run success!" + + # 5. 比较真值文件 + cd $CURRENT_DIR + ret=$(python3 scripts/verify_result.py output/output_z.bin output/golden.bin) + echo $ret + if [ "x$ret" == "xtest pass" ]; then + echo "" + echo "#####################################" + echo "INFO: you have passed the Precision!" + echo "#####################################" + echo "" + fi +} + +main diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/scripts/gen_data.py b/cust_op/cust_op_by_addr/aclnn_update_test/scripts/gen_data.py new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/scripts/verify_result.py b/cust_op/cust_op_by_addr/aclnn_update_test/scripts/verify_result.py new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/src/common.cpp b/cust_op/cust_op_by_addr/aclnn_update_test/src/common.cpp new file mode 100644 index 00000000..5388c370 --- /dev/null +++ b/cust_op/cust_op_by_addr/aclnn_update_test/src/common.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * Author: MindX SDK + * Date: 2023/12/20 + */ +#include "common.h" + +#include +#include +#include +#include + +extern bool g_isDevice; + +bool ReadFile(const std::string &filePath, size_t fileSize, void *buffer, size_t bufferSize) +{ + struct stat sBuf; + int fileStatus = stat(filePath.data(), &sBuf); + if (fileStatus == -1) { + ERROR_LOG("failed to get file %s", filePath.c_str()); + return false; + } + if (S_ISREG(sBuf.st_mode) == 0) { + ERROR_LOG("%s is not a file, please enter a file", filePath.c_str()); + return false; + } + + std::ifstream file; + file.open(filePath, std::ios::binary); + if (!file.is_open()) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + + std::filebuf *buf = file.rdbuf(); + size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (size == 0) { + ERROR_LOG("file size is 0"); + file.close(); + return false; + } + if (size > bufferSize) { + ERROR_LOG("file size is larger than buffer size"); + file.close(); + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), size); + fileSize = size; + file.close(); + return true; +} + +bool WriteFile(const std::string &filePath, const void *buffer, size_t size) +{ + if (buffer == nullptr) { + ERROR_LOG("Write file failed. buffer is nullptr"); + return false; + } + + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + + auto writeSize = write(fd, buffer, size); + (void) close(fd); + if (writeSize != size) { + ERROR_LOG("Write file Failed."); + return false; + } + + return true; +} diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/src/main.cpp b/cust_op/cust_op_by_addr/aclnn_update_test/src/main.cpp new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/src/op_runner.cpp b/cust_op/cust_op_by_addr/aclnn_update_test/src/op_runner.cpp new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/src/operator_desc.cpp b/cust_op/cust_op_by_addr/aclnn_update_test/src/operator_desc.cpp new file mode 100644 index 00000000..e69de29b diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 8e8379e0..445fd6c1 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -74,7 +74,7 @@ namespace optiling int32_t alignNum = MIN_BLOCK_SIZE / typeSize; // embeddingDimAligned,表示需要向上对齐到最小单位 int32_t embeddingDimAligned = ((embeddingDim - 1 + alignNum) / alignNum) * alignNum; - // 每个地址需要占用sizeof(int64_t)个字节,typeSize表示每个数据的字节数,需要使用2倍的内存空间,因为每次移动都需要复制一份数据 + // LocalTensor空间,tbuf存储int64的地址+inQueue+outQueue两倍的emb,因为每次移动都需要复制一份数据 int32_t occupyAddressBytesNum = sizeof(int64_t) + typeSize * embeddingDimAligned * PING_PONG_NUM * 2; // 一轮计算中最多计算多少个addr,由于地址也要搬到ub,所以需要对齐32, diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 19a384c0..55273f1e 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -25,7 +25,7 @@ public: pipe.InitBuffer(tbuf, addrNumPerLoop * sizeof(int64_t)); pipe.InitBuffer(inQueue, pingpongNum, veclen); - pipe.InitBuffer(outQueue, pingpongNum, veclen); // + pipe.InitBuffer(outQueue, pingpongNum, veclen); // get start index for current core, core parallel block_indx block_dim,即使是最后一个核也应该多初始化一些,并对齐4的倍数 srcAddrGlobal.SetGlobalBuffer((__gm__ int64_t *)(address + block_idx * singleCoreAddrLen), needComputeAddrLen); @@ -169,7 +169,7 @@ private: TBuf tbuf; TQue inQueue; TQue outQueue; - GlobalTensor srcDataBufferGm, dstDataGm, outDataGm; + GlobalTensor srcDataBufferGm, dstDataGm; GlobalTensor srcAddrGlobal; }; -- Gitee From 5db876329c3fa7292c36c0a65b6b7bd92652dc01 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 28 Dec 2023 14:09:18 +0800 Subject: [PATCH 528/551] Match-id-88ee8d36f8ee026163c8b8a78de91c469b842514 --- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 9 +++++++++ src/core/key_process/key_process.cpp | 22 +++++++++++++++------- src/core/utils/common.h | 2 +- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 36d54e7b..26ac9c75 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -685,6 +685,10 @@ bool HybridMgmt::ParseKeysHBM(int channelId, int& batchId) LOG_INFO(MGMT + "channelId:{} batchId:{}, embName:{}, ParseKeys with HBM mode end.", channelId, batchId, embInfo.name); } + if (KEY_PROCESS_INSTANCE->isNeedExit[channelId]) { + LOG_WARN(MGMT + "can not send data after eos, channelId:{} batchId:{}!", channelId, batchId); + return false; + } batchId++; return true; } @@ -750,6 +754,11 @@ bool HybridMgmt::ParseKeys(int channelId, int& batchId) EmbHDTransWrap(channelId, batchId - 1, start); LOG_DEBUG(MGMT + "channelId:{} batchId:{}, ParseKeys end, parseKeyTC(ms):{}", channelId, batchId, parseKeyTC.ElapsedMS()); + + if (KEY_PROCESS_INSTANCE->isNeedExit[channelId]) { + LOG_WARN(MGMT + "can not send data after eos, channelId:{} batchId:{}!", channelId, batchId--); + return false; + } #endif return true; } diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index f1e8295e..8b1ab8e3 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1076,18 +1076,26 @@ void KeyProcess::HandleRankExitScene(int commId, const unique_ptr &ba } LOG_INFO("channelId:{} batchId:{}, GetScAll HandleRankExitScene eos.", batch->channel, batch->batchId); - int timeout = 0; + int timeCount = 0; + int timeout = 120; // 120s还在等待,就发送eos HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); - bool isExit = hybridMgmtBlock->pythonBatchId[batch->channel] < - (hybridMgmtBlock->hybridBatchId[batch->channel] - hybridMgmtBlock->loop[batch->channel] + 1); - while (isExit && timeout < EOS_TIMEOUT) { + bool isWait = hybridMgmtBlock->pythonBatchId[batch->channel] < + (hybridMgmtBlock->hybridBatchId[batch->channel] - hybridMgmtBlock->loop[batch->channel] + 1); + + if (!isWait) { // double check + this_thread::sleep_for(seconds(EOS_TIMEOUT)); + isWait = hybridMgmtBlock->pythonBatchId[batch->channel] < + (hybridMgmtBlock->hybridBatchId[batch->channel] - hybridMgmtBlock->loop[batch->channel] + 1); + } + + while (isWait && timeout < EOS_TIMEOUT) { LOG_DEBUG("wait until hybridBatchId equal pythonBatchId before SendEos, channelId:{}, pyBatchId:{}, " "mgmtBatchId:{}", batch->channel, hybridMgmtBlock->pythonBatchId[batch->channel], hybridMgmtBlock->hybridBatchId[batch->channel]); - this_thread::sleep_for(seconds(1)); - isExit = hybridMgmtBlock->pythonBatchId[batch->channel] < + this_thread::sleep_for(seconds(EOS_TIMEOUT)); + isWait = hybridMgmtBlock->pythonBatchId[batch->channel] < (hybridMgmtBlock->hybridBatchId[batch->channel] - hybridMgmtBlock->loop[batch->channel] + 1); - timeout++; + timeCount++; } SendEos(batch->batchId, batch->channel); diff --git a/src/core/utils/common.h b/src/core/utils/common.h index d353ea0d..21132730 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -85,7 +85,7 @@ namespace MxRec { constexpr int KEY_PROCESS_TIMEOUT = 120; #endif constexpr int GET_BATCH_TIMEOUT = 300; - constexpr int EOS_TIMEOUT = 60; + constexpr int EOS_TIMEOUT = 5; constexpr size_t DEFAULT_RANDOM_SEED = 10086; constexpr int INVALID_KEY_VALUE = -1; -- Gitee From 4bba4e20bc1ff7e0699016145dcaa82932f6d6c7 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 16 Jan 2024 17:11:38 +0800 Subject: [PATCH 529/551] Match-id-758bfe576d70ece65fc9fd30ee69be3ef6641540 --- CMakeLists.txt | 15 ++ LICENSE | 217 ++++++++++++++++++ README.md | 85 +++++++ build.sh | 16 +- build/build.sh | 19 +- build/build_all.sh | 20 +- build/build_tf1.sh | 19 +- build/build_tf1_with_opensource.sh | 21 +- build/build_tf2.sh | 19 +- .../aclnn_lookup_test/inc/common.h | 23 +- .../aclnn_lookup_test/inc/op_runner.h | 23 +- .../aclnn_lookup_test/inc/operator_desc.h | 23 +- .../cust_op_by_addr/aclnn_lookup_test/run.sh | 19 +- .../aclnn_lookup_test/src/common.cpp | 22 +- .../aclnn_update_test/inc/common.h | 22 +- .../aclnn_update_test/inc/op_runner.h | 22 +- .../aclnn_update_test/inc/operator_desc.h | 22 +- .../cust_op_by_addr/aclnn_update_test/run.sh | 19 +- .../aclnn_update_test/src/common.cpp | 22 +- .../op_host/embedding_lookup_by_address.cpp | 14 ++ .../embedding_lookup_by_address_tiling.h | 15 ++ .../op_host/embedding_update_by_address.cpp | 14 ++ .../embedding_update_by_address_tiling.h | 15 ++ .../op_kernel/embedding_lookup_by_address.cpp | 15 ++ .../op_kernel/embedding_update_by_address.cpp | 15 ++ cust_op/cust_op_by_addr/run.sh | 20 +- mx_rec/__init__.py | 15 +- mx_rec/constants/__init__.py | 15 +- mx_rec/constants/constants.py | 15 +- mx_rec/core/__init__.py | 15 +- mx_rec/core/asc/__init__.py | 15 +- mx_rec/core/asc/build_graph.py | 15 +- mx_rec/core/asc/feature_spec.py | 15 +- mx_rec/core/asc/helper.py | 15 +- mx_rec/core/asc/manager.py | 15 +- mx_rec/core/asc/merge_table.py | 15 +- mx_rec/core/embedding.py | 15 +- mx_rec/core/feature_process.py | 15 +- mx_rec/data/__init__.py | 15 +- mx_rec/data/dataset.py | 15 +- mx_rec/data/patch.py | 15 +- mx_rec/graph/__init__.py | 15 +- mx_rec/graph/acg_push_ops.py | 15 +- mx_rec/graph/merge_lookup.py | 15 +- mx_rec/graph/modifier.py | 15 +- mx_rec/graph/patch.py | 15 +- mx_rec/graph/utils.py | 15 +- mx_rec/optimizers/__init__.py | 15 +- mx_rec/optimizers/adagrad.py | 15 +- mx_rec/optimizers/base.py | 15 +- mx_rec/optimizers/ftrl.py | 15 +- mx_rec/optimizers/gradient_descent.py | 15 +- mx_rec/optimizers/gradient_descent_by_addr.py | 15 +- mx_rec/optimizers/lazy_adam.py | 15 +- mx_rec/optimizers/lazy_adam_by_addr.py | 15 +- mx_rec/saver/__init__.py | 15 +- mx_rec/saver/patch.py | 31 ++- mx_rec/saver/saver.py | 15 +- mx_rec/saver/sparse.py | 15 +- mx_rec/util/__init__.py | 15 +- mx_rec/util/atomic.py | 15 +- mx_rec/util/communication/__init__.py | 15 +- mx_rec/util/communication/hccl_mgmt.py | 15 +- mx_rec/util/global_env_conf.py | 15 +- mx_rec/util/initialize.py | 15 +- mx_rec/util/log.py | 15 +- mx_rec/util/normalization.py | 16 +- mx_rec/util/ops.py | 15 +- mx_rec/util/perf.py | 15 +- mx_rec/util/tf_version_adapter.py | 15 +- mx_rec/util/variable.py | 15 +- mx_rec/validator/__init__.py | 15 +- mx_rec/validator/validator.py | 15 +- setup.py | 19 +- src/CMakeLists.txt | 15 ++ src/build.sh | 20 +- src/core/CMakeLists.txt | 15 ++ src/core/checkpoint/checkpoint.cpp | 20 +- src/core/checkpoint/checkpoint.h | 20 +- .../ckpt_data_handler/ckpt_data_handler.cpp | 21 +- .../ckpt_data_handler/ckpt_data_handler.h | 20 +- .../emb_hash_ckpt/emb_hash_ckpt.cpp | 20 +- .../emb_hash_ckpt/emb_hash_ckpt.h | 20 +- .../feat_admit_n_evict_ckpt.cpp | 20 +- .../feat_admit_n_evict_ckpt.h | 20 +- .../host_emb_ckpt/host_emb_ckpt.cpp | 20 +- .../host_emb_ckpt/host_emb_ckpt.h | 20 +- .../key_count_map_ckpt/key_count_map_ckpt.cpp | 19 +- .../key_count_map_ckpt/key_count_map_ckpt.h | 20 +- .../key_freq_map_ckpt/key_freq_map_ckpt.cpp | 19 +- .../key_freq_map_ckpt/key_freq_map_ckpt.h | 20 +- .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp | 21 +- .../nddr_feat_map_ckpt/nddr_feat_map_ckpt.h | 20 +- .../nddr_offset_ckpt/nddr_offset_ckpt.cpp | 20 +- .../nddr_offset_ckpt/nddr_offset_ckpt.h | 20 +- src/core/emb_hashmap/emb_hashmap.cpp | 20 +- src/core/emb_hashmap/emb_hashmap.h | 20 +- src/core/emb_table/emb_table.cpp | 20 +- src/core/emb_table/emb_table.h | 20 +- src/core/file_system/buffer_queue.cpp | 21 +- src/core/file_system/buffer_queue.h | 21 +- src/core/file_system/file_system.h | 20 +- src/core/file_system/file_system_handler.cpp | 20 +- src/core/file_system/file_system_handler.h | 20 +- .../hdfs_file_system/hdfs_file_system.cpp | 20 +- .../hdfs_file_system/hdfs_file_system.h | 20 +- .../hdfs_file_system/hdfs_wrapper.h | 20 +- .../local_file_system/local_file_system.cpp | 20 +- .../local_file_system/local_file_system.h | 20 +- src/core/hd_transfer/acl_channel.h | 20 +- src/core/hd_transfer/hd_transfer.cpp | 21 +- src/core/hd_transfer/hd_transfer.h | 20 +- src/core/host_emb/host_emb.cpp | 20 +- src/core/host_emb/host_emb.h | 20 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 21 +- src/core/hybrid_mgmt/hybrid_mgmt.h | 20 +- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 22 +- src/core/hybrid_mgmt/hybrid_mgmt_block.h | 22 +- .../constant_initializer.cpp | 20 +- .../constant_initializer.h | 20 +- src/core/initializer/initializer.cpp | 21 +- src/core/initializer/initializer.h | 20 +- .../random_normal_initializer.cpp | 20 +- .../random_normal_initializer.h | 20 +- .../truncated_normal_initializer.cpp | 20 +- .../truncated_normal_initializer.h | 20 +- .../key_process/feature_admit_and_evict.cpp | 20 +- .../key_process/feature_admit_and_evict.h | 20 +- src/core/key_process/key_process.cpp | 20 +- src/core/key_process/key_process.h | 20 +- src/core/ock_ctr_common/include/error_code.h | 17 +- src/core/ock_ctr_common/include/factory.h | 23 +- .../include/ock_ctr_common_def.h | 23 +- src/core/ock_ctr_common/include/unique.h | 23 +- src/core/ssd_cache/cache_manager.cpp | 21 +- src/core/ssd_cache/cache_manager.h | 20 +- src/core/ssd_cache/lfu_cache.cpp | 21 +- src/core/ssd_cache/lfu_cache.h | 21 +- src/core/ssd_engine/file.cpp | 18 +- src/core/ssd_engine/file.h | 18 +- src/core/ssd_engine/ssd_engine.cpp | 18 +- src/core/ssd_engine/ssd_engine.h | 18 +- src/core/ssd_engine/table.cpp | 16 +- src/core/ssd_engine/table.h | 18 +- src/core/utils/common.cpp | 21 +- src/core/utils/common.h | 21 +- src/core/utils/config.cpp | 21 +- src/core/utils/config.h | 21 +- src/core/utils/logger.cpp | 21 +- src/core/utils/logger.h | 21 +- src/core/utils/safe_queue.h | 21 +- src/core/utils/singleton.h | 21 +- src/core/utils/time_cost.h | 21 +- src/dataset_tf/CMakeLists.txt | 15 ++ src/dataset_tf/eos_dataset_op.cc | 21 +- src/dataset_tf/eos_dataset_op.h | 21 +- src/ops_tf/CMakeLists.txt | 15 ++ src/ops_tf/hybrid_dataset_ops.cpp | 20 +- src/ops_tf/tf_ops.h | 21 +- src/pybind/CMakeLists.txt | 15 ++ src/pybind/module_main.cpp | 21 +- src/test_ut.sh | 20 +- src/tests/CMakeLists.txt | 15 ++ src/tests/checkpoint/checkpoint_test.cpp | 20 +- .../ckpt_data_handler_test.cpp | 20 +- src/tests/emb_hashmap/emb_hashmap_test.cpp | 20 +- src/tests/emb_mgmt/emb_mgmt_test.cpp | 21 +- src/tests/emb_table/emb_table_test.cpp | 20 +- .../file_system/local_file_system_test.cpp | 17 +- src/tests/gtest_main.cpp | 21 +- src/tests/host_emb/host_emb_test.cpp | 21 +- .../hybrid_mgmt/hybrid_mgmt_block_test.cpp | 21 +- src/tests/initializer/initializer_test.cpp | 21 +- .../feature_admit_and_evict_test.cpp | 20 +- src/tests/key_process/key_process_test.cpp | 21 +- src/tests/ssd_cache/cache_manager_test.cpp | 20 +- src/tests/ssd_cache/lfu_cache_test.cpp | 20 +- src/tests/ssd_engine/engine_test.cpp | 17 +- src/tests/ssd_engine/file_test.cpp | 16 +- src/tests/ssd_engine/table_test.cpp | 17 +- src/tests/utils/common_h_test.cpp | 21 +- src/tests/utils/common_test.cpp | 21 +- src/tests/utils/config_test.cpp | 21 +- src/tests/utils/log_test.cpp | 21 +- src/tests/utils/safe_queue_test.cpp | 21 +- tests/mx_rec/core/generator_dataset.py | 15 +- tests/mx_rec/core/mock_class.py | 15 +- tests/mx_rec/core/test_build_graph.py | 15 +- tests/mx_rec/core/test_embedding.py | 15 +- tests/mx_rec/core/test_feature_process.py | 15 +- tests/mx_rec/core/test_feature_spec.py | 15 +- tests/mx_rec/core/test_helper.py | 15 +- tests/mx_rec/core/test_manager.py | 15 +- tests/mx_rec/core/test_merge_table.py | 15 +- tests/mx_rec/data/__init__.py | 15 +- tests/mx_rec/data/mock_class.py | 16 +- tests/mx_rec/data/test_dataset.py | 15 +- tests/mx_rec/graph/mock_dataset.py | 15 +- tests/mx_rec/graph/test_acg_push_ops.py | 15 +- tests/mx_rec/graph/test_merge_lookup.py | 15 +- tests/mx_rec/graph/test_modifier.py | 15 +- tests/mx_rec/graph/test_utils.py | 15 +- tests/mx_rec/saver/sparse_embedding_mock.py | 16 +- tests/mx_rec/saver/test_saver.py | 15 +- tests/mx_rec/saver/test_sparse.py | 15 +- .../util/communication/test_hccl_mgmt.py | 16 +- tests/mx_rec/util/test_atomic.py | 16 +- tests/mx_rec/util/test_normalization.py | 16 +- tests/mx_rec/util/test_perf.py | 16 +- tests/mx_rec/util/test_variable.py | 16 +- tests/mx_rec/validator/test_validators.py | 16 +- tests/run_python_dt.sh | 19 +- 212 files changed, 3294 insertions(+), 825 deletions(-) create mode 100644 LICENSE diff --git a/CMakeLists.txt b/CMakeLists.txt index d56a5b63..8e031fb6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,2 +1,17 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + cmake_minimum_required(VERSION 3.20) project(MxRec LANGUAGES CXX) diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..47f7187f --- /dev/null +++ b/LICENSE @@ -0,0 +1,217 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + ## Some of mxRec's code is derived from TensorFlow, which is subject to the following copyright notice: + + Copyright 2015 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index 086c739d..6a787dab 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,87 @@ # mxRec +## 产品背景 + +随着人工智能技术的演进,电商、长短视频、社交等行业对搜索系统、推荐系统以及广告系统的效果诉求越发强烈。在如今互联网发达的时代,大量的用户数据、商品数据、视频资料,使信息剧烈爆炸,也使得搜索推荐广告系统的价值进一步凸显。搜索推荐广告系统的需求增长必然带来对算力的需求,如何部署更大算力并充分发挥算力成为系统管理人员重点关注的问题。 + +## 产品定义 + +mxRec作为面向互联网市场搜索推荐广告的应用使能SDK产品,对于搜索推荐广告模型训练的应用场景需求,提供基于昇腾平台的搜索推荐广告框架,支撑大规模搜推广场景,助力完成搜推广模型的高效训练。mxRec的功能涉及: + +1. 模型训练基础功能。支持单机单卡训练、多机多卡分布式训练,支持基于TensorFlow开发模型。 +2. 推荐场景特有功能。基于mxRec的稀疏表方案,mxRec提供必备功能,如特征保存和加载、特征准入、特征淘汰等。 +3. 大规模稀疏表特有功能。支持加速卡内存、主机内存、主机磁盘多级存储、支持多机存储、支持动态扩容。规模可超10TB。 + +## 安装方式 + +安装前,请参考《CANN 软件安装指南CANN 软件安装指南》安装CANN开发套件软件包和TensorFlow适配昇腾插件。 + +CANN软件提供进程级环境变量设置脚本,供用户在进程中引用,以自动完成环境变量设置。用户进程结束后自动失效。可在程序启动的Shell脚本中使用如下命令设置CANN的相关环境变量,也可通过命令行执行如下命令(以root用户默认安装路径“/usr/local/Ascend”为例): +```shell +source /usr/local/Ascend/ascend-toolkit/set_env.sh +source /usr/local/Ascend/tfplugin/set_env.sh +``` + +安装依赖,若未构建镜像,直接在物理机上进行开发,则须安装以下Python依赖 +```shell +pip3 install numpy decorator sympy==1.4 cffi==1.12.3 pyyaml pathlib2 grpcio grpcio-tools protobuf==3.20.0 scipy requests mpi4py easydict scikit-learn==0.20.0 attrs +``` + +horovod依赖安装前需配置“HOROVOD_WITH_MPI”、“HOROVOD_WITH_TENSORFLOW”,依赖安装命令参考如下。 +```shell +HOROVOD_WITH_MPI=1 HOROVOD_WITH_TENSORFLOW=1 pip3.7 install horovod --no-cache-dir +``` + +### 二进制包安装 + +从昇腾开源社区直接获取编译打包后的产品包。解压后包含tf1和tf2两个版本的whl安装包,使用pip命令安装whl包(请根据实际需求,选取对应TensorFlow版本匹配的Wheel包): +```shell +pip3 install mx_rec-{version}-py3-none-linux_{arch}.whl +``` + +Wheel包默认安装在Python的“site-packages”路径,如通过“--target”参数指定目录,在安装完成后需要将mxRec路径加入“PYTHONPATH”环境变量。 + +```shell +export PYTHONPATH={mxrec_install_path}:{mxrec_install_path}/mxRec:$PYTHONPATH +``` + +如需使用动态扩容功能,进入已解压的mxRec软件包“mindxsdk-mxrec/cust_op/cust_op_by_addr”目录中。参考以下命令编译并安装动态扩容算子包。 +```shell +bash run.sh +``` + +### 源码编译安装 + +编译环境依赖: +- Python3.7.5 +- GCC 7.3.0 +- CMake 3.20.6 + +开源依赖: +- pybind11 v2.10.3 +- securec +- openmpi 4.1.1: 请参考软件文档在编译环境完成安装 +- tensorflow 1.15/2.6.5:根据实际需求选择对应版本 + +pybind11的压缩包放在与MxRec代码同级的opensource/opensource目录下,如果没有opensource目录,则需要在MxRec同级的目录下手动创建opensource/opensource目录。然后将pybind11的压缩包放在opensource/opensource目录下。解压压缩包,并且将解压之后的压缩包改名为pybind11。 + +securec是华为开源的安全函数库。下载后: +1. 将platform下的eSDK_LogAPI_V2.1.10文件夹删除 +2. 将platform下的huaweisecurec改名为securec +3. 在securec文件夹下,有src、lib和include三个文件夹,删除lib文件夹下的所有文件 +4. 将platform文件夹放到MxRec代码目录下 + +为了构建多个版本的whl包,编译脚本在python虚拟环境完成对应tensorflow版本的安装。用户可以根据实际情况调整编译脚本,指定tensorflow的安装路径。编译方法: +- build/build.sh:执行脚本完成tf1和tf2版本whl包的构建和打包。执行脚本前,请参考build/build_tf1.sh、build/build_tf2.sh创建对应的虚拟环境,在虚拟环境中完成对应tensorflow版本的安装,并修改对应的激活命令。 +- build/build_tf1.sh:执行脚本完成tf1版本whl包的构建,构建成功后,whl包在tf1_whl子目录下。执行脚本前,创建tf1虚拟环境,在虚拟环境中完成tensorflow 1.15.0版本的安装,并修改对应的激活命令。 +- build/build_tf2.sh:执行脚本完成tf2版本whl包的构建,构建成功后,whl包在tf2_whl子目录下。执行脚本前,创建tf2虚拟环境,在虚拟环境中完成tensorflow 2.6.5版本的安装,并修改对应的激活命令。 + +如需使用动态扩容功能,进入“./cust_op/cust_op_by_addr”目录中。参考以下命令编译并安装动态扩容算子包。 +```shell +bash run.sh +``` + +## 使用指导 + +mxRec所支持的使用环境、功能特性、API接口与使用样例请参考昇腾开源社区MindX SDK产品文档。 + diff --git a/build.sh b/build.sh index 87183edc..a2d7564e 100644 --- a/build.sh +++ b/build.sh @@ -1,6 +1,18 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: build entrance script. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== set -e ROOT_DIR=$(dirname "$(readlink -f "$0")") diff --git a/build/build.sh b/build/build.sh index 65657ee4..8bf383de 100644 --- a/build/build.sh +++ b/build/build.sh @@ -1,9 +1,18 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: build script. -# Author: MindX SDK -# Create: 2021 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== export GLOG_CUSTOM_PREFIX_SUPPORT=1 diff --git a/build/build_all.sh b/build/build_all.sh index bd33cc5f..8f8d275c 100644 --- a/build/build_all.sh +++ b/build/build_all.sh @@ -1,10 +1,18 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: build script. -# Author: MindX SDK -# Create: 2021 -# History: NA - +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== set -e warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } ARCH="$(uname -m)" diff --git a/build/build_tf1.sh b/build/build_tf1.sh index 14c07a9c..11b57c89 100644 --- a/build/build_tf1.sh +++ b/build/build_tf1.sh @@ -1,9 +1,18 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: build script. -# Author: MindX SDK -# Create: 2021 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== set -e warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } diff --git a/build/build_tf1_with_opensource.sh b/build/build_tf1_with_opensource.sh index 78cd24df..487b6d85 100644 --- a/build/build_tf1_with_opensource.sh +++ b/build/build_tf1_with_opensource.sh @@ -1,12 +1,21 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -# Description: build script. -# Author: MindX SDK -# Create: 2023 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== ################################################################## -# build_tf1_with_opensource.sh 用于美团客户编译MxRec和动态扩容算子 +# build_tf1_with_opensource.sh 编译MxRec和动态扩容算子 # 编译环境:Python3.7.5 GCC 7.3.0 CMake 3.20.6 # 代码主要分为四部分: # 1、准备编译MxRec所需依赖:pybind11(v2.10.3) securec diff --git a/build/build_tf2.sh b/build/build_tf2.sh index dd586a31..42321e89 100644 --- a/build/build_tf2.sh +++ b/build/build_tf2.sh @@ -1,9 +1,18 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. -# Description: build script. -# Author: MindX SDK -# Create: 2023 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== set -e warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/common.h b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/common.h index 5b22736d..ba754761 100644 --- a/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/common.h +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/common.h @@ -1,11 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * Author: MindX SDK - * Date: 2023/12/20 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef COMMON_H #define COMMON_H diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/op_runner.h b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/op_runner.h index d6f415a0..67619207 100644 --- a/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/op_runner.h +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/op_runner.h @@ -1,11 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * Author: MindX SDK - * Date: 2023/12/20 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef OP_RUNNER_H #define OP_RUNNER_H diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/operator_desc.h b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/operator_desc.h index 434353d4..5fa232ac 100644 --- a/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/operator_desc.h +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/inc/operator_desc.h @@ -1,11 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * Author: MindX SDK - * Date: 2023/12/20 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef OPERATOR_DESC_H #define OPERATOR_DESC_H diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/run.sh b/cust_op/cust_op_by_addr/aclnn_lookup_test/run.sh index 4f1f92c5..2a4a862b 100644 --- a/cust_op/cust_op_by_addr/aclnn_lookup_test/run.sh +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/run.sh @@ -1,9 +1,18 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: build test script. -# Author: MindX SDK -# Create: 2023 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== export ASCEND_SLOG_PRINT_TO_STDOUT=0 export ASCEND_GLOBAL_LOG_LEVEL=0 diff --git a/cust_op/cust_op_by_addr/aclnn_lookup_test/src/common.cpp b/cust_op/cust_op_by_addr/aclnn_lookup_test/src/common.cpp index 5388c370..d2746d80 100644 --- a/cust_op/cust_op_by_addr/aclnn_lookup_test/src/common.cpp +++ b/cust_op/cust_op_by_addr/aclnn_lookup_test/src/common.cpp @@ -1,11 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * Author: MindX SDK - * Date: 2023/12/20 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "common.h" #include diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/inc/common.h b/cust_op/cust_op_by_addr/aclnn_update_test/inc/common.h index 5b22736d..dbf3d76a 100644 --- a/cust_op/cust_op_by_addr/aclnn_update_test/inc/common.h +++ b/cust_op/cust_op_by_addr/aclnn_update_test/inc/common.h @@ -1,11 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * Author: MindX SDK - * Date: 2023/12/20 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef COMMON_H #define COMMON_H diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/inc/op_runner.h b/cust_op/cust_op_by_addr/aclnn_update_test/inc/op_runner.h index d6f415a0..e7c2faca 100644 --- a/cust_op/cust_op_by_addr/aclnn_update_test/inc/op_runner.h +++ b/cust_op/cust_op_by_addr/aclnn_update_test/inc/op_runner.h @@ -1,11 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * Author: MindX SDK - * Date: 2023/12/20 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef OP_RUNNER_H #define OP_RUNNER_H diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/inc/operator_desc.h b/cust_op/cust_op_by_addr/aclnn_update_test/inc/operator_desc.h index 097b432d..d5e45b10 100644 --- a/cust_op/cust_op_by_addr/aclnn_update_test/inc/operator_desc.h +++ b/cust_op/cust_op_by_addr/aclnn_update_test/inc/operator_desc.h @@ -1,11 +1,17 @@ -/* -* Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -* Description: This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of -* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. -* Author: MindX SDK -* Date: 2023/12/20 -*/ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef OPERATOR_DESC_H #define OPERATOR_DESC_H diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/run.sh b/cust_op/cust_op_by_addr/aclnn_update_test/run.sh index 6367be33..6c14149c 100644 --- a/cust_op/cust_op_by_addr/aclnn_update_test/run.sh +++ b/cust_op/cust_op_by_addr/aclnn_update_test/run.sh @@ -1,9 +1,18 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: build test script. -# Author: MindX SDK -# Create: 2023 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== export ASCEND_SLOG_PRINT_TO_STDOUT=0 export ASCEND_GLOBAL_LOG_LEVEL=0 diff --git a/cust_op/cust_op_by_addr/aclnn_update_test/src/common.cpp b/cust_op/cust_op_by_addr/aclnn_update_test/src/common.cpp index 5388c370..d2746d80 100644 --- a/cust_op/cust_op_by_addr/aclnn_update_test/src/common.cpp +++ b/cust_op/cust_op_by_addr/aclnn_update_test/src/common.cpp @@ -1,11 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. - * Author: MindX SDK - * Date: 2023/12/20 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "common.h" #include diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp index 445fd6c1..f7164d44 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address.cpp @@ -1,3 +1,17 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "embedding_lookup_by_address_tiling.h" #include "register/op_def_registry.h" diff --git a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h index 2a9d8951..d1476702 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h +++ b/cust_op/cust_op_by_addr/op_host/embedding_lookup_by_address_tiling.h @@ -1,3 +1,18 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef EMBEDDING_LOOKUP_BY_ADDRESS_TILING_H #define EMBEDDING_LOOKUP_BY_ADDRESS_TILING_H #include "register/tilingdata_base.h" diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp index 5a5cc953..80be3a66 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address.cpp @@ -1,3 +1,17 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "embedding_update_by_address_tiling.h" #include "register/op_def_registry.h" diff --git a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h index 7bbeb16c..32ac2843 100644 --- a/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h +++ b/cust_op/cust_op_by_addr/op_host/embedding_update_by_address_tiling.h @@ -1,3 +1,18 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef EMBEDDING_UPDATE_BY_ADDRESS_TILING_H #define EMBEDDING_UPDATE_BY_ADDRESS_TILING_H #include "register/tilingdata_base.h" diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp index 55273f1e..e0bd181a 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_lookup_by_address.cpp @@ -1,3 +1,18 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "kernel_operator.h" using namespace AscendC; diff --git a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp index ca00a5fe..26beac24 100644 --- a/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp +++ b/cust_op/cust_op_by_addr/op_kernel/embedding_update_by_address.cpp @@ -1,3 +1,18 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "kernel_operator.h" using namespace AscendC; diff --git a/cust_op/cust_op_by_addr/run.sh b/cust_op/cust_op_by_addr/run.sh index 2e7c0d67..ddb932ca 100644 --- a/cust_op/cust_op_by_addr/run.sh +++ b/cust_op/cust_op_by_addr/run.sh @@ -1,9 +1,19 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. -# Description: Build for cust_op by address -# Author: MindX SDK -# Create: 2023 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + set -e source /etc/profile diff --git a/mx_rec/__init__.py b/mx_rec/__init__.py index 53c31414..bdb85131 100644 --- a/mx_rec/__init__.py +++ b/mx_rec/__init__.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== __all__ = ["version", "__version__"] diff --git a/mx_rec/constants/__init__.py b/mx_rec/constants/__init__.py index eab15375..e32874b8 100644 --- a/mx_rec/constants/__init__.py +++ b/mx_rec/constants/__init__.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== __all__ = ["ASCEND_TIMESTAMP", "ApplyGradientsStrategy"] diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 92dab5c3..891d65e4 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from enum import Enum import numpy as np diff --git a/mx_rec/core/__init__.py b/mx_rec/core/__init__.py index f904b242..f45b5e9d 100644 --- a/mx_rec/core/__init__.py +++ b/mx_rec/core/__init__.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== __all__ = ["create_table", "sparse_lookup", "EvictHook"] diff --git a/mx_rec/core/asc/__init__.py b/mx_rec/core/asc/__init__.py index 4c0f202a..71e4b693 100644 --- a/mx_rec/core/asc/__init__.py +++ b/mx_rec/core/asc/__init__.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== __all__ = ["get_asc_insert_func", "start_asc_pipeline", "FeatureSpec"] diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index db4f037e..a4858d84 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from typing import Optional diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 52f2be23..71454f1b 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from typing import Union, Optional from functools import reduce diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 4b1e1fbf..6cab56c1 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from functools import reduce diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 7b999ac8..1aeeb573 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import tensorflow as tf diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index 9cb59bcc..d3b5dec8 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from typing import Dict, List diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 7d96eb6c..6b2adc83 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import math import os diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py index 0fb95682..6c19d41d 100644 --- a/mx_rec/core/feature_process.py +++ b/mx_rec/core/feature_process.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import time diff --git a/mx_rec/data/__init__.py b/mx_rec/data/__init__.py index 6924f767..a0260bd0 100644 --- a/mx_rec/data/__init__.py +++ b/mx_rec/data/__init__.py @@ -1,3 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/mx_rec/data/dataset.py b/mx_rec/data/dataset.py index 844fc967..1b64a5cd 100644 --- a/mx_rec/data/dataset.py +++ b/mx_rec/data/dataset.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from tensorflow.python.data.ops.dataset_ops import get_legacy_output_types, get_legacy_output_classes, \ get_legacy_output_shapes, UnaryDataset diff --git a/mx_rec/data/patch.py b/mx_rec/data/patch.py index 71655512..47c5bea3 100644 --- a/mx_rec/data/patch.py +++ b/mx_rec/data/patch.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from tensorflow.python.data.ops.dataset_ops import DatasetV2, DatasetV1Adapter diff --git a/mx_rec/graph/__init__.py b/mx_rec/graph/__init__.py index 143dd645..f4d2642c 100644 --- a/mx_rec/graph/__init__.py +++ b/mx_rec/graph/__init__.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== __all__ = ["modify_graph_and_start_emb_cache", "GraphModifierHook", "run", "ACGPushOpsToDatasetHook"] diff --git a/mx_rec/graph/acg_push_ops.py b/mx_rec/graph/acg_push_ops.py index 1622c483..06d89a45 100644 --- a/mx_rec/graph/acg_push_ops.py +++ b/mx_rec/graph/acg_push_ops.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os import weakref diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py index 380879a4..4e54b2a2 100644 --- a/mx_rec/graph/merge_lookup.py +++ b/mx_rec/graph/merge_lookup.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import tensorflow as tf diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index c883ce4d..e41723a1 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from collections import defaultdict from typing import Any, List, Dict, DefaultDict, Tuple, Union diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index ad9f2c42..5c875281 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import weakref from typing import Any diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index e78803a8..04e32ebe 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os from collections import defaultdict diff --git a/mx_rec/optimizers/__init__.py b/mx_rec/optimizers/__init__.py index 660bdc14..a467b8fe 100644 --- a/mx_rec/optimizers/__init__.py +++ b/mx_rec/optimizers/__init__.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== __all__ = [ "create_hash_optimizer", "create_hash_optimizer_by_addr", diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index 596b6ef3..13459b14 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/base.py b/mx_rec/optimizers/base.py index 313fded8..a5d68a70 100644 --- a/mx_rec/optimizers/base.py +++ b/mx_rec/optimizers/base.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index a7424f7d..1ff1c052 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 3487a8d9..b1edb4ed 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index 33457205..2838fb4a 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 75ee91b8..cc51ab5c 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index b977ac5e..e78e24c0 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from __future__ import absolute_import from __future__ import division diff --git a/mx_rec/saver/__init__.py b/mx_rec/saver/__init__.py index bfd098f3..3697f282 100644 --- a/mx_rec/saver/__init__.py +++ b/mx_rec/saver/__init__.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== __all__ = ["export", "save", "restore"] diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index d64d5d9d..d5d30946 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -1,6 +1,35 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Some code is derived from Tensorflow, which is subject to the following copyright notice: +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================== import os import time diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 9c08d0a0..aebe621f 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import json import os diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index a75fec13..9e66769a 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os import json diff --git a/mx_rec/util/__init__.py b/mx_rec/util/__init__.py index 2d244a10..6c919515 100644 --- a/mx_rec/util/__init__.py +++ b/mx_rec/util/__init__.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== __all__ = [ "init", "get_rank_id", "get_initializer", "terminate_config_initializer", "clear_channel", diff --git a/mx_rec/util/atomic.py b/mx_rec/util/atomic.py index 4c7242dc..8c3e1870 100644 --- a/mx_rec/util/atomic.py +++ b/mx_rec/util/atomic.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import threading diff --git a/mx_rec/util/communication/__init__.py b/mx_rec/util/communication/__init__.py index 9da1fcf6..05fd37d4 100644 --- a/mx_rec/util/communication/__init__.py +++ b/mx_rec/util/communication/__init__.py @@ -1,5 +1,18 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== __all__ = ["hccl_mgmt"] \ No newline at end of file diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index 89a50400..aaf15267 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import json import os diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index 5f4f2a48..1852d21b 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os import dataclasses from dataclasses import dataclass diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index e33a12c3..b1021b05 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import atexit import os diff --git a/mx_rec/util/log.py b/mx_rec/util/log.py index 9fb1c678..e15d06f5 100644 --- a/mx_rec/util/log.py +++ b/mx_rec/util/log.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os import logging diff --git a/mx_rec/util/normalization.py b/mx_rec/util/normalization.py index 45f9e2fd..dc9dd2c1 100644 --- a/mx_rec/util/normalization.py +++ b/mx_rec/util/normalization.py @@ -1,6 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import re from mx_rec.util.log import logger diff --git a/mx_rec/util/ops.py b/mx_rec/util/ops.py index 7869ce56..f0bd2c2e 100644 --- a/mx_rec/util/ops.py +++ b/mx_rec/util/ops.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os from types import ModuleType diff --git a/mx_rec/util/perf.py b/mx_rec/util/perf.py index 5070773e..3feb7332 100644 --- a/mx_rec/util/perf.py +++ b/mx_rec/util/perf.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import time diff --git a/mx_rec/util/tf_version_adapter.py b/mx_rec/util/tf_version_adapter.py index 7d13f96e..a0c60b72 100644 --- a/mx_rec/util/tf_version_adapter.py +++ b/mx_rec/util/tf_version_adapter.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import tensorflow as tf diff --git a/mx_rec/util/variable.py b/mx_rec/util/variable.py index c74b8718..ff8f6989 100644 --- a/mx_rec/util/variable.py +++ b/mx_rec/util/variable.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import tensorflow as tf from tensorflow.python.framework import ops diff --git a/mx_rec/validator/__init__.py b/mx_rec/validator/__init__.py index 8f75c6b6..a0260bd0 100644 --- a/mx_rec/validator/__init__.py +++ b/mx_rec/validator/__init__.py @@ -1,3 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index d0ae61a3..af064a81 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from typing import List, Tuple, Any, Callable, Dict, Optional, Union, Type diff --git a/setup.py b/setup.py index 55573ac2..efb4c994 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -# Description: setup script. -# Author: MindX SDK -# Create: 2022 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os import stat diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e5ab5996..9965bfba 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,3 +1,18 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + cmake_minimum_required(VERSION 3.20) project(MxRec LANGUAGES CXX) set(CMAKE_CXX_STANDARD 14) diff --git a/src/build.sh b/src/build.sh index 1b824dad..8250f350 100644 --- a/src/build.sh +++ b/src/build.sh @@ -1,9 +1,19 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -# Description: build script. -# Author: MindX SDK -# Create: 2022 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + set -e [ -d build ] && rm -rf build; mkdir build && cd build || exit 1 diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index e8eda15c..dd1052f2 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -1,3 +1,18 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + cmake_minimum_required(VERSION 3.12) set(CMAKE_CXX_STANDARD 17) diff --git a/src/core/checkpoint/checkpoint.cpp b/src/core/checkpoint/checkpoint.cpp index 79c11c97..92745f3d 100644 --- a/src/core/checkpoint/checkpoint.cpp +++ b/src/core/checkpoint/checkpoint.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/core/checkpoint/checkpoint.h b/src/core/checkpoint/checkpoint.h index 2872b6eb..c9ff9000 100644 --- a/src/core/checkpoint/checkpoint.h +++ b/src/core/checkpoint/checkpoint.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * Description: use to manage model saving and loading process - * Author: MindX SDK - * Create: 2022-11-15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_CHECKPOINT_H #define MX_REC_CHECKPOINT_H diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.cpp b/src/core/ckpt_data_handler/ckpt_data_handler.cpp index e9c93476..18f1a090 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.cpp +++ b/src/core/ckpt_data_handler/ckpt_data_handler.cpp @@ -1,9 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-12 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "ckpt_data_handler.h" diff --git a/src/core/ckpt_data_handler/ckpt_data_handler.h b/src/core/ckpt_data_handler/ckpt_data_handler.h index 460438f8..383317d9 100644 --- a/src/core/ckpt_data_handler/ckpt_data_handler.h +++ b/src/core/ckpt_data_handler/ckpt_data_handler.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-10 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_CKPT_DATA_HANDLER_H #define MX_REC_CKPT_DATA_HANDLER_H diff --git a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp index f334e49c..783f46fb 100644 --- a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp +++ b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-14 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "emb_hash_ckpt.h" #include diff --git a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h index c7cd8aec..06b12c6c 100644 --- a/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h +++ b/src/core/ckpt_data_handler/emb_hash_ckpt/emb_hash_ckpt.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-14 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_EMB_HASH_CKPT_H #define MX_REC_EMB_HASH_CKPT_H diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp index 918e026e..be35044b 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-22 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "feat_admit_n_evict_ckpt.h" diff --git a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h index 37d8623e..96afbfc0 100644 --- a/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h +++ b/src/core/ckpt_data_handler/feat_admit_n_evict_ckpt/feat_admit_n_evict_ckpt.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-22 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MXREC_FEAT_ADMIT_N_EVICT_CKPT_H #define MXREC_FEAT_ADMIT_N_EVICT_CKPT_H diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp index 9eb8d2b7..ce176bfd 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-12 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "host_emb_ckpt.h" diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h index f51a5120..fea5dd5d 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-12 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_HOST_EMB_CKPT_H #define MX_REC_HOST_EMB_CKPT_H diff --git a/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.cpp b/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.cpp index d3ae9963..7acfcfe4 100644 --- a/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-11-01 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "key_count_map_ckpt.h" using namespace std; diff --git a/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h b/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h index 16869a9d..117d957b 100644 --- a/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h +++ b/src/core/ckpt_data_handler/key_count_map_ckpt/key_count_map_ckpt.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-11-01 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MXREC_KEY_COUNT_MAP_CKPT_H #define MXREC_KEY_COUNT_MAP_CKPT_H diff --git a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp index 1609faaf..da6ee35f 100644 --- a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-08-17 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "key_freq_map_ckpt.h" using namespace std; diff --git a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h index cde71e68..1afd2ce9 100644 --- a/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h +++ b/src/core/ckpt_data_handler/key_freq_map_ckpt/key_freq_map_ckpt.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-08-17 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_KEY_FREQ_MAP_CKPT_H #define MX_REC_KEY_FREQ_MAP_CKPT_H diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp index 66eb4ecc..147b72e8 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.cpp @@ -1,9 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-17 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "nddr_feat_map_ckpt.h" diff --git a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h index 57fb7d21..99837c1e 100644 --- a/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h +++ b/src/core/ckpt_data_handler/nddr_feat_map_ckpt/nddr_feat_map_ckpt.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-17 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_NDDR_FEAT_MAP_CKPT_H #define MX_REC_NDDR_FEAT_MAP_CKPT_H diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp index 41e81a0c..c5379ef5 100644 --- a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-17 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "nddr_offset_ckpt.h" diff --git a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h index 0e414462..25b6347e 100644 --- a/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h +++ b/src/core/ckpt_data_handler/nddr_offset_ckpt/nddr_offset_ckpt.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-17 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_NDDR_OFFSET_CKPT_H #define MX_REC_NDDR_OFFSET_CKPT_H diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 8905f6db..9bdd5b04 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * Description: common module - * Author: MindX SDK - * Date: 2022/11/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "emb_hashmap.h" #include diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index f76b18b7..3ad51442 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module - * Author: MindX SDK - * Date: 2022/11/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_EMB_HASHMAP_H #define MX_REC_EMB_HASHMAP_H diff --git a/src/core/emb_table/emb_table.cpp b/src/core/emb_table/emb_table.cpp index af617e3c..1c24eb2b 100644 --- a/src/core/emb_table/emb_table.cpp +++ b/src/core/emb_table/emb_table.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * Description: emb table - * Author: MindX SDK - * Date: 2023/5/6 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/core/emb_table/emb_table.h b/src/core/emb_table/emb_table.h index 62200ada..2d30818c 100644 --- a/src/core/emb_table/emb_table.h +++ b/src/core/emb_table/emb_table.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * Description: emb table - * Author: MindX SDK - * Date: 2023/5/6 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_EMB_TABLE_H #define MX_REC_EMB_TABLE_H diff --git a/src/core/file_system/buffer_queue.cpp b/src/core/file_system/buffer_queue.cpp index 0d289359..87e6a98b 100644 --- a/src/core/file_system/buffer_queue.cpp +++ b/src/core/file_system/buffer_queue.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * Description: checkpoint module - * Author: MindX SDK - * Date: 2023/9/28 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "buffer_queue.h" diff --git a/src/core/file_system/buffer_queue.h b/src/core/file_system/buffer_queue.h index cf38dff8..93636a51 100644 --- a/src/core/file_system/buffer_queue.h +++ b/src/core/file_system/buffer_queue.h @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * Description: checkpoint module - * Author: MindX SDK - * Date: 2023/9/28 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MXREC_BUFFER_QUEUE_H #define MXREC_BUFFER_QUEUE_H diff --git a/src/core/file_system/file_system.h b/src/core/file_system/file_system.h index 2e8a788f..6b08b6f6 100644 --- a/src/core/file_system/file_system.h +++ b/src/core/file_system/file_system.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-10-19 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_FILE_SYSTEM_H #define MX_REC_FILE_SYSTEM_H diff --git a/src/core/file_system/file_system_handler.cpp b/src/core/file_system/file_system_handler.cpp index faa8147f..9c6bd890 100644 --- a/src/core/file_system/file_system_handler.cpp +++ b/src/core/file_system/file_system_handler.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-11-16 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "file_system_handler.h" diff --git a/src/core/file_system/file_system_handler.h b/src/core/file_system/file_system_handler.h index b2a92999..643d6ca9 100644 --- a/src/core/file_system/file_system_handler.h +++ b/src/core/file_system/file_system_handler.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-10-19 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_FILE_SYSTEM_HANDLER_H #define MX_REC_FILE_SYSTEM_HANDLER_H diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp index 5ac3291d..6dc14cc6 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-10-19 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "hdfs_file_system.h" diff --git a/src/core/file_system/hdfs_file_system/hdfs_file_system.h b/src/core/file_system/hdfs_file_system/hdfs_file_system.h index e6c0f6f1..8eaf8582 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_file_system.h +++ b/src/core/file_system/hdfs_file_system/hdfs_file_system.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-10-19 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_HDFS_FILE_SYSTEM_H #define MX_REC_HDFS_FILE_SYSTEM_H diff --git a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h index 2accf487..806107b4 100644 --- a/src/core/file_system/hdfs_file_system/hdfs_wrapper.h +++ b/src/core/file_system/hdfs_file_system/hdfs_wrapper.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-10-27 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_HDFS_LOADER_H #define MX_REC_HDFS_LOADER_H diff --git a/src/core/file_system/local_file_system/local_file_system.cpp b/src/core/file_system/local_file_system/local_file_system.cpp index 6e55072f..36cd0c57 100644 --- a/src/core/file_system/local_file_system/local_file_system.cpp +++ b/src/core/file_system/local_file_system/local_file_system.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-10-19 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "local_file_system.h" diff --git a/src/core/file_system/local_file_system/local_file_system.h b/src/core/file_system/local_file_system/local_file_system.h index 78ea4167..06a8d18a 100644 --- a/src/core/file_system/local_file_system/local_file_system.h +++ b/src/core/file_system/local_file_system/local_file_system.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2023-10-19 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_LOCAL_FILE_SYSTEM_H #define MX_REC_LOCAL_FILE_SYSTEM_H diff --git a/src/core/hd_transfer/acl_channel.h b/src/core/hd_transfer/acl_channel.h index efbadd77..dc55a4a2 100644 --- a/src/core/hd_transfer/acl_channel.h +++ b/src/core/hd_transfer/acl_channel.h @@ -1,9 +1,17 @@ -/* -* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. -* Description: acl channel api -* Author: MindX SDK -* Date: 2022/11/15 -*/ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef ACL_CHANNEL_H #define ACL_CHANNEL_H diff --git a/src/core/hd_transfer/hd_transfer.cpp b/src/core/hd_transfer/hd_transfer.cpp index fd0ce522..328819a4 100644 --- a/src/core/hd_transfer/hd_transfer.cpp +++ b/src/core/hd_transfer/hd_transfer.cpp @@ -1,9 +1,18 @@ -/* -* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. -* Description: common module -* Author: MindX SDK -* Date: 2022/11/15 -*/ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "hd_transfer.h" #include #include "utils/common.h" diff --git a/src/core/hd_transfer/hd_transfer.h b/src/core/hd_transfer/hd_transfer.h index eceaa617..9ad822b9 100644 --- a/src/core/hd_transfer/hd_transfer.h +++ b/src/core/hd_transfer/hd_transfer.h @@ -1,9 +1,17 @@ -/* -* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. -* Description: common module -* Author: MindX SDK -* Date: 2022/11/15 -*/ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_HD_TRANSFER_H #define MX_REC_HD_TRANSFER_H diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index e75ba892..7f885fb0 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module - * Author: MindX SDK - * Date: 2022/11/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "host_emb.h" #include diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index d61b7cad..3b7e6942 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module - * Author: MindX SDK - * Date: 2022/11/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_HOSTEMB_H #define MX_REC_HOSTEMB_H diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 26ac9c75..c35a7443 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -1,9 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module - * Author: MindX SDK - * Date: 2022/11/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "hybrid_mgmt.h" #include "utils/time_cost.h" diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index 7a9fa84d..ce55172e 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module - * Author: MindX SDK - * Date: 2022/11/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_EMB_MGMT_H #define MX_REC_EMB_MGMT_H diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index daeb07c8..4ce54bb4 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -1,10 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: hybrid mgmt module,Record the number of program running steps, - * manage blocking and wakeup - * Author: MindX SDK - * Date: 2023/08/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include #include "utils/common.h" diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.h b/src/core/hybrid_mgmt/hybrid_mgmt_block.h index 930df3ea..2b9322af 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.h @@ -1,10 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: hybrid mgmt module,Record the number of program running steps, - * manage blocking and wakeup - * Author: MindX SDK - * Date: 2023/08/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef MX_REC_HYBRID_BLOCKING_H #define MX_REC_HYBRID_BLOCKING_H diff --git a/src/core/initializer/constant_initializer/constant_initializer.cpp b/src/core/initializer/constant_initializer/constant_initializer.cpp index 2c2e1489..fac007a0 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.cpp +++ b/src/core/initializer/constant_initializer/constant_initializer.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: constant initializer module - * Author: MindX SDK - * Date: 2022/12/22 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "constant_initializer.h" #include "utils/common.h" diff --git a/src/core/initializer/constant_initializer/constant_initializer.h b/src/core/initializer/constant_initializer/constant_initializer.h index 6ca1c3e5..a9ffd970 100644 --- a/src/core/initializer/constant_initializer/constant_initializer.h +++ b/src/core/initializer/constant_initializer/constant_initializer.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: constant initializer module - * Author: MindX SDK - * Date: 2022/12/22 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_CONSTANT_INITIALIZER_H #define MX_REC_CONSTANT_INITIALIZER_H diff --git a/src/core/initializer/initializer.cpp b/src/core/initializer/initializer.cpp index 51586eac..ce7a4291 100644 --- a/src/core/initializer/initializer.cpp +++ b/src/core/initializer/initializer.cpp @@ -1,9 +1,18 @@ -/* -* Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. -* Description: initializer module -* Author: MindX SDK -* Date: 2022/12/22 -*/ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "initializer.h" #include diff --git a/src/core/initializer/initializer.h b/src/core/initializer/initializer.h index dbe59ac9..647672fd 100644 --- a/src/core/initializer/initializer.h +++ b/src/core/initializer/initializer.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: initializer module - * Author: MindX SDK - * Date: 2022/12/22 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_INITIALIZER_H #define MX_REC_INITIALIZER_H diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp index c10b7e46..1ea0084f 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: random normal initializer module - * Author: MindX SDK - * Date: 2022/12/23 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include "utils/common.h" diff --git a/src/core/initializer/random_normal_initializer/random_normal_initializer.h b/src/core/initializer/random_normal_initializer/random_normal_initializer.h index fedb0b5f..9d5f9942 100644 --- a/src/core/initializer/random_normal_initializer/random_normal_initializer.h +++ b/src/core/initializer/random_normal_initializer/random_normal_initializer.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: random normal initializer module - * Author: MindX SDK - * Date: 2022/12/23 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_RANDOM_NORMAL_INITIALIZER_H #define MX_REC_RANDOM_NORMAL_INITIALIZER_H diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp index 0ee9c336..d50a7a97 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: truncated normal initializer module - * Author: MindX SDK - * Date: 2022/12/22 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include "utils/common.h" diff --git a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h index e7d9ea5f..923eca18 100644 --- a/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h +++ b/src/core/initializer/truncated_normal_initializer/truncated_normal_initializer.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: truncated normal initializer module - * Author: MindX SDK - * Date: 2022/12/22 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_TRUNCATED_NORMAL_INITIALIZER_H #define MX_REC_TRUNCATED_NORMAL_INITIALIZER_H diff --git a/src/core/key_process/feature_admit_and_evict.cpp b/src/core/key_process/feature_admit_and_evict.cpp index ec8fc476..fe7295b2 100644 --- a/src/core/key_process/feature_admit_and_evict.cpp +++ b/src/core/key_process/feature_admit_and_evict.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: operator module - * Author: MindX SDK - * Date: 2022/11/23 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "feature_admit_and_evict.h" #include diff --git a/src/core/key_process/feature_admit_and_evict.h b/src/core/key_process/feature_admit_and_evict.h index 4103c82a..cb633e1e 100644 --- a/src/core/key_process/feature_admit_and_evict.h +++ b/src/core/key_process/feature_admit_and_evict.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: operator module - * Author: MindX SDK - * Date: 2022/11/23 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef FEATURE_ADMIT_AND_EVICT_H #define FEATURE_ADMIT_AND_EVICT_H diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 8b1ab8e3..03a0dbb7 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Date: 2022/11/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "key_process.h" diff --git a/src/core/key_process/key_process.h b/src/core/key_process/key_process.h index b28038ca..0bf704d3 100644 --- a/src/core/key_process/key_process.h +++ b/src/core/key_process/key_process.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Date: 2022/11/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_KEY_PROCESS_H #define MX_REC_KEY_PROCESS_H diff --git a/src/core/ock_ctr_common/include/error_code.h b/src/core/ock_ctr_common/include/error_code.h index b8616b46..a9b98f23 100644 --- a/src/core/ock_ctr_common/include/error_code.h +++ b/src/core/ock_ctr_common/include/error_code.h @@ -1,6 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef OCK_CTR_ERROR_CODE_H #define OCK_CTR_ERROR_CODE_H diff --git a/src/core/ock_ctr_common/include/factory.h b/src/core/ock_ctr_common/include/factory.h index 6753462c..44a2fce0 100644 --- a/src/core/ock_ctr_common/include/factory.h +++ b/src/core/ock_ctr_common/include/factory.h @@ -1,12 +1,17 @@ -/* - * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * @Description: - * @Version: 1.0 - * @Author: dev - * @Date: 2023-05-5 09:50:00 - * @LastEditors: dev - * @LastEditTime: 2023-05-5 09:50:00 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef UNIQUE_OCK_CTR_COMMON_H #define UNIQUE_OCK_CTR_COMMON_H diff --git a/src/core/ock_ctr_common/include/ock_ctr_common_def.h b/src/core/ock_ctr_common/include/ock_ctr_common_def.h index 66a50c8b..e8b3f0b5 100644 --- a/src/core/ock_ctr_common/include/ock_ctr_common_def.h +++ b/src/core/ock_ctr_common/include/ock_ctr_common_def.h @@ -1,12 +1,17 @@ -/* - * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * @Description: - * @Version: 1.0 - * @Author: dev - * @Date: 2023-05-5 09:50:00 - * @LastEditors: dev - * @LastEditTime: 2023-05-5 09:50:00 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef OCK_OCK_CTR_COMMON_DEF_H #define OCK_OCK_CTR_COMMON_DEF_H diff --git a/src/core/ock_ctr_common/include/unique.h b/src/core/ock_ctr_common/include/unique.h index 59ed98b5..cb8960e7 100644 --- a/src/core/ock_ctr_common/include/unique.h +++ b/src/core/ock_ctr_common/include/unique.h @@ -1,12 +1,17 @@ -/* - * @Copyright: Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * @Description: - * @Version: 1.0 - * @Author: dev - * @Date: 2023-05-5 09:50:00 - * @LastEditors: dev - * @LastEditTime: 2023-05-5 09:50:00 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef OCK_UNIQUE_H #define OCK_UNIQUE_H diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index fdf3eac3..a75834df 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -1,9 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: ssd cache module - * Author: MindX SDK - * Date: 2023/8/10 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "cache_manager.h" #include diff --git a/src/core/ssd_cache/cache_manager.h b/src/core/ssd_cache/cache_manager.h index 1995556a..26ca3682 100644 --- a/src/core/ssd_cache/cache_manager.h +++ b/src/core/ssd_cache/cache_manager.h @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: ssd cache module head - * Author: MindX SDK - * Date: 2023/8/10 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MXREC_CACHE_MANAGER_H #define MXREC_CACHE_MANAGER_H diff --git a/src/core/ssd_cache/lfu_cache.cpp b/src/core/ssd_cache/lfu_cache.cpp index 2ceb5607..c204e336 100644 --- a/src/core/ssd_cache/lfu_cache.cpp +++ b/src/core/ssd_cache/lfu_cache.cpp @@ -1,9 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: lfu cache module - * Author: MindX SDK - * Date: 2023/8/10 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "lfu_cache.h" #include diff --git a/src/core/ssd_cache/lfu_cache.h b/src/core/ssd_cache/lfu_cache.h index 46584474..247e490e 100644 --- a/src/core/ssd_cache/lfu_cache.h +++ b/src/core/ssd_cache/lfu_cache.h @@ -1,9 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: lfu cache module - * Author: MindX SDK - * Date: 2023/8/10 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef MXREC_LFU_CACHE_H #define MXREC_LFU_CACHE_H diff --git a/src/core/ssd_engine/file.cpp b/src/core/ssd_engine/file.cpp index 6b8805ba..83395f36 100644 --- a/src/core/ssd_engine/file.cpp +++ b/src/core/ssd_engine/file.cpp @@ -1,6 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "file.h" #include diff --git a/src/core/ssd_engine/file.h b/src/core/ssd_engine/file.h index 1c234d11..3f0f7d1a 100644 --- a/src/core/ssd_engine/file.h +++ b/src/core/ssd_engine/file.h @@ -1,6 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef MXREC_FILE_H #define MXREC_FILE_H diff --git a/src/core/ssd_engine/ssd_engine.cpp b/src/core/ssd_engine/ssd_engine.cpp index 25d2da38..65708792 100644 --- a/src/core/ssd_engine/ssd_engine.cpp +++ b/src/core/ssd_engine/ssd_engine.cpp @@ -1,6 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include "ssd_engine.h" using namespace MxRec; diff --git a/src/core/ssd_engine/ssd_engine.h b/src/core/ssd_engine/ssd_engine.h index b6ad644d..10f89d57 100644 --- a/src/core/ssd_engine/ssd_engine.h +++ b/src/core/ssd_engine/ssd_engine.h @@ -1,6 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef MXREC_ENGINE_H #define MXREC_ENGINE_H diff --git a/src/core/ssd_engine/table.cpp b/src/core/ssd_engine/table.cpp index d294bb36..c7ed5363 100644 --- a/src/core/ssd_engine/table.cpp +++ b/src/core/ssd_engine/table.cpp @@ -1,7 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "table.h" diff --git a/src/core/ssd_engine/table.h b/src/core/ssd_engine/table.h index 4c7e1ad9..87fa6f35 100644 --- a/src/core/ssd_engine/table.h +++ b/src/core/ssd_engine/table.h @@ -1,6 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #ifndef MXREC_TABLE_H #define MXREC_TABLE_H diff --git a/src/core/utils/common.cpp b/src/core/utils/common.cpp index 48c3a8f4..7ba37e54 100644 --- a/src/core/utils/common.cpp +++ b/src/core/utils/common.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module - * Author: MindX SDK - * Create: 2021 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "common.h" diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 21132730..e857eb5c 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module - * Author: MindX SDK - * Create: 2021 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef COMMON_H #define COMMON_H diff --git a/src/core/utils/config.cpp b/src/core/utils/config.cpp index cabbfee4..03951ce1 100644 --- a/src/core/utils/config.cpp +++ b/src/core/utils/config.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: config module - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "config.h" diff --git a/src/core/utils/config.h b/src/core/utils/config.h index 49b4b501..4c56c0d4 100644 --- a/src/core/utils/config.h +++ b/src/core/utils/config.h @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: config module - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MXREC_CONFIG_H #define MXREC_CONFIG_H diff --git a/src/core/utils/logger.cpp b/src/core/utils/logger.cpp index 59134dda..23e84bd0 100644 --- a/src/core/utils/logger.cpp +++ b/src/core/utils/logger.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "utils/logger.h" diff --git a/src/core/utils/logger.h b/src/core/utils/logger.h index f095599b..82321acd 100644 --- a/src/core/utils/logger.h +++ b/src/core/utils/logger.h @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: common module - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MXREC_LOGGER_H #define MXREC_LOGGER_H diff --git a/src/core/utils/safe_queue.h b/src/core/utils/safe_queue.h index 79122038..95bb0b6d 100644 --- a/src/core/utils/safe_queue.h +++ b/src/core/utils/safe_queue.h @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: safe queue class - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef SAFE_QUEUE_H #define SAFE_QUEUE_H diff --git a/src/core/utils/singleton.h b/src/core/utils/singleton.h index 7a265a29..956c57c4 100644 --- a/src/core/utils/singleton.h +++ b/src/core/utils/singleton.h @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: singleton module. - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef RC_UTILS_SINGLETON_H #define RC_UTILS_SINGLETON_H diff --git a/src/core/utils/time_cost.h b/src/core/utils/time_cost.h index 495282c1..55f5f1e0 100644 --- a/src/core/utils/time_cost.h +++ b/src/core/utils/time_cost.h @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: time cost profile module. - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef TIMECOST_H #define TIMECOST_H diff --git a/src/dataset_tf/CMakeLists.txt b/src/dataset_tf/CMakeLists.txt index 70d42d88..495c8384 100644 --- a/src/dataset_tf/CMakeLists.txt +++ b/src/dataset_tf/CMakeLists.txt @@ -1,3 +1,18 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + cmake_minimum_required(VERSION 3.12) set(CMAKE_CXX_STANDARD 14) set(CMAKE_BUILD_TYPE "Release") diff --git a/src/dataset_tf/eos_dataset_op.cc b/src/dataset_tf/eos_dataset_op.cc index 112c64b3..2acd3a68 100644 --- a/src/dataset_tf/eos_dataset_op.cc +++ b/src/dataset_tf/eos_dataset_op.cc @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: dataset eos ops. - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include "eos_dataset_op.h" diff --git a/src/dataset_tf/eos_dataset_op.h b/src/dataset_tf/eos_dataset_op.h index 5f5383f5..802fc641 100644 --- a/src/dataset_tf/eos_dataset_op.h +++ b/src/dataset_tf/eos_dataset_op.h @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: dataset eos ops. - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef TENSORFLOW_CORE_KERNELS_DATA_EOS_DATASET_OP_H_ #define TENSORFLOW_CORE_KERNELS_DATA_EOS_DATASET_OP_H_ diff --git a/src/ops_tf/CMakeLists.txt b/src/ops_tf/CMakeLists.txt index 055cb842..c534cc97 100644 --- a/src/ops_tf/CMakeLists.txt +++ b/src/ops_tf/CMakeLists.txt @@ -1,3 +1,18 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + cmake_minimum_required(VERSION 3.12) set(CMAKE_CXX_STANDARD 14) diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index c605f06b..7d5cb1a4 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -1,11 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: dataset ops. - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/ops_tf/tf_ops.h b/src/ops_tf/tf_ops.h index 71b011bc..f3a7ed32 100644 --- a/src/ops_tf/tf_ops.h +++ b/src/ops_tf/tf_ops.h @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: tf ops. - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #ifndef MX_REC_TF_OPS_H #define MX_REC_TF_OPS_H diff --git a/src/pybind/CMakeLists.txt b/src/pybind/CMakeLists.txt index 63131cd3..9d4682c3 100644 --- a/src/pybind/CMakeLists.txt +++ b/src/pybind/CMakeLists.txt @@ -1,3 +1,18 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + cmake_minimum_required(VERSION 3.20) pybind11_add_module(mxrec_pybind module_main.cpp) diff --git a/src/pybind/module_main.cpp b/src/pybind/module_main.cpp index d3793f1d..8d00000c 100644 --- a/src/pybind/module_main.cpp +++ b/src/pybind/module_main.cpp @@ -1,9 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: pybind module - * Author: MindX SDK - * Date: 2022/11/15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/src/test_ut.sh b/src/test_ut.sh index 3a2987f0..7dde8ac2 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -1,9 +1,19 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -# Description: NA -# Author: MindX SDK -# Create: 2022 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + set -e # add mpirun env diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index cbf85dd9..ec959f0d 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -1,3 +1,18 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + # 开启测试 enable_testing() set(CMAKE_CXX_STANDARD 17) diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index 3754cdf2..cc18eb14 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-11-15 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp index 2ac145d4..390e5c76 100644 --- a/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp +++ b/src/tests/ckpt_data_handler/ckpt_data_handler_test.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: - * Author: MindX SDK - * Create: 2022-12-03 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include diff --git a/src/tests/emb_hashmap/emb_hashmap_test.cpp b/src/tests/emb_hashmap/emb_hashmap_test.cpp index 4c08e315..960538d5 100644 --- a/src/tests/emb_hashmap/emb_hashmap_test.cpp +++ b/src/tests/emb_hashmap/emb_hashmap_test.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: emb_hashmap test - * Author: MindX SDK - * Date: 2023/9/18 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/emb_mgmt/emb_mgmt_test.cpp b/src/tests/emb_mgmt/emb_mgmt_test.cpp index 20735b94..e47f3b4f 100644 --- a/src/tests/emb_mgmt/emb_mgmt_test.cpp +++ b/src/tests/emb_mgmt/emb_mgmt_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: emb mgmt test - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include "hybrid_mgmt/hybrid_mgmt.h" diff --git a/src/tests/emb_table/emb_table_test.cpp b/src/tests/emb_table/emb_table_test.cpp index d669c7c6..52a0c169 100644 --- a/src/tests/emb_table/emb_table_test.cpp +++ b/src/tests/emb_table/emb_table_test.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: emb table test - * Author: MindX SDK - * Date: 2023/5/6 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/file_system/local_file_system_test.cpp b/src/tests/file_system/local_file_system_test.cpp index 359b86de..06f8b273 100644 --- a/src/tests/file_system/local_file_system_test.cpp +++ b/src/tests/file_system/local_file_system_test.cpp @@ -1,6 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include diff --git a/src/tests/gtest_main.cpp b/src/tests/gtest_main.cpp index 50747a2f..882feb00 100644 --- a/src/tests/gtest_main.cpp +++ b/src/tests/gtest_main.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: gtest main - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/host_emb/host_emb_test.cpp b/src/tests/host_emb/host_emb_test.cpp index 2632e0df..3bcc34f7 100644 --- a/src/tests/host_emb/host_emb_test.cpp +++ b/src/tests/host_emb/host_emb_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: host emb test - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp index f278875a..2bb86c42 100644 --- a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp +++ b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: key process test - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/initializer/initializer_test.cpp b/src/tests/initializer/initializer_test.cpp index b27a6a26..ae8c8702 100644 --- a/src/tests/initializer/initializer_test.cpp +++ b/src/tests/initializer/initializer_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: initializer test - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/key_process/feature_admit_and_evict_test.cpp b/src/tests/key_process/feature_admit_and_evict_test.cpp index 4ef990f2..09cadc7f 100644 --- a/src/tests/key_process/feature_admit_and_evict_test.cpp +++ b/src/tests/key_process/feature_admit_and_evict_test.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. - * Description: operator module - * Author: MindX SDK - * Date: 2022/12/08 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 0e14a8f7..3b91e726 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2021-2022. All rights reserved. - * Description: key process test - * Author: MindX SDK - * Create: 2022 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 767dee98..31c953fa 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: ssd cache module head - * Author: MindX SDK - * Date: 2023/8/19 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/ssd_cache/lfu_cache_test.cpp b/src/tests/ssd_cache/lfu_cache_test.cpp index 76d5f7fa..1adf4aad 100644 --- a/src/tests/ssd_cache/lfu_cache_test.cpp +++ b/src/tests/ssd_cache/lfu_cache_test.cpp @@ -1,9 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: lfu cache module - * Author: MindX SDK - * Date: 2023/8/10 - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/ssd_engine/engine_test.cpp b/src/tests/ssd_engine/engine_test.cpp index a09b2078..aad64a99 100644 --- a/src/tests/ssd_engine/engine_test.cpp +++ b/src/tests/ssd_engine/engine_test.cpp @@ -1,6 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/ssd_engine/file_test.cpp b/src/tests/ssd_engine/file_test.cpp index 1b60a801..599b5975 100644 --- a/src/tests/ssd_engine/file_test.cpp +++ b/src/tests/ssd_engine/file_test.cpp @@ -1,7 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/ssd_engine/table_test.cpp b/src/tests/ssd_engine/table_test.cpp index 6fdb06b8..2e180c13 100644 --- a/src/tests/ssd_engine/table_test.cpp +++ b/src/tests/ssd_engine/table_test.cpp @@ -1,6 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/utils/common_h_test.cpp b/src/tests/utils/common_h_test.cpp index 2fcd4083..2e86b88d 100644 --- a/src/tests/utils/common_h_test.cpp +++ b/src/tests/utils/common_h_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: common.h test - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include "utils/common.h" diff --git a/src/tests/utils/common_test.cpp b/src/tests/utils/common_test.cpp index d918360b..8377c738 100644 --- a/src/tests/utils/common_test.cpp +++ b/src/tests/utils/common_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: common.cpp test - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/utils/config_test.cpp b/src/tests/utils/config_test.cpp index 24dd1c2e..fc1b7b95 100644 --- a/src/tests/utils/config_test.cpp +++ b/src/tests/utils/config_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: config test - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/src/tests/utils/log_test.cpp b/src/tests/utils/log_test.cpp index ebdb8487..98111850 100644 --- a/src/tests/utils/log_test.cpp +++ b/src/tests/utils/log_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: log test - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include "utils/common.h" diff --git a/src/tests/utils/safe_queue_test.cpp b/src/tests/utils/safe_queue_test.cpp index 9b2e78f7..4a84d030 100644 --- a/src/tests/utils/safe_queue_test.cpp +++ b/src/tests/utils/safe_queue_test.cpp @@ -1,10 +1,17 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. - * Description: safe_queue test - * Author: MindX SDK - * Create: 2023 - * History: NA - */ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ #include #include diff --git a/tests/mx_rec/core/generator_dataset.py b/tests/mx_rec/core/generator_dataset.py index cf53f471..14e6e585 100644 --- a/tests/mx_rec/core/generator_dataset.py +++ b/tests/mx_rec/core/generator_dataset.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from typing import Callable diff --git a/tests/mx_rec/core/mock_class.py b/tests/mx_rec/core/mock_class.py index 58737058..491578fd 100644 --- a/tests/mx_rec/core/mock_class.py +++ b/tests/mx_rec/core/mock_class.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import tensorflow as tf from tensorflow_core.python.training import slot_creator diff --git a/tests/mx_rec/core/test_build_graph.py b/tests/mx_rec/core/test_build_graph.py index f6ac6e6d..41706ce7 100644 --- a/tests/mx_rec/core/test_build_graph.py +++ b/tests/mx_rec/core/test_build_graph.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import unittest from unittest import mock diff --git a/tests/mx_rec/core/test_embedding.py b/tests/mx_rec/core/test_embedding.py index 0ce32500..f5284b8d 100644 --- a/tests/mx_rec/core/test_embedding.py +++ b/tests/mx_rec/core/test_embedding.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os import unittest diff --git a/tests/mx_rec/core/test_feature_process.py b/tests/mx_rec/core/test_feature_process.py index 2acaf56b..91cdbe65 100644 --- a/tests/mx_rec/core/test_feature_process.py +++ b/tests/mx_rec/core/test_feature_process.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import time import unittest diff --git a/tests/mx_rec/core/test_feature_spec.py b/tests/mx_rec/core/test_feature_spec.py index 7adad830..bef5da5c 100644 --- a/tests/mx_rec/core/test_feature_spec.py +++ b/tests/mx_rec/core/test_feature_spec.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import unittest from unittest import mock diff --git a/tests/mx_rec/core/test_helper.py b/tests/mx_rec/core/test_helper.py index 3af6dc1a..3444b47b 100644 --- a/tests/mx_rec/core/test_helper.py +++ b/tests/mx_rec/core/test_helper.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os import shutil diff --git a/tests/mx_rec/core/test_manager.py b/tests/mx_rec/core/test_manager.py index 72920bdd..038ea60e 100644 --- a/tests/mx_rec/core/test_manager.py +++ b/tests/mx_rec/core/test_manager.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import unittest from unittest import mock diff --git a/tests/mx_rec/core/test_merge_table.py b/tests/mx_rec/core/test_merge_table.py index 56bbc357..01d6736e 100644 --- a/tests/mx_rec/core/test_merge_table.py +++ b/tests/mx_rec/core/test_merge_table.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import unittest from unittest import mock diff --git a/tests/mx_rec/data/__init__.py b/tests/mx_rec/data/__init__.py index 6924f767..a0260bd0 100644 --- a/tests/mx_rec/data/__init__.py +++ b/tests/mx_rec/data/__init__.py @@ -1,3 +1,16 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests/mx_rec/data/mock_class.py b/tests/mx_rec/data/mock_class.py index a3f5c416..08b50140 100644 --- a/tests/mx_rec/data/mock_class.py +++ b/tests/mx_rec/data/mock_class.py @@ -1,7 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== class MockEosOpsLib: """ diff --git a/tests/mx_rec/data/test_dataset.py b/tests/mx_rec/data/test_dataset.py index 4db7f63f..d1047e17 100644 --- a/tests/mx_rec/data/test_dataset.py +++ b/tests/mx_rec/data/test_dataset.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import unittest diff --git a/tests/mx_rec/graph/mock_dataset.py b/tests/mx_rec/graph/mock_dataset.py index 8148c397..517476d9 100644 --- a/tests/mx_rec/graph/mock_dataset.py +++ b/tests/mx_rec/graph/mock_dataset.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from typing import Dict, Iterable diff --git a/tests/mx_rec/graph/test_acg_push_ops.py b/tests/mx_rec/graph/test_acg_push_ops.py index f24ab78e..362eb296 100644 --- a/tests/mx_rec/graph/test_acg_push_ops.py +++ b/tests/mx_rec/graph/test_acg_push_ops.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== from unittest import TestCase from unittest.mock import patch, Mock diff --git a/tests/mx_rec/graph/test_merge_lookup.py b/tests/mx_rec/graph/test_merge_lookup.py index 1bf4311f..19bf0686 100644 --- a/tests/mx_rec/graph/test_merge_lookup.py +++ b/tests/mx_rec/graph/test_merge_lookup.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import unittest from typing import Union diff --git a/tests/mx_rec/graph/test_modifier.py b/tests/mx_rec/graph/test_modifier.py index 14d1633e..a648f5cf 100644 --- a/tests/mx_rec/graph/test_modifier.py +++ b/tests/mx_rec/graph/test_modifier.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os import unittest diff --git a/tests/mx_rec/graph/test_utils.py b/tests/mx_rec/graph/test_utils.py index 562679e5..0b810fac 100644 --- a/tests/mx_rec/graph/test_utils.py +++ b/tests/mx_rec/graph/test_utils.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import sys import os diff --git a/tests/mx_rec/saver/sparse_embedding_mock.py b/tests/mx_rec/saver/sparse_embedding_mock.py index 6a3d12df..e46dd86c 100644 --- a/tests/mx_rec/saver/sparse_embedding_mock.py +++ b/tests/mx_rec/saver/sparse_embedding_mock.py @@ -1,7 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. - +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os diff --git a/tests/mx_rec/saver/test_saver.py b/tests/mx_rec/saver/test_saver.py index 27bed04d..4df0c025 100644 --- a/tests/mx_rec/saver/test_saver.py +++ b/tests/mx_rec/saver/test_saver.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os import unittest diff --git a/tests/mx_rec/saver/test_sparse.py b/tests/mx_rec/saver/test_sparse.py index f827aa9c..34bb8419 100644 --- a/tests/mx_rec/saver/test_sparse.py +++ b/tests/mx_rec/saver/test_sparse.py @@ -1,6 +1,19 @@ #!/usr/bin/env python3 # coding: UTF-8 -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import os from unittest import mock diff --git a/tests/mx_rec/util/communication/test_hccl_mgmt.py b/tests/mx_rec/util/communication/test_hccl_mgmt.py index b9afb1c3..3da291f1 100644 --- a/tests/mx_rec/util/communication/test_hccl_mgmt.py +++ b/tests/mx_rec/util/communication/test_hccl_mgmt.py @@ -1,6 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import sys import unittest from unittest import mock diff --git a/tests/mx_rec/util/test_atomic.py b/tests/mx_rec/util/test_atomic.py index 2a6b37c2..dedf697d 100644 --- a/tests/mx_rec/util/test_atomic.py +++ b/tests/mx_rec/util/test_atomic.py @@ -1,6 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import unittest from threading import Thread diff --git a/tests/mx_rec/util/test_normalization.py b/tests/mx_rec/util/test_normalization.py index 7bb2c967..dea5b305 100644 --- a/tests/mx_rec/util/test_normalization.py +++ b/tests/mx_rec/util/test_normalization.py @@ -1,6 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import unittest from mx_rec.util.normalization import fix_invalid_table_name diff --git a/tests/mx_rec/util/test_perf.py b/tests/mx_rec/util/test_perf.py index 2b50a2ac..c295155f 100644 --- a/tests/mx_rec/util/test_perf.py +++ b/tests/mx_rec/util/test_perf.py @@ -1,6 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import unittest from mx_rec.util.perf import performance diff --git a/tests/mx_rec/util/test_variable.py b/tests/mx_rec/util/test_variable.py index b6ac5bd8..6b5e0c0e 100644 --- a/tests/mx_rec/util/test_variable.py +++ b/tests/mx_rec/util/test_variable.py @@ -1,6 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import unittest from unittest import mock from unittest.mock import patch diff --git a/tests/mx_rec/validator/test_validators.py b/tests/mx_rec/validator/test_validators.py index aa97c3b6..ed102b05 100644 --- a/tests/mx_rec/validator/test_validators.py +++ b/tests/mx_rec/validator/test_validators.py @@ -1,6 +1,20 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import os import sys import tempfile diff --git a/tests/run_python_dt.sh b/tests/run_python_dt.sh index 773cece7..a487fb04 100644 --- a/tests/run_python_dt.sh +++ b/tests/run_python_dt.sh @@ -1,10 +1,19 @@ #!/bin/bash -# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. -# Description: start script. -# Author: MindX SDK -# Create: 2023 -# History: NA +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== set -e -- Gitee From 9dc344a6c9a76a0bcc618f7f3277cdb2ae022276 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 16 Jan 2024 19:28:29 +0800 Subject: [PATCH 530/551] Match-id-a62e5b07ebbefd68a3bcabe55722c1ab5e532dd5 --- mx_rec/graph/patch.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 5c875281..1e5c8d6a 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -13,6 +13,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# Some code is derived from Tensorflow, which is subject to the following copyright notice: +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # ============================================================================== import weakref -- Gitee From 1e3d6b0d12cb069398b3e331853f947fffe4b3f3 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 18 Jan 2024 17:04:21 +0800 Subject: [PATCH 531/551] Match-id-9765a1f6acfc1a672e00c9d531d6f82409af8186 --- examples/demo/little_demo/config.py | 127 +++++ examples/demo/little_demo/dataset.py | 75 +++ examples/demo/little_demo/main.py | 311 ++++++++++++ examples/demo/little_demo/model.py | 108 +++++ examples/demo/little_demo/op_impl_mode.ini | 3 + examples/demo/little_demo/optimizer.py | 35 ++ .../demo/little_demo/random_data_generator.py | 65 +++ examples/demo/little_demo/run.sh | 174 +++++++ examples/demo/little_demo/run_mode.py | 196 ++++++++ examples/demo/little_demo_estimator/config.py | 92 ++++ .../demo/little_demo_estimator/dataset.py | 76 +++ examples/demo/little_demo_estimator/main.py | 204 ++++++++ .../little_demo_estimator/nn_model_build.py | 221 +++++++++ .../little_demo_estimator/nn_model_input.py | 68 +++ .../demo/little_demo_estimator/nn_optim.py | 103 ++++ .../demo/little_demo_estimator/nn_reader.py | 50 ++ .../little_demo_estimator/op_precision.ini | 3 + .../random_data_generator.py | 67 +++ examples/demo/little_demo_estimator/run.sh | 176 +++++++ .../demo/little_demo_estimator/tf_adapter.py | 23 + examples/demo/little_demo_estimator/utils.py | 72 +++ examples/dlrm/criteo_tb/README.md | 19 + examples/dlrm/criteo_tb/gen_ttf.py | 402 ++++++++++++++++ examples/dlrm/model/config.py | 234 +++++++++ examples/dlrm/model/delay_loss_scale.py | 48 ++ examples/dlrm/model/gradient_descent_w.py | 74 +++ examples/dlrm/model/main_mxrec.py | 448 ++++++++++++++++++ examples/dlrm/model/mean_auc.py | 40 ++ examples/dlrm/model/model.py | 94 ++++ examples/dlrm/model/op_impl_mode.ini | 1 + examples/dlrm/model/optimizer.py | 35 ++ examples/dlrm/model/run.sh | 116 +++++ mx_rec/graph/patch.py | 18 +- mx_rec/saver/patch.py | 16 +- 34 files changed, 3765 insertions(+), 29 deletions(-) create mode 100644 examples/demo/little_demo/config.py create mode 100644 examples/demo/little_demo/dataset.py create mode 100644 examples/demo/little_demo/main.py create mode 100644 examples/demo/little_demo/model.py create mode 100644 examples/demo/little_demo/op_impl_mode.ini create mode 100644 examples/demo/little_demo/optimizer.py create mode 100644 examples/demo/little_demo/random_data_generator.py create mode 100644 examples/demo/little_demo/run.sh create mode 100644 examples/demo/little_demo/run_mode.py create mode 100644 examples/demo/little_demo_estimator/config.py create mode 100644 examples/demo/little_demo_estimator/dataset.py create mode 100644 examples/demo/little_demo_estimator/main.py create mode 100644 examples/demo/little_demo_estimator/nn_model_build.py create mode 100644 examples/demo/little_demo_estimator/nn_model_input.py create mode 100644 examples/demo/little_demo_estimator/nn_optim.py create mode 100644 examples/demo/little_demo_estimator/nn_reader.py create mode 100644 examples/demo/little_demo_estimator/op_precision.ini create mode 100644 examples/demo/little_demo_estimator/random_data_generator.py create mode 100644 examples/demo/little_demo_estimator/run.sh create mode 100644 examples/demo/little_demo_estimator/tf_adapter.py create mode 100644 examples/demo/little_demo_estimator/utils.py create mode 100644 examples/dlrm/criteo_tb/README.md create mode 100644 examples/dlrm/criteo_tb/gen_ttf.py create mode 100644 examples/dlrm/model/config.py create mode 100644 examples/dlrm/model/delay_loss_scale.py create mode 100644 examples/dlrm/model/gradient_descent_w.py create mode 100644 examples/dlrm/model/main_mxrec.py create mode 100644 examples/dlrm/model/mean_auc.py create mode 100644 examples/dlrm/model/model.py create mode 100644 examples/dlrm/model/op_impl_mode.ini create mode 100644 examples/dlrm/model/optimizer.py create mode 100644 examples/dlrm/model/run.sh diff --git a/examples/demo/little_demo/config.py b/examples/demo/little_demo/config.py new file mode 100644 index 00000000..0e098a9c --- /dev/null +++ b/examples/demo/little_demo/config.py @@ -0,0 +1,127 @@ +# coding: UTF-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math +import tensorflow as tf + +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig + +from mx_rec.util.initialize import get_rank_size + + +class Config: + def __init__(self, mode="simple", task_name="default"): + self.task_name = task_name + if mode == "simple": + self.generate_simple_config() + else: + self.generate_large_scale_config() + + def generate_simple_config(self): + self.batch_number = 8192 + self.batch_size = 4096 + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.float32 + + self.item_range = 80000 * get_rank_size() + self.user_range = 200000 * get_rank_size() + self.category_range = 5000 * get_rank_size() + self.item_feat_cnt = 16 + self.user_feat_cnt = 8 + self.category_feat_cnt = 3 + self.access_threshold = 2 + self.eviction_threshold = 2 + + rank_size = get_rank_size() + coefficient = 1.1 + if rank_size != 0: + max_ui_send_cnt = max(self.item_feat_cnt, self.user_feat_cnt) + max_ui_range = max(self.item_range, self.user_range) + self.item_send_cnt = min(int(self.batch_size * self.item_feat_cnt * coefficient), + math.ceil(self.item_range / rank_size)) + self.item_vocab_size = max(self.item_send_cnt * rank_size * rank_size, self.item_range) + self.user_send_cnt = min(int(self.batch_size * max_ui_send_cnt * coefficient), + math.ceil(max_ui_range / rank_size)) + self.user_vocab_size = max(self.user_send_cnt * rank_size * rank_size, self.user_range) + self.category_send_cnt = min(int(self.batch_size * self.category_feat_cnt * coefficient), + math.ceil(self.category_range / rank_size)) + else: + raise ZeroDivisionError("rank size must be an integer which is greater value zero.") + + self.user_hashtable_dim = 32 + self.user_hashtable_threshold = 1 + self.item_hashtable_dim = 8 + self.item_hashtable_threshold = 1 + + self.learning_rate = 0.01 + + def generate_large_scale_config(self): + self.lookup_count = 40 + self.tensor_name_list = ["sparse_tensor_%d" % i for i in range(self.lookup_count)] + self.hashtable_name_list = ["hashtable_%d" % i for i in range(self.lookup_count)] + self.batch_size = 9600 + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.float32 + + self.vocabulary_size = 500000 + self.feat_cnt = 1 + + rank_size = get_rank_size() + coefficient = 1.1 + if rank_size != 0: + self.send_cnt = min(int(self.batch_size * self.feat_cnt * coefficient), + math.ceil(self.vocabulary_size / rank_size)) + else: + raise ZeroDivisionError("rank size must be an integer which is greater value zero.") + + self.hashtable_dim = 8 + self.learning_rate = 0.01 + + +def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): + session_config = tf.compat.v1.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["mix_compile_mode"].b = False + custom_op.parameter_map["use_off_line"].b = True + custom_op.parameter_map["min_group_size"].b = 1 + custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:pairwise;level1:pairwise") + custom_op.parameter_map["enable_data_pre_proc"].b = True + custom_op.parameter_map["iterations_per_loop"].i = 1 + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + custom_op.parameter_map["hcom_parallel"].b = False + custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_impl_mode.ini") + custom_op.parameter_map["op_execute_timeout"].i = 2000 + if dump_data: + """ + To see the details, please refer to the descriptions at official web site + """ + custom_op.parameter_map["enable_dump"].b = True + custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(dump_path) + custom_op.parameter_map["dump_step"].s = tf.compat.as_bytes(dump_steps) + custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all") + + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + return session_config diff --git a/examples/demo/little_demo/dataset.py b/examples/demo/little_demo/dataset.py new file mode 100644 index 00000000..d5ede53f --- /dev/null +++ b/examples/demo/little_demo/dataset.py @@ -0,0 +1,75 @@ +# coding: UTF-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf + +from random_data_generator import get_data_generator, get_large_scale_data_generator +from mx_rec.util.initialize import get_rank_size, get_rank_id, get_host_pipeline_ops + + +def generate_dataset(cfg, use_timestamp=False, batch_number=100): + dataset = tf.compat.v1.data.Dataset.from_generator( + generator=get_data_generator(cfg, batch_number=batch_number), + output_types={"item_ids": cfg.key_type, + "user_ids": cfg.key_type, + "category_ids": cfg.key_type, + "label_0": cfg.label_type, + "label_1": cfg.label_type}, + output_shapes={"item_ids": tf.TensorShape([cfg.batch_size, cfg.item_feat_cnt]), + "user_ids": tf.TensorShape([cfg.batch_size, cfg.user_feat_cnt]), + "category_ids": tf.TensorShape([cfg.batch_size, cfg.category_feat_cnt]), + "label_0": tf.TensorShape([cfg.batch_size]), + "label_1": tf.TensorShape([cfg.batch_size])}) + if use_timestamp: + dataset = dataset.map(add_timestamp_func) + + rank_size = get_rank_size() + rank_id = get_rank_id() + if rank_size > 1: + dataset = dataset.shard(rank_size, rank_id) + + return dataset + + +def add_timestamp_func(batch): + host_pipeline_ops = get_host_pipeline_ops() + timestamp = host_pipeline_ops.return_timestamp(tf.cast(batch['label_0'], tf.int64)) + batch["timestamp"] = timestamp + return batch + + +def generate_large_scale_data(cfg): + key_type_list = [cfg.key_type for _ in range(cfg.lookup_count)] + output_type_dict = dict(zip(cfg.tensor_name_list, key_type_list)) + output_type_dict["label_0"] = cfg.label_type + output_type_dict["label_1"] = cfg.label_type + + tensor_shape_list = [tf.TensorShape([cfg.batch_size]) for _ in range(cfg.lookup_count)] + output_shape_dict = dict(zip(cfg.tensor_name_list, tensor_shape_list)) + output_shape_dict["label_0"] = tf.TensorShape([cfg.batch_size]) + output_shape_dict["label_1"] = tf.TensorShape([cfg.batch_size]) + + dataset = tf.data.Dataset.from_generator(generator=get_large_scale_data_generator(cfg), + output_types=output_type_dict, + output_shapes=output_shape_dict) + rank_size = get_rank_size() + rank_id = get_rank_id() + if rank_size > 1: + dataset = dataset.shard(rank_size, rank_id) + + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + return batch, iterator diff --git a/examples/demo/little_demo/main.py b/examples/demo/little_demo/main.py new file mode 100644 index 00000000..3363c0fb --- /dev/null +++ b/examples/demo/little_demo/main.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import enum +import os +import shutil +import warnings +from glob import glob + +import tensorflow as tf +from config import Config +from dataset import generate_dataset +from optimizer import create_dense_and_sparse_optimizer +from model import MyModel +from run_mode import RunMode, UseMode + +from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.asc.helper import get_asc_insert_func +from mx_rec.core.asc.manager import start_asc_pipeline +from mx_rec.core.embedding import create_table, sparse_lookup +from mx_rec.graph.modifier import modify_graph_and_start_emb_cache +from mx_rec.constants.constants import ASCEND_TIMESTAMP +from mx_rec.util.initialize import get_rank_id, init, terminate_config_initializer, set_if_load, get_rank_size +from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.util.log import logger + +tf.compat.v1.disable_eager_execution() + +_SSD_SAVE_PATH = ["ssd_data"] + + +class CacheModeEnum(enum.Enum): + HBM = "HBM" + DDR = "DDR" + SSD = "SSD" + + +def make_batch_and_iterator(is_training, feature_spec_list=None, + use_timestamp=False, dump_graph=False, batch_number=100): + dataset = generate_dataset(cfg, use_timestamp=use_timestamp, batch_number=batch_number) + if not MODIFY_GRAPH_FLAG: + insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph) + dataset = dataset.map(insert_fn) + dataset = dataset.prefetch(100) + if USE_ONE_SHOT: + iterator = dataset.make_one_shot_iterator() + else: + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + return batch, iterator + + +def model_forward(input_list, batch, is_train, modify_graph, config_dict=None): + embedding_list = [] + feature_list, hash_table_list, send_count_list, is_grad_list, dim_list = input_list + for feature, hash_table, send_count, is_grad, dim in zip(feature_list, hash_table_list, send_count_list, + is_grad_list, dim_list): + access_and_evict_config = None + if isinstance(config_dict, dict): + access_and_evict_config = config_dict.get(hash_table.table_name) + embedding = sparse_lookup(hash_table, feature, send_count, is_train=is_train, + access_and_evict_config=access_and_evict_config, is_grad=is_grad, + name=hash_table.table_name + "_lookup", modify_graph=modify_graph, batch=batch, + serving_default_value = tf.ones(shape=(dim), dtype=tf.float32) * 2) + + reduced_embedding = tf.reduce_sum(embedding, axis=1, keepdims=False) + embedding_list.append(reduced_embedding) + + my_model = MyModel() + my_model(embedding_list, batch["label_0"], batch["label_1"]) + return my_model + + +def build_graph(hash_table_list, is_train, feature_spec_list=None, config_dict=None, batch_number=100): + batch, iterator = make_batch_and_iterator(is_train, feature_spec_list=feature_spec_list, + use_timestamp=USE_TIMESTAMP, dump_graph=is_train, + batch_number=batch_number) + if MODIFY_GRAPH_FLAG: + input_list = [[batch["user_ids"], batch["item_ids"]], + [hash_table_list[0], hash_table_list[1]], + [cfg.user_send_cnt, cfg.item_send_cnt], + [True, True], + [cfg.user_hashtable_dim, cfg.item_hashtable_dim]] + if use_multi_lookup: + # add `MULTI_LOOKUP_TIMES` times + for i, _ in enumerate(input_list): + input_list[i].extend([input_list[i][0]] * MULTI_LOOKUP_TIMES) + if USE_TIMESTAMP: + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, batch["timestamp"]) + model = model_forward(input_list, batch, + is_train=is_train, modify_graph=True, config_dict=config_dict) + else: + input_list = [feature_spec_list, + [hash_table_list[0], hash_table_list[1]], + [cfg.user_send_cnt, cfg.item_send_cnt], + [True, True], + [cfg.user_hashtable_dim, cfg.item_hashtable_dim]] + if use_multi_lookup: + # add `MULTI_LOOKUP_TIMES` times + for i, _ in enumerate(input_list): + if i == 0: + continue + input_list[i].extend([input_list[i][0]] * MULTI_LOOKUP_TIMES) + + model = model_forward(input_list, batch, + is_train=is_train, modify_graph=False, config_dict=config_dict) + + return iterator, model, batch + + +def create_feature_spec_list(use_timestamp=False): + access_threshold = cfg.access_threshold if use_timestamp else None + eviction_threshold = cfg.eviction_threshold if use_timestamp else None + feature_spec_list = [FeatureSpec("user_ids", table_name="user_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold, + faae_coefficient=1), + FeatureSpec("item_ids", table_name="item_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold, + faae_coefficient=4)] + if use_multi_lookup: + # add `MULTI_LOOKUP_TIMES` times + for _ in range(MULTI_LOOKUP_TIMES): + feature_spec_list.append(FeatureSpec("user_ids", table_name="user_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold, + faae_coefficient=1)) + if use_timestamp: + feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) + return feature_spec_list + + +def clear_saved_model(): + mode = UseMode.mapping(os.getenv("USE_MODE")) + if mode == UseMode.TRAIN: + logger.info("current mode is train, will delete previous saved model data if exist.") + save_model_path = os.path.join(os.getcwd(), "saved-model") + shutil.rmtree(save_model_path, ignore_errors=True) + if not (os.getenv("CACHE_MODE", "") == CacheModeEnum.SSD.value and mode == UseMode.TRAIN): + return + + # ssd not allow overwrite file, should clear it before training + logger.info("current cache mode is SSD, will delete previous saved ssd data if exist.") + for part_path in _SSD_SAVE_PATH: + if "/" not in part_path and "\\" not in part_path: + part_path = os.path.join(os.getcwd(), part_path) + shutil.rmtree(part_path, ignore_errors=True) + try: + os.mkdir(part_path) + except OSError: + logger.warning("ssd path has exist") # 多进程并行,忽略异常 + + +if __name__ == "__main__": + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + warnings.filterwarnings("ignore") + + use_mode = UseMode.mapping(os.getenv("USE_MODE")) + clear_saved_model() + # 最大训练的步数 + MAX_TRAIN_STEPS = 200 + # 训练多少步切换为评估 + TRAIN_STEPS = 100 + # 评估多少步切换为训练 + EVAL_STEPS = 10 + # 训练多少步进行保存 + SAVING_INTERVAL = 100 + + # get init configuration + try: + use_mpi = bool(int(os.getenv("USE_MPI", 1))) + use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) + use_hot = bool(int(os.getenv("USE_HOT", 0))) + use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) + use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 1))) + MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) + USE_TIMESTAMP = bool(int(os.getenv("USE_TIMESTAMP", 0))) + USE_ONE_SHOT = bool(int(os.getenv("USE_ONE_SHOT", 0))) + except ValueError as err: + raise ValueError(f"please correctly config USE_MPI or USE_DYNAMIC or USE_HOT or USE_DYNAMIC_EXPANSION or " + f"USE_MULTI_LOOKUP or USE_MODIFY_GRAPH or USE_TIMESTAMP or USE_ONE_SHOT " + f"only 0 or 1 is supported.") from err + + try: + MULTI_LOOKUP_TIMES = int(os.getenv("MULTI_LOOKUP_TIMES", 2)) + except ValueError as err: + raise ValueError(f"please correctly config MULTI_LOOKUP_TIMES only int is supported.") from err + + # nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0 + init(use_mpi=use_mpi, + train_steps=TRAIN_STEPS, + eval_steps=EVAL_STEPS, + save_steps=SAVING_INTERVAL, + use_dynamic=use_dynamic, + use_hot=use_hot, + use_dynamic_expansion=use_dynamic_expansion) + IF_LOAD = False + rank_id = get_rank_id() + file_list = glob(f"./saved-model/sparse-model-{rank_id}-*") + if file_list: + IF_LOAD = True + set_if_load(IF_LOAD) + + cfg = Config() + # multi lookup config, batch size: 32 * 128 = 4096 + if use_multi_lookup and MULTI_LOOKUP_TIMES > 2: + cfg.batch_size = 32 + + # access_threshold unit counts; eviction_threshold unit seconds + ACCESS_AND_EVICT = None + if USE_TIMESTAMP: + config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, + faae_coefficient=1) + config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold, + faae_coefficient=4) + ACCESS_AND_EVICT = dict(user_table=config_for_user_table, item_table=config_for_item_table) + train_feature_spec_list = None + eval_feature_spec_list = None + if not MODIFY_GRAPH_FLAG: + train_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) + eval_feature_spec_list = create_feature_spec_list(use_timestamp=USE_TIMESTAMP) + + optimizer_list = [create_dense_and_sparse_optimizer(cfg)] + sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] + + # 如需验证DDR模式,请按照key数量、batch unique数量合理设置device与host表大小。 + # 验证DDR的配置参考:建议跑dynamic避免调参。数据集key总量大于device表,小于device+host;一个batch的unique key数量小于device表。 + # 验证SSD的配置参考:建议跑dynamic避免调参。数据集key总量大于device+host;一个batch的unique key数量小于device表。 + hbm_test_cfg = {"device_vocabulary_size": cfg.user_vocab_size * 10 * get_rank_size(), "host_vocabulary_size": 0} + ddr_test_cfg = {"device_vocabulary_size": int(cfg.user_vocab_size * 0.5 * get_rank_size()), + "host_vocabulary_size": cfg.user_vocab_size * 10 * get_rank_size()} + ssd_test_cfg = { + "device_vocabulary_size": 60000 * get_rank_size(), "host_vocabulary_size": 60000 * get_rank_size(), + "ssd_vocabulary_size": 10000000, "ssd_data_path": _SSD_SAVE_PATH + } + cache_mode_dict = {CacheModeEnum.HBM.value: hbm_test_cfg, CacheModeEnum.DDR.value: ddr_test_cfg, + CacheModeEnum.SSD.value: ssd_test_cfg} + + cache_mode = os.getenv("CACHE_MODE") + if cache_mode not in cache_mode_dict.keys(): + raise ValueError(f"cache mode must in {list(cache_mode_dict.keys())}, get:{cache_mode}") + if cache_mode in ["DDR", "SSD"] and not use_dynamic: + logger.warning("when cache_mode in [DDR, SSD], suggest use_dynamic=true to avoid tuning size parameter") + + user_hashtable = create_table(key_dtype=tf.int64, + dim=tf.TensorShape([cfg.user_hashtable_dim]), + name='user_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + optimizer_list=sparse_optimizer_list, + all2all_gradients_op="sum_gradients_and_div_by_ranksize", + **cache_mode_dict[cache_mode]) + + item_hashtable = create_table(key_dtype=tf.int64, + dim=tf.TensorShape([cfg.item_hashtable_dim]), + name='item_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + optimizer_list=sparse_optimizer_list, + **cache_mode_dict[cache_mode]) + + # 在predict的场景下,train model不需要被执行 + train_iterator = None + train_model = None + train_batch = None + table_list = [user_hashtable, item_hashtable] + if use_mode == UseMode.TRAIN: + train_iterator, train_model, train_batch = build_graph(table_list, is_train=True, + feature_spec_list=train_feature_spec_list, + config_dict=ACCESS_AND_EVICT, + batch_number=MAX_TRAIN_STEPS * get_rank_size()) + eval_iterator, eval_model, eval_batch = build_graph(table_list, is_train=False, + feature_spec_list=eval_feature_spec_list, + config_dict=ACCESS_AND_EVICT, + batch_number=EVAL_STEPS * get_rank_size()) + dense_variables, sparse_variables = get_dense_and_sparse_variable() + + params = {"train_batch": train_batch, "eval_batch": eval_batch, "use_one_shot": USE_ONE_SHOT} + run_mode = RunMode( + MODIFY_GRAPH_FLAG, USE_TIMESTAMP, table_list, optimizer_list, train_model, eval_model, train_iterator, + eval_iterator, MAX_TRAIN_STEPS, EVAL_STEPS, params + ) + + # start host pipeline + if not MODIFY_GRAPH_FLAG: + start_asc_pipeline() + # start modify graph + if MODIFY_GRAPH_FLAG and use_mode != UseMode.TRAIN: + logger.info("start to modifying graph") + modify_graph_and_start_emb_cache(dump_graph=True) + + if use_mode == UseMode.TRAIN: + run_mode.train(TRAIN_STEPS, SAVING_INTERVAL) + elif use_mode == UseMode.PREDICT: + run_mode.predict() + + terminate_config_initializer() + logger.info("Demo done!") diff --git a/examples/demo/little_demo/model.py b/examples/demo/little_demo/model.py new file mode 100644 index 00000000..526f421c --- /dev/null +++ b/examples/demo/little_demo/model.py @@ -0,0 +1,108 @@ +# coding: UTF-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import print_function + +import tensorflow as tf + + +class MyModel: + def __init__(self): + self.layer_dims = [1024, 512, 256, 128] + self.act_func = 'relu' + self.keep_prob = 0.8 + self._lambda = 4.91e-7 + self.emb_dim = None + self.loss_list = [] + self.predict_list = [] + self.all_layer_dims = None + self.h_w, self.h_b = [], [] + self.h_w_head_0, self.h_w_head_1, self.h_b_head_0, self.h_b_head_1 = None, None, None, None + + def __call__(self, embedding_list, label_0, label_1, is_training=True): + with tf.compat.v1.variable_scope("mlp", reuse=tf.compat.v1.AUTO_REUSE): + embedding = tf.concat(embedding_list, axis=1) + self.emb_dim = embedding.shape.as_list()[-1] + self.all_layer_dims = [self.emb_dim] + self.layer_dims + [1] + + with tf.compat.v1.variable_scope("mlp", reuse=tf.compat.v1.AUTO_REUSE): + for i in range(len(self.all_layer_dims) - 2): + self.h_w.append(tf.compat.v1.get_variable('h%d_w' % (i + 1), shape=self.all_layer_dims[i: i + 2], + initializer=tf.random_uniform_initializer(-0.01, 0.01), + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"])) + self.h_b.append( + tf.compat.v1.get_variable('h%d_b' % (i + 1), shape=[self.all_layer_dims[i + 1]], + initializer=tf.zeros_initializer, + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"])) + i += 1 + self.h_w_head_0 = tf.compat.v1.get_variable('h_w_head_0', shape=self.all_layer_dims[i: i + 2], + initializer=tf.random_uniform_initializer(-0.01, 0.01), + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"]) + self.h_b_head_0 = tf.compat.v1.get_variable('h_b_head_0', shape=[self.all_layer_dims[i + 1]], + initializer=tf.zeros_initializer, + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"]) + self.h_w_head_1 = tf.compat.v1.get_variable('h_w_head_1', shape=self.all_layer_dims[i: i + 2], + initializer=tf.random_uniform_initializer(-0.01, 0.01), + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_wts"]) + self.h_b_head_1 = tf.compat.v1.get_variable('h_b_head_1', shape=[self.all_layer_dims[i + 1]], + initializer=tf.zeros_initializer, + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", "mlp_bias"]) + + logit_list = self.forward(embedding, self.act_func, self.keep_prob, training=is_training) + + for logit, label in zip(logit_list, (label_0, label_1)): + train_preds = tf.sigmoid(logit) + + basic_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=label) + + deep_loss = tf.reduce_mean(basic_loss) # + _lambda * tf.nn.l2_loss(embedding) + self.predict_list.append(train_preds) + self.loss_list.append(deep_loss) + + + def forward(self, embedding, act_func, keep_prob, training): + hidden_output = tf.reshape(embedding, [-1, self.emb_dim]) # *512 + for i, h_w_var in enumerate(self.h_w): + hidden_output = tf.matmul(self.activate(act_func, hidden_output), h_w_var) + hidden_output = hidden_output + self.h_b[i] + + def output_head(hidden_output, h_w, h_b): + hidden_output_branch = tf.matmul(self.activate(act_func, hidden_output), h_w) + logit = hidden_output_branch + h_b + logit = tf.reshape(logit, [-1, ]) + + return logit + + logit_0 = output_head(hidden_output, self.h_w_head_0, self.h_b_head_0) + logit_1 = output_head(hidden_output, self.h_w_head_1, self.h_b_head_1) + logit_list = [logit_0, logit_1] + + return logit_list + + @staticmethod + def activate(act_func, input_x): + if act_func == 'tanh': + return tf.tanh(input_x) + elif act_func == 'relu': + return tf.nn.relu(input_x) + else: + return tf.sigmoid(input_x) diff --git a/examples/demo/little_demo/op_impl_mode.ini b/examples/demo/little_demo/op_impl_mode.ini new file mode 100644 index 00000000..4a744500 --- /dev/null +++ b/examples/demo/little_demo/op_impl_mode.ini @@ -0,0 +1,3 @@ +ScatterNdAdd=support_out_of_bound_index +GatherV2=high_performance +UnsortedSegmentSum=high_performance \ No newline at end of file diff --git a/examples/demo/little_demo/optimizer.py b/examples/demo/little_demo/optimizer.py new file mode 100644 index 00000000..f6eeabdb --- /dev/null +++ b/examples/demo/little_demo/optimizer.py @@ -0,0 +1,35 @@ +# coding: UTF-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf + +from mx_rec.optimizers.lazy_adam import create_hash_optimizer +from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address +from mx_rec.util.initialize import get_use_dynamic_expansion +from mx_rec.util.log import logger + + +def create_dense_and_sparse_optimizer(cfg): + dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate) + use_dynamic_expansion = get_use_dynamic_expansion() + if use_dynamic_expansion: + sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate) + logger.info("optimizer lazy_adam_by_addr") + else: + sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate) + logger.info("optimizer lazy_adam") + + return dense_optimizer, sparse_optimizer diff --git a/examples/demo/little_demo/random_data_generator.py b/examples/demo/little_demo/random_data_generator.py new file mode 100644 index 00000000..628700d4 --- /dev/null +++ b/examples/demo/little_demo/random_data_generator.py @@ -0,0 +1,65 @@ +# coding: UTF-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np + +from mx_rec.util.initialize import get_rank_id +from mx_rec.util.log import logger + +def get_data_generator(config, batch_number): + rank_id = get_rank_id() + + def data_generator(): + i = 0 + while i < batch_number: + item_ids = np.random.randint(0, config.item_range, (config.batch_size, config.item_feat_cnt)) + user_ids = np.random.randint(0, config.user_range, (config.batch_size, config.user_feat_cnt)) + category_ids = np.random.randint(0, config.category_range, (config.batch_size, config.category_feat_cnt)) + label_0 = np.random.randint(0, 2, (config.batch_size,)) + label_1 = np.random.randint(0, 2, (config.batch_size,)) + + yield {"item_ids": item_ids, + "user_ids": user_ids, + "category_ids": category_ids, + "label_0": label_0, + "label_1": label_1} + i += 1 + + logger.debug(f"================ end of data generator for {config.task_name} task | rank id {rank_id} " + f"================") + + return data_generator + + +def get_large_scale_data_generator(config): + def data_generator(): + i = 0 + while True: + id_list = [np.random.randint(0, config.vocabulary_size, (config.batch_size,)) + for _ in range(config.lookup_count)] + + data_block = dict(zip(config.tensor_name_list, id_list)) + + label_0 = np.random.randint(0, 2, (config.batch_size,)) + label_1 = np.random.randint(0, 2, (config.batch_size,)) + data_block["label_0"] = label_0 + data_block["label_1"] = label_1 + + logger.debug(f"================ generate NO.{i} step ================") + yield data_block + i += 1 + + return data_generator diff --git a/examples/demo/little_demo/run.sh b/examples/demo/little_demo/run.sh new file mode 100644 index 00000000..41e285f4 --- /dev/null +++ b/examples/demo/little_demo/run.sh @@ -0,0 +1,174 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +kill -9 `ps -ef | grep python | grep -v grep | awk '{print $2}'` > /dev/null 2>&1 +rm -rf /root/ascend/log/* +rm -rf ./kernel* +rm -rf ./export_graph/* + +export USE_MODE="train" # 支持[train, predict] + +# cache mode support: HBM, DDR, SSD +export CACHE_MODE="HBM" + +# 获取输入参数:py、ip +if [ $# -ge 1 ]; then + py=$1 + ip=$2 +else + echo "for example: bash run.sh main.py 10.10.10.10 or bash run.sh main.py" + exit 1 +fi + +# 检查输入的python文件是否合法 +if [[ $py =~ ^[a-z0-9_]+\.py$ ]]; then + echo "File $py is a valid Python file" +else + echo "File $py is not a Python file" + exit 1 +fi + +# 判断IP地址是否有效 +if [ -n "$ip" ]; then + if [[ $ip =~ ^([0-9]{1,3}\.){3}[0-9]{1,3}$ ]]; then + # 将IP地址拆分成四个数字 + ip_array=(${ip//./ }) + # 判断每个数字是否在0-255之间 + valid=true + for i in "${ip_array[@]}"; do + if ((i < 0 || i > 255)); then + valid=false + break + fi + done + if $valid; then + echo "ip: $ip is valid" + else + echo "ip: $ip is not valid" + exit 1 + fi + else + echo "ip: $ip is not valid." + exit 1 + fi +fi + +cur_path=`pwd` +mx_rec_package_path="/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec" # please config +so_path=${mx_rec_package_path}/libasc +# GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=0 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' +interface="lo" +local_rank_size=8 # 每个节点使用的NPU卡数 +num_server=1 # 训练节点数 +num_process=$((${num_server} * ${local_rank_size})) # 训练总的进程数,等于使用的NPU卡的总数 + +export HCCL_CONNECT_TIMEOUT=1200 # HCCL集合通信 建链超时时间,取值范围[120,7200] +export PYTHONPATH=${so_path}:$PYTHONPATH # 环境python安装路径 +export LD_PRELOAD=/usr/lib64/libgomp.so.1 # GNU OpenMP动态库路径. 不应该使用LD_PRELOAD这种方式加载! +export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH +# 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 +export JOB_ID=10086 +# 训练任务使用的NPU卡数总数 +export MXREC_LOG_LEVEL="DEBUG" # 框架日志等级 +export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL +# 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 +export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL +export MXREC_MODE="ASC" +export USE_MPI=1 + +# 配置梯度策略 +apply_gradient_strategy="sum_same_id_gradients_and_apply" +#apply_gradient_strategy="direct_apply" +export APPLY_GRADIENTS_STRATEGY=${apply_gradient_strategy} + +################# 参数配置 ###################### +export USE_DYNAMIC=1 # 0:静态shape;1:动态shape +export USE_HOT=0 # 0:关闭hot emb;1: 开启hot emb +export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 +export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 +export MULTI_LOOKUP_TIMES=2 # 一表多查次数:默认2,上限127(因为一表已经有一查);仅当export USE_MULTI_LOOKUP=1时生效 +export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 +export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 +export USE_ONE_SHOT=0 # 0:MakeIterator;1:OneShotIterator +export UpdateEmb_V2=1 # 0: UpdateEmb同步更新;1:UpdateEmb_V2异步更新 +export USE_COMBINE_FAAE=0 # 0: separate history when faae; 1: combine history when faae +################# 性能调优相关 #################### +export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 +export FAST_UNIQUE=0 #if use fast unique +export MGMT_HBM_TASK_MODE=0 #if async h2d (get and send tensors) +################################################ + +# 帮助信息,不需要修改 +if [[ $1 == --help || $1 == -h ]];then + echo "Usage: ./run.sh [OPTION]... [IP]..." + echo " " + echo "parameter explain: + [OPTION] main.py + [IP] IP address of the host + -h/--help show help message + " + exit 1 +fi + +# 使用ranktable方案 +function rankTableSolution() { + echo "The ranktable solution" + export RANK_TABLE_FILE="${cur_path}/hccl_json_${local_rank_size}p.json" + export RANK_SIZE=$num_process + echo "RANK_TABLE_FILE=$RANK_TABLE_FILE" + if [ ! -f "$RANK_TABLE_FILE" ];then + echo "the rank table file does not exit. Please reference {hccl_json_${local_rank_size}p.json} to correctly config rank table file" + exit 1 + fi +} + +if [ ! -n "$ip" ]; then + rankTableSolution +else + VALID_CHECK=$(echo $ip|awk -F. '$1<=255&&$2<=255&&$3<=255&&$4<=255{print "yes"}') + if echo $ip|grep -E "^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$">/dev/null; then + if [ "$VALID_CHECK" == "yes" ]; then + #################使用去除ranktable方案时开启###################### + echo "ip: $ip available." + echo "The ranktable solution is removed." + export CM_CHIEF_IP=$ip # 主节点ip + export CM_CHIEF_PORT=6000 # 主节点监听端口 + export CM_CHIEF_DEVICE=0 # 主节点device id + export CM_WORKER_IP=$ip # 当前节点ip + export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 + echo "CM_CHIEF_IP=$CM_CHIEF_IP" + echo "CM_CHIEF_PORT=$CM_CHIEF_PORT" + echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" + echo "CM_WORKER_IP=$CM_WORKER_IP" + echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" + echo "ASCEND_VISIBLE_DEVICES=$ASCEND_VISIBLE_DEVICES" + ######################################################### + else + echo "ip: $ip not available!" # 使用ranktable方案 + rankTableSolution + fi + else + echo "ip: $ip not available!" # 使用ranktable方案 + rankTableSolution + fi +fi + +echo "use horovod to start tasks" +DATE=$(date +%Y-%m-%d-%H-%M-%S) +horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ +python3.7 ${py} 2>&1 | tee "temp_${local_rank_size}p_${KEY_PROCESS_THREAD_NUM}t_${USE_MODE}_${CACHE_MODE}_${DATE}.log" + diff --git a/examples/demo/little_demo/run_mode.py b/examples/demo/little_demo/run_mode.py new file mode 100644 index 00000000..c43ef4a6 --- /dev/null +++ b/examples/demo/little_demo/run_mode.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os + +import tensorflow as tf +from config import sess_config + +from mx_rec.util.initialize import get_initializer, get_rank_id, get_rank_size, clear_channel +from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.util.tf_version_adapter import hccl_ops +from mx_rec.constants.constants import BaseEnum +from mx_rec.graph.modifier import modify_graph_and_start_emb_cache +from mx_rec.constants.constants import ApplyGradientsStrategy +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.util.log import logger + + +class UseMode(BaseEnum): + TRAIN = "train" + PREDICT = "predict" + + +class RunMode: + + def __init__( + self, is_modify_graph: bool, is_faae: bool, table_list: list, optimizer_list: list, train_model, + eval_model, train_iterator, eval_iterator, max_train_steps: int, infer_steps: int, params: dict): + self.is_modify_graph = is_modify_graph + self.is_faae = is_faae + self.session = tf.compat.v1.Session(config=sess_config(dump_data=False)) + self.train_model = train_model + self.train_iterator = train_iterator + self.eval_model = eval_model + self.eval_iterator = eval_iterator + self.rank_id = get_rank_id() + self.train_ops = [] + self.table_list = table_list + self.optimizer_list = optimizer_list + self.epoch = 1 + self.max_train_steps = max_train_steps + self.infer_steps = infer_steps + self.use_one_shot = params.get("use_one_shot") + self.train_batch = params.get("train_batch") + self.eval_batch = params.get("eval_batch") + + def _infer(self): + if not self.use_one_shot: + initializer = self.eval_iterator.initializer if not self.is_modify_graph else get_initializer(False) + self.session.run(initializer) + else: + logger.debug(f"use one shot iterator and modify graph is `{self.is_modify_graph}`.") + clear_channel(is_train_channel=False) + + for i in range(1, self.infer_steps + 1): + logger.info("############### infer at step %d ################", i) + try: + self.session.run(self.eval_model.loss_list) + except tf.errors.OutOfRangeError: + logger.info(f"Encounter the end of Sequence for eval.") + break + + def set_train_ops(self): + dense_variables, sparse_variables = get_dense_and_sparse_variable() + + # multi task training + for loss, (dense_optimizer, sparse_optimizer) in zip(self.train_model.loss_list, self.optimizer_list): + # do dense optimization + grads = dense_optimizer.compute_gradients(loss, var_list=dense_variables) + avg_grads = [] + for grad, var in grads: + if get_rank_size() > 1: + grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None + if grad is not None: + avg_grads.append((grad, var)) + # apply gradients: update variables + self.train_ops.append(dense_optimizer.apply_gradients(avg_grads)) + + if bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))): + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ + ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + + train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) + + if (ApplyGradientsStrategy.mapping(os.getenv("APPLY_GRADIENTS_STRATEGY")) == + ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY): + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) + else: + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + + # do sparse optimization by addr + local_grads = tf.gradients(loss, train_emb_list) # local_embedding + grads_and_vars = [(grad, address) for grad, address in zip(local_grads, train_address_list)] + self.train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + else: + # do sparse optimization + sparse_grads = tf.gradients(loss, sparse_variables) + grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] + self.train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + + def train(self, train_interval: int, saving_interval: int): + self.set_train_ops() + # In train mode, graph modify needs to be performed after compute gradients + if self.is_modify_graph: + logger.info("start to modifying graph") + modify_graph_and_start_emb_cache(dump_graph=True) + + if not self.use_one_shot: + initializer = self.train_iterator.initializer if not self.is_modify_graph else get_initializer(True) + self.session.run(initializer) + else: + logger.debug(f"use one shot iterator and modify graph is `{self.is_modify_graph}`.") + + self.session.run(tf.compat.v1.global_variables_initializer()) + self.saver = tf.compat.v1.train.Saver() + for i in range(1, self.max_train_steps + 1): + logger.info("################ training at step %d ################", i) + try: + self.session.run([self.train_ops, self.train_model.loss_list]) + except tf.errors.OutOfRangeError: + logger.info(f"Encounter the end of Sequence for training.") + break + else: + for t in self.table_list: + logger.info(f"training at step:{i}, table[{t.table_name}], table size:{t.size()}, " + f"table capacity:{t.capacity()}") + + if i % train_interval == 0: + self.evaluate() + + if i % saving_interval == 0: + self.saver.save(self.session, f"./saved-model/model-{self.rank_id}", global_step=i) + + if self.is_faae and i == train_interval // 2: + logger.info("############### set_threshold at step:%d ################", i) + self.change_threshold() + + # save last step without duplication + if i % saving_interval != 0: + self.saver.save(self.session, f"./saved-model/model-{self.rank_id}", global_step=i) + + logger.info("################ training end ################") + + def evaluate(self): + logger.info("############### start evaluate, epoch:%d ################", self.epoch) + self._infer() + logger.info("############### evaluate end, epoch::%d ################", self.epoch) + self.epoch += 1 + + def predict(self): + logger.info(f"############### start predict ################") + import glob + import re + + model_file = glob.glob(f"./saved-model/sparse-model-{self.rank_id}-*") + if len(model_file) == 0: + raise ValueError("model file not exit") + + # get the latest model + pattern = f".*sparse-model-{self.rank_id}-([0-9]+).*" + latest_step = -1 + for file_path in model_file: + match = re.match(pattern, file_path) + if match and match.groups(): + step = int(match.groups()[0]) + + if step > latest_step: + latest_step = step + if latest_step == -1: + raise RuntimeError("latest model not found") + + self.saver = tf.compat.v1.train.Saver() + self.saver.restore(self.session, f"./saved-model/model-{self.rank_id}-{latest_step}") + self._infer() + logger.info(f"############### predict end ################") + + def change_threshold(self): + thres_tensor = tf.constant(60, dtype=tf.int32) + set_threshold_op = ConfigInitializer.get_instance().host_pipeline_ops. \ + set_threshold(thres_tensor, emb_name=self.table_list[0].table_name, + ids_name=self.table_list[0].table_name + "_lookup") + self.session.run([set_threshold_op]) diff --git a/examples/demo/little_demo_estimator/config.py b/examples/demo/little_demo_estimator/config.py new file mode 100644 index 00000000..beb18942 --- /dev/null +++ b/examples/demo/little_demo_estimator/config.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math + +import tensorflow as tf +from mx_rec.util.initialize import get_rank_size + + +class Config: + def __init__(self, mode="simple", task_name="default"): + self.task_name = task_name + if mode == "simple": + self.generate_simple_config() + else: + self.generate_large_scale_config() + + def generate_simple_config(self): + self.batch_numbers = 8192 + self.batch_size = 4096 + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.float32 + + self.item_range = 10000 + self.user_range = 200000 + self.category_range = 5000 + self.item_feat_cnt = 16 + self.user_feat_cnt = 8 + self.category_feat_cnt = 3 + self.access_threshold = 100 + self.eviction_threshold = 60 + + rank_size = get_rank_size() + coefficient = 1.1 + if rank_size != 0: + self.item_send_cnt = min(int(self.batch_size * self.item_feat_cnt * coefficient), + math.ceil(self.item_range / rank_size)) + self.item_vocab_size = max(self.item_send_cnt * rank_size * rank_size, self.item_range) + self.user_send_cnt = min(int(self.batch_size * self.user_feat_cnt * coefficient), + math.ceil(self.user_range / rank_size)) + self.user_vocab_size = max(self.user_send_cnt * rank_size * rank_size, self.user_range) + self.category_send_cnt = min(int(self.batch_size * self.category_feat_cnt * coefficient), + math.ceil(self.category_range / rank_size)) + else: + raise ZeroDivisionError("rank size must be an integer which is greater value zero.") + + self.user_hashtable_dim = 32 + self.user_hashtable_threshold = 1 + self.item_hashtable_dim = 8 + self.item_hashtable_threshold = 1 + + self.learning_rate = 0.01 + + def generate_large_scale_config(self): + self.lookup_count = 40 + self.tensor_name_list = ["sparse_tensor_%d" % i for i in range(self.lookup_count)] + self.hashtable_name_list = ["hashtable_%d" % i for i in range(self.lookup_count)] + self.batch_size = 9600 + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.float32 + + self.vocabulary_size = 500000 + self.feat_cnt = 1 + + rank_size = get_rank_size() + coefficient = 1.1 + if rank_size != 0: + self.send_cnt = min(int(self.batch_size * self.feat_cnt * coefficient), + math.ceil(self.vocabulary_size / rank_size)) + else: + raise ZeroDivisionError("rank size must be an integer which is greater value zero.") + + self.hashtable_dim = 8 + self.learning_rate = 0.01 diff --git a/examples/demo/little_demo_estimator/dataset.py b/examples/demo/little_demo_estimator/dataset.py new file mode 100644 index 00000000..c052eca1 --- /dev/null +++ b/examples/demo/little_demo_estimator/dataset.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf + +from random_data_generator import get_data_generator, get_large_scale_data_generator +from mx_rec.util.initialize import get_rank_size, get_rank_id, get_host_pipeline_ops + + +def generate_dataset(cfg, use_timestamp=False, batch_number=100): + dataset = tf.compat.v1.data.Dataset.from_generator( + generator=get_data_generator(cfg, batch_number=batch_number), + output_types={"item_ids": cfg.key_type, + "user_ids": cfg.key_type, + "category_ids": cfg.key_type, + "label_0": cfg.label_type, + "label_1": cfg.label_type}, + output_shapes={"item_ids": tf.TensorShape([cfg.batch_size, cfg.item_feat_cnt]), + "user_ids": tf.TensorShape([cfg.batch_size, cfg.user_feat_cnt]), + "category_ids": tf.TensorShape([cfg.batch_size, cfg.category_feat_cnt]), + "label_0": tf.TensorShape([cfg.batch_size]), + "label_1": tf.TensorShape([cfg.batch_size])}) + if use_timestamp: + dataset = dataset.map(add_timestamp_func) + + rank_size = get_rank_size() + rank_id = get_rank_id() + if rank_size > 1: + dataset = dataset.shard(rank_size, rank_id) + + return dataset + + +def add_timestamp_func(batch): + host_pipeline_ops = get_host_pipeline_ops() + timestamp = host_pipeline_ops.return_timestamp(tf.cast(batch['label_0'], tf.int64)) + batch["timestamp"] = timestamp + return batch + + +def generate_large_scale_data(cfg): + key_type_list = [cfg.key_type for _ in range(cfg.lookup_count)] + output_type_dict = dict(zip(cfg.tensor_name_list, key_type_list)) + output_type_dict["label_0"] = cfg.label_type + output_type_dict["label_1"] = cfg.label_type + + tensor_shape_list = [tf.TensorShape([cfg.batch_size]) for _ in range(cfg.lookup_count)] + output_shape_dict = dict(zip(cfg.tensor_name_list, tensor_shape_list)) + output_shape_dict["label_0"] = tf.TensorShape([cfg.batch_size]) + output_shape_dict["label_1"] = tf.TensorShape([cfg.batch_size]) + + dataset = tf.data.Dataset.from_generator(generator=get_large_scale_data_generator(cfg), + output_types=output_type_dict, + output_shapes=output_shape_dict) + rank_size = get_rank_size() + rank_id = get_rank_id() + if rank_size > 1: + dataset = dataset.shard(rank_size, rank_id) + + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + return batch, iterator diff --git a/examples/demo/little_demo_estimator/main.py b/examples/demo/little_demo_estimator/main.py new file mode 100644 index 00000000..1a8852fb --- /dev/null +++ b/examples/demo/little_demo_estimator/main.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import os + +import tensorflow as tf +from mx_rec.util.initialize import init, get_rank_id, terminate_config_initializer +from mx_rec.core.asc.helper import FeatureSpec +from mx_rec.graph.modifier import GraphModifierHook +from mx_rec.graph.acg_push_ops import ACGPushOpsToDatasetHook +from mx_rec.core.feature_process import EvictHook +from mx_rec.util.log import logger + +from tf_adapter import NPURunConfig, NPUEstimator, npu_hooks_append, DumpConfig +from nn_reader import input_fn +from nn_model_input import get_model_fn +from config import Config +from utils import FeatureSpecIns + +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) + + +def main(params, cfg): + mg_session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False) + run_config = NPURunConfig( + model_dir=params.model_dir, + save_summary_steps=1000, # tf.summary运行周期 + save_checkpoints_steps=params.save_checkpoints_steps, + keep_checkpoint_max=5, + session_config=mg_session_config, + log_step_count_steps=1000, # tf.logging运行周期 + precision_mode='allow_mix_precision', + enable_data_pre_proc=True, + iterations_per_loop=1, + op_precision_mode='./op_precision.ini', # high performance + op_compiler_cache_mode="enable", + op_compiler_cache_dir="./op_cache", + HCCL_algorithm="level0:pairwise;level1:pairwise" + ) + + # access_threshold unit counts; eviction_threshold unit seconds + access_and_evict = None + + if not params.enable_push_ops_test: + hooks_list = [GraphModifierHook(modify_graph=params.modify_graph)] + else: + hooks_list = [ACGPushOpsToDatasetHook(dump_graph=True), GraphModifierHook(modify_graph=params.modify_graph)] + + if params.use_timestamp: + config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) + config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) + access_and_evict = dict(user_table=config_for_user_table, item_table=config_for_item_table) + + evict_hook = EvictHook(evict_enable=True, evict_time_interval=10) + hooks_list.append(evict_hook) + create_fs_params = dict(cfg=cfg, use_timestamp=params.use_timestamp, + use_multi_lookup=use_multi_lookup, multi_lookup_times=MULTI_LOOKUP_TIMES) + est = NPUEstimator( + model_fn=get_model_fn(create_fs_params, cfg, access_and_evict), + params=params, + model_dir=params.model_dir, + config=run_config + ) + + if params.run_mode == 'train': + est.train(input_fn=lambda: input_fn(params, create_fs_params, cfg), max_steps=params.max_steps, + hooks=npu_hooks_append(hooks_list)) + + elif params.run_mode == 'train_and_evaluate': + train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(params, create_fs_params, cfg, + use_one_shot=args.use_one_shot), + max_steps=params.max_steps, hooks=npu_hooks_append(hooks_list)) + # 在开启evict时,eval时不支持淘汰,所以无需加入evict hook + + if not params.enable_push_ops_test: + eval_hook_list = [GraphModifierHook(modify_graph=params.modify_graph)] + else: + eval_hook_list = [ACGPushOpsToDatasetHook(dump_graph=True), + GraphModifierHook(modify_graph=params.modify_graph)] + + eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(params, create_fs_params, cfg, is_eval=True, + use_one_shot=args.use_one_shot), + steps=params.eval_steps, hooks=npu_hooks_append(eval_hook_list), + throttle_secs=0) + tf.estimator.train_and_evaluate(est, train_spec=train_spec, eval_spec=eval_spec) + + elif params.run_mode == 'predict': + results = est.predict(input_fn=lambda: input_fn(params, create_fs_params, cfg), + hooks=npu_hooks_append(hooks_list=hooks_list), yield_single_examples=False) + output_pred1 = [] + output_pred2 = [] + labels = [] + + for res in results: + output_pred1.append(res['task_1'][0]) + output_pred2.append(res['task_2'][0]) + labels.append(res['label'][0]) + + terminate_config_initializer() + logger.info("Demo done!") + + +def create_feature_spec_list(use_timestamp=False): + access_threshold = cfg.access_threshold if use_timestamp else None + eviction_threshold = cfg.eviction_threshold if use_timestamp else None + feature_spec_list = [FeatureSpec("user_ids", table_name="user_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold), + FeatureSpec("item_ids", table_name="item_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold)] + if use_multi_lookup: + # add `MULTI_LOOKUP_TIMES` times + for _ in range(MULTI_LOOKUP_TIMES): + feature_spec_list.append(FeatureSpec("user_ids", table_name="user_table", + access_threshold=access_threshold, + eviction_threshold=eviction_threshold, + faae_coefficient=1)) + if use_timestamp: + feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) + return feature_spec_list + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--run_mode', type=str, default='train_and_evaluate') # 运行模式,在run.sh中进行配置 + parser.add_argument('--model_ckpt_dir', type=str, default='') + parser.add_argument('--learning_rate', type=float, default=0.0008) + parser.add_argument('--use_timestamp', type=bool, default=False) # 是否开启特征准入与淘汰 + parser.add_argument('--modify_graph', type=bool, default=False) # 是否开启自动改图 + parser.add_argument('--use_multi_lookup', type=bool, default=True) # 是否一表多查 + parser.add_argument('--multi_lookup_times', type=int, default=2) # 一表多查次数 + parser.add_argument('--max_steps', type=int, default=200) # train的最大步数 + parser.add_argument('--train_steps', type=int, default=100) # 训练train_steps步后进行eval + parser.add_argument('--eval_steps', type=int, default=10) # 每次eval的步数 + # 每隔step保存一次模型, 若在train_and_evaluate模式, 还会进行eval, 注: 若设为None, NPURunConfig内部会设默认值100 + parser.add_argument('--save_checkpoints_steps', type=int, default=200) + parser.add_argument('--use_one_shot', type=bool, default=False) # 是否使用one shot iterator + + args, unknowns = parser.parse_known_args() + # get init configuration + try: + use_mpi = bool(int(os.getenv("USE_MPI", 1))) + use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) + use_hot = bool(int(os.getenv("USE_HOT", 0))) + use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) + use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 1))) + MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) + USE_TIMESTAMP = bool(int(os.getenv("USE_TIMESTAMP", 0))) + args.use_one_shot = bool(int(os.getenv("USE_ONE_SHOT", 0))) + args.enable_push_ops_test = bool(int(os.getenv("ENABLE_PUSH_OPS_TEST", 0))) + except ValueError as err: + raise ValueError(f"please correctly config USE_MPI or USE_DYNAMIC or USE_HOT or USE_DYNAMIC_EXPANSION or " + f"USE_MULTI_LOOKUP or USE_MODIFY_GRAPH or USE_TIMESTAMP or USE_ONE_SHOT " + f"only 0 or 1 is supported.") from err + + try: + MULTI_LOOKUP_TIMES = int(os.getenv("MULTI_LOOKUP_TIMES", 2)) + except ValueError as err: + raise ValueError(f"please correctly config MULTI_LOOKUP_TIMES only int is supported.") from err + + if args.run_mode == 'train': + args.train_steps = -1 + args.eval_steps = -1 + elif args.run_mode == 'predict': + args.eval_steps = -1 + elif args.run_mode == 'train_and_evaluate': + args.save_checkpoints_steps = args.train_steps + + # set init + init(use_mpi=use_mpi, + train_steps=args.train_steps, + eval_steps=args.eval_steps, + use_dynamic=use_dynamic, + use_hot=use_hot, + use_dynamic_expansion=use_dynamic_expansion) + + args.model_dir = f"{args.model_ckpt_dir}_rank{get_rank_id()}" + args.modify_graph = MODIFY_GRAPH_FLAG + args.use_timestamp = USE_TIMESTAMP + args.use_multi_lookup = use_multi_lookup + args.multi_lookup_times = MULTI_LOOKUP_TIMES + cfg = Config() + # multi lookup config, batch size: 32 * 128 = 4096 + if use_multi_lookup and MULTI_LOOKUP_TIMES > 2: + cfg.batch_size = 32 + # init FeatureSpecIns + FeatureSpecIns.set_instance() + main(args, cfg) diff --git a/examples/demo/little_demo_estimator/nn_model_build.py b/examples/demo/little_demo_estimator/nn_model_build.py new file mode 100644 index 00000000..e715f930 --- /dev/null +++ b/examples/demo/little_demo_estimator/nn_model_build.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from tensorflow import Tensor +from mx_rec.util.tf_version_adapter import npu_ops +from mx_rec.core.embedding import create_table, sparse_lookup +from mx_rec.constants.constants import ASCEND_TIMESTAMP + +from nn_optim import get_dense_and_sparse_optimizer +from utils import FeatureSpecIns + + +class LittleModel: + def __init__(self, params, cfg, mode, features, create_fs_params=None, access_and_evict_config_dict=None): + self.layer_dims = [1024, 512, 256, 128] + self.act_func = 'relu' + self.keep_prob = 0.8 + self._lambda = 4.91e-7 + self.emb_dim = None + self.loss_list = [] + self.predict_list = [] + self.all_layer_dims = None + self.h_w, self.h_b = [], [] + self.h_w_head_0, self.h_w_head_1, self.h_b_head_0, self.h_b_head_1 = None, None, None, None + + self.is_train = mode == tf.estimator.ModeKeys.TRAIN + self.cfg = cfg + self.params = params + self.features = features + self.create_fs_params = create_fs_params + self.access_and_evict_config_dict = access_and_evict_config_dict + + @staticmethod + def activate(act_func, input_x): + if act_func == 'tanh': + return tf.tanh(input_x) + elif act_func == 'relu': + return tf.nn.relu(input_x) + else: + return tf.sigmoid(input_x) + + def inference(self, label_0, label_1): + with tf.compat.v1.variable_scope("mlp", reuse=tf.compat.v1.AUTO_REUSE): + embedding_list = self._get_embedding_list() + embedding = tf.concat(embedding_list, axis=1) + self.emb_dim = embedding.shape.as_list()[-1] + self.all_layer_dims = [self.emb_dim] + self.layer_dims + [1] + + with tf.compat.v1.variable_scope("mlp", reuse=tf.compat.v1.AUTO_REUSE): + for i in range(len(self.all_layer_dims) - 2): + self.h_w.append(tf.compat.v1.get_variable('h%d_w' % (i + 1), shape=self.all_layer_dims[i: i + 2], + initializer=tf.random_uniform_initializer(-0.01, 0.01), + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, + "deep", "mlp_wts"])) + self.h_b.append( + tf.compat.v1.get_variable('h%d_b' % (i + 1), shape=[self.all_layer_dims[i + 1]], + initializer=tf.compat.v1.zeros_initializer, + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, "deep", + "mlp_bias"])) + i += 1 + self.h_w_head_0 = tf.compat.v1.get_variable('h_w_head_0', shape=self.all_layer_dims[i: i + 2], + initializer=tf.compat.v1.random_uniform_initializer(-0.01, + 0.01), + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, + "deep", "mlp_wts"]) + self.h_b_head_0 = tf.compat.v1.get_variable('h_b_head_0', shape=[self.all_layer_dims[i + 1]], + initializer=tf.compat.v1.zeros_initializer, + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, + "deep", "mlp_bias"]) + self.h_w_head_1 = tf.compat.v1.get_variable('h_w_head_1', shape=self.all_layer_dims[i: i + 2], + initializer=tf.compat.v1.random_uniform_initializer(-0.01, + 0.01), + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, + "deep", "mlp_wts"]) + self.h_b_head_1 = tf.compat.v1.get_variable('h_b_head_1', shape=[self.all_layer_dims[i + 1]], + initializer=tf.compat.v1.zeros_initializer, + dtype=tf.float32, + collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, + "deep", "mlp_bias"]) + + logit_list = self._forward(embedding, self.act_func, self.keep_prob, training=self.is_train) + + for logit, label in zip(logit_list, (label_0, label_1)): + train_preds = tf.sigmoid(logit) + basic_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=label) + self.predict_list.append(train_preds) + self.loss_list.append(basic_loss) + + loss = tf.reduce_mean(self.loss_list, keepdims=False) + return loss, self.predict_list + + def _forward(self, embedding, act_func, keep_prob, training): + hidden_output = tf.reshape(embedding, [-1, self.emb_dim]) # *512 + for i, h_w_var in enumerate(self.h_w): + if training: + hidden_output = tf.matmul(npu_ops.dropout(self.activate(act_func, hidden_output), + keep_prob=keep_prob), h_w_var) + else: + hidden_output = tf.matmul(self.activate(act_func, hidden_output), h_w_var) + hidden_output = hidden_output + self.h_b[i] + + def output_head(hidden_output, h_w, h_b): + if training: + hidden_output_branch = tf.matmul(npu_ops.dropout(self.activate(act_func, hidden_output), + keep_prob=keep_prob), h_w) + else: + hidden_output_branch = tf.matmul(self.activate(act_func, hidden_output), h_w) + logit = hidden_output_branch + h_b + logit = tf.reshape(logit, [-1, ]) + + return logit + + logit_0 = output_head(hidden_output, self.h_w_head_0, self.h_b_head_0) + logit_1 = output_head(hidden_output, self.h_w_head_1, self.h_b_head_1) + logit_list = [logit_0, logit_1] + + return logit_list + + def _get_embedding_list(self): + optimizer_list = [get_dense_and_sparse_optimizer(self.cfg)] + sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] + user_hashtable = create_table(key_dtype=tf.int64, + dim=tf.TensorShape([self.cfg.user_hashtable_dim]), + name='user_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + device_vocabulary_size=self.cfg.user_vocab_size * 10, + host_vocabulary_size=self.cfg.user_vocab_size * 0, + optimizer_list=sparse_optimizer_list) + item_hashtable = create_table(key_dtype=tf.int64, + dim=tf.TensorShape([self.cfg.item_hashtable_dim]), + name='item_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + device_vocabulary_size=self.cfg.item_vocab_size * 10, + host_vocabulary_size=self.cfg.item_vocab_size * 0, + optimizer_list=sparse_optimizer_list) + + if self.params.modify_graph: + if not self.params.enable_push_ops_test: + input_list = [[self.features["user_ids"], self.features["item_ids"]], + [user_hashtable, item_hashtable], + [self.cfg.user_send_cnt, self.cfg.item_send_cnt], + [True, True]] + else: + const_ids = _make_ids_with_const_ops(self.features["user_ids"]) + str_ids = _make_ids_with_str_ops(self.features["item_ids"]) + input_list = [[const_ids, str_ids], + [user_hashtable, item_hashtable], + [self.cfg.user_send_cnt, self.cfg.item_send_cnt], + [True, True]] + + if self.params.use_multi_lookup: + # add `MULTI_LOOKUP_TIMES` times + input_list[0].extend([self.features["user_ids"]] * self.params.multi_lookup_times) + input_list[1].extend([user_hashtable] * self.params.multi_lookup_times) + input_list[2].extend([self.cfg.user_send_cnt] * self.params.multi_lookup_times) + input_list[3].extend([False] * self.params.multi_lookup_times) + if self.params.use_timestamp: + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, self.features["timestamp"]) + else: + if self.is_train: + feature_spec_list = FeatureSpecIns.get_instance().get_train_feature_spec_list() + else: + feature_spec_list = FeatureSpecIns.get_instance().get_eval_feature_spec_list() + input_list = [feature_spec_list, + [user_hashtable, item_hashtable], + [self.cfg.user_send_cnt, self.cfg.item_send_cnt], + [True, True]] + if self.params.use_multi_lookup: + # add `MULTI_LOOKUP_TIMES` times + input_list[1].extend([user_hashtable] * self.params.multi_lookup_times) + input_list[2].extend([self.cfg.user_send_cnt] * self.params.multi_lookup_times) + input_list[3].extend([False] * self.params.multi_lookup_times) + + embedding_list = [] + feature_list, hash_table_list, send_count_list, is_grad_list = input_list + for feature, hash_table, send_count, is_grad in zip(feature_list, hash_table_list, send_count_list, is_grad_list): + access_and_evict_config = None + if isinstance(self.access_and_evict_config_dict, dict): + access_and_evict_config = self.access_and_evict_config_dict.get(hash_table.table_name) + embedding = sparse_lookup(hash_table, feature, send_count, dim=None, is_train=self.is_train, is_grad=is_grad, + name=hash_table.table_name + "_lookup", modify_graph=self.params.modify_graph, + access_and_evict_config=access_and_evict_config, batch=self.features) + + reduced_embedding = tf.reduce_sum(embedding, axis=1, keepdims=False) + embedding_list.append(reduced_embedding) + + return embedding_list + + +def _make_ids_with_const_ops(input: Tensor) -> Tensor: + const_ids = tf.constant(1, shape=input.shape, dtype=input.dtype) + const_ids = tf.compat.v1.add(const_ids, 1) + const_ids = tf.compat.v1.subtract(const_ids, 1) + + return const_ids + +def _make_ids_with_str_ops(input: Tensor) -> Tensor: + str_ids = tf.compat.v1.strings.as_string(input) + str_ids = tf.compat.v1.strings.to_number(str_ids) + + return str_ids diff --git a/examples/demo/little_demo_estimator/nn_model_input.py b/examples/demo/little_demo_estimator/nn_model_input.py new file mode 100644 index 00000000..2ce70d41 --- /dev/null +++ b/examples/demo/little_demo_estimator/nn_model_input.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from mx_rec.constants.constants import ASCEND_TIMESTAMP + +from nn_model_build import LittleModel +from nn_optim import get_train_op +from mx_rec.util.log import logger + + +def get_model_fn(create_fs_params, cfg, access_and_evict_config_dict=None): + def model_fn(features, labels, mode, params): + if params.modify_graph: + if params.use_timestamp: + model = LittleModel(params, cfg, mode, features, create_fs_params, + access_and_evict_config_dict=access_and_evict_config_dict) + tf.add_to_collection(ASCEND_TIMESTAMP, features["timestamp"]) + else: + model = LittleModel(params, cfg, mode, features, create_fs_params) + else: + model = LittleModel(params, cfg, mode, features, create_fs_params) + + loss, prediction = model.inference(features["label_0"], features["label_1"]) + + loss_dict = {} + if mode == tf.estimator.ModeKeys.TRAIN: + logger.info(f"use estimator train mode") + loss_dict['loss'] = [['train_loss', loss]] + return tf.estimator.EstimatorSpec(mode=mode, + loss=loss, + train_op=get_train_op(params, loss_dict.get('loss'))) + + if mode == tf.estimator.ModeKeys.EVAL: + logger.info("use estimator eval mode") + return tf.estimator.EstimatorSpec(mode=mode, + loss=loss) + + if mode == tf.estimator.ModeKeys.PREDICT: + logger.info("use estimator predict mode") + loss_dict['task_1'] = prediction[0] + + loss_dict['task_2'] = prediction[1] + if params.run_mode != 'export_pb': + loss_dict['label'] = features["label_0"] + + export_outputs = { + 'predictor': tf.estimator.export.PredictOutput(loss_dict) + } + return tf.estimator.EstimatorSpec(mode=mode, + predictions=loss_dict, + export_outputs=export_outputs) + + return model_fn diff --git a/examples/demo/little_demo_estimator/nn_optim.py b/examples/demo/little_demo_estimator/nn_optim.py new file mode 100644 index 00000000..1bf1ea3c --- /dev/null +++ b/examples/demo/little_demo_estimator/nn_optim.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import os + +import tensorflow as tf +from mx_rec.util.tf_version_adapter import hccl_ops +from mx_rec.util.initialize import get_rank_size, get_use_dynamic_expansion +from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.optimizers.gradient_descent import create_hash_optimizer +from mx_rec.optimizers.gradient_descent_by_addr import create_hash_optimizer_by_addr +from mx_rec.util.log import logger + + +def get_dense_and_sparse_optimizer(cfg): + dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate) + if get_use_dynamic_expansion(): + sparse_optimizer = create_hash_optimizer_by_addr(learning_rate=cfg.learning_rate) + logger.info("optimizer create_hash_optimizer_by_addr") + else: + sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate) + logger.info("optimizer create_hash_optimizer") + + return dense_optimizer, sparse_optimizer + + +def get_train_op_list(losses, learning_rate): + train_ops_list = [] + update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) + name = None + + dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) + use_dynamic_expansion = get_use_dynamic_expansion() + if use_dynamic_expansion: + sparse_optimizer = create_hash_optimizer_by_addr(learning_rate=learning_rate) + else: + sparse_optimizer = create_hash_optimizer(learning_rate=learning_rate) + + dense_variables, sparse_variables = get_dense_and_sparse_variable() + trainable_variables = [dense_variables] + + for i in range(len(losses)): + name = losses[i][0] + loss = losses[i][1] + with tf.control_dependencies(update_ops): + # do dense grad + grads = dense_optimizer.compute_gradients(loss, var_list=trainable_variables) + dense_grads = grads[:len(dense_variables)] + avg_grads = [] + for grad, var in dense_grads: + if get_rank_size() > 1: + grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None + if grad is not None: + avg_grads.append((grad, var)) + # apply gradients: update variables + train_ops_list.append(dense_optimizer.apply_gradients(avg_grads, name="dense_optimizer")) + + # do sparse optimization + if use_dynamic_expansion: + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ + ApplyGradientsStrategy, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) + + if (ApplyGradientsStrategy.mapping(os.getenv("APPLY_GRADIENTS_STRATEGY")) == + ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY): + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) + else: + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + + local_grads = tf.gradients(loss, train_emb_list) # local_embedding + grads_and_vars = [(grad, address) for grad, address in zip(local_grads, train_address_list)] + train_ops_list.append(sparse_optimizer.apply_gradients(grads_and_vars, name='hashtable_optimizer')) + else: + sparse_grads = tf.gradients(loss, sparse_variables) + grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] + train_ops_list.append(sparse_optimizer.apply_gradients(grads_and_vars, name="sparse_optimizer")) + + global_step_op = tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(), 1) + train_ops_list.append(global_step_op) + return train_ops_list, name + + +def get_train_op(params, losses): + train_ops = [] + op_list, name = get_train_op_list(losses, params.learning_rate) + train_ops.append([name + '_train', tf.group(*op_list)]) + ops = [loss[1] for loss in losses] + [train_op[1] for train_op in train_ops] + return tf.group(*ops) diff --git a/examples/demo/little_demo_estimator/nn_reader.py b/examples/demo/little_demo_estimator/nn_reader.py new file mode 100644 index 00000000..a13468eb --- /dev/null +++ b/examples/demo/little_demo_estimator/nn_reader.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from mx_rec.util.initialize import get_rank_size, clear_channel +from mx_rec.core.asc.helper import get_asc_insert_func + +from dataset import generate_dataset +from utils import FeatureSpecIns, create_feature_spec_list + + +def input_fn(params, create_fs_params, cfg, is_eval=False, use_one_shot=False): + dataset = generate_dataset(cfg, + use_timestamp=params.use_timestamp, + batch_number=params.max_steps * get_rank_size()) + + if not params.modify_graph: + feature_spec_list = create_feature_spec_list(create_fs_params.get("cfg"), + create_fs_params.get("use_timestamp"), + create_fs_params.get("use_multi_lookup"), + create_fs_params.get("multi_lookup_times")) + if is_eval: + FeatureSpecIns.get_instance().set_eval_feature_spec_list(feature_spec_list) + dataset = dataset.map(get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=False)) + clear_channel(is_train_channel=False) + else: + FeatureSpecIns.get_instance().set_train_feature_spec_list(feature_spec_list) + dataset = dataset.map(get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=True)) + + dataset = dataset.prefetch(100) + + if not use_one_shot: + return dataset + + iterator = dataset.make_one_shot_iterator() + batch = iterator.get_next() + return batch diff --git a/examples/demo/little_demo_estimator/op_precision.ini b/examples/demo/little_demo_estimator/op_precision.ini new file mode 100644 index 00000000..094319e7 --- /dev/null +++ b/examples/demo/little_demo_estimator/op_precision.ini @@ -0,0 +1,3 @@ +ScatterNdAdd = support_out_of_bound_index +GatherV2 = high_performance +UnsortedSegmentSum = high_performance \ No newline at end of file diff --git a/examples/demo/little_demo_estimator/random_data_generator.py b/examples/demo/little_demo_estimator/random_data_generator.py new file mode 100644 index 00000000..23acf22b --- /dev/null +++ b/examples/demo/little_demo_estimator/random_data_generator.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np + +from mx_rec.util.initialize import get_rank_id +from mx_rec.util.log import logger + + +def get_data_generator(config, batch_number): + rank_id = get_rank_id() + + def data_generator(): + i = 0 + while i < batch_number: + item_ids = np.random.randint(0, config.item_range, (config.batch_size, config.item_feat_cnt)) + user_ids = np.random.randint(0, config.user_range, (config.batch_size, config.user_feat_cnt)) + category_ids = np.random.randint(0, config.category_range, (config.batch_size, config.category_feat_cnt)) + label_0 = np.random.randint(0, 2, (config.batch_size,)) + label_1 = np.random.randint(0, 2, (config.batch_size,)) + + yield {"item_ids": item_ids, + "user_ids": user_ids, + "category_ids": category_ids, + "label_0": label_0, + "label_1": label_1} + i += 1 + + logger.debug(f"================ end of data generator for {config.task_name} task | rank id {rank_id} " + f"================") + + return data_generator + + +def get_large_scale_data_generator(config): + def data_generator(): + i = 0 + while True: + id_list = [np.random.randint(0, config.vocabulary_size, (config.batch_size,)) + for _ in range(config.lookup_count)] + + data_block = dict(zip(config.tensor_name_list, id_list)) + + label_0 = np.random.randint(0, 2, (config.batch_size,)) + label_1 = np.random.randint(0, 2, (config.batch_size,)) + data_block["label_0"] = label_0 + data_block["label_1"] = label_1 + + logger.debug(f"================ generate NO.{i} step ================") + yield data_block + i += 1 + + return data_generator diff --git a/examples/demo/little_demo_estimator/run.sh b/examples/demo/little_demo_estimator/run.sh new file mode 100644 index 00000000..f4ae3b03 --- /dev/null +++ b/examples/demo/little_demo_estimator/run.sh @@ -0,0 +1,176 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +kill -9 `ps -ef | grep python | grep -v grep | awk '{print $2}'` > /dev/null 2>&1 +rm -rf /root/ascend/log/* +rm -rf ./kernel* +rm -rf ./export_graph/* + +# 获取输入参数:py、ip +if [ $# -ge 1 ]; then + py=$1 + ip=$2 +else + echo "for example: bash run.sh main.py 10.10.10.10 or bash run.sh main.py" + exit 1 +fi + +# 检查输入的python文件是否合法 +if [[ $py =~ ^[a-z0-9_]+\.py$ ]]; then + echo "File $py is a valid Python file" +else + echo "File $py is not a Python file" + exit 1 +fi + +# 判断IP地址是否有效 +if [ -n "$ip" ]; then + if [[ $ip =~ ^([0-9]{1,3}\.){3}[0-9]{1,3}$ ]]; then + # 将IP地址拆分成四个数字 + ip_array=(${ip//./ }) + # 判断每个数字是否在0-255之间 + valid=true + for i in "${ip_array[@]}"; do + if ((i < 0 || i > 255)); then + valid=false + break + fi + done + if $valid; then + echo "ip: $ip is valid" + else + echo "ip: $ip is not valid" + exit 1 + fi + else + echo "ip: $ip is not valid." + exit 1 + fi +fi + +cur_path=`pwd` +mx_rec_package_path="/usr/local/python3.7.5/lib/python3.7/site-packages/mx_rec" # please config +so_path=${mx_rec_package_path}/libasc +# GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO +mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=0 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' +interface="lo" +local_rank_size=8 # 每个节点使用的NPU卡数 +num_server=1 # 训练节点数 +num_process=$((${num_server} * ${local_rank_size})) # 训练总的进程数,等于使用的NPU卡的总数 + +export HCCL_CONNECT_TIMEOUT=1200 # HCCL集合通信 建链超时时间,取值范围[120,7200] +export PYTHONPATH=${so_path}:$PYTHONPATH # 环境python安装路径 +export LD_PRELOAD=/usr/lib64/libgomp.so.1 # GNU OpenMP动态库路径. 不应该使用LD_PRELOAD这种方式加载! +export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH +# 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 +export JOB_ID=10086 +# 训练任务使用的NPU卡数总数 +export MXREC_LOG_LEVEL="DEBUG" # 框架日志等级 +export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL +# 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 +export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL +export MXREC_MODE="ASC" +export USE_MPI=1 +export USE_MODE="train_and_evaluate" # 支持[train, predict, train_and_evaluate] + +if [ $USE_MODE = "train" ] || [ $USE_MODE = "train_and_evaluate" ];then + echo "train mode: saved-model will be deleted" + rm -rf ./_rank* +fi + +################# 配置梯度策略 ###################### +apply_gradient_strategy="sum_same_id_gradients_and_apply" +# apply_gradient_strategy="direct_apply" +export APPLY_GRADIENTS_STRATEGY=${apply_gradient_strategy} +################# 参数配置 ###################### +export USE_DYNAMIC=1 # 0:静态shape;1:动态shape +export USE_HOT=1 # 0:关闭hot emb;1: 开启hot emb +export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 +export USE_MULTI_LOOKUP=1 # 0:一表一查;1:一表多查 +export MULTI_LOOKUP_TIMES=2 # 一表多查次数:默认2,上限127(因为一表已经有一查);仅当export USE_MULTI_LOOKUP=1时生效 +export USE_MODIFY_GRAPH=1 # 0:feature spec模式;1:自动改图模式 +export USE_TIMESTAMP=0 # 0:关闭特征准入淘汰;1:开启特征准入淘汰 +export UpdateEmb_V2=0 # 0: UpdateEmb同步更新;1:UpdateEmb_V2异步更新 +export USE_ONE_SHOT=0 # 0:MakeIterator;1:OneShotIterator +################# 性能调优相关 #################### +export KEY_PROCESS_THREAD_NUM=6 #default 6, max 10 +export FAST_UNIQUE=0 #if use fast unique +export MGMT_HBM_TASK_MODE=0 #if async h2d (get and send tensors) +################## 测试配置项 ##################### +# NOTE: 仅在测试constant、string相关op作为稀疏表输入时启用,当前版本只支持TF1。 +export ENABLE_PUSH_OPS_TEST=0 + +# 帮助信息,不需要修改 +if [[ $1 == --help || $1 == -h ]];then + echo "Usage: ./run.sh [OPTION]... [IP]..." + echo " " + echo "parameter explain: + [OPTION] main.py + [IP] IP address of the host + -h/--help show help message + " + exit 1 +fi + +# 使用ranktable方案 +function rankTableSolution() { + echo "The ranktable solution" + export RANK_TABLE_FILE="${cur_path}/hccl_json_${local_rank_size}p.json" + export RANK_SIZE=$num_process + echo "RANK_TABLE_FILE=$RANK_TABLE_FILE" + if [ ! -f "$RANK_TABLE_FILE" ];then + echo "the rank table file does not exit. Please reference {hccl_json_${local_rank_size}p.json} to correctly config rank table file" + exit 1 + fi +} + +if [ ! -n "$ip" ]; then + rankTableSolution +else + VALID_CHECK=$(echo $ip|awk -F. '$1<=255&&$2<=255&&$3<=255&&$4<=255{print "yes"}') + if echo $ip|grep -E "^[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}$">/dev/null; then + if [ "$VALID_CHECK" == "yes" ]; then + #################使用去除ranktable方案时开启###################### + echo "ip: $ip available." + echo "The ranktable solution is removed." + export CM_CHIEF_IP=$ip # 主节点ip + export CM_CHIEF_PORT=6000 # 主节点监听端口 + export CM_CHIEF_DEVICE=0 # 主节点device id + export CM_WORKER_IP=$ip # 当前节点ip + export CM_WORKER_SIZE=$num_process # 参与集群训练的device数量 + echo "CM_CHIEF_IP=$CM_CHIEF_IP" + echo "CM_CHIEF_PORT=$CM_CHIEF_PORT" + echo "CM_CHIEF_DEVICE=$CM_CHIEF_DEVICE" + echo "CM_WORKER_IP=$CM_WORKER_IP" + echo "CM_WORKER_SIZE=$CM_WORKER_SIZE" + echo "ASCEND_VISIBLE_DEVICES=$ASCEND_VISIBLE_DEVICES" + ######################################################### + else + echo "ip: $ip not available!" # 使用ranktable方案 + rankTableSolution + fi + else + echo "ip: $ip not available!" # 使用ranktable方案 + rankTableSolution + fi +fi + +echo "use horovod to start tasks" +DATE=$(date +%Y-%m-%d-%H-%M-%S) +horovodrun --network-interface ${interface} -np ${num_process} --mpi-args "${mpi_args}" --mpi -H localhost:${local_rank_size} \ +python3.7 ${py} \ +--run_mode=$USE_MODE \ +2>&1 | tee "temp_${local_rank_size}p_${KEY_PROCESS_THREAD_NUM}t_${DATE}.log" diff --git a/examples/demo/little_demo_estimator/tf_adapter.py b/examples/demo/little_demo_estimator/tf_adapter.py new file mode 100644 index 00000000..f011d253 --- /dev/null +++ b/examples/demo/little_demo_estimator/tf_adapter.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf + +if tf.__version__.startswith("1"): + from npu_bridge.npu_init import NPURunConfig, NPUEstimator, npu_hooks_append, DumpConfig +else: + from npu_device.compat.v1.npu_init import NPURunConfig, NPUEstimator, npu_hooks_append, DumpConfig diff --git a/examples/demo/little_demo_estimator/utils.py b/examples/demo/little_demo_estimator/utils.py new file mode 100644 index 00000000..436dac01 --- /dev/null +++ b/examples/demo/little_demo_estimator/utils.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from mx_rec.core.asc.helper import FeatureSpec + + +class FeatureSpecIns: + _single_instance = None + _train_feature_spec_list = None + _eval_feature_spec_list = None + + @staticmethod + def set_instance(): + if FeatureSpecIns._single_instance is not None: + raise RuntimeError("`FeatureSpecIns` single instance cannot be set once.") + FeatureSpecIns._single_instance = FeatureSpecIns() + + @staticmethod + def get_instance(): + if FeatureSpecIns._single_instance is None: + raise RuntimeError("Please set `FeatureSpecIns` before get `FeatureSpecIns`.") + return FeatureSpecIns._single_instance + + @staticmethod + def set_train_feature_spec_list(fs_list): + FeatureSpecIns._train_feature_spec_list = fs_list + + @staticmethod + def get_train_feature_spec_list(): + if FeatureSpecIns._train_feature_spec_list is None: + raise RuntimeError("Please set `train_feature_spec_list` before get `train_feature_spec_list`.") + return FeatureSpecIns._train_feature_spec_list + + @staticmethod + def set_eval_feature_spec_list(fs_list): + FeatureSpecIns._eval_feature_spec_list = fs_list + + @staticmethod + def get_eval_feature_spec_list(): + if FeatureSpecIns._eval_feature_spec_list is None: + raise RuntimeError("Please set `eval_feature_spec_list` before get `eval_feature_spec_list`.") + return FeatureSpecIns._eval_feature_spec_list + + +def create_feature_spec_list(cfg, use_timestamp=False, use_multi_lookup=False, multi_lookup_times=2): + access_threshold = cfg.access_threshold if use_timestamp else None + eviction_threshold = cfg.eviction_threshold if use_timestamp else None + feature_spec_list = [FeatureSpec("user_ids", table_name="user_table", + access_threshold=access_threshold, eviction_threshold=eviction_threshold), + FeatureSpec("item_ids", table_name="item_table", + access_threshold=access_threshold, eviction_threshold=eviction_threshold)] + if use_multi_lookup: + for _ in range(multi_lookup_times): + feature_spec_list.append(FeatureSpec("user_ids", table_name="user_table", access_threshold=access_threshold, + eviction_threshold=eviction_threshold, faae_coefficient=1)) + if use_timestamp: + feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) + return feature_spec_list diff --git a/examples/dlrm/criteo_tb/README.md b/examples/dlrm/criteo_tb/README.md new file mode 100644 index 00000000..c8825e5b --- /dev/null +++ b/examples/dlrm/criteo_tb/README.md @@ -0,0 +1,19 @@ +# 准备criteo_TB原始数据 +首先从[官网](https://labs.criteo.com/2013/12/download-terabyte-click-logs/) +下载24天的原始数据集,执行命令如下。所有文件保存下来大概需要365G。 + +`curl -O https://storage.googleapis.com/criteo-cail-datasets/day_{`seq -s “,” 0 23`}.gz` + +然后将下载好的24个文件解压,解压后的文件需要占用1035G。 + +# 原始数据集转tfrecord +运行转换脚本: + +`python3.7 gen_ttf.py --train_data_dir train_dir --test_data_dir test_dir --tf_base_dir save_base_dir` + +参数说明: + +- train_data_dir: 解压后训练集路径,该路径下存放day_0,day_1,...,day_22 +- test_data_dir: 解压后测试集路径,该路径下存放day_23 +- tf_base_dir:tfrecord存放路径,磁盘至少需要633G + diff --git a/examples/dlrm/criteo_tb/gen_ttf.py b/examples/dlrm/criteo_tb/gen_ttf.py new file mode 100644 index 00000000..92fabb3d --- /dev/null +++ b/examples/dlrm/criteo_tb/gen_ttf.py @@ -0,0 +1,402 @@ +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import pickle +import collections +import logging +import argparse +from multiprocessing import Process +import numpy as np +import time +from tqdm import tqdm +from glob import glob +from collections import Counter, OrderedDict +import sys + +import tensorflow as tf + + +class Logger(object): + level_relations = { + 'debug': logging.DEBUG, + 'info': logging.INFO, + 'warning': logging.WARNING, + 'error': logging.ERROR, + 'crit': logging.CRITICAL + } # 日志级别关系映射 + + def __init__(self, filename, level='info', + fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'): + self.logger = logging.getLogger(filename) + format_str = logging.Formatter(fmt) # 设置日志格式 + self.logger.setLevel(self.level_relations.get(level)) # 设置日志级别 + sh = logging.StreamHandler() # 往屏幕上输出 + sh.setFormatter(format_str) # 设置屏幕上显示的格式 + th = logging.FileHandler(filename=filename) # 往文件里写入#指定间隔时间自动生成文件的处理器 + th.setFormatter(format_str) # 设置文件里写入的格式 + self.logger.addHandler(sh) # 把对象加到logger里 + self.logger.addHandler(th) + + def info(self, *args): + if len(args) == 1: + self.logger.info(*args) + else: + self.logger.info([*args]) + + +class CriteoStatsDict(): + def __init__(self): + self.field_size = 39 # value_1-13; cat_1-26; + self.val_cols = ["val_{}".format(i + 1) for i in range(13)] + self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)] + + # + self.val_min_dict = {col: 0 for col in self.val_cols} + self.val_max_dict = {col: 0 for col in self.val_cols} + self.global_idx_range_dict = {col: 0 for col in self.cat_cols} # cat_0: v_0, cat_1: v_1, ... + self.hist_map = {col: set() for col in self.cat_cols} + # + self.hash_bucket = 40000000 + self.dense_bias = 1 + self.hush_bucket_offset = [self.hash_bucket * i for i in range(26)] # voc_size=1040000000 + self.slot_size_array = [39884407, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 38532952, + 2953546, 403346, 10, 2208, 11938, 155, 4, 976, 14, 39979772, 25641295, + 39664985, 585935, 12972, 108, 36] + self.offset_size_list = np.cumsum([0] + self.slot_size_array[:-1]) + + + def stats_cats(self, cat_list): + def map_cat_count(i, cat): + capped_value = int(cat, 16) % self.hash_bucket if cat else self.hash_bucket + + key_col = self.cat_cols[i] + if capped_value not in self.hist_map[key_col]: + self.hist_map[key_col].add(capped_value) + + for i, cat in enumerate(cat_list): + map_cat_count(i, cat) + + # + def save_dict(self, output_path, hist_map, prefix=""): + with open(os.path.join(output_path, "{}hist_map.pkl".format(prefix)), "wb") as file_wrt: + pickle.dump(hist_map, file_wrt) + + # + def load_dict(self, dict_path, prefix=""): + with open(os.path.join(dict_path, "{}hist_map.pkl".format(prefix)), "rb") as file_wrt: + self.hist_map = pickle.load(file_wrt) + + + def map_cat2id(self, denses, cats): + # dense-features + # dense missing value, the minimum in dense-feature is 1.0 + dense_list = [int(d) + self.dense_bias if d else self.dense_bias for d in denses] + + + # cat-features + # process missing value + cat_list = [] + def map_cat_count(i, cat): + capped_value = int(cat, 16) % self.hash_bucket if cat else self.hash_bucket + + key_col = self.cat_cols[i] + if capped_value in self.hist_map[key_col]: + cat_list.append(self.hist_map[key_col][capped_value]) + else: + print(f"error: {key_col}, {cat}, {capped_value}") + + + for i, cat in enumerate(cats): + map_cat_count(i, cat) + # cat preprocess: mod and offset + cat_list = [0 if cat < 0 else cat % (self.slot_size_array[idx] + 1) for idx, cat in enumerate(cat_list)] + cat_list = [cat + offset for cat, offset in zip(cat_list, self.offset_size_list)] + + + return dense_list, cat_list + +def statsdata_multiprocess(process_num, process_id, data_file_path, output_path, criteo_stats): + start_time = time.time() + with open(data_file_path, encoding="utf-8") as file_in: + errorline_list = [] + count = 0 + for i, line in enumerate(file_in): + if i % process_num != process_id: + continue + count += 1 + line = line.strip("\n") + items = line.split("\t") + if len(items) != 40: + errorline_list.append(count) + print("line: {}".format(line)) + continue + if count % 1000000 == 0: + print("Have handle {}w lines.".format(count // 10000)) + cats = items[14:] + criteo_stats.stats_cats(cats) + criteo_stats.save_dict(output_path) + print('statsdata time cost: {:.2f}s'.format(time.time() - start_time)) + + +def get_unique_id_multiprocess(process_num, process_id, data_file_path, output_path, criteo_stats): + if os.path.exists(os.path.join(output_path, "unique_id.pkl")): + return + start_time = time.time() + cat_sets = [OrderedDict() for col in criteo_stats.cat_cols] + cat_global_id_nums = [0 for col in criteo_stats.cat_cols] + hash_bucket = criteo_stats.hash_bucket + line_num = 0 + with open(data_file_path, encoding="utf-8") as file_in: + errorline_list = [] + + for i, line in enumerate(file_in): + line_num += 1 + start_line = process_id * ((line_num + process_num) // process_num) + end_line = (process_id + 1) * ((line_num + process_num) // process_num) + with open(data_file_path, encoding="utf-8") as file_in: + errorline_list = [] + count = 0 + for i, line in enumerate(file_in): + if i < start_line or i >= end_line: + continue + count += 1 + line = line.strip("\n") + items = line.split("\t") + if len(items) != 40: + errorline_list.append(count) + print("line: {}".format(line)) + continue + if count % 10000 == 0: + print("Have handle {}w lines.".format(count // 10000)) + sys.stdout.flush() + cats = items[14:] + # criteo_stats.stats_cats(cats) + # def map_cat_count(i, cat): + for k, cat in enumerate(cats): + # map_cat_count(i, cat) + capped_value = int(cat, 16) % hash_bucket if cat else hash_bucket + # if capped_value not in self.hist_map[key_col]: + if capped_value not in cat_sets: + cat_sets[k][capped_value] = cat_global_id_nums[k] + cat_global_id_nums[k] += 1 + with open(os.path.join(output_path, "unique_id.pkl"), "wb") as file_wrt: + pickle.dump(cat_sets, file_wrt) + print('statsdata time cost: {:.2f}s'.format(time.time() - start_time)) + + +def merge_stats_count(stats_dir, criteo_stats): + if os.path.exists(f'{stats_dir}/hist_map.pkl'): + return + stats_sub_dirs = sorted(glob(f'{stats_dir}/*[0-9]')) + with open(f'{stats_sub_dirs[0]}/unique_id.pkl', 'rb') as f: + all_hist_map = pickle.load(f) + + for i in tqdm(range(1, len(stats_sub_dirs))): + with open(f'{stats_sub_dirs[i]}/unique_id.pkl', 'rb') as f: + others_count = pickle.load(f) + for k, _ in enumerate(criteo_stats.cat_cols): + all_count_1, others_count_1 = all_hist_map[k], others_count[k] + all_count_1.update(others_count_1) + all_hist_map[k] = all_count_1 + hist_map = {} + for i, col in enumerate(criteo_stats.cat_cols): + hist_map[col] = dict(zip(list(all_hist_map[i].keys()), range(len(all_hist_map[i])))) + + criteo_stats.save_dict(stats_dir, hist_map) + + +def mkdir_path(file_path): + if not os.path.exists(file_path): + os.makedirs(file_path) + + +def make_example(label_list, dense_feat_list, sparse_feat_list): + dense_feature = np.array(dense_feat_list, dtype=np.float32).reshape(-1) + sparse_feature = np.array(sparse_feat_list, dtype=np.int64).reshape(-1) + label = np.array(label_list, dtype=np.int64).reshape(-1) + feature_dict = {"dense_feature": tf.train.Feature(float_list=tf.train.FloatList(value=dense_feature)), + "sparse_feature": tf.train.Feature(int64_list=tf.train.Int64List(value=sparse_feature)), + "label": tf.train.Feature(int64_list=tf.train.Int64List(value=label)) + } + example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) + + return example + +def convert_input2tfrd_multiprocess(process_num, process_id, in_file_path, output_path, criteo_stats, line_per_sample=1024, + part_rows=2000000, mode="train_"): + start_time = time.time() + print("----------" * 10 + "\n" * 2) + + part_number = 0 + file_name = output_path + "part_{:0>8d}.tfrecord" + + file_writer = tf.python_io.TFRecordWriter(file_name.format(part_number)) + sample_count = 0 + part_count = 0 + line_num = 0 + with open(in_file_path, encoding="utf-8") as file_in: + errorline_list = [] + + for i, line in tqdm(enumerate(file_in)): + line_num += 1 + print(f'line_num: {line_num}') + start_line = process_id * ((line_num + process_num) // process_num) + end_line = (process_id + 1) * ((line_num + process_num) // process_num) + dense_res_list = [] + cat_res_list = [] + label_res_list = [] + with open(in_file_path, encoding="utf-8") as file_in: + total_count = 0 + part_number = 0 + for i, line in enumerate(file_in): + if i < start_line or i >= end_line: + continue + + total_count += 1 + if total_count % 10000 == 0: + print("Have handle {}w tfrecords.".format(total_count // 10000)) + sys.stdout.flush() + line = line.strip("\n") + items = line.split("\t") + if len(items) != 40: + continue + label = int(items[0]) + values = items[1:14] + cats = items[14:] + assert len(values) == 13, "values.size: {}".format(len(values)) + assert len(cats) == 26, "cats.size: {}".format(len(cats)) + val_list, cat_list = criteo_stats.map_cat2id(values, cats) + dense_res_list.append(val_list) + cat_res_list.append(cat_list) + label_res_list.append(label) + sample_count += 1 + if sample_count % line_per_sample == 0 and sample_count > 0: + ex = make_example(label_res_list, dense_res_list, cat_res_list) + serialized = ex.SerializeToString() + file_writer.write(serialized) + part_count += line_per_sample + sample_count = 0 + dense_res_list = [] + cat_res_list = [] + label_res_list = [] + if part_count >= part_rows: + part_number += 1 + file_writer.close() + file_writer = tf.python_io.TFRecordWriter(file_name.format(part_number)) + part_count = 0 + + if sample_count > 0: + file_writer.close() + part_number += 1 + + print('convert_input2tfrd time cost: {:.2f}s'.format(time.time() - start_time)) + return part_number + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Get and Process datasets') + parser.add_argument('--train_data_dir', default="train", + help='day_0, ..., day_22 file path') + parser.add_argument('--test_data_dir', default="test", + help='day_23 file path') + parser.add_argument('--tf_base_dir', default="tf_base_dir", + help='tfrecord saved base path. Disk occupation better over 720G.') + parser.add_argument('--stats_process_num', default=72, type=int, + help='prcoess num of stats') + parser.add_argument('--train_process_num', default=69, type=int, + help='prcoess num of train tfrecord generation') + parser.add_argument('--test_process_num', default=24, type=int, + help='prcoess num of test tfrecord generation') + args, _ = parser.parse_known_args() + train_data_dir = args.train_data_dir + test_data_dir = args.test_data_dir + criteo_stats = CriteoStatsDict() + + base_path = "./" + train_data_files = sorted(glob(f'{train_data_dir}/*')) + test_data_files = sorted(glob(f'{test_data_dir}/*')) + data_files = train_data_files + test_data_files + print("train data files: ", train_data_files) + print("test data files: ", test_data_files) + print("data files: ", data_files) + process_num = args.stats_process_num + if True: + processs = [] + for process_id in range(process_num): + sub_process_num = process_num // len(data_files) + data_file = data_files[process_id//sub_process_num] + stats_output_path = base_path + f"/stats_dict_mp/{process_id:02}/" + mkdir_path(stats_output_path) + p = Process(target=get_unique_id_multiprocess, args=( + sub_process_num, process_id % sub_process_num, data_file, stats_output_path, criteo_stats)) + processs.append(p) + for p in processs: + p.start() + for p in processs: + p.join() + merge_stats_count(base_path + f"/stats_dict_mp/", criteo_stats) + + print("----------" * 10) + stats_output_path = base_path + f"/stats_dict_mp/" + criteo_stats.load_dict(dict_path=stats_output_path, prefix="") + + spe_num = 1024 + tf_base_dir = args.tf_base_dir + + # gen_train tfrecords + dataset_mode = "train" + save_tfrecord_path = os.path.join(tf_base_dir, "tfrecord", dataset_mode) + mkdir_path(save_tfrecord_path) + processs = [] + process_num = args.train_process_num + assert process_num % len(train_data_files) == 0, print( + f'process_num {process_num} must exact div length of data_files {len(data_files)}') + + for process_id in range(process_num): + sub_process_num = process_num // len(train_data_files) + data_file = train_data_files[process_id // sub_process_num] + output_path = f'{save_tfrecord_path}/{process_id:04}_' + p = Process(target=convert_input2tfrd_multiprocess, args=(sub_process_num, process_id%sub_process_num, data_file, output_path, + criteo_stats, spe_num, + 5000000)) + processs.append(p) + for p in processs: + p.start() + for p in processs: + p.join() + + # gen_test tfrecords + dataset_mode = "test" + save_tfrecord_path = os.path.join(tf_base_dir, "tfrecord", dataset_mode) + mkdir_path(save_tfrecord_path) + processs = [] + process_num = args.test_process_num + assert process_num % len(test_data_files) == 0, print( + f'process_num {process_num} must exact div length of data_files {len(data_files)}') + + for process_id in range(process_num): + sub_process_num = process_num // len(test_data_files) + data_file = test_data_files[process_id // sub_process_num] + output_path = f'{save_tfrecord_path}/{process_id:04}_' + p = Process(target=convert_input2tfrd_multiprocess, args=(sub_process_num, process_id%sub_process_num, data_file, output_path, + criteo_stats, spe_num, + 5000000)) + + processs.append(p) + for p in processs: + p.start() + for p in processs: + p.join() \ No newline at end of file diff --git a/examples/dlrm/model/config.py b/examples/dlrm/model/config.py new file mode 100644 index 00000000..452b2a7f --- /dev/null +++ b/examples/dlrm/model/config.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os + +import tensorflow as tf +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig +from npu_bridge.estimator.npu.npu_config import NPURunConfig + + +class LearningRateScheduler: + """ + LR Scheduler combining Polynomial Decay with Warmup at the beginning. + TF-based cond operations necessary for performance in graph mode. + """ + + def __init__(self, base_lr_dense, base_lr_sparse, warmup_steps, decay_start_step, decay_steps): + self.warmup_steps = tf.constant(warmup_steps, dtype=tf.int32) + self.decay_start_step = tf.constant(decay_start_step, dtype=tf.int32) + self.decay_steps = tf.constant(decay_steps) + self.decay_end_step = decay_start_step + decay_steps # 65041 + self.poly_power = 2.0 + self.base_lr_dense = base_lr_dense + self.base_lr_sparse = base_lr_sparse + + def calc(self, global_step): + # used for the warmup stage + warmup_step = tf.cast(1 / self.warmup_steps, tf.float32) + lr_factor_warmup = 1 - tf.cast(self.warmup_steps - global_step, tf.float32) * warmup_step + # lr_factor_warmup = tf.cast(global_step, tf.float32) / tf.cast(self.warmup_steps, tf.float32) #hx + lr_factor_warmup = tf.cast(lr_factor_warmup, tf.float32) + # used for the constant stage + lr_factor_constant = tf.cast(1.0, tf.float32) + + # used for the decay stage + lr_factor_decay = (self.decay_end_step - global_step) / self.decay_steps + lr_factor_decay = tf.math.pow(lr_factor_decay, self.poly_power) + lr_factor_decay = tf.cast(lr_factor_decay, tf.float32) + sparse_after_decay = tf.cast(1 / self.decay_steps, tf.float32) + + lr_factor_decay_sparse = tf.cond( + global_step < self.decay_end_step, + lambda: lr_factor_decay, + lambda: sparse_after_decay, + # lambda: 0.000 #hx + ) + + lr_factor_decay_dense = tf.cond( + global_step < self.decay_end_step, + lambda: lr_factor_decay, + lambda: sparse_after_decay, + ) + + poly_schedule_sparse = tf.cond( + global_step < self.decay_start_step, + lambda: lr_factor_constant, + lambda: lr_factor_decay_sparse, + ) + + poly_schedule_dense = tf.cond( + global_step < self.decay_start_step, + lambda: lr_factor_constant, + lambda: lr_factor_decay_dense, + ) + + lr_factor_sparse = tf.cond( + global_step < self.warmup_steps, lambda: lr_factor_warmup, lambda: poly_schedule_sparse + ) + + lr_factor_dense = tf.cond( + global_step < self.warmup_steps, lambda: lr_factor_warmup, lambda: poly_schedule_dense + ) + + lr_sparse = self.base_lr_sparse * lr_factor_sparse + lr_dense = self.base_lr_dense * lr_factor_dense + return lr_dense, lr_sparse + + +class Config: + def __init__(self, ): + self.rank_id = int(os.getenv("RANK_ID")) if os.getenv("RANK_ID") else None + tmp = os.getenv("RANK_SIZE") + if tmp is None: + raise ValueError("please export RANK_SIZE") + self.rank_size = int(tmp) + + self.data_path = os.getenv("DLRM_CRITEO_DATA_PATH") + self.train_file_pattern = "train" + self.test_file_pattern = "test" + + self.batch_size = 8192 + self.line_per_sample = 1024 + self.train_epoch = 3 + self.test_epoch = 1 + self.perform_shuffle = False + + self.key_type = tf.int64 + self.label_type = tf.float32 + self.value_type = tf.int64 + + self.feat_cnt = 26 + self.__set_emb_table_size() + + self.field_num = 26 + self.send_count = 46000 // self.rank_size + + self.emb_dim = 128 + self.hashtable_threshold = 1 + # self.learning_rate = 0.01 + + self.USE_PIPELINE_TEST = False + + # 动态学习率 + GLOBAL_BATCH_SIZE = 8192 * 8 + LR_SCHEDULE_STEPS = [ + int(2750 * 55296 / GLOBAL_BATCH_SIZE), + int(49315 * 55296 / GLOBAL_BATCH_SIZE), + int(27772 * 55296 / GLOBAL_BATCH_SIZE), + ] + self.global_step = tf.Variable(0, trainable=False) + _lr_scheduler = LearningRateScheduler( + 28.443, + 33.71193, + LR_SCHEDULE_STEPS[0], + LR_SCHEDULE_STEPS[1], + LR_SCHEDULE_STEPS[2], + ) + self.learning_rate = _lr_scheduler.calc(self.global_step) + + def __set_emb_table_size(self): + self.cache_mode = os.getenv("CACHE_MODE") + if self.cache_mode is None: + raise ValueError("please export CACHE_MODE environment variable, support:[HBM, DDR, SSD]") + + if self.cache_mode == "HBM": + self.dev_vocab_size = 24_000_000 * self.rank_size + self.host_vocab_size = 0 + elif self.cache_mode == "DDR": + self.dev_vocab_size = 500_000 * self.rank_size + self.host_vocab_size = 24_000_000 * self.rank_size + elif self.cache_mode == "SSD": + self.dev_vocab_size = 100_000 * self.rank_size + self.host_vocab_size = 2_000_000 * self.rank_size + self.ssd_vocab_size = 24_000_000 * self.rank_size + else: + raise ValueError(f"get CACHE_MODE:{self.cache_mode}, expect in [HBM, DDR, SSD]") + + def get_emb_table_cfg(self) -> dict: + if self.cache_mode == "HBM": + return {"device_vocabulary_size": self.dev_vocab_size} + elif self.cache_mode == "DDR": + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size} + elif self.cache_mode == "SSD": + return {"device_vocabulary_size": self.dev_vocab_size, + "host_vocabulary_size": self.host_vocab_size, + "ssd_vocabulary_size": self.ssd_vocab_size, + "ssd_data_path": ["ssd_data"]} + else: + raise RuntimeError(f"get CACHE_MODE:{self.cache_mode}, check Config.__set_emb_table_size implementation") + + +def sess_config(dump_data=False, dump_path="./dump_output", dump_steps="0|1|2"): + session_config = tf.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["mix_compile_mode"].b = False + custom_op.parameter_map["use_off_line"].b = True + custom_op.parameter_map["min_group_size"].b = 1 + custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:fullmesh;level1:fullmesh") + # custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes("level0:pairwise;level1:pairwise") + custom_op.parameter_map["enable_data_pre_proc"].b = True + custom_op.parameter_map["iterations_per_loop"].i = 10 + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + custom_op.parameter_map["hcom_parallel"].b = False + custom_op.parameter_map["op_precision_mode"].s = tf.compat.as_bytes("op_impl_mode.ini") + custom_op.parameter_map["op_execute_timeout"].i = 2000 + custom_op.parameter_map["variable_memory_max_size"].s = tf.compat.as_bytes( + str(13 * 1024 * 1024 * 1024)) # total 31 need 13; + custom_op.parameter_map["graph_memory_max_size"].s = tf.compat.as_bytes(str(18 * 1024 * 1024 * 1024)) # need 25 + custom_op.parameter_map["stream_max_parallel_num"].s = tf.compat.as_bytes("DNN_VM_AICPU:3,AIcoreEngine:3") + + if dump_data: + custom_op.parameter_map["enable_dump"].b = True + custom_op.parameter_map["dump_path"].s = tf.compat.as_bytes(dump_path) + custom_op.parameter_map["dump_step"].s = tf.compat.as_bytes(dump_steps) + custom_op.parameter_map["dump_mode"].s = tf.compat.as_bytes("all") + + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + return session_config + + +def get_npu_run_config(): + session_config = tf.ConfigProto(allow_soft_placement=False, + log_device_placement=False) + + session_config.gpu_options.allow_growth = True + custom_op = session_config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + session_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + session_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF + + run_config = NPURunConfig( + save_summary_steps=1000, + save_checkpoints_steps=100, + keep_checkpoint_max=5, + session_config=session_config, + log_step_count_steps=20, + precision_mode='allow_mix_precision', + enable_data_pre_proc=True, + iterations_per_loop=1, + jit_compile=False, + op_compiler_cache_mode="enable", + HCCL_algorithm="level0:fullmesh;level1:fullmesh" + # HCCL_algorithm="level0:pairwise;level1:pairwise" + ) + return run_config diff --git a/examples/dlrm/model/delay_loss_scale.py b/examples/dlrm/model/delay_loss_scale.py new file mode 100644 index 00000000..0cb50688 --- /dev/null +++ b/examples/dlrm/model/delay_loss_scale.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from tensorflow.python.training import optimizer + + +class DenseLossScaleOptimizer: + def __init__(self, opt, loss_scale): + if not isinstance(opt, optimizer.Optimizer): + raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % type(opt)) + self._optimizer = opt + self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) + self._optimizer._learning_rate = self._optimizer._learning_rate / self._loss_scale + + def compute_gradients(self, loss, var_list=None): + return self._optimizer.compute_gradients(loss*self._loss_scale, var_list=var_list) + + def apply_gradients(self, avg_grads): + return self._optimizer.apply_gradients(avg_grads) + + +class SparseLossScaleOptimizer: + def __init__(self, opt, loss_scale): + if not isinstance(opt, optimizer.Optimizer): + raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % type(opt)) + self._optimizer = opt + self._loss_scale = tf.convert_to_tensor(loss_scale, tf.float32) + self._optimizer._learning_rate = self._optimizer._learning_rate / self._loss_scale + + def compute_gradients(self, loss, var_list=None): + return tf.gradients(loss*self._loss_scale, var_list) + + def apply_gradients(self, grads_and_vars): + return self._optimizer.apply_gradients(grads_and_vars) \ No newline at end of file diff --git a/examples/dlrm/model/gradient_descent_w.py b/examples/dlrm/model/gradient_descent_w.py new file mode 100644 index 00000000..b66ec67d --- /dev/null +++ b/examples/dlrm/model/gradient_descent_w.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + +import tensorflow as tf +from tensorflow.python.ops import math_ops +from tensorflow.python.training import gradient_descent +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.log import logger + + +def create_hash_optimizer(learning_rate, weight_decay=0.0001, use_locking=False, name="GradientDescent"): + return CustomizedGradientDescentWithWeighDecay(learning_rate=learning_rate, + weight_decay=weight_decay, + use_locking=use_locking, + name=name) + + +class CustomizedGradientDescentWithWeighDecay(gradient_descent.GradientDescentOptimizer, CustomizedOptimizer): + name_counter = defaultdict(int) + + def __init__(self, learning_rate, weight_decay, use_locking=False, name="GradientDescent"): + self.optimizer_type = "gradient_descent_with_weight_decay" + self.weight_decay = weight_decay + super(CustomizedGradientDescentWithWeighDecay, self)._get_name(name=name) + super(CustomizedGradientDescentWithWeighDecay, self).__init__( + learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name + ) + + def initialize_slots(self, var, table_instance): + logger.info("no slot for gradient descent") + return [] + + def insert_slot(self, slot, named_slots_key, slot_name): + logger.info("no slot for gradient descent") + return dict() + + def get_slot_init_values(self): + logger.info("no slot for gradient descent") + return [] + + def _apply_sparse_duplicate_indices(self, grad, var): + logger.debug(">>>> Enter _apply_sparse_duplicate_indices") + nd_indices = tf.expand_dims(grad.indices, 1) + logger.info(f"weigh_decay={self.weight_decay}") + if self.weight_decay is None: + nd_value = grad.values * math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) + else: + nd_value = (grad.values + math_ops.cast(self.weight_decay, var.dtype.base_dtype) * tf.gather(var, grad.indices)) * math_ops.cast( + self._learning_rate_tensor, var.dtype.base_dtype) + var_update_op = tf.scatter_nd_add(var, nd_indices, -nd_value, use_locking=self._use_locking) + return var_update_op + + def _apply_dense(self, grad, var): + logger.debug(">>>> Enter _apply_dense") + raise NotImplementedError("You are using a wrong type of variable.") diff --git a/examples/dlrm/model/main_mxrec.py b/examples/dlrm/model/main_mxrec.py new file mode 100644 index 00000000..24a1a1b2 --- /dev/null +++ b/examples/dlrm/model/main_mxrec.py @@ -0,0 +1,448 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time +import warnings +import random +from glob import glob + +from sklearn.metrics import roc_auc_score +import numpy as np + +from mx_rec.core.asc.helper import FeatureSpec, get_asc_insert_func +from mx_rec.core.asc.manager import start_asc_pipeline +from mx_rec.core.embedding import create_table, sparse_lookup +from mx_rec.core.feature_process import EvictHook +from mx_rec.graph.modifier import modify_graph_and_start_emb_cache, GraphModifierHook +from mx_rec.constants.constants import ASCEND_TIMESTAMP, ApplyGradientsStrategy +from mx_rec.util.initialize import get_rank_size, init, clear_channel, get_rank_id, set_if_load, \ + terminate_config_initializer, get_host_pipeline_ops, get_initializer, get_target_batch +import mx_rec.util as mxrec_util +from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.util.log import logger +from npu_bridge.npu_init import * + +from model import MyModel +from config import sess_config, Config +from optimizer import get_dense_and_sparse_optimizer + +dense_hashtable_seed = 128 +sparse_hashtable_seed = 128 +shuffle_seed = 128 +random.seed(shuffle_seed) + + +def add_timestamp_func(batch): + host_pipeline_ops = get_host_pipeline_ops() + timestamp = host_pipeline_ops.return_timestamp(tf.cast(batch['label'], dtype=tf.int64)) + # tf.constant(np.random.randint(1,1688109060,1)), tf.int64)) + batch["timestamp"] = timestamp + return batch + + +def make_batch_and_iterator(cfg, feature_spec_list, is_training, dump_graph, use_faae=False): + if cfg.USE_PIPELINE_TEST: + num_parallel = 1 + else: + num_parallel = 8 + + def extract_fn(data_record): + features = { + # Extract features using the keys set during creation + 'label': tf.compat.v1.FixedLenFeature(shape=(cfg.line_per_sample,), dtype=tf.int64), + 'sparse_feature': tf.compat.v1.FixedLenFeature(shape=(26 * cfg.line_per_sample,), dtype=tf.int64), + 'dense_feature': tf.compat.v1.FixedLenFeature(shape=(13 * cfg.line_per_sample,), dtype=tf.float32), + } + sample = tf.compat.v1.parse_single_example(data_record, features) + return sample + + def reshape_fn(batch): + batch['label'] = tf.reshape(batch['label'], [-1, 1]) + batch['dense_feature'] = tf.reshape(batch['dense_feature'], [-1, 13]) + batch['dense_feature'] = tf.math.log(batch['dense_feature'] + 3.0) + batch['sparse_feature'] = tf.reshape(batch['sparse_feature'], [-1, 26]) + return batch + + if is_training: + files_list = glob(os.path.join(cfg.data_path, cfg.train_file_pattern) + '/*.tfrecord') + else: + files_list = glob(os.path.join(cfg.data_path, cfg.test_file_pattern) + '/*.tfrecord') + dataset = tf.data.TFRecordDataset(files_list, num_parallel_reads=num_parallel) + batch_size = cfg.batch_size // cfg.line_per_sample + + dataset = dataset.shard(cfg.rank_size, cfg.rank_id) + if is_training: + dataset = dataset.shuffle(batch_size * 1000, seed=shuffle_seed) + if is_training: + dataset = dataset.repeat(cfg.train_epoch) + else: + dataset = dataset.repeat(cfg.test_epoch) + # dataset = dataset.repeat(cfg.num_epochs) + dataset = dataset.map(extract_fn, num_parallel_calls=num_parallel).batch(batch_size, + drop_remainder=True) + dataset = dataset.map(reshape_fn, num_parallel_calls=num_parallel) + if use_faae: + dataset = dataset.map(add_timestamp_func) + + if not MODIFY_GRAPH_FLAG: + insert_fn = get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=is_training, dump_graph=dump_graph) + dataset = dataset.map(insert_fn) + + dataset = dataset.prefetch(100) + + iterator = dataset.make_initializable_iterator() + batch = iterator.get_next() + return batch, iterator + + +def model_forward(feature_list, hash_table_list, batch, is_train, modify_graph): + embedding_list = [] + logger.debug(f"In model_forward function, is_train: {is_train}, feature_list: {len(feature_list)}, " + f"hash_table_list: {len(hash_table_list)}") + for feature, hash_table in zip(feature_list, hash_table_list): + if MODIFY_GRAPH_FLAG: + feature = batch["sparse_feature"] + embedding = sparse_lookup(hash_table, feature, cfg.send_count, dim=None, is_train=is_train, + name="user_embedding_lookup", modify_graph=modify_graph, batch=batch, + access_and_evict_config=None) + embedding_list.append(embedding) + + if len(embedding_list) == 1: + emb = embedding_list[0] + elif len(embedding_list) > 1: + emb = tf.reduce_sum(embedding_list, axis=0, keepdims=False) + else: + raise ValueError("The length of embedding_list must be greater than or equal to 1.") + my_model = MyModel() + model_output = my_model.build_model(embedding=emb, + dense_feature=batch["dense_feature"], + label=batch["label"], + is_training=is_train, + seed=dense_hashtable_seed) + return model_output + + +def evaluate(): + print("read_test dataset") + if not MODIFY_GRAPH_FLAG: + eval_label = eval_model.get("label") + sess.run([eval_iterator.initializer, clear_channel(False)]) + else: + # 在sess run模式下,若还是使用原来batch中的label去sess run,则会出现getnext超时报错,需要使用新数据集中的batch + eval_label = get_target_batch(False).get("label") + sess.run([get_initializer(False), clear_channel(False)]) + log_loss_list = [] + pred_list = [] + label_list = [] + eval_current_steps = 0 + finished = False + print("eval begin") + + while not finished: + try: + eval_current_steps += 1 + eval_start = time.time() + eval_loss, pred, label = sess.run([eval_model["loss"], eval_model["pred"], eval_label]) + eval_cost = time.time() - eval_start + qps = (1 / eval_cost) * rank_size * cfg.batch_size + log_loss_list += list(eval_loss.reshape(-1)) + pred_list += list(pred.reshape(-1)) + label_list += list(label.reshape(-1)) + print(f"eval current_steps: {eval_current_steps}, qps: {qps}") + if eval_current_steps == eval_steps: + finished = True + except tf.errors.OutOfRangeError: + finished = True + auc = roc_auc_score(label_list, pred_list) + mean_log_loss = np.mean(log_loss_list) + return auc, mean_log_loss + + +def evaluate_fix(step): + print("read_test dataset evaluate_fix") + if not MODIFY_GRAPH_FLAG: + sess.run([eval_iterator.initializer, clear_channel(False)]) + else: + sess.run([get_initializer(False), clear_channel(False)]) + log_loss_list = [] + pred_list = [] + label_list = [] + eval_current_steps = 0 + finished = False + print("eval begin") + while not finished: + try: + eval_current_steps += 1 + eval_loss, pred, label = sess.run([eval_model["loss"], eval_model["pred"], eval_model["label"]]) + log_loss_list += list(eval_loss.reshape(-1)) + pred_list += list(pred.reshape(-1)) + label_list += list(label.reshape(-1)) + print(f"eval current_steps: {eval_current_steps}") + + if eval_current_steps == eval_steps: + finished = True + except tf.errors.OutOfRangeError: + finished = True + + label_numpy = np.array(label_list) + pred_numpy = np.array(pred_list) + if not os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}"): + os.makedirs(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}") + + if os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy"): + os.remove(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy") + if os.path.exists(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy"): + os.remove(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy") + if os.path.exists(f"flag_{rank_id}.txt"): + os.remove(f"flag_{rank_id}.txt") + np.save(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/label_{rank_id}.npy", label_numpy) + np.save(os.path.abspath(".") + f"/interval_{interval}/numpy_{step}/pred_{rank_id}.npy", pred_numpy) + os.mknod(f"flag_{rank_id}.txt") + while True: + file_exists_list = [os.path.exists(f"flag_{i}.txt") for i in range(rank_size)] + # print(file_exists_list) + if sum(file_exists_list) == rank_size: + print("All saved!!!!!!!!!!") + break + else: + print("Waitting for saving numpy!!!!!!!!") + time.sleep(1) + continue + + auc = roc_auc_score(label_list, pred_list) + mean_log_loss = np.mean(log_loss_list) + return auc, mean_log_loss + + +def create_feature_spec_list(use_timestamp=False): + access_threshold = None + eviction_threshold = None + if use_timestamp: + access_threshold = 1000 + eviction_threshold = 180 + + feature_spec_list = [FeatureSpec("sparse_feature", table_name="sparse_embeddings", batch_size=cfg.batch_size, + access_threshold=access_threshold, eviction_threshold=eviction_threshold)] + if use_multi_lookup: + feature_spec_list.append(FeatureSpec("sparse_feature", table_name="sparse_embeddings", + batch_size=cfg.batch_size, + access_threshold=access_threshold, + eviction_threshold=eviction_threshold)) + if use_timestamp: + feature_spec_list.append(FeatureSpec("timestamp", is_timestamp=True)) + return feature_spec_list + + +if __name__ == "__main__": + tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + warnings.filterwarnings("ignore") + + use_mpi = bool(int(os.getenv("USE_MPI"))) + rank_id = int(os.getenv("RANK_ID")) if os.getenv("RANK_ID") else None + rank_size = int(os.getenv("RANK_SIZE")) if os.getenv("RANK_SIZE") else None + interval = int(os.getenv("INTERVAL")) if os.getenv("INTERVAL") else None + train_steps = 10000 + eval_steps = 1360 + + try: + use_dynamic_expansion = bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))) + use_multi_lookup = bool(int(os.getenv("USE_MULTI_LOOKUP", 0))) + MODIFY_GRAPH_FLAG = bool(int(os.getenv("USE_MODIFY_GRAPH", 0))) + use_faae = bool(int(os.getenv("USE_FAAE", 0))) + except ValueError as err: + raise ValueError(f"please correctly config USE_DYNAMIC_EXPANSION or USE_MULTI_LOOKUP or USE_FAAE " + f"or USE_MODIFY_GRAPH only 0 or 1 is supported.") from err + + use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) + logger.info(f"USE_DYNAMIC:{use_dynamic}") + init(use_mpi, rank_id=rank_id, rank_size=rank_size, train_steps=train_steps, eval_steps=eval_steps, + use_dynamic=use_dynamic, use_dynamic_expansion=use_dynamic_expansion) + IF_LOAD = False + rank_id = mxrec_util.initialize.get_rank_id() + filelist = glob(f"./saved-model/sparse-model-{rank_id}-0") + if filelist: + IF_LOAD = True + set_if_load(IF_LOAD) + + cfg = Config() + feature_spec_list_train = None + feature_spec_list_eval = None + if use_faae: + feature_spec_list_train = create_feature_spec_list(use_timestamp=True) + feature_spec_list_eval = create_feature_spec_list(use_timestamp=True) + else: + feature_spec_list_train = create_feature_spec_list(use_timestamp=False) + feature_spec_list_eval = create_feature_spec_list(use_timestamp=False) + + train_batch, train_iterator = make_batch_and_iterator(cfg, feature_spec_list_train, is_training=True, + dump_graph=True, use_faae=use_faae) + eval_batch, eval_iterator = make_batch_and_iterator(cfg, feature_spec_list_eval, is_training=False, + dump_graph=False, use_faae=use_faae) + logger.info(f"train_batch: {train_batch}") + + if use_faae: + cfg.dev_vocab_size = cfg.dev_vocab_size // 2 + + optimizer_list = [get_dense_and_sparse_optimizer(cfg)] + sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] + + # note: variance_scaling_initializer only support HBM mode + emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=sparse_hashtable_seed)\ + if cfg.cache_mode != "HBM" or use_dynamic_expansion else\ + tf.compat.v1.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=sparse_hashtable_seed) + sparse_hashtable = create_table( + key_dtype=cfg.key_type, + dim=tf.TensorShape([cfg.emb_dim]), + name="sparse_embeddings", + emb_initializer=emb_initializer, + optimizer_list=[sparse_optimizer_list[0]._optimizer], + **cfg.get_emb_table_cfg() + ) + if use_faae: + tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, train_batch["timestamp"]) + + sparse_hashtable_list = [sparse_hashtable, sparse_hashtable] if use_multi_lookup else [sparse_hashtable] + train_model = model_forward(feature_spec_list_train, sparse_hashtable_list, train_batch, + is_train=True, modify_graph=MODIFY_GRAPH_FLAG) + eval_model = model_forward(feature_spec_list_eval, sparse_hashtable_list, eval_batch, + is_train=False, modify_graph=MODIFY_GRAPH_FLAG) + + dense_variables, sparse_variables = get_dense_and_sparse_variable() + + rank_size = mxrec_util.initialize.get_rank_size() + train_ops = [] + # multi task training + for loss, (dense_optimizer, sparse_optimizer) in zip([train_model["loss"]], optimizer_list): + # do dense optimization + grads = dense_optimizer.compute_gradients(loss, var_list=dense_variables) + avg_grads = [] + for grad, var in grads: + if rank_size > 1: + grad = hccl_ops.allreduce(grad, "sum") if grad is not None else None + if grad is not None: + avg_grads.append((grad / 8.0, var)) + # apply gradients: update variables + train_ops.append(dense_optimizer.apply_gradients(avg_grads)) + + if use_dynamic_expansion: + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ + ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + + if (ApplyGradientsStrategy.mapping(os.getenv("APPLY_GRADIENTS_STRATEGY")) == + ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY): + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) + else: + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) + # do sparse optimization by addr + sparse_grads = sparse_optimizer.compute_gradients(loss, train_emb_list) # local_embedding + grads_and_vars = [(grad, address) for grad, address in zip(sparse_grads, train_address_list)] + train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + else: + # do sparse optimization + sparse_grads = sparse_optimizer.compute_gradients(loss, sparse_variables) + print("sparse_grads_tensor:", sparse_grads) + grads_and_vars = [(grad, variable) for grad, variable in zip(sparse_grads, sparse_variables)] + train_ops.append(sparse_optimizer.apply_gradients(grads_and_vars)) + + # 动态学习率更新 + train_ops.extend([cfg.global_step.assign(cfg.global_step + 1), cfg.learning_rate[0], cfg.learning_rate[1]]) + + with tf.control_dependencies(train_ops): + train_ops = tf.no_op() + cfg.learning_rate = [cfg.learning_rate[0], cfg.learning_rate[1]] + + saver = tf.train.Saver() + if MODIFY_GRAPH_FLAG: + modify_graph_and_start_emb_cache(dump_graph=True) + else: + start_asc_pipeline() + + hook_list = [] + if use_faae: + hook_evict = EvictHook(evict_enable=True, evict_time_interval=120) + hook_list.append(hook_evict) + if MODIFY_GRAPH_FLAG: # 该场景添加hook处理校验问题 + hook_list.append(GraphModifierHook(modify_graph=False)) + + # with tf.compat.v1.Session(config=sess_config(dump_data=False)) as sess: + if use_faae: + sess = tf.compat.v1.train.MonitoredTrainingSession( + hooks=hook_list, + config=sess_config(dump_data=False) + ) + sess.graph._unsafe_unfinalize() + if not MODIFY_GRAPH_FLAG: + sess.run(train_iterator.initializer) + else: + sess.run(get_initializer(True)) + else: + sess = tf.compat.v1.Session(config=sess_config(dump_data=False)) + sess.run(tf.compat.v1.global_variables_initializer()) + if not MODIFY_GRAPH_FLAG: + sess.run(train_iterator.initializer) + else: + sess.run(get_initializer(True)) + + epoch = 0 + cost_sum = 0 + qps_sum = 0 + best_auc = 0 + iteration_per_loop = 10 + + train_ops = util.set_iteration_per_loop(sess, train_ops, 10) + + # for i in range(1, TRAIN_STEPS): + i = 0 + while True: + i += 1 + logger.info(f"################ training at step {i * iteration_per_loop} ################") + start_time = time.time() + + try: + grad, loss = sess.run([train_ops, train_model["loss"]]) + lr = sess.run(cfg.learning_rate) + global_step = sess.run(cfg.global_step) + except tf.errors.OutOfRangeError: + logger.info(f"Encounter the end of Sequence for training.") + break + + end_time = time.time() + cost_time = end_time - start_time + qps = (1 / cost_time) * rank_size * cfg.batch_size * iteration_per_loop + cost_sum += cost_time + # qps_sum += qps + logger.info(f"step: {i * iteration_per_loop}; training loss: {loss}") + logger.info(f"step: {i * iteration_per_loop}; grad: {grad}") + logger.info(f"step: {i * iteration_per_loop}; lr: {lr}") + logger.info(f"global step: {global_step}") + logger.info(f"step: {i * iteration_per_loop}; current sess cost time: {cost_time:.10f}; current QPS: {qps}") + logger.info(f"training at step:{i * iteration_per_loop}, table[{sparse_hashtable.table_name}], " + f"table size:{sparse_hashtable.size()}, table capacity:{sparse_hashtable.capacity()}") + + if i % (train_steps // iteration_per_loop) == 0: + if interval is not None: + test_auc, test_mean_log_loss = evaluate_fix(i * iteration_per_loop) + else: + test_auc, test_mean_log_loss = evaluate() + print("Test auc: {}; log_loss: {} ".format(test_auc, test_mean_log_loss)) + best_auc = max(best_auc, test_auc) + logger.info(f"training step: {i * iteration_per_loop}, best auc: {best_auc}") + + sess.close() + + terminate_config_initializer() + logger.info("Demo done!") diff --git a/examples/dlrm/model/mean_auc.py b/examples/dlrm/model/mean_auc.py new file mode 100644 index 00000000..1116ebd5 --- /dev/null +++ b/examples/dlrm/model/mean_auc.py @@ -0,0 +1,40 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import numpy as np +from glob import glob + + +def split_auc(log_input): + with open(log_input, 'r') as log: + all_auc = [] + for line in log.readlines(): + if 'Test' in line: + all_auc.append(float(line.split(';')[0].split(':')[-1].strip())) + all_auc_len = len(all_auc) + all_auc_arr = np.array(all_auc)[:all_auc_len - all_auc_len%8] + test_auc = np.mean(all_auc_arr.reshape(-1, 8), axis=-1) + return test_auc + + +log_path_all = 'latest_*.log' +log_path_list = glob(log_path_all) + +for log_path in log_path_list: + print(os.path.basename(log_path)) + print(split_auc(log_path)) + print('*'*20) \ No newline at end of file diff --git a/examples/dlrm/model/model.py b/examples/dlrm/model/model.py new file mode 100644 index 00000000..037fb276 --- /dev/null +++ b/examples/dlrm/model/model.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time +from easydict import EasyDict as edict + +import tensorflow as tf + + +model_cfg = edict() +model_cfg.loss_mode = "batch" +LOSS_OP_NAME = "loss" +LABEL_OP_NAME = "label" +VAR_LIST = "variable" +PRED_OP_NAME = "pred" + + +class MyModel: + def __init__(self): + self.kernel_init = None + self._loss_fn = None + self.is_training = None + + @classmethod + def _dot_interaction(cls, _input): + num_features = tf.shape(_input)[1] + batch_size = tf.shape(_input)[0] + xactions = tf.matmul(_input, _input, transpose_b=True) + ones = tf.ones_like(xactions, dtype=tf.float32) + upper_tri_mask = tf.linalg.band_part(ones, 0, -1) + + activations = tf.where(condition=tf.cast(upper_tri_mask, tf.bool), + x=tf.zeros_like(xactions), + y=xactions) + out_dim = num_features * num_features + activations = tf.reshape(activations, (batch_size, out_dim)) + return activations + + def build_model(self, + embedding=None, + dense_feature=None, + label=None, + is_training=True, + seed=None): + with tf.variable_scope("mlp", reuse=tf.AUTO_REUSE): + self._loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True) + self.is_training = is_training + dense_embedding_vec = self.bottom_stack(dense_feature, seed) + dense_embedding = tf.expand_dims(dense_embedding_vec, 1) + interaction_args = tf.concat([dense_embedding, embedding], axis=1) + interaction_output = self._dot_interaction(interaction_args) + feature_interaction_output = tf.concat([dense_embedding_vec, interaction_output], axis=1) + # (8192, 857) + logits = self.top_stack(feature_interaction_output, seed) + loss = self._loss_fn(label, logits) + prediction = tf.sigmoid(logits) + trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='mlp') + return {LOSS_OP_NAME: loss, + PRED_OP_NAME: prediction, + LABEL_OP_NAME: label, + VAR_LIST: trainable_variables} + + def bottom_stack(self, _input, seed): + dnn1 = tf.layers.dense(_input, 512, activation='relu', name='bs1', + kernel_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), + bias_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), + kernel_regularizer=tf.contrib.layers.l1_regularizer(1e-2)) + dnn2 = tf.layers.dense(dnn1, 256, activation='relu', name='bs2', kernel_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), bias_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), kernel_regularizer=tf.contrib.layers.l1_regularizer(1e-2)) + dnn3 = tf.layers.dense(dnn2, 128, activation='relu', name='bs3', kernel_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), bias_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), kernel_regularizer=tf.contrib.layers.l1_regularizer(1e-2)) + return dnn3 + + def top_stack(self, _input, seed): + dnn1 = tf.layers.dense(_input, 1024, activation='relu', name='ts1', kernel_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), bias_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), kernel_regularizer=tf.contrib.layers.l1_regularizer(1e-2)) + dnn2 = tf.layers.dense(dnn1, 1024, activation='relu', name='ts2', kernel_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), bias_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), kernel_regularizer=tf.contrib.layers.l1_regularizer(1e-2)) + dnn3 = tf.layers.dense(dnn2, 512, activation='relu', name='ts3', kernel_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), bias_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), kernel_regularizer=tf.contrib.layers.l1_regularizer(1e-2)) + dnn4 = tf.layers.dense(dnn3, 256, activation='relu', name='ts4', kernel_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), bias_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), kernel_regularizer=tf.contrib.layers.l1_regularizer(1e-2)) + dnn5 = tf.layers.dense(dnn4, 1, activation=None, name='ts5', kernel_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), bias_initializer=tf.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=seed), kernel_regularizer=tf.contrib.layers.l1_regularizer(1e-2)) + return dnn5 + + +my_model = MyModel() diff --git a/examples/dlrm/model/op_impl_mode.ini b/examples/dlrm/model/op_impl_mode.ini new file mode 100644 index 00000000..579dea43 --- /dev/null +++ b/examples/dlrm/model/op_impl_mode.ini @@ -0,0 +1 @@ +ScatterNdAdd=support_out_of_bound_index \ No newline at end of file diff --git a/examples/dlrm/model/optimizer.py b/examples/dlrm/model/optimizer.py new file mode 100644 index 00000000..8e33fa97 --- /dev/null +++ b/examples/dlrm/model/optimizer.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +from delay_loss_scale import DenseLossScaleOptimizer, SparseLossScaleOptimizer +from gradient_descent_w import create_hash_optimizer +from mx_rec.util.initialize import get_use_dynamic_expansion +from mx_rec.optimizers.gradient_descent_by_addr import create_hash_optimizer_by_addr + + +def get_dense_and_sparse_optimizer(cfg): + dense_optimizer = tf.train.GradientDescentOptimizer(learning_rate=cfg.learning_rate[0]) + use_dynamic_expansion = get_use_dynamic_expansion() + sparse_optimizer = None + if use_dynamic_expansion: + sparse_optimizer = create_hash_optimizer_by_addr(learning_rate=cfg.learning_rate[1], weight_decay=0.0001) + else: + sparse_optimizer = create_hash_optimizer(learning_rate=cfg.learning_rate[1], weight_decay=0.0001) + sparse_optimizer = SparseLossScaleOptimizer(sparse_optimizer, 1024) + dense_optimizer = DenseLossScaleOptimizer(dense_optimizer, 1024) + + return dense_optimizer, sparse_optimizer diff --git a/examples/dlrm/model/run.sh b/examples/dlrm/model/run.sh new file mode 100644 index 00000000..5f3d6c5e --- /dev/null +++ b/examples/dlrm/model/run.sh @@ -0,0 +1,116 @@ +#!/bin/bash +# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +cur_path=$(dirname "$(readlink -f "$0")") + +so_path=$1 +mx_rec_package_path=$2 +hccl_cfg_json=$3 +dlrm_criteo_data_path=$4 + +export RANK_SIZE=8 +echo "RANK_SIZE=${RANK_SIZE}, please make sure hccl configuration json file match this parameter" +export RANK_TABLE_FILE=${hccl_cfg_json} + +################# 参数配置 ###################### +export USE_DYNAMIC=0 # 0:静态shape;1:动态shape +export CACHE_MODE="HBM" # HBM;DDR;SSD +export USE_FAAE=0 # 0:关闭准入淘汰;1:开启准入淘汰 +export USE_DYNAMIC_EXPANSION=0 # 0:关闭动态扩容;1: 开启动态扩容 +export USE_MULTI_LOOKUP=0 # 0:一表一查;1:一表多查 +export USE_MODIFY_GRAPH=0 # 0:feature spec模式;1:自动改图模式 +################################################ + +echo "CACHE_MODE:${CACHE_MODE}" +if [ ${CACHE_MODE} = "SSD" ]; then + echo "SSD train mode not allow file exist before training, + deleting dir ${cur_path}/ssd_data then create for SSD use case" + rm -rf ssd_data + mkdir ssd_data +fi + +export HCCL_CONNECT_TIMEOUT=1200 + +export DLRM_CRITEO_DATA_PATH=${dlrm_criteo_data_path} +export PYTHONPATH=${mx_rec_package_path}:${so_path}:$PYTHONPATH +export LD_PRELOAD=/usr/lib64/libgomp.so.1 +export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH + +rm -rf kernel* +rm -rf /root/ascend/log/* +rm -rf model_dir_rank* op_cache + +export ASCEND_DEVICE_ID=0 +export RANK_ID_START=0 +export JOB_ID=10086 +export CUSTOMIZED_OPS_LIB_PATH=${so_path}/libcust_ops.so # Todo: please config +export MXREC_LOG_LEVEL="INFO" +export TF_CPP_MIN_LOG_LEVEL=3 +export ASCEND_GLOBAL_LOG_LEVEL=3 +#export USE_FAAE=1 +export ENABLE_FORCE_V2_CONTROL=1 + +#apply_gradient_strategy="direct_apply" +apply_gradient_strategy="sum_same_id_gradients_and_apply" +export APPLY_GRADIENTS_STRATEGY=${apply_gradient_strategy} +#export USE_MULTI_LOOKUP=1 + +export PROFILING_OPTIONS='{"output":"/home/yz/profiling", + "training_trace":"on", + "task_trace":"on", + "aicpu":"on", + "fp_point":"", + "bp_point":"", + "aic_metrics":"PipeUtilization"}' + +RANK_ID_START=0 + +export MXREC_MODE="ASC" +echo "MXREC_MODE is $MXREC_MODE" +export USE_MPI=1 +echo "USE_MPI is $USE_MPI" +export py=main_mxrec.py +echo "py is $py" + + +if [ $USE_MPI -eq 0 ]; then + echo "use for loop to start tasks" + for((RANK_ID=$RANK_ID_START;RANK_ID<$((RANK_SIZE+RANK_ID_START));RANK_ID++)); + do + #设置环境变量,不需要修改 + echo "Device ID: $RANK_ID" + export RANK_ID=$RANK_ID + export ASCEND_DEVICE_ID=$RANK_ID + ASCEND_DEVICE_ID=$RANK_ID + if [ -d $cur_path/output/${ASCEND_DEVICE_ID} ];then + rm -rf $cur_path/output/${ASCEND_DEVICE_ID} + mkdir -p $cur_path/output/${ASCEND_DEVICE_ID} + else + mkdir -p $cur_path/output/${ASCEND_DEVICE_ID} + fi + nohup python3 ${py} > $cur_path/output/$ASCEND_DEVICE_ID/test_$ASCEND_DEVICE_ID.log 2>&1 & + done +else + echo "use horovod to start tasks" + # GLOG_stderrthreshold -2:TRACE -1:DEBUG 0:INFO 1:WARN 2.ERROR, 默认为INFO + mpi_args='-x BIND_INFO="0:12 12:48 60:48" -x GLOG_stderrthreshold=2 -x GLOG_logtostderr=true -bind-to none -x NCCL_SOCKET_IFNAME=docker0 -mca btl_tcp_if_exclude docker0' + interface="lo" + + horovodrun --network-interface ${interface} -np ${RANK_SIZE} --mpi-args "${mpi_args}" --mpi -H localhost:${RANK_SIZE} \ + python3.7 ${py} 2>&1 | tee temp_${CACHE_MODE}_${RANK_SIZE}p.log +fi + + diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index 1e5c8d6a..cdf2d255 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -1,27 +1,15 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # Some code is derived from Tensorflow, which is subject to the following copyright notice: # Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# +# We pick up the code of Tensorflow to make the api of mxRec compatible with Tensorflow for model executing. + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index d5d30946..38fd2c3e 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -1,27 +1,15 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # Some code is derived from Tensorflow, which is subject to the following copyright notice: # Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# We pick up the code of Tensorflow to make the api of mxRec compatible with Tensorflow for model saving and loading. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -- Gitee From 8733165ac8d5f4440f249a5658506e25e312bf06 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 26 Jan 2024 19:02:28 +0800 Subject: [PATCH 532/551] Match-id-d55f2dc43ba55372b52d228189bbce1c7f7b3033 --- cust_op/cust_op_by_addr/README.md | 134 ++++++++++++++ docs/build_mxRec_images/README.md | 94 ++++++++++ .../centos_build/Dockerfile | 170 ++++++++++++++++++ .../build_mxRec_images/mxrec-build/Dockerfile | 35 ++++ 4 files changed, 433 insertions(+) create mode 100644 cust_op/cust_op_by_addr/README.md create mode 100644 docs/build_mxRec_images/README.md create mode 100644 docs/build_mxRec_images/centos_build/Dockerfile create mode 100644 docs/build_mxRec_images/mxrec-build/Dockerfile diff --git a/cust_op/cust_op_by_addr/README.md b/cust_op/cust_op_by_addr/README.md new file mode 100644 index 00000000..d7cee5a8 --- /dev/null +++ b/cust_op/cust_op_by_addr/README.md @@ -0,0 +1,134 @@ +# 稀疏表自动扩容算子及样例说明 + +## 扩容算子文件结构 +```shell +├── aclnn_lookup_test # lookup单算子测试用例 +├── aclnn_update_test # update单算子测试用例 +├── emb_custom.json # 算子配置 +├── op_host # 扩容算子Host侧实现 +├── op_kernel # 扩容算子Kernel测实现 +├── README.md # 扩容算子说明文档 +└── run.sh # 扩容算子安装脚本 +``` + +## Ascend C参考设计 +更多详情可以参考CANN官方的Ascend C算子开发手册[Ascend C算子开发](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0001.html)。 + +针对mxRec,用于动态扩容功能的Ascend C算子有两个:**查询算子embedding_lookup_by_addr**和**更新算子embedding_update_by_addr**, +以下以embedding_lookup_by_addr算子为例对扩容算子做详细说明,embedding_update_by_addr算子同理。 + +## 查询算子embedding_lookup_by_addr + +1. 算子分析 + +a) 算子的主要功能是用addr地址作为入参,替换tf.gather算子; + +b) 算子支持emb表为int32、float32和float16三种类型的emb查询; + +c) 算子入参为:表示待查询emb地址列表的address,表示待查询emb的维度embedding_dim,表示待查询emb的类型embedding_type, +其中,0:int32、1:float32、2:float16. + +2. Host侧算子实现 + +Host侧算子实现在目录cust_op_by_addr/op_host下,其中包括:embedding_lookup_by_address.cpp和 +embedding_lookup_by_address_tiling.h。 + +a) Tiling实现 + +namespace optiling域中的TilingFunc函数,主要实现从context中获取外部入参信息,并校验有效性,并计算kernel侧需要的中间变量,如embeddingDimAligned、 +addrPerLoop等,设置BlockDim,最后通过TilingData传递属性信息。 + +b) Shape推导 + +namespace ge域中的InferShape和InferDataType函数,主要通过输入的tensorShape和tensorType来推导输出的tensorShape和tensorType。 + +c) 原型注册 + +namespace ops域中的EmbeddingLookupByAddress类。 + +3. Kernel侧算子实现 + +Kernel侧算子实现在目录cust_op_by_addr/op_kernel下,其中包括:embedding_lookup_by_address.cpp。 + +a) 核函数的入口 extern "C" __global__ __aicore__ void embedding_lookup_by_address + +b) GET_TILING_DATA(constData, tiling)从TilingData中获取host侧传入的数据 + +c) 根据模板类KernelEimtable构建类型不同的op对象,依次调用Init_param、Init、Process三个函数实现数据的搬运和计算; + +d) KernelEimtable::Init_param函数中,使用获取到的TilingData计算得到singleCoreAddrNum、veclen等变量 + +e) KernelEimtable::Init函数中,针对非对齐shape算子,使用Init_param的中间变量计算得到每个核上的偏移量、每个分块大小,并初始化和绑定Buffer + +f) KernelEimtable::Process函数实现算子的搬运和计算,最终输出结果到dstDataGm,即GM_ADDR y + +## AclNN单算子测试参考设计 + +更多详情可以参考CANN官方的[Ascend C单算子调用概述](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0036.html)。 + +单算子调用分为两种方式:单算子API执行和模型执行。mxRec提供单算子API执行供参考。 + +单算子测试用例在目录cust_op_by_addr/aclnn_lookup_test和cust_op_by_addr/aclnn_update_test下,其中: +* inc是头文件目录 +* scripts存放生成数据和验证数据的python脚本 +* input是存放算子入参的bin文件 +* output是存放生成的可执行程序execute_op、算子输出bin文件和用于验证的golden数据bin文件 +* src是存放公共函数common、构造算子输入输出描述类oprator_desc、单算子调用主体流程实现op_runner文件和入口main文件 + +执行单算子测试: +```shell +bash run.sh +``` + +### 前置条件 + +1. 参考[基于msopgen工具创建算子工程](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0023.html)完成算子工程的创建, +参考[kernel侧算子实现](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0024.html)完成kernel侧实现的相关准备, +参考[host侧算子实现](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0026.html)完成host侧实现相关准备。 +2. 参考[算子编译部署](https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/operatordev/Ascendcopdevg/atlas_ascendc_10_0031.html)完成算子的编译部署,编译部署时需要开启算子的二进制编译功能:修改算子工程中的编译配置项文件CMakePresets.json,将 +ENABLE_BINARY_PACKAGE设置为True。编译部署时可将算子的二进制部署到当前环境,便于后续算子的调用。 +3. 检查API执行需要的头文件和库文件是否自动生成,针对mxRec,检查cust_op/cust_op_by_addr/custom_op/build_out/autogen目录下,是否有 +aclnn_embedding_lookup_by_address.cpp和aclnn_embedding_lookup_by_address.h等。 + +注意:对于cust_op/cust_op_by_addr/run.sh脚本,安装算子后会删除构建目录。运行单算子测试时,需要屏蔽掉删除rm rf ./custom_op这一步,以确保前置条件3。 + +### 查询算子 embedding_lookup_by_addr +针对embedding_lookup_by_addr算子,入口src/main.cpp中: + +1. InitResource函数:初始化AscendCL并运行管理资源申请,不用修改 +2. RunLookupOp运行算子: + +a) 创建算子输入输出描述CreateOpDescLookup,该类是继承OperatorDesc,主要是引入了embeddingDim和embeddingType两个入参成员变量,以便后续 +op_runner中使用,基类OperatorDesc不用做修改; + +b) 创建OpRunnerLookup的对象,并依次执行: +* opRunner.Init():申请内存存放执行算子的输入输出数据 +* SetLookupInputData():加载数据输入bin文件并传输给OpRunner的Buffer供后续算子执行使用 +* RunOp():算子执行,核心调用OpRunnerLookup::RunOpHelper +* ProcessLookupOutputData():算子输出数据处理,并落盘文件,以供后续与golden数据比对 + +OpRunnerLookup类重载了基类OpRunner的虚函数RunOpHelper,实现具体算子的aclnn调用,基类OpRunner不用做修改; + +3. DestoryResource函数:释放内存,不用修改 + +### 运行脚本 +run.sh脚本依次执行: +1. 清除遗留生成文件和日志文件 +2. 生成输入数据和真值数据 +3. 编译acl可执行文件 +4. 运行可执行文件 +5. 比较真值文件 + +### scripts脚本 +* gen_data.py:生成embedding_lookup_by_addr算子(这里以embedding_lookup_by_addr算子为例)的输入数据和用于精度校 +验的golden数据,用户可自行修改测试的规模,如表的大小、查询的数量、表的dim等信息。 +* verify_result.py:将算子的输出和脚本生成的golden数据进行精度比对,比对规则为:允许误差精度loss:1e-4 + +a) 绝对误差 +b) 相对误差 +c) 误差相对个数 + +同时满足绝对误差不全小于loss,相对误差不全小于loss,且绝对误差和相对误差大于loss的个数都超过总数的1/loss,也就是 +1/10000(双万分之一),即认为算子精度不达标。其余情况均认为算子达标。 + +用户可自行修改允许精度误差范围loss。 \ No newline at end of file diff --git a/docs/build_mxRec_images/README.md b/docs/build_mxRec_images/README.md new file mode 100644 index 00000000..54f6b8a8 --- /dev/null +++ b/docs/build_mxRec_images/README.md @@ -0,0 +1,94 @@ +# 说明文档 +本文档旨在指导用户根据已有镜像制作mxRec的训练镜像 + +## 文档结构 +```shell +└── build_mxRec_images + ├── centos_build # 以AscendHub上CentOS开源镜像以及客户自己的镜像为基础镜像 + │ └── Dockerfile + ├── mxrec-build # 以AscendHub上mxRec开源镜像为基础镜像 + │ └── Dockerfile + └── README.md # 说明文档 +``` + +## 前提 +物理机上已经安装好对应CANN版本的驱动和固件 + +物理机上已经安装docker,并且docker网络可用 + +准备好基础镜像,如果用户没准好好基础镜像,可以从[昇腾镜像仓库](https://ascendhub.huawei.com/#/index)拉取基础 +镜像,建议拉取以下镜像作为基础镜像: +* 优先拉取mxRec训练镜像,因为AscendHub上的mxRec训练镜像中已经安装gcc、cmake等基础依赖,无需再次安装。 +同时,镜像也安装了CANN以及mxRec包,但是版本较老。所以如果使用mxRec镜像作为基础镜像只需更新其中的CANN和mxRec包即可。 +* 其次从AscendHub上拉取[CentOS7.6.1810](https://ascendhub.huawei.com/#/detail/centos)这个镜像 +* 最后,如果不用以上两个镜像,用户自己准备一个镜像作为基础镜像,建议这个镜像是CentOS 7.6.1810为基础。 + +## 准备依赖 +根据基础镜像的不同,需要下载的依赖也有所区别 +1. 以AscendHub上的mxRec训练镜像作为基础镜像,只需要下载[昇腾社区](https://www.hiascend.com/developer/download/community/result?module=sdk+cann)上最新版本配套的CANN和mxRec,其中CANN包括 +tookit和tfplugin。可以参考以下链接下载配套版本的CANN和mxRec: + +https://www.hiascend.com/zh/developer/download/community/result?module=sdk+cann + +https://www.hiascend.com/developer/download/community/result?module=tf+cann&tf=7.0.0.beta1&cann=7.0.0.beta1 + +具体构建镜像步骤参考mxrec-build下的Dockerfile + +2. 以CentOS7.6.1810以及用户镜像作为基础镜像,这种情况下需要较多的依赖,同时用户需要确认自己镜像中是否已经安装以下 +依赖。由于需要安装许多依赖,建议按照Dockerfile中的步骤**手动安装**其中的依赖,比如gcc、cmake等。 + +* gcc-7.3.0 + +下载链接:[https://mirrors.ustc.edu.cn/gnu/gcc/gcc-7.3.0/gcc-7.3.0.tar.gz](https://mirrors.ustc.edu.cn/gnu/gcc/gcc-7.3.0/gcc-7.3.0.tar.gz) + +* cmake-3.20.6 + +下载链接:[https://cmake.org/files/v3.20/cmake-3.20.6.tar.gz](https://cmake.org/files/v3.20/cmake-3.20.6.tar.gz) + +* ucx + +下载链接:[https://github.com/openucx/ucx/archive/master.zip](https://github.com/openucx/ucx/archive/master.zip) + +* openmpi-4.1.5 + +下载链接:[https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz](https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz) + +* easy_profiler-2.1.0 + +下载链接:[https://codeload.github.com/yse/easy_profiler/tar.gz/refs/tags/v2.1.0](https://codeload.github.com/yse/easy_profiler/tar.gz/refs/tags/v2.1.0) + +* python-3.7.5 + +下载链接:[https://repo.huaweicloud.com/python/3.7.5/Python-3.7.5.tar.xz](https://repo.huaweicloud.com/python/3.7.5/Python-3.7.5.tar.xz) + +* hdf5-1.10.5 + +下载链接:[https://support.hdfgroup.org/ftp/HDF5/releases/hdf5-1.10/hdf5-1.10.5/src/hdf5-1.10.5.tar.gz](https://support.hdfgroup.org/ftp/HDF5/releases/hdf5-1.10/hdf5-1.10.5/src/hdf5-1.10.5.tar.gz) + +* CANN与mxRec + +mxRec在[昇腾社区](https://www.hiascend.com/developer/download/community/result?module=sdk+cann) +上发布的版本包与CANN都是配套的,所以用户需要从社区下载配套版本的mxRec和CANN。其中需要CANN包括toolkit和tfplugin。 +用户可以通过以下链接选择下载版本配套的mxRec和CANN: + +https://www.hiascend.com/zh/developer/download/community/result?module=sdk+cann + +https://www.hiascend.com/developer/download/community/result?module=tf+cann&tf=7.0.0.beta1&cann=7.0.0.beta1 + +* Tensorflow(1.15.0/2.6.5) + +当前mxRec是基于tensorflow开发的,所以需要在环境中安装tensorflow。其中x86环境下可以通过pip或pip3命令直接安装。 +但是在arm环境下,tensorflow没有对应的whl包,无法直接用pip或pip3命令安装。用户可以从以下链接下载arm架构的tensorflow。 + +[https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindX/OpenSource/python/index.html](https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindX/OpenSource/python/index.html) + +* 安装CANN时需要的version.info、ascend_install.info + +在安装CANN包时需要两个文件,分别是version.info(驱动版本文件)、ascend_install.info(固件驱动安装参数),这两个 +文件可以参考物理机上对应的文件将其拷贝到同一个目录下。其中,version.info默认安装在/usr/local/Ascend/driver/version.info; +ascend_install.info文件默认路径是/etc/ascend_install.info。 + +具体构建镜像步骤参考centos-build下的Dockerfile + +**建议**:根据实际需要**下载上述依赖到同一个目录下**,这样方便处理。同时,在使用Dockerfile构建镜像之前可以仔细看一下对应的 +Dockerfile,因为需要用户根据实际情况修改一下Dockerfile,构建镜像的步骤在Dockerfile中有详细的说明。 \ No newline at end of file diff --git a/docs/build_mxRec_images/centos_build/Dockerfile b/docs/build_mxRec_images/centos_build/Dockerfile new file mode 100644 index 00000000..e218c1c0 --- /dev/null +++ b/docs/build_mxRec_images/centos_build/Dockerfile @@ -0,0 +1,170 @@ +# please configure 根据实际情况使用基础镜像 +FROM ascendhub.huawei.com/public-ascendhub/centos:7.6.1810 + +WORKDIR /tmp +COPY . ./ + +RUN chmod 777 /tmp + +# please configure 根据实际情况选择安装需要的依赖,如果一些依赖不需要可以将对应代码去掉或注释 + +# 1.安装编译环境 +RUN yum makecache && \ + yum -y install centos-release-scl && \ + yum -y install devtoolset-7 && \ + yum -y install devtoolset-7-gcc-c++ && \ + yum -y install epel-release && \ + yum -y install wget zlib-devel bzip2 bzip2-devel openssl-devel ncurses-devel openssh-clients sqlite-devel openmpi-devel \ + readline-devel tk-devel gdbm-devel db4-devel libpcap-devel xz-devel libffi-devel hdf5-devel patch pciutils lcov vim dos2unix gcc-c++ \ + autoconf automake libtool git && \ + yum clean all && \ + rm -rf /var/cache/yum && \ + echo "source /opt/rh/devtoolset-7/enable" >> /etc/profile + +# 2.安装gcc-7.3.0 +RUN source /etc/profile && \ + tar -zxvf gcc-7.3.0.tar.gz && \ + cd gcc-7.3.0 && \ + wget https://mirrors.huaweicloud.com/gnu/gmp/gmp-6.1.0.tar.bz2 --no-check-certificate && \ + wget --no-check-certificate https://mirrors.huaweicloud.com/gnu/mpfr/mpfr-3.1.4.tar.bz2 && \ + wget --no-check-certificate https://mirrors.huaweicloud.com/gnu/mpc/mpc-1.0.3.tar.gz && \ + wget --no-check-certificate https://mindx.obs.cn-south-1.myhuaweicloud.com/opensource/isl-0.16.1.tar.bz2 && \ + sed -i "246s/tar -xf "${ar}"/tar --no-same-owner -xf "${ar}"/" contrib/download_prerequisites && \ + ./contrib/download_prerequisites && \ + ./configure --enable-languages=c,c++ --disable-multilib --with-system-zlib --prefix=/usr/local/gcc7.3.0 && \ + make -j && make -j install && cd .. && \ + find gcc-7.3.0/ -name libstdc++.so.6.0.24 -exec cp {} /lib64/ \; && \ + rm -rf gcc-7.3.0* + +ENV LD_LIBRARY_PATH=/usr/local/gcc7.3.0/lib64:$LD_LIBRARY_PATH \ + PATH=/usr/local/gcc7.3.0/bin:$PATH + +# 3.安装cmake +RUN source /etc/profile && gcc -v && tar -zxf cmake-3.20.6.tar.gz && \ + cd cmake-3.20.6 && \ + ./bootstrap && make && make install && cd .. && \ + rm -rf cmake-3.20.6* + + +# 4.安装ucx +RUN source /etc/profile && gcc -v && unzip ucx-master.zip && \ + cd ucx-master && \ + ./autogen.sh && \ + ./contrib/configure-release --prefix=/usr/local/ucx && \ + make && make install && cd .. && \ + rm -rf ucx-master* + +# 5.安装openmpi,需要配置ucx +RUN source /etc/profile && gcc -v && tar -zxvf openmpi-4.1.5.tar.gz && \ + cd openmpi-4.1.5 && \ + ./configure --enable-orterun-prefix-by-default --prefix=/usr/local/openmpi --with-ucx=/usr/local/ucx && \ + make -j 16 && make install && cd .. && \ + rm -rf openmpi-4.1.5* + + +ENV LD_LIBRARY_PATH=/usr/local/openmpi/lib:$LD_LIBRARY_PATH \ + PATH=/usr/local/openmpi/bin:$PATH + +# 6.安装easy_profile +RUN source /etc/profile && gcc -v && tar -zxf v2.1.0 && \ + cd easy_profiler-2.1.0 && mkdir -p build && cd build && cmake .. && make -j && make install && \ + cd ../../ && rm -rf easy_profiler-2.1.0* + +SHELL ["/usr/bin/scl", "enable", "devtoolset-7"] + +# 7.安装python3.7.5 +RUN source /etc/profile && gcc -v && tar -xvf Python-3.7.5.tar.xz && \ + cd Python-3.7.5 && \ + mkdir -p build && cd build && \ + ../configure --enable-shared --prefix=/usr/local/python3.7.5 && \ + make -j && make install && \ + cd ../../ && rm -rf Python-3.7.5* && \ + ldconfig + +ENV PATH=$PATH:/usr/local/python3.7.5/bin \ + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/python3.7.5/lib + +# 配置python源 +RUN mkdir ~/.pip && touch ~/.pip/pip.conf && \ + echo "[global]" > ~/.pip/pip.conf && \ + echo "trusted-host=pypi.douban.com" >> ~/.pip/pip.conf && \ + echo "index-url=http://pypi.douban.com/simple/" >> ~/.pip/pip.conf && \ + echo "timeout=200" >> ~/.pip/pip.conf + +# 8.安装hdf5 +RUN source /etc/profile && gcc -v && tar -zxvf hdf5-1.10.5.tar.gz && \ + cd hdf5-1.10.5 && \ + ./configure --prefix=/usr/local/hdf5 && \ + make && make install && cd .. && rm -rf hdf5-1.10.5* + +ENV CPATH=/usr/local/hdf5/include/:/usr/local/hdf5/lib/ + +RUN ln -s /usr/local/hdf5/lib/libhdf5.so /usr/lib/libhdf5.so && \ + ln -s /usr/local/hdf5/lib/libhdf5_hl.so /usr/lib/libhdf5_hl.so + +ENV CC=/usr/lib64/openmpi/bin/mpicc + +# 9.安装python包 +RUN pip3.7 install -U pip && \ + pip3.7 install numpy && \ + pip3.7 install decorator && \ + pip3.7 install sympy==1.4 && \ + pip3.7 install cffi==1.12.3 && \ + pip3.7 install pyyaml && \ + pip3.7 install pathlib2 && \ + pip3.7 install grpcio && \ + pip3.7 install grpcio-tools && \ + pip3.7 install protobuf==3.20.0 && \ + pip3.7 install scipy && \ + pip3.7 install requests && \ + pip3.7 install mpi4py && \ + pip3.7 install scikit-learn && \ + pip3.7 install easydict && \ + pip3.7 install attrs && \ + pip3.7 install pytest==7.1.1 && \ + pip3.7 install pytest-cov==4.1.0 && \ + pip3.7 install pytest-html && \ + pip3.7 install Cython && \ + pip3.7 install h5py==3.1.0 && \ + rm -rf /root/.cache/pip + +# 10.设置驱动路径环境变量 +ARG ASCEND_BASE=/usr/local/Ascend +ENV LD_LIBRARY_PATH=$ASCEND_BASE/driver/lib64:$ASCEND_BASE/driver/lib64/common:$ASCEND_BASE/driver/lib64/driver:$LD_LIBRARY_PATH + +# 11.CANN相关参数 +ARG TOOLKIT_PKG=Ascend-cann-toolkit*.run +ARG TOOLKIT_PATH=$ASCEND_BASE/ascend-toolkit/latest + +# 12.TF相关 +ARG TFPLUGIN_PKG=Ascend-cann-tfplugin*.run +# MODIFIED TF=1.15.0 or TF=2.6.5,在arm环境下换成对应的whl包 +ARG TF_PKG=tensorflow-cpu== + +# 13.安装ascend-toolkit和tfplugin,及其他python依赖包 +RUN umask 0022 && \ + mkdir -p $ASCEND_BASE/driver && \ + cp version.info $ASCEND_BASE/driver/ && \ + cp ascend_install.info /etc/ && \ + chmod +x $TOOLKIT_PKG && \ + bash $TOOLKIT_PKG --quiet --install --install-path=$ASCEND_BASE && \ + source $ASCEND_BASE/ascend-toolkit/set_env.sh && \ + chmod +x ./$TFPLUGIN_PKG && \ + bash $TFPLUGIN_PKG --quiet --install --install-for-all && \ + source $ASCEND_BASE/tfplugin/set_env.sh && \ + rm -f ./$TFPLUGIN_PKG && \ + pip3.7 install $TF_PKG && \ + HOROVOD_WITH_MPI=1 HOROVOD_WITH_TENSORFLOW=1 pip3.7 install horovod --no-cache-dir && \ + pip3.7 install tf_slim && \ + pip3.7 install funcsigs && \ + rm -rf /root/.cache/pip && \ + rm -f $TOOLKIT_PKG && \ + rm -rf $ASCEND_BASE/driver && \ + rm -rf /etc/ascend_install.info + +# 14.安装mxRec,确认安装tf1或tf2 +RUN tar -zxvf Ascend-mindxsdk-mxrec*.tar.gz && \ + pip3 install mindxsdk-mxrec/{tf1|tf2}_whl/mx_rec-*.whl --force-reinstall + +# 15.清理临时目录 +RUN rm -rf ./* \ No newline at end of file diff --git a/docs/build_mxRec_images/mxrec-build/Dockerfile b/docs/build_mxRec_images/mxrec-build/Dockerfile new file mode 100644 index 00000000..424da3cd --- /dev/null +++ b/docs/build_mxRec_images/mxrec-build/Dockerfile @@ -0,0 +1,35 @@ +# please configure 根据实际情况使用基础镜像 +FROM mxrec-tf1:6.0.RC1 + +WORKDIR /tmp +COPY . ./ + +RUN chmod 777 /tmp + +# please configure 根据实际情况选择安装需要的依赖,如果一些依赖不需要可以将对应代码去掉或注释 + +# 设置驱动路径环境变量 +ARG ASCEND_BASE=/usr/local/Ascend + +# CANN相关参数 +ARG TOOLKIT_PKG=Ascend-cann-toolkit*.run +ARG TFPLUGIN_PKG=Ascend-cann-tfplugin*.run + +# 删除旧的CANN +RUN rm -rf $ASCEND_BASE/ascend-toolkit + +# 安装ascend-toolkit和tfplugin +RUN umask 0022 && \ + chmod +x $TOOLKIT_PKG && \ + bash $TOOLKIT_PKG --quiet --install --install-path=$ASCEND_BASE && \ + source $ASCEND_BASE/ascend-toolkit/set_env.sh && \ + chmod +x ./$TFPLUGIN_PKG && \ + bash $TFPLUGIN_PKG --quiet --install --install-for-all && \ + source $ASCEND_BASE/tfplugin/set_env.sh && \ + rm -f ./$TFPLUGIN_PKG && \ + rm -rf /root/.cache/pip && \ + rm -f $TOOLKIT_PKG + +# 安装mxRec,确认安装tf1或tf2 +RUN tar -zxvf Ascend-mindxsdk-mxrec*.tar.gz && \ + pip3 install mindxsdk-mxrec/{tf1|tf2}_whl/mx_rec-*.whl --force-reinstall -- Gitee From a96293b59f81a01a5b4cfd9fcf7f409daf988c21 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 30 Jan 2024 15:33:02 +0800 Subject: [PATCH 533/551] Match-id-6bd1917f91ac8491dfd89687ffca21a5b2e6dde4 --- .gitmodules | 12 - build/build.sh | 26 +- build/build_all.sh | 226 --- build/build_tf1_with_opensource.sh | 61 +- build/build_tf2.sh | 180 -- ...ld_tf1.sh => build_tf2_with_opensource.sh} | 200 +-- src/AccCTR/3rdparty/CMakeLists.txt | 47 + src/AccCTR/CMakeLists.txt | 120 ++ src/AccCTR/README.md | 9 + src/AccCTR/build.sh | 56 + src/AccCTR/build/build_3rdparty.sh | 82 + src/AccCTR/build/build_env.sh | 27 + src/AccCTR/build/build_pkg.sh | 41 + src/AccCTR/build/build_src.sh | 45 + src/AccCTR/build/build_test.sh | 90 + src/AccCTR/config.conf | 0 src/AccCTR/dependency.xml | 0 src/AccCTR/src/CMakeLists.txt | 64 + src/AccCTR/src/common/CMakeLists.txt | 15 + src/AccCTR/src/common/util/CMakeLists.txt | 17 + src/AccCTR/src/common/util/common_execption.h | 26 + src/AccCTR/src/common/util/common_includes.h | 33 + src/AccCTR/src/common/util/defines.h | 32 + src/AccCTR/src/common/util/error_code.h | 37 + .../src/common/util/external_logger.cpp | 48 + src/AccCTR/src/common/util/external_logger.h | 69 + .../src/common/util/external_thread.cpp | 33 + .../src/common/util/external_threader.h | 73 + src/AccCTR/src/common/util/lock.h | 214 +++ src/AccCTR/src/common/util/singleton.h | 47 + src/AccCTR/src/common/util/spinlock.h | 123 ++ src/AccCTR/src/common/util/time_cost.h | 49 + src/AccCTR/src/factory_impl.cpp | 69 + src/AccCTR/src/factory_impl.h | 39 + src/AccCTR/src/include/CMakeLists.txt | 19 + src/AccCTR/src/include/factory.h | 64 + src/AccCTR/src/include/ock_ctr_common_def.h | 62 + src/AccCTR/src/include/unique.h | 128 ++ src/AccCTR/src/unique/CMakeLists.txt | 27 + src/AccCTR/src/unique/unique_func.cpp | 193 +++ src/AccCTR/src/unique/unique_func.h | 575 +++++++ src/AccCTR/src/unique/unique_impl.cpp | 313 ++++ src/AccCTR/src/unique/unique_impl.h | 51 + src/AccCTR/tests/CMakeLists.txt | 21 + src/AccCTR/tests/tools/create_fake_id.py | 106 ++ src/AccCTR/tests/ut/CMakeLists.txt | 15 + src/AccCTR/tests/ut/src/CMakeLists.txt | 45 + src/AccCTR/tests/ut/src/gtest_main.cpp | 21 + src/AccCTR/tests/ut/src/unique_test.cpp | 1520 +++++++++++++++++ src/AccCTR/tests/ut/src/unique_test.h | 60 + src/CMakeLists.txt | 7 +- src/build.sh | 4 +- src/platform/AccCTR | 1 - src/test_ut.sh | 91 +- tests/run_python_dt.sh | 3 +- 55 files changed, 4949 insertions(+), 587 deletions(-) delete mode 100644 .gitmodules delete mode 100644 build/build_all.sh delete mode 100644 build/build_tf2.sh rename build/{build_tf1.sh => build_tf2_with_opensource.sh} (32%) create mode 100644 src/AccCTR/3rdparty/CMakeLists.txt create mode 100644 src/AccCTR/CMakeLists.txt create mode 100644 src/AccCTR/README.md create mode 100644 src/AccCTR/build.sh create mode 100644 src/AccCTR/build/build_3rdparty.sh create mode 100644 src/AccCTR/build/build_env.sh create mode 100644 src/AccCTR/build/build_pkg.sh create mode 100644 src/AccCTR/build/build_src.sh create mode 100644 src/AccCTR/build/build_test.sh create mode 100644 src/AccCTR/config.conf create mode 100644 src/AccCTR/dependency.xml create mode 100644 src/AccCTR/src/CMakeLists.txt create mode 100644 src/AccCTR/src/common/CMakeLists.txt create mode 100644 src/AccCTR/src/common/util/CMakeLists.txt create mode 100644 src/AccCTR/src/common/util/common_execption.h create mode 100644 src/AccCTR/src/common/util/common_includes.h create mode 100644 src/AccCTR/src/common/util/defines.h create mode 100644 src/AccCTR/src/common/util/error_code.h create mode 100644 src/AccCTR/src/common/util/external_logger.cpp create mode 100644 src/AccCTR/src/common/util/external_logger.h create mode 100644 src/AccCTR/src/common/util/external_thread.cpp create mode 100644 src/AccCTR/src/common/util/external_threader.h create mode 100644 src/AccCTR/src/common/util/lock.h create mode 100644 src/AccCTR/src/common/util/singleton.h create mode 100644 src/AccCTR/src/common/util/spinlock.h create mode 100644 src/AccCTR/src/common/util/time_cost.h create mode 100644 src/AccCTR/src/factory_impl.cpp create mode 100644 src/AccCTR/src/factory_impl.h create mode 100644 src/AccCTR/src/include/CMakeLists.txt create mode 100644 src/AccCTR/src/include/factory.h create mode 100644 src/AccCTR/src/include/ock_ctr_common_def.h create mode 100644 src/AccCTR/src/include/unique.h create mode 100644 src/AccCTR/src/unique/CMakeLists.txt create mode 100644 src/AccCTR/src/unique/unique_func.cpp create mode 100644 src/AccCTR/src/unique/unique_func.h create mode 100644 src/AccCTR/src/unique/unique_impl.cpp create mode 100644 src/AccCTR/src/unique/unique_impl.h create mode 100644 src/AccCTR/tests/CMakeLists.txt create mode 100644 src/AccCTR/tests/tools/create_fake_id.py create mode 100644 src/AccCTR/tests/ut/CMakeLists.txt create mode 100644 src/AccCTR/tests/ut/src/CMakeLists.txt create mode 100644 src/AccCTR/tests/ut/src/gtest_main.cpp create mode 100644 src/AccCTR/tests/ut/src/unique_test.cpp create mode 100644 src/AccCTR/tests/ut/src/unique_test.h delete mode 160000 src/platform/AccCTR diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 57a5bc65..00000000 --- a/.gitmodules +++ /dev/null @@ -1,12 +0,0 @@ -[submodule "src/thirdparty/googletest"] - path = src/thirdparty/googletest - url = https://codehub-dg-y.huawei.com/OpenSourceCenter/googletest.git -[submodule "src/thirdparty/spdlog"] - path = src/thirdparty/spdlog - url = https://codehub-dg-y.huawei.com/OpenSourceCenter/spdlog.git -[submodule "src/thirdparty/pybind11"] - path = src/thirdparty/pybind11 - url = https://codehub-dg-y.huawei.com/OpenSourceCenter/pybind11.git -[submodule "src/platform/AccCTR"] - path = src/platform/AccCTR - url = https://szv-y.codehub.huawei.com/ComputingFoundationSoftware/ock-ascend-domain/AccCTR.git diff --git a/build/build.sh b/build/build.sh index 8bf383de..ddbb777f 100644 --- a/build/build.sh +++ b/build/build.sh @@ -14,8 +14,6 @@ # limitations under the License. # ============================================================================== -export GLOG_CUSTOM_PREFIX_SUPPORT=1 - set -e warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } ARCH="$(uname -m)" @@ -32,7 +30,7 @@ get_version() { VERSION=${VERSION%.*} fi else - VERSION="5.0.rc3" + VERSION="5.0.0" fi } @@ -65,8 +63,6 @@ release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz gen_tar_file() { cd "${src_path}" - mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" - mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" # change dirs and files 's permission chmod 550 ../build/"${pkg_dir}"/tf1_whl @@ -106,8 +102,14 @@ clean() if [ "$(uname -m)" = "x86_64" ] then echo "-----Build gen tar -----" - bash ${ROOT_DIR}/build/build_tf1.sh - bash ${ROOT_DIR}/build/build_tf2.sh + source /opt/buildtools/tf1_env/bin/activate + pip3 install setuptools==65.6.3 + bash ${ROOT_DIR}/build/build_tf1_with_opensource.sh + deactivate tf1_env + source /opt/buildtools/tf2_env/bin/activate + pip3 install setuptools==65.6.3 + bash ${ROOT_DIR}/build/build_tf2_with_opensource.sh + deactivate tf2_env gen_tar_file echo "-----Build gen tar finished-----" @@ -118,8 +120,14 @@ fi if [ "$(uname -m)" = "aarch64" ] then echo "-----Build gen tar -----" - bash ${ROOT_DIR}/build/build_tf1.sh - bash ${ROOT_DIR}/build/build_tf2.sh + source /opt/buildtools/tf1_env/bin/activate + pip3 install setuptools==65.6.3 + bash ${ROOT_DIR}/build/build_tf1_with_opensource.sh + deactivate tf1_env + source /opt/buildtools/tf2_env/bin/activate + pip3 install setuptools==65.6.3 + bash ${ROOT_DIR}/build/build_tf2_with_opensource.sh + deactivate tf2_env gen_tar_file echo "-----Build gen tar finished-----" diff --git a/build/build_all.sh b/build/build_all.sh deleted file mode 100644 index 8f8d275c..00000000 --- a/build/build_all.sh +++ /dev/null @@ -1,226 +0,0 @@ -#!/bin/bash -# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -set -e -warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } -ARCH="$(uname -m)" -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -ROOT_DIR=$(dirname "${SCRIPT_DIR}") -cd "$SCRIPT_DIR" -if [ "$(uname -m)" = "aarch64" ] -then - source tf2_env/bin/activate - tf265="tensorflow-2.6.5-cp37-cp37m-manylinux2014_aarch64.whl" - [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ - tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow - deactivate tf2_env - - source tf1_env/bin/activate - tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2014_aarch64.whl" - [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ - tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core - deactivate tf1_env -fi - -if [ "$(uname -m)" = "x86_64" ] -then - source tf2_env/bin/activate - tf265="tensorflow_cpu-2.6.5-cp37-cp37m-manylinux2010_x86_64.whl" - [ ! -f "${tf265}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf265}" -ap ./ - tf2_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow - deactivate tf2_env - - source tf1_env/bin/activate - tf115="tensorflow-1.15.0-cp37-cp37m-manylinux2010_x86_64.whl" - [ ! -f "${tf115}" ] && artget pull "mindx_img_tools 1.0.0" -ru software -rp "${tf115}" -ap ./ - tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core - deactivate tf1_env -fi - -VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml -get_version() { - if [ -f "$VERSION_FILE" ]; then - VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") - if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then - VERSION=${VERSION%.*} - fi - else - VERSION="5.0.rc3" - fi -} - -remove() -{ - if [ -d "$1" ]; then - rm -rf "$1" - elif [ -f "$1" ]; then - rm -f "$1" - fi -} - -project_output_path="${ROOT_DIR}"/output/ -remove "${project_output_path}" -remove "${SCRIPT_DIR}/lib" -get_version -export VERSION -echo "MindX SDK mxrec: ${VERSION}" >> ./version.info - -pkg_dir=mindxsdk-mxrec -remove "${pkg_dir}" -mkdir "${pkg_dir}" -mv version.info "${pkg_dir}" - -opensource_path="${ROOT_DIR}"/../opensource/opensource -abseil_src_path=${opensource_path}/abseil -echo "${abseil_src_path}" -abseil_install_path="${ROOT_DIR}"/install/abseil - -src_path="${ROOT_DIR}"/src -acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR -cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c -cd "${ROOT_DIR}" - -release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz - -install_abseil() -{ - remove "${abseil_install_path}" - echo "${abseil_install_path}" - if [[ ! -d "${abseil_install_path}" ]] - then mkdir -p "${abseil_install_path}" - fi - - cd "${abseil_src_path}" - echo "${abseil_src_path}" - remove CMakeCache.txt - cmake -DCMAKE_INSTALL_PREFIX="${abseil_install_path}" . && make -j8 && make install - - echo "${project_output_path}"/abseil - mkdir -p "${project_output_path}"/abseil - if [ -d "${abseil_install_path}"/lib64/ ]; then - cp -rf "${abseil_install_path}"/lib64/libabsl* "${project_output_path}"/abseil - elif [ -d "${abseil_install_path}"/lib/ ]; then - cp -rf "${abseil_install_path}"/lib/libabsl* "${project_output_path}"/abseil - else - echo "${abseil_install_path}"/lib64/ not exist - exit 1 - fi -} - -compile_securec() -{ - if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then - echo "securec is not exist" - exit 1 - fi - - if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then - cd "${ROOT_DIR}"/platform/securec/src - make -j - fi -} - -compile_so_file() -{ - cd "${src_path}" - chmod u+x build.sh - ./build.sh "$1" "${ROOT_DIR}" - cd .. -} - -compile_acc_ctr_so_file() -{ - cd "${acc_ctr_path}" - chmod u+x build.sh - ./build.sh "release" -} - -collect_so_file() -{ - cd "${src_path}" - remove "${src_path}"/libasc - mkdir -p "${src_path}"/libasc - chmod u+x libasc - - cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc - cp -df "${ROOT_DIR}"/output/*.so* libasc - cp "${ROOT_DIR}"/platform/securec/lib/libsecurec.so libasc -} - -gen_wheel_file() -{ - cd "${ROOT_DIR}" - touch "${src_path}"/libasc/__init__.py - remove "${ROOT_DIR}"/mx_rec/libasc - mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec - python3 setup.py bdist_wheel --plat-name=linux_$(arch) - mkdir -p "$1" - mv dist/mx_rec*.whl "$1" - remove "${ROOT_DIR}"/mx_rec/libasc -} - -gen_tar_file() -{ - cd "${src_path}" - mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" - mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" - cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" - cd ../build - tar -zvcf "${release_tar}" "${pkg_dir}" || { - warn "compression failed, packages might be broken" - } - - mv "${release_tar}" "${SCRIPT_DIR}"/../output/ - -} - -clean() -{ - remove "${ROOT_DIR}"/dist - remove "${ROOT_DIR}"/install - remove "${ROOT_DIR}"/mx_rec.egg-info - remove "${ROOT_DIR}"/src/build - remove "${ROOT_DIR}"/build/bdist.linux-"$(arch)" - remove "${ROOT_DIR}"/build/tf1_env - remove "${ROOT_DIR}"/build/tf2_env - remove "${ROOT_DIR}"/build/lib - remove "${ROOT_DIR}"/build/mindxsdk-mxrec -} - -install_abseil -compile_securec - -echo "-----Build AccCTR -----" -compile_acc_ctr_so_file - -echo "-----Build Start tf1 -----" -source "${SCRIPT_DIR}"/tf1_env/bin/activate -compile_so_file "${tf1_path}" -collect_so_file -gen_wheel_file "${ROOT_DIR}"/tf1_whl -deactivate tf1_env - -echo "-----Build Start tf2 -----" -source "${SCRIPT_DIR}"/tf2_env/bin/activate -compile_so_file "${tf2_path}" -collect_so_file -gen_wheel_file "${ROOT_DIR}"/tf2_whl -deactivate tf2_env - -echo "-----Build gen tar -----" -gen_tar_file - -clean -echo "-----Done-----" diff --git a/build/build_tf1_with_opensource.sh b/build/build_tf1_with_opensource.sh index 487b6d85..37cfcf64 100644 --- a/build/build_tf1_with_opensource.sh +++ b/build/build_tf1_with_opensource.sh @@ -30,37 +30,34 @@ ARCH="$(uname -m)" SCRIPT_DIR=$(dirname "$(readlink -f "$0")") MxRec_DIR=$(dirname "${SCRIPT_DIR}") -opensource_path="${MxRec_DIR}"/opensource - -function prepare_pybind_and_securec() { +opensource_path="${MxRec_DIR}"/../opensource +if [ ! -d ${opensource_path} ]; then + echo "user should download dependency packages to mxRec/../opensource directory, see README.md" + exit -1 +fi + +function prepare_pybind(){ + cd "${opensource_path}" if [ ! -d pybind11 ]; then - if [ ! -d pybind11-v2.10.3 ]; then - unzip pybind11-v2.10.3.zip - fi - mv pybind11-v2.10.3 pybind11 - fi - - if [ ! -d glog ]; then - if [ ! -d glog-0.6.0 ]; then - tar -zxvf glog-0.6.0.tar.gz - fi - mv glog-0.6.0 glog + unzip pybind11-2.10.3.zip + mv pybind11-2.10.3 pybind11 fi +} +function prepare_securec(){ + cd "${opensource_path}" if [ ! -d securec ]; then - unzip securec.zip + unzip huaweicloud-sdk-c-obs-3.23.9.zip + mv huaweicloud-sdk-c-obs-3.23.9/platform/huaweisecurec securec + rm -rf huaweicloud-sdk-c-obs-3.23.9 rm -rf securec/lib/* - if [ ! -d ../platform ]; then - mkdir -p ../platform - cp -rf securec ../platform - fi fi } # 准备pybind11和securec -cd "${opensource_path}" -prepare_pybind_and_securec -cd - +echo "opensource path:${opensource_path}" +prepare_pybind +prepare_securec # 配置tf1路径 tf1_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow_core @@ -75,7 +72,7 @@ function get_version() { VERSION=${VERSION%.*} fi else - VERSION="5.0.rc3" + VERSION="5.0.0" fi } @@ -94,20 +91,19 @@ mv version.info "${pkg_dir}" # 配置MxRec C++代码路径和AccCTR路径 src_path="${MxRec_DIR}"/src -acc_ctr_path="${MxRec_DIR}"/src/platform/AccCTR -cp -rf "${MxRec_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c +acc_ctr_path="${MxRec_DIR}"/src/AccCTR cd "${MxRec_DIR}" function compile_securec() { - if [[ ! -d "${MxRec_DIR}"/platform/securec ]]; then - echo "securec is not exist" - exit 1 + if [[ ! -d "${opensource_path}"/securec ]]; then + echo "securec is not exist" + exit 1 fi - if [[ ! -f "${MxRec_DIR}"/platform/securec/lib/libsecurec.so ]]; then - cd "${MxRec_DIR}"/platform/securec/src - make -j + if [[ ! -f "${opensource_path}"/securec/lib/libsecurec.so ]]; then + cd "${opensource_path}"/securec/src + make -j4 fi } @@ -135,7 +131,7 @@ function collect_so_file() cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc cp -df "${MxRec_DIR}"/output/*.so* libasc - cp "${MxRec_DIR}"/platform/securec/lib/libsecurec.so libasc + cp "${opensource_path}"/securec/lib/libsecurec.so libasc } function gen_wheel_file() @@ -146,6 +142,7 @@ function gen_wheel_file() mv "${src_path}"/libasc "${MxRec_DIR}"/mx_rec python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) mkdir -p "$1" + echo "moving whl file $1" mv dist/mx_rec*.whl "$1" rm -rf "${MxRec_DIR}"/mx_rec/libasc } diff --git a/build/build_tf2.sh b/build/build_tf2.sh deleted file mode 100644 index 42321e89..00000000 --- a/build/build_tf2.sh +++ /dev/null @@ -1,180 +0,0 @@ -#!/bin/bash -# Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -set -e -warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } -ARCH="$(uname -m)" -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -ROOT_DIR=$(dirname "${SCRIPT_DIR}") -cd "$SCRIPT_DIR" - -if [ "$(uname -m)" = "x86_64" ] -then - source /opt/buildtools/tf2_env/bin/activate - tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow - deactivate tf2_env -fi - -if [ "$(uname -m)" = "aarch64" ] -then - source /opt/buildtools/tf2_env/bin/activate - tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow - deactivate tf2_env -fi - -VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml -get_version() { - if [ -f "$VERSION_FILE" ]; then - VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") - if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then - VERSION=${VERSION%.*} - fi - else - VERSION="5.0.rc3" - fi -} - -remove() -{ - if [ -d "$1" ]; then - rm -rf "$1" - elif [ -f "$1" ]; then - rm -f "$1" - fi -} - -project_output_path="${ROOT_DIR}"/output/ -remove "${project_output_path}" -remove "${SCRIPT_DIR}/lib" -get_version -export VERSION -echo "MindX SDK mxrec: ${VERSION}" >> ./version.info -chmod 640 ./version.info - -pkg_dir=mindxsdk-mxrec -remove "${pkg_dir}" -mkdir "${pkg_dir}" -chmod 750 "$pkg_dir" -mv version.info "${pkg_dir}" - -opensource_path="${ROOT_DIR}"/../opensource/opensource - -src_path="${ROOT_DIR}"/src -acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR -cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c -cd "${ROOT_DIR}" - -release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz - -compile_securec() -{ - if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then - echo "securec is not exist" - exit 1 - fi - - if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then - cd "${ROOT_DIR}"/platform/securec/src - make -j - fi -} - -compile_so_file() -{ - cd "${src_path}" - chmod u+x build.sh - ./build.sh "$1" "${ROOT_DIR}" "NO" - cd .. -} - -compile_acc_ctr_so_file() -{ - cd "${acc_ctr_path}" - chmod u+x build.sh - ./build.sh "release" -} - -collect_so_file() -{ - cd "${src_path}" - remove "${src_path}"/libasc - mkdir -p "${src_path}"/libasc - chmod u+x libasc - - cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc - cp -df "${ROOT_DIR}"/output/*.so* libasc - cp "${ROOT_DIR}"/platform/securec/lib/libsecurec.so libasc -} - -gen_wheel_file() -{ - cd "${ROOT_DIR}" - touch "${src_path}"/libasc/__init__.py - remove "${ROOT_DIR}"/mx_rec/libasc - mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec - python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) - mkdir -p "$1" - mv dist/mx_rec*.whl "$1" - # remove "${ROOT_DIR}"/mx_rec/libasc -} - -gen_tar_file() -{ - cd "${src_path}" - mv "${ROOT_DIR}"/tf2_whl ../build/"${pkg_dir}" - cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" - cd ../build - tar -zvcf "${release_tar}" "${pkg_dir}" || { - warn "compression failed, packages might be broken" - } - - mv "${release_tar}" "${SCRIPT_DIR}"/../output/ - -} - -if [ "$(uname -m)" = "x86_64" ] -then - compile_securec - - echo "-----Build AccCTR -----" - compile_acc_ctr_so_file - - echo "-----Build Start tf2 -----" - source /opt/buildtools/tf2_env/bin/activate - compile_so_file "${tf2_path}" - collect_so_file - gen_wheel_file "${ROOT_DIR}"/tf2_whl - - deactivate tf2_env - echo "-----Build tf2 finished -----" -fi - -if [ "$(uname -m)" = "aarch64" ] -then - compile_securec - - echo "-----Build AccCTR -----" - compile_acc_ctr_so_file - - echo "-----Build Start tf2 -----" - source /opt/buildtools/tf2_env/bin/activate - compile_so_file "${tf2_path}" - collect_so_file - gen_wheel_file "${ROOT_DIR}"/tf2_whl - - deactivate tf2_env - echo "-----Build tf2 finished -----" -fi \ No newline at end of file diff --git a/build/build_tf1.sh b/build/build_tf2_with_opensource.sh similarity index 32% rename from build/build_tf1.sh rename to build/build_tf2_with_opensource.sh index 11b57c89..bf4a5b03 100644 --- a/build/build_tf1.sh +++ b/build/build_tf2_with_opensource.sh @@ -14,167 +14,155 @@ # limitations under the License. # ============================================================================== +################################################################## +# build_tf2_with_opensource.sh 编译MxRec和动态扩容算子 +# 编译环境:Python3.7.5 GCC 7.3.0 CMake 3.20.6 +# 代码主要分为四部分: +# 1、准备编译MxRec所需依赖:pybind11(v2.10.3) securec +# 2、编译securec、AccCTR以及MxRec +# 3、生成MxRec Wheel包,生成的whl包在当前目录下的mindxsdk-mxrec/tf2_whl +# 4、编译动态扩容算子 +################################################################## + set -e warn() { echo >&2 -e "\033[1;31m[WARN ][Depend ] $1\033[1;37m" ; } ARCH="$(uname -m)" SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -ROOT_DIR=$(dirname "${SCRIPT_DIR}") -cd "$SCRIPT_DIR" - -if [ "$(uname -m)" = "x86_64" ] -then - source /opt/buildtools/tf1_env/bin/activate - pip3 install setuptools==65.6.3 - tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core - deactivate tf1_env -fi +MxRec_DIR=$(dirname "${SCRIPT_DIR}") -if [ "$(uname -m)" = "aarch64" ] -then - source /opt/buildtools/tf1_env/bin/activate - tf1_path=$(dirname "$(dirname "$(which python3)")")/lib/python3.7/site-packages/tensorflow_core - deactivate tf1_env +opensource_path="${MxRec_DIR}"/../opensource +if [ ! -d ${opensource_path} ]; then + echo "user should download dependency packages to mxRec/../opensource directory, see README.md" + exit -1 fi -VERSION_FILE="${ROOT_DIR}"/../mindxsdk/build/conf/config.yaml -get_version() { +function prepare_pybind(){ + cd "${opensource_path}" + if [ ! -d pybind11 ]; then + unzip pybind11-2.10.3.zip + mv pybind11-2.10.3 pybind11 + fi +} + +function prepare_securec(){ + cd "${opensource_path}" + if [ ! -d securec ]; then + unzip huaweicloud-sdk-c-obs-3.23.9.zip + mv huaweicloud-sdk-c-obs-3.23.9/platform/huaweisecurec securec + rm -rf huaweicloud-sdk-c-obs-3.23.9 + rm -rf securec/lib/* + fi +} + +# 准备pybind11和securec +echo "opensource path:${opensource_path}" +prepare_pybind +prepare_securec + +# 配置tf2路径 +tf2_path=$(dirname "$(dirname "$(which python3.7)")")/lib/python3.7/site-packages/tensorflow + +project_output_path="${MxRec_DIR}"/output/ +VERSION_FILE="${MxRec_DIR}"/../mindxsdk/build/conf/config.yaml + +function get_version() { if [ -f "$VERSION_FILE" ]; then VERSION=$(sed '/.*mindxsdk:/!d;s/.*: //' "$VERSION_FILE") if [[ "$VERSION" == *.[b/B]* ]] && [[ "$VERSION" != *.[RC/rc]* ]]; then VERSION=${VERSION%.*} fi else - VERSION="5.0.rc3" + VERSION="5.0.0" fi } -remove() -{ - if [ -d "$1" ]; then - rm -rf "$1" - elif [ -f "$1" ]; then - rm -f "$1" - fi -} +rm -rf "${project_output_path}" +rm -rf "${SCRIPT_DIR}/lib" -project_output_path="${ROOT_DIR}"/output/ -remove "${project_output_path}" -remove "${SCRIPT_DIR}/lib" +# 获取MxRec版本信息 get_version export VERSION -echo "MindX SDK mxrec: ${VERSION}" >> ./version.info -chmod 640 ./version.info +echo "MindX SDK MxRec: ${VERSION}" >> ./version.info pkg_dir=mindxsdk-mxrec -remove "${pkg_dir}" +rm -rf "${pkg_dir}" mkdir "${pkg_dir}" -chmod 750 "$pkg_dir" mv version.info "${pkg_dir}" -opensource_path="${ROOT_DIR}"/../opensource/opensource +# 配置MxRec C++代码路径和AccCTR路径 +src_path="${MxRec_DIR}"/src +acc_ctr_path="${MxRec_DIR}"/src/AccCTR +cd "${MxRec_DIR}" -src_path="${ROOT_DIR}"/src -acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR -cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c -cd "${ROOT_DIR}" - -release_tar=Ascend-"${pkg_dir}"_"${VERSION}"_linux-"${ARCH}".tar.gz - -compile_securec() +function compile_securec() { - if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then - echo "securec is not exist" - exit 1 + if [[ ! -d "${opensource_path}"/securec ]]; then + echo "securec is not exist" + exit 1 fi - if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then - cd "${ROOT_DIR}"/platform/securec/src - make -j + if [[ ! -f "${opensource_path}"/securec/lib/libsecurec.so ]]; then + cd "${opensource_path}"/securec/src + make -j4 fi } -compile_so_file() +function compile_so_file() { cd "${src_path}" chmod u+x build.sh - ./build.sh "$1" "${ROOT_DIR}" "NO" + ./build.sh "$1" "${MxRec_DIR}" "YES" cd .. } -compile_acc_ctr_so_file() +function compile_acc_ctr_so_file() { cd "${acc_ctr_path}" chmod u+x build.sh ./build.sh "release" } -collect_so_file() +function collect_so_file() { cd "${src_path}" - remove "${src_path}"/libasc + rm -rf "${src_path}"/libasc mkdir -p "${src_path}"/libasc chmod u+x libasc cp ${acc_ctr_path}/output/ock_ctr_common/lib/* libasc - cp -df "${ROOT_DIR}"/output/*.so* libasc - cp "${ROOT_DIR}"/platform/securec/lib/libsecurec.so libasc + cp -df "${MxRec_DIR}"/output/*.so* libasc + cp "${opensource_path}"/securec/lib/libsecurec.so libasc } -gen_wheel_file() +function gen_wheel_file() { - cd "${ROOT_DIR}" + cd "${MxRec_DIR}" touch "${src_path}"/libasc/__init__.py - remove "${ROOT_DIR}"/mx_rec/libasc - mv "${src_path}"/libasc "${ROOT_DIR}"/mx_rec + rm -rf "${MxRec_DIR}"/mx_rec/libasc + mv "${src_path}"/libasc "${MxRec_DIR}"/mx_rec python3.7 setup.py bdist_wheel --plat-name=linux_$(arch) mkdir -p "$1" + echo "moving whl file $1" mv dist/mx_rec*.whl "$1" - remove "${ROOT_DIR}"/mx_rec/libasc + rm -rf "${MxRec_DIR}"/mx_rec/libasc } -gen_tar_file() -{ - cd "${src_path}" - mv "${ROOT_DIR}"/tf1_whl ../build/"${pkg_dir}" - cp -r "${src_path}"/../cust_op ../build/"${pkg_dir}" - cd ../build - tar -zvcf "${release_tar}" "${pkg_dir}" || { - warn "compression failed, packages might be broken" - } - - mv "${release_tar}" "${SCRIPT_DIR}"/../output/ - -} - - -if [ "$(uname -m)" = "x86_64" ] -then - compile_securec - - echo "-----Build AccCTR -----" - compile_acc_ctr_so_file - - echo "-----Build Start tf1 -----" - source /opt/buildtools/tf1_env/bin/activate - compile_so_file "${tf1_path}" - collect_so_file - gen_wheel_file "${ROOT_DIR}"/tf1_whl - deactivate tf1_env - echo "-----Build tf1 finished-----" -fi - -if [ "$(uname -m)" = "aarch64" ] -then - compile_securec - - echo "-----Build AccCTR -----" - compile_acc_ctr_so_file - - echo "-----Build Start tf1 -----" - source /opt/buildtools/tf1_env/bin/activate - compile_so_file "${tf1_path}" - collect_so_file - gen_wheel_file "${ROOT_DIR}"/tf1_whl - deactivate tf1_env - echo "-----Build tf1 finished-----" -fi \ No newline at end of file +# start to build MxRec +echo "---------------- compile securec ----------------" +compile_securec +echo "---------------- compile AccCTR ----------------" +compile_acc_ctr_so_file +echo "---------------- compile MxRec so files ----------------" +compile_so_file "${tf2_path}" +echo "---------------- collect so files and mv them to libasc ----------------" +collect_so_file +echo "---------------- generate MxRec wheel package ----------------" +gen_wheel_file "$SCRIPT_DIR"/"${pkg_dir}"/tf2_whl +echo "---------------- compile MxRec success!!!! ----------------" + +# start to compile cust op +echo "---------------- start to compile cust op ----------------" +cd "${MxRec_DIR}"/cust_op/cust_op_by_addr +chmod u+x run.sh +./run.sh +echo "---------------- compile cust op success!!!! ----------------" \ No newline at end of file diff --git a/src/AccCTR/3rdparty/CMakeLists.txt b/src/AccCTR/3rdparty/CMakeLists.txt new file mode 100644 index 00000000..a17e472c --- /dev/null +++ b/src/AccCTR/3rdparty/CMakeLists.txt @@ -0,0 +1,47 @@ +message("build mode " ${BUILD_MODE}) + +set(PLATFORM_UTILITIES_3RDPARTY_SOURCE_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) +set(PLATFORM_UTILITIES_3RDPARTY_BUILD_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) +set(PLATFORM_UTILITIES_3RDPARTY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/install) +set(GTEST_SOURCE_DIR ${PLATFORM_UTILITIES_3RDPARTY_SOURCE_DIR}/googletest-release-1.8.1) +set(GTEST_BUILD_DIR ${PLATFORM_UTILITIES_3RDPARTY_BUILD_DIR}/googletest-release-1.8.1) +set(GTEST_INSTALL_DIR ${PLATFORM_UTILITIES_3RDPARTY_INSTALL_DIR}/googletest-release-1.8.1) +set(SECUREC_SOURCE_DIR ${PLATFORM_UTILITIES_3RDPARTY_SOURCE_DIR}/securec) +set(SECUREC_BUILD_DIR ${PLATFORM_UTILITIES_3RDPARTY_BUILD_DIR}/securec) +set(SECUREC_INSTALL_DIR ${PLATFORM_UTILITIES_3RDPARTY_INSTALL_DIR}/securec) + +add_definitions(_DDOFUN) +set(DOFUN "FALSE") + +if (${BUILD_MODE} MATCHES "ut") + set(DOFUN "TRUE") +endif (${BUILD_MODE} MATCHES "ut") + +message("build securec") +# create build dir +exec_program(mkdir ${PLATFORM_UTILITIES_3RDPARTY_BUILD_DIR} ARGS -p ${SECUREC_BUILD_DIR}) +exec_program(mkdir ${PLATFORM_UTILITIES_3RDPARTY_INSTALL_DIR} ARGS -p ${SECUREC_INSTALL_DIR}) + +# execute make && make install +exec_program(make ${SECUREC_SOURCE_DIR}/src ARGS -j) + +# scp -r ${SECUREC_SRC_PATH}/../include ${SECUREC_INSTALL_PATH}/ +# scp -r ${SECUREC_SRC_PATH}/../lib ${SECUREC_INSTALL_PATH}/ +exec_program(scp ARGS -r ${SECUREC_SOURCE_DIR}/include ${SECUREC_INSTALL_DIR}) +exec_program(scp ARGS -r ${SECUREC_SOURCE_DIR}/lib ${SECUREC_INSTALL_DIR}) + + +message(============ ${DOFUN}) +if (${DOFUN} MATCHES TRUE) + message("build gTest") + # create build dir + exec_program(mkdir ${PLATFORM_UTILITIES_3RDPARTY_BUILD_DIR} ARGS -p ${GTEST_BUILD_DIR}) + + # configure + exec_program(cmake ${GTEST_BUILD_DIR} ARGS -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=${GTEST_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR=lib64 ${GTEST_SOURCE_DIR}) + + # execute make && make install + exec_program(make ${GTEST_BUILD_DIR} ARGS clean) + exec_program(make ${GTEST_BUILD_DIR} ARGS -j8) + exec_program("make install" ${GTEST_BUILD_DIR}) +endif () \ No newline at end of file diff --git a/src/AccCTR/CMakeLists.txt b/src/AccCTR/CMakeLists.txt new file mode 100644 index 00000000..0cb63176 --- /dev/null +++ b/src/AccCTR/CMakeLists.txt @@ -0,0 +1,120 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +cmake_minimum_required(VERSION 3.14.1) + +project(ctr CXX C) + +if (${BUILD_MODE} MATCHES "release") + message("======BUILD_MODE release") + set(CXX_FLAGS + -O3 + -Wall + -fPIC + -fms-extensions + -Wno-unused-parameter + -Wno-unused-function + -Wunused-variable + -Wunused-value + -Wcast-align + -Wcast-qual + -Winvalid-pch + -Wwrite-strings + -Wsign-compare + -Wfloat-equal + -Wextra + -D_FORTIFY_SOURCE=2 + -std=c++17 + -fstack-protector-all + -fstack-protector-strong + ) +elseif (${BUILD_MODE} MATCHES "debug") + message("======BUILD_MODE debug") + set(CXX_FLAGS + -g + -O0 + -Wall + -fPIC + -fms-extensions + -Wno-unused-parameter + -Wno-unused-function + -Wunused-variable + -Wunused-value + -Winvalid-pch + -Wcast-align + -Wcast-qual + -Wwrite-strings + -Wsign-compare + -Wfloat-equal + -Wextra + -std=c++17 + ) +elseif (${BUILD_MODE} MATCHES "ut") + message("======BUILD_MODE ut") + set(CXX_FLAGS + -g + -Wall + -fPIC + -fms-extensions + -Wno-unused-parameter + -Wno-unused-function + -Wunused-variable + -Wunused-value + -Winvalid-pch + -Wcast-align + -Wcast-qual + -Wwrite-strings + -Wsign-compare + -Wfloat-equal + -Wextra + -std=c++17 + #-fsanitize=address + #-fno-omit-frame-pointer + #-fstack-protector-all + #-fstack-protector-strong + ) +else () + message(FATAL_ERROR "======BUILD_MODE not found") +endif (${BUILD_MODE} MATCHES "release") + +string(REPLACE ";" " " CMAKE_CXX_FLAGS "${CXX_FLAGS} ${CMAKE_CXX_FLAGS}") + +if (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64") + add_definitions(-D__AARCH64__) + set(CMAKE_ARC linux-aarch64) +elseif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "x86_64") + add_definitions(-D__X86_64__) + set(CMAKE_ARC linux-x86_64) + set(CXX_FLAGS + ${CXX_FLAGS} + -msse2 + -mavx + #-w + ) +else () + message(FATAL_ERROR "don't support ${CMAKE_HOST_SYSTEM_PROCESSOR}") +endif (${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64") + +set(OCK_CTR_PLATFORM_UTIL_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) +message(===============${OCK_CTR_PLATFORM_UTIL_DIR}) +include_directories(${OCK_CTR_PLATFORM_UTIL_DIR}/securec/include) + +add_subdirectory(3rdparty) +add_subdirectory(src) + +if (${BUILD_MODE} MATCHES "release") +elseif (${BUILD_MODE} MATCHES "debug") +else () + add_subdirectory(tests) +endif (${BUILD_MODE} MATCHES "release") \ No newline at end of file diff --git a/src/AccCTR/README.md b/src/AccCTR/README.md new file mode 100644 index 00000000..1a394699 --- /dev/null +++ b/src/AccCTR/README.md @@ -0,0 +1,9 @@ +# AccCTR + +使用方法: + +1、bash build.sh release //编译release + +2、bash build.sh debug //编译debug + +3、bash build.sh ut //编译并运行ut,覆盖率在tests/build/cov/gen目录下 diff --git a/src/AccCTR/build.sh b/src/AccCTR/build.sh new file mode 100644 index 00000000..2f0b57f2 --- /dev/null +++ b/src/AccCTR/build.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Used for memory pool kit builds +# BUILD_TYPE 1、release 2、ut 3、perf +set -ex +readonly CURRENT_PATH=$(cd "$(dirname "$0")"; pwd) +BUILD_TYPE=$1 + +TOP_DIR=${CURRENT_PATH} +OUTPUT_PATH=${TOP_DIR}/output +OCK_CTR_PATH=${TOP_DIR} +OCK_CTR_DIAGNOSE_PATH=${TOP_DIR}/tests +BUILD_PATH=${CURRENT_PATH}/build + +# default is use build Release version +if [ "${BUILD_TYPE}" == "debug" ];then + BUILD_MODE="debug" +elif [ "${BUILD_TYPE}" == "ut" ];then + BUILD_MODE="ut" +else + BUILD_MODE="release" +fi + +cd ${BUILD_PATH} +cmake ${OCK_CTR_PATH} -DCMAKE_INSTALL_PREFIX:STRING=${OUTPUT_PATH}/ock_ctr_common -DCTR_ENV=${CPU_TYPE} -DBUILD_MODE=${BUILD_MODE} +if [ 0 != $? ];then + echo "Failed to build_src" + exit 1 +fi + +make clean; make -j 4; make install; +if [ 0 != $? ];then + echo "Failed to build_src" + exit 1 +fi +cd - + +if [[ "${BUILD_TYPE}" == "ut" ]];then + cp ${CURRENT_PATH}/../../../opensource/securec/lib/libsecurec.so ${CURRENT_PATH}/output/ock_ctr_common/lib/ + export LD_LIBRARY_PATH=${CURRENT_PATH}/output/ock_ctr_common/lib:$LD_LIBRARY_PATH +fi + +echo "build end!" \ No newline at end of file diff --git a/src/AccCTR/build/build_3rdparty.sh b/src/AccCTR/build/build_3rdparty.sh new file mode 100644 index 00000000..8c119c6d --- /dev/null +++ b/src/AccCTR/build/build_3rdparty.sh @@ -0,0 +1,82 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# build 3rdparty +set -e +CURRENT_PATH=$( + cd "$(dirname "$0")" + pwd +) +BUILD_MODE=$1 +INSTALL_PATH=${CURRENT_PATH}/../install +OPENSOURCE_PATH=${CURRENT_PATH}/../../../../opensource + +if [ ! -d "${INSTALL_PATH}" ]; then + mkdir -p ${INSTALL_PATH} +else + rm -rf ${INSTALL_PATH}/* +fi + +GTEST_SRC_PATH=${OPENSOURCE_PATH}/googletest-release-1.8.1 +echo "${GTEST_SRC_PATH}" +GTEST_INSTALL_PATH=${INSTALL_PATH}/googletest-release-1.8.1 + +install_gtest() { + [ -n "${GTEST_INSTALL_PATH}" ] && rm -rf "${GTEST_INSTALL_PATH}" + echo "${GTEST_INSTALL_PATH}" + if [[ ! -d "${GTEST_INSTALL_PATH}" ]]; then + mkdir -p "${GTEST_INSTALL_PATH}" + fi + + cd "${GTEST_SRC_PATH}" + echo "${GTEST_SRC_PATH}" + cmake -DCMAKE_INSTALL_PREFIX="${GTEST_INSTALL_PATH}" -DCMAKE_INSTALL_LIBDIR=lib64 . && make && make install +} + +function prepare_securec(){ + cd ${OPENSOURCE_PATH} + if [ ! -d securec ]; then + unzip huaweicloud-sdk-c-obs-3.23.9.zip + mv huaweicloud-sdk-c-obs-3.23.9/platform/huaweisecurec securec + rm -rf huaweicloud-sdk-c-obs-3.23.9 + rm -rf securec/lib/* + fi +} + +prepare_securec +SECUREC_SRC_PATH=${OPENSOURCE_PATH}/securec/src +echo "${SECUREC_SRC_PATH}" +SECUREC_INSTALL_PATH=${INSTALL_PATH}/securec +compile_securec() { + [ -n "${SECUREC_INSTALL_PATH}" ] && rm -rf "${SECUREC_INSTALL_PATH}" + echo "${SECUREC_INSTALL_PATH}" + if [[ ! -d "${SECUREC_INSTALL_PATH}" ]]; then + mkdir -p "${SECUREC_INSTALL_PATH}" + fi + cd "${SECUREC_SRC_PATH}" + make -j + scp -r ${SECUREC_SRC_PATH}/../include ${SECUREC_INSTALL_PATH}/ + scp -r ${SECUREC_SRC_PATH}/../lib ${SECUREC_INSTALL_PATH}/ +} + + +if [[ "${BUILD_MODE}" == "ut" ]];then + BUILD_MODE="debug" + install_gtest + echo "compiled GTest" +fi + +compile_securec +echo "compiled huawei securec" \ No newline at end of file diff --git a/src/AccCTR/build/build_env.sh b/src/AccCTR/build/build_env.sh new file mode 100644 index 00000000..d38f41ad --- /dev/null +++ b/src/AccCTR/build/build_env.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# set env for building project +set -e +CURRENT_PATH=$(cd "$(dirname "$0")"; pwd) +TOP_DIR=${CURRENT_PATH}/.. +OUTPUT_PATH=${TOP_DIR}/output + +OCK_CTR_PATH=${TOP_DIR}/src +OCK_CTR_DIAGNOSE_PATH=${TOP_DIR}/tests +OCK_CTR_OPENSOURCE_PATH=${TOP_DIR}/../../../opensource + +CPU_TYPE=$(arch) +OCK_VERSION=22.0.0 diff --git a/src/AccCTR/build/build_pkg.sh b/src/AccCTR/build/build_pkg.sh new file mode 100644 index 00000000..573380df --- /dev/null +++ b/src/AccCTR/build/build_pkg.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# build pkg +set -e +readonly CURRENT_PATH=$(cd "$(dirname "$0")"; pwd) +PKG_PATH=${CURRENT_PATH}/../pkg +OUTPUT_PATH=${CURRENT_PATH}/../output +if [ ! -d "${OUTPUT_PATH}" ]; then + echo "${OUTPUT_PATH} not exist, user should call build_src.sh first" +fi + +INSTALL_PATH=${CURRENT_PATH}/../install + +echo "${PKG_PATH}" +if [ ! -d "${PKG_PATH}" ]; then + mkdir -p ${PKG_PATH} +else + rm -rf ${PKG_PATH}/* +fi + +scp -r ${OUTPUT_PATH}/* ${PKG_PATH}/ +scp -r ${INSTALL_PATH}/securec/include/* ${PKG_PATH}/ock_ctr_common/include/ +scp -r ${INSTALL_PATH}/securec/lib/* ${PKG_PATH}/ock_ctr_common/lib/ + + + + + diff --git a/src/AccCTR/build/build_src.sh b/src/AccCTR/build/build_src.sh new file mode 100644 index 00000000..5ffd6ab7 --- /dev/null +++ b/src/AccCTR/build/build_src.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# build src +set -e +CURRENT_PATH=$(dirname "$0") +BUILD_PATH=${CURRENT_PATH}/src +CPU_TYPE=$(arch) +BUILD_MODE=$1 +echo "${BUILD_PATH}" +if [ ! -d "${BUILD_PATH}" ]; then + mkdir -p ${BUILD_PATH} +else + rm -rf ${BUILD_PATH}/* +fi + +source ${CURRENT_PATH}/build_env.sh +cd ${BUILD_PATH} +cmake ${OCK_CTR_PATH} -DCMAKE_INSTALL_PREFIX:STRING=${OUTPUT_PATH}/ock_ctr_common -DCTR_ENV=${CPU_TYPE} -DBUILD_MODE=${BUILD_MODE} + +if [ 0 != $? ];then + echo "cmake failed." + exit 1 +fi +echo "cmake success." + +make clean; make -j 4; make install; + +if [ 0 != $? ];then + echo "make failed." + exit 1 +fi +echo "make success." diff --git a/src/AccCTR/build/build_test.sh b/src/AccCTR/build/build_test.sh new file mode 100644 index 00000000..9441efe3 --- /dev/null +++ b/src/AccCTR/build/build_test.sh @@ -0,0 +1,90 @@ +#!/bin/bash +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# build test +set -e +readonly CURRENT_PATH=$(cd "$(dirname "$0")"; pwd) +DATA_PATH=${CURRENT_PATH}/../data +BUILD_PATH=${CURRENT_PATH}/tests +TOOL_PATH=${CURRENT_PATH}/../tests/tools +UT_PATH=${CURRENT_PATH}/tests/ut +TOOL_FILE="create_fake_id.py" +CPU_TYPE=$(arch) +BUILD_MODE=$1 + +create_data() +{ + cd ${TOOL_PATH} + python3 $TOOL_FILE +} + +ut_cover() +{ + cd ${UT_PATH}/src + scp -r ${TOOL_PATH}/*.txt . + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${CURRENT_PATH}/src + ./test_unique_files --gtest_output=xml:./ + GENERATE_DIR=${BUILD_PATH}/cov/gen + rm -rf ${BUILD_PATH}/cov/; mkdir -p ${GENERATE_DIR} + + echo "================"${CURRENT_PATH} + find ${CURRENT_PATH}/.. -name "*.gcda" | xargs -i mv {} ${GENERATE_DIR} + find ${CURRENT_PATH}/.. -name "*.gcno" | xargs -i mv {} ${GENERATE_DIR} + + lcov --d ${GENERATE_DIR} --c --output-file ${GENERATE_DIR}/coverage.info --rc lcov_branch_coverage=1 + if [ 0 != $? ];then + echo "Failed to generate all coverage info" + exit 1 + fi + + lcov -r ${GENERATE_DIR}/coverage.info "*7.3.0*" -o ${GENERATE_DIR}/coverage.info --rc lcov_branch_coverage=1 + if [ 0 != $? ];then + echo "Failed to remove *7.3.0* from coverage info" + exit 1 + fi + + lcov -r ${GENERATE_DIR}/coverage.info "*tests/ut*" -o ${GENERATE_DIR}/coverage.info --rc lcov_branch_coverage=1 + if [ 0 != $? ];then + echo "Failed to remove *tests/ut* from coverage info" + exit 1 + fi + + lcov -r ${GENERATE_DIR}/coverage.info "*install*" -o ${GENERATE_DIR}/coverage.info --rc lcov_branch_coverage=1 + if [ 0 != $? ];then + echo "Failed to remove *install* from coverage info" + exit 1 + fi + + genhtml -o ${GENERATE_DIR}/result ${GENERATE_DIR}/coverage.info --show-details --legend --rc lcov_branch_coverage=1 + if [ 0 != $? ];then + echo "Failed to generate all coverage info with html format" + exit 1 + fi +} + +if [ "${BUILD_MODE}" == "ut" ]; then + create_data + ut_cover + if [ 0 != $? ];then + echo "Failed to ut_cover" + exit 1 + fi +elif [ "${BUILD_MODE}" == "debug" ];then + echo "BUILD_MODE ${BUILD_MODE} skip" +elif [ "${BUILD_MODE}" == "release" ];then + echo "BUILD_MODE "${BUILD_MODE}" skip" +else + echo "BUILD_MODE "${BUILD_MODE}" not exists" +fi \ No newline at end of file diff --git a/src/AccCTR/config.conf b/src/AccCTR/config.conf new file mode 100644 index 00000000..e69de29b diff --git a/src/AccCTR/dependency.xml b/src/AccCTR/dependency.xml new file mode 100644 index 00000000..e69de29b diff --git a/src/AccCTR/src/CMakeLists.txt b/src/AccCTR/src/CMakeLists.txt new file mode 100644 index 00000000..09da4670 --- /dev/null +++ b/src/AccCTR/src/CMakeLists.txt @@ -0,0 +1,64 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set(OCK_ACCCTR_LINK_FLAGS -Wl,-z,relro,-z,now,-z,noexecstack -s) + +set(OCK_CTR_SRC_DIR ${PROJECT_SOURCE_DIR}) +set(OCK_CTR_SRC_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/src/include) +set(OCK_CTR_COMMON_DIR ${PROJECT_SOURCE_DIR}/src/common) +set(OCK_CTR_BUILD_PATH ${PROJECT_SOURCE_DIR}/build) + +set(OUTPUT ${PROJECT_SOURCE_DIR}/output) +set(OCK_CTR_PLATFORM_UTIL_DIR ${PROJECT_SOURCE_DIR}/../../../opensource) +set(OCK_CTR_UTIL_INSTALL_DIR ${PROJECT_SOURCE_DIR}/install) + + +if (${BUILD_MODE} MATCHES "ut") + add_compile_options(-ftest-coverage -fprofile-arcs) + link_libraries(gcov) +endif (${BUILD_MODE} MATCHES "ut") + + +message("include : " ${OCK_CTR_SRC_INCLUDE_DIR}) + +set(LIB_HW_SECURE ${OCK_CTR_PLATFORM_UTIL_DIR}/securec/lib/libsecurec.so) + +add_subdirectory(include) +add_subdirectory(common) +add_subdirectory(unique) + + +file(GLOB_RECURSE CTR_SRC factory_impl.cpp) + +add_library(_ock_ctr_common SHARED ${CTR_SRC}) + +target_include_directories(_ock_ctr_common + PUBLIC + ${PROJECT_SOURCE_DIR} + ${OCK_CTR_SRC_INCLUDE_DIR} + ${OCK_CTR_COMMON_DIR}/util) + +target_link_libraries(_ock_ctr_common PUBLIC + -Wl,--start-group + unique + dl + utils + ${LIB_HW_SECURE} + ${OCK_ACCCTR_LINK_FLAGS} + -Wl,--end-group) + +set(TARGET_INSTALL_LIB ${OUTPUT}/ock_ctr_common/lib) +install(TARGETS _ock_ctr_common DESTINATION ${TARGET_INSTALL_LIB}/ PERMISSIONS OWNER_WRITE OWNER_READ OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE) + + diff --git a/src/AccCTR/src/common/CMakeLists.txt b/src/AccCTR/src/common/CMakeLists.txt new file mode 100644 index 00000000..625b3eeb --- /dev/null +++ b/src/AccCTR/src/common/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +add_subdirectory(util) \ No newline at end of file diff --git a/src/AccCTR/src/common/util/CMakeLists.txt b/src/AccCTR/src/common/util/CMakeLists.txt new file mode 100644 index 00000000..ab2dd3d8 --- /dev/null +++ b/src/AccCTR/src/common/util/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +file(GLOB UTIL_SRCS *.cpp *.h) + +add_library(utils STATIC ${UTIL_SRCS}) \ No newline at end of file diff --git a/src/AccCTR/src/common/util/common_execption.h b/src/AccCTR/src/common/util/common_execption.h new file mode 100644 index 00000000..82eb7747 --- /dev/null +++ b/src/AccCTR/src/common/util/common_execption.h @@ -0,0 +1,26 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef ACCCTR_COMMON_EXECPTION_H +#define ACCCTR_COMMON_EXECPTION_H + +namespace ock { +namespace ctr { +class CommonException {}; +class AllocError : public std::exception {}; +class NullptrError : public std::exception {}; +} +} + +#endif // ACCCTR_COMMON_EXECPTION_H diff --git a/src/AccCTR/src/common/util/common_includes.h b/src/AccCTR/src/common/util/common_includes.h new file mode 100644 index 00000000..8b07914f --- /dev/null +++ b/src/AccCTR/src/common/util/common_includes.h @@ -0,0 +1,33 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_CTR_COMMON_INCLUDES_H +#define OCK_CTR_COMMON_INCLUDES_H + +#include +#include +#include +#include + +#include "error_code.h" +#include "external_logger.h" +#include "external_threader.h" +#include "securec.h" +#include "time_cost.h" +#include "spinlock.h" +#include "lock.h" +#include "defines.h" +#include "common_execption.h" + +#endif // OCK_CTR_COMMON_INCLUDES_H diff --git a/src/AccCTR/src/common/util/defines.h b/src/AccCTR/src/common/util/defines.h new file mode 100644 index 00000000..f982ee39 --- /dev/null +++ b/src/AccCTR/src/common/util/defines.h @@ -0,0 +1,32 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_DEFINES_H +#define OCK_DEFINES_H + +namespace ock { +namespace ctr { +using HResult = int32_t; +constexpr int FACTOR_BIT = 2; +constexpr int FACTOR = 4; +constexpr int HASH_L_L = 16; +constexpr int HASH_L = 32; +constexpr int HASH_H = 48; +constexpr int DEFAULT_NUM = 256; +constexpr int MAX_ID_COUNT = 1 << 29; +constexpr int MAX_DESIRED_SIZE = 1431655765; // (2^32 -1)/2/1.5 +} +} + +#endif // OCK_DEFINES_H diff --git a/src/AccCTR/src/common/util/error_code.h b/src/AccCTR/src/common/util/error_code.h new file mode 100644 index 00000000..04d26a57 --- /dev/null +++ b/src/AccCTR/src/common/util/error_code.h @@ -0,0 +1,37 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_CTR_ERRNO_H +#define OCK_CTR_ERRNO_H + +namespace ock { +namespace ctr { +using CTRCode = enum : int { + H_OK = 0, + H_ERROR = 1, + H_NEW_OBJECT_FAILED = 2, + H_ADDRESS_NULL = 3, + H_NUM_SMALL = 4, + H_COPY_ERROR = 5, + H_ID_LARGE = 6, + H_PADDING_SMALL = 7, + H_OUTPUT_TYPE_ERROR = 8, + H_SCENE_ERROR = 9, + H_MEMORY_ALLOC_ERROR = 10, + H_UNIQUE_UNINITIALIZED_ERROR = 11 +}; +} +} + +#endif // OCK_CTR_ERRNO_H diff --git a/src/AccCTR/src/common/util/external_logger.cpp b/src/AccCTR/src/common/util/external_logger.cpp new file mode 100644 index 00000000..0cbcdd0d --- /dev/null +++ b/src/AccCTR/src/common/util/external_logger.cpp @@ -0,0 +1,48 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include "external_logger.h" + +namespace ock { +void ExternalLogger::SetExternalLogFunction(ExternalLog func) +{ + if (mLogFunc == nullptr) { + mLogFunc = func; + } +} + +void ExternalLogger::Log(const int level, const std::ostringstream &oss) const +{ + if (mLogFunc != nullptr) { + mLogFunc(level, oss.str().c_str()); + } +} + +void ExternalLogger::PrintLog(LogLevel level, const std::string &message) +{ + std::ostringstream oss; + oss << message; + auto logger = ExternalLogger::Instance(); + if (logger != nullptr) { + logger->Log(static_cast(level), oss); + } +} + +void ExternalLogger::PrintLog(LogLevel level, const std::string &message, bool flag) +{ + if (flag) { + PrintLog(level, message); + } +} +} diff --git a/src/AccCTR/src/common/util/external_logger.h b/src/AccCTR/src/common/util/external_logger.h new file mode 100644 index 00000000..20cb4edd --- /dev/null +++ b/src/AccCTR/src/common/util/external_logger.h @@ -0,0 +1,69 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_EXTERNAL_LOG_H +#define OCK_EXTERNAL_LOG_H + +#include +#include +#include +#include "singleton.h" + +using ExternalLog = void (*)(int level, const char *msg); + +namespace ock { +enum class LogLevel { + DEBUG = 0, + INFO = 1, + WARN = 2, + ERROR = 3, +}; + +class ExternalLogger { +public: + ExternalLogger() = default; + + static ExternalLogger *Instance() + { + return Singleton::GetInstance(); + } + + void SetExternalLogFunction(ExternalLog func); + + void Log(const int level, const std::ostringstream &oss) const; + + static void PrintLog(LogLevel level, const std::string &message); + + static void PrintLog(LogLevel level, const std::string &message, bool flag); + + ExternalLogger(const ExternalLogger &) = delete; + ExternalLogger &operator = (const ExternalLogger &) = delete; + ExternalLogger(ExternalLogger &&) = delete; + ExternalLogger &operator = (const ExternalLogger &&) = delete; + + ~ExternalLogger() + { + mLogFunc = nullptr; + } + +private: +private: + static ExternalLogger *gLogger; + static std::mutex gMutex; + + ExternalLog mLogFunc = nullptr; +}; +} + +#endif // OCK_EXTERNAL_LOG_H diff --git a/src/AccCTR/src/common/util/external_thread.cpp b/src/AccCTR/src/common/util/external_thread.cpp new file mode 100644 index 00000000..0b787e58 --- /dev/null +++ b/src/AccCTR/src/common/util/external_thread.cpp @@ -0,0 +1,33 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include "external_threader.h" + +namespace ock { +void ExternalThreader::SetExternalLogFunction(ExternalThread func) +{ + if (mThreadFunc == nullptr) { + mThreadFunc = func; + } +} + +void ExternalThreader::Run(const std::vector> &tasks) const +{ + if (mThreadFunc != nullptr) { + mThreadFunc(tasks); + } else { + SimpleThreadPool::SyncRun(tasks); + } +} +} diff --git a/src/AccCTR/src/common/util/external_threader.h b/src/AccCTR/src/common/util/external_threader.h new file mode 100644 index 00000000..5a1132af --- /dev/null +++ b/src/AccCTR/src/common/util/external_threader.h @@ -0,0 +1,73 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_EXTERNAL_THREADER_H +#define OCK_EXTERNAL_THREADER_H + +#include +#include +#include +#include +#include +#include "singleton.h" + +using ExternalThread = void (*)(const std::vector> &tasks); + +namespace ock { +class SimpleThreadPool { +public: + static void SyncRun(const std::vector> &tasks) + { + std::vector> futs; + for (auto &task : tasks) { + futs.push_back(std::async(task)); + } + for (auto &fut : futs) { + fut.wait(); + } + } +}; + +class ExternalThreader { +public: + ExternalThreader() = default; + + static ExternalThreader *Instance() + { + return Singleton::GetInstance(); + } + + void SetExternalLogFunction(ExternalThread func); + + void Run(const std::vector> &tasks) const; + + ExternalThreader(const ExternalThreader &) = delete; + ExternalThreader &operator = (const ExternalThreader &) = delete; + ExternalThreader(ExternalThreader &&) = delete; + ExternalThreader &operator = (const ExternalThreader &&) = delete; + + ~ExternalThreader() + { + mThreadFunc = nullptr; + } + +private: + static ExternalThreader *gThread; + static std::mutex gMutex; + + ExternalThread mThreadFunc = nullptr; +}; +} + +#endif // OCK_EXTERNAL_THREADER_H diff --git a/src/AccCTR/src/common/util/lock.h b/src/AccCTR/src/common/util/lock.h new file mode 100644 index 00000000..e0a26be6 --- /dev/null +++ b/src/AccCTR/src/common/util/lock.h @@ -0,0 +1,214 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_CTR_LOCK_H +#define OCK_CTR_LOCK_H + +#include +#include +#include + +namespace ock { +namespace ctr { +class Lock { +public: + Lock() = default; + ~Lock() = default; + + Lock(const Lock &) = delete; + Lock &operator = (const Lock &) = delete; + Lock(Lock &&) = delete; + Lock &operator = (Lock &&) = delete; + + inline void DoLock() + { + mLock.lock(); + } + + inline void UnLock() + { + mLock.unlock(); + } + +private: + std::mutex mLock; +}; + +class RecursiveLock { +public: + RecursiveLock() = default; + ~RecursiveLock() = default; + + RecursiveLock(const RecursiveLock &) = delete; + RecursiveLock &operator = (const RecursiveLock &) = delete; + RecursiveLock(RecursiveLock &&) = delete; + RecursiveLock &operator = (RecursiveLock &&) = delete; + + inline void DoLock() + { + mLock.lock(); + } + inline void UnLock() + { + mLock.unlock(); + } + +private: + std::recursive_mutex mLock; +}; + +class ReadWriteLock { +public: + ReadWriteLock() + { + pthread_rwlock_init(&mLock, nullptr); + } + ~ReadWriteLock() + { + pthread_rwlock_destroy(&mLock); + } + + ReadWriteLock(const ReadWriteLock &) = delete; + ReadWriteLock &operator = (const ReadWriteLock &) = delete; + ReadWriteLock(ReadWriteLock &&) = delete; + ReadWriteLock &operator = (ReadWriteLock &&) = delete; + + inline void LockRead() + { + pthread_rwlock_rdlock(&mLock); + } + + inline void LockWrite() + { + pthread_rwlock_wrlock(&mLock); + } + + inline void UnLock() + { + pthread_rwlock_unlock(&mLock); + } + +private: + pthread_rwlock_t mLock {}; +}; + + +class SpinLock { +public: + SpinLock() = default; + ~SpinLock() = default; + + SpinLock(const SpinLock &) = delete; + SpinLock &operator = (const SpinLock &) = delete; + SpinLock(SpinLock &&) = delete; + SpinLock &operator = (SpinLock &&) = delete; + + inline void TryLock() + { + mFlag.test_and_set(std::memory_order_acquire); + } + + inline void Lock() + { + while (mFlag.test_and_set(std::memory_order_acquire)) { + } + } + + inline void UnLock() + { + mFlag.clear(std::memory_order_release); + } + +private: + std::atomic_flag mFlag = ATOMIC_FLAG_INIT; +}; + +template class Locker { +public: + explicit Locker(T *lock) : mLock(lock) + { + if (mLock != nullptr) { + mLock->DoLock(); + } + } + + ~Locker() + { + if (mLock != nullptr) { + mLock->UnLock(); + } + } + + Locker(const Locker &) = delete; + Locker &operator = (const Locker &) = delete; + Locker(Locker &&) = delete; + Locker &operator = (Locker &&) = delete; + +private: + T *mLock; +}; + +template class ReadLocker { +public: + explicit ReadLocker(T *lock) : mLock(lock) + { + if (mLock != nullptr) { + mLock->LockRead(); + } + } + + ~ReadLocker() + { + if (mLock != nullptr) { + mLock->UnLock(); + } + } + + ReadLocker(const ReadLocker &) = delete; + ReadLocker &operator = (const ReadLocker &) = delete; + ReadLocker(ReadLocker &&) noexcept = delete; + ReadLocker &operator = (ReadLocker &&) noexcept = delete; + +private: + T *mLock; +}; + +template class WriteLocker { +public: + explicit WriteLocker(T *lock) : mLock(lock) + { + if (mLock != NULL) { + mLock->LockWrite(); + } + } + + ~WriteLocker() + { + if (mLock != NULL) { + mLock->UnLock(); + } + } + + WriteLocker(const WriteLocker &) = delete; + WriteLocker &operator = (const WriteLocker &) = delete; + WriteLocker(WriteLocker &&) noexcept = delete; + WriteLocker &operator = (WriteLocker &&) noexcept = delete; + +private: + T *mLock; +}; +} +} + +#endif // OCK_CTR_LOCK_H diff --git a/src/AccCTR/src/common/util/singleton.h b/src/AccCTR/src/common/util/singleton.h new file mode 100644 index 00000000..c645eb7e --- /dev/null +++ b/src/AccCTR/src/common/util/singleton.h @@ -0,0 +1,47 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + + +#ifndef ACCCTR_SRC_COMMON_UTIL_SINGLETON_H_ +#define ACCCTR_SRC_COMMON_UTIL_SINGLETON_H_ + +#include +#include + +/** + * T must be destructed + * @tparam T + */ +namespace ock { +template class Singleton { +public: + Singleton() = delete; + + Singleton(const Singleton &singleton) = delete; + + Singleton &operator = (const Singleton &singleton) = delete; + + static T *GetInstance() + { + try { + static T instance; + return &instance; + } catch (std::exception &e) { + std::cout << " create singleton error" << std::endl; + return nullptr; + } + } +}; +} +#endif // ACCCTR_SRC_COMMON_UTIL_SINGLETON_H_ diff --git a/src/AccCTR/src/common/util/spinlock.h b/src/AccCTR/src/common/util/spinlock.h new file mode 100644 index 00000000..7cecf036 --- /dev/null +++ b/src/AccCTR/src/common/util/spinlock.h @@ -0,0 +1,123 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef SRC_UTILS_SPINLOCK_H +#define SRC_UTILS_SPINLOCK_H + +#include +#include +#include // NOLINT + +namespace ock { +namespace ctr { +#ifdef LOCK_NOTHING + +class SpinLock final { +public: + void lock() noexcept {} + bool try_lock() noexcept + { + return true; + } + void unlock() noexcept {} +}; + +#elif defined(USE_MUTEX) + +class SpinLock final { +public: + void lock() noexcept + { + mt_.lock(); + } + bool try_lock() noexcept + { + return mt_.try_lock(); + } + void unlock() noexcept + { + mt_.unlock(); + } + +private: + std::mutex mt_; +}; + +#else + +class SpinLockG final { +public: + const int maxSpinCountBeforeThreadYield = 64; + SpinLockG() = default; + SpinLockG(SpinLockG const &) = delete; + SpinLockG(SpinLockG &&) noexcept = delete; + SpinLockG &operator = (SpinLockG const &) = delete; + SpinLockG &operator = (SpinLockG const &&) = delete; + + static __inline void CpuPause() + { +#ifdef __GNUC__ +#ifdef __aarch64__ + __asm volatile("yield" ::: "memory"); +#elif defined(__i386__) || defined(__x86_64__) + __asm__ __volatile__("rep;nop;nop" ::: "memory"); +#else +#error "unknown architecture" +#endif +#else +#error "unknown architecture" +#endif + } + inline void lock() noexcept + { + bool flag = true; + while (flag) { + if (!lock_.exchange(true, std::memory_order_acquire)) { + flag = false; + break; + } + + uint16_t counter = 0; + while (lock_.load(std::memory_order_relaxed)) { + CpuPause(); + if (++counter > maxSpinCountBeforeThreadYield) { + std::this_thread::yield(); + // reset counter + counter = 0; + } + } + } + } + + inline bool try_lock() noexcept + { + if (lock_.load(std::memory_order_relaxed)) { + return false; + } + return !lock_.exchange(true, std::memory_order_acquire); + } + + inline void unlock() noexcept + { + lock_.store(false, std::memory_order_release); + } + +private: + std::atomic lock_ { false }; +}; +} +} + +#endif +#endif diff --git a/src/AccCTR/src/common/util/time_cost.h b/src/AccCTR/src/common/util/time_cost.h new file mode 100644 index 00000000..0b303169 --- /dev/null +++ b/src/AccCTR/src/common/util/time_cost.h @@ -0,0 +1,49 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef TIMECOST_H +#define TIMECOST_H + +#include + +namespace ock { +namespace ctr { +class TimeCost { +public: + TimeCost() + { + start_ = std::chrono::high_resolution_clock::now(); + } + + double ElapsedSec() + { + std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); + std::chrono::duration d = std::chrono::duration_cast>(end - start_); + return d.count(); + } + + size_t ElapsedMS() + { + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::milliseconds d = std::chrono::duration_cast(end - start_); + return d.count(); + } + +private: + std::chrono::high_resolution_clock::time_point start_; +}; +} +} + +#endif \ No newline at end of file diff --git a/src/AccCTR/src/factory_impl.cpp b/src/AccCTR/src/factory_impl.cpp new file mode 100644 index 00000000..f0f5cdac --- /dev/null +++ b/src/AccCTR/src/factory_impl.cpp @@ -0,0 +1,69 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include "factory_impl.h" + +namespace ock { +namespace ctr { +Factory *FactoryImpl::gGlobalFactory = nullptr; +Lock FactoryImpl::gLock; + +#ifdef __cplusplus +extern "C" { +#endif + +int CTR_CreateFactory(uintptr_t *outFactory) +{ + if (FactoryImpl::gGlobalFactory == nullptr) { + Locker locker(&FactoryImpl::gLock); + if (FactoryImpl::gGlobalFactory == nullptr) { + auto tmp = new (std::nothrow) FactoryImpl(); + if (tmp == nullptr) { + return H_NEW_OBJECT_FAILED; + } + + FactoryImpl::gGlobalFactory = tmp; + } + } + *outFactory = reinterpret_cast(FactoryImpl::gGlobalFactory); + return H_OK; +} +#ifdef __cplusplus +} +#endif + +int FactoryImpl::CreateUnique(std::shared_ptr &out) +{ + auto tmp = new (std::nothrow) UniqueImpl(); + if (tmp == nullptr) { + return H_NEW_OBJECT_FAILED; + } + + out.reset(dynamic_cast(tmp)); + return H_OK; +} + +int FactoryImpl::SetExternalLogFuncInner(ExternalLog logFunc) +{ + auto logger = ExternalLogger::Instance(); + if (logger == nullptr) { + std::cout << "Failed to create logger instance" << std::endl; + return H_NEW_OBJECT_FAILED; + } + + logger->SetExternalLogFunction(logFunc); + return H_OK; +} +} +} diff --git a/src/AccCTR/src/factory_impl.h b/src/AccCTR/src/factory_impl.h new file mode 100644 index 00000000..cc1c025a --- /dev/null +++ b/src/AccCTR/src/factory_impl.h @@ -0,0 +1,39 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_FACTORY_IMPL_H +#define OCK_FACTORY_IMPL_H + +#include "include/factory.h" +#include "unique/unique_impl.h" + +namespace ock { +namespace ctr { +class FactoryImpl : public Factory { +public: + FactoryImpl() = default; + ~FactoryImpl() override = default; + +public: + int CreateUnique(std::shared_ptr &out) override; + int SetExternalLogFuncInner(ExternalLog logFunc) override; + +public: + static Factory *gGlobalFactory; + static Lock gLock; +}; +} +} + +#endif // OCK_FACTORY_IMPL_H \ No newline at end of file diff --git a/src/AccCTR/src/include/CMakeLists.txt b/src/AccCTR/src/include/CMakeLists.txt new file mode 100644 index 00000000..c9d2b215 --- /dev/null +++ b/src/AccCTR/src/include/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set(INCLUDE_HEADERS factory.h ock_ctr_common_def.h unique.h) + +set(TARGET_INSTALL_INCLUDE ${OUTPUT}/ock_ctr_common/include) + +install(FILES ${INCLUDE_HEADERS} DESTINATION ${TARGET_INSTALL_INCLUDE} PERMISSIONS OWNER_WRITE OWNER_READ GROUP_READ WORLD_READ) \ No newline at end of file diff --git a/src/AccCTR/src/include/factory.h b/src/AccCTR/src/include/factory.h new file mode 100644 index 00000000..14732cf9 --- /dev/null +++ b/src/AccCTR/src/include/factory.h @@ -0,0 +1,64 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef UNIQUE_OCK_CTR_COMMON_H +#define UNIQUE_OCK_CTR_COMMON_H + +#include +#include +#include +#include "unique.h" + + +#ifdef __cplusplus +extern "C" { +#endif + +using ExternalLog = void (*)(int level, const char *msg); + +#ifdef __cplusplus +} +#endif + +#include "ock_ctr_common_def.h" + +namespace ock { +namespace ctr { +class Factory; + +using FactoryPtr = std::shared_ptr; +using UniquePtr = std::shared_ptr; + +class Factory { +public: + virtual ~Factory() = default; + virtual int CreateUnique(UniquePtr &out) = 0; + virtual int SetExternalLogFuncInner(ExternalLog logFunc) = 0; + +public: + static int Create(FactoryPtr &out) + { + int result = 0; + uintptr_t factory = 0; + /* dynamic load function */ + if ((result = OckCtrCommonDef::CreatFactory(&factory)) == 0) { + out.reset(reinterpret_cast(factory)); + } + return result; + } +}; +} +} + +#endif // UNIQUE_OCK_CTR_COMMON_H diff --git a/src/AccCTR/src/include/ock_ctr_common_def.h b/src/AccCTR/src/include/ock_ctr_common_def.h new file mode 100644 index 00000000..ed955996 --- /dev/null +++ b/src/AccCTR/src/include/ock_ctr_common_def.h @@ -0,0 +1,62 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_OCK_CTR_COMMON_DEF_H +#define OCK_OCK_CTR_COMMON_DEF_H + +#include +#include +#include + +using CTR_CREATE_FACTORY_FUNCTION = int (*)(uintptr_t *); + +namespace ock { +namespace ctr { +class OckCtrCommonDef { +public: + static int CreatFactory(uintptr_t *factory) + { + static void *handle = nullptr; + static std::mutex m; + std::unique_lock lock(m); + if (handle != nullptr) { + std::cout << "can't create factory more than 1 time." << std::endl; + return -1; + } + + handle = dlopen(LIBRARY_NAME, RTLD_NOW); + if (handle == nullptr) { + std::cout << "Failed to call dlopen to load library '" << LIBRARY_NAME << "', error " << dlerror() << + std::endl; + return -1; + } + + auto fun = (CTR_CREATE_FACTORY_FUNCTION)dlsym(handle, "CTR_CreateFactory"); + if (fun == nullptr) { + std::cout << "Failed to call dlsym to load function 'CTR_CreateFactory', error " << dlerror() << std::endl; + dlclose(handle); + return -1; + } + + fun(factory); + return 0; + } + +private: + constexpr static const char *LIBRARY_NAME = "lib_ock_ctr_common.so"; +}; +} +} + +#endif // OCK_OCK_CTR_COMMON_DEF_H diff --git a/src/AccCTR/src/include/unique.h b/src/AccCTR/src/include/unique.h new file mode 100644 index 00000000..3154a784 --- /dev/null +++ b/src/AccCTR/src/include/unique.h @@ -0,0 +1,128 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_UNIQUE_H +#define OCK_UNIQUE_H +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +using ExternalThread = void (*)(const std::vector> &tasks); + +#ifdef __cplusplus +} +#endif + +namespace ock { +namespace ctr { +using BucketStrategy = enum class BucketStrategy { + MODULO +}; + +using DataType = enum class DataType { + INT64 = 0, + INT32 +}; + +using OutputType = enum class OutputType { + NORMAL = 0, + ENHANCED +}; + +using UniqueConf = struct UniqueConfCTR { + BucketStrategy bucketStrategy = BucketStrategy::MODULO; + OutputType outputType = OutputType::NORMAL; // 是否为普通unique + DataType dataType = DataType::INT64; // 输入id类型 + bool usePadding = false; // 是否开启padding, 开启前要先开启sharding + bool useIdCount = false; // 是否开启id计数 + bool useSharding = false; // 是否开启sharding + int shardingNum = -1; // 分桶个数 + uint32_t desiredSize = 256; // 预估id所需内存空间 + int paddingSize = 0; // 填充后长度 + int paddingVal = -1; // 填充值 + uint32_t minThreadNum = 1; // 最小工作线程数 + uint32_t maxThreadNum = 8; // 最大工作线程数 + int64_t maxIdVal = 0; // 最大id值 + bool trace = false; // 是否开启性能检测,需要配合外部日志输出 +} __attribute__((packed)); + +using UniqueIn = struct UniqueInCTR { + void *inputId = nullptr; // 输入的ids首地址(需要用户申请)必填 + uint32_t inputIdCnt = 0; // 输入ids的个数 +}; + +using UniqueOut = struct UniqueOutCTR { + void *uniqueId = nullptr; // 去重分桶填充之后最终的的ids(需要用户申请)必选 + uint32_t *index = nullptr; // 去重后id的索引位置(需要用户申请)必选 + int uniqueIdCnt = 0; // 去重后的id个数 +}; + +using EnhancedUniqueOut = struct EnhancedUniqueOutCTR { + void *uniqueId = nullptr; // 去重分桶填充之后最终的的ids(需要用户申请)必选 + uint32_t *index = nullptr; // 去重后id的索引位置(需要用户申请)必选 + void *uniqueIdInBucket = nullptr; // 去重之后的分桶内的ids(需要用户申请) sharding开启之后必须申请 + int *uniqueIdCntInBucket = nullptr; // 每个桶去重后的id个数(需要用户申请) sharding开启之后必须申请 + int uniqueIdCnt = 0; // 去重后的id个数 + int *idCnt = nullptr; // 每个id的重复次数(需要用户申请) 开启idCnt之后必选 + int *idCntFill = nullptr; // 每个id的重复次数带了填充(需要用户申请) 开启idCnt和padding之后必选 +}; + +class Unique { +public: + virtual ~Unique() = default; + /* * + * 初始化unique 所需配置项 + * + * @param conf 输入unique所需的配置 + * @return error_code + */ + virtual int Initialize(const UniqueConf &conf) = 0; + + /* * + * 释放unique资源 + */ + virtual void UnInitialize() = 0; + + /* * + * id去重接口 + * + * @param UniqueIn 入参:unique用户输入 + * @param UniqueOut 出参:unique用户输出 + * @return errorCode + */ + virtual int DoUnique(UniqueIn &uniqueIn, UniqueOut &uniqueOut) = 0; + + /* * + * 具有额外输出的unique + * + * @param uniqueIn + * @param EnhancedUniqueOut + * @return errorCode + */ + virtual int DoEnhancedUnique(UniqueIn &uniqueIn, EnhancedUniqueOut &enhancedUniqueOut) = 0; + + /* * + * 设置外部线程池方法 + * + * @return + */ + virtual int SetExternalThreadFuncInner(ExternalThread threadFunc) = 0; +}; +} +} + +#endif // OCK_UNIQUE_H diff --git a/src/AccCTR/src/unique/CMakeLists.txt b/src/AccCTR/src/unique/CMakeLists.txt new file mode 100644 index 00000000..6666ed44 --- /dev/null +++ b/src/AccCTR/src/unique/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +file(GLOB_RECURSE SRCS *.cpp *.h) + +add_library(unique STATIC ${SRCS}) + +target_link_libraries(unique + -Wl,--start-group + -Wl,--end-group + ) + +target_include_directories(unique + PUBLIC + ${PROJECT_SOURCE_DIR}/src/common/util + ${PROJECT_SOURCE_DIR}/src/include) \ No newline at end of file diff --git a/src/AccCTR/src/unique/unique_func.cpp b/src/AccCTR/src/unique/unique_func.cpp new file mode 100644 index 00000000..64ad6d52 --- /dev/null +++ b/src/AccCTR/src/unique/unique_func.cpp @@ -0,0 +1,193 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include "unique_func.h" + +namespace ock { +namespace ctr { +void Dedup::Insert(uint64_t val) +{ + auto h = static_cast(Hash(val) & bucketCountMask_); + Meta *bucket = &table_[h]; + + int8_t count = bucket->count; + + int8_t totalCount = 0; + + for (int8_t i = 0; i < count; ++i) { + if (bucket->data[totalCount] == val) { + TryIncreaseIdCount(bucket->idCount[totalCount]); + // found one + return; + } + totalCount++; + } + // try again, this time with lock acquired + if (count < n) { + std::lock_guard lg(bucket->lock); + for (int8_t j = totalCount; j < bucket->count; ++j) { + if (bucket->data[totalCount] == val) { + TryIncreaseIdCount(bucket->idCount[totalCount]); + // found one + return; + } + totalCount++; + } + if (totalCount < n) { + bucket->data[totalCount] = val; + bucket->count++; + TryIncreaseIdCount(bucket->idCount[totalCount]); + return; + } + } + // shift to the overflow reservior + InsertOverflow(val); +} + +inline void Dedup::TryIncreaseIdCount(std::atomic &val) +{ + if (idCountEnable_) { + val++; + } +} + +int32_t Dedup::GetReplaceOffsetUnsafe(uint64_t val) +{ + auto h = static_cast(Hash(val) & bucketCountMask_); + Meta *bucket = &table_[h]; + + int8_t totalCount = 0; + for (int8_t i = 0; i < bucket->count; ++i) { + if (bucket->data[totalCount] == val) { + // found one + return bucket->replaceBase + totalCount; + } + totalCount++; + } + if (totalCount < n) { + return -1; + } + return GetReplaceOffsetFromOverflowUnsafe(val); +} + +void Dedup::InitTable() +{ + void *area = aligned_alloc(64, sizeof(Meta) * bucketCount_); + if (area == nullptr) { + throw AllocError(); + } else { + table_ = reinterpret_cast *>(area); + } +} + +void Dedup::Clear(uint64_t newBucketCountPowerOf2) +{ + std::lock_guard lg(overflowMutex_); + if (newBucketCountPowerOf2 > 0 && newBucketCountPowerOf2 != bucketCount_) { + if (table_ != nullptr) { + free(table_); + table_ = nullptr; + } + bucketCount_ = newBucketCountPowerOf2; + bucketCountMask_ = bucketCount_ - 1; + table_ = reinterpret_cast *>(aligned_alloc(K_ALIGNMENT, sizeof(Meta) * bucketCount_)); + if (table_ == nullptr) { + throw AllocError(); + } + } + bzero(table_, sizeof(Meta) * bucketCount_); + overflow_.clear(); + idCountOverflow_.clear(); +} + +void Dedup::NewParameter() +{ + uint64_t newBucketCountPowerOf2 = bucketCount_; + + if (stats_.totalUniques > 0 && stats_.totalOverflowUniques > K_MINIMAL_WORKLOAD_PER_WORKER) { + // Time to check the proper size of sharded tables for performance + // sake. + uint64_t shardedTableSize = 0; + if (std::numeric_limits::max() / n / groupCount_ < newBucketCountPowerOf2) { + shardedTableSize = std::numeric_limits::max(); + } else { + shardedTableSize = newBucketCountPowerOf2 * n * groupCount_; + } + + int largeCount = 0; + while (shardedTableSize > stats_.totalUniques * FACTOR && largeCount_ != 1) { + // too large + newBucketCountPowerOf2 >>= 1; + shardedTableSize >>= 1; + largeCount++; + } + + int count = ((largeCount == 1) && (largeCount != largeCount_)) ? 2 : 1; + for (int i = 0; i < count; i++) { + if (stats_.totalOverflowUniques > K_MINIMAL_WORKLOAD_PER_WORKER) { + newBucketCountPowerOf2 <<= 1; + shardedTableSize <<= 1; + } + } + + while (shardedTableSize < stats_.totalUniques + (stats_.totalUniques >> FACTOR_BIT)) { + newBucketCountPowerOf2 <<= 1; + shardedTableSize <<= 1; + } + + if (largeCount_ != 1) { + largeCount_ = largeCount; + } + } + + Clear(newBucketCountPowerOf2); + bucketCount_ = newBucketCountPowerOf2; + stats_.totalUniques = 0; + stats_.totalOverflowUniques = 0; +} + +int32_t ShardedDedup::GetFillOffset(const std::vector &totalUniqueSize, int64_t val, int32_t group) +{ + if (!conf.usePadding) { + return dedupShards_[group]->GetReplaceOffsetUnsafe(val); + } else { + return dedupShards_[group]->GetReplaceOffsetUnsafe(val) + conf.paddingSize * group - totalUniqueSize[group]; + } +} + + +size_t ShardedDedup::CalThreadNum() const +{ + uint32_t threadNum = (conf.desiredSize + K_MINIMAL_WORKLOAD_PER_WORKER - 1) / K_MINIMAL_WORKLOAD_PER_WORKER; + threadNum = std::min(conf.maxThreadNum, std::max(threadNum, conf.minThreadNum)); + return threadNum; +} + +bool ShardedDedup::IsPaddingValid(UniqueOutSelf &uniqueOut) +{ + if (conf.outputType == OutputType::ENHANCED && conf.usePadding) { + for (int i = 0; i < conf.shardingNum; i++) { + if (conf.paddingSize < uniqueOut.uniqueIdCntInBucket[i]) { + std::stringstream ssm; + ssm << "paddingSize should not be smaller than uniqueSize, paddingSize " << conf.paddingSize << + " , uniqueSize " << uniqueOut.uniqueIdCntInBucket[i]; + ExternalLogger::PrintLog(LogLevel::ERROR, ssm.str()); + return false; + } + } + } + return true; +} +} +} \ No newline at end of file diff --git a/src/AccCTR/src/unique/unique_func.h b/src/AccCTR/src/unique/unique_func.h new file mode 100644 index 00000000..39e5a6b3 --- /dev/null +++ b/src/AccCTR/src/unique/unique_func.h @@ -0,0 +1,575 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef SRC_UTILS_UNIQUE_H +#define SRC_UTILS_UNIQUE_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "securec.h" +#include "common_includes.h" +#include "factory.h" + +namespace ock { +namespace ctr { +using UniqueOutSelf = struct UniqueSelf { + void *uniqueId = nullptr; // 去重分桶填充之后最终的的ids(需要用户申请)必选 + uint32_t *index = nullptr; // 去重后id的索引位置(需要用户申请)必选 + void *uniqueIdInBucket = nullptr; // 去重之后的分桶内的ids(需要用户申请) shard开启后必选 + int *uniqueIdCntInBucket = nullptr; // 每个桶去重后的id个数(需要用户申请) shard开启后必选 + int *idCnt = nullptr; // 每个id的重复次数(需要用户申请) idCnt开启后必选 + int *idCntFill = nullptr; // 每个id的重复次数带了填充(需要用户申请) idCnt和padding开启后必选 + int uniqueIdCnt = 0; // 每个桶去重后的id个数(需要用户申请) +}; + +constexpr int UNIQUE_MAX_BUCKET_WIDTH = 5; + +template struct Map {}; +template <> struct Map { + using type = int64_t; +}; + +template <> struct Map { + using type = int32_t; +}; + +template typename Map::type *TypeTrans(void *input) +{ + return reinterpret_cast::type *>(input); +} + +using StrategyFun = int (*)(const uint64_t &val, const int &groupCount); + +class BucketStrategies { +public: + static int SimpleGroupFun(const uint64_t &val, const int &groupCount) + { + return val % groupCount; + } +}; + + +class GroupMethod { +public: + inline int GroupCount() + { + return groupCount_; + } + + inline int GroupId(uint64_t val) + { + return strategyFun_(val, groupCount_); + } + void SetGroupCount(int count) + { + groupCount_ = count; + } + + void SetStrategyFun(StrategyFun strategyFun) + { + strategyFun_ = strategyFun; + } + + void SetStrategyFunByConf(BucketStrategy bucketStrategy) + { + if (bucketStrategy == BucketStrategy::MODULO) { + SetStrategyFun(BucketStrategies::SimpleGroupFun); + } + } + +private: + int groupCount_; + StrategyFun strategyFun_; +}; + +class Dedup { + static constexpr uint32_t K_MINIMAL_WORKLOAD_PER_WORKER = 1 << 12; + static constexpr size_t K_ALIGNMENT = 64; + static const int kDefaultBucketCount = 1 << 24; + static const int8_t n = 4; + + template struct Meta { + static_assert(M <= UNIQUE_MAX_BUCKET_WIDTH, "should be no larger than max bucket width"); + SpinLockG lock; + volatile int8_t count {}; + uint32_t replaceBase {}; + volatile uint64_t data[M] {}; + std::atomic idCount[M] {}; + } __attribute__((__aligned__(64))); + + struct Statistics { + uint64_t totalUniques = 0; + uint64_t totalOverflowUniques = 0; + }; + +public: + explicit Dedup(int bucketCountPower2, int groups, bool idCountEnable) + : bucketCount_(bucketCountPower2), + bucketCountMask_(bucketCountPower2 - 1), + groupCount_(groups), + idCountEnable_(idCountEnable) + { + bucketCount_ = static_cast(bucketCountPower2); + bucketCountMask_ = bucketCount_ - 1; + groupCount_ = groups; + InitTable(); + Clear(bucketCount_); + } + + ~Dedup() + { + if (table_ != nullptr) { + free(table_); + } + } + +public: + void Insert(uint64_t val); + int32_t GetReplaceOffsetUnsafe(uint64_t val); + void InitTable(); + void TryIncreaseIdCount(std::atomic &val); + void Clear(uint64_t newBucketCountPowerOf2); + void NewParameter(); + + template uint32_t UniqueRaw(void *output, uint32_t priorTotal, int32_t *idCount) + { + uint32_t total = priorTotal; + uint32_t replaceOffset = priorTotal; + auto out = TypeTrans(output); + for (uint64_t i = 0; i < bucketCount_; ++i) { + Meta *bucket = &table_[i]; + if (bucket->count == 0) { + continue; + } + bucket->replaceBase = replaceOffset; + for (int j = 0; j < bucket->count; ++j) { + if (idCountEnable_) { + idCount[total] = bucket->idCount[j]; + } + out[total++] = bucket->data[j]; + } + replaceOffset += bucket->count; + } + auto it = overflow_.begin(); + int32_t totalOverflow = 0; + while (it != overflow_.end()) { + if (idCountEnable_) { + idCount[total] = idCountOverflow_[it->first]; + } + out[total++] = it->first; + it->second = replaceOffset++; + ++it; + ++totalOverflow; + } + + // set total overflow count + stats_.totalUniques = static_cast(total - priorTotal); + stats_.totalOverflowUniques = totalOverflow; + return total - priorTotal; + } + +private: + uint64_t bucketCount_ = static_cast(kDefaultBucketCount); + uint64_t bucketCountMask_; + int groupCount_ = 1; + int largeCount_ { 0 }; + Meta *table_ {}; + std::unordered_map overflow_; + std::unordered_map idCountOverflow_; + SpinLockG overflowMutex_; + Statistics stats_; + bool idCountEnable_ { false }; + + static inline uint64_t Hash(uint64_t val) + { + return val ^ (val >> HASH_L_L) ^ (val >> HASH_L_L) ^ (val >> HASH_H); + } + + void InsertOverflow(uint64_t val) + { + std::lock_guard lg(overflowMutex_); + auto it = overflow_.find(val); + if (it == overflow_.end()) { + overflow_[val] = 0; + } + + if (idCountEnable_) { + idCountOverflow_[val]++; + } + } + + int32_t GetReplaceOffsetFromOverflowUnsafe(uint64_t val) + { + auto it = overflow_.find(val); + return (it != overflow_.end()) ? it->second : -1; + } +}; // Dedup + +class ShardedDedup { + static constexpr uint32_t K_MINIMAL_WORKLOAD_PER_WORKER = 1 << 13; + static constexpr int K_DEFAULT_DUPLICATE_RATIO = 4; + static constexpr int K_BUCKET_WIDTH = 4; + +public: + using DedupT = Dedup; + + ShardedDedup(const GroupMethod &groupMethod, const UniqueConf &uniqueConf, + int estimatedDuplicateRatio = K_DEFAULT_DUPLICATE_RATIO) + : groupMethod_(groupMethod), bucketCountPower2_(DEFAULT_NUM), conf(uniqueConf), partSize(0) + { + const int numOfGroupsInShard = groupMethod_.GroupCount(); + uint32_t totalSize = conf.desiredSize + (conf.desiredSize >> 1); + while (bucketCountPower2_ * K_BUCKET_WIDTH * numOfGroupsInShard * estimatedDuplicateRatio < totalSize) { + bucketCountPower2_ <<= 1; + } + + idCountEnable_ = (conf.outputType == OutputType::ENHANCED) && conf.useIdCount; + for (int32_t i = 0; i < numOfGroupsInShard; ++i) { + auto obj = new DedupT(bucketCountPower2_, numOfGroupsInShard, idCountEnable_); + if (obj == nullptr) { + ExternalLogger::PrintLog(LogLevel::ERROR, "creat object error"); + throw NullptrError(); + } + dedupShards_.emplace_back(obj); + } + } + + ~ShardedDedup() = default; + + void StartNewRound() + { + for (auto &s : dedupShards_) { + s->NewParameter(); + } + } + +public: + template int Compute(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut) + { + try { + if (!firstEnterFlag_) { + StartNewRound(); + } + } catch (AllocError &) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memory alloc error"); + return H_MEMORY_ALLOC_ERROR; + } + firstEnterFlag_ = false; + size_t threadNum = CalThreadNum(); + partSize = (uniqueIn.inputIdCnt + threadNum - 1) / threadNum; + + int ret = InsertVal(uniqueIn, threadNum); + if (ret != H_OK) { + return ret; + } + + DoUniqueRaw(uniqueOut); + + partSize = CacheLineAlign(partSize); + + if (!IsPaddingValid(uniqueOut)) { + return H_PADDING_SMALL; + } + + std::vector totalUniqueSize; + totalUniqueSize.resize(conf.shardingNum); + + if (conf.outputType == OutputType::ENHANCED) { + int totalNumber = 0; + for (int i = 0; i < conf.shardingNum; i++) { + totalUniqueSize[i] = totalNumber; + if (conf.useSharding) { + totalNumber += uniqueOut.uniqueIdCntInBucket[i]; + } + } + } + + ret = CalUniqueOut(uniqueIn, uniqueOut, totalUniqueSize); + if (ret != H_OK) { + ExternalLogger::PrintLog(LogLevel::ERROR, "CalUniqueOut ERROR"); + return ret; + } + + if (conf.outputType == OutputType::ENHANCED) { + HandleTileAndFill(uniqueIn, uniqueOut); + } + + return H_OK; + } + +private: + template T CacheLineAlign(T size) + { + return (((size) + 63uL) & ~63uL); + } + + bool IsPaddingValid(UniqueOutSelf &uniqueOut); + + size_t CalThreadNum() const; + + int32_t GetFillOffset(const std::vector &totalUniqueSize, int64_t val, int32_t group); + + template int HandleTileAndFill(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut) + { + int ret = H_OK; + if (conf.useSharding) { // 使能shard + ret = TileAndFill(uniqueOut.uniqueIdInBucket, uniqueOut.uniqueIdCntInBucket, uniqueOut.uniqueId, + uniqueOut.idCnt, uniqueOut.idCntFill); + } else if (!conf.useSharding && conf.useIdCount) { // 不使能shard和使能特征计数 + std::vector count; + count.emplace_back(uniqueOut.uniqueIdCnt); // 记录去重后id个数 + ret = TileAndFill(uniqueOut.uniqueId, count.data(), uniqueOut.uniqueId, uniqueOut.idCnt, + uniqueOut.idCntFill); + } + + if (ret != H_OK) { + ExternalLogger::PrintLog(LogLevel::ERROR, "TileAndFill ERROR"); + return ret; + } + + return H_OK; + } + + template void DoUniqueRaw(UniqueOutSelf &uniqueOut) + { + // Collect Unique and base vectors + int32_t total = 0; + for (int j = 0; j < groupMethod_.GroupCount(); ++j) { + uint64_t inGroupTotal; + if (conf.outputType == OutputType::ENHANCED) { + if (conf.useSharding && conf.useIdCount) { + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total, + uniqueOut.idCnt); // 特征计数使能和shard同时使能 + uniqueOut.uniqueIdCntInBucket[j] = inGroupTotal; + } else if (!conf.useSharding && conf.useIdCount) { + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total, + uniqueOut.idCnt); // 特征计数使能和shard不使能 + } else if (conf.useSharding && !conf.useIdCount) { + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueIdInBucket, total, + nullptr); // 特征计数使能和shard不使能 + uniqueOut.uniqueIdCntInBucket[j] = inGroupTotal; + } else { + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total, + nullptr); // 特征计数不使能和shard不使能,跟普通unique对等 + } + } else { + inGroupTotal = dedupShards_[j]->UniqueRaw(uniqueOut.uniqueId, total, nullptr); + } + total += inGroupTotal; + } + uniqueOut.uniqueIdCnt = total; + } + + template + int TileAndFill(void *uniqueIdInBucket, const int32_t *uniqueSizeInBucket, void *uniqueIds, const int32_t *idCnt, + int32_t *idCntFill) + { + int start = 0; + int index = 0; + + auto uIdInBucket = TypeTrans(uniqueIdInBucket); + auto uIds = TypeTrans(uniqueIds); + + for (int i = 0; i < conf.shardingNum; i++) { + GetIndexAndStart(uniqueSizeInBucket, conf.usePadding, i, start, index); + + uint32_t memSize = 0; + if (T == DataType::INT64) { + memSize = uniqueSizeInBucket[i] * sizeof(int64_t); + } else if (T == DataType::INT32) { + memSize = uniqueSizeInBucket[i] * sizeof(int32_t); + } + + if (memSize == 0) { + continue; + } + + auto rc = memcpy_s(uIds + start, memSize, uIdInBucket + index, memSize); + int ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/uniqueIds]"); + if (ret != 0) { + return ret; + } + + if (conf.useIdCount && conf.usePadding) { + memSize = uniqueSizeInBucket[i] * sizeof(int32_t); + rc = memcpy_s(idCntFill + start, memSize, idCnt + index, memSize); + ret = PrintMemCpyLog(rc, memSize, "[TileAndFill/idCntFill]"); + } + if (ret != 0) { + return ret; + } + } + + if (conf.usePadding) { + HandleFill(uIds, uniqueSizeInBucket, idCntFill); + } + + return H_OK; + } + + int PrintMemCpyLog(int rc, const uint32_t dstSize, const std::string &logMsg) + { + if (rc != 0) { + std::stringstream ssm; + ssm << "[" << logMsg << "] memcpy_s failed... dstSize: " << dstSize; + ExternalLogger::PrintLog(LogLevel::ERROR, ssm.str()); + return H_COPY_ERROR; + } else { + return H_OK; + } + } + + template + void HandleFill(typename Map::type *uIds, const int32_t *uniqueSizeInBucket, int32_t *idCntFill) + { + int start = 0; + int index = 0; + + for (int i = 0; i < conf.shardingNum; i++) { + GetIndexAndStart(uniqueSizeInBucket, conf.usePadding, i, start, index); + + int fillLen = conf.paddingSize - uniqueSizeInBucket[i]; + for (int j = 0; j < fillLen; j++) { + uIds[start + uniqueSizeInBucket[i] + j] = conf.paddingVal; // padding填充 + } + + if (idCntFill != nullptr) { + for (int y = 0; y < fillLen; y++) { + idCntFill[start + uniqueSizeInBucket[i] + y] = 0; // 特征计数填充 + } + } + } + } + + void GetIndexAndStart(const int32_t *uniqueSizeInBucket, bool usePadding, int shardingNumber, int &start, + int &index) + { + if (shardingNumber > 0) { + index += uniqueSizeInBucket[shardingNumber - 1]; + } + + if (usePadding) { + start = shardingNumber * conf.paddingSize; + } else { + start = index; + } + } + + template int InsertVal(UniqueIn &uniqueIn, size_t threadNum) + { + auto val = TypeTrans(uniqueIn.inputId); + std::vector> tasks; + int ret = H_OK; + for (uint32_t i = 0; i < threadNum; ++i) { + uint32_t start = i * partSize; + uint32_t end = std::min(uniqueIn.inputIdCnt, (i + 1) * partSize); + tasks.push_back([this, val, start, end, &ret]() { + for (uint64_t j = start; j < end; ++j) { + auto value = val[j]; + if (value > conf.maxIdVal) { + ExternalLogger::PrintLog(LogLevel::ERROR, "id val is larger than maxIdVal"); + ret = H_ID_LARGE; + break; + } + auto group = groupMethod_.GroupId(value); + dedupShards_[group]->Insert(value); + } + }); + } + + try { + if (!tasks.empty()) { + auto threader = ExternalThreader::Instance(); + if (threader == nullptr) { + return H_ADDRESS_NULL; + } + threader->Run(tasks); + } + } catch (NullptrError &) { + return H_ADDRESS_NULL; + } + + return ret; + } + + template + int CalUniqueOut(UniqueIn &uniqueIn, UniqueOutSelf &uniqueOut, std::vector &totalUniqueSize) + { + uint32_t *beginPtr = uniqueOut.index; + uint32_t *finishPtr = beginPtr + uniqueIn.inputIdCnt; + uint32_t *partBeginPtr = beginPtr; + auto *partEndPtr = + reinterpret_cast(CacheLineAlign(reinterpret_cast(partBeginPtr + partSize))); + std::vector> tasks; + auto val = TypeTrans(uniqueIn.inputId); + while (partBeginPtr < finishPtr) { + if (partEndPtr > finishPtr) { + partEndPtr = finishPtr; + } + if (partBeginPtr < partEndPtr) { + // Due to cacheline alignment computation, the actual number of + // threads created here may not match threadNum exactly but + // should be +/-1 off. + tasks.push_back([this, val, beginPtr, partBeginPtr, partEndPtr, totalUniqueSize]() { + for (uint32_t *ptr = partBeginPtr; ptr < partEndPtr; ++ptr) { + auto group = groupMethod_.GroupId(val[ptr - beginPtr]); + int32_t fillOffset = GetFillOffset(totalUniqueSize, val[ptr - beginPtr], group); + *ptr = fillOffset; + } + }); + } + partBeginPtr = partEndPtr; + partEndPtr += partSize; + } + + try { + if (!tasks.empty()) { + auto threader = ExternalThreader::Instance(); + if (threader == nullptr) { + return H_ADDRESS_NULL; + } + threader->Run(tasks); + } + } catch (NullptrError &) { + return H_ADDRESS_NULL; + } + return H_OK; + } + +private: + GroupMethod groupMethod_; + uint32_t bucketCountPower2_; + UniqueConf conf; + std::vector> dedupShards_ {}; + uint32_t partSize; + bool firstEnterFlag_ = false; + bool idCountEnable_ { false }; +}; +} +} +#endif \ No newline at end of file diff --git a/src/AccCTR/src/unique/unique_impl.cpp b/src/AccCTR/src/unique/unique_impl.cpp new file mode 100644 index 00000000..77113214 --- /dev/null +++ b/src/AccCTR/src/unique/unique_impl.cpp @@ -0,0 +1,313 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include +#include "unique_impl.h" + +namespace ock { +namespace ctr { +UniqueImpl::UniqueImpl() {} + +int UniqueImpl::Initialize(const UniqueConf &conf) +{ + uniqueConf = conf; + int ret = CheckConf(uniqueConf); + if (ret != H_OK) { + return ret; + } + + if ((uniqueConf.outputType == OutputType::NORMAL) || (!uniqueConf.useSharding)) { + uniqueConf.shardingNum = 1; // 默认1个桶 + } + + GroupMethod groupMethod {}; + groupMethod.SetGroupCount(uniqueConf.shardingNum); + groupMethod.SetStrategyFunByConf(uniqueConf.bucketStrategy); + try { + UnInitialize(); + unique = new ShardedDedup(groupMethod, uniqueConf); + if (unique == nullptr) { + return H_ADDRESS_NULL; + } + } catch (AllocError &) { + ExternalLogger::PrintLog(LogLevel::ERROR, "memory alloc error"); + return H_MEMORY_ALLOC_ERROR; + } catch (NullptrError &) { + return H_ADDRESS_NULL; + } + + return H_OK; +} + +int UniqueImpl::DoUnique(UniqueIn &uniqueIn, UniqueOut &uniqueOut) +{ + TimeCost doUniqueTimeCost; + if (!IsInitialized()) { + return H_UNIQUE_UNINITIALIZED_ERROR; + } + + if (uniqueConf.outputType != OutputType::NORMAL) { + ExternalLogger::PrintLog(LogLevel::ERROR, "output type error, should be NORMAL"); + return H_OUTPUT_TYPE_ERROR; + } + + int ret = CheckInput(uniqueIn, uniqueOut); + if (ret != H_OK) { + ExternalLogger::PrintLog(LogLevel::ERROR, "input or conf is error"); + return ret; + } + + UniqueOutSelf uniqueOutSelf; + uniqueOutSelf.uniqueId = uniqueOut.uniqueId; + uniqueOutSelf.index = uniqueOut.index; + uniqueOutSelf.uniqueIdCnt = uniqueOut.uniqueIdCnt; + + if (uniqueConf.dataType == DataType::INT64) { + ret = unique->Compute(uniqueIn, uniqueOutSelf); + } else if (uniqueConf.dataType == DataType::INT32) { + ret = unique->Compute(uniqueIn, uniqueOutSelf); + } + + if (ret != H_OK) { + ExternalLogger::PrintLog(LogLevel::ERROR, "do unique error"); + return ret; + } + uniqueOut.uniqueIdCnt = uniqueOutSelf.uniqueIdCnt; + + std::stringstream sm; + sm << "input id count " << uniqueIn.inputIdCnt << "; id count after unique " << uniqueOut.uniqueIdCnt << + "; unique id time cost " << doUniqueTimeCost.ElapsedMS() << " (ms)"; + ExternalLogger::PrintLog(LogLevel::INFO, sm.str(), uniqueConf.trace); + + // 资源回收 + return ret; +} + +int UniqueImpl::DoEnhancedUnique(UniqueIn &uniqueIn, EnhancedUniqueOut &uniqueOut) +{ + TimeCost doEnhancedUniqueTimeCost; + if (!IsInitialized()) { + return H_UNIQUE_UNINITIALIZED_ERROR; + } + + if (uniqueConf.outputType != OutputType::ENHANCED) { + ExternalLogger::PrintLog(LogLevel::ERROR, "output type error, should be ENHANCED"); + return H_OUTPUT_TYPE_ERROR; + } + + int ret = CheckInput(uniqueIn, uniqueOut); + if (ret != H_OK) { + ExternalLogger::PrintLog(LogLevel::ERROR, "input or conf is error."); + return ret; + } + + UniqueOutSelf uniqueOutSelf; + uniqueOutSelf.uniqueId = uniqueOut.uniqueId; + uniqueOutSelf.index = uniqueOut.index; + uniqueOutSelf.uniqueIdCntInBucket = uniqueOut.uniqueIdCntInBucket; + uniqueOutSelf.uniqueIdInBucket = uniqueOut.uniqueIdInBucket; + if (uniqueConf.useIdCount) { + uniqueOutSelf.idCnt = uniqueOut.idCnt; + uniqueOutSelf.idCntFill = uniqueOut.idCntFill; + } + + if (uniqueConf.dataType == DataType::INT64) { + ret = unique->Compute(uniqueIn, uniqueOutSelf); + } else if (uniqueConf.dataType == DataType::INT32) { + ret = unique->Compute(uniqueIn, uniqueOutSelf); + } + + if (ret != H_OK) { + ExternalLogger::PrintLog(LogLevel::ERROR, "do unique error"); + return ret; + } + + if (uniqueConf.usePadding) { + uniqueOut.uniqueIdCnt = uniqueConf.paddingSize * uniqueConf.shardingNum; + } else { + uniqueOut.uniqueIdCnt = uniqueOutSelf.uniqueIdCnt; + } + + std::stringstream sm; + sm << "input id count " << uniqueIn.inputIdCnt << "; id count after unique " << uniqueOutSelf.uniqueIdCnt << + "; unique id time cost " << doEnhancedUniqueTimeCost.ElapsedMS() << " (ms)"; + ExternalLogger::PrintLog(LogLevel::INFO, sm.str(), uniqueConf.trace); + // 资源回收 + return ret; +} + +void UniqueImpl::UnInitialize() +{ + if (unique != nullptr) { + delete unique; + unique = nullptr; + } +} + +int UniqueImpl::SetExternalThreadFuncInner(ExternalThread threadFunc) +{ + auto threader = ExternalThreader::Instance(); + if (threader == nullptr) { + std::cout << "Failed to create threader instance" << std::endl; + return H_NEW_OBJECT_FAILED; + } + + threader->SetExternalLogFunction(threadFunc); + return H_OK; +} + +int UniqueImpl::CheckConf(const UniqueConf &conf) +{ + int ret = CheckNormalConf(conf); + if (ret != H_OK) { + return ret; + } + + if (conf.outputType == OutputType::ENHANCED) { + ret = CheckEnhancedUniqueConf(conf); + if (ret != H_OK) { + return ret; + } + } + + return H_OK; +} + +int UniqueImpl::CheckNormalConf(const UniqueConf &conf) +{ + if (CheckInputZero(conf.maxIdVal, "maxIdVal") || CheckInputZero(conf.desiredSize, "desiredSize")) { + return H_NUM_SMALL; + } + uint32_t processCoreNum = std::thread::hardware_concurrency(); + if (conf.maxThreadNum > processCoreNum) { + std::stringstream sm; + sm << "maxThreadNum can not larger than " << processCoreNum; + ExternalLogger::PrintLog(LogLevel::ERROR, sm.str()); + return H_ERROR; + } + + if (conf.maxThreadNum == 0 || conf.minThreadNum == 0 || conf.minThreadNum > conf.maxThreadNum) { + ExternalLogger::PrintLog(LogLevel::ERROR, "please check minThreadNum and maxThreadNum"); + return H_ERROR; + } + + if (conf.desiredSize > MAX_DESIRED_SIZE) { + ExternalLogger::PrintLog(LogLevel::ERROR, "desiredSize can not larger than 1431655765"); + return H_ERROR; + } + + return H_OK; +} + +int UniqueImpl::CheckEnhancedUniqueConf(const UniqueConf &conf) +{ + if (conf.usePadding) { + if (!conf.useSharding) { + ExternalLogger::PrintLog(LogLevel::ERROR, "sharding is not enable."); // 使能padding时,先使能sharding + return H_SCENE_ERROR; + } + + if (CheckInputZero(conf.paddingSize, "paddingSize")) { + ExternalLogger::PrintLog(LogLevel::ERROR, "if usePadding is true, paddingSize can not be zero"); + return H_NUM_SMALL; + } + } + + if (conf.useSharding) { + if (CheckInputZero(conf.shardingNum, "shardingNum")) { + return H_NUM_SMALL; + } + } + + return H_OK; +} + +int UniqueImpl::CheckInput(UniqueIn &uniqueIn, EnhancedUniqueOut &uniqueOut) +{ + if (CheckInputNull(uniqueIn.inputId, "inputId") || CheckInputNull(uniqueOut.uniqueId, "uniqueId") || + CheckInputNull(uniqueOut.index, "index")) { + return H_ADDRESS_NULL; + } + + if (uniqueConf.useSharding) { + if (CheckInputNull(uniqueOut.uniqueIdInBucket, "uniqueIdInBucket") || + CheckInputNull(uniqueOut.uniqueIdCntInBucket, "uniqueIdCntInBucket")) { + return H_ADDRESS_NULL; + } + } + + if (uniqueConf.useIdCount) { + if (CheckInputNull(uniqueOut.idCnt, "idCnt")) { + return H_ADDRESS_NULL; + } + if (uniqueConf.usePadding) { + if (CheckInputNull(uniqueOut.idCntFill, "idCntFill")) { + return H_ADDRESS_NULL; + } + } + } + + if (CheckInputZero(uniqueIn.inputIdCnt, "inputIdCnt")) { + return H_NUM_SMALL; + } + + if (uniqueIn.inputIdCnt > MAX_ID_COUNT) { + ExternalLogger::PrintLog(LogLevel::ERROR, "inputIdCnt can not larger than 2^28"); + return H_ERROR; + } + + return H_OK; +} + +int UniqueImpl::CheckInput(UniqueIn &uniqueIn, UniqueOut &uniqueOut) +{ + if (CheckInputNull(uniqueIn.inputId, "inputId") || CheckInputNull(uniqueOut.uniqueId, "uniqueId") || + CheckInputNull(uniqueOut.index, "index")) { + return H_ADDRESS_NULL; + } + return H_OK; +} + +bool UniqueImpl::CheckInputNull(void *ptr, const std::string &name) +{ + if (ptr == nullptr) { + std::stringstream sm; + sm << name << "can not be nullptr"; + ExternalLogger::PrintLog(LogLevel::ERROR, sm.str()); + return true; + } + return false; +} + +bool UniqueImpl::CheckInputZero(int64_t in, const std::string &name) +{ + if (in <= 0) { + std::stringstream sm; + sm << name << "can not be zero or negative"; + ExternalLogger::PrintLog(LogLevel::ERROR, sm.str()); + return true; + } + return false; +} + +bool UniqueImpl::IsInitialized() +{ + if (unique == nullptr) { + ExternalLogger::PrintLog(LogLevel::ERROR, "please call Initialize before DoUnique"); + return false; + } + return true; +} +} +} diff --git a/src/AccCTR/src/unique/unique_impl.h b/src/AccCTR/src/unique/unique_impl.h new file mode 100644 index 00000000..f4c45fde --- /dev/null +++ b/src/AccCTR/src/unique/unique_impl.h @@ -0,0 +1,51 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_UNIQUE_IMPL_H +#define OCK_UNIQUE_IMPL_H + +#include "unique_func.h" +#include "factory.h" +namespace ock { +namespace ctr { +class UniqueImpl : public Unique { +public: + explicit UniqueImpl(); + ~UniqueImpl() override = default; + +public: + int Initialize(const UniqueConf &conf) override; + void UnInitialize() override; + int DoUnique(UniqueIn &uniqueIn, UniqueOut &uniqueOut) override; + int DoEnhancedUnique(UniqueIn &uniqueIn, EnhancedUniqueOut &uniqueOut) override; + int SetExternalThreadFuncInner(ExternalThread threadFunc) override; + +private: + int CheckInput(UniqueIn &uniqueIn, UniqueOut &uniqueOut); + bool CheckInputNull(void *ptr, const std::string &name); + bool CheckInputZero(int64_t in, const std::string &name); + bool IsInitialized(); + int CheckConf(const UniqueConf &conf); + int CheckInput(UniqueIn &uniqueIn, EnhancedUniqueOut &uniqueOut); + int CheckNormalConf(const UniqueConf &conf); + int CheckEnhancedUniqueConf(const UniqueConf &conf); + +private: + ShardedDedup *unique = nullptr; + UniqueConf uniqueConf {}; +}; +} +} + +#endif // OCK_UNIQUE_IMPL_H diff --git a/src/AccCTR/tests/CMakeLists.txt b/src/AccCTR/tests/CMakeLists.txt new file mode 100644 index 00000000..ba38343d --- /dev/null +++ b/src/AccCTR/tests/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +message("build mode " ${BUILD_MODE}) +add_compile_options(-ftest-coverage -fprofile-arcs) +link_libraries(gcov) + +if (${BUILD_MODE} MATCHES "ut") + add_subdirectory(ut) +endif (${BUILD_MODE} MATCHES "ut") diff --git a/src/AccCTR/tests/tools/create_fake_id.py b/src/AccCTR/tests/tools/create_fake_id.py new file mode 100644 index 00000000..fc0f1f8e --- /dev/null +++ b/src/AccCTR/tests/tools/create_fake_id.py @@ -0,0 +1,106 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import random +import os + + +def write_data(file_name, x, y, dup): + clear_file(file_name) + length = 200 + interval = 100000 + count = 0 + ids = [] + all_ids = [] + for j in range(0, int(x / length)): + for i in range(0, length): + ids.append(random.randrange(0 + interval * j, interval * (j + 1))) + if len(ids) == 200: + for k in range(0, dup): + count = int(count) + len(ids) + for val in ids: + all_ids.append(val) + ids = [] + + if int(len(ids)) > 0: + for val in ids: + all_ids.append(val) + + ids = [] + my_set = set() + for j in range(0, int(y / length)): + for i in range(0, length): + ids.append(random.randrange(0 + interval * (int(x / length) + j), interval * (int(x / length) + j + 1))) + count = count + 1 + all_ids.append(ids[i]) + if len(ids) == 200: + ids = [] + if int(len(ids)) > 0: + for val in ids: + all_ids.append(val) + + random.shuffle(all_ids) + + ids = [] + for j in range(0, len(all_ids)): + ids.append(all_ids[j]) + if len(ids) % 200 == 0: + write_file(ids, file_name) + ids = [] + + write_file(ids, file_name) + + for val in all_ids: + my_set.add(val) + + print("count: ", count, "all_ids len:", len(all_ids), " set size: ", len(my_set)) + + +def main(): + # 300w id去重率20% + # 6x + y =300 + # x + y = 60 + # x = 48 y =12 + write_data('data20.txt', 48*10000, 12*10000, 6) + + # 300w id去重率30% + # 6x + y =300 + # x + y = 90 + # x = 42 y =48 + write_data('data30.txt', 42*10000, 48*10000, 6) + + # 300w id去重率40% + # 6x + y =300 + # x + y = 120 + # x = 36 y =84 + write_data('data40.txt', 36*10000, 84*10000, 6) + + +def write_file(ids, file_name): + w = "" + for id in ids: + w += str(id) + ", " + f = open(file_name, 'a') + f.write(w + "\n") + f.close() + + +def clear_file(file_name): + if os.path.exists(file_name): + with open(file_name, "r+") as f: + f.truncate(0) + + +if __name__ == '__main__': + main() diff --git a/src/AccCTR/tests/ut/CMakeLists.txt b/src/AccCTR/tests/ut/CMakeLists.txt new file mode 100644 index 00000000..1ed27503 --- /dev/null +++ b/src/AccCTR/tests/ut/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +add_subdirectory(src) \ No newline at end of file diff --git a/src/AccCTR/tests/ut/src/CMakeLists.txt b/src/AccCTR/tests/ut/src/CMakeLists.txt new file mode 100644 index 00000000..a4c631e8 --- /dev/null +++ b/src/AccCTR/tests/ut/src/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +add_compile_options(-ftest-coverage -fprofile-arcs) +link_libraries(gcov) + +set(OCK_CTR_UTIL_INSTALL_DIR ${PROJECT_SOURCE_DIR}/install) +set(OCK_CTR_SRC_DIR ${PROJECT_SOURCE_DIR}/src) +message("src" ${OCK_CTR_SRC_DIR}) + +file(GLOB_RECURSE TEST_UNIQUE_FILES *.cpp *.h) +add_executable(test_unique_files ${TEST_UNIQUE_FILES}) +include_directories(${OCK_CTR_UTIL_INSTALL_DIR}/googletest-release-1.8.1/include) +link_directories(${OCK_CTR_UTIL_INSTALL_DIR}/googletest-release-1.8.1/lib64) + +SET(LIB_3RD_GMOCK ${OCK_CTR_UTIL_INSTALL_DIR}/googletest-release-1.8.1/lib64/libgmock.a) +SET(LIB_3RD_GTEST ${OCK_CTR_UTIL_INSTALL_DIR}/googletest-release-1.8.1/lib64/libgtest.a) + + +message(${OCK_CTR_SRC_DIR}/include) + +target_include_directories(test_unique_files + PUBLIC + ${OCK_CTR_SRC_DIR}/include) + +target_link_libraries(test_unique_files + PUBLIC + -Wl,--start-group + pthread + dl + ${LIB_3RD_GTEST} + ${LIB_3RD_GMOCK} + -Wl,--end-group) + diff --git a/src/AccCTR/tests/ut/src/gtest_main.cpp b/src/AccCTR/tests/ut/src/gtest_main.cpp new file mode 100644 index 00000000..068d08d7 --- /dev/null +++ b/src/AccCTR/tests/ut/src/gtest_main.cpp @@ -0,0 +1,21 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include + +int main(int argc, char *argv[]) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/AccCTR/tests/ut/src/unique_test.cpp b/src/AccCTR/tests/ut/src/unique_test.cpp new file mode 100644 index 00000000..ef6846f8 --- /dev/null +++ b/src/AccCTR/tests/ut/src/unique_test.cpp @@ -0,0 +1,1520 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#include +#include +#include "unique_test.h" + +FactoryPtr factory; + +void UniqueTest::SetUpTestCase() +{ + Factory::Create(factory); +} + +void UniqueTest::TearDownTestCase() {} + +TEST_F(UniqueTest, Conf) +{ + std::cout << "===========Conf start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + ASSERT_EQ(Factory::Create(factory), -1); // 重复创建检查 + + UniqueConf conf; + conf.usePadding = true; + conf.useSharding = true; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.minThreadNum = 1; + conf.maxThreadNum = 1; + conf.maxIdVal = 0; + conf.shardingNum = 2; + conf.paddingSize = 1; + conf.outputType = OutputType::ENHANCED; + ASSERT_EQ(unique->Initialize(conf), 4); // maxIdVal为0错误 + conf.maxIdVal = 9; + conf.shardingNum = 0; + ASSERT_EQ(unique->Initialize(conf), 4); // shardingNum为0错误 + conf.shardingNum = 4; + conf.maxThreadNum = 0; + ASSERT_EQ(unique->Initialize(conf), 1); // maxThreadNum为0错误 + conf.maxThreadNum = 2; + conf.paddingSize = 0; + ASSERT_EQ(unique->Initialize(conf), 4); // paddingSize为0错误 + conf.desiredSize = 0; + ASSERT_EQ(unique->Initialize(conf), 4); // desiredSize为0错误 + conf.desiredSize = 1431655766; + ASSERT_EQ(unique->Initialize(conf), 1); // desiredSize为1431655766错误 + conf.desiredSize = 6; + conf.minThreadNum = 100; + conf.maxThreadNum = 100; + ASSERT_EQ(unique->Initialize(conf), 1); // minThreadNum过大错误 + conf.minThreadNum = 1; + conf.maxThreadNum = 1; + conf.paddingSize = 1; + ASSERT_EQ(unique->Initialize(conf), 0); // 配置正确 + + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(6); + vector index(6); + int *idCnt = new int[6]; + int *idCntFill = new int[6]; + int *uniqueIdCntInBucket = new int[6]; + int64_t *uniqueIdInBucket = new int64_t[6]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = 6; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = nullptr; + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt; + uniqueOut.idCntFill = idCntFill; + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket; + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 3); // uniqueId 空指针 + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.uniqueIdCntInBucket = nullptr; + ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 3); // uniqueIdCntInBucket空指针 + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket; + uniqueOut.idCntFill = nullptr; + ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 3); // idCntFill空指针 + uniqueOut.idCntFill = idCntFill; + ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 7); // padding长度过小 + std::cout << "===========Conf end=============" << std::endl; +} + +// 使用了padding,但是没配sharding(padding依赖于sharding),报场景错误(9) +TEST_F(UniqueTest, usePaddingNoShardingErr) +{ + std::cout << "===========usePaddingNoShardingErr start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.usePadding = true; + conf.useSharding = false; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.maxIdVal = 9; + conf.paddingSize = 1; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 9); + std::cout << "===========usePaddingNoShardingErr end=============" << std::endl; +} + +TEST_F(UniqueTest, useNegativeDesiredSize) +{ + std::cout << "===========useNegativeDesiredSize start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = -1; + conf.dataType = DataType::INT64; + conf.maxIdVal = 9; + conf.outputType = OutputType::NORMAL; + + ASSERT_EQ(unique->Initialize(conf), 1); + + std::cout << "===========useNegativeDesiredSize end=============" << std::endl; +} + +// DoUniqueNormal 只返回去重向量,恢复向量,长度 +TEST_F(UniqueTest, DoUniqueNormal) +{ + std::cout << "===========DoUniqueNormal start=============" << std::endl; + char *path = get_current_dir_name(); + std::string input_path(path); + std::cout << "input_path:" + input_path + "/data30.txt" << std::endl; + std::ifstream input(input_path + "/data30.txt"); + + std::vector numbers; + std::string line; + while (std::getline(input, line, ',')) { + std::istringstream in(line); + std::copy(std::istream_iterator(in), std::istream_iterator(), std::back_inserter(numbers)); + } + input.close(); + std::cout << "read data close, numbers size:" << numbers.size() << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.trace = true; + conf.desiredSize = numbers.size(); + conf.dataType = DataType::INT64; + conf.minThreadNum = 1; + conf.maxThreadNum = 8; + conf.maxIdVal = 10000000000; + conf.shardingNum = 1; + conf.outputType = OutputType::NORMAL; + + ASSERT_EQ(unique->Initialize(conf), 0); + + int inputLen = numbers.size(); + vector inputId; + for (size_t i = 0; i < numbers.size(); i++) { + inputId.emplace_back(numbers[i]); + } + vector uniqueId(inputLen); + vector index(inputLen); + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = nullptr; + + UniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.uniqueIdCnt = 0; + ASSERT_EQ(unique->DoUnique(uniqueIn, uniqueOut), 3); // idCntFill空指针 + + uniqueIn.inputId = inputId.data(); + unique->DoUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + for (uint32_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + } + + unordered_set idsSet; + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + idsSet.insert(numbers[i]); + } + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)idsSet.size()); + + unique->UnInitialize(); + std::cout << "===========DoUniqueNormal end=============" << std::endl; +} + +// 配Enhanced的conf,却使用normal接口 +TEST_F(UniqueTest, UseErrOutputTypeEnhanced) +{ + std::cout << "===========UseErrOutputTypeEnhanced start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.minThreadNum = 1; + conf.maxThreadNum = 1; + conf.maxIdVal = 9; + conf.shardingNum = 1; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + + int inputLen = 6; + vector inputId = { 1 }; + vector uniqueId(inputLen); + vector index(inputLen); + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + UniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.uniqueIdCnt = 0; + + int ret = unique->DoUnique(uniqueIn, uniqueOut); + ASSERT_EQ(ret, 8); + + unique->UnInitialize(); + std::cout << "===========UseErrOutputTypeEnhanced end=============" << std::endl; +} + +// 配NORAML的conf,却使用extra接口 +TEST_F(UniqueTest, UseErrOutputTypeNormal) +{ + std::cout << "===========UseErrOutputTypeNormal start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.minThreadNum = 1; + conf.maxThreadNum = 1; + conf.maxIdVal = 9; + conf.shardingNum = 1; + conf.outputType = OutputType::NORMAL; + + ASSERT_EQ(unique->Initialize(conf), 0); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + + int ret = unique->DoEnhancedUnique(uniqueIn, uniqueOut); + ASSERT_EQ(ret, 8); + + unique->UnInitialize(); + std::cout << "===========UseErrOutputTypeNormal end=============" << std::endl; +} + +// 用增强接口实现基础场景,只返回去重向量、恢复向量、长度 +TEST_F(UniqueTest, DoEnhancedUnique) +{ + std::cout << "===========DoEnhancedUnique start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.maxIdVal = 9; + conf.shardingNum = 1; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + } + + unordered_set idsSet; + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + idsSet.insert(inputId[i]); + } + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)idsSet.size()); + + unique->UnInitialize(); + std::cout << "===========DoEnhancedUnique end=============" << std::endl; +} + +// 开启特征计数但是不开启padding,且配置了idCntFill导致系统异常 +TEST_F(UniqueTest, DoEnhancedUniqueErr) +{ + std::cout << "===========DoEnhancedUniqueErr start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.maxIdVal = 9; + conf.shardingNum = 1; + conf.useIdCount = true; + conf.useSharding = true; + conf.outputType = OutputType::ENHANCED; + + + ASSERT_EQ(unique->Initialize(conf), 0); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + + int size = 6; + vector idCntFill(size); + vector uniqueIdCntInBucket(size); + int64_t *uniqueIdInBucket = new int64_t[size]; + int *idCnt = new int[size]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + + // 不开启padding counting但是传入地址 + uniqueOut.uniqueIdInBucket = uniqueIdInBucket; + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.idCntFill = idCntFill.data(); + uniqueOut.idCnt = idCnt; + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + } + + unordered_set idsSet; + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + idsSet.insert(inputId[i]); + } + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)idsSet.size()); + + unique->UnInitialize(); + std::cout << "===========DoEnhancedUniqueErr end=============" << std::endl; +} + +// 用增强接口实现基础场景,padding 返回长度测试 +TEST_F(UniqueTest, DoEnhancedUnique_UniqueIdSize) +{ + std::cout << "===========DoEnhancedUnique_UniqueIdSize start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.maxIdVal = 9; + conf.shardingNum = 1; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + vector restoreIds(uniqueIn.inputIdCnt); + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + } + + unordered_set idsSet; + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + idsSet.insert(inputId[i]); + } + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)idsSet.size()); + + unique->UnInitialize(); + std::cout << "===========DoEnhancedUnique_UniqueIdSize end=============" << std::endl; +} + +// 增强接口配置特征计数,但特征计数向量设置为空 +TEST_F(UniqueTest, idCntIsNull) +{ + std::cout << "===========idCntIsNull start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.maxIdVal = 9; + conf.shardingNum = 1; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + int *idCnt = nullptr; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt; + + int ret = unique->DoEnhancedUnique(uniqueIn, uniqueOut); + ASSERT_EQ(ret, 3); + + unique->UnInitialize(); + std::cout << "===========idCntIsNull end=============" << std::endl; +} + +// 增强接口配置padding,特征计数,但特征计数向量设置为空 +TEST_F(UniqueTest, idCntIsNullSharding) +{ + std::cout << "===========idCntIsNullSharding start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useSharding = true; + conf.usePadding = true; + conf.useIdCount = true; + conf.paddingSize = 5; + conf.maxIdVal = 9; + conf.shardingNum = 1; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + int *idCnt = nullptr; + int *idCntFill = nullptr; + int *uniqueIdCntInBucket = new int[inputLen]; + int64_t *uniqueIdInBucket = new int64_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt; + uniqueOut.idCntFill = idCntFill; + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket; + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + int ret = unique->DoEnhancedUnique(uniqueIn, uniqueOut); + ASSERT_EQ(ret, 3); + + unique->UnInitialize(); + std::cout << "===========idCntIsNullSharding end=============" << std::endl; +} + +// 增强接口,配置sharding,特征计数 +TEST_F(UniqueTest, DoUniqueShard) +{ + std::cout << "===========DoUniqueShard start=============" << std::endl; + + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.useSharding = true; + conf.useIdCount = true; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.maxIdVal = 9; + conf.paddingSize = 1; + conf.shardingNum = 2; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + vector idCnt(inputLen); + vector uniqueIdCntInBucket(conf.shardingNum); + int64_t *uniqueIdInBucket = new int64_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt.data(); + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + vector expectedUniqueIdCnt(conf.shardingNum); + + unordered_set uniqueIdSet; + map expectedIdCntMap; + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + expectedIdCntMap[inputId[i]]++; + if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + continue; + } else { + uniqueIdSet.insert(inputId[i]); + expectedUniqueIdCnt[inputId[i] % conf.shardingNum]++; + } + } + + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)uniqueIdSet.size()); + + vector expectedIdCnt(uniqueOut.uniqueIdCnt); + for (int i = 0; i < uniqueOut.uniqueIdCnt; i++) { + expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + } + expectedIdCnt.resize(uniqueIn.inputIdCnt); + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); + ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); + unique->UnInitialize(); + + std::cout << "===========DoUniqueShard end=============" << std::endl; +} + +// 增强接口,只配置sharding +TEST_F(UniqueTest, DoUniqueOnlyShard) +{ + std::cout << "===========DoUniqueOnlyShard start=============" << std::endl; + + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.useSharding = true; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.maxIdVal = 9; + conf.paddingSize = 1; + conf.shardingNum = 2; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + vector uniqueIdCntInBucket(conf.shardingNum); + int64_t *uniqueIdInBucket = new int64_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + vector expectedUniqueIdCnt(conf.shardingNum); + + unordered_set uniqueIdSet; + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + continue; + } else { + uniqueIdSet.insert(inputId[i]); + expectedUniqueIdCnt[inputId[i] % conf.shardingNum]++; + } + } + + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)uniqueIdSet.size()); + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); + unique->UnInitialize(); + + std::cout << "===========DoUniqueOnlyShard end=============" << std::endl; +} + +// 增强接口,配置sharding,padding,特征计数 +TEST_F(UniqueTest, DoUniquePadding) +{ + std::cout << "===========DoUniquePadding start=============" << std::endl; + + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.usePadding = true; + conf.useSharding = true; + conf.paddingVal = -1; + conf.paddingSize = 4; + conf.desiredSize = 1; // 配置空间小于实际输入数组长度,验证正常运行 + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.maxIdVal = 9; + conf.shardingNum = 2; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 4, 4, 5, 6, 4, 0 }; + vector uniqueId(conf.paddingSize * conf.shardingNum); + vector index(inputLen); + int *idCnt = new int[inputLen]; + vector idCntFill(conf.paddingSize * conf.shardingNum); + vector uniqueIdCntInBucket(conf.shardingNum); + int64_t *uniqueIdInBucket = new int64_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = 6; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt; + uniqueOut.idCntFill = idCntFill.data(); + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + vector expectedUniqueIdCnt(conf.shardingNum); + unordered_set uniqueIdSet; + map expectedIdCntMap; + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + expectedIdCntMap[inputId[i]]++; + + if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + continue; + } else { + uniqueIdSet.insert(inputId[i]); + expectedUniqueIdCnt[inputId[i] % conf.shardingNum]++; + } + } + + ASSERT_EQ(uniqueOut.uniqueIdCnt, conf.shardingNum * conf.paddingSize); + + vector expectedIdCnt(conf.paddingSize * conf.shardingNum); + for (int i = 0; i < conf.paddingSize * conf.shardingNum; i++) { + if (uniqueId[i] == -1) { + continue; + } + expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + } + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); + ASSERT_THAT(idCntFill, testing::ElementsAreArray(expectedIdCnt)); + ASSERT_EQ(uniqueOut.uniqueIdCnt, conf.paddingSize * conf.shardingNum); + unique->UnInitialize(); + std::cout << "===========DoUniquePadding end=============" << std::endl; +} + +// 增强接口,只配置特征计数,不配置线程池 +TEST_F(UniqueTest, DoUniqueNoThreadPool) +{ + std::cout << "===========DoUniqueNoThreadPool start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 20; // 配置空间大于实际输入数组长度,验证正常运行 + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.maxIdVal = 9; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + vector idCnt(inputLen); + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = 6; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt.data(); + + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + } + + unordered_set uniqueIdSet; + map expectedIdCntMap; + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + uniqueIdSet.insert(inputId[i]); + expectedIdCntMap[inputId[i]]++; + } + + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)uniqueIdSet.size()); + + vector expectedIdCnt(uniqueOut.uniqueIdCnt); + for (int i = 0; i < uniqueOut.uniqueIdCnt; i++) { + expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + } + expectedIdCnt.resize(uniqueIn.inputIdCnt); + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); + + unique->UnInitialize(); + std::cout << "===========DoUniqueNoThreadPool end=============" << std::endl; +} + +// 增强接口,配置sharding,特征计数,分桶数大于数据量 +TEST_F(UniqueTest, DoUniqueShardNumberOversize) +{ + std::cout << "===========DoUniqueShardNumberOversize start=============" << std::endl; + + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.useSharding = true; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.maxIdVal = 100; + conf.shardingNum = 7; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + vector idCnt(inputLen); + vector uniqueIdCntInBucket(conf.shardingNum); + int64_t *uniqueIdInBucket = new int64_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt.data(); + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + vector expectedUniqueIdCnt(conf.shardingNum); + + unordered_set uniqueIdSet; + map expectedIdCntMap; + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + expectedIdCntMap[inputId[i]]++; + if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + continue; + } else { + uniqueIdSet.insert(inputId[i]); + expectedUniqueIdCnt[inputId[i] % conf.shardingNum]++; + } + } + + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)uniqueIdSet.size()); + + int uniqueSum = 0; + + for (int i = 0; i < conf.shardingNum; i++) { + uniqueSum += uniqueIdCntInBucket[i]; + } + + vector expectedIdCnt(uniqueSum); + for (int i = 0; i < uniqueSum; i++) { + expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + } + expectedIdCnt.resize(uniqueIn.inputIdCnt); + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); + ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); + unique->UnInitialize(); + + std::cout << "===========DoUniqueShardNumberOversize end=============" << std::endl; +} + +TEST_F(UniqueTest, DoUniqueSpecial) +{ + std::cout << "===========DoUniqueSpecial start=============" << std::endl; + + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + int count = 1000000; + UniqueConf conf; + conf.paddingVal = -1; + conf.usePadding = false; + conf.useSharding = true; + conf.desiredSize = 100; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.minThreadNum = 1; + conf.maxThreadNum = 8; + conf.maxIdVal = 0; + conf.paddingSize = 4; + conf.shardingNum = 2; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 4); // maxIdVal为0错误 + conf.maxIdVal = count + 1; + ASSERT_EQ(unique->Initialize(conf), 0); + + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + std::vector inputData; + inputData.resize(count); + for (int i = 0; i < count; i++) { + inputData[i] = count; + } + int64_t *uniqueData = new int64_t[count]; + uint32_t *index = new uint32_t[count]; + int *idCnt = new int[count]; + int *idCntFill = new int[count]; + int *uniqueIdCntInBucket = new int[count]; + int64_t *uniqueIdInBucket = new int64_t[count]; + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = count; + uniqueIn.inputId = inputData.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueData); + uniqueOut.index = index; + uniqueOut.idCnt = idCnt; + uniqueOut.idCntFill = idCntFill; + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket; + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + for (int i = 0; i < 2; i++) { + ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 0); + + vector restoreIds(uniqueIn.inputIdCnt); + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueData[index[i]]; + } + + ASSERT_THAT(inputData, testing::ElementsAreArray(restoreIds)); + } + + unique->UnInitialize(); + + std::cout << "===========DoUniqueSpecial end=============" << std::endl; +} + +// 增强接口,id数过大 +TEST_F(UniqueTest, IdLarge) +{ + std::cout << "===========IdLarge start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.maxIdVal = 1; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + int *idCnt = new int[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt; + + ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 6); // ID太大 + std::cout << "===========IdLarge end=============" << std::endl; +} + +// 增强接口,配置输入数据为int32,配置sharding,特征计数。 +TEST_F(UniqueTest, DoUniqueNormalInt32) +{ + std::cout << "===========DoUniqueNormalInt32 start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.useSharding = true; + conf.desiredSize = 6; + conf.dataType = DataType::INT32; + conf.useIdCount = true; + conf.maxIdVal = 9; + conf.shardingNum = 2; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + vector idCnt(inputLen); + vector uniqueIdCntInBucket(conf.shardingNum); + int32_t *uniqueIdInBucket = new int32_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt.data(); + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + unordered_set uniqueIdSet; + map expectedIdCntMap; + + vector restoreIds(uniqueIn.inputIdCnt); + vector expectedUniqueIdCnt(conf.shardingNum); + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + expectedIdCntMap[inputId[i]]++; + if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + continue; + } else { + uniqueIdSet.insert(inputId[i]); + expectedUniqueIdCnt[inputId[i] % conf.shardingNum]++; + } + } + + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)uniqueIdSet.size()); + + vector expectedIdCnt(uniqueOut.uniqueIdCnt); + for (int i = 0; i < uniqueOut.uniqueIdCnt; i++) { + expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + } + expectedIdCnt.resize(uniqueIn.inputIdCnt); + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); + ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); + + unique->UnInitialize(); + std::cout << "===========DoUniqueNormalInt32 end=============" << std::endl; +} + +TEST_F(UniqueTest, DoUniqueMultipleTimes) +{ + std::cout << "===========DoUniqueMultipleTimes start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.maxIdVal = 9; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + + for (int i = 0; i < 1000; i++) { + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + } + + unordered_set idsSet; + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + idsSet.insert(inputId[i]); + } + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)idsSet.size()); + } + + unique->UnInitialize(); + std::cout << "===========DoUniqueMultipleTimes end=============" << std::endl; +} + +TEST_F(UniqueTest, DoUniqueShardMultipleTimes) +{ + std::cout << "===========DoUniqueShardMultipleTimes start=============" << std::endl; + + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.useSharding = true; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.maxIdVal = 9; + conf.shardingNum = 2; + conf.outputType = OutputType::ENHANCED; + + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + vector idCnt(inputLen); + vector uniqueIdCntInBucket(conf.shardingNum); + int64_t *uniqueIdInBucket = new int64_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt.data(); + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + for (int i = 0; i < 1000; i++) { + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + vector expectedUniqueIdCnt(conf.shardingNum); + + unordered_set uniqueIdSet; + map expectedIdCntMap; + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + expectedIdCntMap[inputId[i]]++; + if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + continue; + } else { + uniqueIdSet.insert(inputId[i]); + expectedUniqueIdCnt[inputId[i] % conf.shardingNum]++; + } + } + + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)uniqueIdSet.size()); + + int uniqueSum = 0; + + for (int i = 0; i < conf.shardingNum; i++) { + uniqueSum += uniqueIdCntInBucket[i]; + } + + vector expectedIdCnt(uniqueSum); + for (int i = 0; i < uniqueSum; i++) { + expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + } + expectedIdCnt.resize(uniqueIn.inputIdCnt); + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); + ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); + } + unique->UnInitialize(); + + std::cout << "===========DoUniqueShardMultipleTimes end=============" << std::endl; +} + +TEST_F(UniqueTest, DoUniquePaddingMultipleTimes) +{ + std::cout << "===========DoUniquePaddingMultipleTimes start=============" << std::endl; + + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.usePadding = true; + conf.useSharding = true; + conf.paddingVal = -1; + conf.paddingSize = 4; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.minThreadNum = 1; + conf.maxThreadNum = 1; + conf.maxIdVal = 9; + conf.shardingNum = 2; + conf.outputType = OutputType::ENHANCED; + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 4, 4, 5, 6, 4, 0 }; + vector uniqueId(conf.paddingSize * conf.shardingNum); + vector index(inputLen); + int *idCnt = new int[inputLen]; + vector idCntFill(conf.paddingSize * conf.shardingNum); + vector uniqueIdCntInBucket(conf.shardingNum); + int64_t *uniqueIdInBucket = new int64_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = 6; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt; + uniqueOut.idCntFill = idCntFill.data(); + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + for (int i = 0; i < 1000; i++) { + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + vector expectedUniqueIdCnt(conf.shardingNum); + unordered_set uniqueIdSet; + map expectedIdCntMap; + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + expectedIdCntMap[inputId[i]]++; + + if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + continue; + } else { + uniqueIdSet.insert(inputId[i]); + expectedUniqueIdCnt[inputId[i] % conf.shardingNum]++; + } + } + + ASSERT_EQ(uniqueOut.uniqueIdCnt, conf.shardingNum * conf.paddingSize); + + vector expectedIdCnt(conf.paddingSize * conf.shardingNum); + for (int i = 0; i < conf.paddingSize * conf.shardingNum; i++) { + if (uniqueId[i] == -1) { + continue; + } + expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + } + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); + ASSERT_THAT(idCntFill, testing::ElementsAreArray(expectedIdCnt)); + } + + unique->UnInitialize(); + std::cout << "===========DoUniquePaddingMultipleTimes end=============" << std::endl; +} + +TEST_F(UniqueTest, IdCntSmall) +{ + std::cout << "===========IdCntSmall start=============" << std::endl; + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + UniqueConf conf; + conf.desiredSize = 6; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.maxIdVal = 100; + conf.outputType = OutputType::ENHANCED; + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + int inputLen = 6; + vector inputId = { 5, 5, 3, 1, 5, 2 }; + vector uniqueId(inputLen); + vector index(inputLen); + int *idCnt = new int[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = 0; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt; + + ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 4); // idcnt过小 + std::cout << "===========IdCntSmall end=============" << std::endl; +} + +TEST_F(UniqueTest, DoUniqueLotsDataFunction) +{ + std::cout << "===========DoUniqueLotsDataFunction start=============" << std::endl; + char *path = get_current_dir_name(); + std::string input_path(path); + std::cout << "input_path:" + input_path + "/data40.txt" << std::endl; + std::ifstream input(input_path + "/data40.txt"); + + std::vector numbers; + std::string line; + while (std::getline(input, line, ',')) { + std::istringstream in(line); + std::copy(std::istream_iterator(in), std::istream_iterator(), std::back_inserter(numbers)); + } + input.close(); + std::cout << "read data close, numbers size:" << numbers.size() << std::endl; + + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + int inputLen = numbers.size(); + UniqueConf conf; + conf.useSharding = true; + conf.desiredSize = 1; // 配置空间小于实际输入数组长度,验证正常运行 + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.minThreadNum = 1; + conf.maxThreadNum = 8; + conf.maxIdVal = 10000000000; + conf.shardingNum = 8; + conf.outputType = OutputType::ENHANCED; + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + vector inputId; + + for (size_t i = 0; i < numbers.size(); i++) { + inputId.emplace_back(numbers[i]); + } + + vector uniqueId(inputLen); + vector index(inputLen); + vector idCnt(inputLen); + vector uniqueIdCntInBucket(conf.shardingNum); + int64_t *uniqueIdInBucket = new int64_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt.data(); + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.uniqueIdInBucket = uniqueIdInBucket; + + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + + vector restoreIds(uniqueIn.inputIdCnt); + vector expectedUniqueIdCnt(conf.shardingNum); + unordered_set uniqueIdSet; + map expectedIdCntMap; + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + expectedIdCntMap[inputId[i]]++; + + if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + continue; + } else { + uniqueIdSet.insert(inputId[i]); + expectedUniqueIdCnt[(uint64_t)(inputId[i]) % conf.shardingNum]++; + } + } + + ASSERT_EQ(uniqueOut.uniqueIdCnt, (int)uniqueIdSet.size()); + + int uniqueSum = 0; + + for (int i = 0; i < conf.shardingNum; i++) { + uniqueSum += uniqueIdCntInBucket[i]; + } + + vector expectedIdCnt(uniqueSum); + for (int i = 0; i < uniqueSum; i++) { + expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + } + expectedIdCnt.resize(uniqueIn.inputIdCnt); + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); + ASSERT_THAT(idCnt, testing::ElementsAreArray(expectedIdCnt)); + + unique->UnInitialize(); + if (path) { + free(path); + } + std::cout << "===========DoUniqueLotsDataFunction end=============" << std::endl; +} + + +TEST_F(UniqueTest, DoUniqueLotsDataPaddingFunction) +{ + std::cout << "===========DoUniqueLotsDataPaddingFunction start=============" << std::endl; + char *path = get_current_dir_name(); + std::string input_path(path); + std::cout << "input_path:" + input_path + "/data30.txt" << std::endl; + std::ifstream input(input_path + "/data30.txt"); + + std::vector numbers; + std::string line; + while (std::getline(input, line, ',')) { + std::istringstream in(line); + std::copy(std::istream_iterator(in), std::istream_iterator(), std::back_inserter(numbers)); + } + input.close(); + std::cout << "read data close, numbers size:" << numbers.size() << std::endl; + + UniquePtr unique; + ASSERT_EQ(factory->CreateUnique(unique), 0); + + int inputLen = numbers.size(); + UniqueConf conf; + conf.trace = true; + conf.usePadding = true; + conf.useSharding = true; + conf.paddingVal = -1; + conf.paddingSize = 150000; + conf.desiredSize = inputLen; + conf.dataType = DataType::INT64; + conf.useIdCount = true; + conf.minThreadNum = 1; + conf.maxThreadNum = 8; + conf.maxIdVal = 10000000000; + conf.shardingNum = 8; + conf.outputType = OutputType::ENHANCED; + ASSERT_EQ(unique->Initialize(conf), 0); + unique->SetExternalThreadFuncInner(SimpleThreadPool::SyncRun); + + vector inputId; + + for (size_t i = 0; i < numbers.size(); i++) { + inputId.emplace_back(numbers[i]); + } + + vector uniqueId(conf.paddingSize * conf.shardingNum); + vector index(inputLen); + int *idCnt = new int[inputLen]; + vector idCntFill(conf.paddingSize * conf.shardingNum); + vector uniqueIdCntInBucket(conf.shardingNum); + int64_t *uniqueIdInBucket = new int64_t[inputLen]; + + UniqueIn uniqueIn; + uniqueIn.inputIdCnt = inputLen; + uniqueIn.inputId = inputId.data(); + + EnhancedUniqueOut uniqueOut; + uniqueOut.uniqueId = reinterpret_cast(uniqueId.data()); + uniqueOut.index = index.data(); + uniqueOut.idCnt = idCnt; + uniqueOut.idCntFill = idCntFill.data(); + uniqueOut.uniqueIdCntInBucket = uniqueIdCntInBucket.data(); + uniqueOut.uniqueIdInBucket = reinterpret_cast(uniqueIdInBucket); + + for (int i = 0; i < 3; i++) { + unique->DoEnhancedUnique(uniqueIn, uniqueOut); + } + + vector restoreIds(uniqueIn.inputIdCnt); + vector expectedUniqueIdCnt(conf.shardingNum); + unordered_set uniqueIdSet; + map expectedIdCntMap; + + for (size_t i = 0; i < uniqueIn.inputIdCnt; i++) { + restoreIds[i] = uniqueId[index[i]]; + expectedIdCntMap[inputId[i]]++; + + if (uniqueIdSet.find(inputId[i]) != uniqueIdSet.end()) { + continue; + } else { + uniqueIdSet.insert(inputId[i]); + expectedUniqueIdCnt[(uint64_t)(inputId[i]) % conf.shardingNum]++; + } + } + + vector expectedIdCnt(conf.paddingSize * conf.shardingNum); + for (int i = 0; i < conf.paddingSize * conf.shardingNum; i++) { + if (uniqueId[i] == -1) { + continue; + } + expectedIdCnt[i] = expectedIdCntMap[uniqueId[i]]; + } + + ASSERT_THAT(inputId, testing::ElementsAreArray(restoreIds)); + ASSERT_THAT(uniqueIdCntInBucket, testing::ElementsAreArray(expectedUniqueIdCnt)); + ASSERT_THAT(idCntFill, testing::ElementsAreArray(expectedIdCnt)); + + unique->UnInitialize(); + ASSERT_EQ(unique->DoEnhancedUnique(uniqueIn, uniqueOut), 11); + if (path) { + free(path); + } + std::cout << "===========DoUniqueLotsDataPaddingFunction end=============" << std::endl; +} \ No newline at end of file diff --git a/src/AccCTR/tests/ut/src/unique_test.h b/src/AccCTR/tests/ut/src/unique_test.h new file mode 100644 index 00000000..0243f262 --- /dev/null +++ b/src/AccCTR/tests/ut/src/unique_test.h @@ -0,0 +1,60 @@ +/* Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + ==============================================================================*/ + +#ifndef OCK_UNIQUE_TEST_H +#define OCK_UNIQUE_TEST_H + +#include +#include +#include +#include +#include "factory.h" +#include "gtest/gtest.h" +#include "gmock/gmock.h" +#include "unique.h" + +using namespace std; +using namespace ock::ctr; + + +class SimpleThreadPool { +public: + static void SyncRun(const std::vector> &tasks) + { + std::vector> futs; + for (auto &task : tasks) { + futs.push_back(std::async(task)); + } + for (auto &fut : futs) { + fut.wait(); + } + } +}; + + +class UniqueTest : public testing::Test { +protected: + UniqueTest() {}; + ~UniqueTest() {}; + static void SetUpTestCase(); + static void TearDownTestCase(); + + + void SetUp() {} + + void TearDown() {} +}; + + +#endif // OCK_UNIQUE_TEST_H diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9965bfba..84505d15 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -109,13 +109,8 @@ if (OPENMP_FOUND) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") endif () add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) -set(SPDLOG_BUILD_SHARED ON) if(NOT OPENSOURCE_DIR) - if(BUILD_CUST STREQUAL "YES") - set(OPENSOURCE_DIR ${PROJECT_SOURCE_DIR}/../opensource/) - else() - set(OPENSOURCE_DIR ${PROJECT_SOURCE_DIR}/../../opensource/opensource/) - endif() + set(OPENSOURCE_DIR ${PROJECT_SOURCE_DIR}/../../opensource/) endif() if(IS_DIRECTORY ${OPENSOURCE_DIR}) diff --git a/src/build.sh b/src/build.sh index 8250f350..ed55e213 100644 --- a/src/build.sh +++ b/src/build.sh @@ -34,8 +34,8 @@ cmake -DCMAKE_BUILD_TYPE=Release \ -DPYTHON_PATH="$python_path" \ -DEASY_PROFILER_PATH=/ \ -DASCEND_PATH="$ascend_path" \ - -DABSEIL_PATH="$python_path"/lib/python3.7/site-packages/tensorflow_core/ \ - -DSECUREC_PATH="$2"/platform/securec \ + -DABSEIL_PATH="$1" \ + -DSECUREC_PATH="$2"/../opensource/securec \ -DCMAKE_INSTALL_PREFIX="$2"/output \ -DBUILD_CUST="$3" .. make -j diff --git a/src/platform/AccCTR b/src/platform/AccCTR deleted file mode 160000 index 62ab674f..00000000 --- a/src/platform/AccCTR +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 62ab674f0a42d8de8398eafb4799e506fe99549d diff --git a/src/test_ut.sh b/src/test_ut.sh index 7dde8ac2..0517f809 100644 --- a/src/test_ut.sh +++ b/src/test_ut.sh @@ -16,6 +16,15 @@ set -e +TF_VERSION=$1 +if [ "${TF_VERSION}" == "tf1" ]; then + TF_DIR=tensorflow_core +elif [ "${TF_VERSION}" == "tf2" ];then + TF_DIR=tensorflow +else + echo "TF_VERSION should be tf1 or tf2" +fi + # add mpirun env export OMPI_ALLOW_RUN_AS_ROOT=1 export OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 @@ -25,22 +34,78 @@ source /opt/rh/devtoolset-7/enable CUR_DIR=$(dirname "$(readlink -f "$0")") ROOT_DIR=$(dirname "${CUR_DIR}") -acc_ctr_path="${ROOT_DIR}"/src/platform/AccCTR -cp -rf "${ROOT_DIR}"/platform/securec/* "${acc_ctr_path}"/3rdparty/huawei_secure_c +opensource_path="${ROOT_DIR}"/../opensource +acc_ctr_path="${ROOT_DIR}"/src/AccCTR export LD_LIBRARY_PATH="${acc_ctr_path}"/output/ock_ctr_common/lib:$LD_LIBRARY_PATH -compile_securec() -{ - if [[ ! -d "${ROOT_DIR}"/platform/securec ]]; then +function prepare_googletest(){ + cd ${opensource_path} + if [ ! -d googletest-release-1.8.1 ]; then + unzip googletest-release-1.8.1.zip + fi + cd googletest-release-1.8.1 + if [ ! -d build ]; then + mkdir build + fi + cd build + rm -f CMakeCache.txt + cmake -DBUILD_SHARED_LIBS=ON .. + make -j8 + make install +} + +function prepare_emock(){ + cd ${opensource_path} + if [ ! -d emock-0.9.0 ]; then + unzip emock-0.9.0.zip + fi + cd emock-0.9.0 + if [ ! -d build ]; then + mkdir build + fi + cd build + rm -f CMakeCache.txt + cmake .. + make -j8 + make install +} + +function prepare_securec(){ + cd "${opensource_path}" + if [ ! -d securec ]; then + unzip huaweicloud-sdk-c-obs-3.23.9.zip + mv huaweicloud-sdk-c-obs-3.23.9/platform/huaweisecurec securec + rm -rf huaweicloud-sdk-c-obs-3.23.9 + rm -rf securec/lib/* + fi +} + +function compile_securec(){ + cd ${opensource_path} + if [[ ! -d "${opensource_path}/securec" ]]; then echo "securec is not exist" exit 1 fi - if [[ ! -f "${ROOT_DIR}"/platform/securec/lib/libsecurec.so ]]; then - cd "${ROOT_DIR}"/platform/securec/src - make -j + if [[ ! -f "${opensource_path}/securec/lib/libsecurec.so" ]]; then + cd "${opensource_path}/securec/src" + make -j4 fi } + +function prepare_pybind(){ + cd "${opensource_path}" + if [ ! -d pybind11 ]; then + unzip pybind11-2.10.3.zip + mv pybind11-2.10.3 pybind11 + fi +} + +prepare_pybind +echo "opensource path:${opensource_path}" +prepare_googletest +prepare_emock +prepare_securec compile_securec compile_acc_ctr_so_file() @@ -63,14 +128,16 @@ find ./ -name "*.sh" -exec chmod +x {} \; mkdir build cd build +python_path="$(dirname "$(dirname "$(which python3.7)")")" + cmake -DCMAKE_BUILD_TYPE=Debug \ - -DTF_PATH="$(dirname "$(dirname "$(which python3.7)")")"/lib/python3.7/site-packages/tensorflow_core \ + -DTF_PATH="${python_path}"/lib/python3.7/site-packages/"${TF_DIR}" \ -DOMPI_PATH=/usr/local/openmpi/ \ - -DPYTHON_PATH="$(dirname "$(dirname "$(which python3.7)")")" \ + -DPYTHON_PATH="${python_path}" \ -DEASY_PROFILER_PATH=/opt/buildtools/ \ -DASCEND_PATH=/usr/local/Ascend/ascend-toolkit/latest \ - -DABSEIL_PATH="$python_path"/lib/python3.7/site-packages/tensorflow_core/ \ - -DSECUREC_PATH="${ROOT_DIR}"/platform/securec \ + -DABSEIL_PATH="${python_path}"/lib/python3.7/site-packages/"${TF_DIR}" \ + -DSECUREC_PATH="${ROOT_DIR}"/../opensource/securec \ -DBUILD_TESTS=on -DCOVERAGE=on "$(dirname "${PWD}")" make -j diff --git a/tests/run_python_dt.sh b/tests/run_python_dt.sh index a487fb04..e0d92666 100644 --- a/tests/run_python_dt.sh +++ b/tests/run_python_dt.sh @@ -21,7 +21,8 @@ CUR_PATH=$(cd "$(dirname "$0")" || { warn "Failed to check path/to/run_python_dt TOP_PATH="${CUR_PATH}"/../ # build mxRec and get output directory -bash "$TOP_PATH"/build/build_tf1.sh +pip3 install setuptools==65.6.3 +bash "$TOP_PATH"/build/build_tf1_with_opensource.sh # create libasc directory and copy so files into it cd "$TOP_PATH"/mx_rec -- Gitee From f341c555a0314f96502df415eae6adc8a1e62550 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 6 Feb 2024 17:25:55 +0800 Subject: [PATCH 534/551] Match-id-9808c6faa0cc8a1695c707c877865ee519d750da --- docs/build_mxRec_images/centos_build/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/build_mxRec_images/centos_build/Dockerfile b/docs/build_mxRec_images/centos_build/Dockerfile index e218c1c0..e7180e39 100644 --- a/docs/build_mxRec_images/centos_build/Dockerfile +++ b/docs/build_mxRec_images/centos_build/Dockerfile @@ -14,12 +14,13 @@ RUN yum makecache && \ yum -y install devtoolset-7 && \ yum -y install devtoolset-7-gcc-c++ && \ yum -y install epel-release && \ - yum -y install wget zlib-devel bzip2 bzip2-devel openssl-devel ncurses-devel openssh-clients sqlite-devel openmpi-devel \ + yum -y install wget zlib-devel bzip2 bzip2-devel openssl-devel ncurses-devel openssh-clients openssh-server sqlite-devel openmpi-devel \ readline-devel tk-devel gdbm-devel db4-devel libpcap-devel xz-devel libffi-devel hdf5-devel patch pciutils lcov vim dos2unix gcc-c++ \ autoconf automake libtool git && \ yum clean all && \ rm -rf /var/cache/yum && \ echo "source /opt/rh/devtoolset-7/enable" >> /etc/profile +# 注:openssh-server为双机训练样例需要,仅单机训练时可去掉 # 2.安装gcc-7.3.0 RUN source /etc/profile && \ -- Gitee From 1895b2fc6ea88fe9b00f61f9d416ec11db69be36 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 7 Feb 2024 14:56:05 +0800 Subject: [PATCH 535/551] Match-id-61615020b56f82aa08d7b9e16e88ad9fd635d047 --- mx_rec/constants/__init__.py | 4 +- mx_rec/constants/constants.py | 22 +- mx_rec/core/asc/build_graph.py | 50 +- mx_rec/core/asc/feature_spec.py | 22 +- mx_rec/core/asc/helper.py | 63 +- mx_rec/core/asc/manager.py | 81 +- mx_rec/core/asc/merge_table.py | 38 +- mx_rec/core/emb/__init__.py | 0 mx_rec/core/emb/base_sparse_embedding.py | 565 ++++++++++ mx_rec/core/emb/dynamic_sparse_embedding.py | 108 ++ mx_rec/core/emb/emb_factory.py | 54 + mx_rec/core/emb/sparse_embedding.py | 205 ++++ mx_rec/core/embedding.py | 793 +------------- mx_rec/core/feature_process.py | 13 +- mx_rec/graph/acg_push_ops.py | 40 +- mx_rec/graph/graph_typing.py | 35 + mx_rec/graph/merge_lookup.py | 30 +- mx_rec/graph/modifier.py | 294 +++--- mx_rec/graph/patch.py | 24 +- mx_rec/graph/utils.py | 28 +- mx_rec/optimizers/adagrad.py | 11 +- mx_rec/optimizers/emb_optimizer.py | 76 ++ mx_rec/optimizers/ftrl.py | 14 +- mx_rec/optimizers/gradient_descent.py | 4 +- mx_rec/optimizers/gradient_descent_by_addr.py | 12 +- mx_rec/optimizers/lazy_adam.py | 15 +- mx_rec/optimizers/lazy_adam_by_addr.py | 10 +- mx_rec/saver/patch.py | 13 +- mx_rec/saver/saver.py | 96 +- mx_rec/saver/sparse.py | 15 +- mx_rec/util/__init__.py | 9 +- mx_rec/util/communication/__init__.py | 2 +- mx_rec/util/communication/hccl_mgmt.py | 37 +- mx_rec/util/communication/hccl_ops.py | 84 ++ mx_rec/util/config_utils/__init__.py | 3 + mx_rec/util/config_utils/embedding_utils.py | 76 ++ .../util/config_utils/feature_spec_utils.py | 40 + mx_rec/util/config_utils/hybrid_mgmt_utils.py | 67 ++ mx_rec/util/config_utils/optimizer_utils.py | 16 + mx_rec/util/config_utils/train_param.py | 101 ++ mx_rec/util/cpu.py | 98 ++ mx_rec/util/framework_npu_env/__init__.py | 7 + mx_rec/util/framework_npu_env/tfa_env.py | 28 + mx_rec/util/global_env_conf.py | 14 +- mx_rec/util/initialize.py | 976 ++---------------- mx_rec/util/perf_factory/__init__.py | 5 + mx_rec/util/perf_factory/bind_cpu.py | 116 +++ mx_rec/util/tf_version_adapter.py | 5 + mx_rec/util/variable.py | 10 +- mx_rec/validator/emb_validator.py | 103 ++ .../host_emb_ckpt/host_emb_ckpt.cpp | 2 +- src/core/emb_hashmap/emb_hashmap.cpp | 47 +- src/core/emb_hashmap/emb_hashmap.h | 10 - src/core/emb_table/embedding_ddr.cpp | 567 ++++++++++ src/core/emb_table/embedding_ddr.h | 95 ++ src/core/emb_table/embedding_dynamic.cpp | 119 +++ src/core/emb_table/embedding_dynamic.h | 46 + src/core/emb_table/embedding_mgmt.cpp | 152 +++ src/core/emb_table/embedding_mgmt.h | 104 ++ src/core/emb_table/embedding_static.cpp | 63 ++ src/core/emb_table/embedding_static.h | 33 + src/core/emb_table/embedding_table.cpp | 171 +++ src/core/emb_table/embedding_table.h | 116 +++ src/core/file_system/file_system.h | 9 +- .../local_file_system/local_file_system.cpp | 44 +- .../local_file_system/local_file_system.h | 5 +- src/core/host_emb/host_emb.cpp | 38 +- src/core/host_emb/host_emb.h | 7 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 124 +-- src/core/hybrid_mgmt/hybrid_mgmt_block.cpp | 4 +- src/core/key_process/key_process.cpp | 74 +- src/core/utils/common.h | 5 +- src/tests/checkpoint/checkpoint_test.cpp | 180 ---- src/tests/emb_hashmap/emb_hashmap_test.cpp | 2 + src/tests/emb_table/emb_table_test.cpp | 2 +- src/tests/emb_table/embedding_ddr_test.cpp | 200 ++++ src/tests/emb_table/embedding_mgmt_test.cpp | 112 ++ src/tests/emb_table/embedding_static_test.cpp | 149 +++ src/tests/host_emb/host_emb_test.cpp | 77 -- .../hybrid_mgmt/hybrid_mgmt_block_test.cpp | 25 - src/tests/key_process/key_process_test.cpp | 273 +---- tests/mx_rec/core/mock_class.py | 140 ++- tests/mx_rec/core/test_build_graph.py | 181 +--- tests/mx_rec/core/test_cpu.py | 26 + tests/mx_rec/core/test_embedding.py | 877 +++------------- tests/mx_rec/core/test_feature_process.py | 50 +- tests/mx_rec/core/test_feature_spec.py | 76 +- tests/mx_rec/core/test_helper.py | 77 +- tests/mx_rec/core/test_manager.py | 138 ++- tests/mx_rec/core/test_merge_table.py | 64 +- tests/mx_rec/data/test_dataset.py | 6 + tests/mx_rec/graph/test_acg_push_ops.py | 12 +- tests/mx_rec/graph/test_merge_lookup.py | 47 +- tests/mx_rec/graph/test_modifier.py | 171 ++- tests/mx_rec/graph/test_utils.py | 4 +- tests/mx_rec/saver/sparse_embedding_mock.py | 4 +- tests/mx_rec/saver/test_saver.py | 30 +- tests/mx_rec/saver/test_sparse.py | 24 +- .../util/communication/test_hccl_mgmt.py | 33 - tests/mx_rec/util/test_variable.py | 30 +- 100 files changed, 5245 insertions(+), 4132 deletions(-) create mode 100644 mx_rec/core/emb/__init__.py create mode 100644 mx_rec/core/emb/base_sparse_embedding.py create mode 100644 mx_rec/core/emb/dynamic_sparse_embedding.py create mode 100644 mx_rec/core/emb/emb_factory.py create mode 100644 mx_rec/core/emb/sparse_embedding.py create mode 100644 mx_rec/graph/graph_typing.py create mode 100644 mx_rec/optimizers/emb_optimizer.py create mode 100644 mx_rec/util/communication/hccl_ops.py create mode 100644 mx_rec/util/config_utils/__init__.py create mode 100644 mx_rec/util/config_utils/embedding_utils.py create mode 100644 mx_rec/util/config_utils/feature_spec_utils.py create mode 100644 mx_rec/util/config_utils/hybrid_mgmt_utils.py create mode 100644 mx_rec/util/config_utils/optimizer_utils.py create mode 100644 mx_rec/util/config_utils/train_param.py create mode 100644 mx_rec/util/cpu.py create mode 100644 mx_rec/util/framework_npu_env/__init__.py create mode 100644 mx_rec/util/framework_npu_env/tfa_env.py create mode 100644 mx_rec/util/perf_factory/__init__.py create mode 100644 mx_rec/util/perf_factory/bind_cpu.py create mode 100644 mx_rec/validator/emb_validator.py create mode 100644 src/core/emb_table/embedding_ddr.cpp create mode 100644 src/core/emb_table/embedding_ddr.h create mode 100644 src/core/emb_table/embedding_dynamic.cpp create mode 100644 src/core/emb_table/embedding_dynamic.h create mode 100644 src/core/emb_table/embedding_mgmt.cpp create mode 100644 src/core/emb_table/embedding_mgmt.h create mode 100644 src/core/emb_table/embedding_static.cpp create mode 100644 src/core/emb_table/embedding_static.h create mode 100644 src/core/emb_table/embedding_table.cpp create mode 100644 src/core/emb_table/embedding_table.h create mode 100644 src/tests/emb_table/embedding_ddr_test.cpp create mode 100644 src/tests/emb_table/embedding_mgmt_test.cpp create mode 100644 src/tests/emb_table/embedding_static_test.cpp create mode 100644 tests/mx_rec/core/test_cpu.py diff --git a/mx_rec/constants/__init__.py b/mx_rec/constants/__init__.py index e32874b8..59e77d8b 100644 --- a/mx_rec/constants/__init__.py +++ b/mx_rec/constants/__init__.py @@ -15,6 +15,6 @@ # limitations under the License. # ============================================================================== -__all__ = ["ASCEND_TIMESTAMP", "ApplyGradientsStrategy"] +__all__ = ["ASCEND_TIMESTAMP"] -from mx_rec.constants.constants import ASCEND_TIMESTAMP, ApplyGradientsStrategy +from mx_rec.constants.constants import ASCEND_TIMESTAMP diff --git a/mx_rec/constants/constants.py b/mx_rec/constants/constants.py index 891d65e4..8e23438e 100644 --- a/mx_rec/constants/constants.py +++ b/mx_rec/constants/constants.py @@ -115,7 +115,6 @@ class EnvOption(Enum): CM_CHIEF_DEVICE = "CM_CHIEF_DEVICE" CM_WORKER_SIZE = "CM_WORKER_SIZE" TF_DEVICE = "TF_DEVICE" - APPLY_GRADIENTS_STRATEGY = "APPLY_GRADIENTS_STRATEGY" ACL_TIMEOUT = "AclTimeout" HD_CHANNEL_SIZE = "HD_CHANNEL_SIZE" KEY_PROCESS_THREAD_NUM = "KEY_PROCESS_THREAD_NUM" @@ -128,6 +127,11 @@ class EnvOption(Enum): STAT_ON = "STAT_ON" RECORD_KEY_COUNT = "RECORD_KEY_COUNT" + # MPI env + OMPI_COMM_WORLD_SIZE = "OMPI_COMM_WORLD_SIZE" + OMPI_COMM_WORLD_LOCAL_SIZE = "OMPI_COMM_WORLD_LOCAL_SIZE" + OMPI_COMM_WORLD_RANK = "OMPI_COMM_WORLD_RANK" + class DataName(Enum): KEY = "key" @@ -180,11 +184,6 @@ class All2allGradientsOp(BaseEnum): SUM_GRADIENTS_AND_DIV_BY_RANKSIZE = "sum_gradients_and_div_by_ranksize" -class ApplyGradientsStrategy(BaseEnum): - DIRECT_APPLY = "direct_apply" - SUM_SAME_ID_GRADIENTS_AND_APPLY = "sum_same_id_gradients_and_apply" - - class RecPyLogLevel(Enum): DEBUG = "DEBUG" INFO = "INFO" @@ -211,3 +210,14 @@ class Flag(Enum): FALSE = "0" +class AnchorDatasetOp(Enum): + MODEL_DATASET = "ModelDataset" + OPTIMIZE_DATASET = "OptimizeDataset" + PREFETCH_DATASET = "PrefetchDataset" + + +class AnchorIteratorOp(Enum): + ITERATOR_GET_NEXT = "IteratorGetNext" + MAKE_ITERATOR = "MakeIterator" + ONE_SHOT_ITERATOR = "OneShotIterator" + diff --git a/mx_rec/core/asc/build_graph.py b/mx_rec/core/asc/build_graph.py index a4858d84..13ddad4a 100644 --- a/mx_rec/core/asc/build_graph.py +++ b/mx_rec/core/asc/build_graph.py @@ -20,50 +20,39 @@ from typing import Optional import tensorflow as tf import mxrec_pybind -from mx_rec.util.initialize import get_use_static +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.tf_version_adapter import npu_ops -from mx_rec.constants.constants import ApplyGradientsStrategy, TRAIN_CHANNEL_ID -from mx_rec.util.global_env_conf import global_env +from mx_rec.constants.constants import TRAIN_CHANNEL_ID from mx_rec.util.log import logger def get_restore_vector(config): logger.debug('Channel %s_restore_%s was built for getnext', config.get("table_name"), config.get("channel_id")) - if config.get("skip_emb_transfer"): - if not isinstance(config.get("emb_size"), int): - raise TypeError("emb_size must be a int") + if config.get("is_hbm"): + if not isinstance(config.get("emb_size"), int) or config.get("emb_size") < 1: + raise TypeError(f"emb_size must be a int") if config.get("emb_size") < 1: - raise ValueError("emb_size is less than 1") + raise ValueError(f"emb_size is less than 1") emb_size = config.get("emb_size") else: - if not isinstance(config.get("ext_emb_size"), int): + if not isinstance(config.get("ext_emb_size"), int) or config.get("ext_emb_size") < 1: raise TypeError("ext_emb_size must be a int") if config.get("ext_emb_size") < 1: raise ValueError("ext_emb_size is less than 1") emb_size = config.get("ext_emb_size") - use_hot = config.get("use_hot") - hot_pos = None - - if get_use_static(): + if ConfigInitializer.get_instance().use_static: restore_size = config.get("batch_size") * config.get("feat_cnt") else: restore_size = None with tf.compat.v1.variable_scope(config.get("table_name"), reuse=tf.compat.v1.AUTO_REUSE): - if use_hot and emb_size: - device_id = int(config.get("device_id")) - hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) - restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32, tf.int32], - output_shapes=[restore_size, [hot_size]], - channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}' - ) - else: - restore_vector = npu_ops.gen_npu_ops.get_next( - output_types=[tf.int32], - output_shapes=[restore_size], - channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}')[0] + device_id = int(config.get("device_id")) + hot_size = int(mxrec_pybind.get_ub_hot_size(device_id) / emb_size) + restore_vector, hot_pos = npu_ops.gen_npu_ops.get_next( + output_types=[tf.int32, tf.int32], + output_shapes=[restore_size, [hot_size]], + channel_name=f'{config.get("table_name")}_restore_{config.get("channel_id")}') return restore_vector, hot_pos @@ -83,7 +72,7 @@ def get_id_offsets(max_lookup_vec_size, config): output_types=[tf.int32], output_shapes=[[max_lookup_vec_size]], channel_name=f'{config.get("table_name")}_lookup_{config.get("channel_id")}') - if config.get("skip_emb_transfer"): + if config.get("is_hbm"): return id_offsets, [], 0 swap_pos, swap_len = npu_ops.gen_npu_ops.get_next( output_types=[tf.int32, tf.int32], @@ -164,12 +153,12 @@ def get_swap_info(config: dict, swap_len: int, swap_pos: list, table: tf.Variabl :param table: the instance to do swap :return: swap info """ - use_static = get_use_static() + use_static = ConfigInitializer.get_instance().use_static max_lookup_vec_size = None if use_static: max_lookup_vec_size = config.get("send_count") * config.get("rank_size") - if config.get("skip_emb_transfer"): + if config.get("is_hbm"): swap_in = [tf.no_op()] else: with tf.compat.v1.variable_scope("h2d_emb"): @@ -199,7 +188,7 @@ def get_swap_info(config: dict, swap_len: int, swap_pos: list, table: tf.Variabl def get_preprocessed_tensor_for_asc(table, config): - use_static = get_use_static() + use_static = ConfigInitializer.get_instance().use_static max_lookup_vec_size = None if use_static: max_lookup_vec_size = config.get("send_count") * config.get("rank_size") @@ -222,9 +211,6 @@ def get_preprocessed_tensor_for_asc(table, config): 'all2all_args': all2all_args, } - if global_env.apply_gradients_strategy != ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY.value: - return result - if config.get("channel_id") != TRAIN_CHANNEL_ID: return result diff --git a/mx_rec/core/asc/feature_spec.py b/mx_rec/core/asc/feature_spec.py index 71454f1b..ca8bab10 100644 --- a/mx_rec/core/asc/feature_spec.py +++ b/mx_rec/core/asc/feature_spec.py @@ -21,7 +21,7 @@ from functools import reduce import tensorflow as tf from mx_rec.util.atomic import AtomicInteger -from mx_rec.util.initialize import insert_feature_spec, insert_training_mode_channel_id, get_use_static +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.normalization import fix_invalid_table_name from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import ClassValidator, StringValidator, para_checker_decorator, \ @@ -76,7 +76,7 @@ class FeatureSpec: self.feat_pos_train = None self.feat_pos_eval = None self.dims = None - self.rank = None + self.tensor_rank = None self.batch_size = batch_size self.split = None # usually split == batch_size * feature_count self.initialized = False @@ -150,7 +150,7 @@ class FeatureSpec: logger.info("FeatureSpec%s. Is training mode [%s] has been set.", self.name, mode) return - insert_training_mode_channel_id(is_training=mode) + ConfigInitializer.get_instance().train_params_config.insert_training_mode_channel_id(is_training=mode) self._pipeline_mode.add(mode) @@ -160,13 +160,13 @@ class FeatureSpec: if not self.initialized: self.initialized = True - if get_use_static(): + if ConfigInitializer.get_instance().use_static: self.dims = tensor.shape.as_list() - self.rank = tensor.shape.rank - if self.rank < 1: - raise ValueError(f"Given tensor rank cannot be smaller than 1, which is {self.rank} now.") + self.tensor_rank = tensor.shape.rank + if self.tensor_rank < 1: + raise ValueError(f"Given tensor rank cannot be smaller than 1, which is {self.tensor_rank} now.") - inferred_feat_cnt = 1 if self.rank == 1 else reduce(lambda x, y: x * y, self.dims[1:]) + inferred_feat_cnt = 1 if self.tensor_rank == 1 else reduce(lambda x, y: x * y, self.dims[1:]) logger.debug("update feature_spec[%s] feature_count to %s via %s", self.name, inferred_feat_cnt, self.dims) self.batch_size = self.dims[0] @@ -175,14 +175,14 @@ class FeatureSpec: else: tensor = tf.reshape(tensor, [-1]) self.dims = tf.shape(tensor) - self.rank = 1 + self.tensor_rank = 1 self.split = tf.math.reduce_prod(tf.shape(tensor)) self.batch_size = self.split self._feat_cnt = 1 else: logger.debug("The initialized Feature Spec was set once again.") - if get_use_static(): + if ConfigInitializer.get_instance().use_static: if self.dims != tensor.shape.as_list(): raise ValueError(f"Given static Tensor shape mismatches with the last one, whose is_training mode " f"is not {is_training}. ") @@ -191,7 +191,7 @@ class FeatureSpec: raise ValueError(f"Given dynamic Tensor shape mismatches with the last one, whose is_training mode " f"is not {is_training}. ") - insert_feature_spec(self, is_training) + ConfigInitializer.get_instance().feature_spec_config.insert_feature_spec(self, is_training) result = { 'tensor': tensor, 'table_name': self.table_name, diff --git a/mx_rec/core/asc/helper.py b/mx_rec/core/asc/helper.py index 6cab56c1..771f359f 100644 --- a/mx_rec/core/asc/helper.py +++ b/mx_rec/core/asc/helper.py @@ -19,12 +19,13 @@ from functools import reduce import tensorflow as tf -from mx_rec.util.initialize import get_host_pipeline_ops, get_training_mode_channel_id, get_use_static, get_modify_graph from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.merge_table import find_dangling_table, should_skip -from mx_rec.validator.validator import para_checker_decorator, ValueCompareValidator, ClassValidator +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.log import logger from mx_rec.util.normalization import fix_invalid_table_name +from mx_rec.util.ops import import_host_pipeline_ops +from mx_rec.validator.validator import para_checker_decorator, ValueCompareValidator, ClassValidator @para_checker_decorator(check_option_list=[ @@ -89,7 +90,7 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, table_na if not isinstance(tgt_key_specs, (list, tuple)): tgt_key_specs = [tgt_key_specs] - def insert_fn_for_feature_specs(*args): # pragma: no cover + def insert_fn_for_feature_specs(*args): data_src = args if len(args) == 1: data_src = args[0] @@ -119,7 +120,7 @@ def get_asc_insert_func_inner(tgt_key_specs=None, args_index_list=None, table_na logger.info("In insert found dangling table(s): %s which does not need to be provided to the EmbInfo.", dangling_tables) - def insert_fn_for_arg_indexes(*args): # pragma: no cover + def insert_fn_for_arg_indexes(*args): insert_tensors = get_target_tensors_with_args_indexes(args_index_list) logger.debug("do_insert without spec for %s", table_names) @@ -170,7 +171,7 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list): f"len(split_list): {len(split_list)}" f"len(table_name_list): {len(table_name_list)}") feature_id_requests = zip(feature_id_list, split_list, table_name_list) - if get_modify_graph(): + if ConfigInitializer.get_instance().modify_graph: feature_id_requests = sorted(feature_id_requests, key=lambda x: (x[2])) else: feature_id_requests = sorted(feature_id_requests, key=lambda x: (x[2], x[0].name)) @@ -203,34 +204,35 @@ def merge_feature_id_request(feature_id_list, split_list, table_name_list): logger.debug("merge request from %s %s to %s %s", table_name_list, split_list, output_table_name_list, output_split_list) - output_dict = { + list_set = { 'output_feature_id_list': output_feature_id_list, 'output_split_list': output_split_list, 'output_table_name_list': output_table_name_list, 'output_tensorshape_split_list': output_tensorshape_split_list, } - return output_dict + return list_set def send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict): - is_training = input_dict.get("is_training") - timestamp = input_dict.get("timestamp") - host_pipeline_ops = get_host_pipeline_ops() - use_static = get_use_static() + is_training = input_dict["is_training"] + timestamp = input_dict["timestamp"] + host_pipeline_ops = import_host_pipeline_ops() + use_static = ConfigInitializer.get_instance().use_static timestamp_feature_id = [] if timestamp: timestamp_feature_id = feature_id_list[:1] feature_id_list = feature_id_list[1:] - merged_dict = merge_feature_id_request(feature_id_list, split_list, table_name_list) - feature_id_list = merged_dict.get("output_feature_id_list") - split_list = merged_dict.get("output_split_list") - table_name_list = merged_dict.get("output_table_name_list") - tensorshape_split_list = merged_dict.get("output_tensorshape_split_list") + list_set = merge_feature_id_request(feature_id_list, split_list, table_name_list) + feature_id_list = list_set.get("output_feature_id_list") + split_list = list_set.get("output_split_list") + table_name_list = list_set.get("output_table_name_list") + tensorshape_split_list = list_set.get("output_tensorshape_split_list") # check training mode order and ensure channel id - channel_id = get_training_mode_channel_id(is_training=is_training) + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id( + is_training) if timestamp: feature_id_list = timestamp_feature_id + feature_id_list @@ -252,11 +254,11 @@ def send_feature_id_request_async(feature_id_list, split_list, table_name_list, def do_insert(args, insert_tensors, splits, table_names, input_dict): - is_training = input_dict.get("is_training") - dump_graph = input_dict.get("dump_graph") - timestamp = input_dict.get("timestamp") - feature_spec_names = input_dict.get("feature_spec_names") - auto_change_graph = input_dict.get("auto_change_graph") + is_training = input_dict["is_training"] + dump_graph = input_dict["dump_graph"] + timestamp = input_dict["timestamp"] + feature_spec_names = input_dict["feature_spec_names"] + auto_change_graph = input_dict["auto_change_graph"] pipeline_op = \ send_feature_id_request_async(feature_id_list=insert_tensors, @@ -282,10 +284,7 @@ def export_read_emb_key_v2_op(args, pipeline_op): raise ValueError("The length of args is less than 1.") if isinstance(origin_batch[0], dict): output_batch = origin_batch[0] - # 找到output_batch中字典序最大的key - sorted_keys = sorted(output_batch) - valid_key = f"{sorted_keys[-1]}_read_emb_key" - # 将readEmbKey算子的输出插入到batch中,当dataset每次getnext时,就会执行readEmbKey算子获取输出 + valid_key = get_valid_op_key(output_batch) output_batch[valid_key] = pipeline_op elif len(origin_batch) == 1 and isinstance(origin_batch[0], tf.Tensor): @@ -314,7 +313,17 @@ def export_read_emb_key_v2_op(args, pipeline_op): return output_batch -def get_target_tensors_with_args_indexes(args_index_list): # pragma: no cover +def get_valid_op_key(batch_dict: dict) -> str: + if not isinstance(batch_dict, dict): + raise TypeError(f"batch_dict must be a dict") + + sorted_keys = sorted(batch_dict) + valid_key = f"{sorted_keys[-1]}_read_emb_key" + + return valid_key + + +def get_target_tensors_with_args_indexes(args_index_list): insert_tensors = [] graph = tf.compat.v1.get_default_graph() for index in args_index_list: diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 1aeeb573..2f555a56 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -20,45 +20,31 @@ import tensorflow as tf from mxrec_pybind import InitializeInfo, ConstantInitializerInfo, NormalInitializerInfo, EmbInfo, EmbInfoParams, \ ThresholdValue, HybridMgmt, RankInfo, USE_STATIC, USE_HOT, USE_DYNAMIC_EXPANSION -from mx_rec.util.initialize import get_rank_id, get_device_id, get_rank_size, set_asc_manager, \ - is_asc_manager_initialized, get_train_steps, get_eval_steps, get_save_steps, \ - export_table_instances, export_feature_spec, get_if_load, get_use_static, \ - get_use_hot, get_stat_on, get_use_dynamic_expansion, export_optimizer, export_dangling_table, export_table_num -from mx_rec.core.asc.merge_table import find_dangling_table, should_skip +from mx_rec.util.communication.hccl_ops import get_rank_id, get_device_id, get_rank_size +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.core.asc.merge_table import should_skip, check_dangling_table from mx_rec.util.log import logger -def check_dangling_table(): - """ - If the dangling_table list is empty(maybe feature_spec mode), try to find again - :return: list of dangling_table - """ - dangling_table = export_dangling_table() - if not dangling_table: - dangling_table = find_dangling_table([table_instance.table_name - for _, table_instance in export_table_instances().items()]) - return dangling_table - - def generate_table_info_list(): # table_name is corresponding to channel_name which is in used in operator gen_npu_ops.get_next table_info_list = [] # check whether DDR is enabled or disabled for all tables. - host_voc_sizes = [table_instance.host_vocabulary_size for table_instance in export_table_instances().values()] - total_host_voc_size = sum(host_voc_sizes) - if total_host_voc_size != 0 and 0 in host_voc_sizes: - raise ValueError(f"The host-side DDR function of all tables must be used or not used at the same time. " - f"However, host voc size of each table is {host_voc_sizes}.") + table_instance_dict = ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict + is_hbm_list = [table_instance.is_hbm for table_instance in table_instance_dict.values()] + if len(set(is_hbm_list)) != 1: + raise ValueError(f"The DDR mode of all tables must be used or not used at the same time. However, is_hbm " + f"of each table `{table_instance_dict.keys()}` is `{is_hbm_list}`.") - optimizer = export_optimizer() + optimizer = ConfigInitializer.get_instance().optimizer_config.optimizer_instance # generate table info dangling_table = check_dangling_table() - for _, table_instance in export_table_instances().items(): + for _, table_instance in ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict.items(): # When dynamic expansion mode, ext_emb_size is set by optimizer if optimizer is not None: - table_instance.ext_emb_size = table_instance.scalar_emb_size * (1 + optimizer.slot_num) + table_instance.ext_emb_size = table_instance.emb_size * (1 + optimizer.slot_num) logger.debug("ext_emb_size is reset to be %s for EmbInfo", table_instance.ext_emb_size) skip = should_skip(table_instance.table_name) if table_instance.table_name in dangling_table or skip: @@ -66,8 +52,8 @@ def generate_table_info_list(): skip, table_instance.table_name) continue - static_shape_rec_flag = get_use_static() and table_instance.send_count > 0 - dynamic_shape_rec_flag = not get_use_static() + static_shape_rec_flag = ConfigInitializer.get_instance().use_static and table_instance.send_count > 0 + dynamic_shape_rec_flag = not ConfigInitializer.get_instance().use_static if static_shape_rec_flag or dynamic_shape_rec_flag: logger.debug("table_instance.slice_device_vocabulary_size: %s", table_instance.slice_device_vocabulary_size) @@ -75,7 +61,7 @@ def generate_table_info_list(): logger.debug("table_instance.slice_ssd_vocabulary_size: %s", table_instance.slice_ssd_vocabulary_size) logger.debug("EmbInfoParams: The table name is %s, and the value of `is_grad` in this table is %s.", table_instance.table_name, table_instance.is_grad) - params = EmbInfoParams(table_instance.table_name, table_instance.send_count, table_instance.scalar_emb_size, + params = EmbInfoParams(table_instance.table_name, table_instance.send_count, table_instance.emb_size, table_instance.ext_emb_size, table_instance.is_save, table_instance.is_grad) table_info = EmbInfo(params, [table_instance.slice_device_vocabulary_size, @@ -90,7 +76,7 @@ def generate_table_info_list(): def matched_constant_initializer(tabel_info): init_param = tabel_info.init_param logger.debug("constant_initializer, tabel: %s, initK is %s.", tabel_info.table_name, init_param) - return InitializeInfo(name="constant_initializer", start=0, len=tabel_info.scalar_emb_size, + return InitializeInfo(name="constant_initializer", start=0, len=tabel_info.emb_size, constant_initializer_info=ConstantInitializerInfo( constant_val=tabel_info.emb_initializer.value, initK=init_param)) @@ -99,7 +85,7 @@ def matched_random_normal_initializer(tabel_info): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed init_param = tabel_info.init_param logger.debug("random_normal_initializer, tabel: %s, initK is %s.", tabel_info.table_name, init_param) - return InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.scalar_emb_size, + return InitializeInfo(name="random_normal_initializer", start=0, len=tabel_info.emb_size, normal_initializer_info=NormalInitializerInfo( mean=tabel_info.emb_initializer.mean, stddev=tabel_info.emb_initializer.stddev, @@ -112,7 +98,7 @@ def matched_truncated_normal_initializer(tabel_info): random_seed = 0 if tabel_info.emb_initializer.seed is None else tabel_info.emb_initializer.seed init_param = tabel_info.init_param logger.debug("truncated_normal_initializer, tabel: %s, initK is %s.", tabel_info.table_name, init_param) - return InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, + return InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.emb_size, normal_initializer_info=NormalInitializerInfo( mean=tabel_info.emb_initializer.mean, stddev=tabel_info.emb_initializer.stddev, @@ -145,7 +131,7 @@ def matched_emb_initializer(tabel_info): initializer_case_map.get("tf2_truncated_normal_initializer"): initializer = matched_truncated_normal_initializer(tabel_info) else: - initializer = InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.scalar_emb_size, + initializer = InitializeInfo(name="truncated_normal_initializer", start=0, len=tabel_info.emb_size, normal_initializer_info=NormalInitializerInfo( mean=0.0, stddev=1.0, @@ -155,7 +141,7 @@ def matched_emb_initializer(tabel_info): def matched_opt_slot_initializers(table_instance): - start_index = table_instance.scalar_emb_size + start_index = table_instance.emb_size slot_initializers = [] logger.debug("matched_opt_slot_initializers, scalar emb size:%s, optimizer_instance_list size:%s", table_instance.ext_emb_size, len(table_instance.optimizer_instance_list)) @@ -163,12 +149,12 @@ def matched_opt_slot_initializers(table_instance): for slot_init_value in optimizer.get_slot_init_values(): slot_initializer = InitializeInfo(name="constant_initializer", start=start_index, - len=table_instance.scalar_emb_size, + len=table_instance.emb_size, constant_initializer_info=ConstantInitializerInfo( constant_val=slot_init_value )) slot_initializers.append(slot_initializer) - start_index += table_instance.scalar_emb_size + start_index += table_instance.emb_size return slot_initializers @@ -176,7 +162,7 @@ def matched_opt_slot_initializers(table_instance): def generate_threshold_list(): threshold_list = [] - for _, feature_spec in export_feature_spec().items(): + for _, feature_spec in ConfigInitializer.get_instance().feature_spec_config.feature_spec_dict.items(): coef = 1 if feature_spec.faae_coefficient is None else feature_spec.faae_coefficient if feature_spec.eviction_threshold: threshold = ThresholdValue(feature_spec.table_name, @@ -201,17 +187,17 @@ def initialize_emb_cache(table_info_list, threshold_list): rank_id = get_rank_id() device_id = get_device_id() rank_size = get_rank_size() - train_steps = get_train_steps() - eval_steps = get_eval_steps() - save_steps = get_save_steps() + train_steps = ConfigInitializer.get_instance().train_steps + eval_steps = ConfigInitializer.get_instance().eval_steps + save_steps = ConfigInitializer.get_instance().save_steps - if_load = get_if_load() + if_load = ConfigInitializer.get_instance().if_load option = 0 - if get_use_static(): + if ConfigInitializer.get_instance().use_static: option = option | USE_STATIC - if get_use_hot(): - option = option | USE_HOT - if get_use_dynamic_expansion(): + # use hot always True + option = option | USE_STATIC << 1 + if ConfigInitializer.get_instance().use_dynamic_expansion: option = option | USE_DYNAMIC_EXPANSION # [train_steps, eval_steps, save_steps] pass step information to HybridMgmt for data process loop @@ -226,7 +212,7 @@ def initialize_emb_cache(table_info_list, threshold_list): logger.error("Failed to init emb_cache!") raise RuntimeError("emb_cache has not been initialized successfully.") - set_asc_manager(emb_cache) + ConfigInitializer.get_instance().hybrid_manager_config.set_asc_manager(emb_cache) logger.info("Preprocessing has been sunk into the host pipeline.") logger.debug("Flag if load is %s.", if_load) logger.debug("train_steps is %s.", train_steps) @@ -241,7 +227,6 @@ def start_asc_pipeline(): if not table_info_list: logger.error("table_info_list is empty!") raise RuntimeError("table_info_list is empty!") - if get_stat_on(): - logger.info("[StatInfo] current_table_num %s", export_table_num()) - if not is_asc_manager_initialized() and table_info_list: + + if not ConfigInitializer.get_instance().hybrid_manager_config.asc_manager and table_info_list: initialize_emb_cache(table_info_list, threshold_list) diff --git a/mx_rec/core/asc/merge_table.py b/mx_rec/core/asc/merge_table.py index d3b5dec8..776a72c4 100644 --- a/mx_rec/core/asc/merge_table.py +++ b/mx_rec/core/asc/merge_table.py @@ -20,9 +20,9 @@ from typing import Dict, List import tensorflow as tf from tensorflow import Operation, Tensor -from mx_rec.constants.constants import MAX_WHILE_SIZE, ASCEND_TABLE_NAME_MUST_CONTAIN -from mx_rec.util.initialize import get_enable_table_merge, export_table_instances, insert_dangling_table, \ - get_bool_gauge_set +from mx_rec.constants.constants import MAX_WHILE_SIZE, TFDevice, ASCEND_TABLE_NAME_MUST_CONTAIN +from mx_rec.util.global_env_conf import global_env +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.log import logger @@ -52,7 +52,7 @@ def check_op(table_reachable_op: Operation) -> bool: def is_train_task(): - bool_gauge_set = get_bool_gauge_set() + bool_gauge_set = ConfigInitializer.get_instance().train_params_config.bool_gauge_set if not bool_gauge_set: op_list = tf.compat.v1.get_default_graph().get_operations() for t_op in op_list: @@ -76,7 +76,7 @@ def find_dangling_table(table_names: List[str]) -> List[str]: def find_table_op(table_name: str, the_op: Operation, table_lookup_op: Dict[str, List[Operation]], - table_reachable_tensor: Dict[str, List[Tensor]]) -> None: # pragma: no cover + table_reachable_tensor: Dict[str, List[Tensor]]) -> None: """ find all the table lookup op. :param table_name: tables' names :param the_op: the op to be @@ -97,7 +97,7 @@ def find_dangling_table(table_names: List[str]) -> List[str]: def extend(op_list: List[Operation], tensor: Tensor, - spread_tensors: List[Tensor]) -> None: # pragma: no cover + spread_tensors: List[Tensor]) -> None: """extend the tensors which table lookup op can reach :param op_list: all op in the graph @@ -109,7 +109,7 @@ def find_dangling_table(table_names: List[str]) -> List[str]: if tensor in the_op.inputs: spread_tensors.extend(the_op.outputs) - def bfs_lookup(next_to_visit: List[Tensor]) -> (set, bool): # pragma: no cover + def bfs_lookup(next_to_visit: List[Tensor]) -> (set, bool): """find all the tensors which table lookup op can reach :param next_to_visit: the tensor list to be visited by bfs @@ -137,7 +137,9 @@ def find_dangling_table(table_names: List[str]) -> List[str]: if not is_train_task(): logger.info("!!merge table only available in train task.") return [] - if not get_enable_table_merge(): + + enable_table_merge = True if global_env.tf_device == TFDevice.NPU.value else False + if not enable_table_merge: return [] op_list = tf.compat.v1.get_default_graph().get_operations() @@ -145,7 +147,7 @@ def find_dangling_table(table_names: List[str]) -> List[str]: table_lookup_op = {} table_reachable_tensor = {} - for _, table_instance in export_table_instances().items(): + for _, table_instance in ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict.items(): if table_instance.table_name not in table_names: table_names.append(table_instance.table_name) @@ -160,13 +162,13 @@ def find_dangling_table(table_names: List[str]) -> List[str]: if table_name not in table_lookup_op: logger.debug("*********** created table %s but never look up***********", table_name) dangling_table.append(table_name) - insert_dangling_table(table_name) + ConfigInitializer.get_instance().sparse_embed_config.insert_dangling_table(table_name) for table_name, table_op in table_reachable_tensor.items(): reach_op, found = bfs_lookup(table_op) if not found and affirm(reach_op): dangling_table.append(table_name) - insert_dangling_table(table_name) + ConfigInitializer.get_instance().sparse_embed_config.insert_dangling_table(table_name) return dangling_table @@ -184,3 +186,17 @@ def should_skip(table_name) -> bool: break return skip return False + + +def check_dangling_table(): + """ + If the dangling_table list is empty(maybe feature_spec mode), try to find again + :return: list of dangling_table + """ + config_instance = ConfigInitializer.get_instance() + dangling_table = config_instance.sparse_embed_config.dangling_table + if not dangling_table: + dangling_table = find_dangling_table([table_instance.table_name + for _, table_instance in + config_instance.sparse_embed_config.table_instance_dict.items()]) + return dangling_table diff --git a/mx_rec/core/emb/__init__.py b/mx_rec/core/emb/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mx_rec/core/emb/base_sparse_embedding.py b/mx_rec/core/emb/base_sparse_embedding.py new file mode 100644 index 00000000..b2cecf2a --- /dev/null +++ b/mx_rec/core/emb/base_sparse_embedding.py @@ -0,0 +1,565 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import abc +from collections import defaultdict +from typing import Optional, Union, Callable + +import tensorflow as tf +from tensorflow.python.ops import array_ops + +from mx_rec.constants.constants import All2allGradientsOp, ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCAnchorAttr +from mx_rec.core.asc.feature_spec import set_temporary_feature_spec_attribute, get_feature_spec, FeatureSpec +from mx_rec.util.communication.hccl_ops import get_rank_size, get_rank_id, get_device_id +from mx_rec.util.tf_version_adapter import hccl_ops +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.util.log import logger +from mx_rec.validator.emb_validator import check_emb_init_params, check_emb_lookup_params + + +class BaseSparseEmbedding(metaclass=abc.ABCMeta): + """ + 稀疏表基类 + """ + # 自动改图使用的全局字典,以ids和待保存内容的字符串为key + anchor_tensor_specs = defaultdict(dict) + + def __init__(self, config: dict): + self._embedding_size = config.get("embedding_size") + if isinstance(self._embedding_size, int): + self._embedding_size = tf.TensorShape([self._embedding_size]) + self._table_name = config.get("table_name") + self._key_dtype = config.get("key_dtype") + self._emb_initializer = config.get("emb_initializer") + self._is_save = config.get("is_save") + self._init_param = config.get("init_param") + self._is_hbm = True if config.get("host_vocabulary_size") <= 0 else False + self._ssd_data_path = list(config.get("ssd_data_path")) + self._send_count = 0 + self._slice_device_vocabulary_size = 0 + self._slice_host_vocabulary_size = 0 + self._slice_ssd_vocabulary_size = 0 + self._emb_size = self._embedding_size.as_list()[0] + self._is_grad = False + self._ext_emb_size = None + self._variable = None + self._multi_lookup_times = {True: 0, False: 0} + + self._all2all_gradients_op = All2allGradientsOp.mapping(config.get("all2all_gradients_op")) + self._device_vocabulary_size = config.get("device_vocabulary_size") + self._host_vocabulary_size = config.get("host_vocabulary_size") + self._ssd_vocabulary_size = config.get("ssd_vocabulary_size") + self._ext_coefficient = 1 + self._default_name_count = -1 + self._same_table_send_count = 0 + self._lookup_result = dict() + self._modify_graph = False + + self._rank_size = get_rank_size() + self._rank_id = get_rank_id() + self._device_id = get_device_id() + self._use_static = ConfigInitializer.get_instance().use_static + + # init variable + self._set_slice_vocab_size() + + if ConfigInitializer.get_instance().hybrid_manager_config.freeze and \ + self._table_name in ConfigInitializer.get_instance().sparse_embed_config.name_to_var_dict: + self._variable = tf.compat.v1.get_variable(self._table_name, + shape=(self._slice_device_vocabulary_size, self._emb_size)) + else: + check_emb_init_params(self._is_hbm, self._embedding_size) + self.__initialize_variables() + tf.compat.v1.add_to_collection( + ConfigInitializer.get_instance().train_params_config.ascend_global_hashtable_collection, self._variable) + self._set_ext_emb_size() + + @property + def optimizer_instance_list(self): + return [] + + @property + def optimizer(self): + return dict() + + @property + def embedding_size(self): + return self._embedding_size + + @property + def table_name(self): + return self._table_name + + @property + def key_dtype(self): + return self._key_dtype + + @property + def emb_initializer(self): + return self._emb_initializer + + @property + def is_save(self): + return self._is_save + + @property + def init_param(self): + return self._init_param + + @property + def send_count(self): + return self._send_count + + @property + def slice_device_vocabulary_size(self): + return self._slice_device_vocabulary_size + + @property + def slice_host_vocabulary_size(self): + return self._slice_host_vocabulary_size + + @property + def slice_ssd_vocabulary_size(self): + return self._slice_ssd_vocabulary_size + + @property + def emb_size(self): + return self._emb_size + + @property + def is_grad(self): + return self._is_grad + + @property + def ext_emb_size(self): + return self._ext_emb_size + + @property + def variable(self): + return self._variable + + @property + def multi_lookup_times(self): + return self._multi_lookup_times + + @property + def ssd_data_path(self): + return self._ssd_data_path + + @property + def is_hbm(self): + return self._is_hbm + + @send_count.setter + def send_count(self, send_count: int): + self._send_count = send_count + + @ext_emb_size.setter + def ext_emb_size(self, ext_emb_size: int): + self._ext_emb_size = ext_emb_size + + @is_grad.setter + def is_grad(self, is_grad: bool): + self._is_grad = is_grad + + @staticmethod + def get_anchor_attribute(anchor: tf.Tensor, attr: ASCAnchorAttr) -> \ + Union['BaseSparseEmbedding', FeatureSpec, bool]: + """ + 获取anchor ids对应的属性. + + Args: + anchor: lookup传入的ids + attr: 待获取属性名称 + + Returns: anchor_tensor_specs中key为attr的属性. + """ + if not isinstance(anchor, tf.Tensor): + raise TypeError("Anchor must be a Tensor.") + + if attr not in ASCAnchorAttr: + raise ValueError("Given attr must be limited in Enum 'ASCAnchorAttr'.") + + specs = BaseSparseEmbedding.anchor_tensor_specs.get(anchor) + if specs is None: + raise KeyError(f"Given anchor '{anchor}' was not registered.") + + return specs.get(attr) + + @abc.abstractmethod + def capacity(self) -> int: + """ + 获取稀疏表的容量. + Returns: 稀疏表的容量 + """ + pass + + @abc.abstractmethod + def set_optimizer(self, key: str, state_dict: dict): + """ + 设置optimizer state. + + Args: + key: 优化器名字 + state_dict: optimizer state + + Returns: None + """ + pass + + @abc.abstractmethod + def _set_slice_vocab_size(self): + pass + + @abc.abstractmethod + def _set_ext_emb_size(self): + pass + + @abc.abstractmethod + def _build_optimizer_states(self): + pass + + @abc.abstractmethod + def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: + pass + + @abc.abstractmethod + def _get_update_grad(self, local_grad: tf.Tensor, result: dict, + table: Union[tf.compat.v1.Variable, tf.Tensor]) -> Union[tf.IndexedSlices, tf.Tensor]: + pass + + @abc.abstractmethod + def _get_local_embeddings(self, table: Union[tf.compat.v1.Variable, tf.Tensor], result: dict, + feature_spec: FeatureSpec, **kwargs) -> tf.Tensor: + pass + + @abc.abstractmethod + def _get_sparse_forward_result(self, sparse_forward_fn: Callable, table: Union[tf.compat.v1.Variable, tf.Tensor], + result: dict, is_training: bool) -> tf.Tensor: + pass + + def size(self) -> int: + """ + 获取稀疏表的大小. + Returns: 稀疏表的大小 + """ + return ConfigInitializer.get_instance().hybrid_manager_config.asc_manager.get_table_size(self._table_name) + + def register_anchor_attribute(self, anchor_ids: tf.Tensor, feature_spec: FeatureSpec, kwargs: dict): + """ + 注册anchor ids的相关属性. + + Args: + anchor_ids: lookup传入的ids + feature_spec: 根据ids创建的FeatureSpec + kwargs: lookup参数字典 + + Returns: None + """ + self.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.TABLE_INSTANCE] = self + self.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = kwargs.get("is_train") + self.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.FEATURE_SPEC] = feature_spec + self.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_GRAD] = kwargs.get("is_grad") + + def get_default_lookup_name(self) -> str: + """ + 获取该表此次lookup的默认名字. + Returns: lookup的默认名字 + """ + self._default_name_count += 1 + default_name = "sparse_lookup_%d" % self._default_name_count + logger.debug("getting one default lookup name %s.", default_name) + return default_name + + def increase_multi_lookup_times(self, is_training: bool): + """ + 增加该表的一次查询次数,用于校验一表多查次数. + + Args: + is_training: 当前流程是训练还是推理 + + Returns: None + """ + self._multi_lookup_times[is_training] = self._multi_lookup_times.get(is_training) + 1 + + def lookup(self, ids: tf.Tensor, send_count: Optional[int], **kwargs) -> tf.Tensor: + """ + 稀疏表的lookup,自动改图模式. + + Args: + ids: 此次lookup的tensor + send_count: all2all通信参数 + **kwargs: lookup参数字典 + + Returns: lookup结果 + """ + is_training = kwargs.get("is_train") + if ConfigInitializer.get_instance().hybrid_manager_config.freeze and is_training: + raise RuntimeError("Cannot build new sparse forward graph after emb cache management was built.") + + # record send count + eval_mode = not is_training and \ + ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(True) is None + if is_training or eval_mode or \ + "train_and_evaluate" in ConfigInitializer.get_instance().train_params_config.bool_gauge_set: + self._same_table_send_count += send_count if send_count is not None else 0 + + # create feature spec + feature_spec = get_feature_spec(self._table_name, kwargs.get("access_and_evict_config")) + feature_spec.set_feat_attribute(ids, is_training) + + # record anchor ids + anchor_ids = tf.identity(ids, name="ids") + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, anchor_ids) + self.register_anchor_attribute(anchor_ids, feature_spec, kwargs) + + # set modify graph + self._modify_graph = kwargs.get("modify_graph", True) + + # return the stub tensor of the lookup result + if not self._use_static: + kwargs["lookup_ids"] = ids + mock_lookup_result = self._lookup_forward(feature_spec, send_count, **kwargs) + mock_lookup_result = tf.identity(mock_lookup_result, name=ASCAnchorAttr.MOCK_LOOKUP_RESULT.value) + if not kwargs.get("is_grad"): + mock_lookup_result = tf.stop_gradient(mock_lookup_result, name="mock_stop_grad_lookup_res") + self.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.MOCK_LOOKUP_RESULT] = mock_lookup_result + logger.debug("Return the stub tensor `%s` of the `%s` table.", mock_lookup_result, self._table_name) + return mock_lookup_result + + def lookup_for_feat_spec(self, feature_spec: FeatureSpec, send_count: Optional[int], **kwargs) -> tf.Tensor: + """ + 稀疏表的lookup,FeatureSpec模式. + + Args: + feature_spec: 此次lookup的tensor的包装类 + send_count: all2all通信参数 + **kwargs: lookup参数字典 + + Returns: lookup结果 + """ + spec_name = feature_spec.name + is_training = kwargs.get("is_train") + if spec_name in self._lookup_result and is_training in self._lookup_result.get(spec_name): + lookup_result = self._lookup_result.get(spec_name).get(is_training) + if not kwargs.get("is_grad"): + return tf.stop_gradient(lookup_result, name="stop_grad_lookup_result") + return lookup_result + + if not self._use_static and not self._modify_graph and kwargs.get("batch") is None: + raise RuntimeError("When the 'feature spec' mode and 'dynamic shape' are used, the 'batch' is required.") + table_name = feature_spec.table_name + same_table_feature_spec = \ + ConfigInitializer.get_instance().feature_spec_config.table_name_to_feature_spec[table_name][is_training] + logger.debug("The feature spec of the same table is %s, table name is %s.", + ([fs.name for fs in same_table_feature_spec],), self._table_name) + + same_table_spec_count = len(same_table_feature_spec) + if same_table_spec_count == 0: + raise RuntimeError(f"spec_name {spec_name} not in table {table_name}.") + + if same_table_spec_count == 1: + lookup_result = self._lookup_forward(feature_spec, send_count, **kwargs) + if spec_name not in self._lookup_result: + self._lookup_result[spec_name] = {} + self._lookup_result[spec_name][is_training] = lookup_result + return lookup_result + + # 改图模式下FeatureSpec是按照lookup顺序创建的,无需对ids进行排序;fs模式下手动创建FeatureSpec,不一定有序 + if not self._modify_graph: + same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) + mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", table_name=table_name) + + if self._use_static: + tensor_list = [] + tensor_split_list = [feat_spec.split for feat_spec in same_table_feature_spec] + total_feature_count = sum(tensor_split_list) + else: + tensor_list = self.__get_tensor_list(same_table_feature_spec, **kwargs) + tensor_split_list = [tf.math.reduce_prod(array_ops.shape(tensor)) for tensor in tensor_list] + total_feature_count = tf.add_n(tensor_split_list) + set_temporary_feature_spec_attribute(mock_feature_spec, total_feature_count) + + kwargs["multi_lookup"] = True + total_send_count = send_count * same_table_spec_count + lookup_result = self._lookup_forward(mock_feature_spec, total_send_count, **kwargs) + logger.debug("multi lookup table %s via %s.", table_name, tensor_split_list) + self.__split_lookup_result(same_table_feature_spec, tensor_split_list, tensor_list, lookup_result, is_training) + + # 当一表多查完成后,将此表对应的feature specs列表清空,便于estimator模式下多轮eval时不会累加上轮eval的feature specs + ConfigInitializer.get_instance().feature_spec_config.clear_same_table_feature_spec(self.table_name, is_training) + if not kwargs.get("is_grad"): + return tf.stop_gradient(self._lookup_result.get(spec_name).get(is_training), name="stop_grad_lookup_res") + return self._lookup_result.get(spec_name).get(is_training) + + def _lookup_forward(self, feature_spec: FeatureSpec, send_count: Optional[int], **kwargs) -> tf.Tensor: + is_training = kwargs.get("is_train") + hashtable_params = dict(slice_device_vocabulary_size=self._slice_device_vocabulary_size, + slice_host_vocabulary_size=self._slice_host_vocabulary_size, send_count=send_count, + table_name=self._table_name, is_hbm=self._is_hbm) + check_emb_lookup_params(hashtable_params, feature_spec, send_count, is_training) + if ConfigInitializer.get_instance().use_static: + self._send_count = send_count + result = self._get_preprocessed_tensor(feature_spec, is_training, send_count) + + @tf.custom_gradient + def sparse_forward(table): + def grad(lookup_grad): + logger.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) + embedding_grad = tf.reshape(lookup_grad, [-1, self._emb_size]) + unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_grad, + result.get("restore_vector"), + unique_embeddings_shape[0]) + bp_all2all_args = all2all_args if self._use_static else tf.transpose(all2all_args) + hot, cold = tf.split(unique_grads, + [tf.shape(result.get("hot_pos"))[0], + tf.shape(unique_grads)[0] - tf.shape(result.get("hot_pos"))[0]], axis=0) + unique_grads = tf.tensor_scatter_nd_add(cold, tf.expand_dims(result.get("hot_pos"), 1), hot) + local_grad = self.__get_own_emb(unique_grads, bp_all2all_args) + + if self._all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: + try: + local_grad = local_grad / get_rank_size() + except ZeroDivisionError as exp: + raise ZeroDivisionError("Rank size cannot be zero.") from exp + + return self._get_update_grad(local_grad, result, table) + + logger.debug("fp rank size: %s", self._rank_size) + local_embeddings = self._get_local_embeddings(table, result, feature_spec, **kwargs) + all2all_args = send_count if self._use_static else result.get("all2all_args") + + unique_embeddings = self.__get_own_emb(local_embeddings, all2all_args) + unique_embeddings = tf.concat([tf.gather(unique_embeddings, result.get("hot_pos"), name="hot_pos"), + unique_embeddings], axis=0) + + if self._use_static: + unique_embeddings_shape = unique_embeddings.shape.as_list() + else: + unique_embeddings_shape = tf.shape(unique_embeddings) + + notify_hybridmgmt_op = self.__generate_lookup_id_notify_hybrid(is_training) + with tf.control_dependencies([notify_hybridmgmt_op]): + embeddings = tf.gather(unique_embeddings, result.get("restore_vector"), axis=0, + name="gather_for_restore_vector") + + if self._use_static: + return tf.reshape(embeddings, feature_spec.dims + [self._emb_size]), grad + + if kwargs.get("multi_lookup"): + return tf.reshape(embeddings, [-1, self._emb_size]), grad + + feature_spec_tensor = None + if not self._modify_graph: + feature_spec_tensor = kwargs.get("batch").get(feature_spec.index_key) + modify_graph_tensor = kwargs.get("lookup_ids") + tensor = feature_spec_tensor if not self._modify_graph else modify_graph_tensor + if tensor is None: + raise KeyError(f"key or ids does not exist in batch, now modify graph is {self._modify_graph}.") + dest_shape = array_ops.concat([array_ops.shape(tensor), [self._emb_size]], 0) + + return array_ops.reshape(embeddings, dest_shape), grad + + with tf.control_dependencies(result.get("swap_in")): + return self._get_sparse_forward_result(sparse_forward, self._variable, result, is_training) + + def __initialize_variables(self): + initialized_tensor = self._emb_initializer( + self._slice_device_vocabulary_size + self._embedding_size) * self._init_param + self._variable = tf.compat.v1.get_variable(self._table_name, trainable=False, initializer=initialized_tensor) + + # make sure sparse table variable will not be saved and restored within tf checkpoint. + ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(self._variable.name) + + self.__record() + self._build_optimizer_states() + + def __record(self): + ConfigInitializer.get_instance().sparse_embed_config.insert_table_instance( + self._table_name, self._variable, self) + logger.debug("Device vocabulary_size for table %s is %s.", self._table_name, self._device_vocabulary_size) + logger.debug("Slice_device_vocabulary_size for table %s is %s.", + self._table_name, self._slice_device_vocabulary_size) + logger.debug("Host vocabulary size for table %s is %s.", self._table_name, self._host_vocabulary_size) + logger.debug("Slice host vocabulary_size for table %s is %s.", + self._table_name, self._slice_host_vocabulary_size) + logger.debug("SSD vocabulary size for table %s is %s.", self._table_name, self._ssd_vocabulary_size) + logger.debug("Slice ssd vocabulary_size for table %s is %s.", + self._table_name, self._slice_ssd_vocabulary_size) + + def __get_own_emb(self, emb: tf.Tensor, all2all_args: Union[int, tf.Tensor]) -> tf.Tensor: + src_emb = emb + reshape_info = [all2all_args * self._rank_size, self._emb_size] if self._use_static else \ + [-1, self._emb_size] + + if self._rank_size == 1 and self._use_static: + return tf.reshape(src_emb, reshape_info) + + if self._use_static: + emb_send_cnt = tf.constant([all2all_args * self._emb_size] * self._rank_size, dtype=tf.int64) + emb_send_offset = tf.constant([all2all_args * self._emb_size * i for i in range(self._rank_size)], + dtype=tf.int64) + src_emb = hccl_ops.all_to_all_v(send_data=emb, + send_counts=emb_send_cnt, + send_displacements=emb_send_offset, + recv_counts=emb_send_cnt, + recv_displacements=emb_send_offset) + else: + src_emb = hccl_ops.all_to_all_v_c(send_data=emb, + send_count_matrix=all2all_args, + rank=self._rank_id) + + return tf.reshape(src_emb, reshape_info) + + def __get_tensor_list(self, same_table_feature_spec: list, **kwargs) -> list: + same_table_tensor_list = [] + for feat_spec in same_table_feature_spec: + feature_spec_tensor_dict = kwargs.get("batch") + modify_graph_tensor_dict = kwargs.get("feature_spec_name_ids_dict") + batch_tensor_dict = feature_spec_tensor_dict if not self._modify_graph else modify_graph_tensor_dict + if batch_tensor_dict is None: + raise KeyError(f"The tensor dict of batch does not exist in kwargs, and modify graph " + f"is `{self._modify_graph}`.") + + feature_spec_tensor = batch_tensor_dict.get(feat_spec.index_key) + modify_graph_tensor = batch_tensor_dict.get(feat_spec.name) + tensor = feature_spec_tensor if not self._modify_graph else modify_graph_tensor + if tensor is None: + tensor_key = feat_spec.index_key if not self._modify_graph else feat_spec.name + raise KeyError(f"Key `{tensor_key}` does not exist in batch_tensor_dict.") + same_table_tensor_list.append(tensor) + return same_table_tensor_list + + def __split_lookup_result(self, same_table_feature_spec: list, tensor_split_list: list, tensor_list: list, + lookup_result: tf.Tensor, is_training: bool): + lookup_result_split = tf.split(lookup_result, tensor_split_list) + if len(lookup_result_split) != len(same_table_feature_spec) or ( + not self._use_static and len(same_table_feature_spec) != len(tensor_list)): + raise RuntimeError(f"shape not match. len(lookup_result_split): {len(lookup_result_split)}," + f"len(same_table_feature_spec): {len(same_table_feature_spec)}" + f"len(tensor_list): {len(tensor_list)}") + for idx, (one_feature_spec, one_result) in enumerate(zip(same_table_feature_spec, lookup_result_split)): + if one_feature_spec.name not in self._lookup_result: + self._lookup_result[one_feature_spec.name] = {} + if self._use_static: + dest_shape = one_feature_spec.dims + [self._emb_size] + else: + dest_shape = array_ops.concat([array_ops.shape(tensor_list[idx]), [self._emb_size]], 0) + self._lookup_result[one_feature_spec.name][is_training] = array_ops.reshape(one_result, dest_shape) + + def __generate_lookup_id_notify_hybrid(self, is_training: bool): + """ + 用于打桩的op节点,它的name用于标识此次的sparse lookup是train还是eval,后续在session run的时候, + 通过图反向查找该子图中查找到此op,最后通过名称判断session run是调用的哪个通道,并通知c++侧进行计数和唤醒操作。 + + Args: + is_training: 当前流程是训练还是推理 + + Returns: 指定名字的tf.no_op() + """ + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(is_training) + channel_name = "d2h_notify_hybridmgmt_{}".format(channel_id) + notify_hybridmgmt_op = tf.no_op(channel_name) + logger.debug("The notify hybridmgmg op of table `%s` is `%s`.", self._table_name, notify_hybridmgmt_op.name) + return notify_hybridmgmt_op diff --git a/mx_rec/core/emb/dynamic_sparse_embedding.py b/mx_rec/core/emb/dynamic_sparse_embedding.py new file mode 100644 index 00000000..194b2795 --- /dev/null +++ b/mx_rec/core/emb/dynamic_sparse_embedding.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import abc +from typing import Optional, Union, Callable + +import tensorflow as tf +from tensorflow.python.ops import array_ops + +from mx_rec.constants.constants import ASCEND_TABLE_NAME_MUST_CONTAIN, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, \ + ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS +from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc +from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.util.log import logger +from mx_rec.util.ops import import_host_pipeline_ops + + +class DynamicSparseEmbedding(BaseSparseEmbedding): + """ + 稀疏表,表的大小非固定,支持动态扩容 + """ + + def __init__(self, config: dict): + super(DynamicSparseEmbedding, self).__init__(config) + + def capacity(self) -> int: + return ConfigInitializer.get_instance().hybrid_manager_config.asc_manager.get_table_capacity(self._table_name) + + @abc.abstractmethod + def set_optimizer(self, key: str, state_dict: dict): + pass + + @abc.abstractmethod + def _build_optimizer_states(self): + pass + + @abc.abstractmethod + def _set_ext_emb_size(self): + pass + + @abc.abstractmethod + def _set_slice_vocab_size(self): + pass + + @abc.abstractmethod + def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: + pass + + def _get_update_grad(self, local_grad: tf.Tensor, result: dict, + table: Union[tf.compat.v1.Variable, tf.Tensor]) -> Union[tf.IndexedSlices, tf.Tensor]: + return tf.compat.v1.unsorted_segment_sum(local_grad, + result.get("restore_vector_second"), + array_ops.shape(result.get("unique_keys"))[0]) + + def _get_local_embeddings(self, table: Union[tf.compat.v1.Variable, tf.Tensor], result: dict, + feature_spec: FeatureSpec, **kwargs) -> tf.Tensor: + return tf.identity(table, name="identity_local_emb") + + def _get_sparse_forward_result(self, sparse_forward_fn: Callable, table: Union[tf.compat.v1.Variable, tf.Tensor], + result: dict, is_training: bool) -> tf.Tensor: + local_embeddings = import_host_pipeline_ops().embedding_lookup_by_address( + result.get("id_offsets"), embedding_dim=self._emb_size, embedding_type=1) + + add_collection_condition = is_training and ( + ASCEND_TABLE_NAME_MUST_CONTAIN is None or ASCEND_TABLE_NAME_MUST_CONTAIN in self._table_name) + logger.debug("feature spec mode, table_name: %s, ASCEND_TABLE_NAME_MUST_CONTAIN: %s", + self._table_name, ASCEND_TABLE_NAME_MUST_CONTAIN) + if not add_collection_condition: + return sparse_forward_fn(local_embeddings) + + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) + tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, result.get("unique_keys")) + return sparse_forward_fn(local_embeddings) + + +class HBMDynamicSparseEmbedding(DynamicSparseEmbedding): + """ + 稀疏表,表的大小非固定,支持动态扩容,HBM模式 + """ + + def __init__(self, config: dict): + super(DynamicSparseEmbedding, self).__init__(config) + + def set_optimizer(self, key: str, state_dict: dict): + pass + + def _build_optimizer_states(self): + pass + + def _set_ext_emb_size(self): + self._ext_emb_size = self._emb_size * self._ext_coefficient + logger.debug("init table, ext_emb_size is set to be %s.", self._ext_emb_size) + + def _set_slice_vocab_size(self): + # 动态扩容模式下,保留device侧variable,大小设置为1 + self._slice_device_vocabulary_size = 1 + + def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(is_training) + config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, + rank_size=self._rank_size, channel_id=channel_id, table_name=self._table_name, + is_hbm=self._is_hbm, ext_emb_size=self._ext_emb_size, + emb_size=self._emb_size, device_id=self._device_id, use_dynamic_expansion=True) + + return get_preprocessed_tensor_for_asc(self._variable, config) diff --git a/mx_rec/core/emb/emb_factory.py b/mx_rec/core/emb/emb_factory.py new file mode 100644 index 00000000..275390d9 --- /dev/null +++ b/mx_rec/core/emb/emb_factory.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import abc + +from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding +from mx_rec.core.emb.dynamic_sparse_embedding import HBMDynamicSparseEmbedding +from mx_rec.core.emb.sparse_embedding import HBMSparseEmbedding, ExternalStorageSparseEmbedding + + +class BaseSparseEmbeddingFactory(metaclass=abc.ABCMeta): + """ + 创建Embedding的工厂基类. + """ + + @abc.abstractmethod + def create_embedding(self, config: dict) -> BaseSparseEmbedding: + """ + 创建embedding类. + + Args: + config: 创建embedding所需的参数字典. + + Returns: embedding类 + """ + pass + + +class HBMDynamicSparseEmbeddingFactory(BaseSparseEmbeddingFactory): + """ + HBMDynamicSparseEmbedding工厂. + """ + + def create_embedding(self, config: dict) -> HBMDynamicSparseEmbedding: + return HBMDynamicSparseEmbedding(config) + + +class HBMSparseEmbeddingFactory(BaseSparseEmbeddingFactory): + """ + HBMSparseEmbedding工厂. + """ + + def create_embedding(self, config: dict) -> HBMSparseEmbedding: + return HBMSparseEmbedding(config) + + +class ExternalStorageSparseEmbeddingFactory(BaseSparseEmbeddingFactory): + """ + ExternalStorageSparseEmbedding工厂. + """ + + def create_embedding(self, config: dict) -> ExternalStorageSparseEmbedding: + return ExternalStorageSparseEmbedding(config) diff --git a/mx_rec/core/emb/sparse_embedding.py b/mx_rec/core/emb/sparse_embedding.py new file mode 100644 index 00000000..d8ce63b1 --- /dev/null +++ b/mx_rec/core/emb/sparse_embedding.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import abc +import math +from typing import Optional, Union, Callable + +import tensorflow as tf +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops + +from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc +from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding +from mx_rec.optimizers.emb_optimizer import EmbOptimizer +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.util.log import logger + + +class SparseEmbedding(BaseSparseEmbedding): + """ + 稀疏表,表的大小为固定大小,不支持动态扩容 + """ + + def __init__(self, config: dict): + super(SparseEmbedding, self).__init__(config) + + @abc.abstractmethod + def capacity(self) -> int: + pass + + @abc.abstractmethod + def set_optimizer(self, key: str, state_dict: dict): + pass + + @abc.abstractmethod + def _set_ext_emb_size(self): + pass + + @abc.abstractmethod + def _build_optimizer_states(self): + pass + + @abc.abstractmethod + def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: + pass + + def _set_slice_vocab_size(self): + self._slice_device_vocabulary_size = math.ceil(self._device_vocabulary_size / self._rank_size) + self._slice_host_vocabulary_size = math.ceil(self._host_vocabulary_size / self._rank_size) + self._slice_ssd_vocabulary_size = math.ceil(self._ssd_vocabulary_size / self._rank_size) + + def _get_update_grad(self, local_grad: tf.Tensor, result: dict, + table: Union[tf.compat.v1.Variable, tf.Tensor]) -> Union[tf.IndexedSlices, tf.Tensor]: + unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, + result.get("restore_vector_second"), + array_ops.shape(result.get("unique_keys"))[0]) + return ops.IndexedSlices(values=unique_local_grad, + indices=result.get("unique_keys"), + dense_shape=tf.shape(table)) + + def _get_local_embeddings(self, table: Union[tf.compat.v1.Variable, tf.Tensor], result: dict, + feature_spec: FeatureSpec, **kwargs) -> tf.Tensor: + id_offsets_abs = tf.abs(result.get("id_offsets")) + local_embeddings = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") + local_embeddings = _set_specific_value_for_non_valid_key(result.get("id_offsets"), + local_embeddings, + feature_spec.access_threshold, + kwargs.get("serving_default_value"), + is_training=kwargs.get("is_train")) + return local_embeddings + + def _get_sparse_forward_result(self, sparse_forward_fn: Callable, table: Union[tf.compat.v1.Variable, tf.Tensor], + result: dict, is_training: bool) -> tf.Tensor: + return sparse_forward_fn(self._variable) + + +class HBMSparseEmbedding(SparseEmbedding): + """ + 稀疏表,表的大小为固定大小,HBM模式 + """ + + def __init__(self, config: dict): + super(HBMSparseEmbedding, self).__init__(config) + + def capacity(self) -> int: + return self._device_vocabulary_size + + def set_optimizer(self, key: str, state_dict: dict): + pass + + def _build_optimizer_states(self): + pass + + def _set_ext_emb_size(self): + self._ext_emb_size = self._emb_size * self._ext_coefficient + logger.debug("Init table, ext_emb_size is set to be %s.", self._ext_emb_size) + + def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(is_training) + config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, + rank_size=self._rank_size, channel_id=channel_id, table_name=self._table_name, + is_hbm=self._is_hbm, ext_emb_size=self._ext_emb_size, + emb_size=self._emb_size, device_id=self._device_id) + + return get_preprocessed_tensor_for_asc(self._variable, config) + + +class ExternalStorageSparseEmbedding(SparseEmbedding): + """ + 稀疏表,表的大小为固定大小,DDR/SSD模式 + """ + + def __init__(self, config: dict): + self.emb_optimizer = EmbOptimizer(config.get("optimizer_list")) + self.emb_optimizer.check_optimizer_instance_list() + + super(ExternalStorageSparseEmbedding, self).__init__(config) + + @property + def optimizer(self): + return self.emb_optimizer.optimizer + + @property + def optimizer_instance_list(self): + return self.emb_optimizer.optimizer_instance_list + + def capacity(self) -> int: + # DDR + if not self._ssd_vocabulary_size: + return self._device_vocabulary_size + self._host_vocabulary_size + # SSD + return self._device_vocabulary_size + self._host_vocabulary_size + self._ssd_vocabulary_size + + def set_optimizer(self, key: str, state_dict: dict): + self.emb_optimizer.set_optimizer(key, state_dict, self._table_name) + + def _set_ext_emb_size(self): + self._ext_coefficient += len(self.emb_optimizer.optimizer_slot_info_list) + self._ext_emb_size = self._emb_size * self._ext_coefficient + logger.debug("Init table, ext_emb_size is set to be %s.", self._ext_emb_size) + + def _build_optimizer_states(self): + for sparse_optimizer_instance in self.emb_optimizer.optimizer_instance_list: + slot_info_list = sparse_optimizer_instance.initialize_slots(self._variable, self) + self.emb_optimizer.optimizer_slot_info_list.extend(slot_info_list) + + for slot_info in self.emb_optimizer.optimizer_slot_info_list: + self.emb_optimizer.set_optimizer_slot(slot_info) + + def _get_preprocessed_tensor(self, feature_spec: FeatureSpec, is_training: bool, send_count: Optional[int]) -> dict: + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(is_training) + config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, + rank_size=self._rank_size, channel_id=channel_id, table_name=self._table_name, + is_hbm=self._is_hbm, ext_emb_size=self._ext_emb_size, + emb_size=self._emb_size, device_id=self._device_id) + + variable_list = [self._variable] + \ + [slot_info.get("slot") for slot_info in self.emb_optimizer.optimizer_slot_info_list] + return get_preprocessed_tensor_for_asc(variable_list, config) + + +def _set_specific_value_for_non_valid_key(id_offsets: Optional[tf.Tensor], + embeddings: Optional[tf.Tensor], + access_threshold: Optional[int], + serving_default_value: Optional[tf.Tensor] = None, + is_training: bool = True) -> tf.Tensor: + """ + 将key为-1(无效值)的特征对应的emb置为0或者指定值. + + Args: + id_offsets: 特征索引 + embeddings: 稀疏表 + access_threshold: 准入阈值 + serving_default_value: 参考sparse_lookup接口描述 + is_training: 当前流程是训练还是推理 + + Returns: embeddings + """ + # 在训练时,仅当开启准入功能才会出现无效值;推理时,是否开启准入都可能存在无效值 + if is_training and (access_threshold is None or access_threshold < 0): + return embeddings + + if serving_default_value is None: + # 未设置时,默认无效值的emb为全0 + default_value = tf.zeros_like(embeddings) + else: + try: + default_value = tf.broadcast_to(serving_default_value, tf.shape(embeddings)) + except ValueError as e: + logger.error("failed to broadcast serving_default_value to target embedding , please check its shape.") + raise e + except Exception as e: + logger.error("failed to process serving_default_value.") + raise e + + if tf.__version__.startswith("1"): + id_offsets_expand = tf.math.greater_equal(id_offsets, 0) + embeddings = tf.where(id_offsets_expand, embeddings, default_value) + return embeddings + + id_offsets_expand = tf.compat.v1.expand_dims(id_offsets >= 0, axis=-1) + embeddings = tf.where(id_offsets_expand, embeddings, default_value) + return embeddings diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 6b2adc83..0438f281 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -15,35 +15,25 @@ # limitations under the License. # ============================================================================== -import math import os -from collections import defaultdict from typing import Optional, Union import tensorflow as tf -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops.init_ops import Initializer as InitializerV1 from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2 -from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc -from mx_rec.core.asc.feature_spec import FeatureSpec, get_feature_spec, set_temporary_feature_spec_attribute -from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ - ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_LOCAL_EMB, MULTI_LOOKUP_TIMES, \ - ASCEND_TABLE_NAME_MUST_CONTAIN, MAX_INT32, All2allGradientsOp, ApplyGradientsStrategy, MAX_VOCABULARY_SIZE, \ - MAX_DEVICE_VOCABULARY_SIZE -from mx_rec.util.initialize import get_rank_id, get_rank_size, is_asc_frozen, get_customized_ops, \ - insert_table_instance, get_training_mode_channel_id, get_use_static, get_name_to_var_dict, \ - clear_channel, get_use_hot, get_device_id, ConfigInitializer, get_ascend_global_hashtable_collection, \ - get_host_pipeline_ops, get_use_dynamic_expansion, set_modify_graph, insert_removing_var_list, get_bool_gauge_set, \ - get_asc_manager, get_table_name_to_feature_spec, clear_same_table_feature_spec +from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding +from mx_rec.core.emb.emb_factory import HBMDynamicSparseEmbeddingFactory, HBMSparseEmbeddingFactory, \ + ExternalStorageSparseEmbeddingFactory +from mx_rec.graph.utils import tag_orphan_ids +from mx_rec.constants.constants import MAX_INT32, All2allGradientsOp, MAX_VOCABULARY_SIZE, MAX_DEVICE_VOCABULARY_SIZE +from mx_rec.util.initialize import ConfigInitializer from mx_rec.validator.validator import ClassValidator, StringValidator, SSDFeatureValidator, \ para_checker_decorator, IntValidator, NumValidator, OptionValidator, OptionalIntValidator, \ OptionalStringValidator, FloatValidator -from mx_rec.util.tf_version_adapter import hccl_ops +from mx_rec.validator.emb_validator import check_emb_multi_lookup_times from mx_rec.util.normalization import fix_invalid_table_name -from mx_rec.util.global_env_conf import global_env from mx_rec.util.log import logger @@ -107,697 +97,18 @@ def create_table(key_dtype, dim, name, emb_initializer, ssd_vocabulary_size=ssd_vocabulary_size, ssd_data_path=ssd_data_path, optimizer_list=optimizer_list, init_param=init_param, is_save=is_save, all2all_gradients_op=all2all_gradients_op) - embedding = SparseEmbedding(config) - return embedding - - -class SparseEmbedding: - """ - each feat_name has its own sparse_embedding_layer. - """ - customized_ops = get_customized_ops() - anchor_tensor_specs = defaultdict(dict) - - def __init__(self, config): - self.embedding_size = config.get("embedding_size") - if isinstance(self.embedding_size, int): - self.embedding_size = tf.TensorShape([self.embedding_size]) - self.device_vocabulary_size = config.get("device_vocabulary_size") - self.host_vocabulary_size = config.get("host_vocabulary_size") - self.ssd_vocabulary_size = config.get("ssd_vocabulary_size") - self.ssd_data_path = list(config.get("ssd_data_path")) - self.table_name = config.get("table_name") - self.key_dtype = config.get("key_dtype") - self._optimizer_instance_list = config.get("optimizer_list") - self.emb_initializer = config.get("emb_initializer") - self.is_save = config.get("is_save") - self.optimizer_slot_info_list = [] - self._slot_num = dict() - self._send_count = 0 - self.same_table_send_count = 0 - self.skip_emb_transfer = True if self.host_vocabulary_size <= 0 else False - self._default_name_count = -1 - self.emb_size = None - self.ext_emb_size = None - self.ext_coefficient = 1 - self._optimizer = dict() - self.slice_device_vocabulary_size = 0 - self.slice_host_vocabulary_size = 0 - self.slice_ssd_vocabulary_size = 0 - self.variable = None - self.lookup_info = set() - self.lookup_result = dict() - self.use_dynamic_expansion = get_use_dynamic_expansion() - self.lookup_name_dict = {True: [], False: []} - self.modify_graph = False - self.init_param = config.get("init_param") - self.all2all_gradients_op = All2allGradientsOp.mapping(config.get("all2all_gradients_op")) - self.is_grad = False - - self.set_slice_vocab_size() - self.set_emb_size() - if is_asc_frozen() and self.table_name in get_name_to_var_dict(): - self.variable = tf.compat.v1.get_variable(self.table_name, - shape=(self.slice_device_vocabulary_size, self.emb_size)) - if not self.skip_emb_transfer: - self.set_ext_emb_size() - else: - self.check_and_format_init_params() - self._initialize_variables() - self.set_ext_emb_size() - tf.compat.v1.add_to_collection(get_ascend_global_hashtable_collection(), self.variable) - - @property - def scalar_emb_size(self): - return self.emb_size - - @property - def send_count(self): - return self._send_count - - @property - def optimizer(self): - return self._optimizer - - @property - def optimizer_instance_list(self): - return self._optimizer_instance_list - - @staticmethod - def generate_lookup_id_notify_hybrid(channel_id: int): - - """ - Args: - channel_id: channel id 0 for train,1 for eval - Returns: tf.no_op notify preprocess step - """ - channel_name = "d2h_notify_hybridmgmt_{}".format(channel_id) - notify_hybridmgmt_op = tf.no_op(channel_name) - return notify_hybridmgmt_op - - @staticmethod - def get_anchor_attribute(anchor, attr): - if not isinstance(anchor, tf.Tensor): - raise TypeError("Anchor must be a Tensor.") - - if attr not in ASCAnchorAttr: - raise ValueError("Given attr must be limited in Enum 'ASCAnchorAttr'.") - - specs = SparseEmbedding.anchor_tensor_specs.get(anchor) - if specs is None: - raise KeyError(f"Given anchor '{anchor}' was not registered.") - - return specs.get(attr) - - @staticmethod - def set_optimizer_slot(slot_info): - slot = slot_info.get("slot") - slot_name = slot_info.get("slot_name") - optimizer = slot_info.get("optimizer") - named_slot_key = slot_info.get("named_slot_key") - - optimizer.insert_slot(slot, named_slot_key, slot_name) - - @staticmethod - def _get_own_emb(emb, all2all_args, emb_size, use_static): - """ - obtain embedding of source data - :param emb: origin embeddding - :param all2all_args: dynamic shape condition parameters - :param emb_size: size of embedding table - :param use_static: enable static shape training or not - :return: local embedding after all2all - """ - - rank_size = get_rank_size() - rank_id = get_rank_id() - - src_emb = emb - - reshape_info = [all2all_args * rank_size, emb_size] if use_static else [-1, emb_size] - - if rank_size == 1 and use_static: - return tf.reshape(src_emb, reshape_info) - - if use_static: - emb_send_cnt = tf.constant([all2all_args * emb_size] * rank_size, dtype=tf.int64) - emb_send_offset = tf.constant([all2all_args * emb_size * i for i in range(rank_size)], dtype=tf.int64) - src_emb = hccl_ops.all_to_all_v(send_data=emb, - send_counts=emb_send_cnt, - send_displacements=emb_send_offset, - recv_counts=emb_send_cnt, - recv_displacements=emb_send_offset) - else: - src_emb = hccl_ops.all_to_all_v_c(send_data=emb, - send_count_matrix=all2all_args, - rank=rank_id) - - return tf.reshape(src_emb, reshape_info) - - def size(self) -> int: - """ - For HBM or DDR or SSD mode, return the size of sparse table - """ - return get_asc_manager().get_table_size(self.table_name) - - def capacity(self) -> int: - """ - For HBM or DDR or SSD mode, return the capacity of sparse table - """ - if get_use_dynamic_expansion(): - return get_asc_manager().get_table_capacity(self.table_name) - - if not self.host_vocabulary_size and not self.ssd_vocabulary_size: - return self.device_vocabulary_size - if not self.ssd_vocabulary_size: - return self.device_vocabulary_size + self.host_vocabulary_size - return self.device_vocabulary_size + self.host_vocabulary_size + self.ssd_vocabulary_size - - def check_optimizer_instance(self): - for optimizer_instance in self._optimizer_instance_list: - if tf.__version__.startswith("1"): - from npu_bridge.estimator.npu.npu_loss_scale_optimizer import NPULossScaleOptimizer - if isinstance(optimizer_instance, NPULossScaleOptimizer): - optimizer_instance = getattr(optimizer_instance, '_opt') - else: - from npu_device.train.optimizer.npu_loss_scale_optimizer import NpuLossScaleOptimizer - if isinstance(optimizer_instance, NpuLossScaleOptimizer): - optimizer_instance = getattr(optimizer_instance, '_opt') - - if not isinstance(optimizer_instance, CustomizedOptimizer): - raise ValueError(f"args optimizer list must be a list or an instance of CustomizedOptimizer.") - - def check_and_format_init_params(self): - if self.embedding_size.ndims != 1: - raise ValueError("Parameter 'embedding_size' can only be one dim shape.") - - if is_asc_frozen(): - raise EnvironmentError(f"Emb cache management has been established, you cannot build new ASC hash table.") - - if not self.skip_emb_transfer and not self._optimizer_instance_list: - raise ValueError("ASC with DDR mode should config optimizers before instantiating sparse table, " - "but nothing was configured.") - - if not self.skip_emb_transfer and self.use_dynamic_expansion: - raise ValueError("DDR mode do not support embedding dynamic_expansion for now.") - - self._optimizer_instance_list = [] if self._optimizer_instance_list is None else self._optimizer_instance_list - if isinstance(self._optimizer_instance_list, CustomizedOptimizer): - self._optimizer_instance_list = [self._optimizer_instance_list] - - if not isinstance(self._optimizer_instance_list, (tuple, list)): - raise ValueError(f"args optimizer list must be a list or an instance of CustomizedOptimizer.") - self._optimizer_instance_list = list(self._optimizer_instance_list) - - self.check_optimizer_instance() - - def get_default_lookup_name(self): - self._default_name_count += 1 - default_name = "sparse_lookup_%d" % self._default_name_count - logger.debug("getting one default lookup name %s", default_name) - return default_name - - def set_emb_size(self): - self.emb_size = self.embedding_size.as_list()[0] - - def set_ext_emb_size(self): - self.ext_coefficient += len(self.optimizer_slot_info_list) - if self.use_dynamic_expansion and len(self._optimizer_instance_list) != 0: - self.ext_coefficient += self._slot_num.get(self.table_name) - self.ext_emb_size = self.emb_size * self.ext_coefficient - logger.debug("init table, ext_emb_size is set to be %s", self.ext_emb_size) - - def set_slice_vocab_size(self): - rank_size = get_rank_size() - if self.use_dynamic_expansion: - self.slice_device_vocabulary_size = 1 # 动态扩容模式下,保留device侧variable,大小设置为 1 - self.slice_host_vocabulary_size = 0 - else: - self.slice_device_vocabulary_size = math.ceil(self.device_vocabulary_size / rank_size) - self.slice_host_vocabulary_size = math.ceil(self.host_vocabulary_size / rank_size) - self.slice_ssd_vocabulary_size = math.ceil(self.ssd_vocabulary_size / rank_size) - - def register_anchor_attribute(self, anchor_ids, feature_spec, kwargs): - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.TABLE_INSTANCE] = self - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = kwargs.get("is_train") - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.FEATURE_SPEC] = feature_spec - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_GRAD] = kwargs.get("is_grad") - - def check_multi_lookup_times(self, is_training): - lookup_times = len(self.lookup_name_dict.get(is_training)) if self.modify_graph else len(self.lookup_result) - if not self.modify_graph and get_training_mode_channel_id(True) is not None and \ - get_training_mode_channel_id(False) is not None: - lookup_times = int(lookup_times / 2) - if lookup_times > MULTI_LOOKUP_TIMES: - run_mode = "Modify Graph" if self.modify_graph else "Feature Spec" - raise RuntimeError(f"In '{run_mode}' mode, the number of multiple sparse lookup for a table" - f"({self.table_name}) is {MULTI_LOOKUP_TIMES}, and current times is {lookup_times}.") - - def check_and_format_lookup_params(self, feature, send_count, is_training): - logger.debug("sparse lookup for table %s with is_training %s", self.table_name, is_training) - - def check_params(): - if not isinstance(is_training, bool): - raise ValueError("Arg is_train should be a boolean.") - - if isinstance(feature, FeatureSpec): - if not feature.initialized: - raise ValueError(f"Feature Spec has not been initialized.") - if is_training not in feature.pipeline_mode: - raise ValueError(f"You have not config feature for is training mode '{is_training}', please config " - f"feature with func sparse_lookup at first.") - - elif isinstance(feature, tf.Tensor): - logger.debug("Input feature is a Tensor.") - - else: - raise TypeError(f"Given feature must be a FeatureSpec or tf.Tensor.") - - if is_training not in self.lookup_info: - self.lookup_info.add(is_training) - - if not isinstance(self.init_param, float): - raise ValueError("Arg init_param should be a float.") - - if get_use_static(): - if isinstance(send_count, int) and send_count > 0: - if self._send_count and self._send_count != send_count: - logger.warning("A new send count %s will be used to replace the old one (%s).", - send_count, self._send_count) - - self._send_count = send_count - else: - raise ValueError("Send count must be a integer which is larger than 0.") - - check_params() - if self.slice_host_vocabulary_size + self.slice_device_vocabulary_size > MAX_VOCABULARY_SIZE: - raise ValueError(f"Given device_vocabulary_size and host_vocabulary_size was too big for table " - f"'{self.table_name}', in which slice_device_vocabulary_size was " - f"{self.slice_device_vocabulary_size} and slice_host_vocabulary_size was " - f"{self.slice_host_vocabulary_size} ") - - is_check_mode = not self.skip_emb_transfer and not self.use_dynamic_expansion - if is_check_mode and self.slice_device_vocabulary_size < self.send_count * get_rank_size(): - raise ValueError(f"Given device_vocabulary_size was too small for table '{self.table_name}', in which " - f"slice_device_vocabulary_size was {self.slice_device_vocabulary_size} and " - f"send_count({self.send_count}) * rank_size({get_rank_size()}) was " - f"{self.send_count * get_rank_size()}") - - if is_check_mode and self.slice_host_vocabulary_size < self.send_count * get_rank_size(): - raise ValueError(f"Given host_vocabulary_size was too small for table '{self.table_name}', in which " - f"slice_host_vocabulary_size was {self.slice_host_vocabulary_size} and " - f"send_count({self.send_count}) * rank_size({get_rank_size()}) was " - f"{self.send_count * get_rank_size()}") - - def set_optimizer(self, key, state_dict): - if key in self._optimizer: - raise ValueError(f"Optimizer {key} has been set for hash table {self.table_name}") - - self._optimizer[key] = state_dict - - def lookup_for_asc(self, ids: tf.Tensor, send_count, **kwargs): - """ - - Args: - ids: Tensor to lookup from hashtable - send_count: int, used to config all2all communication parameters - kwargs: - dim: not in use - is_train: - name: not in use - modify_graph: if True, the original graph will be modified before building a Session instance - - Returns: Tensor for lookup result - - """ - logger.debug(f"Enter ASC Branch.") - is_training = kwargs.get("is_train") - self.check_and_format_lookup_params(ids, send_count, is_training) - if is_asc_frozen() and is_training: - raise RuntimeError(f"Cannot build new sparse forward graph after emb cache management was built.") - - # record send count - eval_mode = not is_training and get_training_mode_channel_id(True) is None - if is_training or eval_mode or "train_and_evaluate" in get_bool_gauge_set(): - self.same_table_send_count += send_count if send_count is not None else 0 - - # create feature spec - feature_spec = get_feature_spec(self.table_name, kwargs.get("access_and_evict_config")) - feature_spec.set_feat_attribute(ids, is_training) - # 'clear_channel()' function needs to be executed after 'set_feat_attribute()' function - if is_asc_frozen() and not is_training: - clear_channel(is_train_channel=False) - - # record anchor ids - anchor_ids = tf.identity(ids, name="ids") - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, anchor_ids) - self.register_anchor_attribute(anchor_ids, feature_spec, kwargs) - - # record multi lookup info - ids_lookup_name = feature_spec.name + "_lookup_ids" - if self.lookup_name_dict.get(is_training) is None: - self.lookup_name_dict[is_training] = [] - self.lookup_name_dict.get(is_training).append(ids_lookup_name) - self.modify_graph = kwargs.get("modify_graph", True) - self.check_multi_lookup_times(is_training) - - # return the stub tensor of the lookup result - if not get_use_static(): - kwargs["lookup_ids"] = ids - mock_lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) - mock_lookup_result = tf.identity(mock_lookup_result, name=ASCAnchorAttr.MOCK_LOOKUP_RESULT.value) - if not kwargs.get("is_grad"): - mock_lookup_result = tf.stop_gradient(mock_lookup_result, name="mock_stop_grad_lookup_res") - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.MOCK_LOOKUP_RESULT] = mock_lookup_result - logger.debug("Return the stub tensor `%s` of the `%s` table.", mock_lookup_result, self.table_name) - return mock_lookup_result - - def lookup_for_asc_with_feature_spec(self, feature_spec: FeatureSpec, send_count: int, **kwargs): - """ - Args: - feature_spec: an instance of FeatureSpec to lookup from hashtable - send_count: int, used to config all2all communication parameters - kwargs: - dim: not in use - is_train: - name: not in use - modify_graph: if True, the original graph will be modified before building a Session instance - - Returns: Tensor for lookup result - - """ - spec_name = feature_spec.name - is_training = kwargs.get("is_train") - if spec_name in self.lookup_result and is_training in self.lookup_result.get(spec_name): - if not kwargs.get("is_grad"): - return tf.stop_gradient(self.lookup_result.get(spec_name).get(is_training), name="stop_grad_lookup_res") - return self.lookup_result.get(spec_name).get(is_training) - - if not get_use_static() and not self.modify_graph and kwargs.get("batch") is None: - raise RuntimeError("When the 'feature spec' mode and 'dynamic shape' are used, the 'batch' is required.") - table_name = feature_spec.table_name - same_table_feature_spec = get_table_name_to_feature_spec(table_name, is_training) - logger.debug("The feature spec of the same table is %s, table name is %s.", - [fs.name for fs in same_table_feature_spec], self.table_name) - same_table_spec_count = len(same_table_feature_spec) - if same_table_spec_count == 0: - raise RuntimeError(f"spec_name {spec_name} not in table {table_name}.") - if same_table_spec_count == 1: - lookup_result = self.lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs) - if spec_name not in self.lookup_result: - self.lookup_result[spec_name] = {} - self.lookup_result[spec_name][is_training] = lookup_result - else: - def get_tensor_list() -> list: - """ - Use 'feature spec' to find the corresponding tensor from batch. - Returns: Tensor list in batch. - """ - same_table_tensor_list = [] - for feat_spec in same_table_feature_spec: - feature_spec_tensor_dict = kwargs.get("batch") - modify_graph_tensor_dict = kwargs.get("feature_spec_name_ids_dict") - batch_tensor_dict = feature_spec_tensor_dict if not self.modify_graph else modify_graph_tensor_dict - if batch_tensor_dict is None: - raise KeyError(f"The tensor dict of batch does not exist in kwargs, and modify graph " - f"is `{self.modify_graph}`.") - - feature_spec_tensor = batch_tensor_dict.get(feat_spec.index_key) - modify_graph_tensor = batch_tensor_dict.get(feat_spec.name) - tensor = feature_spec_tensor if not self.modify_graph else modify_graph_tensor - if tensor is None: - tensor_key = feat_spec.index_key if not self.modify_graph else feat_spec.name - raise KeyError(f"Key `{tensor_key}` does not exist in batch_tensor_dict.") - same_table_tensor_list.append(tensor) - return same_table_tensor_list - - # 改图模式下FeatureSpec是按照lookup顺序创建的,无需对ids进行排序;fs模式下手动创建FeatureSpec,不一定有序 - if not self.modify_graph: - same_table_feature_spec = sorted(same_table_feature_spec, key=lambda x: x.name) - mock_feature_spec = FeatureSpec(f"mock_feature_spec_{table_name}", table_name=table_name) - - if get_use_static(): - tensor_list = [] - tensor_split_list = [feat_spec.split for feat_spec in same_table_feature_spec] - total_feature_count = sum(tensor_split_list) - else: - tensor_list = get_tensor_list() - tensor_split_list = [tf.math.reduce_prod(array_ops.shape(tensor)) for tensor in tensor_list] - total_feature_count = tf.add_n(tensor_split_list) - set_temporary_feature_spec_attribute(mock_feature_spec, total_feature_count) - - kwargs["multi_lookup"] = True - total_send_count = self.same_table_send_count if self.modify_graph else send_count * same_table_spec_count - lookup_result = self.lookup_for_asc_with_feature_spec_inner(mock_feature_spec, total_send_count, **kwargs) - logger.debug("multi lookup table %s via %s.", table_name, tensor_split_list) - self.split_lookup_result(same_table_feature_spec, tensor_split_list, tensor_list, lookup_result, - is_training) - # 当一表多查完成后,将此表对应的feature specs列表清空,便于estimator模式下多轮eval时不会累加上轮eval的feature specs - clear_same_table_feature_spec(self.table_name, is_training) - - if not self.modify_graph: - self.check_multi_lookup_times(is_training) - if not kwargs.get("is_grad"): - return tf.stop_gradient(self.lookup_result.get(spec_name).get(is_training), name="stop_grad_lookup_res") - return self.lookup_result.get(spec_name).get(is_training) - - def split_lookup_result(self, same_table_feature_spec: list, tensor_split_list: list, tensor_list: list, - lookup_result: tf.Tensor, is_training: bool): - """ - Splits the result of the merge sparse lookup. - - Args: - same_table_feature_spec: a list of feature specs in a same table - tensor_split_list: a list of tensor split in a same table - tensor_list: a list of tensor in a same table - lookup_result: results of the sparse lookup - is_training: indicates whether the training mode is used. - - Returns: None - - """ - lookup_result_split = tf.split(lookup_result, tensor_split_list) - if len(lookup_result_split) != len(same_table_feature_spec) or ( - not get_use_static() and len(same_table_feature_spec) != len(tensor_list)): - raise RuntimeError(f"shape not match. len(lookup_result_split): {len(lookup_result_split)}," - f"len(same_table_feature_spec): {len(same_table_feature_spec)}" - f"len(tensor_list): {len(tensor_list)}") - for idx, (one_feature_spec, one_result) in enumerate(zip(same_table_feature_spec, lookup_result_split)): - if one_feature_spec.name not in self.lookup_result: - self.lookup_result[one_feature_spec.name] = {} - if get_use_static(): - dest_shape = one_feature_spec.dims + [self.scalar_emb_size] - else: - dest_shape = array_ops.concat([array_ops.shape(tensor_list[idx]), [self.scalar_emb_size]], 0) - self.lookup_result[one_feature_spec.name][is_training] = array_ops.reshape(one_result, dest_shape) - - def lookup_for_asc_with_feature_spec_inner(self, feature_spec: FeatureSpec, send_count: int, **kwargs): - """ - Args: - feature_spec: an instance of FeatureSpec to lookup from hashtable - send_count: int, used to config all2all communication parameters - kwargs: - dim: not in use - is_train: - name: not in use - modify_graph: if True, the original graph will be modified before building a Session instance - - Returns: Tensor for lookup result - - """ - logger.debug(f"Enter ASC Branch, looking up with FeatureSpec.") - is_training = kwargs.get("is_train") - self.check_and_format_lookup_params(feature_spec, send_count, is_training) - rank_size = get_rank_size() - device_id = get_device_id() - use_hot = get_use_hot() - use_dynamic_expansion = get_use_dynamic_expansion() - - # check training mode order and ensure channel id - channel_id = get_training_mode_channel_id(is_training=is_training) - logger.debug("get preprocessed tensor for asc for table %s with skip emb transfer %s is_training: %s, " - "channel_id: %s .", self.table_name, self.skip_emb_transfer, is_training, channel_id) - config = dict(batch_size=feature_spec.batch_size, feat_cnt=feature_spec.feat_cnt, send_count=send_count, - rank_size=rank_size, channel_id=channel_id, table_name=self.table_name, - skip_emb_transfer=self.skip_emb_transfer, ext_emb_size=self.ext_emb_size, - emb_size=self.emb_size, use_hot=use_hot, device_id=device_id, - use_dynamic_expansion=use_dynamic_expansion) - - if self.skip_emb_transfer: - result = get_preprocessed_tensor_for_asc(self.variable, config) - else: - variable_list = [self.variable] + [slot_info.get("slot") for slot_info in self.optimizer_slot_info_list] - result = get_preprocessed_tensor_for_asc(variable_list, config) - restore_vector = result.get("restore_vector") - restore_vector_second = result.get("restore_vector_second") - hot_pos = result.get("hot_pos") - id_offsets = result.get("id_offsets") - unique_keys = result.get("unique_keys") - swap_in = result.get("swap_in") - all2all_matrix = result.get("all2all_args") - control_ops = swap_in - - id_offsets = tf.identity(id_offsets, name="identity_addr") - restore_vector = tf.identity(restore_vector, name="identity_restore") - - use_static = get_use_static() - host_pipeline_ops = get_host_pipeline_ops() - - @tf.custom_gradient - def sparse_forward(table): - logger.debug("fp rank size: %s", rank_size) - if not use_dynamic_expansion: - id_offsets_abs = tf.abs(id_offsets) - local_embeddings = tf.gather(table, id_offsets_abs, axis=0, name="gather_for_id_offsets") - local_embeddings = set_specific_value_for_non_valid_key(id_offsets, - local_embeddings, - feature_spec.access_threshold, - kwargs.get("serving_default_value"), - is_training=is_training) - else: - local_embeddings = tf.identity(table, name="identity_local_emb") - - all2all_args = send_count if use_static else all2all_matrix - unique_embeddings = self._get_own_emb(local_embeddings, all2all_args, self.scalar_emb_size, use_static) - - if hot_pos is not None: - unique_embeddings = tf.concat([tf.gather(unique_embeddings, hot_pos, name="hot_pos"), - unique_embeddings], axis=0) - if use_static: - unique_embeddings_shape = unique_embeddings.shape.as_list() - else: - unique_embeddings_shape = tf.shape(unique_embeddings) - - # 用于打桩的op节点,它的name用于标识此次的sparse lookup是train还是eval - # 后续在session run的时候,通过图反向查找该子图中查找到此op - # 最后通过名称判断session run是调用的哪个通道,并通知c++侧进行计数和唤醒操作 - notify_hybridmgmt_op = self.generate_lookup_id_notify_hybrid(channel_id) - with tf.control_dependencies([notify_hybridmgmt_op]): - embeddings = tf.gather(unique_embeddings, restore_vector, axis=0, name="gather_for_restore_vector") - - if use_static: - lookup_result = tf.reshape(embeddings, feature_spec.dims + [self.scalar_emb_size]) - else: - if kwargs.get("multi_lookup"): - lookup_result = tf.reshape(embeddings, [-1, self.scalar_emb_size]) - else: - feature_spec_tensor = None - if not self.modify_graph: - feature_spec_tensor = kwargs.get("batch").get(feature_spec.index_key) - modify_graph_tensor = kwargs.get("lookup_ids") - tensor = feature_spec_tensor if not self.modify_graph else modify_graph_tensor - if tensor is None: - raise KeyError(f"key or ids does not exist in batch, now modify graph is {self.modify_graph}.") - dest_shape = array_ops.concat([array_ops.shape(tensor), [self.scalar_emb_size]], 0) - lookup_result = array_ops.reshape(embeddings, dest_shape) - - def grad(lookup_diff): # pragma: no cover - logger.debug("Into lookup grad function, feature spec name: %s.", feature_spec.name) - embedding_diff = tf.reshape(lookup_diff, [-1, self.scalar_emb_size]) - unique_grads = tf.compat.v1.unsorted_segment_sum(embedding_diff, - restore_vector, - unique_embeddings_shape[0]) - bp_all2all_args = all2all_args if use_static else tf.transpose(all2all_args) - if hot_pos is not None: - hot, cold = tf.split(unique_grads, [tf.shape(hot_pos)[0], - tf.shape(unique_grads)[0] - tf.shape(hot_pos)[0]], axis=0) - unique_grads = tf.tensor_scatter_nd_add(cold, tf.expand_dims(hot_pos, 1), hot) - local_grad = self._get_own_emb(unique_grads, bp_all2all_args, self.scalar_emb_size, use_static) - if self.all2all_gradients_op == All2allGradientsOp.SUM_GRADIENTS_AND_DIV_BY_RANKSIZE: - try: - local_grad = local_grad / get_rank_size() - except ZeroDivisionError as exp: - raise ZeroDivisionError("Rank size cannot be zero.") from exp - - if use_dynamic_expansion: - if global_env.apply_gradients_strategy == \ - ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY.value: - update_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - restore_vector_second, - array_ops.shape(unique_keys)[0]) - else: - update_grad = local_grad - else: - if global_env.apply_gradients_strategy == \ - ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY.value: - unique_local_grad = tf.compat.v1.unsorted_segment_sum(local_grad, - restore_vector_second, - array_ops.shape(unique_keys)[0]) - update_grad = ops.IndexedSlices(values=unique_local_grad, - indices=unique_keys, - dense_shape=tf.shape(table)) - else: - update_grad = ops.IndexedSlices(values=local_grad, - indices=id_offsets, - dense_shape=tf.shape(table)) - return update_grad - - return lookup_result, grad - - with tf.control_dependencies(control_ops): - if not use_dynamic_expansion: - return sparse_forward(self.variable) - - local_embeddings = \ - host_pipeline_ops.embedding_lookup_by_address(id_offsets, embedding_dim=self.emb_size, - embedding_type=1) - - is_table_name_valid = ASCEND_TABLE_NAME_MUST_CONTAIN is None or \ - ASCEND_TABLE_NAME_MUST_CONTAIN in self.table_name - - def add_to_collection(): - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB, local_embeddings) - if global_env.apply_gradients_strategy == ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY.value: - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS, unique_keys) - else: - tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET, id_offsets) - - logger.debug("feature spec mode, table_name: %s, ASCEND_TABLE_NAME_MUST_CONTAIN: %s", - self.table_name, ASCEND_TABLE_NAME_MUST_CONTAIN) - - if is_training and is_table_name_valid: - add_to_collection() - - return sparse_forward(local_embeddings) - - def _record(self): - insert_table_instance(self.table_name, self.variable, self) - logger.debug("Device vocabulary_size for table %s is %s.", self.table_name, self.device_vocabulary_size) - logger.debug("Slice_device_vocabulary_size for table %s is %s.", - self.table_name, self.slice_device_vocabulary_size) - logger.debug(f"Host vocabulary size for table %s is %s.", self.table_name, self.host_vocabulary_size) - logger.debug(f"Slice host vocabulary_size for table %s is %s.", - self.table_name, self.slice_host_vocabulary_size) - logger.debug(f"SSD vocabulary size for table %s is %s.", self.table_name, self.ssd_vocabulary_size) - logger.debug(f"Slice ssd vocabulary_size for table %s is %s.", self.table_name, self.slice_ssd_vocabulary_size) - - def _initialize_variables(self): - initialized_tensor = \ - self.emb_initializer(self.slice_device_vocabulary_size + self.embedding_size) * self.init_param - - self.variable = tf.compat.v1.get_variable(self.table_name, trainable=False, initializer=initialized_tensor) - # make sure sparse table variable will not be saved and restored within tf checkpoint. - insert_removing_var_list(self.variable.name) - self._record() - - if self.use_dynamic_expansion: - for sparse_optimizer_instance in self._optimizer_instance_list: - self._slot_num[self.table_name] = sparse_optimizer_instance.slot_num - logger.info("init emb, table name: %s, slot_num: %s", - self.table_name, sparse_optimizer_instance.slot_num) - - if not self.skip_emb_transfer: - # build optimizer states - for sparse_optimizer_instance in self._optimizer_instance_list: - slot_info_list = sparse_optimizer_instance.initialize_slots(self.variable, self) - self.optimizer_slot_info_list.extend(slot_info_list) - - for slot_info in self.optimizer_slot_info_list: - self.set_optimizer_slot(slot_info) + # 动态扩容 + if ConfigInitializer.get_instance().use_dynamic_expansion: + return HBMDynamicSparseEmbeddingFactory().create_embedding(config) + # DDR or SSD + if host_vocabulary_size > 0: + return ExternalStorageSparseEmbeddingFactory().create_embedding(config) + # HBM + return HBMSparseEmbeddingFactory().create_embedding(config) @para_checker_decorator(check_option_list=[ - ("hashtable", ClassValidator, {"classes": (SparseEmbedding, )}), + ("hashtable", ClassValidator, {"classes": (BaseSparseEmbedding, )}), ("ids", ClassValidator, {"classes": (FeatureSpec, tf.Tensor)}), ("is_train", ClassValidator, {"classes": (bool, )}), ("send_count", ClassValidator, {"classes": (int, type(None))}), @@ -810,7 +121,7 @@ class SparseEmbedding: ("is_grad", ClassValidator, {"classes": (bool, )}), ("serving_default_value", ClassValidator, {"classes": (tf.Tensor, type(None))}) ]) -def sparse_lookup(hashtable: SparseEmbedding, +def sparse_lookup(hashtable: BaseSparseEmbedding, ids: Union[FeatureSpec, tf.Tensor], send_count: Optional[int] = None, is_train: bool = True, @@ -847,85 +158,29 @@ def sparse_lookup(hashtable: SparseEmbedding, kwargs["modify_graph"] = modify_graph kwargs["batch"] = batch kwargs["access_and_evict_config"] = access_and_evict_config + kwargs["serving_default_value"] = serving_default_value # 参数由内部创建,不使用外部入参,覆盖外部入参 kwargs["feature_spec_name_ids_dict"] = None kwargs["multi_lookup"] = False kwargs["lookup_ids"] = None - kwargs["serving_default_value"] = serving_default_value - scope_name = "{0}//{1}".format(hashtable.table_name, kwargs.get("name")) logger.info("Lookup: The table name is %s, and the value of `is_grad` in this lookup (lookup name is %s) is %s.", hashtable.table_name, name, is_grad) # 对于向上找没有IteratorGetNext的孤儿ids需要标记,以便于后续ACGPushOpsToDataset工作 if isinstance(ids, tf.Tensor): - ids = _tag_orphan_ids(ids) + ids = tag_orphan_ids(ids) - with tf.compat.v1.variable_scope(scope_name): + with tf.compat.v1.variable_scope("{0}//{1}".format(hashtable.table_name, kwargs.get("name"))): if isinstance(ids, FeatureSpec): # check whether the name of the table exists with FeatureSpec. if hashtable.table_name != ids.table_name: raise ValueError(f"The table name '{ids.table_name}' specified by FeatureSpec is inconsistent with" f" the SparseEmbedding table name '{hashtable.table_name}'.") - return hashtable.lookup_for_asc_with_feature_spec(ids, send_count, **kwargs) + return hashtable.lookup_for_feat_spec(ids, send_count, **kwargs) if not modify_graph: raise ValueError("'ids' is type of tf.Tensor, 'modify_graph' should be set to True") - set_modify_graph(modify_graph) - return hashtable.lookup_for_asc(ids, send_count, **kwargs) - - -def set_specific_value_for_non_valid_key(id_offsets: Optional[tf.Tensor], - embeddings: Optional[tf.Tensor], - access_threshold: Optional[int], - serving_default_value: Optional[tf.Tensor] = None, - is_training: bool = True): - """ - 将key为-1(无效值)的特征对应的emb置为0或者指定值 - :param id_offsets: 特征索引 - :param embeddings: 稀疏表 - :param access_threshold: 准入阈值 - :param serving_default_value: 参考create_table接口描述 - :param is_training: 当前流程是训练还是推理 - :return: - """ - # 在训练时,仅当开启准入功能才会出现无效值;推理时,是否开启准入都可能存在无效值 - if is_training and (access_threshold is None or access_threshold < 0): - return embeddings - - if serving_default_value is None: - # 未设置时,默认无效值的emb为全0 - default_value = tf.zeros_like(embeddings) - else: - try: - default_value = tf.broadcast_to(serving_default_value, tf.shape(embeddings)) - except ValueError as e: - logger.error("failed to broadcast serving_default_value to target embedding , please check its shape.") - raise e - except Exception as e: - logger.error("failed to process serving_default_value.") - raise e - - if tf.__version__.startswith("1"): - id_offsets_expand = tf.math.greater_equal(id_offsets, 0) - embeddings = tf.where(id_offsets_expand, embeddings, default_value) - return embeddings - - id_offsets_expand = tf.compat.v1.expand_dims(id_offsets >= 0, axis=-1) - embeddings = tf.where(id_offsets_expand, embeddings, default_value) - return embeddings - - -def _tag_orphan_ids(ids: tf.Tensor) -> tf.Tensor: - """ - 将孤儿ids使用identity操作创建ACG_PUSH_NODE前缀命名的标记节点,以便在PushOps时能找到。 - """ - graph_def = tf.compat.v1.get_default_graph().as_graph_def() - subgraph = tf.compat.v1.graph_util.extract_sub_graph(graph_def, [ids.op.name]) - for node in subgraph.node: - if node.name == 'IteratorGetNext': - return ids - new_ids = tf.identity(ids, name=f"ACG_PUSH_NODE_{ids.op.name}") - logger.info('Tag orphan op node: %s with %s.', ids, new_ids) - return new_ids + ConfigInitializer.get_instance().modify_graph = modify_graph + return hashtable.lookup(ids, send_count, **kwargs) diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py index 6c19d41d..2b663ce2 100644 --- a/mx_rec/core/feature_process.py +++ b/mx_rec/core/feature_process.py @@ -21,9 +21,9 @@ import tensorflow as tf from mx_rec.util.tf_version_adapter import npu_ops from mx_rec.constants.constants import DEFAULT_EVICT_TIME_INTERVAL, TRAIN_CHANNEL_ID, MAX_INT32 -from mx_rec.util.initialize import trigger_evict, get_table_instance_by_name, export_feature_spec -from mx_rec.validator.validator import para_checker_decorator, ClassValidator, IntValidator, OptionalIntValidator +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.log import logger +from mx_rec.validator.validator import para_checker_decorator, ClassValidator, IntValidator, OptionalIntValidator class EvictHook(tf.compat.v1.train.SessionRunHook): @@ -96,15 +96,16 @@ class EvictHook(tf.compat.v1.train.SessionRunHook): if cur_time - self._start_time > self._evict_time_interval or \ (self._evict_step_interval is not None and self._global_step % self._evict_step_interval == 0): logger.info("_EvictHook - > evict switch on!!! after_run step: %d", self._global_step) - if not trigger_evict(): + if not ConfigInitializer.get_instance().hybrid_manager_config.trigger_evict(): return self._start_time = cur_time for name in self._hash_table_instance.keys(): run_context.session.run(self._evict_op.get(name)) def check_name_and_get_hashtable(self): - for _, feature_spec in export_feature_spec().items(): + for _, feature_spec in ConfigInitializer.get_instance().feature_spec_config.feature_spec_dict.items(): if feature_spec.eviction_threshold: logger.debug("_EvictHook - > check and get instance: table_names %s", feature_spec.table_name) - self._hash_table_instance[feature_spec.table_name] = get_table_instance_by_name(feature_spec.table_name) - + self._hash_table_instance[feature_spec.table_name] = \ + ConfigInitializer.get_instance().sparse_embed_cofnig.get_table_instance_by_name( + feature_spec.table_name) diff --git a/mx_rec/graph/acg_push_ops.py b/mx_rec/graph/acg_push_ops.py index 06d89a45..625ef92f 100644 --- a/mx_rec/graph/acg_push_ops.py +++ b/mx_rec/graph/acg_push_ops.py @@ -15,17 +15,11 @@ # limitations under the License. # ============================================================================== -import os -import weakref -from dataclasses import dataclass -from typing import Dict, Tuple, FrozenSet, List, Set +from typing import Dict, Tuple, List, Set import tensorflow as tf -from tensorflow.python.framework import ops -from tensorflow.python.data.ops.dataset_ops import _VariantTracker from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter from tensorflow.python.framework.ops import Operation -from tensorflow.python.data.ops.dataset_ops import DatasetV2 from tensorflow.python.util import nest as tf_nest from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import attr_value_pb2 @@ -34,7 +28,8 @@ from tensorflow.python.framework import tensor_util from mx_rec.graph import modifier from mx_rec.util.log import logger from mx_rec.graph.utils import export_pb_graph -from mx_rec.constants.constants import ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE +from mx_rec.graph.graph_typing import SubgraphInfo +from mx_rec.constants.constants import ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE, AnchorIteratorOp from mx_rec.validator.validator import para_checker_decorator, ClassValidator tf.compat.v1.disable_eager_execution() @@ -44,7 +39,7 @@ _ACG_NEW_ITERATOR = "ACG_NEW_ITERATOR" _ACG_NEW_INITIALIZER = "ACG_NEW_INITIALIZER" _OP_TYPE_TO_PUSH = frozenset(["StringSplit", "StringToNumber"]) -_OP_TYPE_TO_IGNORE = frozenset(["IteratorGetNext"]) +_OP_TYPE_TO_IGNORE = frozenset([AnchorIteratorOp.ITERATOR_GET_NEXT]) _OP_TYPE_CONTAIN_STRING_TO_IGNORE = frozenset(["Dataset", "Summary"]) _OP_NAME_CONTAIN_STRING_TO_IGNORE = frozenset(["save", "report_", "loss"]) _OP_NAME_CONTAIN_STRING_TO_PUSH = frozenset(["ACG_PUSH_NODE"]) @@ -55,13 +50,6 @@ _VARIABLE_TYPES = frozenset(["Variable", "VariableV2", "VarHandleOp"]) _IGNORE_REPLACE_NODE = frozenset(["Assign", "SaveV2"]) -@dataclass -class SubgraphInfo: - subgraph_in: Dict[tf.Operation, Set[tf.Operation]] - subgraph_out: Dict[tf.Operation, Set[tf.Operation]] - subgraph_to_push: Set[tf.Operation] - - class ACGPushOpsToDatasetHook(tf.estimator.SessionRunHook): @para_checker_decorator( check_option_list=[ @@ -132,7 +120,9 @@ def _find_ops_to_be_pushed(graph: tf.Graph, dump_graph: bool = False): return logger.info("Found operations should be pushed: %s.", nodes_to_push) - subgraph_nodes = _find_subgraph_nodes(graph, nodes_to_push, tgt_op_type="IteratorGetNext", exclude_tgt_op=True) + subgraph_nodes = _find_subgraph_nodes( + graph, nodes_to_push, tgt_op_type=AnchorIteratorOp.ITERATOR_GET_NEXT.value, exclude_tgt_op=True + ) _push_subgraph_to_dataset(graph, subgraph_nodes, dump_graph) export_pb_graph("after_push_graph.pbtxt", dump_graph, graph_def=graph.as_graph_def()) @@ -199,7 +189,7 @@ def _find_op_from_base_op(base_ops: tf.Operation, target_op_type: str) -> tf.Ope def _get_dataset_op(graph: tf.Graph, get_next_op: Operation) -> Operation: - if get_next_op.type != "IteratorGetNext": + if get_next_op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: raise TypeError("Op '{get_next_op}' must be one instance of IteratorGetNext.") # looking for the MakeIterator operator which corresponds to given batch_tensor base_op = modifier.find_make_iterator_op(get_next_op.outputs[0]) @@ -255,7 +245,7 @@ def _push_subgraph_to_dataset(graph: tf.Graph, subgraph_to_push: Set[tf.Operatio logger.info("Got input tensor of extracted subgraph: %s", subgraph_in) logger.info("Got output tensor of extracted subgraph: %s", subgraph_out) - get_next_node = graph.get_operation_by_name("IteratorGetNext") + get_next_node = graph.get_operation_by_name(AnchorIteratorOp.ITERATOR_GET_NEXT.value) src_dataset = _get_src_dataset(graph, get_next_node) def acg_func(*x): # pragma: no cover @@ -392,11 +382,11 @@ def _clone_subgraph_into_funcgraph( def _get_mapping_for_subgraph_in( from_node: tf.Operation, to_nodes: Set[tf.Operation], x: List[tf.Tensor], tensor_mapping ): - if from_node.type != "IteratorGetNext": + if from_node.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: raise RuntimeError(f"Expect IteratorGetNext for input tensor of subgraph, but got {from_node}") for node in to_nodes: for each_tensor in node.inputs: - if each_tensor.op.type != "IteratorGetNext": + if each_tensor.op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: continue old_tensor_name = each_tensor.name x_index = int(old_tensor_name.split(":")[-1]) @@ -493,7 +483,7 @@ def _topo_subgraph(subgraph: Set[tf.Operation]) -> List[tf.Operation]: output_set.add(curr_node) for tensor in curr_inputs: node = tensor.op - if node.type != "IteratorGetNext" and node not in output_set: + if node.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value and node not in output_set: topo_subgraph_dfs(node, output_list, output_set) output_list.append(curr_node) @@ -518,13 +508,13 @@ def _update_iterator_getnext( iterator_type = get_next_op.inputs[0].op.type if iterator_type == "IteratorV2": iterator_type = modifier.find_make_iterator_op(get_next_op.outputs[0]).type - if iterator_type not in ("MakeIterator", "OneShotIterator"): + if iterator_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, AnchorIteratorOp.ONE_SHOT_ITERATOR.value): raise RuntimeError( f"Only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, " f"but the current iterator is `{iterator_type}`." ) logger.info("The iterator type of dataset is %s.", iterator_type) - if iterator_type == "MakeIterator": + if iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value: new_iterator = tgt_dataset.make_initializable_iterator() logger.info("Got new_iterator: %s, new_iterator.initializer: %s.", new_iterator, new_iterator.initializer) graph.add_to_collection(_ACG_NEW_INITIALIZER, new_iterator.initializer) @@ -550,7 +540,7 @@ def _update_iterator_getnext( ) except IndexError as err: raise IndexError("Cannot find a tensor from given batch.") from err - new_get_next_op = _find_op_from_base_op(new_batch_tensor.op, "IteratorGetNext") + new_get_next_op = _find_op_from_base_op(new_batch_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value) logger.info("Got new_get_next_op: %s.", new_get_next_op) _replace_get_next_op(graph, get_next_op, new_get_next_op, subgraph_out, subgraph_to_push) diff --git a/mx_rec/graph/graph_typing.py b/mx_rec/graph/graph_typing.py new file mode 100644 index 00000000..c11bd4c0 --- /dev/null +++ b/mx_rec/graph/graph_typing.py @@ -0,0 +1,35 @@ +# !/usr/bin/env python3 +# -- coding: utf-8 -- +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + +import dataclasses +from typing import Dict, DefaultDict, List, Tuple, Set + +from tensorflow import Operation, Tensor +from tensorflow.core.framework.graph_pb2 import GraphDef + + +# DefaultDict: +# Key: Tensor => Represent output tensor of `IteratorGetNext` operation. +# Val: List[Tuple[int, Operation]] => Contains target operation of output tensor and it's corresponding index. +ReplacementSpec = DefaultDict[Tensor, List[Tuple[int, Operation]]] + + +@dataclasses.dataclass +class AnchorRecord: + replacement_spec: ReplacementSpec + passing_tensors: List[Tensor] + batch_tensor_indexs: List[int] + sub_cutting_points: List[Tensor] + sub_graph_def: GraphDef + input_names: List[str] + output_names: List[str] + is_training: bool + input_indexs: List[int] = None + + +@dataclasses.dataclass +class SubgraphInfo: + subgraph_in: Dict[Operation, Set[Operation]] + subgraph_out: Dict[Operation, Set[Operation]] + subgraph_to_push: Set[Operation] diff --git a/mx_rec/graph/merge_lookup.py b/mx_rec/graph/merge_lookup.py index 4e54b2a2..8a11e515 100644 --- a/mx_rec/graph/merge_lookup.py +++ b/mx_rec/graph/merge_lookup.py @@ -18,9 +18,9 @@ import tensorflow as tf from mx_rec.constants.constants import ASCAnchorAttr, ASCEND_SPARSE_LOOKUP_ENTRANCE -from mx_rec.core.embedding import SparseEmbedding from mx_rec.graph.utils import check_cutting_points, replace_anchor_vec -from mx_rec.util.initialize import get_modify_graph, get_merged_multi_lookup, insert_merged_multi_lookup, get_use_static +from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.log import logger @@ -39,10 +39,10 @@ def do_merge_lookup(is_train: bool = True): """ - if not get_modify_graph(): + if not ConfigInitializer.get_instance().modify_graph: logger.debug("The `do_merge_multi_lookup` function is called only for `modify graph` mode.") return - if get_merged_multi_lookup(is_train): + if ConfigInitializer.get_instance().train_params_config.get_merged_multi_lookup(is_train): logger.debug("The merge multi lookup has been executed once and does not need to be executed again.") return logger.info("start to merge multi lookup, mode(train: True, eval: False): %s.", is_train) @@ -57,15 +57,15 @@ def do_merge_lookup(is_train: bool = True): sub_cutting_points_dict = dict() feature_spec_name_ids_dict = dict() for cutting_point in cutting_point_list: - is_training = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_TRAINING) + is_training = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_TRAINING) if is_training != is_train: logger.debug("Skip! The current mode(train: True, eval: False) is %s, but the mode of %s is %s.", is_train, cutting_point, is_training) continue - table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) - if not get_use_static() and len(table_instance.lookup_name_dict.get(is_train)) > 1: - feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) + table_instance = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) + if not ConfigInitializer.get_instance().use_static and table_instance.multi_lookup_times.get(is_train) > 1: + feature_spec = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) feature_spec_name_ids_dict[feature_spec.name] = cutting_point if sub_cutting_points_dict.get(is_training) is None: sub_cutting_points_dict[is_training] = [] @@ -78,22 +78,22 @@ def do_merge_lookup(is_train: bool = True): f"have anchor ids.") for cutting_point in sub_cutting_point_list: - table_instance = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) - feature_spec = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) - is_grad = SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_GRAD) - if len(table_instance.lookup_name_dict.get(is_train)) == 1: + table_instance = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.TABLE_INSTANCE) + feature_spec = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC) + is_grad = BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.IS_GRAD) + if table_instance.multi_lookup_times.get(is_train) == 1: logger.debug("The origin lookup result of %s for %s does not need to be replaced.", feature_spec.name, table_instance.table_name) continue send_count = table_instance.send_count kwargs = dict(is_train=is_train, lookup_ids=cutting_point, multi_lookup=True, is_grad=is_grad) - if not get_use_static(): + if not ConfigInitializer.get_instance().use_static: kwargs["feature_spec_name_ids_dict"] = feature_spec_name_ids_dict - lookup_result = table_instance.lookup_for_asc_with_feature_spec(feature_spec, send_count, **kwargs) + lookup_result = table_instance.lookup_for_feat_spec(feature_spec, send_count, **kwargs) replace_anchor_vec(cutting_point, ASCAnchorAttr.MOCK_LOOKUP_RESULT, lookup_result) logger.debug("The mock lookup result of %s for %s was replaced.", feature_spec.name, table_instance.table_name) # records whether the current mode has been merged or restored lookup - insert_merged_multi_lookup(is_train, True) + ConfigInitializer.get_instance().train_params_config.insert_merged_multi_lookup(is_train, True) logger.info("finish to merge multi lookup, mode(train: True, eval: False): %s.", is_train) diff --git a/mx_rec/graph/modifier.py b/mx_rec/graph/modifier.py index e41723a1..b936f7e5 100644 --- a/mx_rec/graph/modifier.py +++ b/mx_rec/graph/modifier.py @@ -16,41 +16,39 @@ # ============================================================================== from collections import defaultdict -from typing import Any, List, Dict, DefaultDict, Tuple, Union from collections.abc import Callable +from typing import Any, List, Dict, Tuple import tensorflow as tf -from tensorflow import Tensor +from tensorflow import Operation, Tensor +from tensorflow.core.framework.graph_pb2 import GraphDef from tensorflow.python.data.ops.dataset_ops import DatasetV1Adapter -from tensorflow.python.framework.ops import Operation from tensorflow.python.framework.errors_impl import InvalidArgumentError -from tensorflow.core.framework.graph_pb2 import GraphDef -from mx_rec.core.asc.helper import get_asc_insert_func +from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ + ASCAnchorAttr, ASCEND_TIMESTAMP, MAX_WHILE_SIZE, LIBREC_EOS_OPS_SO, AnchorDatasetOp, \ + AnchorIteratorOp from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline -from mx_rec.core.embedding import SparseEmbedding -from mx_rec.constants.constants import ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, \ - ASCAnchorAttr, ASCEND_TIMESTAMP, ANCHOR_DATASET_NAME, MAX_WHILE_SIZE, LIBREC_EOS_OPS_SO -from mx_rec.util.initialize import get_feature_spec, insert_feature_spec, set_initializer, \ - get_training_mode_channel_id, set_is_graph_modify_hook_running, get_bool_gauge_set, \ - insert_merged_multi_lookup, get_merged_multi_lookup, set_target_batch, get_iterator_type, \ - set_iterator_type +from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding +from mx_rec.graph.merge_lookup import do_merge_lookup +from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ + export_pb_graph, make_sorted_key_to_tensor_list +from mx_rec.graph.graph_typing import AnchorRecord, ReplacementSpec +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.util.log import logger from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.util.perf import performance -from mx_rec.graph.utils import check_input_list, find_parent_op, check_cutting_points, record_ops_to_replace, \ - export_pb_graph, make_sorted_key_to_tensor_list, ReplacementSpec, AnchorRecord -from mx_rec.graph.merge_lookup import do_merge_lookup from mx_rec.validator.validator import para_checker_decorator, ClassValidator -from mx_rec.util.log import logger def get_preprocessing_map_func( - graph_def: GraphDef, - input_names: List[str], - output_names: List[str], - batch_tensor_names: List[str] = None, - pipeline_input_indexes: List[int] = None + graph_def: GraphDef, + input_names: List[str], + output_names: List[str], + batch_tensor_names: List[str] = None, + pipeline_input_indexes: List[int] = None ) -> Callable: input_names = check_input_list(input_names, str) output_names = check_input_list(output_names, str) @@ -62,54 +60,7 @@ def get_preprocessing_map_func( raise ValueError("It is legal when and only when one of the parameters 'batch_tensor_names' and " "'pipeline_input_indexes' was given.") - def map_func(*args): # pragma: no cover - def parse_batch(data_args: Any, data_batch: dict, key: str = None): - """ - 解析原始数据集中的batch,并将非dict格式的batch转为dict格式. - Args: - data_args: 待解析的batch - data_batch: 解析后的batch - key: batch中的key - - Returns: None - - """ - - def parse_tensor(data_tensor: Tensor, data_batch: dict, key: str = None): - """ - 将待解析batch中的tensor写入解析后的batch中,如果key存在则使用原key,不存在则生成batch中字典序最小的key. - Args: - data_tensor: 待解析batch中的tensor - data_batch: 解析后的batch - key: batch中的key - - Returns: None - - """ - - if key is not None: - data_batch[key] = data_tensor - return - - last_key = f"{sorted(data_batch)[-1]}_last_key" - data_batch[last_key] = data_tensor - - # 开始解析old batch - if isinstance(data_args, dict): - for key, data_tensor in data_args.items(): - parse_batch(data_tensor, data_batch, key) - return - elif isinstance(data_args, (list, tuple)): - for data_arg in data_args: - parse_batch(data_arg, data_batch, key) - return - elif isinstance(data_args, Tensor): - # 将old batch中的tensor加入到dict中 - parse_tensor(data_args, data_batch, key) - return - else: - raise ValueError("Encounter a invalid batch.") - + def map_func(*args): logger.debug("In get_preprocessing_map_func, the old batch is: %s.", args) batch = dict() parse_batch(args, batch, key=None) @@ -141,12 +92,60 @@ def get_preprocessing_map_func( return map_func +def parse_batch(data_args: Any, data_batch: dict, key: str = None): + """ + 解析原始数据集中的batch,并将非dict格式的batch转为dict格式. + Args: + data_args: 待解析的batch + data_batch: 解析后的batch + key: batch中的key + + Returns: None + + """ + + def parse_tensor(data_tensor: Tensor, data_batch: dict, key: str = None): + """ + 将待解析batch中的tensor写入解析后的batch中,如果key存在则使用原key,不存在则生成batch中字典序最小的key. + Args: + data_tensor: 待解析batch中的tensor + data_batch: 解析后的batch + key: batch中的key + + Returns: None + + """ + + if key is not None: + data_batch[key] = data_tensor + return + + last_key = f"{sorted(data_batch)[-1]}_last_key" + data_batch[last_key] = data_tensor + + # 开始解析old batch + if isinstance(data_args, dict): + for key, data_tensor in data_args.items(): + parse_batch(data_tensor, data_batch, key) + return + if isinstance(data_args, (list, tuple)): + for data_arg in data_args: + parse_batch(data_arg, data_batch, key) + return + if isinstance(data_args, Tensor): + # 将old batch中的tensor加入到dict中 + parse_tensor(data_args, data_batch, key) + return + + raise ValueError(f"Invalid batch type, expected: (dict, list, tuple, Tensor), got: {type(data_args)}.") + + def get_input_index_list( - cutting_point_list: List[Tensor], - replacement_specs: ReplacementSpec, - mapping_name_list: List[str], - base_count: int, - timestamp_index: int = None + cutting_point_list: List[Tensor], + replacement_specs: ReplacementSpec, + mapping_name_list: List[str], + base_count: int, + timestamp_index: int = None ) -> List[int]: input_index_list = [] for cutting_point in cutting_point_list: @@ -171,7 +170,7 @@ def find_make_iterator_op(batch_tensor: Tensor) -> Operation: for each_op in operations: for input_tensor in batch_tensor.op.inputs: if input_tensor.op.outputs and input_tensor.op.outputs[0] in list( - each_op.inputs) and each_op.type == "MakeIterator": + each_op.inputs) and each_op.type == AnchorIteratorOp.MAKE_ITERATOR.value: logger.debug("Op MakeIterator '%s' was found.", each_op.name) return each_op @@ -204,7 +203,8 @@ def find_target_dataset_op(base_ops: Operation, op_type: str) -> Operation: def get_dataset_op(get_next_op: Operation) -> Operation: """ - 根据`IteratorGetNext`算子从图中找到`OptimizeDataset`的dataset op. 注: TF2没有`OptimizeDataset`,则找的是dataset的默认锚点. + 根据`IteratorGetNext`算子从图中找到`OptimizeDataset`的dataset op. + 注: TF2没有`OptimizeDataset`,则找的是dataset的默认锚点. Args: get_next_op: `IteratorGetNext`算子 @@ -213,29 +213,29 @@ def get_dataset_op(get_next_op: Operation) -> Operation: """ - if get_next_op.type != "IteratorGetNext": + if get_next_op.type != AnchorIteratorOp.ITERATOR_GET_NEXT.value: raise TypeError("Op '{get_next_op}' must be one instance of IteratorGetNext.") # looking for the MakeIterator operator which corresponds to given batch_tensor base_op = find_make_iterator_op(get_next_op.outputs[0]) # looking for the op which is the one before OptimizeDataset operator if tf.__version__.startswith("1"): - optimize_dataset_op = find_target_dataset_op(base_op, "ModelDataset") + optimize_dataset_op = find_target_dataset_op(base_op, AnchorDatasetOp.MODEL_DATASET.value) target_op = find_parent_op(optimize_dataset_op) if not target_op: raise RuntimeError(f"The parent op for 'ModelDataset' op was not found.") - if target_op[0].type != "OptimizeDataset": + if target_op[0].type != AnchorDatasetOp.OPTIMIZE_DATASET.value: raise TypeError(f"Op OptimizeDataset was not found.") target_op = target_op[0] else: # 'OptimizeDataset' is not available in TensorFlow2.X - target_op = find_target_dataset_op(base_op, ANCHOR_DATASET_NAME) + target_op = find_target_dataset_op(base_op, AnchorDatasetOp.PREFETCH_DATASET.value) return target_op def get_passing_tensor_list( - src_tensors: List[Tensor], - target_op: Operation + src_tensors: List[Tensor], + target_op: Operation ) -> Tuple[List[Tensor], List[int], List[Tensor]]: def get_passing_tensors(src_tensor): passing_tensors = [] @@ -291,8 +291,8 @@ def find_target_instance_dataset(variant_tensor: Tensor) -> DatasetV1Adapter: def get_sub_graph( - input_tensors: List[Tensor], - output_tensors: List[Tensor] + input_tensors: List[Tensor], + output_tensors: List[Tensor] ) -> Tuple[GraphDef, List[str], List[str]]: input_tensors = check_input_list(input_tensors, tf.Tensor) output_tensors = check_input_list(output_tensors, tf.Tensor) @@ -377,29 +377,33 @@ def modify_graph_and_start_emb_cache(dump_graph: bool = False): def generate_get_next_op_specs( - cutting_point_list: List[Tensor], - dump_graph: bool = False -) -> Dict[Tensor, ReplacementSpec]: + cutting_point_list: List[Tensor], + dump_graph: bool = False +) -> Dict[Tensor, AnchorRecord]: get_next_op_map = defaultdict(dict) + for input_tensor in cutting_point_list: - get_next_op = find_target_dataset_op(input_tensor.op, "IteratorGetNext") + get_next_op = find_target_dataset_op(input_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value) if get_next_op not in get_next_op_map: logger.debug("find a new get_next_op named '%s'", get_next_op.name) + replacement_specs = record_ops_to_replace(get_next_op) - get_next_op_map[get_next_op]["replacement_specs"] = replacement_specs - passing_tensor_list, batch_tensor_index_list, sub_cutting_point_list = \ + passing_tensors, batch_tensor_indexs, sub_cutting_points = \ get_passing_tensor_list(cutting_point_list, get_next_op) - get_next_op_map[get_next_op]["passing_tensor_list"] = passing_tensor_list - get_next_op_map[get_next_op]["batch_tensor_index_list"] = batch_tensor_index_list - get_next_op_map[get_next_op]["sub_cutting_point_list"] = sub_cutting_point_list - - sub_graph_def, input_name_list, output_name_list = get_sub_graph(passing_tensor_list, - sub_cutting_point_list) - get_next_op_map[get_next_op]["sub_graph_def"] = sub_graph_def - get_next_op_map[get_next_op]["input_name_list"] = input_name_list - get_next_op_map[get_next_op]["output_name_list"] = output_name_list - get_next_op_map[get_next_op]["is_training"] = \ - SparseEmbedding.get_anchor_attribute(input_tensor, ASCAnchorAttr.IS_TRAINING) + sub_graph_def, input_names, output_names = get_sub_graph(passing_tensors, sub_cutting_points) + is_training = BaseSparseEmbedding.get_anchor_attribute(input_tensor, ASCAnchorAttr.IS_TRAINING) + + record = AnchorRecord( + replacement_specs, + passing_tensors, + batch_tensor_indexs, + sub_cutting_points, + sub_graph_def, + input_names, + output_names, + is_training + ) + get_next_op_map[get_next_op] = record export_pb_graph(f"cut_graph_{get_next_op.name}.pb", dump_graph, graph_def=sub_graph_def) @@ -423,7 +427,7 @@ def get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapt except (ValueError, TypeError, RuntimeError) as err: logger.warning("The dataset op was not found, the error is `%s`. Start to traverse the operations.", err) graph = tf.compat.v1.get_default_graph() - dataset_op_list = [op for op in graph.get_operations() if ANCHOR_DATASET_NAME in op.name] + dataset_op_list = [op for op in graph.get_operations() if AnchorDatasetOp.PREFETCH_DATASET.value in op.name] logger.debug("In get_src_dataset function, current mode(train: True, eval: False): %s, dataset_op_list: %s.", is_training, dataset_op_list) @@ -436,7 +440,7 @@ def get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapt prefetch_dataset_op_list = sorted(dataset_op_list, key=lambda op: op.name) target_op = prefetch_dataset_op_list[1] else: - raise RuntimeError(f"The `{ANCHOR_DATASET_NAME}` was not found from the operations, dataset_op_list: " + raise RuntimeError(f"`{AnchorDatasetOp.PREFETCH_DATASET.value}` not found, got dataset_op_list: " f"{dataset_op_list}.") from err except Exception as err: raise RuntimeError(f"The dataset was not found, the error is `{err}`.") from err @@ -449,11 +453,11 @@ def get_src_dataset(get_next_op: Operation, is_training: bool) -> DatasetV1Adapt def get_tgt_dataset( - src_dataset: DatasetV1Adapter, - sub_cutting_point_list: List[Tensor], - records: AnchorRecord, - dump_graph: bool = False, - prefetch: int = 10 + src_dataset: DatasetV1Adapter, + sub_cutting_point_list: List[Tensor], + record: AnchorRecord, + dump_graph: bool = False, + prefetch: int = 10 ) -> DatasetV1Adapter: """ 根据原始数据集生成新的数据集实例. @@ -470,24 +474,24 @@ def get_tgt_dataset( """ librec = import_host_pipeline_ops(LIBREC_EOS_OPS_SO) - channel_id = get_training_mode_channel_id(records.get("is_training")) + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id( + record.is_training) # 在数据读取完时,通过EosDataset向acl数据通道发送end_of_sequence src_dataset = src_dataset.eos_map(librec, channel_id) - tgt_dataset = src_dataset.map(get_preprocessing_map_func(records.get("sub_graph_def"), - records.get("input_name_list"), - records.get("output_name_list"), - pipeline_input_indexes=records.get( - "batch_tensor_index_list"))) + tgt_dataset = src_dataset.map(get_preprocessing_map_func(record.sub_graph_def, + record.input_names, + record.output_names, + pipeline_input_indexes=record.batch_tensor_indexs)) - feature_numbers = [SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).feat_cnt for + feature_numbers = [BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).feat_cnt for cutting_point in sub_cutting_point_list] - table_names = [SparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).table_name for + table_names = [BaseSparseEmbedding.get_anchor_attribute(cutting_point, ASCAnchorAttr.FEATURE_SPEC).table_name for cutting_point in sub_cutting_point_list] tgt_dataset = tgt_dataset.map(get_asc_insert_func(feature_numbers=feature_numbers, table_names=table_names, - args_index_list=records.get("input_index_list"), - is_training=records.get("is_training"), + args_index_list=record.input_indexs, + is_training=record.is_training, dump_graph=dump_graph)) tgt_dataset = tgt_dataset.prefetch(prefetch) @@ -497,7 +501,7 @@ def get_tgt_dataset( def update_iterator_getnext(get_next_op: Operation, tgt_dataset: DatasetV1Adapter, is_training: bool, - records: AnchorRecord): + record: AnchorRecord): """ 用新数据集中的`IteratorGetNext`算子替换计算图中原始数据集的`IteratorGetNext`算子,即用新数据集的batch替换原始数据集的batch. @@ -517,27 +521,27 @@ def update_iterator_getnext(get_next_op: Operation, iterator_type = get_next_op.outputs[0].op.inputs[0].op.type if iterator_type == "IteratorV2": iterator_type = find_make_iterator_op(get_next_op.outputs[0]).type - if iterator_type not in ("MakeIterator", "OneShotIterator"): + if iterator_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, AnchorIteratorOp.ONE_SHOT_ITERATOR.value): raise RuntimeError(f"Only iterators `MakeIterator` and `OneShotIterator` are supported in `graph modify` mode, " f"but the current iterator is `{iterator_type}`.") - set_iterator_type(iterator_type) + ConfigInitializer.get_instance().train_params_config.iterator_type = iterator_type logger.info("The iterator type of dataset is `%s`.", iterator_type) - if iterator_type == "MakeIterator": + if iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value: new_iterator = tgt_dataset.make_initializable_iterator() tf.compat.v1.add_to_collection(ASCEND_CUTTING_POINT_INITIALIZER, new_iterator.initializer) - set_initializer(is_training, new_iterator.initializer) + ConfigInitializer.get_instance().train_params_config.set_initializer(is_training, new_iterator.initializer) else: new_iterator = tgt_dataset.make_one_shot_iterator() new_batch = new_iterator.get_next() - set_target_batch(is_training, new_batch) + ConfigInitializer.get_instance().train_params_config.set_target_batch(is_training, new_batch) try: new_batch_tensor = list(new_batch.values())[0] except IndexError as err: raise IndexError("Cannot find a tensor from given batch.") from err - new_get_next_op_name = find_target_dataset_op(new_batch_tensor.op, "IteratorGetNext").name - update_input_tensor_with_new_batch(records.get("replacement_specs"), new_get_next_op_name, new_batch) + new_get_next_op_name = find_target_dataset_op(new_batch_tensor.op, AnchorIteratorOp.ITERATOR_GET_NEXT.value).name + update_input_tensor_with_new_batch(record.replacement_spec, new_get_next_op_name, new_batch) @performance("graph_modifier") @@ -553,8 +557,8 @@ def modify_graph_for_asc(dump_graph: bool = False, prefetch: int = 10): logger.debug("In modify_graph_for_asc function, get_next_op_map.len: %d, get_next_op_map.key: %s.", len(get_next_op_map), get_next_op_map.keys()) - for get_next_op, records in get_next_op_map.items(): - is_training = records.get("is_training") + for get_next_op, record in get_next_op_map.items(): + is_training = record.is_training # get source dataset src_dataset = get_src_dataset(get_next_op, is_training) @@ -562,27 +566,27 @@ def modify_graph_for_asc(dump_graph: bool = False, prefetch: int = 10): # generate target dataset timestamp_index = get_timestamp_index(get_next_op, is_training) original_batch_tensor_count = get_dataset_tensor_count(src_dataset) - sub_cutting_point_list = records.get("sub_cutting_point_list") - input_index_list = get_input_index_list(sub_cutting_point_list, - records.get("replacement_specs"), - records.get("output_name_list"), + sub_cutting_points = record.sub_cutting_points + input_index_list = get_input_index_list(sub_cutting_points, + record.replacement_spec, + record.output_names, original_batch_tensor_count, timestamp_index=timestamp_index) - records["input_index_list"] = input_index_list - tgt_dataset = get_tgt_dataset(src_dataset, sub_cutting_point_list, records, + record.input_indexs = input_index_list + tgt_dataset = get_tgt_dataset(src_dataset, sub_cutting_points, record, dump_graph=dump_graph, prefetch=prefetch) # update the batch of dataset - update_iterator_getnext(get_next_op, tgt_dataset, is_training, records) + update_iterator_getnext(get_next_op, tgt_dataset, is_training, record) # In eval mode, backward is not required. In addition, compute gradients is not executed when # only eval is used. Therefore, `do_merge_lookup` needs to be invoked during modify graph. if not is_training: do_merge_lookup(is_train=False) - if 'evaluate' in get_bool_gauge_set(): + if 'evaluate' in ConfigInitializer.get_instance().train_params_config.bool_gauge_set: logger.debug("In estimator mode, eval re-creates graph each time, so the flag needs to be cleared.") - insert_merged_multi_lookup(is_training, False) + ConfigInitializer.get_instance().train_params_config.insert_merged_multi_lookup(is_training, False) # In training mode, `do_merge_lookup` should have been executed in compute gradients phase. - if is_training and not get_merged_multi_lookup(True): + if is_training and not ConfigInitializer.get_instance().train_params_config.get_merged_multi_lookup(True): raise RuntimeError("In training mode, `do_merge_lookup` should have been executed in compute gradients " "phase. Please check whether compute gradients is performed.") @@ -596,11 +600,12 @@ def get_timestamp_index(get_next_op: Operation, is_training: bool) -> int: for timestamp in timestamp_tensor_list: if timestamp in get_next_op.outputs: timestamp_index = int(timestamp.name.split(":")[1]) - timestamp_feature_spec = get_feature_spec("timestamp") + timestamp_feature_spec = ConfigInitializer.get_instance().feature_spec_config.get_feature_spec("timestamp") if timestamp_feature_spec is None: timestamp_feature_spec = FeatureSpec("timestamp", index_key=timestamp_index, is_timestamp=True) timestamp_feature_spec.include_timestamp(is_training) - insert_feature_spec(timestamp_feature_spec, is_training) + ConfigInitializer.get_instance().feature_spec_config.insert_feature_spec(timestamp_feature_spec, + is_training) break if timestamp_feature_spec.index_key != timestamp_index: @@ -622,7 +627,7 @@ class GraphModifierHook(tf.estimator.SessionRunHook): self._dump_graph = dump_graph self._modify_graph = modify_graph self._iterator_type = "" - set_is_graph_modify_hook_running(True) + ConfigInitializer.get_instance().train_params_config.is_graph_modify_hook_running = True def begin(self): if self._modify_graph: @@ -630,11 +635,12 @@ class GraphModifierHook(tf.estimator.SessionRunHook): else: start_asc_pipeline() - self._iterator_type = get_iterator_type() - if self._modify_graph and self._iterator_type not in ("MakeIterator", "OneShotIterator"): + self._iterator_type = ConfigInitializer.get_instance().train_params_config.iterator_type + if self._modify_graph and self._iterator_type not in (AnchorIteratorOp.MAKE_ITERATOR.value, + AnchorIteratorOp.ONE_SHOT_ITERATOR.value): raise ValueError("The value of iterator type should be like `MakeIterator` or `OneShotIterator`.") logger.debug("In GraphModifierHook, iterator type is `%s`.", self._iterator_type) def after_create_session(self, session, coord): - if self._modify_graph and self._iterator_type == "MakeIterator": + if self._modify_graph and self._iterator_type == AnchorIteratorOp.MAKE_ITERATOR.value: session.run(tf.compat.v1.get_collection(ASCEND_CUTTING_POINT_INITIALIZER)) diff --git a/mx_rec/graph/patch.py b/mx_rec/graph/patch.py index cdf2d255..f1a989b2 100644 --- a/mx_rec/graph/patch.py +++ b/mx_rec/graph/patch.py @@ -37,8 +37,7 @@ from tensorflow.python.training.optimizer import Optimizer from tensorflow.python.client.session import BaseSession from mx_rec.constants import constants -from mx_rec.util.initialize import get_is_graph_modify_hook_running, get_modify_graph, insert_bool_gauge, \ - get_bool_gauge_set, terminate_config_initializer, get_asc_manager, export_table_instances +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.tf_version_adapter import NPUCheckpointSaverHook from mx_rec.graph.merge_lookup import do_merge_lookup from mx_rec.util.log import logger @@ -69,7 +68,6 @@ def init_dataset(self, input_data): ("run_metadata", ClassValidator, {"classes": (tf.compat.v1.RunMetadata, type(None))}), ], output_log=False) def run(self, fetches, feed_dict=None, options=None, run_metadata=None): - """ Replace tensorflow's session run method with this method, this method will notify the hybridMgmt side to wake up and count each time sess run is called. @@ -144,8 +142,9 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): except AssertionError: channel_id = -1 - if channel_id != -1 and get_asc_manager(): - get_asc_manager().block_notify_wake(channel_id) + asc_manager = ConfigInitializer.get_instance().hybrid_manager_config.asc_manager + if channel_id != -1 and asc_manager: + asc_manager.block_notify_wake(channel_id) if channel_id == constants.EVAL_CHANNEL_ID: # eval的时候不进行循环下沉 @@ -156,8 +155,8 @@ def run(self, fetches, feed_dict=None, options=None, run_metadata=None): # 调用tensorflow原生的方法 result = self.old_run_method(fetches, feed_dict, options, run_metadata) - if channel_id != -1 and get_asc_manager(): - get_asc_manager().block_count_steps(channel_id, steps) + if channel_id != -1 and asc_manager: + asc_manager.block_count_steps(channel_id, steps) return result @@ -214,7 +213,8 @@ def chief_session_creator_init(self, scaffold=None, master='', config=None, chec Returns:None """ logger.debug("Enter the mxrec init function of Class 'monitored_session.ChiefSessionCreator'.") - if get_modify_graph() and not get_is_graph_modify_hook_running(): + if ConfigInitializer.get_instance().modify_graph and \ + not ConfigInitializer.get_instance().train_params_config.is_graph_modify_hook_running: raise RuntimeError( f"When 'modify_graph' is True, 'GraphModifierHook' must be configured. Example: \n" f"\t from mx_rec.graph.modifier import GraphModifierHook \n" @@ -251,7 +251,7 @@ def get_cell(self: BoolGauge, *labels: Any) -> Any: logger.debug("Enter patch 'BoolGauge.get_cell'.") if len(labels) > 0: logger.debug("BoolGauge insert: %s.", labels[0]) - insert_bool_gauge(labels[0]) + ConfigInitializer.get_instance().train_params_config.insert_bool_gauge(labels[0]) return BoolGaugeCell(super(BoolGauge, self).get_cell(*labels)) @@ -277,8 +277,8 @@ def assert_eval_spec(eval_spec: EvalSpec): if not isinstance(eval_spec, EvalSpec): raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`. Got: {}'.format(type(eval_spec))) - if 'train_and_evaluate' not in get_bool_gauge_set(): - insert_bool_gauge('train_and_evaluate') + if 'train_and_evaluate' not in ConfigInitializer.get_instance().train_params_config.bool_gauge_set: + ConfigInitializer.get_instance().train_params_config.insert_bool_gauge('train_and_evaluate') logger.debug("assert_eval_spec: add 'train_and_evaluate' to BoolGaugeCell.") @@ -310,7 +310,7 @@ def scale_loss(self: Optimizer, loss_value: tf.Tensor) -> tf.Tensor: # 在训练情况下,至少要有一个variable参与反向,否则报错 is_grad = False table_var_list = [] - for _, table_instance in export_table_instances().items(): + for _, table_instance in ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict.items(): is_grad |= table_instance.is_grad table_var_list.append(table_instance.variable) if not is_grad: diff --git a/mx_rec/graph/utils.py b/mx_rec/graph/utils.py index 04e32ebe..c010d80d 100644 --- a/mx_rec/graph/utils.py +++ b/mx_rec/graph/utils.py @@ -17,23 +17,19 @@ import os from collections import defaultdict -from typing import Any, List, Dict, DefaultDict, Tuple, Union +from typing import List, Dict, Union import tensorflow as tf -from tensorflow import Tensor -from tensorflow import Operation +from tensorflow import Operation, Tensor from tensorflow.core.framework.graph_pb2 import GraphDef from tensorflow.python.framework.errors_impl import InvalidArgumentError from mx_rec.constants.constants import ASCAnchorAttr, DUMP_MIDIFY_GRAPH_FILE_MODE -from mx_rec.core.embedding import SparseEmbedding +from mx_rec.core.embedding import BaseSparseEmbedding +from mx_rec.graph.graph_typing import ReplacementSpec from mx_rec.util.log import logger -ReplacementSpec = DefaultDict[Tensor, List[Tuple[int, Operation]]] -AnchorRecord = Dict[str, Union[ReplacementSpec, GraphDef, bool, List[Tensor], List[int], List[str]]] - - def check_input_list(objs: Union[object, List[object]], obj_type: type) -> Union[object, List[object]]: if isinstance(objs, obj_type): objs = [objs] @@ -158,7 +154,7 @@ def replace_anchor_vec(cutting_point: Tensor, attribute: ASCAnchorAttr, anchor: """ # get stub node - anchor_vec = SparseEmbedding.get_anchor_attribute(cutting_point, attribute) + anchor_vec = BaseSparseEmbedding.get_anchor_attribute(cutting_point, attribute) if anchor_vec is None: raise RuntimeError(f"Node `{attribute.value}` does not exist. Check whether the sparse lookup interface " f"is correctly invoked.") @@ -166,3 +162,17 @@ def replace_anchor_vec(cutting_point: Tensor, attribute: ASCAnchorAttr, anchor: replacement_specs_for_anchor_vec = record_ops_to_replace(anchor_vec.op) # replace anchor_vec with anchor replace_anchor(replacement_specs_for_anchor_vec, [anchor]) + + +def tag_orphan_ids(ids: tf.Tensor) -> tf.Tensor: + """ + 将孤儿ids使用identity操作创建ACG_PUSH_NODE前缀命名的标记节点,以便在PushOps时能找到。 + """ + graph_def = tf.compat.v1.get_default_graph().as_graph_def() + subgraph = tf.compat.v1.graph_util.extract_sub_graph(graph_def, [ids.op.name]) + for node in subgraph.node: + if node.op == 'IteratorGetNext': + return ids + new_ids = tf.identity(ids, name=f"ACG_PUSH_NODE_{ids.op.name}") + logger.info('Tag orphan op node: %s with %s.', ids, new_ids) + return new_ids diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index 13459b14..c7829145 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -27,7 +27,7 @@ from tensorflow.python.training import adagrad, training_ops from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance, insert_removing_var_list, get_use_dynamic_expansion +from mx_rec.util.initialize import ConfigInitializer from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @@ -50,8 +50,7 @@ def create_hash_optimizer(learning_rate=0.001, :param name: Optional name prefix for the operations created when applying gradients. Defaults to "Adagrad". :return: adagrad hash optimizer instance """ - - if get_use_dynamic_expansion(): + if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") return CustomizedAdagrad(learning_rate=learning_rate, @@ -83,11 +82,11 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): return new_slot_variable accumulator = creat_one_single_slot(var, self._name + "/" + "accumulator") - insert_removing_var_list(accumulator.name) + ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accumulator.name) named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") + raise EnvironmentError(f"Sparse optimizer named {self._name} already exists.") table_instance.set_optimizer(self._name, {"accumulator": accumulator}) return [{"slot": accumulator, "named_slot_key": named_slot_key, "slot_name": "acc", "optimizer": self}] diff --git a/mx_rec/optimizers/emb_optimizer.py b/mx_rec/optimizers/emb_optimizer.py new file mode 100644 index 00000000..c7f1b64a --- /dev/null +++ b/mx_rec/optimizers/emb_optimizer.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +from mx_rec.optimizers.base import CustomizedOptimizer +from mx_rec.util.tf_version_adapter import NPULossScaleOptimizer + + +class EmbOptimizer: + """ + 稀疏表的优化器. + """ + + def __init__(self, optimizer_list): + self._optimizer_instance_list = optimizer_list + self._optimizer_slot_info_list = [] + self._optimizer = dict() + + @property + def optimizer_instance_list(self): + return self._optimizer_instance_list + + @property + def optimizer_slot_info_list(self): + return self._optimizer_slot_info_list + + @property + def optimizer(self): + return self._optimizer + + @staticmethod + def set_optimizer_slot(slot_info: dict): + """ + 设置稀疏表优化器的slot信息. + + Args: + slot_info: 优化器slot信息 + + Returns: None + """ + slot = slot_info.get("slot") + slot_name = slot_info.get("slot_name") + optimizer = slot_info.get("optimizer") + named_slot_key = slot_info.get("named_slot_key") + + optimizer.insert_slot(slot, named_slot_key, slot_name) + + def set_optimizer(self, key: str, state_dict: dict, table_name: str): + """ + 设置optimizer state. + + Args: + key: 优化器名字 + state_dict: optimizer state + table_name: 稀疏表名 + + Returns: None + """ + if key in self._optimizer: + raise ValueError(f"Optimizer {key} has been set for hash table {table_name}.") + self._optimizer[key] = state_dict + + def check_optimizer_instance_list(self): + """ + 校验优化器实例列表. + """ + if not self._optimizer_instance_list: + raise ValueError("External storage mode should config optimizers before instantiating sparse table, " + "but nothing was configured.") + + for optimizer_instance in self._optimizer_instance_list: + if isinstance(optimizer_instance, NPULossScaleOptimizer): + optimizer_instance = getattr(optimizer_instance, '_opt') + + if not isinstance(optimizer_instance, CustomizedOptimizer): + raise TypeError("The optimizer instance must be an instance of CustomizedOptimizer.") diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 1ff1c052..a25471ac 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -32,7 +32,7 @@ from tensorflow.python.training import ftrl from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance, insert_removing_var_list, get_use_dynamic_expansion +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, ClassValidator, NumValidator, StringValidator, \ @@ -52,7 +52,7 @@ from mx_rec.validator.validator import para_checker_decorator, ClassValidator, N ("linear_name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl", **kwargs): - if get_use_dynamic_expansion(): + if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") return CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) @@ -83,10 +83,10 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): accum = slot_creator.create_slot(var, val, self._name + "/" + "accum") linear = slot_creator.create_zeros_slot(var, self._name + "/" + "linear") - insert_removing_var_list(accum.name) - insert_removing_var_list(linear.name) + ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accum.name) + ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(linear.name) named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) if self._name in table_instance.optimizer: raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") @@ -246,8 +246,8 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): accum = self._get_or_make_slot(each_var, val, "accum", accum_state_name) linear = self._zeros_slot(each_var, "linear", linear_state_name) # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - insert_removing_var_list(accum.name) - insert_removing_var_list(linear.name) + ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accum.name) + ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(linear.name) if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"accum": accum, "linear": linear}) diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index b1edb4ed..598571c8 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -28,7 +28,7 @@ from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import MAX_INT32 -from mx_rec.util.initialize import get_use_dynamic_expansion +from mx_rec.util.initialize import ConfigInitializer from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @@ -38,7 +38,7 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer(learning_rate, use_locking=False, name="GradientDescent"): - if get_use_dynamic_expansion(): + if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") return CustomizedGradientDescent(learning_rate=learning_rate, use_locking=use_locking, name=name) diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index 2838fb4a..e2de8903 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -25,8 +25,9 @@ from tensorflow.python.ops import math_ops from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer, get_use_dynamic_expansion +from mx_rec.util.initialize import ConfigInitializer from mx_rec.constants.constants import MAX_INT32 +from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.validator.validator import para_checker_decorator, StringValidator, ClassValidator, FloatValidator @@ -37,14 +38,14 @@ from mx_rec.validator.validator import para_checker_decorator, StringValidator, ("name", StringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) ]) def create_hash_optimizer_by_addr(learning_rate, weight_decay=0.0001, use_locking=False, name="GradientDescentByAddr"): - if not get_use_dynamic_expansion(): + if not ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") optimizer_by_addr = CustomizedGradientDescentByAddr(learning_rate=learning_rate, weight_decay=weight_decay, use_locking=use_locking, name=name) - insert_optimizer(optimizer_by_addr) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer_by_addr return optimizer_by_addr @@ -71,7 +72,7 @@ class CustomizedGradientDescentByAddr(gradient_descent.GradientDescentOptimizer, return [] def _apply_sparse(self, grad, addr): - host_pipeline_ops = get_host_pipeline_ops() + host_pipeline_ops = import_host_pipeline_ops() dim = grad.shape.as_list()[-1] if self.weight_decay is None: nd_value = grad * math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype) @@ -86,6 +87,3 @@ class CustomizedGradientDescentByAddr(gradient_descent.GradientDescentOptimizer, def _apply_dense(self, grad, var): raise NotImplementedError("You are using a wrong type of variable.") - - - diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index cc51ab5c..e4268fed 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -31,7 +31,7 @@ from tensorflow.python.training import adam from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer -from mx_rec.util.initialize import get_table_instance, insert_removing_var_list, get_use_dynamic_expansion +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator @@ -55,7 +55,7 @@ def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1 Returns: a customized optimizer instance """ - if get_use_dynamic_expansion(): + if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") return CustomizedLazyAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) @@ -66,6 +66,7 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="LazyAdam"): self.optimizer_type = "LazyAdam" + self.config_instance = ConfigInitializer.get_instance() super(CustomizedLazyAdam, self)._get_name(name=name) super(CustomizedLazyAdam, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, use_locking=use_locking, name=self.unique_name) @@ -79,10 +80,10 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): momentum = creat_one_single_slot(var, self._name + "/" + "momentum") velocity = creat_one_single_slot(var, self._name + "/" + "velocity") - insert_removing_var_list(momentum.name) - insert_removing_var_list(velocity.name) + self.config_instance.sparse_embed_config.insert_removing_var_list(momentum.name) + self.config_instance.sparse_embed_config.insert_removing_var_list(velocity.name) named_slot_key = (var.op.graph, var.op.name) - table_instance = get_table_instance(var) + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) if self._name in table_instance.optimizer: raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") @@ -192,8 +193,8 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): momentum = self._zeros_slot(each_var, "m", m_state_name) velocity = self._zeros_slot(each_var, "v", v_state_name) # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. - insert_removing_var_list(momentum.name) - insert_removing_var_list(velocity.name) + self.config_instance.sparse_embed_config.insert_removing_var_list(momentum.name) + self.config_instance.sparse_embed_config.insert_removing_var_list(velocity.name) if self._name not in table_instance.optimizer: table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index e78e24c0..c338b592 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -25,9 +25,10 @@ import tensorflow as tf from tensorflow.python.ops import math_ops from tensorflow.python.training import adam -from mx_rec.util.initialize import get_host_pipeline_ops, insert_optimizer, get_use_dynamic_expansion +from mx_rec.util.initialize import ConfigInitializer from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.constants.constants import MAX_INT32 +from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator @@ -50,14 +51,13 @@ def create_hash_optimizer_by_address(learning_rate=0.001, beta1=0.9, beta2=0.999 Returns: a customized optimizer instance """ - - if not get_use_dynamic_expansion(): + if not ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") optimizer_by_addr = CustomizedLazyAdamByAddress(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) - insert_optimizer(optimizer_by_addr) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer_by_addr return optimizer_by_addr @@ -124,7 +124,7 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): temp_epsilon = temp.get("temp_epsilon") learning_rate = tf.divide(temp_lr * math_ops.sqrt(1 - power_b2), (1 - power_b1)) - host_pipeline_ops = get_host_pipeline_ops() + host_pipeline_ops = import_host_pipeline_ops() dim = grad.shape.as_list()[-1] combined_tensor = \ host_pipeline_ops.embedding_lookup_by_address(addr, embedding_dim=3 * dim, embedding_type=1) diff --git a/mx_rec/saver/patch.py b/mx_rec/saver/patch.py index 38fd2c3e..8f40ef1c 100644 --- a/mx_rec/saver/patch.py +++ b/mx_rec/saver/patch.py @@ -43,7 +43,7 @@ from tensorflow.python.training.saving import saveable_object_util import numpy as np from mx_rec.saver.saver import Saver as SparseSaver, check_file_system_is_valid -from mx_rec.util.initialize import get_ascend_global_hashtable_collection, export_removing_var_list +from mx_rec.util.initialize import ConfigInitializer from mx_rec.validator.validator import para_checker_decorator, ClassValidator, StringValidator, OptionalIntValidator, \ OptionalStringValidator, DirectoryValidator from mx_rec.util.log import logger @@ -56,12 +56,14 @@ def get_sparse_vars(var_list): if var_list is not None: if not isinstance(var_list, (list, tuple)): raise TypeError("A non-None var_list must be a list or tuple.") - ascend_variables = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + ascend_variables = tf.compat.v1.get_collection( + ConfigInitializer.get_instance().train_params_config.ascend_global_hashtable_collection) for var in var_list: if var in ascend_variables: sparse_var_list.append(var) else: - sparse_var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + sparse_var_list = tf.compat.v1.get_collection( + ConfigInitializer.get_instance().train_params_config.ascend_global_hashtable_collection) return sparse_var_list @@ -80,7 +82,6 @@ def saver_init(self, var_list=None, reshape=False, sharded=False, max_to_keep=5, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=saver_pb2.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None, fid_version=0): - self._var_list = var_list self._last_checkpoints = [] self._checkpoints_to_be_deleted = [] @@ -372,7 +373,7 @@ def saver_from_object_based_checkpoint(checkpoint_path, var_list=None, builder=N def build_var_list(): save_var_list = [] tmp_list = variables._all_saveable_objects() - removing_var_list = export_removing_var_list() + removing_var_list = ConfigInitializer.get_instance().sparse_embed_config.removing_var_list for var in tmp_list: if var.name not in removing_var_list: save_var_list.append(var) @@ -426,5 +427,3 @@ def patch_for_saver(): dense_saver.restore = restore dense_saver.build = build logger.debug("Class tf.train.Saver has been patched.") - - diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index aebe621f..1fa47827 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -26,10 +26,9 @@ from tensorflow.python.util import compat from mx_rec.constants.constants import DataName, DataAttr, MIN_SIZE, MAX_FILE_SIZE, Flag, TFDevice, \ MAX_INT32, HDFS_FILE_PREFIX -from mx_rec.util.initialize import get_rank_id, get_rank_size, get_customized_ops, get_table_instance, \ - get_table_instance_by_name, is_asc_manager_initialized, save_host_data, restore_host_data, get_host_data, \ - send_host_data, get_ascend_global_hashtable_collection, set_sparse_dir, get_local_rank_size, \ - get_use_dynamic_expansion +from mx_rec.util.communication.hccl_ops import get_rank_id, get_rank_size, get_local_rank_size +from mx_rec.util.initialize import ConfigInitializer + from mx_rec.util.perf import performance from mx_rec.validator.validator import DirectoryValidator, FileValidator, para_checker_decorator, ClassValidator, \ IntValidator, OptionalStringValidator @@ -52,11 +51,9 @@ class SaveModelThread(threading.Thread): class Saver(object): - customized_ops = get_customized_ops() - @staticmethod def _make_table_name_dir(root_dir, table_instance, table_name): - if table_instance.host_vocabulary_size > 0: + if not table_instance.is_hbm: table_dir = os.path.join(root_dir, "HashTable", "DDR", table_name) else: table_dir = os.path.join(root_dir, "HashTable", "HBM", table_name) @@ -79,15 +76,18 @@ class Saver(object): self.restore_fetch_list = [] self.placeholder_dict = defaultdict(dict) self._last_checkponts = [] + self.config_instance = ConfigInitializer.get_instance() self.build() def build(self): if self.var_list is None: self.var_list = [] - logger.debug("optimizer collection name: %s", get_ascend_global_hashtable_collection()) - temp_var_list = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + logger.debug("optimizer collection name: %s", + self.config_instance.train_params_config.ascend_global_hashtable_collection) + temp_var_list = tf.compat.v1.get_collection( + self.config_instance.train_params_config.ascend_global_hashtable_collection) for var in temp_var_list: - table_instance = get_table_instance(var) + table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) if table_instance.is_save: self.var_list.append(var) @@ -128,7 +128,7 @@ class Saver(object): ckpt_name = f"sparse-{base_name}" saving_path = os.path.join(directory, ckpt_name) - set_sparse_dir(saving_path) + self.config_instance.train_params_config.sparse_dir = saving_path try: if not check_file_system_is_hdfs(saving_path): @@ -168,7 +168,7 @@ class Saver(object): ckpt_name = f"sparse-{base_name}" reading_path = os.path.join(directory, ckpt_name) - set_sparse_dir(reading_path) + self.config_instance.train_params_config.sparse_dir = reading_path if not tf.io.gfile.exists(reading_path): raise FileExistsError(f"Given dir {reading_path} does not exist, please double check.") @@ -178,12 +178,12 @@ class Saver(object): @performance("save_table_name_data") def save_table_name_data(self, sess, result, root_dir, table_name): - table_instance = get_table_instance_by_name(table_name) + table_instance = self.config_instance.sparse_embed_config.get_table_instance_by_name(table_name) self._make_table_name_dir(root_dir, table_instance, table_name) dump_data_dict = sess.run(result.get(table_name)) # when HBM mode is on, need to get host offset data, to process dump data dict for saving valid embedding. - if is_asc_manager_initialized() and table_instance.host_vocabulary_size == 0: + if self.config_instance.hybrid_manager_config.asc_manager and table_instance.is_hbm: self._get_valid_dict_data(dump_data_dict, table_name) # save embedding @@ -197,11 +197,11 @@ class Saver(object): @performance("_save") def _save(self, sess, root_dir): - if is_asc_manager_initialized(): - save_host_data(root_dir) + if self.config_instance.hybrid_manager_config.asc_manager: + self.config_instance.hybrid_manager_config.save_host_data(root_dir) logger.debug(f"host data was saved.") - if get_use_dynamic_expansion(): + if self.config_instance.use_dynamic_expansion: # Data related to dynamic expansion needs to be saved only on the host side. return @@ -218,7 +218,7 @@ class Saver(object): thread.join() def _get_valid_dict_data(self, dump_data_dict, table_name): - host_data = get_host_data(table_name) + host_data = self.config_instance.hybrid_manager_config.get_host_data(table_name) offset = list(host_data) get_valid_dict_data_from_host_offset(dump_data_dict, offset) @@ -227,7 +227,8 @@ class Saver(object): for var in self.var_list: if global_env.tf_device == TFDevice.NPU.value and "merged" not in var.name: continue - table_instance = get_table_instance(var) + + table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) table_name = table_instance.table_name with tf.compat.v1.variable_scope(table_name): sub_dict = self.save_op_dict[table_name] @@ -239,12 +240,12 @@ class Saver(object): for var in self.var_list: if global_env.tf_device == TFDevice.NPU.value and "merged" not in var.name: continue - table_instance = get_table_instance(var) + table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) sub_placeholder_dict = self.placeholder_dict[table_instance.table_name] with tf.compat.v1.variable_scope(table_instance.table_name): sub_placeholder_dict[DataName.EMBEDDING.value] = variable = \ tf.compat.v1.placeholder(dtype=tf.float32, shape=[table_instance.slice_device_vocabulary_size, - table_instance.scalar_emb_size], + table_instance.emb_size], name=DataName.EMBEDDING.value) assign_op = var.assign(variable) self.restore_fetch_list.append(assign_op) @@ -259,7 +260,7 @@ class Saver(object): optimizer_placeholder_dict[optimizer_name] = sub_optimizer_placeholder_dict = \ dict([(state_key, tf.compat.v1.placeholder(dtype=tf.float32, shape=[table_instance.slice_device_vocabulary_size, - table_instance.scalar_emb_size], + table_instance.emb_size], name=state_key)) for state_key, state in optimizer_state_dict.items()]) for key_state, state in optimizer_state_dict.items(): @@ -267,11 +268,11 @@ class Saver(object): self.restore_fetch_list.append(assign_op) def _restore(self, sess, reading_path): - if is_asc_manager_initialized(): - restore_host_data(reading_path) + if self.config_instance.hybrid_manager_config.asc_manager: + self.config_instance.hybrid_manager_config.restore_host_data(reading_path) logger.info("host data was restored.") - if get_use_dynamic_expansion(): + if self.config_instance.use_dynamic_expansion: # Data related to dynamic expansion needs to be restored only on the host side. return @@ -283,8 +284,8 @@ class Saver(object): if "optimizer" in sub_placeholder_dict: optimizer_state_placeholder_dict_group = sub_placeholder_dict.get("optimizer") - fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group, reading_path, - restore_feed_dict, self.rank_id, table_name) + _fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group, reading_path, + restore_feed_dict, self.rank_id, table_name) sess.run(self.restore_fetch_list, feed_dict=restore_feed_dict) @@ -314,17 +315,6 @@ def get_valid_dict_data_from_host_offset(dump_data_dict: dict, offset: list): dump_data_dict["optimizer"] = dump_optimizer_data_dict -def fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group, reading_path, restore_feed_dict, suffix, - table_name): - for optimizer_name, optimizer_state_placeholder_dict in optimizer_state_placeholder_dict_group.items(): - for state_key in optimizer_state_placeholder_dict: - fill_placeholder(reading_path=reading_path, - placeholder_dict=optimizer_state_placeholder_dict, - feed_dict=restore_feed_dict, - suffix=suffix, - name_descriptor=NameDescriptor(table_name, state_key, optimizer_name=optimizer_name)) - - def fill_placeholder(reading_path, placeholder_dict, feed_dict, suffix, name_descriptor): if name_descriptor.optimizer_name: target_path = generate_path(reading_path, "Optimizer", name_descriptor.optimizer_name, "HBM", @@ -452,8 +442,8 @@ def read_binary_data(reading_path: str, suffix: int, data_name: str, table_name: if DataAttr.SHAPE.value in attributes and data_name != DataName.KEY.value: data_shape = attributes.pop(DataAttr.SHAPE.value) data_to_restore = data_to_restore.reshape(data_shape) - table_instance = get_table_instance_by_name(table_name) - current_data_shape = [table_instance.slice_device_vocabulary_size, table_instance.scalar_emb_size] + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance_by_name(table_name) + current_data_shape = [table_instance.slice_device_vocabulary_size, table_instance.emb_size] if data_shape != current_data_shape: data_to_restore = process_embedding_data(data_to_restore, current_data_shape, data_shape) @@ -498,7 +488,8 @@ def process_embedding_data(data_to_restore: np.ndarray, current_data_shape: list elif restore_vocab_size < vocab_size: raise Exception(f"restore vocabulary size {restore_vocab_size} cannot be less than " - f"saved vocabulary size {vocab_size},which would loss the mapping between keys and embeddings ") + f"saved vocabulary size {vocab_size}," + f"which would lose the mapping between keys and embeddings ") return data_to_restore @@ -514,3 +505,26 @@ def check_file_system_is_hdfs(file_path): if file_path.startswith(prefix): return True return False + + +def _fill_placeholder_for_optimizer(optimizer_state_placeholder_dict_group: dict, reading_path: str, + restore_feed_dict: dict, suffix: int, table_name: str): + """ + 给优化器填充加载的数据. + + Args: + optimizer_state_placeholder_dict_group: 待填充优化器的字典 + reading_path: 读取路径 + restore_feed_dict: session run的feed dict + suffix: rank id + table_name: 表名 + + Returns: None + """ + for optimizer_name, optimizer_state_placeholder_dict in optimizer_state_placeholder_dict_group.items(): + for state_key in optimizer_state_placeholder_dict: + fill_placeholder(reading_path=reading_path, + placeholder_dict=optimizer_state_placeholder_dict, + feed_dict=restore_feed_dict, + suffix=suffix, + name_descriptor=NameDescriptor(table_name, state_key, optimizer_name=optimizer_name)) diff --git a/mx_rec/saver/sparse.py b/mx_rec/saver/sparse.py index 9e66769a..c08d20ec 100644 --- a/mx_rec/saver/sparse.py +++ b/mx_rec/saver/sparse.py @@ -21,8 +21,7 @@ import json import numpy as np import tensorflow as tf -from mx_rec.util.initialize import get_table_instance_by_name, export_table_name_set, get_sparse_dir -from mx_rec.validator.validator import FileValidator +from mx_rec.util.initialize import ConfigInitializer from mx_rec.validator.validator import para_checker_decorator, ClassValidator from mx_rec.util.log import logger from mx_rec.saver.saver import validate_read_file @@ -44,7 +43,7 @@ class SparseProcessor: self.json_attrib_dtype = "data_type" self.json_attrib_shape = "shape" self.table_list = table_list - self.default_table_list = list(export_table_name_set()) + self.default_table_list = list(ConfigInitializer.get_instance().sparse_embed_config.table_name_set) if not self.table_list: logger.debug("table list not be set, use default value : all table created ") @@ -88,25 +87,25 @@ class SparseProcessor: def export_sparse_data(self): logger.info("table list to be exported is %s", self.table_list) - sparse_dir = get_sparse_dir() + sparse_dir = ConfigInitializer.get_instance().train_params_config.sparse_dir ddr = False dev_dir = set_upper_dir(sparse_dir, self.device_dir_list) host_dir = set_upper_dir(sparse_dir, self.host_dir_list) for table in self.table_list: - table_instance = get_table_instance_by_name(table) + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance_by_name(table) device_table_dir = os.path.join(dev_dir, table) host_table_dir = os.path.join(host_dir, table) - if table_instance.host_vocabulary_size != 0: + if not table_instance.is_hbm: out_dir = host_table_dir key, offset = self._get_hashmap(host_table_dir, True) emb_data = self.get_embedding(device_table_dir, host_table_dir, True, - table_instance.use_dynamic_expansion) + ConfigInitializer.get_instance().use_dynamic_expansion) emb_data = emb_data[offset] else: out_dir = device_table_dir key, _ = self._get_hashmap(device_table_dir, False) emb_data = self.get_embedding(device_table_dir, host_table_dir, False, - table_instance.use_dynamic_expansion) + ConfigInitializer.get_instance().use_dynamic_expansion) transformed_data = dict(zip(key[:], emb_data[:])) save_path = os.path.join(out_dir, self.export_name + ".npy") with tf.io.gfile.GFile(save_path, "wb") as file: diff --git a/mx_rec/util/__init__.py b/mx_rec/util/__init__.py index 6c919515..41d41654 100644 --- a/mx_rec/util/__init__.py +++ b/mx_rec/util/__init__.py @@ -16,14 +16,9 @@ # ============================================================================== __all__ = [ - "init", "get_rank_id", "get_initializer", "terminate_config_initializer", "clear_channel", - "get_dense_and_sparse_variable", "set_if_load", "set_ascend_global_hashtable_collection", - "get_ascend_global_hashtable_collection", "get_rank_size", "get_host_pipeline_ops", - "get_use_dynamic_expansion", "set_ascend_table_name_must_contain", "get_target_batch" + "init", "terminate_config_initializer" ] from mx_rec.util.tf_version_adapter import npu_ops, hccl_ops, NPUCheckpointSaverHook -from mx_rec.util.initialize import init, get_rank_id, get_initializer, terminate_config_initializer, clear_channel, \ - set_if_load, set_ascend_global_hashtable_collection, get_ascend_global_hashtable_collection, get_rank_size, \ - get_host_pipeline_ops, get_use_dynamic_expansion, set_ascend_table_name_must_contain, get_target_batch +from mx_rec.util.initialize import init, terminate_config_initializer from mx_rec.util.variable import get_dense_and_sparse_variable diff --git a/mx_rec/util/communication/__init__.py b/mx_rec/util/communication/__init__.py index 05fd37d4..c731b156 100644 --- a/mx_rec/util/communication/__init__.py +++ b/mx_rec/util/communication/__init__.py @@ -15,4 +15,4 @@ # limitations under the License. # ============================================================================== -__all__ = ["hccl_mgmt"] \ No newline at end of file +__all__ = ["hccl_mgmt", "hccl_ops"] diff --git a/mx_rec/util/communication/hccl_mgmt.py b/mx_rec/util/communication/hccl_mgmt.py index aaf15267..c245dfc5 100644 --- a/mx_rec/util/communication/hccl_mgmt.py +++ b/mx_rec/util/communication/hccl_mgmt.py @@ -18,11 +18,11 @@ import json import os -from mxrec_pybind import get_logic_id -from mx_rec.constants.constants import MAX_CONFIG_SIZE, MAX_DEVICE_ID, MAX_RANK_SIZE, MIN_RANK_SIZE, MIN_SIZE, \ - VALID_DEVICE_ID_LIST +from mx_rec.constants.constants import VALID_DEVICE_ID_LIST, MIN_SIZE, MAX_CONFIG_SIZE, MAX_DEVICE_ID, \ + MIN_RANK_SIZE, MAX_RANK_SIZE +from mx_rec.validator.validator import FileValidator, para_checker_decorator, StringValidator, \ + Convert2intValidator from mx_rec.util.global_env_conf import global_env -from mx_rec.validator.validator import Convert2intValidator, FileValidator, para_checker_decorator, StringValidator def parse_hccl_json(): @@ -52,13 +52,11 @@ def parse_hccl_json(): raise AttributeError(f"Lack of attribute device.") rank_to_device_dict = dict() - local_rank_size = -1 for server_list in table_hccl.get("server_list"): devices = server_list.get("device") if devices is None: raise ValueError("device is empty") - local_rank_size = len(devices) for device in devices: if "rank_id" not in device or not device.get("rank_id").isdigit(): raise ValueError(f"hccl_json rank_id wrong.") @@ -66,7 +64,8 @@ def parse_hccl_json(): if "device_id" not in device or not device.get("device_id").isdigit(): raise ValueError(f"hccl_json device_id wrong.") - res = get_logic_id(int(device.get("device_id"))) + import mxrec_pybind + res = mxrec_pybind.get_logic_id(int(device.get("device_id"))) if res < 0: raise RuntimeError( f"get logic id from physic id fail, error code is {res}, please check if dsmi api is functional.") @@ -74,26 +73,18 @@ def parse_hccl_json(): raise ValueError(f"get logic id from physic id fail, the device id is invalid.") rank_to_device_dict[rank_id] = res - return rank_to_device_dict, local_rank_size + return rank_to_device_dict -@para_checker_decorator(check_option_list=[ - ("visible_devices", StringValidator, {"msg": "please config ASCEND_VISIBLE_DEVICES in docker container start"}), - ("rank_size", StringValidator, {"msg": "please config CM_WORKER_SIZE in docker container start"}), - ("chief_device", StringValidator, {"msg": "please config CM_CHIEF_DEVICE in docker container start"}), - ("rank_size", Convert2intValidator, {"min_value": MIN_RANK_SIZE, "max_value": MAX_RANK_SIZE, - "constrained_options": [1, 2, 4, 8, 16]}, ["check_value"]), - ("chief_device", Convert2intValidator, {"min_value": 0, "max_value": 15}, ["check_value"]), -]) -def set_hccl_info_without_json(visible_devices: str, rank_size: str, chief_device: str): +def set_hccl_info_without_json(): """ Used for no rank table file configured training situation. Now, only less than or equal 8p training job is supported. - :param visible_devices: 昇腾处理器可见的设备,来指定程序只使用其中的部分设备。 - :param rank_size: 参与集群训练的device数量。 - :param chief_device: 主节点device id。 :return: """ + visible_devices = global_env.ascend_visible_devices + rank_size = global_env.cm_worker_size + chief_device = global_env.cm_chief_device device_list = get_device_list(visible_devices) chief_device = int(chief_device) rank_size = int(rank_size) @@ -107,14 +98,14 @@ def set_hccl_info_without_json(visible_devices: str, rank_size: str, chief_devic if chief_device not in sorted_device_list: raise ValueError(f"The environment variable CM_CHIEF_DEVICE {chief_device} is not in the local device list. ") - rank_to_device_dict = {} chief_index = sorted_device_list.index(chief_device) sorted_device_list = sorted_device_list[chief_index:] + sorted_device_list[0: chief_index] sorted_device_list = sorted_device_list[:rank_size] for device_idx in sorted_device_list: - res = get_logic_id(int(device_idx)) + import mxrec_pybind + res = mxrec_pybind.get_logic_id(int(device_idx)) if res < 0: raise RuntimeError( f"get logic id from physic id fail, error code is {res}, please check if dsmi api is functional.") @@ -123,7 +114,7 @@ def set_hccl_info_without_json(visible_devices: str, rank_size: str, chief_devic raise ValueError(f"get logic id from physic id fail.") index = sorted_device_list.index(device_idx) rank_to_device_dict[index] = res - return rank_to_device_dict, local_rank_size + return rank_to_device_dict def get_device_list(ascend_visible_devices): diff --git a/mx_rec/util/communication/hccl_ops.py b/mx_rec/util/communication/hccl_ops.py new file mode 100644 index 00000000..0bd6e16a --- /dev/null +++ b/mx_rec/util/communication/hccl_ops.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import os +from typing import Optional + +from mx_rec.constants.constants import EnvOption +from mx_rec.util.communication.hccl_mgmt import parse_hccl_json, set_hccl_info_without_json +from mx_rec.util.global_env_conf import global_env, get_global_env_conf +from mx_rec.util.log import logger + + +def get_rank_id() -> Optional[int]: + """ + Get the rank id for the current device in the collective communication group + Note: this method should be used after mpi init + :return: int or None, the rank id of the calling process + """ + rank_id = os.getenv(EnvOption.OMPI_COMM_WORLD_RANK.value) + if rank_id is None: + raise RuntimeError("Environment variable RANK_ID has not been exported, please init mpi/hccl first") + try: + rank_id_int = int(rank_id) + except ValueError as e: + raise ValueError(f"Environment variable RANK_ID should be number, but got the type: {type(rank_id)}") from e + + return rank_id_int + + +def get_device_id() -> Optional[int]: + """ + Get the device id of the calling process + Note: this method should be used after mpi init + :return: int or None, the device id of the calling process + """ + if global_env.rank_table_file: + rank_to_device_dict = parse_hccl_json() + else: + rank_to_device_dict = set_hccl_info_without_json() + device_id = rank_to_device_dict.get(get_rank_id()) + if device_id is None: + raise RuntimeError("Environment variable DEVICE_ID has not been exported, please init mpi/hccl first") + try: + device_id_int = int(device_id) + except ValueError as e: + raise ValueError(f"Environment variable DEVICE_ID should be number, but got the type: " + f"{type(device_id)}.") from e + return device_id_int + + +def get_rank_size() -> Optional[int]: + """ + Get the rank size of the default collective communication group + Note: this method should be used after mpi init + :return: int, the rank size of the group + """ + rank_size = os.getenv(EnvOption.OMPI_COMM_WORLD_LOCAL_SIZE.value) + if rank_size is None: + raise RuntimeError("Environment variable RANK_SIZE has not been exported, please init mpi/hccl first") + try: + rank_size_int = int(rank_size) + except ValueError as e: + raise ValueError(f"Environment variable RANK_SIZE should be number, but got the type: " + f"{type(rank_size)}.") from e + + return rank_size_int + + +def get_local_rank_size() -> Optional[int]: + """ + Get the local rank size of the default collective communication group + Note: this method should be used after mpi init + :return: int, the local rank size of the group + """ + local_rank_size = os.getenv(EnvOption.OMPI_COMM_WORLD_LOCAL_SIZE.value) + if local_rank_size is None: + raise RuntimeError("Environment variable LOCAL_RANK_SIZE has not been exported, please init mpi/hccl first") + try: + local_rank_size_int = int(local_rank_size) + except ValueError as e: + raise ValueError(f"Environment variable LOCAL_RANK_SIZE should be number, but got the type:" + f" {type(local_rank_size)}.") from e + + return local_rank_size_int diff --git a/mx_rec/util/config_utils/__init__.py b/mx_rec/util/config_utils/__init__.py new file mode 100644 index 00000000..6924f767 --- /dev/null +++ b/mx_rec/util/config_utils/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. diff --git a/mx_rec/util/config_utils/embedding_utils.py b/mx_rec/util/config_utils/embedding_utils.py new file mode 100644 index 00000000..b47120f7 --- /dev/null +++ b/mx_rec/util/config_utils/embedding_utils.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +from typing import Optional + +from tensorflow import Variable + +from mx_rec.util.log import logger + + +class SparseEmbedConfig: + """ + Sparse table related configurations. + """ + def __init__(self): + self._table_instance_dict = dict() + self._dangling_table = [] + self._table_name_set = set() + self._removing_var_list = [] + self._name_to_var_dict = dict() + + @property + def table_instance_dict(self): + return self._table_instance_dict + + @property + def dangling_table(self): + return self._dangling_table + + @property + def table_name_set(self): + return self._table_name_set + + @property + def name_to_var_dict(self): + return self._name_to_var_dict + + @property + def removing_var_list(self): + return self._removing_var_list + + def get_table_instance(self, key) -> object: + if key not in self._table_instance_dict: + raise KeyError(f"Given key does not exist.") + + return self._table_instance_dict.get(key) + + def get_table_instance_by_name(self, table_name: Optional[str]) -> object: + if table_name not in self._name_to_var_dict: + raise KeyError(f"Given table name does not exist.") + + key = self._name_to_var_dict.get(table_name) + return self._table_instance_dict.get(key) + + def insert_dangling_table(self, table_name: Optional[str]) -> None: + if table_name not in self._dangling_table: + self._dangling_table.append(table_name) + + def insert_removing_var_list(self, var_name) -> None: + if var_name not in self._removing_var_list: + self._removing_var_list.append(var_name) + + def insert_table_instance(self, name: str, key: Variable, instance: object) -> None: + if key in self._table_instance_dict: + raise KeyError(f"Given key {key} has been used.") + + if name in self._table_name_set: + raise ValueError(f"Duplicated hashtable name '{name}' was used.") + + logger.debug("Record one hash table, with name: %s, key: %s.", name, key) + self._table_name_set.add(name) + self._name_to_var_dict[name] = key + self._table_instance_dict[key] = instance + + def export_table_num(self) -> int: + return len(self.table_instance_dict) if self.table_instance_dict else 0 diff --git a/mx_rec/util/config_utils/feature_spec_utils.py b/mx_rec/util/config_utils/feature_spec_utils.py new file mode 100644 index 00000000..4c40996c --- /dev/null +++ b/mx_rec/util/config_utils/feature_spec_utils.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +from typing import Optional + +from mx_rec.util.log import logger + + +class FeatureSpecConfig: + """ + Feature Spec configurations, including relationship between table name and feature spec + """ + def __init__(self): + self._table_name_to_feature_spec = dict() + self._feature_spec_dict = dict() + + @property + def table_name_to_feature_spec(self): + return self._table_name_to_feature_spec + + @property + def feature_spec_dict(self): + return self._feature_spec_dict + + def clear_same_table_feature_spec(self, table_name: Optional[str], is_training: bool) -> None: + if self.table_name_to_feature_spec.get(table_name) is None or \ + self.table_name_to_feature_spec.get(table_name).get(is_training) is None: + raise KeyError("The table name `%s` does not exist in table_name_to_feature_spec, " + "please check whether the insert_feature_spec(...) is invoked.", table_name) + self.table_name_to_feature_spec.get(table_name)[is_training] = [] + logger.debug("The feature spec of the table name `%s` has been cleared.", table_name) + + def insert_feature_spec(self, feature_spec: object, is_training: bool) -> None: + self._feature_spec_dict[feature_spec.name] = feature_spec + if feature_spec.table_name not in self._table_name_to_feature_spec: + self._table_name_to_feature_spec[feature_spec.table_name] = {True: [], False: []} + self._table_name_to_feature_spec[feature_spec.table_name][is_training].append(feature_spec) + + def get_feature_spec(self, feature_spec_name: Optional[str]) -> object: + return self._feature_spec_dict.get(feature_spec_name) diff --git a/mx_rec/util/config_utils/hybrid_mgmt_utils.py b/mx_rec/util/config_utils/hybrid_mgmt_utils.py new file mode 100644 index 00000000..ec5a802c --- /dev/null +++ b/mx_rec/util/config_utils/hybrid_mgmt_utils.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +from typing import Optional + +from mx_rec.util.log import logger + + +class HybridManagerConfig: + def __init__(self): + self._asc_manager = None + self._is_freeze = False + + @property + def asc_manager(self): + return self._asc_manager + + @property + def freeze(self): + return self._is_freeze + + def set_asc_manager(self, manager) -> None: + from mxrec_pybind import HybridMgmt + if not isinstance(manager, HybridMgmt): + raise ValueError(f"Given manager must be the instance of {HybridMgmt}, which is {type(manager)} " + f"type currently.") + self._asc_manager = manager + self._is_freeze = True + + def del_asc_manager(self) -> None: + if self.asc_manager: + self._asc_manager.destroy() + self._asc_manager = None + self._is_freeze = False + logger.debug("ASC manager has been destroyed.") + + def trigger_evict(self) -> bool: + if not self._asc_manager: + raise RuntimeError("ASC manager does not exist.") + + if self.asc_manager.evict(): + logger.debug("Feature evict is triggered by ops.") + return True + logger.warning("Feature evict not success, skip this time!") + return False + + def get_host_data(self, table_name: str) -> object: + if self.asc_manager is None: + raise RuntimeError("ASC manager does not exist.") + logger.debug("start to get host data.") + return self.asc_manager.send(table_name) + + def save_host_data(self, root_dir: Optional[str]) -> None: + if self.asc_manager is None: + raise RuntimeError("ASC manager does not exist.") + + self.asc_manager.save(root_dir) + logger.debug("Data from host pipeline has been saved.") + + def restore_host_data(self, root_dir: Optional[str]) -> None: + if self.asc_manager is None: + raise RuntimeError("ASC manager does not exist.") + + if not self.asc_manager.load(root_dir): + raise TypeError("Asc load data does not match usr setups, \ + please re-consider if you want to restore from this dir") + logger.debug("Data from host pipeline has been restored.") diff --git a/mx_rec/util/config_utils/optimizer_utils.py b/mx_rec/util/config_utils/optimizer_utils.py new file mode 100644 index 00000000..f1d6a4f9 --- /dev/null +++ b/mx_rec/util/config_utils/optimizer_utils.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + + +class OptimizerConfig: + def __init__(self): + self._optimizer_instance = None + + @property + def optimizer_instance(self): + return self._optimizer_instance + + @optimizer_instance.setter + def optimizer_instance(self, optimizer): + self._optimizer_instance = optimizer diff --git a/mx_rec/util/config_utils/train_param.py b/mx_rec/util/config_utils/train_param.py new file mode 100644 index 00000000..3396d688 --- /dev/null +++ b/mx_rec/util/config_utils/train_param.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +from typing import Optional + +from tensorflow.python.framework.ops import Operation + +from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID + + +class TrainParamsConfig: + """ + Configuration of training job parameters, such as dataset iterator type. + """ + + def __init__(self): + self._ascend_global_hashtable_collection = ASCEND_GLOBAL_HASHTABLE_COLLECTION + self._training_mode_channel_dict = dict() + self._bool_gauge_set = set() + self._is_graph_modify_hook_running = False + self._is_last_round = False + self._merged_multi_lookup = dict() + self._target_batch = dict() + self._iterator_type = "" + self._sparse_dir = "" + self._initializer_dict = dict() + + @property + def iterator_type(self): + return self._iterator_type + + @property + def is_last_round(self): + return self._is_last_round + + @property + def is_graph_modify_hook_running(self): + return self._is_graph_modify_hook_running + + @property + def sparse_dir(self): + return self._sparse_dir + + @property + def ascend_global_hashtable_collection(self): + return self._ascend_global_hashtable_collection + + @iterator_type.setter + def iterator_type(self, iterator_type): + self._iterator_type = iterator_type + + @is_graph_modify_hook_running.setter + def is_graph_modify_hook_running(self, is_hook_running): + self._is_graph_modify_hook_running = is_hook_running + + @sparse_dir.setter + def sparse_dir(self, sparse_dir): + self._sparse_dir = sparse_dir + + @is_last_round.setter + def is_last_round(self, last_round): + self._is_last_round = last_round + + @ascend_global_hashtable_collection.setter + def ascend_global_hashtable_collection(self, name): + self._ascend_global_hashtable_collection = name + + @property + def bool_gauge_set(self): + return self._bool_gauge_set + + def insert_training_mode_channel_id(self, is_training: bool) -> None: + if is_training not in self._training_mode_channel_dict: + # mx_rec has 2 channel for data input. + # train_model bind to channel TRAIN_CHANNEL_ID + # eval_model bind to channel EVAL_CHANNEL_ID + self._training_mode_channel_dict[is_training] = TRAIN_CHANNEL_ID if is_training else EVAL_CHANNEL_ID + + def get_training_mode_channel_id(self, is_training: bool) -> bool: + return self._training_mode_channel_dict.get(is_training) + + def insert_bool_gauge(self, name: Optional[str]) -> None: + self._bool_gauge_set.add(name) + + def insert_merged_multi_lookup(self, is_training: bool, value: bool = True) -> None: + self._merged_multi_lookup[is_training] = value + + def get_merged_multi_lookup(self, is_training: bool) -> None: + return self._merged_multi_lookup.get(is_training) + + def set_target_batch(self, is_training: bool, batch: dict) -> None: + self._target_batch[is_training] = batch + + def get_target_batch(self, is_training: bool) -> Optional[dict]: + return self._target_batch.get(is_training) + + def get_initializer(self, is_training: bool) -> Optional[Operation]: + return self._initializer_dict.get(is_training) + + def set_initializer(self, is_training: bool, initializer: Optional[Operation]) -> None: + self._initializer_dict[is_training] = initializer diff --git a/mx_rec/util/cpu.py b/mx_rec/util/cpu.py new file mode 100644 index 00000000..f4d299ed --- /dev/null +++ b/mx_rec/util/cpu.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +import ctypes +from ctypes import * +import psutil + +from mx_rec.util.log import logger + +PCIE_INFO_ALL_RESERVE_LEN = 32 +g_dcmi = None + + +class PcieInfo(ctypes.Structure): + _fields_ = [ + ("venderid", ctypes.c_uint), + ("subdeviceid", ctypes.c_uint), + ("deviceid", ctypes.c_uint), + ("subdeviceid", ctypes.c_uint), + ("domain", ctypes.c_int), + ("bdf_busid", ctypes.c_uint), + ("bdf_deviceid", ctypes.c_uint), + ("bdf_funcid", ctypes.c_uint), + ("reserve", ctypes.c_char * PCIE_INFO_ALL_RESERVE_LEN) + ] + + +def get_card_and_deivce(logic_id): + """ + 通过芯片逻辑id获取芯片的卡id和device id + 一张卡可能有多个芯片,对应多个device_id,但每个芯片的逻辑ID + 是不一样的 + """ + card = ctypes.c_int(0) + device = ctypes.c_int(0) + logic = ctypes.c_uint(logic_id) + ret = g_dcmi.dcmi_get_card_id_device_id_from_logicid(ctypes.pointer(card), + ctypes.pointer(device), + logic) + if ret != 0: + raise OSError(f"logic id {logic_id} not exist") + return card.value, device.value + + +def get_pcie_id(card_id, device_id): + """ + get pcie ::. + """ + info = PcieInfo() + card = ctypes.c_int(card_id) + dev = ctypes.c_int(device_id) + ret = g_dcmi.dcmi_get_device_pcie_info_v2(card, dev, ctypes.pointer(info)) + if ret != 0: + raise OSError("cant get pcie info of device {card_id}:{deivce_id}") + pcie_id = f'{info.domain:04X}:{info.bdf_busid:02x}:' + pcie_id += f'{info.bdf_deviceid:02x}.{info.bdf_funcid}' + return pcie_id + + +def get_numa_by_pcie(pcie_id): + """ + get numa node by pcie id + """ + with open(f'/sys/bus/pci/devices/{pcie_id}/numa_node') as f: + numa_node = f.read() + return int(numa_node) + + +def get_cpu_list_by_numa(node): + with open(f'/sys/devices/system/node/node{node}/cpulist') as f: + cpulistinfo = f.read() + cpulist_first = cpulistinfo.split(",")[0] + [cpu_start, cpu_end] = cpulist_first.split("-") + return list(range(int(cpu_start), int(cpu_end))) + + +def bind_cpu_by_device_logic_id(logic_id): + global g_dcmi + if g_dcmi is None: + try: + g_dcmi = ctypes.CDLL("libdcmi.so") + if g_dcmi.dcmi_init() != 0: + logger.error("dcmi init failed") + return False + except OSError as e: + logger.error(e) + return False + try: + card_id, device_id = get_card_and_deivce(logic_id) + pcie_id = get_pcie_id(card_id, device_id) + numa = get_numa_by_pcie(pcie_id) + cpu_list = get_cpu_list_by_numa(numa) + psutil.Process().cpu_affinity(cpu_list) + except Exception as e: + logger.error(e) + return False + return True diff --git a/mx_rec/util/framework_npu_env/__init__.py b/mx_rec/util/framework_npu_env/__init__.py new file mode 100644 index 00000000..768c6907 --- /dev/null +++ b/mx_rec/util/framework_npu_env/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +__all__ = ["set_ascend_env"] + +from mx_rec.util.framework_npu_env.tfa_env import set_ascend_env diff --git a/mx_rec/util/framework_npu_env/tfa_env.py b/mx_rec/util/framework_npu_env/tfa_env.py new file mode 100644 index 00000000..a00fd7ce --- /dev/null +++ b/mx_rec/util/framework_npu_env/tfa_env.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import os + +from mx_rec.util.communication.hccl_ops import get_rank_id, get_rank_size, get_device_id +from mx_rec.util.global_env_conf import global_env +from mx_rec.util.log import logger + + +def set_ascend_env(): + """ + 配置昇腾相关的参数和环境变量 + """ + logger.debug("Ascend env set start.") + os.environ["RANK_ID"] = str(get_rank_id()) + + device_id = str(get_device_id()) + os.environ["DEVICE_ID"] = device_id + os.environ["ASCEND_DEVICE_ID"] = device_id + os.environ["DEVICE_INDEX"] = device_id + + if global_env.rank_table_file: + rank_size = get_rank_size() + os.environ["RANK_SIZE"] = str(rank_size) + + os.environ["JOB_ID"] = "10086" + logger.debug("Ascend env has been set.") diff --git a/mx_rec/util/global_env_conf.py b/mx_rec/util/global_env_conf.py index 1852d21b..52b5af46 100644 --- a/mx_rec/util/global_env_conf.py +++ b/mx_rec/util/global_env_conf.py @@ -18,7 +18,7 @@ import os import dataclasses from dataclasses import dataclass -from mx_rec.constants.constants import EnvOption, RecPyLogLevel, Flag, EMPTY_STR, ApplyGradientsStrategy, \ +from mx_rec.constants.constants import EnvOption, RecPyLogLevel, Flag, EMPTY_STR, \ DEFAULT_HD_CHANNEL_SIZE, DEFAULT_KP_THREAD_NUM, DEFAULT_FAST_UNIQUE_THREAD_NUM, RecCPPLogLevel, MAX_INT32, \ MIN_HD_CHANNEL_SIZE, MAX_HD_CHANNEL_SIZE, MIN_KP_THREAD_NUM, MAX_KP_THREAD_NUM, \ MIN_FAST_UNIQUE_THREAD_NUM, MAX_FAST_UNIQUE_THREAD_NUM, DEFAULT_HOT_EMB_UPDATE_STEP, MIN_HOT_EMB_UPDATE_STEP, \ @@ -34,7 +34,6 @@ class RecEnv: cm_chief_device: str cm_worker_size: str tf_device: str - apply_gradients_strategy: str acl_timeout: str hd_channel_size: str key_process_thread_num: str @@ -46,6 +45,9 @@ class RecEnv: use_combine_faae: str stat_on: str record_key_count: str + rank_id_env: str + rank_size_env: str + local_rank_size_env: str def get_global_env_conf() -> RecEnv: @@ -60,8 +62,6 @@ def get_global_env_conf() -> RecEnv: cm_chief_device=os.getenv(EnvOption.CM_CHIEF_DEVICE.value), cm_worker_size=os.getenv(EnvOption.CM_WORKER_SIZE.value), tf_device=os.getenv(EnvOption.TF_DEVICE.value, TFDevice.NONE.value), - apply_gradients_strategy=os.getenv(EnvOption.APPLY_GRADIENTS_STRATEGY.value, - ApplyGradientsStrategy.DIRECT_APPLY.value), acl_timeout=os.getenv(EnvOption.ACL_TIMEOUT.value, "-1"), hd_channel_size=os.getenv(EnvOption.HD_CHANNEL_SIZE.value, DEFAULT_HD_CHANNEL_SIZE), key_process_thread_num=os.getenv(EnvOption.KEY_PROCESS_THREAD_NUM.value, DEFAULT_KP_THREAD_NUM), @@ -72,7 +72,10 @@ def get_global_env_conf() -> RecEnv: glog_stderrthreahold=os.getenv(EnvOption.GLOG_STDERRTHREAHOLD.value, RecCPPLogLevel.INFO.value), use_combine_faae=os.getenv(EnvOption.USE_COMBINE_FAAE.value, Flag.FALSE.value), stat_on=os.getenv(EnvOption.STAT_ON.value, Flag.FALSE.value), - record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value) + record_key_count=os.getenv(EnvOption.RECORD_KEY_COUNT.value, Flag.FALSE.value), + rank_id_env=os.getenv(EnvOption.OMPI_COMM_WORLD_RANK.value), + rank_size_env=os.getenv(EnvOption.OMPI_COMM_WORLD_LOCAL_SIZE.value), + local_rank_size_env=os.getenv(EnvOption.OMPI_COMM_WORLD_LOCAL_SIZE.value), ) return rec_env @@ -82,7 +85,6 @@ def get_global_env_conf() -> RecEnv: ("mxrec_log_level", OptionValidator, {"options": [i.value for i in list(RecPyLogLevel)]}), ("rank_table_file", DirectoryValidator, {}, ["check_exists_if_not_empty"]), ("tf_device", OptionValidator, {"options": [i.value for i in list(TFDevice)]}), - ("apply_gradients_strategy", OptionValidator, {"options": [i.value for i in list(ApplyGradientsStrategy)]}), ("acl_timeout", Convert2intValidator, {"min_value": -1, "max_value": MAX_INT32}, ["check_value"]), ("hd_channel_size", Convert2intValidator, {"min_value": MIN_HD_CHANNEL_SIZE, "max_value": MAX_HD_CHANNEL_SIZE}, ["check_value"]), diff --git a/mx_rec/util/initialize.py b/mx_rec/util/initialize.py index b1021b05..e3cf0847 100644 --- a/mx_rec/util/initialize.py +++ b/mx_rec/util/initialize.py @@ -16,29 +16,25 @@ # ============================================================================== import atexit -import os -from collections import defaultdict import dataclasses import json -import psutil - -import mx_rec.constants.constants -from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION, HASHTABLE_COLLECTION_NAME_LENGTH, \ - TRAIN_CHANNEL_ID, EVAL_CHANNEL_ID, MIN_SIZE, MAX_CONFIG_SIZE, MAX_INT32, TFDevice, Flag, \ - GET_CONFIG_INSTANCE_ERR_MSG -from mx_rec.util.communication.hccl_mgmt import parse_hccl_json, set_hccl_info_without_json -from mx_rec.util.ops import import_host_pipeline_ops -from mx_rec.validator.validator import StringValidator, FileValidator, para_checker_decorator, ClassValidator, \ - IntValidator, ValueCompareValidator, OptionalStringValidator -from mx_rec.util.atomic import AtomicInteger + +from mx_rec.constants.constants import MAX_INT32, GET_CONFIG_INSTANCE_ERR_MSG +from mx_rec.util.config_utils.embedding_utils import SparseEmbedConfig +from mx_rec.util.config_utils.feature_spec_utils import FeatureSpecConfig +from mx_rec.util.config_utils.hybrid_mgmt_utils import HybridManagerConfig +from mx_rec.util.config_utils.optimizer_utils import OptimizerConfig +from mx_rec.util.config_utils.train_param import TrainParamsConfig +from mx_rec.util.framework_npu_env.tfa_env import set_ascend_env from mx_rec.util.global_env_conf import global_env from mx_rec.util.log import logger +from mx_rec.util.perf_factory.bind_cpu import bind_cpu +from mx_rec.validator.validator import para_checker_decorator, ClassValidator, \ + IntValidator, ValueCompareValidator class ConfigInitializer: _single_instance = None - customized_ops = None - host_pipeline_ops = import_host_pipeline_ops() @para_checker_decorator(check_option_list=[ ("use_mpi", ClassValidator, {"classes": (bool,)}), @@ -49,168 +45,46 @@ class ConfigInitializer: ["check_at_least_one_not_equal_to_target"]), ("if_load", ClassValidator, {"classes": (bool,)}), ("use_dynamic", ClassValidator, {"classes": (bool,)}), - ("use_hot", ClassValidator, {"classes": (bool,)}), ("use_dynamic_expansion", ClassValidator, {"classes": (bool,)}), ("bind_cpu", ClassValidator, {"classes": (bool,)}), ]) - def __init__(self, use_mpi=True, **kwargs): - self._use_mpi = use_mpi - self._ascend_global_hashtable_collection = ASCEND_GLOBAL_HASHTABLE_COLLECTION - self._comm = None - self._asc_manager = None - self._mpi = None - self._is_frozen = False - self._train_steps = None - self._eval_steps = None - self._save_steps = None - self._if_load = None - self._table_instance_dict = dict() - self._dangling_table = [] - self._removing_var_list = [] - self._name_to_var_dict = dict() - self._table_name_set = set() - self._table_name_to_feature_spec = dict() - self._feature_spec_dict = dict() - self._training_mode_channel_dict = dict() - self._rank_to_device_dict = dict() - self._initializer_dict = {} - self._bool_gauge_set = set() - self._optimizer_instance = None - self._is_graph_modify_hook_running = False + @bind_cpu + def __init__(self, **kwargs): self._modify_graph = False - self._is_terminated = False - self._is_last_round = False - self._run_times = AtomicInteger() - self._merged_multi_lookup = dict() - self._target_batch = dict() - self._iterator_type = "" - self._sparse_dir = "" - - if self._use_mpi: - logger.debug(f"Using mpi to launch task.") - from mpi4py import MPI - self._mpi = MPI - self._comm = MPI.COMM_WORLD - self._rank_id = self._comm.Get_rank() - self._rank_size = self._comm.Get_size() - self.check_mpi_params() - else: - raise ValueError("only mpi is supported for launching task.") - - if global_env.rank_table_file: - self._rank_to_device_dict, self._local_rank_size = parse_hccl_json() - else: - self._rank_to_device_dict, self._local_rank_size = set_hccl_info_without_json( - visible_devices=global_env.ascend_visible_devices, - rank_size=global_env.cm_worker_size, - chief_device=global_env.cm_chief_device) - - self.train_steps = kwargs.get("train_steps", -1) - self.eval_steps = kwargs.get("eval_steps", -1) - self.save_steps = kwargs.get("save_steps", -1) - - self.if_load = kwargs.get("if_load", False) - - self.use_static = not kwargs.get("use_dynamic", True) - self.use_hot = kwargs.get("use_hot", True) - self.use_dynamic_expansion = kwargs.get("use_dynamic_expansion", False) - if kwargs.get("bind_cpu", True): - bind_cpu(self._rank_id, self._local_rank_size) - self.enable_table_merge = True if global_env.tf_device == TFDevice.NPU.value else False - # 两个通道的sparse look id,用于通讯的标识 - self.notify_hybrid_channel_sparse_id = [0, 0] - self.stat_on = (global_env.stat_on == Flag.TRUE.value) - - @property - def iterator_type(self): - return self._iterator_type - - @property - def local_rank_size(self): - return self._local_rank_size - - @property - def merged_multi_lookup(self): - return self._merged_multi_lookup - @property - def target_batch(self): - return self._target_batch + self._train_steps = kwargs.get("train_steps", -1) + self._eval_steps = kwargs.get("eval_steps", -1) + self._save_steps = kwargs.get("save_steps", -1) - @property - def is_last_round(self): - return self._is_last_round + self._if_load = kwargs.get("if_load", False) - @property - def run_times(self): - return self._run_times + self._use_static = not kwargs.get("use_dynamic", True) + self._use_dynamic_expansion = kwargs.get("use_dynamic_expansion", False) - @property - def bool_gauge_set(self): - return self._bool_gauge_set + self._is_terminated = False - @property - def is_graph_modify_hook_running(self): - return self._is_graph_modify_hook_running + self._sparse_embed_config = SparseEmbedConfig() + self._feature_spec_config = FeatureSpecConfig() + self._hybrid_manager_config = HybridManagerConfig() + self._optimizer_config = OptimizerConfig() + self._train_params_config = TrainParamsConfig() @property def modify_graph(self): return self._modify_graph - @property - def sparse_dir(self): - return self._sparse_dir - - @property - def feature_spec_dict(self): - return self._feature_spec_dict - - @property - def table_name_set(self): - return self._table_name_set - - @property - def table_name_to_feature_spec(self): - return self._table_name_to_feature_spec - - @property - def table_instance_dict(self): - return self._table_instance_dict - - @property - def optimizer_instance(self): - return self._optimizer_instance - - @property - def is_frozen(self): - return self._is_frozen - - @property - def name_to_var_dict(self): - return self._name_to_var_dict - - @property - def use_mpi(self): - return self._use_mpi - - @property - def rank_size(self): - return self._rank_size - - @property - def rank_id(self): - return self._rank_id - - @property - def device_id(self): - if self._rank_id not in self._rank_to_device_dict: - raise KeyError(f"rank id not in rank_to_device_dict. {self._rank_id} {self._rank_to_device_dict}") - return self._rank_to_device_dict[self._rank_id] + @modify_graph.setter + def modify_graph(self, modify_graph): + self._modify_graph = modify_graph @property def train_steps(self): return self._train_steps + @train_steps.setter + def train_steps(self, step: int): + self._train_steps = step + @property def eval_steps(self): return self._eval_steps @@ -224,158 +98,52 @@ class ConfigInitializer: return self._if_load @property - def ascend_global_hashtable_collection(self): - return self._ascend_global_hashtable_collection + def use_static(self): + return self._use_static @property - def dangling_table(self): - return self._dangling_table + def use_dynamic_expansion(self): + return self._use_dynamic_expansion @property - def removing_var_list(self): - return self._removing_var_list + def sparse_embed_config(self): + return self._sparse_embed_config - @staticmethod - def get_instance(): - if ConfigInitializer._single_instance is None: - raise EnvironmentError(GET_CONFIG_INSTANCE_ERR_MSG) + @sparse_embed_config.setter + def sparse_embed_config(self, sparse_emb_config_instance): + self._sparse_embed_config = sparse_emb_config_instance - return ConfigInitializer._single_instance + @property + def feature_spec_config(self): + return self._feature_spec_config - @staticmethod - def set_instance(use_mpi, **kwargs): - if ConfigInitializer._single_instance is not None: - raise EnvironmentError("ConfigInitializer has been initialized once, twice initialization was forbidden.") + @feature_spec_config.setter + def feature_spec_config(self, feature_spec_config_instance): + self._feature_spec_config = feature_spec_config_instance - ConfigInitializer._single_instance = ConfigInitializer(use_mpi, **kwargs) + @property + def hybrid_manager_config(self): + return self._hybrid_manager_config - def check_mpi_params(self): - if self._rank_size < 1: - raise ValueError("The length of the mpi rank_size is less than 1.") - if self._rank_id < 0: - raise ValueError("The length of the mpi rank_id is less than 0.") + @hybrid_manager_config.setter + def hybrid_manager_config(self, hybrid_manager_config_instance): + self._hybrid_manager_config = hybrid_manager_config_instance - def terminate(self): - logger.info("python process run into terminate") - if self._is_terminated: - logger.warning("The initializer has already been released once, please do not release it again.") - return + @property + def optimizer_config(self): + return self._optimizer_config - if self._asc_manager is not None: - self.del_asc_manager() - logger.info("python process run terminate success") + @optimizer_config.setter + def optimizer_config(self, optimizer_config_instance): + self._optimizer_config = optimizer_config_instance - self._is_terminated = True - ConfigInitializer._single_instance = None - - def clear_same_table_feature_spec(self, table_name, is_training): - if self.table_name_to_feature_spec.get(table_name) is None or \ - self.table_name_to_feature_spec.get(table_name).get(is_training) is None: - raise KeyError("The table name `%s` does not exist in table_name_to_feature_spec, " - "please check whether the insert_feature_spec(...) is invoked.", table_name) - self.table_name_to_feature_spec.get(table_name)[is_training] = [] - logger.debug("The feature spec of the table name `%s` has been cleared.", table_name) - - def insert_feature_spec(self, feature, is_training): - self._feature_spec_dict[feature.name] = feature - if feature.table_name not in self._table_name_to_feature_spec: - self._table_name_to_feature_spec[feature.table_name] = {True: [], False: []} - self._table_name_to_feature_spec[feature.table_name][is_training].append(feature) - - def get_feature_spec(self, key): - return self._feature_spec_dict.get(key) - - def insert_training_mode_channel_id(self, is_training): - if is_training not in self._training_mode_channel_dict: - # mx_rec has 2 channel for data input. - # train_model bind to channel TRAIN_CHANNEL_ID - # eval_model bind to channel EVAL_CHANNEL_ID - self._training_mode_channel_dict[is_training] = TRAIN_CHANNEL_ID if is_training else EVAL_CHANNEL_ID - - def get_training_mode_channel_id(self, is_training): - return self._training_mode_channel_dict.get(is_training) - - def insert_dangling_table(self, name): - if name not in self._dangling_table: - self._dangling_table.append(name) - - def insert_removing_var_list(self, name): - if name not in self._removing_var_list: - self._removing_var_list.append(name) - - def insert_table_instance(self, name, key, instance): - if key in self._table_instance_dict: - raise KeyError(f"Given key {key} has been used.") - - if name in self._table_name_set: - raise ValueError(f"Duplicated hashtable name '{name}' was used.") - - logger.debug("Record one hash table, with name: %s, key: %s.", name, key) - self._table_name_set.add(name) - if name not in self._table_name_to_feature_spec: - self._table_name_to_feature_spec[name] = {True: [], False: []} - self._name_to_var_dict[name] = key - self._table_instance_dict[key] = instance - if self.stat_on: - logger.info("[StatInfo] current_table_num %s", len(self._table_instance_dict)) - - def insert_bool_gauge(self, name): - if not isinstance(name, str): - raise TypeError(f"bool gauge name '{name}' should be str.") - - self._bool_gauge_set.add(name) - - def get_table_instance(self, key): - if key not in self._table_instance_dict: - raise KeyError(f"Given key does not exist.") - - return self._table_instance_dict.get(key) - - def get_table_instance_by_name(self, table_name): - if table_name not in self._name_to_var_dict: - raise KeyError(f"Given table name does not exist.") - - key = self._name_to_var_dict.get(table_name) - return self._table_instance_dict.get(key) - - def insert_optimizer(self, optimizer): - self._optimizer_instance = optimizer - - def freeze(self): - self._is_frozen = True - - def unfreeze(self): - self._is_frozen = False - - def set_asc_manager(self, manager): - from mxrec_pybind import HybridMgmt - if not isinstance(manager, HybridMgmt): - raise ValueError(f"Given manager must be the instance of {HybridMgmt}, which is {type(manager)} " - f"type currently.") - self._asc_manager = manager - self.freeze() - - def get_asc_manager(self): - return self._asc_manager - - def del_asc_manager(self): - self.delete_initializers() - self._asc_manager.destroy() - self._asc_manager = None - self.unfreeze() - logger.debug("ASC manager has been destroyed.") - - @iterator_type.setter - def iterator_type(self, iterator_type): - if not isinstance(iterator_type, str): - raise TypeError(f"iterator_type `{iterator_type}` should be str.") - - self._iterator_type = iterator_type + @property + def train_params_config(self): + return self._train_params_config - @train_steps.setter - def train_steps(self, step: int): - check_step(step) - self._train_steps = step + @train_params_config.setter + def train_params_config(self, train_params_config_instance): + self._train_params_config = train_params_config_instance @eval_steps.setter def eval_steps(self, steps): @@ -389,354 +157,44 @@ class ConfigInitializer: @if_load.setter def if_load(self, flag): - if not isinstance(flag, bool): - raise TypeError(f"Flag if load should be a boolean.") - self._if_load = flag - @is_graph_modify_hook_running.setter - def is_graph_modify_hook_running(self, is_hook_running): - if not isinstance(is_hook_running, bool): - raise TypeError(f"is_hook_running should be a boolean.") - - self._is_graph_modify_hook_running = is_hook_running - - @modify_graph.setter - def modify_graph(self, is_modify_graph): - if not isinstance(is_modify_graph, bool): - raise TypeError(f"is_modify_graph should be a boolean.") - - self._modify_graph = is_modify_graph - - @sparse_dir.setter - def sparse_dir(self, sparse_dir): - if not isinstance(sparse_dir, str): - raise TypeError(f"sparse_dir should be str.") - - self._sparse_dir = sparse_dir - - @is_last_round.setter - def is_last_round(self, last_round): - if not isinstance(last_round, bool): - raise TypeError(f"last_round should be a boolean.") - - self._is_last_round = last_round - - @ascend_global_hashtable_collection.setter - def ascend_global_hashtable_collection(self, name): - self._ascend_global_hashtable_collection = name - - def get_initializer(self, is_training): - return self._initializer_dict.get(is_training) - - def set_initializer(self, is_training, initializer): - if not isinstance(is_training, bool): - raise ValueError(f"Given key must be a boolean, but got {is_training}.") - - self._initializer_dict[is_training] = initializer - - def insert_merged_multi_lookup(self, is_training, value=True): - if not isinstance(is_training, bool): - raise TypeError(f"Given key must be a boolean, but got {is_training} for `merged_multi_lookup`.") - - self._merged_multi_lookup[is_training] = value - - def get_merged_multi_lookup(self, is_training): - return self._merged_multi_lookup.get(is_training) - - def set_target_batch(self, is_training, batch): - if not isinstance(is_training, bool): - raise TypeError(f"Given key must be a boolean, but got {is_training} for `target_batch`.") - - self._target_batch[is_training] = batch - - def get_target_batch(self, is_training): - return self._target_batch.get(is_training) - - def delete_initializers(self): - self._initializer_dict = {} - + @staticmethod + def get_instance(): + if ConfigInitializer._single_instance is None: + raise EnvironmentError(GET_CONFIG_INSTANCE_ERR_MSG) -@para_checker_decorator(check_option_list=[ - ("name", ClassValidator, {"classes": (str, type(None))}), - ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length", "check_whitelist"]), -]) -def set_ascend_global_hashtable_collection(name=ASCEND_GLOBAL_HASHTABLE_COLLECTION): - ConfigInitializer.get_instance().ascend_global_hashtable_collection = name + return ConfigInitializer._single_instance + @staticmethod + def set_instance(**kwargs): + if ConfigInitializer._single_instance is not None: + raise EnvironmentError("ConfigInitializer has been initialized once, twice initialization was forbidden.") -def get_ascend_global_hashtable_collection(): - return ConfigInitializer.get_instance().ascend_global_hashtable_collection + ConfigInitializer._single_instance = ConfigInitializer(**kwargs) + def terminate(self): + logger.info("python process run into terminate") + if self._is_terminated: + logger.warning("The initializer has already been released once, please do not release it again.") + return -def check_step(param, min_value=-1): - if not isinstance(param, int): - raise TypeError("Given param must be an integer.") + if self._hybrid_manager_config.asc_manager is not None: + self._hybrid_manager_config.del_asc_manager() + logger.info("python process run terminate success") - if param < min_value: - raise ValueError(f"Valid value range is larger than or equals to {min_value}.") + self._is_terminated = True -def init(use_mpi, **kwargs): +def init(**kwargs): logger.info("The environment variables set for mxRec is: %s", json.dumps(dataclasses.asdict(global_env), ensure_ascii=False)) - ConfigInitializer.set_instance(use_mpi, **kwargs) + from mpi4py import MPI set_ascend_env() + ConfigInitializer.set_instance(**kwargs) atexit.register(terminate_config_initializer) -def get_is_graph_modify_hook_running(): - return ConfigInitializer.get_instance().is_graph_modify_hook_running - - -def set_is_graph_modify_hook_running(is_running): - ConfigInitializer.get_instance().is_graph_modify_hook_running = is_running - - -def get_bool_gauge_set(): - return ConfigInitializer.get_instance().bool_gauge_set - - -def insert_bool_gauge(name): - ConfigInitializer.get_instance().insert_bool_gauge(name) - - -def get_modify_graph(): - return ConfigInitializer.get_instance().modify_graph - - -def set_modify_graph(is_modify_graph): - ConfigInitializer.get_instance().modify_graph = is_modify_graph - - -def set_sparse_dir(sparse_dir): - ConfigInitializer.get_instance().sparse_dir = sparse_dir - - -def get_sparse_dir(): - return ConfigInitializer.get_instance().sparse_dir - - -def get_rank_size(): - return ConfigInitializer.get_instance().rank_size - - -def get_rank_id(): - return ConfigInitializer.get_instance().rank_id - - -def get_device_id(): - return ConfigInitializer.get_instance().device_id - - -def set_asc_manager(manager): - ConfigInitializer.get_instance().set_asc_manager(manager) - - -def get_asc_manager(): - return ConfigInitializer.get_instance().get_asc_manager() - - -def trigger_evict(): - if not is_asc_manager_initialized(): - raise RuntimeError("ASC manager does not exist.") - - if ConfigInitializer.get_instance().get_asc_manager().evict(): - logger.debug("Feature evict is triggered by ops.") - return True - logger.warning("Feature evict not success, skip this time!") - return False - - -def clear_channel(is_train_channel=False): - if not isinstance(is_train_channel, bool): - raise ValueError("Arg is_train_channel should be a boolean.") - channel_id = get_training_mode_channel_id(is_train_channel) - logger.info("clear channel: %s", channel_id) - - return ConfigInitializer.get_instance().host_pipeline_ops.clear_channel(channel_id) - - -def is_asc_manager_initialized(): - return ConfigInitializer.get_instance().get_asc_manager() is not None - - -def get_host_data(table_name): - if not is_asc_manager_initialized(): - raise RuntimeError("ASC manager does not exist.") - logger.debug("start to get host data.") - return ConfigInitializer.get_instance().get_asc_manager().send(table_name) - - -def send_host_data(key_offset_map): - if not is_asc_manager_initialized(): - raise RuntimeError("ASC manager does not exist.") - ConfigInitializer.get_instance().get_asc_manager().receive(key_offset_map) - logger.debug("Data has been send to the host pipeline.") - - -def save_host_data(root_dir): - if not is_asc_manager_initialized(): - raise RuntimeError("ASC manager does not exist.") - - ConfigInitializer.get_instance().get_asc_manager().save(root_dir) - logger.debug("Data from host pipeline has been saved.") - - -def restore_host_data(root_dir): - if not is_asc_manager_initialized(): - raise RuntimeError("ASC manager does not exist.") - - if not ConfigInitializer.get_instance().get_asc_manager().load(root_dir): - raise TypeError("Asc load data does not match usr setups, \ - please re-consider if you want to restore from this dir") - logger.debug("Data from host pipeline has been restored.") - - -def destroy_asc_manager(): - initializer = ConfigInitializer.get_instance() - if initializer.get_asc_manager() is not None: - logger.debug("start destroy asc manager...") - initializer.del_asc_manager() - else: - logger.warning("ASC manager does not exist, please check your code.") - - -def is_asc_frozen(): - return ConfigInitializer.get_instance().is_frozen - - -def export_table_name_set(): - return ConfigInitializer.get_instance().table_name_set - - -def get_host_pipeline_ops(): - return ConfigInitializer.host_pipeline_ops - - -def get_customized_ops(): - return ConfigInitializer.customized_ops - - -def get_train_steps(): - return ConfigInitializer.get_instance().train_steps - - -def get_eval_steps(): - return ConfigInitializer.get_instance().eval_steps - - -def get_save_steps(): - return ConfigInitializer.get_instance().save_steps - - -def set_train_steps(steps: int): - ConfigInitializer.get_instance().train_steps = steps - - -def set_eval_steps(steps: int): - ConfigInitializer.get_instance().eval_steps = steps - - -def set_save_steps(steps: int): - ConfigInitializer.get_instance().save_steps = steps - - -def get_table_instance(key): - return ConfigInitializer.get_instance().get_table_instance(key) - - -def get_table_instance_by_name(table_name): - return ConfigInitializer.get_instance().get_table_instance_by_name(table_name) - - -def insert_dangling_table(table_name): - ConfigInitializer.get_instance().insert_dangling_table(table_name) - - -def insert_removing_var_list(var_name): - ConfigInitializer.get_instance().insert_removing_var_list(var_name) - - -def insert_table_instance(name, key, instance): - ConfigInitializer.get_instance().insert_table_instance(name, key, instance) - - -def export_table_instances(): - return ConfigInitializer.get_instance().table_instance_dict - - -def export_table_num(): - return len(ConfigInitializer.get_instance().table_instance_dict) - - -def export_dangling_table(): - return ConfigInitializer.get_instance().dangling_table - - -def export_removing_var_list(): - return ConfigInitializer.get_instance().removing_var_list - - -def insert_optimizer(optimizer): - ConfigInitializer.get_instance().insert_optimizer(optimizer) - - -def export_optimizer(): - return ConfigInitializer.get_instance().optimizer_instance - - -def insert_feature_spec(feature, is_training): - ConfigInitializer.get_instance().insert_feature_spec(feature, is_training) - - -def get_feature_spec(key): - return ConfigInitializer.get_instance().get_feature_spec(key) - - -def insert_training_mode_channel_id(is_training): - ConfigInitializer.get_instance().insert_training_mode_channel_id(is_training) - - -def get_training_mode_channel_id(is_training): - return ConfigInitializer.get_instance().get_training_mode_channel_id(is_training) - - -def export_feature_spec(): - return ConfigInitializer.get_instance().feature_spec_dict - - -@para_checker_decorator(check_option_list=[ - ("if_load", ClassValidator, {"classes": (bool,)}) -]) -def set_if_load(if_load): - ConfigInitializer.get_instance().if_load = if_load - - -def get_if_load(): - return ConfigInitializer.get_instance().if_load - - -def get_use_static(): - return ConfigInitializer.get_instance().use_static - - -def get_stat_on(): - return ConfigInitializer.get_instance().stat_on - - -def get_use_hot(): - return ConfigInitializer.get_instance().use_hot - - -def get_enable_table_merge(): - return ConfigInitializer.get_instance().enable_table_merge - - -def get_use_dynamic_expansion(): - return ConfigInitializer.get_instance().use_dynamic_expansion - - def terminate_config_initializer(): try: ConfigInitializer.get_instance().terminate() @@ -744,263 +202,3 @@ def terminate_config_initializer(): if GET_CONFIG_INSTANCE_ERR_MSG not in str(err): raise err logger.warning(GET_CONFIG_INSTANCE_ERR_MSG) - - -def get_name_to_var_dict(): - return ConfigInitializer.get_instance().name_to_var_dict - - -@para_checker_decorator(check_option_list=[ - ("is_training", ClassValidator, {"classes": (bool,)}) -]) -def get_initializer(is_training): - return ConfigInitializer.get_instance().get_initializer(is_training) - - -def set_initializer(is_training, initializer): - ConfigInitializer.get_instance().set_initializer(is_training, initializer) - - -@para_checker_decorator(check_option_list=[ - ("name", ClassValidator, {"classes": (str, list)}), - ("name", OptionalStringValidator, {"min_len": 1, "max_len": 255}, ["check_string_length"]) -]) -def set_ascend_table_name_must_contain(name="merged"): - """ - 设置表名中必须包含的关键字 - Args: - name: 表名中必须包含的关键字 - - Returns: None - - """ - mx_rec.constants.constants.ASCEND_TABLE_NAME_MUST_CONTAIN = name - - -def insert_merged_multi_lookup(is_training: bool, value: bool = True): - """ - 记录自动改图模式下是否调用了合并lookup的函数. - Args: - is_training: 当前是否为训练模式,训练模式为True,否则为False - value: 是否调用了合并lookup的函数, 调用了为True,否则为False - Returns: None - """ - ConfigInitializer.get_instance().insert_merged_multi_lookup(is_training, value) - - -def get_merged_multi_lookup(is_training: bool) -> bool: - """ - 返回自动改图模式下是否调用了合并lookup函数的记录. - Args: - is_training: 当前是否为训练模式,训练模式为True,否则为False - Returns: 调用记录,调用了为True,否则为False - """ - return ConfigInitializer.get_instance().get_merged_multi_lookup(is_training) - - -def set_target_batch(is_training: bool, batch: dict): - """ - 记录自动改图模式下生成新数据集中的batch. - Args: - is_training: 当前是否为训练模式,训练模式为True,否则为False - batch: 数据集中的batch - Returns: None - """ - ConfigInitializer.get_instance().set_target_batch(is_training, batch) - - -@para_checker_decorator(check_option_list=[ - ("is_training", ClassValidator, {"classes": (bool, )}) -]) -def get_target_batch(is_training: bool) -> dict: - """ - 返回自动改图模式下生成新数据集中batch的记录. - Args: - is_training: 当前是否为训练模式,训练模式为True,否则为False - Returns: 新数据集中的batch - """ - return ConfigInitializer.get_instance().get_target_batch(is_training) - - -def get_iterator_type() -> str: - """ - 返回数据集的迭代器类型. - Returns: 数据集的迭代器类型 - """ - return ConfigInitializer.get_instance().iterator_type - - -def get_local_rank_size() -> int: - """ - 获取当前worker参与任务的进程数 - Returns: - """ - return ConfigInitializer.get_instance().local_rank_size - - -def set_iterator_type(iterator_type: str): - """ - 记录数据集的迭代器类型. - Args: - iterator_type: 数据集的迭代器类型 - Returns: None - """ - ConfigInitializer.get_instance().iterator_type = iterator_type - - -def get_table_name_to_feature_spec(table_name: str, is_training: bool): - """ - 获取同一张表的所有FeatureSpec - Args: - table_name: 表名 - is_training: 是否为训练模式 - Returns: FeatureSpec列表 - """ - - same_table_feature_spec_dict = ConfigInitializer.get_instance().table_name_to_feature_spec.get(table_name) - return same_table_feature_spec_dict.get(is_training) - - -def clear_same_table_feature_spec(table_name: str, is_training: bool): - """ - 将表对应的feature specs列表清空 - Args: - table_name: 表名 - is_training: 是否为训练模式 - Returns: None - """ - - ConfigInitializer.get_instance().clear_same_table_feature_spec(table_name, is_training) - - -def set_ascend_env(): - """ - 配置昇腾相关的参数和环境变量,生成hccl配置 - """ - rank = get_rank_id() - rank_size = get_rank_size() - - os.environ["MOX_USE_NPU"] = "1" - os.environ["FUSION_TENSOR_SIZE"] = "2000000000" - os.environ["MOX_USE_TF_ESTIMATOR"] = "0" - os.environ["MOX_USE_TDT"] = "1" - os.environ["HEARTBEAT"] = "1" - os.environ["CONITNUE_TRAIN"] = "true" - - os.environ["RANK_ID"] = str(rank) - - device_id = str(get_device_id()) - os.environ["DEVICE_ID"] = device_id - os.environ["ASCEND_DEVICE_ID"] = device_id - os.environ["DEVICE_INDEX"] = device_id - - if global_env.rank_table_file: - os.environ["RANK_SIZE"] = str(rank_size) - os.environ["HCCL_CONNECT_TIMEOUT"] = "1200" - - os.environ["JOB_ID"] = "10086" - os.environ["SOC_VERSION"] = "Ascend910" - os.environ["GE_AICPU_FLAG"] = "1" - os.environ["NEW_GE_FE_ID"] = "1" - os.environ["EXPERIMENTAL_DYNAMIC_PARTITION"] = "1" - os.environ["ENABLE_FORCE_V2_CONTROL"] = "1" - - logger.debug(f"Ascend env has been set.") - - -def get_available_cpu_num_and_range(): - """ - 获取当前环境可用的cpu数量和numa范围 - Returns: - - """ - cpu_available = os.sched_getaffinity(os.getpid()) # 获取可被绑定的核心 - - is_ok = True - cpu_pkg_id_file = "/sys/devices/system/cpu/cpu{}/topology/physical_package_id" - pkg_id2cpu_list = defaultdict(list) - for cpu in cpu_available: - f_path = cpu_pkg_id_file.format(cpu) - if not os.path.exists(f_path): - logger.warning("failed to get numa node of cpu: %s", cpu) - is_ok = False - break - - with open(f_path, "r", encoding="utf-8") as f_in: - # check whether file is valid - file_validator = FileValidator("cpu_topology_file", f_path) - # 1.check whether f_path is soft link - file_validator.check_not_soft_link() - # 2.check file size - file_validator.check_file_size(MAX_CONFIG_SIZE, MIN_SIZE) - file_validator.check() - pkg_id = f_in.readline().strip() - pkg_id2cpu_list[pkg_id].append(cpu) - - def parse_range(cpu_list, cpu_range): - sorted_cpu_list = sorted(cpu_list) - pre_cpu = sorted_cpu_list[0] - cpu_range.append([pre_cpu]) - - for sorted_cpu in sorted_cpu_list[1:]: - if sorted_cpu - pre_cpu != 1: - cpu_range[-1].append(pre_cpu) - cpu_range.append([sorted_cpu]) - pre_cpu = sorted_cpu - - if len(cpu_range[-1]) == 1: - cpu_range[-1].append(pre_cpu) - - valid_cpu_range_list = [] - if is_ok: - logger.info("available numa node num: %s", len(pkg_id2cpu_list)) - for _, part_cpu_list in pkg_id2cpu_list.items(): - parse_range(part_cpu_list, valid_cpu_range_list) - else: - parse_range(list(cpu_available), valid_cpu_range_list) - return len(cpu_available), valid_cpu_range_list - - -def bind_cpu(rank_id: int, local_rank_size: int): - """ - 以均衡的方式为每个进程绑定CPU - :param rank_id:当前进程的rank_id - :param local_rank_size: 当前worker进程数 - :return: - """ - import math - - total_cpu, cpu_range_list = get_available_cpu_num_and_range() - - if local_rank_size <= 0: - logger.error(f"local rank size 's value less than or equal 0.") - return - - avg_count = math.ceil(total_cpu / local_rank_size) - while True: - if avg_count == 0: - logger.warning(f"not enough cpu to bind. cpu num: %s, range: %s", total_cpu, cpu_range_list) - return - - max_split = 0 - for cpu_range in cpu_range_list: - max_split += (cpu_range[1] - cpu_range[0] + 1) // avg_count - if max_split >= local_rank_size: - break - avg_count -= 1 - - candidate_list = [] - for cpu_range in cpu_range_list: - start = cpu_range[0] - splits = (cpu_range[1] - cpu_range[0] + 1) // avg_count - candidate_range = [list(range(start + i * avg_count, start + ((i + 1) * avg_count))) for i in range(splits)] - candidate_list.extend(candidate_range) - - cpu_list = candidate_list[rank_id] - - process = psutil.Process() - try: - process.cpu_affinity(cpu_list) - except IndexError: - logger.error("failed to bind cpu for rank %s: %s", rank_id, cpu_list) - logger.info("bind cpu for rank %s: %s", rank_id, cpu_list) diff --git a/mx_rec/util/perf_factory/__init__.py b/mx_rec/util/perf_factory/__init__.py new file mode 100644 index 00000000..d884330e --- /dev/null +++ b/mx_rec/util/perf_factory/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + +__all__ = ["bind_cpu"] \ No newline at end of file diff --git a/mx_rec/util/perf_factory/bind_cpu.py b/mx_rec/util/perf_factory/bind_cpu.py new file mode 100644 index 00000000..a9be6846 --- /dev/null +++ b/mx_rec/util/perf_factory/bind_cpu.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. +import os +from collections import defaultdict +import psutil + +from mx_rec.constants.constants import MIN_SIZE, MAX_CONFIG_SIZE +from mx_rec.util.communication.hccl_ops import get_local_rank_size, get_rank_id +from mx_rec.validator.validator import FileValidator +from mx_rec.util.log import logger + + +def get_available_cpu_num_and_range(): + """ + 获取当前环境可用的cpu数量和numa范围 + Returns: cpu的数量,numa范围 + + """ + cpu_available = os.sched_getaffinity(os.getpid()) # 获取可被绑定的核心 + + is_ok = True + cpu_pkg_id_file = "/sys/devices/system/cpu/cpu{}/topology/physical_package_id" + pkg_id2cpu_list = defaultdict(list) + for cpu in cpu_available: + f_path = cpu_pkg_id_file.format(cpu) + if not os.path.exists(f_path): + logger.warning("failed to get numa node of cpu: %s", cpu) + is_ok = False + break + + with open(f_path, "r", encoding="utf-8") as f_in: + # check whether file is valid + file_validator = FileValidator("cpu_topology_file", f_path) + # 1.check whether f_path is soft link + file_validator.check_not_soft_link() + # 2.check file size + file_validator.check_file_size(MAX_CONFIG_SIZE, MIN_SIZE) + file_validator.check() + pkg_id = f_in.readline().strip() + pkg_id2cpu_list[pkg_id].append(cpu) + + def parse_range(cpu_list, cpu_range): + sorted_cpu_list = sorted(cpu_list) + pre_cpu = sorted_cpu_list[0] + cpu_range.append([pre_cpu]) + + for sorted_cpu in sorted_cpu_list[1:]: + if sorted_cpu - pre_cpu != 1: + cpu_range[-1].append(pre_cpu) + cpu_range.append([sorted_cpu]) + pre_cpu = sorted_cpu + + if len(cpu_range[-1]) == 1: + cpu_range[-1].append(pre_cpu) + + valid_cpu_range_list = [] + if is_ok: + logger.info("available numa node num: %s", len(pkg_id2cpu_list)) + for _, part_cpu_list in pkg_id2cpu_list.items(): + parse_range(part_cpu_list, valid_cpu_range_list) + else: + parse_range(list(cpu_available), valid_cpu_range_list) + return len(cpu_available), valid_cpu_range_list + + +def bind_cpu_task(): + """ + 为每个进程绑定CPU + """ + import math + + total_cpu, cpu_range_list = get_available_cpu_num_and_range() + local_rank_size = get_local_rank_size() + if local_rank_size <= 0: + logger.error(f"local rank size 's value less than or equal 0.") + return + + avg_count = math.ceil(total_cpu / local_rank_size) + while True: + if avg_count == 0: + logger.warning(f"not enough cpu to bind. cpu num: %s, range: %s", total_cpu, cpu_range_list) + return + + max_split = 0 + for cpu_range in cpu_range_list: + max_split += (cpu_range[1] - cpu_range[0] + 1) // avg_count + if max_split >= local_rank_size: + break + avg_count -= 1 + + candidate_list = [] + for cpu_range in cpu_range_list: + start = cpu_range[0] + splits = (cpu_range[1] - cpu_range[0] + 1) // avg_count + candidate_range = [list(range(start + i * avg_count, start + ((i + 1) * avg_count))) for i in range(splits)] + candidate_list.extend(candidate_range) + + rank_id = get_rank_id() + cpu_list = candidate_list[rank_id] + + process = psutil.Process() + try: + process.cpu_affinity(cpu_list) + except IndexError: + logger.error("failed to bind cpu for rank %s: %s", rank_id, cpu_list) + logger.info("bind cpu for rank %s: %s", rank_id, cpu_list) + + +def bind_cpu(func): + def wrapper(*args, **kwargs): + func(*args, **kwargs) + if kwargs.get("bind_cpu"): + bind_cpu_task() + + return wrapper diff --git a/mx_rec/util/tf_version_adapter.py b/mx_rec/util/tf_version_adapter.py index a0c60b72..49e210e8 100644 --- a/mx_rec/util/tf_version_adapter.py +++ b/mx_rec/util/tf_version_adapter.py @@ -31,3 +31,8 @@ if tf.__version__.startswith("1"): from npu_bridge.estimator.npu.npu_hook import NPUCheckpointSaverHook else: from npu_device.compat.v1.estimator.npu.npu_hook import NPUCheckpointSaverHook + +if tf.__version__.startswith("1"): + from npu_bridge.estimator.npu.npu_loss_scale_optimizer import NPULossScaleOptimizer +else: + from npu_device.train.optimizer.npu_loss_scale_optimizer import NpuLossScaleOptimizer as NPULossScaleOptimizer diff --git a/mx_rec/util/variable.py b/mx_rec/util/variable.py index ff8f6989..2c9f49a9 100644 --- a/mx_rec/util/variable.py +++ b/mx_rec/util/variable.py @@ -16,21 +16,21 @@ # ============================================================================== import tensorflow as tf -from tensorflow.python.framework import ops -from mx_rec.util.initialize import get_table_instance, get_ascend_global_hashtable_collection +from mx_rec.util.initialize import ConfigInitializer def get_dense_and_sparse_variable(): dense_variables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) - sparse_variables = tf.compat.v1.get_collection(get_ascend_global_hashtable_collection()) + sparse_variables = tf.compat.v1.get_collection( + ConfigInitializer.get_instance().train_params_config.ascend_global_hashtable_collection) return dense_variables, sparse_variables def check_and_get_config_via_var(variable, optimizer_type: str): - table_instance = get_table_instance(variable) + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(variable) - if not table_instance.skip_emb_transfer and not table_instance.optimizer: + if not table_instance.is_hbm and not table_instance.optimizer: raise EnvironmentError(f"When ASC with DDR, you must pass the '{optimizer_type}' optimizer instances to the" f" init method of SparseEmbedding.") diff --git a/mx_rec/validator/emb_validator.py b/mx_rec/validator/emb_validator.py new file mode 100644 index 00000000..c9d18f05 --- /dev/null +++ b/mx_rec/validator/emb_validator.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +from typing import Union, Optional + +import tensorflow as tf + +from mx_rec.constants.constants import MAX_VOCABULARY_SIZE, MULTI_LOOKUP_TIMES +from mx_rec.core.asc.feature_spec import FeatureSpec +from mx_rec.util.communication.hccl_ops import get_rank_size +from mx_rec.util.initialize import ConfigInitializer + + +def check_emb_init_params(is_hbm: bool, embedding_size: tf.TensorShape): + """ + 校验稀疏表的初始化参数. + + Args: + is_hbm: 是否为HBM模式 + embedding_size: 稀疏表维度大小 + + Returns: None + """ + if ConfigInitializer.get_instance().hybrid_manager_config.freeze: + raise EnvironmentError("Emb cache management has been established, you cannot build new hash table.") + + if not is_hbm and ConfigInitializer.get_instance().use_dynamic_expansion: + raise ValueError("DDR/SSD mode do not support embedding dynamic expansion for now.") + + if embedding_size.ndims != 1: + raise ValueError("Parameter 'embedding_size' can only be one dim shape.") + + +def check_emb_lookup_params(table_params: dict, feature_spec: Union[tf.Tensor, FeatureSpec], send_count: Optional[int], + is_training: bool): + """ + 校验稀疏表此次lookup的参数. + + Args: + table_params: 稀疏表参数字典 + feature_spec: 稀疏表次数lookup的tensor或tensor的包装类 + send_count: all2all通信参数 + is_training: 当前流程是训练还是推理 + + Returns: None + """ + # check FeatureSpec + if isinstance(feature_spec, FeatureSpec): + if not feature_spec.initialized: + raise RuntimeError("Feature Spec has not been initialized.") + if is_training not in feature_spec.pipeline_mode: + raise RuntimeError(f"You have not config feature for is training mode '{is_training}', please config " + f"feature with func sparse_lookup at first.") + + # check max vocabulary size + slice_device_vocabulary_size = table_params.get("slice_device_vocabulary_size") + slice_host_vocabulary_size = table_params.get("slice_host_vocabulary_size") + table_name = table_params.get("table_name") + if slice_host_vocabulary_size + slice_device_vocabulary_size > MAX_VOCABULARY_SIZE: + raise ValueError(f"Given device_vocabulary_size and host_vocabulary_size was too big for table " + f"'{table_name}', in which slice_device_vocabulary_size was " + f"{slice_device_vocabulary_size} and slice_host_vocabulary_size was " + f"{slice_host_vocabulary_size}.") + + if not ConfigInitializer.get_instance().use_static: + return + + # check send count + if not (isinstance(send_count, int) and send_count > 0): + raise ValueError("Send count must be a integer which is larger than 0.") + + if table_params.get("is_hbm") or ConfigInitializer.get_instance().use_dynamic_expansion: + return + + # check vocabulary size with send count + rank_size = get_rank_size() + if slice_device_vocabulary_size < send_count * rank_size: + raise ValueError(f"Given device_vocabulary_size was too small for table '{table_name}', " + f"in which slice_device_vocabulary_size was {slice_device_vocabulary_size} " + f"and send_count({send_count}) * rank_size({rank_size}) was " + f"{send_count * rank_size}.") + + if slice_host_vocabulary_size < send_count * rank_size: + raise ValueError(f"Given host_vocabulary_size was too small for table '{table_name}', " + f"in which slice_host_vocabulary_size was {slice_host_vocabulary_size} " + f"and send_count({send_count}) * rank_size({rank_size}) was " + f"{send_count * rank_size}.") + + +def check_emb_multi_lookup_times(lookup_times: int, table_name: str): + """ + 校验稀疏表一表多查的次数. + + Args: + lookup_times: 稀疏表lookup的次数 + table_name: 稀疏表名 + + Returns: None + """ + if lookup_times > MULTI_LOOKUP_TIMES: + raise RuntimeError(f"The number of multiple sparse lookup for a table ({table_name}) is " + f"{MULTI_LOOKUP_TIMES}, and current times is {lookup_times}.") diff --git a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp index ce176bfd..0ca41e29 100644 --- a/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp +++ b/src/core/ckpt_data_handler/host_emb_ckpt/host_emb_ckpt.cpp @@ -169,4 +169,4 @@ size_t HostEmbCkpt::GetEmbDataRows(string embName) transferData.attributeSize = transferData.attribute.size() * eightBytes; return embDataOuterSize; -} \ No newline at end of file +} diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index 9bdd5b04..c380aab2 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -19,6 +19,7 @@ See the License for the specific language governing permissions and #include "hybrid_mgmt/hybrid_mgmt_block.h" #include "utils/common.h" +#include "emb_table/embedding_mgmt.h" using namespace MxRec; @@ -61,54 +62,45 @@ void EmbHashMap::Process(const string& embName, vector& keys, DDRPara #ifndef GTEST EASY_FUNCTION(profiler::colors::Pink) TimeCost swapTimeCost; - auto it = embHashMaps.find(embName); - if (it == embHashMaps.end()) { - throw runtime_error("table not exist in embHashMaps"); - } - auto &embHashMap = it->second; - embHashMap.devOffset2KeyOld.clear(); - embHashMap.oldSwap.clear(); - embHashMap.maxOffsetOld = embHashMap.maxOffset; - - auto keepBatch = swapId; // 处理batch的次数,多个预取一起处理算一次 + std::shared_ptr table = EmbeddingMgmt::Instance()->GetTable(embName); - // 找到所有key的偏移;dev和host需要交换的位置 - FindOffset(embName, keys, swapId, keepBatch, channelId); - LOG_DEBUG("FindOffset end"); - - // 调用刷新频次数据方法 - RefreshFreqInfoWithSwap(embName, embHashMap); + int32_t keepBatch = swapId; // 处理batch的次数,多个预取一起处理算一次 + vector swapPos; + vector lookUpVec = table->FindOffset(keys, swapId, channelId, swapPos); EASY_BLOCK("hostHashMaps->tdt") - std::copy(embHashMap.lookUpVec.begin(), embHashMap.lookUpVec.end(), std::back_inserter(ddrParam.offsetsOut)); + std::copy(lookUpVec.begin(), lookUpVec.end(), std::back_inserter(ddrParam.offsetsOut)); // 构造查询向量tensor - auto lookUpVecSize = static_cast(embHashMap.lookUpVec.size()); + int lookUpVecSize = static_cast(lookUpVec.size()); ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { lookUpVecSize })); auto lookupTensorData = ddrParam.tmpDataOut.back().flat(); for (int i = 0; i < lookUpVecSize; i++) { - lookupTensorData(i) = static_cast(embHashMap.lookUpVec[i]); + lookupTensorData(i) = static_cast(lookUpVec[i]); } - LOG_TRACE("lookupTensor, {}", VectorToString(embHashMap.lookUpVec)); + LOG_TRACE("lookupTensor, {}", VectorToString(lookUpVec)); // 构造交换向量tensor - auto swapSize = static_cast(embHashMap.swapPos.size()); + int swapSize = static_cast(swapPos.size()); ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { swapSize })); auto swapTensorData = ddrParam.tmpDataOut.back().flat(); for (int i = 0; i < swapSize; i++) { - swapTensorData(i) = static_cast(embHashMap.swapPos[i]); + swapTensorData(i) = static_cast(swapPos[i]); } if (swapSize > 0) { LOG_DEBUG("swap num: {}", swapSize); } - LOG_TRACE("swapTensor, {}", VectorToString(embHashMap.swapPos)); + + LOG_TRACE("swapTensor, {}", VectorToString(swapPos)); // 清空本次记录的查询偏移和交换偏移 - ClearLookupAndSwapOffset(embHashMap); - LOG_INFO("current ddr emb:{}, usage:{}/[{}+{}]", embName, embHashMap.maxOffset, - embHashMap.devVocabSize, embHashMap.hostVocabSize); + table->ClearLookupAndSwapOffset(); + + LOG_INFO("current ddr emb:{}, usage:{}/[{}+{}]", embName, table->GetMaxOffset(), + table->GetDevVocabSize(), table->GetHostVocabSize()); + ddrParam.tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto swapLen = ddrParam.tmpDataOut.back().flat(); swapLen(0) = swapSize; @@ -123,6 +115,7 @@ void EmbHashMap::Process(const string& embName, vector& keys, DDRPara #endif } + auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map { LOG_DEBUG(HYBRID_BLOCKING + " start GetHashMaps"); @@ -152,7 +145,7 @@ auto EmbHashMap::GetHashMaps() -> absl::flat_hash_map return embHashMapsOld; } // 此时需要回退2步,无法满足此条件,保存的东西错误,直接回退 - if (not rankInfo.noDDR) { + if (rankInfo.isDDR) { throw HybridMgmtBlockingException("EmbHashMap::GetHashMaps() "); } return embHashMapsOld; diff --git a/src/core/emb_hashmap/emb_hashmap.h b/src/core/emb_hashmap/emb_hashmap.h index 3ad51442..96a75e54 100644 --- a/src/core/emb_hashmap/emb_hashmap.h +++ b/src/core/emb_hashmap/emb_hashmap.h @@ -40,16 +40,6 @@ namespace MxRec { void LoadHashMap(absl::flat_hash_map& loadData); - const std::vector& GetMissingKeys(const string& embName) - { - return embHashMaps.at(embName).missingKeysHostPos; - } - - void ClearMissingKeys(const string& embName) - { - embHashMaps.at(embName).missingKeysHostPos.clear(); - } - void EvictDeleteEmb(const string& embName, const vector& keys); absl::flat_hash_map embHashMaps; diff --git a/src/core/emb_table/embedding_ddr.cpp b/src/core/emb_table/embedding_ddr.cpp new file mode 100644 index 00000000..4a4b09a6 --- /dev/null +++ b/src/core/emb_table/embedding_ddr.cpp @@ -0,0 +1,567 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: EmbeddingDDR DDR 模式embedding表实现 + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#include "emb_table/embedding_ddr.h" +#include "utils/logger.h" +#include "utils/singleton.h" +#include "host_emb/host_emb.h" +#include "hd_transfer/hd_transfer.h" +#include "file_system/file_system_handler.h" + +using namespace MxRec; + +constexpr int ELEMENT_NUM = 4; +constexpr int CURRENT_UPDATE_IDX = 0; +constexpr int HOST_VOCAB_SIZE_IDX = 1; +constexpr int DEV_VOCAB_SIZE_IDX = 2; +constexpr int MAX_OFFSET_IDX = 3; + +constexpr int EMB_INFO_ELEMENT_NUM = 3; +constexpr int EMB_INFO_EXT_SIZE_IDX = 0; +constexpr int EMB_INFO_DEV_VOCAB_SIZE_IDX = 1; +constexpr int EMB_INFO_HOST_VOCAB_SIZE_IDX = 2; + +EmbeddingDDR::EmbeddingDDR() +{ +} + +EmbeddingDDR::EmbeddingDDR(const EmbInfo& info, const RankInfo& rankInfo, int inSeed) + : EmbeddingTable(info, rankInfo, inSeed) +{ + LOG_INFO("Init DDR table [{}] devVocabSize = {} hostVocabSize = {}", name_, devVocabSize_, hostVocabSize_); + currentUpdatePos = 0; + devOffset2Key.resize(devVocabSize_); + devOffset2Batch.resize(devVocabSize_); + std::fill(devOffset2Batch.begin(), devOffset2Batch.end(), -1); + std::fill(devOffset2Key.begin(), devOffset2Key.end(), -1); +} + +EmbeddingDDR::~EmbeddingDDR() +{ +} + +void EmbeddingDDR::Key2Offset(std::vector& splitKey, int channel) +{ +} + +int64_t EmbeddingDDR::capacity() const +{ + return capacity_; +} + +std::vector EmbeddingDDR::FindOffset(const vector& keys, + size_t batchId, int channelId, + std::vector& swapPos) +{ + devOffset2KeyOld.clear(); + oldSwap.clear(); + maxOffsetOld = maxOffset_; + + UpdateBatchId(keys, batchId); + std::vector lookUpVec; + for (size_t i = 0; i < keys.size(); i++) { + emb_key_t key = keys[i]; + if (key == INVALID_KEY_VALUE) { + lookUpVec.emplace_back(INVALID_KEY_VALUE); + continue; + } + emb_key_t offset = FindOffsetHelper(key, channelId); + if (offset == INVALID_KEY_VALUE) { + lookUpVec.emplace_back(INVALID_KEY_VALUE); + continue; + } + if (offset < devVocabSize_) { + // 偏移小于等于HBM容量:直接放入查询向量;更新偏移之前关联的key和当前关联的key + lookUpVec.push_back(offset); + devOffset2KeyOld.emplace_back(offset, static_cast(devOffset2Key[offset])); + devOffset2Key[offset] = key; + } else { + // 偏移大于HBM容量:记录在host emb上的偏移;找到需要交换的HBM偏移 + missingKeysHostPos_.emplace_back(offset - devVocabSize_); + offset = FindSwapPosOld(key, offset, batchId, swapPos); + lookUpVec.emplace_back(offset); + } + } + if (batchId == 0) { + LOG_INFO("max offset {}", maxOffset_); + } + LOG_TRACE("keyOffsetMap_, {}", MapToString(keyOffsetMap_)); + return lookUpVec; +} + +emb_key_t EmbeddingDDR::FindOffsetHelper(const emb_key_t& key, int channelId) +{ + const auto& iter = keyOffsetMap_.find(key); + emb_key_t offset = INVALID_KEY_VALUE; + if (iter != keyOffsetMap_.end()) { + offset = iter->second; + LOG_TRACE("devVocabSize, {} , offset , {}", devVocabSize_, offset); + if (offset >= devVocabSize_) { + ddr2HbmKeys.emplace_back(key); + } + return offset; + } + if (channelId != TRAIN_CHANNEL_ID) { + return offset; + } + if (evictPos_.size() != 0) { // 优先复用hbm表 + offset = evictPos_.back(); + keyOffsetMap_[key] = offset; + LOG_TRACE("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", + key, offset, evictPos_.size()); + evictPos_.pop_back(); + LOG_ERROR("dev evicted offset = {}", offset); + return offset; + } + + if (evictHostPos_.size() != 0) { // hbm不足,再复用host/ddr表 + offset = evictHostPos_.back(); + keyOffsetMap_[key] = offset; + LOG_TRACE("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", + key, offset, evictHostPos_.size()); + evictHostPos_.pop_back(); + LOG_ERROR("host evicted offset = {}", offset); + return offset; + } + keyOffsetMap_[key] = maxOffset_; + offset = maxOffset_; + maxOffset_++; + if (maxOffset_ == devVocabSize_) { + LOG_INFO("start using host vocab!"); + } + if (maxOffset_ > (hostVocabSize_ + devVocabSize_)) { + LOG_ERROR("hostVocabSize too small! dev:{} host:{}", devVocabSize_, hostVocabSize_); + throw runtime_error("hostVocabSize too small"); + } + return offset; +} + +void EmbeddingDDR::UpdateBatchId(const vector& keys, size_t currentBatchId) +{ + for (size_t i = 0; i < keys.size(); i++) { + size_t offset; + emb_key_t key = keys[i]; + if (key == -1) { + continue; + } + const auto& iter = keyOffsetMap_.find(key); + if (iter != keyOffsetMap_.end()) { + offset = iter->second; + + LOG_TRACE("key will be used, {} , offset , {}", key, offset); + if (offset < devVocabSize_) { + // devOffset2Batch size equal to devVocabSize, unnecessary to check index boundary + devOffset2Batch[offset] = static_cast(currentBatchId); + } + } + } +} + +/// 利用devOffset2Batch上key最近使用的batchId,来选择需要淘汰的key,记录淘汰位置和device侧所需的keys +/// \param embName 表名 +/// \param key 输入特征 +/// \param hostOffset 全局偏移 +/// \param currentBatchId 已处理的batch数 +/// \param keepBatchId 处理batch的次数,多个预取一起处理算一次 +/// \return 是否找到需要交换的位置 +emb_key_t EmbeddingDDR::FindSwapPosOld(emb_key_t key, size_t hostOffset, size_t batchId, + std::vector& swapPos) +{ + bool notFind = true; + emb_key_t offset = INVALID_KEY_VALUE; + while (notFind) { + // 找到本次预取之前的偏移(保证所有预取batch的key都在HBM中) + if (currentUpdatePos >= devOffset2Batch.size()) { + LOG_ERROR("outofrange {} >= {}", currentUpdatePos, devOffset2Batch.size()); + throw runtime_error("currentUpdatePos out of range"); + } + + if (devOffset2Batch[currentUpdatePos] < static_cast(batchId)) { + devOffset2Batch[currentUpdatePos] = static_cast(batchId); + swapPos.emplace_back(currentUpdatePos); // 记录需要被换出的HBM偏移 + offset = currentUpdatePos; + keyOffsetMap_[key] = currentUpdatePos; // 更新key对应的HBM偏移 + // 记录HBM偏移之前的key + devOffset2KeyOld.emplace_back(currentUpdatePos, devOffset2Key[currentUpdatePos]); + auto& oldKey = devOffset2Key[currentUpdatePos]; + oldSwap.emplace_back(oldKey, key); // 记录交换的两个key oldKey:HBM->DDR key:DDR->HBM + keyOffsetMap_[oldKey] = hostOffset; // 更新被替换的key的偏移 + oldKey = key; + notFind = false; + } + currentUpdatePos++; // 查找位置+1 + freeSize_--; // HBM可用空间-1 + + // 遍历完一遍整个HBM表后,从头开始遍历 + if (currentUpdatePos == devVocabSize_) { + currentUpdatePos = 0; + } + + /** + * currentUpdatePos已经绕了HBM一圈 + * 已经找完整个HBM空间,且没找到可用位置,表示HBM空间不足以放下整个batch(预取batch数)的key, + * 无法正常执行训练,故运行时错误退出 + */ + if (currentUpdatePos == currentUpdatePosStart && notFind) { + LOG_ERROR("devVocabSize is too small"); + throw runtime_error("devVocabSize is too small"); + } + } + return offset; +} + +/* +* 删除淘汰key的映射关系,并将其offset更新到evictPos,待后续复用 +*/ +void EmbeddingDDR::EvictDeleteEmb(const vector& keys) +{ + EASY_FUNCTION() + size_t keySize = keys.size(); + vector evictHBMKeys; + vector evictDDRKeys; + for (size_t i = 0; i < keySize; ++i) { + size_t offset; + emb_key_t key = keys[i]; + if (key == INVALID_KEY_VALUE) { + LOG_WARN("evict key equal -1!"); + continue; + } + const auto& iter = keyOffsetMap_.find(key); + if (iter == keyOffsetMap_.end()) { + // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 + continue; + } + offset = iter->second; + keyOffsetMap_.erase(iter); + LOG_TRACE("evict embName {}, offset {}", name_, offset); + + if (offset < devVocabSize_) { + // offset 在device中 + devOffset2Batch[offset] = -1; + devOffset2KeyOld.emplace_back(offset, devOffset2Key[offset]); + devOffset2Key[offset] = -1; + evictPos_.emplace_back(offset); + evictHBMKeys.emplace_back(key); + } else { + // offset 在Host + evictHostPos_.emplace_back(offset); + evictDDRKeys.emplace_back(key); // 删除映射表、初始化host表、发送dev淘汰位置 + } + } + + LOG_INFO("ddr EvictDeleteEmb, emb: [{}], hostEvictSize: {}, devEvictSize: {}", + name_, evictPos_.size(), evictHostPos_.size()); + LOG_TRACE("keyOffsetMap_, {}", MapToString(keyOffsetMap_)); +} + +/// DDR模式下的淘汰:删除映射表、初始化host表、发送dev淘汰位置 +/// \param embName +/// \param keys +void EmbeddingDDR::EvictKeys(const vector& keys) +{ + EASY_FUNCTION() + for (const emb_key_t& key : keys) { + size_t offset; + if (key == INVALID_KEY_VALUE) { + LOG_WARN("evict key equal -1!"); + continue; + } + const auto& iter = keyOffsetMap_.find(key); + if (iter == keyOffsetMap_.end()) { + continue; + } + // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 + offset = iter->second; + keyOffsetMap_.erase(iter); + LOG_TRACE("evict embName {}, offset {}", name_, offset); + + if (offset < devVocabSize_) { + devOffset2Batch[offset] = INVALID_KEY_VALUE; + devOffset2KeyOld.emplace_back(offset, devOffset2Key[offset]); + devOffset2Key[offset] = INVALID_KEY_VALUE; + evictPos_.emplace_back(offset); + } else { + evictHostPos_.emplace_back(offset); + } + } +} + +void EmbeddingDDR::ClearLookupAndSwapOffset() +{ + ddr2HbmKeys.clear(); +} + +void EmbeddingDDR::SetStartCount() +{ + currentUpdatePosStart = currentUpdatePos; + freeSize_ = devVocabSize_; +} + +int EmbeddingDDR::Load(const string& savePath) +{ + LoadHashMap(savePath); + LoadDevOffset(savePath); + LoadCurrStat(savePath); + LoadEvictPos(savePath); + LoadEmbInfo(savePath); + LoadEmbData(savePath); +} + +int EmbeddingDDR::Save(const string& savePath) +{ + SaveHashMap(savePath); + SaveDevOffset(savePath); + SaveCurrStat(savePath); + SaveEvictPos(savePath); + SaveEmbInfo(savePath); + SaveEmbData(savePath); +} + +int EmbeddingDDR::LoadHashMap(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_hashmap/slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + size_t fileSize = 0; + try { + fileSize = fileSystemPtr->GetFileSize(ss.str()); + } catch (exception& e) { + LOG_ERROR("open file {} failed:{}", ss.str(), strerror(errno)); + return -1; + } + if (fileSize >= FILE_MAX_SIZE) { + LOG_ERROR("file {} size = {} is too big", ss.str(), fileSize); + return -1; + } + + int64_t* buf = static_cast(malloc(fileSize)); + if (buf == nullptr) { + LOG_ERROR("malloc failed: {}", strerror(errno)); + return -1; + } + fileSystemPtr->Read(ss.str(), reinterpret_cast(buf), fileSize); + for (int i = 0; i < fileSize / sizeof(int64_t); i = i + 2) { // key, offset进行pair对存储 + keyOffsetMap_[buf[i]] = buf[i + 1]; + } + free(static_cast(buf)); + return 0; +} + +int EmbeddingDDR::LoadDevOffset(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/" << name_ <<"/dev_offset_2_Batch_n_Key/slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + size_t fileSize = 0; + try { + fileSize = fileSystemPtr->GetFileSize(ss.str()); + } catch (exception& e) { + LOG_ERROR("open file {} failed:{}", ss.str(), strerror(errno)); + return -1; + } + if (fileSize >= FILE_MAX_SIZE) { + LOG_ERROR("file {} size = {} is too big", ss.str(), fileSize); + return -1; + } + + devOffset2Key.resize(fileSize / sizeof(emb_key_t)); + fileSystemPtr->Read(ss.str(), reinterpret_cast(devOffset2Key.data()), fileSize); + return 0; +} + +int EmbeddingDDR::LoadCurrStat(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_current_status/slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + size_t raw[ELEMENT_NUM] = {0}; + fileSystemPtr->Read(ss.str(), reinterpret_cast(raw), sizeof(raw)); + currentUpdatePos = raw[CURRENT_UPDATE_IDX]; + hostVocabSize_ = raw[HOST_VOCAB_SIZE_IDX]; + devVocabSize_ = raw[MAX_OFFSET_IDX]; + maxOffset_ = raw[MAX_OFFSET_IDX]; + return 0; +} + +int EmbeddingDDR::LoadEvictPos(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/" << name_ <<"/evict_pos/slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + size_t fileSize = 0; + try { + fileSize = fileSystemPtr->GetFileSize(ss.str()); + } catch (exception& e) { + LOG_ERROR("open file {} failed:{}", ss.str(), strerror(errno)); + return -1; + } + if (fileSize >= FILE_MAX_SIZE) { + LOG_ERROR("File {} size = {} is too big", ss.str(), fileSize); + return -1; + } + evictPos_.resize(fileSize / sizeof(int64_t)); + + fileSystemPtr->Read(ss.str(), reinterpret_cast(evictPos_.data()), fileSize); + return 0; +} + +int EmbeddingDDR::LoadEmbInfo(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_info/slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + size_t raw[EMB_INFO_ELEMENT_NUM] = {0}; + fileSystemPtr->Read(ss.str(), reinterpret_cast(raw), sizeof(raw)); + extEmbSize_ = raw[EMB_INFO_EXT_SIZE_IDX]; + devVocabSize_ = raw[EMB_INFO_DEV_VOCAB_SIZE_IDX]; + hostVocabSize_ = raw[EMB_INFO_HOST_VOCAB_SIZE_IDX]; + return 0; +} + +int EmbeddingDDR::LoadEmbData(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_data/slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + HostEmb* hostEmbs = Singleton::GetInstance(); + HostEmbTable& table = hostEmbs->GetEmb(name_); + if (table.embData.empty()) { + LOG_ERROR("hostEmb data is empty"); + return -1; + } + fileSystemPtr->Read(ss.str(), table.embData); + return 0; +} + +int EmbeddingDDR::SaveHashMap(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_hashmap/"; + MakeDir(ss.str()); + ss << "slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + vector raw; + for (const auto& it : keyOffsetMap_) { + raw.push_back(it.first); + raw.push_back(static_cast(it.second)); + } + fileSystemPtr->Write(ss.str(), reinterpret_cast(raw.data()), + static_cast(raw.size() * sizeof(int64_t))); + return 0; +} + +int EmbeddingDDR::SaveDevOffset(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/" << name_ <<"/dev_offset_2_Batch_n_Key/"; + MakeDir(ss.str()); + ss << "slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + fileSystemPtr->Write(ss.str(), reinterpret_cast(devOffset2Key.data()), + static_cast(devOffset2Key.size() * sizeof(emb_key_t))); + return 0; +} + +int EmbeddingDDR::SaveCurrStat(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/"<< name_ <<"/embedding_current_status/"; + MakeDir(ss.str()); + ss << "slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + size_t raw[ELEMENT_NUM] = {0}; + raw[CURRENT_UPDATE_IDX] = currentUpdatePos; + raw[HOST_VOCAB_SIZE_IDX] = hostVocabSize_; + raw[DEV_VOCAB_SIZE_IDX] = devVocabSize_; + raw[MAX_OFFSET_IDX] = maxOffset_; + fileSystemPtr->Write(ss.str(), reinterpret_cast(raw), sizeof(raw)); + return 0; +} + +int EmbeddingDDR::SaveEvictPos(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/" << name_ << "/evict_pos/"; + MakeDir(ss.str()); + ss << "slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + fileSystemPtr->Write(ss.str(), reinterpret_cast(evictPos_.data()), + static_cast(evictPos_.size() * sizeof(int64_t))); + return 0; +} + +int EmbeddingDDR::SaveEmbInfo(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/"<< name_ <<"/embedding_info/"; + MakeDir(ss.str()); + ss << "slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + size_t raw[EMB_INFO_ELEMENT_NUM] = {}; + raw[EMB_INFO_EXT_SIZE_IDX] = extEmbSize_; + raw[EMB_INFO_DEV_VOCAB_SIZE_IDX] = devVocabSize_; + raw[EMB_INFO_HOST_VOCAB_SIZE_IDX] = hostVocabSize_; + fileSystemPtr->Write(ss.str(), reinterpret_cast(raw), sizeof(raw)); + return 0; +} + +int EmbeddingDDR::SaveEmbData(const string& savePath) +{ + stringstream ss; + ss << savePath << "/HashTable/DDR/"<< name_ <<"/embedding_data/"; + MakeDir(ss.str()); + ss << "slice_" << rankId_ << ".data"; + + unique_ptr fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); + + HostEmb* hostEmbs = Singleton::GetInstance(); + HostEmbTable& table = hostEmbs->GetEmb(name_); + if (table.embData.empty()) { + LOG_ERROR("host embedding data is empty"); + return 0; + } + vector content; + for (vector& emb : table.embData) { + content.push_back(emb.data()); + } + size_t dataSize = table.embData[0].size(); + fileSystemPtr->Write(ss.str(), content, dataSize * sizeof(float)); + return 0; +} diff --git a/src/core/emb_table/embedding_ddr.h b/src/core/emb_table/embedding_ddr.h new file mode 100644 index 00000000..8025b8c8 --- /dev/null +++ b/src/core/emb_table/embedding_ddr.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: emb table + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#ifndef MX_REC_EMBEDDING_DDR_H +#define MX_REC_EMBEDDING_DDR_H + +#include "emb_table/embedding_table.h" + +namespace MxRec { + +class EmbeddingDDR : public EmbeddingTable { +public: + EmbeddingDDR(); + + EmbeddingDDR(const EmbInfo& info, const RankInfo& rankInfo, int inSeed); + + EmbeddingDDR& operator=(const EmbeddingDDR& table); + + ~EmbeddingDDR(); + + virtual void Key2Offset(std::vector& splitKey, int channel); + + virtual int64_t capacity() const; + + virtual std::vector FindOffset(const vector& keys, + size_t batchId, int channelId, + std::vector& swapPos); + + emb_key_t FindOffsetHelper(const emb_key_t& key, int channelId); + + void UpdateBatchId(const vector& keys, size_t currentBatchId); + + emb_key_t FindSwapPosOld(emb_key_t key, size_t hostOffset, size_t batchId, std::vector& swapPos); + + virtual void EvictKeys(const vector& keys); + +// std::vector lookUpVec; // 查询结果 + + virtual void ClearLookupAndSwapOffset(); + + void SetStartCount(); + + int Load(const string& savePath); + + int Save(const string& savePath); + +GTEST_PRIVATE: + + int LoadHashMap(const string& savePath); + int LoadDevOffset(const string& savePath); + int LoadCurrStat(const string& savePath); + int LoadEvictPos(const string& savePath); + int LoadEmbInfo(const string& savePath); + int LoadEmbData(const string& savePath); + + int SaveHashMap(const string& savePath); + int SaveDevOffset(const string& savePath); + int SaveCurrStat(const string& savePath); + int SaveEvictPos(const string& savePath); + int SaveEmbInfo(const string& savePath); + int SaveEmbData(const string& savePath); + + void EvictDeleteEmb(const vector& keys); + + std::vector devOffset2Key; + + size_t maxOffsetOld { 0 }; + std::vector evictPosChange; + std::vector evictDevPosChange; + std::vector> devOffset2KeyOld; + std::vector> oldSwap; // (old on dev, old on host) + + /* + * HBM与DDR换入换出时,已存在于DDR且要转移到HBM的key(不包含新key); 用于SSD模式 + * (区别于oldSwap: pair.second为已存在于DDR key + 换入换出前映射到DDR的新key) + */ + std::vector ddr2HbmKeys; + bool isSSDEnabled; + std::vector devOffset2Batch; // has -1 + + /** + * 记录HBM上查找空位的当前位置 + * 值域为[0, devVocabSize_] + **/ + size_t currentUpdatePos; + size_t currentUpdatePosStart; // 记录HBM上查找空位的起始位置 +}; + +} + +#endif // MX_REC_EMBEDDING_DDR_H diff --git a/src/core/emb_table/embedding_dynamic.cpp b/src/core/emb_table/embedding_dynamic.cpp new file mode 100644 index 00000000..e561b45c --- /dev/null +++ b/src/core/emb_table/embedding_dynamic.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: EmbeddingDynamic HBM动态扩容embedding表实现 + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#include "emb_table/embedding_dynamic.h" + +#include +#include + +#include "utils/logger.h" +#include "utils/singleton.h" + +using namespace MxRec; + +EmbeddingDynamic::EmbeddingDynamic() +{ +} + +EmbeddingDynamic::EmbeddingDynamic(const EmbInfo& info, const RankInfo& rankInfo, int inSeed) + : EmbeddingTable(info, rankInfo, inSeed) +{ + if (isDynamic_) { + auto ret = aclrtSetDevice(static_cast(rankInfo.deviceId)); + if (ret != ACL_ERROR_NONE) { + LOG_ERROR("Set device failed, device_id:{}, ret={}", rankInfo.deviceId, ret); + throw runtime_error("Acl set device failed!"); + } + MallocEmbeddingBlock(BLOCK_EMB_NUM); + } +} + +EmbeddingDynamic::~EmbeddingDynamic() +{ + for (auto& it: memoryList_) { + aclError ret = aclrtFree(it); + if (ret != ACL_SUCCESS) { + LOG_ERROR("aclrtFree failed, ret={}", ret); + } + } +} + +void EmbeddingDynamic::Key2Offset(std::vector& keys, int channel) +{ + constexpr emb_key_t INVALID_DYNAMIC_EXPANSION_ADDR = 0; // 动态扩容算子中的无效地址是0 + std::lock_guard lk(mut_); // lock for PROCESS_THREAD + for (emb_key_t& key : keys) { + if (key == INVALID_KEY_VALUE) { + key = INVALID_DYNAMIC_EXPANSION_ADDR; + continue; + } + const auto& iter = keyOffsetMap_.find(key); + if (iter != keyOffsetMap_.end()) { + key = iter->second; + continue; + } + // 新值 + if (channel == TRAIN_CHANNEL_ID) { + int64_t addr = GetEmptyEmbeddingAddress(); + keyOffsetMap_[key] = addr; + key = addr; + maxOffset_++; + continue; + } + key = INVALID_DYNAMIC_EXPANSION_ADDR; + } + LOG_DEBUG("current expansion emb:{}, usage:{}/{})", name_, maxOffset_, devVocabSize_); +} + +int64_t EmbeddingDynamic::capacity() const +{ + return capacity_; +} + +int64_t EmbeddingDynamic::GetEmptyEmbeddingAddress() +{ + if (embeddingList_.empty()) { + MallocEmbeddingBlock(BLOCK_EMB_NUM); + } + float *addr = embeddingList_.front(); + embeddingList_.pop_front(); + return reinterpret_cast(addr); +} + +void EmbeddingDynamic::MallocEmbeddingBlock(int embNum) +{ + void *block = nullptr; + aclError ec = aclrtMalloc(&block, embNum * embSize_ * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); + if (ec != 0) { + throw std::bad_alloc(); + } + RandomInit(block, embNum); + memoryList_.push_back(block); + for (int i = 0; i < embNum; i++) { + float *embAddr = static_cast(block) + (i * embSize_); + embeddingList_.push_back(embAddr); + } + capacity_ += embNum; +} + +void EmbeddingDynamic::RandomInit(void* addr, size_t embNum) +{ + LOG_INFO("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed_, embInfo_.initializeInfos.size()); + vector hostmem(embNum * embSize_); + for (const auto& initializeInfo: as_const(embInfo_.initializeInfos)) { + for (size_t i = 0; i < embNum; ++i) { + initializeInfo.initializer->GenerateData(&hostmem[i * embSize_], embSize_); + } + } + LOG_INFO("Device GenerateEmbData End, seed:{}", seed_); + + aclError ret = aclrtMemcpy(addr, embNum * embSize_ * sizeof(float), + hostmem.data(), embNum * embSize_ * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_SUCCESS) { + LOG_ERROR("aclrtMemcpy failed, ret={}", ret); + } +} diff --git a/src/core/emb_table/embedding_dynamic.h b/src/core/emb_table/embedding_dynamic.h new file mode 100644 index 00000000..dcac7e74 --- /dev/null +++ b/src/core/emb_table/embedding_dynamic.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: embeddingtable with dynamic expansion + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#ifndef MX_REC_EMBEDDING_DYNAMIC_H +#define MX_REC_EMBEDDING_DYNAMIC_H + +#include "emb_table/embedding_table.h" + +namespace MxRec { + +/** + * 支持动态扩容的embedding表 + */ +class EmbeddingDynamic : public EmbeddingTable { +public: + EmbeddingDynamic(); + + EmbeddingDynamic(const EmbInfo& info, const RankInfo& rankInfo, int inSeed); + + ~EmbeddingDynamic(); + + virtual void Key2Offset(std::vector& keys, int channel); + + virtual int64_t capacity() const; + +private: + constexpr static int BLOCK_EMB_NUM = 100000; // 每次扩容分配10w条 + + void RandomInit(void* addr, size_t embNum); + + int64_t GetEmptyEmbeddingAddress(); + + void MallocEmbeddingBlock(int embNum); + + // embedding地址的列表 + list embeddingList_; + // 内存块列表 + vector memoryList_; +}; +} + +#endif // MX_REC_EMBEDDING_DYNAMIC_H diff --git a/src/core/emb_table/embedding_mgmt.cpp b/src/core/emb_table/embedding_mgmt.cpp new file mode 100644 index 00000000..baf3ead0 --- /dev/null +++ b/src/core/emb_table/embedding_mgmt.cpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: EmbeddingMgmt管理类 + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#include "emb_table/embedding_mgmt.h" +#include "emb_table/embedding_static.h" +#include "emb_table/embedding_dynamic.h" +#include "emb_table/embedding_ddr.h" +#include "utils/logger.h" + +using namespace MxRec; + +EmbeddingMgmt::EmbeddingMgmt() +{ +} + +void EmbeddingMgmt::Init(const RankInfo& rInfo, const vector& eInfos, + const vector& thresholdValues, int seed) +{ + for (size_t i = 0; i < eInfos.size(); ++i) { + if (rInfo.isDDR) { + embeddings[eInfos[i].name] = std::make_shared(eInfos[i], rInfo, seed); + continue; + } + if (rInfo.useDynamicExpansion) { + embeddings[eInfos[i].name] = std::make_shared(eInfos[i], rInfo, seed); + continue; + } + embeddings[eInfos[i].name] = std::make_shared(eInfos[i], rInfo, seed); + } +} + +EmbeddingMgmt* EmbeddingMgmt::Instance() +{ + static EmbeddingMgmt mgmt; + return &mgmt; +} + +void EmbeddingMgmt::Key2Offset(const std::string& name, std::vector& keys, int channel) +{ + embeddings[name]->Key2Offset(keys, channel); +} + +size_t EmbeddingMgmt::GetMaxOffset(const std::string& name) +{ + embeddings[name]->GetMaxOffset(); +} + +void EmbeddingMgmt::LoadMaxOffset(OffsetMemT& loadData) +{ + LOG_ERROR("load max offset"); +} + +void EmbeddingMgmt::LoadKeyOffsetMap(KeyOffsetMemT& loadData) +{ + LOG_ERROR("load key offset"); +} + +std::map EmbeddingMgmt::GetMaxOffset() +{ + std::map maxoffset; + for (auto &it: embeddings) { + maxoffset[it.first] = it.second->GetMaxOffset(); + } + return maxoffset; +} + +KeyOffsetMemT EmbeddingMgmt::GetKeyOffsetMap() +{ + KeyOffsetMemT keyOffsetMap; + for (auto &it: embeddings) { + keyOffsetMap[it.first] = it.second->GetKeyOffsetMap(); + } + return keyOffsetMap; +} + +void EmbeddingMgmt::EvictKeys(const string& name, const vector& keys) +{ + LOG_ERROR("evict keys for {}", name); + if (keys.size() != 0) { + embeddings[name]->EvictKeys(keys); + } + embeddings[name]->EvictInitDeviceEmb(); +} + +void EmbeddingMgmt::EvictKeysCombine(const vector& keys) +{ + if (keys.size() != 0) { + for (auto& table: embeddings) { + table.second->EvictKeys(keys); + } + } + for (auto& table: embeddings) { + // 初始化 dev + table.second->EvictInitDeviceEmb(); + } +} + +int64_t EmbeddingMgmt::GetSize(const std::string &name) +{ + return embeddings[name]->size(); +} + +int64_t EmbeddingMgmt::GetCapacity(const std::string &name) +{ + return embeddings[name]->capacity(); +} + +void EmbeddingMgmt::FindOffset(const std::string& name, const vector& keys, + size_t currentBatchId, size_t keepBatchId, int channel) +{ + return embeddings[name]->FindOffset(keys, currentBatchId, keepBatchId, channel); +} + +const std::vector& EmbeddingMgmt::GetMissingKeys(const std::string& name) +{ + return embeddings[name]->GetMissingKeys(); +} + +void EmbeddingMgmt::ClearMissingKeys(const std::string& name) +{ + return embeddings[name]->ClearMissingKeys(); +} + +std::shared_ptr EmbeddingMgmt::GetTable(const string& name) +{ + auto it = embeddings.find(name); + if (it == embeddings.end()) { + LOG_ERROR("table not found"); + } + return std::dynamic_pointer_cast(it->second); +} + +int EmbeddingMgmt::Load(const string& name, const string& filePath) +{ + return embeddings[name]->Load(filePath); +} + +int EmbeddingMgmt::Save(const string& name, const string& filePath) +{ + return embeddings[name]->Save(filePath); +} + +int EmbeddingMgmt::Save(const string& filePath) +{ + for (auto& tablePair: embeddings) { + tablePair.second->Save(filePath); + } +} diff --git a/src/core/emb_table/embedding_mgmt.h b/src/core/emb_table/embedding_mgmt.h new file mode 100644 index 00000000..00666113 --- /dev/null +++ b/src/core/emb_table/embedding_mgmt.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: EmbeddingMgmt管理类 + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#ifndef MX_REC_EMBEDDING_MGMT_H +#define MX_REC_EMBEDDING_MGMT_H + +#include +#include +#include +#include "utils/common.h" +#include "emb_table/embedding_table.h" + +namespace MxRec { + +/** + * Embedding管理类 + */ +class EmbeddingMgmt { +public: + + /** + * @param[in] rInfo 从python侧传过了的rank信息 + * @param[in] eInfos 从python侧传过了的embedding表信息 + */ + void Init(const RankInfo& rInfo, const vector& eInfos, + const vector& thresholdValues = {}, int seed = 0); + + /** + * 从embedding表中查批量查找key + * @param[in] name embedding表名 + * @param[in,out] splitKey 待查找的key,输出为找到的HBM偏移或者HBM地址 + * @param[in] channel 数据通道,主要区分train和eval + */ + void Key2Offset(const std::string& name, std::vector& keys, int channel); + + void FindOffset(const std::string& name, const vector& keys, + size_t currentBatchId, size_t keepBatchId, int channel); + + /** + * 在指定的embedding表中淘汰key + * @param[in] name embedding表名 + * @param[in] keys 待淘汰的key + */ + void EvictKeys(const std::string& name, const vector& keys); + + /** + * 在全部的embedding表中淘汰key + * @param[in] keys 待淘汰的key + */ + void EvictKeysCombine(const vector& keys); + + const std::vector& GetMissingKeys(const std::string& name); + + void ClearMissingKeys(const std::string& name); + + void LoadMaxOffset(OffsetMemT& loadData); + + void LoadKeyOffsetMap(KeyOffsetMemT& loadData); + + size_t GetMaxOffset(const std::string& name); + + int64_t GetSize(const std::string &name); + + int64_t GetCapacity(const std::string &name); + + std::map GetMaxOffset(); + + KeyOffsetMemT GetKeyOffsetMap(); + + static EmbeddingMgmt* Instance(); + + std::shared_ptr GetTable(const string& name); + + /** + * 加载 + */ + int Load(const string& name, const string& filePath); + + /** + * 保存单个表 + */ + int Save(const string& name, const string& filePath); + + /** + * 保存所有表 + */ + int Save(const string& filePath); + +private: + + EmbeddingMgmt(); + + EmbeddingMgmt(const EmbeddingMgmt& mgmt) = delete; + + std::unordered_map> embeddings; +}; + +} + +#endif // MX_REC_EMBEDDING_MGMT_H diff --git a/src/core/emb_table/embedding_static.cpp b/src/core/emb_table/embedding_static.cpp new file mode 100644 index 00000000..2a4f2705 --- /dev/null +++ b/src/core/emb_table/embedding_static.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: EmbeddingStatic HBM模式embedding表实现 + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#include "emb_table/embedding_static.h" +#include "utils/logger.h" + +using namespace MxRec; + +EmbeddingStatic::EmbeddingStatic() +{ +} + +EmbeddingStatic::EmbeddingStatic(const EmbInfo& info, const RankInfo& rankInfo, int inSeed) + : EmbeddingTable(info, rankInfo, inSeed) +{ +} + +EmbeddingStatic::~EmbeddingStatic() +{ +} + +void EmbeddingStatic::Key2Offset(std::vector& keys, int channel) +{ + std::lock_guard lk(mut_); // lock for PROCESS_THREAD + for (emb_key_t& key : keys) { + if (key == INVALID_KEY_VALUE) { + continue; + } + const auto& iter = keyOffsetMap_.find(key); + if (iter != keyOffsetMap_.end()) { + key = iter->second; + continue; + } + if (evictPos_.size() != 0 && channel == TRAIN_CHANNEL_ID) { + // 新值, emb有pos可复用 + size_t offset = evictPos_.back(); + keyOffsetMap_[key] = offset; + key = offset; + evictPos_.pop_back(); + continue; + } + // 新值 + if (channel != TRAIN_CHANNEL_ID) { + key = INVALID_KEY_VALUE; + continue; + } + keyOffsetMap_[key] = maxOffset_; + key = maxOffset_++; + } + if (maxOffset_ > devVocabSize_) { + LOG_ERROR("dev cache overflow {} > {}", maxOffset_, devVocabSize_); + throw std::runtime_error("dev cache overflow!"); + } +} + +int64_t EmbeddingStatic::capacity() const +{ + return this->devVocabSize_; +} diff --git a/src/core/emb_table/embedding_static.h b/src/core/emb_table/embedding_static.h new file mode 100644 index 00000000..47c1ee41 --- /dev/null +++ b/src/core/emb_table/embedding_static.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: emb table + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#ifndef MX_REC_EMBEDDING_STATIC_H +#define MX_REC_EMBEDDING_STATIC_H + +#include "emb_table/embedding_table.h" + +namespace MxRec { + +/** + * 静态大小的Embedding表。在HBM中分配好后大小无法改变 + */ +class EmbeddingStatic : public EmbeddingTable { +public: + EmbeddingStatic(); + + EmbeddingStatic(const EmbInfo& info, const RankInfo& rankInfo, int inSeed); + + ~EmbeddingStatic(); + + virtual void Key2Offset(std::vector& keys, int channel); + + virtual int64_t capacity() const; +}; + +} + +#endif // MX_REC_EMBEDDING_STATIC_H diff --git a/src/core/emb_table/embedding_table.cpp b/src/core/emb_table/embedding_table.cpp new file mode 100644 index 00000000..d48cee03 --- /dev/null +++ b/src/core/emb_table/embedding_table.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: emb table + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#include "emb_table/embedding_table.h" +#include "utils/logger.h" +#include "utils/singleton.h" +#include "hd_transfer/hd_transfer.h" +#include "file_system/file_system_handler.h" + +using namespace MxRec; + +EmbeddingTable::EmbeddingTable() +{ +} + +EmbeddingTable::EmbeddingTable(const EmbInfo& info, const RankInfo& rankInfo, int inSeed) + : name_(info.name), hostVocabSize_(info.hostVocabSize), devVocabSize_(info.devVocabSize), + freeSize_(0), maxOffset_(0), isDynamic_(rankInfo.useDynamicExpansion), + embSize_(info.embeddingSize), extEmbSize_(info.extEmbeddingSize), + embInfo_(info), seed_(inSeed), rankId_(rankInfo.rankId) +{ + LOG_TRACE("table {} isDynamic = {} embeddingSize {} extSize {}", + name_, isDynamic_, embSize_, extEmbSize_); +} + +EmbeddingTable::~EmbeddingTable() +{ +} + +void EmbeddingTable::Key2Offset(std::vector& keys, int channel) +{ + return; +} + +void EmbeddingTable::FindOffset(const vector& keys, + size_t currentBatchId, size_t keepBatchId, int channelId) +{ + return; +} + +std::vector EmbeddingTable::FindOffset(const vector& keys, + size_t batchId, int channelId, + std::vector& swapPos) +{ + return {}; +} + +size_t EmbeddingTable::GetMaxOffset() +{ + return maxOffset_; +} + +int64_t EmbeddingTable::capacity() const +{ + return static_cast(devVocabSize_); +} + +size_t EmbeddingTable::size() const +{ + return maxOffset_; +} + +void EmbeddingTable::EvictKeys(const std::vector& keys) +{ + std::lock_guard lk(mut_); // lock for PROCESS_THREAD + size_t keySize = keys.size(); + for (size_t i = 0; i < keySize; i++) { + emb_key_t key = keys[i]; + if (key == INVALID_KEY_VALUE) { + LOG_WARN("evict key is INVALID_KEY_VALUE!"); + continue; + } + const auto& iter = keyOffsetMap_.find(key); + if (iter == keyOffsetMap_.end()) { // not found + continue; + } + keyOffsetMap_.erase(iter); + evictPos_.emplace_back(iter->second); + LOG_TRACE("evict embName:{}, offset:{}", name_, iter->second); + } + LOG_INFO("EvictKeys: table [{}] evict size on dev:{}", name_, evictPos_.size()); +} + +const std::vector& EmbeddingTable::GetEvictedKeys() +{ + return evictPos_; +} + +const std::vector& EmbeddingTable::GetHostEvictedKeys() +{ + return evictHostPos_; +} + +void EmbeddingTable::EvictInitDeviceEmb() +{ + if (evictPos_.size() > devVocabSize_) { + LOG_ERROR("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", + name_, evictPos_.size(), devVocabSize_); + throw runtime_error( + Logger::Format("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", + name_, evictPos_.size(), devVocabSize_).c_str()); + } + + vector tmpDataOut; + Tensor tmpData = Vec2TensorI32(evictPos_); + tmpDataOut.emplace_back(tmpData); + tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); + + auto evictLen = tmpDataOut.back().flat(); + evictLen(0) = static_cast(evictPos_.size()); + + // evict key发送给dev侧,dev侧初始化emb + auto trans = Singleton::GetInstance(); + trans->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, name_); + + LOG_INFO(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", name_, evictPos_.size()); +} + +absl::flat_hash_map EmbeddingTable::GetKeyOffsetMap() +{ + return keyOffsetMap_; +} + +void EmbeddingTable::ClearMissingKeys() +{ + missingKeysHostPos_.clear(); +} + +const std::vector& EmbeddingTable::GetMissingKeys() +{ + return missingKeysHostPos_; +} + +void EmbeddingTable::SetStartCount() +{ +} + +void EmbeddingTable::ClearLookupAndSwapOffset() +{ +} + +size_t EmbeddingTable::GetDevVocabSize() +{ + return devVocabSize_; +} + +size_t EmbeddingTable::GetHostVocabSize() +{ + return hostVocabSize_; +} + +int EmbeddingTable::Load(const string& filePath) +{ + return 0; +} + +int EmbeddingTable::Save(const string& filePath) +{ + return 0; +} + +void EmbeddingTable::MakeDir(const string& dirName) +{ + auto fileSystemHandler = make_unique(); + unique_ptr fileSystemPtr = fileSystemHandler->Create(dirName); + fileSystemPtr->CreateDir(dirName); +} diff --git a/src/core/emb_table/embedding_table.h b/src/core/emb_table/embedding_table.h new file mode 100644 index 00000000..06baf7c3 --- /dev/null +++ b/src/core/emb_table/embedding_table.h @@ -0,0 +1,116 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Description: emb table + * Author: MindX SDK + * Date: 2023/12/11 + */ + +#ifndef MX_REC_EMBEDDING_TABLE_H +#define MX_REC_EMBEDDING_TABLE_H +#include +#include +#include + +#include "utils/common.h" + +namespace MxRec { + +class EmbeddingTable { +public: + EmbeddingTable(); + EmbeddingTable(const EmbInfo& info, const RankInfo& rankInfo, int inSeed); + virtual ~EmbeddingTable(); + + /** + * 从embedding表中查批量查找key + * @param[in,out] keys 待查找的key,输出为找到的HBM偏移或者HBM地址 + * @param[in] channel 数据通道,主要区分train和eval + */ + virtual void Key2Offset(std::vector& keys, int channel); + + /** + * DDR模式使用 + */ + virtual void FindOffset(const vector& keys, + size_t currentBatchId, size_t keepBatchId, int channelId); + + virtual std::vector FindOffset(const vector& keys, + size_t batchId, int channelId, + std::vector& swapPos); + + /** + * 淘汰key, 配合GetEvictedKeys一起使用GetEvictedKeys + * EvictKeys执行,通过GetEvictedKeys, GetEvictedKeys拿结果 + */ + virtual void EvictKeys(const std::vector& keys); + + /** + * 获取设备侧淘汰的key的偏移或者地址 + * @return HBM模式为偏移, 动态扩容时为地址 + */ + virtual const std::vector& GetEvictedKeys(); + + /** + * 获取host侧淘汰的key的偏移。只有Host侧扩容DDR使用 + * @return host侧淘汰key的偏移 + */ + virtual const std::vector& GetHostEvictedKeys(); + + virtual void EvictInitDeviceEmb(); + + size_t GetMaxOffset(); + + virtual int64_t capacity() const; + + virtual size_t size() const; + + void ClearMissingKeys(); + + virtual const std::vector& GetMissingKeys(); + + absl::flat_hash_map GetKeyOffsetMap(); + + virtual void SetStartCount(); + + virtual void ClearLookupAndSwapOffset(); + + virtual int Load(const string& savePath); + + virtual int Save(const string& savePath); + + size_t GetDevVocabSize(); + + size_t GetHostVocabSize(); + + static void MakeDir(const string& dirName); + +#ifdef NDEBUG +protected: +#endif + + EmbeddingTable& operator=(const EmbeddingTable& table) = delete; + + std::string name_; + size_t hostVocabSize_; + size_t devVocabSize_; + size_t freeSize_; + size_t maxOffset_; + bool isDynamic_; + absl::flat_hash_map keyOffsetMap_; + std::vector evictPos_; // 记录HBM内被淘汰的key + std::vector evictHostPos_; // 记录Host内淘汰列表 + std::mutex mut_; + std::vector initializeInfos_; + EmbInfo embInfo_; + size_t embSize_; + size_t extEmbSize_; + int seed_; + int64_t capacity_; + size_t rankId_; + + std::vector missingKeysHostPos_; // 用于记录当前batch在host上需要换出的偏移 +}; + +} + +#endif // MX_REC_EMBEDDING_TABLE_H diff --git a/src/core/file_system/file_system.h b/src/core/file_system/file_system.h index 6b08b6f6..5af6985f 100644 --- a/src/core/file_system/file_system.h +++ b/src/core/file_system/file_system.h @@ -36,7 +36,14 @@ namespace MxRec { const vector& addressArr, int deviceId) = 0; virtual ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) = 0; - virtual ssize_t Read(const string& filePath, vector>& fileContent, size_t datasetSize) = 0; + + /** + * datasetSize为文件大小,文件大小除以fileContent的size,即为每条embedding的size + * param[in] 文件路径 + * param[out] fileContent 文件内容将读取到这个矩阵中 + * param[in] datasetSize 文件大小 + */ + virtual ssize_t Read(const string& filePath, vector>& fileContent, size_t datasetSize = 0) = 0; virtual void ReadEmbedding(const string& filePath, const int& embeddingSize, vector& addressArr, int deviceId) = 0; diff --git a/src/core/file_system/local_file_system/local_file_system.cpp b/src/core/file_system/local_file_system/local_file_system.cpp index 36cd0c57..87e6b687 100644 --- a/src/core/file_system/local_file_system/local_file_system.cpp +++ b/src/core/file_system/local_file_system/local_file_system.cpp @@ -29,9 +29,22 @@ using namespace MxRec; void LocalFileSystem::CreateDir(const string& dirName) { - if (access(dirName.c_str(), F_OK) == -1) { - if (mkdir(dirName.c_str(), dirMode) == -1) { - LOG_DEBUG("Unable to create directory: {}", dirName); + constexpr int maxDepth = 100; + int guard = 0; + stringstream input(dirName); // 读取str到字符串流中 + stringstream ss; + string tmp; + // 按'/'分割,自动创建多级目录 + while (getline(input, tmp, '/')) { + guard++; + if (guard > maxDepth) { + throw runtime_error(StringFormat("create directory {} exceed max depth", dirName)); + } + ss << tmp << '/'; + int ret = mkdir(ss.str().c_str(), dirMode); + if (ret != 0 && errno != EEXIST) { + LOG_ERROR("Unable to create directory: {} ret:{} error info: {}", dirName, ret, strerror(errno)); + throw runtime_error(StringFormat("create directory {} failed: {}", dirName, strerror(errno))); } } } @@ -229,9 +242,23 @@ ssize_t LocalFileSystem::Read(const string& filePath, char* fileContent, size_t ssize_t LocalFileSystem::Read(const string& filePath, vector>& fileContent, size_t datasetSize) { - size_t embDataOuterSize = fileContent.capacity(); - auto onceReadByteSize { datasetSize / embDataOuterSize }; + int fd = open(filePath.c_str(), O_RDONLY); + if (fd < 0) { + throw runtime_error(StringFormat("Failed to open read file: %s", filePath.c_str())); + } + if (datasetSize == 0) { + struct stat statbuf; + fstat(fd, &statbuf); + datasetSize = statbuf.st_size; + } + size_t embDataOuterSize = fileContent.size(); + if (embDataOuterSize == 0 || datasetSize == 0) { + close(fd); + throw runtime_error(StringFormat("output buffer or file size is empty")); + } + // datasetSize为文件大小, 文件大小除以fileContent的size,即为每条embedding的size + size_t onceReadByteSize = datasetSize / embDataOuterSize; size_t mapByteSize; size_t mapRowNum; CalculateMapSize(datasetSize, mapByteSize, mapRowNum, onceReadByteSize); @@ -239,12 +266,6 @@ ssize_t LocalFileSystem::Read(const string& filePath, vector>& fil off_t offset = 0; size_t remainBytes = datasetSize; ssize_t readBytesNum = 0; - - int fd = open(filePath.c_str(), O_RDONLY); - if (fd == -1) { - throw runtime_error(StringFormat("Failed to open read file: %s", filePath.c_str())); - } - for (size_t i = 0; i < embDataOuterSize; i += mapRowNum) { // 如果剩余字节数小于每次映射的字节数,则更新每次映射的字节数和行数 if (remainBytes < mapByteSize) { @@ -258,7 +279,6 @@ ssize_t LocalFileSystem::Read(const string& filePath, vector>& fil return -1; } readBytesNum += mapByteSize; - char* mappedData = static_cast(tempMappedData); // 处理映射的数据 diff --git a/src/core/file_system/local_file_system/local_file_system.h b/src/core/file_system/local_file_system/local_file_system.h index 06a8d18a..8a8527df 100644 --- a/src/core/file_system/local_file_system/local_file_system.h +++ b/src/core/file_system/local_file_system/local_file_system.h @@ -38,7 +38,8 @@ namespace MxRec { const vector& addressArr, int deviceId) override; ssize_t Read(const string& filePath, char* fileContent, size_t datasetSize) override; - ssize_t Read(const string& filePath, vector>& fileContent, size_t datasetSize) override; + + ssize_t Read(const string& filePath, vector>& fileContent, size_t datasetSize = 0) override; void ReadEmbedding(const string& filePath, const int& embeddingSize, vector& addressArr, int deviceId) override; @@ -56,4 +57,4 @@ namespace MxRec { }; } -#endif // MX_REC_LOCAL_FILE_SYSTEM_H \ No newline at end of file +#endif // MX_REC_LOCAL_FILE_SYSTEM_H diff --git a/src/core/host_emb/host_emb.cpp b/src/core/host_emb/host_emb.cpp index 7f885fb0..ce0e0a78 100644 --- a/src/core/host_emb/host_emb.cpp +++ b/src/core/host_emb/host_emb.cpp @@ -49,6 +49,7 @@ void HostEmb::Initialize(const vector& embInfos, int seed) void HostEmb::EmbDataGenerator(const vector &initializeInfos, int seed, int vocabSize, int embeddingSize, vector> &embData) const { +#ifndef GTEST LOG_INFO(HOSTEMB + "GenerateEmbData Start, seed:{}, initializer num: {}", seed, initializeInfos.size()); embData.clear(); embData.resize(vocabSize, vector(embeddingSize)); @@ -60,6 +61,7 @@ void HostEmb::EmbDataGenerator(const vector &initializeInfos, in } } LOG_INFO(HOSTEMB + "GenerateEmbData End, seed:{}", seed); +#endif } /// 停止用于异步更新D2H emb的线程 @@ -91,6 +93,7 @@ void HostEmb::Join(int channelId) } } +#ifndef GTEST /// 从hdTransfer获取device侧返回的emb信息,并在host侧表的对应位置插入。 /// missingKeysHostPos为host侧需要发送的emb的位置,也就是淘汰的emb的插入位置 /// \param missingKeysHostPos 当前batch在host上需要换出的偏移 @@ -116,7 +119,7 @@ void HostEmb::UpdateEmb(const vector& missingKeysHostPos, int channelId, auto& embData = hostEmbs[embName].embData; LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}, " - "embData.size = {}", embName, missingKeysHostPos.size(), embeddingSize, embData.size()); + "embData.size = {} {}", embName, missingKeysHostPos.size(), embeddingSize, embData.size(), tensorPtr); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) \ shared(missingKeysHostPos, tensorPtr, embData, embeddingSize) @@ -158,7 +161,10 @@ void HostEmb::UpdateEmbV2(const vector& missingKeysHostPos, int channelI if (aclData == nullptr) { throw runtime_error("Acl get tensor data from dataset failed."); } - float* ptr = reinterpret_cast(acltdtGetDataAddrFromItem(aclData)); + float* ptr = static_cast(acltdtGetDataAddrFromItem(aclData)); + if (ptr == nullptr || missingKeysHostPos.size() == 0) { + return; + } size_t elementSize = acltdtGetDataSizeFromItem(aclData); size_t dimNum = acltdtGetDimNumFromItem(aclData); LOG_DEBUG(HOSTEMB + "embName:{}, UpdateEmb missingKeys len = {}, embeddingSize = {}," @@ -204,7 +210,7 @@ void HostEmb::GetH2DEmb(const vector& missingKeysHostPos, const string& auto& tmpTensor = h2dEmbOut.back(); auto tmpData = tmpTensor.flat(); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(missingKeysHostPos, emb, tmpData) - for (size_t i = 0; i < missingKeysHostPos.size(); i++) { + for (size_t i = 0; i < missingKeysHostPos.size(); ++i) { const auto& src = emb.embData[missingKeysHostPos[i]]; #pragma omp simd for (int j = 0; j < embeddingSize; j++) { @@ -230,19 +236,43 @@ void HostEmb::EmbPartGenerator(const vector &initializeInfos, ve { for (auto initializeInfo: initializeInfos) { LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name); - for (size_t i = 0; i < offset.size(); i++) { + for (size_t i = 0; i < offset.size(); ++i) { initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), static_cast(embData[0].size())); } } } +void HostEmb::EmbPartGenerator(const vector &initializeInfos, vector> &embData, + const vector& offset) const +{ + for (auto initializeInfo: initializeInfos) { + LOG_INFO("Device GenerateEmbData ing. name {}", initializeInfo.name); + for (size_t i = 0; i < offset.size(); ++i) { + initializeInfo.initializer->GenerateData(embData.at(offset.at(i)).data(), + static_cast(embData[0].size())); + } + } +} +#endif + /// 利用initializer初始化emb淘汰的位置 /// \param embName 表名 /// \param offset 淘汰的偏移列表 void HostEmb::EvictInitEmb(const string& embName, const vector& offset) { +#ifndef GTEST + auto& hostEmb = GetEmb(embName); + EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); + LOG_INFO(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); +#endif +} + +void HostEmb::EvictInitEmb(const string& embName, const vector& offset) +{ +#ifndef GTEST auto& hostEmb = GetEmb(embName); EmbPartGenerator(hostEmb.hostEmbInfo.initializeInfos, hostEmb.embData, offset); LOG_INFO(HOSTEMB + "ddr EvictInitEmb!host embName {}, init offsets size: {}", embName, offset.size()); +#endif } \ No newline at end of file diff --git a/src/core/host_emb/host_emb.h b/src/core/host_emb/host_emb.h index 3b7e6942..a9ff3786 100644 --- a/src/core/host_emb/host_emb.h +++ b/src/core/host_emb/host_emb.h @@ -50,6 +50,8 @@ namespace MxRec { void EvictInitEmb(const string& embName, const vector& offset); + void EvictInitEmb(const string& embName, const vector& offset); + HostEmbTable& GetEmb(const string& embName) { return hostEmbs.at(embName); @@ -65,7 +67,10 @@ namespace MxRec { vector>& embData) const; void EmbPartGenerator(const vector &initializeInfos, vector> &embData, const vector& offset) const; + + void EmbPartGenerator(const vector &initializeInfos, vector> &embData, + const vector& offset) const; }; } -#endif // MX_REC_HOSTEMB_H \ No newline at end of file +#endif // MX_REC_HOSTEMB_H diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index c35a7443..f5895604 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -21,6 +21,7 @@ See the License for the specific language governing permissions and #include "checkpoint/checkpoint.h" #include "key_process/key_process.h" #include "key_process/feature_admit_and_evict.h" +#include "emb_table/embedding_mgmt.h" using namespace MxRec; @@ -44,8 +45,8 @@ void HybridMgmt::InitRankInfo(RankInfo& rankInfo, const vector& embInfo } // 根据DDR的key数量,配置存储模式HBM/DDR - if (totHostVocabSize == 0) { - rankInfo.noDDR = true; + if (totHostVocabSize != 0) { + rankInfo.isDDR = true; } if (totalSsdVocabSize != 0) { rankInfo.isSSDEnabled = true; @@ -79,6 +80,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, } InitRankInfo(rankInfo, embInfos); + EmbeddingMgmt::Instance()->Init(rankInfo, embInfos, thresholdValues, seed); GlogConfig::gStatOn = GlobalEnv::statOn; LOG_INFO(MGMT + "begin initialize, localRankSize:{}, localRankId:{}, rank:{}", @@ -100,7 +102,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, isRunning = true; // DDR模式,初始化hashmap和host emb - if (!rankInfo.noDDR) { + if (rankInfo.isDDR) { hostEmbs = Singleton::GetInstance(); hostHashMaps = make_unique(); hostEmbs->Initialize(embInfos, seed); @@ -124,7 +126,7 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, LOG_INFO(MGMT + "emb[{}] vocab size {}+{} sc:{}", info.name, info.devVocabSize, info.hostVocabSize, info.sendCount); } - LOG_INFO(MGMT + "end initialize, noDDR:{}, maxStep:[{}, {}], rank:{}", rankInfo.noDDR, + LOG_INFO(MGMT + "end initialize, isDDR:{}, maxStep:[{}, {}], rank:{}", rankInfo.isDDR, rankInfo.maxStep.at(TRAIN_CHANNEL_ID), rankInfo.maxStep.at(EVAL_CHANNEL_ID), rankInfo.rankId); #endif isInitialized = true; @@ -225,16 +227,16 @@ bool HybridMgmt::Save(const string savePath) CkptData saveData; Checkpoint saveCkpt; saveData.keyCountMap = KEY_PROCESS_INSTANCE->GetKeyCountMap(); - if (!mgmtRankInfo.noDDR) { + + if (mgmtRankInfo.isDDR) { // DDR模式保存host的emb表以及hashmap LOG_DEBUG(MGMT + "Start host side save: ddr mode hashmap"); - saveData.hostEmbs = hostEmbs->GetHostEmbs(); - saveData.embHashMaps = hostHashMaps->GetHashMaps(); + EmbeddingMgmt::Instance()->Save(savePath); } else { // HBM模式保存最大偏移(真正使用了多少vocab容量),特征到偏移的映射 LOG_DEBUG(MGMT + "Start host side save: no ddr mode hashmap"); - saveData.maxOffset = KEY_PROCESS_INSTANCE->GetMaxOffset(); - saveData.keyOffsetMap = KEY_PROCESS_INSTANCE->GetKeyOffsetMap(); + saveData.maxOffset = EmbeddingMgmt::Instance()->GetMaxOffset(); + saveData.keyOffsetMap = EmbeddingMgmt::Instance()->GetKeyOffsetMap(); } if (isSSDEnabled) { @@ -292,13 +294,13 @@ bool HybridMgmt::Load(const string& loadPath) loadCkpt.LoadModel(loadPath, loadData, mgmtRankInfo, mgmtEmbInfo, loadFeatures); // 检查DDR模式保存的模型和当前训练配置是否一致,不一致则退出 - if (!mgmtRankInfo.noDDR && !LoadMatchesDDRSetup(loadData)) { + if (mgmtRankInfo.isDDR && !LoadMatchesDDRSetup(loadData)) { KEY_PROCESS_INSTANCE->LoadSaveUnlock(); return false; } KEY_PROCESS_INSTANCE->LoadKeyCountMap(loadData.keyCountMap); - if (!mgmtRankInfo.noDDR) { + if (mgmtRankInfo.isDDR) { // DDR模式 将加载的hash map进行赋值 LOG_DEBUG(MGMT + "Start host side load: ddr mode hashmap"); hostHashMaps->LoadHashMap(loadData.embHashMaps); @@ -340,7 +342,7 @@ void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures) if (GlobalEnv::recordKeyCount) { loadFeatures.push_back(CkptFeatureType::KEY_COUNT_MAP); } - if (!mgmtRankInfo.noDDR) { + if (mgmtRankInfo.isDDR) { // DDR模式加载的类型为host的emb表以及hashmap loadFeatures.push_back(CkptFeatureType::HOST_EMB); loadFeatures.push_back(CkptFeatureType::EMB_HASHMAP); @@ -399,7 +401,7 @@ void HybridMgmt::ReceiveHostMap(AllKeyOffsetMapT receiveKeyOffsetMap) maxOffset = keyOffsetMap.second.size(); } } - if (!mgmtRankInfo.noDDR) { + if (mgmtRankInfo.isDDR) { LOG_DEBUG(MGMT + "Start receive sparse data: ddr mode hashmap"); } else { LOG_DEBUG(MGMT + "Start receive sparse data: no ddr mode hashmap"); @@ -488,10 +490,10 @@ bool HybridMgmt::LoadMatchesDDRSetup(const CkptData& loadData) void HybridMgmt::Start() { #ifndef GTEST - if (mgmtRankInfo.noDDR) { - StartThreadForHBM(); - } else { + if (mgmtRankInfo.isDDR) { StartThreadForDDR(); + } else { + StartThreadForHBM(); } #endif } @@ -803,10 +805,11 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha LOG_ERROR("Failed to get embedding hash map with given name: {}", embName); return false; } - auto& embHashMap = hostHashMaps->embHashMaps.at(embName); + auto& embHashMap = hostHashMaps->embHashMaps.at(embName); // 计数初始化 - embHashMap.SetStartCount(); + std::shared_ptr table = EmbeddingMgmt::Instance()->GetTable(embName); + table->SetStartCount(); // 获取查询向量 auto lookupKeys = KEY_PROCESS_INSTANCE->GetLookupKeys(batchId, embName, channelId); @@ -840,7 +843,9 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha vector offsetsOut; DDRParam ddrParam(tmpData, offsetsOut); TimeCost hostHashMapProcessTC; + hostHashMaps->Process(embName, lookupKeys, ddrParam, channelId); + LOG_DEBUG("channelId:{} batchId:{}, hostHashMapProcessTC(ms):{}", channelId, batchId, hostHashMapProcessTC.ElapsedMS()); @@ -923,7 +928,7 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) TimeCost h2dTC; // 发送host需要换出的emb for (const auto& embInfo: mgmtEmbInfo) { - auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); + const auto& missingKeys = EmbeddingMgmt::Instance()->GetMissingKeys(embInfo.name); vector h2dEmb; hostEmbs->GetH2DEmb(missingKeys, embInfo.name, h2dEmb); // order! hdTransfer->Send(TransferChannel::H2D, h2dEmb, channelId, embInfo.name, batchId); @@ -933,13 +938,9 @@ void HybridMgmt::EmbHDTrans(const int channelId, const int batchId) TimeCost d2hTC; // 接收device换出的emb,并更新到host上 for (const auto& embInfo: mgmtEmbInfo) { - const auto& missingKeys = hostHashMaps->GetMissingKeys(embInfo.name); - if (GlobalEnv::updateEmbV2) { - hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! - } else { - hostEmbs->UpdateEmb(missingKeys, channelId, embInfo.name); // order! - } - hostHashMaps->ClearMissingKeys(embInfo.name); + const auto& missingKeys = EmbeddingMgmt::Instance()->GetMissingKeys(embInfo.name); + hostEmbs->UpdateEmbV2(missingKeys, channelId, embInfo.name); // order! + EmbeddingMgmt::Instance()->ClearMissingKeys(embInfo.name); } LOG_DEBUG("channelId:{} batchId:{}, EmbHDTrans d2h end, d2hTC(ms):{}", channelId, batchId, d2hTC.ElapsedMS()); } @@ -970,19 +971,18 @@ bool HybridMgmt::Evict() return false; } - if (mgmtRankInfo.noDDR) { + if (!mgmtRankInfo.isDDR) { if (GlobalEnv::useCombineFaae) { - KEY_PROCESS_INSTANCE->EvictKeysCombine(evictKeyMap[COMBINE_HISTORY_NAME]); + EmbeddingMgmt::Instance()->EvictKeysCombine(evictKeyMap[COMBINE_HISTORY_NAME]); } else { for (const auto& evict : as_const(evictKeyMap)) { - KEY_PROCESS_INSTANCE->EvictKeys(evict.first, evict.second); + EmbeddingMgmt::Instance()->EvictKeys(evict.first, evict.second); } } } else { if (GlobalEnv::useCombineFaae) { for (auto& map : hostHashMaps->embHashMaps) { - EvictKeys(map.first, evictKeyMap[COMBINE_HISTORY_NAME]); - EvictSSDKeys(map.first, evictKeyMap[COMBINE_HISTORY_NAME]); + EmbeddingMgmt::Instance()->EvictKeys(map.first, evictKeyMap[COMBINE_HISTORY_NAME]); } } else { for (const auto& evict : as_const(evictKeyMap)) { @@ -1001,46 +1001,35 @@ bool HybridMgmt::Evict() /// \param keys void HybridMgmt::EvictKeys(const string& embName, const vector& keys) { -#ifndef GTEST - LOG_DEBUG(MGMT + "ddr mode, delete emb: [{}]! evict keySize:{}", embName, keys.size()); - // 删除映射关系 - if (keys.size() != 0) { - hostHashMaps->EvictDeleteEmb(embName, keys); - } + std::shared_ptr table = EmbeddingMgmt::Instance()->GetTable(embName); - // 初始化host侧的emb - auto& evictOffset = hostHashMaps->GetEvictPos(embName); - vector evictOffset4Ddr; - if (hostHashMaps->embHashMaps.find(embName) == hostHashMaps->embHashMaps.end()) { - LOG_ERROR("Failed to get embedding hash map with given name: {}", embName); - return; - } - auto devVocabSize = hostHashMaps->embHashMaps.at(embName).devVocabSize; - for (auto& offsetInHostHashMap : evictOffset) { - evictOffset4Ddr.emplace_back(offsetInHostHashMap - devVocabSize); - } - if (!evictOffset4Ddr.empty()) { - LOG_DEBUG(MGMT + "ddr mode, delete emb: [{}]! evict size on host:{}", embName, evictOffset4Ddr.size()); - hostEmbs->EvictInitEmb(embName, evictOffset4Ddr); - } else { - LOG_INFO(MGMT + "ddr mode, evict size on host is empty"); - } + table->EvictKeys(keys); + + const vector& evictOffsetDev = table->GetEvictedKeys(); + const vector& evictOffsetHost = table->GetHostEvictedKeys(); - // 发送dev侧的淘汰pos,以便dev侧初始化emb - auto evictDevOffset = hostHashMaps->embHashMaps.at(embName).evictDevPos; - LOG_DEBUG(MGMT + "ddr mode, init dev emb: [{}]! evict size on dev :{}", embName, evictDevOffset.size()); + vector evictOffsetHostx(evictOffsetHost); + + size_t devVocabSize = table->GetDevVocabSize(); + for (int64_t& key: evictOffsetHostx) { + key -= static_cast(devVocabSize); + }; + + /* 淘汰Host侧 */ + if (!evictOffsetHost.empty()) { + hostEmbs->EvictInitEmb(embName, evictOffsetHost); + } vector tmpDataOut; - Tensor tmpData = Vec2TensorI32(evictDevOffset); + Tensor tmpData = Vec2TensorI32(evictOffsetDev); tmpDataOut.emplace_back(tmpData); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto evictLen = tmpDataOut.back().flat(); - auto evictSize = static_cast(evictDevOffset.size()); + auto evictSize = static_cast(evictOffsetDev.size()); evictLen(0) = evictSize; hdTransfer->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); -#endif } inline void HybridMgmt::PrepareDDRData(const string& embTableName, EmbHashMapInfo& embHashMap, @@ -1130,18 +1119,13 @@ int64_t HybridMgmt::GetTableSize(const string& embName) const } if (mgmtRankInfo.useDynamicExpansion) { - int64_t size = KEY_PROCESS_INSTANCE->GetExpansionTableSize(embName); + int64_t size = EmbeddingMgmt::Instance()->GetSize(embName); LOG_INFO(MGMT + "dynamic expansion mode, get emb:[{}] size:{}", embName, size); return size; } - if (mgmtRankInfo.noDDR) { - auto maxOffset = KEY_PROCESS_INSTANCE->GetMaxOffset(); - const auto& iter = maxOffset.find(embName); - if (iter == maxOffset.end()) { - LOG_ERROR(MGMT + "get maxOffset, wrong embName:{} ", embName); - return -1; - } - int64_t size = static_cast(maxOffset[embName]); + if (!mgmtRankInfo.isDDR) { + size_t maxOffset = EmbeddingMgmt::Instance()->GetMaxOffset(embName); + int64_t size = static_cast(maxOffset); LOG_INFO(MGMT + "HBM mode, get emb:[{}] size:{}", embName, size); return size; } @@ -1174,7 +1158,7 @@ int64_t HybridMgmt::GetTableCapacity(const string& embName) const } if (mgmtRankInfo.useDynamicExpansion) { - int64_t capacity = KEY_PROCESS_INSTANCE->GetExpansionTableCapacity(embName); + int64_t capacity = EmbeddingMgmt::Instance()->GetCapacity(embName); LOG_INFO(MGMT + "dynamic expansion mode, get emb:[{}] capacity:{}", embName, capacity); return capacity; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp index 4ce54bb4..714edf9d 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt_block.cpp @@ -119,7 +119,7 @@ void HybridMgmtBlock::CheckValid(int channelId) lastRunChannelId, pythonBatchId[lastRunChannelId], hybridBatchId[lastRunChannelId]); } else if (pythonBatchId[lastRunChannelId] < hybridBatchId[lastRunChannelId]) { // 在通道切换时,上一个通道处理的数据超出了python侧的调用 - if (!rankInfo.noDDR and !WaitValid(lastRunChannelId)) { + if (rankInfo.isDDR and !WaitValid(lastRunChannelId)) { throw HybridMgmtBlockingException("when channel switch"); } } else { @@ -229,4 +229,4 @@ void HybridMgmtBlock::SetStepInterval(int trainStep, int evalStep) HybridMgmtBlock::~HybridMgmtBlock() { Destroy(); -} \ No newline at end of file +} diff --git a/src/core/key_process/key_process.cpp b/src/core/key_process/key_process.cpp index 03a0dbb7..2218799a 100644 --- a/src/core/key_process/key_process.cpp +++ b/src/core/key_process/key_process.cpp @@ -20,7 +20,7 @@ See the License for the specific language governing permissions and #include "utils/time_cost.h" #include "utils/config.h" #include "host_emb/host_emb.h" -#include "checkpoint/checkpoint.h" +#include "emb_table/embedding_mgmt.h" #include "hd_transfer/hd_transfer.h" #include "ock_ctr_common/include/error_code.h" @@ -86,7 +86,9 @@ bool KeyProcess::Initialize(const RankInfo& rInfo, const vector& eInfos LOG_INFO(KEY_PROCESS "scInfo:{}, localRankSize:{}, rankSize:{}, useStatic:{}, useHot:{}", MapToString(scInfo), rInfo.localRankSize, rInfo.rankSize, rInfo.useStatic, rInfo.useHot); +#ifndef GTEST Start(); +#endif return true; } @@ -135,7 +137,7 @@ void KeyProcess::InitHotEmbTotCount(const EmbInfo& info, const RankInfo& rInfo) OffsetMemT KeyProcess::GetMaxOffset() { - return maxOffset; + return EmbeddingMgmt::Instance()->GetMaxOffset(); } KeyOffsetMemT KeyProcess::GetKeyOffsetMap() @@ -155,14 +157,14 @@ FeatureAdmitAndEvict& KeyProcess::GetFeatAdmitAndEvict() void KeyProcess::LoadMaxOffset(OffsetMemT& loadData) { - maxOffset = std::move(loadData); + EmbeddingMgmt::Instance()->LoadMaxOffset(loadData); } /// 加载每张表key到offset的映射 /// \param loadData void KeyProcess::LoadKeyOffsetMap(KeyOffsetMemT& loadData) { - keyOffsetMap = std::move(loadData); + EmbeddingMgmt::Instance()->LoadKeyOffsetMap(loadData); } void KeyProcess::LoadKeyCountMap(KeyCountMemT& loadData) @@ -370,9 +372,9 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch // map key to offset directly by lookup keyOffsetMap (hashmap) RecordKeyCountMap(batch); - if (rankInfo.noDDR) { + if (!rankInfo.isDDR) { TimeCost key2OffsetTC; - Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv, channel); + EmbeddingMgmt::Instance()->Key2Offset(batch->name, uniqueInfo.all2AllInfo.keyRecv, channel); LOG_DEBUG("key2OffsetTC(ms):{}", key2OffsetTC.ElapsedMS()); } // Static all2all,need send count @@ -385,7 +387,7 @@ bool KeyProcess::KeyProcessTaskHelperWithFastUnique(unique_ptr& batch tensors->push_back(Vec2TensorI32(uniqueInfo.hotPos)); } - if (rankInfo.noDDR) { + if (!rankInfo.isDDR) { PushGlobalUniqueTensors(move(tensors), uniqueInfo.all2AllInfo.keyRecv, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(uniqueInfo.all2AllInfo.keyRecv) : Vec2TensorI32(uniqueInfo.all2AllInfo.keyRecv)); @@ -432,12 +434,8 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, // without host, just device, all embedding vectors were stored in device // map key to offset directly by lookup keyOffsetMap (hashmap) - if (rankInfo.noDDR) { - if (rankInfo.useDynamicExpansion) { - Key2OffsetDynamicExpansion(batch->name, lookupKeys, channel); - } else { - Key2Offset(batch->name, lookupKeys, channel); - } + if (!rankInfo.isDDR) { + EmbeddingMgmt::Instance()->Key2Offset(batch->name, lookupKeys, channel); } // Static all2all,need send count @@ -451,7 +449,7 @@ bool KeyProcess::KeyProcessTaskHelper(unique_ptr& batch, int channel, tensors->push_back(Vec2TensorI32(hotPos)); } - if (rankInfo.noDDR) { + if (!rankInfo.isDDR) { PushGlobalUniqueTensors(tensors, lookupKeys, channel); tensors->push_back(rankInfo.useDynamicExpansion ? Vec2TensorI64(lookupKeys) : Vec2TensorI32(lookupKeys)); } @@ -519,7 +517,7 @@ void KeyProcess::PushResult(unique_ptr& batch, unique_ptr lockGuard(mut); storage.push_front(move(tensors)); infoList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, storage.begin())); - if (!rankInfo.noDDR) { + if (rankInfo.isDDR) { lookupKeysList[batch->name][batch->channel].push(make_tuple(batch->batchId, batch->name, move(lookupKeys))); } lockGuard.unlock(); @@ -1321,8 +1319,9 @@ KeysT KeyProcess::GetLookupKeys(int batch, const string& embName, int channel) /// \param channel 通道索引(训练/推理) void KeyProcess::SendEos(int batchId, int channel) { - LOG_INFO("channelId:{} batchId:{}, SendEos start.", channel, batchId); #ifndef GTEST + LOG_INFO("channelId:{} batchId:{}, SendEos start.", channel, batchId); + auto trans = Singleton::GetInstance(); unordered_map transChannels = trans->GetTransChannel(); std::set usedChannelNames = trans->GetUsedTransChannel()[channel]; @@ -1342,11 +1341,12 @@ void KeyProcess::SendEos(int batchId, int channel) } LOG_INFO("channelId:{} batchId:{}, the embName:{} related channel SendEos end.", channel, batchId, emb.first); } -#endif + LOG_INFO("channelId:{} batchId:{}, SendEos end.", channel, batchId); isNeedSendEos[channel] = false; mpiAllReduceSend[channel] = 0; isNeedExit[channel] = true; +#endif } /// HBM模式下,从list中获取指定类型的tensor向量 @@ -1442,29 +1442,13 @@ int KeyProcess::GetMaxStep(int channelId) const void KeyProcess::EvictKeys(const string& embName, const vector& keys) // hbm { LOG_INFO(KEY_PROCESS "hbm funEvictCall: [{}]! keySize:{}", embName, keys.size()); - - // 删除映射关系 - if (keys.size() != 0) { - EvictDeleteDeviceEmb(embName, keys); - } - - // 初始化 dev - EvictInitDeviceEmb(embName, evictPosMap.at(embName)); + EmbeddingMgmt::Instance()->EvictKeys(embName, keys); } void KeyProcess::EvictKeysCombine(const vector& keys) // hbm { LOG_INFO(KEY_PROCESS "hbm combine funEvictCall, keySize:{}", keys.size()); - // 删除映射关系 - if (keys.size() != 0) { - for (const auto& map : keyOffsetMap) { - EvictDeleteDeviceEmb(map.first, keys); - } - } - for (const auto map : evictPosMap) { - // 初始化 dev - EvictInitDeviceEmb(map.first, map.second); - } + EmbeddingMgmt::Instance()->EvictKeysCombine(keys); } void KeyProcess::EvictDeleteDeviceEmb(const string& embName, const vector& keys) @@ -1505,7 +1489,7 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset embName, offset.size(), embInfos[embName].devVocabSize ).c_str()); } -#ifndef GTEST + vector tmpDataOut; Tensor tmpData = Vec2TensorI32(offset); tmpDataOut.emplace_back(tmpData); @@ -1518,7 +1502,7 @@ void KeyProcess::EvictInitDeviceEmb(const string& embName, vector offset // evict key发送给dev侧,dev侧初始化emb auto trans = Singleton::GetInstance(); trans->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); -#endif + LOG_INFO(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", embName, offset.size()); } @@ -1537,24 +1521,12 @@ string KeyProcess::DumpSplitKeys(vector> &splitKeys) const int64_t KeyProcess::GetExpansionTableSize(const string& embName) { - const auto& iter = embeddingTableMap.find(embName); - if (iter == embeddingTableMap.end()) { - LOG_ERROR(KEY_PROCESS "GetExpansionEmbSize, wrong embName:{} ", embName); - return -1; - } - std::lock_guard lk(mut); // lock for PROCESS_THREAD - return iter->second.GetTableSize(); + return EmbeddingMgmt::Instance()->GetSize(embName); } int64_t KeyProcess::GetExpansionTableCapacity(const string& embName) { - const auto& iter = embeddingTableMap.find(embName); - if (iter == embeddingTableMap.end()) { - LOG_ERROR(KEY_PROCESS "GetExpansionEmbSize, wrong embName:{} ", embName); - return -1; - } - std::lock_guard lk(mut); // lock for PROCESS_THREAD - return iter->second.GetTableCapacity(); + return EmbeddingMgmt::Instance()->GetCapacity(embName); } void KeyProcess::RecordKeyCountMap(const unique_ptr& batch) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index e857eb5c..5c2d316e 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -95,7 +95,8 @@ namespace MxRec { constexpr int EOS_TIMEOUT = 5; constexpr size_t DEFAULT_RANDOM_SEED = 10086; - constexpr int INVALID_KEY_VALUE = -1; + // constexpr int INVALID_KEY_VALUE = -1; + constexpr int64_t INVALID_KEY_VALUE = -1; constexpr int ALLTOALLVC_ALIGN = 128; constexpr int PROFILING_START_BATCH_ID = 100; constexpr int PROFILING_END_BATCH_ID = 200; @@ -222,7 +223,7 @@ namespace MxRec { bool useHot {}; uint32_t option {}; int nBatch {}; - bool noDDR { false }; + bool isDDR { true }; bool isSSDEnabled { false }; bool useDynamicExpansion {false}; std::vector maxStep; diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index cc18eb14..fd9fbf9d 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -338,48 +338,6 @@ TEST_F(CheckpointTest, HostEmbs) } } -TEST_F(CheckpointTest, EmbHashMaps) -{ - EmbHashMemT testEmbHashMaps; - EmbHashMemT validEmbHashMaps; - - SetEmbInfo(); - SetEmbHashMaps(testEmbHashMaps); - validEmbHashMaps = testEmbHashMaps; - - CkptData testSaveData; - CkptData validLoadData; - CkptData testLoadData; - - testSaveData.embHashMaps = std::move(testEmbHashMaps); - validLoadData.embHashMaps = std::move(validEmbHashMaps); - - Checkpoint testCkpt; - testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); - testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::EMB_HASHMAP }); - - EXPECT_EQ(validLoadData.embHashMaps.size(), testLoadData.embHashMaps.size()); - for (const auto& it : validLoadData.embHashMaps) { - EXPECT_EQ(1, testLoadData.embHashMaps.count(it.first)); - - const auto& hostHashMap = testLoadData.embHashMaps.at(it.first).hostHashMap; - const auto& devOffset2Batch = testLoadData.embHashMaps.at(it.first).devOffset2Batch; - const auto& devOffset2Key = testLoadData.embHashMaps.at(it.first).devOffset2Key; - const auto& currentUpdatePos = testLoadData.embHashMaps.at(it.first).currentUpdatePos; - const auto& hostVocabSize = testLoadData.embHashMaps.at(it.first).hostVocabSize; - const auto& devVocabSize = testLoadData.embHashMaps.at(it.first).devVocabSize; - - EXPECT_EQ(it.second.hostHashMap, hostHashMap); - - EXPECT_EQ(it.second.devOffset2Batch, devOffset2Batch); - EXPECT_EQ(it.second.devOffset2Key, devOffset2Key); - - EXPECT_EQ(it.second.currentUpdatePos, currentUpdatePos); - EXPECT_EQ(it.second.hostVocabSize, hostVocabSize); - EXPECT_EQ(it.second.devVocabSize, devVocabSize); - } -} - TEST_F(CheckpointTest, KeyOffsetMaps) { KeyOffsetMemT testKeyOffsetMaps; @@ -411,120 +369,6 @@ TEST_F(CheckpointTest, KeyOffsetMaps) } } -TEST_F(CheckpointTest, AllMgmt) -{ - OffsetMemT testMaxOffset; - OffsetMemT validMaxOffset; - KeyOffsetMemT testKeyOffsetMaps; - KeyOffsetMemT validKeyOffsetMaps; - - SetEmbInfo(); - SetMaxOffset(testMaxOffset); - validMaxOffset = testMaxOffset; - SetKeyOffsetMaps(testKeyOffsetMaps); - validKeyOffsetMaps = testKeyOffsetMaps; - - CkptData testSaveData; - CkptData validLoadData; - CkptData testLoadData; - - testSaveData.maxOffset = std::move(testMaxOffset); - validLoadData.maxOffset = std::move(validMaxOffset); - testSaveData.keyOffsetMap = std::move(testKeyOffsetMaps); - validLoadData.keyOffsetMap = std::move(validKeyOffsetMaps); - - Checkpoint testCkpt; - testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); - testCkpt.LoadModel(testPath, - testLoadData, - rankInfo, - testEmbInfos, - {CkptFeatureType::KEY_OFFSET_MAP }); - - EXPECT_EQ(validLoadData.maxOffset.size(), testLoadData.maxOffset.size()); - for (const auto& it : validLoadData.maxOffset) { - EXPECT_EQ(1, testLoadData.maxOffset.count(it.first)); - const auto& maxOffset = testLoadData.maxOffset.at(it.first); - EXPECT_EQ(it.second, maxOffset); - } - - EXPECT_EQ(validLoadData.keyOffsetMap.size(), testLoadData.keyOffsetMap.size()); - for (const auto& it : validLoadData.keyOffsetMap) { - EXPECT_EQ(1, testLoadData.keyOffsetMap.count(it.first)); - const auto& keyOffsetMap = testLoadData.keyOffsetMap.at(it.first); - const auto& validKeyOffsetMap = validLoadData.keyOffsetMap.at(it.first); - for (const auto& key: keyOffsetMap) { - EXPECT_EQ(validKeyOffsetMap.count(key.first), 1); - } - } -} - -TEST_F(CheckpointTest, FeatAdmitNEvict) -{ - Table2ThreshMemT testTrens2Thresh; - Table2ThreshMemT validTrens2Thresh; - AdmitAndEvictData testHistRec; - AdmitAndEvictData validHistRec; - - SetEmbInfo(); - SetTable2Threshold(testTrens2Thresh); - validTrens2Thresh = testTrens2Thresh; - bool isCombine = false; - - if (isCombine) { - SetHistRecCombine(testHistRec); - } else { - SetHistRec(testHistRec); - } - - validHistRec = testHistRec; - - CkptData testSaveData; - CkptData validLoadData; - CkptData testLoadData; - - testSaveData.table2Thresh = testTrens2Thresh; - testSaveData.histRec.timestamps = testHistRec.timestamps; - testSaveData.histRec.historyRecords = testHistRec.historyRecords; - validLoadData.table2Thresh = validTrens2Thresh; - validLoadData.histRec = validHistRec; - validLoadData.histRec.timestamps = validHistRec.timestamps; - validLoadData.histRec.historyRecords = validHistRec.historyRecords; - - Checkpoint testCkpt; - testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); - testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::FEAT_ADMIT_N_EVICT }); - - EXPECT_EQ(validLoadData.table2Thresh.size(), testLoadData.table2Thresh.size()); - EXPECT_EQ(validLoadData.histRec.historyRecords.size(), testLoadData.histRec.historyRecords.size()); - for (const auto& it : validLoadData.table2Thresh) { - EXPECT_EQ(1, testLoadData.table2Thresh.count(it.first)); - - const auto& table2Thresh = testLoadData.table2Thresh.at(it.first); - - EXPECT_EQ(it.second.tableName, table2Thresh.tableName); - EXPECT_EQ(it.second.countThreshold, table2Thresh.countThreshold); - EXPECT_EQ(it.second.timeThreshold, table2Thresh.timeThreshold); - } - - for (const auto& it : validLoadData.histRec.timestamps) { - EXPECT_EQ(1, testLoadData.histRec.timestamps.count(it.first)); - EXPECT_EQ(1, testLoadData.histRec.historyRecords.count(it.first)); - - const auto& timestamps = testLoadData.histRec.timestamps.at(it.first); - const auto& historyRecords = testLoadData.histRec.historyRecords.at(it.first); - const auto& validHistRec = validLoadData.histRec.historyRecords.at(it.first); - - EXPECT_EQ(it.second, timestamps); - for (const auto& validHR : validHistRec) { - const auto& testHR = historyRecords.at(validHR.first); - - EXPECT_EQ(validHR.second.count, testHR.count); - EXPECT_EQ(validHR.second.lastTime, testHR.lastTime); - } - } -} - TEST_F(CheckpointTest, KeyFreqMaps) { @@ -546,7 +390,6 @@ TEST_F(CheckpointTest, KeyFreqMaps) testSaveData.ddrKeyFreqMaps = std::move(testDDRKeyFreqMaps); testSaveData.excludeDDRKeyFreqMaps = std::move(testExcludeDDRKeyFreqMaps); validLoadData.ddrKeyFreqMaps = std::move(validDDRKeyFreqMaps); - Checkpoint testCkpt; testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::DDR_KEY_FREQ_MAP }); @@ -557,27 +400,4 @@ TEST_F(CheckpointTest, KeyFreqMaps) const auto& ddrKeyFreqMap = testLoadData.ddrKeyFreqMaps.at(it.first); EXPECT_EQ(it.second, ddrKeyFreqMap); } -} - -TEST_F(CheckpointTest, KeyCountMapCkpt) -{ - KeyCountMemT testKeyCountMaps; - KeyCountMemT validKeyCountMaps; - - SetEmbInfo(); - SetKeyCountMaps(testKeyCountMaps); - - validKeyCountMaps = testKeyCountMaps; - - CkptData testSaveData; - CkptData validLoadData; - CkptData testLoadData; - - testSaveData.keyCountMap = std::move(testKeyCountMaps); - validLoadData.keyCountMap = std::move(validKeyCountMaps); - - Checkpoint testCkpt; - testCkpt.SaveModel(testPath, testSaveData, rankInfo, testEmbInfos); - testCkpt.LoadModel(testPath, testLoadData, rankInfo, testEmbInfos, { CkptFeatureType::KEY_COUNT_MAP }); - EXPECT_EQ(validLoadData.keyCountMap.size(), testLoadData.keyCountMap.size()); } \ No newline at end of file diff --git a/src/tests/emb_hashmap/emb_hashmap_test.cpp b/src/tests/emb_hashmap/emb_hashmap_test.cpp index 960538d5..ac2f1583 100644 --- a/src/tests/emb_hashmap/emb_hashmap_test.cpp +++ b/src/tests/emb_hashmap/emb_hashmap_test.cpp @@ -71,6 +71,7 @@ TEST(EmbHashMap, TestFindOffset) string embTableName = "table1"; EmbHashMap hostHashMaps; RankInfo rankInfo; + rankInfo.isDDR = true; auto embInfo = GetEmbInfoList(); hostHashMaps.Init(rankInfo, embInfo, false); CacheManager cacheManager; @@ -123,6 +124,7 @@ TEST(EmbHashMap, TESTGetHashMaps) string embTableName = "table1"; EmbHashMap hostHashMaps; RankInfo rankInfo; + rankInfo.isDDR = true; auto embInfo = GetEmbInfoList(); hostHashMaps.Init(rankInfo, embInfo, false); CacheManager cacheManager; diff --git a/src/tests/emb_table/emb_table_test.cpp b/src/tests/emb_table/emb_table_test.cpp index 52a0c169..ccae5041 100644 --- a/src/tests/emb_table/emb_table_test.cpp +++ b/src/tests/emb_table/emb_table_test.cpp @@ -41,7 +41,7 @@ protected: rankInfo.localRankSize = 1; rankInfo.useStatic = true; rankInfo.localRankId = 0; - rankInfo.noDDR = false; + rankInfo.isDDR = true; rankInfo.maxStep = { 1, -1 }; rankInfo.deviceId = 0; // 初始化EmbeddingTable diff --git a/src/tests/emb_table/embedding_ddr_test.cpp b/src/tests/emb_table/embedding_ddr_test.cpp new file mode 100644 index 00000000..71245b59 --- /dev/null +++ b/src/tests/emb_table/embedding_ddr_test.cpp @@ -0,0 +1,200 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/common.h" +#include "emb_table/emb_table.h" +#include "emb_table/embedding_ddr.h" +#include "host_emb/host_emb.h" + +using namespace std; +using namespace MxRec; +using namespace testing; +using namespace tensorflow; + +class EmbeddingDDRTest : public testing::Test { +protected: + EmbeddingDDRTest() + { + struct EmbInfoParams embParam(string("test1"), 0, 1000, 2000, true, true); + std::vector vocabsize = {100}; + std::vector initializeInfos = {}; + std::vector ssdDataPath = {""}; + vector maxStep = {1000}; + embInfo_ = EmbInfo(embParam, vocabsize, initializeInfos, ssdDataPath); + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + rankInfo_ = RankInfo(rankId, 0, 0, 1, maxStep); + } + + void SetUp() { + } + void TearDown() { + } + + static void SetupTestCase() + { + if (access("test_dir", F_OK) == 0) { + system("rm -rf test_dir"); + } + } + + static void TearDownTestCase() + { + if (access("test_dir", F_OK) == 0) { + system("rm -rf test_dir"); + } + } + + EmbInfo embInfo_; + RankInfo rankInfo_; +}; + +TEST_F(EmbeddingDDRTest, SaveLoadBasic) +{ + vector embInfos = {embInfo_}; + HostEmb* hostEmbs = Singleton::GetInstance(); + hostEmbs->Initialize(embInfos, 0); + HostEmbTable& table = hostEmbs->GetEmb("test1"); + + shared_ptr ddr1 = std::make_shared(embInfo_, rankInfo_, 0); + shared_ptr ddr2 = std::make_shared(embInfo_, rankInfo_, 0); + + // 使用时间构造测试数据 + ddr1->extEmbSize_ = time(nullptr); + ddr1->devVocabSize_ = time(nullptr); + ddr1->hostVocabSize_ = time(nullptr); + ddr1->currentUpdatePos = time(nullptr); + ddr1->maxOffset_ = time(nullptr); + + vector devOffset2KeyTestData; + for (int i = 0; i < 10; ++i) { + devOffset2KeyTestData.push_back(static_cast(i)); + ddr1->keyOffsetMap_[i] = i; + ddr1->evictPos_.push_back(i); + } + + ddr1->devOffset2Key = devOffset2KeyTestData; + + ddr1->Save("test_dir"); + ddr2->Load("test_dir"); + + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(ddr1->evictPos_[i], ddr2->evictPos_[i]); + } + + EXPECT_EQ(ddr1->extEmbSize_, ddr2->extEmbSize_); + EXPECT_EQ(ddr1->devVocabSize_, ddr2->devVocabSize_); +} + +/** + * 测试host侧 embedding数据的保存和加载 + */ +TEST_F(EmbeddingDDRTest, SaveLoadEmbeddingData) +{ + vector embInfos = {embInfo_}; + HostEmb* hostEmbs = Singleton::GetInstance(); + hostEmbs->Initialize(embInfos, 0); + HostEmbTable& table = hostEmbs->GetEmb("test1"); + + vector tmp1 {1.1, 2.1, 3.1}; + vector tmp2 {1.2, 2.2, 3.2}; + vector tmp3 {1.3, 2.3, 3.3}; + vector> testData; + testData.push_back(tmp1); + testData.push_back(tmp2); + testData.push_back(tmp3); + + for (vector& tmp : testData) { + table.embData.push_back(tmp); + } + + shared_ptr ddr1 = std::make_shared(embInfo_, rankInfo_, 0); + shared_ptr ddr2 = std::make_shared(embInfo_, rankInfo_, 0); + ddr1->Save("test_dir"); + // 修改成0 + for (vector& tmp: table.embData) { + for (float& t : tmp) { + t = 0; + } + } + ddr2->Load("test_dir"); + for (size_t i = 0; i < table.embData.size(); ++i) { + for (size_t j = 0; j < table.embData[i].size(); ++j) { + EXPECT_EQ(testData[i][j], table.embData[i][j]); + } + } +} + +/** + * 测试基本查找 + */ +TEST_F(EmbeddingDDRTest, DDRBasic) +{ + shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); + const size_t testNum = 100; + vector testKeys; + vector testSwap; + for (size_t i = 0; i < testNum; ++i) { + testKeys.push_back(i); + } + table->FindOffset(testKeys, 0, TRAIN_CHANNEL_ID, testSwap); + EXPECT_EQ(testKeys.size(), 100); + EXPECT_EQ(testSwap.size(), 0); +} + +TEST_F(EmbeddingDDRTest, evict) +{ + shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); + const size_t testNum = 100; + vector testKeys; + vector testSwap; + for (size_t i = 0; i < testNum; ++i) { + testKeys.push_back(i); + } + table->FindOffset(testKeys, 0, TRAIN_CHANNEL_ID, testSwap); + table->EvictKeys(testKeys); + EXPECT_EQ(table->evictPos_.size(), 100); + EXPECT_EQ(testKeys.size(), 100); + EXPECT_EQ(testSwap.size(), 0); +} + +TEST_F(EmbeddingDDRTest, FindSwap) +{ + shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); + const size_t testNum = 100; + vector testSwap; + table->FindSwapPosOld(0, 0, 0, testSwap); + EXPECT_EQ(testSwap.size(), 1); +} + +TEST_F(EmbeddingDDRTest, EvictDeleteEmb) +{ + shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); + const size_t testNum = 100; + vector testKeys; + for (size_t i = 0; i < testNum; ++i) { + testKeys.push_back(i); + } + table->EvictDeleteEmb(testKeys); + EXPECT_EQ(testKeys.size(), 100); +} diff --git a/src/tests/emb_table/embedding_mgmt_test.cpp b/src/tests/emb_table/embedding_mgmt_test.cpp new file mode 100644 index 00000000..9374b078 --- /dev/null +++ b/src/tests/emb_table/embedding_mgmt_test.cpp @@ -0,0 +1,112 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/common.h" +#include "emb_table/emb_table.h" +#include "emb_table/embedding_mgmt.h" + +using namespace std; +using namespace MxRec; +using namespace testing; +using namespace tensorflow; + +class EmbeddingMgmtTest : public testing::Test { +protected: + EmbeddingMgmtTest() + { + struct EmbInfoParams embParam(string("test1"), 0, 1000, 2000, true, true); + std::vector vocabsize = {100}; + std::vector initializeInfos = {}; + std::vector ssdDataPath = {""}; + vector maxStep = {1000}; + embInfo_ = EmbInfo(embParam, vocabsize, initializeInfos, ssdDataPath); + int rankId; + MPI_Comm_rank(MPI_COMM_WORLD, &rankId); + rankInfo_ = RankInfo(rankId, 0, 0, 1, maxStep); + rankInfo_.isDDR = false; + } + + void SetUp() { + } + void TearDown() { + } + + static void SetupTestCase() + { + if (access("test_dir", F_OK) == 0) { + system("rm -rf test_dir"); + } + } + + static void TearDownTestCase() + { + if (access("test_dir", F_OK) == 0) { + system("rm -rf test_dir"); + } + } + + EmbInfo embInfo_; + RankInfo rankInfo_; +}; + +TEST_F(EmbeddingMgmtTest, Init) +{ + const string tableName = "test1"; + ThresholdValue thvalue(tableName, 0, 0, 0, false); + vector embInfos = {embInfo_}; + vector thresholds = {thvalue}; + EmbeddingMgmt::Instance()->Init(rankInfo_, embInfos, thresholds, 0); + + constexpr int testNum = 100; + vector testKeys; + for (size_t i = 0; i < testNum; ++i) { + testKeys.push_back(i); + } + EmbeddingMgmt::Instance()->Key2Offset(tableName, testKeys, TRAIN_CHANNEL_ID); + for (size_t i = 0; i < testNum; ++i) { + EXPECT_EQ(testKeys[i], i); + } + EXPECT_EQ(EmbeddingMgmt::Instance()->GetMaxOffset(tableName), testNum); +} + +TEST_F(EmbeddingMgmtTest, GetAttributes) +{ + const string tableName = "test1"; + ThresholdValue thvalue(tableName, 0, 0, 0, false); + vector embInfos = {embInfo_}; + vector thresholds = {thvalue}; + EmbeddingMgmt::Instance()->Init(rankInfo_, embInfos, thresholds, 0); + + constexpr int testNum = 100; + vector testKeys; + for (size_t i = 0; i < testNum; ++i) { + testKeys.push_back(i); + } + EmbeddingMgmt::Instance()->Key2Offset(tableName, testKeys, TRAIN_CHANNEL_ID); + for (size_t i = 0; i < testNum; ++i) { + EXPECT_EQ(testKeys[i], i); + } + EXPECT_EQ(EmbeddingMgmt::Instance()->GetMaxOffset(tableName), testNum); + EXPECT_EQ(EmbeddingMgmt::Instance()->GetSize(tableName), 100); + EXPECT_EQ(EmbeddingMgmt::Instance()->GetCapacity(tableName), 100); +} diff --git a/src/tests/emb_table/embedding_static_test.cpp b/src/tests/emb_table/embedding_static_test.cpp new file mode 100644 index 00000000..27b9fb5f --- /dev/null +++ b/src/tests/emb_table/embedding_static_test.cpp @@ -0,0 +1,149 @@ +/* Copyright 2024. Huawei Technologies Co.,Ltd. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include "utils/common.h" +#include "emb_table/emb_table.h" +#include "emb_table/embedding_static.h" + +using namespace std; +using namespace MxRec; +using namespace testing; +using namespace tensorflow; + +class EmbeddingStaticTest : public testing::Test { +protected: + EmbeddingStaticTest() + { + struct EmbInfoParams embParam(string("test1"), 0, 1000, 2000, true, true); + std::vector vocabsize = {100}; + std::vector initializeInfos = {}; + std::vector ssdDataPath = {""}; + vector maxStep = {1000}; + embInfo_ = EmbInfo(embParam, vocabsize, initializeInfos, ssdDataPath); + rankInfo_ = RankInfo(0, 0, 0, 1, maxStep); + } + + void SetUp() { + } + void TearDown() { + } + static void TearDownTestCase() { + } + + EmbInfo embInfo_; + RankInfo rankInfo_; +}; + +/** + * 正常情况,将表装满 + */ +TEST_F(EmbeddingStaticTest, Key2OffsetBasic) +{ + vector embInfos = {embInfo_}; + + shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); + + vector tmp1; + for (size_t i = 0; i < 100; ++i) { + tmp1.push_back(i); + } + + table->Key2Offset(tmp1, TRAIN_CHANNEL_ID); + for (size_t i = 0; i < tmp1.size(); ++i) { + EXPECT_EQ(tmp1[i], i); + } + EXPECT_EQ(table->size(), 100); + EXPECT_EQ(table->size(), table->GetMaxOffset()); + EXPECT_EQ(table->capacity(), 100); +} + +/** + * 边界条件:101超过100容量 + */ +TEST_F(EmbeddingStaticTest, Key2OffsetOverflow) +{ + vector embInfos = {embInfo_}; + + shared_ptr table= std::make_shared(embInfo_, rankInfo_, 0); + + vector tmp1; + for (size_t i = 0; i < 101; ++i) { + tmp1.push_back(i); + } + int exp = 0; + try { + table->Key2Offset(tmp1, TRAIN_CHANNEL_ID); + } catch (exception& e) { + exp = 1; + } + + EXPECT_EQ(table->capacity(), 100); + EXPECT_EQ(exp, 1); +} + +/** + * 异常1: 使用eval channel + */ +TEST_F(EmbeddingStaticTest, Key2OffsetEvalChannel) +{ + vector embInfos = {embInfo_}; + + shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); + + vector testData; + for (size_t i = 0; i < 100; ++i) { + testData.push_back(i); + } + table->Key2Offset(testData, EVAL_CHANNEL_ID); + for (size_t i = 0; i < 100; ++i) { + EXPECT_EQ(testData[i], INVALID_KEY_VALUE); + } +} + +/** + * 正常: 使用淘汰的位置 + */ +TEST_F(EmbeddingStaticTest, Key2OffsetEvict) +{ + vector embInfos = {embInfo_}; + shared_ptr table = std::make_shared(embInfo_, rankInfo_, 0); + + constexpr size_t tableNum = 100; + constexpr size_t testNum = 10; + + vector testData; + for (size_t i = 0; i < tableNum; ++i) { + testData.push_back(i); + } + table->Key2Offset(testData, TRAIN_CHANNEL_ID); + // 全部淘汰 + table->EvictKeys(testData); + + vector new_data; + for (size_t i = 0; i < testNum; ++i) { + new_data.push_back(i); + } + table->Key2Offset(new_data, TRAIN_CHANNEL_ID); + // 查看是否淘汰 + std::vector evicted_keys = table->GetEvictedKeys(); + EXPECT_EQ(evicted_keys.size(), tableNum - testNum); +} diff --git a/src/tests/host_emb/host_emb_test.cpp b/src/tests/host_emb/host_emb_test.cpp index 3bcc34f7..05a636d9 100644 --- a/src/tests/host_emb/host_emb_test.cpp +++ b/src/tests/host_emb/host_emb_test.cpp @@ -56,33 +56,6 @@ bool operator==(const vector& p1, const vector& p2) return true; } -TEST(HostEmb, HostEmbUpdateTest) { - vector tensors; - Tensor tmpTensor(tensorflow::DT_FLOAT, { 32 }); - auto tmpData = tmpTensor.flat(); - for (int j = 0; j < 32; j++) { - tmpData(j) = 0.1f*j; - } - tensors.emplace_back(tmpTensor); - - EMOCK(&HDTransfer::Recv).expects(exactly(1)).will(returnValue(tensors)); - HostEmb h; - EmbInfo embInfo; - embInfo.name = "TestEmb"; - embInfo.devVocabSize = 100; - embInfo.hostVocabSize = 200; - embInfo.extEmbeddingSize = 32; - std::string name = "random_normal_initializer"; - InitializeInfo info(name, 0, embInfo.extEmbeddingSize, NormalInitializerInfo(0, 1, 7, 1.0)); - embInfo.initializeInfos.emplace_back(info); - vector embInfos {embInfo}; - h.Initialize(embInfos, 7); - vector missingKeysHostPos{199}; - h.UpdateEmb(missingKeysHostPos, TRAIN_CHANNEL_ID, embInfo.name); - ASSERT_EQ(h.hostEmbs[embInfo.name].embData[199][0], 0); - ASSERT_EQ(h.hostEmbs[embInfo.name].embData[199][31], 0.1f*31); -} - TEST(HostEmb, Tensor2Float) { shared_ptr>> lookups; @@ -131,54 +104,4 @@ TEST(HostEmb, DefaultConstructor) ASSERT_EQ(h.procThreadsForEval.size(), 0); } -TEST(HostEmb, InitializerAndEvict) -{ - HostEmb h; - EmbInfo embInfo; - embInfo.name = "TestEmb"; - embInfo.devVocabSize = 100; - embInfo.hostVocabSize = 200; - embInfo.extEmbeddingSize = 32; - std::string name = "constant_initializer"; - float initVal = 0.05f; - InitializeInfo info(name, 0, embInfo.extEmbeddingSize, ConstantInitializerInfo(initVal, 1.0)); - embInfo.initializeInfos.emplace_back(info); - vector embInfos {embInfo}; - h.Initialize(embInfos, 7); - - ASSERT_EQ(h.hostEmbs[embInfo.name].embData.size(), embInfo.hostVocabSize); - ASSERT_EQ(h.hostEmbs[embInfo.name].embData[0].size(), embInfo.extEmbeddingSize); - ASSERT_EQ(h.hostEmbs[embInfo.name].embData[0][0], initVal); - ASSERT_EQ(h.hostEmbs[embInfo.name].embData[0][embInfo.extEmbeddingSize-1], initVal); - - float initVal1 = 100.89f; - InitializeInfo info1(name, 0, embInfo.extEmbeddingSize, ConstantInitializerInfo(initVal1, 1.0)); - embInfo.initializeInfos.clear(); - embInfo.initializeInfos.emplace_back(info1); - vector offset{1, 199}; - h.hostEmbs[embInfo.name].hostEmbInfo = embInfo; - h.EvictInitEmb(embInfo.name, offset); - ASSERT_EQ(h.hostEmbs[embInfo.name].embData[1][0], initVal1); - ASSERT_EQ(h.hostEmbs[embInfo.name].embData[199][embInfo.extEmbeddingSize-1], initVal1); -} - -TEST(HostEmb, GetH2DEmb) -{ - HostEmb h; - EmbInfo embInfo; - embInfo.name = "TestEmb"; - embInfo.devVocabSize = 100; - embInfo.hostVocabSize = 200; - embInfo.extEmbeddingSize = 32; - std::string name = "random_normal_initializer"; - InitializeInfo info(name, 0, embInfo.extEmbeddingSize, NormalInitializerInfo(0, 1, 7, 1.0)); - embInfo.initializeInfos.emplace_back(info); - vector embInfos {embInfo}; - h.Initialize(embInfos, 7); - vector missingKeysHostPos{1, 199}; - vector h2dEmbOut; - h.GetH2DEmb(missingKeysHostPos, embInfo.name, h2dEmbOut); - ASSERT_EQ(h.hostEmbs[embInfo.name].embData[1][0], h2dEmbOut[0].flat()(0)); - ASSERT_EQ(h.hostEmbs[embInfo.name].embData[199][0], h2dEmbOut[0].flat()(32)); -} } \ No newline at end of file diff --git a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp index 2bb86c42..c51d4be2 100644 --- a/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp +++ b/src/tests/hybrid_mgmt/hybrid_mgmt_block_test.cpp @@ -68,31 +68,6 @@ TEST_F(HybridMgmtBlockTest, CountAndNotifyWake) } } -TEST_F(HybridMgmtBlockTest, CheckValid) -{ - hybridMgmtBlock = std::make_unique(); - hybridMgmtBlock->SetStepInterval(1, 1); - hybridMgmtBlock->pythonBatchId[0] = 0; - hybridMgmtBlock->hybridBatchId[0] = 0; - hybridMgmtBlock->CheckValid(0); - hybridMgmtBlock->CheckValid(0); - - int step2 = 2; - hybridMgmtBlock->pythonBatchId[0] = 0; - hybridMgmtBlock->hybridBatchId[0] = step2; - hybridMgmtBlock->lastRunChannelId = 0; - try { - hybridMgmtBlock->CheckValid(1); - ASSERT_EQ(-1, 0); - } catch (HybridMgmtBlockingException e) { - LOG_INFO(HYBRID_BLOCKING + "sucess"); - ASSERT_EQ(0, 0); - } - hybridMgmtBlock->pythonBatchId[0] = 0; - hybridMgmtBlock->hybridBatchId[0] = 1; - hybridMgmtBlock->CheckValid(0); -} - TEST_F(HybridMgmtBlockTest, DoBlock) { hybridMgmtBlock = std::make_unique(); diff --git a/src/tests/key_process/key_process_test.cpp b/src/tests/key_process/key_process_test.cpp index 3b91e726..20be5486 100644 --- a/src/tests/key_process/key_process_test.cpp +++ b/src/tests/key_process/key_process_test.cpp @@ -14,8 +14,6 @@ See the License for the specific language governing permissions and ==============================================================================*/ #include - -#include #include #include #include @@ -24,6 +22,7 @@ See the License for the specific language governing permissions and #include "key_process/key_process.h" #include "ock_ctr_common/include/unique.h" #include "ock_ctr_common/include/error_code.h" +#include "emb_table/embedding_mgmt.h" using namespace std; using namespace MxRec; @@ -74,14 +73,14 @@ protected: rankInfo.useStatic = useStatic; rankInfo.localRankId = rankInfo.rankId % rankInfo.localRankSize; rankInfo.deviceId = rankInfo.localRankId; - rankInfo.noDDR = false; + rankInfo.isDDR = false; + rankInfo.useDynamicExpansion = false; rankInfo.maxStep = { 1, -1 }; rankInfo.useHot = false; // 初始化emb信息 GenEmbInfos(embNum, embInfos, fieldNums); splits = fieldNums; BuildExpect(); - GlobalMockObject::verify(); } // 使用该方法构造的数据需要使用掉,否则会影响其他用例 @@ -106,8 +105,8 @@ protected: batch->batchId = batchId; batch->channel = channel; LOG_DEBUG("[{}/{}]" KEY_PROCESS "PrepareBatch: batchQueueId: {}, {}[{}]{}, sampleSize:{}", - worldRank, worldSize, - batchQueueId, batch->name, batch->channel, batch->batchId, batch->sample.size() + worldRank, worldSize, + batchQueueId, batch->name, batch->channel, batch->batchId, batch->sample.size() ); EmbBatchT temp; temp.sample = batch->sample; @@ -318,51 +317,13 @@ protected: void TearDown() { - GlobalMockObject::verify(); // delete } }; -int EMOCK_API mockStart(KeyProcess* obj) { - return 1; -} - -void EMOCK_API mockDestroy(KeyProcess* obj) { - // 等待线程主动处理结束,再isRunning = false - for (auto& i: obj->procThreads) { - i->join(); - } - obj->isRunning = false; - obj->procThreads.clear(); - return; -} - -void EMOCK_API mockEmptyDestroy(KeyProcess* obj) { - auto batchQueue = SingletonQueue::GetInstances(0); - do { - LOG_INFO("wait for thread running"); - this_thread::sleep_for(2s); - } while (batchQueue->TryPop() != nullptr); - // 通过Queue中数据是否取完了来判断是否需要退出;可能会出现取完却未处理完的情况、出数据不均衡 - obj->isRunning = false; - for (auto& i: obj->procThreads) { - i->join(); - } - return; -} - -void EMOCK_API mockInitExpansionEmb(EmbTable* obj, const EmbInfo& embInfos, const RankInfo&, int) -{ - obj->totalCapacity = embInfos.devVocabSize; - obj->embSize = embInfos.extEmbeddingSize; - obj->usedCapacity = 1; -} - TEST_F(KeyProcessTest, Initialize) { - EMOCK(&KeyProcess::Start) - .expects(exactly(1)) - .will(invoke(mockStart)); + EmbeddingMgmt::Instance()->Init(rankInfo, embInfos); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); ASSERT_EQ(process.rankInfo.rankId, rankInfo.rankId); @@ -378,44 +339,9 @@ TEST_F(KeyProcessTest, Initialize) ock::ctr::Factory::Create(factory); } -TEST_F(KeyProcessTest, InitializeHot) -{ - EMOCK(&KeyProcess::Start) - .expects(exactly(1)) - .will(invoke(mockStart)); - EMOCK(GetChipName) - .stubs() - .with(emock::any()) - .will(returnValue(string("910B"))); // 调用GetChipName时返回910B - EMOCK(&EmbTable::Init) - .stubs() - .will(invoke(mockInitExpansionEmb)); - rankInfo.useHot = true; - rankInfo.useDynamicExpansion = true; - ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - ASSERT_EQ(process.isRunning, true); - ASSERT_EQ(process.hotEmbUpdateStep, GlobalEnv::hotEmbUpdateStep); - - for (const EmbInfo& info: embInfos) { - ASSERT_NE(process.hotEmbTotCount.find(info.name), process.hotEmbTotCount.end()); - } -} - -TEST_F(KeyProcessTest, GetExpansionTableSizeOrCapacity) -{ - EMOCK(&EmbTable::Init) - .stubs() - .will(invoke(mockInitExpansionEmb)); - - for (const EmbInfo& info: embInfos) { - process.embeddingTableMap[info.name].Init(info, rankInfo, 0); - ASSERT_EQ(process.GetExpansionTableSize(info.name), 1); - ASSERT_EQ(process.GetExpansionTableCapacity(info.name), info.devVocabSize); - } -} - TEST_F(KeyProcessTest, Start) { + EmbeddingMgmt::Instance()->Init(rankInfo, embInfos); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); setenv("keyProcessThreadNum", "2", 1); @@ -428,9 +354,6 @@ TEST_F(KeyProcessTest, Start) TEST_F(KeyProcessTest, HashSplit) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); int rankSize = 4; auto queue = SingletonQueue::GetInstances(0); auto batch = queue->GetOne(); @@ -451,9 +374,6 @@ TEST_F(KeyProcessTest, HashSplit) TEST_F(KeyProcessTest, HashSplitWithFAAE) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); int rankSize = 4; auto queue = SingletonQueue::GetInstances(0); auto batch = queue->GetOne(); @@ -480,9 +400,6 @@ TEST_F(KeyProcessTest, HashSplitWithFAAE) // 准入+动态shape下,有padding TEST_F(KeyProcessTest, PaddingHashSplitWithFAAE) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); int rankSize = 4; auto queue = SingletonQueue::GetInstances(0); auto batch = queue->GetOne(); @@ -499,7 +416,7 @@ TEST_F(KeyProcessTest, PaddingHashSplitWithFAAE) process.rankInfo.rankSize = rankSize; auto [splitKeys, restore, keyCount] = process.HashSplitWithFAAE(batch); LOG_INFO(KEY_PROCESS "HashSplitWithFAAE Padding, batch splitKeys: {}, keyCount: {}", VectorToString(splitKeys[0]), - VectorToString(keyCount[0])); + VectorToString(keyCount[0])); for (unsigned int i = 0; i < splitKeys.size(); ++i) { ASSERT_EQ(splitKeys[i].size(), ALLTOALLVC_ALIGN); @@ -509,12 +426,6 @@ TEST_F(KeyProcessTest, PaddingHashSplitWithFAAE) TEST_F(KeyProcessTest, HotHashSplit) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); - EMOCK(&KeyProcess::Destroy) - .expects(exactly(1)) - .will(invoke(mockDestroy)); PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); LOG_INFO("CPU Core Num: %{}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 @@ -533,18 +444,16 @@ TEST_F(KeyProcessTest, HotHashSplit) }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < 1; ++id) { - // use lambda expression initialize thread + // use lambda expression initialize thread process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } } + this_thread::sleep_for(10s); process.Destroy(); } TEST_F(KeyProcessTest, GetScAll) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 LOG_DEBUG(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, VectorToString(keyScLocal)); vector expectScAll(worldSize * worldSize); @@ -563,9 +472,6 @@ TEST_F(KeyProcessTest, GetScAll) TEST_F(KeyProcessTest, HandleRankExitScene) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); // 仅用于集合通信获取sendCount信息,构造EmbBatchT对象即可,通道传0,不用构造batch数据 @@ -594,9 +500,6 @@ TEST_F(KeyProcessTest, HandleRankExitScene) TEST_F(KeyProcessTest, GetScAllForUnique) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); vector keyScLocal(worldSize, worldRank + 1); // 用worldRank+1初始化发送数据量 LOG_INFO(KEY_PROCESS "rank {} keyScLocal: {}", worldRank, VectorToString(keyScLocal)); vector expectScAll(worldSize * worldSize); @@ -618,15 +521,12 @@ TEST_F(KeyProcessTest, GetScAllForUnique) // 非hot、非准入模式,固定batch输入,校验restore TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); auto queue = SingletonQueue::GetInstances(0); auto batch = queue->GetOne(); vector allBatchKeys = { { 1, 4, 23, 14, 16, 7, 2, 21, 21, 29 }, - { 5, 17, 26, 9, 27, 22, 27, 28, 15, 3 }, - { 10, 4, 22, 17, 24, 13, 24, 26, 29, 11 }, - { 14, 21, 18, 25, 21, 4, 20, 24, 13, 19 } }; + { 5, 17, 26, 9, 27, 22, 27, 28, 15, 3 }, + { 10, 4, 22, 17, 24, 13, 24, 26, 29, 11 }, + { 14, 21, 18, 25, 21, 4, 20, 24, 13, 19 } }; vector> allExpectSs = { { 0, 2, 5, 7, 9 }, { 0, 1, 4, 6 }, { 0, 2, 5, 8 }, { 0, 3, 6, 8 } }; vector> allExpectRestore = { { 2, 0, 7, 5, 1, 8, 6, 3, 3, 4 }, { 1, 2, 4, 3, 6, 5, 6, 0, 7, 8 }, @@ -634,7 +534,7 @@ TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) { 6, 3, 7, 4, 3, 0, 1, 2, 5, 8 } }; batch->sample = std::move(allBatchKeys[worldRank]); LOG_INFO(KEY_PROCESS "test BuildRestoreVec: rank {}, batchKeys {}", - worldRank, VectorToString(batch->sample)); + worldRank, VectorToString(batch->sample)); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); auto [splitKeys, restore] = process.HashSplit(batch); @@ -654,12 +554,6 @@ TEST_F(KeyProcessTest, BuildRestoreVec_4cpu) // hot模式,batch随机数,ProcessSplitKeys后人为校验lookupKeys、scAll、restore TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); - EMOCK(&KeyProcess::Destroy) - .expects(exactly(1)) - .will(invoke(mockDestroy)); PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); LOG_INFO("CPU Core Num: {}", sysconf(_SC_NPROCESSORS_CONF)); // 查看CPU核数 @@ -676,8 +570,8 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) auto [lookupKeys, scAll, ss] = process.ProcessSplitKeys(batch, id, splitKeys); process.BuildRestoreVec(batch, ss, restore, hotPos.size()); LOG_INFO("rankid :{}, batchid: {}, lookupKeys: {}, scAll: {}, restore after build {}", - rankInfo.rankId, batch->batchId, VectorToString(lookupKeys), - VectorToString(scAll), VectorToString(restore)); + rankInfo.rankId, batch->batchId, VectorToString(lookupKeys), + VectorToString(scAll), VectorToString(restore)); }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { @@ -685,19 +579,13 @@ TEST_F(KeyProcessTest, BuildRestoreVec_rebuilt) process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } } - + this_thread::sleep_for(10s); process.Destroy(); } // 准入模式,batch随机数,ProcessSplitKeys后人为校验lookupKeys、scAll、count TEST_F(KeyProcessTest, GetCountRecv) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); - EMOCK(&KeyProcess::Destroy) - .expects(exactly(1)) - .will(invoke(mockDestroy)); PrepareBatch(); process.m_featureAdmitAndEvict.m_isEnableFunction = true; for (size_t i = 0; i < embInfos.size(); i++) { @@ -725,66 +613,45 @@ TEST_F(KeyProcessTest, GetCountRecv) }; // for clean code for (int channel = 0; channel < 1; ++channel) { for (int id = 0; id < KEY_PROCESS_THREAD; ++id) { - // use lambda expression initialize thread + // use lambda expression initialize thread process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } } + this_thread::sleep_for(10s); process.Destroy(); } TEST_F(KeyProcessTest, Key2Offset) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); + EmbeddingMgmt::Instance()->Init(rankInfo, embInfos); KeysT lookupKeys = { 4, 16, 28, 4, 24, 4, 20, 24 }; KeysT expectOffset = { 0, 1, 2, 0, 3, 0, 4, 3 }; ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); - process.Key2Offset("emb0", lookupKeys, TRAIN_CHANNEL_ID); + EmbeddingMgmt::Instance()->Key2Offset("emb0", lookupKeys, TRAIN_CHANNEL_ID); map tmp; - for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { + MxRec::KeyOffsetMemT kom = EmbeddingMgmt::Instance()->GetKeyOffsetMap(); + for (auto it = kom.begin(); it != kom.end(); ++it) { tmp.insert(pair(it->first, MapToString(it->second).c_str())); } LOG_DEBUG(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", - VectorToString(lookupKeys), MapToString(tmp)); + VectorToString(lookupKeys), MapToString(tmp)); ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); KeysT lookupKeys2 = { 5, 17, 29, 5, 25, 5, 21, 25 }; KeysT expectOffset2 = { -1, -1, -1, -1, -1, -1, -1, -1 }; - process.Key2Offset("emb0", lookupKeys2, EVAL_CHANNEL_ID); + EmbeddingMgmt::Instance()->Key2Offset("emb0", lookupKeys2, EVAL_CHANNEL_ID); map tmp2; - for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { + MxRec::KeyOffsetMemT kom2 = EmbeddingMgmt::Instance()->GetKeyOffsetMap(); + for (auto it = kom2.begin(); it != kom2.end(); ++it) { tmp.insert(pair(it->first, MapToString(it->second).c_str())); } LOG_DEBUG(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", - VectorToString(lookupKeys2), MapToString(tmp2).c_str()); + VectorToString(lookupKeys2), MapToString(tmp2).c_str()); ASSERT_THAT(lookupKeys2, ElementsAreArray(expectOffset2)); } -TEST_F(KeyProcessTest, Key2OffsetDynamicExpansion) -{ - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); - KeysT lookupKeys = { 4, 16, 28, -1, 24, -1, 20, 24 }; - KeysT expectOffset = { 0, 0, 0, 0, 0, 0, 0, 0 }; - ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - ASSERT_EQ(process.isRunning, true); - process.Key2OffsetDynamicExpansion("emb0", lookupKeys, EVAL_CHANNEL_ID); - - LOG_DEBUG(KEY_PROCESS "test Key2Offset: lookupKeys: {}, keyOffsetMap: {}", VectorToString(lookupKeys), [&] { - map tmp; - for (auto it = process.keyOffsetMap.begin(); it != process.keyOffsetMap.end(); ++it) { - tmp.insert(pair(it->first, MapToString(it->second).c_str())); - } - return MapToString(tmp); - }()); - - ASSERT_THAT(lookupKeys, ElementsAreArray(expectOffset)); -} - TEST_F(KeyProcessTest, GetUniqueConfig) { ock::ctr::UniqueConf uniqueConf; @@ -795,41 +662,22 @@ TEST_F(KeyProcessTest, GetUniqueConfig) process.GetUniqueConfig(uniqueConf); } -// 自动化测试用例 -// 边界值、重复度测试 -TEST_F(KeyProcessTest, ProcessPrefetchTask) -{ - EMOCK(&KeyProcess::Destroy) - .expects(exactly(1)) - .will(invoke(mockEmptyDestroy)); - PrepareBatch(); - rankInfo.noDDR = true; - GlobalEnv::applyGradientsStrategy = ApplyGradientsStrategyOptions::SUM_SAME_ID_GRADIENTS_AND_APPLY; - ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - process.rankInfo.rankSize = worldSize; - process.rankInfo.localRankId = process.rankInfo.rankId % process.rankInfo.localRankSize; - ASSERT_EQ(process.isRunning, true); - // 所有线程处理完(训练结束)后调用 - process.Destroy(); - GlobalEnv::applyGradientsStrategy = ApplyGradientsStrategyOptions::DIRECT_APPLY; -} - // HBM端到端测试,动态shape,固定batch输入 TEST_F(KeyProcessTest, KeyProcessTaskHelper) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); - rankInfo.noDDR = true; + rankInfo.isDDR = false; rankInfo.useStatic = false; rankInfo.useHot = false; + rankInfo.useDynamicExpansion = false; + EmbeddingMgmt::Instance()->Init(rankInfo, embInfos); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); int batchId = 0; int channelId = 0; auto batch = GenBatch(embInfos[0].name, batchId, channelId); // 测试一个表 - LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}", rankInfo.rankId, batch->batchId); + LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, batchSize: {}", + rankInfo.rankId, batch->batchId, batch->sample.size()); ASSERT_EQ(process.KeyProcessTaskHelper(batch, channelId, 0), true); // threadId = 0 auto infoVecs = process.GetInfoVec(batchId, embInfos[0].name, channelId, ProcessedInfo::RESTORE); @@ -859,19 +707,18 @@ TEST_F(KeyProcessTest, KeyProcessTaskHelper) process.SetEos(1, 1); ASSERT_EQ(process.GetInfoVec(batchId + 1, embInfos[0].name, channelId + 1, ProcessedInfo::RESTORE), nullptr); LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, eos status success", rankInfo.rankId, batch->batchId); - + this_thread::sleep_for(10s); process.Destroy(); } // DDR端到端测试,静态shape,固定batch输入 TEST_F(KeyProcessTest, KeyProcessTaskHelperDDR) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); - rankInfo.noDDR = false; + rankInfo.isDDR = true; rankInfo.useStatic = true; rankInfo.useHot = false; + rankInfo.useDynamicExpansion = false; + EmbeddingMgmt::Instance()->Init(rankInfo, embInfos); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); ASSERT_EQ(process.isRunning, true); int batchId = 0; @@ -921,24 +768,25 @@ TEST_F(KeyProcessTest, KeyProcessTaskHelperDDR) process.SetEos(1, 1); ASSERT_EQ(process.GetLookupKeys(batchId + 1, embInfos[0].name, channelId + 1).empty(), true); LOG_INFO("KeyProcessTaskHelper, rankid: {}, batchid: {}, eos status success", rankInfo.rankId, batch->batchId); + this_thread::sleep_for(10s); process.Destroy(); } TEST_F(KeyProcessTest, InitializeUnique) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); ASSERT_EQ(ock::ctr::Factory::Create(factory), -1); ock::ctr::UniquePtr unique; ASSERT_EQ(factory->CreateUnique(unique), 0); PrepareBatch(); - unique_ptr batch; + + unique_ptr batch; batch = process.GetBatchData(0, 0); ock::ctr::UniqueConf uniqueConf; - process.rankInfo.rankSize = worldSize; - process.rankInfo.useStatic = true; + process.rankInfo. + rankSize = worldSize; + process.rankInfo. + useStatic = true; bool uniqueInitialize = false; size_t preBatchSize = 0; process.InitializeUnique(uniqueConf, preBatchSize, uniqueInitialize, batch, unique); @@ -957,12 +805,6 @@ TEST_F(KeyProcessTest, GetKeySize) TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) { - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); - EMOCK(&KeyProcess::Destroy) - .expects(exactly(1)) - .will(invoke(mockDestroy)); PrepareBatch(); ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); @@ -998,6 +840,7 @@ TEST_F(KeyProcessTest, ProcessBatchWithFastUnique) process.procThreads.emplace_back(std::make_unique(fn, channel, id)); } } + this_thread::sleep_for(10s); process.Destroy(); } @@ -1007,29 +850,3 @@ TEST_F(KeyProcessTest, LoadSaveLock) process.LoadSaveUnlock(); } -TEST_F(KeyProcessTest, EvictKeys) -{ - EMOCK(&KeyProcess::Start) - .stubs() - .will(invoke(mockStart)); - ASSERT_EQ(process.Initialize(rankInfo, embInfos), true); - ASSERT_EQ(process.isRunning, true); - absl::flat_hash_map flatTmp0 { {0, 0}, {4, 1}, {8, 2}, {12, 3} }; - absl::flat_hash_map flatTmp1 { {1, 0}, {5, 1}, {9, 2}, {13, 3} }; - absl::flat_hash_map flatTmp2 { {2, 0}, {6, 1}, {10, 2}, {14, 3} }; - absl::flat_hash_map flatTmp3 { {3, 0}, {7, 1}, {11, 2}, {15, 3} }; - vector> allHashMap {flatTmp0, flatTmp1, flatTmp2, flatTmp3}; - process.keyOffsetMap.emplace(embInfos[0].name, allHashMap[worldRank]); - process.evictPosMap.emplace(embInfos[0].name, vector{}); - - vector> allEvictKeys {{4, 8}, {1, 5}, {10, 14}, {3, 11}}; - vector> allEvictPos {{1, 2}, {0, 1}, {2, 3}, {0, 2}}; - process.EvictKeys(embInfos[0].name, allEvictKeys[worldRank]); - ASSERT_THAT(process.evictPosMap.at(embInfos[0].name), ElementsAreArray(allEvictPos[worldRank])); - - // 测试并表统计情况下的淘汰 - vector> allEvictKeysCom {{0}, {}, {18}, {}}; - vector> allEvictPosCom {{1, 2, 0}, {0, 1}, {2, 3}, {0, 2}}; - process.EvictKeysCombine(allEvictKeysCom[worldRank]); - ASSERT_THAT(process.evictPosMap.at(embInfos[0].name), ElementsAreArray(allEvictPosCom[worldRank])); -} diff --git a/tests/mx_rec/core/mock_class.py b/tests/mx_rec/core/mock_class.py index 491578fd..01f2e4f3 100644 --- a/tests/mx_rec/core/mock_class.py +++ b/tests/mx_rec/core/mock_class.py @@ -14,11 +14,137 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import dataclass import tensorflow as tf from tensorflow_core.python.training import slot_creator +from mx_rec import ASCEND_GLOBAL_HASHTABLE_COLLECTION from mx_rec.optimizers.lazy_adam import CustomizedLazyAdam +from mx_rec.util.config_utils.embedding_utils import SparseEmbedConfig +from mx_rec.util.config_utils.feature_spec_utils import FeatureSpecConfig +from mx_rec.util.config_utils.optimizer_utils import OptimizerConfig + + +class MockHybridManagerConfig: + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.freeze = kwargs.get("freeze", False) + self.asc_manager = kwargs.get("asc_manager", None) + + def trigger_evict(self): + return self.kwargs.get("trigger_evict", True) + + def set_asc_manager(self, cache): + pass + + def save_host_data(self, root_dir): + pass + + def get_host_data(self, table_name): + return self.kwargs.get("host_data", []) + + def restore_host_data(self, path): + pass + + +class MockSparseEmbedConfig: + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.table_instance_dict = kwargs.get("table_instance_dict", {}) + self.dangling_table = kwargs.get("dangling_table", []) + self.table_name_set = kwargs.get("table_name_set", set()) + + @staticmethod + def insert_dangling_table(table_name): + pass + + @staticmethod + def insert_removing_var_list(var_name): + pass + + @staticmethod + def insert_table_instance(name, key, instance): + pass + + def get_table_instance(self, var): + return self.kwargs.get("var", None) + + def get_table_instance_by_name(self, name): + return self.kwargs.get("var", None) + + +class MockTrainParamsConfig: + + def __init__(self, **kwargs): + def _get_training_mode_channel_id(is_training): + _dict = {True: 0, False: 1} + return _dict.get(is_training) + + def _insert_training_mode_channel_id(is_training): + pass + + def _get_merged_multi_lookup(is_training): + return kwargs.get('merged_multi_lookup', False) + + def _insert_merged_multi_lookup(is_training, flag): + pass + + def _set_initializer(is_training, initializer): + pass + + def _set_target_batch(is_training, batch): + pass + + self.ascend_global_hashtable_collection = kwargs.get("ascend_global_hashtable_collection", + ASCEND_GLOBAL_HASHTABLE_COLLECTION) + self.is_graph_modify_hook_running = kwargs.get("is_graph_modify_hook_running", True) + self.bool_gauge_set = kwargs.get("bool_gauge_set", []) + self.iterator_type = kwargs.get("iterator_type", "") + self.sparse_dir = kwargs.get("sparse_dir", "") + + self.get_training_mode_channel_id = _get_training_mode_channel_id + self.insert_training_mode_channel_id = _insert_training_mode_channel_id + self.get_merged_multi_lookup = _get_merged_multi_lookup + self.insert_merged_multi_lookup = _insert_merged_multi_lookup + self.set_initializer = _set_initializer + self.set_target_batch = _set_target_batch + + +class MockConfigInitializer: + """ + 原始ConfigInitializer的mock + """ + + def __init__(self, **kwargs): + self.use_dynamic_expansion = kwargs.get("use_dynamic_expansion", False) + self.use_static = kwargs.get("use_static", False) + self.use_hot = kwargs.get("use_static", True) + self.modify_graph = kwargs.get("modify_graph", True) + self.train_steps = kwargs.get("get_train_steps", -1) + self.eval_steps = kwargs.get("eval_steps", -1) + self.save_steps = kwargs.get("save_steps", -1) + self.if_load = kwargs.get("if_load", False) + self.iterator_type = kwargs.get("iterator_type", "MakeIterator") + self.sparse_dir = kwargs.get("sparse_dir", "") + + self.hybrid_manager_config = MockHybridManagerConfig(**kwargs) + self.sparse_embed_config = MockSparseEmbedConfig(**kwargs) + self.train_params_config = MockTrainParamsConfig(**kwargs) + self.optimizer_config = OptimizerConfig() + self.feature_spec_config = FeatureSpecConfig() + self.sparse_embed_cofnig = SparseEmbedConfig() + + def get_instance(self): + return self + + +class MockGlobalEnv: + + def __init__(self, **kwargs): + self.tf_device = kwargs.get("tf_device", 'NPU') class MockSparseEmbedding: @@ -28,6 +154,7 @@ class MockSparseEmbedding: def __init__(self, table_name="test_table", slice_device_vocabulary_size=10, embedding_size=5, init_param=1., emb_initializer=tf.zeros_initializer()): + self.is_hbm = True self.table_name = table_name self.slice_device_vocabulary_size = slice_device_vocabulary_size self.embedding_size = tf.TensorShape([embedding_size]) @@ -74,15 +201,12 @@ class MockHcclOps: self.all_to_all_v_c = _mock_all_to_all_v_c -class MockOptimizer(CustomizedLazyAdam): +class MockOptimizer: """ 用于mock optimizer """ def __init__(self): - super(MockOptimizer, self)._get_name(name="MockLazyAdam") - super(MockOptimizer, self).__init__(learning_rate=0.001, beta1=0.9, beta2=0.999, - epsilon=1e-8, use_locking=False, name="MockLazyAdam") self.slot_num = 2 def initialize_slots(self, var, table_instance): @@ -108,19 +232,19 @@ class MockOptimizer(CustomizedLazyAdam): return [initial_momentum_value, initial_velocity_value] def update_op(self, optimizer, g): - return super().update_op(optimizer, g) + pass def _apply_spare_duplicate_indices(self, grad, var): return self._apply_sparse(grad, var) def _apply_sparse(self, grad, var): - return super()._apply_sparse(grad, var) + pass def _resource_apply_sparse(self, grad, handle, indices): - return super()._resource_apply_sparse(grad, handle, indices) + pass def _apply_dense(self, grad, var): - return super()._apply_dense(grad, var) + pass def _apply_sparse_duplicate_indices(self, grad, var): return self._apply_sparse(grad, var) diff --git a/tests/mx_rec/core/test_build_graph.py b/tests/mx_rec/core/test_build_graph.py index 41706ce7..c15d851f 100644 --- a/tests/mx_rec/core/test_build_graph.py +++ b/tests/mx_rec/core/test_build_graph.py @@ -21,6 +21,7 @@ from unittest import mock import tensorflow as tf from mx_rec.util.global_env_conf import global_env +from tests.mx_rec.core.mock_class import MockConfigInitializer class TestGetRestoreVectorFunc(unittest.TestCase): @@ -30,13 +31,13 @@ class TestGetRestoreVectorFunc(unittest.TestCase): def setUp(self): # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) def tearDown(self): # 恢复config - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) @@ -52,17 +53,6 @@ class TestGetRestoreVectorFunc(unittest.TestCase): with self.assertRaises(TypeError): get_restore_vector(self.config) - def test_get_restore_vector_case2(self): - """ - case2: HBM,emb_size小于1,抛出异常 - """ - - from mx_rec.core.asc.build_graph import get_restore_vector - - self.config["emb_size"] = 0 - with self.assertRaises(ValueError): - get_restore_vector(self.config) - def test_get_restore_vector_case3(self): """ case3: 非HBM,ext_emb_size不为int,抛出异常 @@ -70,28 +60,15 @@ class TestGetRestoreVectorFunc(unittest.TestCase): from mx_rec.core.asc.build_graph import get_restore_vector - self.config["skip_emb_transfer"] = False + self.config["is_hbm"] = False self.config["ext_emb_size"] = "xxx" with self.assertRaises(TypeError): get_restore_vector(self.config) - def test_get_restore_vector_case4(self): - """ - case4: 非HBM,ext_emb_size小于1,抛出异常 - """ - - from mx_rec.core.asc.build_graph import get_restore_vector - - self.config["skip_emb_transfer"] = False - self.config["ext_emb_size"] = 0 - with self.assertRaises(ValueError): - get_restore_vector(self.config) - - @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") @mock.patch("mx_rec.core.asc.build_graph.mxrec_pybind.get_ub_hot_size") @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_restore_vector_case5(self, mock_get_next, mock_get_ub_hot_size): + def test_get_restore_vector_case5(self, mock_get_next, mock_get_ub_hot_size, build_graph_config_initializer): """ case5: HBM,静态shape,hot emb """ @@ -99,17 +76,19 @@ class TestGetRestoreVectorFunc(unittest.TestCase): from mx_rec.core.asc.build_graph import get_restore_vector with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_static=True) + build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_get_next.return_value = [0, 1] mock_get_ub_hot_size.return_value = 8 restore_vector, hot_pos = get_restore_vector(self.config) self.assertEqual(restore_vector, 0) self.assertEqual(hot_pos, 1) - @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") @mock.patch("mx_rec.core.asc.build_graph.mxrec_pybind.get_ub_hot_size") @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_restore_vector_case6(self, mock_get_next, mock_get_ub_hot_size): + def test_get_restore_vector_case6(self, mock_get_next, mock_get_ub_hot_size, build_graph_config_initializer): """ case6: HBM,动态shape,hot emb """ @@ -117,64 +96,15 @@ class TestGetRestoreVectorFunc(unittest.TestCase): from mx_rec.core.asc.build_graph import get_restore_vector with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_static=True) + build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_get_next.return_value = [0, 1] mock_get_ub_hot_size.return_value = 8 restore_vector, hot_pos = get_restore_vector(self.config) self.assertEqual(restore_vector, 0) self.assertEqual(hot_pos, 1) - @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=True)) - @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_restore_vector_case7(self, mock_get_next): - """ - case7: HBM,静态shape - """ - - from mx_rec.core.asc.build_graph import get_restore_vector - - with tf.Graph().as_default(): - mock_get_next.return_value = [0] - self.config["use_hot"] = False - restore_vector, hot_pos = get_restore_vector(self.config) - self.assertEqual(restore_vector, 0) - self.assertIsNone(hot_pos) - - @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=False)) - @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_restore_vector_case8(self, mock_get_next): - """ - case8: HBM,动态shape - """ - - from mx_rec.core.asc.build_graph import get_restore_vector - - with tf.Graph().as_default(): - mock_get_next.return_value = [0] - self.config["use_hot"] = False - restore_vector, hot_pos = get_restore_vector(self.config) - self.assertEqual(restore_vector, 0) - self.assertIsNone(hot_pos) - - @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=False)) - @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_restore_vector_case9(self, mock_get_next): - """ - case9: 非HBM,动态shape - """ - - from mx_rec.core.asc.build_graph import get_restore_vector - - with tf.Graph().as_default(): - mock_get_next.return_value = [0] - self.config["skip_emb_transfer"] = False - self.config["use_hot"] = False - restore_vector, hot_pos = get_restore_vector(self.config) - self.assertEqual(restore_vector, 0) - self.assertIsNone(hot_pos) - class TestGetIdOffsetsFunc(unittest.TestCase): """ @@ -183,14 +113,14 @@ class TestGetIdOffsetsFunc(unittest.TestCase): def setUp(self): # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) self.max_lookup_vec_size = self.config.get("send_count") * self.config.get("rank_size") def tearDown(self): # 恢复config - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) @@ -233,14 +163,14 @@ class TestGetRestoreVectorSecondFunc(unittest.TestCase): def setUp(self): # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) self.max_lookup_vec_size = self.config.get("send_count") * self.config.get("rank_size") def tearDown(self): # 恢复config - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) @@ -265,14 +195,14 @@ class TestGetUniqueKeysFunc(unittest.TestCase): def setUp(self): # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) self.max_lookup_vec_size = self.config.get("send_count") * self.config.get("rank_size") def tearDown(self): # 恢复config - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) @@ -311,13 +241,13 @@ class TestGetAll2allArgsFunc(unittest.TestCase): def setUp(self): # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) def tearDown(self): # 恢复config - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) @@ -353,19 +283,18 @@ class TestGetSwapInfoFunc(unittest.TestCase): def setUp(self): # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) def tearDown(self): # 恢复config - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) - @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=True)) - def test_get_swap_info_case1(self): + @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") + def test_get_swap_info_case1(self, build_graph_config_initializer): """ case1: 静态shape,HBM """ @@ -373,13 +302,15 @@ class TestGetSwapInfoFunc(unittest.TestCase): from mx_rec.core.asc.build_graph import get_swap_info with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_static=True) + build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + swap_in = get_swap_info(self.config, None, None, None) self.assertIsInstance(swap_in[0], type(tf.no_op())) - @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=True)) + @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") @mock.patch("mx_rec.core.asc.build_graph.npu_ops.gen_npu_ops.get_next") - def test_get_swap_info_case2(self, mock_get_next): + def test_get_swap_info_case2(self, mock_get_next, build_graph_config_initializer): """ case2: 静态shape,非HBM,table传入非list,抛出异常 """ @@ -387,11 +318,14 @@ class TestGetSwapInfoFunc(unittest.TestCase): from mx_rec.core.asc.build_graph import get_swap_info with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_static=True) + build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_get_next.return_value = tf.ones(shape=[8, 8], dtype=tf.float32) swap_pos = tf.constant([8, 9], dtype=tf.int32) swap_len = tf.constant(2, dtype=tf.int32) table = tf.compat.v1.get_variable("test_table", shape=[10, 8], initializer=tf.ones_initializer()) - self.config["skip_emb_transfer"] = False + self.config["is_hbm"] = False with self.assertRaises(RuntimeError): get_swap_info(self.config, swap_len, swap_pos, table) @@ -403,26 +337,26 @@ class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): def setUp(self): # 默认动态扩容、hot emb、HBM - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) def tearDown(self): # 恢复config - self.config = dict(table_name="test_table", channel_id=0, skip_emb_transfer=True, emb_size=8, ext_emb_size=8, + self.config = dict(table_name="test_table", channel_id=0, is_hbm=True, emb_size=8, ext_emb_size=8, feat_cnt=8, batch_size=32, rank_size=8, send_count=1, device_id=0, use_hot=True, use_dynamic_expansion=True) global_env.apply_gradients_strategy = "direct_apply" @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=True), get_restore_vector=mock.MagicMock(return_value=[0, 0]), get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), get_all2all_args=mock.MagicMock(return_value=0), get_swap_info=mock.MagicMock(return_value=0), get_restore_vector_second=mock.MagicMock(return_value=0), get_unique_keys=mock.MagicMock(return_value=0)) - def test_get_preprocessed_tensor_for_asc_case1(self): + @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") + def test_get_preprocessed_tensor_for_asc_case1(self, build_graph_config_initializer): """ case1: 静态shape,全局unique """ @@ -431,20 +365,23 @@ class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): global_env.apply_gradients_strategy = "sum_same_id_gradients_and_apply" with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_static=True) + build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + result = get_preprocessed_tensor_for_asc(None, self.config) self.assertIsNotNone(result.get("restore_vector")) self.assertIsNotNone(result.get("restore_vector_second")) self.assertIsNotNone(result.get("unique_keys")) @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=False), get_restore_vector=mock.MagicMock(return_value=[0, 0]), get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), get_all2all_args=mock.MagicMock(return_value=0), get_swap_info=mock.MagicMock(return_value=0), get_restore_vector_second=mock.MagicMock(return_value=0), get_unique_keys=mock.MagicMock(return_value=0)) - def test_get_preprocessed_tensor_for_asc_case2(self): + @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") + def test_get_preprocessed_tensor_for_asc_case2(self, build_graph_config_initializer): """ case2: 动态shape,全局unique """ @@ -453,20 +390,23 @@ class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): global_env.apply_gradients_strategy = "sum_same_id_gradients_and_apply" with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer() + build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + result = get_preprocessed_tensor_for_asc(None, self.config) self.assertIsNotNone(result.get("restore_vector")) self.assertIsNotNone(result.get("restore_vector_second")) self.assertIsNotNone(result.get("unique_keys")) @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=False), get_restore_vector=mock.MagicMock(return_value=[0, 0]), get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), get_all2all_args=mock.MagicMock(return_value=0), get_swap_info=mock.MagicMock(return_value=0), get_restore_vector_second=mock.MagicMock(return_value=0), get_unique_keys=mock.MagicMock(return_value=0)) - def test_get_preprocessed_tensor_for_asc_case3(self): + @mock.patch("mx_rec.core.asc.build_graph.ConfigInitializer") + def test_get_preprocessed_tensor_for_asc_case3(self, build_graph_config_initializer): """ case3: 动态shape,全局unique,channel_id=1 """ @@ -475,27 +415,10 @@ class TestGetPreProcessedTensorForAscFunc(unittest.TestCase): global_env.apply_gradients_strategy = "sum_same_id_gradients_and_apply" with tf.Graph().as_default(): - self.config["channel_id"] = 1 - result = get_preprocessed_tensor_for_asc(None, self.config) - self.assertIsNotNone(result.get("restore_vector")) - self.assertIsNone(result.get("restore_vector_second")) + mock_config_initializer = MockConfigInitializer() + build_graph_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - @mock.patch.multiple("mx_rec.core.asc.build_graph", - get_use_static=mock.MagicMock(return_value=False), - get_restore_vector=mock.MagicMock(return_value=[0, 0]), - get_id_offsets=mock.MagicMock(return_value=[0, 0, 0]), - get_all2all_args=mock.MagicMock(return_value=0), - get_swap_info=mock.MagicMock(return_value=0), - get_restore_vector_second=mock.MagicMock(return_value=0), - get_unique_keys=mock.MagicMock(return_value=0)) - def test_get_preprocessed_tensor_for_asc_case4(self): - """ - case4: 动态shape - """ - - from mx_rec.core.asc.build_graph import get_preprocessed_tensor_for_asc - - with tf.Graph().as_default(): + self.config["channel_id"] = 1 result = get_preprocessed_tensor_for_asc(None, self.config) self.assertIsNotNone(result.get("restore_vector")) self.assertIsNone(result.get("restore_vector_second")) diff --git a/tests/mx_rec/core/test_cpu.py b/tests/mx_rec/core/test_cpu.py new file mode 100644 index 00000000..df653f54 --- /dev/null +++ b/tests/mx_rec/core/test_cpu.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# coding: UTF-8 +# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + +import os +import sys +import unittest +from mx_rec.util.cpu import * + + +class TestCPUBind(unittest.TestCase): + + def test_pcie(self): + for root, dirs, _ in os.walk("/sys/bus/pci/devices/"): + for d in dirs: + dev = os.path.join(root, d) + numa = get_numa_by_pcie(d) + with open(os.path.join(dev, "numa_node")) as f: + self.assertEqual(numa, int(f.read().strip())) + + def test_bind_failed(self): + self.assertFalse(bind_cpu_by_device_logic_id(20)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/mx_rec/core/test_embedding.py b/tests/mx_rec/core/test_embedding.py index f5284b8d..bf7d9240 100644 --- a/tests/mx_rec/core/test_embedding.py +++ b/tests/mx_rec/core/test_embedding.py @@ -15,7 +15,6 @@ # limitations under the License. # ============================================================================== -import os import unittest from unittest import mock @@ -23,9 +22,10 @@ import tensorflow as tf from mx_rec.core.asc import FeatureSpec from mx_rec.core.asc.feature_spec import set_temporary_feature_spec_attribute -from mx_rec.core.embedding import SparseEmbedding -from mx_rec.constants.constants import All2allGradientsOp, ASCAnchorAttr -from tests.mx_rec.core.mock_class import MockSparseEmbedding, MockOptimizer, MockHcclOps, MockAscManager +from mx_rec.core.emb.dynamic_sparse_embedding import HBMDynamicSparseEmbedding +from mx_rec.core.emb.sparse_embedding import HBMSparseEmbedding, ExternalStorageSparseEmbedding +from mx_rec.optimizers.gradient_descent import create_hash_optimizer +from tests.mx_rec.core.mock_class import MockConfigInitializer class TestCreateTableFunc(unittest.TestCase): @@ -33,592 +33,137 @@ class TestCreateTableFunc(unittest.TestCase): Test for 'mx_rec.core.embedding.create_table'. """ - @mock.patch.multiple("mx_rec.core.embedding", - fix_invalid_table_name=mock.MagicMock(return_value="table1"), - SparseEmbedding=mock.MagicMock(return_value=MockSparseEmbedding())) - def test_create_table(self): - """ - case: test create_table - """ - - from mx_rec.core.embedding import create_table - - test_table = create_table(key_dtype=tf.int64, - dim=tf.TensorShape([8]), - name='test_table', - emb_initializer=tf.compat.v1.truncated_normal_initializer()) - self.assertIsInstance(test_table, MockSparseEmbedding) - - -class TestSparseEmbeddingClass(unittest.TestCase): - """ - Test for 'mx_rec.core.embedding.SparseEmbedding'. - """ - - def setUp(self): - key_dtype = tf.int64 - dim = 8 - name = 'test_table' - emb_initializer = tf.compat.v1.truncated_normal_initializer() - optimizer_list = [MockOptimizer()] - device_vocabulary_size = 1 - host_vocabulary_size = 2 - ssd_vocabulary_size = 3 - ssd_data_path = (os.getcwd(),) - is_save = True - init_param = 1. - all2all_gradients_op = All2allGradientsOp.SUM_GRADIENTS.value - - self.config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, - device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, - ssd_vocabulary_size=ssd_vocabulary_size, ssd_data_path=ssd_data_path, - optimizer_list=optimizer_list, init_param=init_param, is_save=is_save, - all2all_gradients_op=all2all_gradients_op) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + @mock.patch.multiple("mx_rec.core.emb.base_sparse_embedding", get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None)) - def test_init(self): - """ - case: test create SparseEmbedding - - """ - - with tf.Graph().as_default(): - test_sparse_emb = SparseEmbedding(self.config) - self.assertIsInstance(test_sparse_emb, SparseEmbedding) - - -class TestGenerateLookupIdNotifyHybridFuncOfSparseEmbeddingClass(unittest.TestCase): - """ - Test for 'mx_rec.core.embedding.SparseEmbedding.generate_lookup_id_notify_hybrid'. - """ - - def test_generate_lookup_id_notify_hybrid(self): - """ - case: test generate_lookup_id_notify_hybrid - """ - - with tf.Graph().as_default(): - self.assertEqual(SparseEmbedding.generate_lookup_id_notify_hybrid(0).name, "d2h_notify_hybridmgmt_0") - - -class TestGetAnchorAttributeFuncOfSparseEmbeddingClass(unittest.TestCase): - """ - Test for 'mx_rec.core.embedding.SparseEmbedding.get_anchor_attribute'. - """ - - def test_get_anchor_attribute_case1(self): - """ - case1: 功能正常 - """ - - with tf.Graph().as_default(): - anchor_ids = tf.constant(1, dtype=tf.int64) - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = True - self.assertTrue(SparseEmbedding.get_anchor_attribute(anchor_ids, ASCAnchorAttr.IS_TRAINING)) - - def test_get_anchor_attribute_case2(self): - """ - case2: anchor_ids不是tensor,抛出异常 - """ - - with tf.Graph().as_default(): - anchor_ids = 1 - SparseEmbedding.anchor_tensor_specs[anchor_ids][ASCAnchorAttr.IS_TRAINING] = True - with self.assertRaises(TypeError): - SparseEmbedding.get_anchor_attribute(anchor_ids, ASCAnchorAttr.IS_TRAINING) - - def test_get_anchor_attribute_case3(self): - """ - case3: attr不是ASCAnchorAttr,抛出异常 - """ - - with tf.Graph().as_default(): - anchor_ids = tf.constant(1, dtype=tf.int64) - SparseEmbedding.anchor_tensor_specs[anchor_ids]["xxx"] = True - with self.assertRaises(ValueError): - SparseEmbedding.get_anchor_attribute(anchor_ids, "xxx") - - def test_get_anchor_attribute_case4(self): - """ - case4: 没有set直接get,抛出异常 - """ - - with tf.Graph().as_default(): - anchor_ids = tf.constant(1, dtype=tf.int64) - with self.assertRaises(KeyError): - SparseEmbedding.get_anchor_attribute(anchor_ids, ASCAnchorAttr.IS_TRAINING) - - -class TestGetOwnEmbFuncOfSparseEmbeddingClass(unittest.TestCase): - """ - Test for 'mx_rec.core.embedding.SparseEmbedding._get_own_emb'. - """ - - @mock.patch.multiple("mx_rec.core.embedding", - get_rank_size=mock.MagicMock(return_value=1), get_rank_id=mock.MagicMock(return_value=0), - hccl_ops=MockHcclOps()) - def test_get_own_emb_case1(self): + get_device_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.embedding.ConfigInitializer") + @mock.patch("mx_rec.core.emb.base_sparse_embedding.ConfigInitializer") + @mock.patch("mx_rec.validator.emb_validator.ConfigInitializer") + def test_create_table_case1(self, embedding_config_initializer, base_sparse_embedding_config_initializer, + emb_validator_config_initializer): """ - case1: rank=1,静态shape + case1: test create_table, 动态扩容 """ - with tf.Graph().as_default(): - src_emb = tf.constant([2, 1], dtype=tf.float32, name="src_emb") - all2all_args = 2 - emb_size = 1 - use_static = True - - # reshape_info为[2, 1] - own_emb = SparseEmbedding._get_own_emb(src_emb, all2all_args, emb_size, use_static) - self.assertListEqual(own_emb.shape.as_list(), [2, 1]) - - @mock.patch.multiple("mx_rec.core.embedding", - get_rank_size=mock.MagicMock(return_value=8), - get_rank_id=mock.MagicMock(return_value=0), - hccl_ops=MockHcclOps(shape=[2 * 8, 1])) - def test_get_own_emb_case2(self): - """ - case2: rank=8,静态shape - """ + from mx_rec.core.embedding import create_table with tf.Graph().as_default(): - src_emb = tf.constant([2, 1], dtype=tf.float32, name="src_emb") - all2all_args = 2 - emb_size = 1 - use_static = True - mock_shape = [all2all_args * 8, emb_size] + # mock + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=True) + embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - own_emb = SparseEmbedding._get_own_emb(src_emb, all2all_args, emb_size, use_static) - self.assertListEqual(own_emb.shape.as_list(), mock_shape) + # test + test_table = create_table(key_dtype=tf.int64, + dim=8, + name='test_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer()) + self.assertIsInstance(test_table, HBMDynamicSparseEmbedding) - @mock.patch.multiple("mx_rec.core.embedding", + @mock.patch.multiple("mx_rec.core.emb.base_sparse_embedding", get_rank_size=mock.MagicMock(return_value=8), get_rank_id=mock.MagicMock(return_value=0), - hccl_ops=MockHcclOps(shape=[2 * 8, 1])) - def test_get_own_emb_case3(self): + get_device_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.embedding.ConfigInitializer") + @mock.patch("mx_rec.core.emb.base_sparse_embedding.ConfigInitializer") + @mock.patch("mx_rec.validator.emb_validator.ConfigInitializer") + def test_create_table_case2(self, embedding_config_initializer, base_sparse_embedding_config_initializer, + emb_validator_config_initializer): """ - case3: rank=8,动态shape + case2: test create_table, 非动态扩容,HBM """ - with tf.Graph().as_default(): - src_emb = tf.constant([2, 1], dtype=tf.float32, name="src_emb") - all2all_args = 2 - emb_size = 1 - use_static = False - mock_shape = [all2all_args * 8, emb_size] - - own_emb = SparseEmbedding._get_own_emb(src_emb, all2all_args, emb_size, use_static) - self.assertListEqual(own_emb.shape.as_list(), mock_shape) - - -class TestSizeFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): - """ - Test for 'mx_rec.core.embedding.SparseEmbedding.size'. - """ - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_asc_manager=mock.MagicMock(return_value=MockAscManager())) - def test_size(self): - """ - case: test size - """ + from mx_rec.core.embedding import create_table with tf.Graph().as_default(): - test_sparse_emb = SparseEmbedding(self.config) - self.assertEqual(test_sparse_emb.size(), 0) - - -class TestCapacityFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): - """ - Test for 'mx_rec.core.embedding.SparseEmbedding.capacity'. - """ - - def tearDown(self): - self.config["device_vocabulary_size"] = 1 - self.config["host_vocabulary_size"] = 2 - self.config["ssd_vocabulary_size"] = 3 - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=True), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_asc_manager=mock.MagicMock(return_value=MockAscManager())) - def test_capacity_case1(self): - """ - case1: 开启动态扩容,HBM + # mock + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) + embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - """ + # test + test_table = create_table(key_dtype=tf.int64, + dim=8, + name='test_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer()) + self.assertIsInstance(test_table, HBMSparseEmbedding) - with tf.Graph().as_default(): - self.config["host_vocabulary_size"] = 0 - self.config["ssd_vocabulary_size"] = 0 - test_sparse_emb = SparseEmbedding(self.config) - self.assertEqual(test_sparse_emb.capacity(), 1) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + @mock.patch.multiple("mx_rec.core.emb.base_sparse_embedding", get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None)) - def test_capacity_case2(self): - """ - case2: 关闭动态扩容,HBM - """ - - with tf.Graph().as_default(): - self.config["host_vocabulary_size"] = 0 - self.config["ssd_vocabulary_size"] = 0 - test_sparse_emb = SparseEmbedding(self.config) - self.assertEqual(test_sparse_emb.capacity(), 1) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None)) - def test_capacity_case3(self): + get_rank_id=mock.MagicMock(return_value=0), + get_device_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.embedding.ConfigInitializer") + @mock.patch("mx_rec.core.emb.base_sparse_embedding.ConfigInitializer") + @mock.patch("mx_rec.validator.emb_validator.ConfigInitializer") + @mock.patch("mx_rec.optimizers.gradient_descent.ConfigInitializer") + def test_create_table_case3(self, embedding_config_initializer, base_sparse_embedding_config_initializer, + emb_validator_config_initializer, lazy_adam_config_initializer): """ - case3: 关闭动态扩容,DDR + case3: test create_table, 非动态扩容,DDR/SSD """ - with tf.Graph().as_default(): - self.config["ssd_vocabulary_size"] = 0 - test_sparse_emb = SparseEmbedding(self.config) - self.assertEqual(test_sparse_emb.capacity(), 3) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None)) - def test_capacity_case4(self): - """ - case4: 关闭动态扩容,SSD - """ + from mx_rec.core.embedding import create_table with tf.Graph().as_default(): - test_sparse_emb = SparseEmbedding(self.config) - self.assertEqual(test_sparse_emb.capacity(), 6) - + # mock + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) + embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + lazy_adam_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) -class TestGetDefaultLookupNameFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): - """ - Test for 'mx_rec.core.embedding.SparseEmbedding.get_default_lookup_name'. - """ + # test + test_table = create_table(key_dtype=tf.int64, + dim=8, + name='test_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + host_vocabulary_size=8, + optimizer_list=[create_hash_optimizer(learning_rate=0.01)]) + self.assertIsInstance(test_table, ExternalStorageSparseEmbedding) - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None)) - def test_get_default_lookup_name(self): - """ - case: test get_default_lookup_name - """ - with tf.Graph().as_default(): - test_sparse_emb = SparseEmbedding(self.config) - self.assertEqual(test_sparse_emb.get_default_lookup_name(), "sparse_lookup_0") - - -class TestLookupForAscFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): +class TestSparseLookupFunc(unittest.TestCase): """ - Test for 'mx_rec.core.embedding.SparseEmbedding.lookup_for_asc'. + Test for 'mx_rec.core.embedding.sparse_lookup'. """ - def tearDown(self): - self.config["device_vocabulary_size"] = 1 - self.config["host_vocabulary_size"] = 2 - self.config["ssd_vocabulary_size"] = 3 - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + @mock.patch.multiple("mx_rec.core.emb.base_sparse_embedding", get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True)) - @mock.patch.multiple("mx_rec.core.embedding.FeatureSpec", - set_feat_attribute=mock.MagicMock(return_value=None)) - def test_lookup_for_asc_case1(self): + get_rank_id=mock.MagicMock(return_value=0), + get_device_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.emb.sparse_embedding.get_preprocessed_tensor_for_asc") + @mock.patch("mx_rec.core.embedding.ConfigInitializer") + @mock.patch("mx_rec.core.emb.base_sparse_embedding.ConfigInitializer") + @mock.patch("mx_rec.validator.emb_validator.ConfigInitializer") + @mock.patch("mx_rec.core.emb.sparse_embedding.ConfigInitializer") + def test_sparse_lookup_case1(self, embedding_config_initializer, base_sparse_embedding_config_initializer, + emb_validator_config_initializer, sparse_embedding_config_initializer, + mock_get_preprocessed_tensor_for_asc): """ - case1: test lookup_for_asc,静态shape + case1: test sparse_lookup + 表:非动态扩容,HBM + ids:FeatureSpec模式,动态shape """ - with tf.Graph().as_default(): - def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): - return tf.constant(1, dtype=tf.int64) - - self.config["device_vocabulary_size"] = 100 * 8 - self.config["host_vocabulary_size"] = 100 * 8 - test_sparse_emb = SparseEmbedding(self.config) - test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner - ids = tf.ones(shape=[2, 1], dtype=tf.int64, name="ids") - send_count = 1 - kwargs = {"is_train": True} - - lookup_res = test_sparse_emb.lookup_for_asc(ids, send_count, **kwargs) - with tf.Session() as sess: - self.assertEqual(sess.run(lookup_res), 1) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=True), - get_name_to_var_dict=mock.MagicMock(return_value={"test_table": 1}), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_training_mode_channel_id=mock.MagicMock(return_value=None), - clear_channel=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=False)) - @mock.patch.multiple("mx_rec.core.embedding.FeatureSpec", - set_feat_attribute=mock.MagicMock(return_value=None)) - def test_lookup_for_asc_case2(self): - """ - case2: test lookup_for_asc,动态shape,is_training=False - """ + from mx_rec.core.embedding import create_table, sparse_lookup with tf.Graph().as_default(): - def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): - return tf.constant(1, dtype=tf.int64) - - self.config["device_vocabulary_size"] = 100 * 8 - self.config["host_vocabulary_size"] = 100 * 8 - test_sparse_emb = SparseEmbedding(self.config) - test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner - ids = tf.ones(shape=[2, 1], dtype=tf.int64, name="ids") - send_count = 1 - kwargs = {"is_train": False} - - lookup_res = test_sparse_emb.lookup_for_asc(ids, send_count, **kwargs) - with tf.Session() as sess: - self.assertEqual(sess.run(lookup_res), 1) + # mock + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) + embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) -class TestLookupForAscWithFeatureSpecFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): - """ - Test for 'mx_rec.core.embedding.SparseEmbedding.lookup_for_asc_with_feature_spec'. - """ - - def tearDown(self): - self.config["device_vocabulary_size"] = 1 - self.config["host_vocabulary_size"] = 2 - self.config["ssd_vocabulary_size"] = 3 - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True)) - @mock.patch("mx_rec.core.embedding.get_table_name_to_feature_spec") - def test_lookup_for_asc_with_feature_spec_case1(self, mock_get_table_name_to_feature_spec): - """ - case1: test lookup_for_asc_with_feature_spec,静态shape,len(same_table_feature_spec)=1 - """ - - with tf.Graph().as_default(): - def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): - return tf.constant(1, dtype=tf.int64) - - self.config["device_vocabulary_size"] = 100 * 8 - self.config["host_vocabulary_size"] = 100 * 8 - test_sparse_emb = SparseEmbedding(self.config) - test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner case1_feat = FeatureSpec("case1_feat", table_name="test_table") set_temporary_feature_spec_attribute(case1_feat, 1) - mock_get_table_name_to_feature_spec.return_value = [case1_feat] - send_count = 1 - kwargs = {"is_train": True} - - lookup_res = test_sparse_emb.lookup_for_asc_with_feature_spec(case1_feat, send_count, **kwargs) - with tf.Session() as sess: - self.assertEqual(sess.run(lookup_res), 1) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True)) - @mock.patch("mx_rec.core.embedding.get_table_name_to_feature_spec") - def test_lookup_for_asc_with_feature_spec_case2(self, mock_get_table_name_to_feature_spec): - """ - case2: test lookup_for_asc_with_feature_spec,静态shape,len(same_table_feature_spec)=0,抛出异常 - """ - - with tf.Graph().as_default(): - def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): - return tf.constant(1, dtype=tf.int64) - - self.config["device_vocabulary_size"] = 100 * 8 - self.config["host_vocabulary_size"] = 100 * 8 - test_sparse_emb = SparseEmbedding(self.config) - test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner - case2_feat = FeatureSpec("case2_feat", table_name="test_table") - set_temporary_feature_spec_attribute(case2_feat, 1) - mock_get_table_name_to_feature_spec.return_value = [] - send_count = 1 - kwargs = {"is_train": True} - - with self.assertRaises(RuntimeError): - test_sparse_emb.lookup_for_asc_with_feature_spec(case2_feat, send_count, **kwargs) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_training_mode_channel_id=mock.MagicMock(return_value=None), - clear_same_table_feature_spec=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True)) - @mock.patch("mx_rec.core.embedding.get_table_name_to_feature_spec") - def test_lookup_for_asc_with_feature_spec_case3(self, mock_get_table_name_to_feature_spec): - """ - case3: test lookup_for_asc_with_feature_spec,静态shape,len(same_table_feature_spec)>1 - """ - - with tf.Graph().as_default(): - def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): - return tf.ones(shape=[16, ], dtype=tf.int64) - - self.config["device_vocabulary_size"] = 100 * 8 - self.config["host_vocabulary_size"] = 100 * 8 - test_sparse_emb = SparseEmbedding(self.config) - test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner - case3_feat = FeatureSpec("case3_feat", table_name="test_table") - case3_feat_multi = FeatureSpec("case3_feat_multi", table_name="test_table") - set_temporary_feature_spec_attribute(case3_feat, 1) - set_temporary_feature_spec_attribute(case3_feat_multi, 1) - case3_feat.split = 8 - case3_feat_multi.split = 8 - mock_get_table_name_to_feature_spec.return_value = [case3_feat, case3_feat_multi] - send_count = 1 - kwargs = {"is_train": True} - - test_sparse_emb.lookup_for_asc_with_feature_spec(case3_feat, send_count, **kwargs) - self.assertGreater(len(test_sparse_emb.lookup_result), 0) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_training_mode_channel_id=mock.MagicMock(return_value=None), - clear_same_table_feature_spec=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=False)) - @mock.patch("mx_rec.core.embedding.get_table_name_to_feature_spec") - def test_lookup_for_asc_with_feature_spec_case4(self, mock_get_table_name_to_feature_spec): - """ - case4: test lookup_for_asc_with_feature_spec,动态shape,len(same_table_feature_spec)>1 - """ - - with tf.Graph().as_default(): - def _mock_lookup_for_asc_with_feature_spec_inner(feature_spec, send_count, **kwargs): - return tf.ones(shape=[16, ], dtype=tf.int64) - - self.config["device_vocabulary_size"] = 100 * 8 - self.config["host_vocabulary_size"] = 100 * 8 - test_sparse_emb = SparseEmbedding(self.config) - test_sparse_emb.lookup_for_asc_with_feature_spec_inner = _mock_lookup_for_asc_with_feature_spec_inner - case4_feat = FeatureSpec("case4_feat", table_name="test_table") - case4_feat_multi = FeatureSpec("case4_feat_multi", table_name="test_table") - set_temporary_feature_spec_attribute(case4_feat, 1) - set_temporary_feature_spec_attribute(case4_feat_multi, 1) - case4_feat.split = 8 - case4_feat_multi.split = 8 - mock_get_table_name_to_feature_spec.return_value = [case4_feat, case4_feat_multi] - send_count = 1 - kwargs = { - "is_train": True, - "batch": { - "case4_feat": tf.ones(shape=[8, ], dtype=tf.int64), - "case4_feat_multi": tf.ones(shape=[8, ], dtype=tf.int64) - } - } - - test_sparse_emb.emb_size = 1 - test_sparse_emb.lookup_for_asc_with_feature_spec(case4_feat, send_count, **kwargs) - self.assertGreater(len(test_sparse_emb.lookup_result), 0) - - -class TestLookupForAscWithFeatureSpecInnerFuncOfSparseEmbeddingClass(TestSparseEmbeddingClass): - """ - Test for 'mx_rec.core.embedding.SparseEmbedding.lookup_for_asc_with_feature_spec_inner'. - """ - - def tearDown(self): - self.config["device_vocabulary_size"] = 1 - self.config["host_vocabulary_size"] = 2 - self.config["ssd_vocabulary_size"] = 3 - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - get_device_id=mock.MagicMock(return_value=0), - get_use_hot=mock.MagicMock(return_value=1), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True)) - @mock.patch("mx_rec.core.embedding.get_preprocessed_tensor_for_asc") - def test_lookup_for_asc_with_feature_spec_inner_case1(self, mock_get_preprocessed_tensor_for_asc): - """ - case1: test lookup_for_asc_with_feature_spec_inner,静态shape,关闭动态扩容 - """ - - with tf.Graph().as_default(): + case1_feat.dims = [8, 8] + mock_config_initializer.get_instance().feature_spec_config.insert_feature_spec(case1_feat, True) + batch = {"case1_feat": tf.ones(shape=[8, 8], dtype=tf.int64)} mock_get_preprocessed_tensor_for_asc.return_value = { "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), "restore_vector_second": tf.ones(shape=[8, ], dtype=tf.int64), @@ -629,89 +174,49 @@ class TestLookupForAscWithFeatureSpecInnerFuncOfSparseEmbeddingClass(TestSparseE "swap_in": [tf.no_op()] } - self.config["device_vocabulary_size"] = 100 * 8 - self.config["host_vocabulary_size"] = 0 - test_sparse_emb = SparseEmbedding(self.config) - case1_feat = FeatureSpec("case1_feat", table_name="test_table") - set_temporary_feature_spec_attribute(case1_feat, 1) - case1_feat.dims = [8, 8] - send_count = 1 - kwargs = {"is_train": True} - - def _mock_get_own_emb(emb, all2all_args, emb_size, use_static): - return test_sparse_emb.variable + # test + test_table = create_table(key_dtype=tf.int64, + dim=8, + name='test_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + device_vocabulary_size=100 * 8) + self.assertIsInstance(test_table, HBMSparseEmbedding) - test_sparse_emb._get_own_emb = _mock_get_own_emb + res = sparse_lookup(test_table, case1_feat, batch=batch) + self.assertIsInstance(res, tf.Tensor) - lookup_res = test_sparse_emb.lookup_for_asc_with_feature_spec_inner(case1_feat, send_count, **kwargs) - self.assertIsInstance(lookup_res, tf.Tensor) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), + @mock.patch.multiple("mx_rec.core.emb.base_sparse_embedding", get_rank_size=mock.MagicMock(return_value=8), - get_device_id=mock.MagicMock(return_value=0), - get_use_hot=mock.MagicMock(return_value=1), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=False)) - @mock.patch("mx_rec.core.embedding.get_preprocessed_tensor_for_asc") - def test_lookup_for_asc_with_feature_spec_inner_case2(self, mock_get_preprocessed_tensor_for_asc): + get_rank_id=mock.MagicMock(return_value=0), + get_device_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.asc.feature_spec.ConfigInitializer") + @mock.patch("mx_rec.core.emb.sparse_embedding.get_preprocessed_tensor_for_asc") + @mock.patch("mx_rec.core.embedding.ConfigInitializer") + @mock.patch("mx_rec.core.emb.base_sparse_embedding.ConfigInitializer") + @mock.patch("mx_rec.validator.emb_validator.ConfigInitializer") + @mock.patch("mx_rec.core.emb.sparse_embedding.ConfigInitializer") + def test_sparse_lookup_case2(self, embedding_config_initializer, base_sparse_embedding_config_initializer, + emb_validator_config_initializer, sparse_embedding_config_initializer, + mock_get_preprocessed_tensor_for_asc, feature_spec_config_initializer): """ - case2: test lookup_for_asc_with_feature_spec_inner,动态shape,关闭动态扩容 + case2: test sparse_lookup + 表:非动态扩容,HBM + ids:自动改图模式,动态shape """ - with tf.Graph().as_default(): - mock_get_preprocessed_tensor_for_asc.return_value = { - "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), - "restore_vector_second": tf.ones(shape=[8, ], dtype=tf.int64), - "unique_keys": tf.ones(shape=[8, ], dtype=tf.int64), - "hot_pos": tf.ones(shape=[8, ], dtype=tf.int64), - "id_offsets": tf.ones(shape=[8, ], dtype=tf.int64), - "all2all_args": tf.ones(shape=[8, 8], dtype=tf.int64), - "swap_in": [tf.no_op()] - } - - self.config["device_vocabulary_size"] = 100 * 8 - self.config["host_vocabulary_size"] = 0 - test_sparse_emb = SparseEmbedding(self.config) - case2_feat = FeatureSpec("case2_feat", table_name="test_table") - set_temporary_feature_spec_attribute(case2_feat, 1) - case2_feat.dims = [8, 8] - send_count = 1 - kwargs = {"is_train": True, "batch": {"case2_feat": tf.ones(shape=[8, 8], dtype=tf.int64)}} + from mx_rec.core.embedding import create_table, sparse_lookup - def _mock_get_own_emb(emb, all2all_args, emb_size, use_static): - return test_sparse_emb.variable - - test_sparse_emb._get_own_emb = _mock_get_own_emb - - lookup_res = test_sparse_emb.lookup_for_asc_with_feature_spec_inner(case2_feat, send_count, **kwargs) - self.assertIsInstance(lookup_res, tf.Tensor) + with tf.Graph().as_default(): + # mock + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - get_device_id=mock.MagicMock(return_value=0), - get_use_hot=mock.MagicMock(return_value=1), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None), - get_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True)) - @mock.patch("mx_rec.core.embedding.get_preprocessed_tensor_for_asc") - def test_lookup_for_asc_with_feature_spec_inner_case3(self, mock_get_preprocessed_tensor_for_asc): - """ - case3: test lookup_for_asc_with_feature_spec_inner,静态shape,关闭动态扩容 - access_threshold > 0,覆盖 set_specific_value_for_non_valid_key() - """ + embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + feature_spec_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - with tf.Graph().as_default(): + case2_feat = tf.ones(shape=[8, 8], dtype=tf.int64) mock_get_preprocessed_tensor_for_asc.return_value = { "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), "restore_vector_second": tf.ones(shape=[8, ], dtype=tf.int64), @@ -722,122 +227,16 @@ class TestLookupForAscWithFeatureSpecInnerFuncOfSparseEmbeddingClass(TestSparseE "swap_in": [tf.no_op()] } - self.config["device_vocabulary_size"] = 100 * 8 - self.config["host_vocabulary_size"] = 0 - test_sparse_emb = SparseEmbedding(self.config) - case3_feat = FeatureSpec("case3_feat", table_name="test_table", access_threshold=10) - set_temporary_feature_spec_attribute(case3_feat, 1) - case3_feat.dims = [8, 8] - send_count = 1 - kwargs = {"is_train": True} - - def _mock_get_own_emb(emb, all2all_args, emb_size, use_static): - return test_sparse_emb.variable - - test_sparse_emb._get_own_emb = _mock_get_own_emb - - lookup_res = test_sparse_emb.lookup_for_asc_with_feature_spec_inner(case3_feat, send_count, **kwargs) - self.assertIsInstance(lookup_res, tf.Tensor) - - -class TestSparseLookupFunc(unittest.TestCase): - """ - Test for 'mx_rec.core.embedding.sparse_lookup'. - """ - - def setUp(self): - key_dtype = tf.int64 - dim = 8 - name = 'test_table' - emb_initializer = tf.compat.v1.truncated_normal_initializer() - optimizer_list = [MockOptimizer()] - device_vocabulary_size = 1 - host_vocabulary_size = 2 - ssd_vocabulary_size = 3 - ssd_data_path = (os.getcwd(),) - is_save = True - init_param = 1. - all2all_gradients_op = All2allGradientsOp.SUM_GRADIENTS.value - - self.config = dict(key_dtype=key_dtype, embedding_size=dim, table_name=name, emb_initializer=emb_initializer, - device_vocabulary_size=device_vocabulary_size, host_vocabulary_size=host_vocabulary_size, - ssd_vocabulary_size=ssd_vocabulary_size, ssd_data_path=ssd_data_path, - optimizer_list=optimizer_list, init_param=init_param, is_save=is_save, - all2all_gradients_op=all2all_gradients_op) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None)) - def test_sparse_lookup_case1(self): - """ - case1: test sparse_lookup,FeatureSpec模式 - """ - - from mx_rec.core.embedding import sparse_lookup - - def _mock_lookup_for_asc_with_feature_spec(ids, send_count, **kwargs): - return 0 - - with tf.Graph().as_default(): - case1_feat = FeatureSpec("case1_feat", table_name="test_table") - test_sparse_emb = SparseEmbedding(self.config) - test_sparse_emb.lookup_for_asc_with_feature_spec = _mock_lookup_for_asc_with_feature_spec - - self.assertEqual(sparse_lookup(test_sparse_emb, case1_feat), 0) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - set_modify_graph=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None)) - def test_sparse_lookup_case2(self): - """ - case2: test sparse_lookup,自动改图模式 - """ - - from mx_rec.core.embedding import sparse_lookup - - def _mock_lookup_for_asc(ids, send_count, **kwargs): - return 1 - - with tf.Graph().as_default(): - ids = tf.constant(1, tf.int64) - test_sparse_emb = SparseEmbedding(self.config) - test_sparse_emb.lookup_for_asc = _mock_lookup_for_asc - - self.assertEqual(sparse_lookup(test_sparse_emb, ids, modify_graph=True), 1) - - @mock.patch.multiple("mx_rec.core.embedding", - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - is_asc_frozen=mock.MagicMock(return_value=False), - get_name_to_var_dict=mock.MagicMock(return_value=None), - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="xxx"), - get_rank_size=mock.MagicMock(return_value=8), - insert_removing_var_list=mock.MagicMock(return_value=None), - set_modify_graph=mock.MagicMock(return_value=None), - insert_table_instance=mock.MagicMock(return_value=None)) - def test_sparse_lookup_case3(self): - """ - case3: test sparse_lookup,自动改图模式,没传入modify_graph参数,抛出异常 - """ - - from mx_rec.core.embedding import sparse_lookup - - with tf.Graph().as_default(): - ids = tf.constant(1, tf.int64) - test_sparse_emb = SparseEmbedding(self.config) + # test + test_table = create_table(key_dtype=tf.int64, + dim=8, + name='test_table', + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + device_vocabulary_size=100 * 8) + self.assertIsInstance(test_table, HBMSparseEmbedding) - with self.assertRaises(ValueError): - sparse_lookup(test_sparse_emb, ids) + res = sparse_lookup(test_table, case2_feat, modify_graph=True) + self.assertIsInstance(res, tf.Tensor) if __name__ == '__main__': diff --git a/tests/mx_rec/core/test_feature_process.py b/tests/mx_rec/core/test_feature_process.py index 91cdbe65..2bac51fc 100644 --- a/tests/mx_rec/core/test_feature_process.py +++ b/tests/mx_rec/core/test_feature_process.py @@ -23,7 +23,7 @@ import tensorflow as tf from mx_rec.core.feature_process import EvictHook from mx_rec.core.asc.feature_spec import FeatureSpec -from tests.mx_rec.core.mock_class import MockSparseEmbedding +from tests.mx_rec.core.mock_class import MockSparseEmbedding, MockConfigInitializer class TestEvictHookClass(unittest.TestCase): @@ -46,8 +46,7 @@ class TestEvictHookClass(unittest.TestCase): @mock.patch.multiple( "mx_rec.graph.patch", - get_modify_graph=mock.Mock(return_value=True), - get_is_graph_modify_hook_running=mock.Mock(return_value=True), + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), ) @mock.patch.multiple( "tensorflow.compat.v1.train.Saver", @@ -63,22 +62,23 @@ class TestAfterRunFuncOfEvictHookClass(TestEvictHookClass): self.ori_var_assert = [[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]] self.evict_var_assert = [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]] - @mock.patch.multiple("mx_rec.core.feature_process", - trigger_evict=mock.MagicMock(return_value=True)) @mock.patch("mx_rec.core.feature_process.npu_ops.gen_npu_ops.get_next") - @mock.patch("mx_rec.core.feature_process.get_table_instance_by_name") - @mock.patch("mx_rec.core.feature_process.export_feature_spec") - def test_after_run_case1(self, mock_export_feature_spec, mock_get_table_instance_by_name, mock_get_next): + @mock.patch("mx_rec.util.config_utils.embedding_utils.SparseEmbedConfig.get_table_instance_by_name") + @mock.patch("mx_rec.core.feature_process.ConfigInitializer") + def test_after_run_case1(self, feature_process_config_initializer, mock_get_table_instance_by_name, mock_get_next): """ case1: evict_enable为True,python和C++侧正常触发淘汰 """ with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) + feature_process_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_config_initializer.get_instance().feature_spec_config.insert_feature_spec( + FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10), True) + test_table = MockSparseEmbedding() mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] mock_get_table_instance_by_name.return_value = test_table - mock_export_feature_spec.return_value = dict( - test_spec=FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10)) evict_hook = EvictHook(evict_enable=True, evict_time_interval=1) with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: @@ -99,22 +99,23 @@ class TestAfterRunFuncOfEvictHookClass(TestEvictHookClass): self.assertListEqual(evict_variable[8:].tolist(), self.evict_var_assert) self.assertListEqual(evict_variable[:2].tolist(), self.ori_var_assert) - @mock.patch.multiple("mx_rec.core.feature_process", - trigger_evict=mock.MagicMock(return_value=False)) @mock.patch("mx_rec.core.feature_process.npu_ops.gen_npu_ops.get_next") - @mock.patch("mx_rec.core.feature_process.get_table_instance_by_name") - @mock.patch("mx_rec.core.feature_process.export_feature_spec") - def test_after_run_case2(self, mock_export_feature_spec, mock_get_table_instance_by_name, mock_get_next): + @mock.patch("mx_rec.util.config_utils.embedding_utils.SparseEmbedConfig.get_table_instance_by_name") + @mock.patch("mx_rec.core.feature_process.ConfigInitializer") + def test_after_run_case2(self, feature_process_config_initializer, mock_get_table_instance_by_name, mock_get_next): """ case2: evict_enable为True,C++侧异常 """ with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False, trigger_evict=False) + feature_process_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_config_initializer.get_instance().feature_spec_config.insert_feature_spec( + FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10), True) + test_table = MockSparseEmbedding() mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] mock_get_table_instance_by_name.return_value = test_table - mock_export_feature_spec.return_value = dict( - test_spec=FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10)) evict_hook = EvictHook(evict_enable=True, evict_time_interval=1) with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: @@ -134,22 +135,23 @@ class TestAfterRunFuncOfEvictHookClass(TestEvictHookClass): # 此时C++侧异常,不执行淘汰,因此evict_variable后两行还是1 self.assertListEqual(evict_variable[8:].tolist(), self.ori_var_assert) - @mock.patch.multiple("mx_rec.core.feature_process", - trigger_evict=mock.MagicMock(return_value=True)) @mock.patch("mx_rec.core.feature_process.npu_ops.gen_npu_ops.get_next") - @mock.patch("mx_rec.core.feature_process.get_table_instance_by_name") - @mock.patch("mx_rec.core.feature_process.export_feature_spec") - def test_after_run_case3(self, mock_export_feature_spec, mock_get_table_instance_by_name, mock_get_next): + @mock.patch("mx_rec.util.config_utils.embedding_utils.SparseEmbedConfig.get_table_instance_by_name") + @mock.patch("mx_rec.core.feature_process.ConfigInitializer") + def test_after_run_case3(self, feature_process_config_initializer, mock_get_table_instance_by_name, mock_get_next): """ case3: evict_enable为False """ with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) + feature_process_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_config_initializer.get_instance().feature_spec_config.insert_feature_spec( + FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10), True) + test_table = MockSparseEmbedding() mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] mock_get_table_instance_by_name.return_value = test_table - mock_export_feature_spec.return_value = dict( - test_spec=FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10)) evict_hook = EvictHook(evict_enable=False, evict_time_interval=1) with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: diff --git a/tests/mx_rec/core/test_feature_spec.py b/tests/mx_rec/core/test_feature_spec.py index bef5da5c..afe1c709 100644 --- a/tests/mx_rec/core/test_feature_spec.py +++ b/tests/mx_rec/core/test_feature_spec.py @@ -22,6 +22,7 @@ from functools import reduce import tensorflow as tf from mx_rec.core.asc.feature_spec import FeatureSpec +from tests.mx_rec.core.mock_class import MockConfigInitializer class TestFeatureSpecClass(unittest.TestCase): @@ -154,6 +155,10 @@ class TestSetFeatPosFuncOfFeatureSpecClass(TestFeatureSpecClass): self.assertEqual(case2_feature_spec.instance_count_train, set_times) +@mock.patch.multiple( + "mx_rec.core.asc.feature_spec", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class TestInsertPipelineModeFuncOfFeatureSpecClass(TestFeatureSpecClass): """ Test for 'mx_rec.core.asc.feature_spec.FeatureSpec.insert_pipeline_mode'. @@ -163,8 +168,6 @@ class TestInsertPipelineModeFuncOfFeatureSpecClass(TestFeatureSpecClass): # 每个测试方法执行前,将FeatureSpec的静态成员设为默认值 super().setUp() - @mock.patch.multiple("mx_rec.core.asc.feature_spec", - insert_training_mode_channel_id=mock.MagicMock(return_value=None)) def test_insert_pipeline_mode_case1(self): """ case1: mode为非bool类型,抛出异常 @@ -175,8 +178,6 @@ class TestInsertPipelineModeFuncOfFeatureSpecClass(TestFeatureSpecClass): with self.assertRaises(TypeError): case1_feature_spec.insert_pipeline_mode(mode) - @mock.patch.multiple("mx_rec.core.asc.feature_spec", - insert_training_mode_channel_id=mock.MagicMock(return_value=None)) def test_insert_pipeline_mode_case2(self): """ case2: mode为False @@ -187,8 +188,6 @@ class TestInsertPipelineModeFuncOfFeatureSpecClass(TestFeatureSpecClass): case2_feature_spec.insert_pipeline_mode(mode) self.assertSetEqual(case2_feature_spec.pipeline_mode, {False}) - @mock.patch.multiple("mx_rec.core.asc.feature_spec", - insert_training_mode_channel_id=mock.MagicMock(return_value=None)) def test_insert_pipeline_mode_case3(self): """ case3: mode为True @@ -199,8 +198,6 @@ class TestInsertPipelineModeFuncOfFeatureSpecClass(TestFeatureSpecClass): case3_feature_spec.insert_pipeline_mode(mode) self.assertSetEqual(case3_feature_spec.pipeline_mode, {True}) - @mock.patch.multiple("mx_rec.core.asc.feature_spec", - insert_training_mode_channel_id=mock.MagicMock(return_value=None)) def test_insert_pipeline_mode_case4(self): """ case4: mode为True,在已经设置过一次True的情况下,设置第二次后无报错 @@ -213,6 +210,10 @@ class TestInsertPipelineModeFuncOfFeatureSpecClass(TestFeatureSpecClass): self.assertSetEqual(case4_feature_spec.pipeline_mode, {True}) +@mock.patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class TestSetFeatAttributeFuncOfFeatureSpecClass(TestFeatureSpecClass): """ Test for 'mx_rec.core.asc.feature_spec.FeatureSpec.set_feat_attribute'. @@ -223,40 +224,43 @@ class TestSetFeatAttributeFuncOfFeatureSpecClass(TestFeatureSpecClass): # 每个测试方法执行前,将FeatureSpec的静态成员设为默认值 super().setUp() - @mock.patch.multiple("mx_rec.core.asc.feature_spec", - insert_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True), - insert_feature_spec=mock.MagicMock(return_value=None)) - def test_set_feat_attribute_case1(self): + @mock.patch("mx_rec.core.asc.feature_spec.ConfigInitializer") + def test_set_feat_attribute_case1(self, feature_spec_config_initializer): """ case1: 未初始化initialized成员,静态shape,tensor的rank为0,抛出异常 """ + + mock_config_initializer = MockConfigInitializer(use_static=True) + feature_spec_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + case1_tensor = tf.ones([], tf.int32) case1_feature_spec = FeatureSpec("case1") with self.assertRaises(ValueError): case1_feature_spec.set_feat_attribute(case1_tensor, self.is_training) - @mock.patch.multiple("mx_rec.core.asc.feature_spec", - insert_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True), - insert_feature_spec=mock.MagicMock(return_value=None)) - def test_set_feat_attribute_case2(self): + @mock.patch("mx_rec.core.asc.feature_spec.ConfigInitializer") + def test_set_feat_attribute_case2(self, feature_spec_config_initializer): """ case2: 未初始化initialized成员,静态shape,tensor的rank等于1 """ + + mock_config_initializer = MockConfigInitializer(use_static=True) + feature_spec_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + case2_tensor = tf.ones([32], tf.int32) case2_feature_spec = FeatureSpec("case2") case2_feature_spec.set_feat_attribute(case2_tensor, self.is_training) self.assertEqual(case2_feature_spec.feat_cnt, 1) - @mock.patch.multiple("mx_rec.core.asc.feature_spec", - insert_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True), - insert_feature_spec=mock.MagicMock(return_value=None)) - def test_set_feat_attribute_case3(self): + @mock.patch("mx_rec.core.asc.feature_spec.ConfigInitializer") + def test_set_feat_attribute_case3(self, feature_spec_config_initializer): """ case3: 未初始化initialized成员,静态shape,tensor的rank大于1 """ + + mock_config_initializer = MockConfigInitializer(use_static=True) + feature_spec_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + test_tensor_shape = [32, 16, 2] test_tensor = tf.ones(test_tensor_shape, tf.int32) case3_feature_spec = FeatureSpec("case3") @@ -264,19 +268,19 @@ class TestSetFeatAttributeFuncOfFeatureSpecClass(TestFeatureSpecClass): reduce_shape = reduce(lambda x, y: x * y, test_tensor_shape[1:]) self.assertTrue(case3_feature_spec.initialized) self.assertListEqual(case3_feature_spec.dims, test_tensor_shape) - self.assertEqual(case3_feature_spec.rank, len(test_tensor_shape)) + self.assertEqual(case3_feature_spec.tensor_rank, len(test_tensor_shape)) self.assertEqual(case3_feature_spec.batch_size, test_tensor_shape[0]) self.assertEqual(case3_feature_spec.feat_cnt, reduce_shape) self.assertEqual(case3_feature_spec.split, test_tensor_shape[0] * reduce_shape) - @mock.patch.multiple("mx_rec.core.asc.feature_spec", - insert_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=False), - insert_feature_spec=mock.MagicMock(return_value=None)) - def test_set_feat_attribute_case4(self): + @mock.patch("mx_rec.core.asc.feature_spec.ConfigInitializer") + def test_set_feat_attribute_case4(self, feature_spec_config_initializer): """ case4: 未初始化initialized成员,动态shape,tensor的rank大于1 """ + mock_config_initializer = MockConfigInitializer(use_static=False) + feature_spec_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + test_tensor_shape = [32, 3] test_tensor = tf.ones(test_tensor_shape, tf.int32) case4_feature_spec = FeatureSpec("case4") @@ -285,19 +289,19 @@ class TestSetFeatAttributeFuncOfFeatureSpecClass(TestFeatureSpecClass): with tf.Session() as sess: self.assertTrue(case4_feature_spec.initialized) self.assertEqual(sess.run(case4_feature_spec.dims), reduce_shape) - self.assertEqual(case4_feature_spec.rank, 1) + self.assertEqual(case4_feature_spec.tensor_rank, 1) self.assertEqual(sess.run(case4_feature_spec.split), reduce_shape) self.assertEqual(sess.run(case4_feature_spec.batch_size), reduce_shape) self.assertEqual(case4_feature_spec.feat_cnt, 1) - @mock.patch.multiple("mx_rec.core.asc.feature_spec", - insert_training_mode_channel_id=mock.MagicMock(return_value=None), - get_use_static=mock.MagicMock(return_value=True), - insert_feature_spec=mock.MagicMock(return_value=None)) - def test_set_feat_attribute_case5(self): + @mock.patch("mx_rec.core.asc.feature_spec.ConfigInitializer") + def test_set_feat_attribute_case5(self, feature_spec_config_initializer): """ case5: 静态shape,tensor的rank大于1,再次用不同tensor初始化initialized成员,抛出异常 """ + mock_config_initializer = MockConfigInitializer(use_static=True) + feature_spec_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + case5_feature_spec = FeatureSpec("case5") case5_feature_spec.set_feat_attribute(tf.ones([32, 3], tf.int32), self.is_training) # 再次初始化 @@ -358,4 +362,4 @@ class TestSetTemporaryFeatureSpecAttributeFunc(unittest.TestCase): if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/mx_rec/core/test_helper.py b/tests/mx_rec/core/test_helper.py index 3444b47b..98739775 100644 --- a/tests/mx_rec/core/test_helper.py +++ b/tests/mx_rec/core/test_helper.py @@ -25,7 +25,7 @@ import tensorflow as tf from mx_rec.core.asc.feature_spec import FeatureSpec from tests.mx_rec.core.generator_dataset import generate_dataset, Config -from tests.mx_rec.core.mock_class import MockHostPipeLineOps +from tests.mx_rec.core.mock_class import MockHostPipeLineOps, MockConfigInitializer class TestGetAscInsertFunc(unittest.TestCase): @@ -96,6 +96,10 @@ class TestGetAscInsertFunc(unittest.TestCase): self.assertTrue(callable(get_asc_insert_func(args_index_list=[], table_names=["xxx"]))) +@mock.patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class TestGetAscInsertFuncInnerFunc(unittest.TestCase): """ Test for 'mx_rec.core.asc.helper.get_asc_insert_func_inner'. @@ -206,7 +210,8 @@ class TestDoInsertFunc(unittest.TestCase): insert_tensor = [] splits = [] table_names = [] - input_dict = dict(dump_graph=True) + input_dict = dict(is_training=True, dump_graph=True, timestamp=None, feature_spec_names=[], + auto_change_graph=True) out_batch = do_insert(args, insert_tensor, splits, table_names, input_dict) self.assertIsInstance(out_batch, dict) @@ -217,12 +222,11 @@ class TestSendFeatureIdRequestAsyncFunc(unittest.TestCase): Test for 'mx_rec.core.asc.helper.send_feature_id_request_async'. """ - @mock.patch.multiple("mx_rec.core.asc.helper", - get_use_static=mock.MagicMock(return_value=True), - get_training_mode_channel_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.asc.helper.ConfigInitializer") @mock.patch("mx_rec.core.asc.helper.merge_feature_id_request") - @mock.patch("mx_rec.core.asc.helper.get_host_pipeline_ops") - def test_send_feature_id_request_async_case1(self, mock_get_host_pipeline_ops, mock_merge_feature_id_request): + @mock.patch("mx_rec.core.asc.helper.import_host_pipeline_ops") + def test_send_feature_id_request_async_case1(self, mock_get_host_pipeline_ops, mock_merge_feature_id_request, + helper_config_initializer): """ case1: 静态shape """ @@ -230,6 +234,9 @@ class TestSendFeatureIdRequestAsyncFunc(unittest.TestCase): from mx_rec.core.asc.helper import send_feature_id_request_async with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_static=True) + helper_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_get_host_pipeline_ops.return_value = MockHostPipeLineOps() feature_id_list = [tf.constant([2, ], dtype=tf.int64), tf.constant([3, ], dtype=tf.int64)] mock_merge_feature_id_request.return_value = dict(output_feature_id_list=[feature_id_list[1]], @@ -241,12 +248,11 @@ class TestSendFeatureIdRequestAsyncFunc(unittest.TestCase): mock_res = send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict) self.assertEqual(mock_res, 0) - @mock.patch.multiple("mx_rec.core.asc.helper", - get_use_static=mock.MagicMock(return_value=False), - get_training_mode_channel_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.asc.helper.ConfigInitializer") @mock.patch("mx_rec.core.asc.helper.merge_feature_id_request") - @mock.patch("mx_rec.core.asc.helper.get_host_pipeline_ops") - def test_send_feature_id_request_async_case2(self, mock_get_host_pipeline_ops, mock_merge_feature_id_request): + @mock.patch("mx_rec.core.asc.helper.import_host_pipeline_ops") + def test_send_feature_id_request_async_case2(self, mock_get_host_pipeline_ops, mock_merge_feature_id_request, + helper_config_initializer): """ case2: 动态shape """ @@ -254,6 +260,9 @@ class TestSendFeatureIdRequestAsyncFunc(unittest.TestCase): from mx_rec.core.asc.helper import send_feature_id_request_async with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_static=False) + helper_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_get_host_pipeline_ops.return_value = MockHostPipeLineOps() feature_id_list = [tf.constant([2, ], dtype=tf.int64), tf.constant([3, ], dtype=tf.int64)] mock_merge_feature_id_request.return_value = dict(output_feature_id_list=[feature_id_list[1]], @@ -265,12 +274,11 @@ class TestSendFeatureIdRequestAsyncFunc(unittest.TestCase): mock_res = send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict) self.assertEqual(mock_res, 1) - @mock.patch.multiple("mx_rec.core.asc.helper", - get_use_static=mock.MagicMock(return_value=False), - get_training_mode_channel_id=mock.MagicMock(return_value=0)) + @mock.patch("mx_rec.core.asc.helper.ConfigInitializer") @mock.patch("mx_rec.core.asc.helper.merge_feature_id_request") - @mock.patch("mx_rec.core.asc.helper.get_host_pipeline_ops") - def test_send_feature_id_request_async_case3(self, mock_get_host_pipeline_ops, mock_merge_feature_id_request): + @mock.patch("mx_rec.core.asc.helper.import_host_pipeline_ops") + def test_send_feature_id_request_async_case3(self, mock_get_host_pipeline_ops, mock_merge_feature_id_request, + helper_config_initializer): """ case3: split_list或tensorshape_split_list为空,抛出异常 """ @@ -278,6 +286,9 @@ class TestSendFeatureIdRequestAsyncFunc(unittest.TestCase): from mx_rec.core.asc.helper import send_feature_id_request_async with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_static=False) + helper_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_get_host_pipeline_ops.return_value = MockHostPipeLineOps() feature_id_list = [tf.constant([2, ], dtype=tf.int64), tf.constant([3, ], dtype=tf.int64)] mock_merge_feature_id_request.return_value = dict(output_feature_id_list=[feature_id_list[1]], @@ -290,6 +301,10 @@ class TestSendFeatureIdRequestAsyncFunc(unittest.TestCase): send_feature_id_request_async(feature_id_list, split_list, table_name_list, input_dict) +@mock.patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class TestMergeFeatureIdRequestFunc(unittest.TestCase): """ Test for 'mx_rec.core.asc.helper.merge_feature_id_request'. @@ -309,9 +324,8 @@ class TestMergeFeatureIdRequestFunc(unittest.TestCase): with self.assertRaises(RuntimeError): merge_feature_id_request(feature_id_list, split_list, table_name_list) - @mock.patch.multiple("mx_rec.core.asc.helper", - get_modify_graph=mock.MagicMock(return_value=False)) - def test_merge_feature_id_request_case2(self): + @mock.patch("mx_rec.core.asc.helper.ConfigInitializer") + def test_merge_feature_id_request_case2(self, helper_config_initializer): """ case2: 非自动改图 """ @@ -319,6 +333,9 @@ class TestMergeFeatureIdRequestFunc(unittest.TestCase): from mx_rec.core.asc.helper import merge_feature_id_request with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(modify_graph=False) + helper_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + feature_id_list = [ tf.constant([2, ], dtype=tf.int64), tf.constant([3, ], dtype=tf.int64), @@ -367,9 +384,8 @@ class TestMergeFeatureIdRequestFunc(unittest.TestCase): self.assertEqual(sess.run(real_output_tensorshape_split), sess.run(except_output_tensorshape_split)) - @mock.patch.multiple("mx_rec.core.asc.helper", - get_modify_graph=mock.MagicMock(return_value=True)) - def test_merge_feature_id_request_case3(self): + @mock.patch("mx_rec.core.asc.helper.ConfigInitializer") + def test_merge_feature_id_request_case3(self, helper_config_initializer): """ case3: 自动改图 """ @@ -377,6 +393,9 @@ class TestMergeFeatureIdRequestFunc(unittest.TestCase): from mx_rec.core.asc.helper import merge_feature_id_request with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(modify_graph=True) + helper_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + feature_id_list = [ tf.constant([2, ], dtype=tf.int64), tf.constant([3, ], dtype=tf.int64), @@ -414,6 +433,10 @@ class TestMergeFeatureIdRequestFunc(unittest.TestCase): sess.run(except_output_tensorshape_split)) +@mock.patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class TestExportReadEmbKeyV2OpFunc(unittest.TestCase): """ Test for 'mx_rec.core.asc.helper.export_read_emb_key_v2_op'. @@ -539,6 +562,10 @@ class TestExportReadEmbKeyV2OpFunc(unittest.TestCase): self.assertEqual(sess.run(output_batch[-1][0]), 0) +@mock.patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class TestGetTargetTensorsWithArgsIndexesFunc(unittest.TestCase): """ Test for 'mx_rec.core.asc.helper.get_target_tensors_with_args_indexes'. @@ -581,6 +608,10 @@ class TestGetTargetTensorsWithArgsIndexesFunc(unittest.TestCase): self.assertEqual(sess.run(tf.shape(batch.get("new_label_0"))), batch_size) +@mock.patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class TestGetTargetTensorsWithFeatureSpecFunc(unittest.TestCase): """ Test for 'mx_rec.core.asc.helper.get_target_tensors_with_feature_specs'. diff --git a/tests/mx_rec/core/test_manager.py b/tests/mx_rec/core/test_manager.py index 038ea60e..b8191249 100644 --- a/tests/mx_rec/core/test_manager.py +++ b/tests/mx_rec/core/test_manager.py @@ -21,26 +21,7 @@ from unittest import mock import tensorflow as tf from mx_rec.core.asc.feature_spec import FeatureSpec -from tests.mx_rec.core.mock_class import MockSparseEmbedding, MockOptimizer, MockHybridMgmt - - -class TestCheckDanglingTableFunc(unittest.TestCase): - """ - Test for 'mx_rec.core.asc.manager.check_dangling_table'. - """ - - @mock.patch.multiple("mx_rec.core.asc.manager", - export_dangling_table=mock.MagicMock(return_value=[]), - export_table_instances=mock.MagicMock(return_value={}), - find_dangling_table=mock.MagicMock(return_value=["test_table"])) - def test_check_dangling_table(self): - """ - case: test check_dangling_table - """ - - from mx_rec.core.asc.manager import check_dangling_table - - self.assertListEqual(check_dangling_table(), ["test_table"]) +from tests.mx_rec.core.mock_class import MockSparseEmbedding, MockOptimizer, MockHybridMgmt, MockConfigInitializer class TestGenerateTableInfoListFunc(unittest.TestCase): @@ -48,8 +29,8 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): Test for 'mx_rec.core.asc.manager.generate_table_info_list'. """ - @mock.patch("mx_rec.core.asc.manager.export_table_instances") - def test_generate_table_info_list_case1(self, mock_export_table_instances): + @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") + def test_generate_table_info_list_case1(self, merge_table_config_initializer): """ case1: 一张表开DDR,一张表没开DDR,抛出异常 """ @@ -57,11 +38,14 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): from mx_rec.core.asc.manager import generate_table_info_list with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer() + merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + test_table1 = MockSparseEmbedding("test_table1") - test_table1.host_vocabulary_size = 1 + test_table1.is_hbm = False test_table2 = MockSparseEmbedding("test_table2") - test_table2.host_vocabulary_size = 0 - mock_export_table_instances.return_value = { + test_table2.is_hbm = True + mock_config_initializer.get_instance().sparse_embed_config.table_instance_dict = { "test_table1": test_table1, "test_table2": test_table2 } @@ -70,11 +54,10 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): generate_table_info_list() @mock.patch.multiple("mx_rec.core.asc.manager", - export_optimizer=mock.MagicMock(return_value=None), should_skip=mock.MagicMock(return_value=True), check_dangling_table=mock.MagicMock(return_value=["test_table"])) - @mock.patch("mx_rec.core.asc.manager.export_table_instances") - def test_generate_table_info_list_case2(self, mock_export_table_instances): + @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") + def test_generate_table_info_list_case2(self, merge_table_config_initializer): """ case2: test_table是dangling_table,skip为True """ @@ -82,19 +65,21 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): from mx_rec.core.asc.manager import generate_table_info_list with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer() + merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + test_table = MockSparseEmbedding("test_table") test_table.host_vocabulary_size = 1 - mock_export_table_instances.return_value = dict(test_table=test_table) + mock_config_initializer.get_instance().sparse_embed_config.table_instance_dict = dict(test_table=test_table) table_info_list = generate_table_info_list() self.assertListEqual(table_info_list, []) @mock.patch.multiple("mx_rec.core.asc.manager", - export_optimizer=mock.MagicMock(return_value=None), should_skip=mock.MagicMock(return_value=True), check_dangling_table=mock.MagicMock(return_value=[])) - @mock.patch("mx_rec.core.asc.manager.export_table_instances") - def test_generate_table_info_list_case3(self, mock_export_table_instances): + @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") + def test_generate_table_info_list_case2(self, merge_table_config_initializer): """ case3: test_table不是dangling_table,skip为True """ @@ -102,24 +87,25 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): from mx_rec.core.asc.manager import generate_table_info_list with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer() + merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + test_table = MockSparseEmbedding("test_table") test_table.host_vocabulary_size = 1 - mock_export_table_instances.return_value = dict(test_table=test_table) + mock_config_initializer.get_instance().sparse_embed_config.table_instance_dict = dict(test_table=test_table) table_info_list = generate_table_info_list() self.assertListEqual(table_info_list, []) @mock.patch.multiple("mx_rec.core.asc.manager", - export_optimizer=mock.MagicMock(return_value=None), EmbInfoParams=mock.MagicMock(return_value=None), EmbInfo=mock.MagicMock(return_value="test_table_info"), matched_emb_initializer=mock.MagicMock(return_value=[]), matched_opt_slot_initializers=mock.MagicMock(return_value=[]), should_skip=mock.MagicMock(return_value=False), - get_use_static=mock.MagicMock(return_value=True), check_dangling_table=mock.MagicMock(return_value=[])) - @mock.patch("mx_rec.core.asc.manager.export_table_instances") - def test_generate_table_info_list_case3(self, mock_export_table_instances): + @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") + def test_generate_table_info_list_case3(self, merge_table_config_initializer): """ case4: 静态shape,test_table不是dangling_table,skip为False """ @@ -127,6 +113,9 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): from mx_rec.core.asc.manager import generate_table_info_list with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer(use_static=True) + merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + test_table = MockSparseEmbedding("test_table") test_table.host_vocabulary_size = 8 test_table.send_count = 1 @@ -135,10 +124,10 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): test_table.slice_ssd_vocabulary_size = 0 test_table.is_grad = True test_table.is_save = True - test_table.scalar_emb_size = 8 + test_table.emb_size = 8 test_table.ext_emb_size = 8 test_table.ssd_data_path = "" - mock_export_table_instances.return_value = dict(test_table=test_table) + mock_config_initializer.get_instance().sparse_embed_config.table_instance_dict = dict(test_table=test_table) table_info_list = generate_table_info_list() self.assertListEqual(table_info_list, ["test_table_info"]) @@ -162,7 +151,7 @@ class TestMatchedConstantInitializerFunc(unittest.TestCase): with tf.Graph().as_default(): table_info = MockSparseEmbedding("test_table") table_info.init_param = 1. - table_info.scalar_emb_size = 8 + table_info.emb_size = 8 table_info.emb_initializer.value = 0 self.assertListEqual(matched_constant_initializer(table_info), []) @@ -186,7 +175,7 @@ class TestMatchedRandomNormalInitializerFunc(unittest.TestCase): with tf.Graph().as_default(): table_info = MockSparseEmbedding("test_table") table_info.init_param = 1. - table_info.scalar_emb_size = 8 + table_info.emb_size = 8 table_info.emb_initializer.seed = None table_info.emb_initializer.mean = 1 table_info.emb_initializer.stddev = 1 @@ -206,7 +195,7 @@ class TestMatchedRandomNormalInitializerFunc(unittest.TestCase): with tf.Graph().as_default(): table_info = MockSparseEmbedding("test_table") table_info.init_param = 1. - table_info.scalar_emb_size = 8 + table_info.emb_size = 8 table_info.emb_initializer.seed = 1 table_info.emb_initializer.mean = 1 table_info.emb_initializer.stddev = 1 @@ -232,7 +221,7 @@ class TestMatchedTruncatedNormalInitializerFunc(unittest.TestCase): with tf.Graph().as_default(): table_info = MockSparseEmbedding("test_table") table_info.init_param = 1. - table_info.scalar_emb_size = 8 + table_info.emb_size = 8 table_info.emb_initializer.seed = None table_info.emb_initializer.mean = 1 table_info.emb_initializer.stddev = 1 @@ -252,7 +241,7 @@ class TestMatchedTruncatedNormalInitializerFunc(unittest.TestCase): with tf.Graph().as_default(): table_info = MockSparseEmbedding("test_table") table_info.init_param = 1. - table_info.scalar_emb_size = 8 + table_info.emb_size = 8 table_info.emb_initializer.seed = 1 table_info.emb_initializer.mean = 1 table_info.emb_initializer.stddev = 1 @@ -276,7 +265,7 @@ class TestMatchedEmbInitializerFunc(unittest.TestCase): with tf.Graph().as_default(): table_info = MockSparseEmbedding("test_table") - table_info.scalar_emb_size = 8 + table_info.emb_size = 8 table_info.emb_initializer = tf.constant_initializer() self.assertEqual(matched_emb_initializer(table_info), 1) @@ -292,7 +281,7 @@ class TestMatchedEmbInitializerFunc(unittest.TestCase): with tf.Graph().as_default(): table_info = MockSparseEmbedding("test_table") - table_info.scalar_emb_size = 8 + table_info.emb_size = 8 table_info.emb_initializer = tf.random_normal_initializer() self.assertEqual(matched_emb_initializer(table_info), 2) @@ -308,7 +297,7 @@ class TestMatchedEmbInitializerFunc(unittest.TestCase): with tf.Graph().as_default(): table_info = MockSparseEmbedding("test_table") - table_info.scalar_emb_size = 8 + table_info.emb_size = 8 table_info.emb_initializer = tf.truncated_normal_initializer() self.assertEqual(matched_emb_initializer(table_info), 3) @@ -330,7 +319,7 @@ class TestMatchedOptSlotInitializersFunc(unittest.TestCase): with tf.Graph().as_default(): table_instance = MockSparseEmbedding("test_table") - table_instance.scalar_emb_size = 8 + table_instance.emb_size = 8 table_instance.ext_emb_size = 8 table_instance.optimizer_instance_list = [MockOptimizer()] @@ -345,8 +334,8 @@ class TestGenerateThresholdListFunc(unittest.TestCase): @mock.patch.multiple("mx_rec.core.asc.manager", ThresholdValue=mock.MagicMock(return_value=0)) - @mock.patch("mx_rec.core.asc.manager.export_feature_spec") - def test_generate_threshold_list(self, mock_export_feature_spec): + @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") + def test_generate_threshold_list(self, manager_config_initializer): """ case: 有淘汰、准入 """ @@ -354,14 +343,15 @@ class TestGenerateThresholdListFunc(unittest.TestCase): from mx_rec.core.asc.manager import generate_threshold_list with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer() + manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + test_feature_spec1 = FeatureSpec("test_feature_spec1", access_threshold=5, eviction_threshold=10, faae_coefficient=None) test_feature_spec2 = FeatureSpec("test_feature_spec2", access_threshold=5, faae_coefficient=None) - mock_export_feature_spec.return_value = { - "test_feature_spec1": test_feature_spec1, - "test_feature_spec2": test_feature_spec2 - } + mock_config_initializer.get_instance().feature_spec_config.insert_feature_spec(test_feature_spec1, True) + mock_config_initializer.get_instance().feature_spec_config.insert_feature_spec(test_feature_spec2, True) self.assertListEqual(generate_threshold_list(), [0, 0]) @@ -375,25 +365,22 @@ class TestInitializeEmbCacheFunc(unittest.TestCase): get_rank_id=mock.MagicMock(return_value=0), get_device_id=mock.MagicMock(return_value=0), get_rank_size=mock.MagicMock(return_value=0), - get_train_steps=mock.MagicMock(return_value=0), - get_eval_steps=mock.MagicMock(return_value=0), - get_save_steps=mock.MagicMock(return_value=0), - get_if_load=mock.MagicMock(return_value=False), - get_use_static=mock.MagicMock(return_value=True), - get_use_hot=mock.MagicMock(return_value=True), - get_use_dynamic_expansion=mock.MagicMock(return_value=True), USE_STATIC=mock.MagicMock(return_value=0), USE_HOT=mock.MagicMock(return_value=1), USE_DYNAMIC_EXPANSION=mock.MagicMock(return_value=2), RankInfo=mock.MagicMock(return_value="mock_info"), HybridMgmt=mock.MagicMock(return_value=MockHybridMgmt(is_initialized=False))) - def test_initialize_emb_cache_case1(self): + @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") + def test_initialize_emb_cache_case1(self, manager_config_initializer): """ case1: 初始化失败 """ from mx_rec.core.asc.manager import initialize_emb_cache + mock_config_initializer = MockConfigInitializer(use_static=True, use_dynamic_expansion=True) + manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + with self.assertRaises(RuntimeError): initialize_emb_cache([], []) @@ -401,26 +388,22 @@ class TestInitializeEmbCacheFunc(unittest.TestCase): get_rank_id=mock.MagicMock(return_value=0), get_device_id=mock.MagicMock(return_value=0), get_rank_size=mock.MagicMock(return_value=0), - get_train_steps=mock.MagicMock(return_value=0), - get_eval_steps=mock.MagicMock(return_value=0), - get_save_steps=mock.MagicMock(return_value=0), - get_if_load=mock.MagicMock(return_value=False), - get_use_static=mock.MagicMock(return_value=True), - get_use_hot=mock.MagicMock(return_value=True), - get_use_dynamic_expansion=mock.MagicMock(return_value=True), USE_STATIC=mock.MagicMock(return_value=0), USE_HOT=mock.MagicMock(return_value=1), USE_DYNAMIC_EXPANSION=mock.MagicMock(return_value=2), - RankInfo=mock.MagicMock(return_value="mock_info"), - set_asc_manager=mock.MagicMock(return_value=None)) + RankInfo=mock.MagicMock(return_value="mock_info")) + @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") @mock.patch("mx_rec.core.asc.manager.HybridMgmt") - def test_initialize_emb_cache_case2(self, mock_hybrid_mgmt): + def test_initialize_emb_cache_case2(self, mock_hybrid_mgmt, manager_config_initializer): """ case2: 初始化成功 """ from mx_rec.core.asc.manager import initialize_emb_cache + mock_config_initializer = MockConfigInitializer(use_static=True, use_dynamic_expansion=True) + manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + mock_mgmt = MockHybridMgmt(is_initialized=True) mock_hybrid_mgmt.return_value = mock_mgmt initialize_emb_cache([], []) @@ -448,21 +431,22 @@ class TestStartAscPipeLineFunc(unittest.TestCase): @mock.patch.multiple("mx_rec.core.asc.manager", generate_table_info_list=mock.MagicMock(return_value=["test_table"]), generate_threshold_list=mock.MagicMock(return_value=[]), - get_stat_on=mock.MagicMock(return_value=True), - is_asc_manager_initialized=mock.MagicMock(return_value=False), - export_table_num=mock.MagicMock(return_value=None), initialize_emb_cache=mock.MagicMock(return_value=None)) - def test_start_asc_pipeline_case2(self): + @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") + def test_start_asc_pipeline_case2(self, manager_config_initializer): """ case2: table_info_list为["test_table"],stat_on为True """ from mx_rec.core.asc.manager import start_asc_pipeline + mock_config_initializer = MockConfigInitializer() + manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + # 该函数无返回值,且内部调用的函数已经被mock了,无异常、无返回值 start_asc_pipeline() self.assertTrue(callable(start_asc_pipeline)) if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/mx_rec/core/test_merge_table.py b/tests/mx_rec/core/test_merge_table.py index 01d6736e..9e3094dc 100644 --- a/tests/mx_rec/core/test_merge_table.py +++ b/tests/mx_rec/core/test_merge_table.py @@ -21,7 +21,7 @@ from unittest import mock import tensorflow as tf import mx_rec.core.asc.merge_table -from tests.mx_rec.core.mock_class import MockSparseEmbedding +from tests.mx_rec.core.mock_class import MockSparseEmbedding, MockConfigInitializer, MockGlobalEnv class TestAffirmFunc(unittest.TestCase): @@ -83,7 +83,11 @@ class TestIsTrainTaskFunc(unittest.TestCase): Test for 'mx_rec.core.asc.merge_table.is_train_task'. """ - def setUp(self): + @mock.patch("mx_rec.graph.patch.ConfigInitializer") + def setUp(self, graph_patch_config_initializer): + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) + graph_patch_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + # 在tensorflow默认图中添加op a = tf.constant(1) b = tf.constant(2) @@ -95,38 +99,35 @@ class TestIsTrainTaskFunc(unittest.TestCase): # 删除tensorflow默认图中添加的op tf.reset_default_graph() - @mock.patch.multiple("mx_rec.core.asc.merge_table", - get_bool_gauge_set=mock.MagicMock(return_value=[])) - def test_is_train_task_case1(self): + @mock.patch("mx_rec.core.asc.merge_table.ConfigInitializer") + @mock.patch("mx_rec.graph.patch.ConfigInitializer") + def test_is_train_task_case1(self, graph_patch_config_initializer, merge_table_config_initializer): """ case1: bool_gauge_set为[] """ from mx_rec.core.asc.merge_table import is_train_task + mock_config_initializer = MockConfigInitializer(bool_gauge_set=[]) + graph_patch_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + self.assertFalse(is_train_task()) @mock.patch.multiple("mx_rec.core.asc.merge_table", - get_bool_gauge_set=mock.MagicMock(return_value=["train"]), check_op=mock.MagicMock(return_value=True)) - def test_is_train_task_case2(self): + @mock.patch("mx_rec.core.asc.merge_table.ConfigInitializer") + @mock.patch("mx_rec.graph.patch.ConfigInitializer") + def test_is_train_task_case2(self, graph_patch_config_initializer, merge_table_config_initializer): """ case2: bool_gauge_set为["train"],且check_op为True """ from mx_rec.core.asc.merge_table import is_train_task - self.assertTrue(is_train_task()) - - @mock.patch.multiple("mx_rec.core.asc.merge_table", - get_bool_gauge_set=mock.MagicMock(return_value=["train"]), - check_op=mock.MagicMock(return_value=False)) - def test_is_train_task_case3(self): - """ - case3: bool_gauge_set为["train"],且check_op为False - """ - - from mx_rec.core.asc.merge_table import is_train_task + mock_config_initializer = MockConfigInitializer(bool_gauge_set=["train"]) + graph_patch_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) self.assertTrue(is_train_task()) @@ -136,7 +137,11 @@ class TestFindDanglingTableFunc(unittest.TestCase): Test for 'mx_rec.core.asc.merge_table.find_dangling_table'. """ - def setUp(self): + @mock.patch("mx_rec.graph.patch.ConfigInitializer") + def setUp(self, graph_patch_config_initializer): + mock_config_initializer = MockConfigInitializer() + graph_patch_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + # 在tensorflow默认图中添加op a = tf.constant(1) b = tf.constant(2) @@ -161,7 +166,7 @@ class TestFindDanglingTableFunc(unittest.TestCase): @mock.patch.multiple("mx_rec.core.asc.merge_table", is_train_task=mock.MagicMock(return_value=True), - get_enable_table_merge=mock.MagicMock(return_value=False)) + global_env=mock.MagicMock(return_value=MockGlobalEnv(tf_device='GPU'))) def test_find_dangling_table_case2(self): """ case2: is_train_task为True,merge为False @@ -171,23 +176,6 @@ class TestFindDanglingTableFunc(unittest.TestCase): self.assertListEqual(find_dangling_table([]), []) - @mock.patch.multiple("mx_rec.core.asc.merge_table", - is_train_task=mock.MagicMock(return_value=True), - affirm=mock.MagicMock(return_value=True), - get_enable_table_merge=mock.MagicMock(return_value=True), - insert_dangling_table=mock.MagicMock(return_value=None)) - @mock.patch("mx_rec.core.asc.merge_table.export_table_instances") - def test_find_dangling_table_case3(self, mock_export_table_instances): - """ - case3: is_train_task为True,merge为True - """ - - from mx_rec.core.asc.merge_table import find_dangling_table - - mock_export_table_instances.return_value = {"table1": MockSparseEmbedding("table1")} - dangling_table = find_dangling_table(["table2"]) - self.assertListEqual(dangling_table, ["table2", "table1"]) - class TestShouldSkipFunc(unittest.TestCase): """ @@ -229,4 +217,4 @@ class TestShouldSkipFunc(unittest.TestCase): if __name__ == '__main__': - unittest.main() + unittest.main() \ No newline at end of file diff --git a/tests/mx_rec/data/test_dataset.py b/tests/mx_rec/data/test_dataset.py index d1047e17..b3884fac 100644 --- a/tests/mx_rec/data/test_dataset.py +++ b/tests/mx_rec/data/test_dataset.py @@ -16,13 +16,19 @@ # ============================================================================== import unittest +from unittest import mock import tensorflow as tf from tests.mx_rec.core.generator_dataset import generate_dataset, Config +from tests.mx_rec.core.mock_class import MockConfigInitializer from tests.mx_rec.data.mock_class import MockEosOpsLib +@mock.patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class TestEosDatasetClass(unittest.TestCase): """ Test for 'mx_rec.data.dataset.EosDataset'. diff --git a/tests/mx_rec/graph/test_acg_push_ops.py b/tests/mx_rec/graph/test_acg_push_ops.py index 362eb296..129b773f 100644 --- a/tests/mx_rec/graph/test_acg_push_ops.py +++ b/tests/mx_rec/graph/test_acg_push_ops.py @@ -18,7 +18,6 @@ from unittest import TestCase from unittest.mock import patch, Mock -import numpy as np import tensorflow as tf from tensorflow.core.framework import node_def_pb2 from tensorflow.python.data.ops.dataset_ops import DatasetV1 @@ -48,14 +47,13 @@ from mx_rec.graph.acg_push_ops import ( _replace_get_next_op, _patched_get_src_dataset, ) - +from tests.mx_rec.core.mock_class import MockConfigInitializer from tests.mx_rec.graph.mock_dataset import gen_mock_dataset @patch.multiple( "mx_rec.graph.patch", - get_modify_graph=Mock(return_value=True), - get_is_graph_modify_hook_running=Mock(return_value=True), + ConfigInitializer=Mock(return_value=MockConfigInitializer(modify_graph=True, is_graph_modify_hook_running=True)), ) @patch.multiple( "tensorflow.compat.v1.train.Saver", @@ -169,7 +167,7 @@ class GetDatasetOpTest(TestCase): tgt_dataset_op = _get_dataset_op(mock_graph, mock_get_next_op) self.assertEqual(tgt_dataset_op, expected) - + def test_err_invalid_get_next_op_type(self): mock_get_next_op = tf.zeros(shape=(3,)).op mock_graph = tf.compat.v1.get_default_graph() @@ -353,6 +351,10 @@ class GetMappingForSubgraphTest(TestCase): self.assertEqual(len(mock_tensor_mapping), 2) +@patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=Mock(return_value=MockConfigInitializer(modify_graph=True, is_graph_modify_hook_running=True)), +) class FrozenVariableNodeToFuncConstNodeDefTest(TestCase): def tearDown(self) -> None: tf.compat.v1.reset_default_graph() diff --git a/tests/mx_rec/graph/test_merge_lookup.py b/tests/mx_rec/graph/test_merge_lookup.py index 19bf0686..84425dc2 100644 --- a/tests/mx_rec/graph/test_merge_lookup.py +++ b/tests/mx_rec/graph/test_merge_lookup.py @@ -24,6 +24,7 @@ import tensorflow as tf from tensorflow import Tensor import mx_rec.graph.merge_lookup as merge_lookup from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCAnchorAttr +from tests.mx_rec.core.mock_class import MockConfigInitializer def mock_get_anchor_attribute(anchor: Tensor, attr: ASCAnchorAttr) -> Union[bool, Mock]: @@ -34,7 +35,7 @@ def mock_get_anchor_attribute(anchor: Tensor, attr: ASCAnchorAttr) -> Union[bool if attr == ASCAnchorAttr.TABLE_INSTANCE: mock_table_instance = Mock() mock_table_instance.table_name = "mock_table_name" - mock_table_instance.lookup_name_dict = {True: ["lookup_1", "lookup_2"]} + mock_table_instance.multi_lookup_times = {True: 2} mock_table_instance.send_count = 4096 * 8 return mock_table_instance if attr == ASCAnchorAttr.FEATURE_SPEC: @@ -51,39 +52,37 @@ class DoMergeLookupTest(TestCase): @patch.multiple( "mx_rec.graph.merge_lookup", - get_modify_graph=Mock(return_value=True), - get_merged_multi_lookup=Mock(return_value=False), - get_use_static=Mock(return_value=False), replace_anchor_vec=Mock(), - insert_merged_multi_lookup=Mock(), ) - @patch.multiple("mx_rec.graph.merge_lookup.SparseEmbedding", get_anchor_attribute=mock_get_anchor_attribute) - def test_ok(self): + @patch.multiple("mx_rec.graph.merge_lookup.BaseSparseEmbedding", get_anchor_attribute=mock_get_anchor_attribute) + @patch("mx_rec.graph.merge_lookup.ConfigInitializer") + def test_ok(self, merge_lookup_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=True, merged_multi_lookup=False, use_static=False) + merge_lookup_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + mock_cutting_point = tf.identity(tf.zeros(shape=(4096, 8))) tf.compat.v1.add_to_collection(ASCEND_SPARSE_LOOKUP_ENTRANCE, mock_cutting_point) merge_lookup.do_merge_lookup() - @patch.multiple( - "mx_rec.graph.merge_lookup", - get_modify_graph=Mock(return_value=False), - ) - def test_ok_disable_modify_graph(self): + @patch("mx_rec.graph.merge_lookup.ConfigInitializer") + def test_ok_disable_modify_graph(self, merge_lookup_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=False) + merge_lookup_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + merge_lookup.do_merge_lookup() - @patch.multiple( - "mx_rec.graph.merge_lookup", - get_modify_graph=Mock(return_value=True), - get_merged_multi_lookup=Mock(return_value=True), - ) - def test_ok_already_exec_merged_lookup(self): + @patch("mx_rec.graph.merge_lookup.ConfigInitializer") + def test_ok_already_exec_merged_lookup(self, merge_lookup_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=True, merged_multi_lookup=True) + merge_lookup_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + merge_lookup.do_merge_lookup() - @patch.multiple( - "mx_rec.graph.merge_lookup", - get_modify_graph=Mock(return_value=True), - get_merged_multi_lookup=Mock(return_value=False), - ) - def test_err_empty_cutting_point_list(self): + @patch("mx_rec.graph.merge_lookup.ConfigInitializer") + def test_err_empty_cutting_point_list(self, merge_lookup_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=True, merged_multi_lookup=False) + merge_lookup_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + with self.assertRaises(RuntimeError): merge_lookup.do_merge_lookup() diff --git a/tests/mx_rec/graph/test_modifier.py b/tests/mx_rec/graph/test_modifier.py index a648f5cf..a04088fe 100644 --- a/tests/mx_rec/graph/test_modifier.py +++ b/tests/mx_rec/graph/test_modifier.py @@ -15,7 +15,6 @@ # limitations under the License. # ============================================================================== -import os import unittest from collections import defaultdict from unittest import TestCase @@ -24,12 +23,15 @@ from typing import Union, Callable import tensorflow as tf from tensorflow import Tensor + from mx_rec.constants.constants import ( ASCEND_CUTTING_POINT_INITIALIZER, ASCEND_SPARSE_LOOKUP_ENTRANCE, ASCEND_TIMESTAMP, ASCAnchorAttr, ) +from mx_rec.core.asc import FeatureSpec +from mx_rec.graph.graph_typing import AnchorRecord from mx_rec.graph.modifier import ( GraphModifierHook, find_make_iterator_op, @@ -45,7 +47,7 @@ from mx_rec.graph.modifier import ( get_timestamp_index, modify_graph_for_asc, ) - +from tests.mx_rec.core.mock_class import MockConfigInitializer from tests.mx_rec.graph.mock_dataset import gen_mock_dataset @@ -191,44 +193,6 @@ class FindTargetInstanceDatasetTest(TestCase): find_target_instance_dataset(None) -class GenerateGetNextOpSpecsTest(TestCase): - def tearDown(self) -> None: - tf.compat.v1.reset_default_graph() - - @patch.multiple("mx_rec.graph.merge_lookup.SparseEmbedding", get_anchor_attribute=Mock(return_value=True)) - def test_ok(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_ids = mock_batch.get("mock_ids") - mock_labels = mock_batch.get("mock_labels") - mock_cutting_point_list = [mock_ids, mock_labels] - - get_next_op = mock_ids.op - replacement_specs = defaultdict(dict) - passing_tensor_list = [mock_ids, mock_labels] - batch_tensor_index_list = [0, 1] - sub_cutting_point_list = [mock_ids, mock_labels] - sub_graph_def = tf.compat.v1.GraphDef() - input_name_list = [mock_ids.name, mock_labels.name] - output_name_list = [mock_ids.name, mock_labels.name] - is_training = True - - get_next_op_map = generate_get_next_op_specs(mock_cutting_point_list) - expected = defaultdict(dict) - expected[get_next_op] = { - "replacement_specs": replacement_specs, - "passing_tensor_list": passing_tensor_list, - "batch_tensor_index_list": batch_tensor_index_list, - "sub_cutting_point_list": sub_cutting_point_list, - "sub_graph_def": sub_graph_def, - "input_name_list": input_name_list, - "output_name_list": output_name_list, - "is_training": is_training, - } - self.assertEqual(get_next_op_map, expected) - - class GetSrcDatasetTest(TestCase): def tearDown(self) -> None: tf.compat.v1.reset_default_graph() @@ -246,28 +210,39 @@ class GetSrcDatasetTest(TestCase): self.assertEqual(src_dataset, mock_dataset) +@patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=Mock(return_value=MockConfigInitializer()), +) class GetTgtDatasetTest(TestCase): def tearDown(self) -> None: tf.compat.v1.reset_default_graph() @patch.multiple( "mx_rec.graph.modifier", - get_training_mode_channel_id=Mock(return_value=0), get_asc_insert_func=Mock(return_value=lambda x, y: x), ) - @patch.multiple("mx_rec.graph.modifier.SparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) - def test_ok(self): + @patch.multiple("mx_rec.graph.modifier.BaseSparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) + @patch("mx_rec.graph.modifier.ConfigInitializer") + def test_ok(self, modifier_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=True) + modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + mock_dataset = gen_mock_dataset() mock_iterator = mock_dataset.make_initializable_iterator() mock_batch = mock_iterator.get_next() mock_ids = mock_batch.get("mock_ids") mock_sub_cutting_point_list = [mock_ids] - mock_records = { - "sub_graph_def": tf.compat.v1.GraphDef(), - "input_name_list": [], - "output_name_list": [], - "batch_tensor_index_list": [], - } + mock_records = AnchorRecord( + defaultdict(), + [], + [], + [], + tf.compat.v1.GraphDef(), + [], + [], + True + ) tgt_dataset = get_tgt_dataset(mock_dataset, mock_sub_cutting_point_list, mock_records) new_iter = tgt_dataset.make_initializable_iterator() @@ -284,15 +259,14 @@ class ModifyGraphForAscTest(TestCase): @patch.multiple( "mx_rec.graph.modifier", - get_training_mode_channel_id=Mock(return_value=True), get_asc_insert_func=Mock(return_value=lambda x, y: x), - set_iterator_type=Mock(), - set_initializer=Mock(), - set_target_batch=Mock(), - get_merged_multi_lookup=Mock(return_value=True), ) - @patch.multiple("mx_rec.graph.modifier.SparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) - def test_ok_train_mode(self): + @patch.multiple("mx_rec.graph.modifier.BaseSparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) + @patch("mx_rec.graph.modifier.ConfigInitializer") + def test_ok_train_mode(self, modifier_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=True, merged_multi_lookup=True) + modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + mock_dataset = gen_mock_dataset() mock_iterator = mock_dataset.make_initializable_iterator() mock_batch = mock_iterator.get_next() @@ -305,20 +279,19 @@ class ModifyGraphForAscTest(TestCase): @patch.multiple( "mx_rec.graph.modifier", - get_training_mode_channel_id=Mock(return_value=True), get_asc_insert_func=Mock(return_value=lambda x, y: x), - set_iterator_type=Mock(), - set_initializer=Mock(), - set_target_batch=Mock(), - get_merged_multi_lookup=Mock(return_value=True), do_merge_lookup=Mock(), - get_bool_gauge_set=Mock(return_value={"evaluate"}), - insert_merged_multi_lookup=Mock(), ) @patch.multiple( - "mx_rec.graph.modifier.SparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute(is_training=False) + "mx_rec.graph.modifier.BaseSparseEmbedding", + get_anchor_attribute=_gen_mock_get_anchor_attribute(is_training=False) ) - def test_ok_eval_mode(self): + @patch("mx_rec.graph.modifier.ConfigInitializer") + def test_ok_eval_mode(self, modifier_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=True, merged_multi_lookup=True, + bool_gauge_set={"evaluate"}) + modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + mock_dataset = gen_mock_dataset() mock_iterator = mock_dataset.make_initializable_iterator() mock_batch = mock_iterator.get_next() @@ -331,16 +304,14 @@ class ModifyGraphForAscTest(TestCase): @patch.multiple( "mx_rec.graph.modifier", - get_training_mode_channel_id=Mock(return_value=True), get_asc_insert_func=Mock(return_value=lambda x, y: x), - set_iterator_type=Mock(), - set_initializer=Mock(), - set_target_batch=Mock(), - get_merged_multi_lookup=Mock(return_value=False), - insert_merged_multi_lookup=Mock(), ) - @patch.multiple("mx_rec.graph.modifier.SparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) - def test_err_not_clear_flag(self): + @patch.multiple("mx_rec.graph.modifier.BaseSparseEmbedding", get_anchor_attribute=_gen_mock_get_anchor_attribute()) + @patch("mx_rec.graph.modifier.ConfigInitializer") + def test_err_not_clear_flag(self, modifier_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=True, merged_multi_lookup=False) + modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + mock_dataset = gen_mock_dataset() mock_iterator = mock_dataset.make_initializable_iterator() mock_batch = mock_iterator.get_next() @@ -357,17 +328,16 @@ class GetTimestampIndexTest(TestCase): def tearDown(self) -> None: tf.compat.v1.reset_default_graph() - @patch.multiple( - "mx_rec.graph.modifier", - insert_feature_spec=Mock(), - get_feature_spec=Mock(return_value=None), - ) @patch.multiple( "mx_rec.graph.modifier.FeatureSpec", include_timestamp=Mock(), index_key=Mock(return_value=2), ) - def test_ok(self): + @patch("mx_rec.graph.modifier.ConfigInitializer") + def test_ok(self, modifier_config_initializer): + mock_config_initializer = MockConfigInitializer() + modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + mock_dataset = gen_mock_dataset() mock_iterator = mock_dataset.make_initializable_iterator() mock_batch = mock_iterator.get_next() @@ -379,33 +349,10 @@ class GetTimestampIndexTest(TestCase): timestamp_index = get_timestamp_index(mock_get_next_op, is_training=True) self.assertEqual(timestamp_index, 2) - @patch.multiple( - "mx_rec.graph.modifier", - insert_feature_spec=Mock(), - get_feature_spec=Mock(), - ) - @patch.multiple( - "mx_rec.graph.modifier.FeatureSpec", - include_timestamp=Mock(), - index_key=Mock(return_value=0), - ) - def test_err_unmatched_timestamp_index(self): - mock_dataset = gen_mock_dataset() - mock_iterator = mock_dataset.make_initializable_iterator() - mock_batch = mock_iterator.get_next() - mock_timestamp = mock_batch.get("mock_timestamp") - mock_get_next_op = mock_timestamp.op - - tf.compat.v1.add_to_collection(ASCEND_TIMESTAMP, mock_timestamp) - - with self.assertRaises(ValueError): - get_timestamp_index(mock_get_next_op, is_training=True) - @patch.multiple( "mx_rec.graph.patch", - get_modify_graph=Mock(return_value=True), - get_is_graph_modify_hook_running=Mock(return_value=True), + ConfigInitializer=Mock(return_value=MockConfigInitializer()), ) @patch.multiple( "tensorflow.compat.v1.train.Saver", @@ -418,12 +365,15 @@ class GraphModifierHookTest(TestCase): @patch.multiple( "mx_rec.graph.modifier", - set_is_graph_modify_hook_running=Mock(), modify_graph_and_start_emb_cache=Mock(), start_asc_pipeline=Mock(), - get_iterator_type=Mock(return_value="MakeIterator"), ) - def test_ok(self): + @patch("mx_rec.graph.modifier.ConfigInitializer") + def test_ok(self, modifier_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=True, is_graph_modify_hook_running=True, + iterator_type="MakeIterator") + modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + mock_dataset = gen_mock_dataset() mock_iterator = mock_dataset.make_initializable_iterator() mock_batch = mock_iterator.get_next() @@ -439,12 +389,15 @@ class GraphModifierHookTest(TestCase): @patch.multiple( "mx_rec.graph.modifier", - set_is_graph_modify_hook_running=Mock(), modify_graph_and_start_emb_cache=Mock(), start_asc_pipeline=Mock(), - get_iterator_type=Mock(return_value="InvalidIterator"), ) - def test_err_invalid_iterator_type(self): + @patch("mx_rec.graph.modifier.ConfigInitializer") + def test_err_invalid_iterator_type(self, modifier_config_initializer): + mock_config_initializer = MockConfigInitializer(modify_graph=True, is_graph_modify_hook_running=True, + iterator_type="InvalidIterator") + modifier_config_initializer.get_instance = Mock(return_value=mock_config_initializer) + mock_dataset = gen_mock_dataset() mock_iterator = mock_dataset.make_initializable_iterator() mock_batch = mock_iterator.get_next() diff --git a/tests/mx_rec/graph/test_utils.py b/tests/mx_rec/graph/test_utils.py index 0b810fac..5a4efffc 100644 --- a/tests/mx_rec/graph/test_utils.py +++ b/tests/mx_rec/graph/test_utils.py @@ -25,7 +25,7 @@ from unittest import TestCase import tensorflow as tf from tensorflow import Tensor, TensorSpec from mx_rec.constants.constants import ASCAnchorAttr -from mx_rec.core.embedding import SparseEmbedding +from mx_rec.core.emb.base_sparse_embedding import BaseSparseEmbedding from mx_rec.graph.utils import ( check_input_list, find_parent_op, @@ -160,7 +160,7 @@ class ReplaceAnchorVecTest(TestCase): anchor_vec = tf.identity(mock_cutting_point, name="anchor_vec") anchor_vec_output = tf.identity(anchor_vec, name="anchor_vec_output") - SparseEmbedding.anchor_tensor_specs[mock_cutting_point][mock_attribute] = anchor_vec + BaseSparseEmbedding.anchor_tensor_specs[mock_cutting_point][mock_attribute] = anchor_vec replace_anchor_vec(mock_cutting_point, mock_attribute, mock_anchor) self.assertEqual(anchor_vec_output.op.inputs[0], mock_anchor) diff --git a/tests/mx_rec/saver/sparse_embedding_mock.py b/tests/mx_rec/saver/sparse_embedding_mock.py index e46dd86c..03df1466 100644 --- a/tests/mx_rec/saver/sparse_embedding_mock.py +++ b/tests/mx_rec/saver/sparse_embedding_mock.py @@ -15,8 +15,6 @@ # limitations under the License. # ============================================================================== -import os - class SparseEmbeddingMock: """ @@ -28,6 +26,8 @@ class SparseEmbeddingMock: self.table_name = "test_table" self.slice_device_vocabulary_size = 10 self.scalar_emb_size = 4 + self.emb_size = 4 + self.is_hbm = host_vocab_size == 0 self.host_vocabulary_size = host_vocab_size self.optimizer = dict() self.use_dynamic_expansion = False diff --git a/tests/mx_rec/saver/test_saver.py b/tests/mx_rec/saver/test_saver.py index 4df0c025..c6055b57 100644 --- a/tests/mx_rec/saver/test_saver.py +++ b/tests/mx_rec/saver/test_saver.py @@ -23,11 +23,16 @@ import tensorflow as tf from mx_rec.saver.saver import Saver from mx_rec.constants.constants import ASCEND_GLOBAL_HASHTABLE_COLLECTION +from tests.mx_rec.core.mock_class import MockConfigInitializer from tests.mx_rec.saver.sparse_embedding_mock import SparseEmbeddingMock table_instance = SparseEmbeddingMock() +@mock.patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class TestSaver(unittest.TestCase): """ Test the function of saving and loading sparse tables. @@ -35,11 +40,16 @@ class TestSaver(unittest.TestCase): @mock.patch.multiple("mx_rec.saver.saver", get_rank_id=mock.MagicMock(return_value=0), - get_local_rank_size=mock.MagicMock(return_value=1), - get_ascend_global_hashtable_collection=mock.MagicMock( - return_value=ASCEND_GLOBAL_HASHTABLE_COLLECTION), - get_table_instance=mock.MagicMock(return_value=table_instance)) - def setUp(self): + get_local_rank_size=mock.MagicMock(return_value=1)) + @mock.patch("mx_rec.saver.saver.ConfigInitializer") + def test_save_and_load_is_consistent(self, saver_config_initializer): + mock_config_initializer = \ + MockConfigInitializer(var=table_instance, asc_manager=True, + use_dynamic_expansion=False, + host_data=[0, 1, 4, 6, 8], + ascend_global_hashtable_collection=ASCEND_GLOBAL_HASHTABLE_COLLECTION) + saver_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + self.table_name = "test_table" self.optim_m_name = "test_table/LazyAdam/m" self.optim_v_name = "test_table/LazyAdam/v" @@ -48,15 +58,6 @@ class TestSaver(unittest.TestCase): with self.graph.as_default(): self.saver = Saver() - @mock.patch.multiple("mx_rec.saver.saver", - set_sparse_dir=mock.MagicMock(), - is_asc_manager_initialized=mock.MagicMock(return_value=True), - save_host_data=mock.MagicMock(), - get_use_dynamic_expansion=mock.MagicMock(return_value=False), - get_table_instance_by_name=mock.MagicMock(return_value=table_instance), - get_host_data=mock.MagicMock(return_value=[0, 1, 4, 6, 8]), - restore_host_data=mock.MagicMock()) - def test_save_and_load_is_consistent(self): with tf.compat.v1.Session(graph=self.graph) as sess: embedding_directory = "./sparse-model/HashTable/HBM/test_table/embedding" data_file = os.path.join(embedding_directory, "slice_0.data") @@ -72,6 +73,7 @@ class TestSaver(unittest.TestCase): self.saver.restore(sess, "./model") load_embedding = sess.run(self.var)[:5, :] self.assertEqual(load_embedding.all(), origin_embedding.all()) + tf.io.gfile.rmtree("./sparse-model") def build_graph(self): self.graph = tf.compat.v1.Graph() diff --git a/tests/mx_rec/saver/test_sparse.py b/tests/mx_rec/saver/test_sparse.py index 34bb8419..8904578d 100644 --- a/tests/mx_rec/saver/test_sparse.py +++ b/tests/mx_rec/saver/test_sparse.py @@ -23,6 +23,7 @@ import tensorflow as tf import numpy as np from mx_rec.saver.saver import write_binary_data, generate_file_name +from tests.mx_rec.core.mock_class import MockConfigInitializer from tests.mx_rec.saver.sparse_embedding_mock import SparseEmbeddingMock from mx_rec.saver.sparse import export, set_upper_dir, check_table_param from mx_rec.constants.constants import DataAttr @@ -42,29 +43,6 @@ class TestSparseProcessor(unittest.TestCase): self.hbm_npy_path = None self.ddr_npy_path = None - @mock.patch.multiple("mx_rec.saver.sparse", - export_table_name_set=mock.MagicMock(return_value={"test_table"}), - get_sparse_dir=mock.MagicMock(return_value="./test_export_hbm/sparse-model"), - get_table_instance_by_name=mock.MagicMock(return_value=SparseEmbeddingMock())) - def test_export_interface_on_hbm_mode(self): - self.build_fake_hbm_save() - export() - self.assertTrue(os.path.exists(self.hbm_npy_path)) - tf.io.gfile.rmtree("./test_export_hbm") - - - @mock.patch.multiple("mx_rec.saver.sparse", - export_table_name_set=mock.MagicMock(return_value={"test_table"}), - get_sparse_dir=mock.MagicMock(return_value="./test_export_ddr/sparse-model"), - get_table_instance_by_name=mock.MagicMock( - return_value=SparseEmbeddingMock(host_vocab_size=10))) - def test_export_interface_on_ddr_mode(self): - self.build_fake_ddr_save() - export() - self.assertTrue(os.path.exists(self.ddr_npy_path)) - tf.io.gfile.rmtree("./test_export_ddr") - - def test_check_table_param(self): table_list = ["test_table_1", "test_table_0"] default_table_list = ["test_table_1", "test_table_2", "test_table_3"] diff --git a/tests/mx_rec/util/communication/test_hccl_mgmt.py b/tests/mx_rec/util/communication/test_hccl_mgmt.py index 3da291f1..0804d8fa 100644 --- a/tests/mx_rec/util/communication/test_hccl_mgmt.py +++ b/tests/mx_rec/util/communication/test_hccl_mgmt.py @@ -41,25 +41,6 @@ class HCCLMGMTTest(unittest.TestCase): """ global_env.rank_table_file = self.rank_table_file - @patch.multiple("mx_rec.util.communication.hccl_mgmt", - get_logic_id=mock.MagicMock(return_value=1)) - def test_parse_hccl_json_when_success(self): - with patch("builtins.open", mock_open(read_data="""{ - "server_count":"1", - "server_list":[ - { - "device":[ - { "device_id":"0", "device_ip":"xxx.xxx.xx.xxx", "rank_id":"0" } - ], - "server_id":"xxx.xxx.xx.xxx" - } - ], - "status":"completed", - "version":"1.0" - }""")) as mock_file: - rank_to_device_dict, local_rank_size = parse_hccl_json() - self.assertEqual(1, local_rank_size) - def test_parse_hccl_json_when_attribute_error(self): with patch("builtins.open", mock_open(read_data="""{ "server_count":"1", @@ -122,20 +103,6 @@ class HCCLMGMTTest(unittest.TestCase): with self.assertRaises(ValueError): rank_to_device_dict, local_rank_size = parse_hccl_json() - def test_set_hccl_info_without_json(self): - rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0-7", "8", "0") - self.assertEqual(8, local_rank_size) - rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0,1", "2", "0") - self.assertEqual(2, local_rank_size) - rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0", "1", "0") - self.assertEqual(1, local_rank_size) - with self.assertRaises(ValueError): - rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0", "8", "0") - with self.assertRaises(ValueError): - rank_to_device_dict, local_rank_size = set_hccl_info_without_json("0-2", "8", "3") - with self.assertRaises(ValueError): - rank_to_device_dict, local_rank_size = set_hccl_info_without_json("17", "1", "1") - if __name__ == '__main__': unittest.main() diff --git a/tests/mx_rec/util/test_variable.py b/tests/mx_rec/util/test_variable.py index 6b5e0c0e..c72ed9dc 100644 --- a/tests/mx_rec/util/test_variable.py +++ b/tests/mx_rec/util/test_variable.py @@ -23,14 +23,19 @@ import tensorflow as tf from mx_rec.util.global_env_conf import global_env from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.util.variable import get_dense_and_sparse_variable +from tests.mx_rec.core.mock_class import MockConfigInitializer class MockTableInstance: def __init__(self): - self.skip_emb_transfer = False + self.is_hbm = False self.optimizer = False +@patch.multiple( + "mx_rec.graph.patch", + ConfigInitializer=mock.Mock(return_value=MockConfigInitializer()), +) class VariableTest(unittest.TestCase): def setUp(self): """ @@ -53,9 +58,11 @@ class VariableTest(unittest.TestCase): global_env.cm_chief_device = self.cm_chief_device global_env.ascend_visible_devices = self.ascend_visible_devices - @patch.multiple("mx_rec.util.variable", - get_ascend_global_hashtable_collection=mock.MagicMock(return_value="sparse_hastable")) - def test_get_dense_and_sparse_variable(self): + @mock.patch("mx_rec.util.variable.ConfigInitializer") + def test_get_dense_and_sparse_variable(self, variable_config_initializer): + mock_config_initializer = MockConfigInitializer(ascend_global_hashtable_collection="sparse_hastable") + variable_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + dense_layer = tf.Variable([1, 2], trainable=True) sparse_emb = tf.Variable([4, 5], trainable=False) tf.compat.v1.add_to_collection("sparse_hastable", sparse_emb) @@ -69,19 +76,14 @@ class VariableTest(unittest.TestCase): self.assertTrue(result_run) tf.reset_default_graph() - @patch.multiple("mx_rec.util.variable", - get_table_instance=mock.MagicMock(return_value=MockTableInstance())) - def test_check_and_get_config_via_var_when_environment_error(self): + @mock.patch("mx_rec.util.variable.ConfigInitializer") + def test_check_and_get_config_via_var_when_environment_error(self, variable_config_initializer): + mock_config_initializer = MockConfigInitializer(var=MockTableInstance()) + variable_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + with self.assertRaises(EnvironmentError): self.assertEqual(MockTableInstance(), check_and_get_config_via_var("1", "optimize")) - def test_check_and_get_config_via_var_when_success(self): - table_instance = MockTableInstance() - table_instance.skip_emb_transfer = True - table_instance.optimizer = True - with patch("mx_rec.util.variable.get_table_instance") as mock_get_table_instance: - mock_get_table_instance.return_value = mock.MagicMock(table_instance) - self.assertEqual(mock_get_table_instance.return_value, check_and_get_config_via_var("1", "optimize")) if __name__ == '__main__': unittest.main() -- Gitee From c7a0e74da63074f41edb030824ef21dbade5bf51 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Sun, 18 Feb 2024 15:20:15 +0800 Subject: [PATCH 536/551] Match-id-6ebb46d9e69f68ba1c940d85fc3bf5af2176aaf1 --- mx_rec/core/emb/sparse_embedding.py | 14 +++++++++++++- mx_rec/core/embedding.py | 4 ++++ tests/mx_rec/core/test_embedding.py | 20 ++++++++++++++++---- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/mx_rec/core/emb/sparse_embedding.py b/mx_rec/core/emb/sparse_embedding.py index d8ce63b1..a7898db0 100644 --- a/mx_rec/core/emb/sparse_embedding.py +++ b/mx_rec/core/emb/sparse_embedding.py @@ -82,18 +82,30 @@ class HBMSparseEmbedding(SparseEmbedding): """ def __init__(self, config: dict): + self.emb_optimizer = EmbOptimizer(config.get("optimizer_list")) + self.emb_optimizer.check_optimizer_instance_list() + super(HBMSparseEmbedding, self).__init__(config) + @property + def optimizer(self): + return self.emb_optimizer.optimizer + + @property + def optimizer_instance_list(self): + return self.emb_optimizer.optimizer_instance_list + def capacity(self) -> int: return self._device_vocabulary_size def set_optimizer(self, key: str, state_dict: dict): - pass + self.emb_optimizer.set_optimizer(key, state_dict, self._table_name) def _build_optimizer_states(self): pass def _set_ext_emb_size(self): + self._ext_coefficient += len(self.emb_optimizer.optimizer_slot_info_list) self._ext_emb_size = self._emb_size * self._ext_coefficient logger.debug("Init table, ext_emb_size is set to be %s.", self._ext_emb_size) diff --git a/mx_rec/core/embedding.py b/mx_rec/core/embedding.py index 0438f281..b7fc9470 100644 --- a/mx_rec/core/embedding.py +++ b/mx_rec/core/embedding.py @@ -166,6 +166,10 @@ def sparse_lookup(hashtable: BaseSparseEmbedding, logger.info("Lookup: The table name is %s, and the value of `is_grad` in this lookup (lookup name is %s) is %s.", hashtable.table_name, name, is_grad) + # 校验一表多查次数 + hashtable.increase_multi_lookup_times(is_train) + check_emb_multi_lookup_times(hashtable.multi_lookup_times.get(is_train), hashtable.table_name) + # 对于向上找没有IteratorGetNext的孤儿ids需要标记,以便于后续ACGPushOpsToDataset工作 if isinstance(ids, tf.Tensor): ids = tag_orphan_ids(ids) diff --git a/tests/mx_rec/core/test_embedding.py b/tests/mx_rec/core/test_embedding.py index bf7d9240..e47b3afd 100644 --- a/tests/mx_rec/core/test_embedding.py +++ b/tests/mx_rec/core/test_embedding.py @@ -24,7 +24,7 @@ from mx_rec.core.asc import FeatureSpec from mx_rec.core.asc.feature_spec import set_temporary_feature_spec_attribute from mx_rec.core.emb.dynamic_sparse_embedding import HBMDynamicSparseEmbedding from mx_rec.core.emb.sparse_embedding import HBMSparseEmbedding, ExternalStorageSparseEmbedding -from mx_rec.optimizers.gradient_descent import create_hash_optimizer +from mx_rec.optimizers.gradient_descent import create_hash_optimizer, CustomizedGradientDescent from tests.mx_rec.core.mock_class import MockConfigInitializer @@ -84,11 +84,15 @@ class TestCreateTableFunc(unittest.TestCase): base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + # prepare optimizer list + optimizer_list = [CustomizedGradientDescent(learning_rate=0.001, use_locking=False, name="GradientDescent")] + # test test_table = create_table(key_dtype=tf.int64, dim=8, name='test_table', - emb_initializer=tf.compat.v1.truncated_normal_initializer()) + emb_initializer=tf.compat.v1.truncated_normal_initializer(), + optimizer_list=optimizer_list) self.assertIsInstance(test_table, HBMSparseEmbedding) @mock.patch.multiple("mx_rec.core.emb.base_sparse_embedding", @@ -159,6 +163,9 @@ class TestSparseLookupFunc(unittest.TestCase): emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + # prepare optimizer list + optimizer_list = [CustomizedGradientDescent(learning_rate=0.001, use_locking=False, name="GradientDescent")] + case1_feat = FeatureSpec("case1_feat", table_name="test_table") set_temporary_feature_spec_attribute(case1_feat, 1) case1_feat.dims = [8, 8] @@ -179,7 +186,8 @@ class TestSparseLookupFunc(unittest.TestCase): dim=8, name='test_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), - device_vocabulary_size=100 * 8) + device_vocabulary_size=100 * 8, + optimizer_list=optimizer_list) self.assertIsInstance(test_table, HBMSparseEmbedding) res = sparse_lookup(test_table, case1_feat, batch=batch) @@ -216,6 +224,9 @@ class TestSparseLookupFunc(unittest.TestCase): sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) feature_spec_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + # prepare optimizer list + optimizer_list = [CustomizedGradientDescent(learning_rate=0.001, use_locking=False, name="GradientDescent")] + case2_feat = tf.ones(shape=[8, 8], dtype=tf.int64) mock_get_preprocessed_tensor_for_asc.return_value = { "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), @@ -232,7 +243,8 @@ class TestSparseLookupFunc(unittest.TestCase): dim=8, name='test_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), - device_vocabulary_size=100 * 8) + device_vocabulary_size=100 * 8, + optimizer_list=optimizer_list) self.assertIsInstance(test_table, HBMSparseEmbedding) res = sparse_lookup(test_table, case2_feat, modify_graph=True) -- Gitee From 3693dd936ab4fca9f24219a436b570c04ef7f7c9 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 20 Feb 2024 16:21:55 +0800 Subject: [PATCH 537/551] Match-id-92e720d487334ab47e1c12224156dcc0354ab045 --- src/core/utils/common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 5c2d316e..1f5fa3e3 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -223,7 +223,7 @@ namespace MxRec { bool useHot {}; uint32_t option {}; int nBatch {}; - bool isDDR { true }; + bool isDDR { false }; bool isSSDEnabled { false }; bool useDynamicExpansion {false}; std::vector maxStep; -- Gitee From 1c632b3ed5fa9b6ec1f71b718a777a141207c8ae Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 21 Feb 2024 16:53:49 +0800 Subject: [PATCH 538/551] Match-id-53c68ad7502ae39abe73917a271fd49dfc02514a --- mx_rec/validator/validator.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/mx_rec/validator/validator.py b/mx_rec/validator/validator.py index af064a81..c9abde87 100644 --- a/mx_rec/validator/validator.py +++ b/mx_rec/validator/validator.py @@ -354,7 +354,19 @@ class NumValidator(Validator): value = value.as_list()[0] if isinstance(value, tf.Tensor): sess = tf.Session() if tf.__version__.startswith("1.") else tf.compat.v1.Session() - value = sess.run(value).item() + try: + value = sess.run(value).item() + except Exception as e: + # 当前仅支持数值类型Tensor和feed数值类型的tf.PlaceHolder,其它tensor可能会导致程序异常 + logger.warning("[Validator] Parameter %s is passed, and an exception occurred while getting the value " + "in the tensor: \n%s\n. Ensure that the passed parameter is a constant tensor or " + "a tf.PlaceHolder that feeds a constant value. Otherwise, an exception may occur.", + value, e) + + value = 0 if min_value is None else int(min_value) + if isinstance(self, FloatValidator): + value = 0.0 if min_value is None else float(min_value) + super(NumValidator, self).__init__(name, value) self.min_value = min_value -- Gitee From e797e5743a1ebe391519aa030337cb40df3f2bb8 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Wed, 21 Feb 2024 17:00:42 +0800 Subject: [PATCH 539/551] Match-id-54ae682e8d2105b69add60a3a69674a8929bf66c --- examples/demo/little_demo/run.sh | 2 +- examples/demo/little_demo_estimator/run.sh | 2 +- src/ops_tf/hybrid_dataset_ops.cpp | 30 +++++++++++----------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/demo/little_demo/run.sh b/examples/demo/little_demo/run.sh index 41e285f4..10d19233 100644 --- a/examples/demo/little_demo/run.sh +++ b/examples/demo/little_demo/run.sh @@ -83,7 +83,7 @@ export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH # 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 export JOB_ID=10086 # 训练任务使用的NPU卡数总数 -export MXREC_LOG_LEVEL="DEBUG" # 框架日志等级 +export MXREC_LOG_LEVEL="INFO" # 框架日志等级 export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL # 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL diff --git a/examples/demo/little_demo_estimator/run.sh b/examples/demo/little_demo_estimator/run.sh index f4ae3b03..5fbc708f 100644 --- a/examples/demo/little_demo_estimator/run.sh +++ b/examples/demo/little_demo_estimator/run.sh @@ -78,7 +78,7 @@ export LD_LIBRARY_PATH=${so_path}:/usr/local/lib:$LD_LIBRARY_PATH # 集合通信文件,格式请参考昇腾官网CANN文档,“准备资源配置文件”章节。 export JOB_ID=10086 # 训练任务使用的NPU卡数总数 -export MXREC_LOG_LEVEL="DEBUG" # 框架日志等级 +export MXREC_LOG_LEVEL="INFO" # 框架日志等级 export TF_CPP_MIN_LOG_LEVEL=3 # tensorflow日志级别,3对应FATAL # 设置应用类日志的全局日志级别及各模块日志级别,具体请参考昇腾官网CANN文档 export ASCEND_GLOBAL_LOG_LEVEL=3 # “设置日志级别”章节0:debug, 1:info, 2:warning, 3:error, 4:NULL diff --git a/src/ops_tf/hybrid_dataset_ops.cpp b/src/ops_tf/hybrid_dataset_ops.cpp index 7d5cb1a4..74be3366 100644 --- a/src/ops_tf/hybrid_dataset_ops.cpp +++ b/src/ops_tf/hybrid_dataset_ops.cpp @@ -49,7 +49,7 @@ namespace MxRec { public: explicit ClearChannel(OpKernelConstructionPtr context) : OpKernel(context) { - LOG_INFO("clear channel init"); + LOG_DEBUG("clear channel init"); OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId)); if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) { @@ -64,7 +64,7 @@ namespace MxRec { void Compute(OpKernelContextPtr context) override { - LOG_INFO("clear channel {}, context {}", channelId, context->step_id()); + LOG_DEBUG("clear channel {}, context {}", channelId, context->step_id()); HybridMgmtBlock* hybridMgmtBlock = Singleton::GetInstance(); hybridMgmtBlock->ResetAll(channelId); } @@ -77,7 +77,7 @@ namespace MxRec { public: explicit SetThreshold(OpKernelConstructionPtr context) : OpKernel(context) { - LOG_INFO("SetThreshold init"); + LOG_DEBUG("SetThreshold init"); OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embName)); OP_REQUIRES_OK(context, context->GetAttr("ids_name", &idsName)); // sparse_lookup查询 } @@ -111,7 +111,7 @@ namespace MxRec { return; } } else { - LOG_WARN("SetThreshold failed, because feature admit-and-evict switch is closed"); + LOG_DEBUG("SetThreshold failed, because feature admit-and-evict switch is closed"); } Tensor* output = nullptr; @@ -130,8 +130,8 @@ namespace MxRec { LOG_ERROR("set threshold[{}] < 0 ", threshold); return 0; } - LOG_INFO("ParseThresholdAndCheck, emb_name:[{}], ids_name: [{}], threshold: [{}]", - embName, idsName, threshold); + LOG_DEBUG("ParseThresholdAndCheck, emb_name:[{}], ids_name: [{}], threshold: [{}]", + embName, idsName, threshold); return 1; } @@ -184,7 +184,7 @@ namespace MxRec { MAX_CHANNEL_NUM))); return; } - LOG_INFO(HYBRID_BLOCKING + " reset channel {}", channelId); + LOG_DEBUG(HYBRID_BLOCKING + " reset channel {}", channelId); hybridMgmtBlock->ResetAll(channelId); threadNum = GetThreadNumEnv(); @@ -209,7 +209,7 @@ namespace MxRec { out(0) = batchId; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { - LOG_WARN("skip excess batch after {}/{}", batchId, maxStep); + LOG_DEBUG("skip excess batch after {}/{}", batchId, maxStep); return; } } @@ -248,7 +248,7 @@ namespace MxRec { auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < embNames.size(); ++i) { if (!keyProcess->HasEmbName(embNames.at(i))) { - LOG_INFO("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i)); + LOG_DEBUG("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i)); tableUsed.push_back(false); } else { tableUsed.push_back(true); @@ -305,7 +305,7 @@ namespace MxRec { // 前面8个字节、即占一个featureId位,是unix时间戳 auto src = reinterpret_cast(inputTensor.tensor_data().data()); std::copy(src, src + 1, ×tamp); - LOG_INFO("current batchId[{}] timestamp[{}]", batchId, timestamp); + LOG_DEBUG("current batchId[{}] timestamp[{}]", batchId, timestamp); dataSize -= 1; if (timestamp <= 0) { @@ -373,7 +373,7 @@ namespace MxRec { MAX_CHANNEL_NUM))); return; } - LOG_INFO(HYBRID_BLOCKING + " reset channel {}", channelId); + LOG_DEBUG(HYBRID_BLOCKING + " reset channel {}", channelId); // 重置此数据通道中所有的步数 hybridMgmtBlock->ResetAll(channelId); @@ -400,7 +400,7 @@ namespace MxRec { out(0) = batchId; if (channelId == 1) { if (maxStep != -1 && batchId >= maxStep) { - LOG_WARN(StringFormat("skip excess batch after {}/{}", batchId, maxStep)); + LOG_DEBUG(StringFormat("skip excess batch after {}/{}", batchId, maxStep)); return; } } @@ -434,7 +434,7 @@ namespace MxRec { auto keyProcess = Singleton::GetInstance(); for (size_t i = 0; i < splits.size(); ++i) { if (!keyProcess->HasEmbName(embNames.at(i))) { - LOG_INFO("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i)); + LOG_DEBUG("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i)); tableUsed.push_back(false); } else { tableUsed.push_back(true); @@ -491,7 +491,7 @@ namespace MxRec { // 前面8个字节、即占一个featureId位,是unix时间戳 auto src = reinterpret_cast(inputTensor.tensor_data().data()); std::copy(src, src + 1, ×tamp); - LOG_INFO("current batchId[{}] timestamp[{}]", batchId, timestamp); + LOG_DEBUG("current batchId[{}] timestamp[{}]", batchId, timestamp); dataSize -= 1; if (timestamp <= 0) { @@ -534,7 +534,7 @@ namespace MxRec { void Compute(OpKernelContextPtr context) override { - LOG_INFO("context {}", context->step_id()); + LOG_DEBUG("context {}", context->step_id()); std::cout << " Cust opp not installed!!" << std::endl; } -- Gitee From a917e75bce94d7b499f12e9161c55dd7a7d36879 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 22 Feb 2024 11:54:14 +0800 Subject: [PATCH 540/551] Match-id-669c84c24fae02a3ae50688393c9c0128b56dcf3 --- src/core/emb_hashmap/emb_hashmap.cpp | 2 + src/core/emb_table/embedding_ddr.cpp | 263 ++++++++++++++------- src/core/emb_table/embedding_ddr.h | 13 +- src/core/emb_table/embedding_dynamic.cpp | 10 +- src/core/emb_table/embedding_mgmt.cpp | 14 ++ src/core/emb_table/embedding_mgmt.h | 4 + src/core/emb_table/embedding_static.cpp | 22 +- src/core/emb_table/embedding_table.cpp | 75 ++++-- src/core/emb_table/embedding_table.h | 26 +- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 14 +- src/core/hybrid_mgmt/hybrid_mgmt.h | 3 +- src/core/ssd_cache/cache_manager.cpp | 98 ++++---- src/core/ssd_cache/cache_manager.h | 28 ++- src/tests/emb_table/embedding_ddr_test.cpp | 16 +- src/tests/ssd_cache/cache_manager_test.cpp | 128 +++++----- 15 files changed, 457 insertions(+), 259 deletions(-) diff --git a/src/core/emb_hashmap/emb_hashmap.cpp b/src/core/emb_hashmap/emb_hashmap.cpp index c380aab2..977b2c0b 100644 --- a/src/core/emb_hashmap/emb_hashmap.cpp +++ b/src/core/emb_hashmap/emb_hashmap.cpp @@ -68,6 +68,8 @@ void EmbHashMap::Process(const string& embName, vector& keys, DDRPara vector swapPos; vector lookUpVec = table->FindOffset(keys, swapId, channelId, swapPos); + table->RefreshFreqInfoWithSwap(); + EASY_BLOCK("hostHashMaps->tdt") std::copy(lookUpVec.begin(), lookUpVec.end(), std::back_inserter(ddrParam.offsetsOut)); diff --git a/src/core/emb_table/embedding_ddr.cpp b/src/core/emb_table/embedding_ddr.cpp index 4a4b09a6..6234aff0 100644 --- a/src/core/emb_table/embedding_ddr.cpp +++ b/src/core/emb_table/embedding_ddr.cpp @@ -9,8 +9,8 @@ #include "utils/logger.h" #include "utils/singleton.h" #include "host_emb/host_emb.h" -#include "hd_transfer/hd_transfer.h" #include "file_system/file_system_handler.h" +#include "ssd_cache/cache_manager.h" using namespace MxRec; @@ -32,10 +32,10 @@ EmbeddingDDR::EmbeddingDDR() EmbeddingDDR::EmbeddingDDR(const EmbInfo& info, const RankInfo& rankInfo, int inSeed) : EmbeddingTable(info, rankInfo, inSeed) { - LOG_INFO("Init DDR table [{}] devVocabSize = {} hostVocabSize = {}", name_, devVocabSize_, hostVocabSize_); + LOG_INFO("Init DDR table [{}] devVocabSize = {} hostVocabSize = {}", name, devVocabSize, hostVocabSize); currentUpdatePos = 0; - devOffset2Key.resize(devVocabSize_); - devOffset2Batch.resize(devVocabSize_); + devOffset2Key.resize(devVocabSize); + devOffset2Batch.resize(devVocabSize); std::fill(devOffset2Batch.begin(), devOffset2Batch.end(), -1); std::fill(devOffset2Key.begin(), devOffset2Key.end(), -1); } @@ -59,7 +59,7 @@ std::vector EmbeddingDDR::FindOffset(const vector& keys, { devOffset2KeyOld.clear(); oldSwap.clear(); - maxOffsetOld = maxOffset_; + maxOffsetOld = maxOffset; UpdateBatchId(keys, batchId); std::vector lookUpVec; @@ -74,33 +74,34 @@ std::vector EmbeddingDDR::FindOffset(const vector& keys, lookUpVec.emplace_back(INVALID_KEY_VALUE); continue; } - if (offset < devVocabSize_) { + AddKeyFreqInfo(key, RecordType::NOT_DDR); + if (offset < devVocabSize) { // 偏移小于等于HBM容量:直接放入查询向量;更新偏移之前关联的key和当前关联的key lookUpVec.push_back(offset); devOffset2KeyOld.emplace_back(offset, static_cast(devOffset2Key[offset])); devOffset2Key[offset] = key; } else { // 偏移大于HBM容量:记录在host emb上的偏移;找到需要交换的HBM偏移 - missingKeysHostPos_.emplace_back(offset - devVocabSize_); + missingKeysHostPos_.emplace_back(offset - devVocabSize); offset = FindSwapPosOld(key, offset, batchId, swapPos); lookUpVec.emplace_back(offset); } } if (batchId == 0) { - LOG_INFO("max offset {}", maxOffset_); + LOG_INFO("max offset {}", maxOffset); } - LOG_TRACE("keyOffsetMap_, {}", MapToString(keyOffsetMap_)); + LOG_TRACE("keyOffsetMap, {}", MapToString(keyOffsetMap)); return lookUpVec; } emb_key_t EmbeddingDDR::FindOffsetHelper(const emb_key_t& key, int channelId) { - const auto& iter = keyOffsetMap_.find(key); + const auto& iter = keyOffsetMap.find(key); emb_key_t offset = INVALID_KEY_VALUE; - if (iter != keyOffsetMap_.end()) { + if (iter != keyOffsetMap.end()) { offset = iter->second; - LOG_TRACE("devVocabSize, {} , offset , {}", devVocabSize_, offset); - if (offset >= devVocabSize_) { + LOG_TRACE("devVocabSize, {} , offset , {}", devVocabSize, offset); + if (offset >= devVocabSize) { ddr2HbmKeys.emplace_back(key); } return offset; @@ -108,33 +109,33 @@ emb_key_t EmbeddingDDR::FindOffsetHelper(const emb_key_t& key, int channelId) if (channelId != TRAIN_CHANNEL_ID) { return offset; } - if (evictPos_.size() != 0) { // 优先复用hbm表 - offset = evictPos_.back(); - keyOffsetMap_[key] = offset; - LOG_TRACE("ddr mode, dev evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, evictPos_.size()); - evictPos_.pop_back(); + if (evictDevPos.size() != 0) { // 优先复用hbm表 + offset = evictDevPos.back(); + keyOffsetMap[key] = offset; + LOG_TRACE("ddr mode, dev evictDevPos is not null, key [{}] reuse offset [{}], evictSize [{}]", + key, offset, evictDevPos.size()); + evictDevPos.pop_back(); LOG_ERROR("dev evicted offset = {}", offset); return offset; } - if (evictHostPos_.size() != 0) { // hbm不足,再复用host/ddr表 - offset = evictHostPos_.back(); - keyOffsetMap_[key] = offset; + if (evictHostPos.size() != 0) { // hbm不足,再复用host/ddr表 + offset = evictHostPos.back(); + keyOffsetMap[key] = offset; LOG_TRACE("ddr mode, host evictPos is not null, key [{}] reuse offset [{}], evictSize [{}]", - key, offset, evictHostPos_.size()); - evictHostPos_.pop_back(); - LOG_ERROR("host evicted offset = {}", offset); + key, offset, evictHostPos.size()); + evictHostPos.pop_back(); + LOG_TRACE("host evicted offset = {}", offset); return offset; } - keyOffsetMap_[key] = maxOffset_; - offset = maxOffset_; - maxOffset_++; - if (maxOffset_ == devVocabSize_) { + keyOffsetMap[key] = maxOffset; + offset = maxOffset; + maxOffset++; + if (maxOffset == devVocabSize) { LOG_INFO("start using host vocab!"); } - if (maxOffset_ > (hostVocabSize_ + devVocabSize_)) { - LOG_ERROR("hostVocabSize too small! dev:{} host:{}", devVocabSize_, hostVocabSize_); + if (maxOffset > (hostVocabSize + devVocabSize)) { + LOG_ERROR("hostVocabSize too small! dev:{} host:{}", devVocabSize, hostVocabSize); throw runtime_error("hostVocabSize too small"); } return offset; @@ -148,12 +149,12 @@ void EmbeddingDDR::UpdateBatchId(const vector& keys, size_t currentBa if (key == -1) { continue; } - const auto& iter = keyOffsetMap_.find(key); - if (iter != keyOffsetMap_.end()) { + const auto& iter = keyOffsetMap.find(key); + if (iter != keyOffsetMap.end()) { offset = iter->second; LOG_TRACE("key will be used, {} , offset , {}", key, offset); - if (offset < devVocabSize_) { + if (offset < devVocabSize) { // devOffset2Batch size equal to devVocabSize, unnecessary to check index boundary devOffset2Batch[offset] = static_cast(currentBatchId); } @@ -184,12 +185,12 @@ emb_key_t EmbeddingDDR::FindSwapPosOld(emb_key_t key, size_t hostOffset, size_t devOffset2Batch[currentUpdatePos] = static_cast(batchId); swapPos.emplace_back(currentUpdatePos); // 记录需要被换出的HBM偏移 offset = currentUpdatePos; - keyOffsetMap_[key] = currentUpdatePos; // 更新key对应的HBM偏移 + keyOffsetMap[key] = currentUpdatePos; // 更新key对应的HBM偏移 // 记录HBM偏移之前的key devOffset2KeyOld.emplace_back(currentUpdatePos, devOffset2Key[currentUpdatePos]); auto& oldKey = devOffset2Key[currentUpdatePos]; oldSwap.emplace_back(oldKey, key); // 记录交换的两个key oldKey:HBM->DDR key:DDR->HBM - keyOffsetMap_[oldKey] = hostOffset; // 更新被替换的key的偏移 + keyOffsetMap[oldKey] = hostOffset; // 更新被替换的key的偏移 oldKey = key; notFind = false; } @@ -197,7 +198,7 @@ emb_key_t EmbeddingDDR::FindSwapPosOld(emb_key_t key, size_t hostOffset, size_t freeSize_--; // HBM可用空间-1 // 遍历完一遍整个HBM表后,从头开始遍历 - if (currentUpdatePos == devVocabSize_) { + if (currentUpdatePos == devVocabSize) { currentUpdatePos = 0; } @@ -230,32 +231,36 @@ void EmbeddingDDR::EvictDeleteEmb(const vector& keys) LOG_WARN("evict key equal -1!"); continue; } - const auto& iter = keyOffsetMap_.find(key); - if (iter == keyOffsetMap_.end()) { + const auto& iter = keyOffsetMap.find(key); + if (iter == keyOffsetMap.end()) { // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 continue; } offset = iter->second; - keyOffsetMap_.erase(iter); - LOG_TRACE("evict embName {}, offset {}", name_, offset); + keyOffsetMap.erase(iter); + LOG_TRACE("evict embName {}, offset {}", name, offset); - if (offset < devVocabSize_) { + if (offset < devVocabSize) { // offset 在device中 devOffset2Batch[offset] = -1; devOffset2KeyOld.emplace_back(offset, devOffset2Key[offset]); devOffset2Key[offset] = -1; - evictPos_.emplace_back(offset); + evictDevPos.emplace_back(offset); evictHBMKeys.emplace_back(key); } else { // offset 在Host - evictHostPos_.emplace_back(offset); + evictHostPos.emplace_back(offset); evictDDRKeys.emplace_back(key); // 删除映射表、初始化host表、发送dev淘汰位置 } } + if (isSSDEnabled_) { + cacheManager_->RefreshFreqInfoCommon(name, evictHBMKeys, TransferType::HBM_2_EVICT); + cacheManager_->RefreshFreqInfoCommon(name, evictDDRKeys, TransferType::DDR_2_EVICT); + } LOG_INFO("ddr EvictDeleteEmb, emb: [{}], hostEvictSize: {}, devEvictSize: {}", - name_, evictPos_.size(), evictHostPos_.size()); - LOG_TRACE("keyOffsetMap_, {}", MapToString(keyOffsetMap_)); + name, evictHostPos.size(), evictDevPos.size()); + LOG_TRACE("keyOffsetMap, {}", MapToString(keyOffsetMap)); } /// DDR模式下的淘汰:删除映射表、初始化host表、发送dev淘汰位置 @@ -270,22 +275,22 @@ void EmbeddingDDR::EvictKeys(const vector& keys) LOG_WARN("evict key equal -1!"); continue; } - const auto& iter = keyOffsetMap_.find(key); - if (iter == keyOffsetMap_.end()) { + const auto& iter = keyOffsetMap.find(key); + if (iter == keyOffsetMap.end()) { continue; } // 淘汰依据keyProcess中的history,hashmap映射关系创建于ParseKey;两者异步,造成淘汰的值在hashmap里可能未创建 offset = iter->second; - keyOffsetMap_.erase(iter); - LOG_TRACE("evict embName {}, offset {}", name_, offset); + keyOffsetMap.erase(iter); + LOG_TRACE("evict embName {}, offset {}", name, offset); - if (offset < devVocabSize_) { + if (offset < devVocabSize) { devOffset2Batch[offset] = INVALID_KEY_VALUE; devOffset2KeyOld.emplace_back(offset, devOffset2Key[offset]); devOffset2Key[offset] = INVALID_KEY_VALUE; - evictPos_.emplace_back(offset); + evictDevPos.emplace_back(offset); } else { - evictHostPos_.emplace_back(offset); + evictHostPos.emplace_back(offset); } } } @@ -298,7 +303,7 @@ void EmbeddingDDR::ClearLookupAndSwapOffset() void EmbeddingDDR::SetStartCount() { currentUpdatePosStart = currentUpdatePos; - freeSize_ = devVocabSize_; + freeSize_ = devVocabSize; } int EmbeddingDDR::Load(const string& savePath) @@ -324,7 +329,7 @@ int EmbeddingDDR::Save(const string& savePath) int EmbeddingDDR::LoadHashMap(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_hashmap/slice_" << rankId_ << ".data"; + ss << savePath << "/HashTable/DDR/" << name <<"/embedding_hashmap/slice_" << rankId_ << ".data"; unique_ptr fileSystemHandler = make_unique(); unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); @@ -348,7 +353,7 @@ int EmbeddingDDR::LoadHashMap(const string& savePath) } fileSystemPtr->Read(ss.str(), reinterpret_cast(buf), fileSize); for (int i = 0; i < fileSize / sizeof(int64_t); i = i + 2) { // key, offset进行pair对存储 - keyOffsetMap_[buf[i]] = buf[i + 1]; + keyOffsetMap[buf[i]] = buf[i + 1]; } free(static_cast(buf)); return 0; @@ -357,7 +362,7 @@ int EmbeddingDDR::LoadHashMap(const string& savePath) int EmbeddingDDR::LoadDevOffset(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/" << name_ <<"/dev_offset_2_Batch_n_Key/slice_" << rankId_ << ".data"; + ss << savePath << "/HashTable/DDR/" << name <<"/dev_offset_2_Batch_n_Key/slice_" << rankId_ << ".data"; unique_ptr fileSystemHandler = make_unique(); unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); @@ -381,7 +386,7 @@ int EmbeddingDDR::LoadDevOffset(const string& savePath) int EmbeddingDDR::LoadCurrStat(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_current_status/slice_" << rankId_ << ".data"; + ss << savePath << "/HashTable/DDR/" << name <<"/embedding_current_status/slice_" << rankId_ << ".data"; unique_ptr fileSystemHandler = make_unique(); unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); @@ -389,16 +394,16 @@ int EmbeddingDDR::LoadCurrStat(const string& savePath) size_t raw[ELEMENT_NUM] = {0}; fileSystemPtr->Read(ss.str(), reinterpret_cast(raw), sizeof(raw)); currentUpdatePos = raw[CURRENT_UPDATE_IDX]; - hostVocabSize_ = raw[HOST_VOCAB_SIZE_IDX]; - devVocabSize_ = raw[MAX_OFFSET_IDX]; - maxOffset_ = raw[MAX_OFFSET_IDX]; + hostVocabSize = raw[HOST_VOCAB_SIZE_IDX]; + devVocabSize = raw[MAX_OFFSET_IDX]; + maxOffset = raw[MAX_OFFSET_IDX]; return 0; } int EmbeddingDDR::LoadEvictPos(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/" << name_ <<"/evict_pos/slice_" << rankId_ << ".data"; + ss << savePath << "/HashTable/DDR/" << name <<"/evict_pos/slice_" << rankId_ << ".data"; unique_ptr fileSystemHandler = make_unique(); unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); @@ -414,16 +419,16 @@ int EmbeddingDDR::LoadEvictPos(const string& savePath) LOG_ERROR("File {} size = {} is too big", ss.str(), fileSize); return -1; } - evictPos_.resize(fileSize / sizeof(int64_t)); + evictDevPos.resize(fileSize / sizeof(int64_t)); - fileSystemPtr->Read(ss.str(), reinterpret_cast(evictPos_.data()), fileSize); + fileSystemPtr->Read(ss.str(), reinterpret_cast(evictDevPos.data()), fileSize); return 0; } int EmbeddingDDR::LoadEmbInfo(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_info/slice_" << rankId_ << ".data"; + ss << savePath << "/HashTable/DDR/" << name <<"/embedding_info/slice_" << rankId_ << ".data"; unique_ptr fileSystemHandler = make_unique(); unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); @@ -431,21 +436,21 @@ int EmbeddingDDR::LoadEmbInfo(const string& savePath) size_t raw[EMB_INFO_ELEMENT_NUM] = {0}; fileSystemPtr->Read(ss.str(), reinterpret_cast(raw), sizeof(raw)); extEmbSize_ = raw[EMB_INFO_EXT_SIZE_IDX]; - devVocabSize_ = raw[EMB_INFO_DEV_VOCAB_SIZE_IDX]; - hostVocabSize_ = raw[EMB_INFO_HOST_VOCAB_SIZE_IDX]; + devVocabSize = raw[EMB_INFO_DEV_VOCAB_SIZE_IDX]; + hostVocabSize = raw[EMB_INFO_HOST_VOCAB_SIZE_IDX]; return 0; } int EmbeddingDDR::LoadEmbData(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_data/slice_" << rankId_ << ".data"; + ss << savePath << "/HashTable/DDR/" << name <<"/embedding_data/slice_" << rankId_ << ".data"; unique_ptr fileSystemHandler = make_unique(); unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); HostEmb* hostEmbs = Singleton::GetInstance(); - HostEmbTable& table = hostEmbs->GetEmb(name_); + HostEmbTable& table = hostEmbs->GetEmb(name); if (table.embData.empty()) { LOG_ERROR("hostEmb data is empty"); return -1; @@ -457,7 +462,7 @@ int EmbeddingDDR::LoadEmbData(const string& savePath) int EmbeddingDDR::SaveHashMap(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/" << name_ <<"/embedding_hashmap/"; + ss << savePath << "/HashTable/DDR/" << name <<"/embedding_hashmap/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; @@ -465,7 +470,7 @@ int EmbeddingDDR::SaveHashMap(const string& savePath) unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); vector raw; - for (const auto& it : keyOffsetMap_) { + for (const auto& it : keyOffsetMap) { raw.push_back(it.first); raw.push_back(static_cast(it.second)); } @@ -477,7 +482,7 @@ int EmbeddingDDR::SaveHashMap(const string& savePath) int EmbeddingDDR::SaveDevOffset(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/" << name_ <<"/dev_offset_2_Batch_n_Key/"; + ss << savePath << "/HashTable/DDR/" << name <<"/dev_offset_2_Batch_n_Key/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; @@ -492,7 +497,7 @@ int EmbeddingDDR::SaveDevOffset(const string& savePath) int EmbeddingDDR::SaveCurrStat(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/"<< name_ <<"/embedding_current_status/"; + ss << savePath << "/HashTable/DDR/"<< name <<"/embedding_current_status/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; @@ -501,9 +506,9 @@ int EmbeddingDDR::SaveCurrStat(const string& savePath) size_t raw[ELEMENT_NUM] = {0}; raw[CURRENT_UPDATE_IDX] = currentUpdatePos; - raw[HOST_VOCAB_SIZE_IDX] = hostVocabSize_; - raw[DEV_VOCAB_SIZE_IDX] = devVocabSize_; - raw[MAX_OFFSET_IDX] = maxOffset_; + raw[HOST_VOCAB_SIZE_IDX] = hostVocabSize; + raw[DEV_VOCAB_SIZE_IDX] = devVocabSize; + raw[MAX_OFFSET_IDX] = maxOffset; fileSystemPtr->Write(ss.str(), reinterpret_cast(raw), sizeof(raw)); return 0; } @@ -511,22 +516,22 @@ int EmbeddingDDR::SaveCurrStat(const string& savePath) int EmbeddingDDR::SaveEvictPos(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/" << name_ << "/evict_pos/"; + ss << savePath << "/HashTable/DDR/" << name << "/evict_pos/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; unique_ptr fileSystemHandler = make_unique(); unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); - fileSystemPtr->Write(ss.str(), reinterpret_cast(evictPos_.data()), - static_cast(evictPos_.size() * sizeof(int64_t))); + fileSystemPtr->Write(ss.str(), reinterpret_cast(evictDevPos.data()), + static_cast(evictDevPos.size() * sizeof(int64_t))); return 0; } int EmbeddingDDR::SaveEmbInfo(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/"<< name_ <<"/embedding_info/"; + ss << savePath << "/HashTable/DDR/"<< name <<"/embedding_info/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; @@ -535,8 +540,8 @@ int EmbeddingDDR::SaveEmbInfo(const string& savePath) size_t raw[EMB_INFO_ELEMENT_NUM] = {}; raw[EMB_INFO_EXT_SIZE_IDX] = extEmbSize_; - raw[EMB_INFO_DEV_VOCAB_SIZE_IDX] = devVocabSize_; - raw[EMB_INFO_HOST_VOCAB_SIZE_IDX] = hostVocabSize_; + raw[EMB_INFO_DEV_VOCAB_SIZE_IDX] = devVocabSize; + raw[EMB_INFO_HOST_VOCAB_SIZE_IDX] = hostVocabSize; fileSystemPtr->Write(ss.str(), reinterpret_cast(raw), sizeof(raw)); return 0; } @@ -544,7 +549,7 @@ int EmbeddingDDR::SaveEmbInfo(const string& savePath) int EmbeddingDDR::SaveEmbData(const string& savePath) { stringstream ss; - ss << savePath << "/HashTable/DDR/"<< name_ <<"/embedding_data/"; + ss << savePath << "/HashTable/DDR/"<< name <<"/embedding_data/"; MakeDir(ss.str()); ss << "slice_" << rankId_ << ".data"; @@ -552,7 +557,7 @@ int EmbeddingDDR::SaveEmbData(const string& savePath) unique_ptr fileSystemPtr = fileSystemHandler->Create(ss.str()); HostEmb* hostEmbs = Singleton::GetInstance(); - HostEmbTable& table = hostEmbs->GetEmb(name_); + HostEmbTable& table = hostEmbs->GetEmb(name); if (table.embData.empty()) { LOG_ERROR("host embedding data is empty"); return 0; @@ -565,3 +570,93 @@ int EmbeddingDDR::SaveEmbData(const string& savePath) fileSystemPtr->Write(ss.str(), content, dataSize * sizeof(float)); return 0; } + +void EmbeddingDDR::SetCacheManager(CacheManager *cm) +{ + cacheManager_ = cm; +} + +void EmbeddingDDR::AddKeyFreqInfo(const emb_key_t& key, RecordType type) +{ + if (!isSSDEnabled_) { + return; + } + cacheManager_->PutKey(name, key, type); +} + +void EmbeddingDDR::RefreshFreqInfoWithSwap() +{ + if (!isSSDEnabled_) { + return; + } + // 换入换出key列表,元素为pair: pair oldKey为从HBM移出的key, key为从DDR移出的key + LOG_DEBUG("RefreshFreqInfoWithSwap:oldSwap Size:{}", oldSwap.size()); + vector enterDDRKeys; + for (auto keyPair : oldSwap) { + enterDDRKeys.emplace_back(keyPair.first); + } + cacheManager_->RefreshFreqInfoCommon(name, enterDDRKeys, TransferType::HBM_2_DDR); + cacheManager_->RefreshFreqInfoCommon(name, ddr2HbmKeys, TransferType::DDR_2_HBM); + + AddCacheManagerTraceLog(); +} + +/// 记录日志:HBM和DDR换入换出后,比较hostHashMap中DDR内key和表对应的lfuCache对象中的key内容 +void EmbeddingDDR::AddCacheManagerTraceLog() const +{ + if (Logger::GetLevel() != Logger::TRACE) { + return; + } + auto& hostMap = keyOffsetMap; + auto& devSize = devVocabSize; + auto iter = cacheManager_->ddrKeyFreqMap.find(name); + if (iter == cacheManager_->ddrKeyFreqMap.end()) { + throw runtime_error("table not in ddrKeyFreqMap"); + } + auto &lfu = iter->second; + const auto& lfuTab = lfu.GetFreqTable(); + if (lfuTab.empty()) { + return; + } + size_t tableKeyInDdr = 0; + vector ddrKeys; // 获取hostHashMap中保存在DDR的key + for (const auto& item : hostMap) { + if (item.second < devSize) { + continue; + } + ddrKeys.emplace_back(item.first); + ++tableKeyInDdr; + } + vector lfuKeys; + for (const auto& it : lfuTab) { + lfuKeys.emplace_back(it.first); + } + std::sort(ddrKeys.begin(), ddrKeys.end()); + std::sort(lfuKeys.begin(), lfuKeys.end()); + std::string ddrKeysString = VectorToString(ddrKeys); + std::string lfuKeysString = VectorToString(lfuKeys); + if (ddrKeysString != lfuKeysString) { + LOG_ERROR("swap HBM with DDR step error, key string not equal, ddrKeysString:{}, lfuKeysString:{}", + ddrKeysString, lfuKeysString); + } else { + LOG_INFO("swap HBM with DDR step OK, table:{}, ddrKeysString == lfuKeysString, string length:{}", + name, lfuKeysString.length()); + } + + LOG_INFO("swap HBM with DDR step end, table:{}, tableKeyInDdr:{}, tableKeyInLfu:{}", + name, tableKeyInDdr, lfu.keyTable.size()); +} + +TableInfo EmbeddingDDR::GetTableInfo() +{ + TableInfo ti = { + .name=name, + .hostVocabSize=hostVocabSize, + .devVocabSize=devVocabSize, + .maxOffset=maxOffset, + .keyOffsetMap=keyOffsetMap, + .evictDevPos=evictDevPos, + .evictHostPos=evictHostPos, + }; + return ti; +} diff --git a/src/core/emb_table/embedding_ddr.h b/src/core/emb_table/embedding_ddr.h index 8025b8c8..9284cec8 100644 --- a/src/core/emb_table/embedding_ddr.h +++ b/src/core/emb_table/embedding_ddr.h @@ -48,6 +48,16 @@ public: int Save(const string& savePath); + void RefreshFreqInfoWithSwap(); + + void AddKeyFreqInfo(const emb_key_t& key, RecordType type); + + void SetCacheManager(CacheManager *cm); + + void AddCacheManagerTraceLog() const; + + TableInfo GetTableInfo(); + GTEST_PRIVATE: int LoadHashMap(const string& savePath); @@ -79,12 +89,11 @@ GTEST_PRIVATE: * (区别于oldSwap: pair.second为已存在于DDR key + 换入换出前映射到DDR的新key) */ std::vector ddr2HbmKeys; - bool isSSDEnabled; std::vector devOffset2Batch; // has -1 /** * 记录HBM上查找空位的当前位置 - * 值域为[0, devVocabSize_] + * 值域为[0, devVocabSize] **/ size_t currentUpdatePos; size_t currentUpdatePosStart; // 记录HBM上查找空位的起始位置 diff --git a/src/core/emb_table/embedding_dynamic.cpp b/src/core/emb_table/embedding_dynamic.cpp index e561b45c..44fc2d71 100644 --- a/src/core/emb_table/embedding_dynamic.cpp +++ b/src/core/emb_table/embedding_dynamic.cpp @@ -51,22 +51,22 @@ void EmbeddingDynamic::Key2Offset(std::vector& keys, int channel) key = INVALID_DYNAMIC_EXPANSION_ADDR; continue; } - const auto& iter = keyOffsetMap_.find(key); - if (iter != keyOffsetMap_.end()) { + const auto& iter = keyOffsetMap.find(key); + if (iter != keyOffsetMap.end()) { key = iter->second; continue; } // 新值 if (channel == TRAIN_CHANNEL_ID) { int64_t addr = GetEmptyEmbeddingAddress(); - keyOffsetMap_[key] = addr; + keyOffsetMap[key] = addr; key = addr; - maxOffset_++; + maxOffset++; continue; } key = INVALID_DYNAMIC_EXPANSION_ADDR; } - LOG_DEBUG("current expansion emb:{}, usage:{}/{})", name_, maxOffset_, devVocabSize_); + LOG_DEBUG("current expansion emb:{}, usage:{}/{})", name, maxOffset, devVocabSize); } int64_t EmbeddingDynamic::capacity() const diff --git a/src/core/emb_table/embedding_mgmt.cpp b/src/core/emb_table/embedding_mgmt.cpp index baf3ead0..c40c66f9 100644 --- a/src/core/emb_table/embedding_mgmt.cpp +++ b/src/core/emb_table/embedding_mgmt.cpp @@ -150,3 +150,17 @@ int EmbeddingMgmt::Save(const string& filePath) tablePair.second->Save(filePath); } } + +void EmbeddingMgmt::SetCacheManagerForEmbTable(CacheManager* cacheManager) +{ + for (auto& table: embeddings) { + table.second->SetCacheManager(cacheManager); + } +} + +void EmbeddingMgmt::EnableSSD() +{ + for (auto& table: embeddings) { + table.second->EnableSSD(); + } +} \ No newline at end of file diff --git a/src/core/emb_table/embedding_mgmt.h b/src/core/emb_table/embedding_mgmt.h index 00666113..efd62774 100644 --- a/src/core/emb_table/embedding_mgmt.h +++ b/src/core/emb_table/embedding_mgmt.h @@ -90,6 +90,10 @@ public: */ int Save(const string& filePath); + void SetCacheManagerForEmbTable(CacheManager* cacheManager); + + void EnableSSD(); + private: EmbeddingMgmt(); diff --git a/src/core/emb_table/embedding_static.cpp b/src/core/emb_table/embedding_static.cpp index 2a4f2705..2b7a0c5a 100644 --- a/src/core/emb_table/embedding_static.cpp +++ b/src/core/emb_table/embedding_static.cpp @@ -30,17 +30,17 @@ void EmbeddingStatic::Key2Offset(std::vector& keys, int channel) if (key == INVALID_KEY_VALUE) { continue; } - const auto& iter = keyOffsetMap_.find(key); - if (iter != keyOffsetMap_.end()) { + const auto& iter = keyOffsetMap.find(key); + if (iter != keyOffsetMap.end()) { key = iter->second; continue; } - if (evictPos_.size() != 0 && channel == TRAIN_CHANNEL_ID) { + if (evictDevPos.size() != 0 && channel == TRAIN_CHANNEL_ID) { // 新值, emb有pos可复用 - size_t offset = evictPos_.back(); - keyOffsetMap_[key] = offset; + size_t offset = evictDevPos.back(); + keyOffsetMap[key] = offset; key = offset; - evictPos_.pop_back(); + evictDevPos.pop_back(); continue; } // 新值 @@ -48,16 +48,16 @@ void EmbeddingStatic::Key2Offset(std::vector& keys, int channel) key = INVALID_KEY_VALUE; continue; } - keyOffsetMap_[key] = maxOffset_; - key = maxOffset_++; + keyOffsetMap[key] = maxOffset; + key = maxOffset++; } - if (maxOffset_ > devVocabSize_) { - LOG_ERROR("dev cache overflow {} > {}", maxOffset_, devVocabSize_); + if (maxOffset > devVocabSize) { + LOG_ERROR("dev cache overflow {} > {}", maxOffset, devVocabSize); throw std::runtime_error("dev cache overflow!"); } } int64_t EmbeddingStatic::capacity() const { - return this->devVocabSize_; + return this->devVocabSize; } diff --git a/src/core/emb_table/embedding_table.cpp b/src/core/emb_table/embedding_table.cpp index d48cee03..1ec48b34 100644 --- a/src/core/emb_table/embedding_table.cpp +++ b/src/core/emb_table/embedding_table.cpp @@ -18,13 +18,13 @@ EmbeddingTable::EmbeddingTable() } EmbeddingTable::EmbeddingTable(const EmbInfo& info, const RankInfo& rankInfo, int inSeed) - : name_(info.name), hostVocabSize_(info.hostVocabSize), devVocabSize_(info.devVocabSize), - freeSize_(0), maxOffset_(0), isDynamic_(rankInfo.useDynamicExpansion), + : name(info.name), hostVocabSize(info.hostVocabSize), devVocabSize(info.devVocabSize), + freeSize_(0), maxOffset(0), isDynamic_(rankInfo.useDynamicExpansion), embSize_(info.embeddingSize), extEmbSize_(info.extEmbeddingSize), embInfo_(info), seed_(inSeed), rankId_(rankInfo.rankId) { LOG_TRACE("table {} isDynamic = {} embeddingSize {} extSize {}", - name_, isDynamic_, embSize_, extEmbSize_); + name, isDynamic_, embSize_, extEmbSize_); } EmbeddingTable::~EmbeddingTable() @@ -51,17 +51,17 @@ std::vector EmbeddingTable::FindOffset(const vector& keys, size_t EmbeddingTable::GetMaxOffset() { - return maxOffset_; + return maxOffset; } int64_t EmbeddingTable::capacity() const { - return static_cast(devVocabSize_); + return static_cast(devVocabSize); } size_t EmbeddingTable::size() const { - return maxOffset_; + return maxOffset; } void EmbeddingTable::EvictKeys(const std::vector& keys) @@ -74,55 +74,55 @@ void EmbeddingTable::EvictKeys(const std::vector& keys) LOG_WARN("evict key is INVALID_KEY_VALUE!"); continue; } - const auto& iter = keyOffsetMap_.find(key); - if (iter == keyOffsetMap_.end()) { // not found + const auto& iter = keyOffsetMap.find(key); + if (iter == keyOffsetMap.end()) { // not found continue; } - keyOffsetMap_.erase(iter); - evictPos_.emplace_back(iter->second); - LOG_TRACE("evict embName:{}, offset:{}", name_, iter->second); + keyOffsetMap.erase(iter); + evictDevPos.emplace_back(iter->second); + LOG_TRACE("evict embName:{}, offset:{}", name, iter->second); } - LOG_INFO("EvictKeys: table [{}] evict size on dev:{}", name_, evictPos_.size()); + LOG_INFO("EvictKeys: table [{}] evict size on dev:{}", name, evictDevPos.size()); } const std::vector& EmbeddingTable::GetEvictedKeys() { - return evictPos_; + return evictDevPos; } const std::vector& EmbeddingTable::GetHostEvictedKeys() { - return evictHostPos_; + return evictHostPos; } void EmbeddingTable::EvictInitDeviceEmb() { - if (evictPos_.size() > devVocabSize_) { + if (evictDevPos.size() > devVocabSize) { LOG_ERROR("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", - name_, evictPos_.size(), devVocabSize_); + name, evictDevPos.size(), devVocabSize); throw runtime_error( Logger::Format("{} overflow! init evict dev, evictOffset size {} bigger than dev vocabSize {}", - name_, evictPos_.size(), devVocabSize_).c_str()); + name, evictDevPos.size(), devVocabSize).c_str()); } vector tmpDataOut; - Tensor tmpData = Vec2TensorI32(evictPos_); + Tensor tmpData = Vec2TensorI32(evictDevPos); tmpDataOut.emplace_back(tmpData); tmpDataOut.emplace_back(Tensor(tensorflow::DT_INT32, { 1 })); auto evictLen = tmpDataOut.back().flat(); - evictLen(0) = static_cast(evictPos_.size()); + evictLen(0) = static_cast(evictDevPos.size()); // evict key发送给dev侧,dev侧初始化emb auto trans = Singleton::GetInstance(); - trans->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, name_); + trans->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, name); - LOG_INFO(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", name_, evictPos_.size()); + LOG_INFO(KEY_PROCESS "hbm EvictInitDeviceEmb: [{}]! send offsetSize:{}", name, evictDevPos.size()); } absl::flat_hash_map EmbeddingTable::GetKeyOffsetMap() { - return keyOffsetMap_; + return keyOffsetMap; } void EmbeddingTable::ClearMissingKeys() @@ -145,12 +145,12 @@ void EmbeddingTable::ClearLookupAndSwapOffset() size_t EmbeddingTable::GetDevVocabSize() { - return devVocabSize_; + return devVocabSize; } size_t EmbeddingTable::GetHostVocabSize() { - return hostVocabSize_; + return hostVocabSize; } int EmbeddingTable::Load(const string& filePath) @@ -169,3 +169,30 @@ void EmbeddingTable::MakeDir(const string& dirName) unique_ptr fileSystemPtr = fileSystemHandler->Create(dirName); fileSystemPtr->CreateDir(dirName); } + +void EmbeddingTable::SetCacheManager(CacheManager *cm) +{ +} + +void EmbeddingTable::EnableSSD() +{ + isSSDEnabled_ = true; +} + +void EmbeddingTable::RefreshFreqInfoWithSwap() +{ +} + +TableInfo EmbeddingTable::GetTableInfo() +{ + TableInfo ti = { + .name=name, + .hostVocabSize=hostVocabSize, + .devVocabSize=devVocabSize, + .maxOffset=maxOffset, + .keyOffsetMap=keyOffsetMap, + .evictDevPos=evictDevPos, + .evictHostPos=evictHostPos, + }; + return ti; +} diff --git a/src/core/emb_table/embedding_table.h b/src/core/emb_table/embedding_table.h index 06baf7c3..4c24de45 100644 --- a/src/core/emb_table/embedding_table.h +++ b/src/core/emb_table/embedding_table.h @@ -12,6 +12,7 @@ #include #include "utils/common.h" +#include "ssd_cache/cache_manager.h" namespace MxRec { @@ -84,21 +85,30 @@ public: static void MakeDir(const string& dirName); + virtual void SetCacheManager(CacheManager* cacheManager); + + void EnableSSD(); + + virtual void RefreshFreqInfoWithSwap(); + + virtual TableInfo GetTableInfo(); + + std::string name; + size_t hostVocabSize; + size_t devVocabSize; + size_t maxOffset; + absl::flat_hash_map keyOffsetMap; + std::vector evictDevPos; // 记录HBM内被淘汰的key + std::vector evictHostPos; // 记录Host内淘汰列表 + #ifdef NDEBUG protected: #endif EmbeddingTable& operator=(const EmbeddingTable& table) = delete; - std::string name_; - size_t hostVocabSize_; - size_t devVocabSize_; size_t freeSize_; - size_t maxOffset_; bool isDynamic_; - absl::flat_hash_map keyOffsetMap_; - std::vector evictPos_; // 记录HBM内被淘汰的key - std::vector evictHostPos_; // 记录Host内淘汰列表 std::mutex mut_; std::vector initializeInfos_; EmbInfo embInfo_; @@ -109,6 +119,8 @@ protected: size_t rankId_; std::vector missingKeysHostPos_; // 用于记录当前batch在host上需要换出的偏移 + CacheManager* cacheManager_; + bool isSSDEnabled_ = false; }; } diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index f5895604..fd5285f3 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -116,6 +116,9 @@ bool HybridMgmt::Initialize(RankInfo rankInfo, const vector& embInfos, cacheManager->Init(hostEmbs, mgmtEmbInfo); hostHashMaps->isSSDEnabled = this->isSSDEnabled; hostHashMaps->cacheManager = this->cacheManager; + // 启用SSD时,EmbeddingDDR依赖cacheManager + EmbeddingMgmt::Instance()->EnableSSD(); + EmbeddingMgmt::Instance()->SetCacheManagerForEmbTable(this->cacheManager); } isLoad = ifLoad; if (!isLoad) { @@ -836,7 +839,7 @@ bool HybridMgmt::ProcessEmbInfo(const std::string& embName, int batchId, int cha channelId, batchId, sendRestoreSyncTC.ElapsedMS()); // 调用SSD cache缓存处理流程 - PrepareDDRData(embName, embHashMap, lookupKeys, channelId, batchId); + PrepareDDRData(table, lookupKeys, channelId, batchId); // 计算查询向量;记录需要被换出的HBM偏移 vector tmpData; @@ -1032,20 +1035,21 @@ void HybridMgmt::EvictKeys(const string& embName, const vector& keys) hdTransfer->Send(TransferChannel::EVICT, tmpDataOut, TRAIN_CHANNEL_ID, embName); } -inline void HybridMgmt::PrepareDDRData(const string& embTableName, EmbHashMapInfo& embHashMap, +inline void HybridMgmt::PrepareDDRData(std::shared_ptr table, const vector& keys, int channelId, int batchId) const { if (!isSSDEnabled) { return; } - LOG_DEBUG("channelId:{} batchId:{}, embTableName:{}, PrepareDDRData start.", channelId, batchId, embTableName); + LOG_DEBUG("channelId:{} batchId:{}, embTableName:{}, PrepareDDRData start.", channelId, batchId, table->name); TimeCost prepareDDRDataTc; - TransferRet ret = cacheManager->TransferDDREmbWithSSD(embTableName, embHashMap, keys, channelId); + TableInfo ti = table->GetTableInfo(); + TransferRet ret = cacheManager->TransferDDREmbWithSSD(ti, keys, channelId); if (ret != TransferRet::TRANSFER_OK) { HandlePrepareDDRDataRet(ret); } LOG_DEBUG("channelId:{} batchId:{}, embTableName:{}, PrepareDDRData end, prepareDDRDataTc(ms):{}", - channelId, batchId, embTableName, prepareDDRDataTc.ElapsedMS()); + channelId, batchId, table->name, prepareDDRDataTc.ElapsedMS()); } void HybridMgmt::EvictSSDKeys(const string& embName, const vector& keys) const diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.h b/src/core/hybrid_mgmt/hybrid_mgmt.h index ce55172e..243f9adc 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.h +++ b/src/core/hybrid_mgmt/hybrid_mgmt.h @@ -30,6 +30,7 @@ See the License for the specific language governing permissions and #include "hd_transfer/hd_transfer.h" #include "ssd_cache/cache_manager.h" #include "hybrid_mgmt_block.h" +#include "emb_table/embedding_table.h" namespace MxRec { using namespace std; @@ -104,7 +105,7 @@ namespace MxRec { void EvictSSDKeys(const string& embName, const vector& keys) const; - void PrepareDDRData(const std::string& embTableName, EmbHashMapInfo& embHashMap, + void PrepareDDRData(std::shared_ptr table, const vector &keys, int channelId, int batchId) const; int GetStepFromPath(const string& loadPath) const; diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index a75834df..c3611d91 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -24,11 +24,12 @@ See the License for the specific language governing permissions and using namespace MxRec; -inline void CacheManager::GetExternalKeys(EmbHashMapInfo &embHashMap, vector &externalKeys, - vector &internalKeys, const vector &keys) const +inline void CacheManager::GetExternalKeys(const absl::flat_hash_map &keyOffsetMap, + vector &externalKeys, vector &internalKeys, + const vector &keys) const { for (const emb_key_t key : keys) { - if (embHashMap.hostHashMap.find(key) == embHashMap.hostHashMap.end()) { + if (keyOffsetMap.find(key) == keyOffsetMap.end()) { externalKeys.emplace_back(key); } else { internalKeys.emplace_back(key); @@ -69,7 +70,7 @@ void CacheManager::HandleRepeatAndInvalidKey(const vector& originalKe /// \param originalKeys 当前批次key /// \param channelId 通道id /// \return 转移结果枚举 -TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, +TransferRet CacheManager::TransferDDREmbWithSSD(TableInfo& table, const vector& originalKeys, int channelId) { vector keys; // 去重和删除无效key @@ -77,22 +78,22 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, // 区分HBM+DDR内key,和HBM+DDR外的key(新key或保存在SSD中的key) vector externalKeys; vector internalKeys; - GetExternalKeys(embHashMap, externalKeys, internalKeys, keys); + GetExternalKeys(table.keyOffsetMap, externalKeys, internalKeys, keys); if (externalKeys.empty()) { return TransferRet::TRANSFER_OK; } // 判断剩余内存空间是否足够; 可用内存空间计算:HBM+DDR-已占用; 若是训练,再加DDR已淘汰; // SSD仅与DDR交互,不考虑HBM淘汰位置;由于maxOffset比实际使用大1,所以虽然从0开始也不用再减1 - size_t ddrAvailableSize = embHashMap.devVocabSize + embHashMap.hostVocabSize - embHashMap.maxOffset; + size_t ddrAvailableSize = table.devVocabSize + table.hostVocabSize - table.maxOffset; if (channelId == TRAIN_CHANNEL_ID) { - ddrAvailableSize += embHashMap.evictPos.size(); + ddrAvailableSize += table.evictHostPos.size(); } - LOG_DEBUG("TransferDDREmbWithSSD, maxOffset:{}, evictPos size:{}, ddrAvailableSize:{}", - embHashMap.maxOffset, embHashMap.evictPos.size(), ddrAvailableSize); - CreateSSDTableIfNotExist(embTableName); + LOG_DEBUG("TransferDDREmbWithSSD, maxOffset:{}, evictHostPos size:{}, ddrAvailableSize:{}", + table.maxOffset, table.evictHostPos.size(), ddrAvailableSize); + CreateSSDTableIfNotExist(table.name); // 调用ssdEngine查询当前批次key中保存在SSD中的key vector externalSSDKeys; - GetSSDKeys(embTableName, externalKeys, externalSSDKeys); + GetSSDKeys(table.name, externalKeys, externalSSDKeys); // 后续判断maxOffset是否超出范围时,maxOffset=devVocabSize+hostVocabSize时可用,此处包含等于 bool isDDRSpaceEnough = ddrAvailableSize >= externalKeys.size(); bool ddrSpaceEnoughOrEval = channelId != TRAIN_CHANNEL_ID || isDDRSpaceEnough; @@ -130,7 +131,7 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, // 训练场景检查SSD剩余空间 评估不考虑新key if (channelId == TRAIN_CHANNEL_ID) { size_t needSSDSize = externalKeys.size() - externalSSDKeys.size() - ddrAvailableSize; - const int64_t ssdAvailableSize = ssdEngine->GetTableAvailableSpace(embTableName); + const int64_t ssdAvailableSize = ssdEngine->GetTableAvailableSpace(table.name); if (int64_t(needSSDSize) > ssdAvailableSize) { LOG_ERROR("TransferDDREmbWithSSD: ssd available space is not enough to transfer DDR emb data. " "needSSDSize:{}, ssdAvailableSize:{}", needSSDSize, ssdAvailableSize); @@ -141,8 +142,8 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, // 从SSD获取emb数据并从SSD删除; 避免DDR->SSD时空间不够 vector> ssdEmbData; if (!externalSSDKeys.empty()) { - ssdEmbData = ssdEngine->FetchEmbeddings(embTableName, externalSSDKeys); - ssdEngine->DeleteEmbeddings(embTableName, externalSSDKeys); + ssdEmbData = ssdEngine->FetchEmbeddings(table.name, externalSSDKeys); + ssdEngine->DeleteEmbeddings(table.name, externalSSDKeys); } // 从ddr转移到ssd的key个数 @@ -155,18 +156,18 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, */ // 记录要从DDR转移到SSD的key对应的offset(相对值,需减去devVocabSize) vector ddrTransferPos; - TransferRet ddr2SsdRet = TransferDDREmb2SSD(embTableName, embHashMap, ddrSwapOutSize, internalKeys, ddrTransferPos); + TransferRet ddr2SsdRet = TransferDDREmb2SSD(table, ddrSwapOutSize, internalKeys, ddrTransferPos); if (ddr2SsdRet == TransferRet::DDR_SPACE_NOT_ENOUGH) { - ssdEngine->InsertEmbeddings(embTableName, externalSSDKeys, ssdEmbData); + ssdEngine->InsertEmbeddings(table.name, externalSSDKeys, ssdEmbData); return ddr2SsdRet; } - HandleDDRTransferPos(ddrTransferPos, externalSSDKeys, embHashMap); + HandleDDRTransferPos(ddrTransferPos, externalSSDKeys, table); /* * 转移SSD中保存的当前批次key的emb数据到DDR */ - return TransferSSDEmb2DDR(embTableName, embHashMap, externalSSDKeys, ddrTransferPos, ssdEmbData); + return TransferSSDEmb2DDR(table, externalSSDKeys, ddrTransferPos, ssdEmbData); } /// SSD数据转移到DDR中后刷新映射和频次信息 @@ -174,32 +175,32 @@ TransferRet CacheManager::TransferDDREmbWithSSD(const std::string& embTableName, /// \param embHashMap emb hash表 /// \param externalSSDKeys 存储在SSD中的key列表 /// \param ddrTransferPos -void CacheManager::RefreshRelateInfoWithSSD2DDR(const std::string& embTableName, EmbHashMapInfo& embHashMap, +void CacheManager::RefreshRelateInfoWithSSD2DDR(TableInfo& table, vector& externalSSDKeys, vector& ddrTransferPos) { for (size_t i = 0; i < externalSSDKeys.size(); ++i) { // 映射关系 ddrTransferPos是在ddrEmbHash中的位置,记录映射时需加上devVocabSize auto& key = externalSSDKeys[i]; - embHashMap.hostHashMap[key] = ddrTransferPos[i] + embHashMap.devVocabSize; + table.keyOffsetMap[key] = ddrTransferPos[i] + table.devVocabSize; // 频次 - ddrKeyFreqMap[embTableName].PutWithInit(key, excludeDDRKeyCountMap[embTableName][key]); - excludeDDRKeyCountMap[embTableName].erase(key); + ddrKeyFreqMap[table.name].PutWithInit(key, excludeDDRKeyCountMap[table.name][key]); + excludeDDRKeyCountMap[table.name].erase(key); } } -void CacheManager::GetDDREmbInfo(vector& keys, const std::string& embTableName, EmbHashMapInfo& embHashMap, +void CacheManager::GetDDREmbInfo(vector& keys, TableInfo& table, vector& ddrTransferPos, vector>& ddrEmbData) const { // 根据offset 获取对应Emb数据 for (auto& key : keys) { - ddrTransferPos.emplace_back(embHashMap.hostHashMap[key] - embHashMap.devVocabSize); + ddrTransferPos.emplace_back(table.keyOffsetMap[key] - table.devVocabSize); } LOG_TRACE("DDR keys:{}", VectorToString(keys)); LOG_TRACE("DDR key positions:{}", VectorToString(ddrTransferPos)); ddrEmbData.resize(keys.size()); - const auto& emb = hostEmbs->GetEmb(embTableName); + const auto& emb = hostEmbs->GetEmb(table.name); #pragma omp parallel for num_threads(MGMT_CPY_THREADS) default(none) shared(ddrTransferPos, emb, ddrEmbData) for (size_t i = 0; i < ddrTransferPos.size(); ++i) { auto& missingKeyPo = ddrTransferPos[i]; @@ -229,14 +230,14 @@ void CacheManager::UpdateDDREmbInfo(const std::string& embTableName, /// \param embHashMap emb map /// \param ddrSwapOutKeys 从DDR中转移到SSD中key列表 /// \param ddrSwapOutCounts 从DDR中转移到SSD中key频次数据 -void CacheManager::RefreshRelateInfoWithDDR2SSD(const string& embTableName, EmbHashMapInfo& embHashMap, +void CacheManager::RefreshRelateInfoWithDDR2SSD(TableInfo& table, vector& ddrSwapOutKeys, vector& ddrSwapOutCounts) { - auto& excludeFreqMap = excludeDDRKeyCountMap[embTableName]; + auto& excludeFreqMap = excludeDDRKeyCountMap[table.name]; for (size_t i = 0; i < ddrSwapOutKeys.size(); ++i) { auto& key = ddrSwapOutKeys[i]; - embHashMap.hostHashMap.erase(key); + table.keyOffsetMap.erase(key); excludeFreqMap[key] = ddrSwapOutCounts[i]; } } @@ -325,7 +326,7 @@ void CacheManager::PutKey(const string& embTableName, const emb_key_t& key, Reco /// \param externalSSDKeys SSD->DDR的key列表 /// \param embHashMap emb hash表 void CacheManager::HandleDDRTransferPos(vector& ddrTransferPos, vector& externalSSDKeys, - EmbHashMapInfo& embHashMap) + TableInfo& table) { if (ddrTransferPos.size() == externalSSDKeys.size()) { return; @@ -336,21 +337,21 @@ void CacheManager::HandleDDRTransferPos(vector& ddrTransferPos, vector externalSSDKeys.size()) { while (ddrTransferPos.size() > externalSSDKeys.size()) { - embHashMap.evictPos.emplace_back(ddrTransferPos.back() + embHashMap.devVocabSize); + table.evictHostPos.emplace_back(ddrTransferPos.back() + table.devVocabSize); ddrTransferPos.pop_back(); } return; } // 补齐offset - while (ddrTransferPos.size() < externalSSDKeys.size() && !embHashMap.evictPos.empty()) { - ddrTransferPos.emplace_back(embHashMap.evictPos.back() - embHashMap.devVocabSize); - embHashMap.evictPos.pop_back(); + while (ddrTransferPos.size() < externalSSDKeys.size() && !table.evictHostPos.empty()) { + ddrTransferPos.emplace_back(table.evictHostPos.back() - table.devVocabSize); + table.evictHostPos.pop_back(); } - auto allSize = embHashMap.devVocabSize + embHashMap.hostVocabSize; + auto allSize = table.devVocabSize + table.hostVocabSize; // 还不够继续使用maxOffset - while (ddrTransferPos.size() < externalSSDKeys.size() && embHashMap.maxOffset < allSize) { - auto nextPos = embHashMap.maxOffset++; - ddrTransferPos.emplace_back(nextPos - embHashMap.devVocabSize); + while (ddrTransferPos.size() < externalSSDKeys.size() && table.maxOffset < allSize) { + auto nextPos = table.maxOffset++; + ddrTransferPos.emplace_back(nextPos - table.devVocabSize); } LOG_DEBUG("HandleDDRTransferPos: handle end, pos len:{}, keys len:{}", ddrTransferPos.size(), externalSSDKeys.size()); @@ -366,7 +367,7 @@ void CacheManager::GetSSDKeys(const std::string& embTableName, vector } } -TransferRet CacheManager::TransferDDREmb2SSD(const string& embTableName, EmbHashMapInfo& embHashMap, +TransferRet CacheManager::TransferDDREmb2SSD(TableInfo& table, int64_t ddrSwapOutSize, const vector& keys, vector& ddrTransferPos) { @@ -380,34 +381,33 @@ TransferRet CacheManager::TransferDDREmb2SSD(const string& embTableName, EmbHash // 获取DDR中指定数量的最低频次key,并获取相应emb数据,执行DDR换出到SSD vector ddrSwapOutKeys; vector ddrSwapOutCounts; - ddrKeyFreqMap[embTableName].GetAndDeleteLeastFreqKeyInfo(ddrSwapOutSize, keys, ddrSwapOutKeys, - ddrSwapOutCounts); + ddrKeyFreqMap[table.name].GetAndDeleteLeastFreqKeyInfo(ddrSwapOutSize, keys, ddrSwapOutKeys, ddrSwapOutCounts); if (static_cast(ddrSwapOutKeys.size()) != ddrSwapOutSize) { - auto keyTableSize = ddrKeyFreqMap[embTableName].keyTable.size(); + auto keyTableSize = ddrKeyFreqMap[table.name].keyTable.size(); // 获取的最低频次key数量和预期不一致,DDR空间不足,不能放置当前批次数据 LOG_ERROR("TransferDDREmbWithSSD, vector length is not equal, ddrSwapOutKeys size:{}, " "ddrSwapOutSize:{}, ddr lfu keyTable size:{}", ddrSwapOutKeys.size(), ddrSwapOutSize, keyTableSize); - RestoreLeastFreqInfo(embTableName, ddrSwapOutKeys, ddrSwapOutCounts); + RestoreLeastFreqInfo(table.name, ddrSwapOutKeys, ddrSwapOutCounts); return TransferRet::DDR_SPACE_NOT_ENOUGH; } LOG_DEBUG("TransferDDREmbWithSSD: get DDR embeddings and save to SSD, size:{}", ddrSwapOutKeys.size()); // 获取DDR中emb数据 vector> ddrEmbData; - GetDDREmbInfo(ddrSwapOutKeys, embTableName, embHashMap, ddrTransferPos, ddrEmbData); + GetDDREmbInfo(ddrSwapOutKeys, table, ddrTransferPos, ddrEmbData); // 调用SSDEngine接口,将DDR Emb数据保存到SSD - ssdEngine->InsertEmbeddings(embTableName, ddrSwapOutKeys, ddrEmbData); + ssdEngine->InsertEmbeddings(table.name, ddrSwapOutKeys, ddrEmbData); // 初始化DDR内被转移出去的位置 - hostEmbs->EvictInitEmb(embTableName, ddrTransferPos); + hostEmbs->EvictInitEmb(table.name, ddrTransferPos); // 更新记录的DDR中key频次信息 - RefreshRelateInfoWithDDR2SSD(embTableName, embHashMap, ddrSwapOutKeys, ddrSwapOutCounts); + RefreshRelateInfoWithDDR2SSD(table, ddrSwapOutKeys, ddrSwapOutCounts); LOG_DEBUG("TransferDDREmbWithSSD: ddr2SsdTc TimeCost(ms):{}", ddr2SsdTc.ElapsedMS()); return TransferRet::TRANSFER_OK; } -TransferRet CacheManager::TransferSSDEmb2DDR(const string& embTableName, EmbHashMapInfo& embHashMap, +TransferRet CacheManager::TransferSSDEmb2DDR(TableInfo& table, vector& externalSSDKeys, vector& ddrTransferPos, vector>& ssdEmbData) { @@ -422,8 +422,8 @@ TransferRet CacheManager::TransferSSDEmb2DDR(const string& embTableName, EmbHash return TransferRet::TRANSFER_ERROR; } // 将SSD emb存储到DDR中 刷新频次信息 - UpdateDDREmbInfo(embTableName, ddrTransferPos, ssdEmbData); - RefreshRelateInfoWithSSD2DDR(embTableName, embHashMap, externalSSDKeys, ddrTransferPos); + UpdateDDREmbInfo(table.name, ddrTransferPos, ssdEmbData); + RefreshRelateInfoWithSSD2DDR(table, externalSSDKeys, ddrTransferPos); LOG_DEBUG("TransferDDREmbWithSSD: ssd2DdrTc TimeCost(ms):{}", ssd2DdrTc.ElapsedMS()); return TransferRet::TRANSFER_OK; } diff --git a/src/core/ssd_cache/cache_manager.h b/src/core/ssd_cache/cache_manager.h index 26ca3682..e6ed6781 100644 --- a/src/core/ssd_cache/cache_manager.h +++ b/src/core/ssd_cache/cache_manager.h @@ -29,6 +29,17 @@ See the License for the specific language governing permissions and #include "utils/common.h" namespace MxRec { + + struct TableInfo { + std::string name; + size_t hostVocabSize; + size_t devVocabSize; + size_t& maxOffset; + absl::flat_hash_map& keyOffsetMap; + std::vector& evictDevPos; // 记录HBM内被淘汰的key + std::vector& evictHostPos; // 记录Host内淘汰列表 + }; + enum class TransferRet { TRANSFER_OK = 0, // 转移成功或无需处理 TRANSFER_ERROR, @@ -62,7 +73,7 @@ namespace MxRec { void SaveSSDEngine(int step); // 转换DDR和SSD数据 - TransferRet TransferDDREmbWithSSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, + TransferRet TransferDDREmbWithSSD(TableInfo& table, const vector& originalKeys, int channelId); /* HBM与DDR换入换出时刷新频次信息 */ @@ -90,27 +101,27 @@ namespace MxRec { }; void GetDDREmbInfo(vector& keys, - const std::string& embTableName, EmbHashMapInfo& embHashMap, + TableInfo& table, vector& ddrTransferPos, vector>& ddrEmbData) const; void UpdateDDREmbInfo(const std::string& embTableName, vector& ddrTransferPos, vector>& ssdEmbData) const; - void RefreshRelateInfoWithDDR2SSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, + void RefreshRelateInfoWithDDR2SSD(TableInfo& table, vector& ddrSwapOutKeys, vector& ddrSwapOutCounts); - void RefreshRelateInfoWithSSD2DDR(const std::string& embTableName, EmbHashMapInfo& embHashMap, + void RefreshRelateInfoWithSSD2DDR(TableInfo& table, vector& externalSSDKeys, vector& ddrTransferPos); void GetSSDKeys(const std::string& embTableName, vector& externalKeys, vector& externalSSDKeys); - TransferRet TransferDDREmb2SSD(const std::string& embTableName, EmbHashMapInfo& embHashMap, + TransferRet TransferDDREmb2SSD(TableInfo& table, int64_t ddrSwapOutSize, const vector& keys, vector& ddrTransferPos); - TransferRet TransferSSDEmb2DDR(const std::string& embTableName, EmbHashMapInfo& embHashMap, + TransferRet TransferSSDEmb2DDR(TableInfo& table, vector& externalSSDKeys, vector& ddrTransferPos, vector>& ssdEmbData); @@ -120,9 +131,10 @@ namespace MxRec { vector& ddrSwapOutCounts); static void HandleDDRTransferPos(vector& ddrTransferPos, vector& externalSSDKeys, - EmbHashMapInfo& embHashMap); + TableInfo& table); - inline void GetExternalKeys(EmbHashMapInfo& embHashMap, vector& externalKeys, + inline void GetExternalKeys(const absl::flat_hash_map &keyOffsetMap, + vector& externalKeys, vector& internalKeys, const vector& keys) const; void AddDebugAndTraceLog(size_t batchKeySize, vector& externalKeys, diff --git a/src/tests/emb_table/embedding_ddr_test.cpp b/src/tests/emb_table/embedding_ddr_test.cpp index 71245b59..621bba60 100644 --- a/src/tests/emb_table/embedding_ddr_test.cpp +++ b/src/tests/emb_table/embedding_ddr_test.cpp @@ -81,16 +81,16 @@ TEST_F(EmbeddingDDRTest, SaveLoadBasic) // 使用时间构造测试数据 ddr1->extEmbSize_ = time(nullptr); - ddr1->devVocabSize_ = time(nullptr); - ddr1->hostVocabSize_ = time(nullptr); + ddr1->devVocabSize = time(nullptr); + ddr1->hostVocabSize = time(nullptr); ddr1->currentUpdatePos = time(nullptr); - ddr1->maxOffset_ = time(nullptr); + ddr1->maxOffset = time(nullptr); vector devOffset2KeyTestData; for (int i = 0; i < 10; ++i) { devOffset2KeyTestData.push_back(static_cast(i)); - ddr1->keyOffsetMap_[i] = i; - ddr1->evictPos_.push_back(i); + ddr1->keyOffsetMap[i] = i; + ddr1->evictDevPos.push_back(i); } ddr1->devOffset2Key = devOffset2KeyTestData; @@ -99,11 +99,11 @@ TEST_F(EmbeddingDDRTest, SaveLoadBasic) ddr2->Load("test_dir"); for (int i = 0; i < 10; ++i) { - EXPECT_EQ(ddr1->evictPos_[i], ddr2->evictPos_[i]); + EXPECT_EQ(ddr1->evictDevPos[i], ddr2->evictDevPos[i]); } EXPECT_EQ(ddr1->extEmbSize_, ddr2->extEmbSize_); - EXPECT_EQ(ddr1->devVocabSize_, ddr2->devVocabSize_); + EXPECT_EQ(ddr1->devVocabSize, ddr2->devVocabSize); } /** @@ -173,7 +173,7 @@ TEST_F(EmbeddingDDRTest, evict) } table->FindOffset(testKeys, 0, TRAIN_CHANNEL_ID, testSwap); table->EvictKeys(testKeys); - EXPECT_EQ(table->evictPos_.size(), 100); + EXPECT_EQ(table->evictDevPos.size(), 100); EXPECT_EQ(testKeys.size(), 100); EXPECT_EQ(testSwap.size(), 0); } diff --git a/src/tests/ssd_cache/cache_manager_test.cpp b/src/tests/ssd_cache/cache_manager_test.cpp index 31c953fa..e4c9e939 100644 --- a/src/tests/ssd_cache/cache_manager_test.cpp +++ b/src/tests/ssd_cache/cache_manager_test.cpp @@ -22,6 +22,7 @@ See the License for the specific language governing permissions and #include "ssd_cache/lfu_cache.h" #include "ssd_cache/cache_manager.h" #include "utils/common.h" +#include "emb_table/embedding_ddr.h" using namespace std; using namespace MxRec; @@ -192,12 +193,16 @@ TEST_F(CacheManagerTest, IsKeyInSSD) TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalKey) { - EmbHashMapInfo embHashMapInfo; + EmbeddingDDR table; + vector currentKeys = {55, 65, 75}; - embHashMapInfo.hostHashMap[55] = 119; - embHashMapInfo.hostHashMap[65] = 118; - embHashMapInfo.hostHashMap[75] = 116; - auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); + table.keyOffsetMap[55] = 119; + table.keyOffsetMap[65] = 118; + table.keyOffsetMap[75] = 116; + + TableInfo ti = table.GetTableInfo(); + + auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, TRAIN_CHANNEL_ID); ASSERT_EQ(ret, TransferRet::TRANSFER_OK); LOG_INFO("test TransferDDREmbWithSSDByEmptyExternalKey end."); } @@ -207,21 +212,24 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByAllProcess) vector ssdKeys = {15, 25}; vector> ssdKeyEmbInfo = {{1.5f}, {2.5f}}; - // init EmbHashMapInfo - EmbHashMapInfo embHashMapInfo; - embHashMapInfo.devVocabSize = 20; - embHashMapInfo.hostVocabSize = 100; - embHashMapInfo.maxOffset = 118; // 剩余2个可用空间(DDR剩余, 相对位置:98 99) - embHashMapInfo.evictPos.emplace_back(110); // 淘汰列表 + // init EmbeddingDDR + EmbeddingDDR table; + table.name = embTableName; + table.devVocabSize = 20; + table.hostVocabSize = 100; + table.maxOffset = 118; + table.evictHostPos.emplace_back(110); // 淘汰列表 + + TableInfo ti = table.GetTableInfo(); // 构造已经存储早DDR中key和offset对应关系; DDR的offset在映射表中范围是 20~119 - embHashMapInfo.hostHashMap[9] = 117; // DDR中相对位置: 97 - embHashMapInfo.hostHashMap[8] = 116; // DDR中相对位置: 96 - embHashMapInfo.hostHashMap[6] = 114; // DDR中相对位置: 94 - embHashMapInfo.hostHashMap[4] = 112; // DDR中相对位置: 92 - embHashMapInfo.hostHashMap[3] = 111; // DDR中相对位置: 91 - embHashMapInfo.hostHashMap[2] = 21; // DDR中相对位置: 1 - embHashMapInfo.hostHashMap[1] = 20; // DDR中相对位置: 0 + table.keyOffsetMap[9] = 117; // DDR中相对位置: 97 + table.keyOffsetMap[8] = 116; // DDR中相对位置: 96 + table.keyOffsetMap[6] = 114; // DDR中相对位置: 94 + table.keyOffsetMap[4] = 112; // DDR中相对位置: 92 + table.keyOffsetMap[3] = 111; // DDR中相对位置: 91 + table.keyOffsetMap[2] = 21; // DDR中相对位置: 1 + table.keyOffsetMap[1] = 20; // DDR中相对位置: 0 // 检查构造数据正确性 auto& embMap = cacheManager.hostEmbs->hostEmbs; @@ -238,33 +246,31 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByAllProcess) ASSERT_FALSE(cacheManager.ssdEngine->IsKeyExist(embTableName, 8)); ASSERT_TRUE(cacheManager.IsKeyInSSD(embTableName, 15)); - LOG_INFO("check detail data before transfer ok."); - // externalKeys: SSD(15, 25) + newKey(55, 65, 75) // 训练场景,构造结果:offsetAvailableSize=20+100-118+evictPos.size()=3 // cacheManager中的频次数据(低-高): 9 8 6 4 3 2 1 // 构造空间超出SSD可用上限 vector exceedKeys = {15, 25, 6, 4, 55, 65, 75, 85, 95, 105, 115}; - auto spaceError1 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, exceedKeys, TRAIN_CHANNEL_ID); + auto spaceError1 = cacheManager.TransferDDREmbWithSSD(ti, exceedKeys, TRAIN_CHANNEL_ID); ASSERT_EQ(spaceError1, TransferRet::SSD_SPACE_NOT_ENOUGH); // 构造训练+超SSD可用+当前批次中不包含报错在SSD的key vector keys2 = {6, 4, 55, 65, 75, 85, 95, 105, 115, 125, 135}; - auto spaceError2 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, exceedKeys, TRAIN_CHANNEL_ID); + auto spaceError2 = cacheManager.TransferDDREmbWithSSD(ti, exceedKeys, TRAIN_CHANNEL_ID); ASSERT_EQ(spaceError2, TransferRet::SSD_SPACE_NOT_ENOUGH); // 构造当前批次key 存储位置: SSD(15, 25) DDR(6, 4) newKey(55, 65, 75) vector currentKeys = {15, 25, 6, 4, 55, 65, 75}; // 需要从ddr转移4个key到ssd, 低频数据中6 4在当前批次key中,不会被转移,构造的数据转移key:9, 8, 3, 2 - auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); + auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, TRAIN_CHANNEL_ID); // 检查处理后数据正确性 ASSERT_EQ(ret, TransferRet::TRANSFER_OK); ASSERT_TRUE(fabs(hostData[94][0] - 6.0f) < EPSILON); // DDR内未移动的数据 ASSERT_TRUE(fabs(hostData[96][0] - 25.0f) < EPSILON); // SSD转移到DDR的数据 ASSERT_TRUE(fabs(hostData[97][0] - 15.0f) < EPSILON); // SSD转移到DDR的数据 - ASSERT_EQ(embHashMapInfo.evictPos.size(), 1); - ASSERT_EQ(embHashMapInfo.evictPos.back(), 110); + ASSERT_EQ(table.evictHostPos.size(), 1); + ASSERT_EQ(table.evictHostPos.back(), 110); // 原DDR中最小频次key(9,8)次数(1)被转移到SSD,SSD转移到DDR的key(15,25)次数(3,5), DDR内频次索引应变为2 ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 2); @@ -277,26 +283,30 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByAllProcess) TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalSSDKey) { // 训练+评估:构造DDR剩余空间足够,externalSSDKeys为空 - EmbHashMapInfo embHashMapInfo; - embHashMapInfo.devVocabSize = 20; - embHashMapInfo.hostVocabSize = 100; - embHashMapInfo.hostHashMap[6] = 114; // DDR中相对位置: 94 - embHashMapInfo.hostHashMap[4] = 112; // DDR中相对位置: 92 + EmbeddingDDR table; + table.name = embTableName; + table.devVocabSize = 20; + table.hostVocabSize = 100; + table.keyOffsetMap[6] = 114; // DDR中相对位置: 94 + table.keyOffsetMap[4] = 112; // DDR中相对位置: 92 // 剩余3个可用空间(DDR剩余2个, 相对位置:98 99; DDR淘汰列表1个) - embHashMapInfo.maxOffset = 118; - embHashMapInfo.evictPos.emplace_back(110); + table.maxOffset = 118; + table.evictHostPos.emplace_back(110); + + TableInfo ti = table.GetTableInfo(); + vector currentKeys = {6, 4, 55, 65, 75}; - auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); + auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, TRAIN_CHANNEL_ID); ASSERT_EQ(ret, TransferRet::TRANSFER_OK); - auto retByEval = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, EVAL_CHANNEL_ID); + auto retByEval = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, EVAL_CHANNEL_ID); ASSERT_EQ(retByEval, TransferRet::TRANSFER_OK); // 评估场景, DDR剩余空间不足, externalSSDKeys为空 vector currentKeys2 = {6, 4, 55, 65, 75, 85, 95, 105, 115}; - auto ret2 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys2, EVAL_CHANNEL_ID); + auto ret2 = cacheManager.TransferDDREmbWithSSD(ti, currentKeys2, EVAL_CHANNEL_ID); ASSERT_EQ(ret2, TransferRet::TRANSFER_OK); // 训练场景,返回ssd空间不足 - auto ret3 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys2, TRAIN_CHANNEL_ID); + auto ret3 = cacheManager.TransferDDREmbWithSSD(ti, currentKeys2, TRAIN_CHANNEL_ID); ASSERT_EQ(ret3, TransferRet::SSD_SPACE_NOT_ENOUGH); LOG_INFO("test TransferDDREmbWithSSDByEmptyExternalSSDKey end."); } @@ -304,24 +314,28 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEmptyExternalSSDKey) TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEval) { // 评估+DDR剩余空间足够+externalSSDKeys为空 - EmbHashMapInfo embHashMapInfo; - embHashMapInfo.devVocabSize = 20; - embHashMapInfo.hostVocabSize = 100; - embHashMapInfo.hostHashMap[9] = 117; // DDR中相对位置: 97 - embHashMapInfo.hostHashMap[8] = 116; // DDR中相对位置: 96 - embHashMapInfo.hostHashMap[6] = 114; // DDR中相对位置: 94 - embHashMapInfo.hostHashMap[4] = 112; // DDR中相对位置: 92 + EmbeddingDDR table; + table.name = embTableName; + table.devVocabSize = 20; + table.hostVocabSize = 100; + table.keyOffsetMap[9] = 117; // DDR中相对位置: 97 + table.keyOffsetMap[8] = 116; // DDR中相对位置: 96 + table.keyOffsetMap[6] = 114; // DDR中相对位置: 94 + table.keyOffsetMap[4] = 112; // DDR中相对位置: 92 // 剩余3个可用空间(DDR剩余2个, 相对位置:98 99; DDR淘汰列表1个) - embHashMapInfo.maxOffset = 118; - embHashMapInfo.evictPos.emplace_back(110); // 淘汰列表 + table.maxOffset = 118; + table.evictHostPos.emplace_back(110); // 淘汰列表 + + TableInfo ti = table.GetTableInfo(); + vector currentKeys = {6, 4, 55, 65, 75}; - auto ret = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys, EVAL_CHANNEL_ID); + auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, EVAL_CHANNEL_ID); ASSERT_EQ(ret, TransferRet::TRANSFER_OK); LOG_INFO("test eval+space enough+externalSSDKeysEmpty ok."); // 评估+DDR剩余空间足够+externalSSDKeys非空 vector currentKeys2 = {15, 25, 6, 4, 55, 65, 75, 85, 95, 105, 115}; - auto ret2 = cacheManager.TransferDDREmbWithSSD(embTableName, embHashMapInfo, currentKeys2, EVAL_CHANNEL_ID); + auto ret2 = cacheManager.TransferDDREmbWithSSD(ti, currentKeys2, EVAL_CHANNEL_ID); ASSERT_EQ(ret2, TransferRet::TRANSFER_OK); // 检查处理后数据正确性 const auto& it = cacheManager.hostEmbs->hostEmbs.find(embTableName); @@ -329,7 +343,7 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEval) ASSERT_TRUE(fabs(hostData[94][0] - 6.0f) < EPSILON); // DDR内未移动的数据 ASSERT_TRUE(fabs(hostData[98][0] - 25.0f) < EPSILON); // SSD转移到DDR的数据 ASSERT_TRUE(fabs(hostData[90][0] - 15.0f) < EPSILON); // SSD转移到DDR的数据 - ASSERT_EQ(embHashMapInfo.evictPos.size(), 0); + ASSERT_EQ(table.evictHostPos.size(), 0); // 原DDR中最小频次key(9,8)次数(1)被转移到SSD,SSD转移到DDR的key(15,25)次数(3,5), DDR内频次索引应变为2 ASSERT_EQ(cacheManager.ddrKeyFreqMap[embTableName].minFreq, 1); ASSERT_FALSE(cacheManager.IsKeyInSSD(embTableName, 9)); @@ -341,15 +355,19 @@ TEST_F(CacheManagerTest, TransferDDREmbWithSSDByEval) TEST_F(CacheManagerTest, TransferDDREmbWithSSDByDDRSpaceNotEnough) { // 构造DDR所有空间不满足存放当前批次数据 - EmbHashMapInfo embHashMapInfo; - embHashMapInfo.devVocabSize = 20; - embHashMapInfo.hostVocabSize = 10; - embHashMapInfo.maxOffset = 30; - embHashMapInfo.hostHashMap[6] = 9; - embHashMapInfo.hostHashMap[4] = 8; + EmbeddingDDR table; + table.name = embTableName2; + table.devVocabSize = 20; + table.hostVocabSize = 10; + table.maxOffset = 30; + table.keyOffsetMap[6] = 9; + table.keyOffsetMap[4] = 8; + + TableInfo ti = table.GetTableInfo(); + // keys size:10, ddr keys:2 externalKeys:8 externalSSDKeys:0 vector currentKeys = {6, 4, 101, 102, 103, 104, 105, 106, 107, 108}; - auto ret = cacheManager.TransferDDREmbWithSSD(embTableName2, embHashMapInfo, currentKeys, TRAIN_CHANNEL_ID); + auto ret = cacheManager.TransferDDREmbWithSSD(ti, currentKeys, TRAIN_CHANNEL_ID); ASSERT_EQ(ret, TransferRet::DDR_SPACE_NOT_ENOUGH); LOG_INFO("test train+ddr space enough+externalSSDKeysEmpty ok."); } -- Gitee From c618c505cde4ce934c806e8d00d8e63708dc7b4a Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 23 Feb 2024 10:07:35 +0800 Subject: [PATCH 541/551] Match-id-a76db0396590ea54ea8989fc0858024c74f5a8c7 --- mx_rec/core/asc/manager.py | 2 +- mx_rec/core/emb/sparse_embedding.py | 14 +------------ mx_rec/optimizers/adagrad.py | 19 +++++++++-------- mx_rec/optimizers/ftrl.py | 21 +++++++++++-------- mx_rec/optimizers/gradient_descent.py | 5 ++++- mx_rec/optimizers/gradient_descent_by_addr.py | 1 + mx_rec/optimizers/lazy_adam.py | 20 +++++++++++------- mx_rec/optimizers/lazy_adam_by_addr.py | 1 + mx_rec/saver/saver.py | 17 ++++++++------- mx_rec/util/config_utils/optimizer_utils.py | 16 ++++++++++++++ tests/mx_rec/core/test_embedding.py | 20 ++++-------------- 11 files changed, 71 insertions(+), 65 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 2f555a56..9c7552dc 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -43,7 +43,7 @@ def generate_table_info_list(): for _, table_instance in ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict.items(): # When dynamic expansion mode, ext_emb_size is set by optimizer - if optimizer is not None: + if ConfigInitializer.get_instance().use_dynamic_expansion: table_instance.ext_emb_size = table_instance.emb_size * (1 + optimizer.slot_num) logger.debug("ext_emb_size is reset to be %s for EmbInfo", table_instance.ext_emb_size) skip = should_skip(table_instance.table_name) diff --git a/mx_rec/core/emb/sparse_embedding.py b/mx_rec/core/emb/sparse_embedding.py index a7898db0..d8ce63b1 100644 --- a/mx_rec/core/emb/sparse_embedding.py +++ b/mx_rec/core/emb/sparse_embedding.py @@ -82,30 +82,18 @@ class HBMSparseEmbedding(SparseEmbedding): """ def __init__(self, config: dict): - self.emb_optimizer = EmbOptimizer(config.get("optimizer_list")) - self.emb_optimizer.check_optimizer_instance_list() - super(HBMSparseEmbedding, self).__init__(config) - @property - def optimizer(self): - return self.emb_optimizer.optimizer - - @property - def optimizer_instance_list(self): - return self.emb_optimizer.optimizer_instance_list - def capacity(self) -> int: return self._device_vocabulary_size def set_optimizer(self, key: str, state_dict: dict): - self.emb_optimizer.set_optimizer(key, state_dict, self._table_name) + pass def _build_optimizer_states(self): pass def _set_ext_emb_size(self): - self._ext_coefficient += len(self.emb_optimizer.optimizer_slot_info_list) self._ext_emb_size = self._emb_size * self._ext_coefficient logger.debug("Init table, ext_emb_size is set to be %s.", self._ext_emb_size) diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index c7829145..8343a108 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -53,10 +53,12 @@ def create_hash_optimizer(learning_rate=0.001, if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") - return CustomizedAdagrad(learning_rate=learning_rate, - initial_accumulator_value=initial_accumulator_value, - use_locking=use_locking, - name=name) + optimizer = CustomizedAdagrad(learning_rate=learning_rate, + initial_accumulator_value=initial_accumulator_value, + use_locking=use_locking, + name=name) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): @@ -68,6 +70,7 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): use_locking=False, name="Adagrad"): self.optimizer_type = "Adagrad" + self.optim_param_list = ["accumulator"] super(CustomizedAdagrad, self)._get_name(name=name) super(CustomizedAdagrad, self).__init__(learning_rate=learning_rate, initial_accumulator_value=initial_accumulator_value, @@ -84,11 +87,9 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): accumulator = creat_one_single_slot(var, self._name + "/" + "accumulator") ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accumulator.name) named_slot_key = (var.op.graph, var.op.name) - table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} already exists.") - - table_instance.set_optimizer(self._name, {"accumulator": accumulator}) + table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) + ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, self._name, + {"accumulator": accumulator}) return [{"slot": accumulator, "named_slot_key": named_slot_key, "slot_name": "acc", "optimizer": self}] def insert_slot(self, slot, named_slots_key, slot_name): diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index a25471ac..d6561881 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -55,7 +55,9 @@ def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl", **kwarg if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") - return CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) + optimizer = CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): @@ -63,6 +65,7 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): def __init__(self, learning_rate, use_locking=False, name="Ftrl", **kwargs): self.optimizer_type = "ftrl" + self.optim_param_list = ["accum", "linear"] super(CustomizedFtrl, self)._get_name(name=name) super(CustomizedFtrl, self).__init__( learning_rate=learning_rate, @@ -86,11 +89,9 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accum.name) ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(linear.name) named_slot_key = (var.op.graph, var.op.name) - table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") - - table_instance.set_optimizer(self._name, {"accum": accum, "linear": linear}) + table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) + ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, self._name, + {"accum": accum, "linear": linear}) return [{"slot": accum, "named_slot_key": named_slot_key, "slot_name": "accum", "optimizer": self}, {"slot": linear, "named_slot_key": named_slot_key, "slot_name": "linear", "optimizer": self}] @@ -248,6 +249,8 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accum.name) ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(linear.name) - - if self._name not in table_instance.optimizer: - table_instance.set_optimizer(self._name, {"accum": accum, "linear": linear}) + table_instance = self.config_instance.sparse_embed_config.get_table_instance(each_var) + ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, + self._name, + {"accum": accum, + "linear": linear}) diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 598571c8..48ff71cd 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -41,7 +41,9 @@ def create_hash_optimizer(learning_rate, use_locking=False, name="GradientDescen if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") - return CustomizedGradientDescent(learning_rate=learning_rate, use_locking=use_locking, name=name) + optimizer = CustomizedGradientDescent(learning_rate=learning_rate, use_locking=use_locking, name=name) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, CustomizedOptimizer): @@ -49,6 +51,7 @@ class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, Custo def __init__(self, learning_rate, use_locking=False, name="GradientDescent"): self.optimizer_type = "gradient_descent" + self.optim_param_list = [] super(CustomizedGradientDescent, self)._get_name(name=name) super(CustomizedGradientDescent, self).__init__(learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name) diff --git a/mx_rec/optimizers/gradient_descent_by_addr.py b/mx_rec/optimizers/gradient_descent_by_addr.py index e2de8903..7a91da54 100644 --- a/mx_rec/optimizers/gradient_descent_by_addr.py +++ b/mx_rec/optimizers/gradient_descent_by_addr.py @@ -55,6 +55,7 @@ class CustomizedGradientDescentByAddr(gradient_descent.GradientDescentOptimizer, def __init__(self, learning_rate, weight_decay, use_locking=False, name="GradientDescentByAddr"): self.optimizer_type = "gradient_descent_by_addr" self.weight_decay = weight_decay + self.optim_param_list = [] super(CustomizedGradientDescentByAddr, self)._get_name(name=name) super(CustomizedGradientDescentByAddr, self).__init__(learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name) diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index e4268fed..3155d8d8 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -58,7 +58,9 @@ def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1 if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") - return CustomizedLazyAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) + optimizer = CustomizedLazyAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): @@ -66,6 +68,7 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="LazyAdam"): self.optimizer_type = "LazyAdam" + self.optim_param_list = ["momentum", "velocity"] self.config_instance = ConfigInitializer.get_instance() super(CustomizedLazyAdam, self)._get_name(name=name) super(CustomizedLazyAdam, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, @@ -83,11 +86,10 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): self.config_instance.sparse_embed_config.insert_removing_var_list(momentum.name) self.config_instance.sparse_embed_config.insert_removing_var_list(velocity.name) named_slot_key = (var.op.graph, var.op.name) - table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) - if self._name in table_instance.optimizer: - raise EnvironmentError(f"Sparse optimizer named {self._name} has exists.") - - table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) + table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) + ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, self._name, + {"momentum": momentum, + "velocity": velocity}) return [{"slot": momentum, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}, {"slot": velocity, "named_slot_key": named_slot_key, "slot_name": "v", "optimizer": self}] @@ -196,5 +198,7 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): self.config_instance.sparse_embed_config.insert_removing_var_list(momentum.name) self.config_instance.sparse_embed_config.insert_removing_var_list(velocity.name) - if self._name not in table_instance.optimizer: - table_instance.set_optimizer(self._name, {"momentum": momentum, "velocity": velocity}) + table_instance = self.config_instance.sparse_embed_config.get_table_instance(each_var) + ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, + self._name, {"momentum": momentum, + "velocity": velocity}) diff --git a/mx_rec/optimizers/lazy_adam_by_addr.py b/mx_rec/optimizers/lazy_adam_by_addr.py index c338b592..cfe609e3 100644 --- a/mx_rec/optimizers/lazy_adam_by_addr.py +++ b/mx_rec/optimizers/lazy_adam_by_addr.py @@ -67,6 +67,7 @@ class CustomizedLazyAdamByAddress(adam.AdamOptimizer, CustomizedOptimizer): def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, use_locking=False, name="LazyAdamByAddress"): self.optimizer_type = "LazyAdamByAddress" + self.optim_param_list = ["momentum", "velocity"] super(CustomizedLazyAdamByAddress, self)._get_name(name=name) super(CustomizedLazyAdamByAddress, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, use_locking=use_locking, diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index 1fa47827..de739acb 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -233,8 +233,9 @@ class Saver(object): with tf.compat.v1.variable_scope(table_name): sub_dict = self.save_op_dict[table_name] sub_dict[DataName.EMBEDDING.value] = var - if table_instance.optimizer: - sub_dict["optimizer"] = table_instance.optimizer + optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name(table_name) + if optimizer: + sub_dict["optimizer"] = optimizer def _build_restore(self): for var in self.var_list: @@ -249,14 +250,14 @@ class Saver(object): name=DataName.EMBEDDING.value) assign_op = var.assign(variable) self.restore_fetch_list.append(assign_op) + optimizer = ConfigInitializer.get_instance().optimizer_config.get_optimizer_by_table_name( + table_instance.table_name) + if optimizer: + self._build_optimizer_restore(sub_placeholder_dict, table_instance, optimizer) - if table_instance.optimizer: - self._build_optimizer_restore(sub_placeholder_dict, table_instance) - - def _build_optimizer_restore(self, sub_placeholder_dict, table_instance): + def _build_optimizer_restore(self, sub_placeholder_dict, optimizer): sub_placeholder_dict["optimizer"] = optimizer_placeholder_dict = dict() - optimizer_states = table_instance.optimizer - for optimizer_name, optimizer_state_dict in optimizer_states.items(): + for optimizer_name, optimizer_state_dict in optimizer.items(): optimizer_placeholder_dict[optimizer_name] = sub_optimizer_placeholder_dict = \ dict([(state_key, tf.compat.v1.placeholder(dtype=tf.float32, shape=[table_instance.slice_device_vocabulary_size, diff --git a/mx_rec/util/config_utils/optimizer_utils.py b/mx_rec/util/config_utils/optimizer_utils.py index f1d6a4f9..db8175fc 100644 --- a/mx_rec/util/config_utils/optimizer_utils.py +++ b/mx_rec/util/config_utils/optimizer_utils.py @@ -6,6 +6,13 @@ class OptimizerConfig: def __init__(self): self._optimizer_instance = None + self._table_optimizer_dict = {} + + @property + def optim_params_list(self): + if not self._optimizer_instance: + return [] + return self._optimizer_instance.optim_param_list @property def optimizer_instance(self): @@ -14,3 +21,12 @@ class OptimizerConfig: @optimizer_instance.setter def optimizer_instance(self, optimizer): self._optimizer_instance = optimizer + + def set_optimize_for_table(self, table_name, optimizer_name, optimizer_dict): + if table_name in self._table_optimizer_dict: + raise EnvironmentError(f"sparse embedding table {table_name} has set optimizers.") + self._table_optimizer_dict[table_name] = {optimizer_name: optimizer_dict} + + def get_optimizer_by_table_name(self, table_name): + return self._table_optimizer_dict.get(table_name) + diff --git a/tests/mx_rec/core/test_embedding.py b/tests/mx_rec/core/test_embedding.py index e47b3afd..bf7d9240 100644 --- a/tests/mx_rec/core/test_embedding.py +++ b/tests/mx_rec/core/test_embedding.py @@ -24,7 +24,7 @@ from mx_rec.core.asc import FeatureSpec from mx_rec.core.asc.feature_spec import set_temporary_feature_spec_attribute from mx_rec.core.emb.dynamic_sparse_embedding import HBMDynamicSparseEmbedding from mx_rec.core.emb.sparse_embedding import HBMSparseEmbedding, ExternalStorageSparseEmbedding -from mx_rec.optimizers.gradient_descent import create_hash_optimizer, CustomizedGradientDescent +from mx_rec.optimizers.gradient_descent import create_hash_optimizer from tests.mx_rec.core.mock_class import MockConfigInitializer @@ -84,15 +84,11 @@ class TestCreateTableFunc(unittest.TestCase): base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - # prepare optimizer list - optimizer_list = [CustomizedGradientDescent(learning_rate=0.001, use_locking=False, name="GradientDescent")] - # test test_table = create_table(key_dtype=tf.int64, dim=8, name='test_table', - emb_initializer=tf.compat.v1.truncated_normal_initializer(), - optimizer_list=optimizer_list) + emb_initializer=tf.compat.v1.truncated_normal_initializer()) self.assertIsInstance(test_table, HBMSparseEmbedding) @mock.patch.multiple("mx_rec.core.emb.base_sparse_embedding", @@ -163,9 +159,6 @@ class TestSparseLookupFunc(unittest.TestCase): emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - # prepare optimizer list - optimizer_list = [CustomizedGradientDescent(learning_rate=0.001, use_locking=False, name="GradientDescent")] - case1_feat = FeatureSpec("case1_feat", table_name="test_table") set_temporary_feature_spec_attribute(case1_feat, 1) case1_feat.dims = [8, 8] @@ -186,8 +179,7 @@ class TestSparseLookupFunc(unittest.TestCase): dim=8, name='test_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), - device_vocabulary_size=100 * 8, - optimizer_list=optimizer_list) + device_vocabulary_size=100 * 8) self.assertIsInstance(test_table, HBMSparseEmbedding) res = sparse_lookup(test_table, case1_feat, batch=batch) @@ -224,9 +216,6 @@ class TestSparseLookupFunc(unittest.TestCase): sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) feature_spec_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - # prepare optimizer list - optimizer_list = [CustomizedGradientDescent(learning_rate=0.001, use_locking=False, name="GradientDescent")] - case2_feat = tf.ones(shape=[8, 8], dtype=tf.int64) mock_get_preprocessed_tensor_for_asc.return_value = { "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), @@ -243,8 +232,7 @@ class TestSparseLookupFunc(unittest.TestCase): dim=8, name='test_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), - device_vocabulary_size=100 * 8, - optimizer_list=optimizer_list) + device_vocabulary_size=100 * 8) self.assertIsInstance(test_table, HBMSparseEmbedding) res = sparse_lookup(test_table, case2_feat, modify_graph=True) -- Gitee From fc36632f3f2b1a641433262516c0a739ed92c8df Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 23 Feb 2024 11:07:38 +0800 Subject: [PATCH 542/551] Match-id-738289e4cba314f507ec720f18ac73f08bc19f82 --- mx_rec/core/asc/manager.py | 31 +++++++++++-------- mx_rec/optimizers/ftrl.py | 9 +++++- mx_rec/optimizers/gradient_descent.py | 9 +++++- mx_rec/optimizers/lazy_adam.py | 9 +++++- src/core/emb_table/embedding_dynamic.cpp | 12 ++++---- src/core/emb_table/embedding_table.cpp | 3 +- tests/mx_rec/core/test_manager.py | 39 +++++++++++++++++------- 7 files changed, 77 insertions(+), 35 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index 2f555a56..c57dd07b 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -38,14 +38,17 @@ def generate_table_info_list(): f"of each table `{table_instance_dict.keys()}` is `{is_hbm_list}`.") optimizer = ConfigInitializer.get_instance().optimizer_config.optimizer_instance + if optimizer is None: + raise ValueError("Optimizer should be set in optimizer_config.") # generate table info dangling_table = check_dangling_table() for _, table_instance in ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict.items(): # When dynamic expansion mode, ext_emb_size is set by optimizer - if optimizer is not None: + if ConfigInitializer.get_instance().use_dynamic_expansion: table_instance.ext_emb_size = table_instance.emb_size * (1 + optimizer.slot_num) logger.debug("ext_emb_size is reset to be %s for EmbInfo", table_instance.ext_emb_size) + skip = should_skip(table_instance.table_name) if table_instance.table_name in dangling_table or skip: logger.info("skip table %s: %s which does not need to be provided to the EmbInfo.", @@ -143,19 +146,21 @@ def matched_emb_initializer(tabel_info): def matched_opt_slot_initializers(table_instance): start_index = table_instance.emb_size slot_initializers = [] - logger.debug("matched_opt_slot_initializers, scalar emb size:%s, optimizer_instance_list size:%s", - table_instance.ext_emb_size, len(table_instance.optimizer_instance_list)) - for optimizer in table_instance.optimizer_instance_list: - for slot_init_value in optimizer.get_slot_init_values(): - slot_initializer = InitializeInfo(name="constant_initializer", - start=start_index, - len=table_instance.emb_size, - constant_initializer_info=ConstantInitializerInfo( - constant_val=slot_init_value - )) - slot_initializers.append(slot_initializer) - start_index += table_instance.emb_size + optimizer = ConfigInitializer.get_instance().optimizer_config.optimizer_instance + for slot_init_value in optimizer.get_slot_init_values(): + slot_initializer = InitializeInfo(name="constant_initializer", + start=start_index, + len=table_instance.emb_size, + constant_initializer_info=ConstantInitializerInfo( + constant_val=slot_init_value + )) + slot_initializers.append(slot_initializer) + start_index += table_instance.emb_size + + logger.debug("matched_opt_slot_initializers, ext emb size:%s, optimizer_instance_list size:%s, " + "slot_initializers size:%s", table_instance.ext_emb_size, len(table_instance.optimizer_instance_list), + len(slot_initializers)) return slot_initializers diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index a25471ac..0635779e 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -55,7 +55,9 @@ def create_hash_optimizer(learning_rate, use_locking=False, name="Ftrl", **kwarg if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") - return CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) + optimizer = CustomizedFtrl(learning_rate=learning_rate, use_locking=use_locking, name=name, **kwargs) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): @@ -76,6 +78,11 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): linear_name=kwargs.get("linear_name", None), l2_shrinkage_regularization_strength=kwargs.get("l2_shrinkage_regularization_strength", 0.0) ) + self._slot_num = 2 + + @property + def slot_num(self): + return self._slot_num def initialize_slots(self, var, table_instance): val = constant_op.constant( diff --git a/mx_rec/optimizers/gradient_descent.py b/mx_rec/optimizers/gradient_descent.py index 598571c8..d6d82a3e 100644 --- a/mx_rec/optimizers/gradient_descent.py +++ b/mx_rec/optimizers/gradient_descent.py @@ -41,7 +41,9 @@ def create_hash_optimizer(learning_rate, use_locking=False, name="GradientDescen if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") - return CustomizedGradientDescent(learning_rate=learning_rate, use_locking=use_locking, name=name) + optimizer = CustomizedGradientDescent(learning_rate=learning_rate, use_locking=use_locking, name=name) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, CustomizedOptimizer): @@ -52,6 +54,11 @@ class CustomizedGradientDescent(gradient_descent.GradientDescentOptimizer, Custo super(CustomizedGradientDescent, self)._get_name(name=name) super(CustomizedGradientDescent, self).__init__(learning_rate=learning_rate, use_locking=use_locking, name=self.unique_name) + self._slot_num = 0 + + @property + def slot_num(self): + return self._slot_num def initialize_slots(self, var, table_instance): return [] diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index e4268fed..b41d31f1 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -58,7 +58,9 @@ def create_hash_optimizer(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1 if ConfigInitializer.get_instance().use_dynamic_expansion: raise ValueError("dynamic expansion mode is not compatible with the optimizer, please config dynamic " "expansion mode and optimizer correctly") - return CustomizedLazyAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) + optimizer = CustomizedLazyAdam(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, name=name) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): @@ -70,6 +72,11 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): super(CustomizedLazyAdam, self)._get_name(name=name) super(CustomizedLazyAdam, self).__init__(learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=epsilon, use_locking=use_locking, name=self.unique_name) + self._slot_num = 2 + + @property + def slot_num(self): + return self._slot_num def initialize_slots(self, var, table_instance): # Create slots for the first and second moments. diff --git a/src/core/emb_table/embedding_dynamic.cpp b/src/core/emb_table/embedding_dynamic.cpp index 44fc2d71..e9c49ea9 100644 --- a/src/core/emb_table/embedding_dynamic.cpp +++ b/src/core/emb_table/embedding_dynamic.cpp @@ -87,14 +87,14 @@ int64_t EmbeddingDynamic::GetEmptyEmbeddingAddress() void EmbeddingDynamic::MallocEmbeddingBlock(int embNum) { void *block = nullptr; - aclError ec = aclrtMalloc(&block, embNum * embSize_ * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); + aclError ec = aclrtMalloc(&block, embNum * extEmbSize_ * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST); if (ec != 0) { throw std::bad_alloc(); } RandomInit(block, embNum); memoryList_.push_back(block); for (int i = 0; i < embNum; i++) { - float *embAddr = static_cast(block) + (i * embSize_); + float *embAddr = static_cast(block) + (i * extEmbSize_); embeddingList_.push_back(embAddr); } capacity_ += embNum; @@ -103,16 +103,16 @@ void EmbeddingDynamic::MallocEmbeddingBlock(int embNum) void EmbeddingDynamic::RandomInit(void* addr, size_t embNum) { LOG_INFO("Device GenerateEmbData Start, seed:{}, initializer num: {}", seed_, embInfo_.initializeInfos.size()); - vector hostmem(embNum * embSize_); + vector hostmem(embNum * extEmbSize_); for (const auto& initializeInfo: as_const(embInfo_.initializeInfos)) { for (size_t i = 0; i < embNum; ++i) { - initializeInfo.initializer->GenerateData(&hostmem[i * embSize_], embSize_); + initializeInfo.initializer->GenerateData(&hostmem[i * extEmbSize_], extEmbSize_); } } LOG_INFO("Device GenerateEmbData End, seed:{}", seed_); - aclError ret = aclrtMemcpy(addr, embNum * embSize_ * sizeof(float), - hostmem.data(), embNum * embSize_ * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); + aclError ret = aclrtMemcpy(addr, embNum * extEmbSize_ * sizeof(float), + hostmem.data(), embNum * extEmbSize_ * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_SUCCESS) { LOG_ERROR("aclrtMemcpy failed, ret={}", ret); } diff --git a/src/core/emb_table/embedding_table.cpp b/src/core/emb_table/embedding_table.cpp index 1ec48b34..c7339367 100644 --- a/src/core/emb_table/embedding_table.cpp +++ b/src/core/emb_table/embedding_table.cpp @@ -23,8 +23,7 @@ EmbeddingTable::EmbeddingTable(const EmbInfo& info, const RankInfo& rankInfo, in embSize_(info.embeddingSize), extEmbSize_(info.extEmbeddingSize), embInfo_(info), seed_(inSeed), rankId_(rankInfo.rankId) { - LOG_TRACE("table {} isDynamic = {} embeddingSize {} extSize {}", - name, isDynamic_, embSize_, extEmbSize_); + LOG_INFO("table {} isDynamic = {} embeddingSize {} extSize {}", name, isDynamic_, embSize_, extEmbSize_); } EmbeddingTable::~EmbeddingTable() diff --git a/tests/mx_rec/core/test_manager.py b/tests/mx_rec/core/test_manager.py index b8191249..815ad843 100644 --- a/tests/mx_rec/core/test_manager.py +++ b/tests/mx_rec/core/test_manager.py @@ -30,7 +30,7 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): """ @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") - def test_generate_table_info_list_case1(self, merge_table_config_initializer): + def test_generate_table_info_list_case1(self, manager_config_initializer): """ case1: 一张表开DDR,一张表没开DDR,抛出异常 """ @@ -39,7 +39,7 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): with tf.Graph().as_default(): mock_config_initializer = MockConfigInitializer() - merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) test_table1 = MockSparseEmbedding("test_table1") test_table1.is_hbm = False @@ -57,7 +57,7 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): should_skip=mock.MagicMock(return_value=True), check_dangling_table=mock.MagicMock(return_value=["test_table"])) @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") - def test_generate_table_info_list_case2(self, merge_table_config_initializer): + def test_generate_table_info_list_case2(self, manager_config_initializer): """ case2: test_table是dangling_table,skip为True """ @@ -66,12 +66,16 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): with tf.Graph().as_default(): mock_config_initializer = MockConfigInitializer() - merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) test_table = MockSparseEmbedding("test_table") test_table.host_vocabulary_size = 1 mock_config_initializer.get_instance().sparse_embed_config.table_instance_dict = dict(test_table=test_table) + mock_opt = MockOptimizer() + manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt + test_table.optimizer_instance_list = [mock_opt] + table_info_list = generate_table_info_list() self.assertListEqual(table_info_list, []) @@ -79,7 +83,7 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): should_skip=mock.MagicMock(return_value=True), check_dangling_table=mock.MagicMock(return_value=[])) @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") - def test_generate_table_info_list_case2(self, merge_table_config_initializer): + def test_generate_table_info_list_case3(self, manager_config_initializer): """ case3: test_table不是dangling_table,skip为True """ @@ -88,12 +92,16 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): with tf.Graph().as_default(): mock_config_initializer = MockConfigInitializer() - merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) test_table = MockSparseEmbedding("test_table") test_table.host_vocabulary_size = 1 mock_config_initializer.get_instance().sparse_embed_config.table_instance_dict = dict(test_table=test_table) + mock_opt = MockOptimizer() + manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt + test_table.optimizer_instance_list = [mock_opt] + table_info_list = generate_table_info_list() self.assertListEqual(table_info_list, []) @@ -105,7 +113,7 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): should_skip=mock.MagicMock(return_value=False), check_dangling_table=mock.MagicMock(return_value=[])) @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") - def test_generate_table_info_list_case3(self, merge_table_config_initializer): + def test_generate_table_info_list_case4(self, manager_config_initializer): """ case4: 静态shape,test_table不是dangling_table,skip为False """ @@ -114,7 +122,7 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): with tf.Graph().as_default(): mock_config_initializer = MockConfigInitializer(use_static=True) - merge_table_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) + manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) test_table = MockSparseEmbedding("test_table") test_table.host_vocabulary_size = 8 @@ -129,6 +137,10 @@ class TestGenerateTableInfoListFunc(unittest.TestCase): test_table.ssd_data_path = "" mock_config_initializer.get_instance().sparse_embed_config.table_instance_dict = dict(test_table=test_table) + mock_opt = MockOptimizer() + manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt + test_table.optimizer_instance_list = [mock_opt] + table_info_list = generate_table_info_list() self.assertListEqual(table_info_list, ["test_table_info"]) @@ -310,7 +322,8 @@ class TestMatchedOptSlotInitializersFunc(unittest.TestCase): @mock.patch.multiple("mx_rec.core.asc.manager", InitializeInfo=mock.MagicMock(return_value="slot_initializer")) - def test_matched_opt_slot_initializers(self): + @mock.patch("mx_rec.core.asc.manager.ConfigInitializer") + def test_matched_opt_slot_initializers(self, manager_config_initializer): """ case: test matched_opt_slot_initializers """ @@ -318,10 +331,14 @@ class TestMatchedOptSlotInitializersFunc(unittest.TestCase): from mx_rec.core.asc.manager import matched_opt_slot_initializers with tf.Graph().as_default(): + mock_config_initializer = MockConfigInitializer() + manager_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) table_instance = MockSparseEmbedding("test_table") table_instance.emb_size = 8 - table_instance.ext_emb_size = 8 - table_instance.optimizer_instance_list = [MockOptimizer()] + table_instance.ext_emb_size = 24 + mock_opt = MockOptimizer() + manager_config_initializer.get_instance().optimizer_config.optimizer_instance = mock_opt + table_instance.optimizer_instance_list = [mock_opt] slot_initializers = matched_opt_slot_initializers(table_instance) self.assertListEqual(slot_initializers, ["slot_initializer", "slot_initializer"]) -- Gitee From e714392492323f5d7bfce1a83c7bb1df2b93c522 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 23 Feb 2024 11:13:27 +0800 Subject: [PATCH 543/551] Match-id-b606b068cdeb9aa0d6394be8c027b50fa4211e35 --- mx_rec/core/emb/sparse_embedding.py | 14 +------------ mx_rec/core/feature_process.py | 2 +- tests/mx_rec/core/mock_class.py | 1 - tests/mx_rec/core/test_embedding.py | 20 ++++-------------- tests/mx_rec/core/test_feature_process.py | 25 +++++++++-------------- 5 files changed, 16 insertions(+), 46 deletions(-) diff --git a/mx_rec/core/emb/sparse_embedding.py b/mx_rec/core/emb/sparse_embedding.py index a7898db0..d8ce63b1 100644 --- a/mx_rec/core/emb/sparse_embedding.py +++ b/mx_rec/core/emb/sparse_embedding.py @@ -82,30 +82,18 @@ class HBMSparseEmbedding(SparseEmbedding): """ def __init__(self, config: dict): - self.emb_optimizer = EmbOptimizer(config.get("optimizer_list")) - self.emb_optimizer.check_optimizer_instance_list() - super(HBMSparseEmbedding, self).__init__(config) - @property - def optimizer(self): - return self.emb_optimizer.optimizer - - @property - def optimizer_instance_list(self): - return self.emb_optimizer.optimizer_instance_list - def capacity(self) -> int: return self._device_vocabulary_size def set_optimizer(self, key: str, state_dict: dict): - self.emb_optimizer.set_optimizer(key, state_dict, self._table_name) + pass def _build_optimizer_states(self): pass def _set_ext_emb_size(self): - self._ext_coefficient += len(self.emb_optimizer.optimizer_slot_info_list) self._ext_emb_size = self._emb_size * self._ext_coefficient logger.debug("Init table, ext_emb_size is set to be %s.", self._ext_emb_size) diff --git a/mx_rec/core/feature_process.py b/mx_rec/core/feature_process.py index 2b663ce2..3963f6d5 100644 --- a/mx_rec/core/feature_process.py +++ b/mx_rec/core/feature_process.py @@ -107,5 +107,5 @@ class EvictHook(tf.compat.v1.train.SessionRunHook): if feature_spec.eviction_threshold: logger.debug("_EvictHook - > check and get instance: table_names %s", feature_spec.table_name) self._hash_table_instance[feature_spec.table_name] = \ - ConfigInitializer.get_instance().sparse_embed_cofnig.get_table_instance_by_name( + ConfigInitializer.get_instance().sparse_embed_config.get_table_instance_by_name( feature_spec.table_name) diff --git a/tests/mx_rec/core/mock_class.py b/tests/mx_rec/core/mock_class.py index 01f2e4f3..2eab64c3 100644 --- a/tests/mx_rec/core/mock_class.py +++ b/tests/mx_rec/core/mock_class.py @@ -135,7 +135,6 @@ class MockConfigInitializer: self.train_params_config = MockTrainParamsConfig(**kwargs) self.optimizer_config = OptimizerConfig() self.feature_spec_config = FeatureSpecConfig() - self.sparse_embed_cofnig = SparseEmbedConfig() def get_instance(self): return self diff --git a/tests/mx_rec/core/test_embedding.py b/tests/mx_rec/core/test_embedding.py index e47b3afd..bf7d9240 100644 --- a/tests/mx_rec/core/test_embedding.py +++ b/tests/mx_rec/core/test_embedding.py @@ -24,7 +24,7 @@ from mx_rec.core.asc import FeatureSpec from mx_rec.core.asc.feature_spec import set_temporary_feature_spec_attribute from mx_rec.core.emb.dynamic_sparse_embedding import HBMDynamicSparseEmbedding from mx_rec.core.emb.sparse_embedding import HBMSparseEmbedding, ExternalStorageSparseEmbedding -from mx_rec.optimizers.gradient_descent import create_hash_optimizer, CustomizedGradientDescent +from mx_rec.optimizers.gradient_descent import create_hash_optimizer from tests.mx_rec.core.mock_class import MockConfigInitializer @@ -84,15 +84,11 @@ class TestCreateTableFunc(unittest.TestCase): base_sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - # prepare optimizer list - optimizer_list = [CustomizedGradientDescent(learning_rate=0.001, use_locking=False, name="GradientDescent")] - # test test_table = create_table(key_dtype=tf.int64, dim=8, name='test_table', - emb_initializer=tf.compat.v1.truncated_normal_initializer(), - optimizer_list=optimizer_list) + emb_initializer=tf.compat.v1.truncated_normal_initializer()) self.assertIsInstance(test_table, HBMSparseEmbedding) @mock.patch.multiple("mx_rec.core.emb.base_sparse_embedding", @@ -163,9 +159,6 @@ class TestSparseLookupFunc(unittest.TestCase): emb_validator_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - # prepare optimizer list - optimizer_list = [CustomizedGradientDescent(learning_rate=0.001, use_locking=False, name="GradientDescent")] - case1_feat = FeatureSpec("case1_feat", table_name="test_table") set_temporary_feature_spec_attribute(case1_feat, 1) case1_feat.dims = [8, 8] @@ -186,8 +179,7 @@ class TestSparseLookupFunc(unittest.TestCase): dim=8, name='test_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), - device_vocabulary_size=100 * 8, - optimizer_list=optimizer_list) + device_vocabulary_size=100 * 8) self.assertIsInstance(test_table, HBMSparseEmbedding) res = sparse_lookup(test_table, case1_feat, batch=batch) @@ -224,9 +216,6 @@ class TestSparseLookupFunc(unittest.TestCase): sparse_embedding_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) feature_spec_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) - # prepare optimizer list - optimizer_list = [CustomizedGradientDescent(learning_rate=0.001, use_locking=False, name="GradientDescent")] - case2_feat = tf.ones(shape=[8, 8], dtype=tf.int64) mock_get_preprocessed_tensor_for_asc.return_value = { "restore_vector": tf.ones(shape=[8, 8], dtype=tf.int64), @@ -243,8 +232,7 @@ class TestSparseLookupFunc(unittest.TestCase): dim=8, name='test_table', emb_initializer=tf.compat.v1.truncated_normal_initializer(), - device_vocabulary_size=100 * 8, - optimizer_list=optimizer_list) + device_vocabulary_size=100 * 8) self.assertIsInstance(test_table, HBMSparseEmbedding) res = sparse_lookup(test_table, case2_feat, modify_graph=True) diff --git a/tests/mx_rec/core/test_feature_process.py b/tests/mx_rec/core/test_feature_process.py index 2bac51fc..b8bb0742 100644 --- a/tests/mx_rec/core/test_feature_process.py +++ b/tests/mx_rec/core/test_feature_process.py @@ -63,22 +63,20 @@ class TestAfterRunFuncOfEvictHookClass(TestEvictHookClass): self.evict_var_assert = [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]] @mock.patch("mx_rec.core.feature_process.npu_ops.gen_npu_ops.get_next") - @mock.patch("mx_rec.util.config_utils.embedding_utils.SparseEmbedConfig.get_table_instance_by_name") @mock.patch("mx_rec.core.feature_process.ConfigInitializer") - def test_after_run_case1(self, feature_process_config_initializer, mock_get_table_instance_by_name, mock_get_next): + def test_after_run_case1(self, feature_process_config_initializer, mock_get_next): """ case1: evict_enable为True,python和C++侧正常触发淘汰 """ with tf.Graph().as_default(): - mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) + test_table = MockSparseEmbedding() + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False, var=test_table) feature_process_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) mock_config_initializer.get_instance().feature_spec_config.insert_feature_spec( FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10), True) - test_table = MockSparseEmbedding() mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] - mock_get_table_instance_by_name.return_value = test_table evict_hook = EvictHook(evict_enable=True, evict_time_interval=1) with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: @@ -100,22 +98,21 @@ class TestAfterRunFuncOfEvictHookClass(TestEvictHookClass): self.assertListEqual(evict_variable[:2].tolist(), self.ori_var_assert) @mock.patch("mx_rec.core.feature_process.npu_ops.gen_npu_ops.get_next") - @mock.patch("mx_rec.util.config_utils.embedding_utils.SparseEmbedConfig.get_table_instance_by_name") @mock.patch("mx_rec.core.feature_process.ConfigInitializer") - def test_after_run_case2(self, feature_process_config_initializer, mock_get_table_instance_by_name, mock_get_next): + def test_after_run_case2(self, feature_process_config_initializer, mock_get_next): """ case2: evict_enable为True,C++侧异常 """ with tf.Graph().as_default(): - mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False, trigger_evict=False) + test_table = MockSparseEmbedding() + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False, trigger_evict=False, + var=test_table) feature_process_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) mock_config_initializer.get_instance().feature_spec_config.insert_feature_spec( FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10), True) - test_table = MockSparseEmbedding() mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] - mock_get_table_instance_by_name.return_value = test_table evict_hook = EvictHook(evict_enable=True, evict_time_interval=1) with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: @@ -136,22 +133,20 @@ class TestAfterRunFuncOfEvictHookClass(TestEvictHookClass): self.assertListEqual(evict_variable[8:].tolist(), self.ori_var_assert) @mock.patch("mx_rec.core.feature_process.npu_ops.gen_npu_ops.get_next") - @mock.patch("mx_rec.util.config_utils.embedding_utils.SparseEmbedConfig.get_table_instance_by_name") @mock.patch("mx_rec.core.feature_process.ConfigInitializer") - def test_after_run_case3(self, feature_process_config_initializer, mock_get_table_instance_by_name, mock_get_next): + def test_after_run_case3(self, feature_process_config_initializer, mock_get_next): """ case3: evict_enable为False """ with tf.Graph().as_default(): - mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False) + test_table = MockSparseEmbedding() + mock_config_initializer = MockConfigInitializer(use_dynamic_expansion=False, var=test_table) feature_process_config_initializer.get_instance = mock.Mock(return_value=mock_config_initializer) mock_config_initializer.get_instance().feature_spec_config.insert_feature_spec( FeatureSpec("test_spec", table_name="test_table", access_threshold=5, eviction_threshold=10), True) - test_table = MockSparseEmbedding() mock_get_next.return_value = [tf.constant([8, 9], dtype=tf.int32), tf.constant(2, dtype=tf.int32)] - mock_get_table_instance_by_name.return_value = test_table evict_hook = EvictHook(evict_enable=False, evict_time_interval=1) with tf.compat.v1.train.MonitoredSession(hooks=[evict_hook]) as sess: -- Gitee From b3e161ccdf481e772b6155811804aabfa699a3b5 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 23 Feb 2024 11:26:29 +0800 Subject: [PATCH 544/551] Match-id-d8b297c1b14d79bc3cce7d8478b80b5ebedf8614 --- mx_rec/optimizers/adagrad.py | 4 ++-- mx_rec/optimizers/ftrl.py | 6 +++--- mx_rec/optimizers/lazy_adam.py | 4 ++-- mx_rec/saver/saver.py | 2 +- mx_rec/util/config_utils/optimizer_utils.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mx_rec/optimizers/adagrad.py b/mx_rec/optimizers/adagrad.py index 8343a108..2c44f1ee 100644 --- a/mx_rec/optimizers/adagrad.py +++ b/mx_rec/optimizers/adagrad.py @@ -87,8 +87,8 @@ class CustomizedAdagrad(adagrad.AdagradOptimizer, CustomizedOptimizer): accumulator = creat_one_single_slot(var, self._name + "/" + "accumulator") ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accumulator.name) named_slot_key = (var.op.graph, var.op.name) - table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) - ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, self._name, + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) + ConfigInitializer.get_instance().optimizer_config.set_optimizer_for_table(table_instance.table_name, self._name, {"accumulator": accumulator}) return [{"slot": accumulator, "named_slot_key": named_slot_key, "slot_name": "acc", "optimizer": self}] diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index f86da89d..40c0924a 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -95,7 +95,7 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(linear.name) named_slot_key = (var.op.graph, var.op.name) table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) - ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, self._name, + ConfigInitializer.get_instance().optimizer_config.set_optimizer_for_table(table_instance.table_name, self._name, {"accum": accum, "linear": linear}) return [{"slot": accum, "named_slot_key": named_slot_key, "slot_name": "accum", "optimizer": self}, {"slot": linear, "named_slot_key": named_slot_key, "slot_name": "linear", "optimizer": self}] @@ -254,8 +254,8 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accum.name) ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(linear.name) - table_instance = self.config_instance.sparse_embed_config.get_table_instance(each_var) - ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(each_var) + ConfigInitializer.get_instance().optimizer_config.set_optimizer_for_table(table_instance.table_name, self._name, {"accum": accum, "linear": linear}) diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index ddd831d1..4141ba03 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -92,7 +92,7 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): self.config_instance.sparse_embed_config.insert_removing_var_list(velocity.name) named_slot_key = (var.op.graph, var.op.name) table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) - ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, self._name, + ConfigInitializer.get_instance().optimizer_config.set_optimizer_for_table(table_instance.table_name, self._name, {"momentum": momentum, "velocity": velocity}) return [{"slot": momentum, "named_slot_key": named_slot_key, "slot_name": "m", "optimizer": self}, @@ -204,6 +204,6 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): self.config_instance.sparse_embed_config.insert_removing_var_list(velocity.name) table_instance = self.config_instance.sparse_embed_config.get_table_instance(each_var) - ConfigInitializer.get_instance().optimizer_config.set_optimize_for_table(table_instance.table_name, + ConfigInitializer.get_instance().optimizer_config.set_optimizer_for_table(table_instance.table_name, self._name, {"momentum": momentum, "velocity": velocity}) diff --git a/mx_rec/saver/saver.py b/mx_rec/saver/saver.py index de739acb..d6a10d85 100644 --- a/mx_rec/saver/saver.py +++ b/mx_rec/saver/saver.py @@ -255,7 +255,7 @@ class Saver(object): if optimizer: self._build_optimizer_restore(sub_placeholder_dict, table_instance, optimizer) - def _build_optimizer_restore(self, sub_placeholder_dict, optimizer): + def _build_optimizer_restore(self, sub_placeholder_dict, table_instance, optimizer): sub_placeholder_dict["optimizer"] = optimizer_placeholder_dict = dict() for optimizer_name, optimizer_state_dict in optimizer.items(): optimizer_placeholder_dict[optimizer_name] = sub_optimizer_placeholder_dict = \ diff --git a/mx_rec/util/config_utils/optimizer_utils.py b/mx_rec/util/config_utils/optimizer_utils.py index db8175fc..27ed2a45 100644 --- a/mx_rec/util/config_utils/optimizer_utils.py +++ b/mx_rec/util/config_utils/optimizer_utils.py @@ -22,7 +22,7 @@ class OptimizerConfig: def optimizer_instance(self, optimizer): self._optimizer_instance = optimizer - def set_optimize_for_table(self, table_name, optimizer_name, optimizer_dict): + def set_optimizer_for_table(self, table_name, optimizer_name, optimizer_dict): if table_name in self._table_optimizer_dict: raise EnvironmentError(f"sparse embedding table {table_name} has set optimizers.") self._table_optimizer_dict[table_name] = {optimizer_name: optimizer_dict} -- Gitee From 11d835ce0edb9693a72f4100672ad1a59517bb37 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Fri, 23 Feb 2024 15:50:23 +0800 Subject: [PATCH 545/551] Match-id-0c58c2ec4aeb15953ff787613106cb15711c135c --- mx_rec/optimizers/ftrl.py | 6 +----- mx_rec/optimizers/lazy_adam.py | 3 --- mx_rec/util/config_utils/optimizer_utils.py | 2 -- src/core/hybrid_mgmt/hybrid_mgmt.cpp | 3 ++- 4 files changed, 3 insertions(+), 11 deletions(-) diff --git a/mx_rec/optimizers/ftrl.py b/mx_rec/optimizers/ftrl.py index 40c0924a..dd27519a 100644 --- a/mx_rec/optimizers/ftrl.py +++ b/mx_rec/optimizers/ftrl.py @@ -33,7 +33,6 @@ from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import ConfigInitializer -from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, ClassValidator, NumValidator, StringValidator, \ FloatValidator @@ -94,7 +93,7 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(accum.name) ConfigInitializer.get_instance().sparse_embed_config.insert_removing_var_list(linear.name) named_slot_key = (var.op.graph, var.op.name) - table_instance = self.config_instance.sparse_embed_config.get_table_instance(var) + table_instance = ConfigInitializer.get_instance().sparse_embed_config.get_table_instance(var) ConfigInitializer.get_instance().optimizer_config.set_optimizer_for_table(table_instance.table_name, self._name, {"accum": accum, "linear": linear}) return [{"slot": accum, "named_slot_key": named_slot_key, "slot_name": "accum", "optimizer": self}, @@ -246,9 +245,6 @@ class CustomizedFtrl(ftrl.FtrlOptimizer, CustomizedOptimizer): with ops.colocate_with(each_var): val = constant_op.constant( self._initial_accumulator_value, dtype=each_var.dtype, shape=each_var.get_shape()) - - table_instance = check_and_get_config_via_var(each_var, self.optimizer_type) - accum = self._get_or_make_slot(each_var, val, "accum", accum_state_name) linear = self._zeros_slot(each_var, "linear", linear_state_name) # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. diff --git a/mx_rec/optimizers/lazy_adam.py b/mx_rec/optimizers/lazy_adam.py index 4141ba03..ee287a49 100644 --- a/mx_rec/optimizers/lazy_adam.py +++ b/mx_rec/optimizers/lazy_adam.py @@ -32,7 +32,6 @@ from tensorflow.python.training import slot_creator from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.initialize import ConfigInitializer -from mx_rec.util.variable import check_and_get_config_via_var from mx_rec.constants.constants import MAX_INT32 from mx_rec.validator.validator import para_checker_decorator, StringValidator, FloatValidator @@ -195,8 +194,6 @@ class CustomizedLazyAdam(adam.AdamOptimizer, CustomizedOptimizer): m_state_name = self._name + "/" + "momentum" v_state_name = self._name + "/" + "velocity" for each_var in var_list: - table_instance = check_and_get_config_via_var(each_var, self.optimizer_type) - momentum = self._zeros_slot(each_var, "m", m_state_name) velocity = self._zeros_slot(each_var, "v", v_state_name) # make sure sparse optimizer statements will not be saved and restored within tf checkpoint. diff --git a/mx_rec/util/config_utils/optimizer_utils.py b/mx_rec/util/config_utils/optimizer_utils.py index 27ed2a45..5dfc5c52 100644 --- a/mx_rec/util/config_utils/optimizer_utils.py +++ b/mx_rec/util/config_utils/optimizer_utils.py @@ -23,8 +23,6 @@ class OptimizerConfig: self._optimizer_instance = optimizer def set_optimizer_for_table(self, table_name, optimizer_name, optimizer_dict): - if table_name in self._table_optimizer_dict: - raise EnvironmentError(f"sparse embedding table {table_name} has set optimizers.") self._table_optimizer_dict[table_name] = {optimizer_name: optimizer_dict} def get_optimizer_by_table_name(self, table_name): diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index fd5285f3..7867d8da 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -234,7 +234,8 @@ bool HybridMgmt::Save(const string savePath) if (mgmtRankInfo.isDDR) { // DDR模式保存host的emb表以及hashmap LOG_DEBUG(MGMT + "Start host side save: ddr mode hashmap"); - EmbeddingMgmt::Instance()->Save(savePath); + saveData.hostEmbs = hostEmbs->GetHostEmbs(); + saveData.embHashMaps = hostHashMaps->GetHashMaps(); } else { // HBM模式保存最大偏移(真正使用了多少vocab容量),特征到偏移的映射 LOG_DEBUG(MGMT + "Start host side save: no ddr mode hashmap"); -- Gitee From 3c0e9f5a9619fbb2e2ef98a822ea450a6dea6ce4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Mon, 26 Feb 2024 11:37:27 +0800 Subject: [PATCH 546/551] Match-id-05598089a83758ef49fcb46821e3e816a4d65fe4 --- mx_rec/core/asc/manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mx_rec/core/asc/manager.py b/mx_rec/core/asc/manager.py index c57dd07b..2829ab98 100644 --- a/mx_rec/core/asc/manager.py +++ b/mx_rec/core/asc/manager.py @@ -38,14 +38,12 @@ def generate_table_info_list(): f"of each table `{table_instance_dict.keys()}` is `{is_hbm_list}`.") optimizer = ConfigInitializer.get_instance().optimizer_config.optimizer_instance - if optimizer is None: - raise ValueError("Optimizer should be set in optimizer_config.") # generate table info dangling_table = check_dangling_table() for _, table_instance in ConfigInitializer.get_instance().sparse_embed_config.table_instance_dict.items(): # When dynamic expansion mode, ext_emb_size is set by optimizer - if ConfigInitializer.get_instance().use_dynamic_expansion: + if ConfigInitializer.get_instance().use_dynamic_expansion and optimizer: table_instance.ext_emb_size = table_instance.emb_size * (1 + optimizer.slot_num) logger.debug("ext_emb_size is reset to be %s for EmbInfo", table_instance.ext_emb_size) @@ -148,6 +146,8 @@ def matched_opt_slot_initializers(table_instance): slot_initializers = [] optimizer = ConfigInitializer.get_instance().optimizer_config.optimizer_instance + if not optimizer: + return slot_initializers for slot_init_value in optimizer.get_slot_init_values(): slot_initializer = InitializeInfo(name="constant_initializer", start=start_index, -- Gitee From 9090699e7dc4eeea256f8d5ea96ac54d94633615 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 27 Feb 2024 09:23:32 +0800 Subject: [PATCH 547/551] Match-id-2ebc54997469ba9fefef6ab25491d2c9a6ef5187 --- src/core/emb_table/embedding_mgmt.cpp | 19 +++++++++++++++++-- src/core/emb_table/embedding_mgmt.h | 3 +++ src/core/hybrid_mgmt/hybrid_mgmt.cpp | 11 ++++++----- src/core/ssd_cache/cache_manager.cpp | 10 ++++++---- src/core/utils/common.h | 2 +- src/tests/checkpoint/checkpoint_test.cpp | 4 +++- 6 files changed, 36 insertions(+), 13 deletions(-) diff --git a/src/core/emb_table/embedding_mgmt.cpp b/src/core/emb_table/embedding_mgmt.cpp index c40c66f9..c062e147 100644 --- a/src/core/emb_table/embedding_mgmt.cpp +++ b/src/core/emb_table/embedding_mgmt.cpp @@ -134,9 +134,12 @@ std::shared_ptr EmbeddingMgmt::GetTable(const string& name) return std::dynamic_pointer_cast(it->second); } -int EmbeddingMgmt::Load(const string& name, const string& filePath) +int EmbeddingMgmt::Load(const string& filePath) { - return embeddings[name]->Load(filePath); + for (auto& tablePair: embeddings) { + tablePair.second->Load(filePath); + } + return 0; } int EmbeddingMgmt::Save(const string& name, const string& filePath) @@ -163,4 +166,16 @@ void EmbeddingMgmt::EnableSSD() for (auto& table: embeddings) { table.second->EnableSSD(); } +} + +EmbHashMemT EmbeddingMgmt::GetEmbHashMaps() +{ + EmbHashMemT EmbHashMaps; + for (auto& tablePair: embeddings) { + EmbHashMaps[tablePair.first].hostHashMap = tablePair.second ->GetKeyOffsetMap(); + EmbHashMaps[tablePair.first].devVocabSize = tablePair.second ->GetDevVocabSize(); + EmbHashMaps[tablePair.first].hostVocabSize = tablePair.second ->GetHostVocabSize(); + EmbHashMaps[tablePair.first].maxOffset = tablePair.second ->GetMaxOffset(); + } + return EmbHashMaps; } \ No newline at end of file diff --git a/src/core/emb_table/embedding_mgmt.h b/src/core/emb_table/embedding_mgmt.h index efd62774..4b1b8790 100644 --- a/src/core/emb_table/embedding_mgmt.h +++ b/src/core/emb_table/embedding_mgmt.h @@ -80,6 +80,8 @@ public: */ int Load(const string& name, const string& filePath); + int Load(const string& filePath); + /** * 保存单个表 */ @@ -94,6 +96,7 @@ public: void EnableSSD(); + EmbHashMemT GetEmbHashMaps(); private: EmbeddingMgmt(); diff --git a/src/core/hybrid_mgmt/hybrid_mgmt.cpp b/src/core/hybrid_mgmt/hybrid_mgmt.cpp index 7867d8da..c9933478 100644 --- a/src/core/hybrid_mgmt/hybrid_mgmt.cpp +++ b/src/core/hybrid_mgmt/hybrid_mgmt.cpp @@ -234,8 +234,7 @@ bool HybridMgmt::Save(const string savePath) if (mgmtRankInfo.isDDR) { // DDR模式保存host的emb表以及hashmap LOG_DEBUG(MGMT + "Start host side save: ddr mode hashmap"); - saveData.hostEmbs = hostEmbs->GetHostEmbs(); - saveData.embHashMaps = hostHashMaps->GetHashMaps(); + EmbeddingMgmt::Instance()->Save(savePath); } else { // HBM模式保存最大偏移(真正使用了多少vocab容量),特征到偏移的映射 LOG_DEBUG(MGMT + "Start host side save: no ddr mode hashmap"); @@ -293,6 +292,8 @@ bool HybridMgmt::Load(const string& loadPath) vector loadFeatures; SetFeatureTypeForLoad(loadFeatures); + EmbeddingMgmt::Instance()->Load(loadPath); + loadData.hostEmbs = hostEmbs->GetHostEmbs(); // 获取已经初始化好的host emb // 执行加载操作 loadCkpt.LoadModel(loadPath, loadData, mgmtRankInfo, mgmtEmbInfo, loadFeatures); @@ -307,7 +308,8 @@ bool HybridMgmt::Load(const string& loadPath) if (mgmtRankInfo.isDDR) { // DDR模式 将加载的hash map进行赋值 LOG_DEBUG(MGMT + "Start host side load: ddr mode hashmap"); - hostHashMaps->LoadHashMap(loadData.embHashMaps); + auto EmbHashMaps = EmbeddingMgmt::Instance()->GetEmbHashMaps(); + hostHashMaps->LoadHashMap(EmbHashMaps); } else { // HBM模式 将加载的最大偏移(真正使用了多少vocab容量)、特征到偏移的映射,进行赋值 LOG_DEBUG(MGMT + "Start host side load: no ddr mode hashmap"); @@ -348,8 +350,7 @@ void HybridMgmt::SetFeatureTypeForLoad(vector& loadFeatures) } if (mgmtRankInfo.isDDR) { // DDR模式加载的类型为host的emb表以及hashmap - loadFeatures.push_back(CkptFeatureType::HOST_EMB); - loadFeatures.push_back(CkptFeatureType::EMB_HASHMAP); + LOG_DEBUG(MGMT + "set feature ddr"); } else { // HBM模式加载的类型为最大偏移(真正使用了多少vocab容量),特征到偏移的映射 loadFeatures.push_back(CkptFeatureType::KEY_OFFSET_MAP); diff --git a/src/core/ssd_cache/cache_manager.cpp b/src/core/ssd_cache/cache_manager.cpp index c3611d91..69832f0a 100644 --- a/src/core/ssd_cache/cache_manager.cpp +++ b/src/core/ssd_cache/cache_manager.cpp @@ -87,8 +87,8 @@ TransferRet CacheManager::TransferDDREmbWithSSD(TableInfo& table, if (channelId == TRAIN_CHANNEL_ID) { ddrAvailableSize += table.evictHostPos.size(); } - LOG_DEBUG("TransferDDREmbWithSSD, maxOffset:{}, evictHostPos size:{}, ddrAvailableSize:{}", - table.maxOffset, table.evictHostPos.size(), ddrAvailableSize); + LOG_DEBUG("TransferDDREmbWithSSD, table:{}, maxOffset:{}, evictHostPos size:{}, ddrAvailableSize:{}", + table.name, table.maxOffset, table.evictHostPos.size(), ddrAvailableSize); CreateSSDTableIfNotExist(table.name); // 调用ssdEngine查询当前批次key中保存在SSD中的key @@ -193,7 +193,8 @@ void CacheManager::GetDDREmbInfo(vector& keys, TableInfo& table, { // 根据offset 获取对应Emb数据 for (auto& key : keys) { - ddrTransferPos.emplace_back(table.keyOffsetMap[key] - table.devVocabSize); + auto koCast = static_cast(table.keyOffsetMap[key]); + ddrTransferPos.emplace_back(koCast - table.devVocabSize); } LOG_TRACE("DDR keys:{}", VectorToString(keys)); @@ -337,7 +338,8 @@ void CacheManager::HandleDDRTransferPos(vector& ddrTransferPos, vector externalSSDKeys.size()) { while (ddrTransferPos.size() > externalSSDKeys.size()) { - table.evictHostPos.emplace_back(ddrTransferPos.back() + table.devVocabSize); + auto evictHostPos = ddrTransferPos.back() + table.devVocabSize; + table.evictHostPos.emplace_back(static_cast(evictHostPos)); ddrTransferPos.pop_back(); } return; diff --git a/src/core/utils/common.h b/src/core/utils/common.h index 1f5fa3e3..2c4ac8e2 100644 --- a/src/core/utils/common.h +++ b/src/core/utils/common.h @@ -453,7 +453,7 @@ namespace MxRec { }; struct EmbHashMapInfo { - absl::flat_hash_map hostHashMap; // key在HBM中的偏移 + absl::flat_hash_map hostHashMap; // key在HBM中的偏移 std::vector devOffset2Batch; // has -1 std::vector devOffset2Key; size_t currentUpdatePos; diff --git a/src/tests/checkpoint/checkpoint_test.cpp b/src/tests/checkpoint/checkpoint_test.cpp index fd9fbf9d..71dce316 100644 --- a/src/tests/checkpoint/checkpoint_test.cpp +++ b/src/tests/checkpoint/checkpoint_test.cpp @@ -146,7 +146,9 @@ protected: for (const auto& testEmbInfo : testEmbInfos) { SetHashMapInfo(testHash, testDev2B, testDev2K); - embHashInfo.hostHashMap = std::move(testHash); + for (auto ko: testHash) { + embHashInfo.hostHashMap[ko.first] = static_cast(ko.second); + } embHashInfo.devOffset2Batch = move(testDev2B); embHashInfo.devOffset2Key = move(testDev2K); -- Gitee From b4a490e91bd6044a1ea9d1f5d263c838e4a465b6 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 27 Feb 2024 14:26:25 +0800 Subject: [PATCH 548/551] Match-id-c0ad9b340b20f2fcbd05d3a712f72d1a4684b754 --- examples/demo/little_demo/config.py | 4 +-- examples/demo/little_demo/dataset.py | 5 +-- examples/demo/little_demo/main.py | 35 +++++++++++-------- examples/demo/little_demo/optimizer.py | 5 +-- .../demo/little_demo/random_data_generator.py | 5 +-- examples/demo/little_demo/run_mode.py | 30 ++++++++-------- examples/demo/little_demo_estimator/config.py | 2 +- .../demo/little_demo_estimator/dataset.py | 5 +-- examples/demo/little_demo_estimator/main.py | 24 +++++++------ .../demo/little_demo_estimator/nn_optim.py | 17 ++++----- .../demo/little_demo_estimator/nn_reader.py | 8 +++-- .../random_data_generator.py | 2 +- examples/dlrm/model/gradient_descent_w.py | 15 ++++---- 13 files changed, 85 insertions(+), 72 deletions(-) diff --git a/examples/demo/little_demo/config.py b/examples/demo/little_demo/config.py index 0e098a9c..2cc48216 100644 --- a/examples/demo/little_demo/config.py +++ b/examples/demo/little_demo/config.py @@ -15,11 +15,11 @@ # ============================================================================== import math -import tensorflow as tf +import tensorflow as tf from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig -from mx_rec.util.initialize import get_rank_size +from mx_rec.util.communication.hccl_ops import get_rank_size class Config: diff --git a/examples/demo/little_demo/dataset.py b/examples/demo/little_demo/dataset.py index d5ede53f..42c74704 100644 --- a/examples/demo/little_demo/dataset.py +++ b/examples/demo/little_demo/dataset.py @@ -16,8 +16,9 @@ import tensorflow as tf +from mx_rec.util.communication.hccl_ops import get_rank_size, get_rank_id +from mx_rec.util.ops import import_host_pipeline_ops from random_data_generator import get_data_generator, get_large_scale_data_generator -from mx_rec.util.initialize import get_rank_size, get_rank_id, get_host_pipeline_ops def generate_dataset(cfg, use_timestamp=False, batch_number=100): @@ -45,7 +46,7 @@ def generate_dataset(cfg, use_timestamp=False, batch_number=100): def add_timestamp_func(batch): - host_pipeline_ops = get_host_pipeline_ops() + host_pipeline_ops = import_host_pipeline_ops() timestamp = host_pipeline_ops.return_timestamp(tf.cast(batch['label_0'], tf.int64)) batch["timestamp"] = timestamp return batch diff --git a/examples/demo/little_demo/main.py b/examples/demo/little_demo/main.py index 3363c0fb..05c9acaa 100644 --- a/examples/demo/little_demo/main.py +++ b/examples/demo/little_demo/main.py @@ -22,21 +22,24 @@ import warnings from glob import glob import tensorflow as tf -from config import Config -from dataset import generate_dataset -from optimizer import create_dense_and_sparse_optimizer -from model import MyModel -from run_mode import RunMode, UseMode +from mx_rec.constants.constants import ASCEND_TIMESTAMP from mx_rec.core.asc.feature_spec import FeatureSpec from mx_rec.core.asc.helper import get_asc_insert_func from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import create_table, sparse_lookup from mx_rec.graph.modifier import modify_graph_and_start_emb_cache -from mx_rec.constants.constants import ASCEND_TIMESTAMP -from mx_rec.util.initialize import get_rank_id, init, terminate_config_initializer, set_if_load, get_rank_size -from mx_rec.util.variable import get_dense_and_sparse_variable +from mx_rec.util.communication.hccl_ops import get_rank_id, get_rank_size +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.util.initialize import init, terminate_config_initializer from mx_rec.util.log import logger +from mx_rec.util.variable import get_dense_and_sparse_variable + +from config import Config +from dataset import generate_dataset +from model import MyModel +from optimizer import create_dense_and_sparse_optimizer +from run_mode import RunMode, UseMode tf.compat.v1.disable_eager_execution() @@ -201,6 +204,12 @@ if __name__ == "__main__": except ValueError as err: raise ValueError(f"please correctly config MULTI_LOOKUP_TIMES only int is supported.") from err + IF_LOAD = False + rank_id = get_rank_id() + + file_list = glob(f"./saved-model/sparse-model-{rank_id}-*") + if file_list: + IF_LOAD = True # nbatch function needs to be used together with the prefetch and host_vocabulary_size != 0 init(use_mpi=use_mpi, train_steps=TRAIN_STEPS, @@ -208,13 +217,9 @@ if __name__ == "__main__": save_steps=SAVING_INTERVAL, use_dynamic=use_dynamic, use_hot=use_hot, - use_dynamic_expansion=use_dynamic_expansion) - IF_LOAD = False - rank_id = get_rank_id() - file_list = glob(f"./saved-model/sparse-model-{rank_id}-*") - if file_list: - IF_LOAD = True - set_if_load(IF_LOAD) + use_dynamic_expansion=use_dynamic_expansion, + bind_cpu=True, + if_load=IF_LOAD) cfg = Config() # multi lookup config, batch size: 32 * 128 = 4096 diff --git a/examples/demo/little_demo/optimizer.py b/examples/demo/little_demo/optimizer.py index f6eeabdb..68580a29 100644 --- a/examples/demo/little_demo/optimizer.py +++ b/examples/demo/little_demo/optimizer.py @@ -18,13 +18,14 @@ import tensorflow as tf from mx_rec.optimizers.lazy_adam import create_hash_optimizer from mx_rec.optimizers.lazy_adam_by_addr import create_hash_optimizer_by_address -from mx_rec.util.initialize import get_use_dynamic_expansion from mx_rec.util.log import logger +from mx_rec.util.initialize import ConfigInitializer def create_dense_and_sparse_optimizer(cfg): dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate) - use_dynamic_expansion = get_use_dynamic_expansion() + + use_dynamic_expansion = ConfigInitializer.get_instance().use_dynamic_expansion if use_dynamic_expansion: sparse_optimizer = create_hash_optimizer_by_address(learning_rate=cfg.learning_rate) logger.info("optimizer lazy_adam_by_addr") diff --git a/examples/demo/little_demo/random_data_generator.py b/examples/demo/little_demo/random_data_generator.py index 628700d4..8d55e74c 100644 --- a/examples/demo/little_demo/random_data_generator.py +++ b/examples/demo/little_demo/random_data_generator.py @@ -16,9 +16,10 @@ import numpy as np -from mx_rec.util.initialize import get_rank_id +from mx_rec.util.communication.hccl_ops import get_rank_id from mx_rec.util.log import logger + def get_data_generator(config, batch_number): rank_id = get_rank_id() @@ -39,7 +40,7 @@ def get_data_generator(config, batch_number): i += 1 logger.debug(f"================ end of data generator for {config.task_name} task | rank id {rank_id} " - f"================") + f"================") return data_generator diff --git a/examples/demo/little_demo/run_mode.py b/examples/demo/little_demo/run_mode.py index c43ef4a6..4d244795 100644 --- a/examples/demo/little_demo/run_mode.py +++ b/examples/demo/little_demo/run_mode.py @@ -20,14 +20,14 @@ import os import tensorflow as tf from config import sess_config -from mx_rec.util.initialize import get_initializer, get_rank_id, get_rank_size, clear_channel from mx_rec.util.variable import get_dense_and_sparse_variable from mx_rec.util.tf_version_adapter import hccl_ops from mx_rec.constants.constants import BaseEnum from mx_rec.graph.modifier import modify_graph_and_start_emb_cache -from mx_rec.constants.constants import ApplyGradientsStrategy -from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.log import logger +from mx_rec.util.ops import import_host_pipeline_ops +from mx_rec.util.initialize import ConfigInitializer +from mx_rec.util.communication.hccl_ops import get_rank_id, get_rank_size class UseMode(BaseEnum): @@ -60,11 +60,13 @@ class RunMode: def _infer(self): if not self.use_one_shot: - initializer = self.eval_iterator.initializer if not self.is_modify_graph else get_initializer(False) + initializer = self.eval_iterator.initializer if not self.is_modify_graph else \ + ConfigInitializer.get_instance().train_params_config.get_initializer(False) self.session.run(initializer) else: logger.debug(f"use one shot iterator and modify graph is `{self.is_modify_graph}`.") - clear_channel(is_train_channel=False) + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id(False) + import_host_pipeline_ops().clear_channel(channel_id) for i in range(1, self.infer_steps + 1): logger.info("############### infer at step %d ################", i) @@ -91,16 +93,11 @@ class RunMode: self.train_ops.append(dense_optimizer.apply_gradients(avg_grads)) if bool(int(os.getenv("USE_DYNAMIC_EXPANSION", 0))): - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ - ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) - if (ApplyGradientsStrategy.mapping(os.getenv("APPLY_GRADIENTS_STRATEGY")) == - ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY): - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) - else: - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) # do sparse optimization by addr local_grads = tf.gradients(loss, train_emb_list) # local_embedding @@ -120,7 +117,8 @@ class RunMode: modify_graph_and_start_emb_cache(dump_graph=True) if not self.use_one_shot: - initializer = self.train_iterator.initializer if not self.is_modify_graph else get_initializer(True) + initializer = self.train_iterator.initializer if not self.is_modify_graph else \ + ConfigInitializer.get_instance().train_params_config.get_initializer(True) self.session.run(initializer) else: logger.debug(f"use one shot iterator and modify graph is `{self.is_modify_graph}`.") @@ -190,7 +188,7 @@ class RunMode: def change_threshold(self): thres_tensor = tf.constant(60, dtype=tf.int32) - set_threshold_op = ConfigInitializer.get_instance().host_pipeline_ops. \ - set_threshold(thres_tensor, emb_name=self.table_list[0].table_name, - ids_name=self.table_list[0].table_name + "_lookup") + set_threshold_op = import_host_pipeline_ops().set_threshold(thres_tensor, + emb_name=self.table_list[0].table_name, + ids_name=self.table_list[0].table_name + "_lookup") self.session.run([set_threshold_op]) diff --git a/examples/demo/little_demo_estimator/config.py b/examples/demo/little_demo_estimator/config.py index beb18942..9a6cc8c0 100644 --- a/examples/demo/little_demo_estimator/config.py +++ b/examples/demo/little_demo_estimator/config.py @@ -18,7 +18,7 @@ import math import tensorflow as tf -from mx_rec.util.initialize import get_rank_size +from mx_rec.util.communication.hccl_ops import get_rank_size class Config: diff --git a/examples/demo/little_demo_estimator/dataset.py b/examples/demo/little_demo_estimator/dataset.py index c052eca1..a8c6c848 100644 --- a/examples/demo/little_demo_estimator/dataset.py +++ b/examples/demo/little_demo_estimator/dataset.py @@ -18,7 +18,8 @@ import tensorflow as tf from random_data_generator import get_data_generator, get_large_scale_data_generator -from mx_rec.util.initialize import get_rank_size, get_rank_id, get_host_pipeline_ops +from mx_rec.util.communication.hccl_ops import get_rank_size, get_rank_id +from mx_rec.util.ops import import_host_pipeline_ops def generate_dataset(cfg, use_timestamp=False, batch_number=100): @@ -46,7 +47,7 @@ def generate_dataset(cfg, use_timestamp=False, batch_number=100): def add_timestamp_func(batch): - host_pipeline_ops = get_host_pipeline_ops() + host_pipeline_ops = import_host_pipeline_ops() timestamp = host_pipeline_ops.return_timestamp(tf.cast(batch['label_0'], tf.int64)) batch["timestamp"] = timestamp return batch diff --git a/examples/demo/little_demo_estimator/main.py b/examples/demo/little_demo_estimator/main.py index 1a8852fb..a11b7016 100644 --- a/examples/demo/little_demo_estimator/main.py +++ b/examples/demo/little_demo_estimator/main.py @@ -19,7 +19,9 @@ import argparse import os import tensorflow as tf -from mx_rec.util.initialize import init, get_rank_id, terminate_config_initializer + +from mx_rec.util.initialize import init, terminate_config_initializer +from mx_rec.util.communication.hccl_ops import get_rank_id from mx_rec.core.asc.helper import FeatureSpec from mx_rec.graph.modifier import GraphModifierHook from mx_rec.graph.acg_push_ops import ACGPushOpsToDatasetHook @@ -60,7 +62,7 @@ def main(params, cfg): hooks_list = [GraphModifierHook(modify_graph=params.modify_graph)] else: hooks_list = [ACGPushOpsToDatasetHook(dump_graph=True), GraphModifierHook(modify_graph=params.modify_graph)] - + if params.use_timestamp: config_for_user_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) config_for_item_table = dict(access_threshold=cfg.access_threshold, eviction_threshold=cfg.eviction_threshold) @@ -138,19 +140,19 @@ def create_feature_spec_list(use_timestamp=False): if __name__ == '__main__': parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--run_mode', type=str, default='train_and_evaluate') # 运行模式,在run.sh中进行配置 + parser.add_argument('--run_mode', type=str, default='train_and_evaluate') # 运行模式,在run.sh中进行配置 parser.add_argument('--model_ckpt_dir', type=str, default='') parser.add_argument('--learning_rate', type=float, default=0.0008) - parser.add_argument('--use_timestamp', type=bool, default=False) # 是否开启特征准入与淘汰 - parser.add_argument('--modify_graph', type=bool, default=False) # 是否开启自动改图 - parser.add_argument('--use_multi_lookup', type=bool, default=True) # 是否一表多查 - parser.add_argument('--multi_lookup_times', type=int, default=2) # 一表多查次数 - parser.add_argument('--max_steps', type=int, default=200) # train的最大步数 - parser.add_argument('--train_steps', type=int, default=100) # 训练train_steps步后进行eval - parser.add_argument('--eval_steps', type=int, default=10) # 每次eval的步数 + parser.add_argument('--use_timestamp', type=bool, default=False) # 是否开启特征准入与淘汰 + parser.add_argument('--modify_graph', type=bool, default=False) # 是否开启自动改图 + parser.add_argument('--use_multi_lookup', type=bool, default=True) # 是否一表多查 + parser.add_argument('--multi_lookup_times', type=int, default=2) # 一表多查次数 + parser.add_argument('--max_steps', type=int, default=200) # train的最大步数 + parser.add_argument('--train_steps', type=int, default=100) # 训练train_steps步后进行eval + parser.add_argument('--eval_steps', type=int, default=10) # 每次eval的步数 # 每隔step保存一次模型, 若在train_and_evaluate模式, 还会进行eval, 注: 若设为None, NPURunConfig内部会设默认值100 parser.add_argument('--save_checkpoints_steps', type=int, default=200) - parser.add_argument('--use_one_shot', type=bool, default=False) # 是否使用one shot iterator + parser.add_argument('--use_one_shot', type=bool, default=False) # 是否使用one shot iterator args, unknowns = parser.parse_known_args() # get init configuration diff --git a/examples/demo/little_demo_estimator/nn_optim.py b/examples/demo/little_demo_estimator/nn_optim.py index 1bf1ea3c..4438627d 100644 --- a/examples/demo/little_demo_estimator/nn_optim.py +++ b/examples/demo/little_demo_estimator/nn_optim.py @@ -20,7 +20,8 @@ import os import tensorflow as tf from mx_rec.util.tf_version_adapter import hccl_ops -from mx_rec.util.initialize import get_rank_size, get_use_dynamic_expansion +from mx_rec.util.communication.hccl_ops import get_rank_size +from mx_rec.util.initialize import ConfigInitializer from mx_rec.util.variable import get_dense_and_sparse_variable from mx_rec.optimizers.gradient_descent import create_hash_optimizer from mx_rec.optimizers.gradient_descent_by_addr import create_hash_optimizer_by_addr @@ -29,7 +30,7 @@ from mx_rec.util.log import logger def get_dense_and_sparse_optimizer(cfg): dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.learning_rate) - if get_use_dynamic_expansion(): + if ConfigInitializer.get_instance().use_dynamic_expansion: sparse_optimizer = create_hash_optimizer_by_addr(learning_rate=cfg.learning_rate) logger.info("optimizer create_hash_optimizer_by_addr") else: @@ -45,7 +46,7 @@ def get_train_op_list(losses, learning_rate): name = None dense_optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) - use_dynamic_expansion = get_use_dynamic_expansion() + use_dynamic_expansion = ConfigInitializer.get_instance().use_dynamic_expansion if use_dynamic_expansion: sparse_optimizer = create_hash_optimizer_by_addr(learning_rate=learning_rate) else: @@ -72,15 +73,11 @@ def get_train_op_list(losses, learning_rate): # do sparse optimization if use_dynamic_expansion: - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ - ApplyGradientsStrategy, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) - if (ApplyGradientsStrategy.mapping(os.getenv("APPLY_GRADIENTS_STRATEGY")) == - ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY): - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) - else: - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) local_grads = tf.gradients(loss, train_emb_list) # local_embedding grads_and_vars = [(grad, address) for grad, address in zip(local_grads, train_address_list)] diff --git a/examples/demo/little_demo_estimator/nn_reader.py b/examples/demo/little_demo_estimator/nn_reader.py index a13468eb..1ce4b094 100644 --- a/examples/demo/little_demo_estimator/nn_reader.py +++ b/examples/demo/little_demo_estimator/nn_reader.py @@ -15,8 +15,10 @@ # limitations under the License. # ============================================================================== -from mx_rec.util.initialize import get_rank_size, clear_channel +from mx_rec.util.communication.hccl_ops import get_rank_size +from mx_rec.util.ops import import_host_pipeline_ops from mx_rec.core.asc.helper import get_asc_insert_func +from mx_rec.util.initialize import ConfigInitializer from dataset import generate_dataset from utils import FeatureSpecIns, create_feature_spec_list @@ -35,7 +37,9 @@ def input_fn(params, create_fs_params, cfg, is_eval=False, use_one_shot=False): if is_eval: FeatureSpecIns.get_instance().set_eval_feature_spec_list(feature_spec_list) dataset = dataset.map(get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=False)) - clear_channel(is_train_channel=False) + channel_id = ConfigInitializer.get_instance().train_params_config.get_training_mode_channel_id( + False) + import_host_pipeline_ops().clear_channel(channel_id) else: FeatureSpecIns.get_instance().set_train_feature_spec_list(feature_spec_list) dataset = dataset.map(get_asc_insert_func(tgt_key_specs=feature_spec_list, is_training=True)) diff --git a/examples/demo/little_demo_estimator/random_data_generator.py b/examples/demo/little_demo_estimator/random_data_generator.py index 23acf22b..60313f01 100644 --- a/examples/demo/little_demo_estimator/random_data_generator.py +++ b/examples/demo/little_demo_estimator/random_data_generator.py @@ -17,7 +17,7 @@ import numpy as np -from mx_rec.util.initialize import get_rank_id +from mx_rec.util.communication.hccl_ops import get_rank_id from mx_rec.util.log import logger diff --git a/examples/dlrm/model/gradient_descent_w.py b/examples/dlrm/model/gradient_descent_w.py index b66ec67d..f3ae78d7 100644 --- a/examples/dlrm/model/gradient_descent_w.py +++ b/examples/dlrm/model/gradient_descent_w.py @@ -25,13 +25,16 @@ from tensorflow.python.ops import math_ops from tensorflow.python.training import gradient_descent from mx_rec.optimizers.base import CustomizedOptimizer from mx_rec.util.log import logger +from mx_rec.util.initialize import ConfigInitializer def create_hash_optimizer(learning_rate, weight_decay=0.0001, use_locking=False, name="GradientDescent"): - return CustomizedGradientDescentWithWeighDecay(learning_rate=learning_rate, - weight_decay=weight_decay, - use_locking=use_locking, - name=name) + optimizer = CustomizedGradientDescentWithWeighDecay(learning_rate=learning_rate, + weight_decay=weight_decay, + use_locking=use_locking, + name=name) + ConfigInitializer.get_instance().optimizer_config.optimizer_instance = optimizer + return optimizer class CustomizedGradientDescentWithWeighDecay(gradient_descent.GradientDescentOptimizer, CustomizedOptimizer): @@ -64,8 +67,8 @@ class CustomizedGradientDescentWithWeighDecay(gradient_descent.GradientDescentOp if self.weight_decay is None: nd_value = grad.values * math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) else: - nd_value = (grad.values + math_ops.cast(self.weight_decay, var.dtype.base_dtype) * tf.gather(var, grad.indices)) * math_ops.cast( - self._learning_rate_tensor, var.dtype.base_dtype) + nd_value = (grad.values + math_ops.cast(self.weight_decay, var.dtype.base_dtype) * + tf.gather(var, grad.indices)) * math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype) var_update_op = tf.scatter_nd_add(var, nd_indices, -nd_value, use_locking=self._use_locking) return var_update_op -- Gitee From 40988a193d5a62c02b0d1dd61635907740e402c4 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Tue, 27 Feb 2024 19:56:42 +0800 Subject: [PATCH 549/551] Match-id-1f9424dc6ab1893c49585ec3826b1a481f981621 --- examples/dlrm/model/main_mxrec.py | 46 ++++++++++++++----------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/examples/dlrm/model/main_mxrec.py b/examples/dlrm/model/main_mxrec.py index 24a1a1b2..90b57184 100644 --- a/examples/dlrm/model/main_mxrec.py +++ b/examples/dlrm/model/main_mxrec.py @@ -14,11 +14,13 @@ # limitations under the License. # ============================================================================== +import os import time import warnings import random from glob import glob +import tensorflow as tf from sklearn.metrics import roc_auc_score import numpy as np @@ -27,9 +29,9 @@ from mx_rec.core.asc.manager import start_asc_pipeline from mx_rec.core.embedding import create_table, sparse_lookup from mx_rec.core.feature_process import EvictHook from mx_rec.graph.modifier import modify_graph_and_start_emb_cache, GraphModifierHook -from mx_rec.constants.constants import ASCEND_TIMESTAMP, ApplyGradientsStrategy -from mx_rec.util.initialize import get_rank_size, init, clear_channel, get_rank_id, set_if_load, \ - terminate_config_initializer, get_host_pipeline_ops, get_initializer, get_target_batch +from mx_rec.constants.constants import ASCEND_TIMESTAMP +from mx_rec.util.initialize import ConfigInitializer, init, terminate_config_initializer +from mx_rec.util.ops import import_host_pipeline_ops import mx_rec.util as mxrec_util from mx_rec.util.variable import get_dense_and_sparse_variable from mx_rec.util.log import logger @@ -46,8 +48,7 @@ random.seed(shuffle_seed) def add_timestamp_func(batch): - host_pipeline_ops = get_host_pipeline_ops() - timestamp = host_pipeline_ops.return_timestamp(tf.cast(batch['label'], dtype=tf.int64)) + timestamp = import_host_pipeline_ops().return_timestamp(tf.cast(batch['label'], dtype=tf.int64)) # tf.constant(np.random.randint(1,1688109060,1)), tf.int64)) batch["timestamp"] = timestamp return batch @@ -139,11 +140,11 @@ def evaluate(): print("read_test dataset") if not MODIFY_GRAPH_FLAG: eval_label = eval_model.get("label") - sess.run([eval_iterator.initializer, clear_channel(False)]) + sess.run([eval_iterator.initializer]) else: # 在sess run模式下,若还是使用原来batch中的label去sess run,则会出现getnext超时报错,需要使用新数据集中的batch - eval_label = get_target_batch(False).get("label") - sess.run([get_initializer(False), clear_channel(False)]) + eval_label = ConfigInitializer.get_instance().train_params_config.get_target_batch(False).get("label") + sess.run([ConfigInitializer.get_instance().train_params_config.get_initializer(False)]) log_loss_list = [] pred_list = [] label_list = [] @@ -174,9 +175,9 @@ def evaluate(): def evaluate_fix(step): print("read_test dataset evaluate_fix") if not MODIFY_GRAPH_FLAG: - sess.run([eval_iterator.initializer, clear_channel(False)]) + sess.run([eval_iterator.initializer]) else: - sess.run([get_initializer(False), clear_channel(False)]) + sess.run([ConfigInitializer.get_instance().train_params_config.get_initializer(False)]) log_loss_list = [] pred_list = [] label_list = [] @@ -268,14 +269,14 @@ if __name__ == "__main__": use_dynamic = bool(int(os.getenv("USE_DYNAMIC", 0))) logger.info(f"USE_DYNAMIC:{use_dynamic}") - init(use_mpi, rank_id=rank_id, rank_size=rank_size, train_steps=train_steps, eval_steps=eval_steps, + init(train_steps=train_steps, eval_steps=eval_steps, use_dynamic=use_dynamic, use_dynamic_expansion=use_dynamic_expansion) IF_LOAD = False - rank_id = mxrec_util.initialize.get_rank_id() + rank_id = mxrec_util.communication.hccl_ops.get_rank_id() filelist = glob(f"./saved-model/sparse-model-{rank_id}-0") if filelist: IF_LOAD = True - set_if_load(IF_LOAD) + ConfigInitializer.get_instance().if_load = IF_LOAD cfg = Config() feature_spec_list_train = None @@ -300,8 +301,8 @@ if __name__ == "__main__": sparse_optimizer_list = [sparse_optimizer for dense_optimizer, sparse_optimizer in optimizer_list] # note: variance_scaling_initializer only support HBM mode - emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=sparse_hashtable_seed)\ - if cfg.cache_mode != "HBM" or use_dynamic_expansion else\ + emb_initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.05, seed=sparse_hashtable_seed) \ + if cfg.cache_mode != "HBM" or use_dynamic_expansion else \ tf.compat.v1.variance_scaling_initializer(mode="fan_avg", distribution='normal', seed=sparse_hashtable_seed) sparse_hashtable = create_table( key_dtype=cfg.key_type, @@ -322,7 +323,7 @@ if __name__ == "__main__": dense_variables, sparse_variables = get_dense_and_sparse_variable() - rank_size = mxrec_util.initialize.get_rank_size() + rank_size = mxrec_util.communication.hccl_ops.get_rank_size() train_ops = [] # multi task training for loss, (dense_optimizer, sparse_optimizer) in zip([train_model["loss"]], optimizer_list): @@ -338,14 +339,9 @@ if __name__ == "__main__": train_ops.append(dense_optimizer.apply_gradients(avg_grads)) if use_dynamic_expansion: - from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_ID_OFFSET, \ - ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS + from mx_rec.constants.constants import ASCEND_SPARSE_LOOKUP_LOCAL_EMB, ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS - if (ApplyGradientsStrategy.mapping(os.getenv("APPLY_GRADIENTS_STRATEGY")) == - ApplyGradientsStrategy.SUM_SAME_ID_GRADIENTS_AND_APPLY): - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) - else: - train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_ID_OFFSET) + train_address_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_UNIQUE_KEYS) train_emb_list = tf.compat.v1.get_collection(ASCEND_SPARSE_LOOKUP_LOCAL_EMB) # do sparse optimization by addr sparse_grads = sparse_optimizer.compute_gradients(loss, train_emb_list) # local_embedding @@ -388,14 +384,14 @@ if __name__ == "__main__": if not MODIFY_GRAPH_FLAG: sess.run(train_iterator.initializer) else: - sess.run(get_initializer(True)) + sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True)) else: sess = tf.compat.v1.Session(config=sess_config(dump_data=False)) sess.run(tf.compat.v1.global_variables_initializer()) if not MODIFY_GRAPH_FLAG: sess.run(train_iterator.initializer) else: - sess.run(get_initializer(True)) + sess.run(ConfigInitializer.get_instance().train_params_config.get_initializer(True)) epoch = 0 cost_sum = 0 -- Gitee From 1a6c14a381f94196d64f2675749d4a33eac77526 Mon Sep 17 00:00:00 2001 From: mxRecTeam Date: Thu, 29 Feb 2024 15:57:24 +0800 Subject: [PATCH 550/551] Match-id-29c13ce1039686831d878ecc848cc80291579d21 --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6a787dab..1a4ec1d7 100644 --- a/README.md +++ b/README.md @@ -73,8 +73,8 @@ securec是华为开源的安全函数库。下载后: 为了构建多个版本的whl包,编译脚本在python虚拟环境完成对应tensorflow版本的安装。用户可以根据实际情况调整编译脚本,指定tensorflow的安装路径。编译方法: - build/build.sh:执行脚本完成tf1和tf2版本whl包的构建和打包。执行脚本前,请参考build/build_tf1.sh、build/build_tf2.sh创建对应的虚拟环境,在虚拟环境中完成对应tensorflow版本的安装,并修改对应的激活命令。 -- build/build_tf1.sh:执行脚本完成tf1版本whl包的构建,构建成功后,whl包在tf1_whl子目录下。执行脚本前,创建tf1虚拟环境,在虚拟环境中完成tensorflow 1.15.0版本的安装,并修改对应的激活命令。 -- build/build_tf2.sh:执行脚本完成tf2版本whl包的构建,构建成功后,whl包在tf2_whl子目录下。执行脚本前,创建tf2虚拟环境,在虚拟环境中完成tensorflow 2.6.5版本的安装,并修改对应的激活命令。 +- build/build_tf1_with_opensource.sh:执行脚本完成tf1版本whl包的构建,构建成功后,whl包在tf1_whl子目录下。执行脚本前,创建tf1虚拟环境,在虚拟环境中完成tensorflow 1.15.0版本的安装,并修改对应的激活命令。 +- build/build_tf2_with_opensource.sh:执行脚本完成tf2版本whl包的构建,构建成功后,whl包在tf2_whl子目录下。执行脚本前,创建tf2虚拟环境,在虚拟环境中完成tensorflow 2.6.5版本的安装,并修改对应的激活命令。 如需使用动态扩容功能,进入“./cust_op/cust_op_by_addr”目录中。参考以下命令编译并安装动态扩容算子包。 ```shell -- Gitee From 8c5e6ef4a171516987d31da29bd32225a55946b3 Mon Sep 17 00:00:00 2001 From: yxy1684 <2270320041@qq.com> Date: Thu, 29 Feb 2024 18:30:54 +0800 Subject: [PATCH 551/551] =?UTF-8?q?=E5=90=8C=E6=AD=A5=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitee/ISSUE_TEMPLATE.zh-CN.md | 13 ------- .gitee/PULL_REQUEST_TEMPLATE.zh-CN.md | 53 --------------------------- README.en.md | 36 ------------------ 3 files changed, 102 deletions(-) delete mode 100644 .gitee/ISSUE_TEMPLATE.zh-CN.md delete mode 100644 .gitee/PULL_REQUEST_TEMPLATE.zh-CN.md delete mode 100644 README.en.md diff --git a/.gitee/ISSUE_TEMPLATE.zh-CN.md b/.gitee/ISSUE_TEMPLATE.zh-CN.md deleted file mode 100644 index f09d98dd..00000000 --- a/.gitee/ISSUE_TEMPLATE.zh-CN.md +++ /dev/null @@ -1,13 +0,0 @@ -### 该问题是怎么引起的? - - - -### 重现步骤 - - - -### 报错信息 - - - - diff --git a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md b/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md deleted file mode 100644 index 0ed1c31c..00000000 --- a/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md +++ /dev/null @@ -1,53 +0,0 @@ -### 一、内容说明(相关的Issue) - - - -### 二、建议测试周期和提测地址 - 建议测试完成时间:xxxx.xx.xx - 投产上线时间:xxxx.xx.xx - 提测地址:CI环境/压测环境 - 测试账号: - -### 三、变更内容 - * 3.1 关联PR列表 - - * 3.2 数据库和部署说明 - 1. 常规更新 - 2. 重启unicorn - 3. 重启sidekiq - 4. 迁移任务:是否有迁移任务,没有写 "无" - 5. rake脚本:`bundle exec xxx RAILS_ENV = production`;没有写 "无" - - * 3.4 其他技术优化内容(做了什么,变更了什么) - - 重构了 xxxx 代码 - - xxxx 算法优化 - - - * 3.5 废弃通知(什么字段、方法弃用?) - - - - * 3.6 后向不兼容变更(是否有无法向后兼容的变更?) - - - -### 四、研发自测点(自测哪些?冒烟用例全部自测?) - 自测测试结论: - - -### 五、测试关注点(需要提醒QA重点关注的、可能会忽略的地方) - 检查点: - -| 需求名称 | 是否影响xx公共模块 | 是否需要xx功能 | 需求升级是否依赖其他子产品 | -|------|------------|----------|---------------| -| xxx | 否 | 需要 | 不需要 | -| | | | | - - 接口测试: - - 性能测试: - - 并发测试: - - 其他: - diff --git a/README.en.md b/README.en.md deleted file mode 100644 index 15f60ae9..00000000 --- a/README.en.md +++ /dev/null @@ -1,36 +0,0 @@ -# mxRec - -#### Description -{**When you're done, you can delete the content in this README and update the file with details for others getting started with your repository**} - -#### Software Architecture -Software architecture description - -#### Installation - -1. xxxx -2. xxxx -3. xxxx - -#### Instructions - -1. xxxx -2. xxxx -3. xxxx - -#### Contribution - -1. Fork the repository -2. Create Feat_xxx branch -3. Commit your code -4. Create Pull Request - - -#### Gitee Feature - -1. You can use Readme\_XXX.md to support different languages, such as Readme\_en.md, Readme\_zh.md -2. Gitee blog [blog.gitee.com](https://blog.gitee.com) -3. Explore open source project [https://gitee.com/explore](https://gitee.com/explore) -4. The most valuable open source project [GVP](https://gitee.com/gvp) -5. The manual of Gitee [https://gitee.com/help](https://gitee.com/help) -6. The most popular members [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) -- Gitee